diff --git a/.gitignore b/.gitignore index 0a8cd91..5c2c599 100644 --- a/.gitignore +++ b/.gitignore @@ -58,4 +58,5 @@ ssl/* !.github/workflows/ src/bin docker-push.sh -.idea \ No newline at end of file +.idea +error.txt \ No newline at end of file diff --git a/src/http2.c b/src/http2.c index 8f1f70b..1fc2446 100644 --- a/src/http2.c +++ b/src/http2.c @@ -258,6 +258,8 @@ static int on_frame_recv_callback(nghttp2_session* session, { close(fd); free(mime_type); + stream_data->mime_type = NULL; // Prevent double-free in cleanup + stream_data->fd = -1; // Mark fd as already closed log_event("HTTP/2: Memory allocation failed"); break; } @@ -403,8 +405,12 @@ static int on_begin_headers_callback(nghttp2_session* session, stream_data->stream_id = frame->hd.stream_id; stream_data->fd = -1; - nghttp2_session_set_stream_user_data(session, frame->hd.stream_id, stream_data); - free(stream_data); + if (nghttp2_session_set_stream_user_data(session, frame->hd.stream_id, stream_data) != 0) + { + free(stream_data); + return NGHTTP2_ERR_CALLBACK_FAILURE; + } + return 0; } @@ -576,4 +582,4 @@ int http2_handle_connection(http2_session_t* h2_session) } return 1; -} \ No newline at end of file +} diff --git a/src/server.c b/src/server.c index c2da7c0..869e9ac 100644 --- a/src/server.c +++ b/src/server.c @@ -528,12 +528,11 @@ static int is_websocket_upgrade(const char* request) // Handle WebSocket connection static void* handle_websocket(void* arg) { - ws_connection_t* conn = (ws_connection_t*)arg; + ws_connection_t* conn = arg; - - // Remove socket timeout for WebSocket (connections should stay open) + // Remove socket timeout for WebSocket struct timeval ws_timeout; - ws_timeout.tv_sec = 0; // No timeout - wait indefinitely + ws_timeout.tv_sec = 0; ws_timeout.tv_usec = 0; setsockopt(conn->socket_fd, SOL_SOCKET, SO_RCVTIMEO, &ws_timeout, sizeof(ws_timeout)); @@ -571,18 +570,19 @@ static void* handle_websocket(void* arg) if (parsed < 0) { log_event("Failed to parse WebSocket frame"); - free(payload); + free(payload); // Safe - only frees if allocated ws_close_connection(conn, 1002); free(conn); pthread_exit(NULL); } + bool should_exit = false; + switch (header.opcode) { case WS_OPCODE_TEXT: if (ws_is_valid_utf8(payload, header.payload_length)) { - // Echo back the text message ws_send_frame(conn, WS_OPCODE_TEXT, payload, header.payload_length); log_event("WebSocket text frame received and echoed"); } @@ -593,7 +593,6 @@ static void* handle_websocket(void* arg) break; case WS_OPCODE_BINARY: - // Echo back binary data ws_send_frame(conn, WS_OPCODE_BINARY, payload, header.payload_length); log_event("WebSocket binary frame received and echoed"); break; @@ -605,16 +604,22 @@ static void* handle_websocket(void* arg) case WS_OPCODE_CLOSE: log_event("WebSocket close frame received"); - free(payload); - ws_close_connection(conn, 1000); - free(conn); - pthread_exit(NULL); + should_exit = true; + break; default: break; } free(payload); + payload = NULL; + + if (should_exit) + { + ws_close_connection(conn, 1000); + free(conn); + pthread_exit(NULL); + } } ws_close_connection(conn, 1000); @@ -622,6 +627,7 @@ static void* handle_websocket(void* arg) pthread_exit(NULL); } + void* handle_http_client(void* arg) { int client_socket = *((int*)arg); @@ -1066,7 +1072,7 @@ void* handle_http_client(void* arg) done_serving: continue; } - else if (bytes_received < 0) + if (bytes_received < 0) { break; } @@ -1081,7 +1087,7 @@ cleanup: void* handle_https_client(void* arg) { - int client_socket = *((int*)arg); + int client_socket = *(int*)arg; free(arg); SSL* ssl = SSL_new(ssl_ctx); @@ -1176,17 +1182,14 @@ void* handle_https_client(void* arg) log_event("SSL_read failed"); goto cleanup; } - else if (bytes_received == 0) + if (bytes_received == 0) { log_event("Client closed connection"); goto cleanup; } - else - { - buffer[bytes_received] = '\0'; - log_event("Received HTTPS request:"); - log_event(buffer); - } + buffer[bytes_received] = '\0'; + log_event("Received HTTPS request:"); + log_event(buffer); // Check for WebSocket upgrade request on HTTPS if (config.enable_websocket && is_websocket_upgrade(buffer)) @@ -1220,19 +1223,13 @@ void* handle_https_client(void* arg) handle_websocket(ws_conn); pthread_exit(NULL); } - else - { - SSL_shutdown(ssl); - SSL_free(ssl); - close(client_socket); - pthread_exit(NULL); - } - } - else - { - log_event("Secure WebSocket handshake failed"); - goto cleanup; + SSL_shutdown(ssl); + SSL_free(ssl); + close(client_socket); + pthread_exit(NULL); } + log_event("Secure WebSocket handshake failed"); + goto cleanup; } char method[8], url[256], protocol[16]; @@ -1791,7 +1788,6 @@ void* worker_thread(void* arg) *socket_ptr = task->socket_fd; handle_https_client(socket_ptr); } - free(socket_ptr); } else { diff --git a/src/websocket.c b/src/websocket.c index 3c12b44..fc981e2 100644 --- a/src/websocket.c +++ b/src/websocket.c @@ -111,7 +111,7 @@ int ws_handle_handshake(int client_socket, const char* request, char* response, } // Create handshake response - int written = snprintf(response, response_size, + const int written = snprintf(response, response_size, "HTTP/1.1 101 Switching Protocols\r\n" "Upgrade: websocket\r\n" "Connection: Upgrade\r\n" @@ -125,34 +125,39 @@ int ws_handle_handshake(int client_socket, const char* request, char* response, { return -1; } - return 0; } // Handle WebSocket handshake for SSL connections -int ws_handle_handshake_ssl(SSL* ssl, const char* request, char* response, size_t response_size) +int ws_handle_handshake_ssl(const SSL* ssl, const char* request, char* response, const size_t response_size) { (void)ssl; // Use the same logic, just different transport return ws_handle_handshake(0, request, response, response_size); } +void ws_free_payload(uint8_t* payload) +{ + if (payload) + { + free(payload); + } +} + // Parse WebSocket frame -int ws_parse_frame(const uint8_t* data, size_t len, ws_frame_header_t* header, uint8_t** payload) +int ws_parse_frame(const uint8_t* data, const size_t len, ws_frame_header_t* header, uint8_t** payload) { // Maximum allowed WebSocket payload size (10MB) #define MAX_WEBSOCKET_PAYLOAD (10 * 1024 * 1024) if (len < 2) - { return -1; - } header->fin = (data[0] & 0x80) >> 7; header->opcode = data[0] & 0x0F; header->mask = (data[1] & 0x80) >> 7; size_t offset = 2; - uint8_t payload_len = data[1] & 0x7F; + const uint8_t payload_len = data[1] & 0x7F; if (payload_len == 126) { @@ -177,10 +182,7 @@ int ws_parse_frame(const uint8_t* data, size_t len, ws_frame_header_t* header, u header->payload_length = payload_len; } - if (header->payload_length > MAX_WEBSOCKET_PAYLOAD) - { - return -1; - } + if (header->payload_length > MAX_WEBSOCKET_PAYLOAD) return -1; if (header->mask) { @@ -195,30 +197,33 @@ int ws_parse_frame(const uint8_t* data, size_t len, ws_frame_header_t* header, u return -1; } - // Unmask payload if masked - *payload = (uint8_t*)malloc(header->payload_length); - if (!*payload) + // Allocate payload buffer if there is payload data + if (header->payload_length > 0) { - return -1; - } - - if (header->mask) - { - for (uint64_t i = 0; i < header->payload_length; i++) + *payload = malloc(header->payload_length); + if (!*payload) { - (*payload)[i] = data[offset + i] ^ header->masking_key[i % 4]; + return -1; // Allocation failed + } + + if (header->mask) + { + for (uint64_t i = 0; i < header->payload_length; i++) + { + (*payload)[i] = data[offset + i] ^ header->masking_key[i % 4]; + } + } + else + { + memcpy(*payload, data + offset, header->payload_length); } - } - else - { - memcpy(*payload, data + offset, header->payload_length); } return offset + header->payload_length; } // Create WebSocket frame -int ws_create_frame(uint8_t* buffer, size_t buffer_size, uint8_t opcode, const uint8_t* payload, size_t payload_len) +int ws_create_frame(uint8_t* buffer,const size_t buffer_size,const uint8_t opcode, const uint8_t* payload, size_t payload_len) { size_t header_size; @@ -276,9 +281,69 @@ int ws_create_frame(uint8_t* buffer, size_t buffer_size, uint8_t opcode, const u return (int)offset; } +int ws_send_frame_chunk(const ws_connection_t* conn,const uint8_t opcode,const uint8_t fin, const uint8_t* payload,const size_t payload_len) { + uint8_t buffer[65536]; + size_t header_size; + size_t offset = 0; + if (payload_len < 126) { + header_size = 2; + } else if (payload_len < 65536) { + header_size = 4; + } else { + header_size = 10; + } + + if (header_size + payload_len >sizeof(buffer)) return -1; + + buffer[offset++] = (fin ? 0x80 : 0x00) | (opcode & 0x0F); + + if (payload_len < 126) { + buffer[offset++] = (uint8_t)payload_len; + } else if (payload_len < 65536) { + buffer[offset++] = 126; + buffer[offset++] = (payload_len >> 8) & 0xFF; + buffer[offset++] = payload_len & 0xFF; + } else { + buffer[offset++] = 127; + for (int i = 7; i >= 0; i--) { + buffer[offset++] = (payload_len >> (i * 8)) & 0xFF; + } + } + + if (payload && payload_len > 0) { + memcpy(buffer + offset, payload, payload_len); + offset += payload_len; + } + + ssize_t sent; + if (conn->is_ssl && conn->ssl) { + sent = SSL_write(conn->ssl, buffer, offset); + } else { + sent = write(conn->socket_fd, buffer, offset); + } + + return (sent > 0) ? sent : -1; +} + +int ws_send_frame_fragmented(const ws_connection_t* conn,const uint8_t opcode, const uint8_t* payload,const size_t payload_len) { + size_t offset = 0; + + while (offset < payload_len) { + const size_t CHUNK_SIZE = 65526; + const size_t chunk = (payload_len - offset > CHUNK_SIZE) ? CHUNK_SIZE : (payload_len - offset); + const uint8_t frame_opcode = (offset == 0) ? opcode : WS_OPCODE_CONTINUATION; + const uint8_t fin = (offset + chunk >= payload_len) ? 0x80 : 0x00; + + const int result = ws_send_frame_chunk(conn, frame_opcode, fin, payload + offset, chunk); + + if (result < 0) return result; + offset += chunk; + } + return offset; +} // Send WebSocket frame -int ws_send_frame(ws_connection_t* conn, uint8_t opcode, const uint8_t* payload, size_t payload_len) +int ws_send_frame(const ws_connection_t* conn, const uint8_t opcode, const uint8_t* payload, const size_t payload_len) { // Allocate buffer with enough space for header (max 10 bytes) + payload // Check for integer overflow @@ -290,13 +355,15 @@ int ws_send_frame(ws_connection_t* conn, uint8_t opcode, const uint8_t* payload, uint8_t buffer[65536]; // Limit payload to avoid overflow (65536 - 10 bytes for max header) + + //TODO: Logging errors size_t safe_payload_len = payload_len; if (safe_payload_len > 65526) { safe_payload_len = 65526; } - int frame_len = ws_create_frame(buffer, sizeof(buffer), opcode, payload, safe_payload_len); + const int frame_len = ws_create_frame(buffer, sizeof(buffer), opcode, payload, safe_payload_len); if (frame_len < 0) { @@ -307,10 +374,7 @@ int ws_send_frame(ws_connection_t* conn, uint8_t opcode, const uint8_t* payload, { return SSL_write(conn->ssl, buffer, frame_len); } - else - { - return write(conn->socket_fd, buffer, frame_len); - } + return write(conn->socket_fd, buffer, frame_len); } // Send text message @@ -326,7 +390,7 @@ int ws_send_pong(ws_connection_t* conn, const uint8_t* payload, size_t payload_l } // Close WebSocket connection -void ws_close_connection(ws_connection_t* conn, uint16_t status_code) +void ws_close_connection(ws_connection_t* conn, const uint16_t status_code) { uint8_t close_payload[2]; close_payload[0] = (status_code >> 8) & 0xFF; @@ -338,12 +402,14 @@ void ws_close_connection(ws_connection_t* conn, uint16_t status_code) { SSL_shutdown(conn->ssl); SSL_free(conn->ssl); + conn->ssl = NULL; // Prevent double-free } close(conn->socket_fd); + conn->socket_fd = -1; // Mark as closed } // Validate UTF-8 encoding -bool ws_is_valid_utf8(const uint8_t* data, size_t len) +bool ws_is_valid_utf8(const uint8_t* data, const size_t len) { size_t i = 0; while (i < len) diff --git a/src/websocket.h b/src/websocket.h index 1ff7b5e..98274df 100644 --- a/src/websocket.h +++ b/src/websocket.h @@ -34,10 +34,10 @@ typedef struct // Function prototypes int ws_handle_handshake(int client_socket, const char* request, char* response, size_t response_size); -int ws_handle_handshake_ssl(SSL* ssl, const char* request, char* response, size_t response_size); +int ws_handle_handshake_ssl(const SSL* ssl, const char* request, char* response, size_t response_size); int ws_parse_frame(const uint8_t* data, size_t len, ws_frame_header_t* header, uint8_t** payload); int ws_create_frame(uint8_t* buffer, size_t buffer_size, uint8_t opcode, const uint8_t* payload, size_t payload_len); -int ws_send_frame(ws_connection_t* conn, uint8_t opcode, const uint8_t* payload, size_t payload_len); +int ws_send_frame(const ws_connection_t* conn, uint8_t opcode, const uint8_t* payload, size_t payload_len); int ws_send_text(ws_connection_t* conn, const char* text); int ws_send_pong(ws_connection_t* conn, const uint8_t* payload, size_t payload_len); void ws_close_connection(ws_connection_t* conn, uint16_t status_code);