Skip to content

Commit

Permalink
Fix RunEndArray filter (#1380)
Browse files Browse the repository at this point in the history
fix #1368
  • Loading branch information
robert3005 authored Nov 20, 2024
1 parent 22552b3 commit 5b4679d
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 62 deletions.
62 changes: 13 additions & 49 deletions encodings/runend/src/compress.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
use std::cmp::min;

use arrow_buffer::BooleanBufferBuilder;
use itertools::Itertools;
use num_traits::{AsPrimitive, FromPrimitive};
use vortex_array::array::{BoolArray, BooleanBuffer, PrimitiveArray};
use vortex_array::validity::Validity;
use vortex_array::variants::PrimitiveArrayTrait;
use vortex_array::ArrayDType;
use vortex_dtype::{match_each_integer_ptype, match_each_native_ptype, NativePType, Nullability};
use vortex_error::{vortex_panic, VortexResult};
use vortex_error::VortexResult;

use crate::iter::trimmed_ends_iter;

pub fn runend_encode(array: &PrimitiveArray) -> (PrimitiveArray, PrimitiveArray) {
let validity = if array.dtype().nullability() == Nullability::NonNullable {
Expand Down Expand Up @@ -61,9 +60,8 @@ pub fn runend_decode_primitive(
match_each_native_ptype!(values.ptype(), |$P| {
match_each_integer_ptype!(ends.ptype(), |$E| {
Ok(PrimitiveArray::from_vec(runend_decode_typed_primitive(
ends.maybe_null_slice::<$E>(),
trimmed_ends_iter(ends.maybe_null_slice::<$E>(), offset, length),
values.maybe_null_slice::<$P>(),
offset,
length,
), validity))
})
Expand All @@ -79,67 +77,33 @@ pub fn runend_decode_bools(
) -> VortexResult<BoolArray> {
match_each_integer_ptype!(ends.ptype(), |$E| {
BoolArray::try_new(runend_decode_typed_bool(
ends.maybe_null_slice::<$E>(),
trimmed_ends_iter(ends.maybe_null_slice::<$E>(), offset, length),
values.boolean_buffer(),
offset,
length,
), validity)
})
}

#[inline]
fn trimmed_run_ends<E: NativePType + AsPrimitive<usize> + FromPrimitive + Ord>(
run_ends: &[E],
offset: usize,
length: usize,
) -> impl Iterator<Item = E> + use<'_, E> {
let offset_e = E::from_usize(offset).unwrap_or_else(|| {
vortex_panic!(
"offset {} cannot be converted to {}",
offset,
std::any::type_name::<E>()
)
});
let length_e = E::from_usize(length).unwrap_or_else(|| {
vortex_panic!(
"length {} cannot be converted to {}",
length,
std::any::type_name::<E>()
)
});
run_ends
.iter()
.map(move |&v| v - offset_e)
.map(move |v| min(v, length_e))
}

pub fn runend_decode_typed_primitive<
E: NativePType + AsPrimitive<usize> + FromPrimitive + Ord,
T: NativePType,
>(
run_ends: &[E],
pub fn runend_decode_typed_primitive<T: NativePType>(
run_ends: impl Iterator<Item = usize>,
values: &[T],
offset: usize,
length: usize,
) -> Vec<T> {
let trimmed_ends = trimmed_run_ends(run_ends, offset, length);
let mut decoded = Vec::with_capacity(length);
for (end, value) in trimmed_ends.zip_eq(values) {
decoded.extend(std::iter::repeat_n(value, end.as_() - decoded.len()));
for (end, value) in run_ends.zip_eq(values) {
decoded.extend(std::iter::repeat_n(value, end - decoded.len()));
}
decoded
}

pub fn runend_decode_typed_bool<E: NativePType + AsPrimitive<usize> + FromPrimitive + Ord>(
run_ends: &[E],
pub fn runend_decode_typed_bool(
run_ends: impl Iterator<Item = usize>,
values: BooleanBuffer,
offset: usize,
length: usize,
) -> BooleanBuffer {
let trimmed_ends = trimmed_run_ends(run_ends, offset, length);
let mut decoded = BooleanBufferBuilder::new(length);
for (end, value) in trimmed_ends.zip_eq(values.iter()) {
decoded.append_n(end.as_() - decoded.len(), value);
for (end, value) in run_ends.zip_eq(values.iter()) {
decoded.append_n(end - decoded.len(), value);
}
decoded.finish()
}
Expand Down
54 changes: 42 additions & 12 deletions encodings/runend/src/compute.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::cmp::min;
use std::ops::AddAssign;

use num_traits::AsPrimitive;
Expand Down Expand Up @@ -76,9 +77,8 @@ impl TakeFn for RunEndArray {
let primitive_indices = indices.clone().into_primitive()?;
let u64_indices = match_each_integer_ptype!(primitive_indices.ptype(), |$P| {
primitive_indices
.maybe_null_slice::<$P>()
.iter()
.copied()
.into_maybe_null_slice::<$P>()
.into_iter()
.map(|idx| {
let usize_idx = idx as usize;
if usize_idx >= self.len() {
Expand All @@ -89,11 +89,11 @@ impl TakeFn for RunEndArray {
})
.collect::<VortexResult<Vec<u64>>>()?
});
let physical_indices: Vec<u64> = self
let physical_indices = self
.find_physical_indices(&u64_indices)?
.iter()
.map(|idx| *idx as u64)
.collect();
.into_iter()
.map(|idx| idx as u64)
.collect::<Vec<_>>();
let physical_indices_array = PrimitiveArray::from(physical_indices).into_array();
let dense_values = take(self.values(), &physical_indices_array, options)?;

Expand Down Expand Up @@ -146,12 +146,12 @@ impl SliceFn for RunEndArray {

impl FilterFn for RunEndArray {
fn filter(&self, mask: FilterMask) -> VortexResult<ArrayData> {
let validity = self.validity().filter(&mask)?;
let primitive_run_ends = self.ends().into_primitive()?;
let (run_ends, mask) = match_each_unsigned_integer_ptype!(primitive_run_ends.ptype(), |$P| {
filter_run_ends(primitive_run_ends.maybe_null_slice::<$P>(), mask)?
let (run_ends, values_mask) = match_each_unsigned_integer_ptype!(primitive_run_ends.ptype(), |$P| {
filter_run_ends(primitive_run_ends.maybe_null_slice::<$P>(), self.offset() as u64, self.len() as u64, mask)?
});
let validity = self.validity().filter(&mask)?;
let values = filter(&self.values(), mask)?;
let values = filter(&self.values(), values_mask)?;

RunEndArray::try_new(run_ends.into_array(), values, validity).map(|a| a.into_array())
}
Expand All @@ -160,6 +160,8 @@ impl FilterFn for RunEndArray {
// Code adapted from apache arrow-rs https://github.com/apache/arrow-rs/blob/b1f5c250ebb6c1252b4e7c51d15b8e77f4c361fa/arrow-select/src/filter.rs#L425
fn filter_run_ends<R: NativePType + AddAssign + From<bool> + AsPrimitive<u64>>(
run_ends: &[R],
offset: u64,
length: u64,
mask: FilterMask,
) -> VortexResult<(PrimitiveArray, FilterMask)> {
let mut new_run_ends = vec![R::zero(); run_ends.len()];
Expand All @@ -171,7 +173,7 @@ fn filter_run_ends<R: NativePType + AddAssign + From<bool> + AsPrimitive<u64>>(

let new_mask: FilterMask = BooleanBuffer::collect_bool(run_ends.len(), |i| {
let mut keep = false;
let end = run_ends[i].as_();
let end = min(run_ends[i].as_() - offset, length);

// Safety: predicate must be the same length as the array the ends have been taken from
for pred in (start..end).map(|i| unsafe { filter_values.value_unchecked(i as usize) }) {
Expand Down Expand Up @@ -464,6 +466,34 @@ mod test {
);
}

#[test]
fn filter_sliced_run_end() {
let arr = slice(ree_array(), 2, 7).unwrap();
let filtered = filter(
&arr,
FilterMask::from_iter([true, false, false, true, true]),
)
.unwrap();
let filtered_run_end = RunEndArray::try_from(filtered).unwrap();

assert_eq!(
filtered_run_end
.ends()
.into_primitive()
.unwrap()
.maybe_null_slice::<u64>(),
[1, 2, 3]
);
assert_eq!(
filtered_run_end
.values()
.into_primitive()
.unwrap()
.maybe_null_slice::<i32>(),
[1, 4, 2]
);
}

#[test]
fn compare_run_end() {
let arr = ree_array();
Expand Down
33 changes: 33 additions & 0 deletions encodings/runend/src/iter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use std::cmp::min;

use num_traits::{AsPrimitive, FromPrimitive};
use vortex_dtype::NativePType;
use vortex_error::vortex_panic;

#[inline]
pub fn trimmed_ends_iter<E: NativePType + FromPrimitive + AsPrimitive<usize> + Ord>(
run_ends: &[E],
offset: usize,
length: usize,
) -> impl Iterator<Item = usize> + use<'_, E> {
let offset_e = E::from_usize(offset).unwrap_or_else(|| {
vortex_panic!(
"offset {} cannot be converted to {}",
offset,
std::any::type_name::<E>()
)
});
let length_e = E::from_usize(length).unwrap_or_else(|| {
vortex_panic!(
"length {} cannot be converted to {}",
length,
std::any::type_name::<E>()
)
});
run_ends
.iter()
.copied()
.map(move |v| v - offset_e)
.map(move |v| min(v, length_e))
.map(|v| v.as_())
}
1 change: 1 addition & 0 deletions encodings/runend/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ pub use array::*;
mod array;
pub mod compress;
mod compute;
mod iter;
10 changes: 9 additions & 1 deletion fuzz/fuzz_targets/array_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,15 @@ fn assert_search_sorted(
}

fn assert_array_eq(lhs: &ArrayData, rhs: &ArrayData, step: usize) {
assert_eq!(lhs.len(), rhs.len());
assert_eq!(
lhs.len(),
rhs.len(),
"LHS len {} != RHS len {}, lhs is {} rhs is {} in step {step}",
lhs.len(),
rhs.len(),
lhs.encoding().id(),
rhs.encoding().id()
);
for idx in 0..lhs.len() {
let l = scalar_at(lhs, idx).unwrap();
let r = scalar_at(rhs, idx).unwrap();
Expand Down

0 comments on commit 5b4679d

Please sign in to comment.