From 62c8a16462944d688bbac9f78c9e791c5ea02a5d Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 10 Apr 2024 18:49:12 -0400 Subject: [PATCH] Avoid copies in `TypeCoercion` via TreeNode API --- .../optimizer/src/analyzer/type_coercion.rs | 189 ++++++++---------- .../simplify_expressions/expr_simplifier.rs | 4 +- 2 files changed, 85 insertions(+), 108 deletions(-) diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 1ea8b9534e808..b9e1040378cb1 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -22,16 +22,15 @@ use std::sync::Arc; use arrow::datatypes::{DataType, IntervalUnit}; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; use datafusion_common::{ - exec_err, internal_err, plan_datafusion_err, plan_err, DFSchema, DFSchemaRef, - DataFusionError, Result, ScalarValue, + exec_err, internal_err, plan_datafusion_err, plan_err, DFSchema, DataFusionError, + Result, ScalarValue, }; use datafusion_expr::expr::{ - self, AggregateFunctionDefinition, Between, BinaryExpr, Case, Exists, InList, - InSubquery, Like, ScalarFunction, WindowFunction, + self, AggregateFunctionDefinition, Between, BinaryExpr, Case, InList, InSubquery, + Like, ScalarFunction, WindowFunction, }; -use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::expr_schema::cast_subquery; use datafusion_expr::logical_plan::Subquery; use datafusion_expr::type_coercion::binary::{ @@ -51,6 +50,7 @@ use datafusion_expr::{ }; use crate::analyzer::AnalyzerRule; +use crate::utils::NamePreserver; #[derive(Default)] pub struct TypeCoercion {} @@ -67,26 +67,21 @@ impl AnalyzerRule for TypeCoercion { } fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { - analyze_internal(&DFSchema::empty(), &plan) + plan.transform_up_with_subqueries(&|plan| rewrite_plan(&DFSchema::empty(), plan)) + .map(|res| res.data) } } -fn analyze_internal( +fn rewrite_plan( // use the external schema to handle the correlated subqueries case external_schema: &DFSchema, - plan: &LogicalPlan, -) -> Result { - // optimize child plans first - let new_inputs = plan - .inputs() - .iter() - .map(|p| analyze_internal(external_schema, p)) - .collect::>>()?; + plan: LogicalPlan, +) -> Result> { // get schema representing all available input fields. This is used for data type // resolution only, so order does not matter here - let mut schema = merge_schema(new_inputs.iter().collect()); + let mut schema = merge_schema(plan.inputs()); - if let LogicalPlan::TableScan(ts) = plan { + if let LogicalPlan::TableScan(ts) = &plan { let source_schema = DFSchema::try_from_qualified_schema( ts.table_name.clone(), &ts.source.schema(), @@ -99,28 +94,23 @@ fn analyze_internal( // select t2.c2 from t1 where t1.c1 in (select t2.c1 from t2 where t2.c2=t1.c3) schema.merge(external_schema); - let mut expr_rewrite = TypeCoercionRewriter { - schema: Arc::new(schema), - }; - - let new_expr = plan - .expressions() - .into_iter() - .map(|expr| { - // ensure aggregate names don't change: - // https://github.com/apache/arrow-datafusion/issues/3555 - rewrite_preserving_name(expr, &mut expr_rewrite) - }) - .collect::>>()?; + let schema = &schema; - plan.with_new_exprs(new_expr, new_inputs) + let name_preserver = NamePreserver::new(&plan); + plan.map_expressions(|expr| { + let original_name = name_preserver.save(&expr)?; + // recursively rewrite the expression bottom up + let mut rewriter = TypeCoercionRewriter { schema }; + expr.rewrite(&mut rewriter)? + .map_data(|expr| original_name.restore(expr)) + }) } -pub(crate) struct TypeCoercionRewriter { - pub(crate) schema: DFSchemaRef, +pub(crate) struct TypeCoercionRewriter<'a> { + pub(crate) schema: &'a DFSchema, } -impl TreeNodeRewriter for TypeCoercionRewriter { +impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { type Node = Expr; fn f_up(&mut self, expr: Expr) -> Result> { @@ -128,33 +118,21 @@ impl TreeNodeRewriter for TypeCoercionRewriter { Expr::Unnest(_) => internal_err!( "Unnest should be rewritten to LogicalPlan::Unnest before type coercion" ), - Expr::ScalarSubquery(Subquery { - subquery, - outer_ref_columns, - }) => { - let new_plan = analyze_internal(&self.schema, &subquery)?; - Ok(Transformed::yes(Expr::ScalarSubquery(Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns, - }))) - } - Expr::Exists(Exists { subquery, negated }) => { - let new_plan = analyze_internal(&self.schema, &subquery.subquery)?; - Ok(Transformed::yes(Expr::Exists(Exists { - subquery: Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns: subquery.outer_ref_columns, - }, - negated, - }))) - } + // Note subqueries are handled by TreeNode visiting + Expr::ScalarSubquery(..) => Ok(Transformed::no(expr)), + Expr::Exists(..) => Ok(Transformed::no(expr)), Expr::InSubquery(InSubquery { expr, subquery, negated, }) => { - let new_plan = analyze_internal(&self.schema, &subquery.subquery)?; - let expr_type = expr.get_type(&self.schema)?; + let sq_plan = match Arc::try_unwrap(subquery.subquery) { + Ok(sq_plan) => sq_plan, + Err(sq_plan) => sq_plan.as_ref().clone(), + }; + + let new_plan = rewrite_plan(self.schema, sq_plan)?.data; + let expr_type = expr.get_type(self.schema)?; let subquery_type = new_plan.schema().field(0).data_type(); let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(plan_datafusion_err!( "expr type {expr_type:?} can't cast to {subquery_type:?} in InSubquery" @@ -165,32 +143,32 @@ impl TreeNodeRewriter for TypeCoercionRewriter { outer_ref_columns: subquery.outer_ref_columns, }; Ok(Transformed::yes(Expr::InSubquery(InSubquery::new( - Box::new(expr.cast_to(&common_type, &self.schema)?), + Box::new(expr.cast_to(&common_type, self.schema)?), cast_subquery(new_subquery, &common_type)?, negated, )))) } Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op( &expr, - &self.schema, + self.schema, )?))), Expr::IsTrue(expr) => Ok(Transformed::yes(is_true( - get_casted_expr_for_bool_op(&expr, &self.schema)?, + get_casted_expr_for_bool_op(&expr, self.schema)?, ))), Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true( - get_casted_expr_for_bool_op(&expr, &self.schema)?, + get_casted_expr_for_bool_op(&expr, self.schema)?, ))), Expr::IsFalse(expr) => Ok(Transformed::yes(is_false( - get_casted_expr_for_bool_op(&expr, &self.schema)?, + get_casted_expr_for_bool_op(&expr, self.schema)?, ))), Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false( - get_casted_expr_for_bool_op(&expr, &self.schema)?, + get_casted_expr_for_bool_op(&expr, self.schema)?, ))), Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown( - get_casted_expr_for_bool_op(&expr, &self.schema)?, + get_casted_expr_for_bool_op(&expr, self.schema)?, ))), Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown( - get_casted_expr_for_bool_op(&expr, &self.schema)?, + get_casted_expr_for_bool_op(&expr, self.schema)?, ))), Expr::Like(Like { negated, @@ -199,8 +177,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter { escape_char, case_insensitive, }) => { - let left_type = expr.get_type(&self.schema)?; - let right_type = pattern.get_type(&self.schema)?; + let left_type = expr.get_type(self.schema)?; + let right_type = pattern.get_type(self.schema)?; let coerced_type = like_coercion(&left_type, &right_type).ok_or_else(|| { let op_name = if case_insensitive { "ILIKE" @@ -211,8 +189,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter { "There isn't a common type to coerce {left_type} and {right_type} in {op_name} expression" ) })?; - let expr = Box::new(expr.cast_to(&coerced_type, &self.schema)?); - let pattern = Box::new(pattern.cast_to(&coerced_type, &self.schema)?); + let expr = Box::new(expr.cast_to(&coerced_type, self.schema)?); + let pattern = Box::new(pattern.cast_to(&coerced_type, self.schema)?); Ok(Transformed::yes(Expr::Like(Like::new( negated, expr, @@ -223,14 +201,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter { } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { let (left_type, right_type) = get_input_types( - &left.get_type(&self.schema)?, + &left.get_type(self.schema)?, &op, - &right.get_type(&self.schema)?, + &right.get_type(self.schema)?, )?; Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new( - Box::new(left.cast_to(&left_type, &self.schema)?), + Box::new(left.cast_to(&left_type, self.schema)?), op, - Box::new(right.cast_to(&right_type, &self.schema)?), + Box::new(right.cast_to(&right_type, self.schema)?), )))) } Expr::Between(Between { @@ -239,15 +217,15 @@ impl TreeNodeRewriter for TypeCoercionRewriter { low, high, }) => { - let expr_type = expr.get_type(&self.schema)?; - let low_type = low.get_type(&self.schema)?; + let expr_type = expr.get_type(self.schema)?; + let low_type = low.get_type(self.schema)?; let low_coerced_type = comparison_coercion(&expr_type, &low_type) .ok_or_else(|| { DataFusionError::Internal(format!( "Failed to coerce types {expr_type} and {low_type} in BETWEEN expression" )) })?; - let high_type = high.get_type(&self.schema)?; + let high_type = high.get_type(self.schema)?; let high_coerced_type = comparison_coercion(&expr_type, &low_type) .ok_or_else(|| { DataFusionError::Internal(format!( @@ -262,10 +240,10 @@ impl TreeNodeRewriter for TypeCoercionRewriter { )) })?; Ok(Transformed::yes(Expr::Between(Between::new( - Box::new(expr.cast_to(&coercion_type, &self.schema)?), + Box::new(expr.cast_to(&coercion_type, self.schema)?), negated, - Box::new(low.cast_to(&coercion_type, &self.schema)?), - Box::new(high.cast_to(&coercion_type, &self.schema)?), + Box::new(low.cast_to(&coercion_type, self.schema)?), + Box::new(high.cast_to(&coercion_type, self.schema)?), )))) } Expr::InList(InList { @@ -273,10 +251,10 @@ impl TreeNodeRewriter for TypeCoercionRewriter { list, negated, }) => { - let expr_data_type = expr.get_type(&self.schema)?; + let expr_data_type = expr.get_type(self.schema)?; let list_data_types = list .iter() - .map(|list_expr| list_expr.get_type(&self.schema)) + .map(|list_expr| list_expr.get_type(self.schema)) .collect::>>()?; let result_type = get_coerce_type_for_list(&expr_data_type, &list_data_types); @@ -286,11 +264,11 @@ impl TreeNodeRewriter for TypeCoercionRewriter { ), Some(coerced_type) => { // find the coerced type - let cast_expr = expr.cast_to(&coerced_type, &self.schema)?; + let cast_expr = expr.cast_to(&coerced_type, self.schema)?; let cast_list_expr = list .into_iter() .map(|list_expr| { - list_expr.cast_to(&coerced_type, &self.schema) + list_expr.cast_to(&coerced_type, self.schema) }) .collect::>>()?; Ok(Transformed::yes(Expr::InList(InList ::new( @@ -302,14 +280,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter { } } Expr::Case(case) => { - let case = coerce_case_expression(case, &self.schema)?; + let case = coerce_case_expression(case, self.schema)?; Ok(Transformed::yes(Expr::Case(case))) } Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { ScalarFunctionDefinition::BuiltIn(fun) => { let new_args = coerce_arguments_for_signature( args.as_slice(), - &self.schema, + self.schema, &fun.signature(), )?; Ok(Transformed::yes(Expr::ScalarFunction(ScalarFunction::new( @@ -319,14 +297,11 @@ impl TreeNodeRewriter for TypeCoercionRewriter { ScalarFunctionDefinition::UDF(fun) => { let new_expr = coerce_arguments_for_signature( args.as_slice(), - &self.schema, + self.schema, fun.signature(), )?; - let new_expr = coerce_arguments_for_fun( - new_expr.as_slice(), - &self.schema, - &fun, - )?; + let new_expr = + coerce_arguments_for_fun(new_expr.as_slice(), self.schema, &fun)?; Ok(Transformed::yes(Expr::ScalarFunction( ScalarFunction::new_udf(fun, new_expr), ))) @@ -347,7 +322,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter { let new_expr = coerce_agg_exprs_for_signature( &fun, &args, - &self.schema, + self.schema, &fun.signature(), )?; Ok(Transformed::yes(Expr::AggregateFunction( @@ -364,7 +339,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter { AggregateFunctionDefinition::UDF(fun) => { let new_expr = coerce_arguments_for_signature( args.as_slice(), - &self.schema, + self.schema, fun.signature(), )?; Ok(Transformed::yes(Expr::AggregateFunction( @@ -391,14 +366,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter { null_treatment, }) => { let window_frame = - coerce_window_frame(window_frame, &self.schema, &order_by)?; + coerce_window_frame(window_frame, self.schema, &order_by)?; let args = match &fun { expr::WindowFunctionDefinition::AggregateFunction(fun) => { coerce_agg_exprs_for_signature( fun, &args, - &self.schema, + self.schema, &fun.signature(), )? } @@ -511,7 +486,7 @@ fn coerce_frame_bound( // For example, ROWS and GROUPS frames use `UInt64` during calculations. fn coerce_window_frame( window_frame: WindowFrame, - schema: &DFSchemaRef, + schema: &DFSchema, expressions: &[Expr], ) -> Result { let mut window_frame = window_frame; @@ -548,7 +523,7 @@ fn coerce_window_frame( // Support the `IsTrue` `IsNotTrue` `IsFalse` `IsNotFalse` type coercion. // The above op will be rewrite to the binary op when creating the physical op. -fn get_casted_expr_for_bool_op(expr: &Expr, schema: &DFSchemaRef) -> Result { +fn get_casted_expr_for_bool_op(expr: &Expr, schema: &DFSchema) -> Result { let left_type = expr.get_type(schema)?; get_input_types(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)?; cast_expr(expr, &DataType::Boolean, schema) @@ -643,7 +618,7 @@ fn coerce_agg_exprs_for_signature( .collect::>>() } -fn coerce_case_expression(case: Case, schema: &DFSchemaRef) -> Result { +fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result { // Given expressions like: // // CASE a1 @@ -1260,33 +1235,33 @@ mod test { #[test] fn test_type_coercion_rewrite() -> Result<()> { // gt - let schema = Arc::new(DFSchema::from_unqualifed_fields( + let schema = DFSchema::from_unqualifed_fields( vec![Field::new("a", DataType::Int64, true)].into(), std::collections::HashMap::new(), - )?); - let mut rewriter = TypeCoercionRewriter { schema }; + )?; + let mut rewriter = TypeCoercionRewriter { schema: &schema }; let expr = is_true(lit(12i32).gt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64))); let result = expr.rewrite(&mut rewriter).data()?; assert_eq!(expected, result); // eq - let schema = Arc::new(DFSchema::from_unqualifed_fields( + let schema = DFSchema::from_unqualifed_fields( vec![Field::new("a", DataType::Int64, true)].into(), std::collections::HashMap::new(), - )?); - let mut rewriter = TypeCoercionRewriter { schema }; + )?; + let mut rewriter = TypeCoercionRewriter { schema: &schema }; let expr = is_true(lit(12i32).eq(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64))); let result = expr.rewrite(&mut rewriter).data()?; assert_eq!(expected, result); // lt - let schema = Arc::new(DFSchema::from_unqualifed_fields( + let schema = DFSchema::from_unqualifed_fields( vec![Field::new("a", DataType::Int64, true)].into(), std::collections::HashMap::new(), - )?); - let mut rewriter = TypeCoercionRewriter { schema }; + )?; + let mut rewriter = TypeCoercionRewriter { schema: &schema }; let expr = is_true(lit(12i32).lt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64))); let result = expr.rewrite(&mut rewriter).data()?; diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 3198807b04cfd..3c7879ab86361 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -212,7 +212,9 @@ impl ExprSimplifier { // it manually. // https://github.com/apache/arrow-datafusion/issues/3793 pub fn coerce(&self, expr: Expr, schema: DFSchemaRef) -> Result { - let mut expr_rewrite = TypeCoercionRewriter { schema }; + let mut expr_rewrite = TypeCoercionRewriter { + schema: schema.as_ref(), + }; expr.rewrite(&mut expr_rewrite).data() }