diff --git a/Cargo.toml b/Cargo.toml index 8861a68..56521e3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,3 +16,5 @@ flate2 = "1.0.0" r2d2 = "0.8.9" r2d2_postgres = "0.18.0" ansi_term = "0.12" +sha1 = "0.6.0" +base64 = "0.13.0" diff --git a/src/error.rs b/src/error.rs index 2d5e181..d522938 100644 --- a/src/error.rs +++ b/src/error.rs @@ -8,6 +8,7 @@ pub enum Kind { DatabaseError, HttpRequestParseError, IoError, + WebSocketError, } #[derive(Copy, Clone, Debug)] @@ -53,6 +54,7 @@ impl Error { Kind::DatabaseError => "Database error", Kind::HttpRequestParseError => "Unable to parse http request", Kind::IoError => "IO error", + Kind::WebSocketError => "WebSocket protocol error", }, } } @@ -79,6 +81,7 @@ impl fmt::Display for Error { Kind::DatabaseError => "database error", Kind::HttpRequestParseError => "unable to parse http request", Kind::IoError => "io error", + Kind::WebSocketError => "websocket protocol error", } .to_string(); if let Some(desc) = &self.desc { diff --git a/src/http/handler.rs b/src/http/handler.rs index 4bf6585..3c55c41 100644 --- a/src/http/handler.rs +++ b/src/http/handler.rs @@ -45,14 +45,19 @@ fn request_handler(client: &mut super::HttpStream) { res.status(404); } else if req.uri.eq("/") { res.status(200); - } else if req.uri.eq("/_usimp/websocket") { - return websocket::connection_handler(client, &req); } else if req.uri.starts_with("/_usimp/") { + res.add_header("Cache-Control", "no-store"); + res.add_header("Access-Control-Allow-Origin", "*"); + + if req.uri.eq("/_usimp/websocket") { + return websocket::connection_handler(client, &req, res); + } + let parts: Vec<&str> = req.uri.split('/').collect(); match parts[2..] { ["entity", entity] => res.status(501), [endpoint] => match req.method { - Method::POST => return endpoint_handler(client, &req, &mut res, endpoint), + Method::POST => return endpoint_handler(client, &req, res, endpoint), _ => { res.status(405); res.add_header("Allow", "POST"); @@ -76,55 +81,53 @@ fn request_handler(client: &mut super::HttpStream) { } } - if let Err(e) = res.send(&mut client.stream) { + if let Err(e) = res.send_default(&mut client.stream) { println!("Unable to send: {}", e); client.server_keep_alive = false; } client.server_keep_alive = false; } +pub fn error_handler(client: &mut super::HttpStream, mut res: super::Response, error: Error) { + println!("{}", error.to_string()); + res.status(match &error.class() { + Class::ClientError => 400, + Class::ServerError => 500, + }); + res.error_info(error.to_string()); + + let mut obj = serde_json::Value::Object(serde_json::Map::new()); + obj["status"] = serde_json::Value::String("error".to_string()); + obj["message"] = serde_json::Value::String(error.to_string()); + let buf = obj.to_string() + "\r\n"; + + let length = buf.as_bytes().len(); + res.add_header("Content-Length", length.to_string().as_str()); + res.add_header("Content-Type", "application/json; charset=utf-8"); + + if let Err(e) = res.send(&mut client.stream) { + println!("Unable to send: {}", e); + client.server_keep_alive = false; + } + + client.stream.write_all(buf.as_bytes()).unwrap(); +} + fn endpoint_handler( client: &mut super::HttpStream, req: &super::Request, - res: &mut super::Response, + mut res: super::Response, endpoint: &str, ) { - res.add_header("Cache-Control", "no-store"); - res.add_header("Access-Control-Allow-Origin", "*"); - - let mut error = |error: Error, client: &mut super::HttpStream| { - println!("{}", error.to_string()); - res.status(match &error.class() { - Class::ClientError => 400, - Class::ServerError => 500, - }); - res.error_info(error.to_string()); - - let mut obj = serde_json::Value::Object(serde_json::Map::new()); - obj["status"] = serde_json::Value::String("error".to_string()); - obj["message"] = serde_json::Value::String(error.to_string()); - let buf = obj.to_string() + "\r\n"; - - let length = buf.as_bytes().len(); - res.add_header("Content-Length", length.to_string().as_str()); - res.add_header("Content-Type", "application/json; charset=utf-8"); - - if let Err(e) = res.send(&mut client.stream) { - println!("Unable to send: {}", e); - client.server_keep_alive = false; - } - - client.stream.write_all(buf.as_bytes()).unwrap(); - }; - let length = req.find_header("Content-Length"); let length: usize = match match length { Some(length) => length, None => { - return error( + return error_handler( + client, + res, Error::new(Kind::HttpRequestParseError, Class::ClientError) .set_desc("field 'Content-Length' missing".to_string()), - client, ) } } @@ -132,11 +135,12 @@ fn endpoint_handler( { Ok(length) => length, Err(e) => { - return error( + return error_handler( + client, + res, Error::new(Kind::HttpRequestParseError, Class::ClientError).set_desc( format!("unable to parse field 'Content-Length': {}", &e).to_string(), ), - client, ) } }; @@ -147,12 +151,12 @@ fn endpoint_handler( // TODO decompress let input = match serde_json::from_slice(&buf[..length]) { Ok(val) => val, - Err(e) => return error(e.into(), client), + Err(e) => return error_handler(client, res, e.into()), }; let buf = match usimp::endpoint(endpoint, input) { Ok(output) => output.to_string() + "\r\n", - Err(e) => return error(e, client), + Err(e) => return error_handler(client, res, e), }; // TODO compress diff --git a/src/http/mod.rs b/src/http/mod.rs index 347d52a..91393ce 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -240,6 +240,19 @@ impl Response { } pub fn send(&mut self, stream: &mut Stream) -> Result<(), std::io::Error> { + let mut header = format!( + "HTTP/{} {:03} {}\r\n", + self.version, self.status.code, self.status.message + ); + for header_field in &self.header.fields { + header.push_str(format!("{}: {}\r\n", header_field.name, header_field.value).as_str()); + } + header.push_str("\r\n"); + stream.write_all(header.as_bytes())?; + Ok(()) + } + + pub fn send_default(&mut self, stream: &mut Stream) -> Result<(), std::io::Error> { let mut buf = None; if let None = self.find_header("Content-Length") { let new_buf = self.format_default_response(); @@ -251,16 +264,8 @@ impl Response { buf = Some(new_buf); } - let mut header = format!( - "HTTP/{} {:03} {}\r\n", - self.version, self.status.code, self.status.message - ); - for header_field in &self.header.fields { - header.push_str(format!("{}: {}\r\n", header_field.name, header_field.value).as_str()); - } - header.push_str("\r\n"); + self.send(stream); - stream.write_all(header.as_bytes())?; if let Some(buf) = buf { stream.write_all(buf.as_bytes())?; } diff --git a/src/websocket/mod.rs b/src/websocket/mod.rs index 8fc7d65..edc88c7 100644 --- a/src/websocket/mod.rs +++ b/src/websocket/mod.rs @@ -1,8 +1,15 @@ +use crate::error::*; use crate::http; -pub fn connection_handler(client: &mut http::HttpStream, req: &http::Request) { +use base64; +use sha1; + +pub fn connection_handler( + client: &mut http::HttpStream, + req: &http::Request, + mut res: http::Response, +) { client.server_keep_alive = false; - let mut res = http::Response::new(); if let http::Method::GET = req.method { } else { @@ -12,7 +19,79 @@ pub fn connection_handler(client: &mut http::HttpStream, req: &http::Request) { return; } + if let Some(connection) = req.find_header("Connection") { + if !connection.eq_ignore_ascii_case("upgrade") { + return http::error_handler( + client, + res, + Error::new(Kind::WebSocketError, Class::ClientError) + .set_desc("invalid value for header field 'Connection'".to_string()), + ); + } + } else { + return http::error_handler( + client, + res, + Error::new(Kind::WebSocketError, Class::ClientError) + .set_desc("unable to find header field 'Connection'".to_string()), + ); + } + + if let Some(upgrade) = req.find_header("Upgrade") { + if !upgrade.eq_ignore_ascii_case("websocket") { + return http::error_handler( + client, + res, + Error::new(Kind::WebSocketError, Class::ClientError) + .set_desc("invalid value for header field 'Upgrade'".to_string()), + ); + } + } else { + return http::error_handler( + client, + res, + Error::new(Kind::WebSocketError, Class::ClientError) + .set_desc("unable to find header field 'Upgrade'".to_string()), + ); + } + + if let Some(version) = req.find_header("Sec-WebSocket-Version") { + if !version.eq("13") { + return http::error_handler( + client, + res, + Error::new(Kind::WebSocketError, Class::ClientError) + .set_desc("invalid value for header field 'Sec-WebSocket-Key'".to_string()), + ); + } + } else { + return http::error_handler( + client, + res, + Error::new(Kind::WebSocketError, Class::ClientError) + .set_desc("unable to find header field 'Sec-WebSocket-Version'".to_string()), + ); + } + + if let Some(key) = req.find_header("Sec-WebSocket-Key") { + let mut hasher = sha1::Sha1::new(); + hasher.update(key.as_bytes()); + hasher.update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11".as_bytes()); + let key = base64::encode(hasher.digest().bytes()); + res.add_header("Sec-WebSocket-Accept", key.as_str()); + } else { + return http::error_handler( + client, + res, + Error::new(Kind::WebSocketError, Class::ClientError) + .set_desc("unable to find header field 'Sec-WebSocket-Key'".to_string()), + ); + } + + res.add_header("Connection", "Upgrade"); + res.add_header("Upgrade", "websocket"); + // TODO implement websocket - res.status(501); + res.status(101); res.send(&mut client.stream).unwrap(); }