bit cleaner websocket implementation
This commit is contained in:
14
src/error.rs
14
src/error.rs
@ -11,6 +11,7 @@ pub enum Kind {
|
||||
WebSocketError,
|
||||
NotImplementedError,
|
||||
UsimpProtocolError,
|
||||
Utf8DecodeError,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
@ -59,6 +60,7 @@ impl Error {
|
||||
Kind::WebSocketError => "WebSocket protocol error",
|
||||
Kind::NotImplementedError => "Not yet implemented",
|
||||
Kind::UsimpProtocolError => "USIMP protocol error",
|
||||
Kind::Utf8DecodeError => "Unable to decode UTF-8 data",
|
||||
},
|
||||
}
|
||||
}
|
||||
@ -88,6 +90,7 @@ impl fmt::Display for Error {
|
||||
Kind::WebSocketError => "websocket protocol error",
|
||||
Kind::NotImplementedError => "not yet implemented",
|
||||
Kind::UsimpProtocolError => "usimp protocol error",
|
||||
Kind::Utf8DecodeError => "unable to decode utf-8 data",
|
||||
}
|
||||
.to_string();
|
||||
if let Some(desc) = &self.desc {
|
||||
@ -143,3 +146,14 @@ impl From<r2d2_postgres::postgres::Error> for Error {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::string::FromUtf8Error> for Error {
|
||||
fn from(error: std::string::FromUtf8Error) -> Self {
|
||||
Error {
|
||||
kind: Kind::Utf8DecodeError,
|
||||
msg: Some("Unable to decode UTF-8 data".to_string()),
|
||||
desc: Some(error.to_string()),
|
||||
class: Class::ClientError,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,80 +1,136 @@
|
||||
use crate::error::*;
|
||||
use crate::http;
|
||||
|
||||
use crate::websocket::FrameHeader;
|
||||
use crate::websocket::*;
|
||||
use base64;
|
||||
use crypto;
|
||||
use crypto::digest::Digest;
|
||||
|
||||
pub fn connection_handler(
|
||||
pub fn recv_message(client: &mut http::HttpStream) -> Result<Message, Error> {
|
||||
let mut msg: Vec<u8> = Vec::new();
|
||||
let mut msg_type = 0;
|
||||
loop {
|
||||
let header = FrameHeader::from(&mut client.stream)?;
|
||||
|
||||
if msg_type != 0 && header.opcode != 0 {
|
||||
return Err(Error::new(Kind::WebSocketError, Class::ClientError)
|
||||
.set_desc("continuation frame expected".to_string()));
|
||||
} else if header.opcode >= 8 && (!header.fin || header.payload_len >= 126) {
|
||||
return Err(Error::new(Kind::WebSocketError, Class::ClientError)
|
||||
.set_desc("invalid control frame".to_string()));
|
||||
}
|
||||
|
||||
match header.opcode {
|
||||
0 => {}, // cont
|
||||
1 => {}, // text
|
||||
2 => // binary
|
||||
return Err(Error::new(Kind::UsimpProtocolError, Class::ClientError)
|
||||
.set_desc("binary frames must not be sent on a usimp connection".to_string())),
|
||||
8 => {}, // close
|
||||
9 => {}, // ping
|
||||
10 => {}, // pong
|
||||
_ => return Err(Error::new(Kind::WebSocketError, Class::ClientError)
|
||||
.set_desc("invalid opcode".to_string())),
|
||||
}
|
||||
|
||||
msg_type = header.opcode;
|
||||
|
||||
// FIXME check payload len and total len
|
||||
|
||||
let mut buf = vec![0u8; header.payload_len() as usize];
|
||||
client.stream.read_exact(&mut buf)?;
|
||||
|
||||
if header.mask {
|
||||
let key: [u8; 4] = [
|
||||
(header.masking_key.unwrap() >> 24) as u8,
|
||||
((header.masking_key.unwrap() >> 16) & 0xFF) as u8,
|
||||
((header.masking_key.unwrap() >> 8) & 0xFF) as u8,
|
||||
(header.masking_key.unwrap() & 0xFF) as u8,
|
||||
];
|
||||
for (pos, byte) in buf.iter_mut().enumerate() {
|
||||
*byte ^= key[pos & 3]; // = pos % 4
|
||||
}
|
||||
}
|
||||
|
||||
msg.append(&mut buf);
|
||||
|
||||
if header.fin {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
match msg_type {
|
||||
1 => {Ok(Message::TextMessage(TextMessage {
|
||||
data: String::from_utf8(msg)?
|
||||
}))},
|
||||
8 => {
|
||||
let mut code = None;
|
||||
let mut reason = None;
|
||||
|
||||
if msg.len() >= 2 {
|
||||
code = Some(((msg[0] as u16) << 8) | (msg[1] as u16));
|
||||
}
|
||||
|
||||
if msg.len() > 2 {
|
||||
reason = Some(String::from_utf8(msg[2..].to_vec())?);
|
||||
}
|
||||
|
||||
Ok(Message::CloseMessage(CloseMessage {
|
||||
code,
|
||||
reason
|
||||
}))
|
||||
},
|
||||
9 => {Ok(Message::PingMessage(PingMessage {
|
||||
data: String::from_utf8(msg)?
|
||||
}))},
|
||||
10 => {Ok(Message::PongMessage(PongMessage {
|
||||
data: String::from_utf8(msg)?
|
||||
}))},
|
||||
_ => panic!("invalid msg_type for websocket")
|
||||
}
|
||||
}
|
||||
|
||||
pub fn handshake(
|
||||
client: &mut http::HttpStream,
|
||||
req: &http::Request,
|
||||
mut res: http::Response,
|
||||
) {
|
||||
res: &mut http::Response,
|
||||
) -> Result<(), Error> {
|
||||
if let http::Method::GET = req.method {
|
||||
} else {
|
||||
res.status(405);
|
||||
res.header.add_field("Allow", "GET");
|
||||
return http::error_handler(
|
||||
client,
|
||||
res,
|
||||
Error::new(Kind::WebSocketError, Class::ClientError)
|
||||
.set_desc("method not allowed".to_string()),
|
||||
);
|
||||
return Err(Error::new(Kind::WebSocketError, Class::ClientError)
|
||||
.set_desc("method not allowed".to_string()));
|
||||
}
|
||||
|
||||
if let Some(_) = req.header.find_field("Connection") {
|
||||
if !req.header.field_has_value("Connection", "upgrade") {
|
||||
return http::error_handler(
|
||||
client,
|
||||
res,
|
||||
Error::new(Kind::WebSocketError, Class::ClientError)
|
||||
.set_desc("invalid value for header field 'Connection'".to_string()),
|
||||
);
|
||||
return Err(Error::new(Kind::WebSocketError, Class::ClientError)
|
||||
.set_desc("invalid value for header field 'Connection'".to_string()));
|
||||
}
|
||||
} else {
|
||||
return http::error_handler(
|
||||
client,
|
||||
res,
|
||||
Error::new(Kind::WebSocketError, Class::ClientError)
|
||||
.set_desc("unable to find header field 'Connection'".to_string()),
|
||||
);
|
||||
return Err(Error::new(Kind::WebSocketError, Class::ClientError)
|
||||
.set_desc("unable to find header field 'Connection'".to_string()));
|
||||
}
|
||||
|
||||
if let Some(upgrade) = req.header.find_field("Upgrade") {
|
||||
if !upgrade.eq_ignore_ascii_case("websocket") {
|
||||
return http::error_handler(
|
||||
client,
|
||||
res,
|
||||
Error::new(Kind::WebSocketError, Class::ClientError)
|
||||
.set_desc("invalid value for header field 'Upgrade'".to_string()),
|
||||
);
|
||||
return Err(Error::new(Kind::WebSocketError, Class::ClientError)
|
||||
.set_desc("invalid value for header field 'Upgrade'".to_string()));
|
||||
}
|
||||
} else {
|
||||
return http::error_handler(
|
||||
client,
|
||||
res,
|
||||
Error::new(Kind::WebSocketError, Class::ClientError)
|
||||
.set_desc("unable to find header field 'Upgrade'".to_string()),
|
||||
);
|
||||
return Err(Error::new(Kind::WebSocketError, Class::ClientError)
|
||||
.set_desc("unable to find header field 'Upgrade'".to_string()));
|
||||
}
|
||||
|
||||
if let Some(version) = req.header.find_field("Sec-WebSocket-Version") {
|
||||
if !version.eq("13") {
|
||||
return http::error_handler(
|
||||
client,
|
||||
res,
|
||||
Error::new(Kind::WebSocketError, Class::ClientError)
|
||||
.set_desc("invalid value for header field 'Sec-WebSocket-Key'".to_string()),
|
||||
);
|
||||
return Err(Error::new(Kind::WebSocketError, Class::ClientError)
|
||||
.set_desc("invalid value for header field 'Sec-WebSocket-Key'".to_string()));
|
||||
}
|
||||
} else {
|
||||
return http::error_handler(
|
||||
client,
|
||||
res,
|
||||
Error::new(Kind::WebSocketError, Class::ClientError)
|
||||
.set_desc("unable to find header field 'Sec-WebSocket-Version'".to_string()),
|
||||
);
|
||||
return Err(Error::new(Kind::WebSocketError, Class::ClientError)
|
||||
.set_desc("unable to find header field 'Sec-WebSocket-Version'".to_string()));
|
||||
}
|
||||
|
||||
if let Some(key) = req.header.find_field("Sec-WebSocket-Key") {
|
||||
@ -87,12 +143,8 @@ pub fn connection_handler(
|
||||
let key = base64::encode(result);
|
||||
res.header.add_field("Sec-WebSocket-Accept", key.as_str());
|
||||
} else {
|
||||
return http::error_handler(
|
||||
client,
|
||||
res,
|
||||
Error::new(Kind::WebSocketError, Class::ClientError)
|
||||
.set_desc("unable to find header field 'Sec-WebSocket-Key'".to_string()),
|
||||
);
|
||||
return Err(Error::new(Kind::WebSocketError, Class::ClientError)
|
||||
.set_desc("unable to find header field 'Sec-WebSocket-Key'".to_string()));
|
||||
}
|
||||
|
||||
client.server_keep_alive = false;
|
||||
@ -100,35 +152,36 @@ pub fn connection_handler(
|
||||
res.header.add_field("Upgrade", "websocket");
|
||||
|
||||
res.status(101);
|
||||
res.send(&mut client.stream).unwrap();
|
||||
res.send(&mut client.stream)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn connection_handler(
|
||||
client: &mut http::HttpStream,
|
||||
req: &http::Request,
|
||||
mut res: http::Response,
|
||||
) {
|
||||
if let Err(error) = handshake(client, req, &mut res) {
|
||||
return http::error_handler(client, res, error);
|
||||
}
|
||||
|
||||
loop {
|
||||
let header = FrameHeader::from(&mut client.stream).unwrap();
|
||||
if header.mask {
|
||||
println!("Mask: {:X}", header.masking_key.unwrap());
|
||||
}
|
||||
let msg = recv_message(client).unwrap();
|
||||
match msg {
|
||||
Message::TextMessage(msg) => {
|
||||
println!("Data: {}", msg.data);
|
||||
},
|
||||
Message::CloseMessage(msg) => {
|
||||
|
||||
let mut buf = [0u8; 8192];
|
||||
client
|
||||
.stream
|
||||
.read_exact(&mut buf[..header.payload_len() as usize])
|
||||
.unwrap();
|
||||
|
||||
if header.mask {
|
||||
let key: [u8; 4] = [
|
||||
(header.masking_key.unwrap() >> 24) as u8,
|
||||
((header.masking_key.unwrap() >> 16) & 0xFF) as u8,
|
||||
((header.masking_key.unwrap() >> 8) & 0xFF) as u8,
|
||||
(header.masking_key.unwrap() & 0xFF) as u8,
|
||||
];
|
||||
for (pos, byte) in buf.iter_mut().enumerate() {
|
||||
*byte ^= key[pos % 4];
|
||||
return
|
||||
}
|
||||
Message::PingMessage(msg) => {
|
||||
// TODO send pong
|
||||
},
|
||||
Message::PongMessage(msg) => {
|
||||
// TODO something
|
||||
}
|
||||
}
|
||||
|
||||
println!(
|
||||
"Msg: {}",
|
||||
String::from_utf8_lossy(&buf[..header.payload_len() as usize])
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -21,6 +21,30 @@ pub struct FrameHeader {
|
||||
masking_key: Option<u32>,
|
||||
}
|
||||
|
||||
pub enum Message {
|
||||
PingMessage(PingMessage),
|
||||
PongMessage(PongMessage),
|
||||
CloseMessage(CloseMessage),
|
||||
TextMessage(TextMessage),
|
||||
}
|
||||
|
||||
pub struct PingMessage {
|
||||
data: String,
|
||||
}
|
||||
|
||||
pub struct PongMessage {
|
||||
data: String,
|
||||
}
|
||||
|
||||
pub struct CloseMessage {
|
||||
code: Option<u16>,
|
||||
reason: Option<String>,
|
||||
}
|
||||
|
||||
pub struct TextMessage {
|
||||
data: String,
|
||||
}
|
||||
|
||||
impl FrameHeader {
|
||||
pub fn from(socket: &mut http::Stream) -> Result<Self, Error> {
|
||||
let mut data = [0u8; 2];
|
||||
|
Reference in New Issue
Block a user