diff --git a/src/database.rs b/src/database.rs new file mode 100644 index 0000000..4544705 --- /dev/null +++ b/src/database.rs @@ -0,0 +1,32 @@ +use r2d2_postgres::postgres::NoTls; +use r2d2_postgres::PostgresConnectionManager; +use std::ops::Deref; +use std::sync::{Arc, Mutex}; + +pub enum Pool { + Postgres(r2d2::Pool>), +} + +pub enum Client { + Postgres(r2d2::PooledConnection>), +} + +static mut POOL: Option>> = None; + +pub fn init() { + let manager = PostgresConnectionManager::new( + "host=localhost user=postgres dbname=test".parse().unwrap(), + NoTls, + ); + + let pool = r2d2::Pool::new(manager).unwrap(); + unsafe { + POOL = Some(Arc::new(Mutex::new(Pool::Postgres(pool)))); + } +} + +pub fn client() -> Client { + match unsafe { POOL.as_ref().unwrap().clone().lock().unwrap().deref() } { + Pool::Postgres(pool) => Client::Postgres(pool.get().unwrap()), + } +} diff --git a/src/http/handler.rs b/src/http/handler.rs index bf1e426..61739de 100644 --- a/src/http/handler.rs +++ b/src/http/handler.rs @@ -3,8 +3,8 @@ use crate::usimp; use crate::websocket; use chrono; use json; -use std::sync::{Arc, Mutex}; use std::borrow::Borrow; +use std::sync::{Arc, Mutex}; pub struct HttpStream { pub stream: super::Stream, @@ -120,7 +120,7 @@ fn endpoint_handler( "Unable to parse header: Content-Length missing", client, ) - }, + } } .parse() { @@ -131,7 +131,7 @@ fn endpoint_handler( format!("Unable to parse Content-Length: {}", &e).as_str(), client, ) - }, + } }; client.stream.read_exact(&mut buf[..length]); @@ -145,7 +145,7 @@ fn endpoint_handler( format!("Unable to parse payload: {}", &e).as_str(), client, ) - }, + } }) { Ok(val) => val, Err(e) => { @@ -154,7 +154,7 @@ fn endpoint_handler( format!("Unable to parse JSON: {}", &e).as_str(), client, ) - }, + } }; let output = usimp::endpoint(endpoint, input); diff --git a/src/http/mod.rs b/src/http/mod.rs index ca3eb35..72871ec 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -274,7 +274,10 @@ impl Response { doc.replace("{code}", self.status.code.to_string().as_str()) .replace("{message}", self.status.message.as_str()) .replace("{desc}", self.status.desc) - .replace("{info}", self.status.info.as_ref().unwrap_or(&String::new()).as_str()) + .replace( + "{info}", + self.status.info.as_ref().unwrap_or(&String::new()).as_str(), + ) .as_str(), ) .replace("{{", "{") diff --git a/src/http/parser.rs b/src/http/parser.rs index 15bc12e..c30707a 100644 --- a/src/http/parser.rs +++ b/src/http/parser.rs @@ -5,7 +5,7 @@ pub fn parse_request(stream: &mut http::Stream) -> Result, let mut buf = [0; 4096]; let size = stream.peek(&mut buf).unwrap(); if size == 0 { - return Ok(None) + return Ok(None); } let mut parser = Parser::new_request_parser(&buf[..size]); diff --git a/src/main.rs b/src/main.rs index 8d1413f..e52e0bf 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,18 +1,19 @@ +mod database; mod http; mod udp; mod usimp; mod websocket; use openssl::ssl::{SslAcceptor, SslFiletype, SslMethod, SslStream}; -use std::net::{TcpListener, UdpSocket, SocketAddr}; +use r2d2; +use r2d2::{ManageConnection, Pool}; +use r2d2_postgres::{postgres::NoTls, PostgresConnectionManager}; +use std::net::{SocketAddr, TcpListener, UdpSocket}; +use std::ops::Deref; use std::sync::Arc; use std::sync::Mutex; use std::thread; use threadpool::ThreadPool; -use r2d2_postgres::{PostgresConnectionManager, postgres::NoTls}; -use r2d2; -use r2d2::{ManageConnection, Pool}; -use std::ops::Deref; enum SocketType { Http, @@ -25,26 +26,6 @@ struct SocketConfig { socket_type: SocketType, } -pub enum BackendPool { - Postgres(Pool>), -} - -pub enum BackendClient { - Postgres(r2d2::PooledConnection>), -} - -static mut DB_POOL: Option>> = None; - -pub fn get_backend() -> BackendClient { - unsafe { - match DB_POOL.as_ref().unwrap().clone().lock().unwrap().deref() { - BackendPool::Postgres(pool) => { - BackendClient::Postgres(pool.get().unwrap()) - } - } - } -} - fn main() { let socket_configs: Vec = vec![ SocketConfig { @@ -58,17 +39,10 @@ fn main() { SocketConfig { address: "[::]:3126".parse().unwrap(), socket_type: SocketType::Udp, - } + }, ]; - let db_manager = PostgresConnectionManager::new( - "host=localhost user=postgres dbname=test".parse().unwrap(), - NoTls, - ); - let db_pool = r2d2::Pool::new(db_manager).unwrap(); - unsafe { - DB_POOL = Some(Arc::new(Mutex::new(BackendPool::Postgres(db_pool)))); - } + database::init(); let thread_pool = ThreadPool::new(256); let thread_pool_mutex = Arc::new(Mutex::new(thread_pool)); @@ -121,7 +95,10 @@ fn main() { loop { let (size, addr) = udp_socket.recv_from(&mut buf).unwrap(); let req = udp::Request::new(&udp_socket, addr, size, &buf); - thread_pool_mutex.lock().unwrap().execute(|| udp::handler(req)); + thread_pool_mutex + .lock() + .unwrap() + .execute(|| udp::handler(req)); } }), }); diff --git a/src/usimp/mod.rs b/src/usimp/mod.rs index bccc530..acb1e7e 100644 --- a/src/usimp/mod.rs +++ b/src/usimp/mod.rs @@ -1,11 +1,7 @@ +use crate::database; use json; -use r2d2::{ManageConnection, Pool, PooledConnection}; -use std::sync::{Arc, Mutex}; -use crate::{get_backend, BackendClient}; -static ENDPOINTS: [(&str, fn(json::JsonValue) -> json::JsonValue); 1] = [ - ("echo", echo) -]; +static ENDPOINTS: [(&str, fn(json::JsonValue) -> json::JsonValue); 1] = [("echo", echo)]; pub fn is_valid_endpoint(endpoint: &str) -> bool { for (name, _func) in &ENDPOINTS { @@ -26,10 +22,10 @@ pub fn endpoint(endpoint: &str, input: json::JsonValue) -> json::JsonValue { } pub fn echo(input: json::JsonValue) -> json::JsonValue { - let backend = get_backend(); + let backend = database::client(); let mut output = input.clone(); match backend { - BackendClient::Postgres(mut client) => { + database::Client::Postgres(mut client) => { let res = client.query("SELECT * FROM test", &[]).unwrap(); for row in res { let val: i32 = row.get(0);