Refactored database handling

This commit is contained in:
2021-05-18 18:44:17 +02:00
parent a89068047b
commit d25e039751
6 changed files with 58 additions and 50 deletions

32
src/database.rs Normal file
View File

@ -0,0 +1,32 @@
use r2d2_postgres::postgres::NoTls;
use r2d2_postgres::PostgresConnectionManager;
use std::ops::Deref;
use std::sync::{Arc, Mutex};
pub enum Pool {
Postgres(r2d2::Pool<PostgresConnectionManager<NoTls>>),
}
pub enum Client {
Postgres(r2d2::PooledConnection<PostgresConnectionManager<NoTls>>),
}
static mut POOL: Option<Arc<Mutex<Pool>>> = None;
pub fn init() {
let manager = PostgresConnectionManager::new(
"host=localhost user=postgres dbname=test".parse().unwrap(),
NoTls,
);
let pool = r2d2::Pool::new(manager).unwrap();
unsafe {
POOL = Some(Arc::new(Mutex::new(Pool::Postgres(pool))));
}
}
pub fn client() -> Client {
match unsafe { POOL.as_ref().unwrap().clone().lock().unwrap().deref() } {
Pool::Postgres(pool) => Client::Postgres(pool.get().unwrap()),
}
}

View File

@ -3,8 +3,8 @@ use crate::usimp;
use crate::websocket; use crate::websocket;
use chrono; use chrono;
use json; use json;
use std::sync::{Arc, Mutex};
use std::borrow::Borrow; use std::borrow::Borrow;
use std::sync::{Arc, Mutex};
pub struct HttpStream { pub struct HttpStream {
pub stream: super::Stream, pub stream: super::Stream,
@ -120,7 +120,7 @@ fn endpoint_handler(
"Unable to parse header: Content-Length missing", "Unable to parse header: Content-Length missing",
client, client,
) )
}, }
} }
.parse() .parse()
{ {
@ -131,7 +131,7 @@ fn endpoint_handler(
format!("Unable to parse Content-Length: {}", &e).as_str(), format!("Unable to parse Content-Length: {}", &e).as_str(),
client, client,
) )
}, }
}; };
client.stream.read_exact(&mut buf[..length]); client.stream.read_exact(&mut buf[..length]);
@ -145,7 +145,7 @@ fn endpoint_handler(
format!("Unable to parse payload: {}", &e).as_str(), format!("Unable to parse payload: {}", &e).as_str(),
client, client,
) )
}, }
}) { }) {
Ok(val) => val, Ok(val) => val,
Err(e) => { Err(e) => {
@ -154,7 +154,7 @@ fn endpoint_handler(
format!("Unable to parse JSON: {}", &e).as_str(), format!("Unable to parse JSON: {}", &e).as_str(),
client, client,
) )
}, }
}; };
let output = usimp::endpoint(endpoint, input); let output = usimp::endpoint(endpoint, input);

View File

@ -274,7 +274,10 @@ impl Response {
doc.replace("{code}", self.status.code.to_string().as_str()) doc.replace("{code}", self.status.code.to_string().as_str())
.replace("{message}", self.status.message.as_str()) .replace("{message}", self.status.message.as_str())
.replace("{desc}", self.status.desc) .replace("{desc}", self.status.desc)
.replace("{info}", self.status.info.as_ref().unwrap_or(&String::new()).as_str()) .replace(
"{info}",
self.status.info.as_ref().unwrap_or(&String::new()).as_str(),
)
.as_str(), .as_str(),
) )
.replace("{{", "{") .replace("{{", "{")

View File

@ -5,7 +5,7 @@ pub fn parse_request(stream: &mut http::Stream) -> Result<Option<http::Request>,
let mut buf = [0; 4096]; let mut buf = [0; 4096];
let size = stream.peek(&mut buf).unwrap(); let size = stream.peek(&mut buf).unwrap();
if size == 0 { if size == 0 {
return Ok(None) return Ok(None);
} }
let mut parser = Parser::new_request_parser(&buf[..size]); let mut parser = Parser::new_request_parser(&buf[..size]);

View File

@ -1,18 +1,19 @@
mod database;
mod http; mod http;
mod udp; mod udp;
mod usimp; mod usimp;
mod websocket; mod websocket;
use openssl::ssl::{SslAcceptor, SslFiletype, SslMethod, SslStream}; use openssl::ssl::{SslAcceptor, SslFiletype, SslMethod, SslStream};
use std::net::{TcpListener, UdpSocket, SocketAddr}; use r2d2;
use r2d2::{ManageConnection, Pool};
use r2d2_postgres::{postgres::NoTls, PostgresConnectionManager};
use std::net::{SocketAddr, TcpListener, UdpSocket};
use std::ops::Deref;
use std::sync::Arc; use std::sync::Arc;
use std::sync::Mutex; use std::sync::Mutex;
use std::thread; use std::thread;
use threadpool::ThreadPool; use threadpool::ThreadPool;
use r2d2_postgres::{PostgresConnectionManager, postgres::NoTls};
use r2d2;
use r2d2::{ManageConnection, Pool};
use std::ops::Deref;
enum SocketType { enum SocketType {
Http, Http,
@ -25,26 +26,6 @@ struct SocketConfig {
socket_type: SocketType, socket_type: SocketType,
} }
pub enum BackendPool {
Postgres(Pool<PostgresConnectionManager<NoTls>>),
}
pub enum BackendClient {
Postgres(r2d2::PooledConnection<PostgresConnectionManager<NoTls>>),
}
static mut DB_POOL: Option<Arc<Mutex<BackendPool>>> = None;
pub fn get_backend() -> BackendClient {
unsafe {
match DB_POOL.as_ref().unwrap().clone().lock().unwrap().deref() {
BackendPool::Postgres(pool) => {
BackendClient::Postgres(pool.get().unwrap())
}
}
}
}
fn main() { fn main() {
let socket_configs: Vec<SocketConfig> = vec![ let socket_configs: Vec<SocketConfig> = vec![
SocketConfig { SocketConfig {
@ -58,17 +39,10 @@ fn main() {
SocketConfig { SocketConfig {
address: "[::]:3126".parse().unwrap(), address: "[::]:3126".parse().unwrap(),
socket_type: SocketType::Udp, socket_type: SocketType::Udp,
} },
]; ];
let db_manager = PostgresConnectionManager::new( database::init();
"host=localhost user=postgres dbname=test".parse().unwrap(),
NoTls,
);
let db_pool = r2d2::Pool::new(db_manager).unwrap();
unsafe {
DB_POOL = Some(Arc::new(Mutex::new(BackendPool::Postgres(db_pool))));
}
let thread_pool = ThreadPool::new(256); let thread_pool = ThreadPool::new(256);
let thread_pool_mutex = Arc::new(Mutex::new(thread_pool)); let thread_pool_mutex = Arc::new(Mutex::new(thread_pool));
@ -121,7 +95,10 @@ fn main() {
loop { loop {
let (size, addr) = udp_socket.recv_from(&mut buf).unwrap(); let (size, addr) = udp_socket.recv_from(&mut buf).unwrap();
let req = udp::Request::new(&udp_socket, addr, size, &buf); let req = udp::Request::new(&udp_socket, addr, size, &buf);
thread_pool_mutex.lock().unwrap().execute(|| udp::handler(req)); thread_pool_mutex
.lock()
.unwrap()
.execute(|| udp::handler(req));
} }
}), }),
}); });

View File

@ -1,11 +1,7 @@
use crate::database;
use json; use json;
use r2d2::{ManageConnection, Pool, PooledConnection};
use std::sync::{Arc, Mutex};
use crate::{get_backend, BackendClient};
static ENDPOINTS: [(&str, fn(json::JsonValue) -> json::JsonValue); 1] = [ static ENDPOINTS: [(&str, fn(json::JsonValue) -> json::JsonValue); 1] = [("echo", echo)];
("echo", echo)
];
pub fn is_valid_endpoint(endpoint: &str) -> bool { pub fn is_valid_endpoint(endpoint: &str) -> bool {
for (name, _func) in &ENDPOINTS { for (name, _func) in &ENDPOINTS {
@ -26,10 +22,10 @@ pub fn endpoint(endpoint: &str, input: json::JsonValue) -> json::JsonValue {
} }
pub fn echo(input: json::JsonValue) -> json::JsonValue { pub fn echo(input: json::JsonValue) -> json::JsonValue {
let backend = get_backend(); let backend = database::client();
let mut output = input.clone(); let mut output = input.clone();
match backend { match backend {
BackendClient::Postgres(mut client) => { database::Client::Postgres(mut client) => {
let res = client.query("SELECT * FROM test", &[]).unwrap(); let res = client.query("SELECT * FROM test", &[]).unwrap();
for row in res { for row in res {
let val: i32 = row.get(0); let val: i32 = row.get(0);