diff --git a/Cargo.toml b/Cargo.toml index 9cbc32b..4cbef2d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,3 +19,4 @@ ansi_term = "0.12" rust-crypto = "^0.2" base64 = "0.13.0" base64-url = "1.4.10" +rand = "0.8.3" diff --git a/db/00.create.sql b/db/00.create.sql index 23833f2..d5428d0 100644 --- a/db/00.create.sql +++ b/db/00.create.sql @@ -89,6 +89,7 @@ CREATE TABLE members CREATE TABLE sessions ( + session_id CHAR(43) NOT NULL, account_id CHAR(43) NOT NULL, session_nr INTEGER NOT NULL DEFAULT 1, session_token VARCHAR(256) NOT NULL, @@ -100,8 +101,9 @@ CREATE TABLE sessions last_used TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT (now() at time zone 'utc'), last_used_tz INTEGER DEFAULT NULL, - CONSTRAINT pk_sessions PRIMARY KEY (account_id, session_nr), - CONSTRAINT sk_sessions UNIQUE (session_token), + CONSTRAINT pk_sessions PRIMARY KEY (session_id), + CONSTRAINT sk_sessions_1 UNIQUE (account_id, session_nr), + CONSTRAINT sk_sessions_2 UNIQUE (session_token), CONSTRAINT fk_sessions_accounts FOREIGN KEY (account_id) REFERENCES accounts (account_id) ); diff --git a/src/error.rs b/src/error.rs index 46cb62f..218a7b9 100644 --- a/src/error.rs +++ b/src/error.rs @@ -13,6 +13,7 @@ pub enum Kind { UsimpProtocolError, Utf8DecodeError, AuthenticationError, + InvalidSessionError, } #[derive(Copy, Clone, Debug)] @@ -64,6 +65,7 @@ impl Error { Kind::UsimpProtocolError => "USIMP protocol error", Kind::Utf8DecodeError => "Unable to decode UTF-8 data", Kind::AuthenticationError => "Unable to authenticate", + Kind::InvalidSessionError => "Invalid session", }, } } @@ -95,6 +97,7 @@ impl fmt::Display for Error { Kind::UsimpProtocolError => "usimp protocol error", Kind::Utf8DecodeError => "unable to decode utf-8 data", Kind::AuthenticationError => "unable to authenticate", + Kind::InvalidSessionError => "invalid session", } .to_string(); if let Some(desc) = &self.desc { diff --git a/src/http/handler.rs b/src/http/handler.rs index 3c88f8d..8b9bba8 100644 --- a/src/http/handler.rs +++ b/src/http/handler.rs @@ -197,16 +197,9 @@ fn endpoint_handler( authorization = Some(auth.split(" ").skip(1).collect()); } - let mut from_domain; + let mut from_domain = None; if let Some(from) = req.header.find_field("From-Domain") { - from_domain = from.to_string(); - } else { - return error_handler( - client, - res, - Error::new(Kind::UsimpProtocolError, Class::ClientProtocolError) - .set_desc("Unable to find field 'From-Domain'".to_string()) - ); + from_domain = Some(from.to_string()); } let mut to_domain; @@ -225,7 +218,7 @@ fn endpoint_handler( endpoint: endpoint.to_string(), from_domain, to_domain, - authorization, + token: authorization, data, }; let buf = match usimp::endpoint(input) { diff --git a/src/usimp/mod.rs b/src/usimp/mod.rs index e4b2752..60e34d8 100644 --- a/src/usimp/mod.rs +++ b/src/usimp/mod.rs @@ -5,23 +5,79 @@ use crate::subscription; use crate::database; use crate::error::*; use crypto::digest::Digest; +use rand; +use rand::Rng; pub struct Envelope { pub endpoint: String, - pub from_domain: String, + pub from_domain: Option, pub to_domain: String, - pub authorization: Option, + pub token: Option, pub data: serde_json::Value, } +pub struct Account { + id: String, + name: String, + domain: String, +} + +pub struct Session { + id: String, + nr: i32, + account: Option, +} + +impl Session { + pub fn from_token(token: &str) -> Result { + let backend = database::client()?; + + let mut session; + match backend { + database::Client::Postgres(mut client) => { + let res = client.query( + "SELECT session_id, session_nr, a.account_id, account_name, domain_id \ + FROM accounts a JOIN sessions s ON a.account_id = s.account_id \ + WHERE session_token = $1", + &[&token] + )?; + + if res.len() == 0 { + return Err(Error::new(Kind::InvalidSessionError, Class::ClientError)); + } + + session = Session { + id: res[0].get(0), + nr: res[0].get(1), + account: Some(Account { + id: res[0].get(2), + name: res[0].get(3), + domain: res[0].get(4), + }), + }; + } + } + + Ok(session) + } +} + pub fn endpoint(envelope: Envelope) -> Result { - // TODO check authorization - // TODO check from/to domain + if envelope.from_domain != None { + // TODO + return Err(Error::new(Kind::NotImplementedError, Class::ServerError)); + } + + let mut session = None; + if let Some(token) = &envelope.token { + session = Some(Session::from_token(token)?); + } + let out = match envelope.endpoint.as_str() { - "echo" => serde_json::to_value(echo(serde_json::from_value(envelope.data)?)?)?, - "authenticate" => serde_json::to_value(authenticate(serde_json::from_value(envelope.data)?)?)?, - "subscribe" => serde_json::to_value(subscribe(serde_json::from_value(envelope.data)?)?)?, - "send_event" => serde_json::to_value(send_event(serde_json::from_value(envelope.data)?)?)?, + "echo" => serde_json::to_value(echo(session, serde_json::from_value(envelope.data)?)?)?, + "authenticate" => serde_json::to_value(authenticate(session, serde_json::from_value(envelope.data)?)?)?, + "subscribe" => serde_json::to_value(subscribe(session, serde_json::from_value(envelope.data)?)?)?, + "send_event" => serde_json::to_value(send_event(session, serde_json::from_value(envelope.data)?)?)?, _ => return Err(Error::new(Kind::InvalidEndpointError, Class::ClientProtocolError)), }; @@ -55,7 +111,7 @@ pub struct EchoOutput { database: Option, } -pub fn echo(input: EchoInput) -> Result { +pub fn echo(session: Option, input: EchoInput) -> Result { let backend = database::client()?; let mut output = EchoOutput { message: input.message, @@ -84,17 +140,43 @@ pub struct AuthenticateOutput { token: String, } -pub fn authenticate(input: AuthenticateInput) -> Result { +pub fn authenticate(session: Option, input: AuthenticateInput) -> Result { let backend = database::client()?; - let mut token; + let mut token: String; match backend { database::Client::Postgres(mut client) => { - let res = client.query("SELECT account_id FROM accounts WHERE account_name = $1", &[&input.name])?; + let res = client.query( + "SELECT account_id FROM accounts WHERE account_name = $1", + &[&input.name] + )?; if res.len() == 0 { return Err(Error::new(Kind::AuthenticationError, Class::ClientError)); } - token = res.get(0).unwrap().get(0); + let account_id: String = res[0].get(0); + + // TODO password check + if !input.password.eq("MichaelScott") { + return Err(Error::new(Kind::AuthenticationError, Class::ClientError)); + } + + token = rand::thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(256) + .map(char::from) + .collect(); + + let session_id: String = rand::thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(43) + .map(char::from) + .collect(); + + client.execute( + "INSERT INTO sessions (account_id, session_nr, session_id, session_token) \ + VALUES ($1, COALESCE((SELECT MAX(session_nr) + 1 FROM sessions WHERE account_id = $1), 1), $2, $3)", + &[&account_id, &session_id, &token], + )?; } } @@ -112,14 +194,25 @@ pub struct SendEventOutput { event_id: String, } -pub fn send_event(input: SendEventInput) -> Result { +pub fn send_event(session: Option, input: SendEventInput) -> Result { let backend = database::client()?; let event_id = get_id("hermann"); // TODO fix id generation let data = serde_json::to_string(&input.data)?; + let session = session.unwrap(); match backend { database::Client::Postgres(mut client) => { - client.execute("INSERT INTO events (event_id, room_id, data) VALUES ($1, $2, to_jsonb($3::text))", &[&event_id, &input.room_id, &data])?; + + let res = client.query( + "SELECT member_id FROM members \ + WHERE (room_id, account_id) = ($1, $2)", + &[&input.room_id, &session.account.unwrap().id])?; + let member_id: String = res[0].get(0); + + client.execute( + "INSERT INTO events (event_id, room_id, from_member_id, from_session_id, data) \ + VALUES ($1, $2, $3, $4, to_jsonb($5::text))", + &[&event_id, &input.room_id, &member_id, &session.id, &data])?; } } @@ -140,7 +233,7 @@ pub struct SubscribeOutput { event: subscription::Event, } -pub fn subscribe(input: SubscribeInput) -> Result { +pub fn subscribe(session: Option, input: SubscribeInput) -> Result { let rx = subscription::subscribe(); let event = rx.recv().unwrap(); subscription::unsubscribe(rx);