diff --git a/include/zephyr/net/websocket.h b/include/zephyr/net/websocket.h index 5d0301c0316..1a0853891a0 100644 --- a/include/zephyr/net/websocket.h +++ b/include/zephyr/net/websocket.h @@ -166,7 +166,10 @@ int websocket_send_msg(int ws_sock, const uint8_t *payload, size_t payload_len, * The value is in milliseconds. Value SYS_FOREVER_MS means to wait * forever. * - * @return <0 if error, >=0 amount of bytes received + * @retval >=0 amount of bytes received. + * @retval -EAGAIN on timeout. + * @retval -ENOTCONN on socket close. + * @retval -errno other negative errno value in case of failure. */ int websocket_recv_msg(int ws_sock, uint8_t *buf, size_t buf_len, uint32_t *message_type, uint64_t *remaining, diff --git a/samples/net/sockets/websocket_client/src/main.c b/samples/net/sockets/websocket_client/src/main.c index b8f84c9a430..6162d025c98 100644 --- a/samples/net/sockets/websocket_client/src/main.c +++ b/samples/net/sockets/websocket_client/src/main.c @@ -201,7 +201,7 @@ static void recv_data_wso_api(int sock, size_t amount, uint8_t *buf, &message_type, &remaining, 0); - if (ret <= 0) { + if (ret < 0) { if (ret == -EAGAIN) { k_sleep(K_MSEC(50)); continue; diff --git a/subsys/net/lib/mqtt/mqtt_transport_websocket.c b/subsys/net/lib/mqtt/mqtt_transport_websocket.c index ed33da0aa82..a4f8b477d1b 100644 --- a/subsys/net/lib/mqtt/mqtt_transport_websocket.c +++ b/subsys/net/lib/mqtt/mqtt_transport_websocket.c @@ -150,15 +150,18 @@ int mqtt_client_websocket_read(struct mqtt_client *client, uint8_t *data, ret = websocket_recv_msg(client->transport.websocket.sock, data, buflen, &message_type, NULL, timeout); - if (ret > 0 && message_type > 0) { + if (ret >= 0 && message_type > 0) { if (message_type & WEBSOCKET_FLAG_CLOSE) { return 0; } - if (!(message_type & WEBSOCKET_FLAG_BINARY)) { + if ((ret == 0) || !(message_type & WEBSOCKET_FLAG_BINARY)) { return -EAGAIN; } } + if (ret == -ENOTCONN) { + ret = 0; + } return ret; } diff --git a/subsys/net/lib/websocket/websocket.c b/subsys/net/lib/websocket/websocket.c index 2c7d5c0e253..490518f0f3a 100644 --- a/subsys/net/lib/websocket/websocket.c +++ b/subsys/net/lib/websocket/websocket.c @@ -274,8 +274,8 @@ int websocket_connect(int sock, struct websocket_request *wreq, } ctx->real_sock = sock; - ctx->tmp_buf = wreq->tmp_buf; - ctx->tmp_buf_len = wreq->tmp_buf_len; + ctx->recv_buf.buf = wreq->tmp_buf; + ctx->recv_buf.size = wreq->tmp_buf_len; ctx->sec_accept_key = sec_accept_key; ctx->http_cb = wreq->http_cb; @@ -388,7 +388,10 @@ int websocket_connect(int sock, struct websocket_request *wreq, /* We will re-use the temp buffer in receive function if needed but * in order that to work the amount of data in buffer must be set to 0 */ - ctx->tmp_buf_pos = 0; + ctx->recv_buf.count = 0; + + /* Init parser FSM */ + ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE; return fd; @@ -583,7 +586,11 @@ static int websocket_prepare_and_send(struct websocket_context *ctx, if (HEXDUMP_SENT_PACKETS) { LOG_HEXDUMP_DBG(header, header_len, "Header"); - LOG_HEXDUMP_DBG(payload, payload_len, "Payload"); + if ((payload != NULL) && (payload_len > 0)) { + LOG_HEXDUMP_DBG(payload, payload_len, "Payload"); + } else { + LOG_DBG("No payload"); + } } #if defined(CONFIG_NET_TEST) @@ -682,16 +689,17 @@ int websocket_send_msg(int ws_sock, const uint8_t *payload, size_t payload_len, header[hdr_len++] |= ctx->masking_value >> 8; header[hdr_len++] |= ctx->masking_value; - data_to_send = k_malloc(payload_len); - if (!data_to_send) { - return -ENOMEM; - } + if ((payload != NULL) && (payload_len > 0)) { + data_to_send = k_malloc(payload_len); + if (!data_to_send) { + return -ENOMEM; + } - memcpy(data_to_send, payload, payload_len); + memcpy(data_to_send, payload, payload_len); - for (i = 0; i < payload_len; i++) { - data_to_send[i] ^= - ctx->masking_value >> (8 * (3 - i % 4)); + for (i = 0; i < payload_len; i++) { + data_to_send[i] ^= ctx->masking_value >> (8 * (3 - i % 4)); + } } } @@ -715,94 +723,153 @@ quit: return ret - hdr_len; } -static bool websocket_parse_header(uint8_t *buf, size_t buf_len, bool *masked, - uint32_t *mask_value, uint64_t *message_length, - uint32_t *message_type_flag, - size_t *header_len) +static uint32_t websocket_opcode2flag(uint8_t data) { - uint8_t len_len; /* length of the length field in header */ - uint8_t len; /* message length byte */ - uint16_t value; - - value = sys_get_be16(&buf[0]); - if (value & 0x8000) { - *message_type_flag |= WEBSOCKET_FLAG_FINAL; - } - - switch (value & 0x0f00) { - case 0x0100: - *message_type_flag |= WEBSOCKET_FLAG_TEXT; - break; - case 0x0200: - *message_type_flag |= WEBSOCKET_FLAG_BINARY; - break; - case 0x0800: - *message_type_flag |= WEBSOCKET_FLAG_CLOSE; - break; - case 0x0900: - *message_type_flag |= WEBSOCKET_FLAG_PING; - break; - case 0x0A00: - *message_type_flag |= WEBSOCKET_FLAG_PONG; + switch (data & 0x0f) { + case WEBSOCKET_OPCODE_DATA_TEXT: + return WEBSOCKET_FLAG_TEXT; + case WEBSOCKET_OPCODE_DATA_BINARY: + return WEBSOCKET_FLAG_BINARY; + case WEBSOCKET_OPCODE_CLOSE: + return WEBSOCKET_FLAG_CLOSE; + case WEBSOCKET_OPCODE_PING: + return WEBSOCKET_FLAG_PING; + case WEBSOCKET_OPCODE_PONG: + return WEBSOCKET_FLAG_PONG; + default: break; } + return 0; +} - len = value & 0x007f; - if (len < 126) { - len_len = 0; - *message_length = len; - } else if (len == 126) { - len_len = 2; - *message_length = sys_get_be16(&buf[2]); - } else { - len_len = 8; - *message_length = sys_get_be64(&buf[2]); - } +static int websocket_parse(struct websocket_context *ctx, struct websocket_buffer *payload) +{ + int len; + uint8_t data; + size_t parsed_count = 0; - /* Minimum websocket header is 2 bytes, header length might be - * bigger depending on length field len. - */ - *header_len = MIN_HEADER_LEN + len_len; + do { + if (parsed_count >= ctx->recv_buf.count) { + return parsed_count; + } + if (ctx->parser_state != WEBSOCKET_PARSER_STATE_PAYLOAD) { + data = ctx->recv_buf.buf[parsed_count++]; - if (buf_len >= *header_len) { - if (value & 0x0080) { - *masked = true; - *mask_value = sys_get_be32(&buf[2 + len_len]); - *header_len += 4; + switch (ctx->parser_state) { + case WEBSOCKET_PARSER_STATE_OPCODE: + ctx->message_type = websocket_opcode2flag(data); + if ((data & 0x80) != 0) { + ctx->message_type |= WEBSOCKET_FLAG_FINAL; + } + ctx->parser_state = WEBSOCKET_PARSER_STATE_LENGTH; + break; + case WEBSOCKET_PARSER_STATE_LENGTH: + ctx->masked = (data & 0x80) != 0; + len = data & 0x7f; + if (len < 126) { + ctx->message_len = len; + if (ctx->masked) { + ctx->masking_value = 0; + ctx->parser_remaining = 4; + ctx->parser_state = WEBSOCKET_PARSER_STATE_MASK; + } else { + ctx->parser_remaining = ctx->message_len; + ctx->parser_state = + (ctx->parser_remaining == 0) + ? WEBSOCKET_PARSER_STATE_OPCODE + : WEBSOCKET_PARSER_STATE_PAYLOAD; + } + } else { + ctx->message_len = 0; + ctx->parser_remaining = (len < 127) ? 2 : 8; + ctx->parser_state = WEBSOCKET_PARSER_STATE_EXT_LEN; + } + break; + case WEBSOCKET_PARSER_STATE_EXT_LEN: + ctx->parser_remaining--; + ctx->message_len |= (data << (ctx->parser_remaining * 8)); + if (ctx->parser_remaining == 0) { + if (ctx->masked) { + ctx->masking_value = 0; + ctx->parser_remaining = 4; + ctx->parser_state = WEBSOCKET_PARSER_STATE_MASK; + } else { + ctx->parser_remaining = ctx->message_len; + ctx->parser_state = WEBSOCKET_PARSER_STATE_PAYLOAD; + } + } + break; + case WEBSOCKET_PARSER_STATE_MASK: + ctx->parser_remaining--; + ctx->masking_value |= (data << (ctx->parser_remaining * 8)); + if (ctx->parser_remaining == 0) { + if (ctx->message_len == 0) { + ctx->parser_remaining = 0; + ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE; + } else { + ctx->parser_remaining = ctx->message_len; + ctx->parser_state = WEBSOCKET_PARSER_STATE_PAYLOAD; + } + } + break; + default: + return -EFAULT; + } + +#if (LOG_LEVEL >= LOG_LEVEL_DBG) + if ((ctx->parser_state == WEBSOCKET_PARSER_STATE_PAYLOAD) || + ((ctx->parser_state == WEBSOCKET_PARSER_STATE_OPCODE) && + (ctx->message_len == 0))) { + NET_DBG("[%p] %smasked, mask 0x%08x, type 0x%02x, msg %zd", ctx, + ctx->masked ? "" : "un", + ctx->masked ? ctx->masking_value : 0, ctx->message_type, + (size_t)ctx->message_len); + } +#endif } else { - *masked = false; + size_t remaining_in_recv_buf = ctx->recv_buf.count - parsed_count; + size_t payload_in_recv_buf = + MIN(remaining_in_recv_buf, ctx->parser_remaining); + size_t free_in_payload_buf = payload->size - payload->count; + size_t ready_to_copy = MIN(payload_in_recv_buf, free_in_payload_buf); + + if (free_in_payload_buf == 0) { + break; + } + + memcpy(&payload->buf[payload->count], &ctx->recv_buf.buf[parsed_count], + ready_to_copy); + parsed_count += ready_to_copy; + payload->count += ready_to_copy; + ctx->parser_remaining -= ready_to_copy; + if (ctx->parser_remaining == 0) { + ctx->parser_remaining = 0; + ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE; + } } - return true; - } + } while (ctx->parser_state != WEBSOCKET_PARSER_STATE_OPCODE); - return false; + return parsed_count; } int websocket_recv_msg(int ws_sock, uint8_t *buf, size_t buf_len, uint32_t *message_type, uint64_t *remaining, int32_t timeout) { struct websocket_context *ctx; - size_t header_len = 0; - int recv_len = 0; - size_t can_copy, left; int ret; k_timeout_t tout = K_FOREVER; + struct websocket_buffer payload = {.buf = buf, .size = buf_len, .count = 0}; if (timeout != SYS_FOREVER_MS) { tout = K_MSEC(timeout); } -#if defined(CONFIG_NET_TEST) - /* Websocket unit test does not use socket layer but feeds - * the data directly here when testing this function. - */ - struct test_data { - uint8_t *input_buf; - size_t input_len; - struct websocket_context *ctx; - }; + if ((buf == NULL) || (buf_len == 0)) { + return -EINVAL; + } +#if defined(CONFIG_NET_TEST) struct test_data *test_data = UINT_TO_POINTER((unsigned int) ws_sock); @@ -818,178 +885,90 @@ int websocket_recv_msg(int ws_sock, uint8_t *buf, size_t buf_len, } #endif /* CONFIG_NET_TEST */ - /* If we have not received the websocket header yet, read it first */ - if (!ctx->header_received) { + do { + size_t parsed_count; + + if (ctx->recv_buf.count == 0) { #if defined(CONFIG_NET_TEST) - size_t input_len = MIN(ctx->tmp_buf_len - ctx->tmp_buf_pos, - test_data->input_len); + size_t input_len = MIN(ctx->recv_buf.size, + test_data->input_len - test_data->input_pos); - memcpy(&ctx->tmp_buf[ctx->tmp_buf_pos], test_data->input_buf, - input_len); - test_data->input_buf += input_len; - ret = input_len; -#else - ret = recv(ctx->real_sock, &ctx->tmp_buf[ctx->tmp_buf_pos], - ctx->tmp_buf_len - ctx->tmp_buf_pos, - K_TIMEOUT_EQ(tout, K_NO_WAIT) ? MSG_DONTWAIT : 0); -#endif /* CONFIG_NET_TEST */ - - if (ret < 0) { - return -errno; - } - - if (ret == 0) { - /* Socket closed */ - return 0; - } - - ctx->tmp_buf_pos += ret; - - if (ctx->tmp_buf_pos >= MIN_HEADER_LEN) { - bool masked; - - /* Now we will be able to figure out what is the - * actual size of the header. - */ - if (websocket_parse_header(&ctx->tmp_buf[0], - ctx->tmp_buf_pos, - &masked, - &ctx->masking_value, - &ctx->message_len, - &ctx->message_type, - &header_len)) { - ctx->masked = masked; - - if (message_type) { - *message_type = ctx->message_type; - } + if (input_len > 0) { + memcpy(ctx->recv_buf.buf, + &test_data->input_buf[test_data->input_pos], input_len); + test_data->input_pos += input_len; + ret = input_len; } else { - return -EAGAIN; + /* emulate timeout */ + errno = EAGAIN; + ret = -1; } - } else { - return -EAGAIN; - } - - if (ctx->tmp_buf_pos < header_len) { - return -EAGAIN; - } - - /* All of the header is now received, we can read the payload - * data next. - */ - ctx->header_received = true; - - if (HEXDUMP_RECV_PACKETS) { - LOG_HEXDUMP_DBG(&ctx->tmp_buf[0], header_len, - "Header"); - NET_DBG("[%p] masked %d mask 0x%04x hdr %zd msg %zd", - ctx, ctx->masked, - ctx->masked ? ctx->masking_value : 0, - header_len, (size_t)ctx->message_len); - } - - ctx->total_read = 0; - - memmove(ctx->tmp_buf, &ctx->tmp_buf[header_len], - ctx->tmp_buf_len - header_len); - ctx->tmp_buf_pos -= header_len; - - if (ctx->tmp_buf_pos == 0) { - /* No data after the header, let the caller call - * this function again to get the payload. - */ - return -EAGAIN; - } - - NET_DBG("There is %zd bytes of data", ctx->tmp_buf_pos); - } - - /* Now read the whole payload or parts of it */ - - if (ctx->tmp_buf_pos == 0) { - /* Read more data into temp buffer */ -#if defined(CONFIG_NET_TEST) - size_t input_len = MIN(ctx->tmp_buf_len, test_data->input_len); - - memcpy(ctx->tmp_buf, test_data->input_buf, input_len); - test_data->input_buf += input_len; - - ret = input_len; #else - ret = recv(ctx->real_sock, ctx->tmp_buf, ctx->tmp_buf_len, - K_TIMEOUT_EQ(tout, K_NO_WAIT) ? MSG_DONTWAIT : 0); + ret = recv(ctx->real_sock, ctx->recv_buf.buf, ctx->recv_buf.size, + K_TIMEOUT_EQ(tout, K_NO_WAIT) ? MSG_DONTWAIT : 0); #endif /* CONFIG_NET_TEST */ + if (ret < 0) { + ret = -errno; + if ((ret == -EAGAIN) && (payload.count > 0)) { + /* go to unmasking */ + break; + } + return ret; + } + + if (ret == 0) { + /* Socket closed */ + return -ENOTCONN; + } + + ctx->recv_buf.count = ret; + + NET_DBG("[%p] Received %d bytes", ctx, ret); + } + + ret = websocket_parse(ctx, &payload); if (ret < 0) { - return -errno; + return ret; + } + parsed_count = ret; + + if ((ctx->parser_state == WEBSOCKET_PARSER_STATE_OPCODE) || + (payload.count >= payload.size)) { + if (remaining != NULL) { + *remaining = ctx->parser_remaining; + } + if (message_type != NULL) { + *message_type = ctx->message_type; + } + + size_t left = ctx->recv_buf.count - parsed_count; + + if (left > 0) { + memmove(ctx->recv_buf.buf, &ctx->recv_buf.buf[parsed_count], left); + } + ctx->recv_buf.count = left; + break; } - if (ret == 0) { - return 0; - } + ctx->recv_buf.count -= parsed_count; - ctx->tmp_buf_pos = ret; - } - - if (ctx->tmp_buf_pos <= buf_len) { - /* Is there already any data in the temp buffer? If yes, - * just return it to the caller. - */ - can_copy = MIN(ctx->message_len - ctx->total_read, - ctx->tmp_buf_pos); - } else { - /* We have more data in tmp buffer that will fit into - * user buffer. - */ - can_copy = MIN(ctx->message_len - ctx->total_read, buf_len); - } - - left = ctx->tmp_buf_pos - can_copy; - - NET_ASSERT(ctx->tmp_buf_pos >= can_copy); - - memmove(buf, ctx->tmp_buf, can_copy); - recv_len = can_copy; - - if (left > 0) { - memmove(ctx->tmp_buf, &ctx->tmp_buf[can_copy], left); - } - - ctx->tmp_buf_pos = left; - ctx->total_read += recv_len; + } while (true); /* Unmask the data */ if (ctx->masked) { - /* As we might have less than 4 received bytes, we must select - * which byte from masking value to take. The mask_shift will - * tell that. - */ - int mask_shift = (ctx->total_read - recv_len) % sizeof(uint32_t); - int i; + uint8_t *mask_as_bytes = (uint8_t *)&ctx->masking_value; + size_t data_buf_offset = ctx->message_len - ctx->parser_remaining - payload.count; - for (i = 0; i < recv_len; i++) { - buf[i] ^= ctx->masking_value >> - (8 * (3 - (i + mask_shift) % 4)); + for (size_t i = 0; i < payload.count; i++) { + size_t m = data_buf_offset % 4; + + payload.buf[i] ^= mask_as_bytes[3 - m]; + data_buf_offset++; } } -#if HEXDUMP_RECV_PACKETS - LOG_HEXDUMP_DBG(buf, recv_len, "Payload"); -#endif - - if (remaining) { - *remaining = ctx->message_len - ctx->total_read; - } - - /* Start to read the header again if all the data has been received */ - if (ctx->message_len == ctx->total_read) { - ctx->header_received = false; - ctx->message_len = 0; - ctx->message_type = 0; - ctx->total_read = 0; - } - - return recv_len; + return payload.count; } static int websocket_send(struct websocket_context *ctx, const uint8_t *buf, @@ -1027,8 +1006,12 @@ static int websocket_recv(struct websocket_context *ctx, uint8_t *buf, ret = websocket_recv_msg(ctx->sock, buf, buf_len, &message_type, &remaining, timeout); if (ret < 0) { - errno = -ret; - return -1; + if (ret == -ENOTCONN) { + ret = 0; + } else { + errno = -ret; + return -1; + } } NET_DBG("[%p] Received %d bytes", ctx, ret); diff --git a/subsys/net/lib/websocket/websocket_internal.h b/subsys/net/lib/websocket/websocket_internal.h index adcc32bc921..b830bac4b23 100644 --- a/subsys/net/lib/websocket/websocket_internal.h +++ b/subsys/net/lib/websocket/websocket_internal.h @@ -23,6 +23,29 @@ /* From RFC 6455 chapter 4.2.2 */ #define WS_MAGIC "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" +/** + * Websocket parser states + */ +enum websocket_parser_state { + WEBSOCKET_PARSER_STATE_OPCODE, + WEBSOCKET_PARSER_STATE_LENGTH, + WEBSOCKET_PARSER_STATE_EXT_LEN, + WEBSOCKET_PARSER_STATE_MASK, + WEBSOCKET_PARSER_STATE_PAYLOAD, +}; + +/** + * Description of external buffers for payload and receiving + */ +struct websocket_buffer { + /* external buffer */ + uint8_t *buf; + /* size of external buffer */ + size_t size; + /* data length in external buffer */ + size_t count; +}; + /** * Websocket connection information */ @@ -59,19 +82,12 @@ __net_socket struct websocket_context { int sock; }; - /** Temporary buffers used for HTTP handshakes and Websocket protocol - * headers. User must provide the actual buffer where the headers are + /** Buffer for receiving from TCP socket. + * This buffer used for HTTP handshakes and Websocket packet parser. + * User must provide the actual buffer where the data are * stored temporarily. */ - uint8_t *tmp_buf; - - /** Temporary buffer length. - */ - size_t tmp_buf_len; - - /** Current reading position in the tmp_buf - */ - size_t tmp_buf_pos; + struct websocket_buffer recv_buf; /** The real TCP socket to use when sending Websocket data to peer. */ @@ -80,15 +96,18 @@ __net_socket struct websocket_context { /** Websocket connection masking value */ uint32_t masking_value; - /** Amount of data received. */ - uint64_t total_read; - /** Message length */ uint64_t message_len; /** Message type */ uint32_t message_type; + /** Parser remaining length in current state */ + uint64_t parser_remaining; + + /** Parser state */ + enum websocket_parser_state parser_state; + /** Is the message masked */ uint8_t masked : 1; @@ -100,11 +119,28 @@ __net_socket struct websocket_context { /** Did we receive all from peer during HTTP handshake */ uint8_t all_received : 1; - - /** Header received */ - uint8_t header_received : 1; }; +#if defined(CONFIG_NET_TEST) +/** + * Websocket unit test does not use socket layer but feeds + * the data directly here when testing this function. + */ +struct test_data { + /** pointer to data "tx" buffer */ + uint8_t *input_buf; + + /** "tx" buffer data length */ + size_t input_len; + + /** "tx" buffer read (recv) position */ + size_t input_pos; + + /** external test context */ + struct websocket_context *ctx; +}; +#endif /* CONFIG_NET_TEST */ + /** * @brief Disconnect the Websocket. * diff --git a/tests/net/socket/websocket/src/main.c b/tests/net/socket/websocket/src/main.c index 39d135104f9..27415701e88 100644 --- a/tests/net/socket/websocket/src/main.c +++ b/tests/net/socket/websocket/src/main.c @@ -52,12 +52,6 @@ static uint8_t temp_recv_buf[MAX_RECV_BUF_LEN + EXTRA_BUF_SPACE]; static uint8_t feed_buf[MAX_RECV_BUF_LEN + EXTRA_BUF_SPACE]; static size_t test_msg_len; -struct test_data { - uint8_t *input_buf; - size_t input_len; - struct websocket_context *ctx; -}; - static int test_recv_buf(uint8_t *feed_buf, size_t feed_len, struct websocket_context *ctx, uint32_t *msg_type, uint64_t *remaining, @@ -69,6 +63,7 @@ static int test_recv_buf(uint8_t *feed_buf, size_t feed_len, test_data.ctx = ctx; test_data.input_buf = feed_buf; test_data.input_len = feed_len; + test_data.input_pos = 0; ctx_ptr = POINTER_TO_INT(&test_data); @@ -103,6 +98,9 @@ static const unsigned char frame2[] = { 0xe9, 0xdc }; +/* Empty websocket frame, opcode is ping, without mask */ +static const unsigned char ping[] = {0x89, 0x00}; + #define FRAME1_HDR_SIZE (sizeof(frame1) - (sizeof(frame1_msg) - 1)) static void test_recv(int count) @@ -115,9 +113,9 @@ static void test_recv(int count) memset(&ctx, 0, sizeof(ctx)); - ctx.tmp_buf = temp_recv_buf; - ctx.tmp_buf_len = sizeof(temp_recv_buf); - ctx.tmp_buf_pos = 0; + ctx.recv_buf.buf = temp_recv_buf; + ctx.recv_buf.size = sizeof(temp_recv_buf); + ctx.recv_buf.count = 0; memcpy(feed_buf, &frame1, sizeof(frame1)); @@ -173,6 +171,7 @@ static void test_recv(int count) frame1_msg, recv_buf); zassert_equal(remaining, 0, "Msg not empty"); + zassert_equal(msg_type & WEBSOCKET_FLAG_TEXT, WEBSOCKET_FLAG_TEXT, "Msg is not text"); } static void test_recv_1_byte(void) @@ -225,6 +224,28 @@ static void test_recv_whole_msg(void) test_recv(sizeof(frame1)); } +static void test_recv_empty_ping(void) +{ + struct websocket_context ctx; + int total_read = 0; + uint32_t msg_type = -1; + uint64_t remaining = -1; + + memset(&ctx, 0, sizeof(ctx)); + + ctx.recv_buf.buf = temp_recv_buf; + ctx.recv_buf.size = sizeof(temp_recv_buf); + ctx.recv_buf.count = 0; + + memcpy(feed_buf, &ping, sizeof(ping)); + + total_read = test_recv_buf(&feed_buf[0], sizeof(ping), &ctx, &msg_type, &remaining, + recv_buf, sizeof(recv_buf)); + + zassert_equal(total_read, 0, "Msg not empty (ret %d)", total_read); + zassert_equal(msg_type & WEBSOCKET_FLAG_PING, WEBSOCKET_FLAG_PING, "Msg is not ping"); +} + static void test_recv_2(int count) { struct websocket_context ctx; @@ -235,8 +256,8 @@ static void test_recv_2(int count) memset(&ctx, 0, sizeof(ctx)); - ctx.tmp_buf = temp_recv_buf; - ctx.tmp_buf_len = sizeof(temp_recv_buf); + ctx.recv_buf.buf = temp_recv_buf; + ctx.recv_buf.size = sizeof(temp_recv_buf); memcpy(feed_buf, &frame2, sizeof(frame2)); @@ -251,17 +272,19 @@ static void test_recv_2(int count) frame1_msg, recv_buf); zassert_equal(remaining, 0, "Msg not empty"); + zassert_equal(msg_type & WEBSOCKET_FLAG_TEXT, WEBSOCKET_FLAG_TEXT, "Msg is not text"); - /* Then read again, now we should get EAGAIN as the second message - * header is partially read. + /* Then read again. Take in account that part of second frame + * have read from tx buffer to rx buffer. */ - ret = test_recv_buf(&feed_buf[sizeof(frame1)], count, &ctx, &msg_type, - &remaining, recv_buf, sizeof(recv_buf)); + ret = test_recv_buf(&feed_buf[count], sizeof(frame2) - count, &ctx, &msg_type, &remaining, + recv_buf, sizeof(recv_buf)); - zassert_equal(ret, sizeof(frame1_msg) - 1, - "2nd header parse failed (ret %d)", ret); + zassert_mem_equal(recv_buf, frame1_msg, sizeof(frame1_msg) - 1, + "Invalid 2nd message, should be '%s' was '%s'", frame1_msg, recv_buf); zassert_equal(remaining, 0, "Msg not empty"); + zassert_equal(msg_type & WEBSOCKET_FLAG_TEXT, WEBSOCKET_FLAG_TEXT, "Msg is not text"); } static void test_recv_two_msg(void) @@ -279,15 +302,19 @@ int verify_sent_and_received_msg(struct msghdr *msg, bool split_msg) memset(&ctx, 0, sizeof(ctx)); - ctx.tmp_buf = temp_recv_buf; - ctx.tmp_buf_len = sizeof(temp_recv_buf); + ctx.recv_buf.buf = temp_recv_buf; + ctx.recv_buf.size = sizeof(temp_recv_buf); /* Read first the header */ ret = test_recv_buf(msg->msg_iov[0].iov_base, msg->msg_iov[0].iov_len, &ctx, &msg_type, &remaining, recv_buf, sizeof(recv_buf)); - zassert_equal(ret, -EAGAIN, "Msg header not found"); + if (remaining > 0) { + zassert_equal(ret, -EAGAIN, "Msg header not found"); + } else { + zassert_equal(ret, 0, "Msg header read error (ret %d)", ret); + } /* Then the first split if it is enabled */ if (split_msg) { @@ -339,8 +366,8 @@ static void test_send_and_recv_lorem_ipsum(void) memset(&ctx, 0, sizeof(ctx)); - ctx.tmp_buf = temp_recv_buf; - ctx.tmp_buf_len = sizeof(temp_recv_buf); + ctx.recv_buf.buf = temp_recv_buf; + ctx.recv_buf.size = sizeof(temp_recv_buf); test_msg_len = sizeof(lorem_ipsum) - 1; @@ -360,8 +387,8 @@ static void test_recv_two_large_split_msg(void) memset(&ctx, 0, sizeof(ctx)); - ctx.tmp_buf = temp_recv_buf; - ctx.tmp_buf_len = sizeof(temp_recv_buf); + ctx.recv_buf.buf = temp_recv_buf; + ctx.recv_buf.size = sizeof(temp_recv_buf); test_msg_len = sizeof(lorem_ipsum) - 1; @@ -373,12 +400,67 @@ static void test_recv_two_large_split_msg(void) test_msg_len, ret); } +static void test_send_and_recv_empty_pong(void) +{ + static struct websocket_context ctx; + int ret; + + memset(&ctx, 0, sizeof(ctx)); + + ctx.recv_buf.buf = temp_recv_buf; + ctx.recv_buf.size = sizeof(temp_recv_buf); + + test_msg_len = 0; + + ret = websocket_send_msg(POINTER_TO_INT(&ctx), NULL, test_msg_len, WEBSOCKET_OPCODE_PING, + true, true, SYS_FOREVER_MS); + zassert_equal(ret, test_msg_len, "Should have sent %zd bytes but sent %d instead", + test_msg_len, ret); +} + +static void test_recv_in_small_buffer(void) +{ + struct websocket_context ctx; + uint32_t msg_type = -1; + uint64_t remaining = -1; + int total_read = 0; + int ret; + const size_t frame1_msg_size = sizeof(frame1_msg) - 1; + const size_t recv_buf_size = 7; + + memset(&ctx, 0, sizeof(ctx)); + + ctx.recv_buf.buf = temp_recv_buf; + ctx.recv_buf.size = sizeof(temp_recv_buf); + + memcpy(feed_buf, &frame1, sizeof(frame1)); + + /* Receive first part of message */ + ret = test_recv_buf(&feed_buf[0], sizeof(frame1), &ctx, &msg_type, &remaining, recv_buf, + recv_buf_size); + zassert_equal(ret, recv_buf_size, "Should have received %zd bytes but ret %d", + recv_buf_size, ret); + total_read += ret; + + /* Receive second part of message */ + ret = test_recv_buf(&feed_buf[sizeof(frame1)], 0, &ctx, &msg_type, &remaining, + &recv_buf[recv_buf_size], recv_buf_size); + zassert_equal(ret, frame1_msg_size - recv_buf_size, + "Should have received %zd bytes but ret %d", frame1_msg_size - recv_buf_size, + ret); + total_read += ret; + + /* Check receiving whole message */ + zassert_equal(total_read, frame1_msg_size, "Received not whole message"); + zassert_mem_equal(recv_buf, frame1_msg, frame1_msg_size, + "Invalid message, should be '%s' was '%s'", frame1_msg, recv_buf); +} + void test_main(void) { k_thread_system_pool_assign(k_current_get()); - ztest_test_suite(websocket, - ztest_unit_test(test_recv_1_byte), + ztest_test_suite(websocket, ztest_unit_test(test_recv_1_byte), ztest_unit_test(test_recv_2_byte), ztest_unit_test(test_recv_3_byte), ztest_unit_test(test_recv_6_byte), @@ -388,10 +470,12 @@ void test_main(void) ztest_unit_test(test_recv_10_byte), ztest_unit_test(test_recv_12_byte), ztest_unit_test(test_recv_whole_msg), + ztest_unit_test(test_recv_empty_ping), ztest_unit_test(test_recv_two_msg), ztest_unit_test(test_send_and_recv_lorem_ipsum), - ztest_unit_test(test_recv_two_large_split_msg) - ); + ztest_unit_test(test_recv_two_large_split_msg), + ztest_unit_test(test_send_and_recv_empty_pong), + ztest_unit_test(test_recv_in_small_buffer)); ztest_run_test_suite(websocket); }