bit cleaner websocket implementation

This commit is contained in:
2021-05-22 15:51:24 +02:00
parent 3867435a2d
commit 0e9caa496f
3 changed files with 168 additions and 77 deletions

View File

@ -11,6 +11,7 @@ pub enum Kind {
WebSocketError, WebSocketError,
NotImplementedError, NotImplementedError,
UsimpProtocolError, UsimpProtocolError,
Utf8DecodeError,
} }
#[derive(Copy, Clone, Debug)] #[derive(Copy, Clone, Debug)]
@ -59,6 +60,7 @@ impl Error {
Kind::WebSocketError => "WebSocket protocol error", Kind::WebSocketError => "WebSocket protocol error",
Kind::NotImplementedError => "Not yet implemented", Kind::NotImplementedError => "Not yet implemented",
Kind::UsimpProtocolError => "USIMP protocol error", Kind::UsimpProtocolError => "USIMP protocol error",
Kind::Utf8DecodeError => "Unable to decode UTF-8 data",
}, },
} }
} }
@ -88,6 +90,7 @@ impl fmt::Display for Error {
Kind::WebSocketError => "websocket protocol error", Kind::WebSocketError => "websocket protocol error",
Kind::NotImplementedError => "not yet implemented", Kind::NotImplementedError => "not yet implemented",
Kind::UsimpProtocolError => "usimp protocol error", Kind::UsimpProtocolError => "usimp protocol error",
Kind::Utf8DecodeError => "unable to decode utf-8 data",
} }
.to_string(); .to_string();
if let Some(desc) = &self.desc { if let Some(desc) = &self.desc {
@ -143,3 +146,14 @@ impl From<r2d2_postgres::postgres::Error> for Error {
} }
} }
} }
impl From<std::string::FromUtf8Error> for Error {
fn from(error: std::string::FromUtf8Error) -> Self {
Error {
kind: Kind::Utf8DecodeError,
msg: Some("Unable to decode UTF-8 data".to_string()),
desc: Some(error.to_string()),
class: Class::ClientError,
}
}
}

View File

@ -1,80 +1,136 @@
use crate::error::*; use crate::error::*;
use crate::http; use crate::http;
use crate::websocket::FrameHeader; use crate::websocket::*;
use base64; use base64;
use crypto; use crypto;
use crypto::digest::Digest; use crypto::digest::Digest;
pub fn connection_handler( pub fn recv_message(client: &mut http::HttpStream) -> Result<Message, Error> {
let mut msg: Vec<u8> = Vec::new();
let mut msg_type = 0;
loop {
let header = FrameHeader::from(&mut client.stream)?;
if msg_type != 0 && header.opcode != 0 {
return Err(Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("continuation frame expected".to_string()));
} else if header.opcode >= 8 && (!header.fin || header.payload_len >= 126) {
return Err(Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("invalid control frame".to_string()));
}
match header.opcode {
0 => {}, // cont
1 => {}, // text
2 => // binary
return Err(Error::new(Kind::UsimpProtocolError, Class::ClientError)
.set_desc("binary frames must not be sent on a usimp connection".to_string())),
8 => {}, // close
9 => {}, // ping
10 => {}, // pong
_ => return Err(Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("invalid opcode".to_string())),
}
msg_type = header.opcode;
// FIXME check payload len and total len
let mut buf = vec![0u8; header.payload_len() as usize];
client.stream.read_exact(&mut buf)?;
if header.mask {
let key: [u8; 4] = [
(header.masking_key.unwrap() >> 24) as u8,
((header.masking_key.unwrap() >> 16) & 0xFF) as u8,
((header.masking_key.unwrap() >> 8) & 0xFF) as u8,
(header.masking_key.unwrap() & 0xFF) as u8,
];
for (pos, byte) in buf.iter_mut().enumerate() {
*byte ^= key[pos & 3]; // = pos % 4
}
}
msg.append(&mut buf);
if header.fin {
break
}
}
match msg_type {
1 => {Ok(Message::TextMessage(TextMessage {
data: String::from_utf8(msg)?
}))},
8 => {
let mut code = None;
let mut reason = None;
if msg.len() >= 2 {
code = Some(((msg[0] as u16) << 8) | (msg[1] as u16));
}
if msg.len() > 2 {
reason = Some(String::from_utf8(msg[2..].to_vec())?);
}
Ok(Message::CloseMessage(CloseMessage {
code,
reason
}))
},
9 => {Ok(Message::PingMessage(PingMessage {
data: String::from_utf8(msg)?
}))},
10 => {Ok(Message::PongMessage(PongMessage {
data: String::from_utf8(msg)?
}))},
_ => panic!("invalid msg_type for websocket")
}
}
pub fn handshake(
client: &mut http::HttpStream, client: &mut http::HttpStream,
req: &http::Request, req: &http::Request,
mut res: http::Response, res: &mut http::Response,
) { ) -> Result<(), Error> {
if let http::Method::GET = req.method { if let http::Method::GET = req.method {
} else { } else {
res.status(405); res.status(405);
res.header.add_field("Allow", "GET"); res.header.add_field("Allow", "GET");
return http::error_handler( return Err(Error::new(Kind::WebSocketError, Class::ClientError)
client, .set_desc("method not allowed".to_string()));
res,
Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("method not allowed".to_string()),
);
} }
if let Some(_) = req.header.find_field("Connection") { if let Some(_) = req.header.find_field("Connection") {
if !req.header.field_has_value("Connection", "upgrade") { if !req.header.field_has_value("Connection", "upgrade") {
return http::error_handler( return Err(Error::new(Kind::WebSocketError, Class::ClientError)
client, .set_desc("invalid value for header field 'Connection'".to_string()));
res,
Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("invalid value for header field 'Connection'".to_string()),
);
} }
} else { } else {
return http::error_handler( return Err(Error::new(Kind::WebSocketError, Class::ClientError)
client, .set_desc("unable to find header field 'Connection'".to_string()));
res,
Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("unable to find header field 'Connection'".to_string()),
);
} }
if let Some(upgrade) = req.header.find_field("Upgrade") { if let Some(upgrade) = req.header.find_field("Upgrade") {
if !upgrade.eq_ignore_ascii_case("websocket") { if !upgrade.eq_ignore_ascii_case("websocket") {
return http::error_handler( return Err(Error::new(Kind::WebSocketError, Class::ClientError)
client, .set_desc("invalid value for header field 'Upgrade'".to_string()));
res,
Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("invalid value for header field 'Upgrade'".to_string()),
);
} }
} else { } else {
return http::error_handler( return Err(Error::new(Kind::WebSocketError, Class::ClientError)
client, .set_desc("unable to find header field 'Upgrade'".to_string()));
res,
Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("unable to find header field 'Upgrade'".to_string()),
);
} }
if let Some(version) = req.header.find_field("Sec-WebSocket-Version") { if let Some(version) = req.header.find_field("Sec-WebSocket-Version") {
if !version.eq("13") { if !version.eq("13") {
return http::error_handler( return Err(Error::new(Kind::WebSocketError, Class::ClientError)
client, .set_desc("invalid value for header field 'Sec-WebSocket-Key'".to_string()));
res,
Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("invalid value for header field 'Sec-WebSocket-Key'".to_string()),
);
} }
} else { } else {
return http::error_handler( return Err(Error::new(Kind::WebSocketError, Class::ClientError)
client, .set_desc("unable to find header field 'Sec-WebSocket-Version'".to_string()));
res,
Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("unable to find header field 'Sec-WebSocket-Version'".to_string()),
);
} }
if let Some(key) = req.header.find_field("Sec-WebSocket-Key") { if let Some(key) = req.header.find_field("Sec-WebSocket-Key") {
@ -87,12 +143,8 @@ pub fn connection_handler(
let key = base64::encode(result); let key = base64::encode(result);
res.header.add_field("Sec-WebSocket-Accept", key.as_str()); res.header.add_field("Sec-WebSocket-Accept", key.as_str());
} else { } else {
return http::error_handler( return Err(Error::new(Kind::WebSocketError, Class::ClientError)
client, .set_desc("unable to find header field 'Sec-WebSocket-Key'".to_string()));
res,
Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("unable to find header field 'Sec-WebSocket-Key'".to_string()),
);
} }
client.server_keep_alive = false; client.server_keep_alive = false;
@ -100,35 +152,36 @@ pub fn connection_handler(
res.header.add_field("Upgrade", "websocket"); res.header.add_field("Upgrade", "websocket");
res.status(101); res.status(101);
res.send(&mut client.stream).unwrap(); res.send(&mut client.stream)?;
Ok(())
}
pub fn connection_handler(
client: &mut http::HttpStream,
req: &http::Request,
mut res: http::Response,
) {
if let Err(error) = handshake(client, req, &mut res) {
return http::error_handler(client, res, error);
}
loop { loop {
let header = FrameHeader::from(&mut client.stream).unwrap(); let msg = recv_message(client).unwrap();
if header.mask { match msg {
println!("Mask: {:X}", header.masking_key.unwrap()); Message::TextMessage(msg) => {
println!("Data: {}", msg.data);
},
Message::CloseMessage(msg) => {
return
} }
Message::PingMessage(msg) => {
let mut buf = [0u8; 8192]; // TODO send pong
client },
.stream Message::PongMessage(msg) => {
.read_exact(&mut buf[..header.payload_len() as usize]) // TODO something
.unwrap();
if header.mask {
let key: [u8; 4] = [
(header.masking_key.unwrap() >> 24) as u8,
((header.masking_key.unwrap() >> 16) & 0xFF) as u8,
((header.masking_key.unwrap() >> 8) & 0xFF) as u8,
(header.masking_key.unwrap() & 0xFF) as u8,
];
for (pos, byte) in buf.iter_mut().enumerate() {
*byte ^= key[pos % 4];
} }
} }
println!(
"Msg: {}",
String::from_utf8_lossy(&buf[..header.payload_len() as usize])
);
} }
} }

View File

@ -21,6 +21,30 @@ pub struct FrameHeader {
masking_key: Option<u32>, masking_key: Option<u32>,
} }
pub enum Message {
PingMessage(PingMessage),
PongMessage(PongMessage),
CloseMessage(CloseMessage),
TextMessage(TextMessage),
}
pub struct PingMessage {
data: String,
}
pub struct PongMessage {
data: String,
}
pub struct CloseMessage {
code: Option<u16>,
reason: Option<String>,
}
pub struct TextMessage {
data: String,
}
impl FrameHeader { impl FrameHeader {
pub fn from(socket: &mut http::Stream) -> Result<Self, Error> { pub fn from(socket: &mut http::Stream) -> Result<Self, Error> {
let mut data = [0u8; 2]; let mut data = [0u8; 2];