From d63b6fc78ffe971cd23d4052d9bb3fdf27e79f2f Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Wed, 27 Nov 2024 18:05:03 +0100 Subject: [PATCH] fix more rules --- datafusion/core/Cargo.toml | 1 + .../tests/user_defined/user_defined_plan.rs | 74 +++--- datafusion/expr/src/logical_plan/plan.rs | 4 +- datafusion/expr/src/logical_plan/tree_node.rs | 2 +- .../src/replace_distinct_aggregate.rs | 233 ++++++++-------- .../optimizer/src/scalar_subquery_to_join.rs | 249 +++++++++--------- .../src/single_distinct_to_groupby.rs | 15 +- 7 files changed, 311 insertions(+), 267 deletions(-) diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index d2365280937fa..98cdaf9f8ede1 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -111,6 +111,7 @@ datafusion-physical-expr-common = { workspace = true } datafusion-physical-optimizer = { workspace = true } datafusion-physical-plan = { workspace = true } datafusion-sql = { workspace = true } +enumset = { workspace = true } flate2 = { version = "1.0.24", optional = true } futures = { workspace = true } glob = "0.3.0" diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index 8b313d13e83e1..12021039c1e1c 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -69,6 +69,7 @@ use arrow::{ util::pretty::pretty_format_batches, }; use async_trait::async_trait; +use enumset::enum_set; use futures::{Stream, StreamExt}; use datafusion::execution::session_state::SessionStateBuilder; @@ -97,8 +98,8 @@ use datafusion::{ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::ScalarValue; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::{FetchType, Projection, SortExpr}; -use datafusion_optimizer::optimizer::ApplyOrder; use datafusion_optimizer::AnalyzerRule; /// Execute the specified sql and return the resulting record batches @@ -343,10 +344,6 @@ impl OptimizerRule for TopKOptimizerRule { "topk" } - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } - fn supports_rewrite(&self) -> bool { true } @@ -357,38 +354,47 @@ impl OptimizerRule for TopKOptimizerRule { plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result, DataFusionError> { - // Note: this code simply looks for the pattern of a Limit followed by a - // Sort and replaces it by a TopK node. It does not handle many - // edge cases (e.g multiple sort columns, sort ASC / DESC), etc. - let LogicalPlan::Limit(ref limit, _) = plan else { - return Ok(Transformed::no(plan)); - }; - let FetchType::Literal(Some(fetch)) = limit.get_fetch_type()? else { - return Ok(Transformed::no(plan)); - }; + plan.transform_down_with_subqueries(|plan| { + if !plan.stats().contains_all_patterns(enum_set!( + LogicalPlanPattern::LogicalPlanLimit + | LogicalPlanPattern::LogicalPlanSort + )) { + return Ok(Transformed::jump(plan)); + } - if let LogicalPlan::Sort( - Sort { - ref expr, - ref input, - .. - }, - _, - ) = limit.input.as_ref() - { - if expr.len() == 1 { - // we found a sort with a single sort expr, replace with a a TopK - return Ok(Transformed::yes(LogicalPlan::extension(Extension { - node: Arc::new(TopKPlanNode { - k: fetch, - input: input.as_ref().clone(), - expr: expr[0].clone(), - }), - }))); + // Note: this code simply looks for the pattern of a Limit followed by a + // Sort and replaces it by a TopK node. It does not handle many + // edge cases (e.g multiple sort columns, sort ASC / DESC), etc. + let LogicalPlan::Limit(ref limit, _) = plan else { + return Ok(Transformed::no(plan)); + }; + let FetchType::Literal(Some(fetch)) = limit.get_fetch_type()? else { + return Ok(Transformed::no(plan)); + }; + + if let LogicalPlan::Sort( + Sort { + ref expr, + ref input, + .. + }, + _, + ) = limit.input.as_ref() + { + if expr.len() == 1 { + // we found a sort with a single sort expr, replace with a a TopK + return Ok(Transformed::yes(LogicalPlan::extension(Extension { + node: Arc::new(TopKPlanNode { + k: fetch, + input: input.as_ref().clone(), + expr: expr[0].clone(), + }), + }))); + } } - } - Ok(Transformed::no(plan)) + Ok(Transformed::no(plan)) + }) } } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index f678c6b234b72..ed4db56910948 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -2377,7 +2377,9 @@ impl LogicalPlan { } pub fn distinct(distinct: Distinct) -> Self { - let stats = distinct.stats(); + let stats = + LogicalPlanStats::new(enum_set!(LogicalPlanPattern::LogicalPlanDistinct)) + .merge(distinct.stats()); LogicalPlan::Distinct(distinct, stats) } diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 81c274dac9ea8..390cf6366b3f5 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -113,7 +113,7 @@ pub enum LogicalPlanPattern { // LogicalPlanExplain, // LogicalPlanAnalyze, // LogicalPlanExtension, - // LogicalPlanDistinct, + LogicalPlanDistinct, // LogicalPlanDml, // LogicalPlanDdl, // LogicalPlanCopy, diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index f2500a09104bb..47a29fa732250 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -17,16 +17,16 @@ //! [`ReplaceDistinctWithAggregate`] replaces `DISTINCT ...` with `GROUP BY ...` -use crate::optimizer::{ApplyOrder, ApplyOrder::BottomUp}; use crate::{OptimizerConfig, OptimizerRule}; -use std::sync::Arc; - use datafusion_common::tree_node::Transformed; use datafusion_common::{Column, Result}; use datafusion_expr::expr_rewriter::normalize_cols; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::utils::expand_wildcard; use datafusion_expr::{col, ExprFunctionExt, LogicalPlanBuilder}; use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan}; +use std::cell::Cell; +use std::sync::Arc; /// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]] /// @@ -76,115 +76,140 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { plan: LogicalPlan, config: &dyn OptimizerConfig, ) -> Result> { - match plan { - LogicalPlan::Distinct(Distinct::All(input), _) => { - let group_expr = expand_wildcard(input.schema(), &input, None)?; - - let field_count = input.schema().fields().len(); - for dep in input.schema().functional_dependencies().iter() { - // If distinct is exactly the same with a previous GROUP BY, we can - // simply remove it: - if dep.source_indices.len() >= field_count - && dep.source_indices[..field_count] - .iter() - .enumerate() - .all(|(idx, f_idx)| idx == *f_idx) - { - return Ok(Transformed::yes(input.as_ref().clone())); - } + let skip = Cell::new(false); + plan.transform_down_up_with_subqueries( + |plan| { + if !plan + .stats() + .contains_pattern(LogicalPlanPattern::LogicalPlanDistinct) + { + skip.set(true); + return Ok(Transformed::jump(plan)); + } + + Ok(Transformed::no(plan)) + }, + |plan| { + if skip.get() { + skip.set(false); + return Ok(Transformed::no(plan)); } - // Replace with aggregation: - let aggr_plan = LogicalPlan::aggregate(Aggregate::try_new( - input, - group_expr, - vec![], - )?); - Ok(Transformed::yes(aggr_plan)) - } - LogicalPlan::Distinct( - Distinct::On(DistinctOn { - select_expr, - on_expr, - sort_expr, - input, - schema, - }), - _, - ) => { - let expr_cnt = on_expr.len(); - - // Construct the aggregation expression to be used to fetch the selected expressions. - let first_value_udaf: Arc = - config.function_registry().unwrap().udaf("first_value")?; - let aggr_expr = select_expr.into_iter().map(|e| { - if let Some(order_by) = &sort_expr { - first_value_udaf - .call(vec![e]) - .order_by(order_by.clone()) - .build() - // guaranteed to be `Expr::AggregateFunction` - .unwrap() - } else { - first_value_udaf.call(vec![e]) + match plan { + LogicalPlan::Distinct(Distinct::All(input), _) => { + let group_expr = expand_wildcard(input.schema(), &input, None)?; + + let field_count = input.schema().fields().len(); + for dep in input.schema().functional_dependencies().iter() { + // If distinct is exactly the same with a previous GROUP BY, we can + // simply remove it: + if dep.source_indices.len() >= field_count + && dep.source_indices[..field_count] + .iter() + .enumerate() + .all(|(idx, f_idx)| idx == *f_idx) + { + return Ok(Transformed::yes(input.as_ref().clone())); + } + } + + // Replace with aggregation: + let aggr_plan = LogicalPlan::aggregate(Aggregate::try_new( + input, + group_expr, + vec![], + )?); + Ok(Transformed::yes(aggr_plan)) } - }); - - let aggr_expr = normalize_cols(aggr_expr, input.as_ref())?; - let group_expr = normalize_cols(on_expr, input.as_ref())?; - - // Build the aggregation plan - let plan = LogicalPlan::aggregate(Aggregate::try_new( - input, group_expr, aggr_expr, - )?); - // TODO use LogicalPlanBuilder directly rather than recreating the Aggregate - // when https://github.com/apache/datafusion/issues/10485 is available - let lpb = LogicalPlanBuilder::from(plan); - - let plan = if let Some(mut sort_expr) = sort_expr { - // While sort expressions were used in the `FIRST_VALUE` aggregation itself above, - // this on it's own isn't enough to guarantee the proper output order of the grouping - // (`ON`) expression, so we need to sort those as well. - - // truncate the sort_expr to the length of on_expr - sort_expr.truncate(expr_cnt); - - lpb.sort(sort_expr)?.build()? - } else { - lpb.build()? - }; - - // Whereas the aggregation plan by default outputs both the grouping and the aggregation - // expressions, for `DISTINCT ON` we only need to emit the original selection expressions. - - let project_exprs = plan - .schema() - .iter() - .skip(expr_cnt) - .zip(schema.iter()) - .map(|((new_qualifier, new_field), (old_qualifier, old_field))| { - col(Column::from((new_qualifier, new_field))) - .alias_qualified(old_qualifier.cloned(), old_field.name()) - }) - .collect::>(); - - let plan = LogicalPlanBuilder::from(plan) - .project(project_exprs)? - .build()?; - - Ok(Transformed::yes(plan)) - } - _ => Ok(Transformed::no(plan)), - } + LogicalPlan::Distinct( + Distinct::On(DistinctOn { + select_expr, + on_expr, + sort_expr, + input, + schema, + }), + _, + ) => { + let expr_cnt = on_expr.len(); + + // Construct the aggregation expression to be used to fetch the selected expressions. + let first_value_udaf: Arc = + config.function_registry().unwrap().udaf("first_value")?; + let aggr_expr = select_expr.into_iter().map(|e| { + if let Some(order_by) = &sort_expr { + first_value_udaf + .call(vec![e]) + .order_by(order_by.clone()) + .build() + // guaranteed to be `Expr::AggregateFunction` + .unwrap() + } else { + first_value_udaf.call(vec![e]) + } + }); + + let aggr_expr = normalize_cols(aggr_expr, input.as_ref())?; + let group_expr = normalize_cols(on_expr, input.as_ref())?; + + // Build the aggregation plan + let plan = LogicalPlan::aggregate(Aggregate::try_new( + input, group_expr, aggr_expr, + )?); + // TODO use LogicalPlanBuilder directly rather than recreating the Aggregate + // when https://github.com/apache/datafusion/issues/10485 is available + let lpb = LogicalPlanBuilder::from(plan); + + let plan = if let Some(mut sort_expr) = sort_expr { + // While sort expressions were used in the `FIRST_VALUE` aggregation itself above, + // this on it's own isn't enough to guarantee the proper output order of the grouping + // (`ON`) expression, so we need to sort those as well. + + // truncate the sort_expr to the length of on_expr + sort_expr.truncate(expr_cnt); + + lpb.sort(sort_expr)?.build()? + } else { + lpb.build()? + }; + + // Whereas the aggregation plan by default outputs both the grouping and the aggregation + // expressions, for `DISTINCT ON` we only need to emit the original selection expressions. + + let project_exprs = plan + .schema() + .iter() + .skip(expr_cnt) + .zip(schema.iter()) + .map( + |( + (new_qualifier, new_field), + (old_qualifier, old_field), + )| { + col(Column::from((new_qualifier, new_field))) + .alias_qualified( + old_qualifier.cloned(), + old_field.name(), + ) + }, + ) + .collect::>(); + + let plan = LogicalPlanBuilder::from(plan) + .project(project_exprs)? + .build()?; + + Ok(Transformed::yes(plan)) + } + _ => Ok(Transformed::no(plan)), + } + }, + ) } fn name(&self) -> &str { "replace_distinct_aggregate" } - - fn apply_order(&self) -> Option { - Some(BottomUp) - } } #[cfg(test)] diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index f9b247fc9a982..a29cc3c951f59 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -17,14 +17,13 @@ //! [`ScalarSubqueryToJoin`] rewriting scalar subquery filters to `JOIN`s -use std::collections::{BTreeSet, HashMap}; -use std::ops::Not; -use std::sync::Arc; - use crate::decorrelate::{PullUpCorrelatedExpr, UN_MATCHED_ROW_INDICATOR}; -use crate::optimizer::ApplyOrder; use crate::utils::replace_qualified_name; use crate::{OptimizerConfig, OptimizerRule}; +use enumset::enum_set; +use std::collections::{BTreeSet, HashMap}; +use std::ops::Not; +use std::sync::Arc; use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::{ @@ -32,6 +31,7 @@ use datafusion_common::tree_node::{ }; use datafusion_common::{internal_err, plan_err, Column, Result, ScalarValue}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::logical_plan::{JoinType, Subquery}; use datafusion_expr::utils::conjunction; use datafusion_expr::{expr, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder}; @@ -79,142 +79,151 @@ impl OptimizerRule for ScalarSubqueryToJoin { plan: LogicalPlan, config: &dyn OptimizerConfig, ) -> Result> { - match plan { - LogicalPlan::Filter(filter, _) => { - // Optimization: skip the rest of the rule and its copies if - // there are no scalar subqueries - if !contains_scalar_subquery(&filter.predicate) { - return Ok(Transformed::no(LogicalPlan::filter(filter))); - } + plan.transform_down_with_subqueries(|plan| { + if !plan.stats().contains_any_patterns(enum_set!( + LogicalPlanPattern::LogicalPlanFilter + | LogicalPlanPattern::LogicalPlanProjection + )) { + return Ok(Transformed::jump(plan)); + } - let (subqueries, mut rewrite_expr) = self.extract_subquery_exprs( - &filter.predicate, - config.alias_generator(), - )?; + match plan { + LogicalPlan::Filter(filter, _) => { + // Optimization: skip the rest of the rule and its copies if + // there are no scalar subqueries + if !contains_scalar_subquery(&filter.predicate) { + return Ok(Transformed::no(LogicalPlan::filter(filter))); + } - if subqueries.is_empty() { - return internal_err!("Expected subqueries not found in filter"); - } + let (subqueries, mut rewrite_expr) = self.extract_subquery_exprs( + &filter.predicate, + config.alias_generator(), + )?; - // iterate through all subqueries in predicate, turning each into a left join - let mut cur_input = filter.input.as_ref().clone(); - for (subquery, alias) in subqueries { - if let Some((optimized_subquery, expr_check_map)) = - build_join(&subquery, &cur_input, &alias)? - { - if !expr_check_map.is_empty() { - rewrite_expr = rewrite_expr - .transform_up(|expr| { - // replace column references with entry in map, if it exists - if let Some(map_expr) = expr - .try_as_col() - .and_then(|col| expr_check_map.get(&col.name)) - { - Ok(Transformed::yes(map_expr.clone())) - } else { - Ok(Transformed::no(expr)) - } - }) - .data()?; + if subqueries.is_empty() { + return internal_err!("Expected subqueries not found in filter"); + } + + // iterate through all subqueries in predicate, turning each into a left join + let mut cur_input = filter.input.as_ref().clone(); + for (subquery, alias) in subqueries { + if let Some((optimized_subquery, expr_check_map)) = + build_join(&subquery, &cur_input, &alias)? + { + if !expr_check_map.is_empty() { + rewrite_expr = rewrite_expr + .transform_up(|expr| { + // replace column references with entry in map, if it exists + if let Some(map_expr) = expr + .try_as_col() + .and_then(|col| expr_check_map.get(&col.name)) + { + Ok(Transformed::yes(map_expr.clone())) + } else { + Ok(Transformed::no(expr)) + } + }) + .data()?; + } + cur_input = optimized_subquery; + } else { + // if we can't handle all of the subqueries then bail for now + return Ok(Transformed::no(LogicalPlan::filter(filter))); } - cur_input = optimized_subquery; - } else { - // if we can't handle all of the subqueries then bail for now - return Ok(Transformed::no(LogicalPlan::filter(filter))); } + let new_plan = LogicalPlanBuilder::from(cur_input) + .filter(rewrite_expr)? + .build()?; + Ok(Transformed::yes(new_plan)) } - let new_plan = LogicalPlanBuilder::from(cur_input) - .filter(rewrite_expr)? - .build()?; - Ok(Transformed::yes(new_plan)) - } - LogicalPlan::Projection(projection, _) => { - // Optimization: skip the rest of the rule and its copies if - // there are no scalar subqueries - if !projection.expr.iter().any(contains_scalar_subquery) { - return Ok(Transformed::no(LogicalPlan::projection(projection))); - } + LogicalPlan::Projection(projection, _) => { + // Optimization: skip the rest of the rule and its copies if + // there are no scalar subqueries + if !projection.expr.iter().any(contains_scalar_subquery) { + return Ok(Transformed::no(LogicalPlan::projection(projection))); + } - let mut all_subqueryies = vec![]; - let mut expr_to_rewrite_expr_map = HashMap::new(); - let mut subquery_to_expr_map = HashMap::new(); - for expr in projection.expr.iter() { - let (subqueries, rewrite_exprs) = - self.extract_subquery_exprs(expr, config.alias_generator())?; - for (subquery, _) in &subqueries { - subquery_to_expr_map.insert(subquery.clone(), expr.clone()); + let mut all_subqueryies = vec![]; + let mut expr_to_rewrite_expr_map = HashMap::new(); + let mut subquery_to_expr_map = HashMap::new(); + for expr in projection.expr.iter() { + let (subqueries, rewrite_exprs) = + self.extract_subquery_exprs(expr, config.alias_generator())?; + for (subquery, _) in &subqueries { + subquery_to_expr_map.insert(subquery.clone(), expr.clone()); + } + all_subqueryies.extend(subqueries); + expr_to_rewrite_expr_map.insert(expr, rewrite_exprs); } - all_subqueryies.extend(subqueries); - expr_to_rewrite_expr_map.insert(expr, rewrite_exprs); - } - if all_subqueryies.is_empty() { - return internal_err!("Expected subqueries not found in projection"); - } - // iterate through all subqueries in predicate, turning each into a left join - let mut cur_input = projection.input.as_ref().clone(); - for (subquery, alias) in all_subqueryies { - if let Some((optimized_subquery, expr_check_map)) = - build_join(&subquery, &cur_input, &alias)? - { - cur_input = optimized_subquery; - if !expr_check_map.is_empty() { - if let Some(expr) = subquery_to_expr_map.get(&subquery) { - if let Some(rewrite_expr) = - expr_to_rewrite_expr_map.get(expr) - { - let new_expr = rewrite_expr - .clone() - .transform_up(|expr| { - // replace column references with entry in map, if it exists - if let Some(map_expr) = - expr.try_as_col().and_then(|col| { - expr_check_map.get(&col.name) - }) - { - Ok(Transformed::yes(map_expr.clone())) - } else { - Ok(Transformed::no(expr)) - } - }) - .data()?; - expr_to_rewrite_expr_map.insert(expr, new_expr); + if all_subqueryies.is_empty() { + return internal_err!( + "Expected subqueries not found in projection" + ); + } + // iterate through all subqueries in predicate, turning each into a left join + let mut cur_input = projection.input.as_ref().clone(); + for (subquery, alias) in all_subqueryies { + if let Some((optimized_subquery, expr_check_map)) = + build_join(&subquery, &cur_input, &alias)? + { + cur_input = optimized_subquery; + if !expr_check_map.is_empty() { + if let Some(expr) = subquery_to_expr_map.get(&subquery) { + if let Some(rewrite_expr) = + expr_to_rewrite_expr_map.get(expr) + { + let new_expr = rewrite_expr + .clone() + .transform_up(|expr| { + // replace column references with entry in map, if it exists + if let Some(map_expr) = + expr.try_as_col().and_then(|col| { + expr_check_map.get(&col.name) + }) + { + Ok(Transformed::yes(map_expr.clone())) + } else { + Ok(Transformed::no(expr)) + } + }) + .data()?; + expr_to_rewrite_expr_map.insert(expr, new_expr); + } } } + } else { + // if we can't handle all of the subqueries then bail for now + return Ok(Transformed::no(LogicalPlan::projection( + projection, + ))); } - } else { - // if we can't handle all of the subqueries then bail for now - return Ok(Transformed::no(LogicalPlan::projection(projection))); } - } - let mut proj_exprs = vec![]; - for expr in projection.expr.iter() { - let old_expr_name = expr.schema_name().to_string(); - let new_expr = expr_to_rewrite_expr_map.get(expr).unwrap(); - let new_expr_name = new_expr.schema_name().to_string(); - if new_expr_name != old_expr_name { - proj_exprs.push(new_expr.clone().alias(old_expr_name)) - } else { - proj_exprs.push(new_expr.clone()); + let mut proj_exprs = vec![]; + for expr in projection.expr.iter() { + let old_expr_name = expr.schema_name().to_string(); + let new_expr = expr_to_rewrite_expr_map.get(expr).unwrap(); + let new_expr_name = new_expr.schema_name().to_string(); + if new_expr_name != old_expr_name { + proj_exprs.push(new_expr.clone().alias(old_expr_name)) + } else { + proj_exprs.push(new_expr.clone()); + } } + let new_plan = LogicalPlanBuilder::from(cur_input) + .project(proj_exprs)? + .build()?; + Ok(Transformed::yes(new_plan)) } - let new_plan = LogicalPlanBuilder::from(cur_input) - .project(proj_exprs)? - .build()?; - Ok(Transformed::yes(new_plan)) - } - plan => Ok(Transformed::no(plan)), - } + plan => Ok(Transformed::no(plan)), + } + }) } fn name(&self) -> &str { "scalar_subquery_to_join" } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } } /// Returns true if the expression has a scalar subquery somewhere in it diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 1cefe352fa939..d96e8447f106a 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -19,13 +19,13 @@ use std::sync::Arc; -use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{ internal_err, tree_node::Transformed, DataFusionError, HashSet, Result, }; use datafusion_expr::builder::project; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::{ col, expr::AggregateFunction, @@ -109,10 +109,6 @@ impl OptimizerRule for SingleDistinctToGroupBy { "single_distinct_aggregation_to_group_by" } - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } - fn supports_rewrite(&self) -> bool { true } @@ -122,7 +118,12 @@ impl OptimizerRule for SingleDistinctToGroupBy { plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result, DataFusionError> { - match plan { + plan.transform_down_with_subqueries(|plan| { + if !plan.stats().contains_pattern(LogicalPlanPattern::LogicalPlanAggregate) { + return Ok(Transformed::jump(plan)); + } + + match plan { LogicalPlan::Aggregate( Aggregate { input, @@ -277,7 +278,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { Ok(Transformed::yes(project(outer_aggr, alias_expr)?)) } _ => Ok(Transformed::no(plan)), - } + }}) } }