Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid copies in TypeCoercion via TreeNode API #10039

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions datafusion/common/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,12 @@ impl<T> Transformed<T> {
Self::new(data, false, TreeNodeRecursion::Continue)
}

/// If not already, sets `self.transformed` to true if `transformed` is true.
pub fn update_transformed(mut self, transformed: bool) -> Self {
self.transformed |= transformed;
self
}

/// Applies the given `f` to the data of this [`Transformed`] object.
pub fn update_data<U, F: FnOnce(T) -> U>(self, f: F) -> Transformed<U> {
Transformed::new(f(self.data), self.transformed, self.tnr)
Expand Down
61 changes: 60 additions & 1 deletion datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,8 @@ impl LogicalPlan {
mut expr: Vec<Expr>,
mut inputs: Vec<LogicalPlan>,
) -> Result<LogicalPlan> {
match self {

match self {
// Since expr may be different than the previous expr, schema of the projection
// may change. We need to use try_new method instead of try_new_with_schema method.
LogicalPlan::Projection(Projection { .. }) => {
Expand Down Expand Up @@ -815,6 +816,64 @@ impl LogicalPlan {
}
}
}

/// Recalculates the schema of a LogicalPlan. This should be invoked if the
/// types of any expressions or inputs are changed (e.g. by an analyzer pass) using the tree node API.
pub fn recalculate_schema(self) -> Result<LogicalPlan> {
match self {
/*
LogicalPlan::Projection(Projection{ expr, input, schema: _ }) => {
Projection::try_new(expr, input)
.map(LogicalPlan::Projection)
}
*/

// These nodes do not change their schema
//LogicalPlan::Filter(_) => Ok(self),
/*
LogicalPlan::Window(_) => {}
LogicalPlan::Aggregate(_) => {}
LogicalPlan::Sort(_) => {}
LogicalPlan::Join(_) => {}
LogicalPlan::CrossJoin(_) => {}
LogicalPlan::Repartition(_) => {}
LogicalPlan::Union(_) => {}
LogicalPlan::TableScan(_) => {}
LogicalPlan::EmptyRelation(_) => {}
LogicalPlan::Subquery(_) => {}
LogicalPlan::SubqueryAlias(_) => {}
LogicalPlan::Limit(_) => {}
LogicalPlan::Statement(_) => {}
LogicalPlan::Values(_) => {}
LogicalPlan::Explain(_) => {}
LogicalPlan::Analyze(_) => {}
LogicalPlan::Extension(_) => {}
LogicalPlan::Distinct(_) => {}
LogicalPlan::Prepare(_) => {}
LogicalPlan::Dml(_) => {}
LogicalPlan::Ddl(_) => {}
LogicalPlan::Copy(_) => {}
LogicalPlan::DescribeTable(_) => {}
LogicalPlan::Unnest(_) => {}
LogicalPlan::RecursiveQuery(_) => {}

*/

_ => {
// default implementation avoids a copy
// TODO avoid this copy
let new_inputs = self
.inputs()
.into_iter()
.map(|input| input.clone())
.collect::<Vec<_>>();

self.with_new_exprs(self.expressions(), new_inputs)
}

}
}

/// Replaces placeholder param values (like `$1`, `$2`) in [`LogicalPlan`]
/// with the specified `param_values`.
///
Expand Down
63 changes: 38 additions & 25 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use std::sync::Arc;
use arrow::datatypes::{DataType, IntervalUnit};

use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TreeNodeRewriter};
use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
use datafusion_common::{
exec_err, internal_err, plan_datafusion_err, plan_err, DFSchema, DFSchemaRef,
DataFusionError, Result, ScalarValue,
Expand All @@ -31,8 +31,8 @@ use datafusion_expr::expr::{
self, AggregateFunctionDefinition, Between, BinaryExpr, Case, Exists, InList,
InSubquery, Like, ScalarFunction, WindowFunction,
};
use datafusion_expr::expr_rewriter::rewrite_preserving_name;
use datafusion_expr::expr_schema::cast_subquery;
use datafusion_expr::logical_plan::tree_node::unwrap_arc;
use datafusion_expr::logical_plan::Subquery;
use datafusion_expr::type_coercion::binary::{
comparison_coercion, get_input_types, like_coercion,
Expand All @@ -51,6 +51,7 @@ use datafusion_expr::{
};

use crate::analyzer::AnalyzerRule;
use crate::utils::NamePreserver;

#[derive(Default)]
pub struct TypeCoercion {}
Expand All @@ -67,26 +68,31 @@ impl AnalyzerRule for TypeCoercion {
}

fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result<LogicalPlan> {
analyze_internal(&DFSchema::empty(), &plan)
Ok(analyze_internal(&DFSchema::empty(), plan)?.data)
}
}

fn analyze_internal(
// use the external schema to handle the correlated subqueries case
external_schema: &DFSchema,
plan: &LogicalPlan,
) -> Result<LogicalPlan> {
// optimize child plans first
let new_inputs = plan
.inputs()
.iter()
.map(|p| analyze_internal(external_schema, p))
.collect::<Result<Vec<_>>>()?;
plan: LogicalPlan,
) -> Result<Transformed<LogicalPlan>> {
// optimize child plans first (since we use external_schema here, can't use LogicalPlan::transform)
let Transformed {
data: plan,
transformed: children_transformed,
..
} = plan.map_children(|plan| analyze_internal(external_schema, plan))?;

// if any of the expressions were rewritten, we need to recreate the plan to
// recalculate the schema. At the moment this requires a copy
let plan = plan.recalculate_schema()?;

// get schema representing all available input fields. This is used for data type
// resolution only, so order does not matter here
let mut schema = merge_schema(new_inputs.iter().collect());
let mut schema = merge_schema(plan.inputs());

if let LogicalPlan::TableScan(ts) = plan {
if let LogicalPlan::TableScan(ts) = &plan {
let source_schema = DFSchema::try_from_qualified_schema(
ts.table_name.clone(),
&ts.source.schema(),
Expand All @@ -103,17 +109,22 @@ fn analyze_internal(
schema: Arc::new(schema),
};

let new_expr = plan
.expressions()
.into_iter()
.map(|expr| {
// ensure aggregate names don't change:
// https://github.com/apache/datafusion/issues/3555
rewrite_preserving_name(expr, &mut expr_rewrite)
let preserver = NamePreserver::new(&plan);
plan.map_expressions(|expr| {
// ensure aggregate names don't change:
// https://github.com/apache/datafusion/issues/3555
let original_name = preserver.save(&expr)?;
expr.rewrite(&mut expr_rewrite)?
.map_data(|expr| original_name.restore(expr))
})?
.transform_data(|plan| {
// recalculate the schema after the rewrites
plan.recalculate_schema().map(Transformed::yes)
})
.collect::<Result<Vec<_>>>()?;

plan.with_new_exprs(new_expr, new_inputs)
//} else {
// Ok(transformed_plan)
//}
}

pub(crate) struct TypeCoercionRewriter {
Expand All @@ -132,14 +143,15 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
subquery,
outer_ref_columns,
}) => {
let new_plan = analyze_internal(&self.schema, &subquery)?;
let new_plan = analyze_internal(&self.schema, unwrap_arc(subquery))?.data;
Ok(Transformed::yes(Expr::ScalarSubquery(Subquery {
subquery: Arc::new(new_plan),
outer_ref_columns,
})))
}
Expr::Exists(Exists { subquery, negated }) => {
let new_plan = analyze_internal(&self.schema, &subquery.subquery)?;
let new_plan =
analyze_internal(&self.schema, unwrap_arc(subquery.subquery))?.data;
Ok(Transformed::yes(Expr::Exists(Exists {
subquery: Subquery {
subquery: Arc::new(new_plan),
Expand All @@ -153,7 +165,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
subquery,
negated,
}) => {
let new_plan = analyze_internal(&self.schema, &subquery.subquery)?;
let new_plan =
analyze_internal(&self.schema, unwrap_arc(subquery.subquery))?.data;
let expr_type = expr.get_type(&self.schema)?;
let subquery_type = new_plan.schema().field(0).data_type();
let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(plan_datafusion_err!(
Expand Down
Loading