From bb895c5bca93a8443610d3768234bc6c3d9cf5e1 Mon Sep 17 00:00:00 2001 From: Lorenz Stechauner Date: Sat, 27 Aug 2022 13:14:25 +0200 Subject: [PATCH] Fix WebSocket connection close handling --- src/client.c | 2 +- src/lib/sock.c | 21 ++++++++++++--------- src/lib/sock.h | 6 +++--- src/lib/websocket.c | 21 ++++++++++++--------- 4 files changed, 28 insertions(+), 22 deletions(-) diff --git a/src/client.c b/src/client.c index 49bd59f..02e3349 100644 --- a/src/client.c +++ b/src/client.c @@ -85,7 +85,7 @@ int client_request_handler(sock *client, unsigned long client_num, unsigned int clock_gettime(CLOCK_MONOTONIC, &begin); - ret = sock_poll_read(&client, NULL, 1, CLIENT_TIMEOUT * 1000); + ret = sock_poll_read(&client, NULL, NULL, 1, NULL, NULL, CLIENT_TIMEOUT * 1000); http_add_header_field(&res.hdr, "Date", http_get_date(buf0, sizeof(buf0))); http_add_header_field(&res.hdr, "Server", SERVER_STR); diff --git a/src/lib/sock.c b/src/lib/sock.c index 850b0c1..bba367e 100644 --- a/src/lib/sock.c +++ b/src/lib/sock.c @@ -120,7 +120,7 @@ int sock_check(sock *s) { return recv(s->socket, &buf, 1, MSG_PEEK | MSG_DONTWAIT) == 1; } -int sock_poll(sock *sockets[], sock *ready[], short events, int n_sock, int timeout_ms) { +int sock_poll(sock *sockets[], sock *ready[], sock *error[], int n_sock, int *n_ready, int *n_error, short events, int timeout_ms) { struct pollfd fds[n_sock]; for (int i = 0; i < n_sock; i++) { fds[i].fd = sockets[i]->socket; @@ -128,20 +128,23 @@ int sock_poll(sock *sockets[], sock *ready[], short events, int n_sock, int time } int ret = poll(fds, n_sock, timeout_ms); - if (ret < 0 || ready == NULL) return ret; + if (ret < 0 || ready == NULL || error == NULL) return ret; - int j = 0; + *n_ready = 0, *n_error = 0; for (int i = 0; i < n_sock; i++) { if (fds[i].revents & events) - ready[j++] = sockets[i]; + ready[(*n_ready)++] = sockets[i]; + if (fds[i].revents & (POLLERR | POLLHUP | POLLNVAL)) + error[(*n_error)++] = sockets[i]; } - return j; + + return ret; } -int sock_poll_read(sock *sockets[], sock *readable[], int n_sock, int timeout_ms) { - return sock_poll(sockets, readable, POLLIN, n_sock, timeout_ms); +int sock_poll_read(sock *sockets[], sock *readable[], sock *error[], int n_sock, int *n_readable, int *n_error, int timeout_ms) { + return sock_poll(sockets, readable, error, n_sock, n_readable, n_error, POLLIN, timeout_ms); } -int sock_poll_write(sock *sockets[], sock *writable[], int n_sock, int timeout_ms) { - return sock_poll(sockets, writable, POLLOUT, n_sock, timeout_ms); +int sock_poll_write(sock *sockets[], sock *writable[], sock *error[], int n_sock, int *n_writable, int *n_error, int timeout_ms) { + return sock_poll(sockets, writable, error, n_sock, n_writable, n_error, POLLOUT, timeout_ms); } diff --git a/src/lib/sock.h b/src/lib/sock.h index 0fb4b5a..9213afd 100644 --- a/src/lib/sock.h +++ b/src/lib/sock.h @@ -38,10 +38,10 @@ int sock_close(sock *s); int sock_check(sock *s); -int sock_poll(sock *sockets[], sock *readable[], short events, int n_sock, int timeout_ms); +int sock_poll(sock *sockets[], sock *ready[], sock *error[], int n_sock, int *n_ready, int *n_error, short events, int timeout_ms); -int sock_poll_read(sock *sockets[], sock *readable[], int n_sock, int timeout_ms); +int sock_poll_read(sock *sockets[], sock *readable[], sock *error[], int n_sock, int *n_readable, int *n_error, int timeout_ms); -int sock_poll_write(sock *sockets[], sock *writable[], int n_sock, int timeout_ms); +int sock_poll_write(sock *sockets[], sock *writable[], sock *error[], int n_sock, int *n_writable, int *n_error, int timeout_ms); #endif //SESIMOS_SOCK_H diff --git a/src/lib/websocket.c b/src/lib/websocket.c index 822f126..23538c2 100644 --- a/src/lib/websocket.c +++ b/src/lib/websocket.c @@ -145,29 +145,32 @@ int ws_send_frame_header(sock *s, ws_frame *frame) { int ws_handle_connection(sock *s1, sock *s2) { sock *poll_socks[2] = {s1, s2}; - sock *readable[2]; - int n_sock = 2; + sock *readable[2], *error[2]; + int n_sock = 2, n_readable = 0, n_error = 0; ws_frame frame; char buf[CHUNK_SIZE]; - int poll, closes = 0; + int closes = 0; long ret; signal(SIGINT, ws_terminate); signal(SIGTERM, ws_terminate); while (!terminate && closes != 3) { - poll = sock_poll_read(poll_socks, readable, n_sock, WS_TIMEOUT * 1000); + ret = sock_poll_read(poll_socks, readable, error, n_sock, &n_readable, &n_error, WS_TIMEOUT * 1000); if (terminate) { break; - } else if (poll < 0) { + } else if (ret < 0) { print(ERR_STR "Unable to poll sockets: %s" CLR_STR, strerror(errno)); return -1; - } else if (poll == 0) { + } else if (n_readable == 0) { print(ERR_STR "Connection timed out" CLR_STR); return -2; + } else if (n_error > 0) { + print(ERR_STR "Peer closed connection" CLR_STR); + return -3; } - for (int i = 0; i < poll; i++) { + for (int i = 0; i < n_readable; i++) { sock *s = readable[i]; sock *o = (s == s1) ? s2 : s1; if (ws_recv_frame_header(s, &frame) != 0) return -3; @@ -188,10 +191,10 @@ int ws_handle_connection(sock *s1, sock *s2) { ret = sock_splice(o, s, buf, sizeof(buf), frame.len); if (ret < 0) { print(ERR_STR "Unable to forward data in WebSocket: %s" CLR_STR, strerror(errno)); - return -3; + return -4; } else if (ret != frame.len) { print(ERR_STR "Unable to forward correct number of bytes in WebSocket" CLR_STR); - return -3; + return -4; } } }