diff --git a/src/http.rs b/src/http.rs index b94cbb2..b034b6d 100644 --- a/src/http.rs +++ b/src/http.rs @@ -1,17 +1,502 @@ use std::net::TcpStream; use openssl::ssl::SslStream; +use std::io::{Write, Read}; +use std::fmt::Formatter; pub enum Stream { Tcp(TcpStream), Ssl(SslStream), } -impl Stream {} - -pub fn connection_handler(client: Stream) { - +pub struct HttpStream { + stream: Stream, + request_num: u32, } -fn request_handler(client: Stream) { - +pub enum Method { + GET, + POST, + PUT, + HEAD, + TRACE, + CONNECT, + DELETE, + OPTIONS, + Custom(String), +} + +impl Method { + pub fn from_str(v: &str) -> Method { + match v { + "GET" => Method::GET, + "POST" => Method::POST, + "PUT" => Method::PUT, + "HEAD" => Method::HEAD, + "TRACE" => Method::TRACE, + "CONNECT" => Method::CONNECT, + "DELETE" => Method::DELETE, + "OPTIONS" => Method::OPTIONS, + _ => Method::Custom(String::from(v)), + } + } + + pub fn to_str(&self) -> &str { + match self { + Method::GET => "GET", + Method::POST => "POST", + Method::PUT => "PUT", + Method::HEAD => "HEAD", + Method::TRACE => "TRACE", + Method::CONNECT => "CONNECT", + Method::DELETE => "DELETE", + Method::OPTIONS => "OPTIONS", + Method::Custom(v) => v, + } + } +} + +impl std::fmt::Display for Method { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.to_str()) + } +} + +pub struct HeaderField { + name: String, + value: String, +} + +impl std::fmt::Display for HeaderField { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "[{}: {}]", self.name, self.value) + } +} + +pub struct Request { + method: Method, + uri: String, + header_fields: Vec +} + +pub struct Response { + status_code: u16, + status_message: String, + header_fields: Vec +} + +impl Stream { + pub fn read(&mut self, buf: &mut [u8]) -> Result { + match self { + Stream::Tcp(stream) => stream.read(buf), + Stream::Ssl(stream) => stream.read(buf), + } + } + + pub fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), std::io::Error> { + match self { + Stream::Tcp(stream) => stream.read_exact(buf), + Stream::Ssl(stream) => stream.read_exact(buf), + } + } + + pub fn peek(&mut self, buf: &mut [u8]) -> Result { + match self { + Stream::Tcp(stream) => stream.peek(buf), + Stream::Ssl(_stream) => todo!("Not implemented"), + } + } + + pub fn write(&mut self, buf: &mut [u8]) -> Result { + match self { + Stream::Tcp(stream) => stream.write(buf), + Stream::Ssl(stream) => stream.write(buf), + } + } +} + +pub mod handler { + pub fn connection_handler(client: super::Stream) { + let mut client = super::HttpStream { + stream: client, + request_num: 0, + }; + + while client.request_num < 200 { + request_handler(&mut client); + client.request_num += 1; + } + } + + fn request_handler(client: &mut super::HttpStream) { + let req = super::parser::parse_request(&mut client.stream).unwrap(); + println!("{} {}", req.method, req.uri); + } +} + +pub mod parser { + use crate::http; + use std::os::linux::raw::stat; + + pub fn parse_request(stream: &mut http::Stream) -> Result { + let mut buf = [0; 4096]; + let size = stream.peek(&mut buf).unwrap(); + + let mut parser = Parser::new_request_parser(&buf[..size]); + let header_size = parser.parse().unwrap(); + + + + let mut header_fields = Vec::new(); + for (name, value) in parser.headers { + header_fields.push(http::HeaderField { + name: String::from(name), + value: String::from(value), + }); + } + + let request = http::Request { + method: http::Method::from_str(parser.method.unwrap()), + uri: String::from(parser.uri.unwrap()), + header_fields, + }; + + stream.read_exact(&mut buf[..header_size]).unwrap(); + + Ok(request) + } + + pub fn parse_response(stream: &mut http::Stream) -> Result { + let mut buf = [0; 4096]; + let size = stream.peek(&mut buf).unwrap(); + + let mut parser = Parser::new_request_parser(&buf[..size]); + let header_size = parser.parse().unwrap(); + + let status_code = parser.status_code.unwrap(); + let status_code = match status_code.parse() { + Ok(v) => v, + Err(e) => return Err(format!("{}", e)), + }; + + let mut header_fields = Vec::new(); + for (name, value) in parser.headers { + header_fields.push(http::HeaderField { + name: String::from(name), + value: String::from(value), + }); + } + + let response = http::Response { + status_code, + status_message: String::from(parser.status_message.unwrap()), + header_fields, + }; + + stream.read_exact(&mut buf[..header_size]).unwrap(); + + Ok(response) + } + + #[derive(Copy, Clone)] + enum State<'a> { + Method, + Uri, + Http(&'a State<'a>), + HttpVersion(&'a State<'a>), + StatusCode, + StatusMessage, + HeaderName, + HeaderValue, + Finish, + CRLF(&'a State<'a>), + Error, + } + + struct Parser<'a> { + state: State<'a>, + buf: &'a [u8], + str_start: usize, + header_size: usize, + method: Option<&'a str>, + uri: Option<&'a str>, + http_version: Option<&'a str>, + status_code: Option<&'a str>, + status_message: Option<&'a str>, + headers: Vec<(&'a str, &'a str)>, + } + + impl Parser<'_> { + fn new_request_parser(buf: &[u8]) -> Parser { + Parser { + state: State::Method, + buf, + str_start: 0, + header_size: 0, + method: None, + uri: None, + http_version: None, + status_code: None, + status_message: None, + headers: Vec::new(), + } + } + + fn new_response_parser(buf: &[u8]) -> Parser { + Parser { + state: State::Http(&State::StatusCode), + buf, + str_start: 0, + header_size: 0, + method: None, + uri: None, + http_version: None, + status_code: None, + status_message: None, + headers: Vec::new(), + } + } + + fn parse(&mut self) -> Result { + for char in self.buf { + self.next(*char); + match self.state { + State::Finish => return Ok(self.header_size), + State::Error => return Err(format!("invalid character at position {}", self.header_size - 1)), + _ => {}, + } + } + return Err(String::from("input too short")); + } + + fn next(&mut self, char: u8) { + self.header_size += 1; + let get_str = || { + std::str::from_utf8(&self.buf[self.str_start..self.header_size - 1]).unwrap() + }; + self.state = match &self.state { + State::Error => State::Error, + State::Finish => State::Error, + State::Method => { + match char { + 0x41..=0x5A => State::Method, + 0x20 => { + self.method = Some(get_str()); + self.str_start = self.header_size; + State::Uri + }, + _ => State::Error, + } + }, + State::Uri => { + match char { + 0x21..=0x7E => State::Uri, + 0x20 => { + self.uri = Some(get_str()); + self.str_start = self.header_size; + State::Http(&State::HeaderName) + }, + _ => State::Error, + } + }, + State::Http(next) => { + match char { + 0x48 | 0x54 | 0x50 => State::Http(next), + 0x2F => { + let http = get_str(); + self.str_start = self.header_size; + if http != "HTTP" { + State::Error + } else { + State::HttpVersion(next) + } + }, + _ => State::Error, + } + }, + State::HttpVersion(next) => { + match char { + 0x30..=0x39 | 0x2E => State::HttpVersion(next), + 0x0D => { + match next { + State::HeaderName => { + self.http_version = Some(get_str()); + State::CRLF(next) + }, + _ => State::Error, + } + }, + 0x20 => { + match next { + State::StatusCode => { + self.http_version = Some(get_str()); + self.str_start = self.header_size; + State::StatusCode + }, + _ => State::Error, + } + } + _ => State::Error, + } + }, + State::StatusCode => { + match char { + 0x30..=0x39 => State::StatusCode, + 0x20 => { + self.status_code = Some(get_str()); + self.str_start = self.header_size; + State::StatusMessage + }, + _ => State::Error, + } + }, + State::StatusMessage => { + match char { + 0x20..=0x7E => State::StatusMessage, + 0x0D => { + self.status_message = Some(get_str()); + State::CRLF(&State::HeaderName) + }, + _ => State::Error, + } + }, + State::HeaderName => { + match char { + 0x0D => { + if self.header_size == self.str_start + 1 { + State::CRLF(&State::Finish) + } else { + State::Error + } + }, + 0x3A => { + let header_name = get_str(); + self.headers.push((header_name, "")); + self.str_start = self.header_size; + State::HeaderValue + }, + 0x00..=0x1F | 0x7F | 0x80..=0xFF | + 0x20 | 0x28 | 0x29 | 0x2C | 0x2F | + 0x3A..=0x40 | 0x5B..=0x5D | 0x7B | 0x7D => State::Error, + _ => State::HeaderName, + } + }, + State::HeaderValue => { + match char { + 0x20..=0x7E | 0x09 => State::HeaderValue, + 0x0D => { + self.headers.last_mut().unwrap().1 = get_str().trim(); + State::CRLF(&State::HeaderName) + }, + _ => State::Error, + } + } + State::CRLF(next) => { + match char { + 0x0A => { + self.str_start = self.header_size; + *next.clone() + }, + _ => State::Error, + } + }, + } + } + } + + #[cfg(test)] + mod tests { + use std::panic::panic_any; + + #[test] + fn simple_request() { + let request: &str = "GET /index.html HTTP/1.1\r\n\ + Host: www.example.com\r\n\ + \r\n"; + + let mut parser = super::Parser::new_request_parser(request.as_bytes()); + let size = parser.parse().unwrap(); + + assert_eq!(51, size); + assert_eq!("GET", parser.method.unwrap()); + assert_eq!("/index.html", parser.uri.unwrap()); + assert_eq!("1.1", parser.http_version.unwrap()); + assert_eq!(None, parser.status_code); + assert_eq!(None, parser.status_message); + + assert_eq!(1, parser.headers.len()); + assert_eq!(("Host", "www.example.com"), parser.headers[0]); + } + + #[test] + fn complex_request() { + let request: &str = "POST /upload/file.txt HTTP/1.3\r\n\ + Host: www.example.com \r\n\ + Content-Length: 13 \r\n\ + User-Agent: Mozilla/5.0 (X11; Linux x86_64) \r\n\ + \r\n\ + username=test"; + + let mut parser = super::Parser::new_request_parser(request.as_bytes()); + let size = parser.parse().unwrap(); + + assert_eq!(129, size); + assert_eq!("POST", parser.method.unwrap()); + assert_eq!("/upload/file.txt", parser.uri.unwrap()); + assert_eq!("1.3", parser.http_version.unwrap()); + assert_eq!(None, parser.status_code); + assert_eq!(None, parser.status_message); + + assert_eq!(3, parser.headers.len()); + assert_eq!(("Host", "www.example.com"), parser.headers[0]); + assert_eq!(("Content-Length", "13"), parser.headers[1]); + assert_eq!(("User-Agent", "Mozilla/5.0 (X11; Linux x86_64)"), parser.headers[2]); + + assert_eq!("username=test", &request[size..]); + } + + #[test] + fn invalid_request_1() { + let request: &str = "GET /files/größe.txt HTTP/1.1\r\n\r\n"; + let mut parser = super::Parser::new_request_parser(request.as_bytes()); + match parser.parse() { + Ok(v) => panic!("should fail"), + Err(e) => assert_eq!("invalid character at position 13", e), + } + } + + #[test] + fn invalid_request_2() { + let request: &str = "GET /index.html HTT"; + let mut parser = super::Parser::new_request_parser(request.as_bytes()); + match parser.parse() { + Ok(v) => panic!("should fail"), + Err(e) => assert_eq!("input too short", e), + } + } + + #[test] + fn simple_response() { + let response: &str = "HTTP/1.1 200 OK\r\n\ + Content-Length: 12\r\n\ + Content-Type: text/plain; charset=us-ascii\r\n\ + \r\n\ + Hello world!"; + + let mut parser = super::Parser::new_response_parser(response.as_bytes()); + let size = parser.parse().unwrap(); + + assert_eq!(83, size); + assert_eq!("200", parser.status_code.unwrap()); + assert_eq!("OK", parser.status_message.unwrap()); + assert_eq!("1.1", parser.http_version.unwrap()); + assert_eq!(None, parser.method); + assert_eq!(None, parser.uri); + + assert_eq!(2, parser.headers.len()); + assert_eq!(("Content-Length", "12"), parser.headers[0]); + assert_eq!(("Content-Type", "text/plain; charset=us-ascii"), parser.headers[1]); + + assert_eq!("Hello world!", &response[size..]); + } + } } diff --git a/src/main.rs b/src/main.rs index 1c129da..78e1680 100644 --- a/src/main.rs +++ b/src/main.rs @@ -21,7 +21,7 @@ fn main() { for stream in tcp_socket.incoming() { pool_mutex_ref.lock().unwrap().execute(|| { let stream = stream.unwrap(); - http::connection_handler(http::Stream::Tcp(stream)); + http::handler::connection_handler(http::Stream::Tcp(stream)); }); } })); @@ -41,7 +41,7 @@ fn main() { pool_mutex_ref.lock().unwrap().execute(move || { let stream = stream.unwrap(); let stream = acceptor.accept(stream).unwrap(); - http::connection_handler(http::Stream::Ssl(stream)); + http::handler::connection_handler(http::Stream::Ssl(stream)); }); } })); @@ -59,6 +59,6 @@ fn main() { })); for thread in threads { - thread.join(); + thread.join().unwrap(); } }