diff --git a/src/client.c b/src/client.c index 9e666e8..2cdd41d 100644 --- a/src/client.c +++ b/src/client.c @@ -20,7 +20,7 @@ #include "lib/compress.h" #include -#include +#include #include #include #include @@ -28,13 +28,13 @@ #include #include + int server_keep_alive = 1; struct timeval client_timeout = {.tv_sec = CLIENT_TIMEOUT, .tv_usec = 0}; int server_keep_alive; char *log_client_prefix, *log_conn_prefix, *log_req_prefix, *client_geoip; char *client_addr_str, *client_addr_str_ptr, *server_addr_str, *server_addr_str_ptr, *client_host_str; -struct timeval client_timeout; host_config *get_host_config(const char *host) { for (int i = 0; i < CONFIG_MAX_HOST_CONFIG; i++) { @@ -91,12 +91,8 @@ int client_request_handler(sock *client, unsigned long client_num, unsigned int clock_gettime(CLOCK_MONOTONIC, &begin); - fd_set socket_fds; - FD_ZERO(&socket_fds); - FD_SET(client->socket, &socket_fds); - client_timeout.tv_sec = CLIENT_TIMEOUT; - client_timeout.tv_usec = 0; - ret = select(client->socket + 1, &socket_fds, NULL, NULL, &client_timeout); + ret = sock_poll_read(&client, NULL, 1, CLIENT_TIMEOUT * 1000); + http_add_header_field(&res.hdr, "Date", http_get_date(buf0, sizeof(buf0))); http_add_header_field(&res.hdr, "Server", SERVER_STR); if (ret <= 0) { diff --git a/src/lib/sock.c b/src/lib/sock.c index 33746f8..46a5703 100644 --- a/src/lib/sock.c +++ b/src/lib/sock.c @@ -12,6 +12,7 @@ #include #include #include +#include int sock_enc_error(sock *s) { return (int) s->enc ? SSL_get_error(s->ssl, (int) s->_last_ret) : 0; @@ -117,3 +118,28 @@ int sock_check(sock *s) { char buf; return recv(s->socket, &buf, 1, MSG_PEEK | MSG_DONTWAIT) == 1; } + +int sock_poll(sock *sockets[], sock *ready[], short events, int n_sock, int timeout_ms) { + struct pollfd fds[n_sock]; + for (int i = 0; i < n_sock; i++) { + fds[i].fd = sockets[i]->socket; + fds[i].events = events; + } + + int ret = poll(fds, n_sock, timeout_ms); + if (ret < 0 || ready == NULL) return ret; + + for (int i = 0, j = 0; i < ret; j++) { + if (fds[i].revents & events) + ready[i++] = sockets[j]; + } + return ret; +} + +int sock_poll_read(sock *sockets[], sock *readable[], int n_sock, int timeout_ms) { + return sock_poll(sockets, readable, POLLIN, n_sock, timeout_ms); +} + +int sock_poll_write(sock *sockets[], sock *writable[], int n_sock, int timeout_ms) { + return sock_poll(sockets, writable, POLLOUT, n_sock, timeout_ms); +} diff --git a/src/lib/sock.h b/src/lib/sock.h index d427b47..5487f46 100644 --- a/src/lib/sock.h +++ b/src/lib/sock.h @@ -37,4 +37,10 @@ int sock_close(sock *s); int sock_check(sock *s); +int sock_poll(sock *sockets[], sock *readable[], short events, int n_sock, int timeout_ms); + +int sock_poll_read(sock *sockets[], sock *readable[], int n_sock, int timeout_ms); + +int sock_poll_write(sock *sockets[], sock *writable[], int n_sock, int timeout_ms); + #endif //NECRONDA_SERVER_SOCK_H diff --git a/src/server.c b/src/server.c index eb37419..1f11dfa 100644 --- a/src/server.c +++ b/src/server.c @@ -22,7 +22,7 @@ #include #include #include -#include +#include #include #include #include @@ -34,6 +34,7 @@ #include #include + int active = 1; const char *config_file; int sockets[NUM_SOCKETS]; @@ -153,8 +154,7 @@ void terminate() { int main(int argc, const char *argv[]) { const int YES = 1; - fd_set socket_fds, read_socket_fds; - int max_socket_fd = 0; + struct pollfd poll_fds[NUM_SOCKETS]; int ready_sockets_num; long client_num = 0; char buf[1024]; @@ -169,8 +169,6 @@ int main(int argc, const char *argv[]) { memset(children, 0, sizeof(children)); memset(mmdbs, 0, sizeof(mmdbs)); - struct timeval timeout; - const struct sockaddr_in6 addresses[2] = { {.sin6_family = AF_INET6, .sin6_addr = IN6ADDR_ANY_INIT, .sin6_port = htons(80)}, {.sin6_family = AF_INET6, .sin6_addr = IN6ADDR_ANY_INIT, .sin6_port = htons(443)} @@ -338,29 +336,23 @@ int main(int argc, const char *argv[]) { } } - FD_ZERO(&socket_fds); for (int i = 0; i < NUM_SOCKETS; i++) { - FD_SET(sockets[i], &socket_fds); - if (sockets[i] > max_socket_fd) { - max_socket_fd = sockets[i]; - } + poll_fds[i].fd = sockets[i]; + poll_fds[i].events = POLLIN; } fprintf(stderr, "Ready to accept connections\n"); while (active) { - timeout.tv_sec = 1; - timeout.tv_usec = 0; - read_socket_fds = socket_fds; - ready_sockets_num = select(max_socket_fd + 1, &read_socket_fds, NULL, NULL, &timeout); + ready_sockets_num = poll(poll_fds, NUM_SOCKETS, 1000); if (ready_sockets_num < 0) { - fprintf(stderr, ERR_STR "Unable to select sockets: %s" CLR_STR "\n", strerror(errno)); + fprintf(stderr, ERR_STR "Unable to poll sockets: %s" CLR_STR "\n", strerror(errno)); terminate(); return 1; } for (int i = 0; i < NUM_SOCKETS; i++) { - if (FD_ISSET(sockets[i], &read_socket_fds)) { + if (poll_fds[i].revents & POLLIN) { client_fd = accept(sockets[i], (struct sockaddr *) &client_addr, &client_addr_len); if (client_fd < 0) { fprintf(stderr, ERR_STR "Unable to accept connection: %s" CLR_STR "\n", strerror(errno));