-
Notifications
You must be signed in to change notification settings - Fork 841
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Vectorized lexicographical_partition_ranges (~80% faster) #4575
Changes from 1 commit
cc29399
1754705
ea83e26
086bf8e
58904d4
fa9d399
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,146 +17,74 @@ | |
|
||
//! Defines partition kernel for `ArrayRef` | ||
|
||
use crate::sort::{LexicographicalComparator, SortColumn}; | ||
use crate::comparison::neq_dyn; | ||
use crate::sort::SortColumn; | ||
use arrow_array::Array; | ||
use arrow_buffer::BooleanBuffer; | ||
use arrow_schema::ArrowError; | ||
use std::cmp::Ordering; | ||
use std::ops::Range; | ||
|
||
/// Given a list of already sorted columns, find partition ranges that would partition | ||
/// lexicographically equal values across columns. | ||
/// | ||
/// Here LexicographicalComparator is used in conjunction with binary | ||
/// search so the columns *MUST* be pre-sorted already. | ||
/// | ||
/// The returned vec would be of size k where k is cardinality of the sorted values; Consecutive | ||
/// values will be connected: (a, b) and (b, c), where start = 0 and end = n for the first and last | ||
/// range. | ||
pub fn lexicographical_partition_ranges( | ||
columns: &[SortColumn], | ||
) -> Result<impl Iterator<Item = Range<usize>> + '_, ArrowError> { | ||
LexicographicalPartitionIterator::try_new(columns) | ||
} | ||
|
||
struct LexicographicalPartitionIterator<'a> { | ||
comparator: LexicographicalComparator<'a>, | ||
num_rows: usize, | ||
previous_partition_point: usize, | ||
partition_point: usize, | ||
} | ||
if columns.is_empty() { | ||
return Err(ArrowError::InvalidArgumentError( | ||
"Sort requires at least one column".to_string(), | ||
)); | ||
} | ||
let num_rows = columns[0].values.len(); | ||
if columns.iter().any(|item| item.values.len() != num_rows) { | ||
return Err(ArrowError::ComputeError( | ||
"Lexical sort columns have different row counts".to_string(), | ||
)); | ||
}; | ||
|
||
impl<'a> LexicographicalPartitionIterator<'a> { | ||
fn try_new( | ||
columns: &'a [SortColumn], | ||
) -> Result<LexicographicalPartitionIterator, ArrowError> { | ||
if columns.is_empty() { | ||
return Err(ArrowError::InvalidArgumentError( | ||
"Sort requires at least one column".to_string(), | ||
)); | ||
} | ||
let num_rows = columns[0].values.len(); | ||
if columns.iter().any(|item| item.values.len() != num_rows) { | ||
return Err(ArrowError::ComputeError( | ||
"Lexical sort columns have different row counts".to_string(), | ||
)); | ||
}; | ||
let acc = find_boundaries(&columns[0])?; | ||
let acc = columns | ||
.iter() | ||
.skip(1) | ||
.try_fold(acc, |acc, c| find_boundaries(c).map(|b| &acc | &b))?; | ||
|
||
let comparator = LexicographicalComparator::try_new(columns)?; | ||
Ok(LexicographicalPartitionIterator { | ||
comparator, | ||
num_rows, | ||
previous_partition_point: 0, | ||
partition_point: 0, | ||
}) | ||
let mut out = vec![]; | ||
let mut current = 0; | ||
for idx in acc.set_indices() { | ||
let t = current; | ||
current = idx + 1; | ||
out.push(t..current) | ||
} | ||
} | ||
|
||
/// Returns the next partition point of the range `start..end` according to the given comparator. | ||
/// The return value is the index of the first element of the second partition, | ||
/// and is guaranteed to be between `start..=end` (inclusive). | ||
/// | ||
/// The values corresponding to those indices are assumed to be partitioned according to the given comparator. | ||
/// | ||
/// Exponential search is to remedy for the case when array size and cardinality are both large. | ||
/// In these cases the partition point would be near the beginning of the range and | ||
/// plain binary search would be doing some unnecessary iterations on each call. | ||
/// | ||
/// see <https://en.wikipedia.org/wiki/Exponential_search> | ||
#[inline] | ||
fn exponential_search_next_partition_point( | ||
start: usize, | ||
end: usize, | ||
comparator: &LexicographicalComparator<'_>, | ||
) -> usize { | ||
let target = start; | ||
let mut bound = 1; | ||
while bound + start < end | ||
&& comparator.compare(bound + start, target) != Ordering::Greater | ||
{ | ||
bound *= 2; | ||
if current != num_rows { | ||
out.push(current..num_rows) | ||
} | ||
|
||
// invariant after while loop: | ||
// (start + bound / 2) <= target < min(end, start + bound + 1) | ||
// where <= and < are defined by the comparator; | ||
// note here we have right = min(end, start + bound + 1) because (start + bound) might | ||
// actually be considered and must be included. | ||
partition_point(start + bound / 2, end.min(start + bound + 1), |idx| { | ||
comparator.compare(idx, target) != Ordering::Greater | ||
}) | ||
Ok(out.into_iter()) | ||
} | ||
|
||
/// Returns the partition point of the range `start..end` according to the given predicate. | ||
/// The return value is the index of the first element of the second partition, | ||
/// and is guaranteed to be between `start..=end` (inclusive). | ||
/// | ||
/// The algorithm is similar to a binary search. | ||
/// | ||
/// The values corresponding to those indices are assumed to be partitioned according to the given predicate. | ||
/// | ||
/// See [`slice::partition_point`] | ||
#[inline] | ||
fn partition_point<P: Fn(usize) -> bool>(start: usize, end: usize, pred: P) -> usize { | ||
let mut left = start; | ||
let mut right = end; | ||
let mut size = right - left; | ||
while left < right { | ||
let mid = left + size / 2; | ||
/// Returns a mask with bits set whenever the value or nullability changes | ||
fn find_boundaries(col: &SortColumn) -> Result<BooleanBuffer, ArrowError> { | ||
let v = &col.values; | ||
let slice_len = v.len().saturating_sub(1); | ||
let v1 = v.slice(0, slice_len); | ||
let v2 = v.slice(1, slice_len); | ||
|
||
let less = pred(mid); | ||
let array_ne = neq_dyn(v1.as_ref(), v2.as_ref())?; | ||
let values_ne = match array_ne.nulls().filter(|n| n.null_count() > 0) { | ||
Some(n) => n.inner() & array_ne.values(), | ||
None => array_ne.values().clone(), | ||
}; | ||
|
||
if less { | ||
left = mid + 1; | ||
} else { | ||
right = mid; | ||
Ok(match v.nulls().filter(|x| x.null_count() > 0) { | ||
Some(n) => { | ||
let n1 = n.inner().slice(0, slice_len); | ||
let n2 = n.inner().slice(1, slice_len); | ||
&(&n1 ^ &n2) | &values_ne | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is quite possibly some more clever way to bit transitions from a bitmask, however, this is already likely sufficiently fast as to be irrelevant There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree. Took me a while to follow the logic though, a comment could help. "values are either not-equal (and both non-null) or exactly one value is null" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is quite clever |
||
} | ||
|
||
size = right - left; | ||
} | ||
left | ||
} | ||
|
||
impl<'a> Iterator for LexicographicalPartitionIterator<'a> { | ||
type Item = Range<usize>; | ||
|
||
fn next(&mut self) -> Option<Self::Item> { | ||
if self.partition_point < self.num_rows { | ||
// invariant: | ||
// in the range [0..previous_partition_point] all values are <= the value at [previous_partition_point] | ||
// so in order to save time we can do binary search on the range [previous_partition_point..num_rows] | ||
// and find the index where any value is greater than the value at [previous_partition_point] | ||
self.partition_point = exponential_search_next_partition_point( | ||
self.partition_point, | ||
self.num_rows, | ||
&self.comparator, | ||
); | ||
let start = self.previous_partition_point; | ||
let end = self.partition_point; | ||
self.previous_partition_point = self.partition_point; | ||
Some(Range { start, end }) | ||
} else { | ||
None | ||
} | ||
} | ||
None => values_ne, | ||
}) | ||
} | ||
|
||
#[cfg(test)] | ||
|
@@ -167,44 +95,6 @@ mod tests { | |
use arrow_schema::DataType; | ||
use std::sync::Arc; | ||
|
||
#[test] | ||
fn test_partition_point() { | ||
let input = &[1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 4]; | ||
{ | ||
let median = input[input.len() / 2]; | ||
assert_eq!( | ||
9, | ||
partition_point(0, input.len(), |i: usize| input[i].cmp(&median) | ||
!= Ordering::Greater) | ||
); | ||
} | ||
{ | ||
let search = input[9]; | ||
assert_eq!( | ||
12, | ||
partition_point(9, input.len(), |i: usize| input[i].cmp(&search) | ||
!= Ordering::Greater) | ||
); | ||
} | ||
{ | ||
let search = input[0]; | ||
assert_eq!( | ||
3, | ||
partition_point(0, 9, |i: usize| input[i].cmp(&search) | ||
!= Ordering::Greater) | ||
); | ||
} | ||
let input = &[1, 2, 2, 2, 2, 2, 2, 2, 9]; | ||
{ | ||
let search = input[5]; | ||
assert_eq!( | ||
8, | ||
partition_point(5, 9, |i: usize| input[i].cmp(&search) | ||
!= Ordering::Greater) | ||
); | ||
} | ||
} | ||
|
||
#[test] | ||
fn test_lexicographical_partition_ranges_empty() { | ||
let input = vec![]; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In opted to preserve the existing function signature for now, I can definitely see a future incarnation returning the computed bitmask somehow to allow for more optimal processing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe worth a ticket (I can also update the docs in #4615)