Skip to content

Commit

Permalink
Fix incorrect results with multiple COUNT(DISTINCT..) aggregates on…
Browse files Browse the repository at this point in the history
… dictionaries (apache#9679)

* Add test for multiple count distincts on a dictionary

* Fix accumulator merge bug

* Fix cleanup code
  • Loading branch information
alamb authored and wiedld committed Apr 1, 2024
1 parent 9b6da0a commit 0219e89
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 8 deletions.
2 changes: 1 addition & 1 deletion datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1746,7 +1746,7 @@ impl ScalarValue {
}

/// Converts `Vec<ScalarValue>` where each element has type corresponding to
/// `data_type`, to a [`ListArray`].
/// `data_type`, to a single element [`ListArray`].
///
/// Example
/// ```
Expand Down
32 changes: 25 additions & 7 deletions datafusion/physical-expr/src/aggregate/count_distinct/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ use crate::binary_map::OutputType;
use crate::expressions::format_state_name;
use crate::{AggregateExpr, PhysicalExpr};

/// Expression for a COUNT(DISTINCT) aggregation.
/// Expression for a `COUNT(DISTINCT)` aggregation.
#[derive(Debug)]
pub struct DistinctCount {
/// Column name
Expand Down Expand Up @@ -100,6 +100,7 @@ impl AggregateExpr for DistinctCount {
use TimeUnit::*;

Ok(match &self.state_data_type {
// try and use a specialized accumulator if possible, otherwise fall back to generic accumulator
Int8 => Box::new(PrimitiveDistinctCountAccumulator::<Int8Type>::new()),
Int16 => Box::new(PrimitiveDistinctCountAccumulator::<Int16Type>::new()),
Int32 => Box::new(PrimitiveDistinctCountAccumulator::<Int32Type>::new()),
Expand Down Expand Up @@ -157,6 +158,7 @@ impl AggregateExpr for DistinctCount {
OutputType::Binary,
)),

// Use the generic accumulator based on `ScalarValue` for all other types
_ => Box::new(DistinctCountAccumulator {
values: HashSet::default(),
state_data_type: self.state_data_type.clone(),
Expand All @@ -183,7 +185,11 @@ impl PartialEq<dyn Any> for DistinctCount {
}

/// General purpose distinct accumulator that works for any DataType by using
/// [`ScalarValue`]. Some types have specialized accumulators that are (much)
/// [`ScalarValue`].
///
/// It stores intermediate results as a `ListArray`
///
/// Note that many types have specialized accumulators that are (much)
/// more efficient such as [`PrimitiveDistinctCountAccumulator`] and
/// [`BytesDistinctCountAccumulator`]
#[derive(Debug)]
Expand All @@ -193,8 +199,9 @@ struct DistinctCountAccumulator {
}

impl DistinctCountAccumulator {
// calculating the size for fixed length values, taking first batch size * number of batches
// This method is faster than .full_size(), however it is not suitable for variable length values like strings or complex types
// calculating the size for fixed length values, taking first batch size *
// number of batches This method is faster than .full_size(), however it is
// not suitable for variable length values like strings or complex types
fn fixed_size(&self) -> usize {
std::mem::size_of_val(self)
+ (std::mem::size_of::<ScalarValue>() * self.values.capacity())
Expand All @@ -207,7 +214,8 @@ impl DistinctCountAccumulator {
+ std::mem::size_of::<DataType>()
}

// calculates the size as accurate as possible, call to this method is expensive
// calculates the size as accurately as possible. Note that calling this
// method is expensive
fn full_size(&self) -> usize {
std::mem::size_of_val(self)
+ (std::mem::size_of::<ScalarValue>() * self.values.capacity())
Expand All @@ -221,6 +229,7 @@ impl DistinctCountAccumulator {
}

impl Accumulator for DistinctCountAccumulator {
/// Returns the distinct values seen so far as (one element) ListArray.
fn state(&mut self) -> Result<Vec<ScalarValue>> {
let scalars = self.values.iter().cloned().collect::<Vec<_>>();
let arr = ScalarValue::new_list(scalars.as_slice(), &self.state_data_type);
Expand All @@ -246,15 +255,24 @@ impl Accumulator for DistinctCountAccumulator {
})
}

/// Merges multiple sets of distinct values into the current set.
///
/// The input to this function is a `ListArray` with **multiple** rows,
/// where each row contains the values from a partial aggregate's phase (e.g.
/// the result of calling `Self::state` on multiple accumulators).
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
if states.is_empty() {
return Ok(());
}
assert_eq!(states.len(), 1, "array_agg states must be singleton!");
let array = &states[0];
let list_array = array.as_list::<i32>();
let inner_array = list_array.value(0);
self.update_batch(&[inner_array])
for inner_array in list_array.iter() {
let inner_array = inner_array
.expect("counts are always non null, so are intermediate results");
self.update_batch(&[inner_array])?;
}
Ok(())
}

fn evaluate(&mut self) -> Result<ScalarValue> {
Expand Down
67 changes: 67 additions & 0 deletions datafusion/sqllogictest/test_files/dictionary.slt
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,70 @@ ORDER BY
2023-12-20T01:20:00 1000 f2 foo
2023-12-20T01:30:00 1000 f1 32.0
2023-12-20T01:30:00 1000 f2 foo

# Cleanup
statement ok
drop view m1;

statement ok
drop view m2;

######
# Create a table using UNION ALL to get 2 partitions (very important)
######
statement ok
create table m3_source as
select * from (values('foo', 'bar', 1))
UNION ALL
select * from (values('foo', 'baz', 1));

######
# Now, create a table with the same data, but column2 has type `Dictionary(Int32)` to trigger the fallback code
######
statement ok
create table m3 as
select
column1,
arrow_cast(column2, 'Dictionary(Int32, Utf8)') as "column2",
column3
from m3_source;

# there are two values in column2
query T?I rowsort
SELECT *
FROM m3;
----
foo bar 1
foo baz 1

# There is 1 distinct value in column1
query I
SELECT count(distinct column1)
FROM m3
GROUP BY column3;
----
1

# There are 2 distinct values in column2
query I
SELECT count(distinct column2)
FROM m3
GROUP BY column3;
----
2

# Should still get the same results when querying in the same query
query II
SELECT count(distinct column1), count(distinct column2)
FROM m3
GROUP BY column3;
----
1 2


# Cleanup
statement ok
drop table m3;

statement ok
drop table m3_source;

0 comments on commit 0219e89

Please sign in to comment.