lib: bitarray: add method to find nth bit set in region

This is part one of several changes to add more methods to the bitarray api
so that it can be used for broader usecases, specifically LoRaWAN forward
error correction.

Signed-off-by: Lucas Romero <luqasn@gmail.com>
This commit is contained in:
Lucas Romero 2024-05-16 21:07:09 +02:00 committed by Anas Nashif
commit 3f50197a90
3 changed files with 204 additions and 0 deletions

View file

@ -184,6 +184,28 @@ int sys_bitarray_alloc(sys_bitarray_t *bitarray, size_t num_bits,
*/
int sys_bitarray_xor(sys_bitarray_t *dst, sys_bitarray_t *other, size_t num_bits, size_t offset);
/**
* Find nth bit set in region
*
* This counts the number of bits set (@p count) in a
* region (@p offset, @p num_bits) and returns the index (@p found_at)
* of the nth set bit, if it exists, as long with a zero return value.
*
* If it does not exist, @p found_at is not updated and the method returns
*
* @param[in] bitarray Bitarray struct
* @param[in] n Nth bit set to look for
* @param[in] num_bits Number of bits to check, must be larger than 0
* @param[in] offset Starting bit position
* @param[out] found_at Index of the nth bit set, if found
*
* @retval 0 Operation successful
* @retval 1 Nth bit set was not found in region
* @retval -EINVAL Invalid argument (e.g. out-of-bounds access, trying to count 0 bits, etc.)
*/
int sys_bitarray_find_nth_set(sys_bitarray_t *bitarray, size_t n, size_t num_bits, size_t offset,
size_t *found_at);
/**
* Count bits set in a bit array region
*

View file

@ -559,6 +559,81 @@ out:
return ret;
}
int sys_bitarray_find_nth_set(sys_bitarray_t *bitarray, size_t n, size_t num_bits, size_t offset,
size_t *found_at)
{
k_spinlock_key_t key;
size_t count, idx;
uint32_t mask;
struct bundle_data bd;
int ret;
__ASSERT_NO_MSG(bitarray != NULL);
__ASSERT_NO_MSG(bitarray->num_bits > 0);
key = k_spin_lock(&bitarray->lock);
if (n == 0 || num_bits == 0 || offset + num_bits > bitarray->num_bits) {
ret = -EINVAL;
goto out;
}
ret = 1;
mask = 0;
setup_bundle_data(bitarray, &bd, offset, num_bits);
count = POPCOUNT(bitarray->bundles[bd.sidx] & bd.smask);
/* If we already found more bits set than n, we found the target bundle */
if (count >= n) {
idx = bd.sidx;
mask = bd.smask;
goto found;
}
/* Keep looking if there are more bundles */
if (bd.sidx != bd.eidx) {
/* We are now only looking for the remaining bits */
n -= count;
/* First bundle was already checked, keep looking in middle (complete)
* bundles.
*/
for (idx = bd.sidx + 1; idx < bd.eidx; idx++) {
count = POPCOUNT(bitarray->bundles[idx]);
if (count >= n) {
mask = ~(mask & 0);
goto found;
}
n -= count;
}
/* Continue searching in last bundle */
count = POPCOUNT(bitarray->bundles[bd.eidx] & bd.emask);
if (count >= n) {
idx = bd.eidx;
mask = bd.emask;
goto found;
}
}
goto out;
found:
/* The bit we are looking for must be in the current bundle idx.
* Find out the exact index of the bit.
*/
for (int j = 0; j <= bundle_bitness(bitarray) - 1; j++) {
if (bitarray->bundles[idx] & mask & BIT(j)) {
if (--n <= 0) {
*found_at = idx * bundle_bitness(bitarray) + j;
ret = 0;
break;
}
}
}
out:
k_spin_unlock(&bitarray->lock, key);
return ret;
}
int sys_bitarray_free(sys_bitarray_t *bitarray, size_t num_bits,
size_t offset)
{

View file

@ -760,6 +760,113 @@ ZTEST(bitarray, test_bitarray_xor)
zassert_equal(ret, -EINVAL, "sys_bitarray_xor() returned unexpected value: %d", ret);
}
ZTEST(bitarray, test_bitarray_find_nth_set)
{
int ret;
size_t found_at;
/* Bitarrays have embedded spinlocks and can't on the stack. */
if (IS_ENABLED(CONFIG_KERNEL_COHERENCE)) {
ztest_test_skip();
}
SYS_BITARRAY_DEFINE(ba, 128);
printk("Testing bit array nth bit set finding spanning single bundle\n");
/* Pre-populate the bits */
ba.bundles[0] = 0x80000001;
ba.bundles[1] = 0x80000001;
ba.bundles[2] = 0x80000001;
ba.bundles[3] = 0x80000001;
ret = sys_bitarray_find_nth_set(&ba, 1, 1, 0, &found_at);
zassert_equal(ret, 0, "sys_bitarray_find_nth_set() returned unexpected value: %d", ret);
zassert_equal(found_at, 0, "sys_bitarray_find_nth_set() returned unexpected found_at: %d",
found_at);
ret = sys_bitarray_find_nth_set(&ba, 1, 32, 0, &found_at);
zassert_equal(ret, 0, "sys_bitarray_find_nth_set() returned unexpected value: %d", ret);
zassert_equal(found_at, 0, "sys_bitarray_find_nth_set() returned unexpected found_at: %d",
found_at);
ret = sys_bitarray_find_nth_set(&ba, 2, 32, 0, &found_at);
zassert_equal(ret, 0, "sys_bitarray_find_nth_set() returned unexpected value: %d", ret);
zassert_equal(found_at, 31, "sys_bitarray_find_nth_set() returned unexpected found_at: %d",
found_at);
ret = sys_bitarray_find_nth_set(&ba, 1, 31, 1, &found_at);
zassert_equal(ret, 0, "sys_bitarray_find_nth_set() returned unexpected value: %d", ret);
zassert_equal(found_at, 31, "sys_bitarray_find_nth_set() returned unexpected found_at: %d",
found_at);
ret = sys_bitarray_find_nth_set(&ba, 2, 31, 1, &found_at);
zassert_equal(ret, 1, "sys_bitarray_find_nth_set() returned unexpected value: %d", ret);
printk("Testing bit array nth bit set finding spanning multiple bundles\n");
ret = sys_bitarray_find_nth_set(&ba, 1, 128, 0, &found_at);
zassert_equal(ret, 0, "sys_bitarray_find_nth_set() returned unexpected value: %d", ret);
zassert_equal(found_at, 0, "sys_bitarray_find_nth_set() returned unexpected found_at: %d",
found_at);
ret = sys_bitarray_find_nth_set(&ba, 8, 128, 0, &found_at);
zassert_equal(ret, 0, "sys_bitarray_find_nth_set() returned unexpected value: %d", ret);
zassert_equal(found_at, 127, "sys_bitarray_find_nth_set() returned unexpected found_at: %d",
found_at);
ret = sys_bitarray_find_nth_set(&ba, 8, 128, 1, &found_at);
zassert_equal(ret, -EINVAL, "sys_bitarray_find_nth_set() returned unexpected value: %d",
ret);
ret = sys_bitarray_find_nth_set(&ba, 7, 127, 1, &found_at);
zassert_equal(ret, 0, "sys_bitarray_find_nth_set() returned unexpected value: %d", ret);
zassert_equal(found_at, 127, "sys_bitarray_find_nth_set() returned unexpected found_at: %d",
found_at);
ret = sys_bitarray_find_nth_set(&ba, 7, 127, 0, &found_at);
zassert_equal(ret, 0, "sys_bitarray_find_nth_set() returned unexpected value: %d", ret);
zassert_equal(found_at, 96, "sys_bitarray_find_nth_set() returned unexpected found_at: %d",
found_at);
ret = sys_bitarray_find_nth_set(&ba, 6, 127, 1, &found_at);
zassert_equal(ret, 0, "sys_bitarray_find_nth_set() returned unexpected value: %d", ret);
zassert_equal(found_at, 96, "sys_bitarray_find_nth_set() returned unexpected found_at: %d",
found_at);
ret = sys_bitarray_find_nth_set(&ba, 6, 127, 1, &found_at);
zassert_equal(ret, 0, "sys_bitarray_find_nth_set() returned unexpected value: %d", ret);
zassert_equal(found_at, 96, "sys_bitarray_find_nth_set() returned unexpected found_at: %d",
found_at);
ret = sys_bitarray_find_nth_set(&ba, 1, 32, 48, &found_at);
zassert_equal(ret, 0, "sys_bitarray_find_nth_set() returned unexpected value: %d", ret);
zassert_equal(found_at, 63, "sys_bitarray_find_nth_set() returned unexpected found_at: %d",
found_at);
ret = sys_bitarray_find_nth_set(&ba, 2, 32, 48, &found_at);
zassert_equal(ret, 0, "sys_bitarray_find_nth_set() returned unexpected value: %d", ret);
zassert_equal(found_at, 64, "sys_bitarray_find_nth_set() returned unexpected found_at: %d",
found_at);
printk("Testing error cases\n");
ret = sys_bitarray_find_nth_set(&ba, 1, 128, 0, &found_at);
zassert_equal(ret, 0, "sys_bitarray_find_nth_set() returned unexpected value: %d", ret);
ret = sys_bitarray_find_nth_set(&ba, 1, 128, 1, &found_at);
zassert_equal(ret, -EINVAL, "sys_bitarray_find_nth_set() returned unexpected value: %d",
ret);
ret = sys_bitarray_find_nth_set(&ba, 1, 129, 0, &found_at);
zassert_equal(ret, -EINVAL, "sys_bitarray_find_nth_set() returned unexpected value: %d",
ret);
ret = sys_bitarray_find_nth_set(&ba, 0, 128, 0, &found_at);
zassert_equal(ret, -EINVAL, "sys_bitarray_find_nth_set() returned unexpected value: %d",
ret);
}
ZTEST(bitarray, test_bitarray_region_set_clear)
{
int ret;