Implemented WebSocket handshake

This commit is contained in:
2021-05-20 21:49:42 +02:00
parent 2f67518546
commit 843f11459b
5 changed files with 144 additions and 51 deletions

View File

@ -16,3 +16,5 @@ flate2 = "1.0.0"
r2d2 = "0.8.9"
r2d2_postgres = "0.18.0"
ansi_term = "0.12"
sha1 = "0.6.0"
base64 = "0.13.0"

View File

@ -8,6 +8,7 @@ pub enum Kind {
DatabaseError,
HttpRequestParseError,
IoError,
WebSocketError,
}
#[derive(Copy, Clone, Debug)]
@ -53,6 +54,7 @@ impl Error {
Kind::DatabaseError => "Database error",
Kind::HttpRequestParseError => "Unable to parse http request",
Kind::IoError => "IO error",
Kind::WebSocketError => "WebSocket protocol error",
},
}
}
@ -79,6 +81,7 @@ impl fmt::Display for Error {
Kind::DatabaseError => "database error",
Kind::HttpRequestParseError => "unable to parse http request",
Kind::IoError => "io error",
Kind::WebSocketError => "websocket protocol error",
}
.to_string();
if let Some(desc) = &self.desc {

View File

@ -45,14 +45,19 @@ fn request_handler(client: &mut super::HttpStream) {
res.status(404);
} else if req.uri.eq("/") {
res.status(200);
} else if req.uri.eq("/_usimp/websocket") {
return websocket::connection_handler(client, &req);
} else if req.uri.starts_with("/_usimp/") {
res.add_header("Cache-Control", "no-store");
res.add_header("Access-Control-Allow-Origin", "*");
if req.uri.eq("/_usimp/websocket") {
return websocket::connection_handler(client, &req, res);
}
let parts: Vec<&str> = req.uri.split('/').collect();
match parts[2..] {
["entity", entity] => res.status(501),
[endpoint] => match req.method {
Method::POST => return endpoint_handler(client, &req, &mut res, endpoint),
Method::POST => return endpoint_handler(client, &req, res, endpoint),
_ => {
res.status(405);
res.add_header("Allow", "POST");
@ -76,55 +81,53 @@ fn request_handler(client: &mut super::HttpStream) {
}
}
if let Err(e) = res.send(&mut client.stream) {
if let Err(e) = res.send_default(&mut client.stream) {
println!("Unable to send: {}", e);
client.server_keep_alive = false;
}
client.server_keep_alive = false;
}
pub fn error_handler(client: &mut super::HttpStream, mut res: super::Response, error: Error) {
println!("{}", error.to_string());
res.status(match &error.class() {
Class::ClientError => 400,
Class::ServerError => 500,
});
res.error_info(error.to_string());
let mut obj = serde_json::Value::Object(serde_json::Map::new());
obj["status"] = serde_json::Value::String("error".to_string());
obj["message"] = serde_json::Value::String(error.to_string());
let buf = obj.to_string() + "\r\n";
let length = buf.as_bytes().len();
res.add_header("Content-Length", length.to_string().as_str());
res.add_header("Content-Type", "application/json; charset=utf-8");
if let Err(e) = res.send(&mut client.stream) {
println!("Unable to send: {}", e);
client.server_keep_alive = false;
}
client.stream.write_all(buf.as_bytes()).unwrap();
}
fn endpoint_handler(
client: &mut super::HttpStream,
req: &super::Request,
res: &mut super::Response,
mut res: super::Response,
endpoint: &str,
) {
res.add_header("Cache-Control", "no-store");
res.add_header("Access-Control-Allow-Origin", "*");
let mut error = |error: Error, client: &mut super::HttpStream| {
println!("{}", error.to_string());
res.status(match &error.class() {
Class::ClientError => 400,
Class::ServerError => 500,
});
res.error_info(error.to_string());
let mut obj = serde_json::Value::Object(serde_json::Map::new());
obj["status"] = serde_json::Value::String("error".to_string());
obj["message"] = serde_json::Value::String(error.to_string());
let buf = obj.to_string() + "\r\n";
let length = buf.as_bytes().len();
res.add_header("Content-Length", length.to_string().as_str());
res.add_header("Content-Type", "application/json; charset=utf-8");
if let Err(e) = res.send(&mut client.stream) {
println!("Unable to send: {}", e);
client.server_keep_alive = false;
}
client.stream.write_all(buf.as_bytes()).unwrap();
};
let length = req.find_header("Content-Length");
let length: usize = match match length {
Some(length) => length,
None => {
return error(
return error_handler(
client,
res,
Error::new(Kind::HttpRequestParseError, Class::ClientError)
.set_desc("field 'Content-Length' missing".to_string()),
client,
)
}
}
@ -132,11 +135,12 @@ fn endpoint_handler(
{
Ok(length) => length,
Err(e) => {
return error(
return error_handler(
client,
res,
Error::new(Kind::HttpRequestParseError, Class::ClientError).set_desc(
format!("unable to parse field 'Content-Length': {}", &e).to_string(),
),
client,
)
}
};
@ -147,12 +151,12 @@ fn endpoint_handler(
// TODO decompress
let input = match serde_json::from_slice(&buf[..length]) {
Ok(val) => val,
Err(e) => return error(e.into(), client),
Err(e) => return error_handler(client, res, e.into()),
};
let buf = match usimp::endpoint(endpoint, input) {
Ok(output) => output.to_string() + "\r\n",
Err(e) => return error(e, client),
Err(e) => return error_handler(client, res, e),
};
// TODO compress

View File

@ -240,6 +240,19 @@ impl Response {
}
pub fn send(&mut self, stream: &mut Stream) -> Result<(), std::io::Error> {
let mut header = format!(
"HTTP/{} {:03} {}\r\n",
self.version, self.status.code, self.status.message
);
for header_field in &self.header.fields {
header.push_str(format!("{}: {}\r\n", header_field.name, header_field.value).as_str());
}
header.push_str("\r\n");
stream.write_all(header.as_bytes())?;
Ok(())
}
pub fn send_default(&mut self, stream: &mut Stream) -> Result<(), std::io::Error> {
let mut buf = None;
if let None = self.find_header("Content-Length") {
let new_buf = self.format_default_response();
@ -251,16 +264,8 @@ impl Response {
buf = Some(new_buf);
}
let mut header = format!(
"HTTP/{} {:03} {}\r\n",
self.version, self.status.code, self.status.message
);
for header_field in &self.header.fields {
header.push_str(format!("{}: {}\r\n", header_field.name, header_field.value).as_str());
}
header.push_str("\r\n");
self.send(stream);
stream.write_all(header.as_bytes())?;
if let Some(buf) = buf {
stream.write_all(buf.as_bytes())?;
}

View File

@ -1,8 +1,15 @@
use crate::error::*;
use crate::http;
pub fn connection_handler(client: &mut http::HttpStream, req: &http::Request) {
use base64;
use sha1;
pub fn connection_handler(
client: &mut http::HttpStream,
req: &http::Request,
mut res: http::Response,
) {
client.server_keep_alive = false;
let mut res = http::Response::new();
if let http::Method::GET = req.method {
} else {
@ -12,7 +19,79 @@ pub fn connection_handler(client: &mut http::HttpStream, req: &http::Request) {
return;
}
if let Some(connection) = req.find_header("Connection") {
if !connection.eq_ignore_ascii_case("upgrade") {
return http::error_handler(
client,
res,
Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("invalid value for header field 'Connection'".to_string()),
);
}
} else {
return http::error_handler(
client,
res,
Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("unable to find header field 'Connection'".to_string()),
);
}
if let Some(upgrade) = req.find_header("Upgrade") {
if !upgrade.eq_ignore_ascii_case("websocket") {
return http::error_handler(
client,
res,
Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("invalid value for header field 'Upgrade'".to_string()),
);
}
} else {
return http::error_handler(
client,
res,
Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("unable to find header field 'Upgrade'".to_string()),
);
}
if let Some(version) = req.find_header("Sec-WebSocket-Version") {
if !version.eq("13") {
return http::error_handler(
client,
res,
Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("invalid value for header field 'Sec-WebSocket-Key'".to_string()),
);
}
} else {
return http::error_handler(
client,
res,
Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("unable to find header field 'Sec-WebSocket-Version'".to_string()),
);
}
if let Some(key) = req.find_header("Sec-WebSocket-Key") {
let mut hasher = sha1::Sha1::new();
hasher.update(key.as_bytes());
hasher.update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11".as_bytes());
let key = base64::encode(hasher.digest().bytes());
res.add_header("Sec-WebSocket-Accept", key.as_str());
} else {
return http::error_handler(
client,
res,
Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("unable to find header field 'Sec-WebSocket-Key'".to_string()),
);
}
res.add_header("Connection", "Upgrade");
res.add_header("Upgrade", "websocket");
// TODO implement websocket
res.status(501);
res.status(101);
res.send(&mut client.stream).unwrap();
}