net: websocket: new receiving algorithm

websocket_recv_msg() is reworked with using fsm. Now the function
return 0 when payload is empty, -ENOTCONN if socket close. Receiving
empty ping and sending empty pong were added in tests.
Fixes #52327

Signed-off-by: Grixa Yrev <grixayrev@yandex.ru>
This commit is contained in:
Grixa Yrev 2022-12-14 00:30:22 +03:00 committed by Fabio Baltieri
commit 2a992c65c0
6 changed files with 390 additions and 281 deletions

View file

@ -166,7 +166,10 @@ int websocket_send_msg(int ws_sock, const uint8_t *payload, size_t payload_len,
* The value is in milliseconds. Value SYS_FOREVER_MS means to wait
* forever.
*
* @return <0 if error, >=0 amount of bytes received
* @retval >=0 amount of bytes received.
* @retval -EAGAIN on timeout.
* @retval -ENOTCONN on socket close.
* @retval -errno other negative errno value in case of failure.
*/
int websocket_recv_msg(int ws_sock, uint8_t *buf, size_t buf_len,
uint32_t *message_type, uint64_t *remaining,

View file

@ -201,7 +201,7 @@ static void recv_data_wso_api(int sock, size_t amount, uint8_t *buf,
&message_type,
&remaining,
0);
if (ret <= 0) {
if (ret < 0) {
if (ret == -EAGAIN) {
k_sleep(K_MSEC(50));
continue;

View file

@ -150,15 +150,18 @@ int mqtt_client_websocket_read(struct mqtt_client *client, uint8_t *data,
ret = websocket_recv_msg(client->transport.websocket.sock,
data, buflen, &message_type, NULL, timeout);
if (ret > 0 && message_type > 0) {
if (ret >= 0 && message_type > 0) {
if (message_type & WEBSOCKET_FLAG_CLOSE) {
return 0;
}
if (!(message_type & WEBSOCKET_FLAG_BINARY)) {
if ((ret == 0) || !(message_type & WEBSOCKET_FLAG_BINARY)) {
return -EAGAIN;
}
}
if (ret == -ENOTCONN) {
ret = 0;
}
return ret;
}

View file

@ -274,8 +274,8 @@ int websocket_connect(int sock, struct websocket_request *wreq,
}
ctx->real_sock = sock;
ctx->tmp_buf = wreq->tmp_buf;
ctx->tmp_buf_len = wreq->tmp_buf_len;
ctx->recv_buf.buf = wreq->tmp_buf;
ctx->recv_buf.size = wreq->tmp_buf_len;
ctx->sec_accept_key = sec_accept_key;
ctx->http_cb = wreq->http_cb;
@ -388,7 +388,10 @@ int websocket_connect(int sock, struct websocket_request *wreq,
/* We will re-use the temp buffer in receive function if needed but
* in order that to work the amount of data in buffer must be set to 0
*/
ctx->tmp_buf_pos = 0;
ctx->recv_buf.count = 0;
/* Init parser FSM */
ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
return fd;
@ -583,7 +586,11 @@ static int websocket_prepare_and_send(struct websocket_context *ctx,
if (HEXDUMP_SENT_PACKETS) {
LOG_HEXDUMP_DBG(header, header_len, "Header");
if ((payload != NULL) && (payload_len > 0)) {
LOG_HEXDUMP_DBG(payload, payload_len, "Payload");
} else {
LOG_DBG("No payload");
}
}
#if defined(CONFIG_NET_TEST)
@ -682,6 +689,7 @@ int websocket_send_msg(int ws_sock, const uint8_t *payload, size_t payload_len,
header[hdr_len++] |= ctx->masking_value >> 8;
header[hdr_len++] |= ctx->masking_value;
if ((payload != NULL) && (payload_len > 0)) {
data_to_send = k_malloc(payload_len);
if (!data_to_send) {
return -ENOMEM;
@ -690,8 +698,8 @@ int websocket_send_msg(int ws_sock, const uint8_t *payload, size_t payload_len,
memcpy(data_to_send, payload, payload_len);
for (i = 0; i < payload_len; i++) {
data_to_send[i] ^=
ctx->masking_value >> (8 * (3 - i % 4));
data_to_send[i] ^= ctx->masking_value >> (8 * (3 - i % 4));
}
}
}
@ -715,94 +723,153 @@ quit:
return ret - hdr_len;
}
static bool websocket_parse_header(uint8_t *buf, size_t buf_len, bool *masked,
uint32_t *mask_value, uint64_t *message_length,
uint32_t *message_type_flag,
size_t *header_len)
static uint32_t websocket_opcode2flag(uint8_t data)
{
uint8_t len_len; /* length of the length field in header */
uint8_t len; /* message length byte */
uint16_t value;
value = sys_get_be16(&buf[0]);
if (value & 0x8000) {
*message_type_flag |= WEBSOCKET_FLAG_FINAL;
}
switch (value & 0x0f00) {
case 0x0100:
*message_type_flag |= WEBSOCKET_FLAG_TEXT;
break;
case 0x0200:
*message_type_flag |= WEBSOCKET_FLAG_BINARY;
break;
case 0x0800:
*message_type_flag |= WEBSOCKET_FLAG_CLOSE;
break;
case 0x0900:
*message_type_flag |= WEBSOCKET_FLAG_PING;
break;
case 0x0A00:
*message_type_flag |= WEBSOCKET_FLAG_PONG;
switch (data & 0x0f) {
case WEBSOCKET_OPCODE_DATA_TEXT:
return WEBSOCKET_FLAG_TEXT;
case WEBSOCKET_OPCODE_DATA_BINARY:
return WEBSOCKET_FLAG_BINARY;
case WEBSOCKET_OPCODE_CLOSE:
return WEBSOCKET_FLAG_CLOSE;
case WEBSOCKET_OPCODE_PING:
return WEBSOCKET_FLAG_PING;
case WEBSOCKET_OPCODE_PONG:
return WEBSOCKET_FLAG_PONG;
default:
break;
}
return 0;
}
len = value & 0x007f;
static int websocket_parse(struct websocket_context *ctx, struct websocket_buffer *payload)
{
int len;
uint8_t data;
size_t parsed_count = 0;
do {
if (parsed_count >= ctx->recv_buf.count) {
return parsed_count;
}
if (ctx->parser_state != WEBSOCKET_PARSER_STATE_PAYLOAD) {
data = ctx->recv_buf.buf[parsed_count++];
switch (ctx->parser_state) {
case WEBSOCKET_PARSER_STATE_OPCODE:
ctx->message_type = websocket_opcode2flag(data);
if ((data & 0x80) != 0) {
ctx->message_type |= WEBSOCKET_FLAG_FINAL;
}
ctx->parser_state = WEBSOCKET_PARSER_STATE_LENGTH;
break;
case WEBSOCKET_PARSER_STATE_LENGTH:
ctx->masked = (data & 0x80) != 0;
len = data & 0x7f;
if (len < 126) {
len_len = 0;
*message_length = len;
} else if (len == 126) {
len_len = 2;
*message_length = sys_get_be16(&buf[2]);
ctx->message_len = len;
if (ctx->masked) {
ctx->masking_value = 0;
ctx->parser_remaining = 4;
ctx->parser_state = WEBSOCKET_PARSER_STATE_MASK;
} else {
len_len = 8;
*message_length = sys_get_be64(&buf[2]);
ctx->parser_remaining = ctx->message_len;
ctx->parser_state =
(ctx->parser_remaining == 0)
? WEBSOCKET_PARSER_STATE_OPCODE
: WEBSOCKET_PARSER_STATE_PAYLOAD;
}
/* Minimum websocket header is 2 bytes, header length might be
* bigger depending on length field len.
*/
*header_len = MIN_HEADER_LEN + len_len;
if (buf_len >= *header_len) {
if (value & 0x0080) {
*masked = true;
*mask_value = sys_get_be32(&buf[2 + len_len]);
*header_len += 4;
} else {
*masked = false;
ctx->message_len = 0;
ctx->parser_remaining = (len < 127) ? 2 : 8;
ctx->parser_state = WEBSOCKET_PARSER_STATE_EXT_LEN;
}
break;
case WEBSOCKET_PARSER_STATE_EXT_LEN:
ctx->parser_remaining--;
ctx->message_len |= (data << (ctx->parser_remaining * 8));
if (ctx->parser_remaining == 0) {
if (ctx->masked) {
ctx->masking_value = 0;
ctx->parser_remaining = 4;
ctx->parser_state = WEBSOCKET_PARSER_STATE_MASK;
} else {
ctx->parser_remaining = ctx->message_len;
ctx->parser_state = WEBSOCKET_PARSER_STATE_PAYLOAD;
}
}
break;
case WEBSOCKET_PARSER_STATE_MASK:
ctx->parser_remaining--;
ctx->masking_value |= (data << (ctx->parser_remaining * 8));
if (ctx->parser_remaining == 0) {
if (ctx->message_len == 0) {
ctx->parser_remaining = 0;
ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
} else {
ctx->parser_remaining = ctx->message_len;
ctx->parser_state = WEBSOCKET_PARSER_STATE_PAYLOAD;
}
}
break;
default:
return -EFAULT;
}
return true;
#if (LOG_LEVEL >= LOG_LEVEL_DBG)
if ((ctx->parser_state == WEBSOCKET_PARSER_STATE_PAYLOAD) ||
((ctx->parser_state == WEBSOCKET_PARSER_STATE_OPCODE) &&
(ctx->message_len == 0))) {
NET_DBG("[%p] %smasked, mask 0x%08x, type 0x%02x, msg %zd", ctx,
ctx->masked ? "" : "un",
ctx->masked ? ctx->masking_value : 0, ctx->message_type,
(size_t)ctx->message_len);
}
#endif
} else {
size_t remaining_in_recv_buf = ctx->recv_buf.count - parsed_count;
size_t payload_in_recv_buf =
MIN(remaining_in_recv_buf, ctx->parser_remaining);
size_t free_in_payload_buf = payload->size - payload->count;
size_t ready_to_copy = MIN(payload_in_recv_buf, free_in_payload_buf);
if (free_in_payload_buf == 0) {
break;
}
return false;
memcpy(&payload->buf[payload->count], &ctx->recv_buf.buf[parsed_count],
ready_to_copy);
parsed_count += ready_to_copy;
payload->count += ready_to_copy;
ctx->parser_remaining -= ready_to_copy;
if (ctx->parser_remaining == 0) {
ctx->parser_remaining = 0;
ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
}
}
} while (ctx->parser_state != WEBSOCKET_PARSER_STATE_OPCODE);
return parsed_count;
}
int websocket_recv_msg(int ws_sock, uint8_t *buf, size_t buf_len,
uint32_t *message_type, uint64_t *remaining, int32_t timeout)
{
struct websocket_context *ctx;
size_t header_len = 0;
int recv_len = 0;
size_t can_copy, left;
int ret;
k_timeout_t tout = K_FOREVER;
struct websocket_buffer payload = {.buf = buf, .size = buf_len, .count = 0};
if (timeout != SYS_FOREVER_MS) {
tout = K_MSEC(timeout);
}
#if defined(CONFIG_NET_TEST)
/* Websocket unit test does not use socket layer but feeds
* the data directly here when testing this function.
*/
struct test_data {
uint8_t *input_buf;
size_t input_len;
struct websocket_context *ctx;
};
if ((buf == NULL) || (buf_len == 0)) {
return -EINVAL;
}
#if defined(CONFIG_NET_TEST)
struct test_data *test_data =
UINT_TO_POINTER((unsigned int) ws_sock);
@ -818,178 +885,90 @@ int websocket_recv_msg(int ws_sock, uint8_t *buf, size_t buf_len,
}
#endif /* CONFIG_NET_TEST */
/* If we have not received the websocket header yet, read it first */
if (!ctx->header_received) {
#if defined(CONFIG_NET_TEST)
size_t input_len = MIN(ctx->tmp_buf_len - ctx->tmp_buf_pos,
test_data->input_len);
do {
size_t parsed_count;
memcpy(&ctx->tmp_buf[ctx->tmp_buf_pos], test_data->input_buf,
input_len);
test_data->input_buf += input_len;
if (ctx->recv_buf.count == 0) {
#if defined(CONFIG_NET_TEST)
size_t input_len = MIN(ctx->recv_buf.size,
test_data->input_len - test_data->input_pos);
if (input_len > 0) {
memcpy(ctx->recv_buf.buf,
&test_data->input_buf[test_data->input_pos], input_len);
test_data->input_pos += input_len;
ret = input_len;
} else {
/* emulate timeout */
errno = EAGAIN;
ret = -1;
}
#else
ret = recv(ctx->real_sock, &ctx->tmp_buf[ctx->tmp_buf_pos],
ctx->tmp_buf_len - ctx->tmp_buf_pos,
ret = recv(ctx->real_sock, ctx->recv_buf.buf, ctx->recv_buf.size,
K_TIMEOUT_EQ(tout, K_NO_WAIT) ? MSG_DONTWAIT : 0);
#endif /* CONFIG_NET_TEST */
if (ret < 0) {
return -errno;
ret = -errno;
if ((ret == -EAGAIN) && (payload.count > 0)) {
/* go to unmasking */
break;
}
return ret;
}
if (ret == 0) {
/* Socket closed */
return 0;
return -ENOTCONN;
}
ctx->tmp_buf_pos += ret;
ctx->recv_buf.count = ret;
if (ctx->tmp_buf_pos >= MIN_HEADER_LEN) {
bool masked;
NET_DBG("[%p] Received %d bytes", ctx, ret);
}
/* Now we will be able to figure out what is the
* actual size of the header.
*/
if (websocket_parse_header(&ctx->tmp_buf[0],
ctx->tmp_buf_pos,
&masked,
&ctx->masking_value,
&ctx->message_len,
&ctx->message_type,
&header_len)) {
ctx->masked = masked;
ret = websocket_parse(ctx, &payload);
if (ret < 0) {
return ret;
}
parsed_count = ret;
if (message_type) {
if ((ctx->parser_state == WEBSOCKET_PARSER_STATE_OPCODE) ||
(payload.count >= payload.size)) {
if (remaining != NULL) {
*remaining = ctx->parser_remaining;
}
if (message_type != NULL) {
*message_type = ctx->message_type;
}
} else {
return -EAGAIN;
}
} else {
return -EAGAIN;
}
if (ctx->tmp_buf_pos < header_len) {
return -EAGAIN;
}
/* All of the header is now received, we can read the payload
* data next.
*/
ctx->header_received = true;
if (HEXDUMP_RECV_PACKETS) {
LOG_HEXDUMP_DBG(&ctx->tmp_buf[0], header_len,
"Header");
NET_DBG("[%p] masked %d mask 0x%04x hdr %zd msg %zd",
ctx, ctx->masked,
ctx->masked ? ctx->masking_value : 0,
header_len, (size_t)ctx->message_len);
}
ctx->total_read = 0;
memmove(ctx->tmp_buf, &ctx->tmp_buf[header_len],
ctx->tmp_buf_len - header_len);
ctx->tmp_buf_pos -= header_len;
if (ctx->tmp_buf_pos == 0) {
/* No data after the header, let the caller call
* this function again to get the payload.
*/
return -EAGAIN;
}
NET_DBG("There is %zd bytes of data", ctx->tmp_buf_pos);
}
/* Now read the whole payload or parts of it */
if (ctx->tmp_buf_pos == 0) {
/* Read more data into temp buffer */
#if defined(CONFIG_NET_TEST)
size_t input_len = MIN(ctx->tmp_buf_len, test_data->input_len);
memcpy(ctx->tmp_buf, test_data->input_buf, input_len);
test_data->input_buf += input_len;
ret = input_len;
#else
ret = recv(ctx->real_sock, ctx->tmp_buf, ctx->tmp_buf_len,
K_TIMEOUT_EQ(tout, K_NO_WAIT) ? MSG_DONTWAIT : 0);
#endif /* CONFIG_NET_TEST */
if (ret < 0) {
return -errno;
}
if (ret == 0) {
return 0;
}
ctx->tmp_buf_pos = ret;
}
if (ctx->tmp_buf_pos <= buf_len) {
/* Is there already any data in the temp buffer? If yes,
* just return it to the caller.
*/
can_copy = MIN(ctx->message_len - ctx->total_read,
ctx->tmp_buf_pos);
} else {
/* We have more data in tmp buffer that will fit into
* user buffer.
*/
can_copy = MIN(ctx->message_len - ctx->total_read, buf_len);
}
left = ctx->tmp_buf_pos - can_copy;
NET_ASSERT(ctx->tmp_buf_pos >= can_copy);
memmove(buf, ctx->tmp_buf, can_copy);
recv_len = can_copy;
size_t left = ctx->recv_buf.count - parsed_count;
if (left > 0) {
memmove(ctx->tmp_buf, &ctx->tmp_buf[can_copy], left);
memmove(ctx->recv_buf.buf, &ctx->recv_buf.buf[parsed_count], left);
}
ctx->recv_buf.count = left;
break;
}
ctx->tmp_buf_pos = left;
ctx->total_read += recv_len;
ctx->recv_buf.count -= parsed_count;
} while (true);
/* Unmask the data */
if (ctx->masked) {
/* As we might have less than 4 received bytes, we must select
* which byte from masking value to take. The mask_shift will
* tell that.
*/
int mask_shift = (ctx->total_read - recv_len) % sizeof(uint32_t);
int i;
uint8_t *mask_as_bytes = (uint8_t *)&ctx->masking_value;
size_t data_buf_offset = ctx->message_len - ctx->parser_remaining - payload.count;
for (i = 0; i < recv_len; i++) {
buf[i] ^= ctx->masking_value >>
(8 * (3 - (i + mask_shift) % 4));
for (size_t i = 0; i < payload.count; i++) {
size_t m = data_buf_offset % 4;
payload.buf[i] ^= mask_as_bytes[3 - m];
data_buf_offset++;
}
}
#if HEXDUMP_RECV_PACKETS
LOG_HEXDUMP_DBG(buf, recv_len, "Payload");
#endif
if (remaining) {
*remaining = ctx->message_len - ctx->total_read;
}
/* Start to read the header again if all the data has been received */
if (ctx->message_len == ctx->total_read) {
ctx->header_received = false;
ctx->message_len = 0;
ctx->message_type = 0;
ctx->total_read = 0;
}
return recv_len;
return payload.count;
}
static int websocket_send(struct websocket_context *ctx, const uint8_t *buf,
@ -1027,9 +1006,13 @@ static int websocket_recv(struct websocket_context *ctx, uint8_t *buf,
ret = websocket_recv_msg(ctx->sock, buf, buf_len, &message_type,
&remaining, timeout);
if (ret < 0) {
if (ret == -ENOTCONN) {
ret = 0;
} else {
errno = -ret;
return -1;
}
}
NET_DBG("[%p] Received %d bytes", ctx, ret);

View file

@ -23,6 +23,29 @@
/* From RFC 6455 chapter 4.2.2 */
#define WS_MAGIC "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
/**
* Websocket parser states
*/
enum websocket_parser_state {
WEBSOCKET_PARSER_STATE_OPCODE,
WEBSOCKET_PARSER_STATE_LENGTH,
WEBSOCKET_PARSER_STATE_EXT_LEN,
WEBSOCKET_PARSER_STATE_MASK,
WEBSOCKET_PARSER_STATE_PAYLOAD,
};
/**
* Description of external buffers for payload and receiving
*/
struct websocket_buffer {
/* external buffer */
uint8_t *buf;
/* size of external buffer */
size_t size;
/* data length in external buffer */
size_t count;
};
/**
* Websocket connection information
*/
@ -59,19 +82,12 @@ __net_socket struct websocket_context {
int sock;
};
/** Temporary buffers used for HTTP handshakes and Websocket protocol
* headers. User must provide the actual buffer where the headers are
/** Buffer for receiving from TCP socket.
* This buffer used for HTTP handshakes and Websocket packet parser.
* User must provide the actual buffer where the data are
* stored temporarily.
*/
uint8_t *tmp_buf;
/** Temporary buffer length.
*/
size_t tmp_buf_len;
/** Current reading position in the tmp_buf
*/
size_t tmp_buf_pos;
struct websocket_buffer recv_buf;
/** The real TCP socket to use when sending Websocket data to peer.
*/
@ -80,15 +96,18 @@ __net_socket struct websocket_context {
/** Websocket connection masking value */
uint32_t masking_value;
/** Amount of data received. */
uint64_t total_read;
/** Message length */
uint64_t message_len;
/** Message type */
uint32_t message_type;
/** Parser remaining length in current state */
uint64_t parser_remaining;
/** Parser state */
enum websocket_parser_state parser_state;
/** Is the message masked */
uint8_t masked : 1;
@ -100,11 +119,28 @@ __net_socket struct websocket_context {
/** Did we receive all from peer during HTTP handshake */
uint8_t all_received : 1;
/** Header received */
uint8_t header_received : 1;
};
#if defined(CONFIG_NET_TEST)
/**
* Websocket unit test does not use socket layer but feeds
* the data directly here when testing this function.
*/
struct test_data {
/** pointer to data "tx" buffer */
uint8_t *input_buf;
/** "tx" buffer data length */
size_t input_len;
/** "tx" buffer read (recv) position */
size_t input_pos;
/** external test context */
struct websocket_context *ctx;
};
#endif /* CONFIG_NET_TEST */
/**
* @brief Disconnect the Websocket.
*

View file

@ -52,12 +52,6 @@ static uint8_t temp_recv_buf[MAX_RECV_BUF_LEN + EXTRA_BUF_SPACE];
static uint8_t feed_buf[MAX_RECV_BUF_LEN + EXTRA_BUF_SPACE];
static size_t test_msg_len;
struct test_data {
uint8_t *input_buf;
size_t input_len;
struct websocket_context *ctx;
};
static int test_recv_buf(uint8_t *feed_buf, size_t feed_len,
struct websocket_context *ctx,
uint32_t *msg_type, uint64_t *remaining,
@ -69,6 +63,7 @@ static int test_recv_buf(uint8_t *feed_buf, size_t feed_len,
test_data.ctx = ctx;
test_data.input_buf = feed_buf;
test_data.input_len = feed_len;
test_data.input_pos = 0;
ctx_ptr = POINTER_TO_INT(&test_data);
@ -103,6 +98,9 @@ static const unsigned char frame2[] = {
0xe9, 0xdc
};
/* Empty websocket frame, opcode is ping, without mask */
static const unsigned char ping[] = {0x89, 0x00};
#define FRAME1_HDR_SIZE (sizeof(frame1) - (sizeof(frame1_msg) - 1))
static void test_recv(int count)
@ -115,9 +113,9 @@ static void test_recv(int count)
memset(&ctx, 0, sizeof(ctx));
ctx.tmp_buf = temp_recv_buf;
ctx.tmp_buf_len = sizeof(temp_recv_buf);
ctx.tmp_buf_pos = 0;
ctx.recv_buf.buf = temp_recv_buf;
ctx.recv_buf.size = sizeof(temp_recv_buf);
ctx.recv_buf.count = 0;
memcpy(feed_buf, &frame1, sizeof(frame1));
@ -173,6 +171,7 @@ static void test_recv(int count)
frame1_msg, recv_buf);
zassert_equal(remaining, 0, "Msg not empty");
zassert_equal(msg_type & WEBSOCKET_FLAG_TEXT, WEBSOCKET_FLAG_TEXT, "Msg is not text");
}
static void test_recv_1_byte(void)
@ -225,6 +224,28 @@ static void test_recv_whole_msg(void)
test_recv(sizeof(frame1));
}
static void test_recv_empty_ping(void)
{
struct websocket_context ctx;
int total_read = 0;
uint32_t msg_type = -1;
uint64_t remaining = -1;
memset(&ctx, 0, sizeof(ctx));
ctx.recv_buf.buf = temp_recv_buf;
ctx.recv_buf.size = sizeof(temp_recv_buf);
ctx.recv_buf.count = 0;
memcpy(feed_buf, &ping, sizeof(ping));
total_read = test_recv_buf(&feed_buf[0], sizeof(ping), &ctx, &msg_type, &remaining,
recv_buf, sizeof(recv_buf));
zassert_equal(total_read, 0, "Msg not empty (ret %d)", total_read);
zassert_equal(msg_type & WEBSOCKET_FLAG_PING, WEBSOCKET_FLAG_PING, "Msg is not ping");
}
static void test_recv_2(int count)
{
struct websocket_context ctx;
@ -235,8 +256,8 @@ static void test_recv_2(int count)
memset(&ctx, 0, sizeof(ctx));
ctx.tmp_buf = temp_recv_buf;
ctx.tmp_buf_len = sizeof(temp_recv_buf);
ctx.recv_buf.buf = temp_recv_buf;
ctx.recv_buf.size = sizeof(temp_recv_buf);
memcpy(feed_buf, &frame2, sizeof(frame2));
@ -251,17 +272,19 @@ static void test_recv_2(int count)
frame1_msg, recv_buf);
zassert_equal(remaining, 0, "Msg not empty");
zassert_equal(msg_type & WEBSOCKET_FLAG_TEXT, WEBSOCKET_FLAG_TEXT, "Msg is not text");
/* Then read again, now we should get EAGAIN as the second message
* header is partially read.
/* Then read again. Take in account that part of second frame
* have read from tx buffer to rx buffer.
*/
ret = test_recv_buf(&feed_buf[sizeof(frame1)], count, &ctx, &msg_type,
&remaining, recv_buf, sizeof(recv_buf));
ret = test_recv_buf(&feed_buf[count], sizeof(frame2) - count, &ctx, &msg_type, &remaining,
recv_buf, sizeof(recv_buf));
zassert_equal(ret, sizeof(frame1_msg) - 1,
"2nd header parse failed (ret %d)", ret);
zassert_mem_equal(recv_buf, frame1_msg, sizeof(frame1_msg) - 1,
"Invalid 2nd message, should be '%s' was '%s'", frame1_msg, recv_buf);
zassert_equal(remaining, 0, "Msg not empty");
zassert_equal(msg_type & WEBSOCKET_FLAG_TEXT, WEBSOCKET_FLAG_TEXT, "Msg is not text");
}
static void test_recv_two_msg(void)
@ -279,15 +302,19 @@ int verify_sent_and_received_msg(struct msghdr *msg, bool split_msg)
memset(&ctx, 0, sizeof(ctx));
ctx.tmp_buf = temp_recv_buf;
ctx.tmp_buf_len = sizeof(temp_recv_buf);
ctx.recv_buf.buf = temp_recv_buf;
ctx.recv_buf.size = sizeof(temp_recv_buf);
/* Read first the header */
ret = test_recv_buf(msg->msg_iov[0].iov_base,
msg->msg_iov[0].iov_len,
&ctx, &msg_type, &remaining,
recv_buf, sizeof(recv_buf));
if (remaining > 0) {
zassert_equal(ret, -EAGAIN, "Msg header not found");
} else {
zassert_equal(ret, 0, "Msg header read error (ret %d)", ret);
}
/* Then the first split if it is enabled */
if (split_msg) {
@ -339,8 +366,8 @@ static void test_send_and_recv_lorem_ipsum(void)
memset(&ctx, 0, sizeof(ctx));
ctx.tmp_buf = temp_recv_buf;
ctx.tmp_buf_len = sizeof(temp_recv_buf);
ctx.recv_buf.buf = temp_recv_buf;
ctx.recv_buf.size = sizeof(temp_recv_buf);
test_msg_len = sizeof(lorem_ipsum) - 1;
@ -360,8 +387,8 @@ static void test_recv_two_large_split_msg(void)
memset(&ctx, 0, sizeof(ctx));
ctx.tmp_buf = temp_recv_buf;
ctx.tmp_buf_len = sizeof(temp_recv_buf);
ctx.recv_buf.buf = temp_recv_buf;
ctx.recv_buf.size = sizeof(temp_recv_buf);
test_msg_len = sizeof(lorem_ipsum) - 1;
@ -373,12 +400,67 @@ static void test_recv_two_large_split_msg(void)
test_msg_len, ret);
}
static void test_send_and_recv_empty_pong(void)
{
static struct websocket_context ctx;
int ret;
memset(&ctx, 0, sizeof(ctx));
ctx.recv_buf.buf = temp_recv_buf;
ctx.recv_buf.size = sizeof(temp_recv_buf);
test_msg_len = 0;
ret = websocket_send_msg(POINTER_TO_INT(&ctx), NULL, test_msg_len, WEBSOCKET_OPCODE_PING,
true, true, SYS_FOREVER_MS);
zassert_equal(ret, test_msg_len, "Should have sent %zd bytes but sent %d instead",
test_msg_len, ret);
}
static void test_recv_in_small_buffer(void)
{
struct websocket_context ctx;
uint32_t msg_type = -1;
uint64_t remaining = -1;
int total_read = 0;
int ret;
const size_t frame1_msg_size = sizeof(frame1_msg) - 1;
const size_t recv_buf_size = 7;
memset(&ctx, 0, sizeof(ctx));
ctx.recv_buf.buf = temp_recv_buf;
ctx.recv_buf.size = sizeof(temp_recv_buf);
memcpy(feed_buf, &frame1, sizeof(frame1));
/* Receive first part of message */
ret = test_recv_buf(&feed_buf[0], sizeof(frame1), &ctx, &msg_type, &remaining, recv_buf,
recv_buf_size);
zassert_equal(ret, recv_buf_size, "Should have received %zd bytes but ret %d",
recv_buf_size, ret);
total_read += ret;
/* Receive second part of message */
ret = test_recv_buf(&feed_buf[sizeof(frame1)], 0, &ctx, &msg_type, &remaining,
&recv_buf[recv_buf_size], recv_buf_size);
zassert_equal(ret, frame1_msg_size - recv_buf_size,
"Should have received %zd bytes but ret %d", frame1_msg_size - recv_buf_size,
ret);
total_read += ret;
/* Check receiving whole message */
zassert_equal(total_read, frame1_msg_size, "Received not whole message");
zassert_mem_equal(recv_buf, frame1_msg, frame1_msg_size,
"Invalid message, should be '%s' was '%s'", frame1_msg, recv_buf);
}
void test_main(void)
{
k_thread_system_pool_assign(k_current_get());
ztest_test_suite(websocket,
ztest_unit_test(test_recv_1_byte),
ztest_test_suite(websocket, ztest_unit_test(test_recv_1_byte),
ztest_unit_test(test_recv_2_byte),
ztest_unit_test(test_recv_3_byte),
ztest_unit_test(test_recv_6_byte),
@ -388,10 +470,12 @@ void test_main(void)
ztest_unit_test(test_recv_10_byte),
ztest_unit_test(test_recv_12_byte),
ztest_unit_test(test_recv_whole_msg),
ztest_unit_test(test_recv_empty_ping),
ztest_unit_test(test_recv_two_msg),
ztest_unit_test(test_send_and_recv_lorem_ipsum),
ztest_unit_test(test_recv_two_large_split_msg)
);
ztest_unit_test(test_recv_two_large_split_msg),
ztest_unit_test(test_send_and_recv_empty_pong),
ztest_unit_test(test_recv_in_small_buffer));
ztest_run_test_suite(websocket);
}