diff --git a/include/net/net_pkt.h b/include/net/net_pkt.h index 9ddc0557ee4..e0692496f97 100644 --- a/include/net/net_pkt.h +++ b/include/net/net_pkt.h @@ -1705,6 +1705,22 @@ size_t net_pkt_available_payload_buffer(struct net_pkt *pkt, */ void net_pkt_trim_buffer(struct net_pkt *pkt); +/** + * @brief Remove @a length bytes from tail of packet + * + * @details This function does not take packet cursor into account. It is a + * helper to remove unneeded bytes from tail of packet (like appended + * CRC). It takes care of buffer deallocation if removed bytes span + * whole buffer(s). + * + * @param pkt Network packet + * @param length Number of bytes to be removed + * + * @retval 0 On success. + * @retval -EINVAL If packet length is shorter than @a length. + */ +int net_pkt_remove_tail(struct net_pkt *pkt, size_t length); + /** * @brief Initialize net_pkt cursor * diff --git a/subsys/net/ip/net_pkt.c b/subsys/net/ip/net_pkt.c index 14ec127773d..e295201b8d3 100644 --- a/subsys/net/ip/net_pkt.c +++ b/subsys/net/ip/net_pkt.c @@ -1089,6 +1089,36 @@ void net_pkt_trim_buffer(struct net_pkt *pkt) } } +int net_pkt_remove_tail(struct net_pkt *pkt, size_t length) +{ + struct net_buf *buf = pkt->buffer; + size_t remaining_len = net_pkt_get_len(pkt); + + if (remaining_len < length) { + return -EINVAL; + } + + remaining_len -= length; + + while (buf) { + if (buf->len >= remaining_len) { + buf->len = remaining_len; + + if (buf->frags) { + net_pkt_frag_unref(buf->frags); + buf->frags = NULL; + } + + break; + } + + remaining_len -= buf->len; + buf = buf->frags; + } + + return 0; +} + #if NET_LOG_LEVEL >= LOG_LEVEL_DBG int net_pkt_alloc_buffer_debug(struct net_pkt *pkt, size_t size, diff --git a/tests/net/net_pkt/src/main.c b/tests/net/net_pkt/src/main.c index 39c54cb934b..43f0fa67533 100644 --- a/tests/net/net_pkt/src/main.c +++ b/tests/net/net_pkt/src/main.c @@ -974,6 +974,90 @@ static void test_net_pkt_get_contiguous_len(void) net_pkt_unref(pkt); } +void test_net_pkt_remove_tail(void) +{ + struct net_pkt *pkt; + int err; + + pkt = net_pkt_alloc_with_buffer(NULL, + CONFIG_NET_BUF_DATA_SIZE * 2 + 3, + AF_UNSPEC, 0, K_NO_WAIT); + zassert_true(pkt != NULL, "Pkt not allocated"); + + net_pkt_cursor_init(pkt); + net_pkt_write(pkt, small_buffer, CONFIG_NET_BUF_DATA_SIZE * 2 + 3); + + zassert_equal(net_pkt_get_len(pkt), CONFIG_NET_BUF_DATA_SIZE * 2 + 3, + "Pkt length is invalid"); + zassert_equal(pkt->frags->frags->frags->len, 3, + "3rd buffer length is invalid"); + + /* Remove some bytes from last buffer */ + err = net_pkt_remove_tail(pkt, 2); + zassert_equal(err, 0, "Failed to remove tail"); + + zassert_equal(net_pkt_get_len(pkt), CONFIG_NET_BUF_DATA_SIZE * 2 + 1, + "Pkt length is invalid"); + zassert_not_equal(pkt->frags->frags->frags, NULL, + "3rd buffer was removed"); + zassert_equal(pkt->frags->frags->frags->len, 1, + "3rd buffer length is invalid"); + + /* Remove last byte from last buffer */ + err = net_pkt_remove_tail(pkt, 1); + zassert_equal(err, 0, "Failed to remove tail"); + + zassert_equal(net_pkt_get_len(pkt), CONFIG_NET_BUF_DATA_SIZE * 2, + "Pkt length is invalid"); + zassert_equal(pkt->frags->frags->frags, NULL, + "3rd buffer was not removed"); + zassert_equal(pkt->frags->frags->len, CONFIG_NET_BUF_DATA_SIZE, + "2nd buffer length is invalid"); + + /* Remove 2nd buffer and one byte from 1st buffer */ + err = net_pkt_remove_tail(pkt, CONFIG_NET_BUF_DATA_SIZE + 1); + zassert_equal(err, 0, "Failed to remove tail"); + + zassert_equal(net_pkt_get_len(pkt), CONFIG_NET_BUF_DATA_SIZE - 1, + "Pkt length is invalid"); + zassert_equal(pkt->frags->frags, NULL, + "2nd buffer was not removed"); + zassert_equal(pkt->frags->len, CONFIG_NET_BUF_DATA_SIZE - 1, + "1st buffer length is invalid"); + + net_pkt_unref(pkt); + + pkt = net_pkt_rx_alloc_with_buffer(NULL, + CONFIG_NET_BUF_DATA_SIZE * 2 + 3, + AF_UNSPEC, 0, K_NO_WAIT); + + net_pkt_cursor_init(pkt); + net_pkt_write(pkt, small_buffer, CONFIG_NET_BUF_DATA_SIZE * 2 + 3); + + zassert_equal(net_pkt_get_len(pkt), CONFIG_NET_BUF_DATA_SIZE * 2 + 3, + "Pkt length is invalid"); + zassert_equal(pkt->frags->frags->frags->len, 3, + "3rd buffer length is invalid"); + + /* Remove bytes spanning 3 buffers */ + err = net_pkt_remove_tail(pkt, CONFIG_NET_BUF_DATA_SIZE + 5); + zassert_equal(err, 0, "Failed to remove tail"); + + zassert_equal(net_pkt_get_len(pkt), CONFIG_NET_BUF_DATA_SIZE - 2, + "Pkt length is invalid"); + zassert_equal(pkt->frags->frags, NULL, + "2nd buffer was not removed"); + zassert_equal(pkt->frags->len, CONFIG_NET_BUF_DATA_SIZE - 2, + "1st buffer length is invalid"); + + /* Try to remove more bytes than packet has */ + err = net_pkt_remove_tail(pkt, CONFIG_NET_BUF_DATA_SIZE); + zassert_equal(err, -EINVAL, + "Removing more bytes than available should fail"); + + net_pkt_unref(pkt); +} + void test_main(void) { eth_if = net_if_get_default(); @@ -989,7 +1073,8 @@ void test_main(void) ztest_unit_test(test_net_pkt_clone), ztest_unit_test(test_net_pkt_headroom), ztest_unit_test(test_net_pkt_headroom_copy), - ztest_unit_test(test_net_pkt_get_contiguous_len) + ztest_unit_test(test_net_pkt_get_contiguous_len), + ztest_unit_test(test_net_pkt_remove_tail) ); ztest_run_test_suite(net_pkt_tests);