Skip to content

Commit bb2f4a7

Browse files
rlubosnashif
authored andcommitted
net: mqtt: Fix packet length decryption
The standard allows up to 4 bytes of packet length data, while current implementation parsed up to 5 bytes. Add additional unit test, which verifies that error is reported in case of invalid packet length. Signed-off-by: Robert Lubos <[email protected]>
1 parent a506141 commit bb2f4a7

File tree

2 files changed

+74
-2
lines changed

2 files changed

+74
-2
lines changed

subsys/net/lib/mqtt/mqtt_decoder.c

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,14 +158,14 @@ static int unpack_data(u32_t length, struct buf_ctx *buf,
158158
* @retval -EINVAL if the length decoding would use more that 4 bytes.
159159
* @retval -EAGAIN if the buffer would be exceeded during the read.
160160
*/
161-
int packet_length_decode(struct buf_ctx *buf, u32_t *length)
161+
static int packet_length_decode(struct buf_ctx *buf, u32_t *length)
162162
{
163163
u8_t shift = 0U;
164164
u8_t bytes = 0U;
165165

166166
*length = 0U;
167167
do {
168-
if (bytes > MQTT_MAX_LENGTH_BYTES) {
168+
if (bytes >= MQTT_MAX_LENGTH_BYTES) {
169169
return -EINVAL;
170170
}
171171

@@ -179,6 +179,10 @@ int packet_length_decode(struct buf_ctx *buf, u32_t *length)
179179
bytes++;
180180
} while ((*(buf->cur++) & MQTT_LENGTH_CONTINUATION_BIT) != 0U);
181181

182+
if (*length > MQTT_MAX_PAYLOAD_SIZE) {
183+
return -EINVAL;
184+
}
185+
182186
MQTT_TRC("length:0x%08x", *length);
183187

184188
return 0;

tests/net/lib/mqtt_packet/src/mqtt_packet.c

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,24 @@ static int eval_msg_unsuback(struct mqtt_test *mqtt_test);
190190
*/
191191
static int eval_msg_disconnect(struct mqtt_test *mqtt_test);
192192

193+
/**
194+
* @brief eval_max_pkt_len Evaluate header with maximum allowed packet
195+
* length.
196+
* @param [in] mqtt_test MQTT test structure
197+
* @return TC_PASS on success
198+
* @return TC_FAIL on error
199+
*/
200+
static int eval_max_pkt_len(struct mqtt_test *mqtt_test);
201+
202+
/**
203+
* @brief eval_corrupted_pkt_len Evaluate header exceeding maximum
204+
* allowed packet length.
205+
* @param [in] mqtt_test MQTT test structure
206+
* @return TC_PASS on success
207+
* @return TC_FAIL on error
208+
*/
209+
static int eval_corrupted_pkt_len(struct mqtt_test *mqtt_test);
210+
193211
/**
194212
* @brief eval_buffers Evaluate if two given buffers are equal
195213
* @param [in] buf Input buffer 1, mostly used as the 'computed'
@@ -202,6 +220,7 @@ static int eval_msg_disconnect(struct mqtt_test *mqtt_test);
202220
static int eval_buffers(const struct buf_ctx *buf,
203221
const u8_t *expected, u16_t len);
204222

223+
205224
/**
206225
* @brief print_array Prints the array 'a' of 'size' elements
207226
* @param a The array
@@ -543,6 +562,19 @@ static ZTEST_DMEM
543562
u8_t unsuback1[] = {0xb0, 0x02, 0x00, 0x01};
544563
static ZTEST_DMEM struct mqtt_unsuback_param msg_unsuback1 = {.message_id = 1};
545564

565+
static ZTEST_DMEM
566+
u8_t max_pkt_len[] = {0x30, 0xff, 0xff, 0xff, 0x7f};
567+
static ZTEST_DMEM struct buf_ctx max_pkt_len_buf = {
568+
.cur = max_pkt_len, .end = max_pkt_len + sizeof(max_pkt_len)
569+
};
570+
571+
static ZTEST_DMEM
572+
u8_t corrupted_pkt_len[] = {0x30, 0xff, 0xff, 0xff, 0xff, 0x01};
573+
static ZTEST_DMEM struct buf_ctx corrupted_pkt_len_buf = {
574+
.cur = corrupted_pkt_len,
575+
.end = corrupted_pkt_len + sizeof(corrupted_pkt_len)
576+
};
577+
546578
static ZTEST_DMEM
547579
struct mqtt_test mqtt_tests[] = {
548580

@@ -638,6 +670,12 @@ struct mqtt_test mqtt_tests[] = {
638670
.ctx = &msg_unsuback1, .eval_fcn = eval_msg_unsuback,
639671
.expected = unsuback1, .expected_len = sizeof(unsuback1)},
640672

673+
{.test_name = "Maximum packet length",
674+
.ctx = &max_pkt_len_buf, .eval_fcn = eval_max_pkt_len},
675+
676+
{.test_name = "Corrupted packet length",
677+
.ctx = &corrupted_pkt_len_buf, .eval_fcn = eval_corrupted_pkt_len},
678+
641679
/* last test case, do not remove it */
642680
{.test_name = NULL}
643681
};
@@ -1048,6 +1086,36 @@ static int eval_msg_unsuback(struct mqtt_test *mqtt_test)
10481086
return TC_PASS;
10491087
}
10501088

1089+
static int eval_max_pkt_len(struct mqtt_test *mqtt_test)
1090+
{
1091+
struct buf_ctx *buf = (struct buf_ctx *)mqtt_test->ctx;
1092+
int rc;
1093+
u8_t flags;
1094+
u32_t length;
1095+
1096+
rc = fixed_header_decode(buf, &flags, &length);
1097+
1098+
zassert_equal(rc, 0, "fixed_header_decode failed");
1099+
zassert_equal(length, MQTT_MAX_PAYLOAD_SIZE,
1100+
"Invalid packet length decoded");
1101+
1102+
return TC_PASS;
1103+
}
1104+
1105+
static int eval_corrupted_pkt_len(struct mqtt_test *mqtt_test)
1106+
{
1107+
struct buf_ctx *buf = (struct buf_ctx *)mqtt_test->ctx;
1108+
int rc;
1109+
u8_t flags;
1110+
u32_t length;
1111+
1112+
rc = fixed_header_decode(buf, &flags, &length);
1113+
1114+
zassert_equal(rc, -EINVAL, "fixed_header_decode should fail");
1115+
1116+
return TC_PASS;
1117+
}
1118+
10511119
void test_mqtt_packet(void)
10521120
{
10531121
TC_START("MQTT Library test");

0 commit comments

Comments
 (0)