From 473b553662d1e217873de94a2132e3d72418c8c9 Mon Sep 17 00:00:00 2001 From: Lorenz Stechauner Date: Sat, 5 Jun 2021 14:17:23 +0200 Subject: [PATCH] Subscriptions working --- src/database.rs | 2 - src/error.rs | 22 +++++++++- src/http.rs | 4 +- src/main.rs | 1 + src/usimp/handler/authenticate.rs | 53 +++++++++++++++++++---- src/usimp/handler/mod.rs | 27 ++++++++---- src/usimp/handler/new_event.rs | 28 ++++++++++++ src/usimp/handler/ping.rs | 6 +-- src/usimp/handler/subscribe.rs | 50 +++++++++++++++++++++ src/usimp/mod.rs | 67 ++++++++++++++++++++++++++-- src/usimp/subscription.rs | 72 +++++++++++++++++++++++++++++++ src/websocket.rs | 8 ++-- 12 files changed, 306 insertions(+), 34 deletions(-) create mode 100644 src/usimp/handler/new_event.rs create mode 100644 src/usimp/handler/subscribe.rs create mode 100644 src/usimp/subscription.rs diff --git a/src/database.rs b/src/database.rs index 87ac9f9..907c5a3 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,8 +1,6 @@ use crate::error::*; use bb8_postgres::tokio_postgres::NoTls; use bb8_postgres::PostgresConnectionManager; -use std::ops::Deref; -use std::sync::{Arc, Mutex, MutexGuard}; use std::time::Duration; pub enum Pool { diff --git a/src/error.rs b/src/error.rs index 9e346c2..57405f6 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,4 +1,4 @@ -use crate::usimp::{InputEnvelope, OutputEnvelope}; +use crate::usimp::{InputEnvelope, OutputEnvelope, Event}; use serde_json::{Value, Map}; use bb8_postgres::tokio_postgres; @@ -25,6 +25,10 @@ pub enum ErrorKind { UsimpError, WebSocketError, DatabaseError, + InvalidSessionError, + AuthenticationError, + SubscriptionError, + InternalError, } impl InputEnvelope { @@ -64,7 +68,10 @@ impl Error { ErrorKind::NotImplemented => "NOT_IMPLEMENTED", ErrorKind::UsimpError => "USIMP_ERROR", ErrorKind::WebSocketError => "WEBSOCKET_ERROR", - ErrorKind::DatabaseError => "BACKEND_ERROR", + ErrorKind::DatabaseError | ErrorKind::InternalError => "SERVER_ERROR", + ErrorKind::InvalidSessionError => "INVALID_SESSION_ERROR", + ErrorKind::AuthenticationError => "AUTHENTICATION_ERROR", + ErrorKind::SubscriptionError => "SUBSCRIPTION_ERROR", } } } @@ -149,3 +156,14 @@ impl From> for Error { } } } + +impl From> for Error { + fn from(error: tokio::sync::mpsc::error::SendError) -> Self { + Error { + kind: ErrorKind::InternalError, + class: ErrorClass::ServerError, + msg: None, + desc: Some(error.to_string()), + } + } +} diff --git a/src/http.rs b/src/http.rs index a420c8a..f5d4499 100644 --- a/src/http.rs +++ b/src/http.rs @@ -55,7 +55,7 @@ async fn endpoint_handler(req: &mut Request, endpoint: String) -> Result) -> Result, hyper::Error> { @@ -135,7 +135,7 @@ pub async fn handler(mut req: Request) -> Result, hyper::Er match val { Ok(val) => Ok(val), - Err(error) => { + Err(_error) => { todo!("help") } } diff --git a/src/main.rs b/src/main.rs index a6ac261..41ada7e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -63,6 +63,7 @@ async fn main() -> Result<(), Error> { println!("Locutus server"); database::init().await?; + usimp::subscription::init(); let server1 = Server::bind(&SocketAddr::from(([0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0], 8080))); let service = make_service_fn(|_: &AddrStream| async { diff --git a/src/usimp/handler/authenticate.rs b/src/usimp/handler/authenticate.rs index 5a8ccb8..af64ed2 100644 --- a/src/usimp/handler/authenticate.rs +++ b/src/usimp/handler/authenticate.rs @@ -1,9 +1,10 @@ +use crate::usimp; use crate::usimp::*; use crate::database; use serde_json::{Value, from_value, to_value}; use serde::{Serialize, Deserialize}; -use std::ops::Deref; +use rand::Rng; #[derive(Serialize, Deserialize, Clone)] struct Input { @@ -13,22 +14,56 @@ struct Input { #[derive(Serialize, Deserialize, Clone)] struct Output { - session: String, + session_id: String, token: String, } -pub async fn handle(input: &InputEnvelope, session: &Session) -> Result { - Ok(to_value(authenticate(from_value(input.data.clone())?).await?)?) +pub async fn handle(input: &InputEnvelope, session: Option) -> Result { + Ok(to_value(authenticate(from_value(input.data.clone())?, session).await?)?) } -async fn authenticate(input: Input) -> Result { - match database::client().await? { +async fn authenticate(input: Input, _session: Option) -> Result { + let backend = database::client().await?; + let token; + let session_id; + match backend { database::Client::Postgres(client) => { - client.execute("SELECT * FROM asdf;", &[]).await?; + let res = client.query( + "SELECT account_id, domain_id \ + FROM accounts \ + WHERE account_name = $1", + &[&input.name] + ).await?; + if res.len() == 0 { + return Err(Error::new(ErrorKind::AuthenticationError, ErrorClass::ClientError, None)); + } + let row = &res[0]; + let account_id: String = row.get(0); + let domain_id: String = row.get(1); + + // TODO password check + if !input.password.eq("MichaelScott") { + return Err(Error::new(ErrorKind::AuthenticationError, ErrorClass::ClientError, None)); + } + + session_id = usimp::get_id(&[domain_id.as_str(), account_id.as_str()]); + token = rand::thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(256) + .map(char::from) + .collect(); + + client.execute( + "INSERT INTO sessions (account_id, session_nr, session_id, session_token) \ + VALUES ($1, COALESCE((SELECT MAX(session_nr) + 1 \ + FROM sessions \ + WHERE account_id = $1), 1), $2, $3);", + &[&account_id, &session_id, &token], + ).await?; } } Ok(Output { - session: "".to_string(), - token: "".to_string(), + session_id, + token, }) } diff --git a/src/usimp/handler/mod.rs b/src/usimp/handler/mod.rs index d2f0008..0043586 100644 --- a/src/usimp/handler/mod.rs +++ b/src/usimp/handler/mod.rs @@ -1,18 +1,29 @@ mod ping; mod authenticate; +mod subscribe; +mod new_event; use crate::usimp::*; +use tokio::sync::mpsc; + +pub async fn endpoint(input: &InputEnvelope, tx: Option>) -> Result { + if input.from_domain != None { + // TODO + return Err(Error::new(ErrorKind::NotImplemented, ErrorClass::ServerError, None)); + } + let session; + if let Some(token) = &input.token { + session = Some(Session::from_token(token).await?); + } else { + session = None; + } -pub async fn endpoint(input: &InputEnvelope) -> Result { println!("Endpoint: {}", input.endpoint); - let session= Session { - account: None, - id: "".to_string(), - nr: 0, - }; Ok(match input.endpoint.as_str() { - "ping" => input.respond(ping::handle(&input, &session).await?), - "authenticate" => input.respond(authenticate::handle(&input, &session).await?), + "ping" => input.respond(ping::handle(&input, session).await?), + "authenticate" => input.respond(authenticate::handle(&input, session).await?), + "subscribe" => input.respond(subscribe::handle(&input, session, tx).await?), + "new_event" => input.respond(new_event::handle(&input, session).await?), _ => input.new_error(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, Some("Invalid endpoint".to_string())), }) } diff --git a/src/usimp/handler/new_event.rs b/src/usimp/handler/new_event.rs new file mode 100644 index 0000000..6dd46cf --- /dev/null +++ b/src/usimp/handler/new_event.rs @@ -0,0 +1,28 @@ +use crate::usimp::*; +use crate::usimp::subscription; + +use serde_json::{Value, from_value, to_value}; +use serde::{Serialize, Deserialize}; + +#[derive(Serialize, Deserialize, Clone)] +struct Input { + room_id: String, + events: Vec, +} + +#[derive(Serialize, Deserialize, Clone)] +struct Output { +} + +pub async fn handle(input: &InputEnvelope, session: Option) -> Result { + Ok(to_value(new_event(from_value(input.data.clone())?, session).await?)?) +} + +async fn new_event(input: Input, session: Option) -> Result { + let _account = get_account(&session)?; + // TODO check permissions + for event in input.events { + subscription::push(input.room_id.as_str(), event).await?; + } + Ok(Output {}) +} diff --git a/src/usimp/handler/ping.rs b/src/usimp/handler/ping.rs index 324da36..76c543f 100644 --- a/src/usimp/handler/ping.rs +++ b/src/usimp/handler/ping.rs @@ -2,10 +2,10 @@ use crate::usimp::*; use serde_json::Value; -pub async fn handle(input: &InputEnvelope, session: &Session) -> Result { - ping(&input.data).await +pub async fn handle(input: &InputEnvelope, session: Option) -> Result { + ping(input.data.clone(), session).await } -async fn ping(input: &Value) -> Result { +async fn ping(input: Value, _session: Option) -> Result { Ok(input.clone()) } diff --git a/src/usimp/handler/subscribe.rs b/src/usimp/handler/subscribe.rs new file mode 100644 index 0000000..6f76654 --- /dev/null +++ b/src/usimp/handler/subscribe.rs @@ -0,0 +1,50 @@ +use crate::usimp::*; +use crate::usimp::subscription; + +use serde_json::{Value, from_value, to_value}; +use serde::{Serialize, Deserialize}; +use tokio::sync::mpsc; + +#[derive(Serialize, Deserialize, Clone)] +struct Input { +} + +#[derive(Serialize, Deserialize, Clone)] +struct Output { + event: Option, +} + +pub async fn handle(input: &InputEnvelope, session: Option, tx: Option>) -> Result { + Ok(to_value(subscribe(from_value(input.data.clone())?, session, input.request_nr, tx).await?)?) +} + +async fn subscribe(_input: Input, session: Option, req_nr: Option, tx: Option>) -> Result { + let account = get_account(&session)?; + let mut rx = subscription::subscribe_account(account).await; + match tx { + Some(tx) => { + tokio::spawn(async move { + while let Some(event) = rx.recv().await { + let _res = tx.send(OutputEnvelope { + error: None, + request_nr: req_nr, + data: to_value(event).unwrap(), + }).await; + } + }); + Ok(Output { + event: None, + }) + } + None => { + if let Some(event) = rx.recv().await { + Ok(Output { + event: Some(event), + }) + } else { + Err(Error::new(ErrorKind::SubscriptionError, ErrorClass::ServerError, None)) + } + } + } +} + diff --git a/src/usimp/mod.rs b/src/usimp/mod.rs index 61e0e0f..18f7c25 100644 --- a/src/usimp/mod.rs +++ b/src/usimp/mod.rs @@ -1,10 +1,15 @@ mod handler; +pub mod subscription; pub use handler::endpoint; -use serde_json::Value; use crate::error::{Error, ErrorClass, ErrorKind}; +use crate::database; +use serde_json::Value; use serde::{Serialize, Deserialize}; +use crypto::sha2::Sha256; +use crypto::digest::Digest; +use base64_url; #[derive(Serialize, Deserialize)] pub struct InputEnvelope { @@ -22,14 +27,43 @@ pub struct OutputEnvelope { pub data: Value, } +#[derive(Clone, Serialize, Deserialize)] +pub struct Event { + data: Value, +} + +pub struct Account { + id: String, + name: String, + domain: String, +} + pub struct Session { id: String, nr: i32, account: Option, } -pub struct Account { +pub fn get_id(input: &[&str]) -> String { + let mut hasher = Sha256::new(); + hasher.input_str(chrono::Utc::now().timestamp_millis().to_string().as_str()); + for part in input { + hasher.input_str(" "); + hasher.input_str(part); + } + let mut result = [0u8; 32]; + hasher.result(&mut result); + base64_url::encode(&result) +} +pub fn get_account(session: &Option) -> Result<&Account, Error> { + match session { + Some(session) => match &session.account { + Some(account) => Ok(&account), + None => return Err(Error::new(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, None)) + }, + None => return Err(Error::new(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, None)) + } } impl InputEnvelope { @@ -43,7 +77,32 @@ impl InputEnvelope { } impl Session { - pub async fn from_token(token: &str) -> Self { - todo!("session") + pub async fn from_token(token: &str) -> Result { + let backend = database::client().await?; + let session; + match backend { + database::Client::Postgres(client) => { + let res = client.query( + "SELECT session_id, session_nr, a.account_id, account_name, domain_id \ + FROM accounts a JOIN sessions s ON a.account_id = s.account_id \ + WHERE session_token = $1;", + &[&token] + ).await?; + if res.len() == 0 { + return Err(Error::new(ErrorKind::InvalidSessionError, ErrorClass::ClientError, None)); + } + let row = &res[0]; + session = Session { + id: row.get(0), + nr: row.get(1), + account: Some(Account { + id: row.get(2), + name: row.get(3), + domain: row.get(4), + }), + }; + } + } + Ok(session) } } diff --git a/src/usimp/subscription.rs b/src/usimp/subscription.rs new file mode 100644 index 0000000..3e65e63 --- /dev/null +++ b/src/usimp/subscription.rs @@ -0,0 +1,72 @@ +use crate::usimp::*; +use crate::database; +use tokio::sync::{mpsc, Mutex}; +use std::collections::HashMap; +use std::sync::Arc; + +static mut ROOMS: Option>>>>> = None; +static mut ACCOUNTS: Option>>>>> = None; + +pub fn init() { + unsafe { + ROOMS = Some(Arc::new(Mutex::new(HashMap::new()))); + ACCOUNTS = Some(Arc::new(Mutex::new(HashMap::new()))); + } +} + +pub async fn subscribe_account(account: &Account) -> mpsc::Receiver { + let (tx, rx) = mpsc::channel::(64); + unsafe { + let mut acc = ACCOUNTS.as_ref().unwrap().lock().await; + match acc.get_mut(account.id.as_str()) { + Some(vec) => { + vec.push(tx); + }, + None => { + acc.insert(account.id.clone(), vec!{tx}); + }, + } + } + rx +} + +pub async fn push(room_id: &str, event: Event) -> Result<(), Error> { + let backend = database::client().await?; + let accounts = match backend { + database::Client::Postgres(client) => { + let res = client.query( + "SELECT account_id \ + FROM members \ + WHERE room_id = $1;", + &[&room_id] + ).await?; + let mut acc: Vec = Vec::new(); + for row in res { + acc.push(row.get(0)); + } + acc + } + }; + + unsafe { + let mut rooms = ROOMS.as_ref().unwrap().lock().await; + if let Some(rooms) = rooms.get_mut(room_id) { + for tx in rooms { + let _res = tx.send(event.clone()).await; + } + } + } + + for account in accounts { + unsafe { + let mut accounts = ACCOUNTS.as_ref().unwrap().lock().await; + if let Some(acc) = accounts.get_mut(account.as_str()) { + for tx in acc { + let _res = tx.send(event.clone()).await; + } + } + } + } + + Ok(()) +} diff --git a/src/websocket.rs b/src/websocket.rs index 425f0d6..5b7d6e6 100644 --- a/src/websocket.rs +++ b/src/websocket.rs @@ -1,4 +1,4 @@ -use hyper::{Request, Response, Body, StatusCode, header}; +use hyper::{Request, Body, StatusCode, header}; use crate::usimp::*; use crate::usimp; use crate::error::*; @@ -40,11 +40,11 @@ async fn receiver(mut stream: SplitStream>, tx: mpsc:: match res { Ok(msg) => { let input: InputEnvelope = serde_json::from_slice(&msg.into_data()[..]).unwrap(); - let output = match usimp::endpoint(&input).await { + let output = match usimp::endpoint(&input, Some(tx.clone())).await { Ok(output) => output, Err(error) => input.error(error), }; - tx.send(output).await; + let _res = tx.send(output).await; }, Err(error) => println!("{:?}", error), } @@ -76,7 +76,7 @@ pub async fn handler(req: Request, res: hyper::http::response::Builder) -> Role::Server, None, ).await; - let (tx, rx) = mpsc::channel::(16); + let (tx, rx) = mpsc::channel::(64); let (sink, stream) = ws_stream.split(); tokio::spawn(async move { sender(sink, rx).await