From c92742275a0c86922f8b6f385095330ceb370f29 Mon Sep 17 00:00:00 2001 From: Lorenz Stechauner Date: Thu, 18 Aug 2022 03:07:54 +0200 Subject: [PATCH] Implement WebSocket reverse proxy --- README.md | 2 +- src/client.c | 73 +++++++++++----- src/lib/http.h | 1 + src/lib/rev_proxy.c | 26 ++++-- src/lib/sock.c | 7 +- src/lib/sock.h | 1 + src/lib/utils.c | 1 + src/lib/websocket.c | 202 ++++++++++++++++++++++++++++++++++++++++++++ src/lib/websocket.h | 36 ++++++++ src/necronda.h | 2 + src/server.h | 1 - 11 files changed, 319 insertions(+), 33 deletions(-) create mode 100644 src/lib/websocket.c create mode 100644 src/lib/websocket.h diff --git a/README.md b/README.md index 1f1844d..a303b35 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ Necronda web server * File compression ([gzip](https://www.gzip.org/), [Brotli](https://www.brotli.org/)) * Disk cache for compressed files * Reverse proxy for other HTTP and HTTPS servers - * Transparent WebSocket reverse proxy **[WIP]** + * Transparent WebSocket reverse proxy * FastCGI support (e.g. [PHP-FPM](https://php-fpm.org/)) * Automatic path info detection (e.g. `/my/file/extra/path` -> script: `/my/file.php`, path info: `extra/path`) * Support for [MaxMind's GeoIP Database](https://www.maxmind.com/en/geoip2-services-and-databases) diff --git a/src/client.c b/src/client.c index b4ffa36..e7de31b 100644 --- a/src/client.c +++ b/src/client.c @@ -5,8 +5,8 @@ * Lorenz Stechauner, 2020-12-03 */ -#include "client.h" #include "necronda.h" +#include "client.h" #include "server.h" #include "lib/utils.h" @@ -18,6 +18,7 @@ #include "lib/cache.h" #include "lib/geoip.h" #include "lib/compress.h" +#include "lib/websocket.h" #include #include @@ -31,7 +32,6 @@ int server_keep_alive = 1; struct timeval client_timeout = {.tv_sec = CLIENT_TIMEOUT, .tv_usec = 0}; -int server_keep_alive; char *log_client_prefix, *log_conn_prefix, *log_req_prefix, *client_geoip; char *client_addr_str, *client_addr_str_ptr, *server_addr_str, *server_addr_str_ptr, *client_host_str; @@ -52,11 +52,6 @@ void client_terminate() { server_keep_alive = 0; } -int client_websocket_handler() { - // TODO implement client_websocket_handler - return 0; -} - int client_request_handler(sock *client, unsigned long client_num, unsigned int req_num) { struct timespec begin, end; long ret; @@ -86,7 +81,7 @@ int client_request_handler(sock *client, unsigned long client_num, unsigned int http_status custom_status; http_res res = {.version = "1.1", .status = http_get_status(501), .hdr.field_num = 0, .hdr.last_field_num = -1}; - http_status_ctx ctx = {.status = 0, .origin = NONE}; + http_status_ctx ctx = {.status = 0, .origin = NONE, .ws_key = NULL}; clock_gettime(CLOCK_MONOTONIC, &begin); @@ -125,7 +120,7 @@ int client_request_handler(sock *client, unsigned long client_num, unsigned int } hdr_connection = http_get_header_field(&req.hdr, "Connection"); - client_keep_alive = (hdr_connection != NULL && (strcmp(hdr_connection, "keep-alive") == 0 || strcmp(hdr_connection, "Keep-Alive") == 0)); + client_keep_alive = (hdr_connection != NULL && (strstr(hdr_connection, "keep-alive") != NULL || strstr(hdr_connection, "Keep-Alive") != NULL)); host_ptr = http_get_header_field(&req.hdr, "Host"); if (host_ptr != NULL && strlen(host_ptr) > 255) { host[0] = 0; @@ -488,6 +483,25 @@ int client_request_handler(sock *client, unsigned long client_num, unsigned int ret = rev_proxy_init(&req, &res, &ctx, conf, client, &custom_status, err_msg); use_rev_proxy = (ret == 0); + if (res.status->code == 101) { + const char *connection = http_get_header_field(&res.hdr, "Connection"); + const char *upgrade = http_get_header_field(&res.hdr, "Upgrade"); + if (connection != NULL && upgrade != NULL && + (strstr(connection, "upgrade") != NULL || strstr(connection, "Upgrade") != NULL) && + strcmp(upgrade, "websocket") == 0) + { + const char *ws_accept = http_get_header_field(&res.hdr, "Sec-WebSocket-Accept"); + if (ws_calc_accept_key(ctx.ws_key, buf0) == 0) { + use_rev_proxy = (strcmp(buf0, ws_accept) == 0) ? 2 : 1; + } + } else { + print("Fail Test1"); + ctx.status = 101; + ctx.origin = INTERNAL; + res.status = http_get_status(501); + } + } + // Let 300 be formatted by origin server if (use_rev_proxy && res.status->code >= 301 && res.status->code < 600) { const char *content_type = http_get_header_field(&res.hdr, "Content-Type"); @@ -496,8 +510,10 @@ int client_request_handler(sock *client, unsigned long client_num, unsigned int if (content_encoding == NULL && content_type != NULL && content_length_f != NULL && strncmp(content_type, "text/html", 9) == 0) { long content_len = strtol(content_length_f, NULL, 10); if (content_len <= sizeof(msg_content) - 1) { - ctx.status = res.status->code; - ctx.origin = res.status->code >= 400 ? SERVER : NONE; + if (ctx.status != 101) { + ctx.status = res.status->code; + ctx.origin = res.status->code >= 400 ? SERVER : NONE; + } use_rev_proxy = 0; rev_proxy_dump(msg_content, content_len); } @@ -604,16 +620,19 @@ int client_request_handler(sock *client, unsigned long client_num, unsigned int } } - const char *conn = http_get_header_field(&res.hdr, "Connection"); - int close_proxy = (conn == NULL || (strcmp(conn, "keep-alive") != 0 && strcmp(conn, "Keep-Alive") != 0)); - http_remove_header_field(&res.hdr, "Connection", HTTP_REMOVE_ALL); - http_remove_header_field(&res.hdr, "Keep-Alive", HTTP_REMOVE_ALL); - if (server_keep_alive && client_keep_alive) { - http_add_header_field(&res.hdr, "Connection", "keep-alive"); - sprintf(buf0, "timeout=%i, max=%i", CLIENT_TIMEOUT, REQ_PER_CONNECTION); - http_add_header_field(&res.hdr, "Keep-Alive", buf0); - } else { - http_add_header_field(&res.hdr, "Connection", "close"); + int close_proxy = 0; + if (use_rev_proxy != 2) { + const char *conn = http_get_header_field(&res.hdr, "Connection"); + close_proxy = (conn == NULL || (strstr(conn, "keep-alive") == NULL && strstr(conn, "Keep-Alive") == NULL)); + http_remove_header_field(&res.hdr, "Connection", HTTP_REMOVE_ALL); + http_remove_header_field(&res.hdr, "Keep-Alive", HTTP_REMOVE_ALL); + if (server_keep_alive && client_keep_alive) { + http_add_header_field(&res.hdr, "Connection", "keep-alive"); + sprintf(buf0, "timeout=%i, max=%i", CLIENT_TIMEOUT, REQ_PER_CONNECTION); + http_add_header_field(&res.hdr, "Keep-Alive", buf0); + } else { + http_add_header_field(&res.hdr, "Connection", "close"); + } } http_send_response(client, &res); @@ -626,7 +645,17 @@ int client_request_handler(sock *client, unsigned long client_num, unsigned int // TODO access/error log file - if (strcmp(req.method, "HEAD") != 0) { + if (use_rev_proxy == 2) { + // WebSocket + print("Upgrading connection to WebSocket connection"); + ret = ws_handle_connection(client, &rev_proxy); + if (ret != 0) { + client_keep_alive = 0; + close_proxy = 1; + } + print("WebSocket connection closed"); + } else if (strcmp(req.method, "HEAD") != 0) { + // default response unsigned long snd_len = 0; unsigned long len; if (msg_buf[0] != 0) { diff --git a/src/lib/http.h b/src/lib/http.h index 6d76246..d0e5589 100644 --- a/src/lib/http.h +++ b/src/lib/http.h @@ -108,6 +108,7 @@ typedef enum { typedef struct { unsigned short status; http_error_origin origin; + const char* ws_key; } http_status_ctx; extern const http_status http_statuses[]; diff --git a/src/lib/rev_proxy.c b/src/lib/rev_proxy.c index dd09b5a..8b64601 100644 --- a/src/lib/rev_proxy.c +++ b/src/lib/rev_proxy.c @@ -5,10 +5,11 @@ * Lorenz Stechauner, 2021-01-07 */ +#include "../necronda.h" +#include "../server.h" #include "rev_proxy.h" #include "utils.h" #include "compress.h" -#include "../server.h" #include #include @@ -33,8 +34,6 @@ int rev_proxy_preload() { int rev_proxy_request_header(http_req *req, int enc) { char buf1[256], buf2[256]; int p_len; - http_remove_header_field(&req->hdr, "Connection", HTTP_REMOVE_ALL); - http_add_header_field(&req->hdr, "Connection", "keep-alive"); const char *via = http_get_header_field(&req->hdr, "Via"); sprintf(buf1, "HTTP/%s %s", req->version, SERVER_NAME); @@ -184,12 +183,12 @@ int rev_proxy_response_header(http_req *req, http_res *res, host_config *conf) { int rev_proxy_init(http_req *req, http_res *res, http_status_ctx *ctx, host_config *conf, sock *client, http_status *custom_status, char *err_msg) { char buffer[CHUNK_SIZE]; + const char *connection, *upgrade, *ws_version; long ret; int tries = 0, retry = 0; - if (rev_proxy.socket != 0 && strcmp(rev_proxy_host, conf->name) == 0 && sock_check(&rev_proxy) == 0) { + if (rev_proxy.socket != 0 && strcmp(rev_proxy_host, conf->name) == 0 && sock_check(&rev_proxy) == 0) goto rev_proxy; - } retry: if (rev_proxy.socket != 0) { @@ -290,6 +289,22 @@ int rev_proxy_init(http_req *req, http_res *res, http_status_ctx *ctx, host_conf print(BLUE_STR "Established new connection with " BLD_STR "[%s]:%i" CLR_STR, buffer, conf->rev_proxy.port); rev_proxy: + connection = http_get_header_field(&req->hdr, "Connection"); + if (connection != NULL && (strstr(connection, "upgrade") != NULL || strstr(connection, "Upgrade") != NULL)) { + upgrade = http_get_header_field(&req->hdr, "Upgrade"); + ws_version = http_get_header_field(&req->hdr, "Sec-WebSocket-Version"); + if (upgrade != NULL && ws_version != NULL && strcmp(upgrade, "websocket") == 0 && strcmp(ws_version, "13") == 0) { + ctx->ws_key = http_get_header_field(&req->hdr, "Sec-WebSocket-Key"); + } else { + res->status = http_get_status(501); + ctx->origin = INTERNAL; + return -1; + } + } else { + http_remove_header_field(&req->hdr, "Connection", HTTP_REMOVE_ALL); + http_add_header_field(&req->hdr, "Connection", "keep-alive"); + } + ret = rev_proxy_request_header(req, (int) client->enc); if (ret != 0) { res->status = http_get_status(500); @@ -454,7 +469,6 @@ int rev_proxy_init(http_req *req, http_res *res, http_status_ctx *ctx, host_conf } int rev_proxy_send(sock *client, unsigned long len_to_send, int flags) { - // TODO handle websockets char buffer[CHUNK_SIZE], comp_out[CHUNK_SIZE], buf[256], *ptr; long ret = 0, len, snd_len; int finish_comp = 0; diff --git a/src/lib/sock.c b/src/lib/sock.c index 5c0cfd5..ba79cd5 100644 --- a/src/lib/sock.c +++ b/src/lib/sock.c @@ -130,11 +130,12 @@ int sock_poll(sock *sockets[], sock *ready[], short events, int n_sock, int time int ret = poll(fds, n_sock, timeout_ms); if (ret < 0 || ready == NULL) return ret; - for (int i = 0, j = 0; i < ret; j++) { + int j = 0; + for (int i = 0; i < n_sock; i++) { if (fds[i].revents & events) - ready[i++] = sockets[j]; + ready[j++] = sockets[i]; } - return ret; + return j; } int sock_poll_read(sock *sockets[], sock *readable[], int n_sock, int timeout_ms) { diff --git a/src/lib/sock.h b/src/lib/sock.h index 5487f46..ba89d6d 100644 --- a/src/lib/sock.h +++ b/src/lib/sock.h @@ -9,6 +9,7 @@ #define NECRONDA_SERVER_SOCK_H #include +#include typedef struct { unsigned int enc:1; diff --git a/src/lib/utils.c b/src/lib/utils.c index 2661f0d..dc29307 100644 --- a/src/lib/utils.c +++ b/src/lib/utils.c @@ -192,6 +192,7 @@ int base64_encode(void *data, unsigned long data_len, char *output, unsigned lon for (int i = 0; i < base64_mod_table[data_len % 3]; i++) output[out_len - 1 - i] = '='; + output[out_len] = 0; return 0; } diff --git a/src/lib/websocket.c b/src/lib/websocket.c new file mode 100644 index 0000000..0070485 --- /dev/null +++ b/src/lib/websocket.c @@ -0,0 +1,202 @@ +/** + * Necronda Web Server + * WebSocket reverse proxy + * src/lib/websocket.c + * Lorenz Stechauner, 2022-08-16 + */ + +#include "../necronda.h" +#include "websocket.h" +#include "utils.h" + +#include +#include +#include +#include + + +int terminate = 0; + +void ws_terminate() { + terminate = 1; +} + +int ws_calc_accept_key(const char *key, char *accept_key) { + if (key == NULL || accept_key == NULL) + return -1; + + char input[256] = ""; + unsigned char output[SHA_DIGEST_LENGTH]; + strcat(input, key); + strcat(input, ws_key_uuid); + + if (SHA1((unsigned char *) input, strlen(input), output) == NULL) { + return -2; + } + + base64_encode(output, sizeof(output), accept_key, NULL); + + return 0; +} + +int ws_recv_frame_header(sock *s, ws_frame *frame) { + unsigned char buf[12]; + + long ret = sock_recv(s, buf, 2, 0); + if (ret < 0) { + print(ERR_STR "Unable to receive from socket: %s" CLR_STR, strerror(errno)); + return -1; + } else if (ret != 2) { + print(ERR_STR "Unable to receive 2 bytes from socket" CLR_STR); + return -2; + } + + unsigned short bits = (buf[0] << 8) | buf[1]; + frame->f_fin = (bits >> 15) & 1; + frame->f_rsv1 = (bits >> 14) & 1; + frame->f_rsv2 = (bits >> 13) & 1; + frame->f_rsv3 = (bits >> 12) & 1; + frame->opcode = (bits >> 8) & 0xF; + frame->f_mask = (bits >> 7) & 1; + unsigned short len = (bits & 0x7F); + + int remaining = frame->f_mask ? 4 : 0; + if (len == 126) { + remaining += 2; + } else if (len == 127) { + remaining += 8; + } + + ret = sock_recv(s, buf, remaining, 0); + if (ret < 0) { + print(ERR_STR "Unable to receive from socket: %s" CLR_STR, strerror(errno)); + return -1; + } else if (ret != remaining) { + print(ERR_STR "Unable to receive correct number of bytes from socket" CLR_STR); + return -2; + } + + if (len == 126) { + frame->len = (((unsigned long) buf[0]) << 8) | ((unsigned long) buf[1]); + } else if (len == 127) { + frame->len = + (((unsigned long) buf[0]) << 56) | + (((unsigned long) buf[1]) << 48) | + (((unsigned long) buf[2]) << 40) | + (((unsigned long) buf[3]) << 32) | + (((unsigned long) buf[4]) << 24) | + (((unsigned long) buf[5]) << 16) | + (((unsigned long) buf[6]) << 8) | + (((unsigned long) buf[7]) << 0); + } else { + frame->len = len; + } + + if (frame->f_mask) memcpy(frame->masking_key, buf + (remaining - 4), 4); + + return 0; +} + +int ws_send_frame_header(sock *s, ws_frame *frame) { + unsigned char buf[14], *ptr = buf; + + unsigned short len; + if (frame->len > 0x7FFF) { + len = 127; + } else if (frame->len > 125) { + len = 126; + } else { + len = frame->len; + } + + unsigned short bits = + (frame->f_fin << 15) | + (frame->f_rsv1 << 14) | + (frame->f_rsv2 << 13) | + (frame->f_rsv3 << 12) | + (frame->opcode << 8) | + (frame->f_mask << 7) | + len; + + ptr++[0] = bits >> 8; + ptr++[0] = bits & 0xFF; + + if (len >= 126) { + for (int i = (len == 126 ? 2 : 8) - 1; i >= 0; i--) + ptr++[0] = (unsigned char) ((frame->len >> (i * 8)) & 0xFF); + } + + if (frame->f_mask) { + memcpy(ptr, frame->masking_key, 4); + ptr += 4; + } + + long ret = sock_send(s, buf, ptr - buf, frame->len != 0 ? MSG_MORE : 0); + if (ret < 0) { + print(ERR_STR "Unable to send to socket: %s" CLR_STR, strerror(errno)); + return -1; + } else if (ret != ptr - buf) { + print(ERR_STR "Unable to send to socket" CLR_STR); + return -2; + } + + return 0; +} + +int ws_handle_connection(sock *s1, sock *s2) { + sock *poll_socks[2] = {s1, s2}; + sock *readable[2]; + int n_sock = 2; + ws_frame frame; + char buf[CHUNK_SIZE]; + int poll, closes = 0; + long ret; + + signal(SIGINT, ws_terminate); + signal(SIGTERM, ws_terminate); + + while (!terminate && closes != 3) { + poll = sock_poll_read(poll_socks, readable, n_sock, WS_TIMEOUT * 1000); + if (terminate) { + break; + } else if (poll < 0) { + print(ERR_STR "Unable to poll sockets: %s" CLR_STR, strerror(errno)); + return -1; + } else if (poll == 0) { + print(ERR_STR "Connection timed out" CLR_STR); + return -2; + } + + for (int i = 0; i < poll; i++) { + sock *s = readable[i]; + sock *o = (s == s1) ? s2 : s1; + if (ws_recv_frame_header(s, &frame) != 0) return -3; + + if (frame.opcode == 0x8) { + n_sock--; + if (s == s1) { + poll_socks[0] = s2; + closes |= 1; + } else { + closes |= 2; + } + } + + if (ws_send_frame_header(o, &frame) != 0) return -3; + + if (frame.len > 0) { + ret = sock_splice(o, s, buf, sizeof(buf), frame.len); + if (ret < 0) { + print(ERR_STR "Unable to forward data in WebSocket: %s" CLR_STR, strerror(errno)); + return -3; + } else if (ret != frame.len) { + print(ERR_STR "Unable to forward correct number of bytes in WebSocket" CLR_STR); + return -3; + } + } + } + } + + return 0; +} + diff --git a/src/lib/websocket.h b/src/lib/websocket.h new file mode 100644 index 0000000..250efce --- /dev/null +++ b/src/lib/websocket.h @@ -0,0 +1,36 @@ +/** + * Necronda Web Server + * WebSocket reverse proxy (header file) + * src/lib/websocket.h + * Lorenz Stechauner, 2022-08-16 + */ + +#ifndef NECRONDA_SERVER_WEBSOCKET_H +#define NECRONDA_SERVER_WEBSOCKET_H + +#include "sock.h" + +#define WS_TIMEOUT 3600 + +const char *ws_key_uuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + +typedef struct { + unsigned char f_fin:1; + unsigned char f_rsv1:1; + unsigned char f_rsv2:1; + unsigned char f_rsv3:1; + unsigned char opcode:4; + unsigned char f_mask:1; + unsigned long len; + char masking_key[4]; +} ws_frame; + +int ws_calc_accept_key(const char *key, char *accept_key); + +int ws_recv_frame_header(sock *s, ws_frame *frame); + +int ws_send_frame_header(sock *s, ws_frame *frame); + +int ws_handle_connection(sock *s1, sock *s2); + +#endif // NECRONDA_SERVER_WEBSOCKET_H diff --git a/src/necronda.h b/src/necronda.h index 28b621c..6ca8a7a 100644 --- a/src/necronda.h +++ b/src/necronda.h @@ -12,6 +12,8 @@ #define SERVER_STR "Necronda/" NECRONDA_VERSION #define SERVER_STR_HTML "Necronda web server " NECRONDA_VERSION +#define CHUNK_SIZE 8192 + #ifndef DEFAULT_HOST # define DEFAULT_HOST "www.necronda.net" #endif diff --git a/src/server.h b/src/server.h index 5ba457a..157ca03 100644 --- a/src/server.h +++ b/src/server.h @@ -20,7 +20,6 @@ #define SERVER_TIMEOUT_INIT 4 #define SERVER_TIMEOUT 3600 -#define CHUNK_SIZE 8192 extern int sockets[NUM_SOCKETS]; extern pid_t children[MAX_CHILDREN];