bit cleaner websocket implementation

This commit is contained in:
2021-05-22 15:51:24 +02:00
parent 3867435a2d
commit 0e9caa496f
3 changed files with 168 additions and 77 deletions

View File

@ -11,6 +11,7 @@ pub enum Kind {
WebSocketError,
NotImplementedError,
UsimpProtocolError,
Utf8DecodeError,
}
#[derive(Copy, Clone, Debug)]
@ -59,6 +60,7 @@ impl Error {
Kind::WebSocketError => "WebSocket protocol error",
Kind::NotImplementedError => "Not yet implemented",
Kind::UsimpProtocolError => "USIMP protocol error",
Kind::Utf8DecodeError => "Unable to decode UTF-8 data",
},
}
}
@ -88,6 +90,7 @@ impl fmt::Display for Error {
Kind::WebSocketError => "websocket protocol error",
Kind::NotImplementedError => "not yet implemented",
Kind::UsimpProtocolError => "usimp protocol error",
Kind::Utf8DecodeError => "unable to decode utf-8 data",
}
.to_string();
if let Some(desc) = &self.desc {
@ -143,3 +146,14 @@ impl From<r2d2_postgres::postgres::Error> for Error {
}
}
}
impl From<std::string::FromUtf8Error> for Error {
fn from(error: std::string::FromUtf8Error) -> Self {
Error {
kind: Kind::Utf8DecodeError,
msg: Some("Unable to decode UTF-8 data".to_string()),
desc: Some(error.to_string()),
class: Class::ClientError,
}
}
}

View File

@ -1,80 +1,136 @@
use crate::error::*;
use crate::http;
use crate::websocket::FrameHeader;
use crate::websocket::*;
use base64;
use crypto;
use crypto::digest::Digest;
pub fn connection_handler(
pub fn recv_message(client: &mut http::HttpStream) -> Result<Message, Error> {
let mut msg: Vec<u8> = Vec::new();
let mut msg_type = 0;
loop {
let header = FrameHeader::from(&mut client.stream)?;
if msg_type != 0 && header.opcode != 0 {
return Err(Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("continuation frame expected".to_string()));
} else if header.opcode >= 8 && (!header.fin || header.payload_len >= 126) {
return Err(Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("invalid control frame".to_string()));
}
match header.opcode {
0 => {}, // cont
1 => {}, // text
2 => // binary
return Err(Error::new(Kind::UsimpProtocolError, Class::ClientError)
.set_desc("binary frames must not be sent on a usimp connection".to_string())),
8 => {}, // close
9 => {}, // ping
10 => {}, // pong
_ => return Err(Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("invalid opcode".to_string())),
}
msg_type = header.opcode;
// FIXME check payload len and total len
let mut buf = vec![0u8; header.payload_len() as usize];
client.stream.read_exact(&mut buf)?;
if header.mask {
let key: [u8; 4] = [
(header.masking_key.unwrap() >> 24) as u8,
((header.masking_key.unwrap() >> 16) & 0xFF) as u8,
((header.masking_key.unwrap() >> 8) & 0xFF) as u8,
(header.masking_key.unwrap() & 0xFF) as u8,
];
for (pos, byte) in buf.iter_mut().enumerate() {
*byte ^= key[pos & 3]; // = pos % 4
}
}
msg.append(&mut buf);
if header.fin {
break
}
}
match msg_type {
1 => {Ok(Message::TextMessage(TextMessage {
data: String::from_utf8(msg)?
}))},
8 => {
let mut code = None;
let mut reason = None;
if msg.len() >= 2 {
code = Some(((msg[0] as u16) << 8) | (msg[1] as u16));
}
if msg.len() > 2 {
reason = Some(String::from_utf8(msg[2..].to_vec())?);
}
Ok(Message::CloseMessage(CloseMessage {
code,
reason
}))
},
9 => {Ok(Message::PingMessage(PingMessage {
data: String::from_utf8(msg)?
}))},
10 => {Ok(Message::PongMessage(PongMessage {
data: String::from_utf8(msg)?
}))},
_ => panic!("invalid msg_type for websocket")
}
}
pub fn handshake(
client: &mut http::HttpStream,
req: &http::Request,
mut res: http::Response,
) {
res: &mut http::Response,
) -> Result<(), Error> {
if let http::Method::GET = req.method {
} else {
res.status(405);
res.header.add_field("Allow", "GET");
return http::error_handler(
client,
res,
Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("method not allowed".to_string()),
);
return Err(Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("method not allowed".to_string()));
}
if let Some(_) = req.header.find_field("Connection") {
if !req.header.field_has_value("Connection", "upgrade") {
return http::error_handler(
client,
res,
Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("invalid value for header field 'Connection'".to_string()),
);
return Err(Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("invalid value for header field 'Connection'".to_string()));
}
} else {
return http::error_handler(
client,
res,
Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("unable to find header field 'Connection'".to_string()),
);
return Err(Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("unable to find header field 'Connection'".to_string()));
}
if let Some(upgrade) = req.header.find_field("Upgrade") {
if !upgrade.eq_ignore_ascii_case("websocket") {
return http::error_handler(
client,
res,
Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("invalid value for header field 'Upgrade'".to_string()),
);
return Err(Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("invalid value for header field 'Upgrade'".to_string()));
}
} else {
return http::error_handler(
client,
res,
Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("unable to find header field 'Upgrade'".to_string()),
);
return Err(Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("unable to find header field 'Upgrade'".to_string()));
}
if let Some(version) = req.header.find_field("Sec-WebSocket-Version") {
if !version.eq("13") {
return http::error_handler(
client,
res,
Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("invalid value for header field 'Sec-WebSocket-Key'".to_string()),
);
return Err(Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("invalid value for header field 'Sec-WebSocket-Key'".to_string()));
}
} else {
return http::error_handler(
client,
res,
Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("unable to find header field 'Sec-WebSocket-Version'".to_string()),
);
return Err(Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("unable to find header field 'Sec-WebSocket-Version'".to_string()));
}
if let Some(key) = req.header.find_field("Sec-WebSocket-Key") {
@ -87,12 +143,8 @@ pub fn connection_handler(
let key = base64::encode(result);
res.header.add_field("Sec-WebSocket-Accept", key.as_str());
} else {
return http::error_handler(
client,
res,
Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("unable to find header field 'Sec-WebSocket-Key'".to_string()),
);
return Err(Error::new(Kind::WebSocketError, Class::ClientError)
.set_desc("unable to find header field 'Sec-WebSocket-Key'".to_string()));
}
client.server_keep_alive = false;
@ -100,35 +152,36 @@ pub fn connection_handler(
res.header.add_field("Upgrade", "websocket");
res.status(101);
res.send(&mut client.stream).unwrap();
res.send(&mut client.stream)?;
Ok(())
}
pub fn connection_handler(
client: &mut http::HttpStream,
req: &http::Request,
mut res: http::Response,
) {
if let Err(error) = handshake(client, req, &mut res) {
return http::error_handler(client, res, error);
}
loop {
let header = FrameHeader::from(&mut client.stream).unwrap();
if header.mask {
println!("Mask: {:X}", header.masking_key.unwrap());
}
let msg = recv_message(client).unwrap();
match msg {
Message::TextMessage(msg) => {
println!("Data: {}", msg.data);
},
Message::CloseMessage(msg) => {
let mut buf = [0u8; 8192];
client
.stream
.read_exact(&mut buf[..header.payload_len() as usize])
.unwrap();
if header.mask {
let key: [u8; 4] = [
(header.masking_key.unwrap() >> 24) as u8,
((header.masking_key.unwrap() >> 16) & 0xFF) as u8,
((header.masking_key.unwrap() >> 8) & 0xFF) as u8,
(header.masking_key.unwrap() & 0xFF) as u8,
];
for (pos, byte) in buf.iter_mut().enumerate() {
*byte ^= key[pos % 4];
return
}
Message::PingMessage(msg) => {
// TODO send pong
},
Message::PongMessage(msg) => {
// TODO something
}
}
println!(
"Msg: {}",
String::from_utf8_lossy(&buf[..header.payload_len() as usize])
);
}
}

View File

@ -21,6 +21,30 @@ pub struct FrameHeader {
masking_key: Option<u32>,
}
pub enum Message {
PingMessage(PingMessage),
PongMessage(PongMessage),
CloseMessage(CloseMessage),
TextMessage(TextMessage),
}
pub struct PingMessage {
data: String,
}
pub struct PongMessage {
data: String,
}
pub struct CloseMessage {
code: Option<u16>,
reason: Option<String>,
}
pub struct TextMessage {
data: String,
}
impl FrameHeader {
pub fn from(socket: &mut http::Stream) -> Result<Self, Error> {
let mut data = [0u8; 2];