diff --git a/src/http/handler.rs b/src/http/handler.rs index 4e808df..c6b7526 100644 --- a/src/http/handler.rs +++ b/src/http/handler.rs @@ -24,6 +24,10 @@ pub fn connection_handler(client: super::Stream) { fn request_handler(client: &mut super::HttpStream) { let mut res = super::Response::new(); + if let Some(conn) = res.find_header("Connection") { + client.client_keep_alive = client.client_keep_alive && conn.eq_ignore_ascii_case("keep-alive"); + } + match super::parser::parse_request(&mut client.stream) { Ok(Some(req)) => { println!("{} {}", req.method, req.uri); @@ -53,7 +57,7 @@ fn request_handler(client: &mut super::HttpStream) { ["entity", entity] => { res.status(501); error = Some(Error::new(Kind::NotImplementedError, Class::ServerError)) - } + }, [endpoint] => match req.method { Method::POST => return endpoint_handler(client, &req, res, endpoint), _ => { @@ -86,7 +90,7 @@ fn request_handler(client: &mut super::HttpStream) { } } - if let Err(e) = res.send_default(&mut client.stream) { + if let Err(e) = client.respond_default(&mut res) { println!("Unable to send: {}", e); client.server_keep_alive = false; } @@ -118,7 +122,7 @@ pub fn error_handler(client: &mut super::HttpStream, mut res: super::Response, e res.add_header("Content-Length", length.to_string().as_str()); res.add_header("Content-Type", "application/json; charset=utf-8"); - if let Err(e) = res.send(&mut client.stream) { + if let Err(e) = client.respond(&mut res) { println!("Unable to send: {}", e); client.server_keep_alive = false; } @@ -178,6 +182,6 @@ fn endpoint_handler( res.add_header("Content-Type", "application/json; charset=utf-8"); res.status(200); - res.send(&mut client.stream).unwrap(); + client.respond(&mut res).unwrap(); client.stream.write_all(buf.as_bytes()).unwrap(); } diff --git a/src/http/mod.rs b/src/http/mod.rs index 088f760..b081cab 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -311,6 +311,25 @@ impl Response { } } +impl HttpStream { + pub fn respond(&mut self, res: &mut Response) -> Result<(), std::io::Error> { + self.keep_alive(res); + res.send(&mut self.stream) + } + + pub fn respond_default(&mut self, res: &mut Response) -> Result<(), std::io::Error> { + self.keep_alive(res); + res.send_default(&mut self.stream) + } + + fn keep_alive(&mut self, res: &mut Response) { + if self.client_keep_alive && self.server_keep_alive { + res.add_header("Connection", "keep-alive"); + res.add_header("Keep-Alive", "timeout=3600, max=200"); + } + } +} + impl Stream { pub fn read(&mut self, buf: &mut [u8]) -> Result { match self { diff --git a/src/websocket/handler.rs b/src/websocket/handler.rs index b69adeb..5ced870 100644 --- a/src/websocket/handler.rs +++ b/src/websocket/handler.rs @@ -11,14 +11,16 @@ pub fn connection_handler( req: &http::Request, mut res: http::Response, ) { - client.server_keep_alive = false; - if let http::Method::GET = req.method { } else { res.status(405); res.add_header("Allow", "GET"); - res.send(&mut client.stream).unwrap(); - return; + return http::error_handler( + client, + res, + Error::new(Kind::WebSocketError, Class::ClientError) + .set_desc("method not allowed".to_string()), + ); } if let Some(connection) = req.find_header("Connection") { @@ -93,6 +95,7 @@ pub fn connection_handler( ); } + client.server_keep_alive = false; res.add_header("Connection", "Upgrade"); res.add_header("Upgrade", "websocket");