Skip to content

Commit

Permalink
feat: propagate statistics through compression (#1236)
Browse files Browse the repository at this point in the history
fixes #1174
  • Loading branch information
lwwmanning authored Nov 7, 2024
1 parent 994c01f commit c1cee33
Show file tree
Hide file tree
Showing 32 changed files with 477 additions and 131 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ resolver = "2"
version = "0.14.0"
homepage = "https://github.com/spiraldb/vortex"
repository = "https://github.com/spiraldb/vortex"
authors = ["Vortex Authors <hello@spiraldb.com>"]
authors = ["Vortex Authors <hello@vortex.dev>"]
license = "Apache-2.0"
keywords = ["vortex"]
include = [
Expand Down
36 changes: 36 additions & 0 deletions encodings/fastlanes/src/for/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ fn decompress_primitive<T: NativePType + WrappingAdd + PrimInt>(
mod test {
use vortex_array::compute::unary::ScalarAtFn;
use vortex_array::IntoArrayVariant;
use vortex_dtype::Nullability;

use super::*;

Expand All @@ -133,6 +134,41 @@ mod test {
assert_eq!(u32::try_from(compressed.reference()).unwrap(), 1_000_000u32);
}

#[test]
fn test_zeros() {
let array = PrimitiveArray::from(vec![0i32; 10_000]);
assert!(array.statistics().to_set().into_iter().next().is_none());

let compressed = for_compress(&array).unwrap();
let constant = ConstantArray::try_from(compressed).unwrap();
assert_eq!(constant.scalar_value(), &ScalarValue::from(0i32));
}

#[test]
fn test_nullable_zeros() {
let array = PrimitiveArray::from_nullable_vec(
vec![Some(0i32), None]
.into_iter()
.cycle()
.take(10_000)
.collect_vec(),
);
assert!(array.statistics().to_set().into_iter().next().is_none());

let compressed = for_compress(&array).unwrap();
let sparse = SparseArray::try_from(compressed).unwrap();
assert!(sparse.statistics().to_set().into_iter().next().is_none());
assert_eq!(sparse.fill_value(), &ScalarValue::Null);
assert_eq!(
sparse.scalar_at(0).unwrap(),
Scalar::primitive(0i32, Nullability::Nullable)
);
assert_eq!(
sparse.scalar_at(1).unwrap(),
Scalar::null(sparse.dtype().clone())
);
}

#[test]
fn test_decompress() {
// Create a range offset by a million
Expand Down
11 changes: 8 additions & 3 deletions vortex-array/src/array/bool/stats.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use arrow_buffer::BooleanBuffer;
use itertools::Itertools;
use vortex_dtype::{DType, Nullability};
use vortex_error::VortexResult;

Expand Down Expand Up @@ -43,7 +44,7 @@ impl ArrayStatisticsCompute for NullableBools<'_> {
acc.n_nulls(first_non_null);
self.0
.iter()
.zip(self.1.iter())
.zip_eq(self.1.iter())
.skip(first_non_null + 1)
.map(|(next, valid)| valid.then_some(next))
.for_each(|next| acc.nullable_next(next));
Expand All @@ -59,6 +60,10 @@ impl ArrayStatisticsCompute for NullableBools<'_> {

impl ArrayStatisticsCompute for BooleanBuffer {
fn compute_statistics(&self, _stat: Stat) -> VortexResult<StatsSet> {
if self.is_empty() {
return Ok(StatsSet::new());
}

let mut stats = BoolStatsAccumulator::new(self.value(0));
self.iter().skip(1).for_each(|next| stats.next(next));
Ok(stats.finish())
Expand All @@ -75,7 +80,7 @@ struct BoolStatsAccumulator {
}

impl BoolStatsAccumulator {
fn new(first_value: bool) -> Self {
pub fn new(first_value: bool) -> Self {
Self {
prev: first_value,
is_sorted: true,
Expand All @@ -86,7 +91,7 @@ impl BoolStatsAccumulator {
}
}

fn n_nulls(&mut self, n_nulls: usize) {
pub fn n_nulls(&mut self, n_nulls: usize) {
self.null_count += n_nulls;
self.len += n_nulls;
}
Expand Down
3 changes: 1 addition & 2 deletions vortex-array/src/array/chunked/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@ impl ArrayStatisticsCompute for ChunkedArray {
.chunks()
.map(|c| {
let s = c.statistics();
// HACK(robert): This will compute all stats, but we could just compute one
s.compute(stat);
s.to_set()
})
.reduce(|mut acc, x| {
acc.merge(&x);
acc.merge_ordered(&x);
acc
})
.unwrap_or_default())
Expand Down
22 changes: 8 additions & 14 deletions vortex-array/src/array/constant/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,14 @@ use serde::{Deserialize, Serialize};
use vortex_error::{vortex_panic, VortexResult};
use vortex_scalar::{Scalar, ScalarValue};

use crate::aliases::hash_map::HashMap;
use crate::array::visitor::{AcceptArrayVisitor, ArrayVisitor};
use crate::encoding::ids;
use crate::stats::{Stat, StatsSet};
use crate::stats::{ArrayStatisticsCompute, Stat, StatsSet};
use crate::validity::{ArrayValidity, LogicalValidity};
use crate::{impl_encoding, ArrayDType, ArrayTrait};

mod canonical;
mod compute;
mod stats;
mod variants;

impl_encoding!("vortex.constant", ids::CONSTANT, Constant);
Expand All @@ -39,24 +37,14 @@ impl ConstantArray {
S: Into<Scalar>,
{
let scalar = scalar.into();
// TODO(aduffy): add stats for bools, ideally there should be a
// StatsSet::constant(Scalar) constructor that does this for us, like StatsSet::nulls.
let stats = StatsSet::from(HashMap::from([
(Stat::Max, scalar.clone()),
(Stat::Min, scalar.clone()),
(Stat::IsConstant, true.into()),
(Stat::IsSorted, true.into()),
(Stat::RunCount, 1.into()),
]));

Self::try_from_parts(
scalar.dtype().clone(),
length,
ConstantMetadata {
scalar_value: scalar.value().clone(),
},
[].into(),
stats,
StatsSet::constant(scalar.clone(), length),
)
.unwrap_or_else(|err| {
vortex_panic!(
Expand Down Expand Up @@ -93,6 +81,12 @@ impl ArrayValidity for ConstantArray {
}
}

impl ArrayStatisticsCompute for ConstantArray {
fn compute_statistics(&self, _stat: Stat) -> VortexResult<StatsSet> {
Ok(StatsSet::constant(self.owned_scalar(), self.len()))
}
}

impl AcceptArrayVisitor for ConstantArray {
fn accept(&self, _visitor: &mut dyn ArrayVisitor) -> VortexResult<()> {
Ok(())
Expand Down
33 changes: 0 additions & 33 deletions vortex-array/src/array/constant/stats.rs

This file was deleted.

69 changes: 67 additions & 2 deletions vortex-array/src/array/extension/mod.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use std::fmt::{Debug, Display};
use std::sync::Arc;

use enum_iterator::all;
use serde::{Deserialize, Serialize};
use vortex_dtype::{DType, ExtDType, ExtID};
use vortex_error::{VortexExpect as _, VortexResult};

use crate::array::visitor::{AcceptArrayVisitor, ArrayVisitor};
use crate::encoding::ids;
use crate::stats::ArrayStatisticsCompute;
use crate::stats::{ArrayStatistics as _, ArrayStatisticsCompute, Stat, StatsSet};
use crate::validity::{ArrayValidity, LogicalValidity};
use crate::variants::{ArrayVariants, ExtensionArrayTrait};
use crate::{impl_encoding, Array, ArrayDType, ArrayTrait, Canonical, IntoCanonical};
Expand Down Expand Up @@ -93,5 +94,69 @@ impl AcceptArrayVisitor for ExtensionArray {
}

impl ArrayStatisticsCompute for ExtensionArray {
// TODO(ngates): pass through stats to the underlying and cast the scalars.
fn compute_statistics(&self, stat: Stat) -> VortexResult<StatsSet> {
let mut stats = self.storage().statistics().compute_all(&[stat])?;

// for e.g., min/max, we want to cast to the extension array's dtype
// for other stats, we don't need to change anything
for stat in all::<Stat>().filter(|s| s.has_same_dtype_as_array()) {
if let Some(value) = stats.get(stat) {
stats.set(stat, value.cast(self.dtype())?);
}
}

Ok(stats)
}
}

#[cfg(test)]
mod tests {
use itertools::Itertools;
use vortex_dtype::PType;
use vortex_scalar::{PValue, Scalar, ScalarValue};

use super::*;
use crate::array::PrimitiveArray;
use crate::validity::Validity;
use crate::IntoArray as _;

#[test]
fn compute_statistics() {
let ext_dtype = Arc::new(ExtDType::new(
ExtID::new("timestamp".into()),
DType::from(PType::I64).into(),
None,
));
let array = ExtensionArray::new(
ext_dtype.clone(),
PrimitiveArray::from_vec(vec![1i64, 2, 3, 4, 5], Validity::NonNullable).into_array(),
);

let stats = array
.statistics()
.compute_all(&[Stat::Min, Stat::Max, Stat::NullCount])
.unwrap();
let num_stats = stats.clone().into_iter().try_len().unwrap();
assert!(
num_stats >= 3,
"Expected at least 3 stats, got {}",
num_stats
);

assert_eq!(
stats.get(Stat::Min),
Some(&Scalar::extension(
ext_dtype.clone(),
ScalarValue::Primitive(PValue::I64(1))
))
);
assert_eq!(
stats.get(Stat::Max),
Some(&Scalar::extension(
ext_dtype.clone(),
ScalarValue::Primitive(PValue::I64(5))
))
);
assert_eq!(stats.get(Stat::NullCount), Some(&0u64.into()));
}
}
25 changes: 23 additions & 2 deletions vortex-array/src/array/sparse/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::array::visitor::{AcceptArrayVisitor, ArrayVisitor};
use crate::compute::unary::scalar_at;
use crate::compute::{search_sorted, SearchResult, SearchSortedSide};
use crate::encoding::ids;
use crate::stats::{ArrayStatisticsCompute, StatsSet};
use crate::stats::{ArrayStatistics, ArrayStatisticsCompute, Stat, StatsSet};
use crate::validity::{ArrayValidity, LogicalValidity};
use crate::variants::PrimitiveArrayTrait;
use crate::{impl_encoding, Array, ArrayDType, ArrayTrait, IntoArray, IntoArrayVariant};
Expand Down Expand Up @@ -180,7 +180,28 @@ impl AcceptArrayVisitor for SparseArray {
}
}

impl ArrayStatisticsCompute for SparseArray {}
impl ArrayStatisticsCompute for SparseArray {
fn compute_statistics(&self, stat: Stat) -> VortexResult<StatsSet> {
let mut stats = self.values().statistics().compute_all(&[stat])?;
if self.len() == self.values().len() {
return Ok(stats);
}

let fill_len = self.len() - self.values().len();
let fill_stats = if self.fill_value().is_null() {
StatsSet::nulls(fill_len, self.dtype())
} else {
StatsSet::constant(self.fill_scalar(), fill_len)
};

if self.values().is_empty() {
return Ok(fill_stats);
}

stats.merge_unordered(&fill_stats);
Ok(stats)
}
}

impl ArrayValidity for SparseArray {
fn is_valid(&self, index: usize) -> bool {
Expand Down
27 changes: 27 additions & 0 deletions vortex-array/src/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use vortex_error::VortexResult;

use crate::aliases::hash_set::HashSet;
use crate::encoding::EncodingRef;
use crate::stats::{ArrayStatistics as _, PRUNING_STATS};
use crate::Array;

pub trait CompressionStrategy {
Expand Down Expand Up @@ -45,3 +46,29 @@ pub fn check_dtype_unchanged(arr: &Array, compressed: &Array) {
);
}
}

// Check that compression preserved the statistics.
pub fn check_statistics_unchanged(arr: &Array, compressed: &Array) {
let _ = arr;
let _ = compressed;
#[cfg(debug_assertions)]
{
for (stat, value) in arr.statistics().to_set().into_iter() {
debug_assert_eq!(
compressed.statistics().get(stat),
Some(value.clone()),
"Compression changed {stat} from {value} to {}",
compressed
.statistics()
.get(stat)
.map(|s| s.to_string())
.unwrap_or_else(|| "null".to_string())
);
}
}
}

/// Compute pruning stats for an array.
pub fn compute_pruning_stats(arr: &Array) -> VortexResult<()> {
arr.statistics().compute_all(PRUNING_STATS).map(|_| ())
}
Loading

0 comments on commit c1cee33

Please sign in to comment.