Skip to content

Commit

Permalink
Relax combine partial final rule (apache#10913)
Browse files Browse the repository at this point in the history
* Minor changes

* Minor changes

* Re-introduce group by expression check
  • Loading branch information
mustafasrepo authored and xinlifoobar committed Jun 22, 2024
1 parent 0658959 commit 832e58b
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 61 deletions.
66 changes: 17 additions & 49 deletions datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ use crate::physical_plan::ExecutionPlan;

use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::{AggregateExpr, PhysicalExpr};
use datafusion_physical_expr::{physical_exprs_equal, AggregateExpr, PhysicalExpr};

/// CombinePartialFinalAggregate optimizer rule combines the adjacent Partial and Final AggregateExecs
/// into a Single AggregateExec if their grouping exprs and aggregate exprs equal.
Expand Down Expand Up @@ -132,19 +131,23 @@ type GroupExprsRef<'a> = (
&'a [Option<Arc<dyn PhysicalExpr>>],
);

type GroupExprs = (
PhysicalGroupBy,
Vec<Arc<dyn AggregateExpr>>,
Vec<Option<Arc<dyn PhysicalExpr>>>,
);

fn can_combine(final_agg: GroupExprsRef, partial_agg: GroupExprsRef) -> bool {
let (final_group_by, final_aggr_expr, final_filter_expr) =
normalize_group_exprs(final_agg);
let (input_group_by, input_aggr_expr, input_filter_expr) =
normalize_group_exprs(partial_agg);

final_group_by.eq(&input_group_by)
let (final_group_by, final_aggr_expr, final_filter_expr) = final_agg;
let (input_group_by, input_aggr_expr, input_filter_expr) = partial_agg;

// Compare output expressions of the partial, and input expressions of the final operator.
physical_exprs_equal(
&input_group_by.output_exprs(),
&final_group_by.input_exprs(),
) && input_group_by.groups() == final_group_by.groups()
&& input_group_by.null_expr().len() == final_group_by.null_expr().len()
&& input_group_by
.null_expr()
.iter()
.zip(final_group_by.null_expr().iter())
.all(|((lhs_expr, lhs_str), (rhs_expr, rhs_str))| {
lhs_expr.eq(rhs_expr) && lhs_str == rhs_str
})
&& final_aggr_expr.len() == input_aggr_expr.len()
&& final_aggr_expr
.iter()
Expand All @@ -160,41 +163,6 @@ fn can_combine(final_agg: GroupExprsRef, partial_agg: GroupExprsRef) -> bool {
)
}

// To compare the group expressions between the final and partial aggregations, need to discard all the column indexes and compare
fn normalize_group_exprs(group_exprs: GroupExprsRef) -> GroupExprs {
let (group, agg, filter) = group_exprs;
let new_group_expr = group
.expr()
.iter()
.map(|(expr, name)| (discard_column_index(expr.clone()), name.clone()))
.collect::<Vec<_>>();
let new_group = PhysicalGroupBy::new(
new_group_expr,
group.null_expr().to_vec(),
group.groups().to_vec(),
);
(new_group, agg.to_vec(), filter.to_vec())
}

fn discard_column_index(group_expr: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
group_expr
.clone()
.transform(|expr| {
let normalized_form: Option<Arc<dyn PhysicalExpr>> =
match expr.as_any().downcast_ref::<Column>() {
Some(column) => Some(Arc::new(Column::new(column.name(), 0))),
None => None,
};
Ok(if let Some(normalized_form) = normalized_form {
Transformed::yes(normalized_form)
} else {
Transformed::no(expr)
})
})
.data()
.unwrap_or(group_expr)
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
71 changes: 71 additions & 0 deletions datafusion/sqllogictest/test_files/group_by.slt
Original file line number Diff line number Diff line change
Expand Up @@ -5064,3 +5064,74 @@ statement error DataFusion error: Error during planning: Cannot find column with
SELECT a, b, COUNT(1)
FROM multiple_ordered_table
GROUP BY 1, 2, 4, 5, 6;

statement ok
set datafusion.execution.target_partitions = 1;

# Create a table that contains various keywords, with their corresponding timestamps
statement ok
CREATE TABLE keywords_stream (
ts TIMESTAMP,
sn INTEGER PRIMARY KEY,
keyword VARCHAR NOT NULL
);

statement ok
INSERT INTO keywords_stream(ts, sn, keyword) VALUES
('2024-01-01T00:00:00Z', '0', 'Drug'),
('2024-01-01T00:00:05Z', '1', 'Bomb'),
('2024-01-01T00:00:10Z', '2', 'Theft'),
('2024-01-01T00:00:15Z', '3', 'Gun'),
('2024-01-01T00:00:20Z', '4', 'Calm');

# Create a table that contains alert keywords
statement ok
CREATE TABLE ALERT_KEYWORDS(keyword VARCHAR NOT NULL);

statement ok
INSERT INTO ALERT_KEYWORDS VALUES
('Drug'),
('Bomb'),
('Theft'),
('Gun'),
('Knife'),
('Fire');

query TT
explain SELECT
DATE_BIN(INTERVAL '2' MINUTE, ts, '2000-01-01') AS ts_chunk,
COUNT(keyword) AS alert_keyword_count
FROM
keywords_stream
WHERE
keywords_stream.keyword IN (SELECT keyword FROM ALERT_KEYWORDS)
GROUP BY
ts_chunk;
----
logical_plan
01)Projection: date_bin(IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }"),keywords_stream.ts,Utf8("2000-01-01")) AS ts_chunk, COUNT(keywords_stream.keyword) AS alert_keyword_count
02)--Aggregate: groupBy=[[date_bin(IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }"), keywords_stream.ts, TimestampNanosecond(946684800000000000, None)) AS date_bin(IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }"),keywords_stream.ts,Utf8("2000-01-01"))]], aggr=[[COUNT(keywords_stream.keyword)]]
03)----LeftSemi Join: keywords_stream.keyword = __correlated_sq_1.keyword
04)------TableScan: keywords_stream projection=[ts, keyword]
05)------SubqueryAlias: __correlated_sq_1
06)--------TableScan: alert_keywords projection=[keyword]
physical_plan
01)ProjectionExec: expr=[date_bin(IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }"),keywords_stream.ts,Utf8("2000-01-01"))@0 as ts_chunk, COUNT(keywords_stream.keyword)@1 as alert_keyword_count]
02)--AggregateExec: mode=Single, gby=[date_bin(IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }, ts@0, 946684800000000000) as date_bin(IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }"),keywords_stream.ts,Utf8("2000-01-01"))], aggr=[COUNT(keywords_stream.keyword)]
03)----CoalesceBatchesExec: target_batch_size=2
04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(keyword@0, keyword@1)]
05)--------MemoryExec: partitions=1, partition_sizes=[1]
06)--------MemoryExec: partitions=1, partition_sizes=[1]

query PI
SELECT
DATE_BIN(INTERVAL '2' MINUTE, ts, '2000-01-01') AS ts_chunk,
COUNT(keyword) AS alert_keyword_count
FROM
keywords_stream
WHERE
keywords_stream.keyword IN (SELECT keyword FROM ALERT_KEYWORDS)
GROUP BY
ts_chunk;
----
2024-01-01T00:00:00 4
23 changes: 11 additions & 12 deletions datafusion/sqllogictest/test_files/joins.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1382,18 +1382,17 @@ physical_plan
02)--AggregateExec: mode=Final, gby=[], aggr=[COUNT(alias1)]
03)----CoalescePartitionsExec
04)------AggregateExec: mode=Partial, gby=[], aggr=[COUNT(alias1)]
05)--------AggregateExec: mode=FinalPartitioned, gby=[alias1@0 as alias1], aggr=[]
06)----------AggregateExec: mode=Partial, gby=[t1_id@0 as alias1], aggr=[]
07)------------CoalesceBatchesExec: target_batch_size=2
08)--------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(t1_id@0, t2_id@0)], projection=[t1_id@0]
09)----------------CoalesceBatchesExec: target_batch_size=2
10)------------------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2
11)--------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
12)----------------------MemoryExec: partitions=1, partition_sizes=[1]
13)----------------CoalesceBatchesExec: target_batch_size=2
14)------------------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
15)--------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
16)----------------------MemoryExec: partitions=1, partition_sizes=[1]
05)--------AggregateExec: mode=SinglePartitioned, gby=[t1_id@0 as alias1], aggr=[]
06)----------CoalesceBatchesExec: target_batch_size=2
07)------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(t1_id@0, t2_id@0)], projection=[t1_id@0]
08)--------------CoalesceBatchesExec: target_batch_size=2
09)----------------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2
10)------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
11)--------------------MemoryExec: partitions=1, partition_sizes=[1]
12)--------------CoalesceBatchesExec: target_batch_size=2
13)----------------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
14)------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
15)--------------------MemoryExec: partitions=1, partition_sizes=[1]

statement ok
set datafusion.explain.logical_plan_only = true;
Expand Down

0 comments on commit 832e58b

Please sign in to comment.