Compare commits

...

3 Commits

Author SHA1 Message Date
62a7cc66bc Allow event ids from clients 2022-08-27 22:33:40 +02:00
d9672bdac0 Add unsubscribe endpoint 2022-08-27 22:13:19 +02:00
97ae71d553 Add action to output envelope 2022-08-27 18:19:29 +02:00
9 changed files with 230 additions and 110 deletions

View File

@@ -41,6 +41,7 @@ impl InputEnvelope {
OutputEnvelope { OutputEnvelope {
request_nr: self.request_nr, request_nr: self.request_nr,
error: Some(Error::new(kind, class, msg)), error: Some(Error::new(kind, class, msg)),
action: None,
data: Value::Null, data: Value::Null,
} }
} }
@@ -49,6 +50,7 @@ impl InputEnvelope {
OutputEnvelope { OutputEnvelope {
request_nr: self.request_nr, request_nr: self.request_nr,
error: Some(error), error: Some(error),
action: None,
data: Value::Null, data: Value::Null,
} }
} }
@@ -85,8 +87,9 @@ impl From<Error> for OutputEnvelope {
fn from(error: Error) -> Self { fn from(error: Error) -> Self {
OutputEnvelope { OutputEnvelope {
error: Some(error), error: Some(error),
data: Value::Null,
request_nr: None, request_nr: None,
action: None,
data: Value::Null,
} }
} }
} }

View File

@@ -3,7 +3,7 @@ use crate::usimp;
use crate::usimp::*; use crate::usimp::*;
use crate::websocket; use crate::websocket;
use hyper::{body, header, Body, Method, Request, Response, StatusCode}; use hyper::{body, header, Body, Method, Request, Response, StatusCode};
use serde_json::{Map, Value}; use serde_json::Value;
use std::str::FromStr; use std::str::FromStr;
use uuid::Uuid; use uuid::Uuid;
@@ -146,32 +146,19 @@ pub async fn handler(mut req: Request<Body>) -> Result<Response<Body>, hyper::Er
}; };
if let Some(output) = output { if let Some(output) = output {
let mut data = Value::Object(Map::new()); if let Some(ref error) = output.error {
res = match error.class {
match output.error { ErrorClass::ClientProtocolError => res.status(StatusCode::BAD_REQUEST),
Some(error) => { ErrorClass::ServerError => {
res = match error.class { res.status(StatusCode::INTERNAL_SERVER_ERROR)
ErrorClass::ClientProtocolError => res.status(StatusCode::BAD_REQUEST), }
ErrorClass::ServerError => { _ => res.status(StatusCode::OK),
res.status(StatusCode::INTERNAL_SERVER_ERROR) };
}
_ => res.status(StatusCode::OK),
};
data["status"] = Value::from("error");
data["error"] = Value::from(error);
}
None => {
data["status"] = Value::from("success");
}
} }
data["request_nr"] = match output.request_nr { let data: Value = output.into();
Some(nr) => Value::from(nr),
None => Value::Null,
};
data["data"] = output.data;
return Ok(res return Ok(res
.body(Body::from(serde_json::to_string(&data).unwrap() + "\r\n")) .body(Body::from(data.to_string() + "\r\n"))
.unwrap()); .unwrap());
} else { } else {
res = res.status(StatusCode::NO_CONTENT); res = res.status(StatusCode::NO_CONTENT);

View File

@@ -2,6 +2,7 @@ mod authenticate;
mod new_event; mod new_event;
mod ping; mod ping;
mod subscribe; mod subscribe;
mod unsubscribe;
use crate::usimp::*; use crate::usimp::*;
use tokio::sync::mpsc; use tokio::sync::mpsc;
@@ -28,10 +29,11 @@ pub async fn endpoint(
println!("Endpoint: {}", input.endpoint); println!("Endpoint: {}", input.endpoint);
Ok(match input.endpoint.as_str() { Ok(match input.endpoint.as_str() {
"ping" => input.respond(ping::handle(&input, session).await?), "ping" => input.respond(ping::handle(&input, session).await?, None),
"authenticate" => input.respond(authenticate::handle(&input, session).await?), "authenticate" => input.respond(authenticate::handle(&input, session).await?, None),
"subscribe" => input.respond(subscribe::handle(&input, session, tx).await?), "subscribe" => input.respond(subscribe::handle(&input, session, tx).await?, Some(OutputAction::Subscribe)),
"new_event" => input.respond(new_event::handle(&input, session).await?), "unsubscribe" => input.respond(unsubscribe::handle(&input, session).await?, Some(OutputAction::Unsubscribe)),
"new_event" => input.respond(new_event::handle(&input, session).await?, None),
_ => input.new_error( _ => input.new_error(
ErrorKind::UsimpError, ErrorKind::UsimpError,
ErrorClass::ClientProtocolError, ErrorClass::ClientProtocolError,

View File

@@ -22,14 +22,17 @@ pub async fn handle(input: &InputEnvelope, session: Option<Session>) -> Result<V
} }
async fn new_event(input: Input, session: Option<Session>) -> Result<Output, Error> { async fn new_event(input: Input, session: Option<Session>) -> Result<Output, Error> {
let _account = get_account(&session)?; let _account = get_account_opt(&session)?;
let mut uuids = vec![]; let mut uuids = vec![];
// TODO check permissions // TODO check permissions
for mut event in input.events { for mut event in input.events {
let uuid = Uuid::new_v4(); let uuid = match event.id {
Some(id) => id,
None => Uuid::new_v4(),
};
event.id = Some(uuid); event.id = Some(uuid);
uuids.push(uuid); uuids.push(uuid);
subscription::push(&input.room_id, event).await?; subscription::push_room(&input.room_id, event).await?;
} }
Ok(Output {events: uuids}) Ok(Output {events: uuids})
} }

View File

@@ -25,8 +25,7 @@ pub async fn handle(
session, session,
input.request_nr, input.request_nr,
tx, tx,
) ).await?,
.await?,
)?) )?)
} }
@@ -36,8 +35,16 @@ async fn subscribe(
req_nr: Option<u64>, req_nr: Option<u64>,
tx: Option<mpsc::Sender<WebSocketEnvelope>>, tx: Option<mpsc::Sender<WebSocketEnvelope>>,
) -> Result<Output, Error> { ) -> Result<Output, Error> {
let account = get_account(&session)?; let session = match session {
let mut rx = subscription::subscribe_account(account).await; Some(s) => s,
None => return Err(Error::new(
ErrorKind::SubscriptionError,
ErrorClass::ClientError,
None,
)),
};
let mut rx = subscription::subscribe_account(&session, req_nr).await?;
match tx { match tx {
Some(tx) => { Some(tx) => {
tokio::spawn(async move { tokio::spawn(async move {
@@ -46,6 +53,7 @@ async fn subscribe(
.send(OutputEnvelope { .send(OutputEnvelope {
error: None, error: None,
request_nr: req_nr, request_nr: req_nr,
action: Some(OutputAction::Push),
data: serde_json::json![{"events": [event]}], data: serde_json::json![{"events": [event]}],
}.into()) }.into())
.await; .await;

View File

@@ -0,0 +1,45 @@
use crate::usimp::subscription;
use crate::usimp::*;
use serde::{Deserialize, Serialize};
use serde_json::{from_value, to_value, Value};
use tokio::sync::mpsc;
use crate::websocket::WebSocketEnvelope;
#[derive(Serialize, Deserialize, Clone)]
struct Input {}
#[derive(Serialize, Deserialize, Clone)]
struct Output {}
pub async fn handle(
input: &InputEnvelope,
session: Option<Session>,
) -> Result<Value, Error> {
Ok(to_value(
unsubscribe(
from_value(input.data.clone())?,
session,
input.request_nr,
).await?,
)?)
}
async fn unsubscribe(
_input: Input,
session: Option<Session>,
req_nr: Option<u64>,
) -> Result<Output, Error> {
let session = match session {
Some(s) => s,
None => return Err(Error::new(
ErrorKind::SubscriptionError,
ErrorClass::ClientError,
None,
)),
};
subscription::unsubscribe_account(&session, req_nr).await?;
Ok(Output {})
}

View File

@@ -9,9 +9,25 @@ use base64_url;
use crypto::digest::Digest; use crypto::digest::Digest;
use crypto::sha2::Sha256; use crypto::sha2::Sha256;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::{Map, Value};
use uuid::Uuid; use uuid::Uuid;
pub enum OutputAction {
Subscribe,
Unsubscribe,
Push,
}
impl From<OutputAction> for Value {
fn from(action: OutputAction) -> Self {
Value::from(match action {
OutputAction::Subscribe => "subscribe",
OutputAction::Unsubscribe => "unsubscribe",
OutputAction::Push => "push",
})
}
}
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub struct InputEnvelope { pub struct InputEnvelope {
pub endpoint: String, pub endpoint: String,
@@ -25,9 +41,40 @@ pub struct InputEnvelope {
pub struct OutputEnvelope { pub struct OutputEnvelope {
pub error: Option<Error>, pub error: Option<Error>,
pub request_nr: Option<u64>, pub request_nr: Option<u64>,
pub action: Option<OutputAction>,
pub data: Value, pub data: Value,
} }
impl From<OutputEnvelope> for Value {
fn from(msg: OutputEnvelope) -> Self {
let mut envelope = Value::Object(Map::new());
envelope["request_nr"] = match msg.request_nr {
Some(nr) => Value::from(nr),
None => Value::Null,
};
match msg.error {
Some(error) => {
envelope["status"] = Value::from("error");
envelope["error"] = Value::from(error);
}
None => {
envelope["status"] = Value::from("success");
}
}
envelope["action"] = match msg.action {
Some(a) => Value::from(a),
None => Value::Null,
};
envelope["data"] = msg.data;
envelope
}
}
#[derive(Clone, Serialize, Deserialize)] #[derive(Clone, Serialize, Deserialize)]
pub struct Event { pub struct Event {
data: Value, data: Value,
@@ -46,18 +93,22 @@ pub struct Session {
account: Option<Account>, account: Option<Account>,
} }
pub fn get_account(session: &Option<Session>) -> Result<&Account, Error> { pub fn get_account_opt(session: &Option<Session>) -> Result<&Account, Error> {
match session { match session {
Some(session) => match &session.account { Some(session) => get_account(session),
Some(account) => Ok(&account), None => {
None => { return Err(Error::new(
return Err(Error::new( ErrorKind::UsimpError,
ErrorKind::UsimpError, ErrorClass::ClientProtocolError,
ErrorClass::ClientProtocolError, None,
None, ))
)) }
} }
}, }
pub fn get_account(session: &Session) -> Result<&Account, Error> {
match &session.account {
Some(account) => Ok(&account),
None => { None => {
return Err(Error::new( return Err(Error::new(
ErrorKind::UsimpError, ErrorKind::UsimpError,
@@ -69,10 +120,11 @@ pub fn get_account(session: &Option<Session>) -> Result<&Account, Error> {
} }
impl InputEnvelope { impl InputEnvelope {
pub fn respond(&self, data: Value) -> OutputEnvelope { pub fn respond(&self, data: Value, action: Option<OutputAction>) -> OutputEnvelope {
OutputEnvelope { OutputEnvelope {
error: None, error: None,
request_nr: self.request_nr, request_nr: self.request_nr,
action,
data, data,
} }
} }

View File

@@ -1,36 +1,80 @@
use crate::database; use crate::database;
use crate::usimp::*; use crate::usimp::*;
use std::collections::HashMap; use std::collections::{HashMap, HashSet};
use std::ops::Deref;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::{mpsc, Mutex}; use tokio::sync::{mpsc, Mutex};
use tokio::sync::mpsc::Sender;
static mut ROOMS: Option<Arc<Mutex<HashMap<Uuid, Vec<mpsc::Sender<Event>>>>>> = None; #[derive(Clone, Eq, Hash, PartialEq)]
static mut ACCOUNTS: Option<Arc<Mutex<HashMap<Uuid, Vec<mpsc::Sender<Event>>>>>> = None; struct Subscription {
session: Uuid,
req_nr: Option<u64>,
}
static mut SUBSCRIPTIONS: Option<Arc<Mutex<HashMap<Subscription, mpsc::Sender<Event>>>>> = None;
static mut ACCOUNTS: Option<Arc<Mutex<HashMap<Uuid, HashSet<Subscription>>>>> = None;
pub fn init() { pub fn init() {
unsafe { unsafe {
ROOMS = Some(Arc::new(Mutex::new(HashMap::new()))); SUBSCRIPTIONS = Some(Arc::new(Mutex::new(HashMap::new())));
ACCOUNTS = Some(Arc::new(Mutex::new(HashMap::new()))); ACCOUNTS = Some(Arc::new(Mutex::new(HashMap::new())));
} }
} }
pub async fn subscribe_account(account: &Account) -> mpsc::Receiver<Event> { pub async fn subscribe_account(session: &Session, req_nr: Option<u64>) -> Result<mpsc::Receiver<Event>, Error> {
let account = get_account(session)?;
let sub = Subscription {session: session.id, req_nr};
let (tx, rx) = mpsc::channel::<Event>(64); let (tx, rx) = mpsc::channel::<Event>(64);
unsafe { unsafe {
let mut acc = ACCOUNTS.as_ref().unwrap().lock().await; let mut subs = SUBSCRIPTIONS.as_ref().unwrap().lock().await;
match acc.get_mut(&account.id) { let mut accs = ACCOUNTS.as_ref().unwrap().lock().await;
Some(vec) => { match accs.get_mut(&account.id) {
vec.push(tx); Some(set) => {
set.insert(sub.clone());
} }
None => { None => {
acc.insert(account.id, vec![tx]); let mut set = HashSet::new();
set.insert(sub.clone());
accs.insert(account.id, set);
}
}
match subs.get_mut(&sub) {
Some(_) => return Err(Error::new(
ErrorKind::SubscriptionError,
ErrorClass::ClientError,
None
)),
None => {
subs.insert(sub.clone(), tx);
} }
} }
} }
rx
Ok(rx)
} }
pub async fn push(room_id: &Uuid, event: Event) -> Result<(), Error> { pub async fn unsubscribe_account(session: &Session, req_nr: Option<u64>) -> Result<(), Error> {
let account = get_account(session)?;
let sub = Subscription {session: session.id, req_nr};
unsafe {
let mut subs = SUBSCRIPTIONS.as_ref().unwrap().lock().await;
let mut accs = ACCOUNTS.as_ref().unwrap().lock().await;
match accs.get_mut(&account.id) {
Some(set) => {
set.remove(&sub);
}
None => {},
}
subs.remove(&sub);
}
Ok(())
}
pub async fn push_room(room_id: &Uuid, event: Event) -> Result<(), Error> {
let backend = database::client().await?; let backend = database::client().await?;
let accounts = match backend { let accounts = match backend {
database::Client::Postgres(client) => { database::Client::Postgres(client) => {
@@ -50,26 +94,26 @@ pub async fn push(room_id: &Uuid, event: Event) -> Result<(), Error> {
} }
}; };
let mut rooms = unsafe { let mut room: Vec<mpsc::Sender<Event>> = Vec::new();
let mut rooms = ROOMS.as_ref().unwrap().lock().await;
if let Some(rooms) = rooms.get_mut(room_id) {
rooms.clone()
} else {
Vec::new()
}
};
for account in accounts { for account in accounts {
unsafe { unsafe {
let mut accounts = ACCOUNTS.as_ref().unwrap().lock().await; let subs = SUBSCRIPTIONS.as_ref().unwrap().lock().await;
if let Some(acc) = accounts.get_mut(&account) { let accs = ACCOUNTS.as_ref().unwrap().lock().await;
let mut acc = acc.clone(); if let Some(acc_subs) = accs.get(&account) {
rooms.append(&mut acc); for sub in acc_subs {
match subs.get(sub) {
None => {}
Some(tx) => {
room.push(tx.clone());
}
}
}
} }
} }
} }
for tx in rooms { for tx in room {
let _res = tx.send(event.clone()).await; let _res = tx.send(event.clone()).await;
} }

View File

@@ -8,7 +8,7 @@ use hyper::{header, Body, Request, StatusCode};
use hyper_tungstenite::hyper::upgrade::Upgraded; use hyper_tungstenite::hyper::upgrade::Upgraded;
use hyper_tungstenite::tungstenite::{handshake, Message}; use hyper_tungstenite::tungstenite::{handshake, Message};
use hyper_tungstenite::{tungstenite::protocol::Role, WebSocketStream}; use hyper_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
use serde_json::{Map, Value}; use serde_json::Value;
use tokio::sync::mpsc; use tokio::sync::mpsc;
pub enum WebSocketEnvelope { pub enum WebSocketEnvelope {
@@ -50,21 +50,7 @@ async fn sender(
break; break;
} }
WebSocketEnvelope::Text(msg) => { WebSocketEnvelope::Text(msg) => {
let mut envelope = Value::Object(Map::new()); let envelope: Value = msg.into();
envelope["data"] = msg.data;
envelope["request_nr"] = match msg.request_nr {
Some(nr) => Value::from(nr),
None => Value::Null,
};
match msg.error {
Some(error) => {
envelope["status"] = Value::from("error");
envelope["error"] = Value::from(error);
}
None => {
envelope["status"] = Value::from("success");
}
}
if let Err(error) = sink.send(Message::Text(envelope.to_string())).await { if let Err(error) = sink.send(Message::Text(envelope.to_string())).await {
eprintln!("{:?}", error); eprintln!("{:?}", error);
return; return;
@@ -90,16 +76,11 @@ async fn receiver(
_res = tx.send(WebSocketEnvelope::Close).await; _res = tx.send(WebSocketEnvelope::Close).await;
break; break;
} else if msg.is_binary() { } else if msg.is_binary() {
_res = tx.send(WebSocketEnvelope::Text(OutputEnvelope { _res = tx.send(WebSocketEnvelope::Text(OutputEnvelope::from(Error::new(
error: Some(Error { ErrorKind::WebSocketError,
kind: ErrorKind::WebSocketError, ErrorClass::ClientProtocolError,
class: ErrorClass::ClientProtocolError, Some("Binary frames are not allowed".to_string())
msg: Some("Binary frames are not allowed".to_string()), )))).await;
desc: None,
}),
request_nr: None,
data: Value::Null,
})).await;
} else if msg.is_text() { } else if msg.is_text() {
let input: InputEnvelope = serde_json::from_slice(&msg.into_data()[..]).unwrap(); let input: InputEnvelope = serde_json::from_slice(&msg.into_data()[..]).unwrap();
let output = match usimp::endpoint(&input, Some(tx.clone())).await { let output = match usimp::endpoint(&input, Some(tx.clone())).await {
@@ -108,16 +89,11 @@ async fn receiver(
}; };
_res = tx.send(WebSocketEnvelope::Text(output)).await; _res = tx.send(WebSocketEnvelope::Text(output)).await;
} else { } else {
_res = tx.send(WebSocketEnvelope::Text(OutputEnvelope { _res = tx.send(WebSocketEnvelope::Text(OutputEnvelope::from(Error::new(
error: Some(Error { ErrorKind::WebSocketError,
kind: ErrorKind::WebSocketError, ErrorClass::ClientProtocolError,
class: ErrorClass::ClientProtocolError, Some("Unknown frame opcode".to_string())
msg: Some("Unknown frame opcode".to_string()), )))).await;
desc: None,
}),
request_nr: None,
data: Value::Null,
})).await;
} }
} }
Err(error) => println!("{:?}", error), Err(error) => println!("{:?}", error),