From adaff12a2137094b5ef119114d234a7662587ad3 Mon Sep 17 00:00:00 2001 From: Xin Li Date: Fri, 5 Jul 2024 21:01:44 +0800 Subject: [PATCH] Implement user defined planner --- .../core/src/execution/session_state.rs | 4 +- datafusion/expr/src/planner.rs | 7 -- datafusion/functions/src/datetime/mod.rs | 1 - datafusion/functions/src/lib.rs | 3 + .../functions/src/{datetime => }/planner.rs | 30 +++++--- datafusion/sql/src/expr/mod.rs | 71 +++++++++---------- 6 files changed, 61 insertions(+), 55 deletions(-) rename datafusion/functions/src/{datetime => }/planner.rs (50%) diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index a831f92def50..175bbc0a9187 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -238,8 +238,8 @@ impl SessionState { Arc::new(functions_array::planner::ArrayFunctionPlanner), #[cfg(feature = "array_expressions")] Arc::new(functions_array::planner::FieldAccessPlanner), - #[cfg(feature = "datetime_expressions")] - Arc::new(functions::datetime::planner::ExtractPlanner), + #[cfg(any(feature = "datetime_expressions", feature ="unicode_expressions"))] + Arc::new(functions::planner::UserDefinedFunctionPlanner), ]; let mut new_self = SessionState { diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index e25de083d7f8..4b3216f7a47c 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -135,13 +135,6 @@ pub trait UserDefinedSQLPlanner: Send + Sync { ) -> Result> { Ok(PlannerResult::Original(expr)) } - - /// Plan an extract expression, e.g., `EXTRACT(month FROM foo)` - /// - /// Returns origin expression arguments if not possible - fn plan_extract(&self, args: Vec) -> Result>> { - Ok(PlannerResult::Original(args)) - } } /// An operator with two arguments to plan diff --git a/datafusion/functions/src/datetime/mod.rs b/datafusion/functions/src/datetime/mod.rs index 8365a38f41f2..9c2f80856bf8 100644 --- a/datafusion/functions/src/datetime/mod.rs +++ b/datafusion/functions/src/datetime/mod.rs @@ -30,7 +30,6 @@ pub mod date_trunc; pub mod from_unixtime; pub mod make_date; pub mod now; -pub mod planner; pub mod to_char; pub mod to_date; pub mod to_timestamp; diff --git a/datafusion/functions/src/lib.rs b/datafusion/functions/src/lib.rs index 4bc24931d06b..433a4f90d95b 100644 --- a/datafusion/functions/src/lib.rs +++ b/datafusion/functions/src/lib.rs @@ -130,6 +130,9 @@ make_stub_package!(crypto, "crypto_expressions"); pub mod unicode; make_stub_package!(unicode, "unicode_expressions"); +#[cfg(any(feature = "datetime_expressions", feature = "unicode_expressions"))] +pub mod planner; + mod utils; /// Fluent-style API for creating `Expr`s diff --git a/datafusion/functions/src/datetime/planner.rs b/datafusion/functions/src/planner.rs similarity index 50% rename from datafusion/functions/src/datetime/planner.rs rename to datafusion/functions/src/planner.rs index 4265ce42a51a..91956a20fa95 100644 --- a/datafusion/functions/src/datetime/planner.rs +++ b/datafusion/functions/src/planner.rs @@ -15,22 +15,36 @@ // specific language governing permissions and limitations // under the License. -//! SQL planning extensions like [`ExtractPlanner`] +//! SQL planning extensions like [`UserDefinedFunctionPlanner`] use datafusion_common::Result; use datafusion_expr::{ expr::ScalarFunction, planner::{PlannerResult, UserDefinedSQLPlanner}, - Expr, + sqlparser, Expr, }; +use sqlparser::ast::Expr as SQLExpr; #[derive(Default)] -pub struct ExtractPlanner; +pub struct UserDefinedFunctionPlanner; -impl UserDefinedSQLPlanner for ExtractPlanner { - fn plan_extract(&self, args: Vec) -> Result>> { - Ok(PlannerResult::Planned(Expr::ScalarFunction( - ScalarFunction::new_udf(crate::datetime::date_part(), args), - ))) +impl UserDefinedSQLPlanner for UserDefinedFunctionPlanner { + // Plan the user defined function, returns origin expression arguments if not possible + fn plan_udf( + &self, + sql: &SQLExpr, + args: Vec, + ) -> Result>> { + match sql { + #[cfg(feature = "datetime_expressions")] + SQLExpr::Extract { .. } => Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf(crate::datetime::date_part(), args), + ))), + #[cfg(feature = "unicode_expressions")] + SQLExpr::Position { .. } => Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf(crate::unicode::strpos(), args), + ))), + _ => Ok(PlannerResult::Original(args)), + } } } diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 828039cea45e..8888652ad27c 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -173,29 +173,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let sql_not_moved = sql.clone(); match sql { SQLExpr::Value(value) => { self.parse_value(value, planner_context.prepare_param_data_types()) } - SQLExpr::Extract { field, expr } => { - let mut extract_args = vec![ - Expr::Literal(ScalarValue::from(format!("{field}"))), - self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, - ]; - - for planner in self.planners.iter() { - match planner.plan_extract(extract_args)? { - PlannerResult::Planned(expr) => return Ok(expr), - PlannerResult::Original(args) => { - extract_args = args; - } - } - } - - not_impl_err!("Extract not supported by UserDefinedExtensionPlanners: {extract_args:?}") - } - SQLExpr::Array(arr) => self.sql_array_literal(arr.elem, schema), SQLExpr::Interval(interval) => { self.sql_interval_to_expr(false, interval, schema, planner_context) @@ -600,24 +581,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::Struct { values, fields } => { self.parse_struct(values, fields, schema, planner_context) } - SQLExpr::Position { expr, r#in } => { - let substr = - self.sql_expr_to_logical_expr(*expr, schema, planner_context)?; - let fullstr = - self.sql_expr_to_logical_expr(*r#in, schema, planner_context)?; - let mut extracted_args = vec![fullstr, substr]; - - for planner in self.planners.iter() { - match planner.plan_udf(&sql_not_moved, extracted_args)? { - PlannerResult::Planned(expr) => return Ok(expr), - PlannerResult::Original(args) => { - extracted_args = args; - } - } - } - - not_impl_err!("Position not supported by UserDefinedExtensionPlanners: {extracted_args:?}") - } SQLExpr::AtTimeZone { timestamp, time_zone, @@ -641,6 +604,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::Dictionary(fields) => { self.try_plan_dictionary_literal(fields, schema, planner_context) } + SQLExpr::Extract { .. } | SQLExpr::Position { .. } => { + self.sql_udf_plan(sql, schema, planner_context) + } _ => not_impl_err!("Unsupported ast node in sqltorel: {sql:?}"), } } @@ -829,6 +795,37 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ))) } + fn sql_udf_plan( + &self, + sql: SQLExpr, + schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { + let sql_not_moved = sql.clone(); + let mut extracted_args = match sql { + SQLExpr::Position { expr, r#in } => Ok(vec![ + self.sql_expr_to_logical_expr(*r#in, schema, planner_context)?, + self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, + ]), + SQLExpr::Extract { field, expr } => Ok(vec![ + Expr::Literal(ScalarValue::from(format!("{field}"))), + self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, + ]), + _ => not_impl_err!("sql_udf_plan not support sql expression: {sql:?}"), + }?; + + for planner in self.planners.iter() { + match planner.plan_udf(&sql_not_moved, extracted_args)? { + PlannerResult::Planned(expr) => return Ok(expr), + PlannerResult::Original(args) => { + extracted_args = args; + } + } + } + + not_impl_err!("sql_udf_plan not support sql expression: {sql_not_moved:?}") + } + fn sql_similarto_to_expr( &self, negated: bool,