From d25e03975131cdf5ccd0f762f09e3184e834f9e6 Mon Sep 17 00:00:00 2001
From: Lorenz Stechauner <lorenz.stechauner@necronda.net>
Date: Tue, 18 May 2021 18:44:17 +0200
Subject: [PATCH] Refactored database handling

---
 src/database.rs     | 32 ++++++++++++++++++++++++++++++
 src/http/handler.rs | 10 +++++-----
 src/http/mod.rs     |  5 ++++-
 src/http/parser.rs  |  2 +-
 src/main.rs         | 47 ++++++++++++---------------------------------
 src/usimp/mod.rs    | 12 ++++--------
 6 files changed, 58 insertions(+), 50 deletions(-)
 create mode 100644 src/database.rs

diff --git a/src/database.rs b/src/database.rs
new file mode 100644
index 0000000..4544705
--- /dev/null
+++ b/src/database.rs
@@ -0,0 +1,32 @@
+use r2d2_postgres::postgres::NoTls;
+use r2d2_postgres::PostgresConnectionManager;
+use std::ops::Deref;
+use std::sync::{Arc, Mutex};
+
+pub enum Pool {
+    Postgres(r2d2::Pool<PostgresConnectionManager<NoTls>>),
+}
+
+pub enum Client {
+    Postgres(r2d2::PooledConnection<PostgresConnectionManager<NoTls>>),
+}
+
+static mut POOL: Option<Arc<Mutex<Pool>>> = None;
+
+pub fn init() {
+    let manager = PostgresConnectionManager::new(
+        "host=localhost user=postgres dbname=test".parse().unwrap(),
+        NoTls,
+    );
+
+    let pool = r2d2::Pool::new(manager).unwrap();
+    unsafe {
+        POOL = Some(Arc::new(Mutex::new(Pool::Postgres(pool))));
+    }
+}
+
+pub fn client() -> Client {
+    match unsafe { POOL.as_ref().unwrap().clone().lock().unwrap().deref() } {
+        Pool::Postgres(pool) => Client::Postgres(pool.get().unwrap()),
+    }
+}
diff --git a/src/http/handler.rs b/src/http/handler.rs
index bf1e426..61739de 100644
--- a/src/http/handler.rs
+++ b/src/http/handler.rs
@@ -3,8 +3,8 @@ use crate::usimp;
 use crate::websocket;
 use chrono;
 use json;
-use std::sync::{Arc, Mutex};
 use std::borrow::Borrow;
+use std::sync::{Arc, Mutex};
 
 pub struct HttpStream {
     pub stream: super::Stream,
@@ -120,7 +120,7 @@ fn endpoint_handler(
                 "Unable to parse header: Content-Length missing",
                 client,
             )
-        },
+        }
     }
     .parse()
     {
@@ -131,7 +131,7 @@ fn endpoint_handler(
                 format!("Unable to parse Content-Length: {}", &e).as_str(),
                 client,
             )
-        },
+        }
     };
 
     client.stream.read_exact(&mut buf[..length]);
@@ -145,7 +145,7 @@ fn endpoint_handler(
                 format!("Unable to parse payload: {}", &e).as_str(),
                 client,
             )
-        },
+        }
     }) {
         Ok(val) => val,
         Err(e) => {
@@ -154,7 +154,7 @@ fn endpoint_handler(
                 format!("Unable to parse JSON: {}", &e).as_str(),
                 client,
             )
-        },
+        }
     };
     let output = usimp::endpoint(endpoint, input);
 
diff --git a/src/http/mod.rs b/src/http/mod.rs
index ca3eb35..72871ec 100644
--- a/src/http/mod.rs
+++ b/src/http/mod.rs
@@ -274,7 +274,10 @@ impl Response {
                 doc.replace("{code}", self.status.code.to_string().as_str())
                     .replace("{message}", self.status.message.as_str())
                     .replace("{desc}", self.status.desc)
-                    .replace("{info}", self.status.info.as_ref().unwrap_or(&String::new()).as_str())
+                    .replace(
+                        "{info}",
+                        self.status.info.as_ref().unwrap_or(&String::new()).as_str(),
+                    )
                     .as_str(),
             )
             .replace("{{", "{")
diff --git a/src/http/parser.rs b/src/http/parser.rs
index 15bc12e..c30707a 100644
--- a/src/http/parser.rs
+++ b/src/http/parser.rs
@@ -5,7 +5,7 @@ pub fn parse_request(stream: &mut http::Stream) -> Result<Option<http::Request>,
     let mut buf = [0; 4096];
     let size = stream.peek(&mut buf).unwrap();
     if size == 0 {
-        return Ok(None)
+        return Ok(None);
     }
 
     let mut parser = Parser::new_request_parser(&buf[..size]);
diff --git a/src/main.rs b/src/main.rs
index 8d1413f..e52e0bf 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,18 +1,19 @@
+mod database;
 mod http;
 mod udp;
 mod usimp;
 mod websocket;
 
 use openssl::ssl::{SslAcceptor, SslFiletype, SslMethod, SslStream};
-use std::net::{TcpListener, UdpSocket, SocketAddr};
+use r2d2;
+use r2d2::{ManageConnection, Pool};
+use r2d2_postgres::{postgres::NoTls, PostgresConnectionManager};
+use std::net::{SocketAddr, TcpListener, UdpSocket};
+use std::ops::Deref;
 use std::sync::Arc;
 use std::sync::Mutex;
 use std::thread;
 use threadpool::ThreadPool;
-use r2d2_postgres::{PostgresConnectionManager, postgres::NoTls};
-use r2d2;
-use r2d2::{ManageConnection, Pool};
-use std::ops::Deref;
 
 enum SocketType {
     Http,
@@ -25,26 +26,6 @@ struct SocketConfig {
     socket_type: SocketType,
 }
 
-pub enum BackendPool {
-    Postgres(Pool<PostgresConnectionManager<NoTls>>),
-}
-
-pub enum BackendClient {
-    Postgres(r2d2::PooledConnection<PostgresConnectionManager<NoTls>>),
-}
-
-static mut DB_POOL: Option<Arc<Mutex<BackendPool>>> = None;
-
-pub fn get_backend() -> BackendClient {
-    unsafe {
-        match DB_POOL.as_ref().unwrap().clone().lock().unwrap().deref() {
-            BackendPool::Postgres(pool) => {
-                BackendClient::Postgres(pool.get().unwrap())
-            }
-        }
-    }
-}
-
 fn main() {
     let socket_configs: Vec<SocketConfig> = vec![
         SocketConfig {
@@ -58,17 +39,10 @@ fn main() {
         SocketConfig {
             address: "[::]:3126".parse().unwrap(),
             socket_type: SocketType::Udp,
-        }
+        },
     ];
 
-    let db_manager = PostgresConnectionManager::new(
-        "host=localhost user=postgres dbname=test".parse().unwrap(),
-        NoTls,
-    );
-    let db_pool = r2d2::Pool::new(db_manager).unwrap();
-    unsafe {
-        DB_POOL = Some(Arc::new(Mutex::new(BackendPool::Postgres(db_pool))));
-    }
+    database::init();
 
     let thread_pool = ThreadPool::new(256);
     let thread_pool_mutex = Arc::new(Mutex::new(thread_pool));
@@ -121,7 +95,10 @@ fn main() {
                 loop {
                     let (size, addr) = udp_socket.recv_from(&mut buf).unwrap();
                     let req = udp::Request::new(&udp_socket, addr, size, &buf);
-                    thread_pool_mutex.lock().unwrap().execute(|| udp::handler(req));
+                    thread_pool_mutex
+                        .lock()
+                        .unwrap()
+                        .execute(|| udp::handler(req));
                 }
             }),
         });
diff --git a/src/usimp/mod.rs b/src/usimp/mod.rs
index bccc530..acb1e7e 100644
--- a/src/usimp/mod.rs
+++ b/src/usimp/mod.rs
@@ -1,11 +1,7 @@
+use crate::database;
 use json;
-use r2d2::{ManageConnection, Pool, PooledConnection};
-use std::sync::{Arc, Mutex};
-use crate::{get_backend, BackendClient};
 
-static ENDPOINTS: [(&str, fn(json::JsonValue) -> json::JsonValue); 1] = [
-    ("echo", echo)
-];
+static ENDPOINTS: [(&str, fn(json::JsonValue) -> json::JsonValue); 1] = [("echo", echo)];
 
 pub fn is_valid_endpoint(endpoint: &str) -> bool {
     for (name, _func) in &ENDPOINTS {
@@ -26,10 +22,10 @@ pub fn endpoint(endpoint: &str, input: json::JsonValue) -> json::JsonValue {
 }
 
 pub fn echo(input: json::JsonValue) -> json::JsonValue {
-    let backend = get_backend();
+    let backend = database::client();
     let mut output = input.clone();
     match backend {
-        BackendClient::Postgres(mut client) => {
+        database::Client::Postgres(mut client) => {
             let res = client.query("SELECT * FROM test", &[]).unwrap();
             for row in res {
                 let val: i32 = row.get(0);