Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement for compare #1570

Draft
wants to merge 2 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions encodings/fastlanes/src/for/compare.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use vortex_array::compute::{compare, CompareFn, Operator};
use vortex_array::ArrayData;
use vortex_error::VortexResult;

use crate::{decompress, FoRArray, FoREncoding};

impl CompareFn<FoRArray> for FoREncoding {
fn compare(
&self,
lhs: &FoRArray,
rhs: &ArrayData,
operator: Operator,
) -> VortexResult<Option<ArrayData>> {
// this is cheap
let owned_lhs = lhs.clone();
let decompressed_lhs = decompress(owned_lhs)?;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So by default, the compare function will fall back to canonicalizing an array into the Arrow representation and then running the compare kernel over it.

So we should only implement a compute function for an encoding when it is able to short-cut some work.

In the case of FoR, if the RHS is constant rhs.as_constant(), then you could subtract the lhs.reference() value and then push-down the comparison without decompressing.

For FoR, this doesn't itself make the compare any cheaper, but it does allow FoR's child to possible implement a cheaper comparison.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Get it, thanks

compare(decompressed_lhs, rhs, operator).map(Some)
}
}

#[cfg(test)]
mod tests {
use vortex_array::array::PrimitiveArray;
use vortex_array::compute::{compare, Operator};
use vortex_array::validity::Validity;
use vortex_array::IntoArrayVariant;

use crate::for_compress;

#[test]
fn test_for_compare() {
let lhs = PrimitiveArray::from_vec(vec![1i32, 2, 3, 4, 5], Validity::AllValid);
let lhs = for_compress(&lhs).unwrap();
let rhs = PrimitiveArray::from_vec(vec![1i32, 2, 9, 4, 5], Validity::AllValid);
assert_eq!(
compare(lhs, rhs, Operator::Eq)
.unwrap()
.into_bool()
.unwrap()
.boolean_buffer(),
vec![true, true, false, true, true].into()
);
}
}
8 changes: 6 additions & 2 deletions encodings/fastlanes/src/for/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use std::ops::AddAssign;

use num_traits::{CheckedShl, CheckedShr, WrappingAdd, WrappingSub};
use vortex_array::compute::{
filter, scalar_at, search_sorted, slice, take, ComputeVTable, FilterFn, FilterMask, ScalarAtFn,
SearchResult, SearchSortedFn, SearchSortedSide, SliceFn, TakeFn, TakeOptions,
filter, scalar_at, search_sorted, slice, take, CompareFn, ComputeVTable, FilterFn, FilterMask,
ScalarAtFn, SearchResult, SearchSortedFn, SearchSortedSide, SliceFn, TakeFn, TakeOptions,
};
use vortex_array::variants::PrimitiveArrayTrait;
use vortex_array::{ArrayDType, ArrayData, IntoArrayData};
Expand All @@ -14,6 +14,10 @@ use vortex_scalar::{PValue, Scalar};
use crate::{FoRArray, FoREncoding};

impl ComputeVTable for FoREncoding {
fn compare_fn(&self) -> Option<&dyn CompareFn<ArrayData>> {
Some(self)
}

fn filter_fn(&self) -> Option<&dyn FilterFn<ArrayData>> {
Some(self)
}
Expand Down
1 change: 1 addition & 0 deletions encodings/fastlanes/src/for/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use vortex_dtype::DType;
use vortex_error::{vortex_bail, VortexExpect as _, VortexResult};
use vortex_scalar::{Scalar, ScalarValue};

mod compare;
mod compress;
mod compute;

Expand Down
Loading