net: mqtt: Fix packet length decryption
The standard allows up to 4 bytes of packet length data, while current implementation parsed up to 5 bytes. Add additional unit test, which verifies that error is reported in case of invalid packet length. Signed-off-by: Robert Lubos <robert.lubos@nordicsemi.no>
This commit is contained in:
parent
70e2836f58
commit
1ad165a62d
2 changed files with 74 additions and 2 deletions
|
@ -158,14 +158,14 @@ static int unpack_data(u32_t length, struct buf_ctx *buf,
|
|||
* @retval -EINVAL if the length decoding would use more that 4 bytes.
|
||||
* @retval -EAGAIN if the buffer would be exceeded during the read.
|
||||
*/
|
||||
int packet_length_decode(struct buf_ctx *buf, u32_t *length)
|
||||
static int packet_length_decode(struct buf_ctx *buf, u32_t *length)
|
||||
{
|
||||
u8_t shift = 0U;
|
||||
u8_t bytes = 0U;
|
||||
|
||||
*length = 0U;
|
||||
do {
|
||||
if (bytes > MQTT_MAX_LENGTH_BYTES) {
|
||||
if (bytes >= MQTT_MAX_LENGTH_BYTES) {
|
||||
return -EINVAL;
|
||||
}
|
||||
|
||||
|
@ -179,6 +179,10 @@ int packet_length_decode(struct buf_ctx *buf, u32_t *length)
|
|||
bytes++;
|
||||
} while ((*(buf->cur++) & MQTT_LENGTH_CONTINUATION_BIT) != 0U);
|
||||
|
||||
if (*length > MQTT_MAX_PAYLOAD_SIZE) {
|
||||
return -EINVAL;
|
||||
}
|
||||
|
||||
MQTT_TRC("length:0x%08x", *length);
|
||||
|
||||
return 0;
|
||||
|
|
|
@ -170,6 +170,24 @@ static int eval_msg_unsuback(struct mqtt_test *mqtt_test);
|
|||
*/
|
||||
static int eval_msg_disconnect(struct mqtt_test *mqtt_test);
|
||||
|
||||
/**
|
||||
* @brief eval_max_pkt_len Evaluate header with maximum allowed packet
|
||||
* length.
|
||||
* @param [in] mqtt_test MQTT test structure
|
||||
* @return TC_PASS on success
|
||||
* @return TC_FAIL on error
|
||||
*/
|
||||
static int eval_max_pkt_len(struct mqtt_test *mqtt_test);
|
||||
|
||||
/**
|
||||
* @brief eval_corrupted_pkt_len Evaluate header exceeding maximum
|
||||
* allowed packet length.
|
||||
* @param [in] mqtt_test MQTT test structure
|
||||
* @return TC_PASS on success
|
||||
* @return TC_FAIL on error
|
||||
*/
|
||||
static int eval_corrupted_pkt_len(struct mqtt_test *mqtt_test);
|
||||
|
||||
/**
|
||||
* @brief eval_buffers Evaluate if two given buffers are equal
|
||||
* @param [in] buf Input buffer 1, mostly used as the 'computed'
|
||||
|
@ -182,6 +200,7 @@ static int eval_msg_disconnect(struct mqtt_test *mqtt_test);
|
|||
static int eval_buffers(const struct buf_ctx *buf,
|
||||
const u8_t *expected, u16_t len);
|
||||
|
||||
|
||||
/**
|
||||
* @brief print_array Prints the array 'a' of 'size' elements
|
||||
* @param a The array
|
||||
|
@ -513,6 +532,19 @@ static ZTEST_DMEM
|
|||
u8_t unsuback1[] = {0xb0, 0x02, 0x00, 0x01};
|
||||
static ZTEST_DMEM struct mqtt_unsuback_param msg_unsuback1 = {.message_id = 1};
|
||||
|
||||
static ZTEST_DMEM
|
||||
u8_t max_pkt_len[] = {0x30, 0xff, 0xff, 0xff, 0x7f};
|
||||
static ZTEST_DMEM struct buf_ctx max_pkt_len_buf = {
|
||||
.cur = max_pkt_len, .end = max_pkt_len + sizeof(max_pkt_len)
|
||||
};
|
||||
|
||||
static ZTEST_DMEM
|
||||
u8_t corrupted_pkt_len[] = {0x30, 0xff, 0xff, 0xff, 0xff, 0x01};
|
||||
static ZTEST_DMEM struct buf_ctx corrupted_pkt_len_buf = {
|
||||
.cur = corrupted_pkt_len,
|
||||
.end = corrupted_pkt_len + sizeof(corrupted_pkt_len)
|
||||
};
|
||||
|
||||
static ZTEST_DMEM
|
||||
struct mqtt_test mqtt_tests[] = {
|
||||
|
||||
|
@ -608,6 +640,12 @@ struct mqtt_test mqtt_tests[] = {
|
|||
.ctx = &msg_unsuback1, .eval_fcn = eval_msg_unsuback,
|
||||
.expected = unsuback1, .expected_len = sizeof(unsuback1)},
|
||||
|
||||
{.test_name = "Maximum packet length",
|
||||
.ctx = &max_pkt_len_buf, .eval_fcn = eval_max_pkt_len},
|
||||
|
||||
{.test_name = "Corrupted packet length",
|
||||
.ctx = &corrupted_pkt_len_buf, .eval_fcn = eval_corrupted_pkt_len},
|
||||
|
||||
/* last test case, do not remove it */
|
||||
{.test_name = NULL}
|
||||
};
|
||||
|
@ -1018,6 +1056,36 @@ static int eval_msg_unsuback(struct mqtt_test *mqtt_test)
|
|||
return TC_PASS;
|
||||
}
|
||||
|
||||
static int eval_max_pkt_len(struct mqtt_test *mqtt_test)
|
||||
{
|
||||
struct buf_ctx *buf = (struct buf_ctx *)mqtt_test->ctx;
|
||||
int rc;
|
||||
u8_t flags;
|
||||
u32_t length;
|
||||
|
||||
rc = fixed_header_decode(buf, &flags, &length);
|
||||
|
||||
zassert_equal(rc, 0, "fixed_header_decode failed");
|
||||
zassert_equal(length, MQTT_MAX_PAYLOAD_SIZE,
|
||||
"Invalid packet length decoded");
|
||||
|
||||
return TC_PASS;
|
||||
}
|
||||
|
||||
static int eval_corrupted_pkt_len(struct mqtt_test *mqtt_test)
|
||||
{
|
||||
struct buf_ctx *buf = (struct buf_ctx *)mqtt_test->ctx;
|
||||
int rc;
|
||||
u8_t flags;
|
||||
u32_t length;
|
||||
|
||||
rc = fixed_header_decode(buf, &flags, &length);
|
||||
|
||||
zassert_equal(rc, -EINVAL, "fixed_header_decode should fail");
|
||||
|
||||
return TC_PASS;
|
||||
}
|
||||
|
||||
void test_mqtt_packet(void)
|
||||
{
|
||||
TC_START("MQTT Library test");
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue