Skip to content

Commit

Permalink
fix more rules
Browse files Browse the repository at this point in the history
  • Loading branch information
peter-toth committed Nov 27, 2024
1 parent b78b57a commit d63b6fc
Show file tree
Hide file tree
Showing 7 changed files with 311 additions and 267 deletions.
1 change: 1 addition & 0 deletions datafusion/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ datafusion-physical-expr-common = { workspace = true }
datafusion-physical-optimizer = { workspace = true }
datafusion-physical-plan = { workspace = true }
datafusion-sql = { workspace = true }
enumset = { workspace = true }
flate2 = { version = "1.0.24", optional = true }
futures = { workspace = true }
glob = "0.3.0"
Expand Down
74 changes: 40 additions & 34 deletions datafusion/core/tests/user_defined/user_defined_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ use arrow::{
util::pretty::pretty_format_batches,
};
use async_trait::async_trait;
use enumset::enum_set;
use futures::{Stream, StreamExt};

use datafusion::execution::session_state::SessionStateBuilder;
Expand Down Expand Up @@ -97,8 +98,8 @@ use datafusion::{
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::ScalarValue;
use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern;
use datafusion_expr::{FetchType, Projection, SortExpr};
use datafusion_optimizer::optimizer::ApplyOrder;
use datafusion_optimizer::AnalyzerRule;

/// Execute the specified sql and return the resulting record batches
Expand Down Expand Up @@ -343,10 +344,6 @@ impl OptimizerRule for TopKOptimizerRule {
"topk"
}

fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::TopDown)
}

fn supports_rewrite(&self) -> bool {
true
}
Expand All @@ -357,38 +354,47 @@ impl OptimizerRule for TopKOptimizerRule {
plan: LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>, DataFusionError> {
// Note: this code simply looks for the pattern of a Limit followed by a
// Sort and replaces it by a TopK node. It does not handle many
// edge cases (e.g multiple sort columns, sort ASC / DESC), etc.
let LogicalPlan::Limit(ref limit, _) = plan else {
return Ok(Transformed::no(plan));
};
let FetchType::Literal(Some(fetch)) = limit.get_fetch_type()? else {
return Ok(Transformed::no(plan));
};
plan.transform_down_with_subqueries(|plan| {
if !plan.stats().contains_all_patterns(enum_set!(
LogicalPlanPattern::LogicalPlanLimit
| LogicalPlanPattern::LogicalPlanSort
)) {
return Ok(Transformed::jump(plan));
}

if let LogicalPlan::Sort(
Sort {
ref expr,
ref input,
..
},
_,
) = limit.input.as_ref()
{
if expr.len() == 1 {
// we found a sort with a single sort expr, replace with a a TopK
return Ok(Transformed::yes(LogicalPlan::extension(Extension {
node: Arc::new(TopKPlanNode {
k: fetch,
input: input.as_ref().clone(),
expr: expr[0].clone(),
}),
})));
// Note: this code simply looks for the pattern of a Limit followed by a
// Sort and replaces it by a TopK node. It does not handle many
// edge cases (e.g multiple sort columns, sort ASC / DESC), etc.
let LogicalPlan::Limit(ref limit, _) = plan else {
return Ok(Transformed::no(plan));
};
let FetchType::Literal(Some(fetch)) = limit.get_fetch_type()? else {
return Ok(Transformed::no(plan));
};

if let LogicalPlan::Sort(
Sort {
ref expr,
ref input,
..
},
_,
) = limit.input.as_ref()
{
if expr.len() == 1 {
// we found a sort with a single sort expr, replace with a a TopK
return Ok(Transformed::yes(LogicalPlan::extension(Extension {
node: Arc::new(TopKPlanNode {
k: fetch,
input: input.as_ref().clone(),
expr: expr[0].clone(),
}),
})));
}
}
}

Ok(Transformed::no(plan))
Ok(Transformed::no(plan))
})
}
}

Expand Down
4 changes: 3 additions & 1 deletion datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2377,7 +2377,9 @@ impl LogicalPlan {
}

pub fn distinct(distinct: Distinct) -> Self {
let stats = distinct.stats();
let stats =
LogicalPlanStats::new(enum_set!(LogicalPlanPattern::LogicalPlanDistinct))
.merge(distinct.stats());
LogicalPlan::Distinct(distinct, stats)
}

Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/logical_plan/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ pub enum LogicalPlanPattern {
// LogicalPlanExplain,
// LogicalPlanAnalyze,
// LogicalPlanExtension,
// LogicalPlanDistinct,
LogicalPlanDistinct,
// LogicalPlanDml,
// LogicalPlanDdl,
// LogicalPlanCopy,
Expand Down
233 changes: 129 additions & 104 deletions datafusion/optimizer/src/replace_distinct_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@

//! [`ReplaceDistinctWithAggregate`] replaces `DISTINCT ...` with `GROUP BY ...`
use crate::optimizer::{ApplyOrder, ApplyOrder::BottomUp};
use crate::{OptimizerConfig, OptimizerRule};
use std::sync::Arc;

use datafusion_common::tree_node::Transformed;
use datafusion_common::{Column, Result};
use datafusion_expr::expr_rewriter::normalize_cols;
use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern;
use datafusion_expr::utils::expand_wildcard;
use datafusion_expr::{col, ExprFunctionExt, LogicalPlanBuilder};
use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan};
use std::cell::Cell;
use std::sync::Arc;

/// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]]
///
Expand Down Expand Up @@ -76,115 +76,140 @@ impl OptimizerRule for ReplaceDistinctWithAggregate {
plan: LogicalPlan,
config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
match plan {
LogicalPlan::Distinct(Distinct::All(input), _) => {
let group_expr = expand_wildcard(input.schema(), &input, None)?;

let field_count = input.schema().fields().len();
for dep in input.schema().functional_dependencies().iter() {
// If distinct is exactly the same with a previous GROUP BY, we can
// simply remove it:
if dep.source_indices.len() >= field_count
&& dep.source_indices[..field_count]
.iter()
.enumerate()
.all(|(idx, f_idx)| idx == *f_idx)
{
return Ok(Transformed::yes(input.as_ref().clone()));
}
let skip = Cell::new(false);
plan.transform_down_up_with_subqueries(
|plan| {
if !plan
.stats()
.contains_pattern(LogicalPlanPattern::LogicalPlanDistinct)
{
skip.set(true);
return Ok(Transformed::jump(plan));
}

Ok(Transformed::no(plan))
},
|plan| {
if skip.get() {
skip.set(false);
return Ok(Transformed::no(plan));
}

// Replace with aggregation:
let aggr_plan = LogicalPlan::aggregate(Aggregate::try_new(
input,
group_expr,
vec![],
)?);
Ok(Transformed::yes(aggr_plan))
}
LogicalPlan::Distinct(
Distinct::On(DistinctOn {
select_expr,
on_expr,
sort_expr,
input,
schema,
}),
_,
) => {
let expr_cnt = on_expr.len();

// Construct the aggregation expression to be used to fetch the selected expressions.
let first_value_udaf: Arc<datafusion_expr::AggregateUDF> =
config.function_registry().unwrap().udaf("first_value")?;
let aggr_expr = select_expr.into_iter().map(|e| {
if let Some(order_by) = &sort_expr {
first_value_udaf
.call(vec![e])
.order_by(order_by.clone())
.build()
// guaranteed to be `Expr::AggregateFunction`
.unwrap()
} else {
first_value_udaf.call(vec![e])
match plan {
LogicalPlan::Distinct(Distinct::All(input), _) => {
let group_expr = expand_wildcard(input.schema(), &input, None)?;

let field_count = input.schema().fields().len();
for dep in input.schema().functional_dependencies().iter() {
// If distinct is exactly the same with a previous GROUP BY, we can
// simply remove it:
if dep.source_indices.len() >= field_count
&& dep.source_indices[..field_count]
.iter()
.enumerate()
.all(|(idx, f_idx)| idx == *f_idx)
{
return Ok(Transformed::yes(input.as_ref().clone()));
}
}

// Replace with aggregation:
let aggr_plan = LogicalPlan::aggregate(Aggregate::try_new(
input,
group_expr,
vec![],
)?);
Ok(Transformed::yes(aggr_plan))
}
});

let aggr_expr = normalize_cols(aggr_expr, input.as_ref())?;
let group_expr = normalize_cols(on_expr, input.as_ref())?;

// Build the aggregation plan
let plan = LogicalPlan::aggregate(Aggregate::try_new(
input, group_expr, aggr_expr,
)?);
// TODO use LogicalPlanBuilder directly rather than recreating the Aggregate
// when https://github.com/apache/datafusion/issues/10485 is available
let lpb = LogicalPlanBuilder::from(plan);

let plan = if let Some(mut sort_expr) = sort_expr {
// While sort expressions were used in the `FIRST_VALUE` aggregation itself above,
// this on it's own isn't enough to guarantee the proper output order of the grouping
// (`ON`) expression, so we need to sort those as well.

// truncate the sort_expr to the length of on_expr
sort_expr.truncate(expr_cnt);

lpb.sort(sort_expr)?.build()?
} else {
lpb.build()?
};

// Whereas the aggregation plan by default outputs both the grouping and the aggregation
// expressions, for `DISTINCT ON` we only need to emit the original selection expressions.

let project_exprs = plan
.schema()
.iter()
.skip(expr_cnt)
.zip(schema.iter())
.map(|((new_qualifier, new_field), (old_qualifier, old_field))| {
col(Column::from((new_qualifier, new_field)))
.alias_qualified(old_qualifier.cloned(), old_field.name())
})
.collect::<Vec<Expr>>();

let plan = LogicalPlanBuilder::from(plan)
.project(project_exprs)?
.build()?;

Ok(Transformed::yes(plan))
}
_ => Ok(Transformed::no(plan)),
}
LogicalPlan::Distinct(
Distinct::On(DistinctOn {
select_expr,
on_expr,
sort_expr,
input,
schema,
}),
_,
) => {
let expr_cnt = on_expr.len();

// Construct the aggregation expression to be used to fetch the selected expressions.
let first_value_udaf: Arc<datafusion_expr::AggregateUDF> =
config.function_registry().unwrap().udaf("first_value")?;
let aggr_expr = select_expr.into_iter().map(|e| {
if let Some(order_by) = &sort_expr {
first_value_udaf
.call(vec![e])
.order_by(order_by.clone())
.build()
// guaranteed to be `Expr::AggregateFunction`
.unwrap()
} else {
first_value_udaf.call(vec![e])
}
});

let aggr_expr = normalize_cols(aggr_expr, input.as_ref())?;
let group_expr = normalize_cols(on_expr, input.as_ref())?;

// Build the aggregation plan
let plan = LogicalPlan::aggregate(Aggregate::try_new(
input, group_expr, aggr_expr,
)?);
// TODO use LogicalPlanBuilder directly rather than recreating the Aggregate
// when https://github.com/apache/datafusion/issues/10485 is available
let lpb = LogicalPlanBuilder::from(plan);

let plan = if let Some(mut sort_expr) = sort_expr {
// While sort expressions were used in the `FIRST_VALUE` aggregation itself above,
// this on it's own isn't enough to guarantee the proper output order of the grouping
// (`ON`) expression, so we need to sort those as well.

// truncate the sort_expr to the length of on_expr
sort_expr.truncate(expr_cnt);

lpb.sort(sort_expr)?.build()?
} else {
lpb.build()?
};

// Whereas the aggregation plan by default outputs both the grouping and the aggregation
// expressions, for `DISTINCT ON` we only need to emit the original selection expressions.

let project_exprs = plan
.schema()
.iter()
.skip(expr_cnt)
.zip(schema.iter())
.map(
|(
(new_qualifier, new_field),
(old_qualifier, old_field),
)| {
col(Column::from((new_qualifier, new_field)))
.alias_qualified(
old_qualifier.cloned(),
old_field.name(),
)
},
)
.collect::<Vec<Expr>>();

let plan = LogicalPlanBuilder::from(plan)
.project(project_exprs)?
.build()?;

Ok(Transformed::yes(plan))
}
_ => Ok(Transformed::no(plan)),
}
},
)
}

fn name(&self) -> &str {
"replace_distinct_aggregate"
}

fn apply_order(&self) -> Option<ApplyOrder> {
Some(BottomUp)
}
}

#[cfg(test)]
Expand Down
Loading

0 comments on commit d63b6fc

Please sign in to comment.