net: websocket: new receiving algorithm

websocket_recv_msg() is reworked with using fsm. Now the function
return 0 when payload is empty, -ENOTCONN if socket close. Receiving
empty ping and sending empty pong were added in tests.
Fixes #52327

Signed-off-by: Grixa Yrev <grixayrev@yandex.ru>
This commit is contained in:
Grixa Yrev 2022-12-14 00:30:22 +03:00 committed by Fabio Baltieri
commit 2a992c65c0
6 changed files with 390 additions and 281 deletions

View file

@ -150,15 +150,18 @@ int mqtt_client_websocket_read(struct mqtt_client *client, uint8_t *data,
ret = websocket_recv_msg(client->transport.websocket.sock,
data, buflen, &message_type, NULL, timeout);
if (ret > 0 && message_type > 0) {
if (ret >= 0 && message_type > 0) {
if (message_type & WEBSOCKET_FLAG_CLOSE) {
return 0;
}
if (!(message_type & WEBSOCKET_FLAG_BINARY)) {
if ((ret == 0) || !(message_type & WEBSOCKET_FLAG_BINARY)) {
return -EAGAIN;
}
}
if (ret == -ENOTCONN) {
ret = 0;
}
return ret;
}

View file

@ -274,8 +274,8 @@ int websocket_connect(int sock, struct websocket_request *wreq,
}
ctx->real_sock = sock;
ctx->tmp_buf = wreq->tmp_buf;
ctx->tmp_buf_len = wreq->tmp_buf_len;
ctx->recv_buf.buf = wreq->tmp_buf;
ctx->recv_buf.size = wreq->tmp_buf_len;
ctx->sec_accept_key = sec_accept_key;
ctx->http_cb = wreq->http_cb;
@ -388,7 +388,10 @@ int websocket_connect(int sock, struct websocket_request *wreq,
/* We will re-use the temp buffer in receive function if needed but
* in order that to work the amount of data in buffer must be set to 0
*/
ctx->tmp_buf_pos = 0;
ctx->recv_buf.count = 0;
/* Init parser FSM */
ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
return fd;
@ -583,7 +586,11 @@ static int websocket_prepare_and_send(struct websocket_context *ctx,
if (HEXDUMP_SENT_PACKETS) {
LOG_HEXDUMP_DBG(header, header_len, "Header");
LOG_HEXDUMP_DBG(payload, payload_len, "Payload");
if ((payload != NULL) && (payload_len > 0)) {
LOG_HEXDUMP_DBG(payload, payload_len, "Payload");
} else {
LOG_DBG("No payload");
}
}
#if defined(CONFIG_NET_TEST)
@ -682,16 +689,17 @@ int websocket_send_msg(int ws_sock, const uint8_t *payload, size_t payload_len,
header[hdr_len++] |= ctx->masking_value >> 8;
header[hdr_len++] |= ctx->masking_value;
data_to_send = k_malloc(payload_len);
if (!data_to_send) {
return -ENOMEM;
}
if ((payload != NULL) && (payload_len > 0)) {
data_to_send = k_malloc(payload_len);
if (!data_to_send) {
return -ENOMEM;
}
memcpy(data_to_send, payload, payload_len);
memcpy(data_to_send, payload, payload_len);
for (i = 0; i < payload_len; i++) {
data_to_send[i] ^=
ctx->masking_value >> (8 * (3 - i % 4));
for (i = 0; i < payload_len; i++) {
data_to_send[i] ^= ctx->masking_value >> (8 * (3 - i % 4));
}
}
}
@ -715,94 +723,153 @@ quit:
return ret - hdr_len;
}
static bool websocket_parse_header(uint8_t *buf, size_t buf_len, bool *masked,
uint32_t *mask_value, uint64_t *message_length,
uint32_t *message_type_flag,
size_t *header_len)
static uint32_t websocket_opcode2flag(uint8_t data)
{
uint8_t len_len; /* length of the length field in header */
uint8_t len; /* message length byte */
uint16_t value;
value = sys_get_be16(&buf[0]);
if (value & 0x8000) {
*message_type_flag |= WEBSOCKET_FLAG_FINAL;
}
switch (value & 0x0f00) {
case 0x0100:
*message_type_flag |= WEBSOCKET_FLAG_TEXT;
break;
case 0x0200:
*message_type_flag |= WEBSOCKET_FLAG_BINARY;
break;
case 0x0800:
*message_type_flag |= WEBSOCKET_FLAG_CLOSE;
break;
case 0x0900:
*message_type_flag |= WEBSOCKET_FLAG_PING;
break;
case 0x0A00:
*message_type_flag |= WEBSOCKET_FLAG_PONG;
switch (data & 0x0f) {
case WEBSOCKET_OPCODE_DATA_TEXT:
return WEBSOCKET_FLAG_TEXT;
case WEBSOCKET_OPCODE_DATA_BINARY:
return WEBSOCKET_FLAG_BINARY;
case WEBSOCKET_OPCODE_CLOSE:
return WEBSOCKET_FLAG_CLOSE;
case WEBSOCKET_OPCODE_PING:
return WEBSOCKET_FLAG_PING;
case WEBSOCKET_OPCODE_PONG:
return WEBSOCKET_FLAG_PONG;
default:
break;
}
return 0;
}
len = value & 0x007f;
if (len < 126) {
len_len = 0;
*message_length = len;
} else if (len == 126) {
len_len = 2;
*message_length = sys_get_be16(&buf[2]);
} else {
len_len = 8;
*message_length = sys_get_be64(&buf[2]);
}
static int websocket_parse(struct websocket_context *ctx, struct websocket_buffer *payload)
{
int len;
uint8_t data;
size_t parsed_count = 0;
/* Minimum websocket header is 2 bytes, header length might be
* bigger depending on length field len.
*/
*header_len = MIN_HEADER_LEN + len_len;
do {
if (parsed_count >= ctx->recv_buf.count) {
return parsed_count;
}
if (ctx->parser_state != WEBSOCKET_PARSER_STATE_PAYLOAD) {
data = ctx->recv_buf.buf[parsed_count++];
if (buf_len >= *header_len) {
if (value & 0x0080) {
*masked = true;
*mask_value = sys_get_be32(&buf[2 + len_len]);
*header_len += 4;
switch (ctx->parser_state) {
case WEBSOCKET_PARSER_STATE_OPCODE:
ctx->message_type = websocket_opcode2flag(data);
if ((data & 0x80) != 0) {
ctx->message_type |= WEBSOCKET_FLAG_FINAL;
}
ctx->parser_state = WEBSOCKET_PARSER_STATE_LENGTH;
break;
case WEBSOCKET_PARSER_STATE_LENGTH:
ctx->masked = (data & 0x80) != 0;
len = data & 0x7f;
if (len < 126) {
ctx->message_len = len;
if (ctx->masked) {
ctx->masking_value = 0;
ctx->parser_remaining = 4;
ctx->parser_state = WEBSOCKET_PARSER_STATE_MASK;
} else {
ctx->parser_remaining = ctx->message_len;
ctx->parser_state =
(ctx->parser_remaining == 0)
? WEBSOCKET_PARSER_STATE_OPCODE
: WEBSOCKET_PARSER_STATE_PAYLOAD;
}
} else {
ctx->message_len = 0;
ctx->parser_remaining = (len < 127) ? 2 : 8;
ctx->parser_state = WEBSOCKET_PARSER_STATE_EXT_LEN;
}
break;
case WEBSOCKET_PARSER_STATE_EXT_LEN:
ctx->parser_remaining--;
ctx->message_len |= (data << (ctx->parser_remaining * 8));
if (ctx->parser_remaining == 0) {
if (ctx->masked) {
ctx->masking_value = 0;
ctx->parser_remaining = 4;
ctx->parser_state = WEBSOCKET_PARSER_STATE_MASK;
} else {
ctx->parser_remaining = ctx->message_len;
ctx->parser_state = WEBSOCKET_PARSER_STATE_PAYLOAD;
}
}
break;
case WEBSOCKET_PARSER_STATE_MASK:
ctx->parser_remaining--;
ctx->masking_value |= (data << (ctx->parser_remaining * 8));
if (ctx->parser_remaining == 0) {
if (ctx->message_len == 0) {
ctx->parser_remaining = 0;
ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
} else {
ctx->parser_remaining = ctx->message_len;
ctx->parser_state = WEBSOCKET_PARSER_STATE_PAYLOAD;
}
}
break;
default:
return -EFAULT;
}
#if (LOG_LEVEL >= LOG_LEVEL_DBG)
if ((ctx->parser_state == WEBSOCKET_PARSER_STATE_PAYLOAD) ||
((ctx->parser_state == WEBSOCKET_PARSER_STATE_OPCODE) &&
(ctx->message_len == 0))) {
NET_DBG("[%p] %smasked, mask 0x%08x, type 0x%02x, msg %zd", ctx,
ctx->masked ? "" : "un",
ctx->masked ? ctx->masking_value : 0, ctx->message_type,
(size_t)ctx->message_len);
}
#endif
} else {
*masked = false;
size_t remaining_in_recv_buf = ctx->recv_buf.count - parsed_count;
size_t payload_in_recv_buf =
MIN(remaining_in_recv_buf, ctx->parser_remaining);
size_t free_in_payload_buf = payload->size - payload->count;
size_t ready_to_copy = MIN(payload_in_recv_buf, free_in_payload_buf);
if (free_in_payload_buf == 0) {
break;
}
memcpy(&payload->buf[payload->count], &ctx->recv_buf.buf[parsed_count],
ready_to_copy);
parsed_count += ready_to_copy;
payload->count += ready_to_copy;
ctx->parser_remaining -= ready_to_copy;
if (ctx->parser_remaining == 0) {
ctx->parser_remaining = 0;
ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
}
}
return true;
}
} while (ctx->parser_state != WEBSOCKET_PARSER_STATE_OPCODE);
return false;
return parsed_count;
}
int websocket_recv_msg(int ws_sock, uint8_t *buf, size_t buf_len,
uint32_t *message_type, uint64_t *remaining, int32_t timeout)
{
struct websocket_context *ctx;
size_t header_len = 0;
int recv_len = 0;
size_t can_copy, left;
int ret;
k_timeout_t tout = K_FOREVER;
struct websocket_buffer payload = {.buf = buf, .size = buf_len, .count = 0};
if (timeout != SYS_FOREVER_MS) {
tout = K_MSEC(timeout);
}
#if defined(CONFIG_NET_TEST)
/* Websocket unit test does not use socket layer but feeds
* the data directly here when testing this function.
*/
struct test_data {
uint8_t *input_buf;
size_t input_len;
struct websocket_context *ctx;
};
if ((buf == NULL) || (buf_len == 0)) {
return -EINVAL;
}
#if defined(CONFIG_NET_TEST)
struct test_data *test_data =
UINT_TO_POINTER((unsigned int) ws_sock);
@ -818,178 +885,90 @@ int websocket_recv_msg(int ws_sock, uint8_t *buf, size_t buf_len,
}
#endif /* CONFIG_NET_TEST */
/* If we have not received the websocket header yet, read it first */
if (!ctx->header_received) {
do {
size_t parsed_count;
if (ctx->recv_buf.count == 0) {
#if defined(CONFIG_NET_TEST)
size_t input_len = MIN(ctx->tmp_buf_len - ctx->tmp_buf_pos,
test_data->input_len);
size_t input_len = MIN(ctx->recv_buf.size,
test_data->input_len - test_data->input_pos);
memcpy(&ctx->tmp_buf[ctx->tmp_buf_pos], test_data->input_buf,
input_len);
test_data->input_buf += input_len;
ret = input_len;
#else
ret = recv(ctx->real_sock, &ctx->tmp_buf[ctx->tmp_buf_pos],
ctx->tmp_buf_len - ctx->tmp_buf_pos,
K_TIMEOUT_EQ(tout, K_NO_WAIT) ? MSG_DONTWAIT : 0);
#endif /* CONFIG_NET_TEST */
if (ret < 0) {
return -errno;
}
if (ret == 0) {
/* Socket closed */
return 0;
}
ctx->tmp_buf_pos += ret;
if (ctx->tmp_buf_pos >= MIN_HEADER_LEN) {
bool masked;
/* Now we will be able to figure out what is the
* actual size of the header.
*/
if (websocket_parse_header(&ctx->tmp_buf[0],
ctx->tmp_buf_pos,
&masked,
&ctx->masking_value,
&ctx->message_len,
&ctx->message_type,
&header_len)) {
ctx->masked = masked;
if (message_type) {
*message_type = ctx->message_type;
}
if (input_len > 0) {
memcpy(ctx->recv_buf.buf,
&test_data->input_buf[test_data->input_pos], input_len);
test_data->input_pos += input_len;
ret = input_len;
} else {
return -EAGAIN;
/* emulate timeout */
errno = EAGAIN;
ret = -1;
}
} else {
return -EAGAIN;
}
if (ctx->tmp_buf_pos < header_len) {
return -EAGAIN;
}
/* All of the header is now received, we can read the payload
* data next.
*/
ctx->header_received = true;
if (HEXDUMP_RECV_PACKETS) {
LOG_HEXDUMP_DBG(&ctx->tmp_buf[0], header_len,
"Header");
NET_DBG("[%p] masked %d mask 0x%04x hdr %zd msg %zd",
ctx, ctx->masked,
ctx->masked ? ctx->masking_value : 0,
header_len, (size_t)ctx->message_len);
}
ctx->total_read = 0;
memmove(ctx->tmp_buf, &ctx->tmp_buf[header_len],
ctx->tmp_buf_len - header_len);
ctx->tmp_buf_pos -= header_len;
if (ctx->tmp_buf_pos == 0) {
/* No data after the header, let the caller call
* this function again to get the payload.
*/
return -EAGAIN;
}
NET_DBG("There is %zd bytes of data", ctx->tmp_buf_pos);
}
/* Now read the whole payload or parts of it */
if (ctx->tmp_buf_pos == 0) {
/* Read more data into temp buffer */
#if defined(CONFIG_NET_TEST)
size_t input_len = MIN(ctx->tmp_buf_len, test_data->input_len);
memcpy(ctx->tmp_buf, test_data->input_buf, input_len);
test_data->input_buf += input_len;
ret = input_len;
#else
ret = recv(ctx->real_sock, ctx->tmp_buf, ctx->tmp_buf_len,
K_TIMEOUT_EQ(tout, K_NO_WAIT) ? MSG_DONTWAIT : 0);
ret = recv(ctx->real_sock, ctx->recv_buf.buf, ctx->recv_buf.size,
K_TIMEOUT_EQ(tout, K_NO_WAIT) ? MSG_DONTWAIT : 0);
#endif /* CONFIG_NET_TEST */
if (ret < 0) {
ret = -errno;
if ((ret == -EAGAIN) && (payload.count > 0)) {
/* go to unmasking */
break;
}
return ret;
}
if (ret == 0) {
/* Socket closed */
return -ENOTCONN;
}
ctx->recv_buf.count = ret;
NET_DBG("[%p] Received %d bytes", ctx, ret);
}
ret = websocket_parse(ctx, &payload);
if (ret < 0) {
return -errno;
return ret;
}
parsed_count = ret;
if ((ctx->parser_state == WEBSOCKET_PARSER_STATE_OPCODE) ||
(payload.count >= payload.size)) {
if (remaining != NULL) {
*remaining = ctx->parser_remaining;
}
if (message_type != NULL) {
*message_type = ctx->message_type;
}
size_t left = ctx->recv_buf.count - parsed_count;
if (left > 0) {
memmove(ctx->recv_buf.buf, &ctx->recv_buf.buf[parsed_count], left);
}
ctx->recv_buf.count = left;
break;
}
if (ret == 0) {
return 0;
}
ctx->recv_buf.count -= parsed_count;
ctx->tmp_buf_pos = ret;
}
if (ctx->tmp_buf_pos <= buf_len) {
/* Is there already any data in the temp buffer? If yes,
* just return it to the caller.
*/
can_copy = MIN(ctx->message_len - ctx->total_read,
ctx->tmp_buf_pos);
} else {
/* We have more data in tmp buffer that will fit into
* user buffer.
*/
can_copy = MIN(ctx->message_len - ctx->total_read, buf_len);
}
left = ctx->tmp_buf_pos - can_copy;
NET_ASSERT(ctx->tmp_buf_pos >= can_copy);
memmove(buf, ctx->tmp_buf, can_copy);
recv_len = can_copy;
if (left > 0) {
memmove(ctx->tmp_buf, &ctx->tmp_buf[can_copy], left);
}
ctx->tmp_buf_pos = left;
ctx->total_read += recv_len;
} while (true);
/* Unmask the data */
if (ctx->masked) {
/* As we might have less than 4 received bytes, we must select
* which byte from masking value to take. The mask_shift will
* tell that.
*/
int mask_shift = (ctx->total_read - recv_len) % sizeof(uint32_t);
int i;
uint8_t *mask_as_bytes = (uint8_t *)&ctx->masking_value;
size_t data_buf_offset = ctx->message_len - ctx->parser_remaining - payload.count;
for (i = 0; i < recv_len; i++) {
buf[i] ^= ctx->masking_value >>
(8 * (3 - (i + mask_shift) % 4));
for (size_t i = 0; i < payload.count; i++) {
size_t m = data_buf_offset % 4;
payload.buf[i] ^= mask_as_bytes[3 - m];
data_buf_offset++;
}
}
#if HEXDUMP_RECV_PACKETS
LOG_HEXDUMP_DBG(buf, recv_len, "Payload");
#endif
if (remaining) {
*remaining = ctx->message_len - ctx->total_read;
}
/* Start to read the header again if all the data has been received */
if (ctx->message_len == ctx->total_read) {
ctx->header_received = false;
ctx->message_len = 0;
ctx->message_type = 0;
ctx->total_read = 0;
}
return recv_len;
return payload.count;
}
static int websocket_send(struct websocket_context *ctx, const uint8_t *buf,
@ -1027,8 +1006,12 @@ static int websocket_recv(struct websocket_context *ctx, uint8_t *buf,
ret = websocket_recv_msg(ctx->sock, buf, buf_len, &message_type,
&remaining, timeout);
if (ret < 0) {
errno = -ret;
return -1;
if (ret == -ENOTCONN) {
ret = 0;
} else {
errno = -ret;
return -1;
}
}
NET_DBG("[%p] Received %d bytes", ctx, ret);

View file

@ -23,6 +23,29 @@
/* From RFC 6455 chapter 4.2.2 */
#define WS_MAGIC "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
/**
* Websocket parser states
*/
enum websocket_parser_state {
WEBSOCKET_PARSER_STATE_OPCODE,
WEBSOCKET_PARSER_STATE_LENGTH,
WEBSOCKET_PARSER_STATE_EXT_LEN,
WEBSOCKET_PARSER_STATE_MASK,
WEBSOCKET_PARSER_STATE_PAYLOAD,
};
/**
* Description of external buffers for payload and receiving
*/
struct websocket_buffer {
/* external buffer */
uint8_t *buf;
/* size of external buffer */
size_t size;
/* data length in external buffer */
size_t count;
};
/**
* Websocket connection information
*/
@ -59,19 +82,12 @@ __net_socket struct websocket_context {
int sock;
};
/** Temporary buffers used for HTTP handshakes and Websocket protocol
* headers. User must provide the actual buffer where the headers are
/** Buffer for receiving from TCP socket.
* This buffer used for HTTP handshakes and Websocket packet parser.
* User must provide the actual buffer where the data are
* stored temporarily.
*/
uint8_t *tmp_buf;
/** Temporary buffer length.
*/
size_t tmp_buf_len;
/** Current reading position in the tmp_buf
*/
size_t tmp_buf_pos;
struct websocket_buffer recv_buf;
/** The real TCP socket to use when sending Websocket data to peer.
*/
@ -80,15 +96,18 @@ __net_socket struct websocket_context {
/** Websocket connection masking value */
uint32_t masking_value;
/** Amount of data received. */
uint64_t total_read;
/** Message length */
uint64_t message_len;
/** Message type */
uint32_t message_type;
/** Parser remaining length in current state */
uint64_t parser_remaining;
/** Parser state */
enum websocket_parser_state parser_state;
/** Is the message masked */
uint8_t masked : 1;
@ -100,11 +119,28 @@ __net_socket struct websocket_context {
/** Did we receive all from peer during HTTP handshake */
uint8_t all_received : 1;
/** Header received */
uint8_t header_received : 1;
};
#if defined(CONFIG_NET_TEST)
/**
* Websocket unit test does not use socket layer but feeds
* the data directly here when testing this function.
*/
struct test_data {
/** pointer to data "tx" buffer */
uint8_t *input_buf;
/** "tx" buffer data length */
size_t input_len;
/** "tx" buffer read (recv) position */
size_t input_pos;
/** external test context */
struct websocket_context *ctx;
};
#endif /* CONFIG_NET_TEST */
/**
* @brief Disconnect the Websocket.
*