diff --git a/src/database.rs b/src/database.rs index 907c5a3..1b50b5a 100644 --- a/src/database.rs +++ b/src/database.rs @@ -15,7 +15,9 @@ static mut POOL: Option = None; pub async fn init() -> Result<(), Error> { let manager = PostgresConnectionManager::new( - "host=localhost user=postgres dbname=locutus".parse().unwrap(), + "host=localhost user=postgres dbname=locutus" + .parse() + .unwrap(), NoTls, ); @@ -24,7 +26,8 @@ pub async fn init() -> Result<(), Error> { .min_idle(Some(2)) .connection_timeout(Duration::from_secs(4)) .max_lifetime(Some(Duration::from_secs(3600))) - .build(manager).await?; + .build(manager) + .await?; unsafe { POOL = Some(Pool::Postgres(pool)); diff --git a/src/error.rs b/src/error.rs index 57405f6..c551044 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,8 +1,8 @@ -use crate::usimp::{InputEnvelope, OutputEnvelope, Event}; +use crate::usimp::{Event, InputEnvelope, OutputEnvelope}; -use serde_json::{Value, Map}; -use bb8_postgres::tokio_postgres; use bb8_postgres; +use bb8_postgres::tokio_postgres; +use serde_json::{Map, Value}; #[derive(Debug)] pub struct Error { @@ -32,7 +32,12 @@ pub enum ErrorKind { } impl InputEnvelope { - pub fn new_error(&self, kind: ErrorKind, class: ErrorClass, msg: Option) -> OutputEnvelope { + pub fn new_error( + &self, + kind: ErrorKind, + class: ErrorClass, + msg: Option, + ) -> OutputEnvelope { OutputEnvelope { request_nr: self.request_nr, error: Some(Error::new(kind, class, msg)), @@ -56,7 +61,7 @@ impl Error { class, msg, desc: None, - } + }; } pub fn msg(&mut self, msg: String) { diff --git a/src/http.rs b/src/http.rs index f5d4499..336e72d 100644 --- a/src/http.rs +++ b/src/http.rs @@ -1,34 +1,62 @@ -use hyper::{Request, Response, Body, StatusCode, header, Method, body}; -use serde_json::{Value, Map}; -use crate::websocket; -use crate::usimp::*; use crate::error::*; use crate::usimp; +use crate::usimp::*; +use crate::websocket; +use hyper::{body, header, Body, Method, Request, Response, StatusCode}; +use serde_json::{Map, Value}; use std::str::FromStr; -async fn endpoint_handler(req: &mut Request, endpoint: String) -> Result, Error> { +async fn endpoint_handler( + req: &mut Request, + endpoint: String, +) -> Result, Error> { if req.method() == Method::OPTIONS { - return Ok(None) + return Ok(None); } else if req.method() != Method::POST { - return Err(Error::new(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, None)) + return Err(Error::new( + ErrorKind::UsimpError, + ErrorClass::ClientProtocolError, + None, + )); } let to_domain; - if let Some(val) = req.headers().get(header::HeaderName::from_str("To-Domain").unwrap()) { + if let Some(val) = req + .headers() + .get(header::HeaderName::from_str("To-Domain").unwrap()) + { to_domain = val.to_str().unwrap().to_string() } else { - return Err(Error::new(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, None)) + return Err(Error::new( + ErrorKind::UsimpError, + ErrorClass::ClientProtocolError, + None, + )); } if let Some(val) = req.headers().get(header::CONTENT_TYPE) { - let parts: Vec = val.to_str()?.split(';').map(|v| v.trim().to_ascii_lowercase()).collect(); + let parts: Vec = val + .to_str()? + .split(';') + .map(|v| v.trim().to_ascii_lowercase()) + .collect(); let p: Vec<&str> = parts.iter().map(|v| v.as_str()).collect(); match p[0..1] { - ["application/json"] => {}, - _ => return Err(Error::new(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, None)) + ["application/json"] => {} + _ => { + return Err(Error::new( + ErrorKind::UsimpError, + ErrorClass::ClientProtocolError, + None, + )) + } } } else { - return Err(Error::new(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, None)) + return Err(Error::new( + ErrorKind::UsimpError, + ErrorClass::ClientProtocolError, + None, + )); } let data = serde_json::from_slice(&body::to_bytes(req.body_mut()).await?)?; @@ -36,7 +64,10 @@ async fn endpoint_handler(req: &mut Request, endpoint: String) -> Result Some(val.to_str()?.to_string()), None => None, }, @@ -47,9 +78,13 @@ async fn endpoint_handler(req: &mut Request, endpoint: String) -> Result None, }, data, @@ -65,7 +100,10 @@ pub async fn handler(mut req: Request) -> Result, hyper::Er println!("{} {}", req.method(), req.uri()); let val: Result, Error> = match &parts[..] { - [""] => Ok(res.status(StatusCode::OK).body(Body::from("Hello World")).unwrap()), + [""] => Ok(res + .status(StatusCode::OK) + .body(Body::from("Hello World")) + .unwrap()), ["_usimp"] | ["_usimp", ..] => { res = res .header(header::SERVER, "Locutus") @@ -75,23 +113,29 @@ pub async fn handler(mut req: Request) -> Result, hyper::Er let output = match &parts[1..] { ["websocket"] => { - res = res - .header(header::ACCESS_CONTROL_ALLOW_METHODS, "GET"); + res = res.header(header::ACCESS_CONTROL_ALLOW_METHODS, "GET"); let (r, val) = websocket::handler(req, res).await; res = r; match val { Some(val) => Ok(Some(val)), None => return Ok(res.body(Body::empty()).unwrap()), } - }, + } [endpoint] => { res = res .header(header::ACCESS_CONTROL_ALLOW_METHODS, "POST, OPTIONS") - .header(header::ACCESS_CONTROL_ALLOW_HEADERS, "Content-Type, From-Domain, To-Domain, Authorization"); + .header( + header::ACCESS_CONTROL_ALLOW_HEADERS, + "Content-Type, From-Domain, To-Domain, Authorization", + ); let endpoint = endpoint.to_string(); endpoint_handler(&mut req, endpoint).await - }, - _ => Err(Error::new(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, None)), + } + _ => Err(Error::new( + ErrorKind::UsimpError, + ErrorClass::ClientProtocolError, + None, + )), }; let output = match output { @@ -107,15 +151,17 @@ pub async fn handler(mut req: Request) -> Result, hyper::Er Some(error) => { res = match error.class { ErrorClass::ClientProtocolError => res.status(StatusCode::BAD_REQUEST), - ErrorClass::ServerError => res.status(StatusCode::INTERNAL_SERVER_ERROR), + ErrorClass::ServerError => { + res.status(StatusCode::INTERNAL_SERVER_ERROR) + } _ => res.status(StatusCode::OK), }; data["status"] = Value::from("error"); data["error"] = Value::from(error); - }, + } None => { data["status"] = Value::from("success"); - }, + } } data["request_nr"] = match output.request_nr { @@ -123,14 +169,19 @@ pub async fn handler(mut req: Request) -> Result, hyper::Er None => Value::Null, }; data["data"] = output.data; - return Ok(res.body(Body::from(serde_json::to_string(&data).unwrap() + "\r\n")).unwrap()) + return Ok(res + .body(Body::from(serde_json::to_string(&data).unwrap() + "\r\n")) + .unwrap()); } else { res = res.status(StatusCode::NO_CONTENT); } - return Ok(res.body(Body::empty()).unwrap()) - }, - _ => Ok(res.status(StatusCode::NOT_FOUND).body(Body::empty()).unwrap()), + return Ok(res.body(Body::empty()).unwrap()); + } + _ => Ok(res + .status(StatusCode::NOT_FOUND) + .body(Body::empty()) + .unwrap()), }; match val { diff --git a/src/main.rs b/src/main.rs index 41ada7e..699420b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,21 +4,30 @@ use std::net; use std::net::SocketAddr; use std::pin::Pin; -use error::*; use ansi_term::{Color, Style}; +use error::*; use futures_util::{future::TryFutureExt, stream::Stream}; -use hyper::Server; use hyper::server::conn::AddrStream; use hyper::service::{make_service_fn, service_fn}; +use hyper::Server; -mod http; -mod websocket; -mod usimp; mod database; mod error; +mod http; +mod usimp; +mod websocket; struct HyperAcceptor<'a> { - acceptor: Pin, std::io::Error>> + 'a>>, + acceptor: Pin< + Box< + dyn Stream< + Item = Result< + tokio_rustls::server::TlsStream, + std::io::Error, + >, + > + 'a, + >, + >, } impl hyper::server::accept::Accept for HyperAcceptor<'_> { @@ -38,7 +47,8 @@ fn load_certs(filename: &str) -> std::io::Result> { .map_err(|e| error(format!("failed to open {}: {}", filename, e)))?; let mut reader = std::io::BufReader::new(certfile); - rustls::internal::pemfile::certs(&mut reader).map_err(|_| error("failed to load certificate".into())) + rustls::internal::pemfile::certs(&mut reader) + .map_err(|_| error("failed to load certificate".into())) } fn load_private_key(filename: &str) -> std::io::Result { @@ -65,7 +75,10 @@ async fn main() -> Result<(), Error> { database::init().await?; usimp::subscription::init(); - let server1 = Server::bind(&SocketAddr::from(([0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0], 8080))); + let server1 = Server::bind(&SocketAddr::from(( + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + 8080, + ))); let service = make_service_fn(|_: &AddrStream| async { Ok::<_, hyper::Error>(service_fn(http::handler)) }); @@ -100,9 +113,7 @@ async fn main() -> Result<(), Error> { acceptor: Box::pin(incoming_tls_stream), }); - let service = make_service_fn(|_| async { - Ok::<_, hyper::Error>(service_fn(http::handler)) - }); + let service = make_service_fn(|_| async { Ok::<_, hyper::Error>(service_fn(http::handler)) }); let srv2 = server2.serve(service); println!("{}", Color::Green.paint("Ready")); diff --git a/src/usimp/handler/authenticate.rs b/src/usimp/handler/authenticate.rs index af64ed2..4405652 100644 --- a/src/usimp/handler/authenticate.rs +++ b/src/usimp/handler/authenticate.rs @@ -1,10 +1,10 @@ +use crate::database; use crate::usimp; use crate::usimp::*; -use crate::database; -use serde_json::{Value, from_value, to_value}; -use serde::{Serialize, Deserialize}; use rand::Rng; +use serde::{Deserialize, Serialize}; +use serde_json::{from_value, to_value, Value}; #[derive(Serialize, Deserialize, Clone)] struct Input { @@ -19,7 +19,9 @@ struct Output { } pub async fn handle(input: &InputEnvelope, session: Option) -> Result { - Ok(to_value(authenticate(from_value(input.data.clone())?, session).await?)?) + Ok(to_value( + authenticate(from_value(input.data.clone())?, session).await?, + )?) } async fn authenticate(input: Input, _session: Option) -> Result { @@ -28,14 +30,20 @@ async fn authenticate(input: Input, _session: Option) -> Result { - let res = client.query( - "SELECT account_id, domain_id \ - FROM accounts \ - WHERE account_name = $1", - &[&input.name] - ).await?; + let res = client + .query( + "SELECT account_id, domain_id \ + FROM accounts \ + WHERE account_name = $1", + &[&input.name], + ) + .await?; if res.len() == 0 { - return Err(Error::new(ErrorKind::AuthenticationError, ErrorClass::ClientError, None)); + return Err(Error::new( + ErrorKind::AuthenticationError, + ErrorClass::ClientError, + None, + )); } let row = &res[0]; let account_id: String = row.get(0); @@ -43,7 +51,11 @@ async fn authenticate(input: Input, _session: Option) -> Result) -> Result>) -> Result { +pub async fn endpoint( + input: &InputEnvelope, + tx: Option>, +) -> Result { if input.from_domain != None { // TODO - return Err(Error::new(ErrorKind::NotImplemented, ErrorClass::ServerError, None)); + return Err(Error::new( + ErrorKind::NotImplemented, + ErrorClass::ServerError, + None, + )); } let session; if let Some(token) = &input.token { @@ -24,8 +31,10 @@ pub async fn endpoint(input: &InputEnvelope, tx: Option input.respond(authenticate::handle(&input, session).await?), "subscribe" => input.respond(subscribe::handle(&input, session, tx).await?), "new_event" => input.respond(new_event::handle(&input, session).await?), - _ => input.new_error(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, Some("Invalid endpoint".to_string())), + _ => input.new_error( + ErrorKind::UsimpError, + ErrorClass::ClientProtocolError, + Some("Invalid endpoint".to_string()), + ), }) } - - diff --git a/src/usimp/handler/new_event.rs b/src/usimp/handler/new_event.rs index 6dd46cf..a41b51d 100644 --- a/src/usimp/handler/new_event.rs +++ b/src/usimp/handler/new_event.rs @@ -1,8 +1,8 @@ -use crate::usimp::*; use crate::usimp::subscription; +use crate::usimp::*; -use serde_json::{Value, from_value, to_value}; -use serde::{Serialize, Deserialize}; +use serde::{Deserialize, Serialize}; +use serde_json::{from_value, to_value, Value}; #[derive(Serialize, Deserialize, Clone)] struct Input { @@ -11,11 +11,12 @@ struct Input { } #[derive(Serialize, Deserialize, Clone)] -struct Output { -} +struct Output {} pub async fn handle(input: &InputEnvelope, session: Option) -> Result { - Ok(to_value(new_event(from_value(input.data.clone())?, session).await?)?) + Ok(to_value( + new_event(from_value(input.data.clone())?, session).await?, + )?) } async fn new_event(input: Input, session: Option) -> Result { diff --git a/src/usimp/handler/subscribe.rs b/src/usimp/handler/subscribe.rs index 6f76654..45e2497 100644 --- a/src/usimp/handler/subscribe.rs +++ b/src/usimp/handler/subscribe.rs @@ -1,50 +1,67 @@ -use crate::usimp::*; use crate::usimp::subscription; +use crate::usimp::*; -use serde_json::{Value, from_value, to_value}; -use serde::{Serialize, Deserialize}; +use serde::{Deserialize, Serialize}; +use serde_json::{from_value, to_value, Value}; use tokio::sync::mpsc; #[derive(Serialize, Deserialize, Clone)] -struct Input { -} +struct Input {} #[derive(Serialize, Deserialize, Clone)] struct Output { event: Option, } -pub async fn handle(input: &InputEnvelope, session: Option, tx: Option>) -> Result { - Ok(to_value(subscribe(from_value(input.data.clone())?, session, input.request_nr, tx).await?)?) +pub async fn handle( + input: &InputEnvelope, + session: Option, + tx: Option>, +) -> Result { + Ok(to_value( + subscribe( + from_value(input.data.clone())?, + session, + input.request_nr, + tx, + ) + .await?, + )?) } -async fn subscribe(_input: Input, session: Option, req_nr: Option, tx: Option>) -> Result { +async fn subscribe( + _input: Input, + session: Option, + req_nr: Option, + tx: Option>, +) -> Result { let account = get_account(&session)?; let mut rx = subscription::subscribe_account(account).await; match tx { Some(tx) => { tokio::spawn(async move { while let Some(event) = rx.recv().await { - let _res = tx.send(OutputEnvelope { - error: None, - request_nr: req_nr, - data: to_value(event).unwrap(), - }).await; + let _res = tx + .send(OutputEnvelope { + error: None, + request_nr: req_nr, + data: to_value(event).unwrap(), + }) + .await; } }); - Ok(Output { - event: None, - }) + Ok(Output { event: None }) } None => { if let Some(event) = rx.recv().await { - Ok(Output { - event: Some(event), - }) + Ok(Output { event: Some(event) }) } else { - Err(Error::new(ErrorKind::SubscriptionError, ErrorClass::ServerError, None)) + Err(Error::new( + ErrorKind::SubscriptionError, + ErrorClass::ServerError, + None, + )) } } } } - diff --git a/src/usimp/mod.rs b/src/usimp/mod.rs index 18f7c25..4c770f4 100644 --- a/src/usimp/mod.rs +++ b/src/usimp/mod.rs @@ -3,13 +3,13 @@ pub mod subscription; pub use handler::endpoint; -use crate::error::{Error, ErrorClass, ErrorKind}; use crate::database; -use serde_json::Value; -use serde::{Serialize, Deserialize}; -use crypto::sha2::Sha256; -use crypto::digest::Digest; +use crate::error::{Error, ErrorClass, ErrorKind}; use base64_url; +use crypto::digest::Digest; +use crypto::sha2::Sha256; +use serde::{Deserialize, Serialize}; +use serde_json::Value; #[derive(Serialize, Deserialize)] pub struct InputEnvelope { @@ -60,9 +60,21 @@ pub fn get_account(session: &Option) -> Result<&Account, Error> { match session { Some(session) => match &session.account { Some(account) => Ok(&account), - None => return Err(Error::new(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, None)) + None => { + return Err(Error::new( + ErrorKind::UsimpError, + ErrorClass::ClientProtocolError, + None, + )) + } }, - None => return Err(Error::new(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, None)) + None => { + return Err(Error::new( + ErrorKind::UsimpError, + ErrorClass::ClientProtocolError, + None, + )) + } } } @@ -82,14 +94,20 @@ impl Session { let session; match backend { database::Client::Postgres(client) => { - let res = client.query( - "SELECT session_id, session_nr, a.account_id, account_name, domain_id \ + let res = client + .query( + "SELECT session_id, session_nr, a.account_id, account_name, domain_id \ FROM accounts a JOIN sessions s ON a.account_id = s.account_id \ WHERE session_token = $1;", - &[&token] - ).await?; + &[&token], + ) + .await?; if res.len() == 0 { - return Err(Error::new(ErrorKind::InvalidSessionError, ErrorClass::ClientError, None)); + return Err(Error::new( + ErrorKind::InvalidSessionError, + ErrorClass::ClientError, + None, + )); } let row = &res[0]; session = Session { diff --git a/src/usimp/subscription.rs b/src/usimp/subscription.rs index 3e65e63..eaeae2f 100644 --- a/src/usimp/subscription.rs +++ b/src/usimp/subscription.rs @@ -1,8 +1,8 @@ -use crate::usimp::*; use crate::database; -use tokio::sync::{mpsc, Mutex}; +use crate::usimp::*; use std::collections::HashMap; use std::sync::Arc; +use tokio::sync::{mpsc, Mutex}; static mut ROOMS: Option>>>>> = None; static mut ACCOUNTS: Option>>>>> = None; @@ -21,10 +21,10 @@ pub async fn subscribe_account(account: &Account) -> mpsc::Receiver { match acc.get_mut(account.id.as_str()) { Some(vec) => { vec.push(tx); - }, + } None => { - acc.insert(account.id.clone(), vec!{tx}); - }, + acc.insert(account.id.clone(), vec![tx]); + } } } rx @@ -34,12 +34,14 @@ pub async fn push(room_id: &str, event: Event) -> Result<(), Error> { let backend = database::client().await?; let accounts = match backend { database::Client::Postgres(client) => { - let res = client.query( - "SELECT account_id \ + let res = client + .query( + "SELECT account_id \ FROM members \ WHERE room_id = $1;", - &[&room_id] - ).await?; + &[&room_id], + ) + .await?; let mut acc: Vec = Vec::new(); for row in res { acc.push(row.get(0)); @@ -48,13 +50,13 @@ pub async fn push(room_id: &str, event: Event) -> Result<(), Error> { } }; - unsafe { - let mut rooms = ROOMS.as_ref().unwrap().lock().await; - if let Some(rooms) = rooms.get_mut(room_id) { - for tx in rooms { - let _res = tx.send(event.clone()).await; - } - } + unsafe { + let mut rooms = ROOMS.as_ref().unwrap().lock().await; + if let Some(rooms) = rooms.get_mut(room_id) { + for tx in rooms { + let _res = tx.send(event.clone()).await; + } + } } for account in accounts { diff --git a/src/websocket.rs b/src/websocket.rs index 5b7d6e6..173ac02 100644 --- a/src/websocket.rs +++ b/src/websocket.rs @@ -1,41 +1,47 @@ -use hyper::{Request, Body, StatusCode, header}; -use crate::usimp::*; -use crate::usimp; use crate::error::*; -use hyper_tungstenite::{WebSocketStream, tungstenite::protocol::Role}; -use futures_util::StreamExt; -use hyper_tungstenite::tungstenite::{handshake, Message}; -use hyper_tungstenite::hyper::upgrade::Upgraded; +use crate::usimp; +use crate::usimp::*; use futures::stream::{SplitSink, SplitStream}; -use tokio::sync::mpsc; -use serde_json::{Value, Map}; use futures_util::SinkExt; +use futures_util::StreamExt; +use hyper::{header, Body, Request, StatusCode}; +use hyper_tungstenite::hyper::upgrade::Upgraded; +use hyper_tungstenite::tungstenite::{handshake, Message}; +use hyper_tungstenite::{tungstenite::protocol::Role, WebSocketStream}; +use serde_json::{Map, Value}; +use tokio::sync::mpsc; -async fn sender(mut sink: SplitSink, Message>, mut rx: mpsc::Receiver) { +async fn sender( + mut sink: SplitSink, Message>, + mut rx: mpsc::Receiver, +) { while let Some(msg) = rx.recv().await { let mut envelope = Value::Object(Map::new()); envelope["data"] = msg.data; envelope["request_nr"] = match msg.request_nr { - Some(nr) => Value::from(nr), + Some(nr) => Value::from(nr), None => Value::Null, }; match msg.error { Some(error) => { envelope["status"] = Value::from("error"); envelope["error"] = Value::from(error); - }, + } None => { envelope["status"] = Value::from("success"); } } if let Err(error) = sink.send(Message::Text(envelope.to_string())).await { eprintln!("{:?}", error); - break + break; } } } -async fn receiver(mut stream: SplitStream>, tx: mpsc::Sender) { +async fn receiver( + mut stream: SplitStream>, + tx: mpsc::Sender, +) { while let Some(res) = stream.next().await { match res { Ok(msg) => { @@ -45,53 +51,91 @@ async fn receiver(mut stream: SplitStream>, tx: mpsc:: Err(error) => input.error(error), }; let _res = tx.send(output).await; - }, + } Err(error) => println!("{:?}", error), } } } -pub async fn handler(req: Request, res: hyper::http::response::Builder) -> (hyper::http::response::Builder, Option) { +pub async fn handler( + req: Request, + res: hyper::http::response::Builder, +) -> (hyper::http::response::Builder, Option) { match req.headers().get(header::UPGRADE) { - Some(val) if val == header::HeaderValue::from_str("websocket").unwrap() => {}, - _ => return (res, Some(OutputEnvelope::from(Error::new(ErrorKind::WebSocketError, ErrorClass::ClientProtocolError, None)))), + Some(val) if val == header::HeaderValue::from_str("websocket").unwrap() => {} + _ => { + return ( + res, + Some(OutputEnvelope::from(Error::new( + ErrorKind::WebSocketError, + ErrorClass::ClientProtocolError, + None, + ))), + ) + } } let key = match req.headers().get(header::SEC_WEBSOCKET_KEY) { Some(key) => key, - None => return (res, Some(OutputEnvelope::from(Error::new(ErrorKind::WebSocketError, ErrorClass::ClientProtocolError, None)))) + None => { + return ( + res, + Some(OutputEnvelope::from(Error::new( + ErrorKind::WebSocketError, + ErrorClass::ClientProtocolError, + None, + ))), + ) + } }; let key = handshake::derive_accept_key(key.as_bytes()); match req.headers().get(header::SEC_WEBSOCKET_PROTOCOL) { Some(val) if val == header::HeaderValue::from_str("usimp").unwrap() => {} - _ => return (res, Some(OutputEnvelope::from(Error::new(ErrorKind::UsimpError, ErrorClass::ClientProtocolError, None)))), + _ => { + return ( + res, + Some(OutputEnvelope::from(Error::new( + ErrorKind::UsimpError, + ErrorClass::ClientProtocolError, + None, + ))), + ) + } } tokio::spawn(async move { match hyper::upgrade::on(req).await { Ok(upgraded) => { - let ws_stream = WebSocketStream::from_raw_socket( - upgraded, - Role::Server, - None, - ).await; + let ws_stream = + WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await; let (tx, rx) = mpsc::channel::(64); let (sink, stream) = ws_stream.split(); - tokio::spawn(async move { - sender(sink, rx).await - }); + tokio::spawn(async move { sender(sink, rx).await }); receiver(stream, tx).await } - Err(error) => eprintln!("Unable to upgrade: {}", error) + Err(error) => eprintln!("Unable to upgrade: {}", error), } }); - (res - .status(StatusCode::SWITCHING_PROTOCOLS) - .header(header::CONNECTION, header::HeaderValue::from_str("Upgrade").unwrap()) - .header(header::UPGRADE, header::HeaderValue::from_str("websocket").unwrap()) - .header(header::SEC_WEBSOCKET_ACCEPT, header::HeaderValue::from_str(key.as_str()).unwrap()) - .header(header::SEC_WEBSOCKET_PROTOCOL, header::HeaderValue::from_str("usimp").unwrap()), - None) + ( + res.status(StatusCode::SWITCHING_PROTOCOLS) + .header( + header::CONNECTION, + header::HeaderValue::from_str("Upgrade").unwrap(), + ) + .header( + header::UPGRADE, + header::HeaderValue::from_str("websocket").unwrap(), + ) + .header( + header::SEC_WEBSOCKET_ACCEPT, + header::HeaderValue::from_str(key.as_str()).unwrap(), + ) + .header( + header::SEC_WEBSOCKET_PROTOCOL, + header::HeaderValue::from_str("usimp").unwrap(), + ), + None, + ) }