WebSockets and tokio working

This commit is contained in:
2021-06-04 15:22:55 +02:00
parent 1427443caf
commit a96cbdc059
18 changed files with 542 additions and 2007 deletions

View File

@ -7,14 +7,20 @@ edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
rusty_pool = "0.6.0" hyper = { version = "0.14", features = ["full"] }
tokio = { version = "1", features = ["full"] }
futures = "0.3"
tokio-tls = "0.3.1"
rustls = "0.19.1"
tokio-rustls = "0.22.0"
futures-util = "0.3.15"
async-stream = "0.3.2"
hyper-tungstenite = "0.3.2"
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0.64" serde_json = "1.0.64"
openssl = {version = "0.10", features = ["vendored"]}
chrono = "0.4" chrono = "0.4"
flate2 = "1.0.0" bb8 = "0.7.0"
r2d2 = "0.8.9" bb8-postgres = "0.7.0"
r2d2_postgres = "0.18.0"
ansi_term = "0.12" ansi_term = "0.12"
rust-crypto = "^0.2" rust-crypto = "^0.2"
base64 = "0.13.0" base64 = "0.13.0"

View File

@ -1,42 +1,42 @@
use crate::error::Error; use crate::error::*;
use r2d2_postgres::postgres::NoTls; use bb8_postgres::tokio_postgres::NoTls;
use r2d2_postgres::PostgresConnectionManager; use bb8_postgres::PostgresConnectionManager;
use std::ops::Deref; use std::ops::Deref;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex, MutexGuard};
use std::time::Duration; use std::time::Duration;
pub enum Pool { pub enum Pool {
Postgres(r2d2::Pool<PostgresConnectionManager<NoTls>>), Postgres(bb8::Pool<PostgresConnectionManager<NoTls>>),
} }
pub enum Client { pub enum Client<'a> {
Postgres(r2d2::PooledConnection<PostgresConnectionManager<NoTls>>), Postgres(bb8::PooledConnection<'a, PostgresConnectionManager<NoTls>>),
} }
static mut POOL: Option<Arc<Mutex<Pool>>> = None; static mut POOL: Option<Pool> = None;
pub fn init() -> Result<(), Error> { pub async fn init() -> Result<(), Error> {
let manager = PostgresConnectionManager::new( let manager = PostgresConnectionManager::new(
"host=localhost user=postgres dbname=locutus".parse().unwrap(), "host=localhost user=postgres dbname=locutus".parse().unwrap(),
NoTls, NoTls,
); );
let pool = r2d2::Pool::builder() let pool = bb8::Pool::builder()
.max_size(64) .max_size(64)
.min_idle(Some(2)) .min_idle(Some(2))
.connection_timeout(Duration::from_secs(4)) .connection_timeout(Duration::from_secs(4))
.max_lifetime(Some(Duration::from_secs(3600))) .max_lifetime(Some(Duration::from_secs(3600)))
.build(manager)?; .build(manager).await?;
unsafe { unsafe {
POOL = Some(Arc::new(Mutex::new(Pool::Postgres(pool)))); POOL = Some(Pool::Postgres(pool));
} }
Ok(()) Ok(())
} }
pub fn client() -> Result<Client, Error> { pub async fn client() -> Result<Client<'static>, Error> {
match unsafe { POOL.as_ref().unwrap().clone().lock().unwrap().deref() } { match unsafe { POOL.as_ref().unwrap().clone() } {
Pool::Postgres(pool) => Ok(Client::Postgres(pool.get()?)), Pool::Postgres(pool) => Ok(Client::Postgres(pool.get().await?)),
} }
} }

View File

@ -1,120 +1,107 @@
use std::fmt; use crate::usimp::{InputEnvelope, OutputEnvelope};
#[derive(Copy, Clone, Debug)] use serde_json::{Value, Map};
pub enum Kind { use bb8_postgres::tokio_postgres;
InvalidEndpointError, use bb8_postgres;
JsonParseError,
DatabaseConnectionError, #[derive(Debug)]
DatabaseError, pub struct Error {
HttpRequestParseError, pub kind: ErrorKind,
IoError, pub class: ErrorClass,
WebSocketError, pub msg: Option<String>,
NotImplementedError, pub desc: Option<String>,
UsimpProtocolError,
Utf8DecodeError,
AuthenticationError,
InvalidSessionError,
} }
#[derive(Copy, Clone, Debug)] #[derive(Debug)]
pub enum Class { pub enum ErrorClass {
ClientProtocolError, ClientProtocolError,
ClientError, ClientError,
ServerError, ServerError,
} }
#[derive(Debug)] #[derive(Debug)]
pub struct Error { pub enum ErrorKind {
kind: Kind, NotImplemented,
msg: Option<String>, UsimpError,
desc: Option<String>, WebSocketError,
class: Class, DatabaseError,
}
impl InputEnvelope {
pub fn new_error(&self, kind: ErrorKind, class: ErrorClass, msg: Option<String>) -> OutputEnvelope {
OutputEnvelope {
request_nr: self.request_nr,
error: Some(Error::new(kind, class, msg)),
data: Value::Null,
}
}
pub fn error(&self, error: Error) -> OutputEnvelope {
OutputEnvelope {
request_nr: self.request_nr,
error: Some(error),
data: Value::Null,
}
}
} }
impl Error { impl Error {
pub fn new(kind: Kind, class: Class) -> Self { pub fn new(kind: ErrorKind, class: ErrorClass, msg: Option<String>) -> Self {
Error { return Error {
kind, kind,
msg: None,
desc: None,
class, class,
msg,
desc: None,
} }
} }
pub fn class(&self) -> &Class { pub fn msg(&mut self, msg: String) {
&self.class
}
pub fn set_msg(mut self, msg: String) -> Self {
self.msg = Some(msg); self.msg = Some(msg);
self
} }
pub fn msg(&self) -> &str { pub fn code(&self) -> &str {
match &self.msg { match self.kind {
Some(msg) => msg.as_str(), ErrorKind::NotImplemented => "NOT_IMPLEMENTED",
None => match self.kind { ErrorKind::UsimpError => "USIMP_ERROR",
Kind::InvalidEndpointError => "Invalid endpoint", ErrorKind::WebSocketError => "WEBSOCKET_ERROR",
Kind::JsonParseError => "Unable to parse JSON data", ErrorKind::DatabaseError => "BACKEND_ERROR",
Kind::DatabaseConnectionError => "Unable to connect to database",
Kind::DatabaseError => "Database error",
Kind::HttpRequestParseError => "Unable to parse http request",
Kind::IoError => "IO error",
Kind::WebSocketError => "WebSocket protocol error",
Kind::NotImplementedError => "Not yet implemented",
Kind::UsimpProtocolError => "USIMP protocol error",
Kind::Utf8DecodeError => "Unable to decode UTF-8 data",
Kind::AuthenticationError => "Unable to authenticate",
Kind::InvalidSessionError => "Invalid session",
},
}
}
pub fn set_desc(mut self, desc: String) -> Self {
self.desc = Some(desc);
self
}
pub fn desc(&self) -> Option<&str> {
match &self.desc {
Some(desc) => Some(desc.as_str()),
None => None,
} }
} }
} }
impl fmt::Display for Error { impl From<Error> for OutputEnvelope {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn from(error: Error) -> Self {
let mut error = match self.kind { OutputEnvelope {
Kind::InvalidEndpointError => "invalid endpoint", error: Some(error),
Kind::JsonParseError => "unable to parse json data", data: Value::Null,
Kind::DatabaseConnectionError => "unable to connect to database", request_nr: None,
Kind::DatabaseError => "database error",
Kind::HttpRequestParseError => "unable to parse http request",
Kind::IoError => "io error",
Kind::WebSocketError => "websocket protocol error",
Kind::NotImplementedError => "not yet implemented",
Kind::UsimpProtocolError => "usimp protocol error",
Kind::Utf8DecodeError => "unable to decode utf-8 data",
Kind::AuthenticationError => "unable to authenticate",
Kind::InvalidSessionError => "invalid session",
} }
.to_string();
if let Some(desc) = &self.desc {
error += ": ";
error += desc;
}
write!(f, "{}", error)
} }
} }
impl From<std::io::Error> for Error { impl From<Error> for Value {
fn from(error: std::io::Error) -> Self { fn from(error: Error) -> Self {
let mut obj = Value::Object(Map::new());
obj["code"] = Value::from(error.code());
obj["message"] = match error.msg {
Some(msg) => Value::from(msg),
None => Value::Null,
};
obj["description"] = match error.desc {
Some(desc) => Value::from(desc),
None => Value::Null,
};
obj
}
}
impl From<hyper::header::ToStrError> for Error {
fn from(error: hyper::header::ToStrError) -> Self {
Error { Error {
kind: Kind::IoError, kind: ErrorKind::UsimpError,
msg: Some(error.to_string()), class: ErrorClass::ClientProtocolError,
msg: None,
desc: Some(error.to_string()), desc: Some(error.to_string()),
class: Class::ClientProtocolError,
} }
} }
} }
@ -122,45 +109,43 @@ impl From<std::io::Error> for Error {
impl From<serde_json::Error> for Error { impl From<serde_json::Error> for Error {
fn from(error: serde_json::Error) -> Self { fn from(error: serde_json::Error) -> Self {
Error { Error {
kind: Kind::JsonParseError, kind: ErrorKind::UsimpError,
msg: Some("Unable to parse JSON data".to_string()), class: ErrorClass::ClientProtocolError,
msg: None,
desc: Some(error.to_string()), desc: Some(error.to_string()),
class: Class::ClientProtocolError,
} }
} }
} }
impl From<r2d2::Error> for Error { impl From<hyper::Error> for Error {
fn from(error: r2d2::Error) -> Self { fn from(error: hyper::Error) -> Self {
Error { Error {
kind: Kind::DatabaseConnectionError, kind: ErrorKind::UsimpError,
msg: Some("Unable to connect to database".to_string()), class: ErrorClass::ClientProtocolError,
msg: None,
desc: Some(error.to_string()), desc: Some(error.to_string()),
class: Class::ServerError,
} }
} }
} }
impl From<r2d2_postgres::postgres::Error> for Error { impl From<tokio_postgres::Error> for Error {
fn from(error: r2d2_postgres::postgres::Error) -> Self { fn from(error: tokio_postgres::Error) -> Self {
// format: "db error: ERROR ..."
let msg = error.to_string().split(":").skip(1).collect::<String>();
Error { Error {
kind: Kind::DatabaseError, kind: ErrorKind::DatabaseError,
msg: Some("Database error".to_string()), class: ErrorClass::ServerError,
desc: Some(msg.trim().to_string()), msg: None,
class: Class::ServerError, desc: Some(error.to_string()),
} }
} }
} }
impl From<std::string::FromUtf8Error> for Error { impl From<bb8_postgres::bb8::RunError<tokio_postgres::Error>> for Error {
fn from(error: std::string::FromUtf8Error) -> Self { fn from(error: bb8_postgres::bb8::RunError<tokio_postgres::Error>) -> Self {
Error { Error {
kind: Kind::Utf8DecodeError, kind: ErrorKind::DatabaseError,
msg: Some("Unable to decode UTF-8 data".to_string()), class: ErrorClass::ServerError,
msg: None,
desc: Some(error.to_string()), desc: Some(error.to_string()),
class: Class::ClientProtocolError,
} }
} }
} }

142
src/http.rs Normal file
View File

@ -0,0 +1,142 @@
use hyper::{Request, Response, Body, StatusCode, header, Method, body};
use serde_json::{Value, Map};
use crate::websocket;
use crate::usimp::*;
use crate::error::*;
use crate::usimp;
use std::str::FromStr;
async fn endpoint_handler(req: &mut Request<Body>, endpoint: String) -> Result<Option<OutputEnvelope>, Error> {
if req.method() == Method::OPTIONS {
return Ok(None)
} else if req.method() != Method::POST {
return Err(Error::new(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, None))
}
let to_domain;
if let Some(val) = req.headers().get(header::HeaderName::from_str("To-Domain").unwrap()) {
to_domain = val.to_str().unwrap().to_string()
} else {
return Err(Error::new(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, None))
}
if let Some(val) = req.headers().get(header::CONTENT_TYPE) {
let parts: Vec<String> = val.to_str()?.split(';').map(|v| v.trim().to_ascii_lowercase()).collect();
let p: Vec<&str> = parts.iter().map(|v| v.as_str()).collect();
match p[0..1] {
["application/json"] => {},
_ => return Err(Error::new(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, None))
}
} else {
return Err(Error::new(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, None))
}
let data = serde_json::from_slice(&body::to_bytes(req.body_mut()).await?)?;
let input = InputEnvelope {
endpoint,
to_domain,
from_domain: match req.headers().get(header::HeaderName::from_str("From-Domain").unwrap()) {
Some(val) => Some(val.to_str()?.to_string()),
None => None,
},
request_nr: None,
token: match req.headers().get(header::AUTHORIZATION) {
Some(val) => {
let val = val.to_str()?;
if val.starts_with("usimp ") {
Some(val[6..].to_string())
} else {
return Err(Error::new(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, None))
}
},
None => None,
},
data,
};
Ok(Some(usimp::endpoint(&input).await?))
}
pub async fn handler(mut req: Request<Body>) -> Result<Response<Body>, hyper::Error> {
let mut res = Response::builder();
let parts: Vec<&str> = req.uri().path().split('/').skip(1).collect();
println!("{} {}", req.method(), req.uri());
let val: Result<Response<Body>, Error> = match &parts[..] {
[""] => Ok(res.status(StatusCode::OK).body(Body::from("Hello World")).unwrap()),
["_usimp"] | ["_usimp", ..] => {
res = res
.header(header::SERVER, "Locutus")
.header(header::CACHE_CONTROL, "no-store")
.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*")
.header(header::ACCESS_CONTROL_MAX_AGE, 3600);
let output = match &parts[1..] {
["websocket"] => {
res = res
.header(header::ACCESS_CONTROL_ALLOW_METHODS, "GET");
let (r, val) = websocket::handler(req, res).await;
res = r;
match val {
Some(val) => Ok(Some(val)),
None => return Ok(res.body(Body::empty()).unwrap()),
}
},
[endpoint] => {
res = res
.header(header::ACCESS_CONTROL_ALLOW_METHODS, "POST, OPTIONS")
.header(header::ACCESS_CONTROL_ALLOW_HEADERS, "Content-Type, From-Domain, To-Domain, Authorization");
let endpoint = endpoint.to_string();
endpoint_handler(&mut req, endpoint).await
},
_ => Err(Error::new(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, None)),
};
let output = match output {
Ok(Some(val)) => Some(val),
Ok(None) => None,
Err(error) => Some(OutputEnvelope::from(error)),
};
if let Some(output) = output {
let mut data = Value::Object(Map::new());
match output.error {
Some(error) => {
res = match error.class {
ErrorClass::ClientProtocolError => res.status(StatusCode::BAD_REQUEST),
ErrorClass::ServerError => res.status(StatusCode::INTERNAL_SERVER_ERROR),
_ => res.status(StatusCode::OK),
};
data["status"] = Value::from("error");
data["error"] = Value::from(error);
},
None => {
data["status"] = Value::from("success");
},
}
data["request_nr"] = match output.request_nr {
Some(nr) => Value::from(nr),
None => Value::Null,
};
data["data"] = output.data;
return Ok(res.body(Body::from(serde_json::to_string(&data).unwrap() + "\r\n")).unwrap())
} else {
res = res.status(StatusCode::NO_CONTENT);
}
return Ok(res.body(Body::empty()).unwrap())
},
_ => Ok(res.status(StatusCode::NOT_FOUND).body(Body::empty()).unwrap()),
};
match val {
Ok(val) => Ok(val),
Err(error) => {
todo!("help")
}
}
}

View File

@ -1,154 +0,0 @@
use super::StatusClass;
use super::StatusClass::*;
pub static HTTP_STATUSES: [(u16, StatusClass, &str, &str); 41] = [
(100, Informational, "Continue",
"The client SHOULD continue with its request."),
(101, Informational, "Switching Protocols",
"The server understands and is willing to comply with the clients request, via the Upgrade message header field, for a change in the application protocol being used on this connection."),
(200, Success, "OK",
"The request has succeeded."),
(201, Success, "Created",
"The request has been fulfilled and resulted in a new resource being created."),
(202, Success, "Accepted",
"The request has been accepted for processing, but the processing has not been completed."),
(203, Success, "Non-Authoritative Information",
"The returned meta information in the entity-header is not the definitive set as available from the origin server, but is gathered from a local or a third-party copy."),
(204, Success, "No Content",
"The server has fulfilled the request but does not need to return an entity-body, and might want to return updated meta information."),
(205, Success, "Reset Content",
"The server has fulfilled the request and the user agent SHOULD reset the document view which caused the request to be sent."),
(206, Success, "Partial Content",
"The server has fulfilled the partial GET request for the resource."),
(300, Redirection, "Multiple Choices",
"The requested resource corresponds to any one of a set of representations, each with its own specific location, and agent-driven negotiation information is being provided so that the user (or user agent) can select a preferred representation and redirect its request to that location."),
(301, Redirection, "Moved Permanently",
"The requested resource has been assigned a new permanent URI and any future references to this resource SHOULD use one of the returned URIs."),
(302, Redirection, "Found",
"The requested resource resides temporarily under a different URI."),
(303, Redirection, "See Other",
"The response to the request can be found under a different URI and SHOULD be retrieved using a GET method on that resource."),
(304, Success, "Not Modified",
"The request has been fulfilled and the requested resource has not been modified."),
(305, Redirection, "Use Proxy",
"The requested resource MUST be accessed through the proxy given by the Location field."),
(307, Redirection, "Temporary Redirect",
"The requested resource resides temporarily under a different URI."),
(308, Redirection, "Permanent Redirect",
"The requested resource has been assigned a new permanent URI and any future references to this resource ought to use one of the enclosed URIs."),
(400, ClientError, "Bad Request",
"The request could not be understood by the server due to malformed syntax."),
(401, ClientError, "Unauthorized",
"The request requires user authentication."),
(402, ClientError, "Payment Required",
""),
(403, ClientError, "Forbidden",
"The server understood the request, but is refusing to fulfill it."),
(404, ClientError, "Not Found",
"The server has not found anything matching the Request-URI."),
(405, ClientError, "Method Not Allowed",
"The method specified in the Request-Line is not allowed for the resource identified by the Request-URI."),
(406, ClientError, "Not Acceptable",
"The resource identified by the request is only capable of generating response entities which have content characteristics not acceptable according to the accept headers sent in the request."),
(407, ClientError, "Proxy Authentication Required",
"The request requires user authentication on the proxy."),
(408, ClientError, "Request Timeout",
"The client did not produce a request within the time that the server was prepared to wait."),
(409, ClientError, "Conflict",
"The request could not be completed due to a conflict with the current state of the resource."),
(410, ClientError, "Gone",
"The requested resource is no longer available at the server and no forwarding address is known."),
(411, ClientError, "Length Required",
"The server refuses to accept the request without a defined Content-Length."),
(412, ClientError, "Precondition Failed",
"The precondition given in one or more of the request-header fields evaluated to false when it was tested on the server."),
(413, ClientError, "Request Entity Too Large",
"The server is refusing to process a request because the request entity is larger than the server is willing or able to process."),
(414, ClientError, "Request-URI Too Long",
"The server is refusing to service the request because the Request-URI is longer than the server is willing to interpret."),
(415, ClientError, "Unsupported Media Type",
"The server is refusing to service the request because the entity of the request is in a format not supported by the requested resource for the requested method."),
(416, ClientError, "Range Not Satisfiable",
"None of the ranges in the requests Range header field overlap the current extent of the selected resource or that the set of ranges requested has been rejected due to invalid ranges or an excessive request of small or overlapping ranges."),
(417, ClientError, "Expectation Failed",
"The expectation given in an Expect request-header field could not be met by this server, or, if the server is a proxy, the server has unambiguous evidence that the request could not be met by the next-hop server."),
(500, ServerError, "Internal Server Error",
"The server encountered an unexpected condition which prevented it from fulfilling the request." ),
(501, ServerError, "Not Implemented",
"The server does not support the functionality required to fulfill the request."),
(502, ServerError, "Bad Gateway",
"The server, while acting as a gateway or proxy, received an invalid response from the upstream server it accessed in attempting to fulfill the request."),
(503, ServerError, "Service Unavailable",
"The server is currently unable to handle the request due to a temporary overloading or maintenance of the server."),
(504, ServerError, "Gateway Timeout",
"The server, while acting as a gateway or proxy, did not receive a timely response from the upstream server specified by the URI or some other auxiliary server it needed to access in attempting to complete the request."),
(505, ServerError, "HTTP Version Not Supported",
"The server does not support, or refuses to support, the HTTP protocol version that was used in the request message."),
];
pub static DEFAULT_DOCUMENT: &str = "\
<!DOCTYPE html>\n\
<html lang=\"en\">\n\
<head>\n\
\t<title>{status_code} {status_message} - Locutus - {hostname}</title>\n\
\t<meta charset=\"UTF-8\"/>\n\
\t<meta name=\"theme-color\" content=\"{theme_color}\"/>\n\
\t<meta name=\"color-scheme\" content=\"light dark\"/>\n\
\t<meta name=\"apple-mobile-web-app-status-bar-style\" content=\"black-translucent\"/>\n\
\t<meta name=\"viewport\" content=\"width=device-width,initial-scale=1.0\"/>\n\
\t<link rel=\"shortcut icon\" type=\"image/x-icon\" href=\"/favicon.ico\"/>\n\
\t<style>\n\
\t\thtml{{font-family:\"Arial\",sans-serif;--error:#C00000;--warning:#E0C000;--success:#008000;--info:#606060;--color:var(--{color_name});}}\n\
\t\tbody{{background-color:#F0F0F0;margin:0;}}\n\
\t\tmain{{max-width:650px;margin:2em auto;}}\n\
\t\tsection{{margin:1em;background-color:#FFFFFF;border: 1px solid var(--color);border-radius:4px;padding:1em;}}\n\
\t\th1,h2,h3,h4,h5,h6,h7{{text-align:center;color:var(--color);font-weight:normal;}}\n\
\t\th1{{font-size:3em;margin:0.125em 0 0.125em 0;}}\n\
\t\th2{{font-size:1.5em;margin:0.25em 0 1em 0;}}\n\
\t\tp{{text-align:center;font-size:0.875em;}}\n\
\t\tdiv.footer{{color:#808080;font-size:0.75em;text-align:center;margin:2em 0 0.5em 0;}}\n\
\t\tdiv.footer a{{color:#808080;}}\n\
\t\t@media(prefers-color-scheme:dark){{\n\
\t\t\thtml{{color:#FFFFFF;}}\n\
\t\t\tbody{{background-color:#101010;}}\n\
\t\t\tsection{{background-color:#181818;}}\n\
\t\t}}\n\
\t</style>\n\
</head>\n\
<body>\n\
\t<main>\n\
\t\t<section>\n\
{doc}\
\t\t\t<div class=\"footer\"><a href=\"https://{hostname}/\">{hostname}</a> - {server_str}</div>\n\
\t\t</section>\n\
\t</main>\n\
</body>\n\
</html>\n";
pub static ERROR_DOCUMENT: &str = "\
\t\t\t<h1>{code}</h1>\n\
\t\t\t<h2>{message} :&#xFEFF;(</h2>\n\
\t\t\t<p>{desc}</p>\n\
\t\t\t<p>{info}</p>\n";
pub static WARNING_DOCUMENT: &str = "\
\t\t\t<h1>{code}</h1>\n\
\t\t\t<h2>{message} :&#xFEFF;o</h2>\n\
\t\t\t<p>{desc}</p>\n\
\t\t\t<p>{info}</p>\n";
pub static SUCCESS_DOCUMENT: &str = "\
\t\t\t<h1>{code}</h1>\n\
\t\t\t<h2>{message} :&#xFEFF;)</h2>\n\
\t\t\t<p>{desc}</p>\n\
\t\t\t<p>{info}</p>\n";
pub static INFO_DOCUMENT: &str = "\
\t\t\t<h1>{code}</h1>\n\
\t\t\t<h2>{message} :&#xFEFF;)</h2>\n\
\t\t\t<p>{desc}</p>\n\
\t\t\t<p>{info}</p>\n";

View File

@ -1,239 +0,0 @@
use super::Method;
use crate::error::*;
use crate::usimp;
use crate::websocket;
use serde_json;
use crate::usimp::Envelope;
pub fn connection_handler(client: super::Stream) {
let mut client = super::HttpStream {
stream: client,
request_num: 0,
client_keep_alive: true,
server_keep_alive: true,
};
while client.request_num < super::REQUESTS_PER_CONNECTION
&& client.client_keep_alive
&& client.server_keep_alive
{
request_handler(&mut client);
client.request_num += 1;
}
}
fn request_handler(client: &mut super::HttpStream) {
let mut res = super::Response::new();
match super::parser::parse_request(&mut client.stream) {
Ok(Some(req)) => {
println!("{} {}", req.method, req.uri);
client.client_keep_alive =
client.client_keep_alive && req.header.field_has_value("Connection", "keep-alive");
if !req.uri.starts_with("/")
|| req.uri.contains("/./")
|| req.uri.contains("/../")
|| req.uri.ends_with("/..")
{
res.status(400);
} else if req.uri.contains("/.") {
res.status(404);
} else if req.uri.eq("/") {
res.status(200);
} else if req.uri.starts_with("/_usimp/") {
res.header.add_field("Cache-Control", "no-store");
res.header.add_field("Access-Control-Allow-Origin", "*");
res.header.add_field("Access-Control-Allow-Methods", "POST, OPTIONS");
res.header.add_field("Access-Control-Allow-Headers", "Content-Type, From-Domain, To-Domain, Authorization");
res.header.add_field("Access-Control-Max-Age", "3600");
if req.uri.eq("/_usimp/websocket") {
return websocket::connection_handler(client, &req, res);
}
// TODO check Content-Type == application/json
let mut error = None;
let parts: Vec<&str> = req.uri.split('/').collect();
match parts[2..] {
["entity", entity] => {
res.status(501);
error = Some(Error::new(Kind::NotImplementedError, Class::ServerError))
},
[endpoint] => match req.method {
Method::POST => {
return endpoint_handler(client, &req, res, endpoint)
},
Method::OPTIONS => {
res.status(204);
client.respond(&mut res);
return
}
_ => {
res.status(405);
res.header.add_field("Allow", "POST");
error = Some(Error::new(Kind::UsimpProtocolError, Class::ClientProtocolError))
}
},
_ => error = Some(Error::new(Kind::InvalidEndpointError, Class::ClientProtocolError)),
}
if let Some(error) = error {
error_handler(client, res, error);
}
return;
} else {
res.status(404);
}
}
Ok(None) => {
client.client_keep_alive = false;
return;
}
Err(e) => {
res.status(400);
res.error_info(format!("{}", &e));
println!("{}", &e);
client.server_keep_alive = false;
}
}
if let Err(e) = client.respond_default(&mut res) {
println!("Unable to send: {}", e);
client.server_keep_alive = false;
}
client.server_keep_alive = false;
}
pub fn error_handler(client: &mut super::HttpStream, mut res: super::Response, error: Error) {
println!("{}", error.to_string());
match &error.class() {
Class::ClientProtocolError => {
if res.status.code < 400 || res.status.code >= 499 {
res.status(400)
}
},
Class::ClientError => {
if res.status.code < 200 || res.status.code >= 299 {
res.status(200)
}
}
Class::ServerError => {
if res.status.code < 500 || res.status.code > 599 {
res.status(500)
}
}
}
res.error_info(error.to_string());
let mut obj = serde_json::Value::Object(serde_json::Map::new());
obj["status"] = serde_json::Value::String("error".to_string());
obj["message"] = serde_json::Value::String(error.to_string());
obj["data"] = serde_json::Value::Null;
let buf = obj.to_string() + "\r\n";
let length = buf.as_bytes().len();
res.header
.add_field("Content-Length", length.to_string().as_str());
res.header
.add_field("Content-Type", "application/json; charset=utf-8");
if let Err(e) = client.respond(&mut res) {
println!("Unable to send: {}", e);
client.server_keep_alive = false;
}
client.stream.write_all(buf.as_bytes()).unwrap();
}
fn endpoint_handler(
client: &mut super::HttpStream,
req: &super::Request,
mut res: super::Response,
endpoint: &str,
) {
let length = req.header.find_field("Content-Length");
let length: usize = match match length {
Some(length) => length,
None => {
return error_handler(
client,
res,
Error::new(Kind::HttpRequestParseError, Class::ClientProtocolError)
.set_desc("field 'Content-Length' missing".to_string()),
)
}
}
.parse()
{
Ok(length) => length,
Err(e) => {
return error_handler(
client,
res,
Error::new(Kind::HttpRequestParseError, Class::ClientProtocolError).set_desc(
format!("unable to parse field 'Content-Length': {}", &e).to_string(),
),
)
}
};
let mut buf = [0; 8192];
client.stream.read_exact(&mut buf[..length]).unwrap();
// TODO decompress
let data = match serde_json::from_slice(&buf[..length]) {
Ok(val) => val,
Err(e) => return error_handler(client, res, e.into()),
};
let mut authorization = None;
if let Some(auth) = req.header.find_field("Authorization") {
// TODO check usimp prefix in Authorization
authorization = Some(auth.split(" ").skip(1).collect());
}
let mut from_domain = None;
if let Some(from) = req.header.find_field("From-Domain") {
from_domain = Some(from.to_string());
}
let mut to_domain;
if let Some(to) = req.header.find_field("To-Domain") {
to_domain = to.to_string();
} else {
return error_handler(
client,
res,
Error::new(Kind::UsimpProtocolError, Class::ClientProtocolError)
.set_desc("Unable to find field 'To-Domain'".to_string())
);
}
let input = Envelope {
endpoint: endpoint.to_string(),
from_domain,
to_domain,
token: authorization,
data,
};
let buf = match usimp::endpoint(input) {
Ok(output) => output.to_string() + "\r\n",
Err(e) => return error_handler(client, res, e),
};
// TODO compress
let length = buf.as_bytes().len();
res.header
.add_field("Content-Length", length.to_string().as_str());
res.header
.add_field("Content-Type", "application/json; charset=utf-8");
res.status(200);
client.respond(&mut res).unwrap();
client.stream.write_all(buf.as_bytes()).unwrap();
}

View File

@ -1,394 +0,0 @@
mod consts;
mod handler;
mod parser;
use openssl::ssl::SslStream;
use std::fmt::Formatter;
use std::io::{Read, Write};
use std::net::TcpStream;
pub use handler::*;
static REQUESTS_PER_CONNECTION: u32 = 200;
pub enum Stream {
Tcp(TcpStream),
Ssl(SslStream<TcpStream>),
}
pub struct HttpStream {
pub stream: Stream,
pub request_num: u32,
pub client_keep_alive: bool,
pub server_keep_alive: bool,
}
pub enum Method {
GET,
POST,
PUT,
HEAD,
TRACE,
CONNECT,
DELETE,
OPTIONS,
Custom(String),
}
#[derive(Copy, Clone)]
pub enum StatusClass {
Informational,
Success,
Redirection,
ClientError,
ServerError,
}
#[derive(Clone)]
pub struct Status {
code: u16,
message: String,
desc: &'static str,
class: StatusClass,
info: Option<String>,
}
pub struct HeaderField {
name: String,
value: String,
}
pub struct Header {
fields: Vec<HeaderField>,
}
pub struct Request {
version: String,
pub method: Method,
pub uri: String,
pub header: Header,
}
pub struct Response {
version: String,
status: Status,
pub header: Header,
}
impl Method {
pub fn from_str(v: &str) -> Method {
match v {
"GET" => Method::GET,
"POST" => Method::POST,
"PUT" => Method::PUT,
"HEAD" => Method::HEAD,
"TRACE" => Method::TRACE,
"CONNECT" => Method::CONNECT,
"DELETE" => Method::DELETE,
"OPTIONS" => Method::OPTIONS,
_ => Method::Custom(String::from(v)),
}
}
pub fn to_str(&self) -> &str {
match self {
Method::GET => "GET",
Method::POST => "POST",
Method::PUT => "PUT",
Method::HEAD => "HEAD",
Method::TRACE => "TRACE",
Method::CONNECT => "CONNECT",
Method::DELETE => "DELETE",
Method::OPTIONS => "OPTIONS",
Method::Custom(v) => v,
}
}
}
impl std::fmt::Display for Method {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.to_str())
}
}
impl StatusClass {
pub fn from_code(status_code: u16) -> StatusClass {
for (code, class, _msg, _desc) in &consts::HTTP_STATUSES {
if *code == status_code {
return class.clone();
}
}
match status_code {
100..=199 => StatusClass::Informational,
200..=299 => StatusClass::Success,
300..=399 => StatusClass::Redirection,
400..=499 => StatusClass::ClientError,
500..=599 => StatusClass::ServerError,
_ => panic!("invalid status code"),
}
}
}
impl Status {
pub fn from_code(status_code: u16) -> Option<Status> {
for (code, class, msg, desc) in &consts::HTTP_STATUSES {
if *code == status_code {
return Some(Status {
code: status_code,
message: msg.to_string(),
desc,
class: class.clone(),
info: None,
});
}
}
None
}
pub fn new_custom(status_code: u16, message: &str) -> Status {
if status_code < 100 || status_code > 599 {
panic!("invalid status code");
}
if let Some(status) = Status::from_code(status_code) {
Status {
code: status_code,
message: message.to_string(),
desc: status.desc,
class: status.class,
info: None,
}
} else {
Status {
code: status_code,
message: message.to_string(),
desc: "",
class: StatusClass::from_code(status_code),
info: None,
}
}
}
}
impl std::fmt::Display for HeaderField {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "[{}: {}]", self.name, self.value)
}
}
impl Header {
pub fn new() -> Self {
Header { fields: Vec::new() }
}
pub fn from(fields: Vec<HeaderField>) -> Self {
Header { fields }
}
pub fn find_field(&self, field_name: &str) -> Option<&str> {
let field_name = field_name.to_lowercase();
for field in &self.fields {
if field.name.to_lowercase().eq(field_name.as_str()) {
return Some(field.value.as_str());
}
}
return None;
}
pub fn add_field(&mut self, name: &str, value: &str) {
self.fields.push(HeaderField {
name: String::from(name),
value: String::from(value),
})
}
pub fn field_has_value(&self, field_name: &str, value: &str) -> bool {
if let Some(field) = self.find_field(field_name) {
let value = value.to_lowercase();
field
.to_lowercase()
.split(",")
.any(|mut s| s.trim().eq(value.as_str()))
} else {
false
}
}
}
impl Response {
pub fn new() -> Response {
let mut res = Response {
version: "1.1".to_string(),
status: Status::from_code(200).unwrap(),
header: Header::new(),
};
res.header.add_field("Server", "Locutus");
res.header.add_field(
"Date",
chrono::Utc::now()
.format("%a, %d %b %Y %H:%M:%S GMT")
.to_string()
.as_str(),
);
res
}
pub fn status(&mut self, status_code: u16) {
self.status = Status::from_code(status_code).unwrap()
}
pub fn error_info(&mut self, info: String) {
self.status.info = Some(info);
}
pub fn send(&mut self, stream: &mut Stream) -> Result<(), std::io::Error> {
let mut header = format!(
"HTTP/{} {:03} {}\r\n",
self.version, self.status.code, self.status.message
);
for header_field in &self.header.fields {
header.push_str(format!("{}: {}\r\n", header_field.name, header_field.value).as_str());
}
header.push_str("\r\n");
stream.write_all(header.as_bytes())?;
Ok(())
}
pub fn send_default(&mut self, stream: &mut Stream) -> Result<(), std::io::Error> {
let mut buf = None;
if let None = self.header.find_field("Content-Length") {
let new_buf = self.format_default_response();
self.header.add_field(
"Content-Length",
new_buf.as_bytes().len().to_string().as_str(),
);
self.header
.add_field("Content-Type", "text/html; charset=utf-8");
buf = Some(new_buf);
}
self.send(stream)?;
if let Some(buf) = buf {
stream.write_all(buf.as_bytes())?;
}
Ok(())
}
fn format_default_response(&self) -> String {
let (doc, color_name, color) = match self.status.class {
StatusClass::Informational => (consts::INFO_DOCUMENT, "info", "#606060"),
StatusClass::Success => (consts::SUCCESS_DOCUMENT, "success", "#008000"),
StatusClass::Redirection => (consts::WARNING_DOCUMENT, "warning", "#E0C000"),
StatusClass::ClientError => (consts::ERROR_DOCUMENT, "error", "#C00000"),
StatusClass::ServerError => (consts::ERROR_DOCUMENT, "error", "#C00000"),
};
consts::DEFAULT_DOCUMENT
.replace("{status_code}", self.status.code.to_string().as_str())
.replace("{status_message}", self.status.message.as_str())
.replace("{hostname}", "localhost") // TODO hostname
.replace("{theme_color}", color)
.replace("{color_name}", color_name)
.replace("{server_str}", "Locutus server") // TODO server string
.replace(
"{doc}",
doc.replace("{code}", self.status.code.to_string().as_str())
.replace("{message}", self.status.message.as_str())
.replace("{desc}", self.status.desc)
.replace(
"{info}",
self.status.info.as_ref().unwrap_or(&String::new()).as_str(),
)
.as_str(),
)
.replace("{{", "{")
.replace("}}", "}")
}
}
impl HttpStream {
pub fn respond(&mut self, res: &mut Response) -> Result<(), std::io::Error> {
self.keep_alive(res);
res.send(&mut self.stream)
}
pub fn respond_default(&mut self, res: &mut Response) -> Result<(), std::io::Error> {
self.keep_alive(res);
res.send_default(&mut self.stream)
}
fn keep_alive(&mut self, res: &mut Response) {
if self.client_keep_alive && self.server_keep_alive {
res.header.add_field("Connection", "keep-alive");
res.header.add_field("Keep-Alive", "timeout=3600, max=200");
}
}
}
impl Stream {
pub fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
match self {
Stream::Tcp(stream) => stream.read(buf),
Stream::Ssl(stream) => loop {
match stream.ssl_read(buf) {
Ok(n) => return Ok(n),
Err(ref e) if e.code() == openssl::ssl::ErrorCode::ZERO_RETURN => return Ok(0),
Err(ref e)
if e.code() == openssl::ssl::ErrorCode::SYSCALL
&& e.io_error().is_none() =>
{
return Ok(0);
}
Err(ref e)
if e.code() == openssl::ssl::ErrorCode::WANT_READ
&& e.io_error().is_none() => {}
Err(e) => {
return Err(e.into_io_error().unwrap_or_else(|e| {
std::io::Error::new(std::io::ErrorKind::Other, e)
}));
}
}
},
}
}
pub fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), std::io::Error> {
match self {
Stream::Tcp(stream) => stream.read_exact(buf),
Stream::Ssl(stream) => stream.read_exact(buf),
}
}
pub fn peek(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
match self {
Stream::Tcp(stream) => stream.peek(buf),
Stream::Ssl(_stream) => todo!("Not implemented in rust-openssl"),
}
}
pub fn write(&mut self, buf: &[u8]) -> Result<usize, std::io::Error> {
match self {
Stream::Tcp(stream) => stream.write(buf),
Stream::Ssl(stream) => loop {
match stream.ssl_write(buf) {
Ok(n) => return Ok(n),
Err(ref e)
if e.code() == openssl::ssl::ErrorCode::WANT_READ
&& e.io_error().is_none() => {}
Err(e) => {
return Err(e.into_io_error().unwrap_or_else(|e| {
std::io::Error::new(std::io::ErrorKind::Other, e)
}));
}
}
},
}
}
pub fn write_all(&mut self, buf: &[u8]) -> Result<(), std::io::Error> {
match self {
Stream::Tcp(stream) => stream.write_all(buf),
Stream::Ssl(stream) => stream.write_all(buf),
}
}
}

View File

@ -1,374 +0,0 @@
use crate::error::*;
use crate::http;
use crate::http::Status;
pub fn parse_request(stream: &mut http::Stream) -> Result<Option<http::Request>, Error> {
let mut buf = [0; 4096];
let size = stream.peek(&mut buf)?;
if size == 0 {
return Ok(None);
}
let mut parser = Parser::new_request_parser(&buf[..size]);
let header_size = parser.parse()?;
let mut header_fields = Vec::new();
for (name, value) in parser.headers {
header_fields.push(http::HeaderField {
name: String::from(name),
value: String::from(value),
});
}
let request = http::Request {
version: String::from(parser.http_version.unwrap()),
method: http::Method::from_str(parser.method.unwrap()),
uri: String::from(parser.uri.unwrap()),
header: http::Header::from(header_fields),
};
stream.read_exact(&mut buf[..header_size])?;
Ok(Some(request))
}
pub fn parse_response(stream: &mut http::Stream) -> Result<http::Response, Error> {
let mut buf = [0; 4096];
let size = stream.peek(&mut buf)?;
let mut parser = Parser::new_request_parser(&buf[..size]);
let header_size = parser.parse()?;
let status_code = parser.status_code.unwrap();
let status_code = match status_code.parse::<u16>() {
Ok(v) => v,
Err(error) => {
return Err(Error::new(Kind::HttpRequestParseError, Class::ClientProtocolError)
.set_desc(error.to_string()))
}
};
let mut header_fields = Vec::new();
for (name, value) in parser.headers {
header_fields.push(http::HeaderField {
name: String::from(name),
value: String::from(value),
});
}
let response = http::Response {
version: String::from(parser.http_version.unwrap()),
status: Status::new_custom(status_code, parser.status_message.unwrap()),
header: http::Header::from(header_fields),
};
stream.read_exact(&mut buf[..header_size])?;
Ok(response)
}
#[derive(Copy, Clone)]
enum State<'a> {
Method,
Uri,
Http(&'a State<'a>),
HttpVersion(&'a State<'a>),
StatusCode,
StatusMessage,
HeaderName,
HeaderValue,
Finish,
CRLF(&'a State<'a>),
Error,
}
struct Parser<'a> {
state: State<'a>,
buf: &'a [u8],
str_start: usize,
header_size: usize,
method: Option<&'a str>,
uri: Option<&'a str>,
http_version: Option<&'a str>,
status_code: Option<&'a str>,
status_message: Option<&'a str>,
headers: Vec<(&'a str, &'a str)>,
}
impl Parser<'_> {
fn new_request_parser(buf: &[u8]) -> Parser {
Parser {
state: State::Method,
buf,
str_start: 0,
header_size: 0,
method: None,
uri: None,
http_version: None,
status_code: None,
status_message: None,
headers: Vec::new(),
}
}
fn new_response_parser(buf: &[u8]) -> Parser {
Parser {
state: State::Http(&State::StatusCode),
buf,
str_start: 0,
header_size: 0,
method: None,
uri: None,
http_version: None,
status_code: None,
status_message: None,
headers: Vec::new(),
}
}
fn parse(&mut self) -> Result<usize, Error> {
for char in self.buf {
self.next(*char);
match self.state {
State::Finish => return Ok(self.header_size),
State::Error => {
return Err(Error::new(Kind::HttpRequestParseError, Class::ClientProtocolError)
.set_desc(format!(
"invalid character at position {}",
self.header_size - 1
)))
}
_ => {}
}
}
return Err(Error::new(Kind::HttpRequestParseError, Class::ClientProtocolError)
.set_desc("input too short".to_string()));
}
fn next(&mut self, char: u8) {
self.header_size += 1;
let get_str =
|| std::str::from_utf8(&self.buf[self.str_start..self.header_size - 1]).unwrap();
self.state = match &self.state {
State::Error => State::Error,
State::Finish => State::Error,
State::Method => match char {
0x41..=0x5A => State::Method,
0x20 => {
self.method = Some(get_str());
self.str_start = self.header_size;
State::Uri
}
_ => State::Error,
},
State::Uri => match char {
0x21..=0x7E => State::Uri,
0x20 => {
self.uri = Some(get_str());
self.str_start = self.header_size;
State::Http(&State::HeaderName)
}
_ => State::Error,
},
State::Http(next) => match char {
0x48 | 0x54 | 0x50 => State::Http(next),
0x2F => {
let http = get_str();
self.str_start = self.header_size;
if http != "HTTP" {
State::Error
} else {
State::HttpVersion(next)
}
}
_ => State::Error,
},
State::HttpVersion(next) => match char {
0x30..=0x39 | 0x2E => State::HttpVersion(next),
0x0D => match next {
State::HeaderName => {
self.http_version = Some(get_str());
State::CRLF(next)
}
_ => State::Error,
},
0x20 => match next {
State::StatusCode => {
self.http_version = Some(get_str());
self.str_start = self.header_size;
State::StatusCode
}
_ => State::Error,
},
_ => State::Error,
},
State::StatusCode => match char {
0x30..=0x39 => State::StatusCode,
0x20 => {
self.status_code = Some(get_str());
self.str_start = self.header_size;
State::StatusMessage
}
_ => State::Error,
},
State::StatusMessage => match char {
0x20..=0x7E => State::StatusMessage,
0x0D => {
self.status_message = Some(get_str());
State::CRLF(&State::HeaderName)
}
_ => State::Error,
},
State::HeaderName => match char {
0x0D => {
if self.header_size == self.str_start + 1 {
State::CRLF(&State::Finish)
} else {
State::Error
}
}
0x3A => {
let header_name = get_str();
self.headers.push((header_name, ""));
self.str_start = self.header_size;
State::HeaderValue
}
0x00..=0x1F
| 0x7F
| 0x80..=0xFF
| 0x20
| 0x28
| 0x29
| 0x2C
| 0x2F
| 0x3A..=0x40
| 0x5B..=0x5D
| 0x7B
| 0x7D => State::Error,
_ => State::HeaderName,
},
State::HeaderValue => match char {
0x20..=0x7E | 0x09 => State::HeaderValue,
0x0D => {
self.headers.last_mut().unwrap().1 = get_str().trim();
State::CRLF(&State::HeaderName)
}
_ => State::Error,
},
State::CRLF(next) => match char {
0x0A => {
self.str_start = self.header_size;
*next.clone()
}
_ => State::Error,
},
}
}
}
#[cfg(test)]
mod tests {
#[test]
fn simple_request() {
let request: &str = "GET /index.html HTTP/1.1\r\n\
Host: www.example.com\r\n\
\r\n";
let mut parser = super::Parser::new_request_parser(request.as_bytes());
let size = parser.parse().unwrap();
assert_eq!(51, size);
assert_eq!("GET", parser.method.unwrap());
assert_eq!("/index.html", parser.uri.unwrap());
assert_eq!("1.1", parser.http_version.unwrap());
assert_eq!(None, parser.status_code);
assert_eq!(None, parser.status_message);
assert_eq!(1, parser.headers.len());
assert_eq!(("Host", "www.example.com"), parser.headers[0]);
}
#[test]
fn complex_request() {
let request: &str = "POST /upload/file.txt HTTP/1.3\r\n\
Host: www.example.com \r\n\
Content-Length: 13 \r\n\
User-Agent: Mozilla/5.0 (X11; Linux x86_64) \r\n\
\r\n\
username=test";
let mut parser = super::Parser::new_request_parser(request.as_bytes());
let size = parser.parse().unwrap();
assert_eq!(129, size);
assert_eq!("POST", parser.method.unwrap());
assert_eq!("/upload/file.txt", parser.uri.unwrap());
assert_eq!("1.3", parser.http_version.unwrap());
assert_eq!(None, parser.status_code);
assert_eq!(None, parser.status_message);
assert_eq!(3, parser.headers.len());
assert_eq!(("Host", "www.example.com"), parser.headers[0]);
assert_eq!(("Content-Length", "13"), parser.headers[1]);
assert_eq!(
("User-Agent", "Mozilla/5.0 (X11; Linux x86_64)"),
parser.headers[2]
);
assert_eq!("username=test", &request[size..]);
}
#[test]
fn invalid_request_1() {
let request: &str = "GET /files/größe.txt HTTP/1.1\r\n\r\n";
let mut parser = super::Parser::new_request_parser(request.as_bytes());
match parser.parse() {
Ok(_v) => panic!("should fail"),
Err(e) => assert_eq!(
"unable to parse http request: invalid character at position 13",
e.to_string()
),
}
}
#[test]
fn invalid_request_2() {
let request: &str = "GET /index.html HTT";
let mut parser = super::Parser::new_request_parser(request.as_bytes());
match parser.parse() {
Ok(_v) => panic!("should fail"),
Err(e) => assert_eq!(
"unable to parse http request: input too short",
e.to_string()
),
}
}
#[test]
fn simple_response() {
let response: &str = "HTTP/1.1 200 OK\r\n\
Content-Length: 12\r\n\
Content-Type: text/plain; charset=us-ascii\r\n\
\r\n\
Hello world!";
let mut parser = super::Parser::new_response_parser(response.as_bytes());
let size = parser.parse().unwrap();
assert_eq!(83, size);
assert_eq!("200", parser.status_code.unwrap());
assert_eq!("OK", parser.status_message.unwrap());
assert_eq!("1.1", parser.http_version.unwrap());
assert_eq!(None, parser.method);
assert_eq!(None, parser.uri);
assert_eq!(2, parser.headers.len());
assert_eq!(("Content-Length", "12"), parser.headers[0]);
assert_eq!(
("Content-Type", "text/plain; charset=us-ascii"),
parser.headers[1]
);
assert_eq!("Hello world!", &response[size..]);
}
}

View File

@ -1,149 +1,112 @@
use std::net::{SocketAddr, TcpListener, UdpSocket}; use std::fmt;
use std::sync::{Arc, Mutex}; use std::io::Read;
use std::thread; use std::net;
use std::net::SocketAddr;
use std::pin::Pin;
use ansi_term::Color; use error::*;
use openssl::ssl::{SslAcceptor, SslFiletype, SslMethod}; use ansi_term::{Color, Style};
use rusty_pool; use futures_util::{future::TryFutureExt, stream::Stream};
use std::fmt::Formatter; use hyper::Server;
use std::time::Duration; use hyper::server::conn::AddrStream;
use hyper::service::{make_service_fn, service_fn};
mod http;
mod websocket;
mod usimp;
mod database; mod database;
mod error; mod error;
mod http;
mod subscription;
mod udp;
mod usimp;
mod websocket;
enum SocketType { struct HyperAcceptor<'a> {
Http, acceptor: Pin<Box<dyn Stream<Item = Result<tokio_rustls::server::TlsStream<tokio::net::TcpStream>, std::io::Error>> + 'a>>,
Https,
Udp,
} }
impl std::fmt::Display for SocketType { impl hyper::server::accept::Accept for HyperAcceptor<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { type Conn = tokio_rustls::server::TlsStream<tokio::net::TcpStream>;
write!( type Error = std::io::Error;
f,
"{}", fn poll_accept(
match self { mut self: Pin<&mut Self>,
SocketType::Http => "http+ws", cx: &mut core::task::Context,
SocketType::Https => "https+wss", ) -> core::task::Poll<Option<Result<Self::Conn, Self::Error>>> {
SocketType::Udp => "udp", Pin::new(&mut self.acceptor).poll_next(cx)
}
)
} }
} }
struct SocketConfig { fn load_certs(filename: &str) -> std::io::Result<Vec<rustls::Certificate>> {
address: SocketAddr, let certfile = std::fs::File::open(filename)
socket_type: SocketType, .map_err(|e| error(format!("failed to open {}: {}", filename, e)))?;
let mut reader = std::io::BufReader::new(certfile);
rustls::internal::pemfile::certs(&mut reader).map_err(|_| error("failed to load certificate".into()))
} }
fn main() { fn load_private_key(filename: &str) -> std::io::Result<rustls::PrivateKey> {
let keyfile = std::fs::File::open(filename)
.map_err(|e| error(format!("failed to open {}: {}", filename, e)))?;
let mut reader = std::io::BufReader::new(keyfile);
let keys = rustls::internal::pemfile::rsa_private_keys(&mut reader)
.map_err(|_| error("failed to load private key".into()))?;
if keys.len() < 1 {
return Err(error("expected a single private key".into()));
}
Ok(keys[0].clone())
}
fn error(err: String) -> std::io::Error {
std::io::Error::new(std::io::ErrorKind::Other, err)
}
#[tokio::main]
async fn main() -> Result<(), Error> {
println!("Locutus server"); println!("Locutus server");
let socket_configs: Vec<SocketConfig> = vec![ database::init().await?;
SocketConfig {
address: "[::]:8080".parse().unwrap(),
socket_type: SocketType::Http,
},
SocketConfig {
address: "[::]:8443".parse().unwrap(),
socket_type: SocketType::Https,
},
SocketConfig {
address: "[::]:3126".parse().unwrap(),
socket_type: SocketType::Udp,
},
];
subscription::init(); let server1 = Server::bind(&SocketAddr::from(([0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0], 8080)));
let service = make_service_fn(|_: &AddrStream| async {
// Note: rust's stdout is line buffered! Ok::<_, hyper::Error>(service_fn(http::handler))
eprint!("Initializing database connection pool...");
if let Err(error) = database::init() {
eprintln!("\n{}", Color::Red.bold().paint(error.to_string()));
std::process::exit(1);
}
eprintln!(" {}", Color::Green.paint("success"));
let thread_pool = rusty_pool::Builder::new()
.core_size(4)
.max_size(1024)
.keep_alive(Duration::from_secs(60 * 60))
.build();
let thread_pool_mutex = Arc::new(Mutex::new(thread_pool));
let mut threads = Vec::new();
for socket_config in socket_configs {
let thread_pool_mutex = thread_pool_mutex.clone();
eprintln!(
"Creating listening thread for {} ({})",
ansi_term::Style::new()
.bold()
.paint(socket_config.address.to_string()),
socket_config.socket_type
);
threads.push(match socket_config.socket_type {
SocketType::Http => thread::spawn(move || {
let mut tcp_socket = TcpListener::bind(socket_config.address).unwrap();
for stream in tcp_socket.incoming() {
thread_pool_mutex.lock().unwrap().execute(|| {
let stream = stream.unwrap();
http::connection_handler(http::Stream::Tcp(stream));
}); });
} let srv1 = server1.serve(service);
}),
SocketType::Https => thread::spawn(move || {
let mut ssl_socket = TcpListener::bind(socket_config.address).unwrap();
let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); let tls_cfg = {
acceptor let certs = load_certs("/home/lorenz/Certificates/priv/fullchain.pem").unwrap();
.set_certificate_chain_file("/home/lorenz/Certificates/chakotay.pem") let key = load_private_key("/home/lorenz/Certificates/priv/privkey.pem").unwrap();
.unwrap(); let mut cfg = rustls::ServerConfig::new(rustls::NoClientAuth::new());
acceptor cfg.set_single_cert(certs, key).unwrap();
.set_private_key_file( cfg.set_protocols(&[b"h2".to_vec(), b"http/1.1".to_vec()]);
"/home/lorenz/Certificates/priv/chakotay.key", std::sync::Arc::new(cfg)
SslFiletype::PEM, };
)
.unwrap();
acceptor.check_private_key().unwrap();
let acceptor = Arc::new(acceptor.build());
for stream in ssl_socket.incoming() {
let acceptor = acceptor.clone();
thread_pool_mutex.lock().unwrap().execute(move || {
let stream = stream.unwrap();
let stream = acceptor.accept(stream).unwrap();
http::connection_handler(http::Stream::Ssl(stream));
});
}
}),
SocketType::Udp => thread::spawn(move || {
let mut udp_socket = UdpSocket::bind(socket_config.address).unwrap();
let mut buf = [0; 65_536];
let acceptor = tokio_rustls::TlsAcceptor::from(tls_cfg);
let tcp = Box::pin(tokio::net::TcpListener::bind("[::]:8443").await.unwrap());
let incoming_tls_stream = async_stream::stream! {
loop { loop {
let (size, addr) = udp_socket.recv_from(&mut buf).unwrap(); let (socket, _) = tcp.accept().await.unwrap();
let req = udp::Request::new(&udp_socket, addr, size, &buf); let stream = acceptor.accept(socket).map_err(|e| {
thread_pool_mutex println!("[!] Voluntary server halt due to client-connection error...");
.lock() // Errors could be handled here, instead of server aborting.
.unwrap() //Ok(None)
.execute(|| udp::handler(req)); //println!("{:?}", e);
} error(format!("TLS Error: {:?}", e))
}),
}); });
yield stream.await;
} }
};
let server2 = Server::builder(HyperAcceptor {
acceptor: Box::pin(incoming_tls_stream),
});
let service = make_service_fn(|_| async {
Ok::<_, hyper::Error>(service_fn(http::handler))
});
let srv2 = server2.serve(service);
println!("{}", Color::Green.paint("Ready")); println!("{}", Color::Green.paint("Ready"));
for thread in threads { let (_res1, _res2) = futures::future::join(srv1, srv2).await;
thread.join().unwrap();
} Ok(())
} }

View File

@ -1,34 +0,0 @@
use std::sync::{Arc, Mutex, mpsc};
use serde::{Deserialize, Serialize};
use serde_json;
static mut SUBSCRIPTIONS: Option<Arc<Mutex<Vec<mpsc::Sender<Event>>>>> = None;
#[derive(Clone, Serialize, Deserialize)]
pub struct Event{
pub data: serde_json::Value,
}
pub fn init() {
unsafe {
SUBSCRIPTIONS = Some(Arc::new(Mutex::new(Vec::new())));
}
}
pub fn subscribe() -> mpsc::Receiver<Event> {
let (rx, tx) = mpsc::channel();
unsafe { SUBSCRIPTIONS.as_ref().unwrap().lock().unwrap().push(rx); }
tx
}
pub fn unsubscribe(rx: mpsc::Receiver<Event>) {
// TODO implement unsubscribe
}
pub fn notify(event: Event) {
for sender in unsafe { SUBSCRIPTIONS.as_ref().unwrap().lock().unwrap().clone() } {
sender.send(event.clone());
}
}

View File

@ -1,28 +0,0 @@
use std::net::{SocketAddr, UdpSocket};
pub struct Request {
socket: UdpSocket,
address: SocketAddr,
size: usize,
buf: [u8; 65_536],
}
impl Request {
pub fn new(
socket: &UdpSocket,
address: SocketAddr,
size: usize,
buf: &[u8; 65_536],
) -> Request {
Request {
socket: socket.try_clone().unwrap(),
address,
size,
buf: buf.clone(),
}
}
}
pub fn handler(request: Request) {
// TODO handle UDP requests
}

View File

@ -0,0 +1,34 @@
use crate::usimp::*;
use crate::database;
use serde_json::{Value, from_value, to_value};
use serde::{Serialize, Deserialize};
use std::ops::Deref;
#[derive(Serialize, Deserialize, Clone)]
struct Input {
name: String,
password: String,
}
#[derive(Serialize, Deserialize, Clone)]
struct Output {
session: String,
token: String,
}
pub async fn handle(input: &InputEnvelope, session: &Session) -> Result<Value, Error> {
Ok(to_value(authenticate(from_value(input.data.clone())?).await?)?)
}
async fn authenticate(input: Input) -> Result<Output, Error> {
match database::client().await? {
database::Client::Postgres(client) => {
client.execute("SELECT * FROM asdf;", &[]).await?;
}
}
Ok(Output {
session: "".to_string(),
token: "".to_string(),
})
}

20
src/usimp/handler/mod.rs Normal file
View File

@ -0,0 +1,20 @@
mod ping;
mod authenticate;
use crate::usimp::*;
pub async fn endpoint(input: &InputEnvelope) -> Result<OutputEnvelope, Error> {
println!("Endpoint: {}", input.endpoint);
let session= Session {
account: None,
id: "".to_string(),
nr: 0,
};
Ok(match input.endpoint.as_str() {
"ping" => input.respond(ping::handle(&input, &session).await?),
"authenticate" => input.respond(authenticate::handle(&input, &session).await?),
_ => input.new_error(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, Some("Invalid endpoint".to_string())),
})
}

11
src/usimp/handler/ping.rs Normal file
View File

@ -0,0 +1,11 @@
use crate::usimp::*;
use serde_json::Value;
pub async fn handle(input: &InputEnvelope, session: &Session) -> Result<Value, Error> {
ping(&input.data).await
}
async fn ping(input: &Value) -> Result<Value, Error> {
Ok(input.clone())
}

View File

@ -1,25 +1,25 @@
use serde::{Deserialize, Serialize}; mod handler;
use serde_json;
use crate::subscription; pub use handler::endpoint;
use crate::database;
use crate::error::*;
use crypto::digest::Digest;
use rand;
use rand::Rng;
pub struct Envelope { use serde_json::Value;
use crate::error::{Error, ErrorClass, ErrorKind};
use serde::{Serialize, Deserialize};
#[derive(Serialize, Deserialize)]
pub struct InputEnvelope {
pub endpoint: String, pub endpoint: String,
pub from_domain: Option<String>, pub from_domain: Option<String>,
pub to_domain: String, pub to_domain: String,
pub token: Option<String>, pub token: Option<String>,
pub data: serde_json::Value, pub request_nr: Option<u64>,
pub data: Value,
} }
pub struct Account { pub struct OutputEnvelope {
id: String, pub error: Option<Error>,
name: String, pub request_nr: Option<u64>,
domain: String, pub data: Value,
} }
pub struct Session { pub struct Session {
@ -28,214 +28,22 @@ pub struct Session {
account: Option<Account>, account: Option<Account>,
} }
pub struct Account {
}
impl InputEnvelope {
pub fn respond(&self, data: Value) -> OutputEnvelope {
OutputEnvelope {
error: None,
request_nr: self.request_nr,
data,
}
}
}
impl Session { impl Session {
pub fn from_token(token: &str) -> Result<Self, Error> { pub async fn from_token(token: &str) -> Self {
let backend = database::client()?; todo!("session")
let mut session;
match backend {
database::Client::Postgres(mut client) => {
let res = client.query(
"SELECT session_id, session_nr, a.account_id, account_name, domain_id \
FROM accounts a JOIN sessions s ON a.account_id = s.account_id \
WHERE session_token = $1",
&[&token]
)?;
if res.len() == 0 {
return Err(Error::new(Kind::InvalidSessionError, Class::ClientError));
}
session = Session {
id: res[0].get(0),
nr: res[0].get(1),
account: Some(Account {
id: res[0].get(2),
name: res[0].get(3),
domain: res[0].get(4),
}),
};
}
}
Ok(session)
} }
} }
pub fn endpoint(envelope: Envelope) -> Result<serde_json::Value, Error> {
if envelope.from_domain != None {
// TODO
return Err(Error::new(Kind::NotImplementedError, Class::ServerError));
}
let mut session = None;
if let Some(token) = &envelope.token {
session = Some(Session::from_token(token)?);
}
let out = match envelope.endpoint.as_str() {
"echo" => serde_json::to_value(echo(session, serde_json::from_value(envelope.data)?)?)?,
"authenticate" => serde_json::to_value(authenticate(session, serde_json::from_value(envelope.data)?)?)?,
"subscribe" => serde_json::to_value(subscribe(session, serde_json::from_value(envelope.data)?)?)?,
"send_event" => serde_json::to_value(send_event(session, serde_json::from_value(envelope.data)?)?)?,
_ => return Err(Error::new(Kind::InvalidEndpointError, Class::ClientProtocolError)),
};
let mut envelope = serde_json::Value::Object(serde_json::Map::new());
envelope["status"] = serde_json::Value::String("success".to_string());
envelope["message"] = serde_json::Value::Null;
envelope["data"] = out;
Ok(envelope)
}
pub fn get_id(input: &str) -> String {
let mut hasher = crypto::sha2::Sha256::new();
hasher.input_str(chrono::Utc::now().timestamp_millis().to_string().as_str());
hasher.input_str(" ");
hasher.input_str(input);
let mut result = [0u8; 32];
hasher.result(&mut result);
base64_url::encode(&result)
}
#[derive(Serialize, Deserialize)]
pub struct EchoInput {
message: String,
}
#[derive(Serialize, Deserialize)]
pub struct EchoOutput {
message: String,
database: Option<i32>,
}
pub fn echo(session: Option<Session>, input: EchoInput) -> Result<EchoOutput, Error> {
let backend = database::client()?;
let mut output = EchoOutput {
message: input.message,
database: None,
};
match backend {
database::Client::Postgres(mut client) => {
let res = client.query("SELECT * FROM test", &[])?;
for row in res {
output.database = Some(row.get(0));
}
}
}
Ok(output)
}
#[derive(Serialize, Deserialize)]
pub struct AuthenticateInput {
r#type: String,
name: String,
password: String,
}
#[derive(Serialize, Deserialize)]
pub struct AuthenticateOutput {
token: String,
}
pub fn authenticate(session: Option<Session>, input: AuthenticateInput) -> Result<AuthenticateOutput, Error> {
let backend = database::client()?;
let mut token: String;
match backend {
database::Client::Postgres(mut client) => {
let res = client.query(
"SELECT account_id FROM accounts WHERE account_name = $1",
&[&input.name]
)?;
if res.len() == 0 {
return Err(Error::new(Kind::AuthenticationError, Class::ClientError));
}
let account_id: String = res[0].get(0);
// TODO password check
if !input.password.eq("MichaelScott") {
return Err(Error::new(Kind::AuthenticationError, Class::ClientError));
}
token = rand::thread_rng()
.sample_iter(&rand::distributions::Alphanumeric)
.take(256)
.map(char::from)
.collect();
let session_id: String = rand::thread_rng()
.sample_iter(&rand::distributions::Alphanumeric)
.take(43)
.map(char::from)
.collect();
client.execute(
"INSERT INTO sessions (account_id, session_nr, session_id, session_token) \
VALUES ($1, COALESCE((SELECT MAX(session_nr) + 1 FROM sessions WHERE account_id = $1), 1), $2, $3)",
&[&account_id, &session_id, &token],
)?;
}
}
Ok(AuthenticateOutput { token })
}
#[derive(Serialize, Deserialize)]
pub struct SendEventInput {
room_id: String,
data: serde_json::Value,
}
#[derive(Serialize, Deserialize)]
pub struct SendEventOutput {
event_id: String,
}
pub fn send_event(session: Option<Session>, input: SendEventInput) -> Result<SendEventOutput, Error> {
let backend = database::client()?;
let event_id = get_id("hermann"); // TODO fix id generation
let data = serde_json::to_string(&input.data)?;
let session = session.unwrap();
match backend {
database::Client::Postgres(mut client) => {
let res = client.query(
"SELECT member_id FROM members \
WHERE (room_id, account_id) = ($1, $2)",
&[&input.room_id, &session.account.unwrap().id])?;
let member_id: String = res[0].get(0);
client.execute(
"INSERT INTO events (event_id, room_id, from_member_id, from_session_id, data) \
VALUES ($1, $2, $3, $4, to_jsonb($5::text))",
&[&event_id, &input.room_id, &member_id, &session.id, &data])?;
}
}
subscription::notify(subscription::Event {
data: input.data
});
Ok(SendEventOutput { event_id })
}
#[derive(Serialize, Deserialize)]
pub struct SubscribeInput {
}
#[derive(Serialize, Deserialize)]
pub struct SubscribeOutput {
event: subscription::Event,
}
pub fn subscribe(session: Option<Session>, input: SubscribeInput) -> Result<SubscribeOutput, Error> {
let rx = subscription::subscribe();
let event = rx.recv().unwrap();
subscription::unsubscribe(rx);
Ok(SubscribeOutput { event })
}

97
src/websocket.rs Normal file
View File

@ -0,0 +1,97 @@
use hyper::{Request, Response, Body, StatusCode, header};
use crate::usimp::*;
use crate::usimp;
use crate::error::*;
use hyper_tungstenite::{WebSocketStream, tungstenite::protocol::Role};
use futures_util::StreamExt;
use hyper_tungstenite::tungstenite::{handshake, Message};
use hyper_tungstenite::hyper::upgrade::Upgraded;
use futures::stream::{SplitSink, SplitStream};
use tokio::sync::mpsc;
use serde_json::{Value, Map};
use futures_util::SinkExt;
async fn sender(mut sink: SplitSink<WebSocketStream<Upgraded>, Message>, mut rx: mpsc::Receiver<OutputEnvelope>) {
while let Some(msg) = rx.recv().await {
let mut envelope = Value::Object(Map::new());
envelope["data"] = msg.data;
envelope["request_nr"] = match msg.request_nr {
Some(nr) => Value::from(nr),
None => Value::Null,
};
match msg.error {
Some(error) => {
envelope["status"] = Value::from("error");
envelope["error"] = Value::from(error);
},
None => {
envelope["status"] = Value::from("success");
}
}
if let Err(error) = sink.send(Message::Text(envelope.to_string())).await {
eprintln!("{:?}", error);
break
}
}
}
async fn receiver(mut stream: SplitStream<WebSocketStream<Upgraded>>, tx: mpsc::Sender<OutputEnvelope>) {
while let Some(res) = stream.next().await {
match res {
Ok(msg) => {
let input: InputEnvelope = serde_json::from_slice(&msg.into_data()[..]).unwrap();
let output = match usimp::endpoint(&input).await {
Ok(output) => output,
Err(error) => input.error(error),
};
tx.send(output).await;
},
Err(error) => println!("{:?}", error),
}
}
}
pub async fn handler(req: Request<Body>, res: hyper::http::response::Builder) -> (hyper::http::response::Builder, Option<OutputEnvelope>) {
match req.headers().get(header::UPGRADE) {
Some(val) if val == header::HeaderValue::from_str("websocket").unwrap() => {},
_ => return (res, Some(OutputEnvelope::from(Error::new(ErrorKind::WebSocketError, ErrorClass::ClientProtocolError, None)))),
}
let key = match req.headers().get(header::SEC_WEBSOCKET_KEY) {
Some(key) => key,
None => return (res, Some(OutputEnvelope::from(Error::new(ErrorKind::WebSocketError, ErrorClass::ClientProtocolError, None))))
};
let key = handshake::derive_accept_key(key.as_bytes());
match req.headers().get(header::SEC_WEBSOCKET_PROTOCOL) {
Some(val) if val == header::HeaderValue::from_str("usimp").unwrap() => {}
_ => return (res, Some(OutputEnvelope::from(Error::new(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, None)))),
}
tokio::spawn(async move {
match hyper::upgrade::on(req).await {
Ok(upgraded) => {
let ws_stream = WebSocketStream::from_raw_socket(
upgraded,
Role::Server,
None,
).await;
let (tx, rx) = mpsc::channel::<OutputEnvelope>(16);
let (sink, stream) = ws_stream.split();
tokio::spawn(async move {
sender(sink, rx).await
});
receiver(stream, tx).await
}
Err(error) => eprintln!("Unable to upgrade: {}", error)
}
});
(res
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(header::CONNECTION, header::HeaderValue::from_str("Upgrade").unwrap())
.header(header::UPGRADE, header::HeaderValue::from_str("websocket").unwrap())
.header(header::SEC_WEBSOCKET_ACCEPT, header::HeaderValue::from_str(key.as_str()).unwrap())
.header(header::SEC_WEBSOCKET_PROTOCOL, header::HeaderValue::from_str("usimp").unwrap()),
None)
}

View File

@ -1,202 +0,0 @@
use crate::error::*;
use crate::http;
use crate::usimp;
use crate::websocket::*;
use base64;
use crypto;
use crypto::digest::Digest;
pub fn recv_message(client: &mut http::HttpStream) -> Result<Message, Error> {
let mut msg: Vec<u8> = Vec::new();
let mut msg_type = 0;
loop {
let header = FrameHeader::from(&mut client.stream)?;
// FIXME control frames may show up in a fragmented stream
if msg_type != 0 && header.opcode != 0 {
return Err(Error::new(Kind::WebSocketError, Class::ClientProtocolError)
.set_desc("continuation frame expected".to_string()));
} else if header.opcode >= 8 && (!header.fin || header.payload_len >= 126) {
return Err(Error::new(Kind::WebSocketError, Class::ClientProtocolError)
.set_desc("invalid control frame".to_string()));
}
match header.opcode {
0 => {}, // cont
1 => {}, // text
2 => // binary
return Err(Error::new(Kind::UsimpProtocolError, Class::ClientProtocolError)
.set_desc("binary frames must not be sent on a usimp connection".to_string())),
8 => {}, // close
9 => {}, // ping
10 => {}, // pong
_ => return Err(Error::new(Kind::WebSocketError, Class::ClientProtocolError)
.set_desc("invalid opcode".to_string())),
}
msg_type = header.opcode;
// FIXME check payload len and total len
let mut buf = vec![0u8; header.payload_len() as usize];
client.stream.read_exact(&mut buf)?;
if header.mask {
let key: [u8; 4] = [
(header.masking_key.unwrap() >> 24) as u8,
((header.masking_key.unwrap() >> 16) & 0xFF) as u8,
((header.masking_key.unwrap() >> 8) & 0xFF) as u8,
(header.masking_key.unwrap() & 0xFF) as u8,
];
for (pos, byte) in buf.iter_mut().enumerate() {
*byte ^= key[pos & 3]; // = pos % 4
}
}
msg.append(&mut buf);
if header.fin {
break
}
}
match msg_type {
1 => {Ok(Message::TextMessage(TextMessage {
data: String::from_utf8(msg)?
}))},
8 => {
let mut code = None;
let mut reason = None;
if msg.len() >= 2 {
code = Some(((msg[0] as u16) << 8) | (msg[1] as u16));
}
if msg.len() > 2 {
reason = Some(String::from_utf8(msg[2..].to_vec())?);
}
Ok(Message::CloseMessage(CloseMessage {
code,
reason
}))
},
9 => {Ok(Message::PingMessage(PingMessage {
data: String::from_utf8(msg)?
}))},
10 => {Ok(Message::PongMessage(PongMessage {
data: String::from_utf8(msg)?
}))},
_ => panic!("invalid msg_type for websocket")
}
}
pub fn handshake(
client: &mut http::HttpStream,
req: &http::Request,
res: &mut http::Response,
) -> Result<(), Error> {
if let http::Method::GET = req.method {
} else {
res.status(405);
res.header.add_field("Allow", "GET");
return Err(Error::new(Kind::WebSocketError, Class::ClientProtocolError)
.set_desc("method not allowed".to_string()));
}
if let Some(_) = req.header.find_field("Connection") {
if !req.header.field_has_value("Connection", "upgrade") {
return Err(Error::new(Kind::WebSocketError, Class::ClientProtocolError)
.set_desc("invalid value for header field 'Connection'".to_string()));
}
} else {
return Err(Error::new(Kind::WebSocketError, Class::ClientProtocolError)
.set_desc("unable to find header field 'Connection'".to_string()));
}
if let Some(upgrade) = req.header.find_field("Upgrade") {
if !upgrade.eq_ignore_ascii_case("websocket") {
return Err(Error::new(Kind::WebSocketError, Class::ClientProtocolError)
.set_desc("invalid value for header field 'Upgrade'".to_string()));
}
} else {
return Err(Error::new(Kind::WebSocketError, Class::ClientProtocolError)
.set_desc("unable to find header field 'Upgrade'".to_string()));
}
if let Some(version) = req.header.find_field("Sec-WebSocket-Version") {
if !version.eq("13") {
return Err(Error::new(Kind::WebSocketError, Class::ClientProtocolError)
.set_desc("invalid value for header field 'Sec-WebSocket-Key'".to_string()));
}
} else {
return Err(Error::new(Kind::WebSocketError, Class::ClientProtocolError)
.set_desc("unable to find header field 'Sec-WebSocket-Version'".to_string()));
}
if let Some(key) = req.header.find_field("Sec-WebSocket-Key") {
let mut hasher = crypto::sha1::Sha1::new();
hasher.input_str(key);
hasher.input_str("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
let mut result = [0u8; 160 / 8];
hasher.result(&mut result);
let key = base64::encode(result);
res.header.add_field("Sec-WebSocket-Accept", key.as_str());
} else {
return Err(Error::new(Kind::WebSocketError, Class::ClientProtocolError)
.set_desc("unable to find header field 'Sec-WebSocket-Key'".to_string()));
}
client.server_keep_alive = false;
res.header.add_field("Connection", "Upgrade");
res.header.add_field("Upgrade", "websocket");
res.status(101);
res.send(&mut client.stream)?;
Ok(())
}
pub fn error_handler(error: Error) {
// TODO send error response frame
}
pub fn connection_handler(
client: &mut http::HttpStream,
req: &http::Request,
mut res: http::Response,
) {
if let Err(error) = handshake(client, req, &mut res) {
return http::error_handler(client, res, error);
}
loop {
match recv_message(client) {
Ok(msg) => {
match msg {
Message::TextMessage(msg) => {
// TODO threads?
let req: RequestEnvelope = serde_json::from_str(msg.data.as_str()).unwrap();
println!("Endpoint: {}, ReqNo: {}, Data: {}", req.endpoint, req.request_nr, req.data);
//let a = usimp::endpoint(req.endpoint.as_str(), req.data).unwrap();
},
Message::CloseMessage(msg) => {
println!("Received close frame: {}: {}", msg.code.unwrap_or(0), msg.reason.unwrap_or("-".to_string()));
return
}
Message::PingMessage(msg) => {
// TODO send pong
},
Message::PongMessage(msg) => {
// TODO something
}
}
},
Err(error) => {
error_handler(error);
}
}
}
}

View File

@ -1,106 +0,0 @@
mod handler;
use serde::{Deserialize, Serialize};
use crate::error::Error;
use crate::http;
pub use handler::*;
pub struct WebSocketStream {
stream: http::HttpStream,
compression: bool,
}
pub struct FrameHeader {
fin: bool,
rsv1: bool,
rsv2: bool,
rsv3: bool,
opcode: u8,
mask: bool,
payload_len: u8,
ex_payload_len: Option<u64>,
masking_key: Option<u32>,
}
pub enum Message {
PingMessage(PingMessage),
PongMessage(PongMessage),
CloseMessage(CloseMessage),
TextMessage(TextMessage),
}
pub struct PingMessage {
data: String,
}
pub struct PongMessage {
data: String,
}
pub struct CloseMessage {
code: Option<u16>,
reason: Option<String>,
}
pub struct TextMessage {
data: String,
}
#[derive(Serialize, Deserialize)]
pub struct RequestEnvelope {
endpoint: String,
request_nr: u64,
data: serde_json::Value,
}
#[derive(Serialize, Deserialize)]
pub struct ResponseEnvelope {
to_request_nr: u64,
status: String,
message: Option<String>,
data: serde_json::Value,
}
impl FrameHeader {
pub fn from(socket: &mut http::Stream) -> Result<Self, Error> {
let mut data = [0u8; 2];
socket.read_exact(&mut data)?;
let mut header = FrameHeader {
fin: data[0] & 0x80 != 0,
rsv1: data[0] & 0x40 != 0,
rsv2: data[0] & 0x20 != 0,
rsv3: data[0] & 0x10 != 0,
opcode: data[0] & 0x0F,
mask: data[1] & 0x80 != 0,
payload_len: data[1] & 0x7F,
ex_payload_len: None,
masking_key: None,
};
if header.payload_len == 126 {
let mut data = [0u8; 2];
socket.read_exact(&mut data)?;
header.ex_payload_len = Some(u16::from_be_bytes(data) as u64);
} else if header.payload_len == 127 {
let mut data = [0u8; 8];
socket.read_exact(&mut data)?;
header.ex_payload_len = Some(u64::from_be_bytes(data));
}
if header.mask {
let mut data = [0u8; 4];
socket.read_exact(&mut data)?;
header.masking_key = Some(u32::from_be_bytes(data));
}
Ok(header)
}
pub fn payload_len(&self) -> u64 {
match self.ex_payload_len {
Some(val) => val,
None => self.payload_len as u64,
}
}
}