bit cleaner websocket implementation
This commit is contained in:
14
src/error.rs
14
src/error.rs
@ -11,6 +11,7 @@ pub enum Kind {
|
|||||||
WebSocketError,
|
WebSocketError,
|
||||||
NotImplementedError,
|
NotImplementedError,
|
||||||
UsimpProtocolError,
|
UsimpProtocolError,
|
||||||
|
Utf8DecodeError,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Copy, Clone, Debug)]
|
#[derive(Copy, Clone, Debug)]
|
||||||
@ -59,6 +60,7 @@ impl Error {
|
|||||||
Kind::WebSocketError => "WebSocket protocol error",
|
Kind::WebSocketError => "WebSocket protocol error",
|
||||||
Kind::NotImplementedError => "Not yet implemented",
|
Kind::NotImplementedError => "Not yet implemented",
|
||||||
Kind::UsimpProtocolError => "USIMP protocol error",
|
Kind::UsimpProtocolError => "USIMP protocol error",
|
||||||
|
Kind::Utf8DecodeError => "Unable to decode UTF-8 data",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -88,6 +90,7 @@ impl fmt::Display for Error {
|
|||||||
Kind::WebSocketError => "websocket protocol error",
|
Kind::WebSocketError => "websocket protocol error",
|
||||||
Kind::NotImplementedError => "not yet implemented",
|
Kind::NotImplementedError => "not yet implemented",
|
||||||
Kind::UsimpProtocolError => "usimp protocol error",
|
Kind::UsimpProtocolError => "usimp protocol error",
|
||||||
|
Kind::Utf8DecodeError => "unable to decode utf-8 data",
|
||||||
}
|
}
|
||||||
.to_string();
|
.to_string();
|
||||||
if let Some(desc) = &self.desc {
|
if let Some(desc) = &self.desc {
|
||||||
@ -143,3 +146,14 @@ impl From<r2d2_postgres::postgres::Error> for Error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<std::string::FromUtf8Error> for Error {
|
||||||
|
fn from(error: std::string::FromUtf8Error) -> Self {
|
||||||
|
Error {
|
||||||
|
kind: Kind::Utf8DecodeError,
|
||||||
|
msg: Some("Unable to decode UTF-8 data".to_string()),
|
||||||
|
desc: Some(error.to_string()),
|
||||||
|
class: Class::ClientError,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,80 +1,136 @@
|
|||||||
use crate::error::*;
|
use crate::error::*;
|
||||||
use crate::http;
|
use crate::http;
|
||||||
|
|
||||||
use crate::websocket::FrameHeader;
|
use crate::websocket::*;
|
||||||
use base64;
|
use base64;
|
||||||
use crypto;
|
use crypto;
|
||||||
use crypto::digest::Digest;
|
use crypto::digest::Digest;
|
||||||
|
|
||||||
pub fn connection_handler(
|
pub fn recv_message(client: &mut http::HttpStream) -> Result<Message, Error> {
|
||||||
|
let mut msg: Vec<u8> = Vec::new();
|
||||||
|
let mut msg_type = 0;
|
||||||
|
loop {
|
||||||
|
let header = FrameHeader::from(&mut client.stream)?;
|
||||||
|
|
||||||
|
if msg_type != 0 && header.opcode != 0 {
|
||||||
|
return Err(Error::new(Kind::WebSocketError, Class::ClientError)
|
||||||
|
.set_desc("continuation frame expected".to_string()));
|
||||||
|
} else if header.opcode >= 8 && (!header.fin || header.payload_len >= 126) {
|
||||||
|
return Err(Error::new(Kind::WebSocketError, Class::ClientError)
|
||||||
|
.set_desc("invalid control frame".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
match header.opcode {
|
||||||
|
0 => {}, // cont
|
||||||
|
1 => {}, // text
|
||||||
|
2 => // binary
|
||||||
|
return Err(Error::new(Kind::UsimpProtocolError, Class::ClientError)
|
||||||
|
.set_desc("binary frames must not be sent on a usimp connection".to_string())),
|
||||||
|
8 => {}, // close
|
||||||
|
9 => {}, // ping
|
||||||
|
10 => {}, // pong
|
||||||
|
_ => return Err(Error::new(Kind::WebSocketError, Class::ClientError)
|
||||||
|
.set_desc("invalid opcode".to_string())),
|
||||||
|
}
|
||||||
|
|
||||||
|
msg_type = header.opcode;
|
||||||
|
|
||||||
|
// FIXME check payload len and total len
|
||||||
|
|
||||||
|
let mut buf = vec![0u8; header.payload_len() as usize];
|
||||||
|
client.stream.read_exact(&mut buf)?;
|
||||||
|
|
||||||
|
if header.mask {
|
||||||
|
let key: [u8; 4] = [
|
||||||
|
(header.masking_key.unwrap() >> 24) as u8,
|
||||||
|
((header.masking_key.unwrap() >> 16) & 0xFF) as u8,
|
||||||
|
((header.masking_key.unwrap() >> 8) & 0xFF) as u8,
|
||||||
|
(header.masking_key.unwrap() & 0xFF) as u8,
|
||||||
|
];
|
||||||
|
for (pos, byte) in buf.iter_mut().enumerate() {
|
||||||
|
*byte ^= key[pos & 3]; // = pos % 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
msg.append(&mut buf);
|
||||||
|
|
||||||
|
if header.fin {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match msg_type {
|
||||||
|
1 => {Ok(Message::TextMessage(TextMessage {
|
||||||
|
data: String::from_utf8(msg)?
|
||||||
|
}))},
|
||||||
|
8 => {
|
||||||
|
let mut code = None;
|
||||||
|
let mut reason = None;
|
||||||
|
|
||||||
|
if msg.len() >= 2 {
|
||||||
|
code = Some(((msg[0] as u16) << 8) | (msg[1] as u16));
|
||||||
|
}
|
||||||
|
|
||||||
|
if msg.len() > 2 {
|
||||||
|
reason = Some(String::from_utf8(msg[2..].to_vec())?);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Message::CloseMessage(CloseMessage {
|
||||||
|
code,
|
||||||
|
reason
|
||||||
|
}))
|
||||||
|
},
|
||||||
|
9 => {Ok(Message::PingMessage(PingMessage {
|
||||||
|
data: String::from_utf8(msg)?
|
||||||
|
}))},
|
||||||
|
10 => {Ok(Message::PongMessage(PongMessage {
|
||||||
|
data: String::from_utf8(msg)?
|
||||||
|
}))},
|
||||||
|
_ => panic!("invalid msg_type for websocket")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn handshake(
|
||||||
client: &mut http::HttpStream,
|
client: &mut http::HttpStream,
|
||||||
req: &http::Request,
|
req: &http::Request,
|
||||||
mut res: http::Response,
|
res: &mut http::Response,
|
||||||
) {
|
) -> Result<(), Error> {
|
||||||
if let http::Method::GET = req.method {
|
if let http::Method::GET = req.method {
|
||||||
} else {
|
} else {
|
||||||
res.status(405);
|
res.status(405);
|
||||||
res.header.add_field("Allow", "GET");
|
res.header.add_field("Allow", "GET");
|
||||||
return http::error_handler(
|
return Err(Error::new(Kind::WebSocketError, Class::ClientError)
|
||||||
client,
|
.set_desc("method not allowed".to_string()));
|
||||||
res,
|
|
||||||
Error::new(Kind::WebSocketError, Class::ClientError)
|
|
||||||
.set_desc("method not allowed".to_string()),
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(_) = req.header.find_field("Connection") {
|
if let Some(_) = req.header.find_field("Connection") {
|
||||||
if !req.header.field_has_value("Connection", "upgrade") {
|
if !req.header.field_has_value("Connection", "upgrade") {
|
||||||
return http::error_handler(
|
return Err(Error::new(Kind::WebSocketError, Class::ClientError)
|
||||||
client,
|
.set_desc("invalid value for header field 'Connection'".to_string()));
|
||||||
res,
|
|
||||||
Error::new(Kind::WebSocketError, Class::ClientError)
|
|
||||||
.set_desc("invalid value for header field 'Connection'".to_string()),
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return http::error_handler(
|
return Err(Error::new(Kind::WebSocketError, Class::ClientError)
|
||||||
client,
|
.set_desc("unable to find header field 'Connection'".to_string()));
|
||||||
res,
|
|
||||||
Error::new(Kind::WebSocketError, Class::ClientError)
|
|
||||||
.set_desc("unable to find header field 'Connection'".to_string()),
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(upgrade) = req.header.find_field("Upgrade") {
|
if let Some(upgrade) = req.header.find_field("Upgrade") {
|
||||||
if !upgrade.eq_ignore_ascii_case("websocket") {
|
if !upgrade.eq_ignore_ascii_case("websocket") {
|
||||||
return http::error_handler(
|
return Err(Error::new(Kind::WebSocketError, Class::ClientError)
|
||||||
client,
|
.set_desc("invalid value for header field 'Upgrade'".to_string()));
|
||||||
res,
|
|
||||||
Error::new(Kind::WebSocketError, Class::ClientError)
|
|
||||||
.set_desc("invalid value for header field 'Upgrade'".to_string()),
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return http::error_handler(
|
return Err(Error::new(Kind::WebSocketError, Class::ClientError)
|
||||||
client,
|
.set_desc("unable to find header field 'Upgrade'".to_string()));
|
||||||
res,
|
|
||||||
Error::new(Kind::WebSocketError, Class::ClientError)
|
|
||||||
.set_desc("unable to find header field 'Upgrade'".to_string()),
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(version) = req.header.find_field("Sec-WebSocket-Version") {
|
if let Some(version) = req.header.find_field("Sec-WebSocket-Version") {
|
||||||
if !version.eq("13") {
|
if !version.eq("13") {
|
||||||
return http::error_handler(
|
return Err(Error::new(Kind::WebSocketError, Class::ClientError)
|
||||||
client,
|
.set_desc("invalid value for header field 'Sec-WebSocket-Key'".to_string()));
|
||||||
res,
|
|
||||||
Error::new(Kind::WebSocketError, Class::ClientError)
|
|
||||||
.set_desc("invalid value for header field 'Sec-WebSocket-Key'".to_string()),
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return http::error_handler(
|
return Err(Error::new(Kind::WebSocketError, Class::ClientError)
|
||||||
client,
|
.set_desc("unable to find header field 'Sec-WebSocket-Version'".to_string()));
|
||||||
res,
|
|
||||||
Error::new(Kind::WebSocketError, Class::ClientError)
|
|
||||||
.set_desc("unable to find header field 'Sec-WebSocket-Version'".to_string()),
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(key) = req.header.find_field("Sec-WebSocket-Key") {
|
if let Some(key) = req.header.find_field("Sec-WebSocket-Key") {
|
||||||
@ -87,12 +143,8 @@ pub fn connection_handler(
|
|||||||
let key = base64::encode(result);
|
let key = base64::encode(result);
|
||||||
res.header.add_field("Sec-WebSocket-Accept", key.as_str());
|
res.header.add_field("Sec-WebSocket-Accept", key.as_str());
|
||||||
} else {
|
} else {
|
||||||
return http::error_handler(
|
return Err(Error::new(Kind::WebSocketError, Class::ClientError)
|
||||||
client,
|
.set_desc("unable to find header field 'Sec-WebSocket-Key'".to_string()));
|
||||||
res,
|
|
||||||
Error::new(Kind::WebSocketError, Class::ClientError)
|
|
||||||
.set_desc("unable to find header field 'Sec-WebSocket-Key'".to_string()),
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
client.server_keep_alive = false;
|
client.server_keep_alive = false;
|
||||||
@ -100,35 +152,36 @@ pub fn connection_handler(
|
|||||||
res.header.add_field("Upgrade", "websocket");
|
res.header.add_field("Upgrade", "websocket");
|
||||||
|
|
||||||
res.status(101);
|
res.status(101);
|
||||||
res.send(&mut client.stream).unwrap();
|
res.send(&mut client.stream)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn connection_handler(
|
||||||
|
client: &mut http::HttpStream,
|
||||||
|
req: &http::Request,
|
||||||
|
mut res: http::Response,
|
||||||
|
) {
|
||||||
|
if let Err(error) = handshake(client, req, &mut res) {
|
||||||
|
return http::error_handler(client, res, error);
|
||||||
|
}
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let header = FrameHeader::from(&mut client.stream).unwrap();
|
let msg = recv_message(client).unwrap();
|
||||||
if header.mask {
|
match msg {
|
||||||
println!("Mask: {:X}", header.masking_key.unwrap());
|
Message::TextMessage(msg) => {
|
||||||
}
|
println!("Data: {}", msg.data);
|
||||||
|
},
|
||||||
|
Message::CloseMessage(msg) => {
|
||||||
|
|
||||||
let mut buf = [0u8; 8192];
|
return
|
||||||
client
|
}
|
||||||
.stream
|
Message::PingMessage(msg) => {
|
||||||
.read_exact(&mut buf[..header.payload_len() as usize])
|
// TODO send pong
|
||||||
.unwrap();
|
},
|
||||||
|
Message::PongMessage(msg) => {
|
||||||
if header.mask {
|
// TODO something
|
||||||
let key: [u8; 4] = [
|
|
||||||
(header.masking_key.unwrap() >> 24) as u8,
|
|
||||||
((header.masking_key.unwrap() >> 16) & 0xFF) as u8,
|
|
||||||
((header.masking_key.unwrap() >> 8) & 0xFF) as u8,
|
|
||||||
(header.masking_key.unwrap() & 0xFF) as u8,
|
|
||||||
];
|
|
||||||
for (pos, byte) in buf.iter_mut().enumerate() {
|
|
||||||
*byte ^= key[pos % 4];
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
println!(
|
|
||||||
"Msg: {}",
|
|
||||||
String::from_utf8_lossy(&buf[..header.payload_len() as usize])
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -21,6 +21,30 @@ pub struct FrameHeader {
|
|||||||
masking_key: Option<u32>,
|
masking_key: Option<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub enum Message {
|
||||||
|
PingMessage(PingMessage),
|
||||||
|
PongMessage(PongMessage),
|
||||||
|
CloseMessage(CloseMessage),
|
||||||
|
TextMessage(TextMessage),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct PingMessage {
|
||||||
|
data: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct PongMessage {
|
||||||
|
data: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct CloseMessage {
|
||||||
|
code: Option<u16>,
|
||||||
|
reason: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct TextMessage {
|
||||||
|
data: String,
|
||||||
|
}
|
||||||
|
|
||||||
impl FrameHeader {
|
impl FrameHeader {
|
||||||
pub fn from(socket: &mut http::Stream) -> Result<Self, Error> {
|
pub fn from(socket: &mut http::Stream) -> Result<Self, Error> {
|
||||||
let mut data = [0u8; 2];
|
let mut data = [0u8; 2];
|
||||||
|
Reference in New Issue
Block a user