Skip to content

Commit

Permalink
fix more rules
Browse files Browse the repository at this point in the history
  • Loading branch information
peter-toth committed Nov 27, 2024
1 parent d63b6fc commit bcb0294
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 44 deletions.
12 changes: 8 additions & 4 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1799,12 +1799,15 @@ impl Expr {
}

pub fn in_subquery(in_subquery: InSubquery) -> Self {
let stats = in_subquery.stats();
let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprInSubquery))
.merge(in_subquery.stats());
Expr::InSubquery(in_subquery, stats)
}

pub fn scalar_subquery(subquery: Subquery) -> Self {
let stats = subquery.stats();
let stats =
LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprScalarSubquery))
.merge(subquery.stats());
Expr::ScalarSubquery(subquery, stats)
}

Expand Down Expand Up @@ -1919,7 +1922,8 @@ impl Expr {
}

pub fn exists(exists: Exists) -> Self {
let stats = exists.stats();
let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprExists))
.merge(exists.stats());
Expr::Exists(exists, stats)
}

Expand All @@ -1939,7 +1943,7 @@ impl Expr {
}

pub fn placeholder(placeholder: Placeholder) -> Self {
let stats = LogicalPlanStats::empty();
let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprPlaceholder));
Expr::Placeholder(placeholder, stats)
}

Expand Down
35 changes: 35 additions & 0 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,13 @@ impl LogicalPlan {
let mut using_columns: Vec<HashSet<Column>> = vec![];

self.apply_with_subqueries(|plan| {
if !plan
.stats()
.contains_pattern(LogicalPlanPattern::LogicalPlanJoin)
{
return Ok(TreeNodeRecursion::Jump);
}

if let LogicalPlan::Join(
Join {
join_constraint: JoinConstraint::Using,
Expand Down Expand Up @@ -1725,8 +1732,22 @@ impl LogicalPlan {
pub fn get_parameter_names(&self) -> Result<HashSet<String>> {
let mut param_names = HashSet::new();
self.apply_with_subqueries(|plan| {
if !plan
.stats()
.contains_pattern(LogicalPlanPattern::ExprPlaceholder)
{
return Ok(TreeNodeRecursion::Jump);
}

plan.apply_expressions(|expr| {
expr.apply(|expr| {
if !plan
.stats()
.contains_pattern(LogicalPlanPattern::ExprPlaceholder)
{
return Ok(TreeNodeRecursion::Jump);
}

if let Expr::Placeholder(Placeholder { id, .. }, _) = expr {
param_names.insert(id.clone());
}
Expand All @@ -1744,8 +1765,22 @@ impl LogicalPlan {
let mut param_types: HashMap<String, Option<DataType>> = HashMap::new();

self.apply_with_subqueries(|plan| {
if !plan
.stats()
.contains_pattern(LogicalPlanPattern::ExprPlaceholder)
{
return Ok(TreeNodeRecursion::Jump);
}

plan.apply_expressions(|expr| {
expr.apply(|expr| {
if !plan
.stats()
.contains_pattern(LogicalPlanPattern::ExprPlaceholder)
{
return Ok(TreeNodeRecursion::Jump);
}

if let Expr::Placeholder(Placeholder { id, data_type }, _) = expr {
let prev = param_types.get(id);
match (prev, data_type) {
Expand Down
8 changes: 4 additions & 4 deletions datafusion/expr/src/logical_plan/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,12 @@ pub enum LogicalPlanPattern {
ExprAggregateFunction,
ExprWindowFunction,
ExprInList,
// ExprExists,
// ExprInSubquery,
// ExprScalarSubquery,
ExprExists,
ExprInSubquery,
ExprScalarSubquery,
// ExprWildcard,
// ExprGroupingSet,
// ExprPlaceholder,
ExprPlaceholder,
// ExprOuterReferenceColumn,
// ExprUnnest,

Expand Down
84 changes: 58 additions & 26 deletions datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
// under the License.

use crate::analyzer::AnalyzerRule;
use enumset::enum_set;
use std::cell::Cell;

use crate::utils::NamePreserver;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::Result;
use datafusion_expr::expr::{AggregateFunction, WindowFunction};
use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern;
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use datafusion_expr::{lit, Expr, LogicalPlan, WindowFunctionDefinition};

Expand All @@ -39,7 +42,61 @@ impl CountWildcardRule {

impl AnalyzerRule for CountWildcardRule {
fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result<LogicalPlan> {
plan.transform_down_with_subqueries(analyze_internal).data()
plan.transform_down_with_subqueries(|plan| {
if !plan.stats().contains_any_patterns(enum_set!(
LogicalPlanPattern::ExprWindowFunction
| LogicalPlanPattern::ExprAggregateFunction
)) {
return Ok(Transformed::jump(plan));
}

let name_preserver = NamePreserver::new(&plan);
plan.map_expressions(|expr| {
let original_name = name_preserver.save(&expr);
let skip = Cell::new(false);
let transformed_expr = expr.transform_down_up(
|expr| {
if !expr.stats().contains_any_patterns(enum_set!(
LogicalPlanPattern::ExprWindowFunction
| LogicalPlanPattern::ExprAggregateFunction
)) {
skip.set(true);
return Ok(Transformed::jump(expr));
}

Ok(Transformed::no(expr))
},
|expr| {
if skip.get() {
skip.set(false);
return Ok(Transformed::no(expr));
}

match expr {
Expr::WindowFunction(mut window_function, _)
if is_count_star_window_aggregate(&window_function) =>
{
window_function.args = vec![lit(COUNT_STAR_EXPANSION)];
Ok(Transformed::yes(Expr::window_function(
window_function,
)))
}
Expr::AggregateFunction(mut aggregate_function, _)
if is_count_star_aggregate(&aggregate_function) =>
{
aggregate_function.args = vec![lit(COUNT_STAR_EXPANSION)];
Ok(Transformed::yes(Expr::aggregate_function(
aggregate_function,
)))
}
_ => Ok(Transformed::no(expr)),
}
},
)?;
Ok(transformed_expr.update_data(|data| original_name.restore(data)))
})
})
.data()
}

fn name(&self) -> &str {
Expand Down Expand Up @@ -67,31 +124,6 @@ fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool {
if udaf.name() == "count" && (args.len() == 1 && is_wildcard(&args[0]) || args.is_empty()))
}

fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
let name_preserver = NamePreserver::new(&plan);
plan.map_expressions(|expr| {
let original_name = name_preserver.save(&expr);
let transformed_expr = expr.transform_up(|expr| match expr {
Expr::WindowFunction(mut window_function, _)
if is_count_star_window_aggregate(&window_function) =>
{
window_function.args = vec![lit(COUNT_STAR_EXPANSION)];
Ok(Transformed::yes(Expr::window_function(window_function)))
}
Expr::AggregateFunction(mut aggregate_function, _)
if is_count_star_aggregate(&aggregate_function) =>
{
aggregate_function.args = vec![lit(COUNT_STAR_EXPANSION)];
Ok(Transformed::yes(Expr::aggregate_function(
aggregate_function,
)))
}
_ => Ok(Transformed::no(expr)),
})?;
Ok(transformed_expr.update_data(|data| original_name.restore(data)))
})
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
36 changes: 26 additions & 10 deletions datafusion/optimizer/src/analyzer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,28 @@

//! [`Analyzer`] and [`AnalyzerRule`]
use enumset::enum_set;
use log::debug;
use std::fmt::Debug;
use std::sync::Arc;

use log::debug;

use crate::analyzer::count_wildcard_rule::CountWildcardRule;
use crate::analyzer::expand_wildcard_rule::ExpandWildcardRule;
use crate::analyzer::inline_table_scan::InlineTableScan;
use crate::analyzer::resolve_grouping_function::ResolveGroupingFunction;
use crate::analyzer::subquery::check_subquery_expr;
use crate::analyzer::type_coercion::TypeCoercion;
use crate::utils::log_plan;
use datafusion_common::config::ConfigOptions;
use datafusion_common::instant::Instant;
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::expr::Exists;
use datafusion_expr::expr::InSubquery;
use datafusion_expr::expr_rewriter::FunctionRewrite;
use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern;
use datafusion_expr::{Expr, LogicalPlan};

use crate::analyzer::count_wildcard_rule::CountWildcardRule;
use crate::analyzer::expand_wildcard_rule::ExpandWildcardRule;
use crate::analyzer::inline_table_scan::InlineTableScan;
use crate::analyzer::resolve_grouping_function::ResolveGroupingFunction;
use crate::analyzer::subquery::check_subquery_expr;
use crate::analyzer::type_coercion::TypeCoercion;
use crate::utils::log_plan;

use self::function_rewrite::ApplyFunctionRewrites;

pub mod count_wildcard_rule;
Expand Down Expand Up @@ -177,9 +177,25 @@ impl Analyzer {
/// Do necessary check and fail the invalid plan
fn check_plan(plan: &LogicalPlan) -> Result<()> {
plan.apply_with_subqueries(|plan: &LogicalPlan| {
if !plan.stats().contains_any_patterns(enum_set!(
LogicalPlanPattern::ExprExists
| LogicalPlanPattern::ExprInSubquery
| LogicalPlanPattern::ExprScalarSubquery
)) {
return Ok(TreeNodeRecursion::Jump);
}

plan.apply_expressions(|expr| {
// recursively look for subqueries
expr.apply(|expr| {
if !plan.stats().contains_any_patterns(enum_set!(
LogicalPlanPattern::ExprExists
| LogicalPlanPattern::ExprInSubquery
| LogicalPlanPattern::ExprScalarSubquery
)) {
return Ok(TreeNodeRecursion::Jump);
}

match expr {
Expr::Exists(Exists { subquery, .. }, _)
| Expr::InSubquery(InSubquery { subquery, .. }, _)
Expand Down
9 changes: 9 additions & 0 deletions datafusion/optimizer/src/analyzer/subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@

use crate::analyzer::check_plan;
use crate::utils::collect_subquery_cols;
use enumset::enum_set;
use recursive::recursive;

use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
use datafusion_common::{plan_err, Result};
use datafusion_expr::expr_rewriter::strip_outer_reference;
use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern;
use datafusion_expr::utils::split_conjunction;
use datafusion_expr::{Aggregate, Expr, Filter, Join, JoinType, LogicalPlan, Window};

Expand Down Expand Up @@ -255,6 +257,13 @@ fn strip_inner_query(inner_plan: &LogicalPlan) -> &LogicalPlan {
fn get_correlated_expressions(inner_plan: &LogicalPlan) -> Result<Vec<Expr>> {
let mut exprs = vec![];
inner_plan.apply_with_subqueries(|plan| {
if !plan
.stats()
.contains_pattern(LogicalPlanPattern::LogicalPlanFilter)
{
return Ok(TreeNodeRecursion::Jump);
}

if let LogicalPlan::Filter(Filter { predicate, .. }, _) = plan {
let (correlated, _): (Vec<_>, Vec<_>) = split_conjunction(predicate)
.into_iter()
Expand Down

0 comments on commit bcb0294

Please sign in to comment.