cargo fmt

This commit is contained in:
2021-06-05 14:29:09 +02:00
parent 473b553662
commit 01e9b9c6ae
11 changed files with 341 additions and 169 deletions

View File

@ -15,7 +15,9 @@ static mut POOL: Option<Pool> = None;
pub async fn init() -> Result<(), Error> { pub async fn init() -> Result<(), Error> {
let manager = PostgresConnectionManager::new( let manager = PostgresConnectionManager::new(
"host=localhost user=postgres dbname=locutus".parse().unwrap(), "host=localhost user=postgres dbname=locutus"
.parse()
.unwrap(),
NoTls, NoTls,
); );
@ -24,7 +26,8 @@ pub async fn init() -> Result<(), Error> {
.min_idle(Some(2)) .min_idle(Some(2))
.connection_timeout(Duration::from_secs(4)) .connection_timeout(Duration::from_secs(4))
.max_lifetime(Some(Duration::from_secs(3600))) .max_lifetime(Some(Duration::from_secs(3600)))
.build(manager).await?; .build(manager)
.await?;
unsafe { unsafe {
POOL = Some(Pool::Postgres(pool)); POOL = Some(Pool::Postgres(pool));

View File

@ -1,8 +1,8 @@
use crate::usimp::{InputEnvelope, OutputEnvelope, Event}; use crate::usimp::{Event, InputEnvelope, OutputEnvelope};
use serde_json::{Value, Map};
use bb8_postgres::tokio_postgres;
use bb8_postgres; use bb8_postgres;
use bb8_postgres::tokio_postgres;
use serde_json::{Map, Value};
#[derive(Debug)] #[derive(Debug)]
pub struct Error { pub struct Error {
@ -32,7 +32,12 @@ pub enum ErrorKind {
} }
impl InputEnvelope { impl InputEnvelope {
pub fn new_error(&self, kind: ErrorKind, class: ErrorClass, msg: Option<String>) -> OutputEnvelope { pub fn new_error(
&self,
kind: ErrorKind,
class: ErrorClass,
msg: Option<String>,
) -> OutputEnvelope {
OutputEnvelope { OutputEnvelope {
request_nr: self.request_nr, request_nr: self.request_nr,
error: Some(Error::new(kind, class, msg)), error: Some(Error::new(kind, class, msg)),
@ -56,7 +61,7 @@ impl Error {
class, class,
msg, msg,
desc: None, desc: None,
} };
} }
pub fn msg(&mut self, msg: String) { pub fn msg(&mut self, msg: String) {

View File

@ -1,34 +1,62 @@
use hyper::{Request, Response, Body, StatusCode, header, Method, body};
use serde_json::{Value, Map};
use crate::websocket;
use crate::usimp::*;
use crate::error::*; use crate::error::*;
use crate::usimp; use crate::usimp;
use crate::usimp::*;
use crate::websocket;
use hyper::{body, header, Body, Method, Request, Response, StatusCode};
use serde_json::{Map, Value};
use std::str::FromStr; use std::str::FromStr;
async fn endpoint_handler(req: &mut Request<Body>, endpoint: String) -> Result<Option<OutputEnvelope>, Error> { async fn endpoint_handler(
req: &mut Request<Body>,
endpoint: String,
) -> Result<Option<OutputEnvelope>, Error> {
if req.method() == Method::OPTIONS { if req.method() == Method::OPTIONS {
return Ok(None) return Ok(None);
} else if req.method() != Method::POST { } else if req.method() != Method::POST {
return Err(Error::new(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, None)) return Err(Error::new(
ErrorKind::UsimpError,
ErrorClass::ClientProtocolError,
None,
));
} }
let to_domain; let to_domain;
if let Some(val) = req.headers().get(header::HeaderName::from_str("To-Domain").unwrap()) { if let Some(val) = req
.headers()
.get(header::HeaderName::from_str("To-Domain").unwrap())
{
to_domain = val.to_str().unwrap().to_string() to_domain = val.to_str().unwrap().to_string()
} else { } else {
return Err(Error::new(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, None)) return Err(Error::new(
ErrorKind::UsimpError,
ErrorClass::ClientProtocolError,
None,
));
} }
if let Some(val) = req.headers().get(header::CONTENT_TYPE) { if let Some(val) = req.headers().get(header::CONTENT_TYPE) {
let parts: Vec<String> = val.to_str()?.split(';').map(|v| v.trim().to_ascii_lowercase()).collect(); let parts: Vec<String> = val
.to_str()?
.split(';')
.map(|v| v.trim().to_ascii_lowercase())
.collect();
let p: Vec<&str> = parts.iter().map(|v| v.as_str()).collect(); let p: Vec<&str> = parts.iter().map(|v| v.as_str()).collect();
match p[0..1] { match p[0..1] {
["application/json"] => {}, ["application/json"] => {}
_ => return Err(Error::new(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, None)) _ => {
return Err(Error::new(
ErrorKind::UsimpError,
ErrorClass::ClientProtocolError,
None,
))
}
} }
} else { } else {
return Err(Error::new(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, None)) return Err(Error::new(
ErrorKind::UsimpError,
ErrorClass::ClientProtocolError,
None,
));
} }
let data = serde_json::from_slice(&body::to_bytes(req.body_mut()).await?)?; let data = serde_json::from_slice(&body::to_bytes(req.body_mut()).await?)?;
@ -36,7 +64,10 @@ async fn endpoint_handler(req: &mut Request<Body>, endpoint: String) -> Result<O
let input = InputEnvelope { let input = InputEnvelope {
endpoint, endpoint,
to_domain, to_domain,
from_domain: match req.headers().get(header::HeaderName::from_str("From-Domain").unwrap()) { from_domain: match req
.headers()
.get(header::HeaderName::from_str("From-Domain").unwrap())
{
Some(val) => Some(val.to_str()?.to_string()), Some(val) => Some(val.to_str()?.to_string()),
None => None, None => None,
}, },
@ -47,9 +78,13 @@ async fn endpoint_handler(req: &mut Request<Body>, endpoint: String) -> Result<O
if val.starts_with("usimp ") { if val.starts_with("usimp ") {
Some(val[6..].to_string()) Some(val[6..].to_string())
} else { } else {
return Err(Error::new(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, None)) return Err(Error::new(
ErrorKind::UsimpError,
ErrorClass::ClientProtocolError,
None,
));
}
} }
},
None => None, None => None,
}, },
data, data,
@ -65,7 +100,10 @@ pub async fn handler(mut req: Request<Body>) -> Result<Response<Body>, hyper::Er
println!("{} {}", req.method(), req.uri()); println!("{} {}", req.method(), req.uri());
let val: Result<Response<Body>, Error> = match &parts[..] { let val: Result<Response<Body>, Error> = match &parts[..] {
[""] => Ok(res.status(StatusCode::OK).body(Body::from("Hello World")).unwrap()), [""] => Ok(res
.status(StatusCode::OK)
.body(Body::from("Hello World"))
.unwrap()),
["_usimp"] | ["_usimp", ..] => { ["_usimp"] | ["_usimp", ..] => {
res = res res = res
.header(header::SERVER, "Locutus") .header(header::SERVER, "Locutus")
@ -75,23 +113,29 @@ pub async fn handler(mut req: Request<Body>) -> Result<Response<Body>, hyper::Er
let output = match &parts[1..] { let output = match &parts[1..] {
["websocket"] => { ["websocket"] => {
res = res res = res.header(header::ACCESS_CONTROL_ALLOW_METHODS, "GET");
.header(header::ACCESS_CONTROL_ALLOW_METHODS, "GET");
let (r, val) = websocket::handler(req, res).await; let (r, val) = websocket::handler(req, res).await;
res = r; res = r;
match val { match val {
Some(val) => Ok(Some(val)), Some(val) => Ok(Some(val)),
None => return Ok(res.body(Body::empty()).unwrap()), None => return Ok(res.body(Body::empty()).unwrap()),
} }
}, }
[endpoint] => { [endpoint] => {
res = res res = res
.header(header::ACCESS_CONTROL_ALLOW_METHODS, "POST, OPTIONS") .header(header::ACCESS_CONTROL_ALLOW_METHODS, "POST, OPTIONS")
.header(header::ACCESS_CONTROL_ALLOW_HEADERS, "Content-Type, From-Domain, To-Domain, Authorization"); .header(
header::ACCESS_CONTROL_ALLOW_HEADERS,
"Content-Type, From-Domain, To-Domain, Authorization",
);
let endpoint = endpoint.to_string(); let endpoint = endpoint.to_string();
endpoint_handler(&mut req, endpoint).await endpoint_handler(&mut req, endpoint).await
}, }
_ => Err(Error::new(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, None)), _ => Err(Error::new(
ErrorKind::UsimpError,
ErrorClass::ClientProtocolError,
None,
)),
}; };
let output = match output { let output = match output {
@ -107,15 +151,17 @@ pub async fn handler(mut req: Request<Body>) -> Result<Response<Body>, hyper::Er
Some(error) => { Some(error) => {
res = match error.class { res = match error.class {
ErrorClass::ClientProtocolError => res.status(StatusCode::BAD_REQUEST), ErrorClass::ClientProtocolError => res.status(StatusCode::BAD_REQUEST),
ErrorClass::ServerError => res.status(StatusCode::INTERNAL_SERVER_ERROR), ErrorClass::ServerError => {
res.status(StatusCode::INTERNAL_SERVER_ERROR)
}
_ => res.status(StatusCode::OK), _ => res.status(StatusCode::OK),
}; };
data["status"] = Value::from("error"); data["status"] = Value::from("error");
data["error"] = Value::from(error); data["error"] = Value::from(error);
}, }
None => { None => {
data["status"] = Value::from("success"); data["status"] = Value::from("success");
}, }
} }
data["request_nr"] = match output.request_nr { data["request_nr"] = match output.request_nr {
@ -123,14 +169,19 @@ pub async fn handler(mut req: Request<Body>) -> Result<Response<Body>, hyper::Er
None => Value::Null, None => Value::Null,
}; };
data["data"] = output.data; data["data"] = output.data;
return Ok(res.body(Body::from(serde_json::to_string(&data).unwrap() + "\r\n")).unwrap()) return Ok(res
.body(Body::from(serde_json::to_string(&data).unwrap() + "\r\n"))
.unwrap());
} else { } else {
res = res.status(StatusCode::NO_CONTENT); res = res.status(StatusCode::NO_CONTENT);
} }
return Ok(res.body(Body::empty()).unwrap()) return Ok(res.body(Body::empty()).unwrap());
}, }
_ => Ok(res.status(StatusCode::NOT_FOUND).body(Body::empty()).unwrap()), _ => Ok(res
.status(StatusCode::NOT_FOUND)
.body(Body::empty())
.unwrap()),
}; };
match val { match val {

View File

@ -4,21 +4,30 @@ use std::net;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::pin::Pin; use std::pin::Pin;
use error::*;
use ansi_term::{Color, Style}; use ansi_term::{Color, Style};
use error::*;
use futures_util::{future::TryFutureExt, stream::Stream}; use futures_util::{future::TryFutureExt, stream::Stream};
use hyper::Server;
use hyper::server::conn::AddrStream; use hyper::server::conn::AddrStream;
use hyper::service::{make_service_fn, service_fn}; use hyper::service::{make_service_fn, service_fn};
use hyper::Server;
mod http;
mod websocket;
mod usimp;
mod database; mod database;
mod error; mod error;
mod http;
mod usimp;
mod websocket;
struct HyperAcceptor<'a> { struct HyperAcceptor<'a> {
acceptor: Pin<Box<dyn Stream<Item = Result<tokio_rustls::server::TlsStream<tokio::net::TcpStream>, std::io::Error>> + 'a>>, acceptor: Pin<
Box<
dyn Stream<
Item = Result<
tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
std::io::Error,
>,
> + 'a,
>,
>,
} }
impl hyper::server::accept::Accept for HyperAcceptor<'_> { impl hyper::server::accept::Accept for HyperAcceptor<'_> {
@ -38,7 +47,8 @@ fn load_certs(filename: &str) -> std::io::Result<Vec<rustls::Certificate>> {
.map_err(|e| error(format!("failed to open {}: {}", filename, e)))?; .map_err(|e| error(format!("failed to open {}: {}", filename, e)))?;
let mut reader = std::io::BufReader::new(certfile); let mut reader = std::io::BufReader::new(certfile);
rustls::internal::pemfile::certs(&mut reader).map_err(|_| error("failed to load certificate".into())) rustls::internal::pemfile::certs(&mut reader)
.map_err(|_| error("failed to load certificate".into()))
} }
fn load_private_key(filename: &str) -> std::io::Result<rustls::PrivateKey> { fn load_private_key(filename: &str) -> std::io::Result<rustls::PrivateKey> {
@ -65,7 +75,10 @@ async fn main() -> Result<(), Error> {
database::init().await?; database::init().await?;
usimp::subscription::init(); usimp::subscription::init();
let server1 = Server::bind(&SocketAddr::from(([0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0], 8080))); let server1 = Server::bind(&SocketAddr::from((
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
8080,
)));
let service = make_service_fn(|_: &AddrStream| async { let service = make_service_fn(|_: &AddrStream| async {
Ok::<_, hyper::Error>(service_fn(http::handler)) Ok::<_, hyper::Error>(service_fn(http::handler))
}); });
@ -100,9 +113,7 @@ async fn main() -> Result<(), Error> {
acceptor: Box::pin(incoming_tls_stream), acceptor: Box::pin(incoming_tls_stream),
}); });
let service = make_service_fn(|_| async { let service = make_service_fn(|_| async { Ok::<_, hyper::Error>(service_fn(http::handler)) });
Ok::<_, hyper::Error>(service_fn(http::handler))
});
let srv2 = server2.serve(service); let srv2 = server2.serve(service);
println!("{}", Color::Green.paint("Ready")); println!("{}", Color::Green.paint("Ready"));

View File

@ -1,10 +1,10 @@
use crate::database;
use crate::usimp; use crate::usimp;
use crate::usimp::*; use crate::usimp::*;
use crate::database;
use serde_json::{Value, from_value, to_value};
use serde::{Serialize, Deserialize};
use rand::Rng; use rand::Rng;
use serde::{Deserialize, Serialize};
use serde_json::{from_value, to_value, Value};
#[derive(Serialize, Deserialize, Clone)] #[derive(Serialize, Deserialize, Clone)]
struct Input { struct Input {
@ -19,7 +19,9 @@ struct Output {
} }
pub async fn handle(input: &InputEnvelope, session: Option<Session>) -> Result<Value, Error> { pub async fn handle(input: &InputEnvelope, session: Option<Session>) -> Result<Value, Error> {
Ok(to_value(authenticate(from_value(input.data.clone())?, session).await?)?) Ok(to_value(
authenticate(from_value(input.data.clone())?, session).await?,
)?)
} }
async fn authenticate(input: Input, _session: Option<Session>) -> Result<Output, Error> { async fn authenticate(input: Input, _session: Option<Session>) -> Result<Output, Error> {
@ -28,14 +30,20 @@ async fn authenticate(input: Input, _session: Option<Session>) -> Result<Output,
let session_id; let session_id;
match backend { match backend {
database::Client::Postgres(client) => { database::Client::Postgres(client) => {
let res = client.query( let res = client
.query(
"SELECT account_id, domain_id \ "SELECT account_id, domain_id \
FROM accounts \ FROM accounts \
WHERE account_name = $1", WHERE account_name = $1",
&[&input.name] &[&input.name],
).await?; )
.await?;
if res.len() == 0 { if res.len() == 0 {
return Err(Error::new(ErrorKind::AuthenticationError, ErrorClass::ClientError, None)); return Err(Error::new(
ErrorKind::AuthenticationError,
ErrorClass::ClientError,
None,
));
} }
let row = &res[0]; let row = &res[0];
let account_id: String = row.get(0); let account_id: String = row.get(0);
@ -43,7 +51,11 @@ async fn authenticate(input: Input, _session: Option<Session>) -> Result<Output,
// TODO password check // TODO password check
if !input.password.eq("MichaelScott") { if !input.password.eq("MichaelScott") {
return Err(Error::new(ErrorKind::AuthenticationError, ErrorClass::ClientError, None)); return Err(Error::new(
ErrorKind::AuthenticationError,
ErrorClass::ClientError,
None,
));
} }
session_id = usimp::get_id(&[domain_id.as_str(), account_id.as_str()]); session_id = usimp::get_id(&[domain_id.as_str(), account_id.as_str()]);
@ -53,17 +65,16 @@ async fn authenticate(input: Input, _session: Option<Session>) -> Result<Output,
.map(char::from) .map(char::from)
.collect(); .collect();
client.execute( client
.execute(
"INSERT INTO sessions (account_id, session_nr, session_id, session_token) \ "INSERT INTO sessions (account_id, session_nr, session_id, session_token) \
VALUES ($1, COALESCE((SELECT MAX(session_nr) + 1 \ VALUES ($1, COALESCE((SELECT MAX(session_nr) + 1 \
FROM sessions \ FROM sessions \
WHERE account_id = $1), 1), $2, $3);", WHERE account_id = $1), 1), $2, $3);",
&[&account_id, &session_id, &token], &[&account_id, &session_id, &token],
).await?; )
.await?;
} }
} }
Ok(Output { Ok(Output { session_id, token })
session_id,
token,
})
} }

View File

@ -1,15 +1,22 @@
mod ping;
mod authenticate; mod authenticate;
mod subscribe;
mod new_event; mod new_event;
mod ping;
mod subscribe;
use crate::usimp::*; use crate::usimp::*;
use tokio::sync::mpsc; use tokio::sync::mpsc;
pub async fn endpoint(input: &InputEnvelope, tx: Option<mpsc::Sender<OutputEnvelope>>) -> Result<OutputEnvelope, Error> { pub async fn endpoint(
input: &InputEnvelope,
tx: Option<mpsc::Sender<OutputEnvelope>>,
) -> Result<OutputEnvelope, Error> {
if input.from_domain != None { if input.from_domain != None {
// TODO // TODO
return Err(Error::new(ErrorKind::NotImplemented, ErrorClass::ServerError, None)); return Err(Error::new(
ErrorKind::NotImplemented,
ErrorClass::ServerError,
None,
));
} }
let session; let session;
if let Some(token) = &input.token { if let Some(token) = &input.token {
@ -24,8 +31,10 @@ pub async fn endpoint(input: &InputEnvelope, tx: Option<mpsc::Sender<OutputEnvel
"authenticate" => input.respond(authenticate::handle(&input, session).await?), "authenticate" => input.respond(authenticate::handle(&input, session).await?),
"subscribe" => input.respond(subscribe::handle(&input, session, tx).await?), "subscribe" => input.respond(subscribe::handle(&input, session, tx).await?),
"new_event" => input.respond(new_event::handle(&input, session).await?), "new_event" => input.respond(new_event::handle(&input, session).await?),
_ => input.new_error(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, Some("Invalid endpoint".to_string())), _ => input.new_error(
ErrorKind::UsimpError,
ErrorClass::ClientProtocolError,
Some("Invalid endpoint".to_string()),
),
}) })
} }

View File

@ -1,8 +1,8 @@
use crate::usimp::*;
use crate::usimp::subscription; use crate::usimp::subscription;
use crate::usimp::*;
use serde_json::{Value, from_value, to_value}; use serde::{Deserialize, Serialize};
use serde::{Serialize, Deserialize}; use serde_json::{from_value, to_value, Value};
#[derive(Serialize, Deserialize, Clone)] #[derive(Serialize, Deserialize, Clone)]
struct Input { struct Input {
@ -11,11 +11,12 @@ struct Input {
} }
#[derive(Serialize, Deserialize, Clone)] #[derive(Serialize, Deserialize, Clone)]
struct Output { struct Output {}
}
pub async fn handle(input: &InputEnvelope, session: Option<Session>) -> Result<Value, Error> { pub async fn handle(input: &InputEnvelope, session: Option<Session>) -> Result<Value, Error> {
Ok(to_value(new_event(from_value(input.data.clone())?, session).await?)?) Ok(to_value(
new_event(from_value(input.data.clone())?, session).await?,
)?)
} }
async fn new_event(input: Input, session: Option<Session>) -> Result<Output, Error> { async fn new_event(input: Input, session: Option<Session>) -> Result<Output, Error> {

View File

@ -1,50 +1,67 @@
use crate::usimp::*;
use crate::usimp::subscription; use crate::usimp::subscription;
use crate::usimp::*;
use serde_json::{Value, from_value, to_value}; use serde::{Deserialize, Serialize};
use serde::{Serialize, Deserialize}; use serde_json::{from_value, to_value, Value};
use tokio::sync::mpsc; use tokio::sync::mpsc;
#[derive(Serialize, Deserialize, Clone)] #[derive(Serialize, Deserialize, Clone)]
struct Input { struct Input {}
}
#[derive(Serialize, Deserialize, Clone)] #[derive(Serialize, Deserialize, Clone)]
struct Output { struct Output {
event: Option<Event>, event: Option<Event>,
} }
pub async fn handle(input: &InputEnvelope, session: Option<Session>, tx: Option<mpsc::Sender<OutputEnvelope>>) -> Result<Value, Error> { pub async fn handle(
Ok(to_value(subscribe(from_value(input.data.clone())?, session, input.request_nr, tx).await?)?) input: &InputEnvelope,
session: Option<Session>,
tx: Option<mpsc::Sender<OutputEnvelope>>,
) -> Result<Value, Error> {
Ok(to_value(
subscribe(
from_value(input.data.clone())?,
session,
input.request_nr,
tx,
)
.await?,
)?)
} }
async fn subscribe(_input: Input, session: Option<Session>, req_nr: Option<u64>, tx: Option<mpsc::Sender<OutputEnvelope>>) -> Result<Output, Error> { async fn subscribe(
_input: Input,
session: Option<Session>,
req_nr: Option<u64>,
tx: Option<mpsc::Sender<OutputEnvelope>>,
) -> Result<Output, Error> {
let account = get_account(&session)?; let account = get_account(&session)?;
let mut rx = subscription::subscribe_account(account).await; let mut rx = subscription::subscribe_account(account).await;
match tx { match tx {
Some(tx) => { Some(tx) => {
tokio::spawn(async move { tokio::spawn(async move {
while let Some(event) = rx.recv().await { while let Some(event) = rx.recv().await {
let _res = tx.send(OutputEnvelope { let _res = tx
.send(OutputEnvelope {
error: None, error: None,
request_nr: req_nr, request_nr: req_nr,
data: to_value(event).unwrap(), data: to_value(event).unwrap(),
}).await; })
.await;
} }
}); });
Ok(Output { Ok(Output { event: None })
event: None,
})
} }
None => { None => {
if let Some(event) = rx.recv().await { if let Some(event) = rx.recv().await {
Ok(Output { Ok(Output { event: Some(event) })
event: Some(event),
})
} else { } else {
Err(Error::new(ErrorKind::SubscriptionError, ErrorClass::ServerError, None)) Err(Error::new(
ErrorKind::SubscriptionError,
ErrorClass::ServerError,
None,
))
} }
} }
} }
} }

View File

@ -3,13 +3,13 @@ pub mod subscription;
pub use handler::endpoint; pub use handler::endpoint;
use crate::error::{Error, ErrorClass, ErrorKind};
use crate::database; use crate::database;
use serde_json::Value; use crate::error::{Error, ErrorClass, ErrorKind};
use serde::{Serialize, Deserialize};
use crypto::sha2::Sha256;
use crypto::digest::Digest;
use base64_url; use base64_url;
use crypto::digest::Digest;
use crypto::sha2::Sha256;
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub struct InputEnvelope { pub struct InputEnvelope {
@ -60,9 +60,21 @@ pub fn get_account(session: &Option<Session>) -> Result<&Account, Error> {
match session { match session {
Some(session) => match &session.account { Some(session) => match &session.account {
Some(account) => Ok(&account), Some(account) => Ok(&account),
None => return Err(Error::new(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, None)) None => {
return Err(Error::new(
ErrorKind::UsimpError,
ErrorClass::ClientProtocolError,
None,
))
}
}, },
None => return Err(Error::new(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, None)) None => {
return Err(Error::new(
ErrorKind::UsimpError,
ErrorClass::ClientProtocolError,
None,
))
}
} }
} }
@ -82,14 +94,20 @@ impl Session {
let session; let session;
match backend { match backend {
database::Client::Postgres(client) => { database::Client::Postgres(client) => {
let res = client.query( let res = client
.query(
"SELECT session_id, session_nr, a.account_id, account_name, domain_id \ "SELECT session_id, session_nr, a.account_id, account_name, domain_id \
FROM accounts a JOIN sessions s ON a.account_id = s.account_id \ FROM accounts a JOIN sessions s ON a.account_id = s.account_id \
WHERE session_token = $1;", WHERE session_token = $1;",
&[&token] &[&token],
).await?; )
.await?;
if res.len() == 0 { if res.len() == 0 {
return Err(Error::new(ErrorKind::InvalidSessionError, ErrorClass::ClientError, None)); return Err(Error::new(
ErrorKind::InvalidSessionError,
ErrorClass::ClientError,
None,
));
} }
let row = &res[0]; let row = &res[0];
session = Session { session = Session {

View File

@ -1,8 +1,8 @@
use crate::usimp::*;
use crate::database; use crate::database;
use tokio::sync::{mpsc, Mutex}; use crate::usimp::*;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::{mpsc, Mutex};
static mut ROOMS: Option<Arc<Mutex<HashMap<String, Vec<mpsc::Sender<Event>>>>>> = None; static mut ROOMS: Option<Arc<Mutex<HashMap<String, Vec<mpsc::Sender<Event>>>>>> = None;
static mut ACCOUNTS: Option<Arc<Mutex<HashMap<String, Vec<mpsc::Sender<Event>>>>>> = None; static mut ACCOUNTS: Option<Arc<Mutex<HashMap<String, Vec<mpsc::Sender<Event>>>>>> = None;
@ -21,10 +21,10 @@ pub async fn subscribe_account(account: &Account) -> mpsc::Receiver<Event> {
match acc.get_mut(account.id.as_str()) { match acc.get_mut(account.id.as_str()) {
Some(vec) => { Some(vec) => {
vec.push(tx); vec.push(tx);
}, }
None => { None => {
acc.insert(account.id.clone(), vec!{tx}); acc.insert(account.id.clone(), vec![tx]);
}, }
} }
} }
rx rx
@ -34,12 +34,14 @@ pub async fn push(room_id: &str, event: Event) -> Result<(), Error> {
let backend = database::client().await?; let backend = database::client().await?;
let accounts = match backend { let accounts = match backend {
database::Client::Postgres(client) => { database::Client::Postgres(client) => {
let res = client.query( let res = client
.query(
"SELECT account_id \ "SELECT account_id \
FROM members \ FROM members \
WHERE room_id = $1;", WHERE room_id = $1;",
&[&room_id] &[&room_id],
).await?; )
.await?;
let mut acc: Vec<String> = Vec::new(); let mut acc: Vec<String> = Vec::new();
for row in res { for row in res {
acc.push(row.get(0)); acc.push(row.get(0));

View File

@ -1,17 +1,20 @@
use hyper::{Request, Body, StatusCode, header};
use crate::usimp::*;
use crate::usimp;
use crate::error::*; use crate::error::*;
use hyper_tungstenite::{WebSocketStream, tungstenite::protocol::Role}; use crate::usimp;
use futures_util::StreamExt; use crate::usimp::*;
use hyper_tungstenite::tungstenite::{handshake, Message};
use hyper_tungstenite::hyper::upgrade::Upgraded;
use futures::stream::{SplitSink, SplitStream}; use futures::stream::{SplitSink, SplitStream};
use tokio::sync::mpsc;
use serde_json::{Value, Map};
use futures_util::SinkExt; use futures_util::SinkExt;
use futures_util::StreamExt;
use hyper::{header, Body, Request, StatusCode};
use hyper_tungstenite::hyper::upgrade::Upgraded;
use hyper_tungstenite::tungstenite::{handshake, Message};
use hyper_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
use serde_json::{Map, Value};
use tokio::sync::mpsc;
async fn sender(mut sink: SplitSink<WebSocketStream<Upgraded>, Message>, mut rx: mpsc::Receiver<OutputEnvelope>) { async fn sender(
mut sink: SplitSink<WebSocketStream<Upgraded>, Message>,
mut rx: mpsc::Receiver<OutputEnvelope>,
) {
while let Some(msg) = rx.recv().await { while let Some(msg) = rx.recv().await {
let mut envelope = Value::Object(Map::new()); let mut envelope = Value::Object(Map::new());
envelope["data"] = msg.data; envelope["data"] = msg.data;
@ -23,19 +26,22 @@ async fn sender(mut sink: SplitSink<WebSocketStream<Upgraded>, Message>, mut rx:
Some(error) => { Some(error) => {
envelope["status"] = Value::from("error"); envelope["status"] = Value::from("error");
envelope["error"] = Value::from(error); envelope["error"] = Value::from(error);
}, }
None => { None => {
envelope["status"] = Value::from("success"); envelope["status"] = Value::from("success");
} }
} }
if let Err(error) = sink.send(Message::Text(envelope.to_string())).await { if let Err(error) = sink.send(Message::Text(envelope.to_string())).await {
eprintln!("{:?}", error); eprintln!("{:?}", error);
break break;
} }
} }
} }
async fn receiver(mut stream: SplitStream<WebSocketStream<Upgraded>>, tx: mpsc::Sender<OutputEnvelope>) { async fn receiver(
mut stream: SplitStream<WebSocketStream<Upgraded>>,
tx: mpsc::Sender<OutputEnvelope>,
) {
while let Some(res) = stream.next().await { while let Some(res) = stream.next().await {
match res { match res {
Ok(msg) => { Ok(msg) => {
@ -45,53 +51,91 @@ async fn receiver(mut stream: SplitStream<WebSocketStream<Upgraded>>, tx: mpsc::
Err(error) => input.error(error), Err(error) => input.error(error),
}; };
let _res = tx.send(output).await; let _res = tx.send(output).await;
}, }
Err(error) => println!("{:?}", error), Err(error) => println!("{:?}", error),
} }
} }
} }
pub async fn handler(req: Request<Body>, res: hyper::http::response::Builder) -> (hyper::http::response::Builder, Option<OutputEnvelope>) { pub async fn handler(
req: Request<Body>,
res: hyper::http::response::Builder,
) -> (hyper::http::response::Builder, Option<OutputEnvelope>) {
match req.headers().get(header::UPGRADE) { match req.headers().get(header::UPGRADE) {
Some(val) if val == header::HeaderValue::from_str("websocket").unwrap() => {}, Some(val) if val == header::HeaderValue::from_str("websocket").unwrap() => {}
_ => return (res, Some(OutputEnvelope::from(Error::new(ErrorKind::WebSocketError, ErrorClass::ClientProtocolError, None)))), _ => {
return (
res,
Some(OutputEnvelope::from(Error::new(
ErrorKind::WebSocketError,
ErrorClass::ClientProtocolError,
None,
))),
)
}
} }
let key = match req.headers().get(header::SEC_WEBSOCKET_KEY) { let key = match req.headers().get(header::SEC_WEBSOCKET_KEY) {
Some(key) => key, Some(key) => key,
None => return (res, Some(OutputEnvelope::from(Error::new(ErrorKind::WebSocketError, ErrorClass::ClientProtocolError, None)))) None => {
return (
res,
Some(OutputEnvelope::from(Error::new(
ErrorKind::WebSocketError,
ErrorClass::ClientProtocolError,
None,
))),
)
}
}; };
let key = handshake::derive_accept_key(key.as_bytes()); let key = handshake::derive_accept_key(key.as_bytes());
match req.headers().get(header::SEC_WEBSOCKET_PROTOCOL) { match req.headers().get(header::SEC_WEBSOCKET_PROTOCOL) {
Some(val) if val == header::HeaderValue::from_str("usimp").unwrap() => {} Some(val) if val == header::HeaderValue::from_str("usimp").unwrap() => {}
_ => return (res, Some(OutputEnvelope::from(Error::new(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, None)))), _ => {
return (
res,
Some(OutputEnvelope::from(Error::new(
ErrorKind::UsimpError,
ErrorClass::ClientProtocolError,
None,
))),
)
}
} }
tokio::spawn(async move { tokio::spawn(async move {
match hyper::upgrade::on(req).await { match hyper::upgrade::on(req).await {
Ok(upgraded) => { Ok(upgraded) => {
let ws_stream = WebSocketStream::from_raw_socket( let ws_stream =
upgraded, WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await;
Role::Server,
None,
).await;
let (tx, rx) = mpsc::channel::<OutputEnvelope>(64); let (tx, rx) = mpsc::channel::<OutputEnvelope>(64);
let (sink, stream) = ws_stream.split(); let (sink, stream) = ws_stream.split();
tokio::spawn(async move { tokio::spawn(async move { sender(sink, rx).await });
sender(sink, rx).await
});
receiver(stream, tx).await receiver(stream, tx).await
} }
Err(error) => eprintln!("Unable to upgrade: {}", error) Err(error) => eprintln!("Unable to upgrade: {}", error),
} }
}); });
(res (
.status(StatusCode::SWITCHING_PROTOCOLS) res.status(StatusCode::SWITCHING_PROTOCOLS)
.header(header::CONNECTION, header::HeaderValue::from_str("Upgrade").unwrap()) .header(
.header(header::UPGRADE, header::HeaderValue::from_str("websocket").unwrap()) header::CONNECTION,
.header(header::SEC_WEBSOCKET_ACCEPT, header::HeaderValue::from_str(key.as_str()).unwrap()) header::HeaderValue::from_str("Upgrade").unwrap(),
.header(header::SEC_WEBSOCKET_PROTOCOL, header::HeaderValue::from_str("usimp").unwrap()), )
None) .header(
header::UPGRADE,
header::HeaderValue::from_str("websocket").unwrap(),
)
.header(
header::SEC_WEBSOCKET_ACCEPT,
header::HeaderValue::from_str(key.as_str()).unwrap(),
)
.header(
header::SEC_WEBSOCKET_PROTOCOL,
header::HeaderValue::from_str("usimp").unwrap(),
),
None,
)
} }