diff --git a/src/error.rs b/src/error.rs index 69a6ba8..6f1777e 100644 --- a/src/error.rs +++ b/src/error.rs @@ -11,6 +11,7 @@ pub enum Kind { WebSocketError, NotImplementedError, UsimpProtocolError, + Utf8DecodeError, } #[derive(Copy, Clone, Debug)] @@ -59,6 +60,7 @@ impl Error { Kind::WebSocketError => "WebSocket protocol error", Kind::NotImplementedError => "Not yet implemented", Kind::UsimpProtocolError => "USIMP protocol error", + Kind::Utf8DecodeError => "Unable to decode UTF-8 data", }, } } @@ -88,6 +90,7 @@ impl fmt::Display for Error { Kind::WebSocketError => "websocket protocol error", Kind::NotImplementedError => "not yet implemented", Kind::UsimpProtocolError => "usimp protocol error", + Kind::Utf8DecodeError => "unable to decode utf-8 data", } .to_string(); if let Some(desc) = &self.desc { @@ -143,3 +146,14 @@ impl From for Error { } } } + +impl From for Error { + fn from(error: std::string::FromUtf8Error) -> Self { + Error { + kind: Kind::Utf8DecodeError, + msg: Some("Unable to decode UTF-8 data".to_string()), + desc: Some(error.to_string()), + class: Class::ClientError, + } + } +} diff --git a/src/websocket/handler.rs b/src/websocket/handler.rs index 57b1fd5..4d5c7d9 100644 --- a/src/websocket/handler.rs +++ b/src/websocket/handler.rs @@ -1,80 +1,136 @@ use crate::error::*; use crate::http; -use crate::websocket::FrameHeader; +use crate::websocket::*; use base64; use crypto; use crypto::digest::Digest; -pub fn connection_handler( +pub fn recv_message(client: &mut http::HttpStream) -> Result { + let mut msg: Vec = Vec::new(); + let mut msg_type = 0; + loop { + let header = FrameHeader::from(&mut client.stream)?; + + if msg_type != 0 && header.opcode != 0 { + return Err(Error::new(Kind::WebSocketError, Class::ClientError) + .set_desc("continuation frame expected".to_string())); + } else if header.opcode >= 8 && (!header.fin || header.payload_len >= 126) { + return Err(Error::new(Kind::WebSocketError, Class::ClientError) + .set_desc("invalid control frame".to_string())); + } + + match header.opcode { + 0 => {}, // cont + 1 => {}, // text + 2 => // binary + return Err(Error::new(Kind::UsimpProtocolError, Class::ClientError) + .set_desc("binary frames must not be sent on a usimp connection".to_string())), + 8 => {}, // close + 9 => {}, // ping + 10 => {}, // pong + _ => return Err(Error::new(Kind::WebSocketError, Class::ClientError) + .set_desc("invalid opcode".to_string())), + } + + msg_type = header.opcode; + + // FIXME check payload len and total len + + let mut buf = vec![0u8; header.payload_len() as usize]; + client.stream.read_exact(&mut buf)?; + + if header.mask { + let key: [u8; 4] = [ + (header.masking_key.unwrap() >> 24) as u8, + ((header.masking_key.unwrap() >> 16) & 0xFF) as u8, + ((header.masking_key.unwrap() >> 8) & 0xFF) as u8, + (header.masking_key.unwrap() & 0xFF) as u8, + ]; + for (pos, byte) in buf.iter_mut().enumerate() { + *byte ^= key[pos & 3]; // = pos % 4 + } + } + + msg.append(&mut buf); + + if header.fin { + break + } + } + + match msg_type { + 1 => {Ok(Message::TextMessage(TextMessage { + data: String::from_utf8(msg)? + }))}, + 8 => { + let mut code = None; + let mut reason = None; + + if msg.len() >= 2 { + code = Some(((msg[0] as u16) << 8) | (msg[1] as u16)); + } + + if msg.len() > 2 { + reason = Some(String::from_utf8(msg[2..].to_vec())?); + } + + Ok(Message::CloseMessage(CloseMessage { + code, + reason + })) + }, + 9 => {Ok(Message::PingMessage(PingMessage { + data: String::from_utf8(msg)? + }))}, + 10 => {Ok(Message::PongMessage(PongMessage { + data: String::from_utf8(msg)? + }))}, + _ => panic!("invalid msg_type for websocket") + } +} + +pub fn handshake( client: &mut http::HttpStream, req: &http::Request, - mut res: http::Response, -) { + res: &mut http::Response, +) -> Result<(), Error> { if let http::Method::GET = req.method { } else { res.status(405); res.header.add_field("Allow", "GET"); - return http::error_handler( - client, - res, - Error::new(Kind::WebSocketError, Class::ClientError) - .set_desc("method not allowed".to_string()), - ); + return Err(Error::new(Kind::WebSocketError, Class::ClientError) + .set_desc("method not allowed".to_string())); } if let Some(_) = req.header.find_field("Connection") { if !req.header.field_has_value("Connection", "upgrade") { - return http::error_handler( - client, - res, - Error::new(Kind::WebSocketError, Class::ClientError) - .set_desc("invalid value for header field 'Connection'".to_string()), - ); + return Err(Error::new(Kind::WebSocketError, Class::ClientError) + .set_desc("invalid value for header field 'Connection'".to_string())); } } else { - return http::error_handler( - client, - res, - Error::new(Kind::WebSocketError, Class::ClientError) - .set_desc("unable to find header field 'Connection'".to_string()), - ); + return Err(Error::new(Kind::WebSocketError, Class::ClientError) + .set_desc("unable to find header field 'Connection'".to_string())); } if let Some(upgrade) = req.header.find_field("Upgrade") { if !upgrade.eq_ignore_ascii_case("websocket") { - return http::error_handler( - client, - res, - Error::new(Kind::WebSocketError, Class::ClientError) - .set_desc("invalid value for header field 'Upgrade'".to_string()), - ); + return Err(Error::new(Kind::WebSocketError, Class::ClientError) + .set_desc("invalid value for header field 'Upgrade'".to_string())); } } else { - return http::error_handler( - client, - res, - Error::new(Kind::WebSocketError, Class::ClientError) - .set_desc("unable to find header field 'Upgrade'".to_string()), - ); + return Err(Error::new(Kind::WebSocketError, Class::ClientError) + .set_desc("unable to find header field 'Upgrade'".to_string())); } if let Some(version) = req.header.find_field("Sec-WebSocket-Version") { if !version.eq("13") { - return http::error_handler( - client, - res, - Error::new(Kind::WebSocketError, Class::ClientError) - .set_desc("invalid value for header field 'Sec-WebSocket-Key'".to_string()), - ); + return Err(Error::new(Kind::WebSocketError, Class::ClientError) + .set_desc("invalid value for header field 'Sec-WebSocket-Key'".to_string())); } } else { - return http::error_handler( - client, - res, - Error::new(Kind::WebSocketError, Class::ClientError) - .set_desc("unable to find header field 'Sec-WebSocket-Version'".to_string()), - ); + return Err(Error::new(Kind::WebSocketError, Class::ClientError) + .set_desc("unable to find header field 'Sec-WebSocket-Version'".to_string())); } if let Some(key) = req.header.find_field("Sec-WebSocket-Key") { @@ -87,12 +143,8 @@ pub fn connection_handler( let key = base64::encode(result); res.header.add_field("Sec-WebSocket-Accept", key.as_str()); } else { - return http::error_handler( - client, - res, - Error::new(Kind::WebSocketError, Class::ClientError) - .set_desc("unable to find header field 'Sec-WebSocket-Key'".to_string()), - ); + return Err(Error::new(Kind::WebSocketError, Class::ClientError) + .set_desc("unable to find header field 'Sec-WebSocket-Key'".to_string())); } client.server_keep_alive = false; @@ -100,35 +152,36 @@ pub fn connection_handler( res.header.add_field("Upgrade", "websocket"); res.status(101); - res.send(&mut client.stream).unwrap(); + res.send(&mut client.stream)?; + + Ok(()) +} + +pub fn connection_handler( + client: &mut http::HttpStream, + req: &http::Request, + mut res: http::Response, +) { + if let Err(error) = handshake(client, req, &mut res) { + return http::error_handler(client, res, error); + } loop { - let header = FrameHeader::from(&mut client.stream).unwrap(); - if header.mask { - println!("Mask: {:X}", header.masking_key.unwrap()); - } + let msg = recv_message(client).unwrap(); + match msg { + Message::TextMessage(msg) => { + println!("Data: {}", msg.data); + }, + Message::CloseMessage(msg) => { - let mut buf = [0u8; 8192]; - client - .stream - .read_exact(&mut buf[..header.payload_len() as usize]) - .unwrap(); - - if header.mask { - let key: [u8; 4] = [ - (header.masking_key.unwrap() >> 24) as u8, - ((header.masking_key.unwrap() >> 16) & 0xFF) as u8, - ((header.masking_key.unwrap() >> 8) & 0xFF) as u8, - (header.masking_key.unwrap() & 0xFF) as u8, - ]; - for (pos, byte) in buf.iter_mut().enumerate() { - *byte ^= key[pos % 4]; + return + } + Message::PingMessage(msg) => { + // TODO send pong + }, + Message::PongMessage(msg) => { + // TODO something } } - - println!( - "Msg: {}", - String::from_utf8_lossy(&buf[..header.payload_len() as usize]) - ); } } diff --git a/src/websocket/mod.rs b/src/websocket/mod.rs index f024b16..5fadad0 100644 --- a/src/websocket/mod.rs +++ b/src/websocket/mod.rs @@ -21,6 +21,30 @@ pub struct FrameHeader { masking_key: Option, } +pub enum Message { + PingMessage(PingMessage), + PongMessage(PongMessage), + CloseMessage(CloseMessage), + TextMessage(TextMessage), +} + +pub struct PingMessage { + data: String, +} + +pub struct PongMessage { + data: String, +} + +pub struct CloseMessage { + code: Option, + reason: Option, +} + +pub struct TextMessage { + data: String, +} + impl FrameHeader { pub fn from(socket: &mut http::Stream) -> Result { let mut data = [0u8; 2];