Refactored websocket

This commit is contained in:
2021-05-21 22:15:59 +02:00
parent 843f11459b
commit 80aaed0cae
5 changed files with 144 additions and 103 deletions

View File

@ -16,5 +16,5 @@ flate2 = "1.0.0"
r2d2 = "0.8.9" r2d2 = "0.8.9"
r2d2_postgres = "0.18.0" r2d2_postgres = "0.18.0"
ansi_term = "0.12" ansi_term = "0.12"
sha1 = "0.6.0" rust-crypto = "^0.2"
base64 = "0.13.0" base64 = "0.13.0"

View File

@ -9,6 +9,8 @@ pub enum Kind {
HttpRequestParseError, HttpRequestParseError,
IoError, IoError,
WebSocketError, WebSocketError,
NotImplementedError,
UsimpProtocolError,
} }
#[derive(Copy, Clone, Debug)] #[derive(Copy, Clone, Debug)]
@ -55,6 +57,8 @@ impl Error {
Kind::HttpRequestParseError => "Unable to parse http request", Kind::HttpRequestParseError => "Unable to parse http request",
Kind::IoError => "IO error", Kind::IoError => "IO error",
Kind::WebSocketError => "WebSocket protocol error", Kind::WebSocketError => "WebSocket protocol error",
Kind::NotImplementedError => "Not yet implemented",
Kind::UsimpProtocolError => "USIMP protocol error",
}, },
} }
} }
@ -82,6 +86,8 @@ impl fmt::Display for Error {
Kind::HttpRequestParseError => "unable to parse http request", Kind::HttpRequestParseError => "unable to parse http request",
Kind::IoError => "io error", Kind::IoError => "io error",
Kind::WebSocketError => "websocket protocol error", Kind::WebSocketError => "websocket protocol error",
Kind::NotImplementedError => "not yet implemented",
Kind::UsimpProtocolError => "usimp protocol error",
} }
.to_string(); .to_string();
if let Some(desc) = &self.desc { if let Some(desc) = &self.desc {

View File

@ -53,18 +53,30 @@ fn request_handler(client: &mut super::HttpStream) {
return websocket::connection_handler(client, &req, res); return websocket::connection_handler(client, &req, res);
} }
let mut error = None;
let parts: Vec<&str> = req.uri.split('/').collect(); let parts: Vec<&str> = req.uri.split('/').collect();
match parts[2..] { match parts[2..] {
["entity", entity] => res.status(501), ["entity", entity] => {
res.status(501);
error = Some(Error::new(Kind::NotImplementedError, Class::ServerError))
}
[endpoint] => match req.method { [endpoint] => match req.method {
Method::POST => return endpoint_handler(client, &req, res, endpoint), Method::POST => return endpoint_handler(client, &req, res, endpoint),
_ => { _ => {
res.status(405); res.status(405);
res.add_header("Allow", "POST"); res.add_header("Allow", "POST");
error = Some(Error::new(Kind::UsimpProtocolError, Class::ClientError))
} }
}, },
_ => res.status(400), _ => error = Some(Error::new(Kind::InvalidEndpointError, Class::ClientError)),
} }
if let Some(error) = error {
error_handler(client, res, error);
}
return;
} else { } else {
res.status(404); res.status(404);
} }
@ -90,10 +102,18 @@ fn request_handler(client: &mut super::HttpStream) {
pub fn error_handler(client: &mut super::HttpStream, mut res: super::Response, error: Error) { pub fn error_handler(client: &mut super::HttpStream, mut res: super::Response, error: Error) {
println!("{}", error.to_string()); println!("{}", error.to_string());
res.status(match &error.class() { match &error.class() {
Class::ClientError => 400, Class::ClientError => {
Class::ServerError => 500, if res.status.code < 400 || res.status.code >= 499 {
}); res.status(400)
}
}
Class::ServerError => {
if res.status.code < 500 || res.status.code > 599 {
res.status(500)
}
}
}
res.error_info(error.to_string()); res.error_info(error.to_string());
let mut obj = serde_json::Value::Object(serde_json::Map::new()); let mut obj = serde_json::Value::Object(serde_json::Map::new());

109
src/websocket/handler.rs Normal file
View File

@ -0,0 +1,109 @@
use crate::error::*;
use crate::http;
use base64;
use crypto;
use crypto::digest::Digest;
pub fn connection_handler(
client: &mut http::HttpStream,
req: &http::Request,
mut res: http::Response,
) {
client.server_keep_alive = false;
if let http::Method::GET = req.method {
} else {
res.status(405);
res.add_header("Allow", "GET");
res.send(&mut client.stream).unwrap();
return;
}
if let Some(connection) = req.find_header("Connection") {
if !connection.eq_ignore_ascii_case("upgrade") {
return http::error_handler(
client,
res,
Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("invalid value for header field 'Connection'".to_string()),
);
}
} else {
return http::error_handler(
client,
res,
Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("unable to find header field 'Connection'".to_string()),
);
}
if let Some(upgrade) = req.find_header("Upgrade") {
if !upgrade.eq_ignore_ascii_case("websocket") {
return http::error_handler(
client,
res,
Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("invalid value for header field 'Upgrade'".to_string()),
);
}
} else {
return http::error_handler(
client,
res,
Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("unable to find header field 'Upgrade'".to_string()),
);
}
if let Some(version) = req.find_header("Sec-WebSocket-Version") {
if !version.eq("13") {
return http::error_handler(
client,
res,
Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("invalid value for header field 'Sec-WebSocket-Key'".to_string()),
);
}
} else {
return http::error_handler(
client,
res,
Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("unable to find header field 'Sec-WebSocket-Version'".to_string()),
);
}
if let Some(key) = req.find_header("Sec-WebSocket-Key") {
let mut hasher = crypto::sha1::Sha1::new();
hasher.input_str(key);
hasher.input_str("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
let mut result = [0u8; 160 / 8];
hasher.result(&mut result);
let key = base64::encode(result);
res.add_header("Sec-WebSocket-Accept", key.as_str());
} else {
return http::error_handler(
client,
res,
Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("unable to find header field 'Sec-WebSocket-Key'".to_string()),
);
}
res.add_header("Connection", "Upgrade");
res.add_header("Upgrade", "websocket");
res.status(101);
res.send(&mut client.stream).unwrap();
loop {
let mut buf = [0u8; 8192];
let res = client.stream.read(&mut buf).unwrap();
if res == 0 {
break;
}
println!("Msg: {}/{}", res, String::from_utf8_lossy(&buf));
}
}

View File

@ -1,97 +1,3 @@
use crate::error::*; mod handler;
use crate::http;
use base64; pub use handler::*;
use sha1;
pub fn connection_handler(
client: &mut http::HttpStream,
req: &http::Request,
mut res: http::Response,
) {
client.server_keep_alive = false;
if let http::Method::GET = req.method {
} else {
res.status(405);
res.add_header("Allow", "GET");
res.send(&mut client.stream).unwrap();
return;
}
if let Some(connection) = req.find_header("Connection") {
if !connection.eq_ignore_ascii_case("upgrade") {
return http::error_handler(
client,
res,
Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("invalid value for header field 'Connection'".to_string()),
);
}
} else {
return http::error_handler(
client,
res,
Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("unable to find header field 'Connection'".to_string()),
);
}
if let Some(upgrade) = req.find_header("Upgrade") {
if !upgrade.eq_ignore_ascii_case("websocket") {
return http::error_handler(
client,
res,
Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("invalid value for header field 'Upgrade'".to_string()),
);
}
} else {
return http::error_handler(
client,
res,
Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("unable to find header field 'Upgrade'".to_string()),
);
}
if let Some(version) = req.find_header("Sec-WebSocket-Version") {
if !version.eq("13") {
return http::error_handler(
client,
res,
Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("invalid value for header field 'Sec-WebSocket-Key'".to_string()),
);
}
} else {
return http::error_handler(
client,
res,
Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("unable to find header field 'Sec-WebSocket-Version'".to_string()),
);
}
if let Some(key) = req.find_header("Sec-WebSocket-Key") {
let mut hasher = sha1::Sha1::new();
hasher.update(key.as_bytes());
hasher.update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11".as_bytes());
let key = base64::encode(hasher.digest().bytes());
res.add_header("Sec-WebSocket-Accept", key.as_str());
} else {
return http::error_handler(
client,
res,
Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("unable to find header field 'Sec-WebSocket-Key'".to_string()),
);
}
res.add_header("Connection", "Upgrade");
res.add_header("Upgrade", "websocket");
// TODO implement websocket
res.status(101);
res.send(&mut client.stream).unwrap();
}