diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index 6e9c42480c32..2c1470a1d6ec 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -258,7 +258,7 @@ pub fn physical_expr(schema: &Schema, expr: Expr) -> Result { + pub(crate) schema: &'a DFSchema, } -impl TreeNodeRewriter for TypeCoercionRewriter { +impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { type Node = Expr; fn f_up(&mut self, expr: Expr) -> Result> { @@ -132,14 +130,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter { subquery, outer_ref_columns, }) => { - let new_plan = analyze_internal(&self.schema, &subquery)?; + let new_plan = analyze_internal(self.schema, &subquery)?; Ok(Transformed::yes(Expr::ScalarSubquery(Subquery { subquery: Arc::new(new_plan), outer_ref_columns, }))) } Expr::Exists(Exists { subquery, negated }) => { - let new_plan = analyze_internal(&self.schema, &subquery.subquery)?; + let new_plan = analyze_internal(self.schema, &subquery.subquery)?; Ok(Transformed::yes(Expr::Exists(Exists { subquery: Subquery { subquery: Arc::new(new_plan), @@ -153,8 +151,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter { subquery, negated, }) => { - let new_plan = analyze_internal(&self.schema, &subquery.subquery)?; - let expr_type = expr.get_type(&self.schema)?; + let new_plan = analyze_internal(self.schema, &subquery.subquery)?; + let expr_type = expr.get_type(self.schema)?; let subquery_type = new_plan.schema().field(0).data_type(); let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(plan_datafusion_err!( "expr type {expr_type:?} can't cast to {subquery_type:?} in InSubquery" @@ -165,32 +163,32 @@ impl TreeNodeRewriter for TypeCoercionRewriter { outer_ref_columns: subquery.outer_ref_columns, }; Ok(Transformed::yes(Expr::InSubquery(InSubquery::new( - Box::new(expr.cast_to(&common_type, &self.schema)?), + Box::new(expr.cast_to(&common_type, self.schema)?), cast_subquery(new_subquery, &common_type)?, negated, )))) } Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op( *expr, - &self.schema, + self.schema, )?))), Expr::IsTrue(expr) => Ok(Transformed::yes(is_true( - get_casted_expr_for_bool_op(*expr, &self.schema)?, + get_casted_expr_for_bool_op(*expr, self.schema)?, ))), Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true( - get_casted_expr_for_bool_op(*expr, &self.schema)?, + get_casted_expr_for_bool_op(*expr, self.schema)?, ))), Expr::IsFalse(expr) => Ok(Transformed::yes(is_false( - get_casted_expr_for_bool_op(*expr, &self.schema)?, + get_casted_expr_for_bool_op(*expr, self.schema)?, ))), Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false( - get_casted_expr_for_bool_op(*expr, &self.schema)?, + get_casted_expr_for_bool_op(*expr, self.schema)?, ))), Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown( - get_casted_expr_for_bool_op(*expr, &self.schema)?, + get_casted_expr_for_bool_op(*expr, self.schema)?, ))), Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown( - get_casted_expr_for_bool_op(*expr, &self.schema)?, + get_casted_expr_for_bool_op(*expr, self.schema)?, ))), Expr::Like(Like { negated, @@ -199,8 +197,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter { escape_char, case_insensitive, }) => { - let left_type = expr.get_type(&self.schema)?; - let right_type = pattern.get_type(&self.schema)?; + let left_type = expr.get_type(self.schema)?; + let right_type = pattern.get_type(self.schema)?; let coerced_type = like_coercion(&left_type, &right_type).ok_or_else(|| { let op_name = if case_insensitive { "ILIKE" @@ -211,8 +209,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter { "There isn't a common type to coerce {left_type} and {right_type} in {op_name} expression" ) })?; - let expr = Box::new(expr.cast_to(&coerced_type, &self.schema)?); - let pattern = Box::new(pattern.cast_to(&coerced_type, &self.schema)?); + let expr = Box::new(expr.cast_to(&coerced_type, self.schema)?); + let pattern = Box::new(pattern.cast_to(&coerced_type, self.schema)?); Ok(Transformed::yes(Expr::Like(Like::new( negated, expr, @@ -223,14 +221,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter { } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { let (left_type, right_type) = get_input_types( - &left.get_type(&self.schema)?, + &left.get_type(self.schema)?, &op, - &right.get_type(&self.schema)?, + &right.get_type(self.schema)?, )?; Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new( - Box::new(left.cast_to(&left_type, &self.schema)?), + Box::new(left.cast_to(&left_type, self.schema)?), op, - Box::new(right.cast_to(&right_type, &self.schema)?), + Box::new(right.cast_to(&right_type, self.schema)?), )))) } Expr::Between(Between { @@ -239,15 +237,15 @@ impl TreeNodeRewriter for TypeCoercionRewriter { low, high, }) => { - let expr_type = expr.get_type(&self.schema)?; - let low_type = low.get_type(&self.schema)?; + let expr_type = expr.get_type(self.schema)?; + let low_type = low.get_type(self.schema)?; let low_coerced_type = comparison_coercion(&expr_type, &low_type) .ok_or_else(|| { DataFusionError::Internal(format!( "Failed to coerce types {expr_type} and {low_type} in BETWEEN expression" )) })?; - let high_type = high.get_type(&self.schema)?; + let high_type = high.get_type(self.schema)?; let high_coerced_type = comparison_coercion(&expr_type, &low_type) .ok_or_else(|| { DataFusionError::Internal(format!( @@ -262,10 +260,10 @@ impl TreeNodeRewriter for TypeCoercionRewriter { )) })?; Ok(Transformed::yes(Expr::Between(Between::new( - Box::new(expr.cast_to(&coercion_type, &self.schema)?), + Box::new(expr.cast_to(&coercion_type, self.schema)?), negated, - Box::new(low.cast_to(&coercion_type, &self.schema)?), - Box::new(high.cast_to(&coercion_type, &self.schema)?), + Box::new(low.cast_to(&coercion_type, self.schema)?), + Box::new(high.cast_to(&coercion_type, self.schema)?), )))) } Expr::InList(InList { @@ -273,10 +271,10 @@ impl TreeNodeRewriter for TypeCoercionRewriter { list, negated, }) => { - let expr_data_type = expr.get_type(&self.schema)?; + let expr_data_type = expr.get_type(self.schema)?; let list_data_types = list .iter() - .map(|list_expr| list_expr.get_type(&self.schema)) + .map(|list_expr| list_expr.get_type(self.schema)) .collect::>>()?; let result_type = get_coerce_type_for_list(&expr_data_type, &list_data_types); @@ -286,11 +284,11 @@ impl TreeNodeRewriter for TypeCoercionRewriter { ), Some(coerced_type) => { // find the coerced type - let cast_expr = expr.cast_to(&coerced_type, &self.schema)?; + let cast_expr = expr.cast_to(&coerced_type, self.schema)?; let cast_list_expr = list .into_iter() .map(|list_expr| { - list_expr.cast_to(&coerced_type, &self.schema) + list_expr.cast_to(&coerced_type, self.schema) }) .collect::>>()?; Ok(Transformed::yes(Expr::InList(InList ::new( @@ -302,18 +300,17 @@ impl TreeNodeRewriter for TypeCoercionRewriter { } } Expr::Case(case) => { - let case = coerce_case_expression(case, &self.schema)?; + let case = coerce_case_expression(case, self.schema)?; Ok(Transformed::yes(Expr::Case(case))) } Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { ScalarFunctionDefinition::UDF(fun) => { let new_expr = coerce_arguments_for_signature( args, - &self.schema, + self.schema, fun.signature(), )?; - let new_expr = - coerce_arguments_for_fun(new_expr, &self.schema, &fun)?; + let new_expr = coerce_arguments_for_fun(new_expr, self.schema, &fun)?; Ok(Transformed::yes(Expr::ScalarFunction( ScalarFunction::new_udf(fun, new_expr), ))) @@ -331,7 +328,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter { let new_expr = coerce_agg_exprs_for_signature( &fun, args, - &self.schema, + self.schema, &fun.signature(), )?; Ok(Transformed::yes(Expr::AggregateFunction( @@ -348,7 +345,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter { AggregateFunctionDefinition::UDF(fun) => { let new_expr = coerce_arguments_for_signature( args, - &self.schema, + self.schema, fun.signature(), )?; Ok(Transformed::yes(Expr::AggregateFunction( @@ -375,14 +372,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter { null_treatment, }) => { let window_frame = - coerce_window_frame(window_frame, &self.schema, &order_by)?; + coerce_window_frame(window_frame, self.schema, &order_by)?; let args = match &fun { expr::WindowFunctionDefinition::AggregateFunction(fun) => { coerce_agg_exprs_for_signature( fun, args, - &self.schema, + self.schema, &fun.signature(), )? } @@ -495,7 +492,7 @@ fn coerce_frame_bound( // For example, ROWS and GROUPS frames use `UInt64` during calculations. fn coerce_window_frame( window_frame: WindowFrame, - schema: &DFSchemaRef, + schema: &DFSchema, expressions: &[Expr], ) -> Result { let mut window_frame = window_frame; @@ -531,7 +528,7 @@ fn coerce_window_frame( // Support the `IsTrue` `IsNotTrue` `IsFalse` `IsNotFalse` type coercion. // The above op will be rewrite to the binary op when creating the physical op. -fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchemaRef) -> Result { +fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchema) -> Result { let left_type = expr.get_type(schema)?; get_input_types(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)?; expr.cast_to(&DataType::Boolean, schema) @@ -615,7 +612,7 @@ fn coerce_agg_exprs_for_signature( .collect() } -fn coerce_case_expression(case: Case, schema: &DFSchemaRef) -> Result { +fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result { // Given expressions like: // // CASE a1 @@ -1238,7 +1235,7 @@ mod test { vec![Field::new("a", DataType::Int64, true)].into(), std::collections::HashMap::new(), )?); - let mut rewriter = TypeCoercionRewriter { schema }; + let mut rewriter = TypeCoercionRewriter { schema: &schema }; let expr = is_true(lit(12i32).gt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64))); let result = expr.rewrite(&mut rewriter).data()?; @@ -1249,7 +1246,7 @@ mod test { vec![Field::new("a", DataType::Int64, true)].into(), std::collections::HashMap::new(), )?); - let mut rewriter = TypeCoercionRewriter { schema }; + let mut rewriter = TypeCoercionRewriter { schema: &schema }; let expr = is_true(lit(12i32).eq(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64))); let result = expr.rewrite(&mut rewriter).data()?; @@ -1260,7 +1257,7 @@ mod test { vec![Field::new("a", DataType::Int64, true)].into(), std::collections::HashMap::new(), )?); - let mut rewriter = TypeCoercionRewriter { schema }; + let mut rewriter = TypeCoercionRewriter { schema: &schema }; let expr = is_true(lit(12i32).lt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64))); let result = expr.rewrite(&mut rewriter).data()?; diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index fb5125f09769..4d7a207afb1b 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -31,9 +31,7 @@ use datafusion_common::{ cast::{as_large_list_array, as_list_array}, tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, }; -use datafusion_common::{ - internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, -}; +use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue}; use datafusion_expr::expr::{InList, InSubquery}; use datafusion_expr::simplify::ExprSimplifyResult; use datafusion_expr::{ @@ -208,14 +206,8 @@ impl ExprSimplifier { /// /// See the [type coercion module](datafusion_expr::type_coercion) /// documentation for more details on type coercion - /// - // Would be nice if this API could use the SimplifyInfo - // rather than creating an DFSchemaRef coerces rather than doing - // it manually. - // https://github.com/apache/datafusion/issues/3793 - pub fn coerce(&self, expr: Expr, schema: DFSchemaRef) -> Result { + pub fn coerce(&self, expr: Expr, schema: &DFSchema) -> Result { let mut expr_rewrite = TypeCoercionRewriter { schema }; - expr.rewrite(&mut expr_rewrite).data() } @@ -1686,7 +1678,7 @@ mod tests { sync::Arc, }; - use datafusion_common::{assert_contains, ToDFSchema}; + use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema}; use datafusion_expr::{interval_arithmetic::Interval, *}; use crate::simplify_expressions::SimplifyContext; @@ -1721,11 +1713,7 @@ mod tests { // should fully simplify to 3 < i (though i has been coerced to i64) let expected = lit(3i64).lt(col("i")); - // Would be nice if this API could use the SimplifyInfo - // rather than creating an DFSchemaRef coerces rather than doing - // it manually. - // https://github.com/apache/datafusion/issues/3793 - let expr = simplifier.coerce(expr, schema).unwrap(); + let expr = simplifier.coerce(expr, &schema).unwrap(); assert_eq!(expected, simplifier.simplify(expr).unwrap()); }