diff --git a/src/http/handler.rs b/src/http/handler.rs index 3f4a692..61647be 100644 --- a/src/http/handler.rs +++ b/src/http/handler.rs @@ -45,6 +45,9 @@ fn request_handler(client: &mut super::HttpStream) { } else if req.uri.starts_with("/_usimp/") { res.header.add_field("Cache-Control", "no-store"); res.header.add_field("Access-Control-Allow-Origin", "*"); + res.header.add_field("Access-Control-Allow-Methods", "POST, OPTIONS"); + res.header.add_field("Access-Control-Allow-Headers", "Content-Type, From-Domain, To-Domain, Authorization"); + res.header.add_field("Access-Control-Max-Age", "3600"); if req.uri.eq("/_usimp/websocket") { return websocket::connection_handler(client, &req, res); @@ -61,7 +64,14 @@ fn request_handler(client: &mut super::HttpStream) { error = Some(Error::new(Kind::NotImplementedError, Class::ServerError)) }, [endpoint] => match req.method { - Method::POST => return endpoint_handler(client, &req, res, endpoint), + Method::POST => { + return endpoint_handler(client, &req, res, endpoint) + }, + Method::OPTIONS => { + res.status(204); + client.respond(&mut res); + return + } _ => { res.status(405); res.header.add_field("Allow", "POST"); @@ -177,6 +187,7 @@ fn endpoint_handler( let mut authorization = None; if let Some(auth) = req.header.find_field("Authorization") { + // TODO check usimp prefix in Authorization authorization = Some(auth.split(" ").skip(1).collect()); } diff --git a/src/main.rs b/src/main.rs index b5db678..9fa6533 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,5 @@ use std::net::{SocketAddr, TcpListener, UdpSocket}; -use std::sync::Arc; -use std::sync::Mutex; +use std::sync::{Arc, Mutex}; use std::thread; use ansi_term::Color; @@ -12,6 +11,7 @@ use std::time::Duration; mod database; mod error; mod http; +mod subscription; mod udp; mod usimp; mod websocket; @@ -59,6 +59,8 @@ fn main() { }, ]; + subscription::init(); + // Note: rust's stdout is line buffered! eprint!("Initializing database connection pool..."); if let Err(error) = database::init() { diff --git a/src/subscription.rs b/src/subscription.rs new file mode 100644 index 0000000..aabb951 --- /dev/null +++ b/src/subscription.rs @@ -0,0 +1,34 @@ +use std::sync::{Arc, Mutex, mpsc}; +use serde::{Deserialize, Serialize}; +use serde_json; + +static mut SUBSCRIPTIONS: Option>>>> = None; + +#[derive(Clone, Serialize, Deserialize)] +pub struct Event{ + pub data: serde_json::Value, +} + +pub fn init() { + unsafe { + SUBSCRIPTIONS = Some(Arc::new(Mutex::new(Vec::new()))); + } +} + +pub fn subscribe() -> mpsc::Receiver { + let (rx, tx) = mpsc::channel(); + unsafe { SUBSCRIPTIONS.as_ref().unwrap().lock().unwrap().push(rx); } + tx +} + +pub fn unsubscribe(rx: mpsc::Receiver) { + // TODO implement unsubscribe +} + +pub fn notify(event: Event) { + for sender in unsafe { SUBSCRIPTIONS.as_ref().unwrap().lock().unwrap().clone() } { + sender.send(event.clone()); + } +} + + diff --git a/src/usimp/mod.rs b/src/usimp/mod.rs index d7f4b0e..e0eebcc 100644 --- a/src/usimp/mod.rs +++ b/src/usimp/mod.rs @@ -1,6 +1,7 @@ use serde::{Deserialize, Serialize}; use serde_json; +use crate::subscription; use crate::database; use crate::error::*; use crypto::digest::Digest; @@ -18,8 +19,9 @@ pub fn endpoint(envelope: Envelope) -> Result { // TODO domain_check match envelope.endpoint.as_str() { "echo" => Ok(serde_json::to_value(echo(serde_json::from_value(envelope.data)?)?)?), - "authorize" => Ok(serde_json::to_value(authorize(serde_json::from_value(envelope.data)?)?)?), - "notify" => Ok(serde_json::to_value(notify(serde_json::from_value(envelope.data)?)?)?), + "authenticate" => Ok(serde_json::to_value(authenticate(serde_json::from_value(envelope.data)?)?)?), + "subscribe" => Ok(serde_json::to_value(subscribe(serde_json::from_value(envelope.data)?)?)?), + "send_event" => Ok(serde_json::to_value(send_event(serde_json::from_value(envelope.data)?)?)?), _ => return Err(Error::new(Kind::InvalidEndpointError, Class::ClientError)), } } @@ -64,24 +66,24 @@ pub fn echo(input: EchoInput) -> Result { } #[derive(Serialize, Deserialize)] -pub struct AuthorizeInput { +pub struct AuthenticateInput { r#type: String, name: String, password: String, } #[derive(Serialize, Deserialize)] -pub struct AuthorizeOutput { +pub struct AuthenticateOutput { token: String, } -pub fn authorize(input: AuthorizeInput) -> Result { +pub fn authenticate(input: AuthenticateInput) -> Result { let backend = database::client()?; let mut token; match backend { database::Client::Postgres(mut client) => { - let res = client.query("SELECT account_id FROM accounts WHERE name = $1", &[&input.name])?; + let res = client.query("SELECT account_id FROM accounts WHERE account_name = $1", &[&input.name])?; if res.len() == 0 { return Err(Error::new(Kind::AuthenticationError, Class::ClientError)); } @@ -89,7 +91,7 @@ pub fn authorize(input: AuthorizeInput) -> Result { } } - Ok(AuthorizeOutput { token }) + Ok(AuthenticateOutput { token }) } #[derive(Serialize, Deserialize)] @@ -110,23 +112,30 @@ pub fn send_event(input: SendEventInput) -> Result { match backend { database::Client::Postgres(mut client) => { - client.execute("INSERT INTO events (event_id, room_id, data) VALUES ($1, $2, $3)", &[&event_id, &input.room_id, &data])?; + client.execute("INSERT INTO events (event_id, room_id, data) VALUES ($1, $2, to_jsonb($3::text))", &[&event_id, &input.room_id, &data])?; } } + subscription::notify(subscription::Event { + data: input.data + }); + Ok(SendEventOutput { event_id }) } #[derive(Serialize, Deserialize)] -pub struct NotifyInput { +pub struct SubscribeInput { } #[derive(Serialize, Deserialize)] -pub struct NotifyOutput { - +pub struct SubscribeOutput { + event: subscription::Event, } -pub fn notify(input: NotifyInput) -> Result { - Ok(NotifyOutput {}) +pub fn subscribe(input: SubscribeInput) -> Result { + let rx = subscription::subscribe(); + let event = rx.recv().unwrap(); + subscription::unsubscribe(rx); + Ok(SubscribeOutput { event }) }