diff --git a/Cargo.toml b/Cargo.toml index 4cbef2d..0ef9b28 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,14 +7,20 @@ edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -rusty_pool = "0.6.0" +hyper = { version = "0.14", features = ["full"] } +tokio = { version = "1", features = ["full"] } +futures = "0.3" +tokio-tls = "0.3.1" +rustls = "0.19.1" +tokio-rustls = "0.22.0" +futures-util = "0.3.15" +async-stream = "0.3.2" +hyper-tungstenite = "0.3.2" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0.64" -openssl = {version = "0.10", features = ["vendored"]} chrono = "0.4" -flate2 = "1.0.0" -r2d2 = "0.8.9" -r2d2_postgres = "0.18.0" +bb8 = "0.7.0" +bb8-postgres = "0.7.0" ansi_term = "0.12" rust-crypto = "^0.2" base64 = "0.13.0" diff --git a/src/database.rs b/src/database.rs index adbefcf..87ac9f9 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,42 +1,42 @@ -use crate::error::Error; -use r2d2_postgres::postgres::NoTls; -use r2d2_postgres::PostgresConnectionManager; +use crate::error::*; +use bb8_postgres::tokio_postgres::NoTls; +use bb8_postgres::PostgresConnectionManager; use std::ops::Deref; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, MutexGuard}; use std::time::Duration; pub enum Pool { - Postgres(r2d2::Pool>), + Postgres(bb8::Pool>), } -pub enum Client { - Postgres(r2d2::PooledConnection>), +pub enum Client<'a> { + Postgres(bb8::PooledConnection<'a, PostgresConnectionManager>), } -static mut POOL: Option>> = None; +static mut POOL: Option = None; -pub fn init() -> Result<(), Error> { +pub async fn init() -> Result<(), Error> { let manager = PostgresConnectionManager::new( "host=localhost user=postgres dbname=locutus".parse().unwrap(), NoTls, ); - let pool = r2d2::Pool::builder() + let pool = bb8::Pool::builder() .max_size(64) .min_idle(Some(2)) .connection_timeout(Duration::from_secs(4)) .max_lifetime(Some(Duration::from_secs(3600))) - .build(manager)?; + .build(manager).await?; unsafe { - POOL = Some(Arc::new(Mutex::new(Pool::Postgres(pool)))); + POOL = Some(Pool::Postgres(pool)); } Ok(()) } -pub fn client() -> Result { - match unsafe { POOL.as_ref().unwrap().clone().lock().unwrap().deref() } { - Pool::Postgres(pool) => Ok(Client::Postgres(pool.get()?)), +pub async fn client() -> Result, Error> { + match unsafe { POOL.as_ref().unwrap().clone() } { + Pool::Postgres(pool) => Ok(Client::Postgres(pool.get().await?)), } } diff --git a/src/error.rs b/src/error.rs index 218a7b9..9e346c2 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,120 +1,107 @@ -use std::fmt; +use crate::usimp::{InputEnvelope, OutputEnvelope}; -#[derive(Copy, Clone, Debug)] -pub enum Kind { - InvalidEndpointError, - JsonParseError, - DatabaseConnectionError, - DatabaseError, - HttpRequestParseError, - IoError, - WebSocketError, - NotImplementedError, - UsimpProtocolError, - Utf8DecodeError, - AuthenticationError, - InvalidSessionError, +use serde_json::{Value, Map}; +use bb8_postgres::tokio_postgres; +use bb8_postgres; + +#[derive(Debug)] +pub struct Error { + pub kind: ErrorKind, + pub class: ErrorClass, + pub msg: Option, + pub desc: Option, } -#[derive(Copy, Clone, Debug)] -pub enum Class { +#[derive(Debug)] +pub enum ErrorClass { ClientProtocolError, ClientError, ServerError, } #[derive(Debug)] -pub struct Error { - kind: Kind, - msg: Option, - desc: Option, - class: Class, +pub enum ErrorKind { + NotImplemented, + UsimpError, + WebSocketError, + DatabaseError, +} + +impl InputEnvelope { + pub fn new_error(&self, kind: ErrorKind, class: ErrorClass, msg: Option) -> OutputEnvelope { + OutputEnvelope { + request_nr: self.request_nr, + error: Some(Error::new(kind, class, msg)), + data: Value::Null, + } + } + + pub fn error(&self, error: Error) -> OutputEnvelope { + OutputEnvelope { + request_nr: self.request_nr, + error: Some(error), + data: Value::Null, + } + } } impl Error { - pub fn new(kind: Kind, class: Class) -> Self { - Error { + pub fn new(kind: ErrorKind, class: ErrorClass, msg: Option) -> Self { + return Error { kind, - msg: None, - desc: None, class, + msg, + desc: None, } } - pub fn class(&self) -> &Class { - &self.class - } - - pub fn set_msg(mut self, msg: String) -> Self { + pub fn msg(&mut self, msg: String) { self.msg = Some(msg); - self } - pub fn msg(&self) -> &str { - match &self.msg { - Some(msg) => msg.as_str(), - None => match self.kind { - Kind::InvalidEndpointError => "Invalid endpoint", - Kind::JsonParseError => "Unable to parse JSON data", - Kind::DatabaseConnectionError => "Unable to connect to database", - Kind::DatabaseError => "Database error", - Kind::HttpRequestParseError => "Unable to parse http request", - Kind::IoError => "IO error", - Kind::WebSocketError => "WebSocket protocol error", - Kind::NotImplementedError => "Not yet implemented", - Kind::UsimpProtocolError => "USIMP protocol error", - Kind::Utf8DecodeError => "Unable to decode UTF-8 data", - Kind::AuthenticationError => "Unable to authenticate", - Kind::InvalidSessionError => "Invalid session", - }, - } - } - - pub fn set_desc(mut self, desc: String) -> Self { - self.desc = Some(desc); - self - } - - pub fn desc(&self) -> Option<&str> { - match &self.desc { - Some(desc) => Some(desc.as_str()), - None => None, + pub fn code(&self) -> &str { + match self.kind { + ErrorKind::NotImplemented => "NOT_IMPLEMENTED", + ErrorKind::UsimpError => "USIMP_ERROR", + ErrorKind::WebSocketError => "WEBSOCKET_ERROR", + ErrorKind::DatabaseError => "BACKEND_ERROR", } } } -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut error = match self.kind { - Kind::InvalidEndpointError => "invalid endpoint", - Kind::JsonParseError => "unable to parse json data", - Kind::DatabaseConnectionError => "unable to connect to database", - Kind::DatabaseError => "database error", - Kind::HttpRequestParseError => "unable to parse http request", - Kind::IoError => "io error", - Kind::WebSocketError => "websocket protocol error", - Kind::NotImplementedError => "not yet implemented", - Kind::UsimpProtocolError => "usimp protocol error", - Kind::Utf8DecodeError => "unable to decode utf-8 data", - Kind::AuthenticationError => "unable to authenticate", - Kind::InvalidSessionError => "invalid session", +impl From for OutputEnvelope { + fn from(error: Error) -> Self { + OutputEnvelope { + error: Some(error), + data: Value::Null, + request_nr: None, } - .to_string(); - if let Some(desc) = &self.desc { - error += ": "; - error += desc; - } - write!(f, "{}", error) } } -impl From for Error { - fn from(error: std::io::Error) -> Self { +impl From for Value { + fn from(error: Error) -> Self { + let mut obj = Value::Object(Map::new()); + obj["code"] = Value::from(error.code()); + obj["message"] = match error.msg { + Some(msg) => Value::from(msg), + None => Value::Null, + }; + obj["description"] = match error.desc { + Some(desc) => Value::from(desc), + None => Value::Null, + }; + obj + } +} + +impl From for Error { + fn from(error: hyper::header::ToStrError) -> Self { Error { - kind: Kind::IoError, - msg: Some(error.to_string()), + kind: ErrorKind::UsimpError, + class: ErrorClass::ClientProtocolError, + msg: None, desc: Some(error.to_string()), - class: Class::ClientProtocolError, } } } @@ -122,45 +109,43 @@ impl From for Error { impl From for Error { fn from(error: serde_json::Error) -> Self { Error { - kind: Kind::JsonParseError, - msg: Some("Unable to parse JSON data".to_string()), + kind: ErrorKind::UsimpError, + class: ErrorClass::ClientProtocolError, + msg: None, desc: Some(error.to_string()), - class: Class::ClientProtocolError, } } } -impl From for Error { - fn from(error: r2d2::Error) -> Self { +impl From for Error { + fn from(error: hyper::Error) -> Self { Error { - kind: Kind::DatabaseConnectionError, - msg: Some("Unable to connect to database".to_string()), + kind: ErrorKind::UsimpError, + class: ErrorClass::ClientProtocolError, + msg: None, desc: Some(error.to_string()), - class: Class::ServerError, } } } -impl From for Error { - fn from(error: r2d2_postgres::postgres::Error) -> Self { - // format: "db error: ERROR ..." - let msg = error.to_string().split(":").skip(1).collect::(); +impl From for Error { + fn from(error: tokio_postgres::Error) -> Self { Error { - kind: Kind::DatabaseError, - msg: Some("Database error".to_string()), - desc: Some(msg.trim().to_string()), - class: Class::ServerError, + kind: ErrorKind::DatabaseError, + class: ErrorClass::ServerError, + msg: None, + desc: Some(error.to_string()), } } } -impl From for Error { - fn from(error: std::string::FromUtf8Error) -> Self { +impl From> for Error { + fn from(error: bb8_postgres::bb8::RunError) -> Self { Error { - kind: Kind::Utf8DecodeError, - msg: Some("Unable to decode UTF-8 data".to_string()), + kind: ErrorKind::DatabaseError, + class: ErrorClass::ServerError, + msg: None, desc: Some(error.to_string()), - class: Class::ClientProtocolError, } } } diff --git a/src/http.rs b/src/http.rs new file mode 100644 index 0000000..a420c8a --- /dev/null +++ b/src/http.rs @@ -0,0 +1,142 @@ +use hyper::{Request, Response, Body, StatusCode, header, Method, body}; +use serde_json::{Value, Map}; +use crate::websocket; +use crate::usimp::*; +use crate::error::*; +use crate::usimp; +use std::str::FromStr; + +async fn endpoint_handler(req: &mut Request, endpoint: String) -> Result, Error> { + if req.method() == Method::OPTIONS { + return Ok(None) + } else if req.method() != Method::POST { + return Err(Error::new(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, None)) + } + + let to_domain; + if let Some(val) = req.headers().get(header::HeaderName::from_str("To-Domain").unwrap()) { + to_domain = val.to_str().unwrap().to_string() + } else { + return Err(Error::new(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, None)) + } + + if let Some(val) = req.headers().get(header::CONTENT_TYPE) { + let parts: Vec = val.to_str()?.split(';').map(|v| v.trim().to_ascii_lowercase()).collect(); + let p: Vec<&str> = parts.iter().map(|v| v.as_str()).collect(); + match p[0..1] { + ["application/json"] => {}, + _ => return Err(Error::new(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, None)) + } + } else { + return Err(Error::new(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, None)) + } + + let data = serde_json::from_slice(&body::to_bytes(req.body_mut()).await?)?; + + let input = InputEnvelope { + endpoint, + to_domain, + from_domain: match req.headers().get(header::HeaderName::from_str("From-Domain").unwrap()) { + Some(val) => Some(val.to_str()?.to_string()), + None => None, + }, + request_nr: None, + token: match req.headers().get(header::AUTHORIZATION) { + Some(val) => { + let val = val.to_str()?; + if val.starts_with("usimp ") { + Some(val[6..].to_string()) + } else { + return Err(Error::new(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, None)) + } + }, + None => None, + }, + data, + }; + + Ok(Some(usimp::endpoint(&input).await?)) +} + +pub async fn handler(mut req: Request) -> Result, hyper::Error> { + let mut res = Response::builder(); + let parts: Vec<&str> = req.uri().path().split('/').skip(1).collect(); + + println!("{} {}", req.method(), req.uri()); + + let val: Result, Error> = match &parts[..] { + [""] => Ok(res.status(StatusCode::OK).body(Body::from("Hello World")).unwrap()), + ["_usimp"] | ["_usimp", ..] => { + res = res + .header(header::SERVER, "Locutus") + .header(header::CACHE_CONTROL, "no-store") + .header(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*") + .header(header::ACCESS_CONTROL_MAX_AGE, 3600); + + let output = match &parts[1..] { + ["websocket"] => { + res = res + .header(header::ACCESS_CONTROL_ALLOW_METHODS, "GET"); + let (r, val) = websocket::handler(req, res).await; + res = r; + match val { + Some(val) => Ok(Some(val)), + None => return Ok(res.body(Body::empty()).unwrap()), + } + }, + [endpoint] => { + res = res + .header(header::ACCESS_CONTROL_ALLOW_METHODS, "POST, OPTIONS") + .header(header::ACCESS_CONTROL_ALLOW_HEADERS, "Content-Type, From-Domain, To-Domain, Authorization"); + let endpoint = endpoint.to_string(); + endpoint_handler(&mut req, endpoint).await + }, + _ => Err(Error::new(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, None)), + }; + + let output = match output { + Ok(Some(val)) => Some(val), + Ok(None) => None, + Err(error) => Some(OutputEnvelope::from(error)), + }; + + if let Some(output) = output { + let mut data = Value::Object(Map::new()); + + match output.error { + Some(error) => { + res = match error.class { + ErrorClass::ClientProtocolError => res.status(StatusCode::BAD_REQUEST), + ErrorClass::ServerError => res.status(StatusCode::INTERNAL_SERVER_ERROR), + _ => res.status(StatusCode::OK), + }; + data["status"] = Value::from("error"); + data["error"] = Value::from(error); + }, + None => { + data["status"] = Value::from("success"); + }, + } + + data["request_nr"] = match output.request_nr { + Some(nr) => Value::from(nr), + None => Value::Null, + }; + data["data"] = output.data; + return Ok(res.body(Body::from(serde_json::to_string(&data).unwrap() + "\r\n")).unwrap()) + } else { + res = res.status(StatusCode::NO_CONTENT); + } + + return Ok(res.body(Body::empty()).unwrap()) + }, + _ => Ok(res.status(StatusCode::NOT_FOUND).body(Body::empty()).unwrap()), + }; + + match val { + Ok(val) => Ok(val), + Err(error) => { + todo!("help") + } + } +} diff --git a/src/http/consts.rs b/src/http/consts.rs deleted file mode 100644 index 5728c1f..0000000 --- a/src/http/consts.rs +++ /dev/null @@ -1,154 +0,0 @@ -use super::StatusClass; -use super::StatusClass::*; - -pub static HTTP_STATUSES: [(u16, StatusClass, &str, &str); 41] = [ - (100, Informational, "Continue", - "The client SHOULD continue with its request."), - (101, Informational, "Switching Protocols", - "The server understands and is willing to comply with the clients request, via the Upgrade message header field, for a change in the application protocol being used on this connection."), - - (200, Success, "OK", - "The request has succeeded."), - (201, Success, "Created", - "The request has been fulfilled and resulted in a new resource being created."), - (202, Success, "Accepted", - "The request has been accepted for processing, but the processing has not been completed."), - (203, Success, "Non-Authoritative Information", - "The returned meta information in the entity-header is not the definitive set as available from the origin server, but is gathered from a local or a third-party copy."), - (204, Success, "No Content", - "The server has fulfilled the request but does not need to return an entity-body, and might want to return updated meta information."), - (205, Success, "Reset Content", - "The server has fulfilled the request and the user agent SHOULD reset the document view which caused the request to be sent."), - (206, Success, "Partial Content", - "The server has fulfilled the partial GET request for the resource."), - - (300, Redirection, "Multiple Choices", - "The requested resource corresponds to any one of a set of representations, each with its own specific location, and agent-driven negotiation information is being provided so that the user (or user agent) can select a preferred representation and redirect its request to that location."), - (301, Redirection, "Moved Permanently", - "The requested resource has been assigned a new permanent URI and any future references to this resource SHOULD use one of the returned URIs."), - (302, Redirection, "Found", - "The requested resource resides temporarily under a different URI."), - (303, Redirection, "See Other", - "The response to the request can be found under a different URI and SHOULD be retrieved using a GET method on that resource."), - (304, Success, "Not Modified", - "The request has been fulfilled and the requested resource has not been modified."), - (305, Redirection, "Use Proxy", - "The requested resource MUST be accessed through the proxy given by the Location field."), - (307, Redirection, "Temporary Redirect", - "The requested resource resides temporarily under a different URI."), - (308, Redirection, "Permanent Redirect", - "The requested resource has been assigned a new permanent URI and any future references to this resource ought to use one of the enclosed URIs."), - - (400, ClientError, "Bad Request", - "The request could not be understood by the server due to malformed syntax."), - (401, ClientError, "Unauthorized", - "The request requires user authentication."), - (402, ClientError, "Payment Required", - ""), - (403, ClientError, "Forbidden", - "The server understood the request, but is refusing to fulfill it."), - (404, ClientError, "Not Found", - "The server has not found anything matching the Request-URI."), - (405, ClientError, "Method Not Allowed", - "The method specified in the Request-Line is not allowed for the resource identified by the Request-URI."), - (406, ClientError, "Not Acceptable", - "The resource identified by the request is only capable of generating response entities which have content characteristics not acceptable according to the accept headers sent in the request."), - (407, ClientError, "Proxy Authentication Required", - "The request requires user authentication on the proxy."), - (408, ClientError, "Request Timeout", - "The client did not produce a request within the time that the server was prepared to wait."), - (409, ClientError, "Conflict", - "The request could not be completed due to a conflict with the current state of the resource."), - (410, ClientError, "Gone", - "The requested resource is no longer available at the server and no forwarding address is known."), - (411, ClientError, "Length Required", - "The server refuses to accept the request without a defined Content-Length."), - (412, ClientError, "Precondition Failed", - "The precondition given in one or more of the request-header fields evaluated to false when it was tested on the server."), - (413, ClientError, "Request Entity Too Large", - "The server is refusing to process a request because the request entity is larger than the server is willing or able to process."), - (414, ClientError, "Request-URI Too Long", - "The server is refusing to service the request because the Request-URI is longer than the server is willing to interpret."), - (415, ClientError, "Unsupported Media Type", - "The server is refusing to service the request because the entity of the request is in a format not supported by the requested resource for the requested method."), - (416, ClientError, "Range Not Satisfiable", - "None of the ranges in the requests Range header field overlap the current extent of the selected resource or that the set of ranges requested has been rejected due to invalid ranges or an excessive request of small or overlapping ranges."), - (417, ClientError, "Expectation Failed", - "The expectation given in an Expect request-header field could not be met by this server, or, if the server is a proxy, the server has unambiguous evidence that the request could not be met by the next-hop server."), - - (500, ServerError, "Internal Server Error", - "The server encountered an unexpected condition which prevented it from fulfilling the request." ), - (501, ServerError, "Not Implemented", - "The server does not support the functionality required to fulfill the request."), - (502, ServerError, "Bad Gateway", - "The server, while acting as a gateway or proxy, received an invalid response from the upstream server it accessed in attempting to fulfill the request."), - (503, ServerError, "Service Unavailable", - "The server is currently unable to handle the request due to a temporary overloading or maintenance of the server."), - (504, ServerError, "Gateway Timeout", - "The server, while acting as a gateway or proxy, did not receive a timely response from the upstream server specified by the URI or some other auxiliary server it needed to access in attempting to complete the request."), - (505, ServerError, "HTTP Version Not Supported", - "The server does not support, or refuses to support, the HTTP protocol version that was used in the request message."), -]; - -pub static DEFAULT_DOCUMENT: &str = "\ - \n\ - \n\ - \n\ - \t{status_code} {status_message} - Locutus - {hostname}\n\ - \t\n\ - \t\n\ - \t\n\ - \t\n\ - \t\n\ - \t\n\ - \t\n\ - \n\ - \n\ - \t
\n\ - \t\t
\n\ - {doc}\ - \t\t\t
{hostname} - {server_str}
\n\ - \t\t
\n\ - \t
\n\ - \n\ - \n"; - -pub static ERROR_DOCUMENT: &str = "\ - \t\t\t

{code}

\n\ - \t\t\t

{message} :(

\n\ - \t\t\t

{desc}

\n\ - \t\t\t

{info}

\n"; - -pub static WARNING_DOCUMENT: &str = "\ - \t\t\t

{code}

\n\ - \t\t\t

{message} :o

\n\ - \t\t\t

{desc}

\n\ - \t\t\t

{info}

\n"; - -pub static SUCCESS_DOCUMENT: &str = "\ - \t\t\t

{code}

\n\ - \t\t\t

{message} :)

\n\ - \t\t\t

{desc}

\n\ - \t\t\t

{info}

\n"; - -pub static INFO_DOCUMENT: &str = "\ - \t\t\t

{code}

\n\ - \t\t\t

{message} :)

\n\ - \t\t\t

{desc}

\n\ - \t\t\t

{info}

\n"; diff --git a/src/http/handler.rs b/src/http/handler.rs deleted file mode 100644 index 8b9bba8..0000000 --- a/src/http/handler.rs +++ /dev/null @@ -1,239 +0,0 @@ -use super::Method; -use crate::error::*; -use crate::usimp; -use crate::websocket; -use serde_json; -use crate::usimp::Envelope; - -pub fn connection_handler(client: super::Stream) { - let mut client = super::HttpStream { - stream: client, - request_num: 0, - client_keep_alive: true, - server_keep_alive: true, - }; - - while client.request_num < super::REQUESTS_PER_CONNECTION - && client.client_keep_alive - && client.server_keep_alive - { - request_handler(&mut client); - client.request_num += 1; - } -} - -fn request_handler(client: &mut super::HttpStream) { - let mut res = super::Response::new(); - - match super::parser::parse_request(&mut client.stream) { - Ok(Some(req)) => { - println!("{} {}", req.method, req.uri); - - client.client_keep_alive = - client.client_keep_alive && req.header.field_has_value("Connection", "keep-alive"); - - if !req.uri.starts_with("/") - || req.uri.contains("/./") - || req.uri.contains("/../") - || req.uri.ends_with("/..") - { - res.status(400); - } else if req.uri.contains("/.") { - res.status(404); - } else if req.uri.eq("/") { - res.status(200); - } else if req.uri.starts_with("/_usimp/") { - res.header.add_field("Cache-Control", "no-store"); - res.header.add_field("Access-Control-Allow-Origin", "*"); - res.header.add_field("Access-Control-Allow-Methods", "POST, OPTIONS"); - res.header.add_field("Access-Control-Allow-Headers", "Content-Type, From-Domain, To-Domain, Authorization"); - res.header.add_field("Access-Control-Max-Age", "3600"); - - if req.uri.eq("/_usimp/websocket") { - return websocket::connection_handler(client, &req, res); - } - - // TODO check Content-Type == application/json - - let mut error = None; - let parts: Vec<&str> = req.uri.split('/').collect(); - - match parts[2..] { - ["entity", entity] => { - res.status(501); - error = Some(Error::new(Kind::NotImplementedError, Class::ServerError)) - }, - [endpoint] => match req.method { - Method::POST => { - return endpoint_handler(client, &req, res, endpoint) - }, - Method::OPTIONS => { - res.status(204); - client.respond(&mut res); - return - } - _ => { - res.status(405); - res.header.add_field("Allow", "POST"); - error = Some(Error::new(Kind::UsimpProtocolError, Class::ClientProtocolError)) - } - }, - _ => error = Some(Error::new(Kind::InvalidEndpointError, Class::ClientProtocolError)), - } - - if let Some(error) = error { - error_handler(client, res, error); - } - - return; - } else { - res.status(404); - } - } - Ok(None) => { - client.client_keep_alive = false; - return; - } - Err(e) => { - res.status(400); - res.error_info(format!("{}", &e)); - println!("{}", &e); - client.server_keep_alive = false; - } - } - - if let Err(e) = client.respond_default(&mut res) { - println!("Unable to send: {}", e); - client.server_keep_alive = false; - } - client.server_keep_alive = false; -} - -pub fn error_handler(client: &mut super::HttpStream, mut res: super::Response, error: Error) { - println!("{}", error.to_string()); - match &error.class() { - Class::ClientProtocolError => { - if res.status.code < 400 || res.status.code >= 499 { - res.status(400) - } - }, - Class::ClientError => { - if res.status.code < 200 || res.status.code >= 299 { - res.status(200) - } - } - Class::ServerError => { - if res.status.code < 500 || res.status.code > 599 { - res.status(500) - } - } - } - res.error_info(error.to_string()); - - let mut obj = serde_json::Value::Object(serde_json::Map::new()); - obj["status"] = serde_json::Value::String("error".to_string()); - obj["message"] = serde_json::Value::String(error.to_string()); - obj["data"] = serde_json::Value::Null; - let buf = obj.to_string() + "\r\n"; - - let length = buf.as_bytes().len(); - res.header - .add_field("Content-Length", length.to_string().as_str()); - res.header - .add_field("Content-Type", "application/json; charset=utf-8"); - - if let Err(e) = client.respond(&mut res) { - println!("Unable to send: {}", e); - client.server_keep_alive = false; - } - - client.stream.write_all(buf.as_bytes()).unwrap(); -} - -fn endpoint_handler( - client: &mut super::HttpStream, - req: &super::Request, - mut res: super::Response, - endpoint: &str, -) { - let length = req.header.find_field("Content-Length"); - let length: usize = match match length { - Some(length) => length, - None => { - return error_handler( - client, - res, - Error::new(Kind::HttpRequestParseError, Class::ClientProtocolError) - .set_desc("field 'Content-Length' missing".to_string()), - ) - } - } - .parse() - { - Ok(length) => length, - Err(e) => { - return error_handler( - client, - res, - Error::new(Kind::HttpRequestParseError, Class::ClientProtocolError).set_desc( - format!("unable to parse field 'Content-Length': {}", &e).to_string(), - ), - ) - } - }; - - let mut buf = [0; 8192]; - client.stream.read_exact(&mut buf[..length]).unwrap(); - - // TODO decompress - let data = match serde_json::from_slice(&buf[..length]) { - Ok(val) => val, - Err(e) => return error_handler(client, res, e.into()), - }; - - let mut authorization = None; - if let Some(auth) = req.header.find_field("Authorization") { - // TODO check usimp prefix in Authorization - authorization = Some(auth.split(" ").skip(1).collect()); - } - - let mut from_domain = None; - if let Some(from) = req.header.find_field("From-Domain") { - from_domain = Some(from.to_string()); - } - - let mut to_domain; - if let Some(to) = req.header.find_field("To-Domain") { - to_domain = to.to_string(); - } else { - return error_handler( - client, - res, - Error::new(Kind::UsimpProtocolError, Class::ClientProtocolError) - .set_desc("Unable to find field 'To-Domain'".to_string()) - ); - } - - let input = Envelope { - endpoint: endpoint.to_string(), - from_domain, - to_domain, - token: authorization, - data, - }; - let buf = match usimp::endpoint(input) { - Ok(output) => output.to_string() + "\r\n", - Err(e) => return error_handler(client, res, e), - }; - - // TODO compress - let length = buf.as_bytes().len(); - res.header - .add_field("Content-Length", length.to_string().as_str()); - res.header - .add_field("Content-Type", "application/json; charset=utf-8"); - - res.status(200); - client.respond(&mut res).unwrap(); - client.stream.write_all(buf.as_bytes()).unwrap(); -} diff --git a/src/http/mod.rs b/src/http/mod.rs deleted file mode 100644 index 2911c50..0000000 --- a/src/http/mod.rs +++ /dev/null @@ -1,394 +0,0 @@ -mod consts; -mod handler; -mod parser; - -use openssl::ssl::SslStream; -use std::fmt::Formatter; -use std::io::{Read, Write}; -use std::net::TcpStream; - -pub use handler::*; - -static REQUESTS_PER_CONNECTION: u32 = 200; - -pub enum Stream { - Tcp(TcpStream), - Ssl(SslStream), -} - -pub struct HttpStream { - pub stream: Stream, - pub request_num: u32, - pub client_keep_alive: bool, - pub server_keep_alive: bool, -} - -pub enum Method { - GET, - POST, - PUT, - HEAD, - TRACE, - CONNECT, - DELETE, - OPTIONS, - Custom(String), -} - -#[derive(Copy, Clone)] -pub enum StatusClass { - Informational, - Success, - Redirection, - ClientError, - ServerError, -} - -#[derive(Clone)] -pub struct Status { - code: u16, - message: String, - desc: &'static str, - class: StatusClass, - info: Option, -} - -pub struct HeaderField { - name: String, - value: String, -} - -pub struct Header { - fields: Vec, -} - -pub struct Request { - version: String, - pub method: Method, - pub uri: String, - pub header: Header, -} - -pub struct Response { - version: String, - status: Status, - pub header: Header, -} - -impl Method { - pub fn from_str(v: &str) -> Method { - match v { - "GET" => Method::GET, - "POST" => Method::POST, - "PUT" => Method::PUT, - "HEAD" => Method::HEAD, - "TRACE" => Method::TRACE, - "CONNECT" => Method::CONNECT, - "DELETE" => Method::DELETE, - "OPTIONS" => Method::OPTIONS, - _ => Method::Custom(String::from(v)), - } - } - - pub fn to_str(&self) -> &str { - match self { - Method::GET => "GET", - Method::POST => "POST", - Method::PUT => "PUT", - Method::HEAD => "HEAD", - Method::TRACE => "TRACE", - Method::CONNECT => "CONNECT", - Method::DELETE => "DELETE", - Method::OPTIONS => "OPTIONS", - Method::Custom(v) => v, - } - } -} - -impl std::fmt::Display for Method { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.to_str()) - } -} - -impl StatusClass { - pub fn from_code(status_code: u16) -> StatusClass { - for (code, class, _msg, _desc) in &consts::HTTP_STATUSES { - if *code == status_code { - return class.clone(); - } - } - match status_code { - 100..=199 => StatusClass::Informational, - 200..=299 => StatusClass::Success, - 300..=399 => StatusClass::Redirection, - 400..=499 => StatusClass::ClientError, - 500..=599 => StatusClass::ServerError, - _ => panic!("invalid status code"), - } - } -} - -impl Status { - pub fn from_code(status_code: u16) -> Option { - for (code, class, msg, desc) in &consts::HTTP_STATUSES { - if *code == status_code { - return Some(Status { - code: status_code, - message: msg.to_string(), - desc, - class: class.clone(), - info: None, - }); - } - } - None - } - - pub fn new_custom(status_code: u16, message: &str) -> Status { - if status_code < 100 || status_code > 599 { - panic!("invalid status code"); - } - if let Some(status) = Status::from_code(status_code) { - Status { - code: status_code, - message: message.to_string(), - desc: status.desc, - class: status.class, - info: None, - } - } else { - Status { - code: status_code, - message: message.to_string(), - desc: "", - class: StatusClass::from_code(status_code), - info: None, - } - } - } -} - -impl std::fmt::Display for HeaderField { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "[{}: {}]", self.name, self.value) - } -} - -impl Header { - pub fn new() -> Self { - Header { fields: Vec::new() } - } - - pub fn from(fields: Vec) -> Self { - Header { fields } - } - - pub fn find_field(&self, field_name: &str) -> Option<&str> { - let field_name = field_name.to_lowercase(); - for field in &self.fields { - if field.name.to_lowercase().eq(field_name.as_str()) { - return Some(field.value.as_str()); - } - } - return None; - } - - pub fn add_field(&mut self, name: &str, value: &str) { - self.fields.push(HeaderField { - name: String::from(name), - value: String::from(value), - }) - } - - pub fn field_has_value(&self, field_name: &str, value: &str) -> bool { - if let Some(field) = self.find_field(field_name) { - let value = value.to_lowercase(); - field - .to_lowercase() - .split(",") - .any(|mut s| s.trim().eq(value.as_str())) - } else { - false - } - } -} - -impl Response { - pub fn new() -> Response { - let mut res = Response { - version: "1.1".to_string(), - status: Status::from_code(200).unwrap(), - header: Header::new(), - }; - res.header.add_field("Server", "Locutus"); - res.header.add_field( - "Date", - chrono::Utc::now() - .format("%a, %d %b %Y %H:%M:%S GMT") - .to_string() - .as_str(), - ); - res - } - - pub fn status(&mut self, status_code: u16) { - self.status = Status::from_code(status_code).unwrap() - } - - pub fn error_info(&mut self, info: String) { - self.status.info = Some(info); - } - - pub fn send(&mut self, stream: &mut Stream) -> Result<(), std::io::Error> { - let mut header = format!( - "HTTP/{} {:03} {}\r\n", - self.version, self.status.code, self.status.message - ); - for header_field in &self.header.fields { - header.push_str(format!("{}: {}\r\n", header_field.name, header_field.value).as_str()); - } - header.push_str("\r\n"); - stream.write_all(header.as_bytes())?; - Ok(()) - } - - pub fn send_default(&mut self, stream: &mut Stream) -> Result<(), std::io::Error> { - let mut buf = None; - if let None = self.header.find_field("Content-Length") { - let new_buf = self.format_default_response(); - self.header.add_field( - "Content-Length", - new_buf.as_bytes().len().to_string().as_str(), - ); - self.header - .add_field("Content-Type", "text/html; charset=utf-8"); - buf = Some(new_buf); - } - - self.send(stream)?; - - if let Some(buf) = buf { - stream.write_all(buf.as_bytes())?; - } - Ok(()) - } - - fn format_default_response(&self) -> String { - let (doc, color_name, color) = match self.status.class { - StatusClass::Informational => (consts::INFO_DOCUMENT, "info", "#606060"), - StatusClass::Success => (consts::SUCCESS_DOCUMENT, "success", "#008000"), - StatusClass::Redirection => (consts::WARNING_DOCUMENT, "warning", "#E0C000"), - StatusClass::ClientError => (consts::ERROR_DOCUMENT, "error", "#C00000"), - StatusClass::ServerError => (consts::ERROR_DOCUMENT, "error", "#C00000"), - }; - - consts::DEFAULT_DOCUMENT - .replace("{status_code}", self.status.code.to_string().as_str()) - .replace("{status_message}", self.status.message.as_str()) - .replace("{hostname}", "localhost") // TODO hostname - .replace("{theme_color}", color) - .replace("{color_name}", color_name) - .replace("{server_str}", "Locutus server") // TODO server string - .replace( - "{doc}", - doc.replace("{code}", self.status.code.to_string().as_str()) - .replace("{message}", self.status.message.as_str()) - .replace("{desc}", self.status.desc) - .replace( - "{info}", - self.status.info.as_ref().unwrap_or(&String::new()).as_str(), - ) - .as_str(), - ) - .replace("{{", "{") - .replace("}}", "}") - } -} - -impl HttpStream { - pub fn respond(&mut self, res: &mut Response) -> Result<(), std::io::Error> { - self.keep_alive(res); - res.send(&mut self.stream) - } - - pub fn respond_default(&mut self, res: &mut Response) -> Result<(), std::io::Error> { - self.keep_alive(res); - res.send_default(&mut self.stream) - } - - fn keep_alive(&mut self, res: &mut Response) { - if self.client_keep_alive && self.server_keep_alive { - res.header.add_field("Connection", "keep-alive"); - res.header.add_field("Keep-Alive", "timeout=3600, max=200"); - } - } -} - -impl Stream { - pub fn read(&mut self, buf: &mut [u8]) -> Result { - match self { - Stream::Tcp(stream) => stream.read(buf), - Stream::Ssl(stream) => loop { - match stream.ssl_read(buf) { - Ok(n) => return Ok(n), - Err(ref e) if e.code() == openssl::ssl::ErrorCode::ZERO_RETURN => return Ok(0), - Err(ref e) - if e.code() == openssl::ssl::ErrorCode::SYSCALL - && e.io_error().is_none() => - { - return Ok(0); - } - Err(ref e) - if e.code() == openssl::ssl::ErrorCode::WANT_READ - && e.io_error().is_none() => {} - Err(e) => { - return Err(e.into_io_error().unwrap_or_else(|e| { - std::io::Error::new(std::io::ErrorKind::Other, e) - })); - } - } - }, - } - } - - pub fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), std::io::Error> { - match self { - Stream::Tcp(stream) => stream.read_exact(buf), - Stream::Ssl(stream) => stream.read_exact(buf), - } - } - - pub fn peek(&mut self, buf: &mut [u8]) -> Result { - match self { - Stream::Tcp(stream) => stream.peek(buf), - Stream::Ssl(_stream) => todo!("Not implemented in rust-openssl"), - } - } - - pub fn write(&mut self, buf: &[u8]) -> Result { - match self { - Stream::Tcp(stream) => stream.write(buf), - Stream::Ssl(stream) => loop { - match stream.ssl_write(buf) { - Ok(n) => return Ok(n), - Err(ref e) - if e.code() == openssl::ssl::ErrorCode::WANT_READ - && e.io_error().is_none() => {} - Err(e) => { - return Err(e.into_io_error().unwrap_or_else(|e| { - std::io::Error::new(std::io::ErrorKind::Other, e) - })); - } - } - }, - } - } - - pub fn write_all(&mut self, buf: &[u8]) -> Result<(), std::io::Error> { - match self { - Stream::Tcp(stream) => stream.write_all(buf), - Stream::Ssl(stream) => stream.write_all(buf), - } - } -} diff --git a/src/http/parser.rs b/src/http/parser.rs deleted file mode 100644 index 168347c..0000000 --- a/src/http/parser.rs +++ /dev/null @@ -1,374 +0,0 @@ -use crate::error::*; -use crate::http; -use crate::http::Status; - -pub fn parse_request(stream: &mut http::Stream) -> Result, Error> { - let mut buf = [0; 4096]; - let size = stream.peek(&mut buf)?; - if size == 0 { - return Ok(None); - } - - let mut parser = Parser::new_request_parser(&buf[..size]); - let header_size = parser.parse()?; - - let mut header_fields = Vec::new(); - for (name, value) in parser.headers { - header_fields.push(http::HeaderField { - name: String::from(name), - value: String::from(value), - }); - } - - let request = http::Request { - version: String::from(parser.http_version.unwrap()), - method: http::Method::from_str(parser.method.unwrap()), - uri: String::from(parser.uri.unwrap()), - header: http::Header::from(header_fields), - }; - - stream.read_exact(&mut buf[..header_size])?; - - Ok(Some(request)) -} - -pub fn parse_response(stream: &mut http::Stream) -> Result { - let mut buf = [0; 4096]; - let size = stream.peek(&mut buf)?; - - let mut parser = Parser::new_request_parser(&buf[..size]); - let header_size = parser.parse()?; - - let status_code = parser.status_code.unwrap(); - let status_code = match status_code.parse::() { - Ok(v) => v, - Err(error) => { - return Err(Error::new(Kind::HttpRequestParseError, Class::ClientProtocolError) - .set_desc(error.to_string())) - } - }; - - let mut header_fields = Vec::new(); - for (name, value) in parser.headers { - header_fields.push(http::HeaderField { - name: String::from(name), - value: String::from(value), - }); - } - - let response = http::Response { - version: String::from(parser.http_version.unwrap()), - status: Status::new_custom(status_code, parser.status_message.unwrap()), - header: http::Header::from(header_fields), - }; - - stream.read_exact(&mut buf[..header_size])?; - - Ok(response) -} - -#[derive(Copy, Clone)] -enum State<'a> { - Method, - Uri, - Http(&'a State<'a>), - HttpVersion(&'a State<'a>), - StatusCode, - StatusMessage, - HeaderName, - HeaderValue, - Finish, - CRLF(&'a State<'a>), - Error, -} - -struct Parser<'a> { - state: State<'a>, - buf: &'a [u8], - str_start: usize, - header_size: usize, - method: Option<&'a str>, - uri: Option<&'a str>, - http_version: Option<&'a str>, - status_code: Option<&'a str>, - status_message: Option<&'a str>, - headers: Vec<(&'a str, &'a str)>, -} - -impl Parser<'_> { - fn new_request_parser(buf: &[u8]) -> Parser { - Parser { - state: State::Method, - buf, - str_start: 0, - header_size: 0, - method: None, - uri: None, - http_version: None, - status_code: None, - status_message: None, - headers: Vec::new(), - } - } - - fn new_response_parser(buf: &[u8]) -> Parser { - Parser { - state: State::Http(&State::StatusCode), - buf, - str_start: 0, - header_size: 0, - method: None, - uri: None, - http_version: None, - status_code: None, - status_message: None, - headers: Vec::new(), - } - } - - fn parse(&mut self) -> Result { - for char in self.buf { - self.next(*char); - match self.state { - State::Finish => return Ok(self.header_size), - State::Error => { - return Err(Error::new(Kind::HttpRequestParseError, Class::ClientProtocolError) - .set_desc(format!( - "invalid character at position {}", - self.header_size - 1 - ))) - } - _ => {} - } - } - return Err(Error::new(Kind::HttpRequestParseError, Class::ClientProtocolError) - .set_desc("input too short".to_string())); - } - - fn next(&mut self, char: u8) { - self.header_size += 1; - let get_str = - || std::str::from_utf8(&self.buf[self.str_start..self.header_size - 1]).unwrap(); - self.state = match &self.state { - State::Error => State::Error, - State::Finish => State::Error, - State::Method => match char { - 0x41..=0x5A => State::Method, - 0x20 => { - self.method = Some(get_str()); - self.str_start = self.header_size; - State::Uri - } - _ => State::Error, - }, - State::Uri => match char { - 0x21..=0x7E => State::Uri, - 0x20 => { - self.uri = Some(get_str()); - self.str_start = self.header_size; - State::Http(&State::HeaderName) - } - _ => State::Error, - }, - State::Http(next) => match char { - 0x48 | 0x54 | 0x50 => State::Http(next), - 0x2F => { - let http = get_str(); - self.str_start = self.header_size; - if http != "HTTP" { - State::Error - } else { - State::HttpVersion(next) - } - } - _ => State::Error, - }, - State::HttpVersion(next) => match char { - 0x30..=0x39 | 0x2E => State::HttpVersion(next), - 0x0D => match next { - State::HeaderName => { - self.http_version = Some(get_str()); - State::CRLF(next) - } - _ => State::Error, - }, - 0x20 => match next { - State::StatusCode => { - self.http_version = Some(get_str()); - self.str_start = self.header_size; - State::StatusCode - } - _ => State::Error, - }, - _ => State::Error, - }, - State::StatusCode => match char { - 0x30..=0x39 => State::StatusCode, - 0x20 => { - self.status_code = Some(get_str()); - self.str_start = self.header_size; - State::StatusMessage - } - _ => State::Error, - }, - State::StatusMessage => match char { - 0x20..=0x7E => State::StatusMessage, - 0x0D => { - self.status_message = Some(get_str()); - State::CRLF(&State::HeaderName) - } - _ => State::Error, - }, - State::HeaderName => match char { - 0x0D => { - if self.header_size == self.str_start + 1 { - State::CRLF(&State::Finish) - } else { - State::Error - } - } - 0x3A => { - let header_name = get_str(); - self.headers.push((header_name, "")); - self.str_start = self.header_size; - State::HeaderValue - } - 0x00..=0x1F - | 0x7F - | 0x80..=0xFF - | 0x20 - | 0x28 - | 0x29 - | 0x2C - | 0x2F - | 0x3A..=0x40 - | 0x5B..=0x5D - | 0x7B - | 0x7D => State::Error, - _ => State::HeaderName, - }, - State::HeaderValue => match char { - 0x20..=0x7E | 0x09 => State::HeaderValue, - 0x0D => { - self.headers.last_mut().unwrap().1 = get_str().trim(); - State::CRLF(&State::HeaderName) - } - _ => State::Error, - }, - State::CRLF(next) => match char { - 0x0A => { - self.str_start = self.header_size; - *next.clone() - } - _ => State::Error, - }, - } - } -} - -#[cfg(test)] -mod tests { - #[test] - fn simple_request() { - let request: &str = "GET /index.html HTTP/1.1\r\n\ - Host: www.example.com\r\n\ - \r\n"; - - let mut parser = super::Parser::new_request_parser(request.as_bytes()); - let size = parser.parse().unwrap(); - - assert_eq!(51, size); - assert_eq!("GET", parser.method.unwrap()); - assert_eq!("/index.html", parser.uri.unwrap()); - assert_eq!("1.1", parser.http_version.unwrap()); - assert_eq!(None, parser.status_code); - assert_eq!(None, parser.status_message); - - assert_eq!(1, parser.headers.len()); - assert_eq!(("Host", "www.example.com"), parser.headers[0]); - } - - #[test] - fn complex_request() { - let request: &str = "POST /upload/file.txt HTTP/1.3\r\n\ - Host: www.example.com \r\n\ - Content-Length: 13 \r\n\ - User-Agent: Mozilla/5.0 (X11; Linux x86_64) \r\n\ - \r\n\ - username=test"; - - let mut parser = super::Parser::new_request_parser(request.as_bytes()); - let size = parser.parse().unwrap(); - - assert_eq!(129, size); - assert_eq!("POST", parser.method.unwrap()); - assert_eq!("/upload/file.txt", parser.uri.unwrap()); - assert_eq!("1.3", parser.http_version.unwrap()); - assert_eq!(None, parser.status_code); - assert_eq!(None, parser.status_message); - - assert_eq!(3, parser.headers.len()); - assert_eq!(("Host", "www.example.com"), parser.headers[0]); - assert_eq!(("Content-Length", "13"), parser.headers[1]); - assert_eq!( - ("User-Agent", "Mozilla/5.0 (X11; Linux x86_64)"), - parser.headers[2] - ); - - assert_eq!("username=test", &request[size..]); - } - - #[test] - fn invalid_request_1() { - let request: &str = "GET /files/größe.txt HTTP/1.1\r\n\r\n"; - let mut parser = super::Parser::new_request_parser(request.as_bytes()); - match parser.parse() { - Ok(_v) => panic!("should fail"), - Err(e) => assert_eq!( - "unable to parse http request: invalid character at position 13", - e.to_string() - ), - } - } - - #[test] - fn invalid_request_2() { - let request: &str = "GET /index.html HTT"; - let mut parser = super::Parser::new_request_parser(request.as_bytes()); - match parser.parse() { - Ok(_v) => panic!("should fail"), - Err(e) => assert_eq!( - "unable to parse http request: input too short", - e.to_string() - ), - } - } - - #[test] - fn simple_response() { - let response: &str = "HTTP/1.1 200 OK\r\n\ - Content-Length: 12\r\n\ - Content-Type: text/plain; charset=us-ascii\r\n\ - \r\n\ - Hello world!"; - - let mut parser = super::Parser::new_response_parser(response.as_bytes()); - let size = parser.parse().unwrap(); - - assert_eq!(83, size); - assert_eq!("200", parser.status_code.unwrap()); - assert_eq!("OK", parser.status_message.unwrap()); - assert_eq!("1.1", parser.http_version.unwrap()); - assert_eq!(None, parser.method); - assert_eq!(None, parser.uri); - - assert_eq!(2, parser.headers.len()); - assert_eq!(("Content-Length", "12"), parser.headers[0]); - assert_eq!( - ("Content-Type", "text/plain; charset=us-ascii"), - parser.headers[1] - ); - - assert_eq!("Hello world!", &response[size..]); - } -} diff --git a/src/main.rs b/src/main.rs index 9fa6533..a6ac261 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,149 +1,112 @@ -use std::net::{SocketAddr, TcpListener, UdpSocket}; -use std::sync::{Arc, Mutex}; -use std::thread; +use std::fmt; +use std::io::Read; +use std::net; +use std::net::SocketAddr; +use std::pin::Pin; -use ansi_term::Color; -use openssl::ssl::{SslAcceptor, SslFiletype, SslMethod}; -use rusty_pool; -use std::fmt::Formatter; -use std::time::Duration; +use error::*; +use ansi_term::{Color, Style}; +use futures_util::{future::TryFutureExt, stream::Stream}; +use hyper::Server; +use hyper::server::conn::AddrStream; +use hyper::service::{make_service_fn, service_fn}; +mod http; +mod websocket; +mod usimp; mod database; mod error; -mod http; -mod subscription; -mod udp; -mod usimp; -mod websocket; -enum SocketType { - Http, - Https, - Udp, +struct HyperAcceptor<'a> { + acceptor: Pin, std::io::Error>> + 'a>>, } -impl std::fmt::Display for SocketType { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}", - match self { - SocketType::Http => "http+ws", - SocketType::Https => "https+wss", - SocketType::Udp => "udp", - } - ) +impl hyper::server::accept::Accept for HyperAcceptor<'_> { + type Conn = tokio_rustls::server::TlsStream; + type Error = std::io::Error; + + fn poll_accept( + mut self: Pin<&mut Self>, + cx: &mut core::task::Context, + ) -> core::task::Poll>> { + Pin::new(&mut self.acceptor).poll_next(cx) } } -struct SocketConfig { - address: SocketAddr, - socket_type: SocketType, +fn load_certs(filename: &str) -> std::io::Result> { + let certfile = std::fs::File::open(filename) + .map_err(|e| error(format!("failed to open {}: {}", filename, e)))?; + let mut reader = std::io::BufReader::new(certfile); + + rustls::internal::pemfile::certs(&mut reader).map_err(|_| error("failed to load certificate".into())) } -fn main() { +fn load_private_key(filename: &str) -> std::io::Result { + let keyfile = std::fs::File::open(filename) + .map_err(|e| error(format!("failed to open {}: {}", filename, e)))?; + let mut reader = std::io::BufReader::new(keyfile); + + let keys = rustls::internal::pemfile::rsa_private_keys(&mut reader) + .map_err(|_| error("failed to load private key".into()))?; + if keys.len() < 1 { + return Err(error("expected a single private key".into())); + } + Ok(keys[0].clone()) +} + +fn error(err: String) -> std::io::Error { + std::io::Error::new(std::io::ErrorKind::Other, err) +} + +#[tokio::main] +async fn main() -> Result<(), Error> { println!("Locutus server"); - let socket_configs: Vec = vec![ - SocketConfig { - address: "[::]:8080".parse().unwrap(), - socket_type: SocketType::Http, - }, - SocketConfig { - address: "[::]:8443".parse().unwrap(), - socket_type: SocketType::Https, - }, - SocketConfig { - address: "[::]:3126".parse().unwrap(), - socket_type: SocketType::Udp, - }, - ]; + database::init().await?; - subscription::init(); + let server1 = Server::bind(&SocketAddr::from(([0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0], 8080))); + let service = make_service_fn(|_: &AddrStream| async { + Ok::<_, hyper::Error>(service_fn(http::handler)) + }); + let srv1 = server1.serve(service); - // Note: rust's stdout is line buffered! - eprint!("Initializing database connection pool..."); - if let Err(error) = database::init() { - eprintln!("\n{}", Color::Red.bold().paint(error.to_string())); - std::process::exit(1); - } - eprintln!(" {}", Color::Green.paint("success")); + let tls_cfg = { + let certs = load_certs("/home/lorenz/Certificates/priv/fullchain.pem").unwrap(); + let key = load_private_key("/home/lorenz/Certificates/priv/privkey.pem").unwrap(); + let mut cfg = rustls::ServerConfig::new(rustls::NoClientAuth::new()); + cfg.set_single_cert(certs, key).unwrap(); + cfg.set_protocols(&[b"h2".to_vec(), b"http/1.1".to_vec()]); + std::sync::Arc::new(cfg) + }; - let thread_pool = rusty_pool::Builder::new() - .core_size(4) - .max_size(1024) - .keep_alive(Duration::from_secs(60 * 60)) - .build(); - let thread_pool_mutex = Arc::new(Mutex::new(thread_pool)); + let acceptor = tokio_rustls::TlsAcceptor::from(tls_cfg); + let tcp = Box::pin(tokio::net::TcpListener::bind("[::]:8443").await.unwrap()); + let incoming_tls_stream = async_stream::stream! { + loop { + let (socket, _) = tcp.accept().await.unwrap(); + let stream = acceptor.accept(socket).map_err(|e| { + println!("[!] Voluntary server halt due to client-connection error..."); + // Errors could be handled here, instead of server aborting. + //Ok(None) + //println!("{:?}", e); + error(format!("TLS Error: {:?}", e)) + }); + yield stream.await; + } + }; - let mut threads = Vec::new(); + let server2 = Server::builder(HyperAcceptor { + acceptor: Box::pin(incoming_tls_stream), + }); - for socket_config in socket_configs { - let thread_pool_mutex = thread_pool_mutex.clone(); - - eprintln!( - "Creating listening thread for {} ({})", - ansi_term::Style::new() - .bold() - .paint(socket_config.address.to_string()), - socket_config.socket_type - ); - - threads.push(match socket_config.socket_type { - SocketType::Http => thread::spawn(move || { - let mut tcp_socket = TcpListener::bind(socket_config.address).unwrap(); - - for stream in tcp_socket.incoming() { - thread_pool_mutex.lock().unwrap().execute(|| { - let stream = stream.unwrap(); - http::connection_handler(http::Stream::Tcp(stream)); - }); - } - }), - SocketType::Https => thread::spawn(move || { - let mut ssl_socket = TcpListener::bind(socket_config.address).unwrap(); - - let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); - acceptor - .set_certificate_chain_file("/home/lorenz/Certificates/chakotay.pem") - .unwrap(); - acceptor - .set_private_key_file( - "/home/lorenz/Certificates/priv/chakotay.key", - SslFiletype::PEM, - ) - .unwrap(); - acceptor.check_private_key().unwrap(); - let acceptor = Arc::new(acceptor.build()); - - for stream in ssl_socket.incoming() { - let acceptor = acceptor.clone(); - thread_pool_mutex.lock().unwrap().execute(move || { - let stream = stream.unwrap(); - let stream = acceptor.accept(stream).unwrap(); - http::connection_handler(http::Stream::Ssl(stream)); - }); - } - }), - SocketType::Udp => thread::spawn(move || { - let mut udp_socket = UdpSocket::bind(socket_config.address).unwrap(); - let mut buf = [0; 65_536]; - - loop { - let (size, addr) = udp_socket.recv_from(&mut buf).unwrap(); - let req = udp::Request::new(&udp_socket, addr, size, &buf); - thread_pool_mutex - .lock() - .unwrap() - .execute(|| udp::handler(req)); - } - }), - }); - } + let service = make_service_fn(|_| async { + Ok::<_, hyper::Error>(service_fn(http::handler)) + }); + let srv2 = server2.serve(service); println!("{}", Color::Green.paint("Ready")); - for thread in threads { - thread.join().unwrap(); - } + let (_res1, _res2) = futures::future::join(srv1, srv2).await; + + Ok(()) } diff --git a/src/subscription.rs b/src/subscription.rs deleted file mode 100644 index aabb951..0000000 --- a/src/subscription.rs +++ /dev/null @@ -1,34 +0,0 @@ -use std::sync::{Arc, Mutex, mpsc}; -use serde::{Deserialize, Serialize}; -use serde_json; - -static mut SUBSCRIPTIONS: Option>>>> = None; - -#[derive(Clone, Serialize, Deserialize)] -pub struct Event{ - pub data: serde_json::Value, -} - -pub fn init() { - unsafe { - SUBSCRIPTIONS = Some(Arc::new(Mutex::new(Vec::new()))); - } -} - -pub fn subscribe() -> mpsc::Receiver { - let (rx, tx) = mpsc::channel(); - unsafe { SUBSCRIPTIONS.as_ref().unwrap().lock().unwrap().push(rx); } - tx -} - -pub fn unsubscribe(rx: mpsc::Receiver) { - // TODO implement unsubscribe -} - -pub fn notify(event: Event) { - for sender in unsafe { SUBSCRIPTIONS.as_ref().unwrap().lock().unwrap().clone() } { - sender.send(event.clone()); - } -} - - diff --git a/src/udp/mod.rs b/src/udp/mod.rs deleted file mode 100644 index 5117744..0000000 --- a/src/udp/mod.rs +++ /dev/null @@ -1,28 +0,0 @@ -use std::net::{SocketAddr, UdpSocket}; - -pub struct Request { - socket: UdpSocket, - address: SocketAddr, - size: usize, - buf: [u8; 65_536], -} - -impl Request { - pub fn new( - socket: &UdpSocket, - address: SocketAddr, - size: usize, - buf: &[u8; 65_536], - ) -> Request { - Request { - socket: socket.try_clone().unwrap(), - address, - size, - buf: buf.clone(), - } - } -} - -pub fn handler(request: Request) { - // TODO handle UDP requests -} diff --git a/src/usimp/handler/authenticate.rs b/src/usimp/handler/authenticate.rs new file mode 100644 index 0000000..5a8ccb8 --- /dev/null +++ b/src/usimp/handler/authenticate.rs @@ -0,0 +1,34 @@ +use crate::usimp::*; +use crate::database; + +use serde_json::{Value, from_value, to_value}; +use serde::{Serialize, Deserialize}; +use std::ops::Deref; + +#[derive(Serialize, Deserialize, Clone)] +struct Input { + name: String, + password: String, +} + +#[derive(Serialize, Deserialize, Clone)] +struct Output { + session: String, + token: String, +} + +pub async fn handle(input: &InputEnvelope, session: &Session) -> Result { + Ok(to_value(authenticate(from_value(input.data.clone())?).await?)?) +} + +async fn authenticate(input: Input) -> Result { + match database::client().await? { + database::Client::Postgres(client) => { + client.execute("SELECT * FROM asdf;", &[]).await?; + } + } + Ok(Output { + session: "".to_string(), + token: "".to_string(), + }) +} diff --git a/src/usimp/handler/mod.rs b/src/usimp/handler/mod.rs new file mode 100644 index 0000000..d2f0008 --- /dev/null +++ b/src/usimp/handler/mod.rs @@ -0,0 +1,20 @@ +mod ping; +mod authenticate; + +use crate::usimp::*; + +pub async fn endpoint(input: &InputEnvelope) -> Result { + println!("Endpoint: {}", input.endpoint); + let session= Session { + account: None, + id: "".to_string(), + nr: 0, + }; + Ok(match input.endpoint.as_str() { + "ping" => input.respond(ping::handle(&input, &session).await?), + "authenticate" => input.respond(authenticate::handle(&input, &session).await?), + _ => input.new_error(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, Some("Invalid endpoint".to_string())), + }) +} + + diff --git a/src/usimp/handler/ping.rs b/src/usimp/handler/ping.rs new file mode 100644 index 0000000..324da36 --- /dev/null +++ b/src/usimp/handler/ping.rs @@ -0,0 +1,11 @@ +use crate::usimp::*; + +use serde_json::Value; + +pub async fn handle(input: &InputEnvelope, session: &Session) -> Result { + ping(&input.data).await +} + +async fn ping(input: &Value) -> Result { + Ok(input.clone()) +} diff --git a/src/usimp/mod.rs b/src/usimp/mod.rs index 60e34d8..61e0e0f 100644 --- a/src/usimp/mod.rs +++ b/src/usimp/mod.rs @@ -1,25 +1,25 @@ -use serde::{Deserialize, Serialize}; -use serde_json; +mod handler; -use crate::subscription; -use crate::database; -use crate::error::*; -use crypto::digest::Digest; -use rand; -use rand::Rng; +pub use handler::endpoint; -pub struct Envelope { +use serde_json::Value; +use crate::error::{Error, ErrorClass, ErrorKind}; +use serde::{Serialize, Deserialize}; + +#[derive(Serialize, Deserialize)] +pub struct InputEnvelope { pub endpoint: String, pub from_domain: Option, pub to_domain: String, pub token: Option, - pub data: serde_json::Value, + pub request_nr: Option, + pub data: Value, } -pub struct Account { - id: String, - name: String, - domain: String, +pub struct OutputEnvelope { + pub error: Option, + pub request_nr: Option, + pub data: Value, } pub struct Session { @@ -28,214 +28,22 @@ pub struct Session { account: Option, } +pub struct Account { + +} + +impl InputEnvelope { + pub fn respond(&self, data: Value) -> OutputEnvelope { + OutputEnvelope { + error: None, + request_nr: self.request_nr, + data, + } + } +} + impl Session { - pub fn from_token(token: &str) -> Result { - let backend = database::client()?; - - let mut session; - match backend { - database::Client::Postgres(mut client) => { - let res = client.query( - "SELECT session_id, session_nr, a.account_id, account_name, domain_id \ - FROM accounts a JOIN sessions s ON a.account_id = s.account_id \ - WHERE session_token = $1", - &[&token] - )?; - - if res.len() == 0 { - return Err(Error::new(Kind::InvalidSessionError, Class::ClientError)); - } - - session = Session { - id: res[0].get(0), - nr: res[0].get(1), - account: Some(Account { - id: res[0].get(2), - name: res[0].get(3), - domain: res[0].get(4), - }), - }; - } - } - - Ok(session) + pub async fn from_token(token: &str) -> Self { + todo!("session") } } - -pub fn endpoint(envelope: Envelope) -> Result { - if envelope.from_domain != None { - // TODO - return Err(Error::new(Kind::NotImplementedError, Class::ServerError)); - } - - let mut session = None; - if let Some(token) = &envelope.token { - session = Some(Session::from_token(token)?); - } - - let out = match envelope.endpoint.as_str() { - "echo" => serde_json::to_value(echo(session, serde_json::from_value(envelope.data)?)?)?, - "authenticate" => serde_json::to_value(authenticate(session, serde_json::from_value(envelope.data)?)?)?, - "subscribe" => serde_json::to_value(subscribe(session, serde_json::from_value(envelope.data)?)?)?, - "send_event" => serde_json::to_value(send_event(session, serde_json::from_value(envelope.data)?)?)?, - _ => return Err(Error::new(Kind::InvalidEndpointError, Class::ClientProtocolError)), - }; - - let mut envelope = serde_json::Value::Object(serde_json::Map::new()); - envelope["status"] = serde_json::Value::String("success".to_string()); - envelope["message"] = serde_json::Value::Null; - envelope["data"] = out; - - Ok(envelope) -} - -pub fn get_id(input: &str) -> String { - let mut hasher = crypto::sha2::Sha256::new(); - hasher.input_str(chrono::Utc::now().timestamp_millis().to_string().as_str()); - hasher.input_str(" "); - hasher.input_str(input); - - let mut result = [0u8; 32]; - hasher.result(&mut result); - base64_url::encode(&result) -} - -#[derive(Serialize, Deserialize)] -pub struct EchoInput { - message: String, -} - -#[derive(Serialize, Deserialize)] -pub struct EchoOutput { - message: String, - database: Option, -} - -pub fn echo(session: Option, input: EchoInput) -> Result { - let backend = database::client()?; - let mut output = EchoOutput { - message: input.message, - database: None, - }; - match backend { - database::Client::Postgres(mut client) => { - let res = client.query("SELECT * FROM test", &[])?; - for row in res { - output.database = Some(row.get(0)); - } - } - } - Ok(output) -} - -#[derive(Serialize, Deserialize)] -pub struct AuthenticateInput { - r#type: String, - name: String, - password: String, -} - -#[derive(Serialize, Deserialize)] -pub struct AuthenticateOutput { - token: String, -} - -pub fn authenticate(session: Option, input: AuthenticateInput) -> Result { - let backend = database::client()?; - - let mut token: String; - match backend { - database::Client::Postgres(mut client) => { - let res = client.query( - "SELECT account_id FROM accounts WHERE account_name = $1", - &[&input.name] - )?; - if res.len() == 0 { - return Err(Error::new(Kind::AuthenticationError, Class::ClientError)); - } - let account_id: String = res[0].get(0); - - // TODO password check - if !input.password.eq("MichaelScott") { - return Err(Error::new(Kind::AuthenticationError, Class::ClientError)); - } - - token = rand::thread_rng() - .sample_iter(&rand::distributions::Alphanumeric) - .take(256) - .map(char::from) - .collect(); - - let session_id: String = rand::thread_rng() - .sample_iter(&rand::distributions::Alphanumeric) - .take(43) - .map(char::from) - .collect(); - - client.execute( - "INSERT INTO sessions (account_id, session_nr, session_id, session_token) \ - VALUES ($1, COALESCE((SELECT MAX(session_nr) + 1 FROM sessions WHERE account_id = $1), 1), $2, $3)", - &[&account_id, &session_id, &token], - )?; - } - } - - Ok(AuthenticateOutput { token }) -} - -#[derive(Serialize, Deserialize)] -pub struct SendEventInput { - room_id: String, - data: serde_json::Value, -} - -#[derive(Serialize, Deserialize)] -pub struct SendEventOutput { - event_id: String, -} - -pub fn send_event(session: Option, input: SendEventInput) -> Result { - let backend = database::client()?; - let event_id = get_id("hermann"); // TODO fix id generation - let data = serde_json::to_string(&input.data)?; - let session = session.unwrap(); - - match backend { - database::Client::Postgres(mut client) => { - - let res = client.query( - "SELECT member_id FROM members \ - WHERE (room_id, account_id) = ($1, $2)", - &[&input.room_id, &session.account.unwrap().id])?; - let member_id: String = res[0].get(0); - - client.execute( - "INSERT INTO events (event_id, room_id, from_member_id, from_session_id, data) \ - VALUES ($1, $2, $3, $4, to_jsonb($5::text))", - &[&event_id, &input.room_id, &member_id, &session.id, &data])?; - } - } - - subscription::notify(subscription::Event { - data: input.data - }); - - Ok(SendEventOutput { event_id }) -} - -#[derive(Serialize, Deserialize)] -pub struct SubscribeInput { - -} - -#[derive(Serialize, Deserialize)] -pub struct SubscribeOutput { - event: subscription::Event, -} - -pub fn subscribe(session: Option, input: SubscribeInput) -> Result { - let rx = subscription::subscribe(); - let event = rx.recv().unwrap(); - subscription::unsubscribe(rx); - Ok(SubscribeOutput { event }) -} diff --git a/src/websocket.rs b/src/websocket.rs new file mode 100644 index 0000000..425f0d6 --- /dev/null +++ b/src/websocket.rs @@ -0,0 +1,97 @@ +use hyper::{Request, Response, Body, StatusCode, header}; +use crate::usimp::*; +use crate::usimp; +use crate::error::*; +use hyper_tungstenite::{WebSocketStream, tungstenite::protocol::Role}; +use futures_util::StreamExt; +use hyper_tungstenite::tungstenite::{handshake, Message}; +use hyper_tungstenite::hyper::upgrade::Upgraded; +use futures::stream::{SplitSink, SplitStream}; +use tokio::sync::mpsc; +use serde_json::{Value, Map}; +use futures_util::SinkExt; + +async fn sender(mut sink: SplitSink, Message>, mut rx: mpsc::Receiver) { + while let Some(msg) = rx.recv().await { + let mut envelope = Value::Object(Map::new()); + envelope["data"] = msg.data; + envelope["request_nr"] = match msg.request_nr { + Some(nr) => Value::from(nr), + None => Value::Null, + }; + match msg.error { + Some(error) => { + envelope["status"] = Value::from("error"); + envelope["error"] = Value::from(error); + }, + None => { + envelope["status"] = Value::from("success"); + } + } + if let Err(error) = sink.send(Message::Text(envelope.to_string())).await { + eprintln!("{:?}", error); + break + } + } +} + +async fn receiver(mut stream: SplitStream>, tx: mpsc::Sender) { + while let Some(res) = stream.next().await { + match res { + Ok(msg) => { + let input: InputEnvelope = serde_json::from_slice(&msg.into_data()[..]).unwrap(); + let output = match usimp::endpoint(&input).await { + Ok(output) => output, + Err(error) => input.error(error), + }; + tx.send(output).await; + }, + Err(error) => println!("{:?}", error), + } + } +} + +pub async fn handler(req: Request, res: hyper::http::response::Builder) -> (hyper::http::response::Builder, Option) { + match req.headers().get(header::UPGRADE) { + Some(val) if val == header::HeaderValue::from_str("websocket").unwrap() => {}, + _ => return (res, Some(OutputEnvelope::from(Error::new(ErrorKind::WebSocketError, ErrorClass::ClientProtocolError, None)))), + } + + let key = match req.headers().get(header::SEC_WEBSOCKET_KEY) { + Some(key) => key, + None => return (res, Some(OutputEnvelope::from(Error::new(ErrorKind::WebSocketError, ErrorClass::ClientProtocolError, None)))) + }; + let key = handshake::derive_accept_key(key.as_bytes()); + + match req.headers().get(header::SEC_WEBSOCKET_PROTOCOL) { + Some(val) if val == header::HeaderValue::from_str("usimp").unwrap() => {} + _ => return (res, Some(OutputEnvelope::from(Error::new(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, None)))), + } + + tokio::spawn(async move { + match hyper::upgrade::on(req).await { + Ok(upgraded) => { + let ws_stream = WebSocketStream::from_raw_socket( + upgraded, + Role::Server, + None, + ).await; + let (tx, rx) = mpsc::channel::(16); + let (sink, stream) = ws_stream.split(); + tokio::spawn(async move { + sender(sink, rx).await + }); + receiver(stream, tx).await + } + Err(error) => eprintln!("Unable to upgrade: {}", error) + } + }); + + (res + .status(StatusCode::SWITCHING_PROTOCOLS) + .header(header::CONNECTION, header::HeaderValue::from_str("Upgrade").unwrap()) + .header(header::UPGRADE, header::HeaderValue::from_str("websocket").unwrap()) + .header(header::SEC_WEBSOCKET_ACCEPT, header::HeaderValue::from_str(key.as_str()).unwrap()) + .header(header::SEC_WEBSOCKET_PROTOCOL, header::HeaderValue::from_str("usimp").unwrap()), + None) +} diff --git a/src/websocket/handler.rs b/src/websocket/handler.rs deleted file mode 100644 index 574fd9d..0000000 --- a/src/websocket/handler.rs +++ /dev/null @@ -1,202 +0,0 @@ -use crate::error::*; -use crate::http; - -use crate::usimp; -use crate::websocket::*; -use base64; -use crypto; -use crypto::digest::Digest; - -pub fn recv_message(client: &mut http::HttpStream) -> Result { - let mut msg: Vec = Vec::new(); - let mut msg_type = 0; - loop { - let header = FrameHeader::from(&mut client.stream)?; - - // FIXME control frames may show up in a fragmented stream - if msg_type != 0 && header.opcode != 0 { - return Err(Error::new(Kind::WebSocketError, Class::ClientProtocolError) - .set_desc("continuation frame expected".to_string())); - } else if header.opcode >= 8 && (!header.fin || header.payload_len >= 126) { - return Err(Error::new(Kind::WebSocketError, Class::ClientProtocolError) - .set_desc("invalid control frame".to_string())); - } - - match header.opcode { - 0 => {}, // cont - 1 => {}, // text - 2 => // binary - return Err(Error::new(Kind::UsimpProtocolError, Class::ClientProtocolError) - .set_desc("binary frames must not be sent on a usimp connection".to_string())), - 8 => {}, // close - 9 => {}, // ping - 10 => {}, // pong - _ => return Err(Error::new(Kind::WebSocketError, Class::ClientProtocolError) - .set_desc("invalid opcode".to_string())), - } - - msg_type = header.opcode; - - // FIXME check payload len and total len - - let mut buf = vec![0u8; header.payload_len() as usize]; - client.stream.read_exact(&mut buf)?; - - if header.mask { - let key: [u8; 4] = [ - (header.masking_key.unwrap() >> 24) as u8, - ((header.masking_key.unwrap() >> 16) & 0xFF) as u8, - ((header.masking_key.unwrap() >> 8) & 0xFF) as u8, - (header.masking_key.unwrap() & 0xFF) as u8, - ]; - for (pos, byte) in buf.iter_mut().enumerate() { - *byte ^= key[pos & 3]; // = pos % 4 - } - } - - msg.append(&mut buf); - - if header.fin { - break - } - } - - match msg_type { - 1 => {Ok(Message::TextMessage(TextMessage { - data: String::from_utf8(msg)? - }))}, - 8 => { - let mut code = None; - let mut reason = None; - - if msg.len() >= 2 { - code = Some(((msg[0] as u16) << 8) | (msg[1] as u16)); - } - - if msg.len() > 2 { - reason = Some(String::from_utf8(msg[2..].to_vec())?); - } - - Ok(Message::CloseMessage(CloseMessage { - code, - reason - })) - }, - 9 => {Ok(Message::PingMessage(PingMessage { - data: String::from_utf8(msg)? - }))}, - 10 => {Ok(Message::PongMessage(PongMessage { - data: String::from_utf8(msg)? - }))}, - _ => panic!("invalid msg_type for websocket") - } -} - -pub fn handshake( - client: &mut http::HttpStream, - req: &http::Request, - res: &mut http::Response, -) -> Result<(), Error> { - if let http::Method::GET = req.method { - } else { - res.status(405); - res.header.add_field("Allow", "GET"); - return Err(Error::new(Kind::WebSocketError, Class::ClientProtocolError) - .set_desc("method not allowed".to_string())); - } - - if let Some(_) = req.header.find_field("Connection") { - if !req.header.field_has_value("Connection", "upgrade") { - return Err(Error::new(Kind::WebSocketError, Class::ClientProtocolError) - .set_desc("invalid value for header field 'Connection'".to_string())); - } - } else { - return Err(Error::new(Kind::WebSocketError, Class::ClientProtocolError) - .set_desc("unable to find header field 'Connection'".to_string())); - } - - if let Some(upgrade) = req.header.find_field("Upgrade") { - if !upgrade.eq_ignore_ascii_case("websocket") { - return Err(Error::new(Kind::WebSocketError, Class::ClientProtocolError) - .set_desc("invalid value for header field 'Upgrade'".to_string())); - } - } else { - return Err(Error::new(Kind::WebSocketError, Class::ClientProtocolError) - .set_desc("unable to find header field 'Upgrade'".to_string())); - } - - if let Some(version) = req.header.find_field("Sec-WebSocket-Version") { - if !version.eq("13") { - return Err(Error::new(Kind::WebSocketError, Class::ClientProtocolError) - .set_desc("invalid value for header field 'Sec-WebSocket-Key'".to_string())); - } - } else { - return Err(Error::new(Kind::WebSocketError, Class::ClientProtocolError) - .set_desc("unable to find header field 'Sec-WebSocket-Version'".to_string())); - } - - if let Some(key) = req.header.find_field("Sec-WebSocket-Key") { - let mut hasher = crypto::sha1::Sha1::new(); - hasher.input_str(key); - hasher.input_str("258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); - - let mut result = [0u8; 160 / 8]; - hasher.result(&mut result); - let key = base64::encode(result); - res.header.add_field("Sec-WebSocket-Accept", key.as_str()); - } else { - return Err(Error::new(Kind::WebSocketError, Class::ClientProtocolError) - .set_desc("unable to find header field 'Sec-WebSocket-Key'".to_string())); - } - - client.server_keep_alive = false; - res.header.add_field("Connection", "Upgrade"); - res.header.add_field("Upgrade", "websocket"); - - res.status(101); - res.send(&mut client.stream)?; - - Ok(()) -} - -pub fn error_handler(error: Error) { - // TODO send error response frame -} - -pub fn connection_handler( - client: &mut http::HttpStream, - req: &http::Request, - mut res: http::Response, -) { - if let Err(error) = handshake(client, req, &mut res) { - return http::error_handler(client, res, error); - } - - loop { - match recv_message(client) { - Ok(msg) => { - match msg { - Message::TextMessage(msg) => { - // TODO threads? - let req: RequestEnvelope = serde_json::from_str(msg.data.as_str()).unwrap(); - println!("Endpoint: {}, ReqNo: {}, Data: {}", req.endpoint, req.request_nr, req.data); - //let a = usimp::endpoint(req.endpoint.as_str(), req.data).unwrap(); - }, - Message::CloseMessage(msg) => { - println!("Received close frame: {}: {}", msg.code.unwrap_or(0), msg.reason.unwrap_or("-".to_string())); - return - } - Message::PingMessage(msg) => { - // TODO send pong - }, - Message::PongMessage(msg) => { - // TODO something - } - } - }, - Err(error) => { - error_handler(error); - } - } - } -} diff --git a/src/websocket/mod.rs b/src/websocket/mod.rs deleted file mode 100644 index 0954b1c..0000000 --- a/src/websocket/mod.rs +++ /dev/null @@ -1,106 +0,0 @@ -mod handler; - -use serde::{Deserialize, Serialize}; -use crate::error::Error; -use crate::http; -pub use handler::*; - -pub struct WebSocketStream { - stream: http::HttpStream, - compression: bool, -} - -pub struct FrameHeader { - fin: bool, - rsv1: bool, - rsv2: bool, - rsv3: bool, - opcode: u8, - mask: bool, - payload_len: u8, - ex_payload_len: Option, - masking_key: Option, -} - -pub enum Message { - PingMessage(PingMessage), - PongMessage(PongMessage), - CloseMessage(CloseMessage), - TextMessage(TextMessage), -} - -pub struct PingMessage { - data: String, -} - -pub struct PongMessage { - data: String, -} - -pub struct CloseMessage { - code: Option, - reason: Option, -} - -pub struct TextMessage { - data: String, -} - -#[derive(Serialize, Deserialize)] -pub struct RequestEnvelope { - endpoint: String, - request_nr: u64, - data: serde_json::Value, -} - -#[derive(Serialize, Deserialize)] -pub struct ResponseEnvelope { - to_request_nr: u64, - status: String, - message: Option, - data: serde_json::Value, -} - -impl FrameHeader { - pub fn from(socket: &mut http::Stream) -> Result { - let mut data = [0u8; 2]; - socket.read_exact(&mut data)?; - - let mut header = FrameHeader { - fin: data[0] & 0x80 != 0, - rsv1: data[0] & 0x40 != 0, - rsv2: data[0] & 0x20 != 0, - rsv3: data[0] & 0x10 != 0, - opcode: data[0] & 0x0F, - mask: data[1] & 0x80 != 0, - payload_len: data[1] & 0x7F, - ex_payload_len: None, - masking_key: None, - }; - - if header.payload_len == 126 { - let mut data = [0u8; 2]; - socket.read_exact(&mut data)?; - header.ex_payload_len = Some(u16::from_be_bytes(data) as u64); - } else if header.payload_len == 127 { - let mut data = [0u8; 8]; - socket.read_exact(&mut data)?; - header.ex_payload_len = Some(u64::from_be_bytes(data)); - } - - if header.mask { - let mut data = [0u8; 4]; - socket.read_exact(&mut data)?; - header.masking_key = Some(u32::from_be_bytes(data)); - } - - Ok(header) - } - - pub fn payload_len(&self) -> u64 { - match self.ex_payload_len { - Some(val) => val, - None => self.payload_len as u64, - } - } -}