Skip to content

Commit

Permalink
Some fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
gatesn committed Dec 18, 2024
1 parent d82c095 commit 1340d8d
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 17 deletions.
6 changes: 5 additions & 1 deletion vortex-array/src/array/list/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ impl ScalarAtFn<ListArray> for ListEncoding {
let elem = array.elements_at(index)?;
let scalars: Vec<Scalar> = (0..elem.len()).map(|i| scalar_at(&elem, i)).try_collect()?;

Ok(Scalar::list(Arc::new(elem.dtype().clone()), scalars))
Ok(Scalar::list(
Arc::new(elem.dtype().clone()),
scalars,
array.dtype().nullability(),
))
}
}

Expand Down
20 changes: 16 additions & 4 deletions vortex-array/src/array/list/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ impl ValidityVTable<ListArray> for ListEncoding {
mod test {
use std::sync::Arc;

use vortex_dtype::PType;
use vortex_dtype::{Nullability, PType};
use vortex_scalar::Scalar;

use crate::array::list::ListArray;
Expand Down Expand Up @@ -228,15 +228,27 @@ mod test {
ListArray::try_new(elements.into_array(), offsets.into_array(), validity).unwrap();

assert_eq!(
Scalar::list(Arc::new(PType::I32.into()), vec![1.into(), 2.into()]),
Scalar::list(
Arc::new(PType::I32.into()),
vec![1.into(), 2.into()],
Nullability::Nullable
),
scalar_at(&list, 0).unwrap()
);
assert_eq!(
Scalar::list(Arc::new(PType::I32.into()), vec![3.into(), 4.into()]),
Scalar::list(
Arc::new(PType::I32.into()),
vec![3.into(), 4.into()],
Nullability::Nullable
),
scalar_at(&list, 1).unwrap()
);
assert_eq!(
Scalar::list(Arc::new(PType::I32.into()), vec![5.into()]),
Scalar::list(
Arc::new(PType::I32.into()),
vec![5.into()],
Nullability::Nullable
),
scalar_at(&list, 2).unwrap()
);
}
Expand Down
16 changes: 13 additions & 3 deletions vortex-array/src/builders/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,17 +156,27 @@ mod tests {

builder
.append_value(
Scalar::list(dtype.clone(), vec![1i32.into(), 2i32.into(), 3i32.into()]).as_list(),
Scalar::list(
dtype.clone(),
vec![1i32.into(), 2i32.into(), 3i32.into()],
Nullability::NonNullable,
)
.as_list(),
)
.unwrap();

builder
.append_value(Scalar::empty(dtype.clone()).as_list())
.append_value(Scalar::list_empty(dtype.clone(), Nullability::NonNullable).as_list())
.unwrap();

builder
.append_value(
Scalar::list(dtype, vec![4i32.into(), 5i32.into(), 6i32.into()]).as_list(),
Scalar::list(
dtype,
vec![4i32.into(), 5i32.into(), 6i32.into()],
Nullability::NonNullable,
)
.as_list(),
)
.unwrap();

Expand Down
5 changes: 4 additions & 1 deletion vortex-array/src/compute/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,10 @@ pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: Operator) -> Scalar {
Operator::Lte => lhs <= rhs,
};

Scalar::bool(b, Nullability::Nullable)
Scalar::bool(
b,
(lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into(),
)
}
}

Expand Down
2 changes: 1 addition & 1 deletion vortex-array/src/compute/scalar_at.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ pub fn scalar_at(array: impl AsRef<ArrayData>, index: usize) -> VortexResult<Sca
.unwrap_or_else(|| Err(vortex_err!(NotImplemented: "scalar_at", array.encoding().id())))?;

debug_assert_eq!(
scalar.dtype(),
array.dtype(),
scalar.dtype(),
"ScalarAt dtype mismatch {}",
array.encoding().id()
);
Expand Down
2 changes: 1 addition & 1 deletion vortex-array/src/data/viewed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ impl Statistics for ViewedArrayData {
.stats()?
.bit_width_freq()
.map(|v| v.iter().map(Scalar::from).collect_vec())
.map(|v| Scalar::list(element_dtype, v))
.map(|v| Scalar::list(element_dtype, v, Nullability::NonNullable))
}
Stat::TrailingZeroFreq => self
.flatbuffer()
Expand Down
15 changes: 9 additions & 6 deletions vortex-scalar/src/list.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use std::ops::Deref;
use std::sync::Arc;

use vortex_dtype::DType;
use vortex_dtype::Nullability::{NonNullable, Nullable};
use vortex_dtype::{DType, Nullability};
use vortex_error::{vortex_bail, vortex_panic, VortexError, VortexResult};

use crate::value::{InnerScalarValue, ScalarValue};
Expand Down Expand Up @@ -72,7 +71,11 @@ impl<'a> ListScalar<'a> {
}

impl Scalar {
pub fn list(element_dtype: Arc<DType>, children: Vec<Scalar>) -> Self {
pub fn list(
element_dtype: Arc<DType>,
children: Vec<Scalar>,
nullability: Nullability,
) -> Self {
for child in &children {
if child.dtype() != &*element_dtype {
vortex_panic!(
Expand All @@ -83,16 +86,16 @@ impl Scalar {
}
}
Self {
dtype: DType::List(element_dtype, NonNullable),
dtype: DType::List(element_dtype, nullability),
value: ScalarValue(InnerScalarValue::List(
children.into_iter().map(|x| x.value).collect::<Arc<[_]>>(),
)),
}
}

pub fn empty(element_dtype: Arc<DType>) -> Self {
pub fn list_empty(element_dtype: Arc<DType>, nullability: Nullability) -> Self {
Self {
dtype: DType::List(element_dtype, Nullable),
dtype: DType::List(element_dtype, nullability),
value: ScalarValue(InnerScalarValue::Null),
}
}
Expand Down

0 comments on commit 1340d8d

Please sign in to comment.