From 1c29a865aa42091dffc595723d39dedd5e1c6676 Mon Sep 17 00:00:00 2001 From: Lorenz Stechauner Date: Mon, 17 May 2021 22:49:43 +0200 Subject: [PATCH] Database pool working --- Cargo.toml | 2 + src/http/handler.rs | 4 +- src/main.rs | 158 ++++++++++++++++++++++++++++++-------------- src/usimp/mod.rs | 20 +++++- 4 files changed, 132 insertions(+), 52 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a5adee7..f6ffd1e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,3 +12,5 @@ json = "0.12.4" openssl = {version = "0.10", features = ["vendored"]} chrono = "0.4" flate2 = "1.0.0" +r2d2 = "0.8.9" +r2d2_postgres = "0.18.0" diff --git a/src/http/handler.rs b/src/http/handler.rs index 231c7c2..bf1e426 100644 --- a/src/http/handler.rs +++ b/src/http/handler.rs @@ -3,6 +3,8 @@ use crate::usimp; use crate::websocket; use chrono; use json; +use std::sync::{Arc, Mutex}; +use std::borrow::Borrow; pub struct HttpStream { pub stream: super::Stream, @@ -157,7 +159,7 @@ fn endpoint_handler( let output = usimp::endpoint(endpoint, input); // TODO compress - let buf = output.to_string() + "\r\n"; + let buf = json::stringify(output) + "\r\n"; let length = buf.as_bytes().len(); res.add_header("Content-Length", length.to_string().as_str()); res.add_header("Content-Type", "application/json; charset=utf-8"); diff --git a/src/main.rs b/src/main.rs index 89a8470..8d1413f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,68 +4,128 @@ mod usimp; mod websocket; use openssl::ssl::{SslAcceptor, SslFiletype, SslMethod, SslStream}; -use std::net::{TcpListener, UdpSocket}; +use std::net::{TcpListener, UdpSocket, SocketAddr}; use std::sync::Arc; use std::sync::Mutex; use std::thread; use threadpool::ThreadPool; +use r2d2_postgres::{PostgresConnectionManager, postgres::NoTls}; +use r2d2; +use r2d2::{ManageConnection, Pool}; +use std::ops::Deref; + +enum SocketType { + Http, + Https, + Udp, +} + +struct SocketConfig { + address: SocketAddr, + socket_type: SocketType, +} + +pub enum BackendPool { + Postgres(Pool>), +} + +pub enum BackendClient { + Postgres(r2d2::PooledConnection>), +} + +static mut DB_POOL: Option>> = None; + +pub fn get_backend() -> BackendClient { + unsafe { + match DB_POOL.as_ref().unwrap().clone().lock().unwrap().deref() { + BackendPool::Postgres(pool) => { + BackendClient::Postgres(pool.get().unwrap()) + } + } + } +} fn main() { + let socket_configs: Vec = vec![ + SocketConfig { + address: "[::]:8080".parse().unwrap(), + socket_type: SocketType::Http, + }, + SocketConfig { + address: "[::]:8443".parse().unwrap(), + socket_type: SocketType::Https, + }, + SocketConfig { + address: "[::]:3126".parse().unwrap(), + socket_type: SocketType::Udp, + } + ]; + + let db_manager = PostgresConnectionManager::new( + "host=localhost user=postgres dbname=test".parse().unwrap(), + NoTls, + ); + let db_pool = r2d2::Pool::new(db_manager).unwrap(); + unsafe { + DB_POOL = Some(Arc::new(Mutex::new(BackendPool::Postgres(db_pool)))); + } + + let thread_pool = ThreadPool::new(256); + let thread_pool_mutex = Arc::new(Mutex::new(thread_pool)); + let mut threads = Vec::new(); - let pool = ThreadPool::new(256); - let pool_mutex = Arc::new(Mutex::new(pool)); + for socket_config in socket_configs { + let thread_pool_mutex = thread_pool_mutex.clone(); - let pool_mutex_ref = pool_mutex.clone(); - threads.push(thread::spawn(move || { - let mut tcp_socket = TcpListener::bind("[::]:8080").unwrap(); + threads.push(match socket_config.socket_type { + SocketType::Http => thread::spawn(move || { + let mut tcp_socket = TcpListener::bind(socket_config.address).unwrap(); - for stream in tcp_socket.incoming() { - pool_mutex_ref.lock().unwrap().execute(|| { - let stream = stream.unwrap(); - http::connection_handler(http::Stream::Tcp(stream)); - }); - } - })); + for stream in tcp_socket.incoming() { + thread_pool_mutex.lock().unwrap().execute(|| { + let stream = stream.unwrap(); + http::connection_handler(http::Stream::Tcp(stream)); + }); + } + }), + SocketType::Https => thread::spawn(move || { + let mut ssl_socket = TcpListener::bind(socket_config.address).unwrap(); - let pool_mutex_ref = pool_mutex.clone(); - threads.push(thread::spawn(move || { - let mut ssl_socket = TcpListener::bind("[::]:8443").unwrap(); + let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); + acceptor + .set_certificate_chain_file("/home/lorenz/Certificates/chakotay.pem") + .unwrap(); + acceptor + .set_private_key_file( + "/home/lorenz/Certificates/priv/chakotay.key", + SslFiletype::PEM, + ) + .unwrap(); + acceptor.check_private_key().unwrap(); + let acceptor = Arc::new(acceptor.build()); - let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); - acceptor - .set_certificate_chain_file("/home/lorenz/Certificates/chakotay.pem") - .unwrap(); - acceptor - .set_private_key_file( - "/home/lorenz/Certificates/priv/chakotay.key", - SslFiletype::PEM, - ) - .unwrap(); - acceptor.check_private_key().unwrap(); - let acceptor = Arc::new(acceptor.build()); + for stream in ssl_socket.incoming() { + let acceptor = acceptor.clone(); + thread_pool_mutex.lock().unwrap().execute(move || { + let stream = stream.unwrap(); + let stream = acceptor.accept(stream).unwrap(); + http::connection_handler(http::Stream::Ssl(stream)); + }); + } + }), + SocketType::Udp => thread::spawn(move || { + let mut udp_socket = UdpSocket::bind(socket_config.address).unwrap(); + let mut buf = [0; 65_536]; - for stream in ssl_socket.incoming() { - let acceptor = acceptor.clone(); - pool_mutex_ref.lock().unwrap().execute(move || { - let stream = stream.unwrap(); - let stream = acceptor.accept(stream).unwrap(); - http::connection_handler(http::Stream::Ssl(stream)); - }); - } - })); - - let pool_mutex_ref = pool_mutex.clone(); - threads.push(thread::spawn(move || { - let mut udp_socket = UdpSocket::bind("[::]:12345").unwrap(); - let mut buf = [0; 65_536]; - - loop { - let (size, addr) = udp_socket.recv_from(&mut buf).unwrap(); - let req = udp::Request::new(&udp_socket, addr, size, &buf); - pool_mutex_ref.lock().unwrap().execute(|| udp::handler(req)); - } - })); + loop { + let (size, addr) = udp_socket.recv_from(&mut buf).unwrap(); + let req = udp::Request::new(&udp_socket, addr, size, &buf); + thread_pool_mutex.lock().unwrap().execute(|| udp::handler(req)); + } + }), + }); + } for thread in threads { thread.join().unwrap(); diff --git a/src/usimp/mod.rs b/src/usimp/mod.rs index 7cbc2f0..bccc530 100644 --- a/src/usimp/mod.rs +++ b/src/usimp/mod.rs @@ -1,6 +1,11 @@ use json; +use r2d2::{ManageConnection, Pool, PooledConnection}; +use std::sync::{Arc, Mutex}; +use crate::{get_backend, BackendClient}; -static ENDPOINTS: [(&str, fn(json::JsonValue) -> json::JsonValue); 1] = [("echo", echo)]; +static ENDPOINTS: [(&str, fn(json::JsonValue) -> json::JsonValue); 1] = [ + ("echo", echo) +]; pub fn is_valid_endpoint(endpoint: &str) -> bool { for (name, _func) in &ENDPOINTS { @@ -21,5 +26,16 @@ pub fn endpoint(endpoint: &str, input: json::JsonValue) -> json::JsonValue { } pub fn echo(input: json::JsonValue) -> json::JsonValue { - input + let backend = get_backend(); + let mut output = input.clone(); + match backend { + BackendClient::Postgres(mut client) => { + let res = client.query("SELECT * FROM test", &[]).unwrap(); + for row in res { + let val: i32 = row.get(0); + output["database"] = val.into(); + } + } + } + output }