fix: improve memory management and error handling in WebSocket and HTTP2 implementations
All checks were successful
All checks were successful
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -59,3 +59,4 @@ ssl/*
|
||||
src/bin
|
||||
docker-push.sh
|
||||
.idea
|
||||
error.txt
|
||||
@@ -258,6 +258,8 @@ static int on_frame_recv_callback(nghttp2_session* session,
|
||||
{
|
||||
close(fd);
|
||||
free(mime_type);
|
||||
stream_data->mime_type = NULL; // Prevent double-free in cleanup
|
||||
stream_data->fd = -1; // Mark fd as already closed
|
||||
log_event("HTTP/2: Memory allocation failed");
|
||||
break;
|
||||
}
|
||||
@@ -403,8 +405,12 @@ static int on_begin_headers_callback(nghttp2_session* session,
|
||||
stream_data->stream_id = frame->hd.stream_id;
|
||||
stream_data->fd = -1;
|
||||
|
||||
nghttp2_session_set_stream_user_data(session, frame->hd.stream_id, stream_data);
|
||||
if (nghttp2_session_set_stream_user_data(session, frame->hd.stream_id, stream_data) != 0)
|
||||
{
|
||||
free(stream_data);
|
||||
return NGHTTP2_ERR_CALLBACK_FAILURE;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
44
src/server.c
44
src/server.c
@@ -528,12 +528,11 @@ static int is_websocket_upgrade(const char* request)
|
||||
// Handle WebSocket connection
|
||||
static void* handle_websocket(void* arg)
|
||||
{
|
||||
ws_connection_t* conn = (ws_connection_t*)arg;
|
||||
ws_connection_t* conn = arg;
|
||||
|
||||
|
||||
// Remove socket timeout for WebSocket (connections should stay open)
|
||||
// Remove socket timeout for WebSocket
|
||||
struct timeval ws_timeout;
|
||||
ws_timeout.tv_sec = 0; // No timeout - wait indefinitely
|
||||
ws_timeout.tv_sec = 0;
|
||||
ws_timeout.tv_usec = 0;
|
||||
setsockopt(conn->socket_fd, SOL_SOCKET, SO_RCVTIMEO, &ws_timeout, sizeof(ws_timeout));
|
||||
|
||||
@@ -571,18 +570,19 @@ static void* handle_websocket(void* arg)
|
||||
if (parsed < 0)
|
||||
{
|
||||
log_event("Failed to parse WebSocket frame");
|
||||
free(payload);
|
||||
free(payload); // Safe - only frees if allocated
|
||||
ws_close_connection(conn, 1002);
|
||||
free(conn);
|
||||
pthread_exit(NULL);
|
||||
}
|
||||
|
||||
bool should_exit = false;
|
||||
|
||||
switch (header.opcode)
|
||||
{
|
||||
case WS_OPCODE_TEXT:
|
||||
if (ws_is_valid_utf8(payload, header.payload_length))
|
||||
{
|
||||
// Echo back the text message
|
||||
ws_send_frame(conn, WS_OPCODE_TEXT, payload, header.payload_length);
|
||||
log_event("WebSocket text frame received and echoed");
|
||||
}
|
||||
@@ -593,7 +593,6 @@ static void* handle_websocket(void* arg)
|
||||
break;
|
||||
|
||||
case WS_OPCODE_BINARY:
|
||||
// Echo back binary data
|
||||
ws_send_frame(conn, WS_OPCODE_BINARY, payload, header.payload_length);
|
||||
log_event("WebSocket binary frame received and echoed");
|
||||
break;
|
||||
@@ -605,16 +604,22 @@ static void* handle_websocket(void* arg)
|
||||
|
||||
case WS_OPCODE_CLOSE:
|
||||
log_event("WebSocket close frame received");
|
||||
free(payload);
|
||||
ws_close_connection(conn, 1000);
|
||||
free(conn);
|
||||
pthread_exit(NULL);
|
||||
should_exit = true;
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
free(payload);
|
||||
payload = NULL;
|
||||
|
||||
if (should_exit)
|
||||
{
|
||||
ws_close_connection(conn, 1000);
|
||||
free(conn);
|
||||
pthread_exit(NULL);
|
||||
}
|
||||
}
|
||||
|
||||
ws_close_connection(conn, 1000);
|
||||
@@ -622,6 +627,7 @@ static void* handle_websocket(void* arg)
|
||||
pthread_exit(NULL);
|
||||
}
|
||||
|
||||
|
||||
void* handle_http_client(void* arg)
|
||||
{
|
||||
int client_socket = *((int*)arg);
|
||||
@@ -1066,7 +1072,7 @@ void* handle_http_client(void* arg)
|
||||
done_serving:
|
||||
continue;
|
||||
}
|
||||
else if (bytes_received < 0)
|
||||
if (bytes_received < 0)
|
||||
{
|
||||
break;
|
||||
}
|
||||
@@ -1081,7 +1087,7 @@ cleanup:
|
||||
|
||||
void* handle_https_client(void* arg)
|
||||
{
|
||||
int client_socket = *((int*)arg);
|
||||
int client_socket = *(int*)arg;
|
||||
free(arg);
|
||||
|
||||
SSL* ssl = SSL_new(ssl_ctx);
|
||||
@@ -1176,17 +1182,14 @@ void* handle_https_client(void* arg)
|
||||
log_event("SSL_read failed");
|
||||
goto cleanup;
|
||||
}
|
||||
else if (bytes_received == 0)
|
||||
if (bytes_received == 0)
|
||||
{
|
||||
log_event("Client closed connection");
|
||||
goto cleanup;
|
||||
}
|
||||
else
|
||||
{
|
||||
buffer[bytes_received] = '\0';
|
||||
log_event("Received HTTPS request:");
|
||||
log_event(buffer);
|
||||
}
|
||||
|
||||
// Check for WebSocket upgrade request on HTTPS
|
||||
if (config.enable_websocket && is_websocket_upgrade(buffer))
|
||||
@@ -1220,20 +1223,14 @@ void* handle_https_client(void* arg)
|
||||
handle_websocket(ws_conn);
|
||||
pthread_exit(NULL);
|
||||
}
|
||||
else
|
||||
{
|
||||
SSL_shutdown(ssl);
|
||||
SSL_free(ssl);
|
||||
close(client_socket);
|
||||
pthread_exit(NULL);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
log_event("Secure WebSocket handshake failed");
|
||||
goto cleanup;
|
||||
}
|
||||
}
|
||||
|
||||
char method[8], url[256], protocol[16];
|
||||
if (parse_request_line(buffer, method, url, protocol) != 0)
|
||||
@@ -1791,7 +1788,6 @@ void* worker_thread(void* arg)
|
||||
*socket_ptr = task->socket_fd;
|
||||
handle_https_client(socket_ptr);
|
||||
}
|
||||
free(socket_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
110
src/websocket.c
110
src/websocket.c
@@ -111,7 +111,7 @@ int ws_handle_handshake(int client_socket, const char* request, char* response,
|
||||
}
|
||||
|
||||
// Create handshake response
|
||||
int written = snprintf(response, response_size,
|
||||
const int written = snprintf(response, response_size,
|
||||
"HTTP/1.1 101 Switching Protocols\r\n"
|
||||
"Upgrade: websocket\r\n"
|
||||
"Connection: Upgrade\r\n"
|
||||
@@ -125,34 +125,39 @@ int ws_handle_handshake(int client_socket, const char* request, char* response,
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Handle WebSocket handshake for SSL connections
|
||||
int ws_handle_handshake_ssl(SSL* ssl, const char* request, char* response, size_t response_size)
|
||||
int ws_handle_handshake_ssl(const SSL* ssl, const char* request, char* response, const size_t response_size)
|
||||
{
|
||||
(void)ssl; // Use the same logic, just different transport
|
||||
return ws_handle_handshake(0, request, response, response_size);
|
||||
}
|
||||
|
||||
void ws_free_payload(uint8_t* payload)
|
||||
{
|
||||
if (payload)
|
||||
{
|
||||
free(payload);
|
||||
}
|
||||
}
|
||||
|
||||
// Parse WebSocket frame
|
||||
int ws_parse_frame(const uint8_t* data, size_t len, ws_frame_header_t* header, uint8_t** payload)
|
||||
int ws_parse_frame(const uint8_t* data, const size_t len, ws_frame_header_t* header, uint8_t** payload)
|
||||
{
|
||||
// Maximum allowed WebSocket payload size (10MB)
|
||||
#define MAX_WEBSOCKET_PAYLOAD (10 * 1024 * 1024)
|
||||
|
||||
if (len < 2)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
header->fin = (data[0] & 0x80) >> 7;
|
||||
header->opcode = data[0] & 0x0F;
|
||||
header->mask = (data[1] & 0x80) >> 7;
|
||||
|
||||
size_t offset = 2;
|
||||
uint8_t payload_len = data[1] & 0x7F;
|
||||
const uint8_t payload_len = data[1] & 0x7F;
|
||||
|
||||
if (payload_len == 126)
|
||||
{
|
||||
@@ -177,10 +182,7 @@ int ws_parse_frame(const uint8_t* data, size_t len, ws_frame_header_t* header, u
|
||||
header->payload_length = payload_len;
|
||||
}
|
||||
|
||||
if (header->payload_length > MAX_WEBSOCKET_PAYLOAD)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
if (header->payload_length > MAX_WEBSOCKET_PAYLOAD) return -1;
|
||||
|
||||
if (header->mask)
|
||||
{
|
||||
@@ -195,11 +197,13 @@ int ws_parse_frame(const uint8_t* data, size_t len, ws_frame_header_t* header, u
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Unmask payload if masked
|
||||
*payload = (uint8_t*)malloc(header->payload_length);
|
||||
// Allocate payload buffer if there is payload data
|
||||
if (header->payload_length > 0)
|
||||
{
|
||||
*payload = malloc(header->payload_length);
|
||||
if (!*payload)
|
||||
{
|
||||
return -1;
|
||||
return -1; // Allocation failed
|
||||
}
|
||||
|
||||
if (header->mask)
|
||||
@@ -213,12 +217,13 @@ int ws_parse_frame(const uint8_t* data, size_t len, ws_frame_header_t* header, u
|
||||
{
|
||||
memcpy(*payload, data + offset, header->payload_length);
|
||||
}
|
||||
}
|
||||
|
||||
return offset + header->payload_length;
|
||||
}
|
||||
|
||||
// Create WebSocket frame
|
||||
int ws_create_frame(uint8_t* buffer, size_t buffer_size, uint8_t opcode, const uint8_t* payload, size_t payload_len)
|
||||
int ws_create_frame(uint8_t* buffer,const size_t buffer_size,const uint8_t opcode, const uint8_t* payload, size_t payload_len)
|
||||
{
|
||||
size_t header_size;
|
||||
|
||||
@@ -276,9 +281,69 @@ int ws_create_frame(uint8_t* buffer, size_t buffer_size, uint8_t opcode, const u
|
||||
|
||||
return (int)offset;
|
||||
}
|
||||
int ws_send_frame_chunk(const ws_connection_t* conn,const uint8_t opcode,const uint8_t fin, const uint8_t* payload,const size_t payload_len) {
|
||||
uint8_t buffer[65536];
|
||||
size_t header_size;
|
||||
size_t offset = 0;
|
||||
|
||||
if (payload_len < 126) {
|
||||
header_size = 2;
|
||||
} else if (payload_len < 65536) {
|
||||
header_size = 4;
|
||||
} else {
|
||||
header_size = 10;
|
||||
}
|
||||
|
||||
if (header_size + payload_len >sizeof(buffer)) return -1;
|
||||
|
||||
buffer[offset++] = (fin ? 0x80 : 0x00) | (opcode & 0x0F);
|
||||
|
||||
if (payload_len < 126) {
|
||||
buffer[offset++] = (uint8_t)payload_len;
|
||||
} else if (payload_len < 65536) {
|
||||
buffer[offset++] = 126;
|
||||
buffer[offset++] = (payload_len >> 8) & 0xFF;
|
||||
buffer[offset++] = payload_len & 0xFF;
|
||||
} else {
|
||||
buffer[offset++] = 127;
|
||||
for (int i = 7; i >= 0; i--) {
|
||||
buffer[offset++] = (payload_len >> (i * 8)) & 0xFF;
|
||||
}
|
||||
}
|
||||
|
||||
if (payload && payload_len > 0) {
|
||||
memcpy(buffer + offset, payload, payload_len);
|
||||
offset += payload_len;
|
||||
}
|
||||
|
||||
ssize_t sent;
|
||||
if (conn->is_ssl && conn->ssl) {
|
||||
sent = SSL_write(conn->ssl, buffer, offset);
|
||||
} else {
|
||||
sent = write(conn->socket_fd, buffer, offset);
|
||||
}
|
||||
|
||||
return (sent > 0) ? sent : -1;
|
||||
}
|
||||
|
||||
int ws_send_frame_fragmented(const ws_connection_t* conn,const uint8_t opcode, const uint8_t* payload,const size_t payload_len) {
|
||||
size_t offset = 0;
|
||||
|
||||
while (offset < payload_len) {
|
||||
const size_t CHUNK_SIZE = 65526;
|
||||
const size_t chunk = (payload_len - offset > CHUNK_SIZE) ? CHUNK_SIZE : (payload_len - offset);
|
||||
const uint8_t frame_opcode = (offset == 0) ? opcode : WS_OPCODE_CONTINUATION;
|
||||
const uint8_t fin = (offset + chunk >= payload_len) ? 0x80 : 0x00;
|
||||
|
||||
const int result = ws_send_frame_chunk(conn, frame_opcode, fin, payload + offset, chunk);
|
||||
|
||||
if (result < 0) return result;
|
||||
offset += chunk;
|
||||
}
|
||||
return offset;
|
||||
}
|
||||
// Send WebSocket frame
|
||||
int ws_send_frame(ws_connection_t* conn, uint8_t opcode, const uint8_t* payload, size_t payload_len)
|
||||
int ws_send_frame(const ws_connection_t* conn, const uint8_t opcode, const uint8_t* payload, const size_t payload_len)
|
||||
{
|
||||
// Allocate buffer with enough space for header (max 10 bytes) + payload
|
||||
// Check for integer overflow
|
||||
@@ -290,13 +355,15 @@ int ws_send_frame(ws_connection_t* conn, uint8_t opcode, const uint8_t* payload,
|
||||
uint8_t buffer[65536];
|
||||
|
||||
// Limit payload to avoid overflow (65536 - 10 bytes for max header)
|
||||
|
||||
//TODO: Logging errors
|
||||
size_t safe_payload_len = payload_len;
|
||||
if (safe_payload_len > 65526)
|
||||
{
|
||||
safe_payload_len = 65526;
|
||||
}
|
||||
|
||||
int frame_len = ws_create_frame(buffer, sizeof(buffer), opcode, payload, safe_payload_len);
|
||||
const int frame_len = ws_create_frame(buffer, sizeof(buffer), opcode, payload, safe_payload_len);
|
||||
|
||||
if (frame_len < 0)
|
||||
{
|
||||
@@ -307,10 +374,7 @@ int ws_send_frame(ws_connection_t* conn, uint8_t opcode, const uint8_t* payload,
|
||||
{
|
||||
return SSL_write(conn->ssl, buffer, frame_len);
|
||||
}
|
||||
else
|
||||
{
|
||||
return write(conn->socket_fd, buffer, frame_len);
|
||||
}
|
||||
}
|
||||
|
||||
// Send text message
|
||||
@@ -326,7 +390,7 @@ int ws_send_pong(ws_connection_t* conn, const uint8_t* payload, size_t payload_l
|
||||
}
|
||||
|
||||
// Close WebSocket connection
|
||||
void ws_close_connection(ws_connection_t* conn, uint16_t status_code)
|
||||
void ws_close_connection(ws_connection_t* conn, const uint16_t status_code)
|
||||
{
|
||||
uint8_t close_payload[2];
|
||||
close_payload[0] = (status_code >> 8) & 0xFF;
|
||||
@@ -338,12 +402,14 @@ void ws_close_connection(ws_connection_t* conn, uint16_t status_code)
|
||||
{
|
||||
SSL_shutdown(conn->ssl);
|
||||
SSL_free(conn->ssl);
|
||||
conn->ssl = NULL; // Prevent double-free
|
||||
}
|
||||
close(conn->socket_fd);
|
||||
conn->socket_fd = -1; // Mark as closed
|
||||
}
|
||||
|
||||
// Validate UTF-8 encoding
|
||||
bool ws_is_valid_utf8(const uint8_t* data, size_t len)
|
||||
bool ws_is_valid_utf8(const uint8_t* data, const size_t len)
|
||||
{
|
||||
size_t i = 0;
|
||||
while (i < len)
|
||||
|
||||
@@ -34,10 +34,10 @@ typedef struct
|
||||
|
||||
// Function prototypes
|
||||
int ws_handle_handshake(int client_socket, const char* request, char* response, size_t response_size);
|
||||
int ws_handle_handshake_ssl(SSL* ssl, const char* request, char* response, size_t response_size);
|
||||
int ws_handle_handshake_ssl(const SSL* ssl, const char* request, char* response, size_t response_size);
|
||||
int ws_parse_frame(const uint8_t* data, size_t len, ws_frame_header_t* header, uint8_t** payload);
|
||||
int ws_create_frame(uint8_t* buffer, size_t buffer_size, uint8_t opcode, const uint8_t* payload, size_t payload_len);
|
||||
int ws_send_frame(ws_connection_t* conn, uint8_t opcode, const uint8_t* payload, size_t payload_len);
|
||||
int ws_send_frame(const ws_connection_t* conn, uint8_t opcode, const uint8_t* payload, size_t payload_len);
|
||||
int ws_send_text(ws_connection_t* conn, const char* text);
|
||||
int ws_send_pong(ws_connection_t* conn, const uint8_t* payload, size_t payload_len);
|
||||
void ws_close_connection(ws_connection_t* conn, uint16_t status_code);
|
||||
|
||||
Reference in New Issue
Block a user