Refactored project

This commit is contained in:
2021-05-15 21:20:02 +02:00
parent 4ce3569458
commit 19004d8cf1
5 changed files with 585 additions and 390 deletions

View File

@ -1,8 +1,16 @@
mod consts;
mod parser;
mod handler;
use std::net::TcpStream;
use openssl::ssl::SslStream;
use std::io::{Write, Read};
use std::fmt::Formatter;
pub use handler::*;
static REQUESTS_PER_CONNECTION: u32 = 200;
pub enum Stream {
Tcp(TcpStream),
Ssl(SslStream<TcpStream>),
@ -61,6 +69,67 @@ impl std::fmt::Display for Method {
}
}
#[derive(Copy, Clone)]
pub enum StatusClass {
Informational,
Success,
Redirection,
ClientError,
ServerError,
}
impl StatusClass {
pub fn from_code(status_code: u16) -> StatusClass {
for (code, class, _msg, _desc) in &consts::HTTP_STATUSES {
if *code == status_code {
return class.clone();
}
}
match status_code {
100..=199 => StatusClass::Informational,
200..=299 => StatusClass::Success,
300..=399 => StatusClass::Redirection,
400..=499 => StatusClass::ClientError,
500..=599 => StatusClass::ServerError,
_ => panic!("invalid status code")
}
}
}
#[derive(Clone)]
pub struct Status {
code: u16,
message: String,
class: StatusClass,
}
impl Status {
pub fn from_code(status_code: u16) -> Status {
for (code, class, msg, _desc) in &consts::HTTP_STATUSES {
if *code == status_code {
return Status {
code: status_code,
message: msg.to_string(),
class: class.clone(),
}
}
}
panic!("invalid status code");
}
pub fn new_custom(status_code: u16, message: &str) -> Status {
if status_code < 100 || status_code > 599 {
panic!("invalid status code");
}
Status {
code: status_code,
message: message.to_string(),
class: StatusClass::from_code(status_code),
}
}
}
pub struct HeaderField {
name: String,
value: String,
@ -73,17 +142,44 @@ impl std::fmt::Display for HeaderField {
}
pub struct Request {
version: String,
method: Method,
uri: String,
header_fields: Vec<HeaderField>
}
pub struct Response {
status_code: u16,
status_message: String,
version: String,
status: Status,
header_fields: Vec<HeaderField>
}
impl Response {
fn new() -> Response {
Response {
version: "1.1".to_string(),
status: Status::from_code(200),
header_fields: Vec::new(),
}
}
fn add_header(&mut self, name: &str, value: &str) {
self.header_fields.push(HeaderField {
name: String::from(name),
value: String::from(value),
});
}
fn send(&self, stream: &mut Stream) -> Result<(), std::io::Error> {
let mut header = format!("HTTP/{} {:03} {}\r\n", self.version, self.status.code, self.status.message);
for header_field in &self.header_fields {
header.push_str(format!("{}: {}\r\n", header_field.name, header_field.value).as_str());
}
header.push_str("\r\n");
stream.write_all(header.as_bytes())
}
}
impl Stream {
pub fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
match self {
@ -102,401 +198,21 @@ impl Stream {
pub fn peek(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
match self {
Stream::Tcp(stream) => stream.peek(buf),
Stream::Ssl(_stream) => todo!("Not implemented"),
Stream::Ssl(_stream) => todo!("Not implemented in rust-openssl"),
}
}
pub fn write(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
pub fn write(&mut self, buf: &[u8]) -> Result<usize, std::io::Error> {
match self {
Stream::Tcp(stream) => stream.write(buf),
Stream::Ssl(stream) => stream.write(buf),
}
}
}
pub mod handler {
pub fn connection_handler(client: super::Stream) {
let mut client = super::HttpStream {
stream: client,
request_num: 0,
};
while client.request_num < 200 {
request_handler(&mut client);
client.request_num += 1;
}
}
fn request_handler(client: &mut super::HttpStream) {
let req = super::parser::parse_request(&mut client.stream).unwrap();
println!("{} {}", req.method, req.uri);
}
}
pub mod parser {
use crate::http;
use std::os::linux::raw::stat;
pub fn parse_request(stream: &mut http::Stream) -> Result<http::Request, String> {
let mut buf = [0; 4096];
let size = stream.peek(&mut buf).unwrap();
let mut parser = Parser::new_request_parser(&buf[..size]);
let header_size = parser.parse().unwrap();
let mut header_fields = Vec::new();
for (name, value) in parser.headers {
header_fields.push(http::HeaderField {
name: String::from(name),
value: String::from(value),
});
}
let request = http::Request {
method: http::Method::from_str(parser.method.unwrap()),
uri: String::from(parser.uri.unwrap()),
header_fields,
};
stream.read_exact(&mut buf[..header_size]).unwrap();
Ok(request)
}
pub fn parse_response(stream: &mut http::Stream) -> Result<http::Response, String> {
let mut buf = [0; 4096];
let size = stream.peek(&mut buf).unwrap();
let mut parser = Parser::new_request_parser(&buf[..size]);
let header_size = parser.parse().unwrap();
let status_code = parser.status_code.unwrap();
let status_code = match status_code.parse() {
Ok(v) => v,
Err(e) => return Err(format!("{}", e)),
};
let mut header_fields = Vec::new();
for (name, value) in parser.headers {
header_fields.push(http::HeaderField {
name: String::from(name),
value: String::from(value),
});
}
let response = http::Response {
status_code,
status_message: String::from(parser.status_message.unwrap()),
header_fields,
};
stream.read_exact(&mut buf[..header_size]).unwrap();
Ok(response)
}
#[derive(Copy, Clone)]
enum State<'a> {
Method,
Uri,
Http(&'a State<'a>),
HttpVersion(&'a State<'a>),
StatusCode,
StatusMessage,
HeaderName,
HeaderValue,
Finish,
CRLF(&'a State<'a>),
Error,
}
struct Parser<'a> {
state: State<'a>,
buf: &'a [u8],
str_start: usize,
header_size: usize,
method: Option<&'a str>,
uri: Option<&'a str>,
http_version: Option<&'a str>,
status_code: Option<&'a str>,
status_message: Option<&'a str>,
headers: Vec<(&'a str, &'a str)>,
}
impl Parser<'_> {
fn new_request_parser(buf: &[u8]) -> Parser {
Parser {
state: State::Method,
buf,
str_start: 0,
header_size: 0,
method: None,
uri: None,
http_version: None,
status_code: None,
status_message: None,
headers: Vec::new(),
}
}
fn new_response_parser(buf: &[u8]) -> Parser {
Parser {
state: State::Http(&State::StatusCode),
buf,
str_start: 0,
header_size: 0,
method: None,
uri: None,
http_version: None,
status_code: None,
status_message: None,
headers: Vec::new(),
}
}
fn parse(&mut self) -> Result<usize, String> {
for char in self.buf {
self.next(*char);
match self.state {
State::Finish => return Ok(self.header_size),
State::Error => return Err(format!("invalid character at position {}", self.header_size - 1)),
_ => {},
}
}
return Err(String::from("input too short"));
}
fn next(&mut self, char: u8) {
self.header_size += 1;
let get_str = || {
std::str::from_utf8(&self.buf[self.str_start..self.header_size - 1]).unwrap()
};
self.state = match &self.state {
State::Error => State::Error,
State::Finish => State::Error,
State::Method => {
match char {
0x41..=0x5A => State::Method,
0x20 => {
self.method = Some(get_str());
self.str_start = self.header_size;
State::Uri
},
_ => State::Error,
}
},
State::Uri => {
match char {
0x21..=0x7E => State::Uri,
0x20 => {
self.uri = Some(get_str());
self.str_start = self.header_size;
State::Http(&State::HeaderName)
},
_ => State::Error,
}
},
State::Http(next) => {
match char {
0x48 | 0x54 | 0x50 => State::Http(next),
0x2F => {
let http = get_str();
self.str_start = self.header_size;
if http != "HTTP" {
State::Error
} else {
State::HttpVersion(next)
}
},
_ => State::Error,
}
},
State::HttpVersion(next) => {
match char {
0x30..=0x39 | 0x2E => State::HttpVersion(next),
0x0D => {
match next {
State::HeaderName => {
self.http_version = Some(get_str());
State::CRLF(next)
},
_ => State::Error,
}
},
0x20 => {
match next {
State::StatusCode => {
self.http_version = Some(get_str());
self.str_start = self.header_size;
State::StatusCode
},
_ => State::Error,
}
}
_ => State::Error,
}
},
State::StatusCode => {
match char {
0x30..=0x39 => State::StatusCode,
0x20 => {
self.status_code = Some(get_str());
self.str_start = self.header_size;
State::StatusMessage
},
_ => State::Error,
}
},
State::StatusMessage => {
match char {
0x20..=0x7E => State::StatusMessage,
0x0D => {
self.status_message = Some(get_str());
State::CRLF(&State::HeaderName)
},
_ => State::Error,
}
},
State::HeaderName => {
match char {
0x0D => {
if self.header_size == self.str_start + 1 {
State::CRLF(&State::Finish)
} else {
State::Error
}
},
0x3A => {
let header_name = get_str();
self.headers.push((header_name, ""));
self.str_start = self.header_size;
State::HeaderValue
},
0x00..=0x1F | 0x7F | 0x80..=0xFF |
0x20 | 0x28 | 0x29 | 0x2C | 0x2F |
0x3A..=0x40 | 0x5B..=0x5D | 0x7B | 0x7D => State::Error,
_ => State::HeaderName,
}
},
State::HeaderValue => {
match char {
0x20..=0x7E | 0x09 => State::HeaderValue,
0x0D => {
self.headers.last_mut().unwrap().1 = get_str().trim();
State::CRLF(&State::HeaderName)
},
_ => State::Error,
}
}
State::CRLF(next) => {
match char {
0x0A => {
self.str_start = self.header_size;
*next.clone()
},
_ => State::Error,
}
},
}
}
}
#[cfg(test)]
mod tests {
use std::panic::panic_any;
#[test]
fn simple_request() {
let request: &str = "GET /index.html HTTP/1.1\r\n\
Host: www.example.com\r\n\
\r\n";
let mut parser = super::Parser::new_request_parser(request.as_bytes());
let size = parser.parse().unwrap();
assert_eq!(51, size);
assert_eq!("GET", parser.method.unwrap());
assert_eq!("/index.html", parser.uri.unwrap());
assert_eq!("1.1", parser.http_version.unwrap());
assert_eq!(None, parser.status_code);
assert_eq!(None, parser.status_message);
assert_eq!(1, parser.headers.len());
assert_eq!(("Host", "www.example.com"), parser.headers[0]);
}
#[test]
fn complex_request() {
let request: &str = "POST /upload/file.txt HTTP/1.3\r\n\
Host: www.example.com \r\n\
Content-Length: 13 \r\n\
User-Agent: Mozilla/5.0 (X11; Linux x86_64) \r\n\
\r\n\
username=test";
let mut parser = super::Parser::new_request_parser(request.as_bytes());
let size = parser.parse().unwrap();
assert_eq!(129, size);
assert_eq!("POST", parser.method.unwrap());
assert_eq!("/upload/file.txt", parser.uri.unwrap());
assert_eq!("1.3", parser.http_version.unwrap());
assert_eq!(None, parser.status_code);
assert_eq!(None, parser.status_message);
assert_eq!(3, parser.headers.len());
assert_eq!(("Host", "www.example.com"), parser.headers[0]);
assert_eq!(("Content-Length", "13"), parser.headers[1]);
assert_eq!(("User-Agent", "Mozilla/5.0 (X11; Linux x86_64)"), parser.headers[2]);
assert_eq!("username=test", &request[size..]);
}
#[test]
fn invalid_request_1() {
let request: &str = "GET /files/größe.txt HTTP/1.1\r\n\r\n";
let mut parser = super::Parser::new_request_parser(request.as_bytes());
match parser.parse() {
Ok(v) => panic!("should fail"),
Err(e) => assert_eq!("invalid character at position 13", e),
}
}
#[test]
fn invalid_request_2() {
let request: &str = "GET /index.html HTT";
let mut parser = super::Parser::new_request_parser(request.as_bytes());
match parser.parse() {
Ok(v) => panic!("should fail"),
Err(e) => assert_eq!("input too short", e),
}
}
#[test]
fn simple_response() {
let response: &str = "HTTP/1.1 200 OK\r\n\
Content-Length: 12\r\n\
Content-Type: text/plain; charset=us-ascii\r\n\
\r\n\
Hello world!";
let mut parser = super::Parser::new_response_parser(response.as_bytes());
let size = parser.parse().unwrap();
assert_eq!(83, size);
assert_eq!("200", parser.status_code.unwrap());
assert_eq!("OK", parser.status_message.unwrap());
assert_eq!("1.1", parser.http_version.unwrap());
assert_eq!(None, parser.method);
assert_eq!(None, parser.uri);
assert_eq!(2, parser.headers.len());
assert_eq!(("Content-Length", "12"), parser.headers[0]);
assert_eq!(("Content-Type", "text/plain; charset=us-ascii"), parser.headers[1]);
assert_eq!("Hello world!", &response[size..]);
pub fn write_all(&mut self, buf: &[u8]) -> Result<(), std::io::Error> {
match self {
Stream::Tcp(stream) => stream.write_all(buf),
Stream::Ssl(stream) => stream.write_all(buf),
}
}
}