Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change array_agg to return null on no input rather than empty list #11299

Merged
merged 13 commits into from
Jul 10, 2024
10 changes: 10 additions & 0 deletions datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1981,6 +1981,16 @@ impl ScalarValue {
Self::new_list(values, data_type, true)
}

/// Create ListArray with Null with specific data type
///
/// - new_null_list(i32, nullable, 1): `ListArray[NULL]`
pub fn new_null_list(data_type: DataType, nullable: bool, null_len: usize) -> Self {
let data_type = DataType::List(Field::new_list_field(data_type, nullable).into());
Self::List(Arc::new(ListArray::from(ArrayData::new_null(
&data_type, null_len,
))))
}

/// Converts `IntoIterator<Item = ScalarValue>` where each element has type corresponding to
/// `data_type`, to a [`ListArray`].
///
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1388,7 +1388,7 @@ async fn unnest_with_redundant_columns() -> Result<()> {
let expected = vec![
"Projection: shapes.shape_id [shape_id:UInt32]",
" Unnest: lists[shape_id2] structs[] [shape_id:UInt32, shape_id2:UInt32;N]",
" Aggregate: groupBy=[[shapes.shape_id]], aggr=[[ARRAY_AGG(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} })]",
" Aggregate: groupBy=[[shapes.shape_id]], aggr=[[ARRAY_AGG(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} });N]",
" TableScan: shapes projection=[shape_id] [shape_id:UInt32]",
];

Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/sql/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ async fn csv_query_array_agg_distinct() -> Result<()> {
Schema::new(vec![Field::new_list(
"ARRAY_AGG(DISTINCT aggregate_test_100.c2)",
Field::new("item", DataType::UInt32, false),
false
true
),])
);

Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ impl AggregateFunction {
pub fn nullable(&self) -> Result<bool> {
match self {
AggregateFunction::Max | AggregateFunction::Min => Ok(true),
AggregateFunction::ArrayAgg => Ok(false),
AggregateFunction::ArrayAgg => Ok(true),
}
}
}
Expand Down
17 changes: 11 additions & 6 deletions datafusion/physical-expr/src/aggregate/array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ impl AggregateExpr for ArrayAgg {
&self.name,
// This should be the same as return type of AggregateFunction::ArrayAgg
Field::new("item", self.input_data_type.clone(), self.nullable),
false,
true,
))
}

Expand All @@ -86,7 +86,7 @@ impl AggregateExpr for ArrayAgg {
Ok(vec![Field::new_list(
format_state_name(&self.name, "array_agg"),
Field::new("item", self.input_data_type.clone(), self.nullable),
false,
true,
)])
}

Expand Down Expand Up @@ -137,8 +137,11 @@ impl Accumulator for ArrayAggAccumulator {
return Ok(());
}
assert!(values.len() == 1, "array_agg can only take 1 param!");

let val = Arc::clone(&values[0]);
self.values.push(val);
if val.len() > 0 {
self.values.push(val);
}
Ok(())
}

Expand All @@ -162,13 +165,15 @@ impl Accumulator for ArrayAggAccumulator {

fn evaluate(&mut self) -> Result<ScalarValue> {
// Transform Vec<ListArr> to ListArr

let element_arrays: Vec<&dyn Array> =
self.values.iter().map(|a| a.as_ref()).collect();

if element_arrays.is_empty() {
let arr = ScalarValue::new_list(&[], &self.datatype, self.nullable);
return Ok(ScalarValue::List(arr));
return Ok(ScalarValue::new_null_list(
self.datatype.clone(),
self.nullable,
1,
));
}

let concated_array = arrow::compute::concat(&element_arrays)?;
Expand Down
11 changes: 9 additions & 2 deletions datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ impl AggregateExpr for DistinctArrayAgg {
&self.name,
// This should be the same as return type of AggregateFunction::ArrayAgg
Field::new("item", self.input_data_type.clone(), self.nullable),
false,
true,
))
}

Expand All @@ -90,7 +90,7 @@ impl AggregateExpr for DistinctArrayAgg {
Ok(vec![Field::new_list(
format_state_name(&self.name, "distinct_array_agg"),
Field::new("item", self.input_data_type.clone(), self.nullable),
false,
true,
)])
}

Expand Down Expand Up @@ -165,6 +165,13 @@ impl Accumulator for DistinctArrayAggAccumulator {

fn evaluate(&mut self) -> Result<ScalarValue> {
let values: Vec<ScalarValue> = self.values.iter().cloned().collect();
if values.is_empty() {
return Ok(ScalarValue::new_null_list(
self.datatype.clone(),
self.nullable,
1,
));
}
let arr = ScalarValue::new_list(&values, &self.datatype, self.nullable);
Ok(ScalarValue::List(arr))
}
Expand Down
12 changes: 10 additions & 2 deletions datafusion/physical-expr/src/aggregate/array_agg_ordered.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ impl AggregateExpr for OrderSensitiveArrayAgg {
&self.name,
// This should be the same as return type of AggregateFunction::ArrayAgg
Field::new("item", self.input_data_type.clone(), self.nullable),
false,
true,
))
}

Expand All @@ -111,7 +111,7 @@ impl AggregateExpr for OrderSensitiveArrayAgg {
let mut fields = vec![Field::new_list(
format_state_name(&self.name, "array_agg"),
Field::new("item", self.input_data_type.clone(), self.nullable),
false, // This should be the same as field()
true, // This should be the same as field()
)];
let orderings = ordering_fields(&self.ordering_req, &self.order_by_data_types);
fields.push(Field::new_list(
Expand Down Expand Up @@ -309,6 +309,14 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator {
}

fn evaluate(&mut self) -> Result<ScalarValue> {
if self.values.is_empty() {
return Ok(ScalarValue::new_null_list(
self.datatypes[0].clone(),
self.nullable,
1,
));
}

let values = self.values.clone();
let array = if self.reverse {
ScalarValue::new_list_from_iter(
Expand Down
4 changes: 2 additions & 2 deletions datafusion/physical-expr/src/aggregate/build_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ mod tests {
Field::new_list(
"c1",
Field::new("item", data_type.clone(), true),
false,
true,
),
result_agg_phy_exprs.field().unwrap()
);
Expand All @@ -167,7 +167,7 @@ mod tests {
Field::new_list(
"c1",
Field::new("item", data_type.clone(), true),
false,
true,
),
result_agg_phy_exprs.field().unwrap()
);
Expand Down
2 changes: 1 addition & 1 deletion datafusion/physical-plan/src/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ impl FilterExec {

Ok(Self {
predicate,
input: input.clone(),
input: Arc::clone(&input),
metrics: ExecutionPlanMetricsSet::new(),
default_selectivity,
cache,
Expand Down
155 changes: 116 additions & 39 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1694,7 +1694,7 @@ SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 ORDER BY c13 LIMIT
query ?
SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 LIMIT 0) test
----
[]
NULL

# csv_query_array_agg_one
query ?
Expand Down Expand Up @@ -1753,31 +1753,12 @@ NULL 4 29 1.260869565217 123 -117 23
NULL 5 -194 -13.857142857143 118 -101 14
NULL NULL 781 7.81 125 -117 100

# TODO: array_agg_distinct output is non-deterministic -- rewrite with array_sort(list_sort)
# unnest is also not available, so manually unnesting via CROSS JOIN
# additional count(1) forces array_agg_distinct instead of array_agg over aggregated by c2 data
#
# select with count to forces array_agg_distinct function, since single distinct expression is converted to group by by optimizer
# csv_query_array_agg_distinct
query III
WITH indices AS (
SELECT 1 AS idx UNION ALL
SELECT 2 AS idx UNION ALL
SELECT 3 AS idx UNION ALL
SELECT 4 AS idx UNION ALL
SELECT 5 AS idx
)
SELECT data.arr[indices.idx] as element, array_length(data.arr) as array_len, dummy
FROM (
SELECT array_agg(distinct c2) as arr, count(1) as dummy FROM aggregate_test_100
) data
CROSS JOIN indices
ORDER BY 1
----
1 5 100
2 5 100
3 5 100
4 5 100
5 5 100
Comment on lines -1761 to -1780
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rewrite it to the simpler one!

query ?I
SELECT array_sort(array_agg(distinct c2)), count(1) FROM aggregate_test_100
----
[1, 2, 3, 4, 5] 100

# aggregate_time_min_and_max
query TT
Expand Down Expand Up @@ -2732,6 +2713,16 @@ SELECT COUNT(DISTINCT c1) FROM test

# TODO: aggregate_with_alias

# test_approx_percentile_cont_decimal_support
query TI
SELECT c1, approx_percentile_cont(c2, cast(0.85 as decimal(10,2))) apc FROM aggregate_test_100 GROUP BY 1 ORDER BY 1
----
a 4
b 5
c 4
d 4
e 4

# array_agg_zero
query ?
SELECT ARRAY_AGG([])
Expand All @@ -2744,28 +2735,114 @@ SELECT ARRAY_AGG([1])
----
[[1]]

# test_approx_percentile_cont_decimal_support
query TI
SELECT c1, approx_percentile_cont(c2, cast(0.85 as decimal(10,2))) apc FROM aggregate_test_100 GROUP BY 1 ORDER BY 1
# test array_agg with no row qualified
statement ok
create table t(a int, b float, c bigint) as values (1, 1.2, 2);

# returns NULL, follows DuckDB's behaviour
query ?
select array_agg(a) from t where a > 2;
----
a 4
b 5
c 4
d 4
e 4
NULL

query ?
select array_agg(b) from t where b > 3.1;
----
NULL

# array_agg_zero
query ?
SELECT ARRAY_AGG([]);
select array_agg(c) from t where c > 3;
----
[[]]
NULL

# array_agg_one
query ?I
select array_agg(c), count(1) from t where c > 3;
----
NULL 0

# returns 0 rows if group by is applied, follows DuckDB's behaviour
query ?
SELECT ARRAY_AGG([1]);
select array_agg(a) from t where a > 3 group by a;
----
[[1]]

query ?I
select array_agg(a), count(1) from t where a > 3 group by a;
----

# returns NULL, follows DuckDB's behaviour
query ?
select array_agg(distinct a) from t where a > 3;
alamb marked this conversation as resolved.
Show resolved Hide resolved
----
NULL

query ?I
select array_agg(distinct a), count(1) from t where a > 3;
----
NULL 0

# returns 0 rows if group by is applied, follows DuckDB's behaviour
query ?
select array_agg(distinct a) from t where a > 3 group by a;
----

query ?I
select array_agg(distinct a), count(1) from t where a > 3 group by a;
----

# test order sensitive array agg
query ?
select array_agg(a order by a) from t where a > 3;
----
NULL

query ?
select array_agg(a order by a) from t where a > 3 group by a;
----

query ?I
select array_agg(a order by a), count(1) from t where a > 3 group by a;
----

statement ok
drop table t;

# test with no values
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add array_agg(distinct case on empty table

statement ok
create table t(a int, b float, c bigint);

query ?
select array_agg(a) from t;
----
NULL

query ?
select array_agg(b) from t;
----
NULL

query ?
select array_agg(c) from t;
----
NULL

query ?I
select array_agg(distinct a), count(1) from t;
----
NULL 0

query ?I
select array_agg(distinct b), count(1) from t;
----
NULL 0

query ?I
select array_agg(distinct b), count(1) from t;
----
NULL 0

statement ok
drop table t;


# array_agg_i32
statement ok
Expand Down