diff --git a/libs/bits/bit_array.go b/libs/bits/bit_array.go index ad4efe31540..18daf50119c 100644 --- a/libs/bits/bit_array.go +++ b/libs/bits/bit_array.go @@ -264,9 +264,17 @@ func (bA *BitArray) PickRandom() (int, bool) { func (bA *BitArray) getNumTrueIndices() int { count := 0 numElems := len(bA.Elems) - for i := 0; i < numElems; i++ { + // handle all elements except the last one + for i := 0; i < numElems-1; i++ { count += bits.OnesCount64(bA.Elems[i]) } + // handle last element + numFinalBits := bA.Bits - (numElems-1)*64 + for i := 0; i < numFinalBits; i++ { + if (bA.Elems[numElems-1] & (uint64(1) << uint64(i))) > 0 { + count++ + } + } return count } diff --git a/libs/bits/bit_array_test.go b/libs/bits/bit_array_test.go index dce587ca00f..ad6a3a9f00e 100644 --- a/libs/bits/bit_array_test.go +++ b/libs/bits/bit_array_test.go @@ -173,6 +173,8 @@ func TestGetNumTrueIndices(t *testing.T) { require.NoError(t, err) result := bitArr.getNumTrueIndices() require.Equal(t, tc.ExpectedResult, result, "for input %s, expected %d, got %d", tc.Input, tc.ExpectedResult, result) + result = bitArr.Not().getNumTrueIndices() + require.Equal(t, bitArr.Bits-result, bitArr.getNumTrueIndices()) } }