diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 912383e7da798..816a0e134eaf9 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -54,11 +54,12 @@ use datafusion_common::{ FunctionalDependencies, ParamValues, Result, ScalarValue, TableReference, UnnestOptions, }; +use enumset::enum_set; use indexmap::IndexSet; // backwards compatibility use crate::display::PgJsonVisitor; -use crate::logical_plan::tree_node::LogicalPlanStats; +use crate::logical_plan::tree_node::{LogicalPlanPattern, LogicalPlanStats}; pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; pub use datafusion_common::{JoinConstraint, JoinType}; @@ -2277,7 +2278,9 @@ impl LogicalPlan { } pub fn filter(filter: Filter) -> Self { - let stats = filter.stats(); + let stats = + LogicalPlanStats::new(enum_set!(LogicalPlanPattern::LogicalPlanFilter)) + .merge(filter.stats()); LogicalPlan::Filter(filter, stats) } @@ -2292,17 +2295,21 @@ impl LogicalPlan { } pub fn aggregate(aggregate: Aggregate) -> Self { - let stats = aggregate.stats(); + let stats = + LogicalPlanStats::new(enum_set!(LogicalPlanPattern::LogicalPlanAggregate)) + .merge(aggregate.stats()); LogicalPlan::Aggregate(aggregate, stats) } pub fn sort(sort: Sort) -> Self { - let stats = sort.stats(); + let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::LogicalPlanSort)) + .merge(sort.stats()); LogicalPlan::Sort(sort, stats) } pub fn join(join: Join) -> Self { - let stats = join.stats(); + let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::LogicalPlanJoin)) + .merge(join.stats()); LogicalPlan::Join(join, stats) } @@ -2312,7 +2319,9 @@ impl LogicalPlan { } pub fn union(projection: Union) -> Self { - let stats = projection.stats(); + let stats = + LogicalPlanStats::new(enum_set!(LogicalPlanPattern::LogicalPlanUnion)) + .merge(projection.stats()); LogicalPlan::Union(projection, stats) } @@ -2332,7 +2341,9 @@ impl LogicalPlan { } pub fn limit(limit: Limit) -> Self { - let stats = limit.stats(); + let stats = + LogicalPlanStats::new(enum_set!(LogicalPlanPattern::LogicalPlanLimit)) + .merge(limit.stats()); LogicalPlan::Limit(limit, stats) } diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 7a1db7a94dc14..5a115dd9431c5 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -95,19 +95,19 @@ pub enum LogicalPlanPattern { // /// [`LogicalPlan`] // LogicalPlanProjection, - // LogicalPlanFilter, + LogicalPlanFilter, // LogicalPlanWindow, - // LogicalPlanAggregate, - // LogicalPlanSort, - // LogicalPlanJoin, + LogicalPlanAggregate, + LogicalPlanSort, + LogicalPlanJoin, // LogicalPlanCrossJoin, // LogicalPlanRepartition, - // LogicalPlanUnion, + LogicalPlanUnion, // LogicalPlanTableScan, // LogicalPlanEmptyRelation, // LogicalPlanSubquery, // LogicalPlanSubqueryAlias, - // LogicalPlanLimit, + LogicalPlanLimit, // LogicalPlanStatement, // LogicalPlanValues, // LogicalPlanExplain, diff --git a/datafusion/optimizer/src/eliminate_duplicated_expr.rs b/datafusion/optimizer/src/eliminate_duplicated_expr.rs index 4fca0dfd4c4a6..ce99603356dd1 100644 --- a/datafusion/optimizer/src/eliminate_duplicated_expr.rs +++ b/datafusion/optimizer/src/eliminate_duplicated_expr.rs @@ -21,10 +21,14 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; use datafusion_common::Result; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::logical_plan::LogicalPlan; use datafusion_expr::{Aggregate, Expr, Sort, SortExpr}; +use enumset::enum_set; use indexmap::IndexSet; +use std::cell::Cell; use std::hash::{Hash, Hasher}; + /// Optimization rule that eliminate duplicated expr. #[derive(Default, Debug)] pub struct EliminateDuplicatedExpr; @@ -49,10 +53,6 @@ impl Hash for SortExprWrapper { } } impl OptimizerRule for EliminateDuplicatedExpr { - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } - fn supports_rewrite(&self) -> bool { true } @@ -62,51 +62,60 @@ impl OptimizerRule for EliminateDuplicatedExpr { plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { - match plan { - LogicalPlan::Sort(sort, _) => { - let len = sort.expr.len(); - let unique_exprs: Vec<_> = sort - .expr - .into_iter() - .map(SortExprWrapper) - .collect::>() - .into_iter() - .map(|wrapper| wrapper.0) - .collect(); + plan.transform_down_with_subqueries(|plan| { + if !plan.stats().contains_any_patterns(enum_set!( + LogicalPlanPattern::LogicalPlanSort + | LogicalPlanPattern::LogicalPlanAggregate + )) { + return Ok(Transformed::jump(plan)); + } - let transformed = if len != unique_exprs.len() { - Transformed::yes - } else { - Transformed::no - }; + match plan { + LogicalPlan::Sort(sort, _) => { + let len = sort.expr.len(); + let unique_exprs: Vec<_> = sort + .expr + .into_iter() + .map(SortExprWrapper) + .collect::>() + .into_iter() + .map(|wrapper| wrapper.0) + .collect(); - Ok(transformed(LogicalPlan::sort(Sort { - expr: unique_exprs, - input: sort.input, - fetch: sort.fetch, - }))) - } - LogicalPlan::Aggregate(agg, _) => { - let len = agg.group_expr.len(); + let transformed = if len != unique_exprs.len() { + Transformed::yes + } else { + Transformed::no + }; + + Ok(transformed(LogicalPlan::sort(Sort { + expr: unique_exprs, + input: sort.input, + fetch: sort.fetch, + }))) + } + LogicalPlan::Aggregate(agg, _) => { + let len = agg.group_expr.len(); - let unique_exprs: Vec = agg - .group_expr - .into_iter() - .collect::>() - .into_iter() - .collect(); + let unique_exprs: Vec = agg + .group_expr + .into_iter() + .collect::>() + .into_iter() + .collect(); - let transformed = if len != unique_exprs.len() { - Transformed::yes - } else { - Transformed::no - }; + let transformed = if len != unique_exprs.len() { + Transformed::yes + } else { + Transformed::no + }; - Aggregate::try_new(agg.input, unique_exprs, agg.aggr_expr) - .map(|f| transformed(LogicalPlan::aggregate(f))) + Aggregate::try_new(agg.input, unique_exprs, agg.aggr_expr) + .map(|f| transformed(LogicalPlan::aggregate(f))) + } + _ => Ok(Transformed::no(plan)), } - _ => Ok(Transformed::no(plan)), - } + }) } fn name(&self) -> &str { "eliminate_duplicated_expr" diff --git a/datafusion/optimizer/src/eliminate_filter.rs b/datafusion/optimizer/src/eliminate_filter.rs index 38d4a531444d2..158835019d672 100644 --- a/datafusion/optimizer/src/eliminate_filter.rs +++ b/datafusion/optimizer/src/eliminate_filter.rs @@ -17,14 +17,15 @@ //! [`EliminateFilter`] replaces `where false` or `where null` with an empty relation. +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::{EmptyRelation, Expr, Filter, LogicalPlan}; +use enumset::enum_set; use std::sync::Arc; -use crate::optimizer::ApplyOrder; -use crate::{OptimizerConfig, OptimizerRule}; - /// Optimization rule that eliminate the scalar value (true/false/null) filter /// with an [LogicalPlan::EmptyRelation] /// @@ -45,10 +46,6 @@ impl OptimizerRule for EliminateFilter { "eliminate_filter" } - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } - fn supports_rewrite(&self) -> bool { true } @@ -58,25 +55,33 @@ impl OptimizerRule for EliminateFilter { plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { - match plan { - LogicalPlan::Filter( - Filter { - predicate: Expr::Literal(ScalarValue::Boolean(v), _), - input, - .. - }, - _, - ) => match v { - Some(true) => Ok(Transformed::yes(Arc::unwrap_or_clone(input))), - Some(false) | None => Ok(Transformed::yes(LogicalPlan::empty_relation( - EmptyRelation { - produce_one_row: false, - schema: Arc::clone(input.schema()), + plan.transform_down_with_subqueries(|plan| { + if !plan.stats().contains_any_patterns(enum_set!( + LogicalPlanPattern::LogicalPlanFilter | LogicalPlanPattern::ExprLiteral + )) { + return Ok(Transformed::jump(plan)); + } + + match plan { + LogicalPlan::Filter( + Filter { + predicate: Expr::Literal(ScalarValue::Boolean(v), _), + input, + .. }, - ))), - }, - _ => Ok(Transformed::no(plan)), - } + _, + ) => match v { + Some(true) => Ok(Transformed::yes(Arc::unwrap_or_clone(input))), + Some(false) | None => Ok(Transformed::yes( + LogicalPlan::empty_relation(EmptyRelation { + produce_one_row: false, + schema: Arc::clone(input.schema()), + }), + )), + }, + _ => Ok(Transformed::no(plan)), + } + }) } } diff --git a/datafusion/optimizer/src/eliminate_group_by_constant.rs b/datafusion/optimizer/src/eliminate_group_by_constant.rs index 0023b2fe012ae..9d2a336ce9642 100644 --- a/datafusion/optimizer/src/eliminate_group_by_constant.rs +++ b/datafusion/optimizer/src/eliminate_group_by_constant.rs @@ -16,11 +16,13 @@ // under the License. //! [`EliminateGroupByConstant`] removes constant expressions from `GROUP BY` clause -use crate::optimizer::ApplyOrder; + use crate::{OptimizerConfig, OptimizerRule}; +use std::cell::Cell; use datafusion_common::tree_node::Transformed; use datafusion_common::Result; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::{Aggregate, Expr, LogicalPlan, LogicalPlanBuilder, Volatility}; /// Optimizer rule that removes constant expressions from `GROUP BY` clause @@ -45,50 +47,71 @@ impl OptimizerRule for EliminateGroupByConstant { plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { - match plan { - LogicalPlan::Aggregate(aggregate, _) => { - let (const_group_expr, nonconst_group_expr): (Vec<_>, Vec<_>) = aggregate - .group_expr - .iter() - .partition(|expr| is_constant_expression(expr)); - - // If no constant expressions found (nothing to optimize) or - // constant expression is the only expression in aggregate, - // optimization is skipped - if const_group_expr.is_empty() - || (!const_group_expr.is_empty() - && nonconst_group_expr.is_empty() - && aggregate.aggr_expr.is_empty()) + let skip = Cell::new(false); + plan.transform_down_up_with_subqueries( + |plan| { + if !plan + .stats() + .contains_pattern(LogicalPlanPattern::LogicalPlanAggregate) { - return Ok(Transformed::no(LogicalPlan::aggregate(aggregate))); + skip.set(true); + return Ok(Transformed::jump(plan)); } - let simplified_aggregate = LogicalPlan::aggregate(Aggregate::try_new( - aggregate.input, - nonconst_group_expr.into_iter().cloned().collect(), - aggregate.aggr_expr.clone(), - )?); - - let projection_expr = - aggregate.group_expr.into_iter().chain(aggregate.aggr_expr); - - let projection = LogicalPlanBuilder::from(simplified_aggregate) - .project(projection_expr)? - .build()?; + Ok(Transformed::no(plan)) + }, + |plan| { + if skip.get() { + skip.set(false); + return Ok(Transformed::no(plan)); + } - Ok(Transformed::yes(projection)) - } - _ => Ok(Transformed::no(plan)), - } + match plan { + LogicalPlan::Aggregate(aggregate, _) => { + let (const_group_expr, nonconst_group_expr): (Vec<_>, Vec<_>) = + aggregate + .group_expr + .iter() + .partition(|expr| is_constant_expression(expr)); + + // If no constant expressions found (nothing to optimize) or + // constant expression is the only expression in aggregate, + // optimization is skipped + if const_group_expr.is_empty() + || (!const_group_expr.is_empty() + && nonconst_group_expr.is_empty() + && aggregate.aggr_expr.is_empty()) + { + return Ok(Transformed::no(LogicalPlan::aggregate( + aggregate, + ))); + } + + let simplified_aggregate = + LogicalPlan::aggregate(Aggregate::try_new( + aggregate.input, + nonconst_group_expr.into_iter().cloned().collect(), + aggregate.aggr_expr.clone(), + )?); + + let projection_expr = + aggregate.group_expr.into_iter().chain(aggregate.aggr_expr); + + let projection = LogicalPlanBuilder::from(simplified_aggregate) + .project(projection_expr)? + .build()?; + + Ok(Transformed::yes(projection)) + } + _ => Ok(Transformed::no(plan)), + } + }, + ) } fn name(&self) -> &str { "eliminate_group_by_constant" } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::BottomUp) - } } /// Checks if expression is constant, and can be eliminated from group by. diff --git a/datafusion/optimizer/src/eliminate_join.rs b/datafusion/optimizer/src/eliminate_join.rs index e36c20d6c8983..3b7e8d0725163 100644 --- a/datafusion/optimizer/src/eliminate_join.rs +++ b/datafusion/optimizer/src/eliminate_join.rs @@ -16,15 +16,18 @@ // under the License. //! [`EliminateJoin`] rewrites `INNER JOIN` with `true`/`null` + use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::JoinType::Inner; use datafusion_expr::{ logical_plan::{EmptyRelation, LogicalPlan}, Expr, }; +use enumset::enum_set; /// Eliminates joins when join condition is false. /// Replaces joins when inner join condition is true with a cross join. @@ -42,31 +45,37 @@ impl OptimizerRule for EliminateJoin { "eliminate_join" } - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } - fn rewrite( &self, plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { - match plan { - LogicalPlan::Join(join, _) - if join.join_type == Inner && join.on.is_empty() => - { - match join.filter { - Some(Expr::Literal(ScalarValue::Boolean(Some(false)), _)) => Ok( - Transformed::yes(LogicalPlan::empty_relation(EmptyRelation { - produce_one_row: false, - schema: join.schema, - })), - ), - _ => Ok(Transformed::no(LogicalPlan::join(join))), + plan.transform_down_with_subqueries(|plan| { + if !plan.stats().contains_all_patterns(enum_set!( + LogicalPlanPattern::LogicalPlanJoin | LogicalPlanPattern::ExprLiteral + )) { + return Ok(Transformed::jump(plan)); + } + + match plan { + LogicalPlan::Join(join, _) + if join.join_type == Inner && join.on.is_empty() => + { + match join.filter { + Some(Expr::Literal(ScalarValue::Boolean(Some(false)), _)) => { + Ok(Transformed::yes(LogicalPlan::empty_relation( + EmptyRelation { + produce_one_row: false, + schema: join.schema, + }, + ))) + } + _ => Ok(Transformed::no(LogicalPlan::join(join))), + } } + _ => Ok(Transformed::no(plan)), } - _ => Ok(Transformed::no(plan)), - } + }) } fn supports_rewrite(&self) -> bool { diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index d47aa3a48ec17..4d22653607a77 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -16,11 +16,15 @@ // under the License. //! [`EliminateLimit`] eliminates `LIMIT` when possible + use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; use datafusion_common::Result; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::logical_plan::{EmptyRelation, FetchType, LogicalPlan, SkipType}; +use enumset::enum_set; +use std::cell::Cell; use std::sync::Arc; /// Optimizer rule to replace `LIMIT 0` or `LIMIT` whose ancestor LIMIT's skip is @@ -45,10 +49,6 @@ impl OptimizerRule for EliminateLimit { "eliminate_limit" } - fn apply_order(&self) -> Option { - Some(ApplyOrder::BottomUp) - } - fn supports_rewrite(&self) -> bool { true } @@ -58,31 +58,53 @@ impl OptimizerRule for EliminateLimit { plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result, datafusion_common::DataFusionError> { - match plan { - LogicalPlan::Limit(limit, _) => { - // Only supports rewriting for literal fetch - let FetchType::Literal(fetch) = limit.get_fetch_type()? else { - return Ok(Transformed::no(LogicalPlan::limit(limit))); - }; + let skip = Cell::new(false); + plan.transform_down_up_with_subqueries( + |plan| { + if !plan + .stats() + .contains_pattern(LogicalPlanPattern::LogicalPlanLimit) + { + skip.set(true); + return Ok(Transformed::jump(plan)); + } + + Ok(Transformed::no(plan)) + }, + |plan| { + if skip.get() { + skip.set(false); + return Ok(Transformed::no(plan)); + } + + match plan { + LogicalPlan::Limit(limit, _) => { + // Only supports rewriting for literal fetch + let FetchType::Literal(fetch) = limit.get_fetch_type()? else { + return Ok(Transformed::no(LogicalPlan::limit(limit))); + }; - if let Some(v) = fetch { - if v == 0 { - return Ok(Transformed::yes(LogicalPlan::empty_relation( - EmptyRelation { - produce_one_row: false, - schema: Arc::clone(limit.input.schema()), - }, - ))); + if let Some(v) = fetch { + if v == 0 { + return Ok(Transformed::yes( + LogicalPlan::empty_relation(EmptyRelation { + produce_one_row: false, + schema: Arc::clone(limit.input.schema()), + }), + )); + } + } else if matches!(limit.get_skip_type()?, SkipType::Literal(0)) { + // If fetch is `None` and skip is 0, then Limit takes no effect and + // we can remove it. Its input also can be Limit, so we should apply again. + return self + .rewrite(Arc::unwrap_or_clone(limit.input), _config); + } + Ok(Transformed::no(LogicalPlan::limit(limit))) } - } else if matches!(limit.get_skip_type()?, SkipType::Literal(0)) { - // If fetch is `None` and skip is 0, then Limit takes no effect and - // we can remove it. Its input also can be Limit, so we should apply again. - return self.rewrite(Arc::unwrap_or_clone(limit.input), _config); + _ => Ok(Transformed::no(plan)), } - Ok(Transformed::no(LogicalPlan::limit(limit))) - } - _ => Ok(Transformed::no(plan)), - } + }, + ) } } diff --git a/datafusion/optimizer/src/eliminate_nested_union.rs b/datafusion/optimizer/src/eliminate_nested_union.rs index 4979ddc2f3ac1..af9470707db3e 100644 --- a/datafusion/optimizer/src/eliminate_nested_union.rs +++ b/datafusion/optimizer/src/eliminate_nested_union.rs @@ -16,13 +16,17 @@ // under the License. //! [`EliminateNestedUnion`]: flattens nested `Union` to a single `Union` + use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; use datafusion_common::Result; use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::{Distinct, LogicalPlan, Union}; +use enumset::enum_set; use itertools::Itertools; +use std::cell::Cell; use std::sync::Arc; #[derive(Default, Debug)] @@ -41,10 +45,6 @@ impl OptimizerRule for EliminateNestedUnion { "eliminate_nested_union" } - fn apply_order(&self) -> Option { - Some(ApplyOrder::BottomUp) - } - fn supports_rewrite(&self) -> bool { true } @@ -54,43 +54,69 @@ impl OptimizerRule for EliminateNestedUnion { plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { - match plan { - LogicalPlan::Union(Union { inputs, schema }, _) => { - let inputs = inputs - .into_iter() - .flat_map(extract_plans_from_union) - .map(|plan| coerce_plan_expr_for_schema(plan, &schema)) - .collect::>>()?; - - Ok(Transformed::yes(LogicalPlan::union(Union { - inputs: inputs.into_iter().map(Arc::new).collect_vec(), - schema, - }))) - } - LogicalPlan::Distinct(Distinct::All(nested_plan), _) => { - match Arc::unwrap_or_clone(nested_plan) { + let skip = Cell::new(false); + plan.transform_down_up_with_subqueries( + |plan| { + if !plan + .stats() + .contains_pattern(LogicalPlanPattern::LogicalPlanUnion) + { + skip.set(true); + return Ok(Transformed::jump(plan)); + } + + Ok(Transformed::no(plan)) + }, + |plan| { + if skip.get() { + skip.set(false); + return Ok(Transformed::no(plan)); + } + + match plan { LogicalPlan::Union(Union { inputs, schema }, _) => { let inputs = inputs .into_iter() - .map(extract_plan_from_distinct) .flat_map(extract_plans_from_union) .map(|plan| coerce_plan_expr_for_schema(plan, &schema)) .collect::>>()?; - Ok(Transformed::yes(LogicalPlan::distinct(Distinct::All( - Arc::new(LogicalPlan::union(Union { - inputs: inputs.into_iter().map(Arc::new).collect_vec(), - schema: Arc::clone(&schema), - })), - )))) + Ok(Transformed::yes(LogicalPlan::union(Union { + inputs: inputs.into_iter().map(Arc::new).collect_vec(), + schema, + }))) } - nested_plan => Ok(Transformed::no(LogicalPlan::distinct( - Distinct::All(Arc::new(nested_plan)), - ))), + LogicalPlan::Distinct(Distinct::All(nested_plan), _) => { + match Arc::unwrap_or_clone(nested_plan) { + LogicalPlan::Union(Union { inputs, schema }, _) => { + let inputs = inputs + .into_iter() + .map(extract_plan_from_distinct) + .flat_map(extract_plans_from_union) + .map(|plan| { + coerce_plan_expr_for_schema(plan, &schema) + }) + .collect::>>()?; + + Ok(Transformed::yes(LogicalPlan::distinct( + Distinct::All(Arc::new(LogicalPlan::union(Union { + inputs: inputs + .into_iter() + .map(Arc::new) + .collect_vec(), + schema: Arc::clone(&schema), + }))), + ))) + } + nested_plan => Ok(Transformed::no(LogicalPlan::distinct( + Distinct::All(Arc::new(nested_plan)), + ))), + } + } + _ => Ok(Transformed::no(plan)), } - } - _ => Ok(Transformed::no(plan)), - } + }, + ) } } diff --git a/datafusion/optimizer/src/eliminate_one_union.rs b/datafusion/optimizer/src/eliminate_one_union.rs index ac3da4e8f65d8..98e4e77ebe97e 100644 --- a/datafusion/optimizer/src/eliminate_one_union.rs +++ b/datafusion/optimizer/src/eliminate_one_union.rs @@ -17,13 +17,14 @@ //! [`EliminateOneUnion`] eliminates single element `Union` +use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{tree_node::Transformed, Result}; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::logical_plan::{LogicalPlan, Union}; +use enumset::enum_set; use std::sync::Arc; -use crate::optimizer::ApplyOrder; - #[derive(Default, Debug)] /// An optimization rule that eliminates union with one element. pub struct EliminateOneUnion; @@ -49,16 +50,23 @@ impl OptimizerRule for EliminateOneUnion { plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { - match plan { - LogicalPlan::Union(Union { mut inputs, .. }, _) if inputs.len() == 1 => Ok( - Transformed::yes(Arc::unwrap_or_clone(inputs.pop().unwrap())), - ), - _ => Ok(Transformed::no(plan)), - } - } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) + plan.transform_down_with_subqueries(|plan| { + if !plan + .stats() + .contains_pattern(LogicalPlanPattern::LogicalPlanUnion) + { + return Ok(Transformed::jump(plan)); + } + + match plan { + LogicalPlan::Union(Union { mut inputs, .. }, _) if inputs.len() == 1 => { + Ok(Transformed::yes(Arc::unwrap_or_clone( + inputs.pop().unwrap(), + ))) + } + _ => Ok(Transformed::no(plan)), + } + }) } } diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index bca5d61d4c449..f6a329fc3c47c 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -24,6 +24,8 @@ use datafusion_expr::{Expr, Filter, Operator}; use crate::optimizer::ApplyOrder; use datafusion_common::tree_node::Transformed; use datafusion_expr::expr::{BinaryExpr, Cast, TryCast}; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; +use enumset::enum_set; use std::sync::Arc; /// @@ -64,10 +66,6 @@ impl OptimizerRule for EliminateOuterJoin { "eliminate_outer_join" } - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } - fn supports_rewrite(&self) -> bool { true } @@ -77,61 +75,70 @@ impl OptimizerRule for EliminateOuterJoin { plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { - match plan { - LogicalPlan::Filter(mut filter, _) => { - match Arc::unwrap_or_clone(filter.input) { - LogicalPlan::Join(join, _) => { - let mut non_nullable_cols: Vec = vec![]; + plan.transform_down_with_subqueries(|plan| { + if !plan.stats().contains_all_patterns(enum_set!( + LogicalPlanPattern::LogicalPlanFilter + | LogicalPlanPattern::LogicalPlanJoin + )) { + return Ok(Transformed::jump(plan)); + } - extract_non_nullable_columns( - &filter.predicate, - &mut non_nullable_cols, - join.left.schema(), - join.right.schema(), - true, - ); + match plan { + LogicalPlan::Filter(mut filter, _) => { + match Arc::unwrap_or_clone(filter.input) { + LogicalPlan::Join(join, _) => { + let mut non_nullable_cols: Vec = vec![]; - let new_join_type = if join.join_type.is_outer() { - let mut left_non_nullable = false; - let mut right_non_nullable = false; - for col in non_nullable_cols.iter() { - if join.left.schema().has_column(col) { - left_non_nullable = true; - } - if join.right.schema().has_column(col) { - right_non_nullable = true; + extract_non_nullable_columns( + &filter.predicate, + &mut non_nullable_cols, + join.left.schema(), + join.right.schema(), + true, + ); + + let new_join_type = if join.join_type.is_outer() { + let mut left_non_nullable = false; + let mut right_non_nullable = false; + for col in non_nullable_cols.iter() { + if join.left.schema().has_column(col) { + left_non_nullable = true; + } + if join.right.schema().has_column(col) { + right_non_nullable = true; + } } - } - eliminate_outer( - join.join_type, - left_non_nullable, - right_non_nullable, - ) - } else { - join.join_type - }; + eliminate_outer( + join.join_type, + left_non_nullable, + right_non_nullable, + ) + } else { + join.join_type + }; - let new_join = Arc::new(LogicalPlan::join(Join { - left: join.left, - right: join.right, - join_type: new_join_type, - join_constraint: join.join_constraint, - on: join.on.clone(), - filter: join.filter.clone(), - schema: Arc::clone(&join.schema), - null_equals_null: join.null_equals_null, - })); - Filter::try_new(filter.predicate, new_join) - .map(|f| Transformed::yes(LogicalPlan::filter(f))) - } - filter_input => { - filter.input = Arc::new(filter_input); - Ok(Transformed::no(LogicalPlan::filter(filter))) + let new_join = Arc::new(LogicalPlan::join(Join { + left: join.left, + right: join.right, + join_type: new_join_type, + join_constraint: join.join_constraint, + on: join.on.clone(), + filter: join.filter.clone(), + schema: Arc::clone(&join.schema), + null_equals_null: join.null_equals_null, + })); + Filter::try_new(filter.predicate, new_join) + .map(|f| Transformed::yes(LogicalPlan::filter(f))) + } + filter_input => { + filter.input = Arc::new(filter_input); + Ok(Transformed::no(LogicalPlan::filter(filter))) + } } } + _ => Ok(Transformed::no(plan)), } - _ => Ok(Transformed::no(plan)), - } + }) } } diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs index 16c3355c3b8f6..bd05e5b0ce34f 100644 --- a/datafusion/optimizer/src/extract_equijoin_predicate.rs +++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs @@ -16,14 +16,19 @@ // under the License. //! [`ExtractEquijoinPredicate`] identifies equality join (equijoin) predicates + use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; use datafusion_common::DFSchema; use datafusion_common::Result; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::utils::split_conjunction_owned; use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair}; use datafusion_expr::{BinaryExpr, Expr, ExprSchemable, Join, LogicalPlan, Operator}; +use enumset::enum_set; +use std::cell::Cell; + // equijoin predicate type EquijoinPredicate = (Expr, Expr); @@ -57,61 +62,82 @@ impl OptimizerRule for ExtractEquijoinPredicate { "extract_equijoin_predicate" } - fn apply_order(&self) -> Option { - Some(ApplyOrder::BottomUp) - } - fn rewrite( &self, plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { - match plan { - LogicalPlan::Join( - Join { - left, - right, - mut on, - filter: Some(expr), - join_type, - join_constraint, - schema, - null_equals_null, - }, - _, - ) => { - let left_schema = left.schema(); - let right_schema = right.schema(); - let (equijoin_predicates, non_equijoin_expr) = - split_eq_and_noneq_join_predicate(expr, left_schema, right_schema)?; - - if !equijoin_predicates.is_empty() { - on.extend(equijoin_predicates); - Ok(Transformed::yes(LogicalPlan::join(Join { - left, - right, - on, - filter: non_equijoin_expr, - join_type, - join_constraint, - schema, - null_equals_null, - }))) - } else { - Ok(Transformed::no(LogicalPlan::join(Join { - left, - right, - on, - filter: non_equijoin_expr, - join_type, - join_constraint, - schema, - null_equals_null, - }))) + let skip = Cell::new(false); + plan.transform_down_up_with_subqueries( + |plan| { + if !plan + .stats() + .contains_pattern(LogicalPlanPattern::LogicalPlanJoin) + { + skip.set(true); + return Ok(Transformed::jump(plan)); } - } - _ => Ok(Transformed::no(plan)), - } + + Ok(Transformed::no(plan)) + }, + |plan| { + if skip.get() { + skip.set(false); + return Ok(Transformed::no(plan)); + } + + match plan { + LogicalPlan::Join( + Join { + left, + right, + mut on, + filter: Some(expr), + join_type, + join_constraint, + schema, + null_equals_null, + }, + _, + ) => { + let left_schema = left.schema(); + let right_schema = right.schema(); + let (equijoin_predicates, non_equijoin_expr) = + split_eq_and_noneq_join_predicate( + expr, + left_schema, + right_schema, + )?; + + if !equijoin_predicates.is_empty() { + on.extend(equijoin_predicates); + Ok(Transformed::yes(LogicalPlan::join(Join { + left, + right, + on, + filter: non_equijoin_expr, + join_type, + join_constraint, + schema, + null_equals_null, + }))) + } else { + Ok(Transformed::no(LogicalPlan::join(Join { + left, + right, + on, + filter: non_equijoin_expr, + join_type, + join_constraint, + schema, + null_equals_null, + }))) + } + } + _ => Ok(Transformed::no(plan)), + } + }, + ) } } diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index f0ce4db625dce..3583582b31063 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -89,10 +89,6 @@ impl OptimizerRule for UnwrapCastInComparison { "unwrap_cast_in_comparison" } - fn apply_order(&self) -> Option { - None - } - fn supports_rewrite(&self) -> bool { true }