net: mqtt: Fix packet length decryption

The standard allows up to 4 bytes of packet length data, while current
implementation parsed up to 5 bytes.

Add additional unit test, which verifies that error is reported in case
of invalid packet length.

Signed-off-by: Robert Lubos <robert.lubos@nordicsemi.no>
This commit is contained in:
Robert Lubos 2020-03-26 10:22:50 +01:00 committed by Jukka Rissanen
commit 1ad165a62d
2 changed files with 74 additions and 2 deletions

View file

@ -158,14 +158,14 @@ static int unpack_data(u32_t length, struct buf_ctx *buf,
* @retval -EINVAL if the length decoding would use more that 4 bytes.
* @retval -EAGAIN if the buffer would be exceeded during the read.
*/
int packet_length_decode(struct buf_ctx *buf, u32_t *length)
static int packet_length_decode(struct buf_ctx *buf, u32_t *length)
{
u8_t shift = 0U;
u8_t bytes = 0U;
*length = 0U;
do {
if (bytes > MQTT_MAX_LENGTH_BYTES) {
if (bytes >= MQTT_MAX_LENGTH_BYTES) {
return -EINVAL;
}
@ -179,6 +179,10 @@ int packet_length_decode(struct buf_ctx *buf, u32_t *length)
bytes++;
} while ((*(buf->cur++) & MQTT_LENGTH_CONTINUATION_BIT) != 0U);
if (*length > MQTT_MAX_PAYLOAD_SIZE) {
return -EINVAL;
}
MQTT_TRC("length:0x%08x", *length);
return 0;

View file

@ -170,6 +170,24 @@ static int eval_msg_unsuback(struct mqtt_test *mqtt_test);
*/
static int eval_msg_disconnect(struct mqtt_test *mqtt_test);
/**
* @brief eval_max_pkt_len Evaluate header with maximum allowed packet
* length.
* @param [in] mqtt_test MQTT test structure
* @return TC_PASS on success
* @return TC_FAIL on error
*/
static int eval_max_pkt_len(struct mqtt_test *mqtt_test);
/**
* @brief eval_corrupted_pkt_len Evaluate header exceeding maximum
* allowed packet length.
* @param [in] mqtt_test MQTT test structure
* @return TC_PASS on success
* @return TC_FAIL on error
*/
static int eval_corrupted_pkt_len(struct mqtt_test *mqtt_test);
/**
* @brief eval_buffers Evaluate if two given buffers are equal
* @param [in] buf Input buffer 1, mostly used as the 'computed'
@ -182,6 +200,7 @@ static int eval_msg_disconnect(struct mqtt_test *mqtt_test);
static int eval_buffers(const struct buf_ctx *buf,
const u8_t *expected, u16_t len);
/**
* @brief print_array Prints the array 'a' of 'size' elements
* @param a The array
@ -513,6 +532,19 @@ static ZTEST_DMEM
u8_t unsuback1[] = {0xb0, 0x02, 0x00, 0x01};
static ZTEST_DMEM struct mqtt_unsuback_param msg_unsuback1 = {.message_id = 1};
static ZTEST_DMEM
u8_t max_pkt_len[] = {0x30, 0xff, 0xff, 0xff, 0x7f};
static ZTEST_DMEM struct buf_ctx max_pkt_len_buf = {
.cur = max_pkt_len, .end = max_pkt_len + sizeof(max_pkt_len)
};
static ZTEST_DMEM
u8_t corrupted_pkt_len[] = {0x30, 0xff, 0xff, 0xff, 0xff, 0x01};
static ZTEST_DMEM struct buf_ctx corrupted_pkt_len_buf = {
.cur = corrupted_pkt_len,
.end = corrupted_pkt_len + sizeof(corrupted_pkt_len)
};
static ZTEST_DMEM
struct mqtt_test mqtt_tests[] = {
@ -608,6 +640,12 @@ struct mqtt_test mqtt_tests[] = {
.ctx = &msg_unsuback1, .eval_fcn = eval_msg_unsuback,
.expected = unsuback1, .expected_len = sizeof(unsuback1)},
{.test_name = "Maximum packet length",
.ctx = &max_pkt_len_buf, .eval_fcn = eval_max_pkt_len},
{.test_name = "Corrupted packet length",
.ctx = &corrupted_pkt_len_buf, .eval_fcn = eval_corrupted_pkt_len},
/* last test case, do not remove it */
{.test_name = NULL}
};
@ -1018,6 +1056,36 @@ static int eval_msg_unsuback(struct mqtt_test *mqtt_test)
return TC_PASS;
}
static int eval_max_pkt_len(struct mqtt_test *mqtt_test)
{
struct buf_ctx *buf = (struct buf_ctx *)mqtt_test->ctx;
int rc;
u8_t flags;
u32_t length;
rc = fixed_header_decode(buf, &flags, &length);
zassert_equal(rc, 0, "fixed_header_decode failed");
zassert_equal(length, MQTT_MAX_PAYLOAD_SIZE,
"Invalid packet length decoded");
return TC_PASS;
}
static int eval_corrupted_pkt_len(struct mqtt_test *mqtt_test)
{
struct buf_ctx *buf = (struct buf_ctx *)mqtt_test->ctx;
int rc;
u8_t flags;
u32_t length;
rc = fixed_header_decode(buf, &flags, &length);
zassert_equal(rc, -EINVAL, "fixed_header_decode should fail");
return TC_PASS;
}
void test_mqtt_packet(void)
{
TC_START("MQTT Library test");