Fix WebSocket connection closes

This commit is contained in:
2022-08-27 17:39:19 +02:00
parent f3c940a96c
commit d94daa2c93
3 changed files with 107 additions and 36 deletions

View File

@ -5,10 +5,11 @@ mod subscribe;
use crate::usimp::*; use crate::usimp::*;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use crate::websocket::WebSocketEnvelope;
pub async fn endpoint( pub async fn endpoint(
input: &InputEnvelope, input: &InputEnvelope,
tx: Option<mpsc::Sender<OutputEnvelope>>, tx: Option<mpsc::Sender<WebSocketEnvelope>>,
) -> Result<OutputEnvelope, Error> { ) -> Result<OutputEnvelope, Error> {
if input.from_domain != None { if input.from_domain != None {
// TODO // TODO

View File

@ -4,6 +4,7 @@ use crate::usimp::*;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{from_value, to_value, Value}; use serde_json::{from_value, to_value, Value};
use tokio::sync::mpsc; use tokio::sync::mpsc;
use crate::websocket::WebSocketEnvelope;
#[derive(Serialize, Deserialize, Clone)] #[derive(Serialize, Deserialize, Clone)]
struct Input {} struct Input {}
@ -16,7 +17,7 @@ struct Output {
pub async fn handle( pub async fn handle(
input: &InputEnvelope, input: &InputEnvelope,
session: Option<Session>, session: Option<Session>,
tx: Option<mpsc::Sender<OutputEnvelope>>, tx: Option<mpsc::Sender<WebSocketEnvelope>>,
) -> Result<Value, Error> { ) -> Result<Value, Error> {
Ok(to_value( Ok(to_value(
subscribe( subscribe(
@ -33,7 +34,7 @@ async fn subscribe(
_input: Input, _input: Input,
session: Option<Session>, session: Option<Session>,
req_nr: Option<u64>, req_nr: Option<u64>,
tx: Option<mpsc::Sender<OutputEnvelope>>, tx: Option<mpsc::Sender<WebSocketEnvelope>>,
) -> Result<Output, Error> { ) -> Result<Output, Error> {
let account = get_account(&session)?; let account = get_account(&session)?;
let mut rx = subscription::subscribe_account(account).await; let mut rx = subscription::subscribe_account(account).await;
@ -46,7 +47,7 @@ async fn subscribe(
error: None, error: None,
request_nr: req_nr, request_nr: req_nr,
data: serde_json::json![{"events": [event]}], data: serde_json::json![{"events": [event]}],
}) }.into())
.await; .await;
} }
}); });

View File

@ -11,46 +11,114 @@ use hyper_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
use serde_json::{Map, Value}; use serde_json::{Map, Value};
use tokio::sync::mpsc; use tokio::sync::mpsc;
pub enum WebSocketEnvelope {
Close,
Ping(Vec<u8>),
Pong(Vec<u8>),
Text(OutputEnvelope),
}
impl From<OutputEnvelope> for WebSocketEnvelope {
fn from(envelope: OutputEnvelope) -> Self {
WebSocketEnvelope::Text(envelope)
}
}
async fn sender( async fn sender(
mut sink: SplitSink<WebSocketStream<Upgraded>, Message>, mut sink: SplitSink<WebSocketStream<Upgraded>, Message>,
mut rx: mpsc::Receiver<OutputEnvelope>, mut rx: mpsc::Receiver<WebSocketEnvelope>,
) { ) {
while let Some(msg) = rx.recv().await { while let Some(msg) = rx.recv().await {
let mut envelope = Value::Object(Map::new()); match msg {
envelope["data"] = msg.data; WebSocketEnvelope::Ping(data) => {
envelope["request_nr"] = match msg.request_nr { if let Err(error) = sink.send(Message::Ping(data)).await {
Some(nr) => Value::from(nr), eprintln!("{:?}", error);
None => Value::Null, return;
}; }
match msg.error {
Some(error) => {
envelope["status"] = Value::from("error");
envelope["error"] = Value::from(error);
} }
None => { WebSocketEnvelope::Pong(data) => {
envelope["status"] = Value::from("success"); if let Err(error) = sink.send(Message::Pong(data)).await {
eprintln!("{:?}", error);
return;
}
}
WebSocketEnvelope::Close => {
if let Err(error) = sink.send(Message::Close(None)).await {
eprintln!("{:?}", error);
return;
}
break;
}
WebSocketEnvelope::Text(msg) => {
let mut envelope = Value::Object(Map::new());
envelope["data"] = msg.data;
envelope["request_nr"] = match msg.request_nr {
Some(nr) => Value::from(nr),
None => Value::Null,
};
match msg.error {
Some(error) => {
envelope["status"] = Value::from("error");
envelope["error"] = Value::from(error);
}
None => {
envelope["status"] = Value::from("success");
}
}
if let Err(error) = sink.send(Message::Text(envelope.to_string())).await {
eprintln!("{:?}", error);
return;
}
} }
}
if let Err(error) = sink.send(Message::Text(envelope.to_string())).await {
eprintln!("{:?}", error);
break;
} }
} }
} }
async fn receiver( async fn receiver(
mut stream: SplitStream<WebSocketStream<Upgraded>>, mut stream: SplitStream<WebSocketStream<Upgraded>>,
tx: mpsc::Sender<OutputEnvelope>, tx: mpsc::Sender<WebSocketEnvelope>,
) { ) {
while let Some(res) = stream.next().await { while let Some(res) = stream.next().await {
match res { match res {
Ok(msg) => { Ok(msg) => {
let input: InputEnvelope = serde_json::from_slice(&msg.into_data()[..]).unwrap(); let _res;
let output = match usimp::endpoint(&input, Some(tx.clone())).await { if msg.is_ping() {
Ok(output) => output, _res = tx.send(WebSocketEnvelope::Pong(msg.into_data())).await;
Err(error) => input.error(error), } else if msg.is_pong() {
}; // Ignore
let _res = tx.send(output).await; } else if msg.is_close() {
_res = tx.send(WebSocketEnvelope::Close).await;
break;
} else if msg.is_binary() {
_res = tx.send(WebSocketEnvelope::Text(OutputEnvelope {
error: Some(Error {
kind: ErrorKind::WebSocketError,
class: ErrorClass::ClientProtocolError,
msg: Some("Binary frames are not allowed".to_string()),
desc: None,
}),
request_nr: None,
data: Value::Null,
})).await;
} else if msg.is_text() {
let input: InputEnvelope = serde_json::from_slice(&msg.into_data()[..]).unwrap();
let output = match usimp::endpoint(&input, Some(tx.clone())).await {
Ok(output) => output,
Err(error) => input.error(error),
};
_res = tx.send(WebSocketEnvelope::Text(output)).await;
} else {
_res = tx.send(WebSocketEnvelope::Text(OutputEnvelope {
error: Some(Error {
kind: ErrorKind::WebSocketError,
class: ErrorClass::ClientProtocolError,
msg: Some("Unknown frame opcode".to_string()),
desc: None,
}),
request_nr: None,
data: Value::Null,
})).await;
}
} }
Err(error) => println!("{:?}", error), Err(error) => println!("{:?}", error),
} }
@ -71,7 +139,7 @@ pub async fn handler(
ErrorClass::ClientProtocolError, ErrorClass::ClientProtocolError,
None, None,
))), ))),
) );
} }
} }
@ -85,7 +153,7 @@ pub async fn handler(
ErrorClass::ClientProtocolError, ErrorClass::ClientProtocolError,
None, None,
))), ))),
) );
} }
}; };
let key = handshake::derive_accept_key(key.as_bytes()); let key = handshake::derive_accept_key(key.as_bytes());
@ -100,19 +168,20 @@ pub async fn handler(
ErrorClass::ClientProtocolError, ErrorClass::ClientProtocolError,
None, None,
))), ))),
) );
} }
} }
tokio::spawn(async move { tokio::spawn(async move {
match hyper::upgrade::on(req).await { match hyper::upgrade::on(req).await {
Ok(upgraded) => { Ok(upgraded) => {
let ws_stream = let ws_stream = WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await;
WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await; let (tx, rx) = mpsc::channel::<WebSocketEnvelope>(64);
let (tx, rx) = mpsc::channel::<OutputEnvelope>(64);
let (sink, stream) = ws_stream.split(); let (sink, stream) = ws_stream.split();
tokio::spawn(async move { sender(sink, rx).await });
receiver(stream, tx).await let other = tokio::spawn(async move { sender(sink, rx).await });
receiver(stream, tx).await;
other.await.unwrap();
} }
Err(error) => eprintln!("Unable to upgrade: {}", error), Err(error) => eprintln!("Unable to upgrade: {}", error),
} }