diff --git a/subsys/net/lib/mqtt/mqtt_decoder.c b/subsys/net/lib/mqtt/mqtt_decoder.c index 36fbc8aef8e..7469a762bac 100644 --- a/subsys/net/lib/mqtt/mqtt_decoder.c +++ b/subsys/net/lib/mqtt/mqtt_decoder.c @@ -158,14 +158,14 @@ static int unpack_data(u32_t length, struct buf_ctx *buf, * @retval -EINVAL if the length decoding would use more that 4 bytes. * @retval -EAGAIN if the buffer would be exceeded during the read. */ -int packet_length_decode(struct buf_ctx *buf, u32_t *length) +static int packet_length_decode(struct buf_ctx *buf, u32_t *length) { u8_t shift = 0U; u8_t bytes = 0U; *length = 0U; do { - if (bytes > MQTT_MAX_LENGTH_BYTES) { + if (bytes >= MQTT_MAX_LENGTH_BYTES) { return -EINVAL; } @@ -179,6 +179,10 @@ int packet_length_decode(struct buf_ctx *buf, u32_t *length) bytes++; } while ((*(buf->cur++) & MQTT_LENGTH_CONTINUATION_BIT) != 0U); + if (*length > MQTT_MAX_PAYLOAD_SIZE) { + return -EINVAL; + } + MQTT_TRC("length:0x%08x", *length); return 0; diff --git a/tests/net/lib/mqtt_packet/src/mqtt_packet.c b/tests/net/lib/mqtt_packet/src/mqtt_packet.c index d6c3761ed85..89bb31c388c 100644 --- a/tests/net/lib/mqtt_packet/src/mqtt_packet.c +++ b/tests/net/lib/mqtt_packet/src/mqtt_packet.c @@ -170,6 +170,24 @@ static int eval_msg_unsuback(struct mqtt_test *mqtt_test); */ static int eval_msg_disconnect(struct mqtt_test *mqtt_test); +/** + * @brief eval_max_pkt_len Evaluate header with maximum allowed packet + * length. + * @param [in] mqtt_test MQTT test structure + * @return TC_PASS on success + * @return TC_FAIL on error + */ +static int eval_max_pkt_len(struct mqtt_test *mqtt_test); + +/** + * @brief eval_corrupted_pkt_len Evaluate header exceeding maximum + * allowed packet length. + * @param [in] mqtt_test MQTT test structure + * @return TC_PASS on success + * @return TC_FAIL on error + */ +static int eval_corrupted_pkt_len(struct mqtt_test *mqtt_test); + /** * @brief eval_buffers Evaluate if two given buffers are equal * @param [in] buf Input buffer 1, mostly used as the 'computed' @@ -182,6 +200,7 @@ static int eval_msg_disconnect(struct mqtt_test *mqtt_test); static int eval_buffers(const struct buf_ctx *buf, const u8_t *expected, u16_t len); + /** * @brief print_array Prints the array 'a' of 'size' elements * @param a The array @@ -513,6 +532,19 @@ static ZTEST_DMEM u8_t unsuback1[] = {0xb0, 0x02, 0x00, 0x01}; static ZTEST_DMEM struct mqtt_unsuback_param msg_unsuback1 = {.message_id = 1}; +static ZTEST_DMEM +u8_t max_pkt_len[] = {0x30, 0xff, 0xff, 0xff, 0x7f}; +static ZTEST_DMEM struct buf_ctx max_pkt_len_buf = { + .cur = max_pkt_len, .end = max_pkt_len + sizeof(max_pkt_len) +}; + +static ZTEST_DMEM +u8_t corrupted_pkt_len[] = {0x30, 0xff, 0xff, 0xff, 0xff, 0x01}; +static ZTEST_DMEM struct buf_ctx corrupted_pkt_len_buf = { + .cur = corrupted_pkt_len, + .end = corrupted_pkt_len + sizeof(corrupted_pkt_len) +}; + static ZTEST_DMEM struct mqtt_test mqtt_tests[] = { @@ -608,6 +640,12 @@ struct mqtt_test mqtt_tests[] = { .ctx = &msg_unsuback1, .eval_fcn = eval_msg_unsuback, .expected = unsuback1, .expected_len = sizeof(unsuback1)}, + {.test_name = "Maximum packet length", + .ctx = &max_pkt_len_buf, .eval_fcn = eval_max_pkt_len}, + + {.test_name = "Corrupted packet length", + .ctx = &corrupted_pkt_len_buf, .eval_fcn = eval_corrupted_pkt_len}, + /* last test case, do not remove it */ {.test_name = NULL} }; @@ -1018,6 +1056,36 @@ static int eval_msg_unsuback(struct mqtt_test *mqtt_test) return TC_PASS; } +static int eval_max_pkt_len(struct mqtt_test *mqtt_test) +{ + struct buf_ctx *buf = (struct buf_ctx *)mqtt_test->ctx; + int rc; + u8_t flags; + u32_t length; + + rc = fixed_header_decode(buf, &flags, &length); + + zassert_equal(rc, 0, "fixed_header_decode failed"); + zassert_equal(length, MQTT_MAX_PAYLOAD_SIZE, + "Invalid packet length decoded"); + + return TC_PASS; +} + +static int eval_corrupted_pkt_len(struct mqtt_test *mqtt_test) +{ + struct buf_ctx *buf = (struct buf_ctx *)mqtt_test->ctx; + int rc; + u8_t flags; + u32_t length; + + rc = fixed_header_decode(buf, &flags, &length); + + zassert_equal(rc, -EINVAL, "fixed_header_decode should fail"); + + return TC_PASS; +} + void test_mqtt_packet(void) { TC_START("MQTT Library test");