From d94daa2c9376f5727fc484548fa8f2c4b7b527bd Mon Sep 17 00:00:00 2001 From: Lorenz Stechauner Date: Sat, 27 Aug 2022 17:39:19 +0200 Subject: [PATCH] Fix WebSocket connection closes --- src/usimp/handler/mod.rs | 3 +- src/usimp/handler/subscribe.rs | 7 +- src/websocket.rs | 133 +++++++++++++++++++++++++-------- 3 files changed, 107 insertions(+), 36 deletions(-) diff --git a/src/usimp/handler/mod.rs b/src/usimp/handler/mod.rs index 0c98485..33dbfa1 100644 --- a/src/usimp/handler/mod.rs +++ b/src/usimp/handler/mod.rs @@ -5,10 +5,11 @@ mod subscribe; use crate::usimp::*; use tokio::sync::mpsc; +use crate::websocket::WebSocketEnvelope; pub async fn endpoint( input: &InputEnvelope, - tx: Option>, + tx: Option>, ) -> Result { if input.from_domain != None { // TODO diff --git a/src/usimp/handler/subscribe.rs b/src/usimp/handler/subscribe.rs index 3a4cc07..2927dac 100644 --- a/src/usimp/handler/subscribe.rs +++ b/src/usimp/handler/subscribe.rs @@ -4,6 +4,7 @@ use crate::usimp::*; use serde::{Deserialize, Serialize}; use serde_json::{from_value, to_value, Value}; use tokio::sync::mpsc; +use crate::websocket::WebSocketEnvelope; #[derive(Serialize, Deserialize, Clone)] struct Input {} @@ -16,7 +17,7 @@ struct Output { pub async fn handle( input: &InputEnvelope, session: Option, - tx: Option>, + tx: Option>, ) -> Result { Ok(to_value( subscribe( @@ -33,7 +34,7 @@ async fn subscribe( _input: Input, session: Option, req_nr: Option, - tx: Option>, + tx: Option>, ) -> Result { let account = get_account(&session)?; let mut rx = subscription::subscribe_account(account).await; @@ -46,7 +47,7 @@ async fn subscribe( error: None, request_nr: req_nr, data: serde_json::json![{"events": [event]}], - }) + }.into()) .await; } }); diff --git a/src/websocket.rs b/src/websocket.rs index 173ac02..88e1037 100644 --- a/src/websocket.rs +++ b/src/websocket.rs @@ -11,46 +11,114 @@ use hyper_tungstenite::{tungstenite::protocol::Role, WebSocketStream}; use serde_json::{Map, Value}; use tokio::sync::mpsc; +pub enum WebSocketEnvelope { + Close, + Ping(Vec), + Pong(Vec), + Text(OutputEnvelope), +} + +impl From for WebSocketEnvelope { + fn from(envelope: OutputEnvelope) -> Self { + WebSocketEnvelope::Text(envelope) + } +} + async fn sender( mut sink: SplitSink, Message>, - mut rx: mpsc::Receiver, + mut rx: mpsc::Receiver, ) { while let Some(msg) = rx.recv().await { - let mut envelope = Value::Object(Map::new()); - envelope["data"] = msg.data; - envelope["request_nr"] = match msg.request_nr { - Some(nr) => Value::from(nr), - None => Value::Null, - }; - match msg.error { - Some(error) => { - envelope["status"] = Value::from("error"); - envelope["error"] = Value::from(error); + match msg { + WebSocketEnvelope::Ping(data) => { + if let Err(error) = sink.send(Message::Ping(data)).await { + eprintln!("{:?}", error); + return; + } } - None => { - envelope["status"] = Value::from("success"); + WebSocketEnvelope::Pong(data) => { + if let Err(error) = sink.send(Message::Pong(data)).await { + eprintln!("{:?}", error); + return; + } + } + WebSocketEnvelope::Close => { + if let Err(error) = sink.send(Message::Close(None)).await { + eprintln!("{:?}", error); + return; + } + break; + } + WebSocketEnvelope::Text(msg) => { + let mut envelope = Value::Object(Map::new()); + envelope["data"] = msg.data; + envelope["request_nr"] = match msg.request_nr { + Some(nr) => Value::from(nr), + None => Value::Null, + }; + match msg.error { + Some(error) => { + envelope["status"] = Value::from("error"); + envelope["error"] = Value::from(error); + } + None => { + envelope["status"] = Value::from("success"); + } + } + if let Err(error) = sink.send(Message::Text(envelope.to_string())).await { + eprintln!("{:?}", error); + return; + } } - } - if let Err(error) = sink.send(Message::Text(envelope.to_string())).await { - eprintln!("{:?}", error); - break; } } } async fn receiver( mut stream: SplitStream>, - tx: mpsc::Sender, + tx: mpsc::Sender, ) { while let Some(res) = stream.next().await { match res { Ok(msg) => { - let input: InputEnvelope = serde_json::from_slice(&msg.into_data()[..]).unwrap(); - let output = match usimp::endpoint(&input, Some(tx.clone())).await { - Ok(output) => output, - Err(error) => input.error(error), - }; - let _res = tx.send(output).await; + let _res; + if msg.is_ping() { + _res = tx.send(WebSocketEnvelope::Pong(msg.into_data())).await; + } else if msg.is_pong() { + // Ignore + } else if msg.is_close() { + _res = tx.send(WebSocketEnvelope::Close).await; + break; + } else if msg.is_binary() { + _res = tx.send(WebSocketEnvelope::Text(OutputEnvelope { + error: Some(Error { + kind: ErrorKind::WebSocketError, + class: ErrorClass::ClientProtocolError, + msg: Some("Binary frames are not allowed".to_string()), + desc: None, + }), + request_nr: None, + data: Value::Null, + })).await; + } else if msg.is_text() { + let input: InputEnvelope = serde_json::from_slice(&msg.into_data()[..]).unwrap(); + let output = match usimp::endpoint(&input, Some(tx.clone())).await { + Ok(output) => output, + Err(error) => input.error(error), + }; + _res = tx.send(WebSocketEnvelope::Text(output)).await; + } else { + _res = tx.send(WebSocketEnvelope::Text(OutputEnvelope { + error: Some(Error { + kind: ErrorKind::WebSocketError, + class: ErrorClass::ClientProtocolError, + msg: Some("Unknown frame opcode".to_string()), + desc: None, + }), + request_nr: None, + data: Value::Null, + })).await; + } } Err(error) => println!("{:?}", error), } @@ -71,7 +139,7 @@ pub async fn handler( ErrorClass::ClientProtocolError, None, ))), - ) + ); } } @@ -85,7 +153,7 @@ pub async fn handler( ErrorClass::ClientProtocolError, None, ))), - ) + ); } }; let key = handshake::derive_accept_key(key.as_bytes()); @@ -100,19 +168,20 @@ pub async fn handler( ErrorClass::ClientProtocolError, None, ))), - ) + ); } } tokio::spawn(async move { match hyper::upgrade::on(req).await { Ok(upgraded) => { - let ws_stream = - WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await; - let (tx, rx) = mpsc::channel::(64); + let ws_stream = WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await; + let (tx, rx) = mpsc::channel::(64); let (sink, stream) = ws_stream.split(); - tokio::spawn(async move { sender(sink, rx).await }); - receiver(stream, tx).await + + let other = tokio::spawn(async move { sender(sink, rx).await }); + receiver(stream, tx).await; + other.await.unwrap(); } Err(error) => eprintln!("Unable to upgrade: {}", error), }