mod consts; mod handler; mod parser; use openssl::ssl::SslStream; use std::fmt::Formatter; use std::io::{Read, Write}; use std::net::TcpStream; pub use handler::*; static REQUESTS_PER_CONNECTION: u32 = 200; pub enum Stream { Tcp(TcpStream), Ssl(SslStream), } pub enum Method { GET, POST, PUT, HEAD, TRACE, CONNECT, DELETE, OPTIONS, Custom(String), } impl Method { pub fn from_str(v: &str) -> Method { match v { "GET" => Method::GET, "POST" => Method::POST, "PUT" => Method::PUT, "HEAD" => Method::HEAD, "TRACE" => Method::TRACE, "CONNECT" => Method::CONNECT, "DELETE" => Method::DELETE, "OPTIONS" => Method::OPTIONS, _ => Method::Custom(String::from(v)), } } pub fn to_str(&self) -> &str { match self { Method::GET => "GET", Method::POST => "POST", Method::PUT => "PUT", Method::HEAD => "HEAD", Method::TRACE => "TRACE", Method::CONNECT => "CONNECT", Method::DELETE => "DELETE", Method::OPTIONS => "OPTIONS", Method::Custom(v) => v, } } } impl std::fmt::Display for Method { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.to_str()) } } #[derive(Copy, Clone)] pub enum StatusClass { Informational, Success, Redirection, ClientError, ServerError, } impl StatusClass { pub fn from_code(status_code: u16) -> StatusClass { for (code, class, _msg, _desc) in &consts::HTTP_STATUSES { if *code == status_code { return class.clone(); } } match status_code { 100..=199 => StatusClass::Informational, 200..=299 => StatusClass::Success, 300..=399 => StatusClass::Redirection, 400..=499 => StatusClass::ClientError, 500..=599 => StatusClass::ServerError, _ => panic!("invalid status code"), } } } #[derive(Clone)] pub struct Status { code: u16, message: String, desc: &'static str, class: StatusClass, } impl Status { pub fn from_code(status_code: u16) -> Status { for (code, class, msg, desc) in &consts::HTTP_STATUSES { if *code == status_code { return Status { code: status_code, message: msg.to_string(), desc, class: class.clone(), }; } } panic!("invalid status code"); } pub fn new_custom(status_code: u16, message: &str) -> Status { if status_code < 100 || status_code > 599 { panic!("invalid status code"); } let status = Status::from_code(status_code); Status { code: status_code, message: message.to_string(), desc: status.desc, class: status.class, } } } pub struct HeaderField { name: String, value: String, } impl std::fmt::Display for HeaderField { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "[{}: {}]", self.name, self.value) } } pub struct Request { version: String, pub method: Method, pub uri: String, header_fields: Vec, } pub struct Response { version: String, status: Status, header_fields: Vec, } impl Response { pub fn new() -> Response { Response { version: "1.1".to_string(), status: Status::from_code(200), header_fields: Vec::new(), } } pub fn status(&mut self, status_code: u16) { self.status = Status::from_code(status_code) } pub fn add_header(&mut self, name: &str, value: &str) { self.header_fields.push(HeaderField { name: String::from(name), value: String::from(value), }); } pub fn find_header(&self, header_name: &str) -> Option { for field in &self.header_fields { if field .name .to_lowercase() .eq(header_name.to_ascii_lowercase().as_str()) { return Some(field.value.clone()); } } return None; } pub fn send(&mut self, stream: &mut Stream) -> Result<(), std::io::Error> { self.add_header("Server", "Locutus"); self.add_header( "Date", chrono::Utc::now() .format("%a, %d %b %Y %H:%M:%S GMT") .to_string() .as_str(), ); let mut buf = None; if let None = self.find_header("Content-Length") { let (doc, color_name, color) = match self.status.class { StatusClass::Informational => (consts::INFO_DOCUMENT, "info", "#606060"), StatusClass::Success => (consts::SUCCESS_DOCUMENT, "success", "#008000"), StatusClass::Redirection => (consts::WARNING_DOCUMENT, "warning", "#E0C000"), StatusClass::ClientError => (consts::ERROR_DOCUMENT, "error", "#C00000"), StatusClass::ServerError => (consts::ERROR_DOCUMENT, "error", "#C00000"), }; let new_buf = consts::DEFAULT_DOCUMENT .replace("{status_code}", self.status.code.to_string().as_str()) .replace("{status_message}", self.status.message.as_str()) .replace("{hostname}", "localhost") // TODO hostname .replace("{theme_color}", color) .replace("{color_name}", color_name) .replace("{server_str}", "Locutus server") // TODO server string .replace( "{doc}", doc.replace("{code}", self.status.code.to_string().as_str()) .replace("{message}", self.status.message.as_str()) .replace("{desc}", self.status.desc) .replace("{info}", "") // TODO info string .as_str(), ) .replace("{{", "{") .replace("}}", "}"); self.add_header("Content-Length", new_buf.len().to_string().as_str()); self.add_header("Content-Type", "text/html; charset=utf-8"); buf = Some(new_buf); } let mut header = format!( "HTTP/{} {:03} {}\r\n", self.version, self.status.code, self.status.message ); for header_field in &self.header_fields { header.push_str(format!("{}: {}\r\n", header_field.name, header_field.value).as_str()); } header.push_str("\r\n"); stream.write_all(header.as_bytes())?; if let Some(buf) = buf { stream.write_all(buf.as_bytes()); } Ok(()) } } impl Stream { pub fn read(&mut self, buf: &mut [u8]) -> Result { match self { Stream::Tcp(stream) => stream.read(buf), Stream::Ssl(stream) => stream.read(buf), } } pub fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), std::io::Error> { match self { Stream::Tcp(stream) => stream.read_exact(buf), Stream::Ssl(stream) => stream.read_exact(buf), } } pub fn peek(&mut self, buf: &mut [u8]) -> Result { match self { Stream::Tcp(stream) => stream.peek(buf), Stream::Ssl(_stream) => todo!("Not implemented in rust-openssl"), } } pub fn write(&mut self, buf: &[u8]) -> Result { match self { Stream::Tcp(stream) => stream.write(buf), Stream::Ssl(stream) => stream.write(buf), } } pub fn write_all(&mut self, buf: &[u8]) -> Result<(), std::io::Error> { match self { Stream::Tcp(stream) => stream.write_all(buf), Stream::Ssl(stream) => stream.write_all(buf), } } }