diff --git a/Cargo.toml b/Cargo.toml index e3450b1..8861a68 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,3 +15,4 @@ chrono = "0.4" flate2 = "1.0.0" r2d2 = "0.8.9" r2d2_postgres = "0.18.0" +ansi_term = "0.12" diff --git a/src/database.rs b/src/database.rs index 4544705..5328b97 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,7 +1,9 @@ +use crate::error::Error; use r2d2_postgres::postgres::NoTls; use r2d2_postgres::PostgresConnectionManager; use std::ops::Deref; use std::sync::{Arc, Mutex}; +use std::time::Duration; pub enum Pool { Postgres(r2d2::Pool>), @@ -13,16 +15,24 @@ pub enum Client { static mut POOL: Option>> = None; -pub fn init() { +pub fn init() -> Result<(), Error> { let manager = PostgresConnectionManager::new( "host=localhost user=postgres dbname=test".parse().unwrap(), NoTls, ); - let pool = r2d2::Pool::new(manager).unwrap(); + let pool = r2d2::Pool::builder() + .max_size(64) + .min_idle(Some(2)) + .connection_timeout(Duration::from_secs(4)) + .max_lifetime(Some(Duration::from_secs(3600))) + .build(manager)?; + unsafe { POOL = Some(Arc::new(Mutex::new(Pool::Postgres(pool)))); } + + Ok(()) } pub fn client() -> Client { diff --git a/src/error.rs b/src/error.rs index ed1faa2..ac1cfd7 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,34 +1,50 @@ - pub enum ErrorKind { - InvalidEndpoint, + InvalidEndpointError, JsonParseError, + DatabaseError, } pub struct Error { kind: ErrorKind, + desc: Option, } impl Error { pub fn new(kind: ErrorKind) -> Error { - Error { - kind, - } - } -} - -impl From for Error { - fn from(_: serde_json::Error) -> Self { - Error { - kind: ErrorKind::JsonParseError, - } + Error { kind, desc: None } } } impl ToString for Error { fn to_string(&self) -> String { - match self.kind { - ErrorKind::InvalidEndpoint => "invalid endpoint", - ErrorKind::JsonParseError => "JSON parse error", - }.to_string() + let mut error = match self.kind { + ErrorKind::InvalidEndpointError => "invalid endpoint", + ErrorKind::JsonParseError => "unable to parse JSON data", + ErrorKind::DatabaseError => "unable to connect to database", + } + .to_string(); + if let Some(desc) = &self.desc { + error += ": "; + error += desc; + } + error + } +} + +impl From for Error { + fn from(error: serde_json::Error) -> Self { + Error { + kind: ErrorKind::JsonParseError, + desc: Some(error.to_string()), + } + } +} + +impl From for Error { + fn from(error: r2d2::Error) -> Self { + Error { + kind: ErrorKind::DatabaseError, + desc: Some(error.to_string()), + } } } diff --git a/src/main.rs b/src/main.rs index ea96209..78f4211 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,16 +3,18 @@ use std::sync::Arc; use std::sync::Mutex; use std::thread; +use ansi_term::Color; use openssl::ssl::{SslAcceptor, SslFiletype, SslMethod}; use rusty_pool; +use std::fmt::Formatter; use std::time::Duration; mod database; +mod error; mod http; mod udp; mod usimp; mod websocket; -mod error; enum SocketType { Http, @@ -20,12 +22,24 @@ enum SocketType { Udp, } +impl std::fmt::Display for SocketType { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", match self { + SocketType::Http => "http+ws", + SocketType::Https => "https+wss", + SocketType::Udp => "udp", + }) + } +} + struct SocketConfig { address: SocketAddr, socket_type: SocketType, } fn main() { + println!("Locutus server"); + let socket_configs: Vec = vec![ SocketConfig { address: "[::]:8080".parse().unwrap(), @@ -41,7 +55,13 @@ fn main() { }, ]; - database::init(); + // Note: rust's stdout is line buffered! + eprint!("Initializing database connection pool..."); + if let Err(error) = database::init() { + eprintln!("\n{}", Color::Red.bold().paint(error.to_string())); + std::process::exit(1); + } + eprintln!(" {}", Color::Green.paint("success")); let thread_pool = rusty_pool::Builder::new() .core_size(4) @@ -55,6 +75,12 @@ fn main() { for socket_config in socket_configs { let thread_pool_mutex = thread_pool_mutex.clone(); + eprintln!( + "Creating listening thread for {} ({})", + ansi_term::Style::new().bold().paint(socket_config.address.to_string()), + socket_config.socket_type + ); + threads.push(match socket_config.socket_type { SocketType::Http => thread::spawn(move || { let mut tcp_socket = TcpListener::bind(socket_config.address).unwrap(); diff --git a/src/usimp/mod.rs b/src/usimp/mod.rs index 867939c..345b300 100644 --- a/src/usimp/mod.rs +++ b/src/usimp/mod.rs @@ -7,7 +7,7 @@ use crate::error::*; pub fn endpoint(endpoint: &str, input: serde_json::Value) -> Result { match endpoint { "echo" => Ok(serde_json::to_value(echo(serde_json::from_value(input)?))?), - _ => Err(Error::new(ErrorKind::InvalidEndpoint)), + _ => Err(Error::new(ErrorKind::InvalidEndpointError)), } }