fix: improve memory management and error handling in WebSocket and HTTP2 implementations
All checks were successful
CI Pipeline / build (push) Successful in 2m21s
CI Pipeline / security-scan (push) Successful in 1m19s
CI Pipeline / code-quality (push) Successful in 2m54s
CI Pipeline / docker-build (push) Successful in 2m55s
CI Pipeline / test (push) Successful in 1m19s

This commit is contained in:
2026-04-09 15:41:17 +02:00
parent 07e90ebb1b
commit 9188733f9f
5 changed files with 142 additions and 73 deletions

1
.gitignore vendored
View File

@@ -59,3 +59,4 @@ ssl/*
src/bin src/bin
docker-push.sh docker-push.sh
.idea .idea
error.txt

View File

@@ -258,6 +258,8 @@ static int on_frame_recv_callback(nghttp2_session* session,
{ {
close(fd); close(fd);
free(mime_type); free(mime_type);
stream_data->mime_type = NULL; // Prevent double-free in cleanup
stream_data->fd = -1; // Mark fd as already closed
log_event("HTTP/2: Memory allocation failed"); log_event("HTTP/2: Memory allocation failed");
break; break;
} }
@@ -403,8 +405,12 @@ static int on_begin_headers_callback(nghttp2_session* session,
stream_data->stream_id = frame->hd.stream_id; stream_data->stream_id = frame->hd.stream_id;
stream_data->fd = -1; stream_data->fd = -1;
nghttp2_session_set_stream_user_data(session, frame->hd.stream_id, stream_data); if (nghttp2_session_set_stream_user_data(session, frame->hd.stream_id, stream_data) != 0)
free(stream_data); {
free(stream_data);
return NGHTTP2_ERR_CALLBACK_FAILURE;
}
return 0; return 0;
} }

View File

@@ -528,12 +528,11 @@ static int is_websocket_upgrade(const char* request)
// Handle WebSocket connection // Handle WebSocket connection
static void* handle_websocket(void* arg) static void* handle_websocket(void* arg)
{ {
ws_connection_t* conn = (ws_connection_t*)arg; ws_connection_t* conn = arg;
// Remove socket timeout for WebSocket
// Remove socket timeout for WebSocket (connections should stay open)
struct timeval ws_timeout; struct timeval ws_timeout;
ws_timeout.tv_sec = 0; // No timeout - wait indefinitely ws_timeout.tv_sec = 0;
ws_timeout.tv_usec = 0; ws_timeout.tv_usec = 0;
setsockopt(conn->socket_fd, SOL_SOCKET, SO_RCVTIMEO, &ws_timeout, sizeof(ws_timeout)); setsockopt(conn->socket_fd, SOL_SOCKET, SO_RCVTIMEO, &ws_timeout, sizeof(ws_timeout));
@@ -571,18 +570,19 @@ static void* handle_websocket(void* arg)
if (parsed < 0) if (parsed < 0)
{ {
log_event("Failed to parse WebSocket frame"); log_event("Failed to parse WebSocket frame");
free(payload); free(payload); // Safe - only frees if allocated
ws_close_connection(conn, 1002); ws_close_connection(conn, 1002);
free(conn); free(conn);
pthread_exit(NULL); pthread_exit(NULL);
} }
bool should_exit = false;
switch (header.opcode) switch (header.opcode)
{ {
case WS_OPCODE_TEXT: case WS_OPCODE_TEXT:
if (ws_is_valid_utf8(payload, header.payload_length)) if (ws_is_valid_utf8(payload, header.payload_length))
{ {
// Echo back the text message
ws_send_frame(conn, WS_OPCODE_TEXT, payload, header.payload_length); ws_send_frame(conn, WS_OPCODE_TEXT, payload, header.payload_length);
log_event("WebSocket text frame received and echoed"); log_event("WebSocket text frame received and echoed");
} }
@@ -593,7 +593,6 @@ static void* handle_websocket(void* arg)
break; break;
case WS_OPCODE_BINARY: case WS_OPCODE_BINARY:
// Echo back binary data
ws_send_frame(conn, WS_OPCODE_BINARY, payload, header.payload_length); ws_send_frame(conn, WS_OPCODE_BINARY, payload, header.payload_length);
log_event("WebSocket binary frame received and echoed"); log_event("WebSocket binary frame received and echoed");
break; break;
@@ -605,16 +604,22 @@ static void* handle_websocket(void* arg)
case WS_OPCODE_CLOSE: case WS_OPCODE_CLOSE:
log_event("WebSocket close frame received"); log_event("WebSocket close frame received");
free(payload); should_exit = true;
ws_close_connection(conn, 1000); break;
free(conn);
pthread_exit(NULL);
default: default:
break; break;
} }
free(payload); free(payload);
payload = NULL;
if (should_exit)
{
ws_close_connection(conn, 1000);
free(conn);
pthread_exit(NULL);
}
} }
ws_close_connection(conn, 1000); ws_close_connection(conn, 1000);
@@ -622,6 +627,7 @@ static void* handle_websocket(void* arg)
pthread_exit(NULL); pthread_exit(NULL);
} }
void* handle_http_client(void* arg) void* handle_http_client(void* arg)
{ {
int client_socket = *((int*)arg); int client_socket = *((int*)arg);
@@ -1066,7 +1072,7 @@ void* handle_http_client(void* arg)
done_serving: done_serving:
continue; continue;
} }
else if (bytes_received < 0) if (bytes_received < 0)
{ {
break; break;
} }
@@ -1081,7 +1087,7 @@ cleanup:
void* handle_https_client(void* arg) void* handle_https_client(void* arg)
{ {
int client_socket = *((int*)arg); int client_socket = *(int*)arg;
free(arg); free(arg);
SSL* ssl = SSL_new(ssl_ctx); SSL* ssl = SSL_new(ssl_ctx);
@@ -1176,17 +1182,14 @@ void* handle_https_client(void* arg)
log_event("SSL_read failed"); log_event("SSL_read failed");
goto cleanup; goto cleanup;
} }
else if (bytes_received == 0) if (bytes_received == 0)
{ {
log_event("Client closed connection"); log_event("Client closed connection");
goto cleanup; goto cleanup;
} }
else buffer[bytes_received] = '\0';
{ log_event("Received HTTPS request:");
buffer[bytes_received] = '\0'; log_event(buffer);
log_event("Received HTTPS request:");
log_event(buffer);
}
// Check for WebSocket upgrade request on HTTPS // Check for WebSocket upgrade request on HTTPS
if (config.enable_websocket && is_websocket_upgrade(buffer)) if (config.enable_websocket && is_websocket_upgrade(buffer))
@@ -1220,19 +1223,13 @@ void* handle_https_client(void* arg)
handle_websocket(ws_conn); handle_websocket(ws_conn);
pthread_exit(NULL); pthread_exit(NULL);
} }
else SSL_shutdown(ssl);
{ SSL_free(ssl);
SSL_shutdown(ssl); close(client_socket);
SSL_free(ssl); pthread_exit(NULL);
close(client_socket);
pthread_exit(NULL);
}
}
else
{
log_event("Secure WebSocket handshake failed");
goto cleanup;
} }
log_event("Secure WebSocket handshake failed");
goto cleanup;
} }
char method[8], url[256], protocol[16]; char method[8], url[256], protocol[16];
@@ -1791,7 +1788,6 @@ void* worker_thread(void* arg)
*socket_ptr = task->socket_fd; *socket_ptr = task->socket_fd;
handle_https_client(socket_ptr); handle_https_client(socket_ptr);
} }
free(socket_ptr);
} }
else else
{ {

View File

@@ -111,7 +111,7 @@ int ws_handle_handshake(int client_socket, const char* request, char* response,
} }
// Create handshake response // Create handshake response
int written = snprintf(response, response_size, const int written = snprintf(response, response_size,
"HTTP/1.1 101 Switching Protocols\r\n" "HTTP/1.1 101 Switching Protocols\r\n"
"Upgrade: websocket\r\n" "Upgrade: websocket\r\n"
"Connection: Upgrade\r\n" "Connection: Upgrade\r\n"
@@ -125,34 +125,39 @@ int ws_handle_handshake(int client_socket, const char* request, char* response,
{ {
return -1; return -1;
} }
return 0; return 0;
} }
// Handle WebSocket handshake for SSL connections // Handle WebSocket handshake for SSL connections
int ws_handle_handshake_ssl(SSL* ssl, const char* request, char* response, size_t response_size) int ws_handle_handshake_ssl(const SSL* ssl, const char* request, char* response, const size_t response_size)
{ {
(void)ssl; // Use the same logic, just different transport (void)ssl; // Use the same logic, just different transport
return ws_handle_handshake(0, request, response, response_size); return ws_handle_handshake(0, request, response, response_size);
} }
void ws_free_payload(uint8_t* payload)
{
if (payload)
{
free(payload);
}
}
// Parse WebSocket frame // Parse WebSocket frame
int ws_parse_frame(const uint8_t* data, size_t len, ws_frame_header_t* header, uint8_t** payload) int ws_parse_frame(const uint8_t* data, const size_t len, ws_frame_header_t* header, uint8_t** payload)
{ {
// Maximum allowed WebSocket payload size (10MB) // Maximum allowed WebSocket payload size (10MB)
#define MAX_WEBSOCKET_PAYLOAD (10 * 1024 * 1024) #define MAX_WEBSOCKET_PAYLOAD (10 * 1024 * 1024)
if (len < 2) if (len < 2)
{
return -1; return -1;
}
header->fin = (data[0] & 0x80) >> 7; header->fin = (data[0] & 0x80) >> 7;
header->opcode = data[0] & 0x0F; header->opcode = data[0] & 0x0F;
header->mask = (data[1] & 0x80) >> 7; header->mask = (data[1] & 0x80) >> 7;
size_t offset = 2; size_t offset = 2;
uint8_t payload_len = data[1] & 0x7F; const uint8_t payload_len = data[1] & 0x7F;
if (payload_len == 126) if (payload_len == 126)
{ {
@@ -177,10 +182,7 @@ int ws_parse_frame(const uint8_t* data, size_t len, ws_frame_header_t* header, u
header->payload_length = payload_len; header->payload_length = payload_len;
} }
if (header->payload_length > MAX_WEBSOCKET_PAYLOAD) if (header->payload_length > MAX_WEBSOCKET_PAYLOAD) return -1;
{
return -1;
}
if (header->mask) if (header->mask)
{ {
@@ -195,30 +197,33 @@ int ws_parse_frame(const uint8_t* data, size_t len, ws_frame_header_t* header, u
return -1; return -1;
} }
// Unmask payload if masked // Allocate payload buffer if there is payload data
*payload = (uint8_t*)malloc(header->payload_length); if (header->payload_length > 0)
if (!*payload)
{ {
return -1; *payload = malloc(header->payload_length);
} if (!*payload)
if (header->mask)
{
for (uint64_t i = 0; i < header->payload_length; i++)
{ {
(*payload)[i] = data[offset + i] ^ header->masking_key[i % 4]; return -1; // Allocation failed
}
if (header->mask)
{
for (uint64_t i = 0; i < header->payload_length; i++)
{
(*payload)[i] = data[offset + i] ^ header->masking_key[i % 4];
}
}
else
{
memcpy(*payload, data + offset, header->payload_length);
} }
}
else
{
memcpy(*payload, data + offset, header->payload_length);
} }
return offset + header->payload_length; return offset + header->payload_length;
} }
// Create WebSocket frame // Create WebSocket frame
int ws_create_frame(uint8_t* buffer, size_t buffer_size, uint8_t opcode, const uint8_t* payload, size_t payload_len) int ws_create_frame(uint8_t* buffer,const size_t buffer_size,const uint8_t opcode, const uint8_t* payload, size_t payload_len)
{ {
size_t header_size; size_t header_size;
@@ -276,9 +281,69 @@ int ws_create_frame(uint8_t* buffer, size_t buffer_size, uint8_t opcode, const u
return (int)offset; return (int)offset;
} }
int ws_send_frame_chunk(const ws_connection_t* conn,const uint8_t opcode,const uint8_t fin, const uint8_t* payload,const size_t payload_len) {
uint8_t buffer[65536];
size_t header_size;
size_t offset = 0;
if (payload_len < 126) {
header_size = 2;
} else if (payload_len < 65536) {
header_size = 4;
} else {
header_size = 10;
}
if (header_size + payload_len >sizeof(buffer)) return -1;
buffer[offset++] = (fin ? 0x80 : 0x00) | (opcode & 0x0F);
if (payload_len < 126) {
buffer[offset++] = (uint8_t)payload_len;
} else if (payload_len < 65536) {
buffer[offset++] = 126;
buffer[offset++] = (payload_len >> 8) & 0xFF;
buffer[offset++] = payload_len & 0xFF;
} else {
buffer[offset++] = 127;
for (int i = 7; i >= 0; i--) {
buffer[offset++] = (payload_len >> (i * 8)) & 0xFF;
}
}
if (payload && payload_len > 0) {
memcpy(buffer + offset, payload, payload_len);
offset += payload_len;
}
ssize_t sent;
if (conn->is_ssl && conn->ssl) {
sent = SSL_write(conn->ssl, buffer, offset);
} else {
sent = write(conn->socket_fd, buffer, offset);
}
return (sent > 0) ? sent : -1;
}
int ws_send_frame_fragmented(const ws_connection_t* conn,const uint8_t opcode, const uint8_t* payload,const size_t payload_len) {
size_t offset = 0;
while (offset < payload_len) {
const size_t CHUNK_SIZE = 65526;
const size_t chunk = (payload_len - offset > CHUNK_SIZE) ? CHUNK_SIZE : (payload_len - offset);
const uint8_t frame_opcode = (offset == 0) ? opcode : WS_OPCODE_CONTINUATION;
const uint8_t fin = (offset + chunk >= payload_len) ? 0x80 : 0x00;
const int result = ws_send_frame_chunk(conn, frame_opcode, fin, payload + offset, chunk);
if (result < 0) return result;
offset += chunk;
}
return offset;
}
// Send WebSocket frame // Send WebSocket frame
int ws_send_frame(ws_connection_t* conn, uint8_t opcode, const uint8_t* payload, size_t payload_len) int ws_send_frame(const ws_connection_t* conn, const uint8_t opcode, const uint8_t* payload, const size_t payload_len)
{ {
// Allocate buffer with enough space for header (max 10 bytes) + payload // Allocate buffer with enough space for header (max 10 bytes) + payload
// Check for integer overflow // Check for integer overflow
@@ -290,13 +355,15 @@ int ws_send_frame(ws_connection_t* conn, uint8_t opcode, const uint8_t* payload,
uint8_t buffer[65536]; uint8_t buffer[65536];
// Limit payload to avoid overflow (65536 - 10 bytes for max header) // Limit payload to avoid overflow (65536 - 10 bytes for max header)
//TODO: Logging errors
size_t safe_payload_len = payload_len; size_t safe_payload_len = payload_len;
if (safe_payload_len > 65526) if (safe_payload_len > 65526)
{ {
safe_payload_len = 65526; safe_payload_len = 65526;
} }
int frame_len = ws_create_frame(buffer, sizeof(buffer), opcode, payload, safe_payload_len); const int frame_len = ws_create_frame(buffer, sizeof(buffer), opcode, payload, safe_payload_len);
if (frame_len < 0) if (frame_len < 0)
{ {
@@ -307,10 +374,7 @@ int ws_send_frame(ws_connection_t* conn, uint8_t opcode, const uint8_t* payload,
{ {
return SSL_write(conn->ssl, buffer, frame_len); return SSL_write(conn->ssl, buffer, frame_len);
} }
else return write(conn->socket_fd, buffer, frame_len);
{
return write(conn->socket_fd, buffer, frame_len);
}
} }
// Send text message // Send text message
@@ -326,7 +390,7 @@ int ws_send_pong(ws_connection_t* conn, const uint8_t* payload, size_t payload_l
} }
// Close WebSocket connection // Close WebSocket connection
void ws_close_connection(ws_connection_t* conn, uint16_t status_code) void ws_close_connection(ws_connection_t* conn, const uint16_t status_code)
{ {
uint8_t close_payload[2]; uint8_t close_payload[2];
close_payload[0] = (status_code >> 8) & 0xFF; close_payload[0] = (status_code >> 8) & 0xFF;
@@ -338,12 +402,14 @@ void ws_close_connection(ws_connection_t* conn, uint16_t status_code)
{ {
SSL_shutdown(conn->ssl); SSL_shutdown(conn->ssl);
SSL_free(conn->ssl); SSL_free(conn->ssl);
conn->ssl = NULL; // Prevent double-free
} }
close(conn->socket_fd); close(conn->socket_fd);
conn->socket_fd = -1; // Mark as closed
} }
// Validate UTF-8 encoding // Validate UTF-8 encoding
bool ws_is_valid_utf8(const uint8_t* data, size_t len) bool ws_is_valid_utf8(const uint8_t* data, const size_t len)
{ {
size_t i = 0; size_t i = 0;
while (i < len) while (i < len)

View File

@@ -34,10 +34,10 @@ typedef struct
// Function prototypes // Function prototypes
int ws_handle_handshake(int client_socket, const char* request, char* response, size_t response_size); int ws_handle_handshake(int client_socket, const char* request, char* response, size_t response_size);
int ws_handle_handshake_ssl(SSL* ssl, const char* request, char* response, size_t response_size); int ws_handle_handshake_ssl(const SSL* ssl, const char* request, char* response, size_t response_size);
int ws_parse_frame(const uint8_t* data, size_t len, ws_frame_header_t* header, uint8_t** payload); int ws_parse_frame(const uint8_t* data, size_t len, ws_frame_header_t* header, uint8_t** payload);
int ws_create_frame(uint8_t* buffer, size_t buffer_size, uint8_t opcode, const uint8_t* payload, size_t payload_len); int ws_create_frame(uint8_t* buffer, size_t buffer_size, uint8_t opcode, const uint8_t* payload, size_t payload_len);
int ws_send_frame(ws_connection_t* conn, uint8_t opcode, const uint8_t* payload, size_t payload_len); int ws_send_frame(const ws_connection_t* conn, uint8_t opcode, const uint8_t* payload, size_t payload_len);
int ws_send_text(ws_connection_t* conn, const char* text); int ws_send_text(ws_connection_t* conn, const char* text);
int ws_send_pong(ws_connection_t* conn, const uint8_t* payload, size_t payload_len); int ws_send_pong(ws_connection_t* conn, const uint8_t* payload, size_t payload_len);
void ws_close_connection(ws_connection_t* conn, uint16_t status_code); void ws_close_connection(ws_connection_t* conn, uint16_t status_code);