From d9672bdac0aa0288bddd8fbbfc72e0ac16fbd9b2 Mon Sep 17 00:00:00 2001 From: Lorenz Stechauner Date: Sat, 27 Aug 2022 22:13:19 +0200 Subject: [PATCH] Add unsubscribe endpoint --- src/usimp/handler/mod.rs | 2 + src/usimp/handler/new_event.rs | 4 +- src/usimp/handler/subscribe.rs | 12 +++- src/usimp/handler/unsubscribe.rs | 45 +++++++++++++++ src/usimp/mod.rs | 26 +++++---- src/usimp/subscription.rs | 94 +++++++++++++++++++++++--------- 6 files changed, 143 insertions(+), 40 deletions(-) create mode 100644 src/usimp/handler/unsubscribe.rs diff --git a/src/usimp/handler/mod.rs b/src/usimp/handler/mod.rs index 2f67148..fb423e5 100644 --- a/src/usimp/handler/mod.rs +++ b/src/usimp/handler/mod.rs @@ -2,6 +2,7 @@ mod authenticate; mod new_event; mod ping; mod subscribe; +mod unsubscribe; use crate::usimp::*; use tokio::sync::mpsc; @@ -31,6 +32,7 @@ pub async fn endpoint( "ping" => input.respond(ping::handle(&input, session).await?, None), "authenticate" => input.respond(authenticate::handle(&input, session).await?, None), "subscribe" => input.respond(subscribe::handle(&input, session, tx).await?, Some(OutputAction::Subscribe)), + "unsubscribe" => input.respond(unsubscribe::handle(&input, session).await?, Some(OutputAction::Unsubscribe)), "new_event" => input.respond(new_event::handle(&input, session).await?, None), _ => input.new_error( ErrorKind::UsimpError, diff --git a/src/usimp/handler/new_event.rs b/src/usimp/handler/new_event.rs index 4c8d4ae..175458f 100644 --- a/src/usimp/handler/new_event.rs +++ b/src/usimp/handler/new_event.rs @@ -22,14 +22,14 @@ pub async fn handle(input: &InputEnvelope, session: Option) -> Result) -> Result { - let _account = get_account(&session)?; + let _account = get_account_opt(&session)?; let mut uuids = vec![]; // TODO check permissions for mut event in input.events { let uuid = Uuid::new_v4(); event.id = Some(uuid); uuids.push(uuid); - subscription::push(&input.room_id, event).await?; + subscription::push_room(&input.room_id, event).await?; } Ok(Output {events: uuids}) } diff --git a/src/usimp/handler/subscribe.rs b/src/usimp/handler/subscribe.rs index 159d3dc..4f43c6d 100644 --- a/src/usimp/handler/subscribe.rs +++ b/src/usimp/handler/subscribe.rs @@ -35,8 +35,16 @@ async fn subscribe( req_nr: Option, tx: Option>, ) -> Result { - let account = get_account(&session)?; - let mut rx = subscription::subscribe_account(account).await; + let session = match session { + Some(s) => s, + None => return Err(Error::new( + ErrorKind::SubscriptionError, + ErrorClass::ClientError, + None, + )), + }; + + let mut rx = subscription::subscribe_account(&session, req_nr).await?; match tx { Some(tx) => { tokio::spawn(async move { diff --git a/src/usimp/handler/unsubscribe.rs b/src/usimp/handler/unsubscribe.rs new file mode 100644 index 0000000..f31fdf6 --- /dev/null +++ b/src/usimp/handler/unsubscribe.rs @@ -0,0 +1,45 @@ +use crate::usimp::subscription; +use crate::usimp::*; + +use serde::{Deserialize, Serialize}; +use serde_json::{from_value, to_value, Value}; +use tokio::sync::mpsc; +use crate::websocket::WebSocketEnvelope; + +#[derive(Serialize, Deserialize, Clone)] +struct Input {} + +#[derive(Serialize, Deserialize, Clone)] +struct Output {} + +pub async fn handle( + input: &InputEnvelope, + session: Option, +) -> Result { + Ok(to_value( + unsubscribe( + from_value(input.data.clone())?, + session, + input.request_nr, + ).await?, + )?) +} + +async fn unsubscribe( + _input: Input, + session: Option, + req_nr: Option, +) -> Result { + let session = match session { + Some(s) => s, + None => return Err(Error::new( + ErrorKind::SubscriptionError, + ErrorClass::ClientError, + None, + )), + }; + + subscription::unsubscribe_account(&session, req_nr).await?; + + Ok(Output {}) +} diff --git a/src/usimp/mod.rs b/src/usimp/mod.rs index c2f4270..9e67ab9 100644 --- a/src/usimp/mod.rs +++ b/src/usimp/mod.rs @@ -93,18 +93,22 @@ pub struct Session { account: Option, } -pub fn get_account(session: &Option) -> Result<&Account, Error> { +pub fn get_account_opt(session: &Option) -> Result<&Account, Error> { match session { - Some(session) => match &session.account { - Some(account) => Ok(&account), - None => { - return Err(Error::new( - ErrorKind::UsimpError, - ErrorClass::ClientProtocolError, - None, - )) - } - }, + Some(session) => get_account(session), + None => { + return Err(Error::new( + ErrorKind::UsimpError, + ErrorClass::ClientProtocolError, + None, + )) + } + } +} + +pub fn get_account(session: &Session) -> Result<&Account, Error> { + match &session.account { + Some(account) => Ok(&account), None => { return Err(Error::new( ErrorKind::UsimpError, diff --git a/src/usimp/subscription.rs b/src/usimp/subscription.rs index 08b1458..5c573d5 100644 --- a/src/usimp/subscription.rs +++ b/src/usimp/subscription.rs @@ -1,36 +1,80 @@ use crate::database; use crate::usimp::*; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; +use std::ops::Deref; use std::sync::Arc; use tokio::sync::{mpsc, Mutex}; +use tokio::sync::mpsc::Sender; -static mut ROOMS: Option>>>>> = None; -static mut ACCOUNTS: Option>>>>> = None; +#[derive(Clone, Eq, Hash, PartialEq)] +struct Subscription { + session: Uuid, + req_nr: Option, +} + +static mut SUBSCRIPTIONS: Option>>>> = None; +static mut ACCOUNTS: Option>>>> = None; pub fn init() { unsafe { - ROOMS = Some(Arc::new(Mutex::new(HashMap::new()))); + SUBSCRIPTIONS = Some(Arc::new(Mutex::new(HashMap::new()))); ACCOUNTS = Some(Arc::new(Mutex::new(HashMap::new()))); } } -pub async fn subscribe_account(account: &Account) -> mpsc::Receiver { +pub async fn subscribe_account(session: &Session, req_nr: Option) -> Result, Error> { + let account = get_account(session)?; + let sub = Subscription {session: session.id, req_nr}; + let (tx, rx) = mpsc::channel::(64); unsafe { - let mut acc = ACCOUNTS.as_ref().unwrap().lock().await; - match acc.get_mut(&account.id) { - Some(vec) => { - vec.push(tx); + let mut subs = SUBSCRIPTIONS.as_ref().unwrap().lock().await; + let mut accs = ACCOUNTS.as_ref().unwrap().lock().await; + match accs.get_mut(&account.id) { + Some(set) => { + set.insert(sub.clone()); } None => { - acc.insert(account.id, vec![tx]); + let mut set = HashSet::new(); + set.insert(sub.clone()); + accs.insert(account.id, set); + } + } + match subs.get_mut(&sub) { + Some(_) => return Err(Error::new( + ErrorKind::SubscriptionError, + ErrorClass::ClientError, + None + )), + None => { + subs.insert(sub.clone(), tx); } } } - rx + + Ok(rx) } -pub async fn push(room_id: &Uuid, event: Event) -> Result<(), Error> { +pub async fn unsubscribe_account(session: &Session, req_nr: Option) -> Result<(), Error> { + let account = get_account(session)?; + let sub = Subscription {session: session.id, req_nr}; + + unsafe { + let mut subs = SUBSCRIPTIONS.as_ref().unwrap().lock().await; + let mut accs = ACCOUNTS.as_ref().unwrap().lock().await; + match accs.get_mut(&account.id) { + Some(set) => { + set.remove(&sub); + } + None => {}, + } + subs.remove(&sub); + } + + Ok(()) +} + +pub async fn push_room(room_id: &Uuid, event: Event) -> Result<(), Error> { let backend = database::client().await?; let accounts = match backend { database::Client::Postgres(client) => { @@ -50,26 +94,26 @@ pub async fn push(room_id: &Uuid, event: Event) -> Result<(), Error> { } }; - let mut rooms = unsafe { - let mut rooms = ROOMS.as_ref().unwrap().lock().await; - if let Some(rooms) = rooms.get_mut(room_id) { - rooms.clone() - } else { - Vec::new() - } - }; + let mut room: Vec> = Vec::new(); for account in accounts { unsafe { - let mut accounts = ACCOUNTS.as_ref().unwrap().lock().await; - if let Some(acc) = accounts.get_mut(&account) { - let mut acc = acc.clone(); - rooms.append(&mut acc); + let subs = SUBSCRIPTIONS.as_ref().unwrap().lock().await; + let accs = ACCOUNTS.as_ref().unwrap().lock().await; + if let Some(acc_subs) = accs.get(&account) { + for sub in acc_subs { + match subs.get(sub) { + None => {} + Some(tx) => { + room.push(tx.clone()); + } + } + } } } } - for tx in rooms { + for tx in room { let _res = tx.send(event.clone()).await; }