Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance of set_bits by avoiding to set individual bits #6288

Merged
merged 48 commits into from
Sep 15, 2024
Merged
Changes from 47 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
6e8c864
bench
kazuyukitanimura Aug 22, 2024
a81ba56
fix: Optimize set_bits
kazuyukitanimura Aug 22, 2024
57634f2
clippy
kazuyukitanimura Aug 23, 2024
32ff203
clippyj
kazuyukitanimura Aug 23, 2024
f94f312
miri
kazuyukitanimura Aug 23, 2024
e9cd77a
fix: Optimize set_bits
kazuyukitanimura Aug 23, 2024
842c2b1
fix: Optimize set_bits
kazuyukitanimura Aug 23, 2024
06de184
fix: Optimize set_bits
kazuyukitanimura Aug 23, 2024
03b0db8
fix: Optimize set_bits
kazuyukitanimura Aug 23, 2024
7faa5f3
fix: Optimize set_bits
kazuyukitanimura Aug 23, 2024
e3d812d
fix: Optimize set_bits
kazuyukitanimura Aug 23, 2024
13dec63
fix: Optimize set_bits
kazuyukitanimura Aug 23, 2024
68cdaf2
fix: Optimize set_bits
kazuyukitanimura Aug 23, 2024
1e9de38
miri
kazuyukitanimura Aug 24, 2024
f1e1bbd
miri
kazuyukitanimura Aug 24, 2024
f294663
miri
kazuyukitanimura Aug 24, 2024
39719c4
miri
kazuyukitanimura Aug 24, 2024
9fbb87d
miri
kazuyukitanimura Aug 24, 2024
74b9d80
miri
kazuyukitanimura Aug 24, 2024
25c309e
miri
kazuyukitanimura Aug 24, 2024
7905330
miri
kazuyukitanimura Aug 24, 2024
6dd9771
miri
kazuyukitanimura Aug 24, 2024
0e956cc
miri
kazuyukitanimura Aug 24, 2024
272ecbb
miri
kazuyukitanimura Aug 24, 2024
08ebf20
address review comments
kazuyukitanimura Sep 3, 2024
d751a7f
address review comments
kazuyukitanimura Sep 3, 2024
ef2864f
address review comments
kazuyukitanimura Sep 4, 2024
e69cf9a
Revert "address review comments"
kazuyukitanimura Sep 4, 2024
b5f8bca
address review comments
kazuyukitanimura Sep 5, 2024
9c15417
address review comments
kazuyukitanimura Sep 5, 2024
533381a
address review comments
kazuyukitanimura Sep 5, 2024
dca9ab8
address review comments
kazuyukitanimura Sep 5, 2024
7f3c3fb
address review comments
kazuyukitanimura Sep 5, 2024
6ccedd2
address review comments
kazuyukitanimura Sep 6, 2024
ff2f3ca
address review comments
kazuyukitanimura Sep 6, 2024
fb46cb0
address review comments
kazuyukitanimura Sep 6, 2024
3fd5e3e
address review comments
kazuyukitanimura Sep 6, 2024
be3076e
address review comments
kazuyukitanimura Sep 6, 2024
58868c1
address review comments
kazuyukitanimura Sep 6, 2024
a15db14
address review comments
kazuyukitanimura Sep 6, 2024
fefafa7
Revert "address review comments"
kazuyukitanimura Sep 6, 2024
d8c3f08
address review comments
kazuyukitanimura Sep 6, 2024
cc5ec2b
address review comments
kazuyukitanimura Sep 6, 2024
4c39dc8
address review comments
kazuyukitanimura Sep 9, 2024
f4789be
address review comments
kazuyukitanimura Sep 9, 2024
59fd805
address review comments
kazuyukitanimura Sep 10, 2024
7d81076
address review comments
kazuyukitanimura Sep 13, 2024
f185a19
address review comments
kazuyukitanimura Sep 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 158 additions & 26 deletions arrow-buffer/src/util/bit_mask.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,48 +17,144 @@

//! Utils for working with packed bit masks

use crate::bit_chunk_iterator::BitChunks;
use crate::bit_util::{ceil, get_bit, set_bit};
use crate::bit_util::ceil;

/// Sets all bits on `write_data` in the range `[offset_write..offset_write+len]` to be equal to the
/// bits in `data` in the range `[offset_read..offset_read+len]`
/// returns the number of `0` bits `data[offset_read..offset_read+len]`
/// `offset_write`, `offset_read`, and `len` are in terms of bits
pub fn set_bits(
write_data: &mut [u8],
data: &[u8],
offset_write: usize,
offset_read: usize,
len: usize,
) -> usize {
assert!(offset_write + len <= write_data.len() * 8);
kazuyukitanimura marked this conversation as resolved.
Show resolved Hide resolved
assert!(offset_read + len <= data.len() * 8);
let mut null_count = 0;

let mut bits_to_align = offset_write % 8;
if bits_to_align > 0 {
bits_to_align = std::cmp::min(len, 8 - bits_to_align);
let mut acc = 0;
while len > acc {
// SAFETY: the arguments to `set_upto_64bits` are within the valid range because
// (offset_write + acc) + (len - acc) == offset_write + len <= write_data.len() * 8
// (offset_read + acc) + (len - acc) == offset_read + len <= data.len() * 8
let (n, len_set) = unsafe {
set_upto_64bits(
write_data,
data,
offset_write + acc,
offset_read + acc,
len - acc,
)
};
null_count += n;
acc += len_set;
}
let mut write_byte_index = ceil(offset_write + bits_to_align, 8);

// Set full bytes provided by bit chunk iterator (which iterates in 64 bits at a time)
let chunks = BitChunks::new(data, offset_read + bits_to_align, len - bits_to_align);
chunks.iter().for_each(|chunk| {
null_count += chunk.count_zeros();
write_data[write_byte_index..write_byte_index + 8].copy_from_slice(&chunk.to_le_bytes());
write_byte_index += 8;
});

// Set individual bits both to align write_data to a byte offset and the remainder bits not covered by the bit chunk iterator
let remainder_offset = len - chunks.remainder_len();
(0..bits_to_align)
.chain(remainder_offset..len)
.for_each(|i| {
if get_bit(data, offset_read + i) {
set_bit(write_data, offset_write + i);

null_count
}

/// Similar to `set_bits` but sets only upto 64 bits, actual number of bits set may vary.
/// Returns a pair of the number of `0` bits and the number of bits set
///
/// # Safety
kazuyukitanimura marked this conversation as resolved.
Show resolved Hide resolved
/// The caller must ensure all arguments are within the valid range.
#[inline]
unsafe fn set_upto_64bits(
write_data: &mut [u8],
data: &[u8],
offset_write: usize,
offset_read: usize,
len: usize,
kazuyukitanimura marked this conversation as resolved.
Show resolved Hide resolved
) -> (usize, usize) {
let read_byte = offset_read / 8;
let read_shift = offset_read % 8;
let write_byte = offset_write / 8;
let write_shift = offset_write % 8;

if len >= 64 {
let chunk = unsafe { (data.as_ptr().add(read_byte) as *const u64).read_unaligned() };
if read_shift == 0 {
if write_shift == 0 {
kazuyukitanimura marked this conversation as resolved.
Show resolved Hide resolved
// no shifting necessary
let len = 64;
let null_count = chunk.count_zeros() as usize;
unsafe { write_u64_bytes(write_data, write_byte, chunk) };
(null_count, len)
} else {
null_count += 1;
// only write shifting necessary
let len = 64 - write_shift;
let chunk = chunk << write_shift;
let null_count = len - chunk.count_ones() as usize;
unsafe { or_write_u64_bytes(write_data, write_byte, chunk) };
(null_count, len)
}
});
} else if write_shift == 0 {
// only read shifting necessary
let len = 64 - 8; // 56 bits so the next set_upto_64bits call will see write_shift == 0
let chunk = (chunk >> read_shift) & 0x00FFFFFFFFFFFFFF; // 56 bits mask
let null_count = len - chunk.count_ones() as usize;
unsafe { write_u64_bytes(write_data, write_byte, chunk) };
(null_count, len)
} else {
let len = 64 - std::cmp::max(read_shift, write_shift);
let chunk = (chunk >> read_shift) << write_shift;
let null_count = len - chunk.count_ones() as usize;
unsafe { or_write_u64_bytes(write_data, write_byte, chunk) };
(null_count, len)
}
} else if len == 1 {
let byte_chunk = (unsafe { data.get_unchecked(read_byte) } >> read_shift) & 1;
unsafe { *write_data.get_unchecked_mut(write_byte) |= byte_chunk << write_shift };
((byte_chunk ^ 1) as usize, 1)
} else {
let len = std::cmp::min(len, 64 - std::cmp::max(read_shift, write_shift));
let bytes = ceil(len + read_shift, 8);
// SAFETY: the args of `read_bytes_to_u64` are valid as read_byte + bytes <= data.len()
let chunk = unsafe { read_bytes_to_u64(data, read_byte, bytes) };
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add some // SAFETY: explanations to unsafe usages?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

let mask = u64::MAX >> (64 - len);
let chunk = (chunk >> read_shift) & mask; // masking to read `len` bits only
let chunk = chunk << write_shift; // shifting back to align with `write_data`
let null_count = len - chunk.count_ones() as usize;
let bytes = ceil(len + write_shift, 8);
for (i, c) in chunk.to_le_bytes().iter().enumerate().take(bytes) {
unsafe { *write_data.get_unchecked_mut(write_byte + i) |= c };
}
(null_count, len)
}
}

null_count as usize
/// # Safety
/// The caller must ensure all arguments are within the valid range.
#[inline]
unsafe fn read_bytes_to_u64(data: &[u8], offset: usize, count: usize) -> u64 {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function doesn't limit reading bytes to be up 8 bytes. Do you want to add an assert?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @viirya updated

debug_assert!(count <= 8);
let mut tmp = std::mem::MaybeUninit::<u64>::new(0);
let src = data.as_ptr().add(offset);
unsafe {
std::ptr::copy_nonoverlapping(src, tmp.as_mut_ptr() as *mut u8, count);
tmp.assume_init()
}
}

/// # Safety
/// The caller must ensure `data` has `offset..(offset + 8)` range
#[inline]
kazuyukitanimura marked this conversation as resolved.
Show resolved Hide resolved
unsafe fn write_u64_bytes(data: &mut [u8], offset: usize, chunk: u64) {
let ptr = data.as_mut_ptr().add(offset) as *mut u64;
ptr.write_unaligned(chunk);
}

/// Similar to `write_u64_bytes`, but this method ORs the offset addressed `data` and `chunk`
/// instead of overwriting
///
/// # Safety
/// The caller must ensure `data` has `offset..(offset + 8)` range
#[inline]
kazuyukitanimura marked this conversation as resolved.
Show resolved Hide resolved
unsafe fn or_write_u64_bytes(data: &mut [u8], offset: usize, chunk: u64) {
let ptr = data.as_mut_ptr().add(offset);
let chunk = chunk | (*ptr) as u64;
(ptr as *mut u64).write_unaligned(chunk);
}

#[cfg(test)]
Expand Down Expand Up @@ -185,4 +281,40 @@ mod tests {
assert_eq!(destination, expected_data);
assert_eq!(result, expected_null_count);
}

#[test]
fn test_set_upto_64bits() {
// len >= 64
let write_data: &mut [u8] = &mut [0; 9];
let data: &[u8] = &[
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please also add a test that is greater than 64 bits (not just = 64 bits)?

Copy link
Contributor Author

@kazuyukitanimura kazuyukitanimura Sep 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am working on some more tests too. Stay tuned...

0b00000001, 0b00000001, 0b00000001, 0b00000001, 0b00000001, 0b00000001, 0b00000001,
0b00000001,
];
let offset_write = 1;
let offset_read = 0;
let len = 64;
let (n, len_set) =
unsafe { set_upto_64bits(write_data, data, offset_write, offset_read, len) };
assert_eq!(n, 55);
assert_eq!(len_set, 63);
assert_eq!(
write_data,
&[
0b00000010, 0b00000010, 0b00000010, 0b00000010, 0b00000010, 0b00000010, 0b00000010,
0b00000010, 0b00000000
]
);

// len = 1
let write_data: &mut [u8] = &mut [0b00000000];
let data: &[u8] = &[0b00000001];
let offset_write = 1;
let offset_read = 0;
let len = 1;
let (n, len_set) =
unsafe { set_upto_64bits(write_data, data, offset_write, offset_read, len) };
assert_eq!(n, 0);
assert_eq!(len_set, 1);
assert_eq!(write_data, &[0b00000010]);
}
}
Loading