From 5b311ee52538a0da63c295dc2ec40ca578ddc5fa Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Fri, 15 Nov 2024 19:24:55 +0100 Subject: [PATCH] Add `LogicalPlanStat`s --- Cargo.toml | 1 + datafusion-cli/Cargo.lock | 62 + datafusion-cli/src/exec.rs | 8 +- datafusion-examples/examples/analyzer_rule.rs | 2 +- datafusion-examples/examples/expr_api.rs | 8 +- .../examples/optimizer_rule.rs | 2 +- .../examples/simplify_udaf_expression.rs | 2 +- .../examples/simplify_udwf_expression.rs | 2 +- datafusion-examples/examples/sql_analysis.rs | 8 +- datafusion/core/src/dataframe/mod.rs | 8 +- .../core/src/datasource/listing/helpers.rs | 63 +- .../core/src/datasource/listing/table.rs | 2 +- datafusion/core/src/execution/context/mod.rs | 31 +- .../core/src/execution/session_state.rs | 6 +- datafusion/core/src/physical_planner.rs | 311 ++-- .../core/tests/custom_sources_cases/mod.rs | 13 +- .../provider_filter_pushdown.rs | 4 +- .../tests/dataframe/dataframe_functions.rs | 7 +- datafusion/core/tests/dataframe/mod.rs | 8 +- .../core/tests/execution/logical_plan.rs | 6 +- .../core/tests/expr_api/simplification.rs | 14 +- datafusion/core/tests/optimizer/mod.rs | 4 +- datafusion/core/tests/sql/mod.rs | 13 +- .../core/tests/user_defined/expr_planner.rs | 16 +- .../tests/user_defined/user_defined_plan.rs | 21 +- .../user_defined_scalar_functions.rs | 2 +- .../user_defined_table_functions.rs | 2 +- datafusion/expr/Cargo.toml | 1 + .../expr/src/conditional_expressions.rs | 2 +- datafusion/expr/src/expr.rs | 953 ++++++++---- datafusion/expr/src/expr_fn.rs | 82 +- datafusion/expr/src/expr_rewriter/mod.rs | 28 +- datafusion/expr/src/expr_rewriter/order_by.rs | 14 +- datafusion/expr/src/expr_schema.rs | 180 +-- datafusion/expr/src/logical_plan/builder.rs | 77 +- datafusion/expr/src/logical_plan/ddl.rs | 41 + datafusion/expr/src/logical_plan/display.rs | 147 +- datafusion/expr/src/logical_plan/dml.rs | 16 +- datafusion/expr/src/logical_plan/plan.rs | 1310 ++++++++++++----- datafusion/expr/src/logical_plan/statement.rs | 11 + datafusion/expr/src/logical_plan/tree_node.rs | 702 +++++---- datafusion/expr/src/operation.rs | 56 +- datafusion/expr/src/test/function_stub.rs | 10 +- datafusion/expr/src/tree_node.rs | 335 +++-- datafusion/expr/src/udaf.rs | 2 +- datafusion/expr/src/udf.rs | 2 +- datafusion/expr/src/udwf.rs | 2 +- datafusion/expr/src/utils.rs | 157 +- datafusion/functions-aggregate/src/count.rs | 2 +- datafusion/functions-aggregate/src/macros.rs | 4 +- datafusion/functions-nested/src/macros.rs | 4 +- datafusion/functions-nested/src/map.rs | 2 +- datafusion/functions-nested/src/planner.rs | 8 +- datafusion/functions/src/core/arrow_cast.rs | 2 +- datafusion/functions/src/core/planner.rs | 6 +- datafusion/functions/src/math/log.rs | 2 +- datafusion/functions/src/math/power.rs | 2 +- datafusion/functions/src/planner.rs | 6 +- datafusion/functions/src/string/concat.rs | 2 +- datafusion/functions/src/string/concat_ws.rs | 2 +- .../src/analyzer/count_wildcard_rule.rs | 10 +- .../src/analyzer/expand_wildcard_rule.rs | 14 +- .../src/analyzer/function_rewrite.rs | 2 +- .../src/analyzer/inline_table_scan.rs | 6 +- datafusion/optimizer/src/analyzer/mod.rs | 6 +- .../src/analyzer/resolve_grouping_function.rs | 36 +- datafusion/optimizer/src/analyzer/subquery.rs | 77 +- .../optimizer/src/analyzer/type_coercion.rs | 309 ++-- .../optimizer/src/common_subexpr_eliminate.rs | 121 +- datafusion/optimizer/src/decorrelate.rs | 45 +- .../src/decorrelate_predicate_subquery.rs | 90 +- .../optimizer/src/eliminate_cross_join.rs | 78 +- .../src/eliminate_duplicated_expr.rs | 8 +- datafusion/optimizer/src/eliminate_filter.rs | 13 +- .../src/eliminate_group_by_constant.rs | 16 +- datafusion/optimizer/src/eliminate_join.rs | 6 +- datafusion/optimizer/src/eliminate_limit.rs | 6 +- .../optimizer/src/eliminate_nested_union.rs | 18 +- .../optimizer/src/eliminate_one_union.rs | 4 +- .../optimizer/src/eliminate_outer_join.rs | 116 +- .../src/extract_equijoin_predicate.rs | 40 +- .../optimizer/src/filter_null_join_keys.rs | 13 +- .../optimizer/src/optimize_projections/mod.rs | 122 +- datafusion/optimizer/src/optimizer.rs | 6 +- datafusion/optimizer/src/plan_signature.rs | 4 +- .../optimizer/src/propagate_empty_relation.rs | 26 +- datafusion/optimizer/src/push_down_filter.rs | 209 +-- datafusion/optimizer/src/push_down_limit.rs | 72 +- .../src/replace_distinct_aggregate.rs | 23 +- .../optimizer/src/scalar_subquery_to_join.rs | 45 +- .../simplify_expressions/expr_simplifier.rs | 1254 +++++++++------- .../src/simplify_expressions/guarantees.rs | 48 +- .../simplify_expressions/inlist_simplifier.rs | 13 +- .../src/simplify_expressions/regex.rs | 6 +- .../simplify_expressions/simplify_exprs.rs | 10 +- .../src/simplify_expressions/utils.rs | 58 +- .../src/single_distinct_to_groupby.rs | 63 +- datafusion/optimizer/src/test/user_defined.rs | 2 +- .../src/unwrap_cast_in_comparison.rs | 67 +- datafusion/optimizer/src/utils.rs | 2 +- datafusion/physical-expr/src/planner.rs | 94 +- datafusion/physical-optimizer/src/pruning.rs | 22 +- .../proto/src/logical_plan/from_proto.rs | 103 +- datafusion/proto/src/logical_plan/mod.rs | 266 ++-- datafusion/proto/src/logical_plan/to_proto.rs | 172 ++- .../tests/cases/roundtrip_logical_plan.rs | 88 +- datafusion/proto/tests/cases/serialize.rs | 2 +- datafusion/sql/src/cte.rs | 2 +- datafusion/sql/src/expr/function.rs | 18 +- datafusion/sql/src/expr/grouping_set.rs | 6 +- datafusion/sql/src/expr/identifier.rs | 2 +- datafusion/sql/src/expr/mod.rs | 46 +- datafusion/sql/src/expr/subquery.rs | 6 +- datafusion/sql/src/expr/unary_op.rs | 4 +- datafusion/sql/src/expr/value.rs | 2 +- datafusion/sql/src/query.rs | 6 +- datafusion/sql/src/relation/mod.rs | 14 +- datafusion/sql/src/select.rs | 22 +- datafusion/sql/src/statement.rs | 56 +- datafusion/sql/src/unparser/expr.rs | 201 +-- datafusion/sql/src/unparser/plan.rs | 88 +- datafusion/sql/src/unparser/rewrite.rs | 49 +- datafusion/sql/src/unparser/utils.rs | 38 +- datafusion/sql/src/utils.rs | 24 +- datafusion/sql/tests/sql_integration.rs | 35 +- .../substrait/src/logical_plan/consumer.rs | 74 +- .../substrait/src/logical_plan/producer.rs | 154 +- .../tests/cases/roundtrip_logical_plan.rs | 8 +- .../building-logical-plans.md | 2 +- 129 files changed, 5708 insertions(+), 3756 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 0011539156326..93f7f1dc3650e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -121,6 +121,7 @@ datafusion-sql = { path = "datafusion/sql", version = "43.0.0" } datafusion-sqllogictest = { path = "datafusion/sqllogictest", version = "43.0.0" } datafusion-substrait = { path = "datafusion/substrait", version = "43.0.0" } doc-comment = "0.3" +enumset = "1.1.5" env_logger = "0.11" futures = "0.3" half = { version = "2.2.1", default-features = false } diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index c5576b7e7d444..504a26f003452 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1166,6 +1166,40 @@ dependencies = [ "syn", ] +[[package]] +name = "darling" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "darling_macro" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" +dependencies = [ + "darling_core", + "quote", + "syn", +] + [[package]] name = "dary_heap" version = "0.3.7" @@ -1352,6 +1386,7 @@ dependencies = [ "datafusion-functions-aggregate-common", "datafusion-functions-window-common", "datafusion-physical-expr-common", + "enumset", "indexmap", "paste", "recursive", @@ -1671,6 +1706,27 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c34f04666d835ff5d62e058c3995147c06f42fe86ff053337632bca83e42702d" +[[package]] +name = "enumset" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d07a4b049558765cef5f0c1a273c3fc57084d768b44d2f98127aef4cceb17293" +dependencies = [ + "enumset_derive", +] + +[[package]] +name = "enumset_derive" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59c3b24c345d8c314966bdc1832f6c2635bfcce8e7cf363bd115987bba2ee242" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "env_filter" version = "0.1.2" @@ -2332,6 +2388,12 @@ dependencies = [ "syn", ] +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + [[package]] name = "idna" version = "1.0.3" diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index 18906536691ef..52c85d1ba3219 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -274,9 +274,9 @@ impl AdjustedPrintOptions { // all rows if matches!( plan, - LogicalPlan::Explain(_) + LogicalPlan::Explain(_, _) | LogicalPlan::DescribeTable(_) - | LogicalPlan::Analyze(_) + | LogicalPlan::Analyze(_, _) ) { self.inner.maxrows = MaxRows::Unlimited; } @@ -311,7 +311,7 @@ async fn create_plan( // Note that cmd is a mutable reference so that create_external_table function can remove all // datafusion-cli specific options before passing through to datafusion. Otherwise, datafusion // will raise Configuration errors. - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan { + if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd), _) = &plan { // To support custom formats, treat error as None let format = config_file_type_from_str(&cmd.file_type); register_object_store_and_config_extensions( @@ -323,7 +323,7 @@ async fn create_plan( .await?; } - if let LogicalPlan::Copy(copy_to) = &mut plan { + if let LogicalPlan::Copy(copy_to, _) = &mut plan { let format = config_file_type_from_str(©_to.file_type.get_ext()); register_object_store_and_config_extensions( diff --git a/datafusion-examples/examples/analyzer_rule.rs b/datafusion-examples/examples/analyzer_rule.rs index bd067be97b8b3..15695992521c0 100644 --- a/datafusion-examples/examples/analyzer_rule.rs +++ b/datafusion-examples/examples/analyzer_rule.rs @@ -175,7 +175,7 @@ impl AnalyzerRule for RowLevelAccessControl { } fn is_employee_table_scan(plan: &LogicalPlan) -> bool { - if let LogicalPlan::TableScan(scan) = plan { + if let LogicalPlan::TableScan(scan, _) = plan { scan.table_name.table() == "employee" } else { false diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index cb0796bdcf735..de45ad7784edc 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -61,7 +61,7 @@ async fn main() -> Result<()> { let expr = col("a") + lit(5); // The same same expression can be created directly, with much more code: - let expr2 = Expr::BinaryExpr(BinaryExpr::new( + let expr2 = Expr::binary_expr(BinaryExpr::new( Box::new(col("a")), Operator::Plus, Box::new(Expr::Literal(ScalarValue::Int32(Some(5)))), @@ -396,20 +396,20 @@ fn type_coercion_demo() -> Result<()> { let coerced_expr = expr .transform(|e| { // Only type coerces binary expressions. - let Expr::BinaryExpr(e) = e else { + let Expr::BinaryExpr(e, _) = e else { return Ok(Transformed::no(e)); }; if let Expr::Column(ref col_expr) = *e.left { let field = df_schema.field_with_name(None, col_expr.name())?; let cast_to_type = field.data_type(); let coerced_right = e.right.cast_to(cast_to_type, &df_schema)?; - Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new( + Ok(Transformed::yes(Expr::binary_expr(BinaryExpr::new( e.left, e.op, Box::new(coerced_right), )))) } else { - Ok(Transformed::no(Expr::BinaryExpr(e))) + Ok(Transformed::no(Expr::binary_expr(e))) } })? .data; diff --git a/datafusion-examples/examples/optimizer_rule.rs b/datafusion-examples/examples/optimizer_rule.rs index e0b552620a9af..f83dce6d04acc 100644 --- a/datafusion-examples/examples/optimizer_rule.rs +++ b/datafusion-examples/examples/optimizer_rule.rs @@ -145,7 +145,7 @@ impl MyOptimizerRule { expr.transform_up(|expr| { // Closure called for each sub tree match expr { - Expr::BinaryExpr(binary_expr) if is_binary_eq(&binary_expr) => { + Expr::BinaryExpr(binary_expr, _) if is_binary_eq(&binary_expr) => { // destruture the expression let BinaryExpr { left, op: _, right } = binary_expr; // rewrite to `my_eq(left, right)` diff --git a/datafusion-examples/examples/simplify_udaf_expression.rs b/datafusion-examples/examples/simplify_udaf_expression.rs index 52a27317e3c3d..707d2ed1d8ff0 100644 --- a/datafusion-examples/examples/simplify_udaf_expression.rs +++ b/datafusion-examples/examples/simplify_udaf_expression.rs @@ -91,7 +91,7 @@ impl AggregateUDFImpl for BetterAvgUdaf { // as an example for this functionality we replace UDF function // with build-in aggregate function to illustrate the use let simplify = |aggregate_function: AggregateFunction, _: &dyn SimplifyInfo| { - Ok(Expr::AggregateFunction(AggregateFunction::new_udf( + Ok(Expr::aggregate_function(AggregateFunction::new_udf( avg_udaf(), // yes it is the same Avg, `BetterAvgUdaf` was just a // marketing pitch :) diff --git a/datafusion-examples/examples/simplify_udwf_expression.rs b/datafusion-examples/examples/simplify_udwf_expression.rs index 117063df4e0d8..410fe73d28fef 100644 --- a/datafusion-examples/examples/simplify_udwf_expression.rs +++ b/datafusion-examples/examples/simplify_udwf_expression.rs @@ -71,7 +71,7 @@ impl WindowUDFImpl for SimplifySmoothItUdf { /// this function will simplify `SimplifySmoothItUdf` to `SmoothItUdf`. fn simplify(&self) -> Option { let simplify = |window_function: WindowFunction, _: &dyn SimplifyInfo| { - Ok(Expr::WindowFunction(WindowFunction { + Ok(Expr::window_function(WindowFunction { fun: datafusion_expr::WindowFunctionDefinition::AggregateUDF(avg_udaf()), args: window_function.args, partition_by: window_function.partition_by, diff --git a/datafusion-examples/examples/sql_analysis.rs b/datafusion-examples/examples/sql_analysis.rs index 2158b8e4b016e..db118eb26cf36 100644 --- a/datafusion-examples/examples/sql_analysis.rs +++ b/datafusion-examples/examples/sql_analysis.rs @@ -39,7 +39,7 @@ fn total_join_count(plan: &LogicalPlan) -> usize { // We can use the TreeNode API to walk over a LogicalPlan. plan.apply(|node| { // if we encounter a join we update the running count - if matches!(node, LogicalPlan::Join(_)) { + if matches!(node, LogicalPlan::Join(_, _)) { total += 1; } Ok(TreeNodeRecursion::Continue) @@ -89,7 +89,7 @@ fn count_trees(plan: &LogicalPlan) -> (usize, Vec) { while let Some(node) = to_visit.pop() { // if we encounter a join, we know were at the root of the tree // count this tree and recurse on it's inputs - if matches!(node, LogicalPlan::Join(_)) { + if matches!(node, LogicalPlan::Join(_, _)) { let (group_count, inputs) = count_tree(node); total += group_count; groups.push(group_count); @@ -146,12 +146,12 @@ fn count_tree(join: &LogicalPlan) -> (usize, Vec<&LogicalPlan>) { // / \ // B C // we can continue the recursion in this case - if let LogicalPlan::Projection(_) = node { + if let LogicalPlan::Projection(_, _) = node { return Ok(TreeNodeRecursion::Continue); } // any join we count - if matches!(node, LogicalPlan::Join(_)) { + if matches!(node, LogicalPlan::Join(_, _)) { total += 1; Ok(TreeNodeRecursion::Continue) } else { diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index bcf803573cdfd..25c8efa03ff06 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -512,7 +512,7 @@ impl DataFrame { group_expr: Vec, aggr_expr: Vec, ) -> Result { - let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]); + let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_, _)]); let aggr_expr_len = aggr_expr.len(); let plan = LogicalPlanBuilder::from(self.plan) .aggregate(group_expr, aggr_expr)? @@ -1404,7 +1404,7 @@ impl DataFrame { /// # } /// ``` pub fn explain(self, verbose: bool, analyze: bool) -> Result { - if matches!(self.plan, LogicalPlan::Explain(_)) { + if matches!(self.plan, LogicalPlan::Explain(_, _)) { return plan_err!("Nested EXPLAINs are not supported"); } let plan = LogicalPlanBuilder::from(self.plan) @@ -2176,7 +2176,7 @@ mod tests { async fn select_with_window_exprs() -> Result<()> { // build plan using Table API let t = test_table().await?; - let first_row = Expr::WindowFunction(WindowFunction::new( + let first_row = Expr::window_function(WindowFunction::new( WindowFunctionDefinition::WindowUDF(first_value_udwf()), vec![col("aggregate_test_100.c1")], )) @@ -2742,7 +2742,7 @@ mod tests { vec![col("c3")], ); - Expr::WindowFunction(w) + Expr::window_function(w) .null_treatment(NullTreatment::IgnoreNulls) .order_by(vec![col("c2").sort(true, true), col("c3").sort(true, true)]) .window_frame(WindowFrame::new_bounds( diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index 04c64156b125b..8bb1672223fa9 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -63,33 +63,33 @@ pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { } } Expr::Literal(_) - | Expr::Alias(_) + | Expr::Alias(_, _) | Expr::OuterReferenceColumn(_, _) | Expr::ScalarVariable(_, _) - | Expr::Not(_) - | Expr::IsNotNull(_) - | Expr::IsNull(_) - | Expr::IsTrue(_) - | Expr::IsFalse(_) - | Expr::IsUnknown(_) - | Expr::IsNotTrue(_) - | Expr::IsNotFalse(_) - | Expr::IsNotUnknown(_) - | Expr::Negative(_) - | Expr::Cast(_) - | Expr::TryCast(_) - | Expr::BinaryExpr(_) - | Expr::Between(_) - | Expr::Like(_) - | Expr::SimilarTo(_) - | Expr::InList(_) - | Expr::Exists(_) - | Expr::InSubquery(_) - | Expr::ScalarSubquery(_) - | Expr::GroupingSet(_) - | Expr::Case(_) => Ok(TreeNodeRecursion::Continue), - - Expr::ScalarFunction(scalar_function) => { + | Expr::Not(_, _) + | Expr::IsNotNull(_, _) + | Expr::IsNull(_, _) + | Expr::IsTrue(_, _) + | Expr::IsFalse(_, _) + | Expr::IsUnknown(_, _) + | Expr::IsNotTrue(_, _) + | Expr::IsNotFalse(_, _) + | Expr::IsNotUnknown(_, _) + | Expr::Negative(_, _) + | Expr::Cast(_, _) + | Expr::TryCast(_, _) + | Expr::BinaryExpr(_, _) + | Expr::Between(_, _) + | Expr::Like(_, _) + | Expr::SimilarTo(_, _) + | Expr::InList(_, _) + | Expr::Exists(_, _) + | Expr::InSubquery(_, _) + | Expr::ScalarSubquery(_, _) + | Expr::GroupingSet(_, _) + | Expr::Case(_, _) => Ok(TreeNodeRecursion::Continue), + + Expr::ScalarFunction(scalar_function, _) => { match scalar_function.func.signature().volatility { Volatility::Immutable => Ok(TreeNodeRecursion::Continue), // TODO: Stable functions could be `applicable`, but that would require access to the context @@ -330,11 +330,14 @@ fn populate_partition_values<'a>( partition_values: &mut HashMap<&'a str, PartitionValue>, filter: &'a Expr, ) { - if let Expr::BinaryExpr(BinaryExpr { - ref left, - op, - ref right, - }) = filter + if let Expr::BinaryExpr( + BinaryExpr { + ref left, + op, + ref right, + }, + _, + ) = filter { match op { Operator::Eq => match (left.as_ref(), right.as_ref()) { diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index ffe49dd2ba116..efbb687c6341c 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -1954,7 +1954,7 @@ mod tests { false, )])); - let filter_predicate = Expr::BinaryExpr(BinaryExpr::new( + let filter_predicate = Expr::binary_expr(BinaryExpr::new( Box::new(Expr::Column("column1".into())), Operator::GtEq, Box::new(Expr::Literal(ScalarValue::Int32(Some(0)))), diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 5f01d41c31e73..99fcb38edf42a 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -648,7 +648,7 @@ impl SessionContext { /// [`SQLOptions::verify_plan`]. pub async fn execute_logical_plan(&self, plan: LogicalPlan) -> Result { match plan { - LogicalPlan::Ddl(ddl) => { + LogicalPlan::Ddl(ddl, _) => { // Box::pin avoids allocating the stack space within this function's frame // for every one of these individual async functions, decreasing the risk of // stack overflows. @@ -681,18 +681,21 @@ impl SessionContext { DdlStatement::DropFunction(cmd) => { Box::pin(self.drop_function(cmd)).await } - ddl => Ok(DataFrame::new(self.state(), LogicalPlan::Ddl(ddl))), + ddl => Ok(DataFrame::new(self.state(), LogicalPlan::ddl(ddl))), } } // TODO what about the other statements (like TransactionStart and TransactionEnd) - LogicalPlan::Statement(Statement::SetVariable(stmt)) => { + LogicalPlan::Statement(Statement::SetVariable(stmt), _) => { self.set_variable(stmt).await } - LogicalPlan::Statement(Statement::Prepare(Prepare { - name, - input, - data_types, - })) => { + LogicalPlan::Statement( + Statement::Prepare(Prepare { + name, + input, + data_types, + }), + _, + ) => { // The number of parameters must match the specified data types length. if !data_types.is_empty() { let param_names = input.get_parameter_names()?; @@ -712,10 +715,10 @@ impl SessionContext { self.state.write().store_prepared(name, data_types, input)?; self.return_empty_dataframe() } - LogicalPlan::Statement(Statement::Execute(execute)) => { + LogicalPlan::Statement(Statement::Execute(execute), _) => { self.execute_prepared(execute) } - LogicalPlan::Statement(Statement::Deallocate(deallocate)) => { + LogicalPlan::Statement(Statement::Deallocate(deallocate), _) => { self.state .write() .remove_prepared(deallocate.name.as_str())?; @@ -1769,16 +1772,16 @@ impl<'n, 'a> TreeNodeVisitor<'n> for BadPlanVisitor<'a> { fn f_down(&mut self, node: &'n Self::Node) -> Result { match node { - LogicalPlan::Ddl(ddl) if !self.options.allow_ddl => { + LogicalPlan::Ddl(ddl, _) if !self.options.allow_ddl => { plan_err!("DDL not supported: {}", ddl.name()) } - LogicalPlan::Dml(dml) if !self.options.allow_dml => { + LogicalPlan::Dml(dml, _) if !self.options.allow_dml => { plan_err!("DML not supported: {}", dml.op) } - LogicalPlan::Copy(_) if !self.options.allow_dml => { + LogicalPlan::Copy(_, _) if !self.options.allow_dml => { plan_err!("DML not supported: COPY") } - LogicalPlan::Statement(stmt) if !self.options.allow_statements => { + LogicalPlan::Statement(stmt, _) if !self.options.allow_statements => { plan_err!("Statement not supported: {}", stmt.name()) } _ => Ok(TreeNodeRecursion::Continue), diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 9fc081dd53298..8fe529cbea1e8 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -630,7 +630,7 @@ impl SessionState { /// Optimizes the logical plan by applying optimizer rules. pub fn optimize(&self, plan: &LogicalPlan) -> datafusion_common::Result { - if let LogicalPlan::Explain(e) = plan { + if let LogicalPlan::Explain(e, _) = plan { let mut stringified_plans = e.stringified_plans.clone(); // analyze & capture output of each rule @@ -650,7 +650,7 @@ impl SessionState { stringified_plans .push(StringifiedPlan::new(plan_type, err.to_string())); - return Ok(LogicalPlan::Explain(Explain { + return Ok(LogicalPlan::explain(Explain { verbose: e.verbose, plan: Arc::clone(&e.plan), stringified_plans, @@ -686,7 +686,7 @@ impl SessionState { Err(e) => return Err(e), }; - Ok(LogicalPlan::Explain(Explain { + Ok(LogicalPlan::explain(Explain { verbose: e.verbose, plan, stringified_plans, diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 44537c951f945..78bf4cd19ffb2 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -438,13 +438,16 @@ impl DefaultPhysicalPlanner { ) -> Result> { let exec_node: Arc = match node { // Leaves (no children) - LogicalPlan::TableScan(TableScan { - source, - projection, - filters, - fetch, - .. - }) => { + LogicalPlan::TableScan( + TableScan { + source, + projection, + filters, + fetch, + .. + }, + _, + ) => { let source = source_as_provider(source)?; // Remove all qualifiers from the scan as the provider // doesn't know (nor should care) how the relation was @@ -454,7 +457,7 @@ impl DefaultPhysicalPlanner { .scan(session_state, projection.as_ref(), &filters, *fetch) .await? } - LogicalPlan::Values(Values { values, schema }) => { + LogicalPlan::Values(Values { values, schema }, _) => { let exec_schema = schema.as_ref().to_owned().into(); let exprs = values .iter() @@ -490,13 +493,16 @@ impl DefaultPhysicalPlanner { } // 1 Child - LogicalPlan::Copy(CopyTo { - input, - output_url, - file_type, - partition_by, - options: source_option_tuples, - }) => { + LogicalPlan::Copy( + CopyTo { + input, + output_url, + file_type, + partition_by, + options: source_option_tuples, + }, + _, + ) => { let input_exec = children.one()?; let parsed_url = ListingTableUrl::parse(output_url)?; let object_store_url = parsed_url.object_store(); @@ -539,11 +545,14 @@ impl DefaultPhysicalPlanner { .create_writer_physical_plan(input_exec, session_state, config, None) .await? } - LogicalPlan::Dml(DmlStatement { - table_name, - op: WriteOp::Insert(insert_op), - .. - }) => { + LogicalPlan::Dml( + DmlStatement { + table_name, + op: WriteOp::Insert(insert_op), + .. + }, + _, + ) => { let name = table_name.table(); let schema = session_state.schema_for_ref(table_name.clone())?; if let Some(provider) = schema.table(name).await? { @@ -555,9 +564,12 @@ impl DefaultPhysicalPlanner { return exec_err!("Table '{table_name}' does not exist"); } } - LogicalPlan::Window(Window { - input, window_expr, .. - }) => { + LogicalPlan::Window( + Window { + input, window_expr, .. + }, + _, + ) => { if window_expr.is_empty() { return internal_err!("Impossibly got empty window expression"); } @@ -584,19 +596,25 @@ impl DefaultPhysicalPlanner { }; let get_sort_keys = |expr: &Expr| match expr { - Expr::WindowFunction(WindowFunction { - ref partition_by, - ref order_by, - .. - }) => generate_sort_key(partition_by, order_by), - Expr::Alias(Alias { expr, .. }) => { + Expr::WindowFunction( + WindowFunction { + ref partition_by, + ref order_by, + .. + }, + _, + ) => generate_sort_key(partition_by, order_by), + Expr::Alias(Alias { expr, .. }, _) => { // Convert &Box to &T match &**expr { - Expr::WindowFunction(WindowFunction { - ref partition_by, - ref order_by, - .. - }) => generate_sort_key(partition_by, order_by), + Expr::WindowFunction( + WindowFunction { + ref partition_by, + ref order_by, + .. + }, + _, + ) => generate_sort_key(partition_by, order_by), _ => unreachable!(), } } @@ -643,12 +661,15 @@ impl DefaultPhysicalPlanner { )?) } } - LogicalPlan::Aggregate(Aggregate { - input, - group_expr, - aggr_expr, - .. - }) => { + LogicalPlan::Aggregate( + Aggregate { + input, + group_expr, + aggr_expr, + .. + }, + _, + ) => { let options = session_state.config().options(); // Initially need to perform the aggregate and then merge the partitions let input_exec = children.one()?; @@ -758,16 +779,19 @@ impl DefaultPhysicalPlanner { Arc::clone(&physical_input_schema), )?) } - LogicalPlan::Projection(Projection { input, expr, .. }) => self + LogicalPlan::Projection(Projection { input, expr, .. }, _) => self .create_project_physical_exec( session_state, children.one()?, input, expr, )?, - LogicalPlan::Filter(Filter { - predicate, input, .. - }) => { + LogicalPlan::Filter( + Filter { + predicate, input, .. + }, + _, + ) => { let physical_input = children.one()?; let input_dfschema = input.schema(); @@ -781,10 +805,13 @@ impl DefaultPhysicalPlanner { let filter = FilterExec::try_new(runtime_expr, physical_input)?; Arc::new(filter.with_default_selectivity(selectivity)?) } - LogicalPlan::Repartition(Repartition { - input, - partitioning_scheme, - }) => { + LogicalPlan::Repartition( + Repartition { + input, + partitioning_scheme, + }, + _, + ) => { let physical_input = children.one()?; let input_dfschema = input.as_ref().schema(); let physical_partitioning = match partitioning_scheme { @@ -815,9 +842,12 @@ impl DefaultPhysicalPlanner { physical_partitioning, )?) } - LogicalPlan::Sort(Sort { - expr, input, fetch, .. - }) => { + LogicalPlan::Sort( + Sort { + expr, input, fetch, .. + }, + _, + ) => { let physical_input = children.one()?; let input_dfschema = input.as_ref().schema(); let sort_expr = create_physical_sort_exprs( @@ -829,9 +859,9 @@ impl DefaultPhysicalPlanner { SortExec::new(sort_expr, physical_input).with_fetch(*fetch); Arc::new(new_sort) } - LogicalPlan::Subquery(_) => todo!(), - LogicalPlan::SubqueryAlias(_) => children.one()?, - LogicalPlan::Limit(limit) => { + LogicalPlan::Subquery(_, _) => todo!(), + LogicalPlan::SubqueryAlias(_, _) => children.one()?, + LogicalPlan::Limit(limit, _) => { let input = children.one()?; let SkipType::Literal(skip) = limit.get_skip_type()? else { return not_impl_err!( @@ -861,13 +891,16 @@ impl DefaultPhysicalPlanner { Arc::new(GlobalLimitExec::new(input, skip, fetch)) } - LogicalPlan::Unnest(Unnest { - list_type_columns, - struct_type_columns, - schema, - options, - .. - }) => { + LogicalPlan::Unnest( + Unnest { + list_type_columns, + struct_type_columns, + schema, + options, + .. + }, + _, + ) => { let input = children.one()?; let schema = SchemaRef::new(schema.as_ref().to_owned().into()); let list_column_indices = list_type_columns @@ -887,16 +920,19 @@ impl DefaultPhysicalPlanner { } // 2 Children - LogicalPlan::Join(Join { - left, - right, - on: keys, - filter, - join_type, - null_equals_null, - schema: join_schema, - .. - }) => { + LogicalPlan::Join( + Join { + left, + right, + on: keys, + filter, + join_type, + null_equals_null, + schema: join_schema, + .. + }, + _, + ) => { let null_equals_null = *null_equals_null; let [physical_left, physical_right] = children.two()?; @@ -925,7 +961,7 @@ impl DefaultPhysicalPlanner { let left = Arc::new(left); let right = Arc::new(right); - let new_join = LogicalPlan::Join(Join::try_new_with_project_input( + let new_join = LogicalPlan::join(Join::try_new_with_project_input( node, Arc::clone(&left), Arc::clone(&right), @@ -938,7 +974,7 @@ impl DefaultPhysicalPlanner { // If left_projected is true we are guaranteed that left is a Projection ( true, - LogicalPlan::Projection(Projection { input, expr, .. }), + LogicalPlan::Projection(Projection { input, expr, .. }, _), ) => self.create_project_physical_exec( session_state, physical_left, @@ -951,7 +987,7 @@ impl DefaultPhysicalPlanner { // If right_projected is true we are guaranteed that right is a Projection ( true, - LogicalPlan::Projection(Projection { input, expr, .. }), + LogicalPlan::Projection(Projection { input, expr, .. }, _), ) => self.create_project_physical_exec( session_state, physical_right, @@ -965,7 +1001,7 @@ impl DefaultPhysicalPlanner { if left_projected || right_projected { let final_join_result = join_schema.iter().map(Expr::from).collect::>(); - let projection = LogicalPlan::Projection(Projection::try_new( + let projection = LogicalPlan::projection(Projection::try_new( final_join_result, Arc::new(new_join), )?); @@ -982,19 +1018,25 @@ impl DefaultPhysicalPlanner { // Retrieving new left/right and join keys (in case plan was mutated above) let (left, right, keys, new_project) = match new_logical.as_ref() { - LogicalPlan::Projection(Projection { input, expr, .. }) => { - if let LogicalPlan::Join(Join { - left, right, on, .. - }) = input.as_ref() + LogicalPlan::Projection(Projection { input, expr, .. }, _) => { + if let LogicalPlan::Join( + Join { + left, right, on, .. + }, + _, + ) = input.as_ref() { (left, right, on, Some((input, expr))) } else { unreachable!() } } - LogicalPlan::Join(Join { - left, right, on, .. - }) => (left, right, on, None), + LogicalPlan::Join( + Join { + left, right, on, .. + }, + _, + ) => (left, right, on, None), // Should either be the original Join, or Join with a Projection on top _ => unreachable!(), }; @@ -1170,9 +1212,12 @@ impl DefaultPhysicalPlanner { join } } - LogicalPlan::RecursiveQuery(RecursiveQuery { - name, is_distinct, .. - }) => { + LogicalPlan::RecursiveQuery( + RecursiveQuery { + name, is_distinct, .. + }, + _, + ) => { let [static_term, recursive_term] = children.two()?; Arc::new(RecursiveQueryExec::try_new( name.clone(), @@ -1183,8 +1228,8 @@ impl DefaultPhysicalPlanner { } // N Children - LogicalPlan::Union(_) => Arc::new(UnionExec::new(children.vec())), - LogicalPlan::Extension(Extension { node }) => { + LogicalPlan::Union(_, _) => Arc::new(UnionExec::new(children.vec())), + LogicalPlan::Extension(Extension { node }, _) => { let mut maybe_plan = None; let children = children.vec(); for planner in &self.extension_planners { @@ -1224,16 +1269,16 @@ impl DefaultPhysicalPlanner { } // Other - LogicalPlan::Statement(statement) => { + LogicalPlan::Statement(statement, _) => { // DataFusion is a read-only query engine, but also a library, so consumers may implement this let name = statement.name(); return not_impl_err!("Unsupported logical plan: Statement({name})"); } - LogicalPlan::Dml(dml) => { + LogicalPlan::Dml(dml, _) => { // DataFusion is a read-only query engine, but also a library, so consumers may implement this return not_impl_err!("Unsupported logical plan: Dml({0})", dml.op); } - LogicalPlan::Ddl(ddl) => { + LogicalPlan::Ddl(ddl, _) => { // There is no default plan for DDl statements -- // it must be handled at a higher level (so that // the appropriate table can be registered with @@ -1241,17 +1286,17 @@ impl DefaultPhysicalPlanner { let name = ddl.name(); return not_impl_err!("Unsupported logical plan: {name}"); } - LogicalPlan::Explain(_) => { + LogicalPlan::Explain(_, _) => { return internal_err!( "Unsupported logical plan: Explain must be root of the plan" ) } - LogicalPlan::Distinct(_) => { + LogicalPlan::Distinct(_, _) => { return internal_err!( "Unsupported logical plan: Distinct should be replaced to Aggregate" ) } - LogicalPlan::Analyze(_) => { + LogicalPlan::Analyze(_, _) => { return internal_err!( "Unsupported logical plan: Analyze must be root of the plan" ) @@ -1269,7 +1314,7 @@ impl DefaultPhysicalPlanner { ) -> Result { if group_expr.len() == 1 { match &group_expr[0] { - Expr::GroupingSet(GroupingSet::GroupingSets(grouping_sets)) => { + Expr::GroupingSet(GroupingSet::GroupingSets(grouping_sets), _) => { merge_grouping_set_physical_expr( grouping_sets, input_dfschema, @@ -1277,13 +1322,15 @@ impl DefaultPhysicalPlanner { session_state, ) } - Expr::GroupingSet(GroupingSet::Cube(exprs)) => create_cube_physical_expr( - exprs, - input_dfschema, - input_schema, - session_state, - ), - Expr::GroupingSet(GroupingSet::Rollup(exprs)) => { + Expr::GroupingSet(GroupingSet::Cube(exprs), _) => { + create_cube_physical_expr( + exprs, + input_dfschema, + input_schema, + session_state, + ) + } + Expr::GroupingSet(GroupingSet::Rollup(exprs), _) => { create_rollup_physical_expr( exprs, input_dfschema, @@ -1516,14 +1563,17 @@ pub fn create_window_expr_with_name( let name = name.into(); let physical_schema: &Schema = &logical_schema.into(); match e { - Expr::WindowFunction(WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - null_treatment, - }) => { + Expr::WindowFunction( + WindowFunction { + fun, + args, + partition_by, + order_by, + window_frame, + null_treatment, + }, + _, + ) => { let physical_args = create_physical_exprs(args, logical_schema, execution_props)?; let partition_by = @@ -1564,7 +1614,7 @@ pub fn create_window_expr( ) -> Result> { // unpack aliased logical expressions, e.g. "sum(col) over () as total" let (name, e) = match e { - Expr::Alias(Alias { expr, name, .. }) => (name.clone(), expr.as_ref()), + Expr::Alias(Alias { expr, name, .. }, _) => (name.clone(), expr.as_ref()), _ => (e.schema_name().to_string(), e), }; create_window_expr_with_name(e, name, logical_schema, execution_props) @@ -1587,14 +1637,17 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( execution_props: &ExecutionProps, ) -> Result { match e { - Expr::AggregateFunction(AggregateFunction { - func, - distinct, - args, - filter, - order_by, - null_treatment, - }) => { + Expr::AggregateFunction( + AggregateFunction { + func, + distinct, + args, + filter, + order_by, + null_treatment, + }, + _, + ) => { let name = if let Some(name) = name { name } else { @@ -1656,8 +1709,8 @@ pub fn create_aggregate_expr_and_maybe_filter( ) -> Result { // unpack (nested) aliased logical expressions, e.g. "sum(col) as total" let (name, e) = match e { - Expr::Alias(Alias { expr, name, .. }) => (Some(name.clone()), expr.as_ref()), - Expr::AggregateFunction(_) => (Some(e.schema_name().to_string()), e), + Expr::Alias(Alias { expr, name, .. }, _) => (Some(name.clone()), expr.as_ref()), + Expr::AggregateFunction(_, _) => (Some(e.schema_name().to_string()), e), _ => (None, e), }; @@ -1713,7 +1766,7 @@ impl DefaultPhysicalPlanner { logical_plan: &LogicalPlan, session_state: &SessionState, ) -> Result>> { - if let LogicalPlan::Explain(e) = logical_plan { + if let LogicalPlan::Explain(e, _) = logical_plan { use PlanType::*; let mut stringified_plans = vec![]; @@ -1836,7 +1889,7 @@ impl DefaultPhysicalPlanner { stringified_plans, e.verbose, )))) - } else if let LogicalPlan::Analyze(a) = logical_plan { + } else if let LogicalPlan::Analyze(a, _) = logical_plan { let input = self.create_physical_plan(&a.input, session_state).await?; let schema = SchemaRef::new((*a.schema).clone().into()); let show_statistics = session_state.config_options().explain.show_statistics; @@ -2175,7 +2228,7 @@ mod tests { ErrorExtensionPlanner {}, )]); - let logical_plan = LogicalPlan::Extension(Extension { + let logical_plan = LogicalPlan::extension(Extension { node: Arc::new(NoOpExtensionNode::default()), }); match planner @@ -2233,7 +2286,7 @@ mod tests { async fn default_extension_planner() { let session_state = make_session_state(); let planner = DefaultPhysicalPlanner::default(); - let logical_plan = LogicalPlan::Extension(Extension { + let logical_plan = LogicalPlan::extension(Extension { node: Arc::new(NoOpExtensionNode::default()), }); let plan = planner @@ -2260,7 +2313,7 @@ mod tests { BadExtensionPlanner {}, )]); - let logical_plan = LogicalPlan::Extension(Extension { + let logical_plan = LogicalPlan::extension(Extension { node: Arc::new(NoOpExtensionNode::default()), }); let plan = planner @@ -2371,7 +2424,7 @@ mod tests { #[tokio::test] async fn hash_agg_grouping_set_input_schema() -> Result<()> { - let grouping_set_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![ + let grouping_set_expr = Expr::grouping_set(GroupingSet::GroupingSets(vec![ vec![col("c1")], vec![col("c2")], vec![col("c1"), col("c2")], @@ -2446,7 +2499,7 @@ mod tests { #[tokio::test] async fn hash_agg_grouping_set_by_partitioned() -> Result<()> { - let grouping_set_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![ + let grouping_set_expr = Expr::grouping_set(GroupingSet::GroupingSets(vec![ vec![col("c1")], vec![col("c2")], vec![col("c1"), col("c2")], @@ -2702,7 +2755,7 @@ mod tests { let options = CsvReadOptions::new().schema_infer_max_records(100); let logical_plan = match ctx.read_csv(path, options).await?.into_optimized_plan()? { - LogicalPlan::TableScan(ref scan) => { + LogicalPlan::TableScan(ref scan, _) => { let mut scan = scan.clone(); let table_reference = TableReference::from(name); scan.table_name = table_reference; @@ -2712,7 +2765,7 @@ mod tests { .clone() .replace_qualifier(name.to_string()); scan.projected_schema = Arc::new(new_schema); - LogicalPlan::TableScan(scan) + LogicalPlan::table_scan(scan) } _ => unimplemented!(), }; diff --git a/datafusion/core/tests/custom_sources_cases/mod.rs b/datafusion/core/tests/custom_sources_cases/mod.rs index e1bd14105e23e..38773ef11f89f 100644 --- a/datafusion/core/tests/custom_sources_cases/mod.rs +++ b/datafusion/core/tests/custom_sources_cases/mod.rs @@ -236,11 +236,14 @@ async fn custom_source_dataframe() -> Result<()> { let optimized_plan = state.optimize(&logical_plan)?; match &optimized_plan { - LogicalPlan::TableScan(TableScan { - source, - projected_schema, - .. - }) => { + LogicalPlan::TableScan( + TableScan { + source, + projected_schema, + .. + }, + _, + ) => { assert_eq!(source.schema().fields().len(), 2); assert_eq!(projected_schema.fields().len(), 1); } diff --git a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs index 09f7265d639a7..40df4c3533813 100644 --- a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs +++ b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs @@ -172,13 +172,13 @@ impl TableProvider for CustomProvider { let empty = Vec::new(); let projection = projection.unwrap_or(&empty); match &filters[0] { - Expr::BinaryExpr(BinaryExpr { right, .. }) => { + Expr::BinaryExpr(BinaryExpr { right, .. }, _) => { let int_value = match &**right { Expr::Literal(ScalarValue::Int8(Some(i))) => *i as i64, Expr::Literal(ScalarValue::Int16(Some(i))) => *i as i64, Expr::Literal(ScalarValue::Int32(Some(i))) => *i as i64, Expr::Literal(ScalarValue::Int64(Some(i))) => *i, - Expr::Cast(Cast { expr, data_type: _ }) => match expr.deref() { + Expr::Cast(Cast { expr, data_type: _ }, _) => match expr.deref() { Expr::Literal(lit_value) => match lit_value { ScalarValue::Int8(Some(v)) => *v as i64, ScalarValue::Int16(Some(v)) => *v as i64, diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index 1bd90fce839d0..c025c6deb5cc5 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -31,7 +31,6 @@ use datafusion::prelude::*; use datafusion::assert_batches_eq; use datafusion_common::{DFSchema, ScalarValue}; -use datafusion_expr::expr::Alias; use datafusion_expr::ExprSchemable; use datafusion_functions_aggregate::expr_fn::{approx_median, approx_percentile_cont}; use datafusion_functions_nested::map::map; @@ -376,11 +375,7 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { assert_batches_eq!(expected, &batches); // the arg2 parameter is a complex expr, but it can be evaluated to the literal value - let alias_expr = Expr::Alias(Alias::new( - cast(lit(0.5), DataType::Float32), - None::<&str>, - "arg_2".to_string(), - )); + let alias_expr = cast(lit(0.5), DataType::Float32).alias("arg_2".to_string()); let expr = approx_percentile_cont(col("b"), alias_expr, None); let df = create_test_table().await?; let expected = [ diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 439aa6147e9b6..c2f64974de88f 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -180,7 +180,7 @@ async fn test_count_wildcard_on_window() -> Result<()> { let df_results = ctx .table("t1") .await? - .select(vec![Expr::WindowFunction(expr::WindowFunction::new( + .select(vec![Expr::window_function(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![wildcard()], )) @@ -589,7 +589,7 @@ async fn select_with_alias_overwrite() -> Result<()> { #[tokio::test] async fn test_grouping_sets() -> Result<()> { - let grouping_set_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![ + let grouping_set_expr = Expr::grouping_set(GroupingSet::GroupingSets(vec![ vec![col("a")], vec![col("b")], vec![col("a"), col("b")], @@ -631,7 +631,7 @@ async fn test_grouping_sets() -> Result<()> { async fn test_grouping_sets_count() -> Result<()> { let ctx = SessionContext::new(); - let grouping_set_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![ + let grouping_set_expr = Expr::grouping_set(GroupingSet::GroupingSets(vec![ vec![col("c1")], vec![col("c2")], ])); @@ -671,7 +671,7 @@ async fn test_grouping_sets_count() -> Result<()> { async fn test_grouping_set_array_agg_with_overflow() -> Result<()> { let ctx = SessionContext::new(); - let grouping_set_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![ + let grouping_set_expr = Expr::grouping_set(GroupingSet::GroupingSets(vec![ vec![col("c1")], vec![col("c2")], vec![col("c1"), col("c2")], diff --git a/datafusion/core/tests/execution/logical_plan.rs b/datafusion/core/tests/execution/logical_plan.rs index 168bf484e5411..97ee32604aea2 100644 --- a/datafusion/core/tests/execution/logical_plan.rs +++ b/datafusion/core/tests/execution/logical_plan.rs @@ -40,7 +40,7 @@ async fn count_only_nulls() -> Result<()> { vec![Field::new("col", DataType::Null, true)].into(), HashMap::new(), )?); - let input = Arc::new(LogicalPlan::Values(Values { + let input = Arc::new(LogicalPlan::values(Values { schema: input_schema, values: vec![ vec![Expr::Literal(ScalarValue::Null)], @@ -54,10 +54,10 @@ async fn count_only_nulls() -> Result<()> { }); // Aggregation: count(col) AS count - let aggregate = LogicalPlan::Aggregate(Aggregate::try_new( + let aggregate = LogicalPlan::aggregate(Aggregate::try_new( input, vec![], - vec![Expr::AggregateFunction(AggregateFunction { + vec![Expr::aggregate_function(AggregateFunction { func: Arc::new(AggregateUDF::new_from_impl(Count::new())), args: vec![input_col_ref], distinct: false, diff --git a/datafusion/core/tests/expr_api/simplification.rs b/datafusion/core/tests/expr_api/simplification.rs index 1e6ff8088d0af..ea32b65aa17f4 100644 --- a/datafusion/core/tests/expr_api/simplification.rs +++ b/datafusion/core/tests/expr_api/simplification.rs @@ -190,7 +190,7 @@ fn make_udf_add(volatility: Volatility) -> Arc { } fn cast_to_int64_expr(expr: Expr) -> Expr { - Expr::Cast(Cast::new(expr.into(), DataType::Int64)) + Expr::cast(Cast::new(expr.into(), DataType::Int64)) } fn to_timestamp_expr(arg: impl Into) -> Expr { @@ -391,7 +391,7 @@ fn test_const_evaluator_scalar_functions() { // rand() + (1 + 2) --> rand() + 3 let fun = math::random(); assert_eq!(fun.signature().volatility, Volatility::Volatile); - let rand = Expr::ScalarFunction(ScalarFunction::new_udf(fun, vec![])); + let rand = Expr::scalar_function(ScalarFunction::new_udf(fun, vec![])); let expr = rand.clone() + (lit(1) + lit(2)); let expected = rand + lit(3); test_evaluate(expr, expected); @@ -399,7 +399,7 @@ fn test_const_evaluator_scalar_functions() { // parenthesization matters: can't rewrite // (rand() + 1) + 2 --> (rand() + 1) + 2) let fun = math::random(); - let rand = Expr::ScalarFunction(ScalarFunction::new_udf(fun, vec![])); + let rand = Expr::scalar_function(ScalarFunction::new_udf(fun, vec![])); let expr = (rand + lit(1)) + lit(2); test_evaluate(expr.clone(), expr); } @@ -429,7 +429,7 @@ fn test_evaluator_udfs() { // immutable UDF should get folded // udf_add(1+2, 30+40) --> 73 - let expr = Expr::ScalarFunction(ScalarFunction::new_udf( + let expr = Expr::scalar_function(ScalarFunction::new_udf( make_udf_add(Volatility::Immutable), args.clone(), )); @@ -439,15 +439,15 @@ fn test_evaluator_udfs() { // udf_add(1+2, 30+40) --> 73 let fun = make_udf_add(Volatility::Stable); let expr = - Expr::ScalarFunction(ScalarFunction::new_udf(Arc::clone(&fun), args.clone())); + Expr::scalar_function(ScalarFunction::new_udf(Arc::clone(&fun), args.clone())); test_evaluate(expr, lit(73)); // volatile UDF should have args folded // udf_add(1+2, 30+40) --> udf_add(3, 70) let fun = make_udf_add(Volatility::Volatile); - let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::clone(&fun), args)); + let expr = Expr::scalar_function(ScalarFunction::new_udf(Arc::clone(&fun), args)); let expected_expr = - Expr::ScalarFunction(ScalarFunction::new_udf(Arc::clone(&fun), folded_args)); + Expr::scalar_function(ScalarFunction::new_udf(Arc::clone(&fun), folded_args)); test_evaluate(expr, expected_expr); } diff --git a/datafusion/core/tests/optimizer/mod.rs b/datafusion/core/tests/optimizer/mod.rs index f17d13a420607..8fd5b74876dd9 100644 --- a/datafusion/core/tests/optimizer/mod.rs +++ b/datafusion/core/tests/optimizer/mod.rs @@ -301,7 +301,7 @@ fn test_inequalities_non_null_bounded() { (col("x").not_between(lit(0), lit(5)), false), (col("x").not_between(lit(5), lit(10)), true), ( - Expr::BinaryExpr(BinaryExpr { + Expr::binary_expr(BinaryExpr { left: Box::new(col("x")), op: Operator::IsDistinctFrom, right: Box::new(lit(ScalarValue::Null)), @@ -309,7 +309,7 @@ fn test_inequalities_non_null_bounded() { true, ), ( - Expr::BinaryExpr(BinaryExpr { + Expr::binary_expr(BinaryExpr { left: Box::new(col("x")), op: Operator::IsDistinctFrom, right: Box::new(lit(5)), diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 177427b47d218..32705163bcb2d 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -326,11 +326,14 @@ async fn nyc() -> Result<()> { let optimized_plan = dataframe.into_optimized_plan().unwrap(); match &optimized_plan { - LogicalPlan::Aggregate(Aggregate { input, .. }) => match input.as_ref() { - LogicalPlan::TableScan(TableScan { - ref projected_schema, - .. - }) => { + LogicalPlan::Aggregate(Aggregate { input, .. }, _) => match input.as_ref() { + LogicalPlan::TableScan( + TableScan { + ref projected_schema, + .. + }, + _, + ) => { assert_eq!(2, projected_schema.fields().len()); assert_eq!(projected_schema.field(0).name(), "passenger_count"); assert_eq!(projected_schema.field(1).name(), "fare_amount"); diff --git a/datafusion/core/tests/user_defined/expr_planner.rs b/datafusion/core/tests/user_defined/expr_planner.rs index ad9c1280d6b11..3704a7a93a4ae 100644 --- a/datafusion/core/tests/user_defined/expr_planner.rs +++ b/datafusion/core/tests/user_defined/expr_planner.rs @@ -25,7 +25,6 @@ use datafusion::logical_expr::Operator; use datafusion::prelude::*; use datafusion::sql::sqlparser::ast::BinaryOperator; use datafusion_common::ScalarValue; -use datafusion_expr::expr::Alias; use datafusion_expr::planner::{ExprPlanner, PlannerResult, RawBinaryExpr}; use datafusion_expr::BinaryExpr; @@ -40,26 +39,23 @@ impl ExprPlanner for MyCustomPlanner { ) -> Result> { match &expr.op { BinaryOperator::Arrow => { - Ok(PlannerResult::Planned(Expr::BinaryExpr(BinaryExpr { + Ok(PlannerResult::Planned(Expr::binary_expr(BinaryExpr { left: Box::new(expr.left.clone()), right: Box::new(expr.right.clone()), op: Operator::StringConcat, }))) } BinaryOperator::LongArrow => { - Ok(PlannerResult::Planned(Expr::BinaryExpr(BinaryExpr { + Ok(PlannerResult::Planned(Expr::binary_expr(BinaryExpr { left: Box::new(expr.left.clone()), right: Box::new(expr.right.clone()), op: Operator::Plus, }))) } - BinaryOperator::Question => { - Ok(PlannerResult::Planned(Expr::Alias(Alias::new( - Expr::Literal(ScalarValue::Boolean(Some(true))), - None::<&str>, - format!("{} ? {}", expr.left, expr.right), - )))) - } + BinaryOperator::Question => Ok(PlannerResult::Planned( + Expr::Literal(ScalarValue::Boolean(Some(true))) + .alias(format!("{} ? {}", expr.left, expr.right)), + )), _ => Ok(PlannerResult::Original(expr)), } } diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index 520a91aeb4d6f..e1af4f2ccc422 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -360,22 +360,25 @@ impl OptimizerRule for TopKOptimizerRule { // Note: this code simply looks for the pattern of a Limit followed by a // Sort and replaces it by a TopK node. It does not handle many // edge cases (e.g multiple sort columns, sort ASC / DESC), etc. - let LogicalPlan::Limit(ref limit) = plan else { + let LogicalPlan::Limit(ref limit, _) = plan else { return Ok(Transformed::no(plan)); }; let FetchType::Literal(Some(fetch)) = limit.get_fetch_type()? else { return Ok(Transformed::no(plan)); }; - if let LogicalPlan::Sort(Sort { - ref expr, - ref input, - .. - }) = limit.input.as_ref() + if let LogicalPlan::Sort( + Sort { + ref expr, + ref input, + .. + }, + _, + ) = limit.input.as_ref() { if expr.len() == 1 { // we found a sort with a single sort expr, replace with a a TopK - return Ok(Transformed::yes(LogicalPlan::Extension(Extension { + return Ok(Transformed::yes(LogicalPlan::extension(Extension { node: Arc::new(TopKPlanNode { k: fetch, input: input.as_ref().clone(), @@ -705,9 +708,9 @@ impl MyAnalyzerRule { fn analyze_plan(plan: LogicalPlan) -> Result { plan.transform(|plan| { Ok(match plan { - LogicalPlan::Projection(projection) => { + LogicalPlan::Projection(projection, _) => { let expr = Self::analyze_expr(projection.expr.clone())?; - Transformed::yes(LogicalPlan::Projection(Projection::try_new( + Transformed::yes(LogicalPlan::projection(Projection::try_new( expr, projection.input, )?)) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index cf403e5d640f1..99a991bffe4fd 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -711,7 +711,7 @@ impl ScalarUDFImpl for CastToI64UDF { arg } else { // need to use an actual cast to get the correct type - Expr::Cast(datafusion_expr::Cast { + Expr::cast(datafusion_expr::Cast { expr: Box::new(arg), data_type: DataType::Int64, }) diff --git a/datafusion/core/tests/user_defined/user_defined_table_functions.rs b/datafusion/core/tests/user_defined/user_defined_table_functions.rs index 0cc156866d4d1..74b035fcbb4bc 100644 --- a/datafusion/core/tests/user_defined/user_defined_table_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_table_functions.rs @@ -176,7 +176,7 @@ impl SimpleCsvTable { )?], Arc::new(plan), ) - .map(LogicalPlan::Projection)?; + .map(LogicalPlan::projection)?; let rbs = collect( state.create_physical_plan(&logical_plan).await?, Arc::new(TaskContext::from(state)), diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index 19cd5ed3158b4..0d0a7c95d76aa 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -48,6 +48,7 @@ datafusion-expr-common = { workspace = true } datafusion-functions-aggregate-common = { workspace = true } datafusion-functions-window-common = { workspace = true } datafusion-physical-expr-common = { workspace = true } +enumset = { workspace = true } indexmap = { workspace = true } paste = "^1.0" recursive = { workspace = true } diff --git a/datafusion/expr/src/conditional_expressions.rs b/datafusion/expr/src/conditional_expressions.rs index 9cb51612d0cab..e4cbfa5a03962 100644 --- a/datafusion/expr/src/conditional_expressions.rs +++ b/datafusion/expr/src/conditional_expressions.rs @@ -88,7 +88,7 @@ impl CaseBuilder { } } - Ok(Expr::Case(Case::new( + Ok(Expr::case(Case::new( self.expr.clone(), self.when_expr .iter() diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 0e1c6c72c5111..717582d144e09 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -29,6 +29,7 @@ use crate::utils::expr_to_columns; use crate::Volatility; use crate::{udaf, ExprSchemable, Operator, Signature, WindowFrame, WindowUDF}; +use crate::logical_plan::tree_node::LogicalPlanStats; use arrow::datatypes::{DataType, FieldRef}; use datafusion_common::cse::HashNode; use datafusion_common::tree_node::{ @@ -221,10 +222,11 @@ use sqlparser::ast::{ /// assert!(rewritten.transformed); /// // to 42 = 5 AND b = 6 /// assert_eq!(rewritten.data, lit(42).eq(lit(5)).and(col("b").eq(lit(6)))); +#[allow(clippy::large_enum_variant)] #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub enum Expr { /// An expression with a specific name. - Alias(Alias), + Alias(Alias, LogicalPlanStats), /// A named reference to a qualified field in a schema. Column(Column), /// A named reference to a variable in a registry. @@ -232,33 +234,33 @@ pub enum Expr { /// A constant value. Literal(ScalarValue), /// A binary expression such as "age > 21" - BinaryExpr(BinaryExpr), + BinaryExpr(BinaryExpr, LogicalPlanStats), /// LIKE expression - Like(Like), + Like(Like, LogicalPlanStats), /// LIKE expression that uses regular expressions - SimilarTo(Like), + SimilarTo(Like, LogicalPlanStats), /// Negation of an expression. The expression's type must be a boolean to make sense. - Not(Box), + Not(Box, LogicalPlanStats), /// True if argument is not NULL, false otherwise. This expression itself is never NULL. - IsNotNull(Box), + IsNotNull(Box, LogicalPlanStats), /// True if argument is NULL, false otherwise. This expression itself is never NULL. - IsNull(Box), + IsNull(Box, LogicalPlanStats), /// True if argument is true, false otherwise. This expression itself is never NULL. - IsTrue(Box), + IsTrue(Box, LogicalPlanStats), /// True if argument is false, false otherwise. This expression itself is never NULL. - IsFalse(Box), + IsFalse(Box, LogicalPlanStats), /// True if argument is NULL, false otherwise. This expression itself is never NULL. - IsUnknown(Box), + IsUnknown(Box, LogicalPlanStats), /// True if argument is FALSE or NULL, false otherwise. This expression itself is never NULL. - IsNotTrue(Box), + IsNotTrue(Box, LogicalPlanStats), /// True if argument is TRUE OR NULL, false otherwise. This expression itself is never NULL. - IsNotFalse(Box), + IsNotFalse(Box, LogicalPlanStats), /// True if argument is TRUE or FALSE, false otherwise. This expression itself is never NULL. - IsNotUnknown(Box), + IsNotUnknown(Box, LogicalPlanStats), /// arithmetic negation of an expression, the operand must be of a signed numeric data type - Negative(Box), + Negative(Box, LogicalPlanStats), /// Whether an expression is between a given range. - Between(Between), + Between(Between, LogicalPlanStats), /// The CASE expression is similar to a series of nested if/else and there are two forms that /// can be used. The first form consists of a series of boolean "when" expressions with /// corresponding "then" expressions, and an optional "else" expression. @@ -280,41 +282,41 @@ pub enum Expr { /// [ELSE result] /// END /// ``` - Case(Case), + Case(Case, LogicalPlanStats), /// Casts the expression to a given type and will return a runtime error if the expression cannot be cast. /// This expression is guaranteed to have a fixed type. - Cast(Cast), + Cast(Cast, LogicalPlanStats), /// Casts the expression to a given type and will return a null value if the expression cannot be cast. /// This expression is guaranteed to have a fixed type. - TryCast(TryCast), + TryCast(TryCast, LogicalPlanStats), /// Represents the call of a scalar function with a set of arguments. - ScalarFunction(ScalarFunction), + ScalarFunction(ScalarFunction, LogicalPlanStats), /// Calls an aggregate function with arguments, and optional /// `ORDER BY`, `FILTER`, `DISTINCT` and `NULL TREATMENT`. /// /// See also [`ExprFunctionExt`] to set these fields. /// /// [`ExprFunctionExt`]: crate::expr_fn::ExprFunctionExt - AggregateFunction(AggregateFunction), + AggregateFunction(AggregateFunction, LogicalPlanStats), /// Represents the call of a window function with arguments. - WindowFunction(WindowFunction), + WindowFunction(WindowFunction, LogicalPlanStats), /// Returns whether the list contains the expr value. - InList(InList), + InList(InList, LogicalPlanStats), /// EXISTS subquery - Exists(Exists), + Exists(Exists, LogicalPlanStats), /// IN subquery - InSubquery(InSubquery), + InSubquery(InSubquery, LogicalPlanStats), /// Scalar subquery - ScalarSubquery(Subquery), + ScalarSubquery(Subquery, LogicalPlanStats), /// Represents a reference to all available fields in a specific schema, /// with an optional (schema) qualifier. /// /// This expr has to be resolved to a list of columns before translating logical /// plan into physical plan. - Wildcard(Wildcard), + Wildcard(Wildcard, LogicalPlanStats), /// List of grouping set expressions. Only valid in the context of an aggregate /// GROUP BY expression list - GroupingSet(GroupingSet), + GroupingSet(GroupingSet, LogicalPlanStats), /// A place holder for parameters in a prepared statement /// (e.g. `$foo` or `$1`) Placeholder(Placeholder), @@ -322,7 +324,7 @@ pub enum Expr { /// in the outer query, used for correlated sub queries. OuterReferenceColumn(DataType, Column), /// Unnest expression - Unnest(Unnest), + Unnest(Unnest, LogicalPlanStats), } impl Default for Expr { @@ -371,6 +373,12 @@ pub struct Wildcard { pub options: WildcardOptions, } +impl Wildcard { + fn stats(&self) -> LogicalPlanStats { + self.options.stats() + } +} + /// UNNEST expression. #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct Unnest { @@ -389,6 +397,10 @@ impl Unnest { pub fn new_boxed(boxed: Box) -> Self { Self { expr: boxed } } + + fn stats(&self) -> LogicalPlanStats { + self.expr.stats() + } } /// Alias expression @@ -412,6 +424,10 @@ impl Alias { name: name.into(), } } + + fn stats(&self) -> LogicalPlanStats { + self.expr.stats() + } } /// Binary expression @@ -430,6 +446,10 @@ impl BinaryExpr { pub fn new(left: Box, op: Operator, right: Box) -> Self { Self { left, op, right } } + + fn stats(&self) -> LogicalPlanStats { + self.left.stats().merge(self.right.stats()) + } } impl Display for BinaryExpr { @@ -445,7 +465,7 @@ impl Display for BinaryExpr { precedence: u8, ) -> fmt::Result { match expr { - Expr::BinaryExpr(child) => { + Expr::BinaryExpr(child, _) => { let p = child.op.precedence(); if p == 0 || p < precedence { write!(f, "({child})")?; @@ -489,6 +509,18 @@ impl Case { else_expr, } } + + fn stats(&self) -> LogicalPlanStats { + self.expr + .iter() + .chain( + self.when_then_expr + .iter() + .flat_map(|(w, t)| vec![w, t]) + .chain(self.else_expr.iter()), + ) + .fold(LogicalPlanStats::empty(), |s, e| s.merge(e.stats())) + } } /// LIKE expression @@ -519,6 +551,10 @@ impl Like { case_insensitive, } } + + fn stats(&self) -> LogicalPlanStats { + self.expr.stats().merge(self.pattern.stats()) + } } /// BETWEEN expression @@ -544,6 +580,13 @@ impl Between { high, } } + + fn stats(&self) -> LogicalPlanStats { + self.expr + .stats() + .merge(self.low.stats()) + .merge(self.high.stats()) + } } /// ScalarFunction expression invokes a built-in scalar function @@ -560,13 +603,17 @@ impl ScalarFunction { pub fn name(&self) -> &str { self.func.name() } -} -impl ScalarFunction { /// Create a new ScalarFunction expression with a user-defined function (UDF) pub fn new_udf(udf: Arc, args: Vec) -> Self { Self { func: udf, args } } + + fn stats(&self) -> LogicalPlanStats { + self.args + .iter() + .fold(LogicalPlanStats::empty(), |s, e| s.merge(e.stats())) + } } /// Access a sub field of a nested type, such as `Field` or `List` @@ -598,6 +645,10 @@ impl Cast { pub fn new(expr: Box, data_type: DataType) -> Self { Self { expr, data_type } } + + fn stats(&self) -> LogicalPlanStats { + self.expr.stats() + } } /// TryCast Expression @@ -614,6 +665,10 @@ impl TryCast { pub fn new(expr: Box, data_type: DataType) -> Self { Self { expr, data_type } } + + fn stats(&self) -> LogicalPlanStats { + self.expr.stats() + } } /// SORT expression @@ -730,6 +785,14 @@ impl AggregateFunction { null_treatment, } } + + fn stats(&self) -> LogicalPlanStats { + self.args + .iter() + .chain(self.filter.iter().map(|e| e.as_ref())) + .chain(self.order_by.iter().flatten().map(|s| &s.expr)) + .fold(LogicalPlanStats::empty(), |s, e| s.merge(e.stats())) + } } /// A function used as a SQL window function @@ -842,6 +905,14 @@ impl WindowFunction { null_treatment: None, } } + + fn stats(&self) -> LogicalPlanStats { + self.args + .iter() + .chain(self.partition_by.iter()) + .chain(self.order_by.iter().map(|s| &s.expr)) + .fold(LogicalPlanStats::empty(), |s, e| s.merge(e.stats())) + } } /// EXISTS expression @@ -858,6 +929,10 @@ impl Exists { pub fn new(subquery: Subquery, negated: bool) -> Self { Self { subquery, negated } } + + fn stats(&self) -> LogicalPlanStats { + self.subquery.stats() + } } /// User Defined Aggregate Function @@ -912,6 +987,12 @@ impl InList { negated, } } + + fn stats(&self) -> LogicalPlanStats { + self.list + .iter() + .fold(self.expr.stats(), |s, e| s.merge(e.stats())) + } } /// IN subquery @@ -934,6 +1015,10 @@ impl InSubquery { negated, } } + + fn stats(&self) -> LogicalPlanStats { + self.expr.stats().merge(self.subquery.stats()) + } } /// Placeholder, representing bind parameter values such as `$1` or `$name`. @@ -991,6 +1076,18 @@ impl GroupingSet { } } } + + fn stats(&self) -> LogicalPlanStats { + match self { + GroupingSet::Rollup(exprs) | GroupingSet::Cube(exprs) => exprs + .iter() + .fold(LogicalPlanStats::empty(), |s, e| s.merge(e.stats())), + GroupingSet::GroupingSets(groups) => groups + .iter() + .flatten() + .fold(LogicalPlanStats::empty(), |s, e| s.merge(e.stats())), + } + } } /// Additional options for wildcards, e.g. Snowflake `EXCLUDE`/`RENAME` and Bigquery `EXCEPT`. @@ -1026,6 +1123,13 @@ impl WildcardOptions { rename: self.rename, } } + + fn stats(&self) -> LogicalPlanStats { + self.replace + .iter() + .flat_map(|prsi| prsi.planned_expressions.iter()) + .fold(LogicalPlanStats::empty(), |s, e| s.merge(e.stats())) + } } impl Display for WildcardOptions { @@ -1116,7 +1220,9 @@ impl Expr { pub fn qualified_name(&self) -> (Option, String) { match self { Expr::Column(Column { relation, name }) => (relation.clone(), name.clone()), - Expr::Alias(Alias { relation, name, .. }) => (relation.clone(), name.clone()), + Expr::Alias(Alias { relation, name, .. }, _) => { + (relation.clone(), name.clone()) + } _ => (None, self.schema_name().to_string()), } } @@ -1156,7 +1262,7 @@ impl Expr { Expr::Literal(..) => "Literal", Expr::Negative(..) => "Negative", Expr::Not(..) => "Not", - Expr::Placeholder(_) => "Placeholder", + Expr::Placeholder { .. } => "Placeholder", Expr::ScalarFunction(..) => "ScalarFunction", Expr::ScalarSubquery { .. } => "ScalarSubquery", Expr::ScalarVariable(..) => "ScalarVariable", @@ -1209,40 +1315,28 @@ impl Expr { /// Return `self LIKE other` pub fn like(self, other: Expr) -> Expr { - Expr::Like(Like::new( - false, - Box::new(self), - Box::new(other), - None, - false, - )) + let like = Like::new(false, Box::new(self), Box::new(other), None, false); + let stats = like.stats(); + Expr::Like(like, stats) } /// Return `self NOT LIKE other` pub fn not_like(self, other: Expr) -> Expr { - Expr::Like(Like::new( - true, - Box::new(self), - Box::new(other), - None, - false, - )) + let like = Like::new(true, Box::new(self), Box::new(other), None, false); + let stats = like.stats(); + Expr::Like(like, stats) } /// Return `self ILIKE other` pub fn ilike(self, other: Expr) -> Expr { - Expr::Like(Like::new( - false, - Box::new(self), - Box::new(other), - None, - true, - )) + let like = Like::new(false, Box::new(self), Box::new(other), None, true); + let stats = like.stats(); + Expr::Like(like, stats) } /// Return `self NOT ILIKE other` pub fn not_ilike(self, other: Expr) -> Expr { - Expr::Like(Like::new(true, Box::new(self), Box::new(other), None, true)) + Expr::_like(Like::new(true, Box::new(self), Box::new(other), None, true)) } /// Return the name to use for the specific Expr @@ -1263,7 +1357,9 @@ impl Expr { /// Return `self AS name` alias expression pub fn alias(self, name: impl Into) -> Expr { - Expr::Alias(Alias::new(self, None::<&str>, name.into())) + let alias = Alias::new(self, None::<&str>, name.into()); + let stats = alias.stats(); + Expr::Alias(alias, stats) } /// Return `self AS name` alias expression with a specific qualifier @@ -1272,7 +1368,9 @@ impl Expr { relation: Option>, name: impl Into, ) -> Expr { - Expr::Alias(Alias::new(self, relation, name.into())) + let alias = Alias::new(self, relation, name.into()); + let stats = alias.stats(); + Expr::Alias(alias, stats) } /// Remove an alias from an expression if one exists. @@ -1297,7 +1395,7 @@ impl Expr { /// ``` pub fn unalias(self) -> Expr { match self { - Expr::Alias(alias) => *alias.expr, + Expr::Alias(alias, _) => *alias.expr, _ => self, } } @@ -1328,7 +1426,9 @@ impl Expr { // f_down: skip subqueries. Check in f_down to avoid recursing into them let recursion = if matches!( expr, - Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::InSubquery(_) + Expr::Exists { .. } + | Expr::ScalarSubquery { .. } + | Expr::InSubquery { .. } ) { // Subqueries could contain aliases so don't recurse into those TreeNodeRecursion::Jump @@ -1340,7 +1440,7 @@ impl Expr { |expr| { // f_up: unalias on up so we can remove nested aliases like // `(x as foo) as bar` - if let Expr::Alias(Alias { expr, .. }) = expr { + if let Expr::Alias(Alias { expr, .. }, _) = expr { Ok(Transformed::yes(*expr)) } else { Ok(Transformed::no(expr)) @@ -1354,17 +1454,17 @@ impl Expr { /// Return `self IN ` if `negated` is false, otherwise /// return `self NOT IN `.a pub fn in_list(self, list: Vec, negated: bool) -> Expr { - Expr::InList(InList::new(Box::new(self), list, negated)) + Expr::_in_list(InList::new(Box::new(self), list, negated)) } /// Return `IsNull(Box(self)) pub fn is_null(self) -> Expr { - Expr::IsNull(Box::new(self)) + Expr::_is_null(Box::new(self)) } /// Return `IsNotNull(Box(self)) pub fn is_not_null(self) -> Expr { - Expr::IsNotNull(Box::new(self)) + Expr::_is_not_null(Box::new(self)) } /// Create a sort configuration from an existing expression. @@ -1379,37 +1479,37 @@ impl Expr { /// Return `IsTrue(Box(self))` pub fn is_true(self) -> Expr { - Expr::IsTrue(Box::new(self)) + Expr::_is_true(Box::new(self)) } /// Return `IsNotTrue(Box(self))` pub fn is_not_true(self) -> Expr { - Expr::IsNotTrue(Box::new(self)) + Expr::_is_not_true(Box::new(self)) } /// Return `IsFalse(Box(self))` pub fn is_false(self) -> Expr { - Expr::IsFalse(Box::new(self)) + Expr::_is_false(Box::new(self)) } /// Return `IsNotFalse(Box(self))` pub fn is_not_false(self) -> Expr { - Expr::IsNotFalse(Box::new(self)) + Expr::_is_not_false(Box::new(self)) } /// Return `IsUnknown(Box(self))` pub fn is_unknown(self) -> Expr { - Expr::IsUnknown(Box::new(self)) + Expr::_is_unknown(Box::new(self)) } /// Return `IsNotUnknown(Box(self))` pub fn is_not_unknown(self) -> Expr { - Expr::IsNotUnknown(Box::new(self)) + Expr::_is_not_unknown(Box::new(self)) } /// return `self BETWEEN low AND high` pub fn between(self, low: Expr, high: Expr) -> Expr { - Expr::Between(Between::new( + Expr::_between(Between::new( Box::new(self), false, Box::new(low), @@ -1419,7 +1519,7 @@ impl Expr { /// Return `self NOT BETWEEN low AND high` pub fn not_between(self, low: Expr, high: Expr) -> Expr { - Expr::Between(Between::new( + Expr::_between(Between::new( Box::new(self), true, Box::new(low), @@ -1469,7 +1569,7 @@ impl Expr { pub fn get_as_join_column(&self) -> Option<&Column> { match self { Expr::Column(c) => Some(c), - Expr::Cast(Cast { expr, .. }) => match &**expr { + Expr::Cast(Cast { expr, .. }, _) => match &**expr { Expr::Column(c) => Some(c), _ => None, }, @@ -1573,7 +1673,7 @@ impl Expr { /// - `rand()` returns `true`, /// - `a + rand()` returns `false` pub fn is_volatile_node(&self) -> bool { - matches!(self, Expr::ScalarFunction(func) if func.func.signature().volatility == Volatility::Volatile) + matches!(self, Expr::ScalarFunction(func, _) if func.func.signature().volatility == Volatility::Volatile) } /// Returns true if the expression is volatile, i.e. whether it can return different @@ -1600,16 +1700,19 @@ impl Expr { let mut has_placeholder = false; self.transform(|mut expr| { // Default to assuming the arguments are the same type - if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = &mut expr { + if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }, _) = &mut expr { rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?; rewrite_placeholder(right.as_mut(), left.as_ref(), schema)?; }; - if let Expr::Between(Between { - expr, - negated: _, - low, - high, - }) = &mut expr + if let Expr::Between( + Between { + expr, + negated: _, + low, + high, + }, + _, + ) = &mut expr { rewrite_placeholder(low.as_mut(), expr.as_ref(), schema)?; rewrite_placeholder(high.as_mut(), expr.as_ref(), schema)?; @@ -1627,8 +1730,8 @@ impl Expr { /// and thus any side effects (like divide by zero) may not be encountered pub fn short_circuits(&self) -> bool { match self { - Expr::ScalarFunction(ScalarFunction { func, .. }) => func.short_circuits(), - Expr::BinaryExpr(BinaryExpr { op, .. }) => { + Expr::ScalarFunction(ScalarFunction { func, .. }, _) => func.short_circuits(), + Expr::BinaryExpr(BinaryExpr { op, .. }, _) => { matches!(op, Operator::And | Operator::Or) } Expr::Case { .. } => true, @@ -1667,6 +1770,141 @@ impl Expr { | Expr::Placeholder(..) => false, } } + + pub fn wildcard(wildcard: Wildcard) -> Self { + let stats = wildcard.stats(); + Expr::Wildcard(wildcard, stats) + } + + pub fn binary_expr(binary_expr: BinaryExpr) -> Self { + let stats = binary_expr.stats(); + Expr::BinaryExpr(binary_expr, stats) + } + + pub fn similar_to(like: Like) -> Self { + let stats = like.stats(); + Expr::SimilarTo(like, stats) + } + + pub fn _like(like: Like) -> Self { + let stats = like.stats(); + Expr::Like(like, stats) + } + + pub fn unnest(unnest: Unnest) -> Self { + let stats = unnest.stats(); + Expr::Unnest(unnest, stats) + } + + pub fn in_subquery(in_subquery: InSubquery) -> Self { + let stats = in_subquery.stats(); + Expr::InSubquery(in_subquery, stats) + } + + pub fn scalar_subquery(subquery: Subquery) -> Self { + let stats = subquery.stats(); + Expr::ScalarSubquery(subquery, stats) + } + + pub fn _not(expr: Box) -> Self { + let stats = expr.stats(); + Expr::Not(expr, stats) + } + + pub fn _is_not_null(expr: Box) -> Self { + let stats = expr.stats(); + Expr::IsNotNull(expr, stats) + } + + pub fn _is_null(expr: Box) -> Self { + let stats = expr.stats(); + Expr::IsNull(expr, stats) + } + + pub fn _is_true(expr: Box) -> Self { + let stats = expr.stats(); + Expr::IsTrue(expr, stats) + } + + pub fn _is_false(expr: Box) -> Self { + let stats = expr.stats(); + Expr::IsFalse(expr, stats) + } + + pub fn _is_unknown(expr: Box) -> Self { + let stats = expr.stats(); + Expr::IsUnknown(expr, stats) + } + + pub fn _is_not_true(expr: Box) -> Self { + let stats = expr.stats(); + Expr::IsNotTrue(expr, stats) + } + + pub fn _is_not_false(expr: Box) -> Self { + let stats = expr.stats(); + Expr::IsNotFalse(expr, stats) + } + + pub fn _is_not_unknown(expr: Box) -> Self { + let stats = expr.stats(); + Expr::IsNotUnknown(expr, stats) + } + + pub fn negative(expr: Box) -> Self { + let stats = expr.stats(); + Expr::Negative(expr, stats) + } + + pub fn _between(between: Between) -> Self { + let stats = between.stats(); + Expr::Between(between, stats) + } + + pub fn case(case: Case) -> Self { + let stats = case.stats(); + Expr::Case(case, stats) + } + + pub fn cast(cast: Cast) -> Self { + let stats = cast.stats(); + Expr::Cast(cast, stats) + } + + pub fn try_cast(try_cast: TryCast) -> Self { + let stats = try_cast.stats(); + Expr::TryCast(try_cast, stats) + } + + pub fn scalar_function(scalar_function: ScalarFunction) -> Self { + let stats = scalar_function.stats(); + Expr::ScalarFunction(scalar_function, stats) + } + + pub fn window_function(window_function: WindowFunction) -> Self { + let stats = window_function.stats(); + Expr::WindowFunction(window_function, stats) + } + + pub fn aggregate_function(aggregate_function: AggregateFunction) -> Self { + let stats = aggregate_function.stats(); + Expr::AggregateFunction(aggregate_function, stats) + } + + pub fn grouping_set(grouping_set: GroupingSet) -> Self { + let stats = grouping_set.stats(); + Expr::GroupingSet(grouping_set, stats) + } + + pub fn _in_list(in_list: InList) -> Self { + let stats = in_list.stats(); + Expr::InList(in_list, stats) + } + + pub fn exists(exists: Exists) -> Self { + let stats = exists.stats(); + Expr::Exists(exists, stats) + } } impl HashNode for Expr { @@ -1676,11 +1914,14 @@ impl HashNode for Expr { fn hash_node(&self, state: &mut H) { mem::discriminant(self).hash(state); match self { - Expr::Alias(Alias { - expr: _expr, - relation, - name, - }) => { + Expr::Alias( + Alias { + expr: _, + relation, + name, + }, + _, + ) => { relation.hash(state); name.hash(state); } @@ -1694,122 +1935,143 @@ impl HashNode for Expr { Expr::Literal(scalar_value) => { scalar_value.hash(state); } - Expr::BinaryExpr(BinaryExpr { - left: _left, - op, - right: _right, - }) => { + Expr::BinaryExpr( + BinaryExpr { + left: _, + op, + right: _, + }, + _, + ) => { op.hash(state); } - Expr::Like(Like { - negated, - expr: _expr, - pattern: _pattern, - escape_char, - case_insensitive, - }) - | Expr::SimilarTo(Like { - negated, - expr: _expr, - pattern: _pattern, - escape_char, - case_insensitive, - }) => { + Expr::Like( + Like { + negated, + expr: _, + pattern: _, + escape_char, + case_insensitive, + }, + _, + ) + | Expr::SimilarTo( + Like { + negated, + expr: _, + pattern: _, + escape_char, + case_insensitive, + }, + _, + ) => { negated.hash(state); escape_char.hash(state); case_insensitive.hash(state); } - Expr::Not(_expr) - | Expr::IsNotNull(_expr) - | Expr::IsNull(_expr) - | Expr::IsTrue(_expr) - | Expr::IsFalse(_expr) - | Expr::IsUnknown(_expr) - | Expr::IsNotTrue(_expr) - | Expr::IsNotFalse(_expr) - | Expr::IsNotUnknown(_expr) - | Expr::Negative(_expr) => {} - Expr::Between(Between { - expr: _expr, - negated, - low: _low, - high: _high, - }) => { + Expr::Not(_, _) + | Expr::IsNotNull(_, _) + | Expr::IsNull(_, _) + | Expr::IsTrue(_, _) + | Expr::IsFalse(_, _) + | Expr::IsUnknown(_, _) + | Expr::IsNotTrue(_, _) + | Expr::IsNotFalse(_, _) + | Expr::IsNotUnknown(_, _) + | Expr::Negative(_, _) => {} + Expr::Between( + Between { + expr: _, + negated, + low: _, + high: _, + }, + _, + ) => { negated.hash(state); } - Expr::Case(Case { - expr: _expr, - when_then_expr: _when_then_expr, - else_expr: _else_expr, - }) => {} - Expr::Cast(Cast { - expr: _expr, - data_type, - }) - | Expr::TryCast(TryCast { - expr: _expr, - data_type, - }) => { + Expr::Case( + Case { + expr: _, + when_then_expr: _, + else_expr: _, + }, + _, + ) => {} + Expr::Cast(Cast { expr: _, data_type }, _) + | Expr::TryCast(TryCast { expr: _, data_type }, _) => { data_type.hash(state); } - Expr::ScalarFunction(ScalarFunction { func, args: _args }) => { + Expr::ScalarFunction(ScalarFunction { func, args: _ }, _) => { func.hash(state); } - Expr::AggregateFunction(AggregateFunction { - func, - args: _args, - distinct, - filter: _filter, - order_by: _order_by, - null_treatment, - }) => { + Expr::AggregateFunction( + AggregateFunction { + func, + args: _, + distinct, + filter: _, + order_by: _, + null_treatment, + }, + _, + ) => { func.hash(state); distinct.hash(state); null_treatment.hash(state); } - Expr::WindowFunction(WindowFunction { - fun, - args: _args, - partition_by: _partition_by, - order_by: _order_by, - window_frame, - null_treatment, - }) => { + Expr::WindowFunction( + WindowFunction { + fun, + args: _, + partition_by: _, + order_by: _, + window_frame, + null_treatment, + }, + _, + ) => { fun.hash(state); window_frame.hash(state); null_treatment.hash(state); } - Expr::InList(InList { - expr: _expr, - list: _list, - negated, - }) => { + Expr::InList( + InList { + expr: _, + list: _, + negated, + }, + _, + ) => { negated.hash(state); } - Expr::Exists(Exists { subquery, negated }) => { + Expr::Exists(Exists { subquery, negated }, _) => { subquery.hash(state); negated.hash(state); } - Expr::InSubquery(InSubquery { - expr: _expr, - subquery, - negated, - }) => { + Expr::InSubquery( + InSubquery { + expr: _, + subquery, + negated, + }, + _, + ) => { subquery.hash(state); negated.hash(state); } - Expr::ScalarSubquery(subquery) => { + Expr::ScalarSubquery(subquery, _) => { subquery.hash(state); } - Expr::Wildcard(wildcard) => { + Expr::Wildcard(wildcard, _) => { wildcard.hash(state); wildcard.hash(state); } - Expr::GroupingSet(grouping_set) => { + Expr::GroupingSet(grouping_set, _) => { mem::discriminant(grouping_set).hash(state); match grouping_set { - GroupingSet::Rollup(_exprs) | GroupingSet::Cube(_exprs) => {} - GroupingSet::GroupingSets(_exprs) => {} + GroupingSet::Rollup(_) | GroupingSet::Cube(_) => {} + GroupingSet::GroupingSets(_) => {} } } Expr::Placeholder(place_holder) => { @@ -1819,7 +2081,7 @@ impl HashNode for Expr { data_type.hash(state); column.hash(state); } - Expr::Unnest(Unnest { expr: _expr }) => {} + Expr::Unnest(Unnest { expr: _ }, _) => {} }; } } @@ -1867,14 +2129,17 @@ impl<'a> Display for SchemaDisplay<'a> { | Expr::Placeholder(_) | Expr::Wildcard { .. } => write!(f, "{}", self.0), - Expr::AggregateFunction(AggregateFunction { - func, - args, - distinct, - filter, - order_by, - null_treatment, - }) => { + Expr::AggregateFunction( + AggregateFunction { + func, + args, + distinct, + filter, + order_by, + null_treatment, + }, + _, + ) => { write!( f, "{}({}{})", @@ -1898,13 +2163,16 @@ impl<'a> Display for SchemaDisplay<'a> { Ok(()) } // Expr is not shown since it is aliased - Expr::Alias(Alias { name, .. }) => write!(f, "{name}"), - Expr::Between(Between { - expr, - negated, - low, - high, - }) => { + Expr::Alias(Alias { name, .. }, _) => write!(f, "{name}"), + Expr::Between( + Between { + expr, + negated, + low, + high, + }, + _, + ) => { if *negated { write!( f, @@ -1923,14 +2191,17 @@ impl<'a> Display for SchemaDisplay<'a> { ) } } - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + Expr::BinaryExpr(BinaryExpr { left, op, right }, _) => { write!(f, "{} {op} {}", SchemaDisplay(left), SchemaDisplay(right),) } - Expr::Case(Case { - expr, - when_then_expr, - else_expr, - }) => { + Expr::Case( + Case { + expr, + when_then_expr, + else_expr, + }, + _, + ) => { write!(f, "CASE ")?; if let Some(e) = expr { @@ -1953,14 +2224,17 @@ impl<'a> Display for SchemaDisplay<'a> { write!(f, "END") } // Cast expr is not shown to be consistant with Postgres and Spark - Expr::Cast(Cast { expr, .. }) | Expr::TryCast(TryCast { expr, .. }) => { + Expr::Cast(Cast { expr, .. }, _) | Expr::TryCast(TryCast { expr, .. }, _) => { write!(f, "{}", SchemaDisplay(expr)) } - Expr::InList(InList { - expr, - list, - negated, - }) => { + Expr::InList( + InList { + expr, + list, + negated, + }, + _, + ) => { let inlist_name = schema_name_from_exprs(list)?; if *negated { @@ -1969,50 +2243,53 @@ impl<'a> Display for SchemaDisplay<'a> { write!(f, "{} IN {}", SchemaDisplay(expr), inlist_name) } } - Expr::Exists(Exists { negated: true, .. }) => write!(f, "NOT EXISTS"), - Expr::Exists(Exists { negated: false, .. }) => write!(f, "EXISTS"), - Expr::GroupingSet(GroupingSet::Cube(exprs)) => { + Expr::Exists(Exists { negated: true, .. }, _) => write!(f, "NOT EXISTS"), + Expr::Exists(Exists { negated: false, .. }, _) => write!(f, "EXISTS"), + Expr::GroupingSet(GroupingSet::Cube(exprs), _) => { write!(f, "ROLLUP ({})", schema_name_from_exprs(exprs)?) } - Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => { + Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs), _) => { write!(f, "GROUPING SETS (")?; for exprs in lists_of_exprs.iter() { write!(f, "({})", schema_name_from_exprs(exprs)?)?; } write!(f, ")") } - Expr::GroupingSet(GroupingSet::Rollup(exprs)) => { + Expr::GroupingSet(GroupingSet::Rollup(exprs), _) => { write!(f, "ROLLUP ({})", schema_name_from_exprs(exprs)?) } - Expr::IsNull(expr) => write!(f, "{} IS NULL", SchemaDisplay(expr)), - Expr::IsNotNull(expr) => { + Expr::IsNull(expr, _) => write!(f, "{} IS NULL", SchemaDisplay(expr)), + Expr::IsNotNull(expr, _) => { write!(f, "{} IS NOT NULL", SchemaDisplay(expr)) } - Expr::IsUnknown(expr) => { + Expr::IsUnknown(expr, _) => { write!(f, "{} IS UNKNOWN", SchemaDisplay(expr)) } - Expr::IsNotUnknown(expr) => { + Expr::IsNotUnknown(expr, _) => { write!(f, "{} IS NOT UNKNOWN", SchemaDisplay(expr)) } - Expr::InSubquery(InSubquery { negated: true, .. }) => { + Expr::InSubquery(InSubquery { negated: true, .. }, _) => { write!(f, "NOT IN") } - Expr::InSubquery(InSubquery { negated: false, .. }) => write!(f, "IN"), - Expr::IsTrue(expr) => write!(f, "{} IS TRUE", SchemaDisplay(expr)), - Expr::IsFalse(expr) => write!(f, "{} IS FALSE", SchemaDisplay(expr)), - Expr::IsNotTrue(expr) => { + Expr::InSubquery(InSubquery { negated: false, .. }, _) => write!(f, "IN"), + Expr::IsTrue(expr, _) => write!(f, "{} IS TRUE", SchemaDisplay(expr)), + Expr::IsFalse(expr, _) => write!(f, "{} IS FALSE", SchemaDisplay(expr)), + Expr::IsNotTrue(expr, _) => { write!(f, "{} IS NOT TRUE", SchemaDisplay(expr)) } - Expr::IsNotFalse(expr) => { + Expr::IsNotFalse(expr, _) => { write!(f, "{} IS NOT FALSE", SchemaDisplay(expr)) } - Expr::Like(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - }) => { + Expr::Like( + Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + }, + _, + ) => { write!( f, "{} {}{} {}", @@ -2028,12 +2305,12 @@ impl<'a> Display for SchemaDisplay<'a> { Ok(()) } - Expr::Negative(expr) => write!(f, "(- {})", SchemaDisplay(expr)), - Expr::Not(expr) => write!(f, "NOT {}", SchemaDisplay(expr)), - Expr::Unnest(Unnest { expr }) => { + Expr::Negative(expr, _) => write!(f, "(- {})", SchemaDisplay(expr)), + Expr::Not(expr, _) => write!(f, "NOT {}", SchemaDisplay(expr)), + Expr::Unnest(Unnest { expr }, _) => { write!(f, "UNNEST({})", SchemaDisplay(expr)) } - Expr::ScalarFunction(ScalarFunction { func, args }) => { + Expr::ScalarFunction(ScalarFunction { func, args }, _) => { match func.schema_name(args) { Ok(name) => { write!(f, "{name}") @@ -2043,16 +2320,19 @@ impl<'a> Display for SchemaDisplay<'a> { } } } - Expr::ScalarSubquery(Subquery { subquery, .. }) => { + Expr::ScalarSubquery(Subquery { subquery, .. }, _) => { write!(f, "{}", subquery.schema().field(0).name()) } - Expr::SimilarTo(Like { - negated, - expr, - pattern, - escape_char, - .. - }) => { + Expr::SimilarTo( + Like { + negated, + expr, + pattern, + escape_char, + .. + }, + _, + ) => { write!( f, "{} {} {}", @@ -2070,14 +2350,17 @@ impl<'a> Display for SchemaDisplay<'a> { Ok(()) } - Expr::WindowFunction(WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - null_treatment, - }) => { + Expr::WindowFunction( + WindowFunction { + fun, + args, + partition_by, + order_by, + window_frame, + null_treatment, + }, + _, + ) => { write!( f, "{}({})", @@ -2158,12 +2441,12 @@ pub fn schema_name_from_sorts(sorts: &[Sort]) -> Result { impl Display for Expr { fn fmt(&self, f: &mut Formatter) -> fmt::Result { match self { - Expr::Alias(Alias { expr, name, .. }) => write!(f, "{expr} AS {name}"), + Expr::Alias(Alias { expr, name, .. }, _) => write!(f, "{expr} AS {name}"), Expr::Column(c) => write!(f, "{c}"), Expr::OuterReferenceColumn(_, c) => write!(f, "outer_ref({c})"), Expr::ScalarVariable(_, var_names) => write!(f, "{}", var_names.join(".")), Expr::Literal(v) => write!(f, "{v:?}"), - Expr::Case(case) => { + Expr::Case(case, _) => { write!(f, "CASE ")?; if let Some(e) = &case.expr { write!(f, "{e} ")?; @@ -2176,57 +2459,72 @@ impl Display for Expr { } write!(f, "END") } - Expr::Cast(Cast { expr, data_type }) => { + Expr::Cast(Cast { expr, data_type }, _) => { write!(f, "CAST({expr} AS {data_type:?})") } - Expr::TryCast(TryCast { expr, data_type }) => { + Expr::TryCast(TryCast { expr, data_type }, _) => { write!(f, "TRY_CAST({expr} AS {data_type:?})") } - Expr::Not(expr) => write!(f, "NOT {expr}"), - Expr::Negative(expr) => write!(f, "(- {expr})"), - Expr::IsNull(expr) => write!(f, "{expr} IS NULL"), - Expr::IsNotNull(expr) => write!(f, "{expr} IS NOT NULL"), - Expr::IsTrue(expr) => write!(f, "{expr} IS TRUE"), - Expr::IsFalse(expr) => write!(f, "{expr} IS FALSE"), - Expr::IsUnknown(expr) => write!(f, "{expr} IS UNKNOWN"), - Expr::IsNotTrue(expr) => write!(f, "{expr} IS NOT TRUE"), - Expr::IsNotFalse(expr) => write!(f, "{expr} IS NOT FALSE"), - Expr::IsNotUnknown(expr) => write!(f, "{expr} IS NOT UNKNOWN"), - Expr::Exists(Exists { - subquery, - negated: true, - }) => write!(f, "NOT EXISTS ({subquery:?})"), - Expr::Exists(Exists { - subquery, - negated: false, - }) => write!(f, "EXISTS ({subquery:?})"), - Expr::InSubquery(InSubquery { - expr, - subquery, - negated: true, - }) => write!(f, "{expr} NOT IN ({subquery:?})"), - Expr::InSubquery(InSubquery { - expr, - subquery, - negated: false, - }) => write!(f, "{expr} IN ({subquery:?})"), - Expr::ScalarSubquery(subquery) => write!(f, "({subquery:?})"), - Expr::BinaryExpr(expr) => write!(f, "{expr}"), - Expr::ScalarFunction(fun) => { + Expr::Not(expr, _) => write!(f, "NOT {expr}"), + Expr::Negative(expr, _) => write!(f, "(- {expr})"), + Expr::IsNull(expr, _) => write!(f, "{expr} IS NULL"), + Expr::IsNotNull(expr, _) => write!(f, "{expr} IS NOT NULL"), + Expr::IsTrue(expr, _) => write!(f, "{expr} IS TRUE"), + Expr::IsFalse(expr, _) => write!(f, "{expr} IS FALSE"), + Expr::IsUnknown(expr, _) => write!(f, "{expr} IS UNKNOWN"), + Expr::IsNotTrue(expr, _) => write!(f, "{expr} IS NOT TRUE"), + Expr::IsNotFalse(expr, _) => write!(f, "{expr} IS NOT FALSE"), + Expr::IsNotUnknown(expr, _) => write!(f, "{expr} IS NOT UNKNOWN"), + Expr::Exists( + Exists { + subquery, + negated: true, + }, + _, + ) => write!(f, "NOT EXISTS ({subquery:?})"), + Expr::Exists( + Exists { + subquery, + negated: false, + }, + _, + ) => write!(f, "EXISTS ({subquery:?})"), + Expr::InSubquery( + InSubquery { + expr, + subquery, + negated: true, + }, + _, + ) => write!(f, "{expr} NOT IN ({subquery:?})"), + Expr::InSubquery( + InSubquery { + expr, + subquery, + negated: false, + }, + _, + ) => write!(f, "{expr} IN ({subquery:?})"), + Expr::ScalarSubquery(subquery, _) => write!(f, "({subquery:?})"), + Expr::BinaryExpr(expr, _) => write!(f, "{expr}"), + Expr::ScalarFunction(fun, _) => { fmt_function(f, fun.name(), false, &fun.args, true) } // TODO: use udf's display_name, need to fix the seperator issue, // Expr::ScalarFunction(ScalarFunction { func, args }) => { // write!(f, "{}", func.display_name(args).unwrap()) // } - Expr::WindowFunction(WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - null_treatment, - }) => { + Expr::WindowFunction( + WindowFunction { + fun, + args, + partition_by, + order_by, + window_frame, + null_treatment, + }, + _, + ) => { fmt_function(f, &fun.to_string(), false, args, true)?; if let Some(nt) = null_treatment { @@ -2246,15 +2544,18 @@ impl Display for Expr { )?; Ok(()) } - Expr::AggregateFunction(AggregateFunction { - func, - distinct, - ref args, - filter, - order_by, - null_treatment, - .. - }) => { + Expr::AggregateFunction( + AggregateFunction { + func, + distinct, + ref args, + filter, + order_by, + null_treatment, + .. + }, + _, + ) => { fmt_function(f, func.name(), *distinct, args, true)?; if let Some(nt) = null_treatment { write!(f, " {}", nt)?; @@ -2267,25 +2568,31 @@ impl Display for Expr { } Ok(()) } - Expr::Between(Between { - expr, - negated, - low, - high, - }) => { + Expr::Between( + Between { + expr, + negated, + low, + high, + }, + _, + ) => { if *negated { write!(f, "{expr} NOT BETWEEN {low} AND {high}") } else { write!(f, "{expr} BETWEEN {low} AND {high}") } } - Expr::Like(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - }) => { + Expr::Like( + Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + }, + _, + ) => { write!(f, "{expr}")?; let op_name = if *case_insensitive { "ILIKE" } else { "LIKE" }; if *negated { @@ -2297,13 +2604,16 @@ impl Display for Expr { write!(f, " {op_name} {pattern}") } } - Expr::SimilarTo(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive: _, - }) => { + Expr::SimilarTo( + Like { + negated, + expr, + pattern, + escape_char, + case_insensitive: _, + }, + _, + ) => { write!(f, "{expr}")?; if *negated { write!(f, " NOT")?; @@ -2314,22 +2624,25 @@ impl Display for Expr { write!(f, " SIMILAR TO {pattern}") } } - Expr::InList(InList { - expr, - list, - negated, - }) => { + Expr::InList( + InList { + expr, + list, + negated, + }, + _, + ) => { if *negated { write!(f, "{expr} NOT IN ([{}])", expr_vec_fmt!(list)) } else { write!(f, "{expr} IN ([{}])", expr_vec_fmt!(list)) } } - Expr::Wildcard(Wildcard { qualifier, options }) => match qualifier { + Expr::Wildcard(Wildcard { qualifier, options }, _) => match qualifier { Some(qualifier) => write!(f, "{qualifier}.*{options}"), None => write!(f, "*{options}"), }, - Expr::GroupingSet(grouping_sets) => match grouping_sets { + Expr::GroupingSet(grouping_sets, _) => match grouping_sets { GroupingSet::Rollup(exprs) => { // ROLLUP (c0, c1, c2) write!(f, "ROLLUP ({})", expr_vec_fmt!(exprs)) @@ -2352,7 +2665,7 @@ impl Display for Expr { } }, Expr::Placeholder(Placeholder { id, .. }) => write!(f, "{id}"), - Expr::Unnest(Unnest { expr }) => { + Expr::Unnest(Unnest { expr }, _) => { write!(f, "UNNEST({expr})") } } @@ -2415,7 +2728,7 @@ mod test { #[test] #[allow(deprecated)] fn format_cast() -> Result<()> { - let expr = Expr::Cast(Cast { + let expr = Expr::cast(Cast { expr: Box::new(Expr::Literal(ScalarValue::Float32(Some(1.23)))), data_type: DataType::Utf8, }); @@ -2445,7 +2758,7 @@ mod test { fn test_collect_expr() -> Result<()> { // single column { - let expr = &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)); + let expr = &Expr::cast(Cast::new(Box::new(col("a")), DataType::Float64)); let columns = expr.column_refs(); assert_eq!(1, columns.len()); assert!(columns.contains(&Column::from_name("a"))); diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index c38ffb888f024..0bd41876f7664 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -121,7 +121,7 @@ pub fn placeholder(id: impl Into) -> Expr { /// assert_eq!(p.to_string(), "*") /// ``` pub fn wildcard() -> Expr { - Expr::Wildcard(Wildcard { + Expr::wildcard(Wildcard { qualifier: None, options: WildcardOptions::default(), }) @@ -129,7 +129,7 @@ pub fn wildcard() -> Expr { /// Create an '*' [`Expr::Wildcard`] expression with the wildcard options pub fn wildcard_with_options(options: WildcardOptions) -> Expr { - Expr::Wildcard(Wildcard { + Expr::wildcard(Wildcard { qualifier: None, options, }) @@ -146,7 +146,7 @@ pub fn wildcard_with_options(options: WildcardOptions) -> Expr { /// assert_eq!(p.to_string(), "t.*") /// ``` pub fn qualified_wildcard(qualifier: impl Into) -> Expr { - Expr::Wildcard(Wildcard { + Expr::wildcard(Wildcard { qualifier: Some(qualifier.into()), options: WildcardOptions::default(), }) @@ -157,7 +157,7 @@ pub fn qualified_wildcard_with_options( qualifier: impl Into, options: WildcardOptions, ) -> Expr { - Expr::Wildcard(Wildcard { + Expr::wildcard(Wildcard { qualifier: Some(qualifier.into()), options, }) @@ -165,12 +165,12 @@ pub fn qualified_wildcard_with_options( /// Return a new expression `left right` pub fn binary_expr(left: Expr, op: Operator, right: Expr) -> Expr { - Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right))) + Expr::binary_expr(BinaryExpr::new(Box::new(left), op, Box::new(right))) } /// Return a new expression with a logical AND pub fn and(left: Expr, right: Expr) -> Expr { - Expr::BinaryExpr(BinaryExpr::new( + Expr::binary_expr(BinaryExpr::new( Box::new(left), Operator::And, Box::new(right), @@ -179,7 +179,7 @@ pub fn and(left: Expr, right: Expr) -> Expr { /// Return a new expression with a logical OR pub fn or(left: Expr, right: Expr) -> Expr { - Expr::BinaryExpr(BinaryExpr::new( + Expr::binary_expr(BinaryExpr::new( Box::new(left), Operator::Or, Box::new(right), @@ -193,7 +193,7 @@ pub fn not(expr: Expr) -> Expr { /// Return a new expression with bitwise AND pub fn bitwise_and(left: Expr, right: Expr) -> Expr { - Expr::BinaryExpr(BinaryExpr::new( + Expr::binary_expr(BinaryExpr::new( Box::new(left), Operator::BitwiseAnd, Box::new(right), @@ -202,7 +202,7 @@ pub fn bitwise_and(left: Expr, right: Expr) -> Expr { /// Return a new expression with bitwise OR pub fn bitwise_or(left: Expr, right: Expr) -> Expr { - Expr::BinaryExpr(BinaryExpr::new( + Expr::binary_expr(BinaryExpr::new( Box::new(left), Operator::BitwiseOr, Box::new(right), @@ -211,7 +211,7 @@ pub fn bitwise_or(left: Expr, right: Expr) -> Expr { /// Return a new expression with bitwise XOR pub fn bitwise_xor(left: Expr, right: Expr) -> Expr { - Expr::BinaryExpr(BinaryExpr::new( + Expr::binary_expr(BinaryExpr::new( Box::new(left), Operator::BitwiseXor, Box::new(right), @@ -220,7 +220,7 @@ pub fn bitwise_xor(left: Expr, right: Expr) -> Expr { /// Return a new expression with bitwise SHIFT RIGHT pub fn bitwise_shift_right(left: Expr, right: Expr) -> Expr { - Expr::BinaryExpr(BinaryExpr::new( + Expr::binary_expr(BinaryExpr::new( Box::new(left), Operator::BitwiseShiftRight, Box::new(right), @@ -229,7 +229,7 @@ pub fn bitwise_shift_right(left: Expr, right: Expr) -> Expr { /// Return a new expression with bitwise SHIFT LEFT pub fn bitwise_shift_left(left: Expr, right: Expr) -> Expr { - Expr::BinaryExpr(BinaryExpr::new( + Expr::binary_expr(BinaryExpr::new( Box::new(left), Operator::BitwiseShiftLeft, Box::new(right), @@ -238,13 +238,13 @@ pub fn bitwise_shift_left(left: Expr, right: Expr) -> Expr { /// Create an in_list expression pub fn in_list(expr: Expr, list: Vec, negated: bool) -> Expr { - Expr::InList(InList::new(Box::new(expr), list, negated)) + Expr::_in_list(InList::new(Box::new(expr), list, negated)) } /// Create an EXISTS subquery expression pub fn exists(subquery: Arc) -> Expr { let outer_ref_columns = subquery.all_out_ref_exprs(); - Expr::Exists(Exists { + Expr::exists(Exists { subquery: Subquery { subquery, outer_ref_columns, @@ -256,7 +256,7 @@ pub fn exists(subquery: Arc) -> Expr { /// Create a NOT EXISTS subquery expression pub fn not_exists(subquery: Arc) -> Expr { let outer_ref_columns = subquery.all_out_ref_exprs(); - Expr::Exists(Exists { + Expr::exists(Exists { subquery: Subquery { subquery, outer_ref_columns, @@ -268,7 +268,7 @@ pub fn not_exists(subquery: Arc) -> Expr { /// Create an IN subquery expression pub fn in_subquery(expr: Expr, subquery: Arc) -> Expr { let outer_ref_columns = subquery.all_out_ref_exprs(); - Expr::InSubquery(InSubquery::new( + Expr::in_subquery(InSubquery::new( Box::new(expr), Subquery { subquery, @@ -281,7 +281,7 @@ pub fn in_subquery(expr: Expr, subquery: Arc) -> Expr { /// Create a NOT IN subquery expression pub fn not_in_subquery(expr: Expr, subquery: Arc) -> Expr { let outer_ref_columns = subquery.all_out_ref_exprs(); - Expr::InSubquery(InSubquery::new( + Expr::in_subquery(InSubquery::new( Box::new(expr), Subquery { subquery, @@ -294,7 +294,7 @@ pub fn not_in_subquery(expr: Expr, subquery: Arc) -> Expr { /// Create a scalar subquery expression pub fn scalar_subquery(subquery: Arc) -> Expr { let outer_ref_columns = subquery.all_out_ref_exprs(); - Expr::ScalarSubquery(Subquery { + Expr::scalar_subquery(Subquery { subquery, outer_ref_columns, }) @@ -302,62 +302,62 @@ pub fn scalar_subquery(subquery: Arc) -> Expr { /// Create a grouping set pub fn grouping_set(exprs: Vec>) -> Expr { - Expr::GroupingSet(GroupingSet::GroupingSets(exprs)) + Expr::grouping_set(GroupingSet::GroupingSets(exprs)) } /// Create a grouping set for all combination of `exprs` pub fn cube(exprs: Vec) -> Expr { - Expr::GroupingSet(GroupingSet::Cube(exprs)) + Expr::grouping_set(GroupingSet::Cube(exprs)) } /// Create a grouping set for rollup pub fn rollup(exprs: Vec) -> Expr { - Expr::GroupingSet(GroupingSet::Rollup(exprs)) + Expr::grouping_set(GroupingSet::Rollup(exprs)) } /// Create a cast expression pub fn cast(expr: Expr, data_type: DataType) -> Expr { - Expr::Cast(Cast::new(Box::new(expr), data_type)) + Expr::cast(Cast::new(Box::new(expr), data_type)) } /// Create a try cast expression pub fn try_cast(expr: Expr, data_type: DataType) -> Expr { - Expr::TryCast(TryCast::new(Box::new(expr), data_type)) + Expr::try_cast(TryCast::new(Box::new(expr), data_type)) } /// Create is null expression pub fn is_null(expr: Expr) -> Expr { - Expr::IsNull(Box::new(expr)) + Expr::_is_null(Box::new(expr)) } /// Create is true expression pub fn is_true(expr: Expr) -> Expr { - Expr::IsTrue(Box::new(expr)) + Expr::_is_true(Box::new(expr)) } /// Create is not true expression pub fn is_not_true(expr: Expr) -> Expr { - Expr::IsNotTrue(Box::new(expr)) + Expr::_is_not_true(Box::new(expr)) } /// Create is false expression pub fn is_false(expr: Expr) -> Expr { - Expr::IsFalse(Box::new(expr)) + Expr::_is_false(Box::new(expr)) } /// Create is not false expression pub fn is_not_false(expr: Expr) -> Expr { - Expr::IsNotFalse(Box::new(expr)) + Expr::_is_not_false(Box::new(expr)) } /// Create is unknown expression pub fn is_unknown(expr: Expr) -> Expr { - Expr::IsUnknown(Box::new(expr)) + Expr::_is_unknown(Box::new(expr)) } /// Create is not unknown expression pub fn is_not_unknown(expr: Expr) -> Expr { - Expr::IsNotUnknown(Box::new(expr)) + Expr::_is_not_unknown(Box::new(expr)) } /// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression. @@ -372,7 +372,7 @@ pub fn when(when: Expr, then: Expr) -> CaseBuilder { /// Create a Unnest expression pub fn unnest(expr: Expr) -> Expr { - Expr::Unnest(Unnest { + Expr::unnest(Unnest { expr: Box::new(expr), }) } @@ -812,7 +812,7 @@ impl ExprFuncBuilder { udaf.filter = filter.map(Box::new); udaf.distinct = distinct; udaf.null_treatment = null_treatment; - Expr::AggregateFunction(udaf) + Expr::aggregate_function(udaf) } ExprFuncKind::Window(mut udwf) => { let has_order_by = order_by.as_ref().map(|o| !o.is_empty()); @@ -821,7 +821,7 @@ impl ExprFuncBuilder { udwf.window_frame = window_frame.unwrap_or(WindowFrame::new(has_order_by)); udwf.null_treatment = null_treatment; - Expr::WindowFunction(udwf) + Expr::window_function(udwf) } }; @@ -871,10 +871,10 @@ impl ExprFunctionExt for ExprFuncBuilder { impl ExprFunctionExt for Expr { fn order_by(self, order_by: Vec) -> ExprFuncBuilder { let mut builder = match self { - Expr::AggregateFunction(udaf) => { + Expr::AggregateFunction(udaf, _) => { ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))) } - Expr::WindowFunction(udwf) => { + Expr::WindowFunction(udwf, _) => { ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))) } _ => ExprFuncBuilder::new(None), @@ -886,7 +886,7 @@ impl ExprFunctionExt for Expr { } fn filter(self, filter: Expr) -> ExprFuncBuilder { match self { - Expr::AggregateFunction(udaf) => { + Expr::AggregateFunction(udaf, _) => { let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))); builder.filter = Some(filter); @@ -897,7 +897,7 @@ impl ExprFunctionExt for Expr { } fn distinct(self) -> ExprFuncBuilder { match self { - Expr::AggregateFunction(udaf) => { + Expr::AggregateFunction(udaf, _) => { let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))); builder.distinct = true; @@ -911,10 +911,10 @@ impl ExprFunctionExt for Expr { null_treatment: impl Into>, ) -> ExprFuncBuilder { let mut builder = match self { - Expr::AggregateFunction(udaf) => { + Expr::AggregateFunction(udaf, _) => { ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))) } - Expr::WindowFunction(udwf) => { + Expr::WindowFunction(udwf, _) => { ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))) } _ => ExprFuncBuilder::new(None), @@ -927,7 +927,7 @@ impl ExprFunctionExt for Expr { fn partition_by(self, partition_by: Vec) -> ExprFuncBuilder { match self { - Expr::WindowFunction(udwf) => { + Expr::WindowFunction(udwf, _) => { let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))); builder.partition_by = Some(partition_by); builder @@ -938,7 +938,7 @@ impl ExprFunctionExt for Expr { fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder { match self { - Expr::WindowFunction(udwf) => { + Expr::WindowFunction(udwf, _) => { let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))); builder.window_frame = Some(window_frame); builder diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index b944428977c4c..cd1d2747afb68 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -82,13 +82,13 @@ pub fn normalize_col_with_schemas_and_ambiguity_check( using_columns: &[HashSet], ) -> Result { // Normalize column inside Unnest - if let Expr::Unnest(Unnest { expr }) = expr { + if let Expr::Unnest(Unnest { expr }, _) = expr { let e = normalize_col_with_schemas_and_ambiguity_check( expr.as_ref().clone(), schemas, using_columns, )?; - return Ok(Expr::Unnest(Unnest { expr: Box::new(e) })); + return Ok(Expr::unnest(Unnest { expr: Box::new(e) })); } expr.transform(|expr| { @@ -177,7 +177,7 @@ pub fn create_col_from_scalar_expr( subqry_alias: String, ) -> Result { match scalar_expr { - Expr::Alias(Alias { name, .. }) => Ok(Column::new( + Expr::Alias(Alias { name, .. }, _) => Ok(Column::new( Some::(subqry_alias.into()), name, )), @@ -225,10 +225,10 @@ pub fn coerce_plan_expr_for_schema( ) -> Result { match plan { // special case Projection to avoid adding multiple projections - LogicalPlan::Projection(Projection { expr, input, .. }) => { + LogicalPlan::Projection(Projection { expr, input, .. }, _) => { let new_exprs = coerce_exprs_for_schema(expr, input.schema(), schema)?; let projection = Projection::try_new(new_exprs, input)?; - Ok(LogicalPlan::Projection(projection)) + Ok(LogicalPlan::projection(projection)) } _ => { let exprs: Vec = plan.schema().iter().map(Expr::from).collect(); @@ -236,7 +236,7 @@ pub fn coerce_plan_expr_for_schema( let add_project = new_exprs.iter().any(|expr| expr.try_as_col().is_none()); if add_project { let projection = Projection::try_new(new_exprs, Arc::new(plan))?; - Ok(LogicalPlan::Projection(projection)) + Ok(LogicalPlan::projection(projection)) } else { Ok(plan) } @@ -256,7 +256,7 @@ fn coerce_exprs_for_schema( let new_type = dst_schema.field(idx).data_type(); if new_type != &expr.get_type(src_schema)? { match expr { - Expr::Alias(Alias { expr, name, .. }) => { + Expr::Alias(Alias { expr, name, .. }, _) => { Ok(expr.cast_to(new_type, src_schema)?.alias(name)) } Expr::Wildcard { .. } => Ok(expr), @@ -273,7 +273,7 @@ fn coerce_exprs_for_schema( #[inline] pub fn unalias(expr: Expr) -> Expr { match expr { - Expr::Alias(Alias { expr, .. }) => unalias(*expr), + Expr::Alias(Alias { expr, .. }, _) => unalias(*expr), _ => expr, } } @@ -310,11 +310,11 @@ impl NamePreserver { // so there is no need to preserve expression names to prevent a schema change. use_alias: !matches!( plan, - LogicalPlan::Filter(_) - | LogicalPlan::Join(_) - | LogicalPlan::TableScan(_) - | LogicalPlan::Limit(_) - | LogicalPlan::Statement(_) + LogicalPlan::Filter(_, _) + | LogicalPlan::Join(_, _) + | LogicalPlan::TableScan(_, _) + | LogicalPlan::Limit(_, _) + | LogicalPlan::Statement(_, _) ), } } @@ -514,7 +514,7 @@ mod test { // cast data types test_rewrite( col("a"), - Expr::Cast(Cast::new(Box::new(col("a")), DataType::Int32)), + Expr::cast(Cast::new(Box::new(col("a")), DataType::Int32)), ); // change literal type from i32 to i64 diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index f0d3d8fcd0c15..319d464799299 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -107,14 +107,16 @@ fn rewrite_in_terms_of_projection( if let Some(found) = proj_exprs.iter().find(|a| expr_match(&search_col, a)) { let found = found.clone(); return Ok(Transformed::yes(match normalized_expr { - Expr::Cast(Cast { expr: _, data_type }) => Expr::Cast(Cast { - expr: Box::new(found), - data_type, - }), - Expr::TryCast(TryCast { expr: _, data_type }) => Expr::TryCast(TryCast { + Expr::Cast(Cast { expr: _, data_type }, _) => Expr::cast(Cast { expr: Box::new(found), data_type, }), + Expr::TryCast(TryCast { expr: _, data_type }, _) => { + Expr::try_cast(TryCast { + expr: Box::new(found), + data_type, + }) + } _ => found, })); } @@ -128,7 +130,7 @@ fn rewrite_in_terms_of_projection( /// so avg(c) as average will match avgc fn expr_match(needle: &Expr, expr: &Expr) -> bool { // check inside aliases - if let Expr::Alias(Alias { expr, .. }) = &expr { + if let Expr::Alias(Alias { expr, .. }, _) = &expr { expr.as_ref() == needle } else { expr == needle diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index b1a461eca41db..ea16513738712 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -103,19 +103,19 @@ impl ExprSchemable for Expr { #[recursive] fn get_type(&self, schema: &dyn ExprSchema) -> Result { match self { - Expr::Alias(Alias { expr, name, .. }) => match &**expr { + Expr::Alias(Alias { expr, name, .. }, _) => match &**expr { Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type { None => schema.data_type(&Column::from_name(name)).cloned(), Some(dt) => Ok(dt.clone()), }, _ => expr.get_type(schema), }, - Expr::Negative(expr) => expr.get_type(schema), + Expr::Negative(expr, _) => expr.get_type(schema), Expr::Column(c) => Ok(schema.data_type(c)?.clone()), Expr::OuterReferenceColumn(ty, _) => Ok(ty.clone()), Expr::ScalarVariable(ty, _) => Ok(ty.clone()), Expr::Literal(l) => Ok(l.data_type()), - Expr::Case(case) => { + Expr::Case(case, _) => { for (_, then_expr) in &case.when_then_expr { let then_type = then_expr.get_type(schema)?; if !then_type.is_null() { @@ -126,9 +126,9 @@ impl ExprSchemable for Expr { .as_ref() .map_or(Ok(DataType::Null), |e| e.get_type(schema)) } - Expr::Cast(Cast { data_type, .. }) - | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()), - Expr::Unnest(Unnest { expr }) => { + Expr::Cast(Cast { data_type, .. }, _) + | Expr::TryCast(TryCast { data_type, .. }, _) => Ok(data_type.clone()), + Expr::Unnest(Unnest { expr }, _) => { let arg_data_type = expr.get_type(schema)?; // Unnest's output type is the inner type of the list match arg_data_type { @@ -146,7 +146,7 @@ impl ExprSchemable for Expr { } } } - Expr::ScalarFunction(ScalarFunction { func, args }) => { + Expr::ScalarFunction(ScalarFunction { func, args }, _) => { let arg_data_types = args .iter() .map(|e| e.get_type(schema)) @@ -170,10 +170,10 @@ impl ExprSchemable for Expr { // expressiveness of `TypeSignature`), then infer return type Ok(func.return_type_from_exprs(args, schema, &new_data_types)?) } - Expr::WindowFunction(window_function) => self + Expr::WindowFunction(window_function, _) => self .data_type_and_nullable_with_window_function(schema, window_function) .map(|(return_type, _)| return_type), - Expr::AggregateFunction(AggregateFunction { func, args, .. }) => { + Expr::AggregateFunction(AggregateFunction { func, args, .. }, _) => { let data_types = args .iter() .map(|e| e.get_type(schema)) @@ -192,27 +192,30 @@ impl ExprSchemable for Expr { })?; Ok(func.return_type(&new_types)?) } - Expr::Not(_) - | Expr::IsNull(_) - | Expr::Exists { .. } - | Expr::InSubquery(_) + Expr::Not(_, _) + | Expr::IsNull(_, _) + | Expr::Exists(_, _) + | Expr::InSubquery(_, _) | Expr::Between { .. } | Expr::InList { .. } - | Expr::IsNotNull(_) - | Expr::IsTrue(_) - | Expr::IsFalse(_) - | Expr::IsUnknown(_) - | Expr::IsNotTrue(_) - | Expr::IsNotFalse(_) - | Expr::IsNotUnknown(_) => Ok(DataType::Boolean), - Expr::ScalarSubquery(subquery) => { + | Expr::IsNotNull(_, _) + | Expr::IsTrue(_, _) + | Expr::IsFalse(_, _) + | Expr::IsUnknown(_, _) + | Expr::IsNotTrue(_, _) + | Expr::IsNotFalse(_, _) + | Expr::IsNotUnknown(_, _) => Ok(DataType::Boolean), + Expr::ScalarSubquery(subquery, _) => { Ok(subquery.subquery.schema().field(0).data_type().clone()) } - Expr::BinaryExpr(BinaryExpr { - ref left, - ref right, - ref op, - }) => get_result_type(&left.get_type(schema)?, op, &right.get_type(schema)?), + Expr::BinaryExpr( + BinaryExpr { + ref left, + ref right, + ref op, + }, + _, + ) => get_result_type(&left.get_type(schema)?, op, &right.get_type(schema)?), Expr::Like { .. } | Expr::SimilarTo { .. } => Ok(DataType::Boolean), Expr::Placeholder(Placeholder { data_type, .. }) => { data_type.clone().ok_or_else(|| { @@ -224,7 +227,7 @@ impl ExprSchemable for Expr { }) } Expr::Wildcard { .. } => Ok(DataType::Null), - Expr::GroupingSet(_) => { + Expr::GroupingSet(_, _) => { // Grouping sets do not really have a type and do not appear in projections Ok(DataType::Null) } @@ -244,11 +247,11 @@ impl ExprSchemable for Expr { /// column that does not exist in the schema. fn nullable(&self, input_schema: &dyn ExprSchema) -> Result { match self { - Expr::Alias(Alias { expr, .. }) | Expr::Not(expr) | Expr::Negative(expr) => { - expr.nullable(input_schema) - } + Expr::Alias(Alias { expr, .. }, _) + | Expr::Not(expr, _) + | Expr::Negative(expr, _) => expr.nullable(input_schema), - Expr::InList(InList { expr, list, .. }) => { + Expr::InList(InList { expr, list, .. }, _) => { // Avoid inspecting too many expressions. const MAX_INSPECT_LIMIT: usize = 6; // Stop if a nullable expression is found or an error occurs. @@ -271,16 +274,19 @@ impl ExprSchemable for Expr { }) } - Expr::Between(Between { - expr, low, high, .. - }) => Ok(expr.nullable(input_schema)? + Expr::Between( + Between { + expr, low, high, .. + }, + _, + ) => Ok(expr.nullable(input_schema)? || low.nullable(input_schema)? || high.nullable(input_schema)?), Expr::Column(c) => input_schema.nullable(c), Expr::OuterReferenceColumn(_, _) => Ok(true), Expr::Literal(value) => Ok(value.is_null()), - Expr::Case(case) => { + Expr::Case(case, _) => { // This expression is nullable if any of the input expressions are nullable let then_nullable = case .when_then_expr @@ -297,14 +303,14 @@ impl ExprSchemable for Expr { Ok(true) } } - Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema), - Expr::ScalarFunction(ScalarFunction { func, args }) => { + Expr::Cast(Cast { expr, .. }, _) => expr.nullable(input_schema), + Expr::ScalarFunction(ScalarFunction { func, args }, _) => { Ok(func.is_nullable(args, input_schema)) } - Expr::AggregateFunction(AggregateFunction { func, .. }) => { + Expr::AggregateFunction(AggregateFunction { func, .. }, _) => { Ok(func.is_nullable()) } - Expr::WindowFunction(window_function) => self + Expr::WindowFunction(window_function, _) => self .data_type_and_nullable_with_window_function( input_schema, window_function, @@ -312,32 +318,35 @@ impl ExprSchemable for Expr { .map(|(_, nullable)| nullable), Expr::ScalarVariable(_, _) | Expr::TryCast { .. } - | Expr::Unnest(_) + | Expr::Unnest(_, _) | Expr::Placeholder(_) => Ok(true), - Expr::IsNull(_) - | Expr::IsNotNull(_) - | Expr::IsTrue(_) - | Expr::IsFalse(_) - | Expr::IsUnknown(_) - | Expr::IsNotTrue(_) - | Expr::IsNotFalse(_) - | Expr::IsNotUnknown(_) + Expr::IsNull(_, _) + | Expr::IsNotNull(_, _) + | Expr::IsTrue(_, _) + | Expr::IsFalse(_, _) + | Expr::IsUnknown(_, _) + | Expr::IsNotTrue(_, _) + | Expr::IsNotFalse(_, _) + | Expr::IsNotUnknown(_, _) | Expr::Exists { .. } => Ok(false), - Expr::InSubquery(InSubquery { expr, .. }) => expr.nullable(input_schema), - Expr::ScalarSubquery(subquery) => { + Expr::InSubquery(InSubquery { expr, .. }, _) => expr.nullable(input_schema), + Expr::ScalarSubquery(subquery, _) => { Ok(subquery.subquery.schema().field(0).is_nullable()) } - Expr::BinaryExpr(BinaryExpr { - ref left, - ref right, - .. - }) => Ok(left.nullable(input_schema)? || right.nullable(input_schema)?), - Expr::Like(Like { expr, pattern, .. }) - | Expr::SimilarTo(Like { expr, pattern, .. }) => { + Expr::BinaryExpr( + BinaryExpr { + ref left, + ref right, + .. + }, + _, + ) => Ok(left.nullable(input_schema)? || right.nullable(input_schema)?), + Expr::Like(Like { expr, pattern, .. }, _) + | Expr::SimilarTo(Like { expr, pattern, .. }, _) => { Ok(expr.nullable(input_schema)? || pattern.nullable(input_schema)?) } Expr::Wildcard { .. } => Ok(false), - Expr::GroupingSet(_) => { + Expr::GroupingSet(_, _) => { // Grouping sets do not really have the concept of nullable and do not appear // in projections Ok(true) @@ -348,8 +357,8 @@ impl ExprSchemable for Expr { fn metadata(&self, schema: &dyn ExprSchema) -> Result> { match self { Expr::Column(c) => Ok(schema.metadata(c)?.clone()), - Expr::Alias(Alias { expr, .. }) => expr.metadata(schema), - Expr::Cast(Cast { expr, .. }) => expr.metadata(schema), + Expr::Alias(Alias { expr, .. }, _) => expr.metadata(schema), + Expr::Cast(Cast { expr, .. }, _) => expr.metadata(schema), _ => Ok(HashMap::new()), } } @@ -369,7 +378,7 @@ impl ExprSchemable for Expr { schema: &dyn ExprSchema, ) -> Result<(DataType, bool)> { match self { - Expr::Alias(Alias { expr, name, .. }) => match &**expr { + Expr::Alias(Alias { expr, name, .. }, _) => match &**expr { Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type { None => schema .data_type_and_nullable(&Column::from_name(name)) @@ -378,36 +387,39 @@ impl ExprSchemable for Expr { }, _ => expr.data_type_and_nullable(schema), }, - Expr::Negative(expr) => expr.data_type_and_nullable(schema), + Expr::Negative(expr, _) => expr.data_type_and_nullable(schema), Expr::Column(c) => schema .data_type_and_nullable(c) .map(|(d, n)| (d.clone(), n)), Expr::OuterReferenceColumn(ty, _) => Ok((ty.clone(), true)), Expr::ScalarVariable(ty, _) => Ok((ty.clone(), true)), Expr::Literal(l) => Ok((l.data_type(), l.is_null())), - Expr::IsNull(_) - | Expr::IsNotNull(_) - | Expr::IsTrue(_) - | Expr::IsFalse(_) - | Expr::IsUnknown(_) - | Expr::IsNotTrue(_) - | Expr::IsNotFalse(_) - | Expr::IsNotUnknown(_) + Expr::IsNull(_, _) + | Expr::IsNotNull(_, _) + | Expr::IsTrue(_, _) + | Expr::IsFalse(_, _) + | Expr::IsUnknown(_, _) + | Expr::IsNotTrue(_, _) + | Expr::IsNotFalse(_, _) + | Expr::IsNotUnknown(_, _) | Expr::Exists { .. } => Ok((DataType::Boolean, false)), - Expr::ScalarSubquery(subquery) => Ok(( + Expr::ScalarSubquery(subquery, _) => Ok(( subquery.subquery.schema().field(0).data_type().clone(), subquery.subquery.schema().field(0).is_nullable(), )), - Expr::BinaryExpr(BinaryExpr { - ref left, - ref right, - ref op, - }) => { + Expr::BinaryExpr( + BinaryExpr { + ref left, + ref right, + ref op, + }, + _, + ) => { let left = left.data_type_and_nullable(schema)?; let right = right.data_type_and_nullable(schema)?; Ok((get_result_type(&left.0, op, &right.0)?, left.1 || right.1)) } - Expr::WindowFunction(window_function) => { + Expr::WindowFunction(window_function, _) => { self.data_type_and_nullable_with_window_function(schema, window_function) } _ => Ok((self.get_type(schema)?, self.nullable(schema)?)), @@ -448,10 +460,10 @@ impl ExprSchemable for Expr { if can_cast_types(&this_type, cast_to_type) { match self { - Expr::ScalarSubquery(subquery) => { - Ok(Expr::ScalarSubquery(cast_subquery(subquery, cast_to_type)?)) - } - _ => Ok(Expr::Cast(Cast::new(Box::new(self), cast_to_type.clone()))), + Expr::ScalarSubquery(subquery, _) => Ok(Expr::scalar_subquery( + cast_subquery(subquery, cast_to_type)?, + )), + _ => Ok(Expr::cast(Cast::new(Box::new(self), cast_to_type.clone()))), } } else { plan_err!("Cannot automatically convert {this_type:?} to {cast_to_type:?}") @@ -538,11 +550,11 @@ pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result { + LogicalPlan::Projection(projection, _) => { let cast_expr = projection.expr[0] .clone() .cast_to(cast_to_type, projection.input.schema())?; - LogicalPlan::Projection(Projection::try_new( + LogicalPlan::projection(Projection::try_new( vec![cast_expr], Arc::clone(&projection.input), )?) @@ -550,7 +562,7 @@ pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result { let cast_expr = Expr::Column(Column::from(plan.schema().qualified_field(0))) .cast_to(cast_to_type, subquery.subquery.schema())?; - LogicalPlan::Projection(Projection::try_new( + LogicalPlan::projection(Projection::try_new( vec![cast_expr], subquery.subquery, )?) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 90235e3f84c48..5c5be53919ee0 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -164,7 +164,7 @@ impl LogicalPlanBuilder { // Ensure that the recursive term has the same field types as the static term let coerced_recursive_term = coerce_plan_expr_for_schema(recursive_term, self.plan.schema())?; - Ok(Self::from(LogicalPlan::RecursiveQuery(RecursiveQuery { + Ok(Self::from(LogicalPlan::recursive_query(RecursiveQuery { name, static_term: self.plan, recursive_term: Arc::new(coerced_recursive_term), @@ -326,7 +326,7 @@ impl LogicalPlanBuilder { let dfschema = DFSchema::from_unqualified_fields(fields.into(), HashMap::new())?; let schema = DFSchemaRef::new(dfschema); - Ok(Self::new(LogicalPlan::Values(Values { schema, values }))) + Ok(Self::new(LogicalPlan::values(Values { schema, values }))) } /// Convert a table provider into a builder with a TableScan @@ -377,7 +377,7 @@ impl LogicalPlanBuilder { options: HashMap, partition_by: Vec, ) -> Result { - Ok(Self::new(LogicalPlan::Copy(CopyTo { + Ok(Self::new(LogicalPlan::copy(CopyTo { input: Arc::new(input), output_url, partition_by, @@ -395,7 +395,7 @@ impl LogicalPlanBuilder { ) -> Result { let table_schema = table_schema.clone().to_dfschema_ref()?; - Ok(Self::new(LogicalPlan::Dml(DmlStatement::new( + Ok(Self::new(LogicalPlan::dml(DmlStatement::new( table_name.into(), table_schema, WriteOp::Insert(insert_op), @@ -411,7 +411,7 @@ impl LogicalPlanBuilder { filters: Vec, ) -> Result { TableScan::try_new(table_name, table_source, projection, filters, None) - .map(LogicalPlan::TableScan) + .map(LogicalPlan::table_scan) .map(Self::new) } @@ -424,7 +424,7 @@ impl LogicalPlanBuilder { fetch: Option, ) -> Result { TableScan::try_new(table_name, table_source, projection, filters, fetch) - .map(LogicalPlan::TableScan) + .map(LogicalPlan::table_scan) .map(Self::new) } @@ -486,7 +486,7 @@ impl LogicalPlanBuilder { pub fn filter(self, expr: impl Into) -> Result { let expr = normalize_col(expr.into(), &self.plan)?; Filter::try_new(expr, self.plan) - .map(LogicalPlan::Filter) + .map(LogicalPlan::filter) .map(Self::new) } @@ -494,13 +494,13 @@ impl LogicalPlanBuilder { pub fn having(self, expr: impl Into) -> Result { let expr = normalize_col(expr.into(), &self.plan)?; Filter::try_new_with_having(expr, self.plan) - .map(LogicalPlan::Filter) + .map(LogicalPlan::filter) .map(Self::from) } /// Make a builder for a prepare logical plan from the builder's plan pub fn prepare(self, name: String, data_types: Vec) -> Result { - Ok(Self::new(LogicalPlan::Statement(Statement::Prepare( + Ok(Self::new(LogicalPlan::statement(Statement::Prepare( Prepare { name, data_types, @@ -529,7 +529,7 @@ impl LogicalPlanBuilder { /// /// Similar to `limit` but uses expressions for `skip` and `fetch` pub fn limit_by_expr(self, skip: Option, fetch: Option) -> Result { - Ok(Self::new(LogicalPlan::Limit(Limit { + Ok(Self::new(LogicalPlan::limit(Limit { skip: skip.map(Box::new), fetch: fetch.map(Box::new), input: self.plan, @@ -575,11 +575,14 @@ impl LogicalPlanBuilder { is_distinct: bool, ) -> Result { match curr_plan { - LogicalPlan::Projection(Projection { - input, - mut expr, - schema: _, - }) if missing_cols.iter().all(|c| input.schema().has_column(c)) => { + LogicalPlan::Projection( + Projection { + input, + mut expr, + schema: _, + }, + _, + ) if missing_cols.iter().all(|c| input.schema().has_column(c)) => { let mut missing_exprs = missing_cols .iter() .map(|c| normalize_col(Expr::Column(c.clone()), &input)) @@ -597,7 +600,7 @@ impl LogicalPlanBuilder { } _ => { let is_distinct = - is_distinct || matches!(curr_plan, LogicalPlan::Distinct(_)); + is_distinct || matches!(curr_plan, LogicalPlan::Distinct(_, _)); let new_inputs = curr_plan .inputs() .into_iter() @@ -632,7 +635,7 @@ impl LogicalPlanBuilder { // As described in https://github.com/apache/datafusion/issues/5293 let all_aliases = missing_exprs.iter().all(|e| { projection_exprs.iter().any(|proj_expr| { - if let Expr::Alias(Alias { expr, .. }) = proj_expr { + if let Expr::Alias(Alias { expr, .. }, _) = proj_expr { e == expr.as_ref() } else { false @@ -696,7 +699,7 @@ impl LogicalPlanBuilder { })?; if missing_cols.is_empty() { - return Ok(Self::new(LogicalPlan::Sort(Sort { + return Ok(Self::new(LogicalPlan::sort(Sort { expr: normalize_sorts(sorts, &self.plan)?, input: self.plan, fetch, @@ -712,14 +715,14 @@ impl LogicalPlanBuilder { &missing_cols, is_distinct, )?; - let sort_plan = LogicalPlan::Sort(Sort { + let sort_plan = LogicalPlan::sort(Sort { expr: normalize_sorts(sorts, &plan)?, input: Arc::new(plan), fetch, }); Projection::try_new(new_expr, Arc::new(sort_plan)) - .map(LogicalPlan::Projection) + .map(LogicalPlan::projection) .map(Self::new) } @@ -733,14 +736,14 @@ impl LogicalPlanBuilder { let left_plan: LogicalPlan = Arc::unwrap_or_clone(self.plan); let right_plan: LogicalPlan = plan; - Ok(Self::new(LogicalPlan::Distinct(Distinct::All(Arc::new( + Ok(Self::new(LogicalPlan::distinct(Distinct::All(Arc::new( union(left_plan, right_plan)?, ))))) } /// Apply deduplication: Only distinct (different) values are returned) pub fn distinct(self) -> Result { - Ok(Self::new(LogicalPlan::Distinct(Distinct::All(self.plan)))) + Ok(Self::new(LogicalPlan::distinct(Distinct::All(self.plan)))) } /// Project first values of the specified expression list according to the provided @@ -751,7 +754,7 @@ impl LogicalPlanBuilder { select_expr: Vec, sort_expr: Option>, ) -> Result { - Ok(Self::new(LogicalPlan::Distinct(Distinct::On( + Ok(Self::new(LogicalPlan::distinct(Distinct::On( DistinctOn::try_new(on_expr, select_expr, sort_expr, self.plan)?, )))) } @@ -967,7 +970,7 @@ impl LogicalPlanBuilder { let join_schema = build_join_schema(self.plan.schema(), right.schema(), &join_type)?; - Ok(Self::new(LogicalPlan::Join(Join { + Ok(Self::new(LogicalPlan::join(Join { left: self.plan, right: Arc::new(right), on, @@ -1031,7 +1034,7 @@ impl LogicalPlanBuilder { DataFusionError::Internal("filters should not be None here".to_string()) })?) } else { - Ok(Self::new(LogicalPlan::Join(Join { + Ok(Self::new(LogicalPlan::join(Join { left: self.plan, right: Arc::new(right), on: join_on, @@ -1048,7 +1051,7 @@ impl LogicalPlanBuilder { pub fn cross_join(self, right: LogicalPlan) -> Result { let join_schema = build_join_schema(self.plan.schema(), right.schema(), &JoinType::Inner)?; - Ok(Self::new(LogicalPlan::Join(Join { + Ok(Self::new(LogicalPlan::join(Join { left: self.plan, right: Arc::new(right), on: vec![], @@ -1062,7 +1065,7 @@ impl LogicalPlanBuilder { /// Repartition pub fn repartition(self, partitioning_scheme: Partitioning) -> Result { - Ok(Self::new(LogicalPlan::Repartition(Repartition { + Ok(Self::new(LogicalPlan::repartition(Repartition { input: self.plan, partitioning_scheme, }))) @@ -1075,7 +1078,7 @@ impl LogicalPlanBuilder { ) -> Result { let window_expr = normalize_cols(window_expr, &self.plan)?; validate_unique_names("Windows", &window_expr)?; - Ok(Self::new(LogicalPlan::Window(Window::try_new( + Ok(Self::new(LogicalPlan::window(Window::try_new( window_expr, self.plan, )?))) @@ -1095,7 +1098,7 @@ impl LogicalPlanBuilder { let group_expr = add_group_by_exprs_from_dependencies(group_expr, self.plan.schema())?; Aggregate::try_new(self.plan, group_expr, aggr_expr) - .map(LogicalPlan::Aggregate) + .map(LogicalPlan::aggregate) .map(Self::new) } @@ -1110,7 +1113,7 @@ impl LogicalPlanBuilder { let schema = schema.to_dfschema_ref()?; if analyze { - Ok(Self::new(LogicalPlan::Analyze(Analyze { + Ok(Self::new(LogicalPlan::analyze(Analyze { verbose, input: self.plan, schema, @@ -1119,7 +1122,7 @@ impl LogicalPlanBuilder { let stringified_plans = vec![self.plan.to_stringified(PlanType::InitialLogicalPlan)]; - Ok(Self::new(LogicalPlan::Explain(Explain { + Ok(Self::new(LogicalPlan::explain(Explain { verbose, plan: self.plan, stringified_plans, @@ -1266,7 +1269,7 @@ impl LogicalPlanBuilder { let join_schema = build_join_schema(self.plan.schema(), right.schema(), &join_type)?; - Ok(Self::new(LogicalPlan::Join(Join { + Ok(Self::new(LogicalPlan::join(Join { left: self.plan, right: Arc::new(right), on: join_key_pairs, @@ -1534,7 +1537,7 @@ pub fn union(left_plan: LogicalPlan, right_plan: LogicalPlan) -> Result, ) -> Result { - SubqueryAlias::try_new(Arc::new(plan), alias).map(LogicalPlan::SubqueryAlias) + SubqueryAlias::try_new(Arc::new(plan), alias).map(LogicalPlan::subquery_alias) } /// Create a LogicalPlanBuilder representing a scan of a table with the provided name and schema. @@ -1641,7 +1644,7 @@ pub fn wrap_projection_for_join_if_necessary( // // then a and cast(a as int) will use the same field name - `a` in projection schema. // https://github.com/apache/datafusion/issues/4478 - if matches!(key, Expr::Cast(_)) || matches!(key, Expr::TryCast(_)) { + if matches!(key, Expr::Cast(_, _)) || matches!(key, Expr::TryCast(_, _)) { let alias = format!("{key}"); key.clone().alias(alias) } else { @@ -1948,7 +1951,7 @@ pub fn unnest_with_options( let deps = input_schema.functional_dependencies().clone(); let schema = Arc::new(df_schema.with_functional_dependencies(deps)?); - Ok(LogicalPlan::Unnest(Unnest { + Ok(LogicalPlan::unnest(Unnest { input: Arc::new(input), exec_columns: columns_to_unnest, list_type_columns: list_columns, diff --git a/datafusion/expr/src/logical_plan/ddl.rs b/datafusion/expr/src/logical_plan/ddl.rs index 8c64a017988e9..bc9ebf6752bba 100644 --- a/datafusion/expr/src/logical_plan/ddl.rs +++ b/datafusion/expr/src/logical_plan/ddl.rs @@ -25,6 +25,7 @@ use std::{ }; use crate::expr::Sort; +use crate::logical_plan::tree_node::LogicalPlanStats; use arrow::datatypes::DataType; use datafusion_common::tree_node::{Transformed, TreeNodeContainer, TreeNodeRecursion}; use datafusion_common::{ @@ -188,6 +189,46 @@ impl DdlStatement { } Wrapper(self) } + + pub(crate) fn stats(&self) -> LogicalPlanStats { + match self { + DdlStatement::CreateExternalTable(CreateExternalTable { + order_exprs, + column_defaults, + .. + }) => order_exprs + .iter() + .flatten() + .map(|s| &s.expr) + .chain(column_defaults.values()) + .fold(LogicalPlanStats::empty(), |s, e| s.merge(e.stats())), + DdlStatement::CreateMemoryTable(CreateMemoryTable { + input, + column_defaults, + .. + }) => column_defaults + .iter() + .map(|(_, e)| e) + .fold(input.stats(), |s, e| s.merge(e.stats())), + DdlStatement::CreateView(CreateView { input, .. }) => input.stats(), + DdlStatement::CreateIndex(CreateIndex { columns, .. }) => columns + .iter() + .map(|s| &s.expr) + .fold(LogicalPlanStats::empty(), |s, e| s.merge(e.stats())), + DdlStatement::CreateFunction(CreateFunction { args, params, .. }) => args + .iter() + .flatten() + .flat_map(|a| a.default_expr.as_slice()) + .chain(params.function_body.as_slice()) + .fold(LogicalPlanStats::empty(), |s, e| s.merge(e.stats())), + DdlStatement::CreateCatalogSchema(_) + | DdlStatement::CreateCatalog(_) + | DdlStatement::DropTable(_) + | DdlStatement::DropView(_) + | DdlStatement::DropCatalogSchema(_) + | DdlStatement::DropFunction(_) => LogicalPlanStats::empty(), + } + } } /// Creates an external table. diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index b808defcb959c..f7322a0f8cee4 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -317,13 +317,13 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { "Node Type": "EmptyRelation", }) } - LogicalPlan::RecursiveQuery(RecursiveQuery { is_distinct, .. }) => { + LogicalPlan::RecursiveQuery(RecursiveQuery { is_distinct, .. }, _) => { json!({ "Node Type": "RecursiveQuery", "Is Distinct": is_distinct, }) } - LogicalPlan::Values(Values { ref values, .. }) => { + LogicalPlan::Values(Values { ref values, .. }, _) => { let str_values = values .iter() // limit to only 5 values to avoid horrible display @@ -347,13 +347,16 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { "Values": values_str }) } - LogicalPlan::TableScan(TableScan { - ref source, - ref table_name, - ref filters, - ref fetch, - .. - }) => { + LogicalPlan::TableScan( + TableScan { + ref source, + ref table_name, + ref filters, + ref fetch, + .. + }, + _, + ) => { let mut object = json!({ "Node Type": "TableScan", "Relation Name": table_name.table(), @@ -407,26 +410,29 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { object } - LogicalPlan::Projection(Projection { ref expr, .. }) => { + LogicalPlan::Projection(Projection { ref expr, .. }, _) => { json!({ "Node Type": "Projection", "Expressions": expr.iter().map(|e| e.to_string()).collect::>() }) } - LogicalPlan::Dml(DmlStatement { table_name, op, .. }) => { + LogicalPlan::Dml(DmlStatement { table_name, op, .. }, _) => { json!({ "Node Type": "Projection", "Operation": op.name(), "Table Name": table_name.table() }) } - LogicalPlan::Copy(CopyTo { - input: _, - output_url, - file_type, - partition_by: _, - options, - }) => { + LogicalPlan::Copy( + CopyTo { + input: _, + output_url, + file_type, + partition_by: _, + options, + }, + _, + ) => { let op_str = options .iter() .map(|(k, v)| format!("{}={}", k, v)) @@ -439,41 +445,50 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { "Options": op_str }) } - LogicalPlan::Ddl(ddl) => { + LogicalPlan::Ddl(ddl, _) => { json!({ "Node Type": "Ddl", "Operation": format!("{}", ddl.display()) }) } - LogicalPlan::Filter(Filter { - predicate: ref expr, - .. - }) => { + LogicalPlan::Filter( + Filter { + predicate: ref expr, + .. + }, + _, + ) => { json!({ "Node Type": "Filter", "Condition": format!("{}", expr) }) } - LogicalPlan::Window(Window { - ref window_expr, .. - }) => { + LogicalPlan::Window( + Window { + ref window_expr, .. + }, + _, + ) => { json!({ "Node Type": "WindowAggr", "Expressions": expr_vec_fmt!(window_expr) }) } - LogicalPlan::Aggregate(Aggregate { - ref group_expr, - ref aggr_expr, - .. - }) => { + LogicalPlan::Aggregate( + Aggregate { + ref group_expr, + ref aggr_expr, + .. + }, + _, + ) => { json!({ "Node Type": "Aggregate", "Group By": expr_vec_fmt!(group_expr), "Aggregates": expr_vec_fmt!(aggr_expr) }) } - LogicalPlan::Sort(Sort { expr, fetch, .. }) => { + LogicalPlan::Sort(Sort { expr, fetch, .. }, _) => { let mut object = json!({ "Node Type": "Sort", "Sort Key": expr_vec_fmt!(expr), @@ -485,13 +500,16 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { object } - LogicalPlan::Join(Join { - on: ref keys, - filter, - join_constraint, - join_type, - .. - }) => { + LogicalPlan::Join( + Join { + on: ref keys, + filter, + join_constraint, + join_type, + .. + }, + _, + ) => { let join_expr: Vec = keys.iter().map(|(l, r)| format!("{l} = {r}")).collect(); let filter_expr = filter @@ -505,10 +523,13 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { "Filter": format!("{}", filter_expr) }) } - LogicalPlan::Repartition(Repartition { - partitioning_scheme, - .. - }) => match partitioning_scheme { + LogicalPlan::Repartition( + Repartition { + partitioning_scheme, + .. + }, + _, + ) => match partitioning_scheme { Partitioning::RoundRobinBatch(n) => { json!({ "Node Type": "Repartition", @@ -537,11 +558,14 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { }) } }, - LogicalPlan::Limit(Limit { - ref skip, - ref fetch, - .. - }) => { + LogicalPlan::Limit( + Limit { + ref skip, + ref fetch, + .. + }, + _, + ) => { let mut object = serde_json::json!( { "Node Type": "Limit", @@ -555,24 +579,24 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { }; object } - LogicalPlan::Subquery(Subquery { .. }) => { + LogicalPlan::Subquery(Subquery { .. }, _) => { json!({ "Node Type": "Subquery" }) } - LogicalPlan::SubqueryAlias(SubqueryAlias { ref alias, .. }) => { + LogicalPlan::SubqueryAlias(SubqueryAlias { ref alias, .. }, _) => { json!({ "Node Type": "Subquery", "Alias": alias.table(), }) } - LogicalPlan::Statement(statement) => { + LogicalPlan::Statement(statement, _) => { json!({ "Node Type": "Statement", "Statement": format!("{}", statement.display()) }) } - LogicalPlan::Distinct(distinct) => match distinct { + LogicalPlan::Distinct(distinct, _) => match distinct { Distinct::All(_) => { json!({ "Node Type": "DistinctAll" @@ -607,12 +631,12 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { "Node Type": "Analyze" }) } - LogicalPlan::Union(_) => { + LogicalPlan::Union(_, _) => { json!({ "Node Type": "Union" }) } - LogicalPlan::Extension(e) => { + LogicalPlan::Extension(e, _) => { json!({ "Node Type": e.node.name(), "Detail": format!("{:?}", e.node) @@ -623,12 +647,15 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { "Node Type": "DescribeTable" }) } - LogicalPlan::Unnest(Unnest { - input: plan, - list_type_columns: list_col_indices, - struct_type_columns: struct_col_indices, - .. - }) => { + LogicalPlan::Unnest( + Unnest { + input: plan, + list_type_columns: list_col_indices, + struct_type_columns: struct_col_indices, + .. + }, + _, + ) => { let input_columns = plan.schema().columns(); let list_type_columns = list_col_indices .iter() diff --git a/datafusion/expr/src/logical_plan/dml.rs b/datafusion/expr/src/logical_plan/dml.rs index 669bc8e8a7d34..8b75d372485ec 100644 --- a/datafusion/expr/src/logical_plan/dml.rs +++ b/datafusion/expr/src/logical_plan/dml.rs @@ -21,11 +21,12 @@ use std::fmt::{self, Debug, Display, Formatter}; use std::hash::{Hash, Hasher}; use std::sync::Arc; +use crate::logical_plan::tree_node::{LogicalPlanNodePattern, LogicalPlanStats}; +use crate::LogicalPlan; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{DFSchemaRef, TableReference}; - -use crate::LogicalPlan; +use enumset::enum_set; /// Operator that copies the contents of a database to file(s) #[derive(Clone)] @@ -89,6 +90,12 @@ impl Hash for CopyTo { } } +impl CopyTo { + pub(crate) fn stats(&self) -> LogicalPlanStats { + LogicalPlanStats::new(enum_set!(LogicalPlanNodePattern::LogicalPlanCopy)) + } +} + /// The operator that modifies the content of a database (adapted from /// substrait WriteRel) #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -128,6 +135,11 @@ impl DmlStatement { pub fn name(&self) -> &str { self.op.name() } + + pub(crate) fn stats(&self) -> LogicalPlanStats { + LogicalPlanStats::new(enum_set!(LogicalPlanNodePattern::LogicalPlanDdl)) + .merge(self.input.stats()) + } } // Manual implementation needed because of `table_schema` and `output_schema` fields. diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index e9ea2170cc7ab..9778832139206 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -54,10 +54,12 @@ use datafusion_common::{ FunctionalDependencies, ParamValues, Result, ScalarValue, TableReference, UnnestOptions, }; +use enumset::enum_set; use indexmap::IndexSet; // backwards compatibility use crate::display::PgJsonVisitor; +use crate::logical_plan::tree_node::{LogicalPlanNodePattern, LogicalPlanStats}; pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; pub use datafusion_common::{JoinConstraint, JoinType}; @@ -198,7 +200,7 @@ pub use datafusion_common::{JoinConstraint, JoinType}; pub enum LogicalPlan { /// Evaluates an arbitrary list of expressions (essentially a /// SELECT with an expression list) on its input. - Projection(Projection), + Projection(Projection, LogicalPlanStats), /// Filters rows from its input that do not match an /// expression (essentially a WHERE clause with a predicate /// expression). @@ -207,79 +209,266 @@ pub enum LogicalPlan { /// input; If the value of `` is true, the input row is /// passed to the output. If the value of `` is false /// (or null), the row is discarded. - Filter(Filter), + Filter(Filter, LogicalPlanStats), /// Windows input based on a set of window spec and window /// function (e.g. SUM or RANK). This is used to implement SQL /// window functions, and the `OVER` clause. - Window(Window), + Window(Window, LogicalPlanStats), /// Aggregates its input based on a set of grouping and aggregate /// expressions (e.g. SUM). This is used to implement SQL aggregates /// and `GROUP BY`. - Aggregate(Aggregate), + Aggregate(Aggregate, LogicalPlanStats), /// Sorts its input according to a list of sort expressions. This /// is used to implement SQL `ORDER BY` - Sort(Sort), + Sort(Sort, LogicalPlanStats), /// Join two logical plans on one or more join columns. /// This is used to implement SQL `JOIN` - Join(Join), + Join(Join, LogicalPlanStats), /// Repartitions the input based on a partitioning scheme. This is /// used to add parallelism and is sometimes referred to as an /// "exchange" operator in other systems - Repartition(Repartition), + Repartition(Repartition, LogicalPlanStats), /// Union multiple inputs with the same schema into a single /// output stream. This is used to implement SQL `UNION [ALL]` and /// `INTERSECT [ALL]`. - Union(Union), + Union(Union, LogicalPlanStats), /// Produces rows from a [`TableSource`], used to implement SQL /// `FROM` tables or views. - TableScan(TableScan), + TableScan(TableScan, LogicalPlanStats), /// Produces no rows: An empty relation with an empty schema that /// produces 0 or 1 row. This is used to implement SQL `SELECT` /// that has no values in the `FROM` clause. EmptyRelation(EmptyRelation), /// Produces the output of running another query. This is used to /// implement SQL subqueries - Subquery(Subquery), + Subquery(Subquery, LogicalPlanStats), /// Aliased relation provides, or changes, the name of a relation. - SubqueryAlias(SubqueryAlias), + SubqueryAlias(SubqueryAlias, LogicalPlanStats), /// Skip some number of rows, and then fetch some number of rows. - Limit(Limit), + Limit(Limit, LogicalPlanStats), /// A DataFusion [`Statement`] such as `SET VARIABLE` or `START TRANSACTION` - Statement(Statement), + Statement(Statement, LogicalPlanStats), /// Values expression. See /// [Postgres VALUES](https://www.postgresql.org/docs/current/queries-values.html) /// documentation for more details. This is used to implement SQL such as /// `VALUES (1, 2), (3, 4)` - Values(Values), + Values(Values, LogicalPlanStats), /// Produces a relation with string representations of /// various parts of the plan. This is used to implement SQL `EXPLAIN`. - Explain(Explain), + Explain(Explain, LogicalPlanStats), /// Runs the input, and prints annotated physical plan as a string /// with execution metric. This is used to implement SQL /// `EXPLAIN ANALYZE`. - Analyze(Analyze), + Analyze(Analyze, LogicalPlanStats), /// Extension operator defined outside of DataFusion. This is used /// to extend DataFusion with custom relational operations that - Extension(Extension), + Extension(Extension, LogicalPlanStats), /// Remove duplicate rows from the input. This is used to /// implement SQL `SELECT DISTINCT ...`. - Distinct(Distinct), + Distinct(Distinct, LogicalPlanStats), /// Data Manipulation Language (DML): Insert / Update / Delete - Dml(DmlStatement), + Dml(DmlStatement, LogicalPlanStats), /// Data Definition Language (DDL): CREATE / DROP TABLES / VIEWS / SCHEMAS - Ddl(DdlStatement), + Ddl(DdlStatement, LogicalPlanStats), /// `COPY TO` for writing plan results to files - Copy(CopyTo), + Copy(CopyTo, LogicalPlanStats), /// Describe the schema of the table. This is used to implement the /// SQL `DESCRIBE` command from MySQL. DescribeTable(DescribeTable), /// Unnest a column that contains a nested list type such as an /// ARRAY. This is used to implement SQL `UNNEST` - Unnest(Unnest), + Unnest(Unnest, LogicalPlanStats), /// A variadic query (e.g. "Recursive CTEs") - RecursiveQuery(RecursiveQuery), + RecursiveQuery(RecursiveQuery, LogicalPlanStats), } +// impl From for LogicalPlan { +// fn from(projection: Projection) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Projection(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: Filter) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Filter(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: Window) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Window(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: Aggregate) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Aggregate(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: Sort) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Sort(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: Join) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Join(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: Repartition) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Repartition(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: Union) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Union(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: TableScan) -> Self { +// let stats = projection.stats(); +// LogicalPlan::TableScan(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: EmptyRelation) -> Self { +// LogicalPlan::EmptyRelation(projection) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: Subquery) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Subquery(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: SubqueryAlias) -> Self { +// let stats = projection.stats(); +// LogicalPlan::SubqueryAlias(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: Limit) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Limit(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: Statement) -> Self { +// LogicalPlan::Statement(projection) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: Values) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Values(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: Explain) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Explain(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: Analyze) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Analyze(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: Extension) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Extension(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: Distinct) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Distinct(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: Prepare) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Prepare(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: Execute) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Execute(projection, stats) +// } +// } +// +// +// impl From for LogicalPlan { +// fn from(projection: DmlStatement) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Dml(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: DdlStatement) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Ddl(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: CopyTo) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Copy(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: DescribeTable) -> Self { +// LogicalPlan::DescribeTable(projection) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: Unnest) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Unnest(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: RecursiveQuery) -> Self { +// let stats = projection.stats(); +// LogicalPlan::RecursiveQuery(projection, stats) +// } +// } + impl Default for LogicalPlan { fn default() -> Self { LogicalPlan::EmptyRelation(EmptyRelation { @@ -310,35 +499,38 @@ impl LogicalPlan { pub fn schema(&self) -> &DFSchemaRef { match self { LogicalPlan::EmptyRelation(EmptyRelation { schema, .. }) => schema, - LogicalPlan::Values(Values { schema, .. }) => schema, - LogicalPlan::TableScan(TableScan { - projected_schema, .. - }) => projected_schema, - LogicalPlan::Projection(Projection { schema, .. }) => schema, - LogicalPlan::Filter(Filter { input, .. }) => input.schema(), - LogicalPlan::Distinct(Distinct::All(input)) => input.schema(), - LogicalPlan::Distinct(Distinct::On(DistinctOn { schema, .. })) => schema, - LogicalPlan::Window(Window { schema, .. }) => schema, - LogicalPlan::Aggregate(Aggregate { schema, .. }) => schema, - LogicalPlan::Sort(Sort { input, .. }) => input.schema(), - LogicalPlan::Join(Join { schema, .. }) => schema, - LogicalPlan::Repartition(Repartition { input, .. }) => input.schema(), - LogicalPlan::Limit(Limit { input, .. }) => input.schema(), - LogicalPlan::Statement(statement) => statement.schema(), - LogicalPlan::Subquery(Subquery { subquery, .. }) => subquery.schema(), - LogicalPlan::SubqueryAlias(SubqueryAlias { schema, .. }) => schema, - LogicalPlan::Explain(explain) => &explain.schema, - LogicalPlan::Analyze(analyze) => &analyze.schema, - LogicalPlan::Extension(extension) => extension.node.schema(), - LogicalPlan::Union(Union { schema, .. }) => schema, + LogicalPlan::Values(Values { schema, .. }, _) => schema, + LogicalPlan::TableScan( + TableScan { + projected_schema, .. + }, + _, + ) => projected_schema, + LogicalPlan::Projection(Projection { schema, .. }, _) => schema, + LogicalPlan::Filter(Filter { input, .. }, _) => input.schema(), + LogicalPlan::Distinct(Distinct::All(input), _) => input.schema(), + LogicalPlan::Distinct(Distinct::On(DistinctOn { schema, .. }), _) => schema, + LogicalPlan::Window(Window { schema, .. }, _) => schema, + LogicalPlan::Aggregate(Aggregate { schema, .. }, _) => schema, + LogicalPlan::Sort(Sort { input, .. }, _) => input.schema(), + LogicalPlan::Join(Join { schema, .. }, _) => schema, + LogicalPlan::Repartition(Repartition { input, .. }, _) => input.schema(), + LogicalPlan::Limit(Limit { input, .. }, _) => input.schema(), + LogicalPlan::Statement(statement, _) => statement.schema(), + LogicalPlan::Subquery(Subquery { subquery, .. }, _) => subquery.schema(), + LogicalPlan::SubqueryAlias(SubqueryAlias { schema, .. }, _) => schema, + LogicalPlan::Explain(explain, _) => &explain.schema, + LogicalPlan::Analyze(analyze, _) => &analyze.schema, + LogicalPlan::Extension(extension, _) => extension.node.schema(), + LogicalPlan::Union(Union { schema, .. }, _) => schema, LogicalPlan::DescribeTable(DescribeTable { output_schema, .. }) => { output_schema } - LogicalPlan::Dml(DmlStatement { output_schema, .. }) => output_schema, - LogicalPlan::Copy(CopyTo { input, .. }) => input.schema(), - LogicalPlan::Ddl(ddl) => ddl.schema(), - LogicalPlan::Unnest(Unnest { schema, .. }) => schema, - LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => { + LogicalPlan::Dml(DmlStatement { output_schema, .. }, _) => output_schema, + LogicalPlan::Copy(CopyTo { input, .. }, _) => input.schema(), + LogicalPlan::Ddl(ddl, _) => ddl.schema(), + LogicalPlan::Unnest(Unnest { schema, .. }, _) => schema, + LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }, _) => { // we take the schema of the static term as the schema of the entire recursive query static_term.schema() } @@ -349,11 +541,11 @@ impl LogicalPlan { /// of the plan. pub fn fallback_normalize_schemas(&self) -> Vec<&DFSchema> { match self { - LogicalPlan::Window(_) - | LogicalPlan::Projection(_) - | LogicalPlan::Aggregate(_) - | LogicalPlan::Unnest(_) - | LogicalPlan::Join(_) => self + LogicalPlan::Window { .. } + | LogicalPlan::Projection { .. } + | LogicalPlan::Aggregate { .. } + | LogicalPlan::Unnest { .. } + | LogicalPlan::Join { .. } => self .inputs() .iter() .map(|input| input.schema().as_ref()) @@ -436,35 +628,39 @@ impl LogicalPlan { /// Note does not include inputs to inputs, or subqueries. pub fn inputs(&self) -> Vec<&LogicalPlan> { match self { - LogicalPlan::Projection(Projection { input, .. }) => vec![input], - LogicalPlan::Filter(Filter { input, .. }) => vec![input], - LogicalPlan::Repartition(Repartition { input, .. }) => vec![input], - LogicalPlan::Window(Window { input, .. }) => vec![input], - LogicalPlan::Aggregate(Aggregate { input, .. }) => vec![input], - LogicalPlan::Sort(Sort { input, .. }) => vec![input], - LogicalPlan::Join(Join { left, right, .. }) => vec![left, right], - LogicalPlan::Limit(Limit { input, .. }) => vec![input], - LogicalPlan::Subquery(Subquery { subquery, .. }) => vec![subquery], - LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => vec![input], - LogicalPlan::Extension(extension) => extension.node.inputs(), - LogicalPlan::Union(Union { inputs, .. }) => { + LogicalPlan::Projection(Projection { input, .. }, _) => vec![input], + LogicalPlan::Filter(Filter { input, .. }, _) => vec![input], + LogicalPlan::Repartition(Repartition { input, .. }, _) => vec![input], + LogicalPlan::Window(Window { input, .. }, _) => vec![input], + LogicalPlan::Aggregate(Aggregate { input, .. }, _) => vec![input], + LogicalPlan::Sort(Sort { input, .. }, _) => vec![input], + LogicalPlan::Join(Join { left, right, .. }, _) => vec![left, right], + LogicalPlan::Limit(Limit { input, .. }, _) => vec![input], + LogicalPlan::Subquery(Subquery { subquery, .. }, _) => vec![subquery], + LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }, _) => vec![input], + LogicalPlan::Extension(extension, _) => extension.node.inputs(), + LogicalPlan::Union(Union { inputs, .. }, _) => { inputs.iter().map(|arc| arc.as_ref()).collect() } LogicalPlan::Distinct( Distinct::All(input) | Distinct::On(DistinctOn { input, .. }), + _, ) => vec![input], - LogicalPlan::Explain(explain) => vec![&explain.plan], - LogicalPlan::Analyze(analyze) => vec![&analyze.input], - LogicalPlan::Dml(write) => vec![&write.input], - LogicalPlan::Copy(copy) => vec![©.input], - LogicalPlan::Ddl(ddl) => ddl.inputs(), - LogicalPlan::Unnest(Unnest { input, .. }) => vec![input], - LogicalPlan::RecursiveQuery(RecursiveQuery { - static_term, - recursive_term, - .. - }) => vec![static_term, recursive_term], - LogicalPlan::Statement(stmt) => stmt.inputs(), + LogicalPlan::Explain(explain, _) => vec![&explain.plan], + LogicalPlan::Analyze(analyze, _) => vec![&analyze.input], + LogicalPlan::Dml(write, _) => vec![&write.input], + LogicalPlan::Copy(copy, _) => vec![©.input], + LogicalPlan::Ddl(ddl, _) => ddl.inputs(), + LogicalPlan::Unnest(Unnest { input, .. }, _) => vec![input], + LogicalPlan::RecursiveQuery( + RecursiveQuery { + static_term, + recursive_term, + .. + }, + _, + ) => vec![static_term, recursive_term], + LogicalPlan::Statement(stmt, _) => stmt.inputs(), // plans without inputs LogicalPlan::TableScan { .. } | LogicalPlan::EmptyRelation { .. } @@ -478,11 +674,14 @@ impl LogicalPlan { let mut using_columns: Vec> = vec![]; self.apply_with_subqueries(|plan| { - if let LogicalPlan::Join(Join { - join_constraint: JoinConstraint::Using, - on, - .. - }) = plan + if let LogicalPlan::Join( + Join { + join_constraint: JoinConstraint::Using, + on, + .. + }, + _, + ) = plan { // The join keys in using-join must be columns. let columns = @@ -512,31 +711,34 @@ impl LogicalPlan { /// returns the first output expression of this `LogicalPlan` node. pub fn head_output_expr(&self) -> Result> { match self { - LogicalPlan::Projection(projection) => { + LogicalPlan::Projection(projection, _) => { Ok(Some(projection.expr.as_slice()[0].clone())) } - LogicalPlan::Aggregate(agg) => { + LogicalPlan::Aggregate(agg, _) => { if agg.group_expr.is_empty() { Ok(Some(agg.aggr_expr.as_slice()[0].clone())) } else { Ok(Some(agg.group_expr.as_slice()[0].clone())) } } - LogicalPlan::Distinct(Distinct::On(DistinctOn { select_expr, .. })) => { + LogicalPlan::Distinct(Distinct::On(DistinctOn { select_expr, .. }), _) => { Ok(Some(select_expr[0].clone())) } - LogicalPlan::Filter(Filter { input, .. }) - | LogicalPlan::Distinct(Distinct::All(input)) - | LogicalPlan::Sort(Sort { input, .. }) - | LogicalPlan::Limit(Limit { input, .. }) - | LogicalPlan::Repartition(Repartition { input, .. }) - | LogicalPlan::Window(Window { input, .. }) => input.head_output_expr(), - LogicalPlan::Join(Join { - left, - right, - join_type, - .. - }) => match join_type { + LogicalPlan::Filter(Filter { input, .. }, _) + | LogicalPlan::Distinct(Distinct::All(input), _) + | LogicalPlan::Sort(Sort { input, .. }, _) + | LogicalPlan::Limit(Limit { input, .. }, _) + | LogicalPlan::Repartition(Repartition { input, .. }, _) + | LogicalPlan::Window(Window { input, .. }, _) => input.head_output_expr(), + LogicalPlan::Join( + Join { + left, + right, + join_type, + .. + }, + _, + ) => match join_type { JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { if left.schema().fields().is_empty() { right.head_output_expr() @@ -549,16 +751,16 @@ impl LogicalPlan { } JoinType::RightSemi | JoinType::RightAnti => right.head_output_expr(), }, - LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => { + LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }, _) => { static_term.head_output_expr() } - LogicalPlan::Union(union) => Ok(Some(Expr::Column(Column::from( + LogicalPlan::Union(union, _) => Ok(Some(Expr::Column(Column::from( union.schema.qualified_field(0), )))), - LogicalPlan::TableScan(table) => Ok(Some(Expr::Column(Column::from( + LogicalPlan::TableScan(table, _) => Ok(Some(Expr::Column(Column::from( table.projected_schema.qualified_field(0), )))), - LogicalPlan::SubqueryAlias(subquery_alias) => { + LogicalPlan::SubqueryAlias(subquery_alias, _) => { let expr_opt = subquery_alias.input.head_output_expr()?; expr_opt .map(|expr| { @@ -569,18 +771,18 @@ impl LogicalPlan { }) .map_or(Ok(None), |v| v.map(Some)) } - LogicalPlan::Subquery(_) => Ok(None), - LogicalPlan::EmptyRelation(_) - | LogicalPlan::Statement(_) - | LogicalPlan::Values(_) - | LogicalPlan::Explain(_) - | LogicalPlan::Analyze(_) - | LogicalPlan::Extension(_) - | LogicalPlan::Dml(_) - | LogicalPlan::Copy(_) - | LogicalPlan::Ddl(_) - | LogicalPlan::DescribeTable(_) - | LogicalPlan::Unnest(_) => Ok(None), + LogicalPlan::Subquery { .. } => Ok(None), + LogicalPlan::EmptyRelation { .. } + | LogicalPlan::Statement { .. } + | LogicalPlan::Values { .. } + | LogicalPlan::Explain { .. } + | LogicalPlan::Analyze { .. } + | LogicalPlan::Extension { .. } + | LogicalPlan::Dml { .. } + | LogicalPlan::Copy { .. } + | LogicalPlan::Ddl { .. } + | LogicalPlan::DescribeTable { .. } + | LogicalPlan::Unnest { .. } => Ok(None), } } @@ -610,47 +812,62 @@ impl LogicalPlan { match self { // Since expr may be different than the previous expr, schema of the projection // may change. We need to use try_new method instead of try_new_with_schema method. - LogicalPlan::Projection(Projection { - expr, - input, - schema: _, - }) => Projection::try_new(expr, input).map(LogicalPlan::Projection), - LogicalPlan::Dml(_) => Ok(self), - LogicalPlan::Copy(_) => Ok(self), - LogicalPlan::Values(Values { schema, values }) => { + LogicalPlan::Projection( + Projection { + expr, + input, + schema: _, + }, + _, + ) => Projection::try_new(expr, input).map(LogicalPlan::projection), + LogicalPlan::Dml { .. } => Ok(self), + LogicalPlan::Copy { .. } => Ok(self), + LogicalPlan::Values(Values { schema, values }, _) => { // todo it isn't clear why the schema is not recomputed here - Ok(LogicalPlan::Values(Values { schema, values })) + Ok(LogicalPlan::values(Values { schema, values })) } - LogicalPlan::Filter(Filter { - predicate, - input, - having, - }) => Filter::try_new_internal(predicate, input, having) - .map(LogicalPlan::Filter), - LogicalPlan::Repartition(_) => Ok(self), - LogicalPlan::Window(Window { - input, - window_expr, - schema: _, - }) => Window::try_new(window_expr, input).map(LogicalPlan::Window), - LogicalPlan::Aggregate(Aggregate { - input, - group_expr, - aggr_expr, - schema: _, - }) => Aggregate::try_new(input, group_expr, aggr_expr) - .map(LogicalPlan::Aggregate), - LogicalPlan::Sort(_) => Ok(self), - LogicalPlan::Join(Join { - left, - right, - filter, - join_type, - join_constraint, - on, - schema: _, - null_equals_null, - }) => { + LogicalPlan::Filter( + Filter { + predicate, + input, + having, + }, + _, + ) => Filter::try_new_internal(predicate, input, having) + .map(LogicalPlan::filter), + LogicalPlan::Repartition { .. } => Ok(self), + LogicalPlan::Window( + Window { + input, + window_expr, + schema: _, + }, + _, + ) => Window::try_new(window_expr, input).map(LogicalPlan::window), + LogicalPlan::Aggregate( + Aggregate { + input, + group_expr, + aggr_expr, + schema: _, + }, + _, + ) => Aggregate::try_new(input, group_expr, aggr_expr) + .map(LogicalPlan::aggregate), + LogicalPlan::Sort { .. } => Ok(self), + LogicalPlan::Join( + Join { + left, + right, + filter, + join_type, + join_constraint, + on, + schema: _, + null_equals_null, + }, + _, + ) => { let schema = build_join_schema(left.schema(), right.schema(), &join_type)?; @@ -662,7 +879,7 @@ impl LogicalPlan { }) .collect(); - Ok(LogicalPlan::Join(Join { + Ok(LogicalPlan::join(Join { left, right, join_type, @@ -673,24 +890,27 @@ impl LogicalPlan { null_equals_null, })) } - LogicalPlan::Subquery(_) => Ok(self), - LogicalPlan::SubqueryAlias(SubqueryAlias { - input, - alias, - schema: _, - }) => SubqueryAlias::try_new(input, alias).map(LogicalPlan::SubqueryAlias), - LogicalPlan::Limit(_) => Ok(self), - LogicalPlan::Ddl(_) => Ok(self), - LogicalPlan::Extension(Extension { node }) => { + LogicalPlan::Subquery { .. } => Ok(self), + LogicalPlan::SubqueryAlias( + SubqueryAlias { + input, + alias, + schema: _, + }, + _, + ) => SubqueryAlias::try_new(input, alias).map(LogicalPlan::subquery_alias), + LogicalPlan::Limit { .. } => Ok(self), + LogicalPlan::Ddl { .. } => Ok(self), + LogicalPlan::Extension(Extension { node }, _) => { // todo make an API that does not require cloning // This requires a copy of the extension nodes expressions and inputs let expr = node.expressions(); let inputs: Vec<_> = node.inputs().into_iter().cloned().collect(); - Ok(LogicalPlan::Extension(Extension { + Ok(LogicalPlan::extension(Extension { node: node.with_exprs_and_inputs(expr, inputs)?, })) } - LogicalPlan::Union(Union { inputs, schema }) => { + LogicalPlan::Union(Union { inputs, schema }, _) => { let input_schema = inputs[0].schema(); // If inputs are not pruned do not change schema // TODO this seems wrong (shouldn't we always use the schema of the input?) @@ -699,9 +919,9 @@ impl LogicalPlan { } else { Arc::clone(input_schema) }; - Ok(LogicalPlan::Union(Union { inputs, schema })) + Ok(LogicalPlan::union(Union { inputs, schema })) } - LogicalPlan::Distinct(distinct) => { + LogicalPlan::Distinct(distinct, _) => { let distinct = match distinct { Distinct::All(input) => Distinct::All(input), Distinct::On(DistinctOn { @@ -717,21 +937,24 @@ impl LogicalPlan { input, )?), }; - Ok(LogicalPlan::Distinct(distinct)) + Ok(LogicalPlan::distinct(distinct)) } - LogicalPlan::RecursiveQuery(_) => Ok(self), - LogicalPlan::Analyze(_) => Ok(self), - LogicalPlan::Explain(_) => Ok(self), - LogicalPlan::TableScan(_) => Ok(self), - LogicalPlan::EmptyRelation(_) => Ok(self), - LogicalPlan::Statement(_) => Ok(self), - LogicalPlan::DescribeTable(_) => Ok(self), - LogicalPlan::Unnest(Unnest { - input, - exec_columns, - options, - .. - }) => { + LogicalPlan::RecursiveQuery { .. } => Ok(self), + LogicalPlan::Analyze { .. } => Ok(self), + LogicalPlan::Explain { .. } => Ok(self), + LogicalPlan::TableScan { .. } => Ok(self), + LogicalPlan::EmptyRelation { .. } => Ok(self), + LogicalPlan::Statement { .. } => Ok(self), + LogicalPlan::DescribeTable { .. } => Ok(self), + LogicalPlan::Unnest( + Unnest { + input, + exec_columns, + options, + .. + }, + _, + ) => { // Update schema with unnested column type. unnest_with_options(Arc::unwrap_or_clone(input), exec_columns, options) } @@ -771,35 +994,41 @@ impl LogicalPlan { match self { // Since expr may be different than the previous expr, schema of the projection // may change. We need to use try_new method instead of try_new_with_schema method. - LogicalPlan::Projection(Projection { .. }) => { + LogicalPlan::Projection(Projection { .. }, _) => { let input = self.only_input(inputs)?; - Projection::try_new(expr, Arc::new(input)).map(LogicalPlan::Projection) + Projection::try_new(expr, Arc::new(input)).map(LogicalPlan::projection) } - LogicalPlan::Dml(DmlStatement { - table_name, - table_schema, - op, - .. - }) => { + LogicalPlan::Dml( + DmlStatement { + table_name, + table_schema, + op, + .. + }, + _, + ) => { self.assert_no_expressions(expr)?; let input = self.only_input(inputs)?; - Ok(LogicalPlan::Dml(DmlStatement::new( + Ok(LogicalPlan::dml(DmlStatement::new( table_name.clone(), Arc::clone(table_schema), op.clone(), Arc::new(input), ))) } - LogicalPlan::Copy(CopyTo { - input: _, - output_url, - file_type, - options, - partition_by, - }) => { + LogicalPlan::Copy( + CopyTo { + input: _, + output_url, + file_type, + options, + partition_by, + }, + _, + ) => { self.assert_no_expressions(expr)?; let input = self.only_input(inputs)?; - Ok(LogicalPlan::Copy(CopyTo { + Ok(LogicalPlan::copy(CopyTo { input: Arc::new(input), output_url: output_url.clone(), file_type: Arc::clone(file_type), @@ -807,9 +1036,9 @@ impl LogicalPlan { partition_by: partition_by.clone(), })) } - LogicalPlan::Values(Values { schema, .. }) => { + LogicalPlan::Values(Values { schema, .. }, _) => { self.assert_no_inputs(inputs)?; - Ok(LogicalPlan::Values(Values { + Ok(LogicalPlan::values(Values { schema: Arc::clone(schema), values: expr .chunks_exact(schema.fields().len()) @@ -821,55 +1050,61 @@ impl LogicalPlan { let predicate = self.only_expr(expr)?; let input = self.only_input(inputs)?; - Filter::try_new(predicate, Arc::new(input)).map(LogicalPlan::Filter) + Filter::try_new(predicate, Arc::new(input)).map(LogicalPlan::filter) } - LogicalPlan::Repartition(Repartition { - partitioning_scheme, - .. - }) => match partitioning_scheme { + LogicalPlan::Repartition( + Repartition { + partitioning_scheme, + .. + }, + _, + ) => match partitioning_scheme { Partitioning::RoundRobinBatch(n) => { self.assert_no_expressions(expr)?; let input = self.only_input(inputs)?; - Ok(LogicalPlan::Repartition(Repartition { + Ok(LogicalPlan::repartition(Repartition { partitioning_scheme: Partitioning::RoundRobinBatch(*n), input: Arc::new(input), })) } Partitioning::Hash(_, n) => { let input = self.only_input(inputs)?; - Ok(LogicalPlan::Repartition(Repartition { + Ok(LogicalPlan::repartition(Repartition { partitioning_scheme: Partitioning::Hash(expr, *n), input: Arc::new(input), })) } Partitioning::DistributeBy(_) => { let input = self.only_input(inputs)?; - Ok(LogicalPlan::Repartition(Repartition { + Ok(LogicalPlan::repartition(Repartition { partitioning_scheme: Partitioning::DistributeBy(expr), input: Arc::new(input), })) } }, - LogicalPlan::Window(Window { window_expr, .. }) => { + LogicalPlan::Window(Window { window_expr, .. }, _) => { assert_eq!(window_expr.len(), expr.len()); let input = self.only_input(inputs)?; - Window::try_new(expr, Arc::new(input)).map(LogicalPlan::Window) + Window::try_new(expr, Arc::new(input)).map(LogicalPlan::window) } - LogicalPlan::Aggregate(Aggregate { group_expr, .. }) => { + LogicalPlan::Aggregate(Aggregate { group_expr, .. }, _) => { let input = self.only_input(inputs)?; // group exprs are the first expressions let agg_expr = expr.split_off(group_expr.len()); Aggregate::try_new(Arc::new(input), expr, agg_expr) - .map(LogicalPlan::Aggregate) + .map(LogicalPlan::aggregate) } - LogicalPlan::Sort(Sort { - expr: sort_expr, - fetch, - .. - }) => { + LogicalPlan::Sort( + Sort { + expr: sort_expr, + fetch, + .. + }, + _, + ) => { let input = self.only_input(inputs)?; - Ok(LogicalPlan::Sort(Sort { + Ok(LogicalPlan::sort(Sort { expr: expr .into_iter() .zip(sort_expr.iter()) @@ -879,13 +1114,16 @@ impl LogicalPlan { fetch: *fetch, })) } - LogicalPlan::Join(Join { - join_type, - join_constraint, - on, - null_equals_null, - .. - }) => { + LogicalPlan::Join( + Join { + join_type, + join_constraint, + on, + null_equals_null, + .. + }, + _, + ) => { let (left, right) = self.only_two_inputs(inputs)?; let schema = build_join_schema(left.schema(), right.schema(), join_type)?; @@ -906,7 +1144,7 @@ impl LogicalPlan { let new_on = expr.into_iter().map(|equi_expr| { // SimplifyExpression rule may add alias to the equi_expr. let unalias_expr = equi_expr.clone().unalias(); - if let Expr::BinaryExpr(BinaryExpr { left, op: Operator::Eq, right }) = unalias_expr { + if let Expr::BinaryExpr(BinaryExpr { left, op: Operator::Eq, right }, _) = unalias_expr { Ok((*left, *right)) } else { internal_err!( @@ -915,7 +1153,7 @@ impl LogicalPlan { } }).collect::>>()?; - Ok(LogicalPlan::Join(Join { + Ok(LogicalPlan::join(Join { left: Arc::new(left), right: Arc::new(right), join_type: *join_type, @@ -926,24 +1164,27 @@ impl LogicalPlan { null_equals_null: *null_equals_null, })) } - LogicalPlan::Subquery(Subquery { - outer_ref_columns, .. - }) => { + LogicalPlan::Subquery( + Subquery { + outer_ref_columns, .. + }, + _, + ) => { self.assert_no_expressions(expr)?; let input = self.only_input(inputs)?; let subquery = LogicalPlanBuilder::from(input).build()?; - Ok(LogicalPlan::Subquery(Subquery { + Ok(LogicalPlan::subquery(Subquery { subquery: Arc::new(subquery), outer_ref_columns: outer_ref_columns.clone(), })) } - LogicalPlan::SubqueryAlias(SubqueryAlias { alias, .. }) => { + LogicalPlan::SubqueryAlias(SubqueryAlias { alias, .. }, _) => { self.assert_no_expressions(expr)?; let input = self.only_input(inputs)?; SubqueryAlias::try_new(Arc::new(input), alias.clone()) - .map(LogicalPlan::SubqueryAlias) + .map(LogicalPlan::subquery_alias) } - LogicalPlan::Limit(Limit { skip, fetch, .. }) => { + LogicalPlan::Limit(Limit { skip, fetch, .. }, _) => { let old_expr_len = skip.iter().chain(fetch.iter()).count(); if old_expr_len != expr.len() { return internal_err!( @@ -955,23 +1196,26 @@ impl LogicalPlan { let new_skip = skip.as_ref().and_then(|_| expr.pop()); let new_fetch = fetch.as_ref().and_then(|_| expr.pop()); let input = self.only_input(inputs)?; - Ok(LogicalPlan::Limit(Limit { + Ok(LogicalPlan::limit(Limit { skip: new_skip.map(Box::new), fetch: new_fetch.map(Box::new), input: Arc::new(input), })) } - LogicalPlan::Ddl(DdlStatement::CreateMemoryTable(CreateMemoryTable { - name, - if_not_exists, - or_replace, - column_defaults, - temporary, - .. - })) => { + LogicalPlan::Ddl( + DdlStatement::CreateMemoryTable(CreateMemoryTable { + name, + if_not_exists, + or_replace, + column_defaults, + temporary, + .. + }), + _, + ) => { self.assert_no_expressions(expr)?; let input = self.only_input(inputs)?; - Ok(LogicalPlan::Ddl(DdlStatement::CreateMemoryTable( + Ok(LogicalPlan::ddl(DdlStatement::CreateMemoryTable( CreateMemoryTable { input: Arc::new(input), constraints: Constraints::empty(), @@ -983,16 +1227,19 @@ impl LogicalPlan { }, ))) } - LogicalPlan::Ddl(DdlStatement::CreateView(CreateView { - name, - or_replace, - definition, - temporary, - .. - })) => { + LogicalPlan::Ddl( + DdlStatement::CreateView(CreateView { + name, + or_replace, + definition, + temporary, + .. + }), + _, + ) => { self.assert_no_expressions(expr)?; let input = self.only_input(inputs)?; - Ok(LogicalPlan::Ddl(DdlStatement::CreateView(CreateView { + Ok(LogicalPlan::ddl(DdlStatement::CreateView(CreateView { input: Arc::new(input), name: name.clone(), or_replace: *or_replace, @@ -1000,10 +1247,10 @@ impl LogicalPlan { definition: definition.clone(), }))) } - LogicalPlan::Extension(e) => Ok(LogicalPlan::Extension(Extension { + LogicalPlan::Extension(e, _) => Ok(LogicalPlan::extension(Extension { node: e.node.with_exprs_and_inputs(expr, inputs)?, })), - LogicalPlan::Union(Union { schema, .. }) => { + LogicalPlan::Union(Union { schema, .. }, _) => { self.assert_no_expressions(expr)?; let input_schema = inputs[0].schema(); // If inputs are not pruned do not change schema. @@ -1012,12 +1259,12 @@ impl LogicalPlan { } else { Arc::clone(input_schema) }; - Ok(LogicalPlan::Union(Union { + Ok(LogicalPlan::union(Union { inputs: inputs.into_iter().map(Arc::new).collect(), schema, })) } - LogicalPlan::Distinct(distinct) => { + LogicalPlan::Distinct(distinct, _) => { let distinct = match distinct { Distinct::All(_) => { self.assert_no_expressions(expr)?; @@ -1041,33 +1288,36 @@ impl LogicalPlan { )?) } }; - Ok(LogicalPlan::Distinct(distinct)) + Ok(LogicalPlan::distinct(distinct)) } - LogicalPlan::RecursiveQuery(RecursiveQuery { - name, is_distinct, .. - }) => { + LogicalPlan::RecursiveQuery( + RecursiveQuery { + name, is_distinct, .. + }, + _, + ) => { self.assert_no_expressions(expr)?; let (static_term, recursive_term) = self.only_two_inputs(inputs)?; - Ok(LogicalPlan::RecursiveQuery(RecursiveQuery { + Ok(LogicalPlan::recursive_query(RecursiveQuery { name: name.clone(), static_term: Arc::new(static_term), recursive_term: Arc::new(recursive_term), is_distinct: *is_distinct, })) } - LogicalPlan::Analyze(a) => { + LogicalPlan::Analyze(a, _) => { self.assert_no_expressions(expr)?; let input = self.only_input(inputs)?; - Ok(LogicalPlan::Analyze(Analyze { + Ok(LogicalPlan::analyze(Analyze { verbose: a.verbose, schema: Arc::clone(&a.schema), input: Arc::new(input), })) } - LogicalPlan::Explain(e) => { + LogicalPlan::Explain(e, _) => { self.assert_no_expressions(expr)?; let input = self.only_input(inputs)?; - Ok(LogicalPlan::Explain(Explain { + Ok(LogicalPlan::explain(Explain { verbose: e.verbose, plan: Arc::new(input), stringified_plans: e.stringified_plans.clone(), @@ -1075,47 +1325,51 @@ impl LogicalPlan { logical_optimization_succeeded: e.logical_optimization_succeeded, })) } - LogicalPlan::Statement(Statement::Prepare(Prepare { - name, - data_types, - .. - })) => { + LogicalPlan::Statement( + Statement::Prepare(Prepare { + name, data_types, .. + }), + _, + ) => { self.assert_no_expressions(expr)?; let input = self.only_input(inputs)?; - Ok(LogicalPlan::Statement(Statement::Prepare(Prepare { + Ok(LogicalPlan::statement(Statement::Prepare(Prepare { name: name.clone(), data_types: data_types.clone(), input: Arc::new(input), }))) } - LogicalPlan::Statement(Statement::Execute(Execute { name, .. })) => { + LogicalPlan::Statement(Statement::Execute(Execute { name, .. }), _) => { self.assert_no_inputs(inputs)?; - Ok(LogicalPlan::Statement(Statement::Execute(Execute { + Ok(LogicalPlan::statement(Statement::Execute(Execute { name: name.clone(), parameters: expr, }))) } - LogicalPlan::TableScan(ts) => { + LogicalPlan::TableScan(ts, _) => { self.assert_no_inputs(inputs)?; - Ok(LogicalPlan::TableScan(TableScan { + Ok(LogicalPlan::table_scan(TableScan { filters: expr, ..ts.clone() })) } - LogicalPlan::EmptyRelation(_) - | LogicalPlan::Ddl(_) - | LogicalPlan::Statement(_) - | LogicalPlan::DescribeTable(_) => { + LogicalPlan::EmptyRelation { .. } + | LogicalPlan::Ddl { .. } + | LogicalPlan::Statement { .. } + | LogicalPlan::DescribeTable { .. } => { // All of these plan types have no inputs / exprs so should not be called self.assert_no_expressions(expr)?; self.assert_no_inputs(inputs)?; Ok(self.clone()) } - LogicalPlan::Unnest(Unnest { - exec_columns: columns, - options, - .. - }) => { + LogicalPlan::Unnest( + Unnest { + exec_columns: columns, + options, + .. + }, + _, + ) => { self.assert_no_expressions(expr)?; let input = self.only_input(inputs)?; // Update schema with unnested column type. @@ -1249,7 +1503,7 @@ impl LogicalPlan { // unwrap Prepare Ok( - if let LogicalPlan::Statement(Statement::Prepare(prepare_lp)) = + if let LogicalPlan::Statement(Statement::Prepare(prepare_lp), _) = plan_with_values { param_values.verify(&prepare_lp.data_types)?; @@ -1267,18 +1521,21 @@ impl LogicalPlan { /// If `Some(n)` then the plan can return at most `n` rows but may return fewer. pub fn max_rows(self: &LogicalPlan) -> Option { match self { - LogicalPlan::Projection(Projection { input, .. }) => input.max_rows(), - LogicalPlan::Filter(filter) => { + LogicalPlan::Projection(Projection { input, .. }, _) => input.max_rows(), + LogicalPlan::Filter(filter, _) => { if filter.is_scalar() { Some(1) } else { filter.input.max_rows() } } - LogicalPlan::Window(Window { input, .. }) => input.max_rows(), - LogicalPlan::Aggregate(Aggregate { - input, group_expr, .. - }) => { + LogicalPlan::Window(Window { input, .. }, _) => input.max_rows(), + LogicalPlan::Aggregate( + Aggregate { + input, group_expr, .. + }, + _, + ) => { // Empty group_expr will return Some(1) if group_expr .iter() @@ -1289,7 +1546,7 @@ impl LogicalPlan { input.max_rows() } } - LogicalPlan::Sort(Sort { input, fetch, .. }) => { + LogicalPlan::Sort(Sort { input, fetch, .. }, _) => { match (fetch, input.max_rows()) { (Some(fetch_limit), Some(input_max)) => { Some(input_max.min(*fetch_limit)) @@ -1299,12 +1556,15 @@ impl LogicalPlan { (None, None) => None, } } - LogicalPlan::Join(Join { - left, - right, - join_type, - .. - }) => match join_type { + LogicalPlan::Join( + Join { + left, + right, + join_type, + .. + }, + _, + ) => match join_type { JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { match (left.max_rows(), right.max_rows()) { (Some(left_max), Some(right_max)) => { @@ -1324,8 +1584,8 @@ impl LogicalPlan { } JoinType::RightSemi | JoinType::RightAnti => right.max_rows(), }, - LogicalPlan::Repartition(Repartition { input, .. }) => input.max_rows(), - LogicalPlan::Union(Union { inputs, .. }) => inputs + LogicalPlan::Repartition(Repartition { input, .. }, _) => input.max_rows(), + LogicalPlan::Union(Union { inputs, .. }, _) => inputs .iter() .map(|plan| plan.max_rows()) .try_fold(0usize, |mut acc, input_max| { @@ -1336,28 +1596,31 @@ impl LogicalPlan { None } }), - LogicalPlan::TableScan(TableScan { fetch, .. }) => *fetch, + LogicalPlan::TableScan(TableScan { fetch, .. }, _) => *fetch, LogicalPlan::EmptyRelation(_) => Some(0), - LogicalPlan::RecursiveQuery(_) => None, - LogicalPlan::Subquery(_) => None, - LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => input.max_rows(), - LogicalPlan::Limit(limit) => match limit.get_fetch_type() { + LogicalPlan::RecursiveQuery(_, _) => None, + LogicalPlan::Subquery(_, _) => None, + LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }, _) => { + input.max_rows() + } + LogicalPlan::Limit(limit, _) => match limit.get_fetch_type() { Ok(FetchType::Literal(s)) => s, _ => None, }, LogicalPlan::Distinct( Distinct::All(input) | Distinct::On(DistinctOn { input, .. }), + _, ) => input.max_rows(), - LogicalPlan::Values(v) => Some(v.values.len()), - LogicalPlan::Unnest(_) => None, - LogicalPlan::Ddl(_) - | LogicalPlan::Explain(_) - | LogicalPlan::Analyze(_) - | LogicalPlan::Dml(_) - | LogicalPlan::Copy(_) + LogicalPlan::Values(v, _) => Some(v.values.len()), + LogicalPlan::Unnest(_, _) => None, + LogicalPlan::Ddl(_, _) + | LogicalPlan::Explain(_, _) + | LogicalPlan::Analyze(_, _) + | LogicalPlan::Dml(_, _) + | LogicalPlan::Copy(_, _) | LogicalPlan::DescribeTable(_) - | LogicalPlan::Statement(_) - | LogicalPlan::Extension(_) => None, + | LogicalPlan::Statement(_, _) + | LogicalPlan::Extension(_, _) => None, } } @@ -1385,16 +1648,19 @@ impl LogicalPlan { /// See also: [`crate::utils::columnize_expr`] pub fn columnized_output_exprs(&self) -> Result> { match self { - LogicalPlan::Aggregate(aggregate) => Ok(aggregate + LogicalPlan::Aggregate(aggregate, _) => Ok(aggregate .output_expressions()? .into_iter() .zip(self.schema().columns()) .collect()), - LogicalPlan::Window(Window { - window_expr, - input, - schema, - }) => { + LogicalPlan::Window( + Window { + window_expr, + input, + schema, + }, + _, + ) => { // The input could be another Window, so the result should also include the input's. For Example: // `EXPLAIN SELECT RANK() OVER (PARTITION BY a ORDER BY b), SUM(b) OVER (PARTITION BY a) FROM t` // Its plan is: @@ -1702,10 +1968,10 @@ impl LogicalPlan { LogicalPlan::EmptyRelation(_) => write!(f, "EmptyRelation"), LogicalPlan::RecursiveQuery(RecursiveQuery { is_distinct, .. - }) => { + }, _) => { write!(f, "RecursiveQuery: is_distinct={}", is_distinct) } - LogicalPlan::Values(Values { ref values, .. }) => { + LogicalPlan::Values(Values { ref values, .. }, _) => { let str_values: Vec<_> = values .iter() // limit to only 5 values to avoid horrible display @@ -1731,7 +1997,7 @@ impl LogicalPlan { ref filters, ref fetch, .. - }) => { + }, _) => { let projected_fields = match projection { Some(indices) => { let schema = source.schema(); @@ -1799,7 +2065,7 @@ impl LogicalPlan { Ok(()) } - LogicalPlan::Projection(Projection { ref expr, .. }) => { + LogicalPlan::Projection(Projection { ref expr, .. }, _) => { write!(f, "Projection: ")?; for (i, expr_item) in expr.iter().enumerate() { if i > 0 { @@ -1809,7 +2075,7 @@ impl LogicalPlan { } Ok(()) } - LogicalPlan::Dml(DmlStatement { table_name, op, .. }) => { + LogicalPlan::Dml(DmlStatement { table_name, op, .. }, _) => { write!(f, "Dml: op=[{op}] table=[{table_name}]") } LogicalPlan::Copy(CopyTo { @@ -1818,7 +2084,7 @@ impl LogicalPlan { file_type, options, .. - }) => { + }, _) => { let op_str = options .iter() .map(|(k, v)| format!("{k} {v}")) @@ -1827,16 +2093,16 @@ impl LogicalPlan { write!(f, "CopyTo: format={} output_url={output_url} options: ({op_str})", file_type.get_ext()) } - LogicalPlan::Ddl(ddl) => { + LogicalPlan::Ddl(ddl, _) => { write!(f, "{}", ddl.display()) } LogicalPlan::Filter(Filter { predicate: ref expr, .. - }) => write!(f, "Filter: {expr}"), + }, _) => write!(f, "Filter: {expr}"), LogicalPlan::Window(Window { ref window_expr, .. - }) => { + }, _) => { write!( f, "WindowAggr: windowExpr=[[{}]]", @@ -1847,13 +2113,13 @@ impl LogicalPlan { ref group_expr, ref aggr_expr, .. - }) => write!( + }, _) => write!( f, "Aggregate: groupBy=[[{}]], aggr=[[{}]]", expr_vec_fmt!(group_expr), expr_vec_fmt!(aggr_expr) ), - LogicalPlan::Sort(Sort { expr, fetch, .. }) => { + LogicalPlan::Sort(Sort { expr, fetch, .. }, _) => { write!(f, "Sort: ")?; for (i, expr_item) in expr.iter().enumerate() { if i > 0 { @@ -1873,7 +2139,7 @@ impl LogicalPlan { join_constraint, join_type, .. - }) => { + }, _) => { let join_expr: Vec = keys.iter().map(|(l, r)| format!("{l} = {r}")).collect(); let filter_expr = filter @@ -1909,7 +2175,7 @@ impl LogicalPlan { LogicalPlan::Repartition(Repartition { partitioning_scheme, .. - }) => match partitioning_scheme { + }, _) => match partitioning_scheme { Partitioning::RoundRobinBatch(n) => { write!(f, "Repartition: RoundRobinBatch partition_count={n}") } @@ -1933,7 +2199,7 @@ impl LogicalPlan { ) } }, - LogicalPlan::Limit(limit) => { + LogicalPlan::Limit(limit, _) => { // Attempt to display `skip` and `fetch` as literals if possible, otherwise as expressions. let skip_str = match limit.get_skip_type() { Ok(SkipType::Literal(n)) => n.to_string(), @@ -1949,16 +2215,16 @@ impl LogicalPlan { "Limit: skip={}, fetch={}", skip_str,fetch_str, ) } - LogicalPlan::Subquery(Subquery { .. }) => { + LogicalPlan::Subquery(Subquery { .. }, _) => { write!(f, "Subquery:") } - LogicalPlan::SubqueryAlias(SubqueryAlias { ref alias, .. }) => { + LogicalPlan::SubqueryAlias(SubqueryAlias { ref alias, .. }, _) => { write!(f, "SubqueryAlias: {alias}") } - LogicalPlan::Statement(statement) => { + LogicalPlan::Statement(statement, _) => { write!(f, "{}", statement.display()) } - LogicalPlan::Distinct(distinct) => match distinct { + LogicalPlan::Distinct(distinct, _) => match distinct { Distinct::All(_) => write!(f, "Distinct:"), Distinct::On(DistinctOn { on_expr, @@ -1975,15 +2241,15 @@ impl LogicalPlan { }, LogicalPlan::Explain { .. } => write!(f, "Explain"), LogicalPlan::Analyze { .. } => write!(f, "Analyze"), - LogicalPlan::Union(_) => write!(f, "Union"), - LogicalPlan::Extension(e) => e.node.fmt_for_explain(f), + LogicalPlan::Union(_, _) => write!(f, "Union"), + LogicalPlan::Extension(e, _) => e.node.fmt_for_explain(f), LogicalPlan::DescribeTable(DescribeTable { .. }) => { write!(f, "DescribeTable") } LogicalPlan::Unnest(Unnest { input: plan, list_type_columns: list_col_indices, - struct_type_columns: struct_col_indices, .. }) => { + struct_type_columns: struct_col_indices, .. }, _) => { let input_columns = plan.schema().columns(); let list_type_columns = list_col_indices .iter() @@ -2005,6 +2271,133 @@ impl LogicalPlan { } Wrapper(self) } + + pub fn projection(projection: Projection) -> Self { + let stats = LogicalPlanStats::new(enum_set!( + LogicalPlanNodePattern::LogicalPlanProjection + )) + .merge(projection.stats()); + LogicalPlan::Projection(projection, stats) + } + + pub fn filter(filter: Filter) -> Self { + let stats = + LogicalPlanStats::new(enum_set!(LogicalPlanNodePattern::LogicalPlanFilter)) + .merge(filter.stats()); + LogicalPlan::Filter(filter, stats) + } + + pub fn statement(statement: Statement) -> Self { + let stats = LogicalPlanStats::new(enum_set!( + LogicalPlanNodePattern::LogicalPlanStatement + )) + .merge(statement.stats()); + LogicalPlan::Statement(statement, stats) + } + + pub fn window(window: Window) -> Self { + let stats = window.stats(); + LogicalPlan::Window(window, stats) + } + + pub fn aggregate(aggregate: Aggregate) -> Self { + let stats = aggregate.stats(); + LogicalPlan::Aggregate(aggregate, stats) + } + + pub fn sort(sort: Sort) -> Self { + let stats = sort.stats(); + LogicalPlan::Sort(sort, stats) + } + + pub fn join(join: Join) -> Self { + let stats = join.stats(); + LogicalPlan::Join(join, stats) + } + + pub fn repartition(repartition: Repartition) -> Self { + let stats = repartition.stats(); + LogicalPlan::Repartition(repartition, stats) + } + + pub fn union(projection: Union) -> Self { + let stats = projection.stats(); + LogicalPlan::Union(projection, stats) + } + + pub fn table_scan(table_scan: TableScan) -> Self { + let stats = table_scan.stats(); + LogicalPlan::TableScan(table_scan, stats) + } + + pub fn subquery(subquery: Subquery) -> Self { + let stats = subquery.stats(); + LogicalPlan::Subquery(subquery, stats) + } + + pub fn subquery_alias(subquery_alias: SubqueryAlias) -> Self { + let stats = subquery_alias.stats(); + LogicalPlan::SubqueryAlias(subquery_alias, stats) + } + + pub fn limit(limit: Limit) -> Self { + let stats = limit.stats(); + LogicalPlan::Limit(limit, stats) + } + + pub fn values(values: Values) -> Self { + let stats = values.stats(); + LogicalPlan::Values(values, stats) + } + + pub fn explain(explain: Explain) -> Self { + let stats = explain.stats(); + LogicalPlan::Explain(explain, stats) + } + + pub fn analyze(analyze: Analyze) -> Self { + let stats = analyze.stats(); + LogicalPlan::Analyze(analyze, stats) + } + + pub fn extension(extension: Extension) -> Self { + let stats = extension.stats(); + LogicalPlan::Extension(extension, stats) + } + + pub fn distinct(distinct: Distinct) -> Self { + let stats = distinct.stats(); + LogicalPlan::Distinct(distinct, stats) + } + + pub fn dml(dml_statement: DmlStatement) -> Self { + let stats = dml_statement.stats(); + LogicalPlan::Dml(dml_statement, stats) + } + + pub fn ddl(ddl_statement: DdlStatement) -> Self { + let stats = ddl_statement.stats(); + LogicalPlan::Ddl(ddl_statement, stats) + } + + pub fn copy(copy_to: CopyTo) -> Self { + let stats = copy_to.stats(); + LogicalPlan::Copy(copy_to, stats) + } + + pub fn describe_table(describe_table: DescribeTable) -> Self { + LogicalPlan::DescribeTable(describe_table) + } + + pub fn unnest(unnest: Unnest) -> Self { + let stats = unnest.stats(); + LogicalPlan::Unnest(unnest, stats) + } + + pub fn recursive_query(recursive_query: RecursiveQuery) -> Self { + let stats = recursive_query.stats(); + LogicalPlan::RecursiveQuery(recursive_query, stats) + } } impl Display for LogicalPlan { @@ -2071,6 +2464,16 @@ pub struct RecursiveQuery { pub is_distinct: bool, } +impl RecursiveQuery { + fn stats(&self) -> LogicalPlanStats { + LogicalPlanStats::new(enum_set!( + LogicalPlanNodePattern::LogicalPlanRecursiveQuery + )) + .merge(self.static_term.stats()) + .merge(self.recursive_term.stats()) + } +} + /// Values expression. See /// [Postgres VALUES](https://www.postgresql.org/docs/current/queries-values.html) /// documentation for more details. @@ -2082,6 +2485,15 @@ pub struct Values { pub values: Vec>, } +impl Values { + fn stats(&self) -> LogicalPlanStats { + self.values.iter().flatten().fold( + LogicalPlanStats::new(enum_set!(LogicalPlanNodePattern::LogicalPlanValues)), + |s, e| s.merge(e.stats()), + ) + } +} + // Manual implementation needed because of `schema` field. Comparison excludes this field. impl PartialOrd for Values { fn partial_cmp(&self, other: &Self) -> Option { @@ -2147,6 +2559,13 @@ impl Projection { schema, } } + + fn stats(&self) -> LogicalPlanStats { + self.expr.iter().fold( + LogicalPlanStats::empty().merge(self.input.stats()), + |s, e| s.merge(e.stats()), + ) + } } /// Computes the schema of the result produced by applying a projection to the input logical plan. @@ -2210,6 +2629,11 @@ impl SubqueryAlias { schema, }) } + + fn stats(&self) -> LogicalPlanStats { + LogicalPlanStats::new(enum_set!(LogicalPlanNodePattern::LogicalPlanSubqueryAlias)) + .merge(self.input.stats()) + } } // Manual implementation needed because of `schema` field. Comparison excludes this field. @@ -2328,11 +2752,14 @@ impl Filter { let eq_pred_cols: HashSet<_> = exprs .iter() .filter_map(|expr| { - let Expr::BinaryExpr(BinaryExpr { - left, - op: Operator::Eq, - right, - }) = expr + let Expr::BinaryExpr( + BinaryExpr { + left, + op: Operator::Eq, + right, + }, + _, + ) = expr else { return None; }; @@ -2360,6 +2787,10 @@ impl Filter { } false } + + fn stats(&self) -> LogicalPlanStats { + self.input.stats() + } } /// Window its input based on a set of window spec and window function (e.g. SUM or RANK) @@ -2399,11 +2830,14 @@ impl Window { .iter() .enumerate() .filter_map(|(idx, expr)| { - if let Expr::WindowFunction(WindowFunction { - fun: WindowFunctionDefinition::WindowUDF(udwf), - partition_by, - .. - }) = expr + if let Expr::WindowFunction( + WindowFunction { + fun: WindowFunctionDefinition::WindowUDF(udwf), + partition_by, + .. + }, + _, + ) = expr { // When there is no PARTITION BY, row number will be unique // across the entire table. @@ -2457,6 +2891,14 @@ impl Window { schema, }) } + + fn stats(&self) -> LogicalPlanStats { + self.window_expr.iter().fold( + LogicalPlanStats::new(enum_set!(LogicalPlanNodePattern::LogicalPlanWindow)) + .merge(self.input.stats()), + |s, e| s.merge(e.stats()), + ) + } } // Manual implementation needed because of `schema` field. Comparison excludes this field. @@ -2604,6 +3046,15 @@ impl TableScan { fetch, }) } + + fn stats(&self) -> LogicalPlanStats { + self.filters.iter().fold( + LogicalPlanStats::new(enum_set!( + LogicalPlanNodePattern::LogicalPlanTableScan + )), + |s, e| s.merge(e.stats()), + ) + } } // Repartition the plan based on a partitioning scheme. @@ -2615,6 +3066,21 @@ pub struct Repartition { pub partitioning_scheme: Partitioning, } +impl Repartition { + fn stats(&self) -> LogicalPlanStats { + let s = LogicalPlanStats::new(enum_set!( + LogicalPlanNodePattern::LogicalPlanRepartition + )) + .merge(self.input.stats()); + match &self.partitioning_scheme { + Partitioning::Hash(expr, _) | Partitioning::DistributeBy(expr) => { + expr.iter().fold(s, |s, e| s.merge(e.stats())) + } + Partitioning::RoundRobinBatch(_) => s, + } + } +} + /// Union multiple inputs #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Union { @@ -2630,6 +3096,14 @@ impl PartialOrd for Union { self.inputs.partial_cmp(&other.inputs) } } +impl Union { + fn stats(&self) -> LogicalPlanStats { + self.inputs.iter().fold( + LogicalPlanStats::new(enum_set!(LogicalPlanNodePattern::LogicalPlanUnion)), + |s, e| s.merge(e.stats()), + ) + } +} /// Describe the schema of table /// @@ -2716,6 +3190,13 @@ impl PartialOrd for Explain { } } +impl Explain { + fn stats(&self) -> LogicalPlanStats { + LogicalPlanStats::new(enum_set!(LogicalPlanNodePattern::LogicalPlanExplain)) + .merge(self.plan.stats()) + } +} + /// Runs the actual plan, and then prints the physical plan with /// with execution metrics. #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -2738,6 +3219,13 @@ impl PartialOrd for Analyze { } } +impl Analyze { + fn stats(&self) -> LogicalPlanStats { + LogicalPlanStats::new(enum_set!(LogicalPlanNodePattern::LogicalPlanAnalyze)) + .merge(self.input.stats()) + } +} + /// Extension operator defined outside of DataFusion // TODO(clippy): This clippy `allow` should be removed if // the manual `PartialEq` is removed in favor of a derive. @@ -2764,6 +3252,20 @@ impl PartialOrd for Extension { } } +impl Extension { + fn stats(&self) -> LogicalPlanStats { + self.node.inputs().iter().fold( + self.node.expressions().iter().fold( + LogicalPlanStats::new(enum_set!( + LogicalPlanNodePattern::LogicalPlanUnion + )), + |s, e| s.merge(e.stats()), + ), + |s, e| s.merge(e.stats()), + ) + } +} + /// Produces the first `n` tuples from its input and discards the rest. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub struct Limit { @@ -2831,6 +3333,13 @@ impl Limit { None => Ok(FetchType::Literal(None)), } } + + fn stats(&self) -> LogicalPlanStats { + self.skip.iter().chain(self.fetch.iter()).fold( + LogicalPlanStats::new(enum_set!(LogicalPlanNodePattern::LogicalPlanUnion)), + |s, e| s.merge(e.stats()), + ) + } } /// Removes duplicate rows from the input @@ -2850,6 +3359,16 @@ impl Distinct { Distinct::On(DistinctOn { input, .. }) => input, } } + + fn stats(&self) -> LogicalPlanStats { + let s = LogicalPlanStats::new(enum_set!( + LogicalPlanNodePattern::LogicalPlanRepartition + )); + match self { + Distinct::All(input) => s.merge(input.stats()), + Distinct::On(distinct_on) => s.merge(distinct_on.stats()), + } + } } /// Removes duplicate rows from the input @@ -2930,6 +3449,14 @@ impl DistinctOn { self.sort_expr = Some(sort_expr); Ok(self) } + + fn stats(&self) -> LogicalPlanStats { + self.on_expr + .iter() + .chain(self.select_expr.iter()) + .chain(self.sort_expr.iter().flatten().map(|s| &s.expr)) + .fold(self.input.stats(), |s, e| s.merge(e.stats())) + } } // Manual implementation needed because of `schema` field. Comparison excludes this field. @@ -2989,7 +3516,7 @@ impl Aggregate { ) -> Result { let group_expr = enumerate_grouping_sets(group_expr)?; - let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]); + let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_, _)]); let grouping_expr: Vec<&Expr> = grouping_set_to_exprlist(group_expr.as_slice())?; @@ -3062,7 +3589,7 @@ impl Aggregate { } fn is_grouping_set(&self) -> bool { - matches!(self.group_expr.as_slice(), [Expr::GroupingSet(_)]) + matches!(self.group_expr.as_slice(), [Expr::GroupingSet(_, _)]) } /// Get the output expressions. @@ -3120,6 +3647,16 @@ impl Aggregate { /// with `NULL` values. To handle these cases correctly, we must distinguish /// between an actual `NULL` value in a column and a column being excluded from the set. pub const INTERNAL_GROUPING_ID: &'static str = "__grouping_id"; + + fn stats(&self) -> LogicalPlanStats { + self.group_expr.iter().chain(self.aggr_expr.iter()).fold( + LogicalPlanStats::new(enum_set!( + LogicalPlanNodePattern::LogicalPlanAggregate + )) + .merge(self.input.stats()), + |s, e| s.merge(e.stats()), + ) + } } // Manual implementation needed because of `schema` field. Comparison excludes this field. @@ -3141,7 +3678,7 @@ impl PartialOrd for Aggregate { fn contains_grouping_set(group_expr: &[Expr]) -> bool { group_expr .iter() - .any(|expr| matches!(expr, Expr::GroupingSet(_))) + .any(|expr| matches!(expr, Expr::GroupingSet(_, _))) } /// Calculates functional dependencies for aggregate expressions. @@ -3187,9 +3724,9 @@ fn calc_func_dependencies_for_project( let proj_indices = exprs .iter() .map(|expr| match expr { - Expr::Wildcard(Wildcard { qualifier, options }) => { + Expr::Wildcard(Wildcard { qualifier, options }, _) => { let wildcard_fields = exprlist_to_fields( - vec![&Expr::Wildcard(Wildcard { + vec![&Expr::wildcard(Wildcard { qualifier: qualifier.clone(), options: options.clone(), })], @@ -3207,7 +3744,7 @@ fn calc_func_dependencies_for_project( .collect::>(), ) } - Expr::Alias(alias) => { + Expr::Alias(alias, _) => { let name = format!("{}", alias.expr); Ok(input_fields .iter() @@ -3247,6 +3784,16 @@ pub struct Sort { pub fetch: Option, } +impl Sort { + fn stats(&self) -> LogicalPlanStats { + self.expr.iter().map(|s| &s.expr).fold( + LogicalPlanStats::new(enum_set!(LogicalPlanNodePattern::LogicalPlanSort)) + .merge(self.input.stats()), + |s, e| s.merge(e.stats()), + ) + } +} + /// Join two logical plans on one or more join columns #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Join { @@ -3277,7 +3824,7 @@ impl Join { column_on: (Vec, Vec), ) -> Result { let original_join = match original { - LogicalPlan::Join(join) => join, + LogicalPlan::Join(join, _) => join, _ => return plan_err!("Could not create join with project input"), }; @@ -3301,6 +3848,19 @@ impl Join { null_equals_null: original_join.null_equals_null, }) } + + fn stats(&self) -> LogicalPlanStats { + self.on + .iter() + .flat_map(|(l, r)| [l, r]) + .chain(&self.filter) + .fold( + LogicalPlanStats::new(enum_set!(LogicalPlanNodePattern::LogicalPlanJoin)) + .merge(self.left.stats()) + .merge(self.right.stats()), + |s, e| s.merge(e.stats()), + ) + } } // Manual implementation needed because of `schema` field. Comparison excludes this field. @@ -3357,8 +3917,8 @@ pub struct Subquery { impl Subquery { pub fn try_from_expr(plan: &Expr) -> Result<&Subquery> { match plan { - Expr::ScalarSubquery(it) => Ok(it), - Expr::Cast(cast) => Subquery::try_from_expr(cast.expr.as_ref()), + Expr::ScalarSubquery(it, _) => Ok(it), + Expr::Cast(cast, _) => Subquery::try_from_expr(cast.expr.as_ref()), _ => plan_err!("Could not coerce into ScalarSubquery!"), } } @@ -3369,6 +3929,14 @@ impl Subquery { outer_ref_columns: self.outer_ref_columns.clone(), } } + + pub(crate) fn stats(&self) -> LogicalPlanStats { + self.outer_ref_columns.iter().fold( + LogicalPlanStats::new(enum_set!(LogicalPlanNodePattern::LogicalPlanSubquery)) + .merge(self.subquery.stats()), + |s, e| s.merge(e.stats()), + ) + } } impl Debug for Subquery { @@ -3488,6 +4056,12 @@ impl PartialOrd for Unnest { } } +impl Unnest { + fn stats(&self) -> LogicalPlanStats { + LogicalPlanStats::new(enum_set!(LogicalPlanNodePattern::LogicalPlanUnnest)) + } +} + #[cfg(test)] mod tests { @@ -3980,7 +4554,7 @@ digraph { let plan = table_scan(TableReference::none(), &schema, None) .unwrap() .aggregate( - vec![Expr::GroupingSet(GroupingSet::GroupingSets(vec![ + vec![Expr::grouping_set(GroupingSet::GroupingSets(vec![ vec![col("foo")], vec![col("bar")], ]))], @@ -4016,7 +4590,7 @@ digraph { ) .unwrap(), ); - let scan = Arc::new(LogicalPlan::TableScan(TableScan { + let scan = Arc::new(LogicalPlan::table_scan(TableScan { table_name: TableReference::bare("tab"), source: Arc::clone(&source) as Arc, projection: None, @@ -4046,7 +4620,7 @@ digraph { ) .unwrap(), ); - let scan = Arc::new(LogicalPlan::TableScan(TableScan { + let scan = Arc::new(LogicalPlan::table_scan(TableScan { table_name: TableReference::bare("tab"), source, projection: None, @@ -4081,13 +4655,13 @@ digraph { // the parent plan is built again with call to LogicalPlan::with_new_inputs -> with_new_exprs let plan = plan .transform(|plan| match plan { - LogicalPlan::TableScan(table) => { + LogicalPlan::TableScan(table, _) => { let filter = Filter::try_new( external_filter.clone(), - Arc::new(LogicalPlan::TableScan(table)), + Arc::new(LogicalPlan::table_scan(table)), ) .unwrap(); - Ok(Transformed::yes(LogicalPlan::Filter(filter))) + Ok(Transformed::yes(LogicalPlan::filter(filter))) } x => Ok(Transformed::no(x)), }) @@ -4139,12 +4713,12 @@ digraph { #[test] fn test_limit_with_new_children() { - let limit = LogicalPlan::Limit(Limit { + let limit = LogicalPlan::limit(Limit { skip: None, fetch: Some(Box::new(Expr::Literal( ScalarValue::new_ten(&DataType::UInt32).unwrap(), ))), - input: Arc::new(LogicalPlan::Values(Values { + input: Arc::new(LogicalPlan::values(Values { schema: Arc::new(DFSchema::empty()), values: vec![vec![]], })), diff --git a/datafusion/expr/src/logical_plan/statement.rs b/datafusion/expr/src/logical_plan/statement.rs index 26df379f5e4ad..ca301b9711ccc 100644 --- a/datafusion/expr/src/logical_plan/statement.rs +++ b/datafusion/expr/src/logical_plan/statement.rs @@ -20,6 +20,7 @@ use datafusion_common::{DFSchema, DFSchemaRef}; use std::fmt::{self, Display}; use std::sync::{Arc, OnceLock}; +use crate::logical_plan::tree_node::LogicalPlanStats; use crate::{expr_vec_fmt, Expr, LogicalPlan}; /// Statements have a unchanging empty schema. @@ -130,6 +131,16 @@ impl Statement { } Wrapper(self) } + + pub(crate) fn stats(&self) -> LogicalPlanStats { + match self { + Statement::Prepare(Prepare { input, .. }) => input.stats(), + Statement::Execute(Execute { parameters, .. }) => parameters + .iter() + .fold(LogicalPlanStats::empty(), |s, e| s.merge(e.stats())), + _ => LogicalPlanStats::empty(), + } + } } /// Indicates if a transaction was committed or aborted diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 6850c30f4f81b..b49337cb21cef 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -53,6 +53,96 @@ use datafusion_common::tree_node::{ TreeNodeRewriter, TreeNodeVisitor, }; use datafusion_common::{internal_err, Result}; +use enumset::{enum_set, EnumSet, EnumSetType}; + +#[derive(EnumSetType, Debug)] +pub enum LogicalPlanNodePattern { + /// [`Expr`] + ExprAlias, + ExprColumn, + ExprScalarVariable, + ExprLiteral, + ExprBinaryExpr, + ExprLike, + ExprSimilarTo, + ExprNot, + ExprIsNotNull, + ExprIsNull, + ExprIsTrue, + ExprIsFalse, + ExprIsUnknown, + ExprIsNotTrue, + ExprIsNotFalse, + ExprIsNotUnknown, + ExprNegative, + ExprGetIndexedField, + ExprBetween, + ExprCase, + ExprCast, + ExprTryCast, + ExprScalarFunction, + ExprAggregateFunction, + ExprWindowFunction, + ExprInList, + ExprExists, + ExprInSubquery, + ExprScalarSubquery, + ExprWildcard, + ExprGroupingSet, + ExprPlaceholder, + ExprOuterReferenceColumn, + ExprUnnest, + + /// [`LogicalPlan`] + LogicalPlanProjection, + LogicalPlanFilter, + LogicalPlanWindow, + LogicalPlanAggregate, + LogicalPlanSort, + LogicalPlanJoin, + LogicalPlanCrossJoin, + LogicalPlanRepartition, + LogicalPlanUnion, + LogicalPlanTableScan, + LogicalPlanEmptyRelation, + LogicalPlanSubquery, + LogicalPlanSubqueryAlias, + LogicalPlanLimit, + LogicalPlanStatement, + LogicalPlanValues, + LogicalPlanExplain, + LogicalPlanAnalyze, + LogicalPlanExtension, + LogicalPlanDistinct, + LogicalPlanDml, + LogicalPlanDdl, + LogicalPlanCopy, + LogicalPlanDescribeTable, + LogicalPlanUnnest, + LogicalPlanRecursiveQuery, +} + +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +pub struct LogicalPlanStats { + patterns: EnumSet, +} + +impl LogicalPlanStats { + pub(crate) fn new(patterns: EnumSet) -> Self { + Self { patterns } + } + + pub(crate) fn empty() -> Self { + Self { + patterns: EnumSet::empty(), + } + } + + pub(crate) fn merge(mut self, other: LogicalPlanStats) -> Self { + self.patterns.insert_all(other.patterns); + self + } +} impl TreeNode for LogicalPlan { fn apply_children<'n, F: FnMut(&'n Self) -> Result>( @@ -75,75 +165,93 @@ impl TreeNode for LogicalPlan { f: F, ) -> Result> { Ok(match self { - LogicalPlan::Projection(Projection { - expr, - input, - schema, - }) => input.map_elements(f)?.update_data(|input| { - LogicalPlan::Projection(Projection { + LogicalPlan::Projection( + Projection { + expr, + input, + schema, + }, + _, + ) => input.map_elements(f)?.update_data(|input| { + LogicalPlan::projection(Projection { expr, input, schema, }) }), - LogicalPlan::Filter(Filter { - predicate, - input, - having, - }) => input.map_elements(f)?.update_data(|input| { - LogicalPlan::Filter(Filter { + LogicalPlan::Filter( + Filter { + predicate, + input, + having, + }, + _, + ) => input.map_elements(f)?.update_data(|input| { + LogicalPlan::filter(Filter { predicate, input, having, }) }), - LogicalPlan::Repartition(Repartition { - input, - partitioning_scheme, - }) => input.map_elements(f)?.update_data(|input| { - LogicalPlan::Repartition(Repartition { + LogicalPlan::Repartition( + Repartition { + input, + partitioning_scheme, + }, + _, + ) => input.map_elements(f)?.update_data(|input| { + LogicalPlan::repartition(Repartition { input, partitioning_scheme, }) }), - LogicalPlan::Window(Window { - input, - window_expr, - schema, - }) => input.map_elements(f)?.update_data(|input| { - LogicalPlan::Window(Window { + LogicalPlan::Window( + Window { + input, + window_expr, + schema, + }, + _, + ) => input.map_elements(f)?.update_data(|input| { + LogicalPlan::window(Window { input, window_expr, schema, }) }), - LogicalPlan::Aggregate(Aggregate { - input, - group_expr, - aggr_expr, - schema, - }) => input.map_elements(f)?.update_data(|input| { - LogicalPlan::Aggregate(Aggregate { + LogicalPlan::Aggregate( + Aggregate { + input, + group_expr, + aggr_expr, + schema, + }, + _, + ) => input.map_elements(f)?.update_data(|input| { + LogicalPlan::aggregate(Aggregate { input, group_expr, aggr_expr, schema, }) }), - LogicalPlan::Sort(Sort { expr, input, fetch }) => input + LogicalPlan::Sort(Sort { expr, input, fetch }, _) => input .map_elements(f)? - .update_data(|input| LogicalPlan::Sort(Sort { expr, input, fetch })), - LogicalPlan::Join(Join { - left, - right, - on, - filter, - join_type, - join_constraint, - schema, - null_equals_null, - }) => (left, right).map_elements(f)?.update_data(|(left, right)| { - LogicalPlan::Join(Join { + .update_data(|input| LogicalPlan::sort(Sort { expr, input, fetch })), + LogicalPlan::Join( + Join { + left, + right, + on, + filter, + join_type, + join_constraint, + schema, + null_equals_null, + }, + _, + ) => (left, right).map_elements(f)?.update_data(|(left, right)| { + LogicalPlan::join(Join { left, right, on, @@ -154,35 +262,43 @@ impl TreeNode for LogicalPlan { null_equals_null, }) }), - LogicalPlan::Limit(Limit { skip, fetch, input }) => input + LogicalPlan::Limit(Limit { skip, fetch, input }, _) => input .map_elements(f)? - .update_data(|input| LogicalPlan::Limit(Limit { skip, fetch, input })), - LogicalPlan::Subquery(Subquery { - subquery, - outer_ref_columns, - }) => subquery.map_elements(f)?.update_data(|subquery| { - LogicalPlan::Subquery(Subquery { + .update_data(|input| LogicalPlan::limit(Limit { skip, fetch, input })), + LogicalPlan::Subquery( + Subquery { + subquery, + outer_ref_columns, + }, + _, + ) => subquery.map_elements(f)?.update_data(|subquery| { + LogicalPlan::subquery(Subquery { subquery, outer_ref_columns, }) }), - LogicalPlan::SubqueryAlias(SubqueryAlias { - input, - alias, - schema, - }) => input.map_elements(f)?.update_data(|input| { - LogicalPlan::SubqueryAlias(SubqueryAlias { + LogicalPlan::SubqueryAlias( + SubqueryAlias { + input, + alias, + schema, + }, + _, + ) => input.map_elements(f)?.update_data(|input| { + LogicalPlan::subquery_alias(SubqueryAlias { input, alias, schema, }) }), - LogicalPlan::Extension(extension) => rewrite_extension_inputs(extension, f)? - .update_data(LogicalPlan::Extension), - LogicalPlan::Union(Union { inputs, schema }) => inputs + LogicalPlan::Extension(extension, _) => { + rewrite_extension_inputs(extension, f)? + .update_data(LogicalPlan::extension) + } + LogicalPlan::Union(Union { inputs, schema }, _) => inputs .map_elements(f)? - .update_data(|inputs| LogicalPlan::Union(Union { inputs, schema })), - LogicalPlan::Distinct(distinct) => match distinct { + .update_data(|inputs| LogicalPlan::union(Union { inputs, schema })), + LogicalPlan::Distinct(distinct, _) => match distinct { Distinct::All(input) => input.map_elements(f)?.update_data(Distinct::All), Distinct::On(DistinctOn { on_expr, @@ -200,15 +316,18 @@ impl TreeNode for LogicalPlan { }) }), } - .update_data(LogicalPlan::Distinct), - LogicalPlan::Explain(Explain { - verbose, - plan, - stringified_plans, - schema, - logical_optimization_succeeded, - }) => plan.map_elements(f)?.update_data(|plan| { - LogicalPlan::Explain(Explain { + .update_data(LogicalPlan::distinct), + LogicalPlan::Explain( + Explain { + verbose, + plan, + stringified_plans, + schema, + logical_optimization_succeeded, + }, + _, + ) => plan.map_elements(f)?.update_data(|plan| { + LogicalPlan::explain(Explain { verbose, plan, stringified_plans, @@ -216,25 +335,31 @@ impl TreeNode for LogicalPlan { logical_optimization_succeeded, }) }), - LogicalPlan::Analyze(Analyze { - verbose, - input, - schema, - }) => input.map_elements(f)?.update_data(|input| { - LogicalPlan::Analyze(Analyze { + LogicalPlan::Analyze( + Analyze { + verbose, + input, + schema, + }, + _, + ) => input.map_elements(f)?.update_data(|input| { + LogicalPlan::analyze(Analyze { verbose, input, schema, }) }), - LogicalPlan::Dml(DmlStatement { - table_name, - table_schema, - op, - input, - output_schema, - }) => input.map_elements(f)?.update_data(|input| { - LogicalPlan::Dml(DmlStatement { + LogicalPlan::Dml( + DmlStatement { + table_name, + table_schema, + op, + input, + output_schema, + }, + _, + ) => input.map_elements(f)?.update_data(|input| { + LogicalPlan::dml(DmlStatement { table_name, table_schema, op, @@ -242,14 +367,17 @@ impl TreeNode for LogicalPlan { output_schema, }) }), - LogicalPlan::Copy(CopyTo { - input, - output_url, - partition_by, - file_type, - options, - }) => input.map_elements(f)?.update_data(|input| { - LogicalPlan::Copy(CopyTo { + LogicalPlan::Copy( + CopyTo { + input, + output_url, + partition_by, + file_type, + options, + }, + _, + ) => input.map_elements(f)?.update_data(|input| { + LogicalPlan::copy(CopyTo { input, output_url, partition_by, @@ -257,7 +385,7 @@ impl TreeNode for LogicalPlan { options, }) }), - LogicalPlan::Ddl(ddl) => { + LogicalPlan::Ddl(ddl, _) => { match ddl { DdlStatement::CreateMemoryTable(CreateMemoryTable { name, @@ -304,18 +432,21 @@ impl TreeNode for LogicalPlan { | DdlStatement::CreateFunction(_) | DdlStatement::DropFunction(_) => Transformed::no(ddl), } - .update_data(LogicalPlan::Ddl) + .update_data(LogicalPlan::ddl) } - LogicalPlan::Unnest(Unnest { - input, - exec_columns: input_columns, - list_type_columns, - struct_type_columns, - dependency_indices, - schema, - options, - }) => input.map_elements(f)?.update_data(|input| { - LogicalPlan::Unnest(Unnest { + LogicalPlan::Unnest( + Unnest { + input, + exec_columns: input_columns, + list_type_columns, + struct_type_columns, + dependency_indices, + schema, + options, + }, + _, + ) => input.map_elements(f)?.update_data(|input| { + LogicalPlan::unnest(Unnest { input, exec_columns: input_columns, dependency_indices, @@ -325,14 +456,17 @@ impl TreeNode for LogicalPlan { options, }) }), - LogicalPlan::RecursiveQuery(RecursiveQuery { - name, - static_term, - recursive_term, - is_distinct, - }) => (static_term, recursive_term).map_elements(f)?.update_data( + LogicalPlan::RecursiveQuery( + RecursiveQuery { + name, + static_term, + recursive_term, + is_distinct, + }, + _, + ) => (static_term, recursive_term).map_elements(f)?.update_data( |(static_term, recursive_term)| { - LogicalPlan::RecursiveQuery(RecursiveQuery { + LogicalPlan::recursive_query(RecursiveQuery { name, static_term, recursive_term, @@ -340,14 +474,14 @@ impl TreeNode for LogicalPlan { }) }, ), - LogicalPlan::Statement(stmt) => match stmt { + LogicalPlan::Statement(stmt, _) => match stmt { Statement::Prepare(p) => p .input .map_elements(f)? .update_data(|input| Statement::Prepare(Prepare { input, ..p })), _ => Transformed::no(stmt), } - .update_data(LogicalPlan::Statement), + .update_data(LogicalPlan::statement), // plans without inputs LogicalPlan::TableScan { .. } | LogicalPlan::EmptyRelation { .. } @@ -403,42 +537,48 @@ impl LogicalPlan { mut f: F, ) -> Result { match self { - LogicalPlan::Projection(Projection { expr, .. }) => expr.apply_elements(f), - LogicalPlan::Values(Values { values, .. }) => values.apply_elements(f), - LogicalPlan::Filter(Filter { predicate, .. }) => f(predicate), - LogicalPlan::Repartition(Repartition { - partitioning_scheme, - .. - }) => match partitioning_scheme { + LogicalPlan::Projection(Projection { expr, .. }, _) => expr.apply_elements(f), + LogicalPlan::Values(Values { values, .. }, _) => values.apply_elements(f), + LogicalPlan::Filter(Filter { predicate, .. }, _) => f(predicate), + LogicalPlan::Repartition( + Repartition { + partitioning_scheme, + .. + }, + _, + ) => match partitioning_scheme { Partitioning::Hash(expr, _) | Partitioning::DistributeBy(expr) => { expr.apply_elements(f) } Partitioning::RoundRobinBatch(_) => Ok(TreeNodeRecursion::Continue), }, - LogicalPlan::Window(Window { window_expr, .. }) => { + LogicalPlan::Window(Window { window_expr, .. }, _) => { window_expr.apply_elements(f) } - LogicalPlan::Aggregate(Aggregate { - group_expr, - aggr_expr, - .. - }) => (group_expr, aggr_expr).apply_ref_elements(f), + LogicalPlan::Aggregate( + Aggregate { + group_expr, + aggr_expr, + .. + }, + _, + ) => (group_expr, aggr_expr).apply_ref_elements(f), // There are two part of expression for join, equijoin(on) and non-equijoin(filter). // 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`. // 2. the second part is non-equijoin(filter). - LogicalPlan::Join(Join { on, filter, .. }) => { + LogicalPlan::Join(Join { on, filter, .. }, _) => { (on, filter).apply_ref_elements(f) } - LogicalPlan::Sort(Sort { expr, .. }) => expr.apply_elements(f), - LogicalPlan::Extension(extension) => { + LogicalPlan::Sort(Sort { expr, .. }, _) => expr.apply_elements(f), + LogicalPlan::Extension(extension, _) => { // would be nice to avoid this copy -- maybe can // update extension to just observer Exprs extension.node.expressions().apply_elements(f) } - LogicalPlan::TableScan(TableScan { filters, .. }) => { + LogicalPlan::TableScan(TableScan { filters, .. }, _) => { filters.apply_elements(f) } - LogicalPlan::Unnest(unnest) => { + LogicalPlan::Unnest(unnest, _) => { let columns = unnest.exec_columns.clone(); let exprs = columns @@ -447,16 +587,19 @@ impl LogicalPlan { .collect::>(); exprs.apply_elements(f) } - LogicalPlan::Distinct(Distinct::On(DistinctOn { - on_expr, - select_expr, - sort_expr, - .. - })) => (on_expr, select_expr, sort_expr).apply_ref_elements(f), - LogicalPlan::Limit(Limit { skip, fetch, .. }) => { + LogicalPlan::Distinct( + Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + .. + }), + _, + ) => (on_expr, select_expr, sort_expr).apply_ref_elements(f), + LogicalPlan::Limit(Limit { skip, fetch, .. }, _) => { (skip, fetch).apply_ref_elements(f) } - LogicalPlan::Statement(stmt) => match stmt { + LogicalPlan::Statement(stmt, _) => match stmt { Statement::Execute(Execute { parameters, .. }) => { parameters.apply_elements(f) } @@ -464,16 +607,16 @@ impl LogicalPlan { }, // plans without expressions LogicalPlan::EmptyRelation(_) - | LogicalPlan::RecursiveQuery(_) - | LogicalPlan::Subquery(_) - | LogicalPlan::SubqueryAlias(_) - | LogicalPlan::Analyze(_) - | LogicalPlan::Explain(_) - | LogicalPlan::Union(_) - | LogicalPlan::Distinct(Distinct::All(_)) - | LogicalPlan::Dml(_) - | LogicalPlan::Ddl(_) - | LogicalPlan::Copy(_) + | LogicalPlan::RecursiveQuery(_, _) + | LogicalPlan::Subquery(_, _) + | LogicalPlan::SubqueryAlias(_, _) + | LogicalPlan::Analyze(_, _) + | LogicalPlan::Explain(_, _) + | LogicalPlan::Union(_, _) + | LogicalPlan::Distinct(Distinct::All(_), _) + | LogicalPlan::Dml(_, _) + | LogicalPlan::Ddl(_, _) + | LogicalPlan::Copy(_, _) | LogicalPlan::DescribeTable(_) => Ok(TreeNodeRecursion::Continue), } } @@ -490,35 +633,44 @@ impl LogicalPlan { mut f: F, ) -> Result> { Ok(match self { - LogicalPlan::Projection(Projection { - expr, - input, - schema, - }) => expr.map_elements(f)?.update_data(|expr| { - LogicalPlan::Projection(Projection { + LogicalPlan::Projection( + Projection { + expr, + input, + schema, + }, + _, + ) => expr.map_elements(f)?.update_data(|expr| { + LogicalPlan::projection(Projection { expr, input, schema, }) }), - LogicalPlan::Values(Values { schema, values }) => values + LogicalPlan::Values(Values { schema, values }, _) => values .map_elements(f)? - .update_data(|values| LogicalPlan::Values(Values { schema, values })), - LogicalPlan::Filter(Filter { - predicate, - input, - having, - }) => f(predicate)?.update_data(|predicate| { - LogicalPlan::Filter(Filter { + .update_data(|values| LogicalPlan::values(Values { schema, values })), + LogicalPlan::Filter( + Filter { + predicate, + input, + having, + }, + _, + ) => f(predicate)?.update_data(|predicate| { + LogicalPlan::filter(Filter { predicate, input, having, }) }), - LogicalPlan::Repartition(Repartition { - input, - partitioning_scheme, - }) => match partitioning_scheme { + LogicalPlan::Repartition( + Repartition { + input, + partitioning_scheme, + }, + _, + ) => match partitioning_scheme { Partitioning::Hash(expr, usize) => expr .map_elements(f)? .update_data(|expr| Partitioning::Hash(expr, usize)), @@ -528,30 +680,36 @@ impl LogicalPlan { Partitioning::RoundRobinBatch(_) => Transformed::no(partitioning_scheme), } .update_data(|partitioning_scheme| { - LogicalPlan::Repartition(Repartition { + LogicalPlan::repartition(Repartition { input, partitioning_scheme, }) }), - LogicalPlan::Window(Window { - input, - window_expr, - schema, - }) => window_expr.map_elements(f)?.update_data(|window_expr| { - LogicalPlan::Window(Window { + LogicalPlan::Window( + Window { + input, + window_expr, + schema, + }, + _, + ) => window_expr.map_elements(f)?.update_data(|window_expr| { + LogicalPlan::window(Window { input, window_expr, schema, }) }), - LogicalPlan::Aggregate(Aggregate { - input, - group_expr, - aggr_expr, - schema, - }) => (group_expr, aggr_expr).map_elements(f)?.update_data( + LogicalPlan::Aggregate( + Aggregate { + input, + group_expr, + aggr_expr, + schema, + }, + _, + ) => (group_expr, aggr_expr).map_elements(f)?.update_data( |(group_expr, aggr_expr)| { - LogicalPlan::Aggregate(Aggregate { + LogicalPlan::aggregate(Aggregate { input, group_expr, aggr_expr, @@ -563,17 +721,20 @@ impl LogicalPlan { // There are two part of expression for join, equijoin(on) and non-equijoin(filter). // 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`. // 2. the second part is non-equijoin(filter). - LogicalPlan::Join(Join { - left, - right, - on, - filter, - join_type, - join_constraint, - schema, - null_equals_null, - }) => (on, filter).map_elements(f)?.update_data(|(on, filter)| { - LogicalPlan::Join(Join { + LogicalPlan::Join( + Join { + left, + right, + on, + filter, + join_type, + join_constraint, + schema, + null_equals_null, + }, + _, + ) => (on, filter).map_elements(f)?.update_data(|(on, filter)| { + LogicalPlan::join(Join { left, right, on, @@ -584,14 +745,14 @@ impl LogicalPlan { null_equals_null, }) }), - LogicalPlan::Sort(Sort { expr, input, fetch }) => expr + LogicalPlan::Sort(Sort { expr, input, fetch }, _) => expr .map_elements(f)? - .update_data(|expr| LogicalPlan::Sort(Sort { expr, input, fetch })), - LogicalPlan::Extension(Extension { node }) => { + .update_data(|expr| LogicalPlan::sort(Sort { expr, input, fetch })), + LogicalPlan::Extension(Extension { node }, _) => { // would be nice to avoid this copy -- maybe can // update extension to just observer Exprs let exprs = node.expressions().map_elements(f)?; - let plan = LogicalPlan::Extension(Extension { + let plan = LogicalPlan::extension(Extension { node: UserDefinedLogicalNode::with_exprs_and_inputs( node.as_ref(), exprs.data, @@ -600,15 +761,18 @@ impl LogicalPlan { }); Transformed::new(plan, exprs.transformed, exprs.tnr) } - LogicalPlan::TableScan(TableScan { - table_name, - source, - projection, - projected_schema, - filters, - fetch, - }) => filters.map_elements(f)?.update_data(|filters| { - LogicalPlan::TableScan(TableScan { + LogicalPlan::TableScan( + TableScan { + table_name, + source, + projection, + projected_schema, + filters, + fetch, + }, + _, + ) => filters.map_elements(f)?.update_data(|filters| { + LogicalPlan::table_scan(TableScan { table_name, source, projection, @@ -617,16 +781,19 @@ impl LogicalPlan { fetch, }) }), - LogicalPlan::Distinct(Distinct::On(DistinctOn { - on_expr, - select_expr, - sort_expr, - input, - schema, - })) => (on_expr, select_expr, sort_expr) + LogicalPlan::Distinct( + Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + input, + schema, + }), + _, + ) => (on_expr, select_expr, sort_expr) .map_elements(f)? .update_data(|(on_expr, select_expr, sort_expr)| { - LogicalPlan::Distinct(Distinct::On(DistinctOn { + LogicalPlan::distinct(Distinct::On(DistinctOn { on_expr, select_expr, sort_expr, @@ -634,12 +801,12 @@ impl LogicalPlan { schema, })) }), - LogicalPlan::Limit(Limit { skip, fetch, input }) => { + LogicalPlan::Limit(Limit { skip, fetch, input }, _) => { (skip, fetch).map_elements(f)?.update_data(|(skip, fetch)| { - LogicalPlan::Limit(Limit { skip, fetch, input }) + LogicalPlan::limit(Limit { skip, fetch, input }) }) } - LogicalPlan::Statement(stmt) => match stmt { + LogicalPlan::Statement(stmt, _) => match stmt { Statement::Execute(e) => { e.parameters.map_elements(f)?.update_data(|parameters| { Statement::Execute(Execute { parameters, ..e }) @@ -647,20 +814,20 @@ impl LogicalPlan { } _ => Transformed::no(stmt), } - .update_data(LogicalPlan::Statement), + .update_data(LogicalPlan::statement), // plans without expressions LogicalPlan::EmptyRelation(_) - | LogicalPlan::Unnest(_) - | LogicalPlan::RecursiveQuery(_) - | LogicalPlan::Subquery(_) - | LogicalPlan::SubqueryAlias(_) - | LogicalPlan::Analyze(_) - | LogicalPlan::Explain(_) - | LogicalPlan::Union(_) - | LogicalPlan::Distinct(Distinct::All(_)) - | LogicalPlan::Dml(_) - | LogicalPlan::Ddl(_) - | LogicalPlan::Copy(_) + | LogicalPlan::Unnest(_, _) + | LogicalPlan::RecursiveQuery(_, _) + | LogicalPlan::Subquery(_, _) + | LogicalPlan::SubqueryAlias(_, _) + | LogicalPlan::Analyze(_, _) + | LogicalPlan::Explain(_, _) + | LogicalPlan::Union(_, _) + | LogicalPlan::Distinct(Distinct::All(_), _) + | LogicalPlan::Ddl(_, _) + | LogicalPlan::Dml(_, _) + | LogicalPlan::Copy(_, _) | LogicalPlan::DescribeTable(_) => Transformed::no(self), }) } @@ -821,13 +988,13 @@ impl LogicalPlan { ) -> Result { self.apply_expressions(|expr| { expr.apply(|expr| match expr { - Expr::Exists(Exists { subquery, .. }) - | Expr::InSubquery(InSubquery { subquery, .. }) - | Expr::ScalarSubquery(subquery) => { + Expr::Exists(Exists { subquery, .. }, _) + | Expr::InSubquery(InSubquery { subquery, .. }, _) + | Expr::ScalarSubquery(subquery, _) => { // use a synthetic plan so the collector sees a // LogicalPlan::Subquery (even though it is // actually a Subquery alias) - f(&LogicalPlan::Subquery(subquery.clone())) + f(&LogicalPlan::subquery(subquery.clone())) } _ => Ok(TreeNodeRecursion::Continue), }) @@ -844,30 +1011,35 @@ impl LogicalPlan { ) -> Result> { self.map_expressions(|expr| { expr.transform_down(|expr| match expr { - Expr::Exists(Exists { subquery, negated }) => { - f(LogicalPlan::Subquery(subquery))?.map_data(|s| match s { - LogicalPlan::Subquery(subquery) => { - Ok(Expr::Exists(Exists { subquery, negated })) + Expr::Exists(Exists { subquery, negated }, _) => { + f(LogicalPlan::subquery(subquery))?.map_data(|s| match s { + LogicalPlan::Subquery(subquery, _) => { + Ok(Expr::exists(Exists { subquery, negated })) } _ => internal_err!("Transformation should return Subquery"), }) } - Expr::InSubquery(InSubquery { - expr, - subquery, - negated, - }) => f(LogicalPlan::Subquery(subquery))?.map_data(|s| match s { - LogicalPlan::Subquery(subquery) => Ok(Expr::InSubquery(InSubquery { + Expr::InSubquery( + InSubquery { expr, subquery, negated, - })), + }, + _, + ) => f(LogicalPlan::subquery(subquery))?.map_data(|s| match s { + LogicalPlan::Subquery(subquery, _) => { + Ok(Expr::in_subquery(InSubquery { + expr, + subquery, + negated, + })) + } _ => internal_err!("Transformation should return Subquery"), }), - Expr::ScalarSubquery(subquery) => f(LogicalPlan::Subquery(subquery))? + Expr::ScalarSubquery(subquery, _) => f(LogicalPlan::subquery(subquery))? .map_data(|s| match s { - LogicalPlan::Subquery(subquery) => { - Ok(Expr::ScalarSubquery(subquery)) + LogicalPlan::Subquery(subquery, _) => { + Ok(Expr::scalar_subquery(subquery)) } _ => internal_err!("Transformation should return Subquery"), }), @@ -875,4 +1047,42 @@ impl LogicalPlan { }) }) } + + pub(crate) fn stats(&self) -> LogicalPlanStats { + match self { + LogicalPlan::Projection(_, stats) => *stats, + LogicalPlan::Filter(_, stats) => *stats, + LogicalPlan::Window(_, stats) => *stats, + LogicalPlan::Aggregate(_, stats) => *stats, + LogicalPlan::Sort(_, stats) => *stats, + LogicalPlan::Join(_, stats) => *stats, + LogicalPlan::Repartition(_, stats) => *stats, + LogicalPlan::Union(_, stats) => *stats, + LogicalPlan::TableScan(_, stats) => *stats, + LogicalPlan::EmptyRelation { .. } => LogicalPlanStats::new(enum_set!( + LogicalPlanNodePattern::LogicalPlanEmptyRelation + )), + LogicalPlan::Subquery(_, stats) => *stats, + LogicalPlan::SubqueryAlias(_, stats) => *stats, + LogicalPlan::Limit(_, stats) => *stats, + LogicalPlan::Statement { .. } => LogicalPlanStats::new(enum_set!( + LogicalPlanNodePattern::LogicalPlanStatement + )), + LogicalPlan::Values(_, stats) => *stats, + LogicalPlan::Explain(_, stats) => *stats, + LogicalPlan::Analyze(_, stats) => *stats, + LogicalPlan::Extension { .. } => LogicalPlanStats::new(enum_set!( + LogicalPlanNodePattern::LogicalPlanExtension + )), + LogicalPlan::Distinct(_, stats) => *stats, + LogicalPlan::Dml(_, stats) => *stats, + LogicalPlan::Ddl(_, stats) => *stats, + LogicalPlan::Copy(_, stats) => *stats, + LogicalPlan::DescribeTable { .. } => LogicalPlanStats::new(enum_set!( + LogicalPlanNodePattern::LogicalPlanDescribeTable + )), + LogicalPlan::Unnest(_, stats) => *stats, + LogicalPlan::RecursiveQuery(_, stats) => *stats, + } + } } diff --git a/datafusion/expr/src/operation.rs b/datafusion/expr/src/operation.rs index 6b79a8248b293..fd532bde9d690 100644 --- a/datafusion/expr/src/operation.rs +++ b/datafusion/expr/src/operation.rs @@ -117,7 +117,7 @@ impl ops::Neg for Expr { type Output = Self; fn neg(self) -> Self::Output { - Expr::Negative(Box::new(self)) + Expr::negative(Box::new(self)) } } @@ -127,33 +127,33 @@ impl Not for Expr { fn not(self) -> Self::Output { match self { - Expr::Like(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - }) => Expr::Like(Like::new( - !negated, - expr, - pattern, - escape_char, - case_insensitive, - )), - Expr::SimilarTo(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - }) => Expr::SimilarTo(Like::new( - !negated, - expr, - pattern, - escape_char, - case_insensitive, - )), - _ => Expr::Not(Box::new(self)), + Expr::Like( + Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + }, + stats, + ) => Expr::Like( + Like::new(!negated, expr, pattern, escape_char, case_insensitive), + stats, + ), + Expr::SimilarTo( + Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + }, + stats, + ) => Expr::SimilarTo( + Like::new(!negated, expr, pattern, escape_char, case_insensitive), + stats, + ), + _ => Expr::_not(Box::new(self)), } } } diff --git a/datafusion/expr/src/test/function_stub.rs b/datafusion/expr/src/test/function_stub.rs index 262aa99e50075..f8cdce45fbfff 100644 --- a/datafusion/expr/src/test/function_stub.rs +++ b/datafusion/expr/src/test/function_stub.rs @@ -60,7 +60,7 @@ macro_rules! create_func { create_func!(Sum, sum_udaf); pub fn sum(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new_udf( + Expr::aggregate_function(AggregateFunction::new_udf( sum_udaf(), vec![expr], false, @@ -73,7 +73,7 @@ pub fn sum(expr: Expr) -> Expr { create_func!(Count, count_udaf); pub fn count(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new_udf( + Expr::aggregate_function(AggregateFunction::new_udf( count_udaf(), vec![expr], false, @@ -86,7 +86,7 @@ pub fn count(expr: Expr) -> Expr { create_func!(Avg, avg_udaf); pub fn avg(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new_udf( + Expr::aggregate_function(AggregateFunction::new_udf( avg_udaf(), vec![expr], false, @@ -284,7 +284,7 @@ impl AggregateUDFImpl for Count { create_func!(Min, min_udaf); pub fn min(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new_udf( + Expr::aggregate_function(AggregateFunction::new_udf( min_udaf(), vec![expr], false, @@ -369,7 +369,7 @@ impl AggregateUDFImpl for Min { create_func!(Max, max_udaf); pub fn max(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new_udf( + Expr::aggregate_function(AggregateFunction::new_udf( max_udaf(), vec![expr], false, diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index eacace5ed0461..afe467e1d300d 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -22,7 +22,9 @@ use crate::expr::{ InSubquery, Like, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, }; use crate::{Expr, ExprFunctionExt}; +use enumset::enum_set; +use crate::logical_plan::tree_node::{LogicalPlanNodePattern, LogicalPlanStats}; use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRefContainer, }; @@ -43,27 +45,27 @@ impl TreeNode for Expr { f: F, ) -> Result { match self { - Expr::Alias(Alias { expr, .. }) - | Expr::Unnest(Unnest { expr }) - | Expr::Not(expr) - | Expr::IsNotNull(expr) - | Expr::IsTrue(expr) - | Expr::IsFalse(expr) - | Expr::IsUnknown(expr) - | Expr::IsNotTrue(expr) - | Expr::IsNotFalse(expr) - | Expr::IsNotUnknown(expr) - | Expr::IsNull(expr) - | Expr::Negative(expr) - | Expr::Cast(Cast { expr, .. }) - | Expr::TryCast(TryCast { expr, .. }) - | Expr::InSubquery(InSubquery { expr, .. }) => expr.apply_elements(f), - Expr::GroupingSet(GroupingSet::Rollup(exprs)) - | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.apply_elements(f), - Expr::ScalarFunction(ScalarFunction { args, .. }) => { + Expr::Alias(Alias { expr, .. }, _) + | Expr::Unnest(Unnest { expr }, _) + | Expr::Not(expr, _) + | Expr::IsNotNull(expr, _) + | Expr::IsTrue(expr, _) + | Expr::IsFalse(expr, _) + | Expr::IsUnknown(expr, _) + | Expr::IsNotTrue(expr, _) + | Expr::IsNotFalse(expr, _) + | Expr::IsNotUnknown(expr, _) + | Expr::IsNull(expr, _) + | Expr::Negative(expr, _) + | Expr::Cast(Cast { expr, .. }, _) + | Expr::TryCast(TryCast { expr, .. }, _) + | Expr::InSubquery(InSubquery { expr, .. }, _) => expr.apply_elements(f), + Expr::GroupingSet(GroupingSet::Rollup(exprs), _) + | Expr::GroupingSet(GroupingSet::Cube(exprs), _) => exprs.apply_elements(f), + Expr::ScalarFunction(ScalarFunction { args, .. }, _) => { args.apply_elements(f) } - Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => { + Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs), _) => { lists_of_exprs.apply_elements(f) } Expr::Column(_) @@ -72,32 +74,32 @@ impl TreeNode for Expr { | Expr::ScalarVariable(_, _) | Expr::Literal(_) | Expr::Exists { .. } - | Expr::ScalarSubquery(_) + | Expr::ScalarSubquery(_, _) | Expr::Wildcard { .. } | Expr::Placeholder(_) => Ok(TreeNodeRecursion::Continue), - Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { + Expr::BinaryExpr(BinaryExpr { left, right, .. }, _) => { (left, right).apply_ref_elements(f) } - Expr::Like(Like { expr, pattern, .. }) - | Expr::SimilarTo(Like { expr, pattern, .. }) => { + Expr::Like(Like { expr, pattern, .. }, _) + | Expr::SimilarTo(Like { expr, pattern, .. }, _) => { (expr, pattern).apply_ref_elements(f) } Expr::Between(Between { expr, low, high, .. - }) => (expr, low, high).apply_ref_elements(f), - Expr::Case(Case { expr, when_then_expr, else_expr }) => + }, _) => (expr, low, high).apply_ref_elements(f), + Expr::Case(Case { expr, when_then_expr, else_expr }, _) => (expr, when_then_expr, else_expr).apply_ref_elements(f), - Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }) => + Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }, _) => (args, filter, order_by).apply_ref_elements(f), Expr::WindowFunction(WindowFunction { args, partition_by, order_by, .. - }) => { + }, _) => { (args, partition_by, order_by).apply_ref_elements(f) } - Expr::InList(InList { expr, list, .. }) => { + Expr::InList(InList { expr, list, .. }, _) => { (expr, list).apply_ref_elements(f) } } @@ -117,40 +119,49 @@ impl TreeNode for Expr { | Expr::Placeholder(Placeholder { .. }) | Expr::OuterReferenceColumn(_, _) | Expr::Exists { .. } - | Expr::ScalarSubquery(_) + | Expr::ScalarSubquery(_, _) | Expr::ScalarVariable(_, _) | Expr::Literal(_) => Transformed::no(self), - Expr::Unnest(Unnest { expr, .. }) => expr + Expr::Unnest(Unnest { expr, .. }, _) => expr .map_elements(f)? - .update_data(|expr| Expr::Unnest(Unnest { expr })), - Expr::Alias(Alias { - expr, - relation, - name, - }) => f(*expr)?.update_data(|e| e.alias_qualified(relation, name)), - Expr::InSubquery(InSubquery { - expr, - subquery, - negated, - }) => expr.map_elements(f)?.update_data(|be| { - Expr::InSubquery(InSubquery::new(be, subquery, negated)) + .update_data(|expr| Expr::unnest(Unnest { expr })), + Expr::Alias( + Alias { + expr, + relation, + name, + }, + _, + ) => f(*expr)?.update_data(|e| e.alias_qualified(relation, name)), + Expr::InSubquery( + InSubquery { + expr, + subquery, + negated, + }, + _, + ) => expr.map_elements(f)?.update_data(|be| { + Expr::in_subquery(InSubquery::new(be, subquery, negated)) }), - Expr::BinaryExpr(BinaryExpr { left, op, right }) => (left, right) + Expr::BinaryExpr(BinaryExpr { left, op, right }, _) => (left, right) .map_elements(f)? .update_data(|(new_left, new_right)| { - Expr::BinaryExpr(BinaryExpr::new(new_left, op, new_right)) + Expr::binary_expr(BinaryExpr::new(new_left, op, new_right)) }), - Expr::Like(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - }) => { + Expr::Like( + Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + }, + _, + ) => { (expr, pattern) .map_elements(f)? .update_data(|(new_expr, new_pattern)| { - Expr::Like(Like::new( + Expr::_like(Like::new( negated, new_expr, new_pattern, @@ -159,17 +170,20 @@ impl TreeNode for Expr { )) }) } - Expr::SimilarTo(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - }) => { + Expr::SimilarTo( + Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + }, + _, + ) => { (expr, pattern) .map_elements(f)? .update_data(|(new_expr, new_pattern)| { - Expr::SimilarTo(Like::new( + Expr::similar_to(Like::new( negated, new_expr, new_pattern, @@ -178,60 +192,77 @@ impl TreeNode for Expr { )) }) } - Expr::Not(expr) => expr.map_elements(f)?.update_data(Expr::Not), - Expr::IsNotNull(expr) => expr.map_elements(f)?.update_data(Expr::IsNotNull), - Expr::IsNull(expr) => expr.map_elements(f)?.update_data(Expr::IsNull), - Expr::IsTrue(expr) => expr.map_elements(f)?.update_data(Expr::IsTrue), - Expr::IsFalse(expr) => expr.map_elements(f)?.update_data(Expr::IsFalse), - Expr::IsUnknown(expr) => expr.map_elements(f)?.update_data(Expr::IsUnknown), - Expr::IsNotTrue(expr) => expr.map_elements(f)?.update_data(Expr::IsNotTrue), - Expr::IsNotFalse(expr) => expr.map_elements(f)?.update_data(Expr::IsNotFalse), - Expr::IsNotUnknown(expr) => { - expr.map_elements(f)?.update_data(Expr::IsNotUnknown) + Expr::Not(expr, _) => expr.map_elements(f)?.update_data(Expr::_not), + Expr::IsNotNull(expr, _) => { + expr.map_elements(f)?.update_data(Expr::_is_not_null) } - Expr::Negative(expr) => expr.map_elements(f)?.update_data(Expr::Negative), - Expr::Between(Between { - expr, - negated, - low, - high, - }) => (expr, low, high).map_elements(f)?.update_data( + Expr::IsNull(expr, _) => expr.map_elements(f)?.update_data(Expr::_is_null), + Expr::IsTrue(expr, _) => expr.map_elements(f)?.update_data(Expr::_is_true), + Expr::IsFalse(expr, _) => expr.map_elements(f)?.update_data(Expr::_is_false), + Expr::IsUnknown(expr, _) => { + expr.map_elements(f)?.update_data(Expr::_is_unknown) + } + Expr::IsNotTrue(expr, _) => { + expr.map_elements(f)?.update_data(Expr::_is_not_true) + } + Expr::IsNotFalse(expr, _) => { + expr.map_elements(f)?.update_data(Expr::_is_not_false) + } + Expr::IsNotUnknown(expr, _) => { + expr.map_elements(f)?.update_data(Expr::_is_not_unknown) + } + Expr::Negative(expr, _) => expr.map_elements(f)?.update_data(Expr::negative), + Expr::Between( + Between { + expr, + negated, + low, + high, + }, + _, + ) => (expr, low, high).map_elements(f)?.update_data( |(new_expr, new_low, new_high)| { - Expr::Between(Between::new(new_expr, negated, new_low, new_high)) + Expr::_between(Between::new(new_expr, negated, new_low, new_high)) }, ), - Expr::Case(Case { - expr, - when_then_expr, - else_expr, - }) => (expr, when_then_expr, else_expr) + Expr::Case( + Case { + expr, + when_then_expr, + else_expr, + }, + _, + ) => (expr, when_then_expr, else_expr) .map_elements(f)? .update_data(|(new_expr, new_when_then_expr, new_else_expr)| { - Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr)) + Expr::case(Case::new(new_expr, new_when_then_expr, new_else_expr)) }), - Expr::Cast(Cast { expr, data_type }) => expr + Expr::Cast(Cast { expr, data_type }, _) => expr .map_elements(f)? - .update_data(|be| Expr::Cast(Cast::new(be, data_type))), - Expr::TryCast(TryCast { expr, data_type }) => expr + .update_data(|be| Expr::cast(Cast::new(be, data_type))), + Expr::TryCast(TryCast { expr, data_type }, _) => expr .map_elements(f)? - .update_data(|be| Expr::TryCast(TryCast::new(be, data_type))), - Expr::ScalarFunction(ScalarFunction { func, args }) => { + .update_data(|be| Expr::try_cast(TryCast::new(be, data_type))), + Expr::ScalarFunction(ScalarFunction { func, args }, _) => { args.map_elements(f)?.map_data(|new_args| { - Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + Ok(Expr::scalar_function(ScalarFunction::new_udf( func, new_args, ))) })? } - Expr::WindowFunction(WindowFunction { - args, - fun, - partition_by, - order_by, - window_frame, - null_treatment, - }) => (args, partition_by, order_by).map_elements(f)?.update_data( + Expr::WindowFunction( + WindowFunction { + args, + fun, + partition_by, + order_by, + window_frame, + null_treatment, + }, + _, + ) => (args, partition_by, order_by).map_elements(f)?.update_data( |(new_args, new_partition_by, new_order_by)| { - Expr::WindowFunction(WindowFunction::new(fun, new_args)) + Expr::window_function(WindowFunction::new(fun, new_args)) .partition_by(new_partition_by) .order_by(new_order_by) .window_frame(window_frame) @@ -240,16 +271,19 @@ impl TreeNode for Expr { .unwrap() }, ), - Expr::AggregateFunction(AggregateFunction { - args, - func, - distinct, - filter, - order_by, - null_treatment, - }) => (args, filter, order_by).map_elements(f)?.map_data( + Expr::AggregateFunction( + AggregateFunction { + args, + func, + distinct, + filter, + order_by, + null_treatment, + }, + _, + ) => (args, filter, order_by).map_elements(f)?.map_data( |(new_args, new_filter, new_order_by)| { - Ok(Expr::AggregateFunction(AggregateFunction::new_udf( + Ok(Expr::aggregate_function(AggregateFunction::new_udf( func, new_args, distinct, @@ -259,28 +293,79 @@ impl TreeNode for Expr { ))) }, )?, - Expr::GroupingSet(grouping_set) => match grouping_set { - GroupingSet::Rollup(exprs) => exprs - .map_elements(f)? - .update_data(|ve| Expr::GroupingSet(GroupingSet::Rollup(ve))), - GroupingSet::Cube(exprs) => exprs - .map_elements(f)? - .update_data(|ve| Expr::GroupingSet(GroupingSet::Cube(ve))), + Expr::GroupingSet(grouping_set, _) => match grouping_set { + GroupingSet::Rollup(exprs) => { + exprs.map_elements(f)?.update_data(GroupingSet::Rollup) + } + GroupingSet::Cube(exprs) => { + exprs.map_elements(f)?.update_data(GroupingSet::Cube) + } GroupingSet::GroupingSets(lists_of_exprs) => lists_of_exprs .map_elements(f)? - .update_data(|new_lists_of_exprs| { - Expr::GroupingSet(GroupingSet::GroupingSets(new_lists_of_exprs)) - }), - }, - Expr::InList(InList { - expr, - list, - negated, - }) => (expr, list) + .update_data(GroupingSet::GroupingSets), + } + .update_data(Expr::grouping_set), + Expr::InList( + InList { + expr, + list, + negated, + }, + _, + ) => (expr, list) .map_elements(f)? .update_data(|(new_expr, new_list)| { - Expr::InList(InList::new(new_expr, new_list, negated)) + Expr::_in_list(InList::new(new_expr, new_list, negated)) }), }) } } +impl Expr { + pub fn stats(&self) -> LogicalPlanStats { + match self { + Expr::Alias(_, stats) => *stats, + Expr::Column { .. } => { + LogicalPlanStats::new(enum_set!(LogicalPlanNodePattern::ExprColumn)) + } + Expr::ScalarVariable { .. } => LogicalPlanStats::new(enum_set!( + LogicalPlanNodePattern::ExprScalarVariable + )), + Expr::Literal { .. } => { + LogicalPlanStats::new(enum_set!(LogicalPlanNodePattern::ExprLiteral)) + } + Expr::BinaryExpr(_, stats) => *stats, + Expr::Like(_, stats) => *stats, + Expr::SimilarTo(_, stats) => *stats, + Expr::Not(_, stats) => *stats, + Expr::IsNotNull(_, stats) => *stats, + Expr::IsNull(_, stats) => *stats, + Expr::IsTrue(_, stats) => *stats, + Expr::IsFalse(_, stats) => *stats, + Expr::IsUnknown(_, stats) => *stats, + Expr::IsNotTrue(_, stats) => *stats, + Expr::IsNotFalse(_, stats) => *stats, + Expr::IsNotUnknown(_, stats) => *stats, + Expr::Negative(_, stats) => *stats, + Expr::Between(_, stats) => *stats, + Expr::Case(_, stats) => *stats, + Expr::Cast(_, stats) => *stats, + Expr::TryCast(_, stats) => *stats, + Expr::ScalarFunction(_, stats) => *stats, + Expr::AggregateFunction(_, stats) => *stats, + Expr::WindowFunction(_, stats) => *stats, + Expr::InList(_, stats) => *stats, + Expr::Exists(_, stats) => *stats, + Expr::InSubquery(_, stats) => *stats, + Expr::ScalarSubquery(_, stats) => *stats, + Expr::Wildcard(_, stats) => *stats, + Expr::GroupingSet(_, stats) => *stats, + Expr::Placeholder(_) => { + LogicalPlanStats::new(enum_set!(LogicalPlanNodePattern::ExprPlaceholder)) + } + Expr::OuterReferenceColumn { .. } => LogicalPlanStats::new(enum_set!( + LogicalPlanNodePattern::ExprOuterReferenceColumn + )), + Expr::Unnest(_, stats) => *stats, + } + } +} diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index dbbf88447ba39..c3aadf8f1e7c6 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -145,7 +145,7 @@ impl AggregateUDF { /// This utility allows using the UDAF without requiring access to /// the registry, such as with the DataFrame API. pub fn call(&self, args: Vec) -> Expr { - Expr::AggregateFunction(AggregateFunction::new_udf( + Expr::aggregate_function(AggregateFunction::new_udf( Arc::new(self.clone()), args, false, diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 1a5d50477b1c8..0c70908327669 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -123,7 +123,7 @@ impl ScalarUDF { /// let expr = my_func.call(vec![col("a"), lit(12.3)]); /// ``` pub fn call(&self, args: Vec) -> Expr { - Expr::ScalarFunction(crate::expr::ScalarFunction::new_udf( + Expr::scalar_function(crate::expr::ScalarFunction::new_udf( Arc::new(self.clone()), args, )) diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 475b864a8a18d..aa7d6f6b3bf7d 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -130,7 +130,7 @@ impl WindowUDF { pub fn call(&self, args: Vec) -> Expr { let fun = crate::WindowFunctionDefinition::WindowUDF(Arc::new(self.clone())); - Expr::WindowFunction(WindowFunction::new(fun, args)) + Expr::window_function(WindowFunction::new(fun, args)) } /// Returns this function's name diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index a1ab142fa8355..8d18cb0ae413b 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -61,7 +61,7 @@ pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet) -> Result /// Count the number of distinct exprs in a list of group by expressions. If the /// first element is a `GroupingSet` expression then it must be the only expr. pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result { - if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() { + if let Some(Expr::GroupingSet(grouping_set, _)) = group_expr.first() { if group_expr.len() > 1 { return plan_err!( "Invalid group by expressions, GroupingSet must be the only expression" @@ -201,7 +201,7 @@ fn cross_join_grouping_sets( pub fn enumerate_grouping_sets(group_expr: Vec) -> Result> { let has_grouping_set = group_expr .iter() - .any(|expr| matches!(expr, Expr::GroupingSet(_))); + .any(|expr| matches!(expr, Expr::GroupingSet(_, _))); if !has_grouping_set || group_expr.len() == 1 { return Ok(group_expr); } @@ -210,17 +210,17 @@ pub fn enumerate_grouping_sets(group_expr: Vec) -> Result> { .iter() .map(|expr| { let exprs = match expr { - Expr::GroupingSet(GroupingSet::GroupingSets(grouping_sets)) => { + Expr::GroupingSet(GroupingSet::GroupingSets(grouping_sets), _) => { check_grouping_sets_size_limit(grouping_sets.len())?; grouping_sets.iter().map(|e| e.iter().collect()).collect() } - Expr::GroupingSet(GroupingSet::Cube(group_exprs)) => { + Expr::GroupingSet(GroupingSet::Cube(group_exprs), _) => { let grouping_sets = powerset(group_exprs) .map_err(|e| plan_datafusion_err!("{}", e))?; check_grouping_sets_size_limit(grouping_sets.len())?; grouping_sets } - Expr::GroupingSet(GroupingSet::Rollup(group_exprs)) => { + Expr::GroupingSet(GroupingSet::Rollup(group_exprs), _) => { let size = group_exprs.len(); let slice = group_exprs.as_slice(); check_grouping_sets_size_limit(size * (size + 1) / 2 + 1)?; @@ -247,7 +247,7 @@ pub fn enumerate_grouping_sets(group_expr: Vec) -> Result> { }) .unwrap_or_default(); - Ok(vec![Expr::GroupingSet(GroupingSet::GroupingSets( + Ok(vec![Expr::grouping_set(GroupingSet::GroupingSets( grouping_sets, ))]) } @@ -255,7 +255,7 @@ pub fn enumerate_grouping_sets(group_expr: Vec) -> Result> { /// Find all distinct exprs in a list of group by expressions. If the /// first element is a `GroupingSet` expression then it must be the only expr. pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result> { - if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() { + if let Some(Expr::GroupingSet(grouping_set, _)) = group_expr.first() { if group_expr.len() > 1 { return plan_err!( "Invalid group by expressions, GroupingSet must be the only expression" @@ -282,23 +282,23 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { // Use explicit pattern match instead of a default // implementation, so that in the future if someone adds // new Expr types, they will check here as well - Expr::Unnest(_) + Expr::Unnest(_, _) | Expr::ScalarVariable(_, _) - | Expr::Alias(_) + | Expr::Alias(_, _) | Expr::Literal(_) | Expr::BinaryExpr { .. } | Expr::Like { .. } | Expr::SimilarTo { .. } - | Expr::Not(_) - | Expr::IsNotNull(_) - | Expr::IsNull(_) - | Expr::IsTrue(_) - | Expr::IsFalse(_) - | Expr::IsUnknown(_) - | Expr::IsNotTrue(_) - | Expr::IsNotFalse(_) - | Expr::IsNotUnknown(_) - | Expr::Negative(_) + | Expr::Not(_, _) + | Expr::IsNotNull(_, _) + | Expr::IsNull(_, _) + | Expr::IsTrue(_, _) + | Expr::IsFalse(_, _) + | Expr::IsUnknown(_, _) + | Expr::IsNotTrue(_, _) + | Expr::IsNotFalse(_, _) + | Expr::IsNotUnknown(_, _) + | Expr::Negative(_, _) | Expr::Between { .. } | Expr::Case { .. } | Expr::Cast { .. } @@ -306,11 +306,11 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::ScalarFunction(..) | Expr::WindowFunction { .. } | Expr::AggregateFunction { .. } - | Expr::GroupingSet(_) + | Expr::GroupingSet(_, _) | Expr::InList { .. } | Expr::Exists { .. } - | Expr::InSubquery(_) - | Expr::ScalarSubquery(_) + | Expr::InSubquery(_, _) + | Expr::ScalarSubquery(_, _) | Expr::Wildcard { .. } | Expr::Placeholder(_) | Expr::OuterReferenceColumn { .. } => {} @@ -582,7 +582,7 @@ pub fn group_window_expr_by_sort_keys( ) -> Result)>> { let mut result = vec![]; window_expr.into_iter().try_for_each(|expr| match &expr { - Expr::WindowFunction( WindowFunction{ partition_by, order_by, .. }) => { + Expr::WindowFunction( WindowFunction{ partition_by, order_by, .. }, _) => { let sort_key = generate_sort_key(partition_by, order_by)?; if let Some((_, values)) = result.iter_mut().find( |group: &&mut (WindowSortKey, Vec)| matches!(group, (key, _) if *key == sort_key), @@ -703,7 +703,7 @@ pub fn exprlist_to_fields<'a>( let result = exprs .into_iter() .map(|e| match e { - Expr::Wildcard(Wildcard { qualifier, options }) => match qualifier { + Expr::Wildcard(Wildcard { qualifier, options }, _) => match qualifier { None => { let excluded: Vec = get_excluded_columns( options.exclude.as_ref(), @@ -769,18 +769,18 @@ pub fn exprlist_to_fields<'a>( /// If we expand a wildcard expression basing the intermediate plan, we could get some duplicate fields. pub fn find_base_plan(input: &LogicalPlan) -> &LogicalPlan { match input { - LogicalPlan::Window(window) => find_base_plan(&window.input), - LogicalPlan::Aggregate(agg) => find_base_plan(&agg.input), + LogicalPlan::Window(window, _) => find_base_plan(&window.input), + LogicalPlan::Aggregate(agg, _) => find_base_plan(&agg.input), // [SqlToRel::try_process_unnest] will convert Expr(Unnest(Expr)) to Projection/Unnest/Projection // We should expand the wildcard expression based on the input plan of the inner Projection. - LogicalPlan::Unnest(unnest) => { - if let LogicalPlan::Projection(projection) = unnest.input.deref() { + LogicalPlan::Unnest(unnest, _) => { + if let LogicalPlan::Projection(projection, _) = unnest.input.deref() { find_base_plan(&projection.input) } else { input } } - LogicalPlan::Filter(filter) => { + LogicalPlan::Filter(filter, _) => { if filter.having { // If a filter is used for a having clause, its input plan is an aggregation. // We should expand the wildcard expression based on the aggregation's input plan. @@ -802,10 +802,13 @@ pub fn exprlist_len( exprs .iter() .map(|e| match e { - Expr::Wildcard(Wildcard { - qualifier: None, - options, - }) => { + Expr::Wildcard( + Wildcard { + qualifier: None, + options, + }, + _, + ) => { let excluded = get_excluded_columns( options.exclude.as_ref(), options.except.as_ref(), @@ -819,10 +822,13 @@ pub fn exprlist_len( .len(), ) } - Expr::Wildcard(Wildcard { - qualifier: Some(qualifier), - options, - }) => { + Expr::Wildcard( + Wildcard { + qualifier: Some(qualifier), + options, + }, + _, + ) => { let related_wildcard_schema = wildcard_schema.as_ref().map_or_else( || Ok(Arc::clone(schema)), |schema| { @@ -1096,15 +1102,18 @@ pub fn split_conjunction(expr: &Expr) -> Vec<&Expr> { fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<&'a Expr> { match expr { - Expr::BinaryExpr(BinaryExpr { - right, - op: Operator::And, - left, - }) => { + Expr::BinaryExpr( + BinaryExpr { + right, + op: Operator::And, + left, + }, + _, + ) => { let exprs = split_conjunction_impl(left, exprs); split_conjunction_impl(right, exprs) } - Expr::Alias(Alias { expr, .. }) => split_conjunction_impl(expr, exprs), + Expr::Alias(Alias { expr, .. }, _) => split_conjunction_impl(expr, exprs), other => { exprs.push(other); exprs @@ -1120,15 +1129,18 @@ pub fn iter_conjunction(expr: &Expr) -> impl Iterator { std::iter::from_fn(move || { while let Some(expr) = stack.pop() { match expr { - Expr::BinaryExpr(BinaryExpr { - right, - op: Operator::And, - left, - }) => { + Expr::BinaryExpr( + BinaryExpr { + right, + op: Operator::And, + left, + }, + _, + ) => { stack.push(right); stack.push(left); } - Expr::Alias(Alias { expr, .. }) => stack.push(expr), + Expr::Alias(Alias { expr, .. }, _) => stack.push(expr), other => return Some(other), } } @@ -1144,15 +1156,18 @@ pub fn iter_conjunction_owned(expr: Expr) -> impl Iterator { std::iter::from_fn(move || { while let Some(expr) = stack.pop() { match expr { - Expr::BinaryExpr(BinaryExpr { - right, - op: Operator::And, - left, - }) => { + Expr::BinaryExpr( + BinaryExpr { + right, + op: Operator::And, + left, + }, + _, + ) => { stack.push(*right); stack.push(*left); } - Expr::Alias(Alias { expr, .. }) => stack.push(*expr), + Expr::Alias(Alias { expr, .. }, _) => stack.push(*expr), other => return Some(other), } } @@ -1217,11 +1232,11 @@ fn split_binary_owned_impl( mut exprs: Vec, ) -> Vec { match expr { - Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => { + Expr::BinaryExpr(BinaryExpr { right, op, left }, _) if op == operator => { let exprs = split_binary_owned_impl(*left, operator, exprs); split_binary_owned_impl(*right, operator, exprs) } - Expr::Alias(Alias { expr, .. }) => { + Expr::Alias(Alias { expr, .. }, _) => { split_binary_owned_impl(*expr, operator, exprs) } other => { @@ -1244,11 +1259,11 @@ fn split_binary_impl<'a>( mut exprs: Vec<&'a Expr>, ) -> Vec<&'a Expr> { match expr { - Expr::BinaryExpr(BinaryExpr { right, op, left }) if *op == operator => { + Expr::BinaryExpr(BinaryExpr { right, op, left }, _) if *op == operator => { let exprs = split_binary_impl(left, operator, exprs); split_binary_impl(right, operator, exprs) } - Expr::Alias(Alias { expr, .. }) => split_binary_impl(expr, operator, exprs), + Expr::Alias(Alias { expr, .. }, _) => split_binary_impl(expr, operator, exprs), other => { exprs.push(other); exprs @@ -1331,7 +1346,7 @@ pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result) -> Result<(Vec, Vec)> { for filter in exprs.into_iter() { // If the expression contains correlated predicates, add it to join filters if filter.contains_outer() { - if !matches!(filter, Expr::BinaryExpr(BinaryExpr{ left, op: Operator::Eq, right }) if left.eq(right)) + if !matches!(filter, Expr::BinaryExpr(BinaryExpr{ left, op: Operator::Eq, right }, _) if left.eq(right)) { joins.push(strip_outer_reference((*filter).clone())); } @@ -1422,19 +1437,19 @@ mod tests { #[test] fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> { - let max1 = Expr::WindowFunction(WindowFunction::new( + let max1 = Expr::window_function(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); - let max2 = Expr::WindowFunction(WindowFunction::new( + let max2 = Expr::window_function(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); - let min3 = Expr::WindowFunction(WindowFunction::new( + let min3 = Expr::window_function(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(min_udaf()), vec![col("name")], )); - let sum4 = Expr::WindowFunction(WindowFunction::new( + let sum4 = Expr::window_function(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), vec![col("age")], )); @@ -1452,25 +1467,25 @@ mod tests { let age_asc = Sort::new(col("age"), true, true); let name_desc = Sort::new(col("name"), false, true); let created_at_desc = Sort::new(col("created_at"), false, true); - let max1 = Expr::WindowFunction(WindowFunction::new( + let max1 = Expr::window_function(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )) .order_by(vec![age_asc.clone(), name_desc.clone()]) .build() .unwrap(); - let max2 = Expr::WindowFunction(WindowFunction::new( + let max2 = Expr::window_function(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); - let min3 = Expr::WindowFunction(WindowFunction::new( + let min3 = Expr::window_function(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(min_udaf()), vec![col("name")], )) .order_by(vec![age_asc.clone(), name_desc.clone()]) .build() .unwrap(); - let sum4 = Expr::WindowFunction(WindowFunction::new( + let sum4 = Expr::window_function(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), vec![col("age")], )) @@ -1802,11 +1817,11 @@ mod tests { fn test_collect_expr() -> Result<()> { let mut accum: HashSet = HashSet::new(); expr_to_columns( - &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)), + &Expr::cast(Cast::new(Box::new(col("a")), DataType::Float64)), &mut accum, )?; expr_to_columns( - &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)), + &Expr::cast(Cast::new(Box::new(col("a")), DataType::Float64)), &mut accum, )?; assert_eq!(1, accum.len()); diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 8fdd702b5b7c6..0baecaee1c9f2 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -69,7 +69,7 @@ make_udaf_expr_and_func!( ); pub fn count_distinct(expr: Expr) -> Expr { - Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + Expr::aggregate_function(datafusion_expr::expr::AggregateFunction::new_udf( count_udaf(), vec![expr], true, diff --git a/datafusion/functions-aggregate/src/macros.rs b/datafusion/functions-aggregate/src/macros.rs index ffb5183278e67..aebcfbbe409ac 100644 --- a/datafusion/functions-aggregate/src/macros.rs +++ b/datafusion/functions-aggregate/src/macros.rs @@ -22,7 +22,7 @@ macro_rules! make_udaf_expr { pub fn $EXPR_FN( $($arg: datafusion_expr::Expr,)* ) -> datafusion_expr::Expr { - datafusion_expr::Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + datafusion_expr::Expr::aggregate_function(datafusion_expr::expr::AggregateFunction::new_udf( $AGGREGATE_UDF_FN(), vec![$($arg),*], false, @@ -45,7 +45,7 @@ macro_rules! make_udaf_expr_and_func { pub fn $EXPR_FN( args: Vec, ) -> datafusion_expr::Expr { - datafusion_expr::Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + datafusion_expr::Expr::aggregate_function(datafusion_expr::expr::AggregateFunction::new_udf( $AGGREGATE_UDF_FN(), args, false, diff --git a/datafusion/functions-nested/src/macros.rs b/datafusion/functions-nested/src/macros.rs index 00247f39ac10f..b561e4ae76aa8 100644 --- a/datafusion/functions-nested/src/macros.rs +++ b/datafusion/functions-nested/src/macros.rs @@ -49,7 +49,7 @@ macro_rules! make_udf_expr_and_func { // "fluent expr_fn" style function #[doc = $DOC] pub fn $EXPR_FN($($arg: datafusion_expr::Expr),*) -> datafusion_expr::Expr { - datafusion_expr::Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction::new_udf( + datafusion_expr::Expr::scalar_function(datafusion_expr::expr::ScalarFunction::new_udf( $SCALAR_UDF_FN(), vec![$($arg),*], )) @@ -62,7 +62,7 @@ macro_rules! make_udf_expr_and_func { // "fluent expr_fn" style function #[doc = $DOC] pub fn $EXPR_FN(arg: Vec) -> datafusion_expr::Expr { - datafusion_expr::Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction::new_udf( + datafusion_expr::Expr::scalar_function(datafusion_expr::expr::ScalarFunction::new_udf( $SCALAR_UDF_FN(), arg, )) diff --git a/datafusion/functions-nested/src/map.rs b/datafusion/functions-nested/src/map.rs index 73aad10a8e26e..ac92d8681035c 100644 --- a/datafusion/functions-nested/src/map.rs +++ b/datafusion/functions-nested/src/map.rs @@ -38,7 +38,7 @@ use crate::make_array::make_array; pub fn map(keys: Vec, values: Vec) -> Expr { let keys = make_array(keys); let values = make_array(values); - Expr::ScalarFunction(ScalarFunction::new_udf(map_udf(), vec![keys, values])) + Expr::scalar_function(ScalarFunction::new_udf(map_udf(), vec![keys, values])) } create_func!(MapFunc, map_udf); diff --git a/datafusion/functions-nested/src/planner.rs b/datafusion/functions-nested/src/planner.rs index 1929b8222a1b6..a390998f60fc7 100644 --- a/datafusion/functions-nested/src/planner.rs +++ b/datafusion/functions-nested/src/planner.rs @@ -111,14 +111,14 @@ impl ExprPlanner for NestedFunctionPlanner { let keys = make_array(keys.into_iter().map(|(_, e)| e).collect()); let values = make_array(values.into_iter().map(|(_, e)| e).collect()); - Ok(PlannerResult::Planned(Expr::ScalarFunction( + Ok(PlannerResult::Planned(Expr::scalar_function( ScalarFunction::new_udf(map_udf(), vec![keys, values]), ))) } fn plan_any(&self, expr: RawBinaryExpr) -> Result> { if expr.op == sqlparser::ast::BinaryOperator::Eq { - Ok(PlannerResult::Planned(Expr::ScalarFunction( + Ok(PlannerResult::Planned(Expr::scalar_function( ScalarFunction::new_udf( array_has_udf(), // left and right are reversed here so `needle=any(haystack)` -> `array_has(haystack, needle)` @@ -150,8 +150,8 @@ impl ExprPlanner for FieldAccessPlanner { GetFieldAccess::ListIndex { key: index } => { match expr { // Special case for array_agg(expr)[index] to NTH_VALUE(expr, index) - Expr::AggregateFunction(agg_func) if is_array_agg(&agg_func) => { - Ok(PlannerResult::Planned(Expr::AggregateFunction( + Expr::AggregateFunction(agg_func, _) if is_array_agg(&agg_func) => { + Ok(PlannerResult::Planned(Expr::aggregate_function( datafusion_expr::expr::AggregateFunction::new_udf( nth_value_udaf(), agg_func diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index a3e3feaa17e3d..b3c54fd942b5f 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -125,7 +125,7 @@ impl ScalarUDFImpl for ArrowCastFunc { arg } else { // Use an actual cast to get the correct type - Expr::Cast(datafusion_expr::Cast { + Expr::cast(datafusion_expr::Cast { expr: Box::new(arg), data_type: target_type, }) diff --git a/datafusion/functions/src/core/planner.rs b/datafusion/functions/src/core/planner.rs index 717a74797c0b5..a9e77a20acd69 100644 --- a/datafusion/functions/src/core/planner.rs +++ b/datafusion/functions/src/core/planner.rs @@ -46,7 +46,7 @@ impl ExprPlanner for CoreFunctionPlanner { args: Vec, is_named_struct: bool, ) -> Result>> { - Ok(PlannerResult::Planned(Expr::ScalarFunction( + Ok(PlannerResult::Planned(Expr::scalar_function( ScalarFunction::new_udf( if is_named_struct { named_struct() @@ -59,7 +59,7 @@ impl ExprPlanner for CoreFunctionPlanner { } fn plan_overlay(&self, args: Vec) -> Result>> { - Ok(PlannerResult::Planned(Expr::ScalarFunction( + Ok(PlannerResult::Planned(Expr::scalar_function( ScalarFunction::new_udf(crate::string::overlay(), args), ))) } @@ -78,7 +78,7 @@ impl ExprPlanner for CoreFunctionPlanner { // Iterate over nested_names and create nested get_field expressions for nested_name in nested_names { let get_field_args = vec![expr, lit(ScalarValue::from(nested_name.clone()))]; - expr = Expr::ScalarFunction(ScalarFunction::new_udf( + expr = Expr::scalar_function(ScalarFunction::new_udf( crate::core::get_field(), get_field_args, )); diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index 9110f9f532d84..21f9e85d39f7b 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -222,7 +222,7 @@ impl ScalarUDFImpl for LogFunc { &info.get_data_type(&base)?, )?))) } - Expr::ScalarFunction(ScalarFunction { func, mut args }) + Expr::ScalarFunction(ScalarFunction { func, mut args }, _) if is_pow(&func) && args.len() == 2 && base == args[0] => { let b = args.pop().unwrap(); // length checked above diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index a24c613f52599..64cfc38ae0fba 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -155,7 +155,7 @@ impl ScalarUDFImpl for PowerFunc { Expr::Literal(value) if value == ScalarValue::new_one(&exponent_type)? => { Ok(ExprSimplifyResult::Simplified(base)) } - Expr::ScalarFunction(ScalarFunction { func, mut args }) + Expr::ScalarFunction(ScalarFunction { func, mut args }, _) if is_log(&func) && args.len() == 2 && base == args[0] => { let b = args.pop().unwrap(); // length checked above diff --git a/datafusion/functions/src/planner.rs b/datafusion/functions/src/planner.rs index 93edec7ece307..31a1c2d10b0f9 100644 --- a/datafusion/functions/src/planner.rs +++ b/datafusion/functions/src/planner.rs @@ -30,21 +30,21 @@ pub struct UserDefinedFunctionPlanner; impl ExprPlanner for UserDefinedFunctionPlanner { #[cfg(feature = "datetime_expressions")] fn plan_extract(&self, args: Vec) -> Result>> { - Ok(PlannerResult::Planned(Expr::ScalarFunction( + Ok(PlannerResult::Planned(Expr::scalar_function( ScalarFunction::new_udf(crate::datetime::date_part(), args), ))) } #[cfg(feature = "unicode_expressions")] fn plan_position(&self, args: Vec) -> Result>> { - Ok(PlannerResult::Planned(Expr::ScalarFunction( + Ok(PlannerResult::Planned(Expr::scalar_function( ScalarFunction::new_udf(crate::unicode::strpos(), args), ))) } #[cfg(feature = "unicode_expressions")] fn plan_substring(&self, args: Vec) -> Result>> { - Ok(PlannerResult::Planned(Expr::ScalarFunction( + Ok(PlannerResult::Planned(Expr::scalar_function( ScalarFunction::new_udf(crate::unicode::substr(), args), ))) } diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index d49a2777b4ff8..8e0c86fb13df5 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -361,7 +361,7 @@ pub fn simplify_concat(args: Vec) -> Result { } if !args.eq(&new_args) { - Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction( + Ok(ExprSimplifyResult::Simplified(Expr::scalar_function( ScalarFunction { func: concat(), args: new_args, diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index 98a75f121c35f..d30b5d9580091 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -349,7 +349,7 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result Result> { plan.map_expressions(|expr| { let original_name = name_preserver.save(&expr); let transformed_expr = expr.transform_up(|expr| match expr { - Expr::WindowFunction(mut window_function) + Expr::WindowFunction(mut window_function, _) if is_count_star_window_aggregate(&window_function) => { window_function.args = vec![lit(COUNT_STAR_EXPANSION)]; - Ok(Transformed::yes(Expr::WindowFunction(window_function))) + Ok(Transformed::yes(Expr::window_function(window_function))) } - Expr::AggregateFunction(mut aggregate_function) + Expr::AggregateFunction(mut aggregate_function, _) if is_count_star_aggregate(&aggregate_function) => { aggregate_function.args = vec![lit(COUNT_STAR_EXPANSION)]; - Ok(Transformed::yes(Expr::AggregateFunction( + Ok(Transformed::yes(Expr::aggregate_function( aggregate_function, ))) } @@ -219,7 +219,7 @@ mod tests { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .window(vec![Expr::WindowFunction(WindowFunction::new( + .window(vec![Expr::window_function(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![wildcard()], )) diff --git a/datafusion/optimizer/src/analyzer/expand_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/expand_wildcard_rule.rs index ff9f3df39fd20..70c1633d52c9e 100644 --- a/datafusion/optimizer/src/analyzer/expand_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/expand_wildcard_rule.rs @@ -53,25 +53,25 @@ impl AnalyzerRule for ExpandWildcardRule { fn expand_internal(plan: LogicalPlan) -> Result> { match plan { - LogicalPlan::Projection(Projection { expr, input, .. }) => { + LogicalPlan::Projection(Projection { expr, input, .. }, _) => { let projected_expr = expand_exprlist(&input, expr)?; validate_unique_names("Projections", projected_expr.iter())?; Ok(Transformed::yes( Projection::try_new(projected_expr, Arc::clone(&input)) - .map(LogicalPlan::Projection)?, + .map(LogicalPlan::projection)?, )) } // The schema of the plan should also be updated if the child plan is transformed. - LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }) => { + LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }, _) => { Ok(Transformed::yes( - SubqueryAlias::try_new(input, alias).map(LogicalPlan::SubqueryAlias)?, + SubqueryAlias::try_new(input, alias).map(LogicalPlan::subquery_alias)?, )) } - LogicalPlan::Distinct(Distinct::On(distinct_on)) => { + LogicalPlan::Distinct(Distinct::On(distinct_on), _) => { let projected_expr = expand_exprlist(&distinct_on.input, distinct_on.select_expr)?; validate_unique_names("Distinct", projected_expr.iter())?; - Ok(Transformed::yes(LogicalPlan::Distinct(Distinct::On( + Ok(Transformed::yes(LogicalPlan::distinct(Distinct::On( DistinctOn::try_new( distinct_on.on_expr, projected_expr, @@ -89,7 +89,7 @@ fn expand_exprlist(input: &LogicalPlan, expr: Vec) -> Result> { let input = find_base_plan(input); for e in expr { match e { - Expr::Wildcard(Wildcard { qualifier, options }) => { + Expr::Wildcard(Wildcard { qualifier, options }, _) => { if let Some(qualifier) = qualifier { let expanded = expand_qualified_wildcard( &qualifier, diff --git a/datafusion/optimizer/src/analyzer/function_rewrite.rs b/datafusion/optimizer/src/analyzer/function_rewrite.rs index c6bf14ebce2e3..d7ef7cfab0ed9 100644 --- a/datafusion/optimizer/src/analyzer/function_rewrite.rs +++ b/datafusion/optimizer/src/analyzer/function_rewrite.rs @@ -50,7 +50,7 @@ impl ApplyFunctionRewrites { // resolution only, so order does not matter here let mut schema = merge_schema(&plan.inputs()); - if let LogicalPlan::TableScan(ts) = &plan { + if let LogicalPlan::TableScan(ts, _) = &plan { let source_schema = DFSchema::try_from_qualified_schema( ts.table_name.clone(), &ts.source.schema(), diff --git a/datafusion/optimizer/src/analyzer/inline_table_scan.rs b/datafusion/optimizer/src/analyzer/inline_table_scan.rs index 68edda671a7a7..2916960631bd3 100644 --- a/datafusion/optimizer/src/analyzer/inline_table_scan.rs +++ b/datafusion/optimizer/src/analyzer/inline_table_scan.rs @@ -57,7 +57,7 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { // Match only on scans without filter / projection / fetch // Views and DataFrames won't have those added // during the early stage of planning. - LogicalPlan::TableScan(table_scan) if table_scan.filters.is_empty() => { + LogicalPlan::TableScan(table_scan, _) if table_scan.filters.is_empty() => { if let Some(sub_plan) = table_scan.source.get_logical_plan() { let sub_plan = sub_plan.into_owned(); let projection_exprs = @@ -71,7 +71,7 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { .build() .map(Transformed::yes) } else { - Ok(Transformed::no(LogicalPlan::TableScan(table_scan))) + Ok(Transformed::no(LogicalPlan::table_scan(table_scan))) } } _ => Ok(Transformed::no(plan)), @@ -93,7 +93,7 @@ fn generate_projection_expr( ))); } } else { - exprs.push(Expr::Wildcard(Wildcard { + exprs.push(Expr::wildcard(Wildcard { qualifier: None, options: WildcardOptions::default(), })); diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index a9fd4900b2f4a..afad2fae2ca81 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -181,9 +181,9 @@ fn check_plan(plan: &LogicalPlan) -> Result<()> { // recursively look for subqueries expr.apply(|expr| { match expr { - Expr::Exists(Exists { subquery, .. }) - | Expr::InSubquery(InSubquery { subquery, .. }) - | Expr::ScalarSubquery(subquery) => { + Expr::Exists(Exists { subquery, .. }, _) + | Expr::InSubquery(InSubquery { subquery, .. }, _) + | Expr::ScalarSubquery(subquery, _) => { check_subquery_expr(plan, &subquery.subquery, expr)?; } _ => {} diff --git a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs index 16ebb8cd3972f..5acb154858e39 100644 --- a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs +++ b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs @@ -30,7 +30,7 @@ use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{ internal_datafusion_err, plan_err, Column, DFSchemaRef, Result, ScalarValue, }; -use datafusion_expr::expr::{AggregateFunction, Alias}; +use datafusion_expr::expr::AggregateFunction; use datafusion_expr::logical_plan::LogicalPlan; use datafusion_expr::utils::grouping_set_to_exprlist; use datafusion_expr::{ @@ -79,7 +79,7 @@ fn replace_grouping_exprs( aggr_expr: Vec, ) -> Result { // Create HashMap from Expr to index in the grouping_id bitmap - let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]); + let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_, _)]); let group_expr_to_bitmap_index = group_expr_to_bitmap_index(&group_expr)?; let columns = schema.columns(); let mut new_agg_expr = Vec::new(); @@ -97,17 +97,14 @@ fn replace_grouping_exprs( .zip(columns.into_iter().skip(group_expr_len + grouping_id_len)) { match expr { - Expr::AggregateFunction(ref function) if is_grouping_function(&expr) => { + Expr::AggregateFunction(ref function, _) if is_grouping_function(&expr) => { let grouping_expr = grouping_function_on_id( function, &group_expr_to_bitmap_index, is_grouping_set, )?; - projection_exprs.push(Expr::Alias(Alias::new( - grouping_expr, - column.relation, - column.name, - ))); + projection_exprs + .push(grouping_expr.alias_qualified(column.relation, column.name)); } _ => { projection_exprs.push(Expr::Column(column)); @@ -117,9 +114,9 @@ fn replace_grouping_exprs( } // Recreate aggregate without grouping functions let new_aggregate = - LogicalPlan::Aggregate(Aggregate::try_new(input, group_expr, new_agg_expr)?); + LogicalPlan::aggregate(Aggregate::try_new(input, group_expr, new_agg_expr)?); // Create projection with grouping functions calculations - let projection = LogicalPlan::Projection(Projection::try_new( + let projection = LogicalPlan::projection(Projection::try_new( projection_exprs, new_aggregate.into(), )?); @@ -132,13 +129,16 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { plan.map_subqueries(|plan| plan.transform_up(analyze_internal))?; let transformed_plan = transformed_plan.transform_data(|plan| match plan { - LogicalPlan::Aggregate(Aggregate { - input, - group_expr, - aggr_expr, - schema, - .. - }) if contains_grouping_function(&aggr_expr) => Ok(Transformed::yes( + LogicalPlan::Aggregate( + Aggregate { + input, + group_expr, + aggr_expr, + schema, + .. + }, + _, + ) if contains_grouping_function(&aggr_expr) => Ok(Transformed::yes( replace_grouping_exprs(input, schema, group_expr, aggr_expr)?, )), _ => Ok(Transformed::no(plan)), @@ -150,7 +150,7 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { fn is_grouping_function(expr: &Expr) -> bool { // TODO: Do something better than name here should grouping be a built // in expression? - matches!(expr, Expr::AggregateFunction(AggregateFunction { ref func, .. }) if func.name() == "grouping") + matches!(expr, Expr::AggregateFunction(AggregateFunction { ref func, .. }, _) if func.name() == "grouping") } fn contains_grouping_function(exprs: &[Expr]) -> bool { diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/optimizer/src/analyzer/subquery.rs index 0b54b302c2dfc..033b4a6ca3b16 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/optimizer/src/analyzer/subquery.rs @@ -38,7 +38,7 @@ pub fn check_subquery_expr( expr: &Expr, ) -> Result<()> { check_plan(inner_plan)?; - if let Expr::ScalarSubquery(subquery) = expr { + if let Expr::ScalarSubquery(subquery, _) = expr { // Scalar subquery should only return one column if subquery.subquery.schema().fields().len() > 1 { return plan_err!( @@ -50,13 +50,13 @@ pub fn check_subquery_expr( // Correlated scalar subquery must be aggregated to return at most one row if !subquery.outer_ref_columns.is_empty() { match strip_inner_query(inner_plan) { - LogicalPlan::Aggregate(agg) => { + LogicalPlan::Aggregate(agg, _) => { check_aggregation_in_scalar_subquery(inner_plan, agg) } - LogicalPlan::Filter(Filter { input, .. }) - if matches!(input.as_ref(), LogicalPlan::Aggregate(_)) => + LogicalPlan::Filter(Filter { input, .. }, _) + if matches!(input.as_ref(), LogicalPlan::Aggregate(_, _)) => { - if let LogicalPlan::Aggregate(agg) = input.as_ref() { + if let LogicalPlan::Aggregate(agg, _) = input.as_ref() { check_aggregation_in_scalar_subquery(inner_plan, agg) } else { Ok(()) @@ -77,9 +77,9 @@ pub fn check_subquery_expr( } }?; match outer_plan { - LogicalPlan::Projection(_) - | LogicalPlan::Filter(_) => Ok(()), - LogicalPlan::Aggregate(Aggregate {group_expr, aggr_expr,..}) => { + LogicalPlan::Projection(_, _) + | LogicalPlan::Filter(_, _) => Ok(()), + LogicalPlan::Aggregate(Aggregate {group_expr, aggr_expr,..}, _) => { if group_expr.contains(expr) && !aggr_expr.contains(expr) { // TODO revisit this validation logic plan_err!( @@ -96,7 +96,7 @@ pub fn check_subquery_expr( } check_correlations_in_subquery(inner_plan) } else { - if let Expr::InSubquery(subquery) = expr { + if let Expr::InSubquery(subquery, _) = expr { // InSubquery should only return one column if subquery.subquery.subquery.schema().fields().len() > 1 { return plan_err!( @@ -107,11 +107,11 @@ pub fn check_subquery_expr( } } match outer_plan { - LogicalPlan::Projection(_) - | LogicalPlan::Filter(_) - | LogicalPlan::Window(_) - | LogicalPlan::Aggregate(_) - | LogicalPlan::Join(_) => Ok(()), + LogicalPlan::Projection(_, _) + | LogicalPlan::Filter(_, _) + | LogicalPlan::Window(_, _) + | LogicalPlan::Aggregate(_, _) + | LogicalPlan::Join(_, _) => Ok(()), _ => plan_err!( "In/Exist subquery can only be used in \ Projection, Filter, Window functions, Aggregate and Join plan nodes, \ @@ -136,17 +136,17 @@ fn check_inner_plan(inner_plan: &LogicalPlan, can_contain_outer_ref: bool) -> Re } // We want to support as many operators as possible inside the correlated subquery match inner_plan { - LogicalPlan::Aggregate(_) => { + LogicalPlan::Aggregate(_, _) => { inner_plan.apply_children(|plan| { check_inner_plan(plan, can_contain_outer_ref)?; Ok(TreeNodeRecursion::Continue) })?; Ok(()) } - LogicalPlan::Filter(Filter { input, .. }) => { + LogicalPlan::Filter(Filter { input, .. }, _) => { check_inner_plan(input, can_contain_outer_ref) } - LogicalPlan::Window(window) => { + LogicalPlan::Window(window, _) => { check_mixed_out_refer_in_window(window)?; inner_plan.apply_children(|plan| { check_inner_plan(plan, can_contain_outer_ref)?; @@ -154,28 +154,31 @@ fn check_inner_plan(inner_plan: &LogicalPlan, can_contain_outer_ref: bool) -> Re })?; Ok(()) } - LogicalPlan::Projection(_) - | LogicalPlan::Distinct(_) - | LogicalPlan::Sort(_) - | LogicalPlan::Union(_) - | LogicalPlan::TableScan(_) + LogicalPlan::Projection(_, _) + | LogicalPlan::Distinct(_, _) + | LogicalPlan::Sort(_, _) + | LogicalPlan::Union(_, _) + | LogicalPlan::TableScan(_, _) | LogicalPlan::EmptyRelation(_) - | LogicalPlan::Limit(_) - | LogicalPlan::Values(_) - | LogicalPlan::Subquery(_) - | LogicalPlan::SubqueryAlias(_) => { + | LogicalPlan::Limit(_, _) + | LogicalPlan::Values(_, _) + | LogicalPlan::Subquery(_, _) + | LogicalPlan::SubqueryAlias(_, _) => { inner_plan.apply_children(|plan| { check_inner_plan(plan, can_contain_outer_ref)?; Ok(TreeNodeRecursion::Continue) })?; Ok(()) } - LogicalPlan::Join(Join { - left, - right, - join_type, - .. - }) => match join_type { + LogicalPlan::Join( + Join { + left, + right, + join_type, + .. + }, + _, + ) => match join_type { JoinType::Inner => { inner_plan.apply_children(|plan| { check_inner_plan(plan, can_contain_outer_ref)?; @@ -202,7 +205,7 @@ fn check_inner_plan(inner_plan: &LogicalPlan, can_contain_outer_ref: bool) -> Re Ok(()) } }, - LogicalPlan::Extension(_) => Ok(()), + LogicalPlan::Extension(_, _) => Ok(()), _ => plan_err!("Unsupported operator in the subquery plan."), } } @@ -240,10 +243,10 @@ fn check_aggregation_in_scalar_subquery( fn strip_inner_query(inner_plan: &LogicalPlan) -> &LogicalPlan { match inner_plan { - LogicalPlan::Projection(projection) => { + LogicalPlan::Projection(projection, _) => { strip_inner_query(projection.input.as_ref()) } - LogicalPlan::SubqueryAlias(alias) => strip_inner_query(alias.input.as_ref()), + LogicalPlan::SubqueryAlias(alias, _) => strip_inner_query(alias.input.as_ref()), other => other, } } @@ -251,7 +254,7 @@ fn strip_inner_query(inner_plan: &LogicalPlan) -> &LogicalPlan { fn get_correlated_expressions(inner_plan: &LogicalPlan) -> Result> { let mut exprs = vec![]; inner_plan.apply_with_subqueries(|plan| { - if let LogicalPlan::Filter(Filter { predicate, .. }) = plan { + if let LogicalPlan::Filter(Filter { predicate, .. }, _) = plan { let (correlated, _): (Vec<_>, Vec<_>) = split_conjunction(predicate) .into_iter() .partition(|e| e.contains_outer()); @@ -339,7 +342,7 @@ mod test { #[test] fn wont_fail_extension_plan() { - let plan = LogicalPlan::Extension(Extension { + let plan = LogicalPlan::extension(Extension { node: Arc::new(MockUserDefinedLogicalPlan { empty_schema: DFSchemaRef::new(DFSchema::empty()), }), diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index b56c2dc604a9b..fa653d078e1b4 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -114,7 +114,7 @@ fn analyze_internal( // resolution only, so order does not matter here let mut schema = merge_schema(&plan.inputs()); - if let LogicalPlan::TableScan(ts) = &plan { + if let LogicalPlan::TableScan(ts, _) = &plan { let source_schema = DFSchema::try_from_qualified_schema( ts.table_name.clone(), &ts.source.schema(), @@ -128,9 +128,9 @@ fn analyze_internal( schema.merge(external_schema); // Coerce filter predicates to boolean (handles `WHERE NULL`) - let plan = if let LogicalPlan::Filter(mut filter) = plan { + let plan = if let LogicalPlan::Filter(mut filter, _) = plan { filter.predicate = filter.predicate.cast_to(&DataType::Boolean, &schema)?; - LogicalPlan::Filter(filter) + LogicalPlan::filter(filter) } else { plan }; @@ -168,9 +168,9 @@ impl<'a> TypeCoercionRewriter<'a> { /// for type-coercion approach. pub fn coerce_plan(&mut self, plan: LogicalPlan) -> Result { match plan { - LogicalPlan::Join(join) => self.coerce_join(join), - LogicalPlan::Union(union) => Self::coerce_union(union), - LogicalPlan::Limit(limit) => Self::coerce_limit(limit), + LogicalPlan::Join(join, _) => self.coerce_join(join), + LogicalPlan::Union(union, _) => Self::coerce_union(union), + LogicalPlan::Limit(limit, _) => Self::coerce_limit(limit), _ => Ok(plan), } } @@ -201,7 +201,7 @@ impl<'a> TypeCoercionRewriter<'a> { .map(|expr| self.coerce_join_filter(expr)) .transpose()?; - Ok(LogicalPlan::Join(join)) + Ok(LogicalPlan::join(join)) } /// Coerce the union’s inputs to a common schema compatible with all inputs. @@ -215,7 +215,7 @@ impl<'a> TypeCoercionRewriter<'a> { let plan = coerce_plan_expr_for_schema(Arc::unwrap_or_clone(p), &union_schema)?; match plan { - LogicalPlan::Projection(Projection { expr, input, .. }) => { + LogicalPlan::Projection(Projection { expr, input, .. }, _) => { Ok(Arc::new(project_with_column_index( expr, input, @@ -226,7 +226,7 @@ impl<'a> TypeCoercionRewriter<'a> { } }) .collect::>>()?; - Ok(LogicalPlan::Union(Union { + Ok(LogicalPlan::union(Union { inputs: new_inputs, schema: union_schema, })) @@ -256,7 +256,7 @@ impl<'a> TypeCoercionRewriter<'a> { .skip .map(|expr| coerce_limit_expr(*expr, &empty_schema, "OFFSET")) .transpose()?; - Ok(LogicalPlan::Limit(Limit { + Ok(LogicalPlan::limit(Limit { input: limit.input, fetch: new_fetch.map(Box::new), skip: new_skip.map(Box::new), @@ -295,27 +295,30 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { fn f_up(&mut self, expr: Expr) -> Result> { match expr { - Expr::Unnest(_) => not_impl_err!( + Expr::Unnest(_, _) => not_impl_err!( "Unnest should be rewritten to LogicalPlan::Unnest before type coercion" ), - Expr::ScalarSubquery(Subquery { - subquery, - outer_ref_columns, - }) => { + Expr::ScalarSubquery( + Subquery { + subquery, + outer_ref_columns, + }, + _, + ) => { let new_plan = analyze_internal(self.schema, Arc::unwrap_or_clone(subquery))?.data; - Ok(Transformed::yes(Expr::ScalarSubquery(Subquery { + Ok(Transformed::yes(Expr::scalar_subquery(Subquery { subquery: Arc::new(new_plan), outer_ref_columns, }))) } - Expr::Exists(Exists { subquery, negated }) => { + Expr::Exists(Exists { subquery, negated }, _) => { let new_plan = analyze_internal( self.schema, Arc::unwrap_or_clone(subquery.subquery), )? .data; - Ok(Transformed::yes(Expr::Exists(Exists { + Ok(Transformed::yes(Expr::exists(Exists { subquery: Subquery { subquery: Arc::new(new_plan), outer_ref_columns: subquery.outer_ref_columns, @@ -323,11 +326,14 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { negated, }))) } - Expr::InSubquery(InSubquery { - expr, - subquery, - negated, - }) => { + Expr::InSubquery( + InSubquery { + expr, + subquery, + negated, + }, + _, + ) => { let new_plan = analyze_internal( self.schema, Arc::unwrap_or_clone(subquery.subquery), @@ -343,41 +349,44 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { subquery: Arc::new(new_plan), outer_ref_columns: subquery.outer_ref_columns, }; - Ok(Transformed::yes(Expr::InSubquery(InSubquery::new( + Ok(Transformed::yes(Expr::in_subquery(InSubquery::new( Box::new(expr.cast_to(&common_type, self.schema)?), cast_subquery(new_subquery, &common_type)?, negated, )))) } - Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op( + Expr::Not(expr, _) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op( *expr, self.schema, )?))), - Expr::IsTrue(expr) => Ok(Transformed::yes(is_true( + Expr::IsTrue(expr, _) => Ok(Transformed::yes(is_true( get_casted_expr_for_bool_op(*expr, self.schema)?, ))), - Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true( + Expr::IsNotTrue(expr, _) => Ok(Transformed::yes(is_not_true( get_casted_expr_for_bool_op(*expr, self.schema)?, ))), - Expr::IsFalse(expr) => Ok(Transformed::yes(is_false( + Expr::IsFalse(expr, _) => Ok(Transformed::yes(is_false( get_casted_expr_for_bool_op(*expr, self.schema)?, ))), - Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false( + Expr::IsNotFalse(expr, _) => Ok(Transformed::yes(is_not_false( get_casted_expr_for_bool_op(*expr, self.schema)?, ))), - Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown( + Expr::IsUnknown(expr, _) => Ok(Transformed::yes(is_unknown( get_casted_expr_for_bool_op(*expr, self.schema)?, ))), - Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown( + Expr::IsNotUnknown(expr, _) => Ok(Transformed::yes(is_not_unknown( get_casted_expr_for_bool_op(*expr, self.schema)?, ))), - Expr::Like(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - }) => { + Expr::Like( + Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + }, + _, + ) => { let left_type = expr.get_type(self.schema)?; let right_type = pattern.get_type(self.schema)?; let coerced_type = like_coercion(&left_type, &right_type).ok_or_else(|| { @@ -395,7 +404,7 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { _ => Box::new(expr.cast_to(&coerced_type, self.schema)?), }; let pattern = Box::new(pattern.cast_to(&coerced_type, self.schema)?); - Ok(Transformed::yes(Expr::Like(Like::new( + Ok(Transformed::yes(Expr::_like(Like::new( negated, expr, pattern, @@ -403,20 +412,23 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { case_insensitive, )))) } - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + Expr::BinaryExpr(BinaryExpr { left, op, right }, _) => { let (left, right) = self.coerce_binary_op(*left, op, *right)?; - Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new( + Ok(Transformed::yes(Expr::binary_expr(BinaryExpr::new( Box::new(left), op, Box::new(right), )))) } - Expr::Between(Between { - expr, - negated, - low, - high, - }) => { + Expr::Between( + Between { + expr, + negated, + low, + high, + }, + _, + ) => { let expr_type = expr.get_type(self.schema)?; let low_type = low.get_type(self.schema)?; let low_coerced_type = comparison_coercion(&expr_type, &low_type) @@ -439,18 +451,21 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression" )) })?; - Ok(Transformed::yes(Expr::Between(Between::new( + Ok(Transformed::yes(Expr::_between(Between::new( Box::new(expr.cast_to(&coercion_type, self.schema)?), negated, Box::new(low.cast_to(&coercion_type, self.schema)?), Box::new(high.cast_to(&coercion_type, self.schema)?), )))) } - Expr::InList(InList { - expr, - list, - negated, - }) => { + Expr::InList( + InList { + expr, + list, + negated, + }, + _, + ) => { let expr_data_type = expr.get_type(self.schema)?; let list_data_types = list .iter() @@ -471,7 +486,7 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { list_expr.cast_to(&coerced_type, self.schema) }) .collect::>>()?; - Ok(Transformed::yes(Expr::InList(InList ::new( + Ok(Transformed::yes(Expr::_in_list(InList ::new( Box::new(cast_expr), cast_list_expr, negated, @@ -479,34 +494,37 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { } } } - Expr::Case(case) => { + Expr::Case(case, _) => { let case = coerce_case_expression(case, self.schema)?; - Ok(Transformed::yes(Expr::Case(case))) + Ok(Transformed::yes(Expr::case(case))) } - Expr::ScalarFunction(ScalarFunction { func, args }) => { + Expr::ScalarFunction(ScalarFunction { func, args }, _) => { let new_expr = coerce_arguments_for_signature_with_scalar_udf( args, self.schema, &func, )?; - Ok(Transformed::yes(Expr::ScalarFunction( + Ok(Transformed::yes(Expr::scalar_function( ScalarFunction::new_udf(func, new_expr), ))) } - Expr::AggregateFunction(expr::AggregateFunction { - func, - args, - distinct, - filter, - order_by, - null_treatment, - }) => { + Expr::AggregateFunction( + expr::AggregateFunction { + func, + args, + distinct, + filter, + order_by, + null_treatment, + }, + _, + ) => { let new_expr = coerce_arguments_for_signature_with_aggregate_udf( args, self.schema, &func, )?; - Ok(Transformed::yes(Expr::AggregateFunction( + Ok(Transformed::yes(Expr::aggregate_function( expr::AggregateFunction::new_udf( func, new_expr, @@ -517,14 +535,17 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { ), ))) } - Expr::WindowFunction(WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - null_treatment, - }) => { + Expr::WindowFunction( + WindowFunction { + fun, + args, + partition_by, + order_by, + window_frame, + null_treatment, + }, + _, + ) => { let window_frame = coerce_window_frame(window_frame, self.schema, &order_by)?; @@ -540,7 +561,7 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { }; Ok(Transformed::yes( - Expr::WindowFunction(WindowFunction::new(fun, args)) + Expr::window_function(WindowFunction::new(fun, args)) .partition_by(partition_by) .order_by(order_by) .window_frame(window_frame) @@ -548,18 +569,18 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { .build()?, )) } - Expr::Alias(_) + Expr::Alias(_, _) | Expr::Column(_) | Expr::ScalarVariable(_, _) | Expr::Literal(_) - | Expr::SimilarTo(_) - | Expr::IsNotNull(_) - | Expr::IsNull(_) - | Expr::Negative(_) - | Expr::Cast(_) - | Expr::TryCast(_) + | Expr::SimilarTo(_, _) + | Expr::IsNotNull(_, _) + | Expr::IsNull(_, _) + | Expr::Negative(_, _) + | Expr::Cast(_, _) + | Expr::TryCast(_, _) | Expr::Wildcard { .. } - | Expr::GroupingSet(_) + | Expr::GroupingSet(_, _) | Expr::Placeholder(_) | Expr::OuterReferenceColumn(_, _) => Ok(Transformed::no(expr)), } @@ -993,7 +1014,7 @@ fn project_with_column_index( .into_iter() .enumerate() .map(|(i, e)| match e { - Expr::Alias(Alias { ref name, .. }) if name != schema.field(i).name() => { + Expr::Alias(Alias { ref name, .. }, _) if name != schema.field(i).name() => { e.unalias().alias(schema.field(i).name()) } Expr::Column(Column { @@ -1006,7 +1027,7 @@ fn project_with_column_index( .collect::>(); Projection::try_new_with_schema(alias_expr, input, schema) - .map(LogicalPlan::Projection) + .map(LogicalPlan::projection) } #[cfg(test)] @@ -1060,7 +1081,7 @@ mod test { fn simple_case() -> Result<()> { let expr = col("a").lt(lit(2_u32)); let empty = empty_with_type(DataType::Float64); - let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![expr], empty)?); let expected = "Projection: a < CAST(UInt32(2) AS Float64)\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) } @@ -1092,7 +1113,7 @@ mod test { // scenario: outermost utf8view projection let expr = col("a"); let empty = empty_with_type(DataType::Utf8View); - let plan = LogicalPlan::Projection(Projection::try_new( + let plan = LogicalPlan::projection(Projection::try_new( vec![expr.clone()], Arc::clone(&empty), )?); @@ -1106,7 +1127,7 @@ mod test { // Plan B // scenario: outermost bool projection let bool_expr = col("a").lt(lit("foo")); - let bool_plan = LogicalPlan::Projection(Projection::try_new( + let bool_plan = LogicalPlan::projection(Projection::try_new( vec![bool_expr], Arc::clone(&empty), )?); @@ -1121,7 +1142,7 @@ mod test { // Plan C // scenario: with a non-projection root logical plan node let sort_expr = expr.sort(true, true); - let sort_plan = LogicalPlan::Sort(Sort { + let sort_plan = LogicalPlan::sort(Sort { expr: vec![sort_expr], input: Arc::new(plan), fetch: None, @@ -1136,7 +1157,7 @@ mod test { // Plan D // scenario: two layers of projections with view types - let plan = LogicalPlan::Projection(Projection::try_new( + let plan = LogicalPlan::projection(Projection::try_new( vec![col("a")], Arc::new(sort_plan), )?); @@ -1156,7 +1177,7 @@ mod test { // scenario: outermost binaryview projection let expr = col("a"); let empty = empty_with_type(DataType::BinaryView); - let plan = LogicalPlan::Projection(Projection::try_new( + let plan = LogicalPlan::projection(Projection::try_new( vec![expr.clone()], Arc::clone(&empty), )?); @@ -1170,7 +1191,7 @@ mod test { // Plan B // scenario: outermost bool projection let bool_expr = col("a").lt(lit(vec![8, 1, 8, 1])); - let bool_plan = LogicalPlan::Projection(Projection::try_new( + let bool_plan = LogicalPlan::projection(Projection::try_new( vec![bool_expr], Arc::clone(&empty), )?); @@ -1185,7 +1206,7 @@ mod test { // Plan C // scenario: with a non-projection root logical plan node let sort_expr = expr.sort(true, true); - let sort_plan = LogicalPlan::Sort(Sort { + let sort_plan = LogicalPlan::sort(Sort { expr: vec![sort_expr], input: Arc::new(plan), fetch: None, @@ -1200,7 +1221,7 @@ mod test { // Plan D // scenario: two layers of projections with view types - let plan = LogicalPlan::Projection(Projection::try_new( + let plan = LogicalPlan::projection(Projection::try_new( vec![col("a")], Arc::new(sort_plan), )?); @@ -1219,7 +1240,7 @@ mod test { let expr = col("a").lt(lit(2_u32)); let empty = empty_with_type(DataType::Float64); - let plan = LogicalPlan::Projection(Projection::try_new( + let plan = LogicalPlan::projection(Projection::try_new( vec![expr.clone().or(expr)], empty, )?); @@ -1263,7 +1284,7 @@ mod test { signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), }) .call(vec![lit(123_i32)]); - let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![udf], empty)?); let expected = "Projection: TestScalarUDF(CAST(Int32(123) AS Float32))\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) @@ -1291,8 +1312,8 @@ mod test { signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), }); let scalar_function_expr = - Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![lit_expr])); - let plan = LogicalPlan::Projection(Projection::try_new( + Expr::scalar_function(ScalarFunction::new_udf(Arc::new(fun), vec![lit_expr])); + let plan = LogicalPlan::projection(Projection::try_new( vec![scalar_function_expr], empty, )?); @@ -1312,7 +1333,7 @@ mod test { Arc::new(|_| Ok(Box::::default())), Arc::new(vec![DataType::UInt64, DataType::Float64]), ); - let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf( + let udaf = Expr::aggregate_function(expr::AggregateFunction::new_udf( Arc::new(my_avg), vec![lit(10i64)], false, @@ -1320,7 +1341,7 @@ mod test { None, None, )); - let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![udaf], empty)?); let expected = "Projection: MY_AVG(CAST(Int64(10) AS Float64))\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) } @@ -1341,7 +1362,7 @@ mod test { Field::new("avg", DataType::Float64, true), ], )); - let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf( + let udaf = Expr::aggregate_function(expr::AggregateFunction::new_udf( Arc::new(my_avg), vec![lit("10")], false, @@ -1360,7 +1381,7 @@ mod test { #[test] fn agg_function_case() -> Result<()> { let empty = empty(); - let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( + let agg_expr = Expr::aggregate_function(expr::AggregateFunction::new_udf( avg_udaf(), vec![lit(12f64)], false, @@ -1368,12 +1389,12 @@ mod test { None, None, )); - let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![agg_expr], empty)?); let expected = "Projection: avg(Float64(12))\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; let empty = empty_with_type(DataType::Int32); - let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( + let agg_expr = Expr::aggregate_function(expr::AggregateFunction::new_udf( avg_udaf(), vec![cast(col("a"), DataType::Float64)], false, @@ -1381,7 +1402,7 @@ mod test { None, None, )); - let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![agg_expr], empty)?); let expected = "Projection: avg(CAST(a AS Float64))\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; Ok(()) @@ -1390,7 +1411,7 @@ mod test { #[test] fn agg_function_invalid_input_avg() -> Result<()> { let empty = empty(); - let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( + let agg_expr = Expr::aggregate_function(expr::AggregateFunction::new_udf( avg_udaf(), vec![lit("1")], false, @@ -1412,7 +1433,7 @@ mod test { let expr = cast(lit("1998-03-18"), DataType::Date32) + lit(ScalarValue::new_interval_dt(123, 456)); let empty = empty(); - let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![expr], empty)?); let expected = "Projection: CAST(Utf8(\"1998-03-18\") AS Date32) + IntervalDayTime(\"IntervalDayTime { days: 123, milliseconds: 456 }\")\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; @@ -1424,7 +1445,7 @@ mod test { // a in (1,4,8), a is int64 let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false); let empty = empty_with_type(DataType::Int64); - let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![expr], empty)?); let expected = "Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)])\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; @@ -1437,7 +1458,7 @@ mod test { std::collections::HashMap::new(), )?), })); - let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![expr], empty)?); let expected = "Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))])\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) } @@ -1451,7 +1472,7 @@ mod test { + lit(ScalarValue::new_interval_ym(0, 1)), ); let empty = empty_with_type(Utf8); - let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?); + let plan = LogicalPlan::filter(Filter::try_new(expr, empty)?); let expected = "Filter: a BETWEEN Utf8(\"2002-05-08\") AND CAST(CAST(Utf8(\"2002-05-08\") AS Date32) + IntervalYearMonth(\"1\") AS Utf8)\ \n EmptyRelation"; @@ -1467,7 +1488,7 @@ mod test { lit("2002-12-08"), ); let empty = empty_with_type(Utf8); - let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?); + let plan = LogicalPlan::filter(Filter::try_new(expr, empty)?); // TODO: we should cast col(a). let expected = "Filter: CAST(a AS Date32) BETWEEN CAST(Utf8(\"2002-05-08\") AS Date32) + IntervalYearMonth(\"1\") AND CAST(Utf8(\"2002-12-08\") AS Date32)\ @@ -1481,12 +1502,12 @@ mod test { let expr = col("a").is_true(); let empty = empty_with_type(DataType::Boolean); let plan = - LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?); + LogicalPlan::projection(Projection::try_new(vec![expr.clone()], empty)?); let expected = "Projection: a IS TRUE\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; let empty = empty_with_type(DataType::Int64); - let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![expr], empty)?); let ret = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, ""); let err = ret.unwrap_err().to_string(); assert!(err.contains("Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean"), "{err}"); @@ -1494,21 +1515,21 @@ mod test { // is not true let expr = col("a").is_not_true(); let empty = empty_with_type(DataType::Boolean); - let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![expr], empty)?); let expected = "Projection: a IS NOT TRUE\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; // is false let expr = col("a").is_false(); let empty = empty_with_type(DataType::Boolean); - let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![expr], empty)?); let expected = "Projection: a IS FALSE\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; // is not false let expr = col("a").is_not_false(); let empty = empty_with_type(DataType::Boolean); - let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![expr], empty)?); let expected = "Projection: a IS NOT FALSE\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; @@ -1520,25 +1541,25 @@ mod test { // like : utf8 like "abc" let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::new_utf8("abc"))); - let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false)); + let like_expr = Expr::_like(Like::new(false, expr, pattern, None, false)); let empty = empty_with_type(Utf8); - let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![like_expr], empty)?); let expected = "Projection: a LIKE Utf8(\"abc\")\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::Null)); - let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false)); + let like_expr = Expr::_like(Like::new(false, expr, pattern, None, false)); let empty = empty_with_type(Utf8); - let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![like_expr], empty)?); let expected = "Projection: a LIKE CAST(NULL AS Utf8)\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::new_utf8("abc"))); - let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false)); + let like_expr = Expr::_like(Like::new(false, expr, pattern, None, false)); let empty = empty_with_type(DataType::Int64); - let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![like_expr], empty)?); let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected); assert!(err.is_err()); assert!(err.unwrap_err().to_string().contains( @@ -1548,25 +1569,25 @@ mod test { // ilike let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::new_utf8("abc"))); - let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true)); + let ilike_expr = Expr::_like(Like::new(false, expr, pattern, None, true)); let empty = empty_with_type(Utf8); - let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![ilike_expr], empty)?); let expected = "Projection: a ILIKE Utf8(\"abc\")\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::Null)); - let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true)); + let ilike_expr = Expr::_like(Like::new(false, expr, pattern, None, true)); let empty = empty_with_type(Utf8); - let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![ilike_expr], empty)?); let expected = "Projection: a ILIKE CAST(NULL AS Utf8)\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::new_utf8("abc"))); - let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true)); + let ilike_expr = Expr::_like(Like::new(false, expr, pattern, None, true)); let empty = empty_with_type(DataType::Int64); - let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![ilike_expr], empty)?); let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected); assert!(err.is_err()); assert!(err.unwrap_err().to_string().contains( @@ -1581,12 +1602,12 @@ mod test { let expr = col("a").is_unknown(); let empty = empty_with_type(DataType::Boolean); let plan = - LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?); + LogicalPlan::projection(Projection::try_new(vec![expr.clone()], empty)?); let expected = "Projection: a IS UNKNOWN\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; let empty = empty_with_type(Utf8); - let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![expr], empty)?); let ret = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected); let err = ret.unwrap_err().to_string(); assert!(err.contains("Cannot infer common argument type for comparison operation Utf8 IS DISTINCT FROM Boolean"), "{err}"); @@ -1594,7 +1615,7 @@ mod test { // is not unknown let expr = col("a").is_not_unknown(); let empty = empty_with_type(DataType::Boolean); - let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![expr], empty)?); let expected = "Projection: a IS NOT UNKNOWN\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; @@ -1612,7 +1633,7 @@ mod test { signature: Signature::variadic(vec![Utf8], Volatility::Immutable), }) .call(args.to_vec()); - let plan = LogicalPlan::Projection(Projection::try_new( + let plan = LogicalPlan::projection(Projection::try_new( vec![expr], Arc::clone(&empty), )?); @@ -1670,7 +1691,7 @@ mod test { ) .eq(cast(lit("1998-03-18"), DataType::Date32)); let empty = empty(); - let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![expr], empty)?); let expected = "Projection: CAST(Utf8(\"1998-03-18\") AS Timestamp(Nanosecond, None)) = CAST(CAST(Utf8(\"1998-03-18\") AS Date32) AS Timestamp(Nanosecond, None))\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; @@ -2007,7 +2028,7 @@ mod test { #[test] fn interval_plus_timestamp() -> Result<()> { // SELECT INTERVAL '1' YEAR + '2000-01-01T00:00:00'::timestamp; - let expr = Expr::BinaryExpr(BinaryExpr::new( + let expr = Expr::binary_expr(BinaryExpr::new( Box::new(lit(ScalarValue::IntervalYearMonth(Some(12)))), Operator::Plus, Box::new(cast( @@ -2016,7 +2037,7 @@ mod test { )), )); let empty = empty(); - let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![expr], empty)?); let expected = "Projection: IntervalYearMonth(\"12\") + CAST(Utf8(\"2000-01-01T00:00:00\") AS Timestamp(Nanosecond, None))\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; Ok(()) @@ -2024,7 +2045,7 @@ mod test { #[test] fn timestamp_subtract_timestamp() -> Result<()> { - let expr = Expr::BinaryExpr(BinaryExpr::new( + let expr = Expr::binary_expr(BinaryExpr::new( Box::new(cast( lit("1998-03-18"), DataType::Timestamp(TimeUnit::Nanosecond, None), @@ -2036,7 +2057,7 @@ mod test { )), )); let empty = empty(); - let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![expr], empty)?); let expected = "Projection: CAST(Utf8(\"1998-03-18\") AS Timestamp(Nanosecond, None)) - CAST(Utf8(\"1998-03-18\") AS Timestamp(Nanosecond, None))\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; @@ -2048,7 +2069,7 @@ mod test { let empty_int32 = empty_with_type(DataType::Int32); let empty_int64 = empty_with_type(DataType::Int64); - let in_subquery_expr = Expr::InSubquery(InSubquery::new( + let in_subquery_expr = Expr::in_subquery(InSubquery::new( Box::new(col("a")), Subquery { subquery: empty_int32, @@ -2056,7 +2077,7 @@ mod test { }, false, )); - let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int64)?); + let plan = LogicalPlan::filter(Filter::try_new(in_subquery_expr, empty_int64)?); // add cast for subquery let expected = "\ Filter: a IN ()\ @@ -2073,7 +2094,7 @@ mod test { let empty_int32 = empty_with_type(DataType::Int32); let empty_int64 = empty_with_type(DataType::Int64); - let in_subquery_expr = Expr::InSubquery(InSubquery::new( + let in_subquery_expr = Expr::in_subquery(InSubquery::new( Box::new(col("a")), Subquery { subquery: empty_int64, @@ -2081,7 +2102,7 @@ mod test { }, false, )); - let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int32)?); + let plan = LogicalPlan::filter(Filter::try_new(in_subquery_expr, empty_int32)?); // add cast for subquery let expected = "\ Filter: CAST(a AS Int64) IN ()\ @@ -2097,7 +2118,7 @@ mod test { let empty_inside = empty_with_type(DataType::Decimal128(10, 5)); let empty_outside = empty_with_type(DataType::Decimal128(8, 8)); - let in_subquery_expr = Expr::InSubquery(InSubquery::new( + let in_subquery_expr = Expr::in_subquery(InSubquery::new( Box::new(col("a")), Subquery { subquery: empty_inside, @@ -2105,7 +2126,7 @@ mod test { }, false, )); - let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_outside)?); + let plan = LogicalPlan::filter(Filter::try_new(in_subquery_expr, empty_outside)?); // add cast for subquery let expected = "Filter: CAST(a AS Decimal128(13, 8)) IN ()\ \n Subquery:\ diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 16a4fa6be38d0..ce0757b7663e8 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -88,7 +88,7 @@ impl CommonSubexprEliminate { self.try_unary_plan(expr, input, config)? .map_data(|(new_expr, new_input)| { Projection::try_new_with_schema(new_expr, Arc::new(new_input), schema) - .map(LogicalPlan::Projection) + .map(LogicalPlan::projection) }) } @@ -106,7 +106,7 @@ impl CommonSubexprEliminate { let new_sort = self .try_unary_plan(sort_expressions, input, config)? .update_data(|(new_expr, new_input)| { - LogicalPlan::Sort(Sort { + LogicalPlan::sort(Sort { expr: new_expr .into_iter() .zip(sort_params) @@ -138,7 +138,7 @@ impl CommonSubexprEliminate { assert_eq!(new_expr.len(), 1); // passed in vec![predicate] let new_predicate = new_expr.pop().unwrap(); Filter::try_new(new_predicate, Arc::new(new_input)) - .map(LogicalPlan::Filter) + .map(LogicalPlan::filter) }) } @@ -213,7 +213,7 @@ impl CommonSubexprEliminate { }) .collect::>(); Window::try_new(new_window_expr, Arc::new(plan)) - .map(LogicalPlan::Window) + .map(LogicalPlan::window) }, ) } else { @@ -226,7 +226,7 @@ impl CommonSubexprEliminate { Arc::new(plan), schema, ) - .map(LogicalPlan::Window) + .map(LogicalPlan::window) }) } }) @@ -331,7 +331,7 @@ impl CommonSubexprEliminate { rewritten_aggr_expr.into_iter().zip(new_aggr_expr) { if expr_rewritten == expr_orig { - if let Expr::Alias(Alias { expr, name, .. }) = + if let Expr::Alias(Alias { expr, name, .. }, _) = expr_rewritten { agg_exprs.push(expr.alias(&name)); @@ -356,13 +356,13 @@ impl CommonSubexprEliminate { } } - let agg = LogicalPlan::Aggregate(Aggregate::try_new( + let agg = LogicalPlan::aggregate(Aggregate::try_new( new_input, new_group_expr, agg_exprs, )?); Projection::try_new(proj_exprs, Arc::new(agg)) - .map(|p| Transformed::yes(LogicalPlan::Projection(p))) + .map(|p| Transformed::yes(LogicalPlan::projection(p))) } // If there aren't any common aggregate sub-expressions, then just @@ -399,7 +399,7 @@ impl CommonSubexprEliminate { // Since `group_expr` may have changed, schema may also. // Use `try_new()` method. Aggregate::try_new(new_input, new_group_expr, new_aggr_expr) - .map(LogicalPlan::Aggregate) + .map(LogicalPlan::aggregate) .map(Transformed::no) } else { Aggregate::try_new_with_schema( @@ -408,7 +408,7 @@ impl CommonSubexprEliminate { rewritten_aggr_expr, schema, ) - .map(LogicalPlan::Aggregate) + .map(LogicalPlan::aggregate) .map(Transformed::no) } } @@ -505,12 +505,15 @@ fn get_consecutive_window_exprs( ) -> (Vec>, Vec, LogicalPlan) { let mut window_expr_list = vec![]; let mut window_schemas = vec![]; - let mut plan = LogicalPlan::Window(window); - while let LogicalPlan::Window(Window { - input, - window_expr, - schema, - }) = plan + let mut plan = LogicalPlan::window(window); + while let LogicalPlan::Window( + Window { + input, + window_expr, + schema, + }, + _, + ) = plan { window_expr_list.push(window_expr); window_schemas.push(schema); @@ -541,31 +544,31 @@ impl OptimizerRule for CommonSubexprEliminate { let original_schema = Arc::clone(plan.schema()); let optimized_plan = match plan { - LogicalPlan::Projection(proj) => self.try_optimize_proj(proj, config)?, - LogicalPlan::Sort(sort) => self.try_optimize_sort(sort, config)?, - LogicalPlan::Filter(filter) => self.try_optimize_filter(filter, config)?, - LogicalPlan::Window(window) => self.try_optimize_window(window, config)?, - LogicalPlan::Aggregate(agg) => self.try_optimize_aggregate(agg, config)?, - LogicalPlan::Join(_) - | LogicalPlan::Repartition(_) - | LogicalPlan::Union(_) - | LogicalPlan::TableScan(_) - | LogicalPlan::Values(_) + LogicalPlan::Projection(proj, _) => self.try_optimize_proj(proj, config)?, + LogicalPlan::Sort(sort, _) => self.try_optimize_sort(sort, config)?, + LogicalPlan::Filter(filter, _) => self.try_optimize_filter(filter, config)?, + LogicalPlan::Window(window, _) => self.try_optimize_window(window, config)?, + LogicalPlan::Aggregate(agg, _) => self.try_optimize_aggregate(agg, config)?, + LogicalPlan::Join(_, _) + | LogicalPlan::Repartition(_, _) + | LogicalPlan::Union(_, _) + | LogicalPlan::TableScan(_, _) + | LogicalPlan::Values(_, _) | LogicalPlan::EmptyRelation(_) - | LogicalPlan::Subquery(_) - | LogicalPlan::SubqueryAlias(_) - | LogicalPlan::Limit(_) - | LogicalPlan::Ddl(_) - | LogicalPlan::Explain(_) - | LogicalPlan::Analyze(_) - | LogicalPlan::Statement(_) + | LogicalPlan::Subquery(_, _) + | LogicalPlan::SubqueryAlias(_, _) + | LogicalPlan::Limit(_, _) + | LogicalPlan::Ddl(_, _) + | LogicalPlan::Explain(_, _) + | LogicalPlan::Analyze(_, _) + | LogicalPlan::Statement(_, _) | LogicalPlan::DescribeTable(_) - | LogicalPlan::Distinct(_) - | LogicalPlan::Extension(_) - | LogicalPlan::Dml(_) - | LogicalPlan::Copy(_) - | LogicalPlan::Unnest(_) - | LogicalPlan::RecursiveQuery(_) => { + | LogicalPlan::Distinct(_, _) + | LogicalPlan::Extension(_, _) + | LogicalPlan::Dml(_, _) + | LogicalPlan::Copy(_, _) + | LogicalPlan::Unnest(_, _) + | LogicalPlan::RecursiveQuery(_, _) => { // This rule handles recursion itself in a `ApplyOrder::TopDown` like // manner. plan.map_children(|c| self.rewrite(c, config))? @@ -631,7 +634,7 @@ impl CSEController for ExprCSEController<'_> { // In case of `ScalarFunction`s we don't know which children are surely // executed so start visiting all children conditionally and stop the // recursion with `TreeNodeRecursion::Jump`. - Expr::ScalarFunction(ScalarFunction { func, args }) + Expr::ScalarFunction(ScalarFunction { func, args }, _) if func.short_circuits() => { Some((vec![], args.iter().collect())) @@ -639,20 +642,26 @@ impl CSEController for ExprCSEController<'_> { // In case of `And` and `Or` the first child is surely executed, but we // account subexpressions as conditional in the second. - Expr::BinaryExpr(BinaryExpr { - left, - op: Operator::And | Operator::Or, - right, - }) => Some((vec![left.as_ref()], vec![right.as_ref()])), + Expr::BinaryExpr( + BinaryExpr { + left, + op: Operator::And | Operator::Or, + right, + }, + _, + ) => Some((vec![left.as_ref()], vec![right.as_ref()])), // In case of `Case` the optional base expression and the first when // expressions are surely executed, but we account subexpressions as // conditional in the others. - Expr::Case(Case { - expr, - when_then_expr, - else_expr, - }) => Some(( + Expr::Case( + Case { + expr, + when_then_expr, + else_expr, + }, + _, + ) => Some(( expr.iter() .map(|e| e.as_ref()) .chain(when_then_expr.iter().take(1).map(|(when, _)| when.as_ref())) @@ -711,12 +720,12 @@ impl CSEController for ExprCSEController<'_> { } fn rewrite_f_down(&mut self, node: &Expr) { - if matches!(node, Expr::Alias(_)) { + if matches!(node, Expr::Alias(_, _)) { self.alias_counter += 1; } } fn rewrite_f_up(&mut self, node: &Expr) { - if matches!(node, Expr::Alias(_)) { + if matches!(node, Expr::Alias(_, _)) { self.alias_counter -= 1 } } @@ -757,7 +766,7 @@ fn build_common_expr_project_plan( } } - Projection::try_new(project_exprs, Arc::new(input)).map(LogicalPlan::Projection) + Projection::try_new(project_exprs, Arc::new(input)).map(LogicalPlan::projection) } /// Build the projection plan to eliminate unnecessary columns produced by @@ -770,11 +779,11 @@ fn build_recover_project_plan( input: LogicalPlan, ) -> Result { let col_exprs = schema.iter().map(Expr::from).collect(); - Projection::try_new(col_exprs, Arc::new(input)).map(LogicalPlan::Projection) + Projection::try_new(col_exprs, Arc::new(input)).map(LogicalPlan::projection) } fn extract_expressions(expr: &Expr, result: &mut Vec) { - if let Expr::GroupingSet(groupings) = expr { + if let Expr::GroupingSet(groupings, _) = expr { for e in groupings.distinct_expr() { let (qualifier, field_name) = e.qualified_name(); let col = Column::new(qualifier, field_name); @@ -878,7 +887,7 @@ mod test { let return_type = DataType::UInt32; let accumulator: AccumulatorFactoryFunction = Arc::new(|_| unimplemented!()); let udf_agg = |inner: Expr| { - Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + Expr::aggregate_function(datafusion_expr::expr::AggregateFunction::new_udf( Arc::new(AggregateUDF::from(SimpleAggregateUDF::new_with_signature( "my_agg", Signature::exact(vec![DataType::UInt32], Volatility::Stable), diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index b5726d9991379..ae07813c13ef3 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -123,8 +123,10 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { fn f_down(&mut self, plan: LogicalPlan) -> Result> { match plan { - LogicalPlan::Filter(_) => Ok(Transformed::no(plan)), - LogicalPlan::Union(_) | LogicalPlan::Sort(_) | LogicalPlan::Extension(_) => { + LogicalPlan::Filter(_, _) => Ok(Transformed::no(plan)), + LogicalPlan::Union(_, _) + | LogicalPlan::Sort(_, _) + | LogicalPlan::Extension(_, _) => { let plan_hold_outer = !plan.all_out_ref_exprs().is_empty(); if plan_hold_outer { // the unsupported case @@ -134,7 +136,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { Ok(Transformed::no(plan)) } } - LogicalPlan::Limit(_) => { + LogicalPlan::Limit(_, _) => { let plan_hold_outer = !plan.all_out_ref_exprs().is_empty(); match (self.exists_sub_query, plan_hold_outer) { (false, true) => { @@ -157,7 +159,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { fn f_up(&mut self, plan: LogicalPlan) -> Result> { let subquery_schema = plan.schema(); match &plan { - LogicalPlan::Filter(plan_filter) => { + LogicalPlan::Filter(plan_filter, _) => { let subquery_filter_exprs = split_conjunction(&plan_filter.predicate); self.can_pull_over_aggregation = self.can_pull_over_aggregation && subquery_filter_exprs @@ -224,7 +226,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { } } } - LogicalPlan::Projection(projection) + LogicalPlan::Projection(projection, _) if self.in_predicate_opt.is_some() || !self.join_filters.is_empty() => { let mut local_correlated_cols = BTreeSet::new(); @@ -266,7 +268,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { } Ok(Transformed::yes(new_plan)) } - LogicalPlan::Aggregate(aggregate) + LogicalPlan::Aggregate(aggregate, _) if self.in_predicate_opt.is_some() || !self.join_filters.is_empty() => { // If the aggregation is from a distinct it will not change the result for @@ -314,7 +316,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { } Ok(Transformed::yes(new_plan)) } - LogicalPlan::SubqueryAlias(alias) => { + LogicalPlan::SubqueryAlias(alias, _) => { let mut local_correlated_cols = BTreeSet::new(); collect_local_correlated_cols( &plan, @@ -336,7 +338,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { } Ok(Transformed::no(plan)) } - LogicalPlan::Limit(limit) => { + LogicalPlan::Limit(limit, _) => { let input_expr_map = self .collected_count_expr_map .get(limit.input.deref()) @@ -402,21 +404,24 @@ impl PullUpCorrelatedExpr { } fn can_pullup_over_aggregation(expr: &Expr) -> bool { - if let Expr::BinaryExpr(BinaryExpr { - left, - op: Operator::Eq, - right, - }) = expr + if let Expr::BinaryExpr( + BinaryExpr { + left, + op: Operator::Eq, + right, + }, + _, + ) = expr { match (left.deref(), right.deref()) { (Expr::Column(_), right) => !right.any_column_refs(), (left, Expr::Column(_)) => !left.any_column_refs(), - (Expr::Cast(Cast { expr, .. }), right) + (Expr::Cast(Cast { expr, .. }, _), right) if matches!(expr.deref(), Expr::Column(_)) => { !right.any_column_refs() } - (left, Expr::Cast(Cast { expr, .. })) + (left, Expr::Cast(Cast { expr, .. }, _)) if matches!(expr.deref(), Expr::Column(_)) => { !left.any_column_refs() @@ -438,7 +443,7 @@ fn collect_local_correlated_cols( local_cols.extend(cols.clone()); } // SubqueryAlias is treated as the leaf node - if !matches!(child, LogicalPlan::SubqueryAlias(_)) { + if !matches!(child, LogicalPlan::SubqueryAlias(_, _)) { collect_local_correlated_cols(child, all_cols_map, local_cols); } } @@ -454,7 +459,7 @@ fn remove_duplicated_filter(filters: Vec, in_predicate: &Expr) -> Vec { + (Expr::BinaryExpr(a_expr, _), Expr::BinaryExpr(b_expr, _)) => { (a_expr.op == b_expr.op) && (a_expr.left == b_expr.left && a_expr.right == b_expr.right) || (a_expr.left == b_expr.right && a_expr.right == b_expr.left) @@ -475,7 +480,7 @@ fn agg_exprs_evaluation_result_on_empty_batch( .clone() .transform_up(|expr| { let new_expr = match expr { - Expr::AggregateFunction(expr::AggregateFunction { func, .. }) => { + Expr::AggregateFunction(expr::AggregateFunction { func, .. }, _) => { if func.name() == "count" { Transformed::yes(Expr::Literal(ScalarValue::Int64(Some(0)))) } else { @@ -531,7 +536,7 @@ fn proj_exprs_evaluation_result_on_empty_batch( let simplifier = ExprSimplifier::new(info); let result_expr = simplifier.simplify(result_expr)?; let expr_name = match expr { - Expr::Alias(Alias { name, .. }) => name.to_string(), + Expr::Alias(Alias { name, .. }, _) => name.to_string(), Expr::Column(Column { relation: _, name }) => name.to_string(), _ => expr.schema_name().to_string(), }; @@ -581,7 +586,7 @@ fn filter_exprs_evaluation_result_on_empty_batch( // can not evaluate statically _ => { for input_expr in input_expr_result_map_for_count_bug.values() { - let new_expr = Expr::Case(expr::Case { + let new_expr = Expr::case(expr::Case { expr: None, when_then_expr: vec![( Box::new(result_expr.clone()), diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index 7fdad5ba4b6e9..403834cee41bd 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -66,12 +66,12 @@ impl OptimizerRule for DecorrelatePredicateSubquery { })? .data; - let LogicalPlan::Filter(filter) = plan else { + let LogicalPlan::Filter(filter, _) = plan else { return Ok(Transformed::no(plan)); }; if !has_subquery(&filter.predicate) { - return Ok(Transformed::no(LogicalPlan::Filter(filter))); + return Ok(Transformed::no(LogicalPlan::filter(filter))); } let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) = @@ -111,7 +111,7 @@ impl OptimizerRule for DecorrelatePredicateSubquery { let expr = conjunction(other_exprs); if let Some(expr) = expr { let new_filter = Filter::try_new(expr, Arc::new(cur_input))?; - cur_input = LogicalPlan::Filter(new_filter); + cur_input = LogicalPlan::filter(new_filter); } Ok(Transformed::yes(cur_input)) } @@ -133,10 +133,13 @@ fn rewrite_inner_subqueries( let mut cur_input = outer; let alias = config.alias_generator(); let expr_without_subqueries = expr.transform(|e| match e { - Expr::Exists(Exists { - subquery: Subquery { subquery, .. }, - negated, - }) => match mark_join(&cur_input, Arc::clone(&subquery), None, negated, alias)? { + Expr::Exists( + Exists { + subquery: Subquery { subquery, .. }, + negated, + }, + _, + ) => match mark_join(&cur_input, Arc::clone(&subquery), None, negated, alias)? { Some((plan, exists_expr)) => { cur_input = plan; Ok(Transformed::yes(exists_expr)) @@ -144,11 +147,14 @@ fn rewrite_inner_subqueries( None if negated => Ok(Transformed::no(not_exists(subquery))), None => Ok(Transformed::no(exists(subquery))), }, - Expr::InSubquery(InSubquery { - expr, - subquery: Subquery { subquery, .. }, - negated, - }) => { + Expr::InSubquery( + InSubquery { + expr, + subquery: Subquery { subquery, .. }, + negated, + }, + _, + ) => { let in_predicate = subquery .head_output_expr()? .map_or(plan_err!("single expression required."), |output_expr| { @@ -185,27 +191,33 @@ enum SubqueryPredicate { fn extract_subquery_info(expr: Expr) -> SubqueryPredicate { match expr { - Expr::Not(not_expr) => match *not_expr { - Expr::InSubquery(InSubquery { - expr, - subquery, - negated, - }) => SubqueryPredicate::Top(SubqueryInfo::new_with_in_expr( + Expr::Not(not_expr, _) => match *not_expr { + Expr::InSubquery( + InSubquery { + expr, + subquery, + negated, + }, + _, + ) => SubqueryPredicate::Top(SubqueryInfo::new_with_in_expr( subquery, *expr, !negated, )), - Expr::Exists(Exists { subquery, negated }) => { + Expr::Exists(Exists { subquery, negated }, _) => { SubqueryPredicate::Top(SubqueryInfo::new(subquery, !negated)) } expr => SubqueryPredicate::Embedded(not(expr)), }, - Expr::InSubquery(InSubquery { - expr, - subquery, - negated, - }) => SubqueryPredicate::Top(SubqueryInfo::new_with_in_expr( + Expr::InSubquery( + InSubquery { + expr, + subquery, + negated, + }, + _, + ) => SubqueryPredicate::Top(SubqueryInfo::new_with_in_expr( subquery, *expr, negated, )), - Expr::Exists(Exists { subquery, negated }) => { + Expr::Exists(Exists { subquery, negated }, _) => { SubqueryPredicate::Top(SubqueryInfo::new(subquery, negated)) } expr => SubqueryPredicate::Embedded(expr), @@ -214,7 +226,7 @@ fn extract_subquery_info(expr: Expr) -> SubqueryPredicate { fn has_subquery(expr: &Expr) -> bool { expr.exists(|e| match e { - Expr::InSubquery(_) | Expr::Exists(_) => Ok(true), + Expr::InSubquery(_, _) | Expr::Exists(_, _) => Ok(true), _ => Ok(false), }) .unwrap() @@ -345,11 +357,14 @@ fn build_join( if let Some(join_filter) = match (join_filter_opt, in_predicate_opt) { ( Some(join_filter), - Some(Expr::BinaryExpr(BinaryExpr { - left, - op: Operator::Eq, - right, - })), + Some(Expr::BinaryExpr( + BinaryExpr { + left, + op: Operator::Eq, + right, + }, + _, + )), ) => { let right_col = create_col_from_scalar_expr(right.deref(), alias)?; let in_predicate = Expr::eq(left.deref().clone(), Expr::Column(right_col)); @@ -358,11 +373,14 @@ fn build_join( (Some(join_filter), _) => Some(join_filter), ( _, - Some(Expr::BinaryExpr(BinaryExpr { - left, - op: Operator::Eq, - right, - })), + Some(Expr::BinaryExpr( + BinaryExpr { + left, + op: Operator::Eq, + right, + }, + _, + )), ) => { let right_col = create_col_from_scalar_expr(right.deref(), alias)?; let in_predicate = Expr::eq(left.deref().clone(), Expr::Column(right_col)); diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 32b7ce44a63a5..c5715d639ca9f 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -91,24 +91,27 @@ impl OptimizerRule for EliminateCrossJoin { let mut all_inputs: Vec = vec![]; let mut all_filters: Vec = vec![]; - let parent_predicate = if let LogicalPlan::Filter(filter) = plan { + let parent_predicate = if let LogicalPlan::Filter(filter, _) = plan { // if input isn't a join that can potentially be rewritten // avoid unwrapping the input let rewriteable = matches!( filter.input.as_ref(), - LogicalPlan::Join(Join { - join_type: JoinType::Inner, - .. - }) + LogicalPlan::Join( + Join { + join_type: JoinType::Inner, + .. + }, + _ + ) ); if !rewriteable { // recursively try to rewrite children - return rewrite_children(self, LogicalPlan::Filter(filter), config); + return rewrite_children(self, LogicalPlan::filter(filter), config); } if !can_flatten_join_inputs(&filter.input) { - return Ok(Transformed::no(LogicalPlan::Filter(filter))); + return Ok(Transformed::no(LogicalPlan::filter(filter))); } let Filter { @@ -125,10 +128,13 @@ impl OptimizerRule for EliminateCrossJoin { Some(predicate) } else if matches!( plan, - LogicalPlan::Join(Join { - join_type: JoinType::Inner, - .. - }) + LogicalPlan::Join( + Join { + join_type: JoinType::Inner, + .. + }, + _ + ) ) { if !can_flatten_join_inputs(&plan) { return Ok(Transformed::no(plan)); @@ -160,7 +166,7 @@ impl OptimizerRule for EliminateCrossJoin { left = rewrite_children(self, left, config)?.data; if &plan_schema != left.schema() { - left = LogicalPlan::Projection(Projection::new_from_schema( + left = LogicalPlan::projection(Projection::new_from_schema( Arc::new(left), Arc::clone(&plan_schema), )); @@ -170,7 +176,7 @@ impl OptimizerRule for EliminateCrossJoin { // Add any filters on top - PushDownFilter can push filters down to applicable join let first = all_filters.swap_remove(0); let predicate = all_filters.into_iter().fold(first, and); - left = LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(left))?); + left = LogicalPlan::filter(Filter::try_new(predicate, Arc::new(left))?); } let Some(predicate) = parent_predicate else { @@ -180,12 +186,12 @@ impl OptimizerRule for EliminateCrossJoin { // If there are no join keys then do nothing: if all_join_keys.is_empty() { Filter::try_new(predicate, Arc::new(left)) - .map(|filter| Transformed::yes(LogicalPlan::Filter(filter))) + .map(|filter| Transformed::yes(LogicalPlan::filter(filter))) } else { // Remove join expressions from filter: match remove_join_expressions(predicate, &all_join_keys) { Some(filter_expr) => Filter::try_new(filter_expr, Arc::new(left)) - .map(|filter| Transformed::yes(LogicalPlan::Filter(filter))), + .map(|filter| Transformed::yes(LogicalPlan::filter(filter))), _ => Ok(Transformed::yes(left)), } } @@ -224,7 +230,7 @@ fn flatten_join_inputs( all_filters: &mut Vec, ) -> Result<()> { match plan { - LogicalPlan::Join(join) if join.join_type == JoinType::Inner => { + LogicalPlan::Join(join, _) if join.join_type == JoinType::Inner => { if let Some(filter) = join.filter { all_filters.push(filter); } @@ -256,15 +262,18 @@ fn flatten_join_inputs( fn can_flatten_join_inputs(plan: &LogicalPlan) -> bool { // can only flatten inner / cross joins match plan { - LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {} + LogicalPlan::Join(join, _) if join.join_type == JoinType::Inner => {} _ => return false, }; for child in plan.inputs() { - if let LogicalPlan::Join(Join { - join_type: JoinType::Inner, - .. - }) = child + if let LogicalPlan::Join( + Join { + join_type: JoinType::Inner, + .. + }, + _, + ) = child { if !can_flatten_join_inputs(child) { return false; @@ -321,7 +330,7 @@ fn find_inner_join( &JoinType::Inner, )?); - return Ok(LogicalPlan::Join(Join { + return Ok(LogicalPlan::join(Join { left: Arc::new(left_input), right: Arc::new(right_input), join_type: JoinType::Inner, @@ -343,7 +352,7 @@ fn find_inner_join( &JoinType::Inner, )?); - Ok(LogicalPlan::Join(Join { + Ok(LogicalPlan::join(Join { left: Arc::new(left_input), right: Arc::new(right), schema: join_schema, @@ -357,7 +366,7 @@ fn find_inner_join( /// Extract join keys from a WHERE clause fn extract_possible_join_keys(expr: &Expr, join_keys: &mut JoinKeySet) { - if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr { + if let Expr::BinaryExpr(BinaryExpr { left, op, right }, _) = expr { match op { Operator::Eq => { // insert handles ensuring we don't add the same Join keys multiple times @@ -389,20 +398,23 @@ fn extract_possible_join_keys(expr: &Expr, join_keys: &mut JoinKeySet) { /// * `None` otherwise fn remove_join_expressions(expr: Expr, join_keys: &JoinKeySet) -> Option { match expr { - Expr::BinaryExpr(BinaryExpr { - left, - op: Operator::Eq, - right, - }) if join_keys.contains(&left, &right) => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: Operator::Eq, + right, + }, + _, + ) if join_keys.contains(&left, &right) => { // was a join key, so remove it None } // Fix for issue#78 join predicates from inside of OR expr also pulled up properly. - Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == Operator::And => { + Expr::BinaryExpr(BinaryExpr { left, op, right }, _) if op == Operator::And => { let l = remove_join_expressions(*left, join_keys); let r = remove_join_expressions(*right, join_keys); match (l, r) { - (Some(ll), Some(rr)) => Some(Expr::BinaryExpr(BinaryExpr::new( + (Some(ll), Some(rr)) => Some(Expr::binary_expr(BinaryExpr::new( Box::new(ll), op, Box::new(rr), @@ -412,11 +424,11 @@ fn remove_join_expressions(expr: Expr, join_keys: &JoinKeySet) -> Option { _ => None, } } - Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == Operator::Or => { + Expr::BinaryExpr(BinaryExpr { left, op, right }, _) if op == Operator::Or => { let l = remove_join_expressions(*left, join_keys); let r = remove_join_expressions(*right, join_keys); match (l, r) { - (Some(ll), Some(rr)) => Some(Expr::BinaryExpr(BinaryExpr::new( + (Some(ll), Some(rr)) => Some(Expr::binary_expr(BinaryExpr::new( Box::new(ll), op, Box::new(rr), diff --git a/datafusion/optimizer/src/eliminate_duplicated_expr.rs b/datafusion/optimizer/src/eliminate_duplicated_expr.rs index 554985667fdf9..4fca0dfd4c4a6 100644 --- a/datafusion/optimizer/src/eliminate_duplicated_expr.rs +++ b/datafusion/optimizer/src/eliminate_duplicated_expr.rs @@ -63,7 +63,7 @@ impl OptimizerRule for EliminateDuplicatedExpr { _config: &dyn OptimizerConfig, ) -> Result> { match plan { - LogicalPlan::Sort(sort) => { + LogicalPlan::Sort(sort, _) => { let len = sort.expr.len(); let unique_exprs: Vec<_> = sort .expr @@ -80,13 +80,13 @@ impl OptimizerRule for EliminateDuplicatedExpr { Transformed::no }; - Ok(transformed(LogicalPlan::Sort(Sort { + Ok(transformed(LogicalPlan::sort(Sort { expr: unique_exprs, input: sort.input, fetch: sort.fetch, }))) } - LogicalPlan::Aggregate(agg) => { + LogicalPlan::Aggregate(agg, _) => { let len = agg.group_expr.len(); let unique_exprs: Vec = agg @@ -103,7 +103,7 @@ impl OptimizerRule for EliminateDuplicatedExpr { }; Aggregate::try_new(agg.input, unique_exprs, agg.aggr_expr) - .map(|f| transformed(LogicalPlan::Aggregate(f))) + .map(|f| transformed(LogicalPlan::aggregate(f))) } _ => Ok(Transformed::no(plan)), } diff --git a/datafusion/optimizer/src/eliminate_filter.rs b/datafusion/optimizer/src/eliminate_filter.rs index 4ed2ac8ba1a4e..ad7f5aa242c73 100644 --- a/datafusion/optimizer/src/eliminate_filter.rs +++ b/datafusion/optimizer/src/eliminate_filter.rs @@ -59,11 +59,14 @@ impl OptimizerRule for EliminateFilter { _config: &dyn OptimizerConfig, ) -> Result> { match plan { - LogicalPlan::Filter(Filter { - predicate: Expr::Literal(ScalarValue::Boolean(v)), - input, - .. - }) => match v { + LogicalPlan::Filter( + Filter { + predicate: Expr::Literal(ScalarValue::Boolean(v)), + input, + .. + }, + _, + ) => match v { Some(true) => Ok(Transformed::yes(Arc::unwrap_or_clone(input))), Some(false) | None => Ok(Transformed::yes(LogicalPlan::EmptyRelation( EmptyRelation { diff --git a/datafusion/optimizer/src/eliminate_group_by_constant.rs b/datafusion/optimizer/src/eliminate_group_by_constant.rs index 13d03d647fe20..6e0c8884e099a 100644 --- a/datafusion/optimizer/src/eliminate_group_by_constant.rs +++ b/datafusion/optimizer/src/eliminate_group_by_constant.rs @@ -46,7 +46,7 @@ impl OptimizerRule for EliminateGroupByConstant { _config: &dyn OptimizerConfig, ) -> Result> { match plan { - LogicalPlan::Aggregate(aggregate) => { + LogicalPlan::Aggregate(aggregate, _) => { let (const_group_expr, nonconst_group_expr): (Vec<_>, Vec<_>) = aggregate .group_expr .iter() @@ -60,10 +60,10 @@ impl OptimizerRule for EliminateGroupByConstant { && nonconst_group_expr.is_empty() && aggregate.aggr_expr.is_empty()) { - return Ok(Transformed::no(LogicalPlan::Aggregate(aggregate))); + return Ok(Transformed::no(LogicalPlan::aggregate(aggregate))); } - let simplified_aggregate = LogicalPlan::Aggregate(Aggregate::try_new( + let simplified_aggregate = LogicalPlan::aggregate(Aggregate::try_new( aggregate.input, nonconst_group_expr.into_iter().cloned().collect(), aggregate.aggr_expr.clone(), @@ -97,12 +97,12 @@ impl OptimizerRule for EliminateGroupByConstant { /// reiles on `SimplifyExpressions` result. fn is_constant_expression(expr: &Expr) -> bool { match expr { - Expr::Alias(e) => is_constant_expression(&e.expr), - Expr::BinaryExpr(e) => { + Expr::Alias(e, _) => is_constant_expression(&e.expr), + Expr::BinaryExpr(e, _) => { is_constant_expression(&e.left) && is_constant_expression(&e.right) } Expr::Literal(_) => true, - Expr::ScalarFunction(e) => { + Expr::ScalarFunction(e, _) => { matches!( e.func.signature().volatility, Volatility::Immutable | Volatility::Stable @@ -267,7 +267,7 @@ mod tests { Volatility::Immutable, )); let udf_expr = - Expr::ScalarFunction(ScalarFunction::new_udf(udf.into(), vec![lit(123u32)])); + Expr::scalar_function(ScalarFunction::new_udf(udf.into(), vec![lit(123u32)])); let scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(scan) .aggregate(vec![udf_expr, col("a")], vec![count(col("c"))])? @@ -292,7 +292,7 @@ mod tests { Volatility::Volatile, )); let udf_expr = - Expr::ScalarFunction(ScalarFunction::new_udf(udf.into(), vec![lit(123u32)])); + Expr::scalar_function(ScalarFunction::new_udf(udf.into(), vec![lit(123u32)])); let scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(scan) .aggregate(vec![udf_expr, col("a")], vec![count(col("c"))])? diff --git a/datafusion/optimizer/src/eliminate_join.rs b/datafusion/optimizer/src/eliminate_join.rs index 789235595dabf..5374c08ef3c36 100644 --- a/datafusion/optimizer/src/eliminate_join.rs +++ b/datafusion/optimizer/src/eliminate_join.rs @@ -52,7 +52,9 @@ impl OptimizerRule for EliminateJoin { _config: &dyn OptimizerConfig, ) -> Result> { match plan { - LogicalPlan::Join(join) if join.join_type == Inner && join.on.is_empty() => { + LogicalPlan::Join(join, _) + if join.join_type == Inner && join.on.is_empty() => + { match join.filter { Some(Expr::Literal(ScalarValue::Boolean(Some(false)))) => Ok( Transformed::yes(LogicalPlan::EmptyRelation(EmptyRelation { @@ -60,7 +62,7 @@ impl OptimizerRule for EliminateJoin { schema: join.schema, })), ), - _ => Ok(Transformed::no(LogicalPlan::Join(join))), + _ => Ok(Transformed::no(LogicalPlan::join(join))), } } _ => Ok(Transformed::no(plan)), diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index 267615c3e0d93..734aaf495c806 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -59,10 +59,10 @@ impl OptimizerRule for EliminateLimit { _config: &dyn OptimizerConfig, ) -> Result, datafusion_common::DataFusionError> { match plan { - LogicalPlan::Limit(limit) => { + LogicalPlan::Limit(limit, _) => { // Only supports rewriting for literal fetch let FetchType::Literal(fetch) = limit.get_fetch_type()? else { - return Ok(Transformed::no(LogicalPlan::Limit(limit))); + return Ok(Transformed::no(LogicalPlan::limit(limit))); }; if let Some(v) = fetch { @@ -79,7 +79,7 @@ impl OptimizerRule for EliminateLimit { // we can remove it. Its input also can be Limit, so we should apply again. return self.rewrite(Arc::unwrap_or_clone(limit.input), _config); } - Ok(Transformed::no(LogicalPlan::Limit(limit))) + Ok(Transformed::no(LogicalPlan::limit(limit))) } _ => Ok(Transformed::no(plan)), } diff --git a/datafusion/optimizer/src/eliminate_nested_union.rs b/datafusion/optimizer/src/eliminate_nested_union.rs index 94da08243d78f..4979ddc2f3ac1 100644 --- a/datafusion/optimizer/src/eliminate_nested_union.rs +++ b/datafusion/optimizer/src/eliminate_nested_union.rs @@ -55,21 +55,21 @@ impl OptimizerRule for EliminateNestedUnion { _config: &dyn OptimizerConfig, ) -> Result> { match plan { - LogicalPlan::Union(Union { inputs, schema }) => { + LogicalPlan::Union(Union { inputs, schema }, _) => { let inputs = inputs .into_iter() .flat_map(extract_plans_from_union) .map(|plan| coerce_plan_expr_for_schema(plan, &schema)) .collect::>>()?; - Ok(Transformed::yes(LogicalPlan::Union(Union { + Ok(Transformed::yes(LogicalPlan::union(Union { inputs: inputs.into_iter().map(Arc::new).collect_vec(), schema, }))) } - LogicalPlan::Distinct(Distinct::All(nested_plan)) => { + LogicalPlan::Distinct(Distinct::All(nested_plan), _) => { match Arc::unwrap_or_clone(nested_plan) { - LogicalPlan::Union(Union { inputs, schema }) => { + LogicalPlan::Union(Union { inputs, schema }, _) => { let inputs = inputs .into_iter() .map(extract_plan_from_distinct) @@ -77,14 +77,14 @@ impl OptimizerRule for EliminateNestedUnion { .map(|plan| coerce_plan_expr_for_schema(plan, &schema)) .collect::>>()?; - Ok(Transformed::yes(LogicalPlan::Distinct(Distinct::All( - Arc::new(LogicalPlan::Union(Union { + Ok(Transformed::yes(LogicalPlan::distinct(Distinct::All( + Arc::new(LogicalPlan::union(Union { inputs: inputs.into_iter().map(Arc::new).collect_vec(), schema: Arc::clone(&schema), })), )))) } - nested_plan => Ok(Transformed::no(LogicalPlan::Distinct( + nested_plan => Ok(Transformed::no(LogicalPlan::distinct( Distinct::All(Arc::new(nested_plan)), ))), } @@ -96,7 +96,7 @@ impl OptimizerRule for EliminateNestedUnion { fn extract_plans_from_union(plan: Arc) -> Vec { match Arc::unwrap_or_clone(plan) { - LogicalPlan::Union(Union { inputs, .. }) => inputs + LogicalPlan::Union(Union { inputs, .. }, _) => inputs .into_iter() .map(Arc::unwrap_or_clone) .collect::>(), @@ -106,7 +106,7 @@ fn extract_plans_from_union(plan: Arc) -> Vec { fn extract_plan_from_distinct(plan: Arc) -> Arc { match Arc::unwrap_or_clone(plan) { - LogicalPlan::Distinct(Distinct::All(plan)) => plan, + LogicalPlan::Distinct(Distinct::All(plan), _) => plan, plan => Arc::new(plan), } } diff --git a/datafusion/optimizer/src/eliminate_one_union.rs b/datafusion/optimizer/src/eliminate_one_union.rs index 3e027811420c4..ac3da4e8f65d8 100644 --- a/datafusion/optimizer/src/eliminate_one_union.rs +++ b/datafusion/optimizer/src/eliminate_one_union.rs @@ -50,7 +50,7 @@ impl OptimizerRule for EliminateOneUnion { _config: &dyn OptimizerConfig, ) -> Result> { match plan { - LogicalPlan::Union(Union { mut inputs, .. }) if inputs.len() == 1 => Ok( + LogicalPlan::Union(Union { mut inputs, .. }, _) if inputs.len() == 1 => Ok( Transformed::yes(Arc::unwrap_or_clone(inputs.pop().unwrap())), ), _ => Ok(Transformed::no(plan)), @@ -110,7 +110,7 @@ mod tests { &schema().to_dfschema()?, )?; let schema = Arc::clone(table_plan.schema()); - let single_union_plan = LogicalPlan::Union(Union { + let single_union_plan = LogicalPlan::union(Union { inputs: vec![Arc::new(table_plan)], schema, }); diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index 1ecb32ca2a435..bc376df42a781 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -78,56 +78,58 @@ impl OptimizerRule for EliminateOuterJoin { _config: &dyn OptimizerConfig, ) -> Result> { match plan { - LogicalPlan::Filter(mut filter) => match Arc::unwrap_or_clone(filter.input) { - LogicalPlan::Join(join) => { - let mut non_nullable_cols: Vec = vec![]; + LogicalPlan::Filter(mut filter, _) => { + match Arc::unwrap_or_clone(filter.input) { + LogicalPlan::Join(join, _) => { + let mut non_nullable_cols: Vec = vec![]; - extract_non_nullable_columns( - &filter.predicate, - &mut non_nullable_cols, - join.left.schema(), - join.right.schema(), - true, - ); + extract_non_nullable_columns( + &filter.predicate, + &mut non_nullable_cols, + join.left.schema(), + join.right.schema(), + true, + ); - let new_join_type = if join.join_type.is_outer() { - let mut left_non_nullable = false; - let mut right_non_nullable = false; - for col in non_nullable_cols.iter() { - if join.left.schema().has_column(col) { - left_non_nullable = true; - } - if join.right.schema().has_column(col) { - right_non_nullable = true; + let new_join_type = if join.join_type.is_outer() { + let mut left_non_nullable = false; + let mut right_non_nullable = false; + for col in non_nullable_cols.iter() { + if join.left.schema().has_column(col) { + left_non_nullable = true; + } + if join.right.schema().has_column(col) { + right_non_nullable = true; + } } - } - eliminate_outer( - join.join_type, - left_non_nullable, - right_non_nullable, - ) - } else { - join.join_type - }; + eliminate_outer( + join.join_type, + left_non_nullable, + right_non_nullable, + ) + } else { + join.join_type + }; - let new_join = Arc::new(LogicalPlan::Join(Join { - left: join.left, - right: join.right, - join_type: new_join_type, - join_constraint: join.join_constraint, - on: join.on.clone(), - filter: join.filter.clone(), - schema: Arc::clone(&join.schema), - null_equals_null: join.null_equals_null, - })); - Filter::try_new(filter.predicate, new_join) - .map(|f| Transformed::yes(LogicalPlan::Filter(f))) - } - filter_input => { - filter.input = Arc::new(filter_input); - Ok(Transformed::no(LogicalPlan::Filter(filter))) + let new_join = Arc::new(LogicalPlan::join(Join { + left: join.left, + right: join.right, + join_type: new_join_type, + join_constraint: join.join_constraint, + on: join.on.clone(), + filter: join.filter.clone(), + schema: Arc::clone(&join.schema), + null_equals_null: join.null_equals_null, + })); + Filter::try_new(filter.predicate, new_join) + .map(|f| Transformed::yes(LogicalPlan::filter(f))) + } + filter_input => { + filter.input = Arc::new(filter_input); + Ok(Transformed::no(LogicalPlan::filter(filter))) + } } - }, + } _ => Ok(Transformed::no(plan)), } } @@ -183,7 +185,7 @@ fn extract_non_nullable_columns( Expr::Column(col) => { non_nullable_cols.push(col.clone()); } - Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op { + Expr::BinaryExpr(BinaryExpr { left, op, right }, _) => match op { // If one of the inputs are null for these operators, the results should be false. Operator::Eq | Operator::NotEq @@ -270,14 +272,14 @@ fn extract_non_nullable_columns( } _ => {} }, - Expr::Not(arg) => extract_non_nullable_columns( + Expr::Not(arg, _) => extract_non_nullable_columns( arg, non_nullable_cols, left_schema, right_schema, false, ), - Expr::IsNotNull(arg) => { + Expr::IsNotNull(arg, _) => { if !top_level { return; } @@ -289,14 +291,16 @@ fn extract_non_nullable_columns( false, ) } - Expr::Cast(Cast { expr, data_type: _ }) - | Expr::TryCast(TryCast { expr, data_type: _ }) => extract_non_nullable_columns( - expr, - non_nullable_cols, - left_schema, - right_schema, - false, - ), + Expr::Cast(Cast { expr, data_type: _ }, _) + | Expr::TryCast(TryCast { expr, data_type: _ }, _) => { + extract_non_nullable_columns( + expr, + non_nullable_cols, + left_schema, + right_schema, + false, + ) + } _ => {} } } diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs index 48191ec206313..16c3355c3b8f6 100644 --- a/datafusion/optimizer/src/extract_equijoin_predicate.rs +++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs @@ -67,16 +67,19 @@ impl OptimizerRule for ExtractEquijoinPredicate { _config: &dyn OptimizerConfig, ) -> Result> { match plan { - LogicalPlan::Join(Join { - left, - right, - mut on, - filter: Some(expr), - join_type, - join_constraint, - schema, - null_equals_null, - }) => { + LogicalPlan::Join( + Join { + left, + right, + mut on, + filter: Some(expr), + join_type, + join_constraint, + schema, + null_equals_null, + }, + _, + ) => { let left_schema = left.schema(); let right_schema = right.schema(); let (equijoin_predicates, non_equijoin_expr) = @@ -84,7 +87,7 @@ impl OptimizerRule for ExtractEquijoinPredicate { if !equijoin_predicates.is_empty() { on.extend(equijoin_predicates); - Ok(Transformed::yes(LogicalPlan::Join(Join { + Ok(Transformed::yes(LogicalPlan::join(Join { left, right, on, @@ -95,7 +98,7 @@ impl OptimizerRule for ExtractEquijoinPredicate { null_equals_null, }))) } else { - Ok(Transformed::no(LogicalPlan::Join(Join { + Ok(Transformed::no(LogicalPlan::join(Join { left, right, on, @@ -123,11 +126,14 @@ fn split_eq_and_noneq_join_predicate( let mut accum_filters: Vec = vec![]; for expr in exprs { match expr { - Expr::BinaryExpr(BinaryExpr { - ref left, - op: Operator::Eq, - ref right, - }) => { + Expr::BinaryExpr( + BinaryExpr { + ref left, + op: Operator::Eq, + ref right, + }, + _, + ) => { let join_key_pair = find_valid_equijoin_key_pair(left, right, left_schema, right_schema)?; diff --git a/datafusion/optimizer/src/filter_null_join_keys.rs b/datafusion/optimizer/src/filter_null_join_keys.rs index 2e7a751ca4c57..3f190ff326673 100644 --- a/datafusion/optimizer/src/filter_null_join_keys.rs +++ b/datafusion/optimizer/src/filter_null_join_keys.rs @@ -50,7 +50,7 @@ impl OptimizerRule for FilterNullJoinKeys { return Ok(Transformed::no(plan)); } match plan { - LogicalPlan::Join(mut join) + LogicalPlan::Join(mut join, _) if !join.on.is_empty() && !join.null_equals_null => { let (left_preserved, right_preserved) = @@ -74,17 +74,17 @@ impl OptimizerRule for FilterNullJoinKeys { if !left_filters.is_empty() { let predicate = create_not_null_predicate(left_filters); - join.left = Arc::new(LogicalPlan::Filter(Filter::try_new( + join.left = Arc::new(LogicalPlan::filter(Filter::try_new( predicate, join.left, )?)); } if !right_filters.is_empty() { let predicate = create_not_null_predicate(right_filters); - join.right = Arc::new(LogicalPlan::Filter(Filter::try_new( + join.right = Arc::new(LogicalPlan::filter(Filter::try_new( predicate, join.right, )?)); } - Ok(Transformed::yes(LogicalPlan::Join(join))) + Ok(Transformed::yes(LogicalPlan::join(join))) } _ => Ok(Transformed::no(plan)), } @@ -95,10 +95,7 @@ impl OptimizerRule for FilterNullJoinKeys { } fn create_not_null_predicate(filters: Vec) -> Expr { - let not_null_exprs: Vec = filters - .into_iter() - .map(|c| Expr::IsNotNull(Box::new(c))) - .collect(); + let not_null_exprs: Vec = filters.into_iter().map(Expr::is_not_null).collect(); // directly unwrap since it should always have a value conjunction(not_null_exprs).unwrap() diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 1519c54dbf68a..16494c2dc0929 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -119,12 +119,12 @@ fn optimize_projections( // Recursively rewrite any nodes that may be able to avoid computation given // their parents' required indices. match plan { - LogicalPlan::Projection(proj) => { + LogicalPlan::Projection(proj, _) => { return merge_consecutive_projections(proj)?.transform_data(|proj| { rewrite_projection_given_requirements(proj, config, &indices) }) } - LogicalPlan::Aggregate(aggregate) => { + LogicalPlan::Aggregate(aggregate, _) => { // Split parent requirements to GROUP BY and aggregate sections: let n_group_exprs = aggregate.group_expr_len()?; // Offset aggregate indices so that they point to valid indices at @@ -200,10 +200,10 @@ fn optimize_projections( new_group_bys, new_aggr_expr, ) - .map(LogicalPlan::Aggregate) + .map(LogicalPlan::aggregate) }); } - LogicalPlan::Window(window) => { + LogicalPlan::Window(window, _) => { let input_schema = Arc::clone(window.input.schema()); // Split parent requirements to child and window expression sections: let n_input_fields = input_schema.fields().len(); @@ -238,12 +238,12 @@ fn optimize_projections( add_projection_on_top_if_helpful(window_child, required_exprs)? .data; Window::try_new(new_window_expr, Arc::new(window_child)) - .map(LogicalPlan::Window) + .map(LogicalPlan::window) .map(Transformed::yes) } }); } - LogicalPlan::TableScan(table_scan) => { + LogicalPlan::TableScan(table_scan, _) => { let TableScan { table_name, source, @@ -266,7 +266,7 @@ fn optimize_projections( filters, fetch, ) - .map(LogicalPlan::TableScan) + .map(LogicalPlan::table_scan) .map(Transformed::yes); } // Other node types are handled below @@ -276,12 +276,12 @@ fn optimize_projections( // For other plan node types, calculate indices for columns they use and // try to rewrite their children let mut child_required_indices: Vec = match &plan { - LogicalPlan::Sort(_) - | LogicalPlan::Filter(_) - | LogicalPlan::Repartition(_) - | LogicalPlan::Union(_) - | LogicalPlan::SubqueryAlias(_) - | LogicalPlan::Distinct(Distinct::On(_)) => { + LogicalPlan::Sort(_, _) + | LogicalPlan::Filter(_, _) + | LogicalPlan::Repartition(_, _) + | LogicalPlan::Union(_, _) + | LogicalPlan::SubqueryAlias(_, _) + | LogicalPlan::Distinct(Distinct::On(_), _) => { // Pass index requirements from the parent as well as column indices // that appear in this plan's expressions to its child. All these // operators benefit from "small" inputs, so the projection_beneficial @@ -296,7 +296,7 @@ fn optimize_projections( }) .collect::>()? } - LogicalPlan::Limit(_) => { + LogicalPlan::Limit(_, _) => { // Pass index requirements from the parent as well as column indices // that appear in this plan's expressions to its child. These operators // do not benefit from "small" inputs, so the projection_beneficial @@ -306,14 +306,14 @@ fn optimize_projections( .map(|input| indices.clone().with_plan_exprs(&plan, input.schema())) .collect::>()? } - LogicalPlan::Copy(_) - | LogicalPlan::Ddl(_) - | LogicalPlan::Dml(_) - | LogicalPlan::Explain(_) - | LogicalPlan::Analyze(_) - | LogicalPlan::Subquery(_) - | LogicalPlan::Statement(_) - | LogicalPlan::Distinct(Distinct::All(_)) => { + LogicalPlan::Copy(_, _) + | LogicalPlan::Ddl(_, _) + | LogicalPlan::Dml(_, _) + | LogicalPlan::Explain(_, _) + | LogicalPlan::Analyze(_, _) + | LogicalPlan::Subquery(_, _) + | LogicalPlan::Statement(_, _) + | LogicalPlan::Distinct(Distinct::All(_), _) => { // These plans require all their fields, and their children should // be treated as final plans -- otherwise, we may have schema a // mismatch. @@ -324,7 +324,7 @@ fn optimize_projections( .map(RequiredIndicies::new_for_all_exprs) .collect() } - LogicalPlan::Extension(extension) => { + LogicalPlan::Extension(extension, _) => { let Some(necessary_children_indices) = extension.node.necessary_children_exprs(indices.indices()) else { @@ -347,13 +347,13 @@ fn optimize_projections( .collect::>>()? } LogicalPlan::EmptyRelation(_) - | LogicalPlan::RecursiveQuery(_) - | LogicalPlan::Values(_) + | LogicalPlan::RecursiveQuery(_, _) + | LogicalPlan::Values(_, _) | LogicalPlan::DescribeTable(_) => { // These operators have no inputs, so stop the optimization process. return Ok(Transformed::no(plan)); } - LogicalPlan::Join(join) => { + LogicalPlan::Join(join, _) => { let left_len = join.left.schema().fields().len(); let (left_req_indices, right_req_indices) = split_join_requirements(left_len, indices, &join.join_type); @@ -369,17 +369,20 @@ fn optimize_projections( ] } // these nodes are explicitly rewritten in the match statement above - LogicalPlan::Projection(_) - | LogicalPlan::Aggregate(_) - | LogicalPlan::Window(_) - | LogicalPlan::TableScan(_) => { + LogicalPlan::Projection(_, _) + | LogicalPlan::Aggregate(_, _) + | LogicalPlan::Window(_, _) + | LogicalPlan::TableScan(_, _) => { return internal_err!( "OptimizeProjection: should have handled in the match statement above" ); } - LogicalPlan::Unnest(Unnest { - dependency_indices, .. - }) => { + LogicalPlan::Unnest( + Unnest { + dependency_indices, .. + }, + _, + ) => { vec![RequiredIndicies::new_from_indices( dependency_indices.clone(), )] @@ -452,7 +455,7 @@ fn merge_consecutive_projections(proj: Projection) -> Result Result Result rewrite_expr(*expr, &prev_projection).map(|result| { - result.update_data(|expr| Expr::Alias(Alias::new(expr, relation, name))) + Expr::Alias( + Alias { + expr, + relation, + name, + }, + _, + ) => rewrite_expr(*expr, &prev_projection).map(|result| { + result.update_data(|expr| expr.alias_qualified(relation, name)) }), e => rewrite_expr(e, &prev_projection), } @@ -513,7 +519,7 @@ fn merge_consecutive_projections(proj: Projection) -> Result Result> { expr.transform_up(|expr| { match expr { // remove any intermediate aliases - Expr::Alias(alias) => Ok(Transformed::yes(*alias.expr)), + Expr::Alias(alias, _) => Ok(Transformed::yes(*alias.expr)), Expr::Column(col) => { // Find index of column: let idx = input.schema.index_of_column(&col)?; @@ -607,13 +613,13 @@ fn outer_columns<'a>(expr: &'a Expr, columns: &mut HashSet<&'a Column>) { Expr::OuterReferenceColumn(_, col) => { columns.insert(col); } - Expr::ScalarSubquery(subquery) => { + Expr::ScalarSubquery(subquery, _) => { outer_columns_helper_multi(&subquery.outer_ref_columns, columns); } - Expr::Exists(exists) => { + Expr::Exists(exists, _) => { outer_columns_helper_multi(&exists.subquery.outer_ref_columns, columns); } - Expr::InSubquery(insubquery) => { + Expr::InSubquery(insubquery, _) => { outer_columns_helper_multi( &insubquery.subquery.outer_ref_columns, columns, @@ -721,7 +727,7 @@ fn add_projection_on_top_if_helpful( Ok(Transformed::no(plan)) } else { Projection::try_new(project_exprs, Arc::new(plan)) - .map(LogicalPlan::Projection) + .map(LogicalPlan::projection) .map(Transformed::yes) } } @@ -763,7 +769,7 @@ fn rewrite_projection_given_requirements( Ok(Transformed::yes(input)) } else { Projection::try_new(exprs_used, Arc::new(input)) - .map(LogicalPlan::Projection) + .map(LogicalPlan::projection) .map(Transformed::yes) } }) @@ -1208,7 +1214,7 @@ mod tests { let expr = Box::new(col("a")); let pattern = Box::new(lit("[0-9]")); let similar_to_expr = - Expr::SimilarTo(Like::new(false, expr, pattern, None, false)); + Expr::similar_to(Like::new(false, expr, pattern, None, false)); let plan = LogicalPlanBuilder::from(table_scan) .project(vec![similar_to_expr])? .build()?; @@ -1276,7 +1282,7 @@ mod tests { #[test] fn test_user_defined_logical_plan_node() -> Result<()> { let table_scan = test_table_scan()?; - let custom_plan = LogicalPlan::Extension(Extension { + let custom_plan = LogicalPlan::extension(Extension { node: Arc::new(NoOpUserDefined::new( Arc::clone(table_scan.schema()), Arc::new(table_scan.clone()), @@ -1300,7 +1306,7 @@ mod tests { fn test_user_defined_logical_plan_node2() -> Result<()> { let table_scan = test_table_scan()?; let exprs = vec![Expr::Column(Column::from_qualified_name("b"))]; - let custom_plan = LogicalPlan::Extension(Extension { + let custom_plan = LogicalPlan::extension(Extension { node: Arc::new( NoOpUserDefined::new( Arc::clone(table_scan.schema()), @@ -1329,13 +1335,13 @@ mod tests { let table_scan = test_table_scan()?; let left_expr = Expr::Column(Column::from_qualified_name("b")); let right_expr = Expr::Column(Column::from_qualified_name("c")); - let binary_expr = Expr::BinaryExpr(BinaryExpr::new( + let binary_expr = Expr::binary_expr(BinaryExpr::new( Box::new(left_expr), Operator::Plus, Box::new(right_expr), )); let exprs = vec![binary_expr]; - let custom_plan = LogicalPlan::Extension(Extension { + let custom_plan = LogicalPlan::extension(Extension { node: Arc::new( NoOpUserDefined::new( Arc::clone(table_scan.schema()), @@ -1362,7 +1368,7 @@ mod tests { fn test_user_defined_logical_plan_node4() -> Result<()> { let left_table = test_table_scan_with_name("l")?; let right_table = test_table_scan_with_name("r")?; - let custom_plan = LogicalPlan::Extension(Extension { + let custom_plan = LogicalPlan::extension(Extension { node: Arc::new(UserDefinedCrossJoin::new( Arc::new(left_table), Arc::new(right_table), @@ -1691,7 +1697,7 @@ mod tests { let table_scan = test_table_scan()?; let projection = LogicalPlanBuilder::from(table_scan) - .project(vec![Expr::Cast(Cast::new( + .project(vec![Expr::cast(Cast::new( Box::new(col("c")), DataType::Float64, ))])? @@ -1731,7 +1737,7 @@ mod tests { // relation is `None`). PlanBuilder resolves the expressions let expr = vec![col("test.a"), col("test.b")]; let plan = - LogicalPlan::Projection(Projection::try_new(expr, Arc::new(table_scan))?); + LogicalPlan::projection(Projection::try_new(expr, Arc::new(table_scan))?); assert_fields_eq(&plan, vec!["a", "b"]); @@ -1942,7 +1948,7 @@ mod tests { fn test_window() -> Result<()> { let table_scan = test_table_scan()?; - let max1 = Expr::WindowFunction(expr::WindowFunction::new( + let max1 = Expr::window_function(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("test.a")], )) @@ -1950,7 +1956,7 @@ mod tests { .build() .unwrap(); - let max2 = Expr::WindowFunction(expr::WindowFunction::new( + let max2 = Expr::window_function(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("test.b")], )); diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 975150cd61220..579e9cca3d8bc 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -573,7 +573,7 @@ mod tests { let input = Arc::new(test_table_scan()?); let input_schema = Arc::clone(input.schema()); - let plan = LogicalPlan::Projection(Projection::try_new_with_schema( + let plan = LogicalPlan::projection(Projection::try_new_with_schema( vec![col("a"), col("b"), col("c")], input, add_metadata_to_fields(input_schema.as_ref()), @@ -740,7 +740,7 @@ mod tests { _config: &dyn OptimizerConfig, ) -> Result> { let projection = match plan { - LogicalPlan::Projection(p) if p.expr.len() >= 2 => p, + LogicalPlan::Projection(p, _) if p.expr.len() >= 2 => p, _ => return Ok(Transformed::no(plan)), }; @@ -754,7 +754,7 @@ mod tests { exprs.rotate_left(1); } - Ok(Transformed::yes(LogicalPlan::Projection( + Ok(Transformed::yes(LogicalPlan::projection( Projection::try_new(exprs, Arc::clone(&projection.input))?, ))) } diff --git a/datafusion/optimizer/src/plan_signature.rs b/datafusion/optimizer/src/plan_signature.rs index 73e6b418272a9..1ce61b728c6d6 100644 --- a/datafusion/optimizer/src/plan_signature.rs +++ b/datafusion/optimizer/src/plan_signature.rs @@ -105,13 +105,13 @@ mod tests { assert_eq!(1, get_node_number(&one_node_plan).get()); - let two_node_plan = Arc::new(LogicalPlan::Projection( + let two_node_plan = Arc::new(LogicalPlan::projection( datafusion_expr::Projection::try_new(vec![lit(1), lit(2)], one_node_plan)?, )); assert_eq!(2, get_node_number(&two_node_plan).get()); - let five_node_plan = Arc::new(LogicalPlan::Union(datafusion_expr::Union { + let five_node_plan = Arc::new(LogicalPlan::union(datafusion_expr::Union { inputs: vec![Arc::clone(&two_node_plan), two_node_plan], schema, })); diff --git a/datafusion/optimizer/src/propagate_empty_relation.rs b/datafusion/optimizer/src/propagate_empty_relation.rs index d26df073dc6fd..f268a8ea84afa 100644 --- a/datafusion/optimizer/src/propagate_empty_relation.rs +++ b/datafusion/optimizer/src/propagate_empty_relation.rs @@ -59,20 +59,20 @@ impl OptimizerRule for PropagateEmptyRelation { ) -> Result> { match plan { LogicalPlan::EmptyRelation(_) => Ok(Transformed::no(plan)), - LogicalPlan::Projection(_) - | LogicalPlan::Filter(_) - | LogicalPlan::Window(_) - | LogicalPlan::Sort(_) - | LogicalPlan::SubqueryAlias(_) - | LogicalPlan::Repartition(_) - | LogicalPlan::Limit(_) => { + LogicalPlan::Projection(_, _) + | LogicalPlan::Filter(_, _) + | LogicalPlan::Window(_, _) + | LogicalPlan::Sort(_, _) + | LogicalPlan::SubqueryAlias(_, _) + | LogicalPlan::Repartition(_, _) + | LogicalPlan::Limit(_, _) => { let empty = empty_child(&plan)?; if let Some(empty_plan) = empty { return Ok(Transformed::yes(empty_plan)); } Ok(Transformed::no(plan)) } - LogicalPlan::Join(ref join) => { + LogicalPlan::Join(ref join, _) => { // TODO: For Join, more join type need to be careful: // For LeftOut/Full Join, if the right side is empty, the Join can be eliminated with a Projection with left side // columns + right side columns replaced with null values. @@ -139,15 +139,15 @@ impl OptimizerRule for PropagateEmptyRelation { _ => Ok(Transformed::no(plan)), } } - LogicalPlan::Aggregate(ref agg) => { + LogicalPlan::Aggregate(ref agg, _) => { if !agg.group_expr.is_empty() { if let Some(empty_plan) = empty_child(&plan)? { return Ok(Transformed::yes(empty_plan)); } } - Ok(Transformed::no(LogicalPlan::Aggregate(agg.clone()))) + Ok(Transformed::no(LogicalPlan::aggregate(agg.clone()))) } - LogicalPlan::Union(ref union) => { + LogicalPlan::Union(ref union, _) => { let new_inputs = union .inputs .iter() @@ -174,7 +174,7 @@ impl OptimizerRule for PropagateEmptyRelation { if child.schema().eq(plan.schema()) { Ok(Transformed::yes(child)) } else { - Ok(Transformed::yes(LogicalPlan::Projection( + Ok(Transformed::yes(LogicalPlan::projection( Projection::new_from_schema( Arc::new(child), Arc::clone(plan.schema()), @@ -182,7 +182,7 @@ impl OptimizerRule for PropagateEmptyRelation { ))) } } else { - Ok(Transformed::yes(LogicalPlan::Union(Union { + Ok(Transformed::yes(LogicalPlan::union(Union { inputs: new_inputs, schema: Arc::clone(&union.schema), }))) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 23cd46803c78d..06c031e1a2d57 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -257,37 +257,37 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::Placeholder(_) | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Jump), Expr::Exists { .. } - | Expr::InSubquery(_) - | Expr::ScalarSubquery(_) + | Expr::InSubquery(_, _) + | Expr::ScalarSubquery(_, _) | Expr::OuterReferenceColumn(_, _) - | Expr::Unnest(_) => { + | Expr::Unnest(_, _) => { is_evaluate = false; Ok(TreeNodeRecursion::Stop) } - Expr::Alias(_) - | Expr::BinaryExpr(_) - | Expr::Like(_) - | Expr::SimilarTo(_) - | Expr::Not(_) - | Expr::IsNotNull(_) - | Expr::IsNull(_) - | Expr::IsTrue(_) - | Expr::IsFalse(_) - | Expr::IsUnknown(_) - | Expr::IsNotTrue(_) - | Expr::IsNotFalse(_) - | Expr::IsNotUnknown(_) - | Expr::Negative(_) - | Expr::Between(_) - | Expr::Case(_) - | Expr::Cast(_) - | Expr::TryCast(_) + Expr::Alias(_, _) + | Expr::BinaryExpr(_, _) + | Expr::Like(_, _) + | Expr::SimilarTo(_, _) + | Expr::Not(_, _) + | Expr::IsNotNull(_, _) + | Expr::IsNull(_, _) + | Expr::IsTrue(_, _) + | Expr::IsFalse(_, _) + | Expr::IsUnknown(_, _) + | Expr::IsNotTrue(_, _) + | Expr::IsNotFalse(_, _) + | Expr::IsNotUnknown(_, _) + | Expr::Negative(_, _) + | Expr::Between(_, _) + | Expr::Case(_, _) + | Expr::Cast(_, _) + | Expr::TryCast(_, _) | Expr::InList { .. } - | Expr::ScalarFunction(_) => Ok(TreeNodeRecursion::Continue), - Expr::AggregateFunction(_) - | Expr::WindowFunction(_) + | Expr::ScalarFunction(_, _) => Ok(TreeNodeRecursion::Continue), + Expr::AggregateFunction(_, _) + | Expr::WindowFunction(_, _) | Expr::Wildcard { .. } - | Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"), + | Expr::GroupingSet(_, _) => internal_err!("Unsupported predicate type"), })?; Ok(is_evaluate) } @@ -333,11 +333,14 @@ fn extract_or_clauses_for_join<'a>( // new formed OR clauses and their column references filters.iter().filter_map(move |expr| { - if let Expr::BinaryExpr(BinaryExpr { - left, - op: Operator::Or, - right, - }) = expr + if let Expr::BinaryExpr( + BinaryExpr { + left, + op: Operator::Or, + right, + }, + _, + ) = expr { let left_expr = extract_or_clause(left.as_ref(), &schema_columns); let right_expr = extract_or_clause(right.as_ref(), &schema_columns); @@ -366,11 +369,14 @@ fn extract_or_clause(expr: &Expr, schema_columns: &HashSet) -> Option { + Expr::BinaryExpr( + BinaryExpr { + left: l_expr, + op: Operator::Or, + right: r_expr, + }, + _, + ) => { let l_expr = extract_or_clause(l_expr, schema_columns); let r_expr = extract_or_clause(r_expr, schema_columns); @@ -378,11 +384,14 @@ fn extract_or_clause(expr: &Expr, schema_columns: &HashSet) -> Option { + Expr::BinaryExpr( + BinaryExpr { + left: l_expr, + op: Operator::And, + right: r_expr, + }, + _, + ) => { let l_expr = extract_or_clause(l_expr, schema_columns); let r_expr = extract_or_clause(r_expr, schema_columns); @@ -498,11 +507,11 @@ fn push_down_all_join( } if let Some(predicate) = conjunction(left_push) { - join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.left)?)); + join.left = Arc::new(LogicalPlan::filter(Filter::try_new(predicate, join.left)?)); } if let Some(predicate) = conjunction(right_push) { join.right = - Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.right)?)); + Arc::new(LogicalPlan::filter(Filter::try_new(predicate, join.right)?)); } // Add any new join conditions as the non join predicates @@ -510,9 +519,9 @@ fn push_down_all_join( join.filter = conjunction(join_conditions); // wrap the join on the filter whose predicates must be kept, if any - let plan = LogicalPlan::Join(join); + let plan = LogicalPlan::join(join); let plan = if let Some(predicate) = conjunction(keep_predicates) { - LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(plan))?) + LogicalPlan::filter(Filter::try_new(predicate, Arc::new(plan))?) } else { plan }; @@ -541,7 +550,7 @@ fn push_down_join( && predicates.is_empty() && inferred_join_predicates.is_empty() { - return Ok(Transformed::no(LogicalPlan::Join(join))); + return Ok(Transformed::no(LogicalPlan::join(join))); } push_down_all_join(predicates, inferred_join_predicates, join, on_filters) @@ -765,18 +774,18 @@ impl OptimizerRule for PushDownFilter { plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { - if let LogicalPlan::Join(join) = plan { + if let LogicalPlan::Join(join, _) = plan { return push_down_join(join, None); }; let plan_schema = Arc::clone(plan.schema()); - let LogicalPlan::Filter(mut filter) = plan else { + let LogicalPlan::Filter(mut filter, _) = plan else { return Ok(Transformed::no(plan)); }; match Arc::unwrap_or_clone(filter.input) { - LogicalPlan::Filter(child_filter) => { + LogicalPlan::Filter(child_filter, _) => { let parents_predicates = split_conjunction_owned(filter.predicate); // remove duplicated filters @@ -792,31 +801,31 @@ impl OptimizerRule for PushDownFilter { let Some(new_predicate) = conjunction(new_predicates) else { return plan_err!("at least one expression exists"); }; - let new_filter = LogicalPlan::Filter(Filter::try_new( + let new_filter = LogicalPlan::filter(Filter::try_new( new_predicate, child_filter.input, )?); self.rewrite(new_filter, _config) } - LogicalPlan::Repartition(repartition) => { + LogicalPlan::Repartition(repartition, _) => { let new_filter = Filter::try_new(filter.predicate, Arc::clone(&repartition.input)) - .map(LogicalPlan::Filter)?; - insert_below(LogicalPlan::Repartition(repartition), new_filter) + .map(LogicalPlan::filter)?; + insert_below(LogicalPlan::repartition(repartition), new_filter) } - LogicalPlan::Distinct(distinct) => { + LogicalPlan::Distinct(distinct, _) => { let new_filter = Filter::try_new(filter.predicate, Arc::clone(distinct.input())) - .map(LogicalPlan::Filter)?; - insert_below(LogicalPlan::Distinct(distinct), new_filter) + .map(LogicalPlan::filter)?; + insert_below(LogicalPlan::distinct(distinct), new_filter) } - LogicalPlan::Sort(sort) => { + LogicalPlan::Sort(sort, _) => { let new_filter = Filter::try_new(filter.predicate, Arc::clone(&sort.input)) - .map(LogicalPlan::Filter)?; - insert_below(LogicalPlan::Sort(sort), new_filter) + .map(LogicalPlan::filter)?; + insert_below(LogicalPlan::sort(sort), new_filter) } - LogicalPlan::SubqueryAlias(subquery_alias) => { + LogicalPlan::SubqueryAlias(subquery_alias, _) => { let mut replace_map = HashMap::new(); for (i, (qualifier, field)) in subquery_alias.input.schema().iter().enumerate() @@ -830,13 +839,13 @@ impl OptimizerRule for PushDownFilter { } let new_predicate = replace_cols_by_name(filter.predicate, &replace_map)?; - let new_filter = LogicalPlan::Filter(Filter::try_new( + let new_filter = LogicalPlan::filter(Filter::try_new( new_predicate, Arc::clone(&subquery_alias.input), )?); - insert_below(LogicalPlan::SubqueryAlias(subquery_alias), new_filter) + insert_below(LogicalPlan::subquery_alias(subquery_alias), new_filter) } - LogicalPlan::Projection(projection) => { + LogicalPlan::Projection(projection, _) => { let predicates = split_conjunction_owned(filter.predicate.clone()); let (new_projection, keep_predicate) = rewrite_projection(predicates, projection)?; @@ -845,15 +854,15 @@ impl OptimizerRule for PushDownFilter { None => Ok(new_projection), Some(keep_predicate) => new_projection.map_data(|child_plan| { Filter::try_new(keep_predicate, Arc::new(child_plan)) - .map(LogicalPlan::Filter) + .map(LogicalPlan::filter) }), } } else { filter.input = Arc::new(new_projection.data); - Ok(Transformed::no(LogicalPlan::Filter(filter))) + Ok(Transformed::no(LogicalPlan::filter(filter))) } } - LogicalPlan::Unnest(mut unnest) => { + LogicalPlan::Unnest(mut unnest, _) => { let predicates = split_conjunction_owned(filter.predicate.clone()); let mut non_unnest_predicates = vec![]; let mut unnest_predicates = vec![]; @@ -874,8 +883,8 @@ impl OptimizerRule for PushDownFilter { // Unnest predicates should not be pushed down. // If no non-unnest predicates exist, early return if non_unnest_predicates.is_empty() { - filter.input = Arc::new(LogicalPlan::Unnest(unnest)); - return Ok(Transformed::no(LogicalPlan::Filter(filter))); + filter.input = Arc::new(LogicalPlan::unnest(unnest)); + return Ok(Transformed::no(LogicalPlan::filter(filter))); } // Push down non-unnest filter predicate @@ -888,7 +897,7 @@ impl OptimizerRule for PushDownFilter { let unnest_input = std::mem::take(&mut unnest.input); - let filter_with_unnest_input = LogicalPlan::Filter(Filter::try_new( + let filter_with_unnest_input = LogicalPlan::filter(Filter::try_new( conjunction(non_unnest_predicates).unwrap(), // Safe to unwrap since non_unnest_predicates is not empty. unnest_input, )?); @@ -897,16 +906,16 @@ impl OptimizerRule for PushDownFilter { // The new filter plan will go through another rewrite pass since the rule itself // is applied recursively to all the child from top to down let unnest_plan = - insert_below(LogicalPlan::Unnest(unnest), filter_with_unnest_input)?; + insert_below(LogicalPlan::unnest(unnest), filter_with_unnest_input)?; match conjunction(unnest_predicates) { None => Ok(unnest_plan), - Some(predicate) => Ok(Transformed::yes(LogicalPlan::Filter( + Some(predicate) => Ok(Transformed::yes(LogicalPlan::filter( Filter::try_new(predicate, Arc::new(unnest_plan.data))?, ))), } } - LogicalPlan::Union(ref union) => { + LogicalPlan::Union(ref union, _) => { let mut inputs = Vec::with_capacity(union.inputs.len()); for input in &union.inputs { let mut replace_map = HashMap::new(); @@ -921,17 +930,17 @@ impl OptimizerRule for PushDownFilter { let push_predicate = replace_cols_by_name(filter.predicate.clone(), &replace_map)?; - inputs.push(Arc::new(LogicalPlan::Filter(Filter::try_new( + inputs.push(Arc::new(LogicalPlan::filter(Filter::try_new( push_predicate, Arc::clone(input), )?))) } - Ok(Transformed::yes(LogicalPlan::Union(Union { + Ok(Transformed::yes(LogicalPlan::union(Union { inputs, schema: Arc::clone(&plan_schema), }))) } - LogicalPlan::Aggregate(agg) => { + LogicalPlan::Aggregate(agg, _) => { // We can push down Predicate which in groupby_expr. let group_expr_columns = agg .group_expr @@ -965,7 +974,7 @@ impl OptimizerRule for PushDownFilter { .collect::>>()?; let agg_input = Arc::clone(&agg.input); - Transformed::yes(LogicalPlan::Aggregate(agg)) + Transformed::yes(LogicalPlan::aggregate(agg)) .transform_data(|new_plan| { // If we have a filter to push, we push it down to the input of the aggregate if let Some(predicate) = conjunction(replaced_push_predicates) { @@ -985,8 +994,8 @@ impl OptimizerRule for PushDownFilter { } }) } - LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)), - LogicalPlan::TableScan(scan) => { + LogicalPlan::Join(join, _) => push_down_join(join, Some(&filter.predicate)), + LogicalPlan::TableScan(scan, _) => { let filter_predicates = split_conjunction(&filter.predicate); let results = scan .source @@ -1016,7 +1025,7 @@ impl OptimizerRule for PushDownFilter { .map(|(pred, _)| pred.clone()) .collect(); - let new_scan = LogicalPlan::TableScan(TableScan { + let new_scan = LogicalPlan::table_scan(TableScan { filters: new_scan_filters, ..scan }); @@ -1029,7 +1038,7 @@ impl OptimizerRule for PushDownFilter { } }) } - LogicalPlan::Extension(extension_plan) => { + LogicalPlan::Extension(extension_plan, _) => { let prevent_cols = extension_plan.node.prevent_predicate_push_down_columns(); @@ -1050,8 +1059,8 @@ impl OptimizerRule for PushDownFilter { // all predicates are kept, no changes needed if predicate_push_or_keep.iter().all(|&x| !x) { - filter.input = Arc::new(LogicalPlan::Extension(extension_plan)); - return Ok(Transformed::no(LogicalPlan::Filter(filter))); + filter.input = Arc::new(LogicalPlan::extension(extension_plan)); + return Ok(Transformed::no(LogicalPlan::filter(filter))); } // going to push some predicates down, so split the predicates @@ -1074,7 +1083,7 @@ impl OptimizerRule for PushDownFilter { .inputs() .into_iter() .map(|child| { - Ok(LogicalPlan::Filter(Filter::try_new( + Ok(LogicalPlan::filter(Filter::try_new( predicate.clone(), Arc::new(child.clone()), )?)) @@ -1083,12 +1092,12 @@ impl OptimizerRule for PushDownFilter { None => extension_plan.node.inputs().into_iter().cloned().collect(), }; // extension with new inputs. - let child_plan = LogicalPlan::Extension(extension_plan); + let child_plan = LogicalPlan::extension(extension_plan); let new_extension = child_plan.with_new_exprs(child_plan.expressions(), new_children)?; let new_plan = match conjunction(keep_predicates) { - Some(predicate) => LogicalPlan::Filter(Filter::try_new( + Some(predicate) => LogicalPlan::filter(Filter::try_new( predicate, Arc::new(new_extension), )?), @@ -1098,7 +1107,7 @@ impl OptimizerRule for PushDownFilter { } child => { filter.input = Arc::new(child); - Ok(Transformed::no(LogicalPlan::Filter(filter))) + Ok(Transformed::no(LogicalPlan::filter(filter))) } } } @@ -1164,7 +1173,7 @@ fn rewrite_projection( Some(expr) => { // re-write all filters based on this projection // E.g. in `Filter: b\n Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1" - let new_filter = LogicalPlan::Filter(Filter::try_new( + let new_filter = LogicalPlan::filter(Filter::try_new( replace_cols_by_name(expr, &non_volatile_map)?, std::mem::take(&mut projection.input), )?); @@ -1172,17 +1181,17 @@ fn rewrite_projection( projection.input = Arc::new(new_filter); Ok(( - Transformed::yes(LogicalPlan::Projection(projection)), + Transformed::yes(LogicalPlan::projection(projection)), conjunction(keep_predicates), )) } - None => Ok((Transformed::no(LogicalPlan::Projection(projection)), None)), + None => Ok((Transformed::no(LogicalPlan::projection(projection)), None)), } } /// Creates a new LogicalPlan::Filter node. pub fn make_filter(predicate: Expr, input: Arc) -> Result { - Filter::try_new(predicate, input).map(LogicalPlan::Filter) + Filter::try_new(predicate, input).map(LogicalPlan::filter) } /// Replace the existing child of the single input node with `new_child`. @@ -1444,7 +1453,7 @@ mod tests { } fn add(left: Expr, right: Expr) -> Expr { - Expr::BinaryExpr(BinaryExpr::new( + Expr::binary_expr(BinaryExpr::new( Box::new(left), Operator::Plus, Box::new(right), @@ -1452,7 +1461,7 @@ mod tests { } fn multiply(left: Expr, right: Expr) -> Expr { - Expr::BinaryExpr(BinaryExpr::new( + Expr::binary_expr(BinaryExpr::new( Box::new(left), Operator::Multiply, Box::new(right), @@ -1580,7 +1589,7 @@ mod tests { fn user_defined_plan() -> Result<()> { let table_scan = test_table_scan()?; - let custom_plan = LogicalPlan::Extension(Extension { + let custom_plan = LogicalPlan::extension(Extension { node: Arc::new(NoopPlan { input: vec![table_scan.clone()], schema: Arc::clone(table_scan.schema()), @@ -1596,7 +1605,7 @@ mod tests { \n TableScan: test, full_filters=[test.a = Int64(1)]"; assert_optimized_plan_eq(plan, expected)?; - let custom_plan = LogicalPlan::Extension(Extension { + let custom_plan = LogicalPlan::extension(Extension { node: Arc::new(NoopPlan { input: vec![table_scan.clone()], schema: Arc::clone(table_scan.schema()), @@ -1613,7 +1622,7 @@ mod tests { \n TableScan: test, full_filters=[test.a = Int64(1)]"; assert_optimized_plan_eq(plan, expected)?; - let custom_plan = LogicalPlan::Extension(Extension { + let custom_plan = LogicalPlan::extension(Extension { node: Arc::new(NoopPlan { input: vec![table_scan.clone(), table_scan.clone()], schema: Arc::clone(table_scan.schema()), @@ -1630,7 +1639,7 @@ mod tests { \n TableScan: test, full_filters=[test.a = Int64(1)]"; assert_optimized_plan_eq(plan, expected)?; - let custom_plan = LogicalPlan::Extension(Extension { + let custom_plan = LogicalPlan::extension(Extension { node: Arc::new(NoopPlan { input: vec![table_scan.clone(), table_scan.clone()], schema: Arc::clone(table_scan.schema()), @@ -2520,7 +2529,7 @@ mod tests { ) -> Result { let test_provider = PushDownProvider { filter_support }; - let table_scan = LogicalPlan::TableScan(TableScan { + let table_scan = LogicalPlan::table_scan(TableScan { table_name: "test".into(), filters: vec![], projected_schema: Arc::new(DFSchema::try_from( @@ -2592,7 +2601,7 @@ mod tests { filter_support: TableProviderFilterPushDown::Inexact, }; - let table_scan = LogicalPlan::TableScan(TableScan { + let table_scan = LogicalPlan::table_scan(TableScan { table_name: "test".into(), filters: vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))], projected_schema: Arc::new(DFSchema::try_from( @@ -2621,7 +2630,7 @@ mod tests { filter_support: TableProviderFilterPushDown::Exact, }; - let table_scan = LogicalPlan::TableScan(TableScan { + let table_scan = LogicalPlan::table_scan(TableScan { table_name: "test".into(), filters: vec![], projected_schema: Arc::new(DFSchema::try_from( @@ -3313,7 +3322,7 @@ Projection: a, b let fun = ScalarUDF::new_from_impl(TestScalarUDF { signature: Signature::exact(vec![], Volatility::Volatile), }); - let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![])); + let expr = Expr::scalar_function(ScalarFunction::new_udf(Arc::new(fun), vec![])); let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("a")], vec![sum(col("b"))])? @@ -3347,7 +3356,7 @@ Projection: a, b let fun = ScalarUDF::new_from_impl(TestScalarUDF { signature: Signature::exact(vec![], Volatility::Volatile), }); - let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![])); + let expr = Expr::scalar_function(ScalarFunction::new_udf(Arc::new(fun), vec![])); let left = LogicalPlanBuilder::from(table_scan).build()?; let right_table_scan = test_table_scan_with_name("test2")?; let right = LogicalPlanBuilder::from(right_table_scan).build()?; diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 8a3aa4bb84599..ed526e950eb28 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -53,29 +53,29 @@ impl OptimizerRule for PushDownLimit { plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { - let LogicalPlan::Limit(mut limit) = plan else { + let LogicalPlan::Limit(mut limit, _) = plan else { return Ok(Transformed::no(plan)); }; // Currently only rewrite if skip and fetch are both literals let SkipType::Literal(skip) = limit.get_skip_type()? else { - return Ok(Transformed::no(LogicalPlan::Limit(limit))); + return Ok(Transformed::no(LogicalPlan::limit(limit))); }; let FetchType::Literal(fetch) = limit.get_fetch_type()? else { - return Ok(Transformed::no(LogicalPlan::Limit(limit))); + return Ok(Transformed::no(LogicalPlan::limit(limit))); }; // Merge the Parent Limit and the Child Limit. - if let LogicalPlan::Limit(child) = limit.input.as_ref() { + if let LogicalPlan::Limit(child, _) = limit.input.as_ref() { let SkipType::Literal(child_skip) = child.get_skip_type()? else { - return Ok(Transformed::no(LogicalPlan::Limit(limit))); + return Ok(Transformed::no(LogicalPlan::limit(limit))); }; let FetchType::Literal(child_fetch) = child.get_fetch_type()? else { - return Ok(Transformed::no(LogicalPlan::Limit(limit))); + return Ok(Transformed::no(LogicalPlan::limit(limit))); }; let (skip, fetch) = combine_limit(skip, fetch, child_skip, child_fetch); - let plan = LogicalPlan::Limit(Limit { + let plan = LogicalPlan::limit(Limit { skip: Some(Box::new(lit(skip as i64))), fetch: fetch.map(|f| Box::new(lit(f as i64))), input: Arc::clone(&child.input), @@ -87,74 +87,76 @@ impl OptimizerRule for PushDownLimit { // no fetch to push, so return the original plan let Some(fetch) = fetch else { - return Ok(Transformed::no(LogicalPlan::Limit(limit))); + return Ok(Transformed::no(LogicalPlan::limit(limit))); }; match Arc::unwrap_or_clone(limit.input) { - LogicalPlan::TableScan(mut scan) => { + LogicalPlan::TableScan(mut scan, _) => { let rows_needed = if fetch != 0 { fetch + skip } else { 0 }; let new_fetch = scan .fetch .map(|x| min(x, rows_needed)) .or(Some(rows_needed)); if new_fetch == scan.fetch { - original_limit(skip, fetch, LogicalPlan::TableScan(scan)) + original_limit(skip, fetch, LogicalPlan::table_scan(scan)) } else { // push limit into the table scan itself scan.fetch = scan .fetch .map(|x| min(x, rows_needed)) .or(Some(rows_needed)); - transformed_limit(skip, fetch, LogicalPlan::TableScan(scan)) + transformed_limit(skip, fetch, LogicalPlan::table_scan(scan)) } } - LogicalPlan::Union(mut union) => { + LogicalPlan::Union(mut union, _) => { // push limits to each input of the union union.inputs = union .inputs .into_iter() .map(|input| make_arc_limit(0, fetch + skip, input)) .collect(); - transformed_limit(skip, fetch, LogicalPlan::Union(union)) + transformed_limit(skip, fetch, LogicalPlan::union(union)) } - LogicalPlan::Join(join) => Ok(push_down_join(join, fetch + skip) + LogicalPlan::Join(join, _) => Ok(push_down_join(join, fetch + skip) .update_data(|join| { - make_limit(skip, fetch, Arc::new(LogicalPlan::Join(join))) + make_limit(skip, fetch, Arc::new(LogicalPlan::join(join))) })), - LogicalPlan::Sort(mut sort) => { + LogicalPlan::Sort(mut sort, _) => { let new_fetch = { let sort_fetch = skip + fetch; Some(sort.fetch.map(|f| f.min(sort_fetch)).unwrap_or(sort_fetch)) }; if new_fetch == sort.fetch { if skip > 0 { - original_limit(skip, fetch, LogicalPlan::Sort(sort)) + original_limit(skip, fetch, LogicalPlan::sort(sort)) } else { - Ok(Transformed::yes(LogicalPlan::Sort(sort))) + Ok(Transformed::yes(LogicalPlan::sort(sort))) } } else { sort.fetch = new_fetch; - limit.input = Arc::new(LogicalPlan::Sort(sort)); - Ok(Transformed::yes(LogicalPlan::Limit(limit))) + limit.input = Arc::new(LogicalPlan::sort(sort)); + Ok(Transformed::yes(LogicalPlan::limit(limit))) } } - LogicalPlan::Projection(mut proj) => { + LogicalPlan::Projection(mut proj, _) => { // commute limit.input = Arc::clone(&proj.input); - let new_limit = LogicalPlan::Limit(limit); + let new_limit = LogicalPlan::limit(limit); proj.input = Arc::new(new_limit); - Ok(Transformed::yes(LogicalPlan::Projection(proj))) + Ok(Transformed::yes(LogicalPlan::projection(proj))) } - LogicalPlan::SubqueryAlias(mut subquery_alias) => { + LogicalPlan::SubqueryAlias(mut subquery_alias, _) => { // commute limit.input = Arc::clone(&subquery_alias.input); - let new_limit = LogicalPlan::Limit(limit); + let new_limit = LogicalPlan::limit(limit); subquery_alias.input = Arc::new(new_limit); - Ok(Transformed::yes(LogicalPlan::SubqueryAlias(subquery_alias))) + Ok(Transformed::yes(LogicalPlan::subquery_alias( + subquery_alias, + ))) } - LogicalPlan::Extension(extension_plan) + LogicalPlan::Extension(extension_plan, _) if extension_plan.node.supports_limit_pushdown() => { let new_children = extension_plan @@ -162,7 +164,7 @@ impl OptimizerRule for PushDownLimit { .inputs() .into_iter() .map(|child| { - LogicalPlan::Limit(Limit { + LogicalPlan::limit(Limit { skip: None, fetch: Some(Box::new(lit((fetch + skip) as i64))), input: Arc::new(child.clone()), @@ -171,7 +173,7 @@ impl OptimizerRule for PushDownLimit { .collect::>(); // Create a new extension node with updated inputs - let child_plan = LogicalPlan::Extension(extension_plan); + let child_plan = LogicalPlan::extension(extension_plan); let new_extension = child_plan.with_new_exprs(child_plan.expressions(), new_children)?; @@ -203,7 +205,7 @@ impl OptimizerRule for PushDownLimit { /// input /// ``` fn make_limit(skip: usize, fetch: usize, input: Arc) -> LogicalPlan { - LogicalPlan::Limit(Limit { + LogicalPlan::limit(Limit { skip: Some(Box::new(lit(skip as i64))), fetch: Some(Box::new(lit(fetch as i64))), input, @@ -400,7 +402,7 @@ mod test { #[test] fn limit_pushdown_basic() -> Result<()> { let table_scan = test_table_scan()?; - let noop_plan = LogicalPlan::Extension(Extension { + let noop_plan = LogicalPlan::extension(Extension { node: Arc::new(NoopPlan { input: vec![table_scan.clone()], schema: Arc::clone(table_scan.schema()), @@ -422,7 +424,7 @@ mod test { #[test] fn limit_pushdown_with_skip() -> Result<()> { let table_scan = test_table_scan()?; - let noop_plan = LogicalPlan::Extension(Extension { + let noop_plan = LogicalPlan::extension(Extension { node: Arc::new(NoopPlan { input: vec![table_scan.clone()], schema: Arc::clone(table_scan.schema()), @@ -444,7 +446,7 @@ mod test { #[test] fn limit_pushdown_multiple_limits() -> Result<()> { let table_scan = test_table_scan()?; - let noop_plan = LogicalPlan::Extension(Extension { + let noop_plan = LogicalPlan::extension(Extension { node: Arc::new(NoopPlan { input: vec![table_scan.clone()], schema: Arc::clone(table_scan.schema()), @@ -467,7 +469,7 @@ mod test { #[test] fn limit_pushdown_multiple_inputs() -> Result<()> { let table_scan = test_table_scan()?; - let noop_plan = LogicalPlan::Extension(Extension { + let noop_plan = LogicalPlan::extension(Extension { node: Arc::new(NoopPlan { input: vec![table_scan.clone(), table_scan.clone()], schema: Arc::clone(table_scan.schema()), @@ -491,7 +493,7 @@ mod test { #[test] fn limit_pushdown_disallowed_noop_plan() -> Result<()> { let table_scan = test_table_scan()?; - let no_limit_noop_plan = LogicalPlan::Extension(Extension { + let no_limit_noop_plan = LogicalPlan::extension(Extension { node: Arc::new(NoLimitNoopPlan { input: vec![table_scan.clone()], schema: Arc::clone(table_scan.schema()), diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index f3e1673e72111..f2500a09104bb 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -77,7 +77,7 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { config: &dyn OptimizerConfig, ) -> Result> { match plan { - LogicalPlan::Distinct(Distinct::All(input)) => { + LogicalPlan::Distinct(Distinct::All(input), _) => { let group_expr = expand_wildcard(input.schema(), &input, None)?; let field_count = input.schema().fields().len(); @@ -95,20 +95,23 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { } // Replace with aggregation: - let aggr_plan = LogicalPlan::Aggregate(Aggregate::try_new( + let aggr_plan = LogicalPlan::aggregate(Aggregate::try_new( input, group_expr, vec![], )?); Ok(Transformed::yes(aggr_plan)) } - LogicalPlan::Distinct(Distinct::On(DistinctOn { - select_expr, - on_expr, - sort_expr, - input, - schema, - })) => { + LogicalPlan::Distinct( + Distinct::On(DistinctOn { + select_expr, + on_expr, + sort_expr, + input, + schema, + }), + _, + ) => { let expr_cnt = on_expr.len(); // Construct the aggregation expression to be used to fetch the selected expressions. @@ -131,7 +134,7 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { let group_expr = normalize_cols(on_expr, input.as_ref())?; // Build the aggregation plan - let plan = LogicalPlan::Aggregate(Aggregate::try_new( + let plan = LogicalPlan::aggregate(Aggregate::try_new( input, group_expr, aggr_expr, )?); // TODO use LogicalPlanBuilder directly rather than recreating the Aggregate diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 2e2c8fb1d6f8c..8ea87ff95a09c 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -18,6 +18,7 @@ //! [`ScalarSubqueryToJoin`] rewriting scalar subquery filters to `JOIN`s use std::collections::{BTreeSet, HashMap}; +use std::ops::Not; use std::sync::Arc; use crate::decorrelate::{PullUpCorrelatedExpr, UN_MATCHED_ROW_INDICATOR}; @@ -79,11 +80,11 @@ impl OptimizerRule for ScalarSubqueryToJoin { config: &dyn OptimizerConfig, ) -> Result> { match plan { - LogicalPlan::Filter(filter) => { + LogicalPlan::Filter(filter, _) => { // Optimization: skip the rest of the rule and its copies if // there are no scalar subqueries if !contains_scalar_subquery(&filter.predicate) { - return Ok(Transformed::no(LogicalPlan::Filter(filter))); + return Ok(Transformed::no(LogicalPlan::filter(filter))); } let (subqueries, mut rewrite_expr) = self.extract_subquery_exprs( @@ -119,7 +120,7 @@ impl OptimizerRule for ScalarSubqueryToJoin { cur_input = optimized_subquery; } else { // if we can't handle all of the subqueries then bail for now - return Ok(Transformed::no(LogicalPlan::Filter(filter))); + return Ok(Transformed::no(LogicalPlan::filter(filter))); } } let new_plan = LogicalPlanBuilder::from(cur_input) @@ -127,11 +128,11 @@ impl OptimizerRule for ScalarSubqueryToJoin { .build()?; Ok(Transformed::yes(new_plan)) } - LogicalPlan::Projection(projection) => { + LogicalPlan::Projection(projection, _) => { // Optimization: skip the rest of the rule and its copies if // there are no scalar subqueries if !projection.expr.iter().any(contains_scalar_subquery) { - return Ok(Transformed::no(LogicalPlan::Projection(projection))); + return Ok(Transformed::no(LogicalPlan::projection(projection))); } let mut all_subqueryies = vec![]; @@ -182,7 +183,7 @@ impl OptimizerRule for ScalarSubqueryToJoin { } } else { // if we can't handle all of the subqueries then bail for now - return Ok(Transformed::no(LogicalPlan::Projection(projection))); + return Ok(Transformed::no(LogicalPlan::projection(projection))); } } @@ -219,7 +220,7 @@ impl OptimizerRule for ScalarSubqueryToJoin { /// Returns true if the expression has a scalar subquery somewhere in it /// false otherwise fn contains_scalar_subquery(expr: &Expr) -> bool { - expr.exists(|expr| Ok(matches!(expr, Expr::ScalarSubquery(_)))) + expr.exists(|expr| Ok(matches!(expr, Expr::ScalarSubquery(_, _)))) .expect("Inner is always Ok") } @@ -233,7 +234,7 @@ impl TreeNodeRewriter for ExtractScalarSubQuery<'_> { fn f_down(&mut self, expr: Expr) -> Result> { match expr { - Expr::ScalarSubquery(subquery) => { + Expr::ScalarSubquery(subquery, _) => { let subqry_alias = self.alias_gen.next("__scalar_sq"); self.sub_query_info .push((subquery.clone(), subqry_alias.clone())); @@ -345,17 +346,20 @@ fn build_join( if let Some(expr_map) = collected_count_expr_map { for (name, result) in expr_map { let computer_expr = if let Some(filter) = &pull_up.pull_up_having_expr { - Expr::Case(expr::Case { + Expr::case(expr::Case { expr: None, when_then_expr: vec![ ( - Box::new(Expr::IsNull(Box::new(Expr::Column( - Column::new_unqualified(UN_MATCHED_ROW_INDICATOR), - )))), + Box::new( + Expr::Column(Column::new_unqualified( + UN_MATCHED_ROW_INDICATOR, + )) + .is_null(), + ), Box::new(result), ), ( - Box::new(Expr::Not(Box::new(filter.clone()))), + Box::new(filter.clone().not()), Box::new(Expr::Literal(ScalarValue::Null)), ), ], @@ -364,12 +368,15 @@ fn build_join( )))), }) } else { - Expr::Case(expr::Case { + Expr::case(expr::Case { expr: None, when_then_expr: vec![( - Box::new(Expr::IsNull(Box::new(Expr::Column( - Column::new_unqualified(UN_MATCHED_ROW_INDICATOR), - )))), + Box::new( + Expr::Column(Column::new_unqualified( + UN_MATCHED_ROW_INDICATOR, + )) + .is_null(), + ), Box::new(result), )], else_expr: Some(Box::new(Expr::Column(Column::new_unqualified( @@ -1038,7 +1045,7 @@ mod tests { .build()?, ); - let between_expr = Expr::Between(Between { + let between_expr = Expr::_between(Between { expr: Box::new(col("customer.c_custkey")), negated: false, low: Box::new(scalar_subquery(sq1)), @@ -1087,7 +1094,7 @@ mod tests { .build()?, ); - let between_expr = Expr::Between(Between { + let between_expr = Expr::_between(Between { expr: Box::new(col("customer.c_custkey")), negated: false, low: Box::new(scalar_subquery(sq1)), diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 6564e722eaf89..5170bf1bf29db 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -422,7 +422,7 @@ impl TreeNodeRewriter for Canonicalizer { type Node = Expr; fn f_up(&mut self, expr: Expr) -> Result> { - let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr else { + let Expr::BinaryExpr(BinaryExpr { left, op, right }, _) = expr else { return Ok(Transformed::no(expr)); }; match (left.as_ref(), right.as_ref(), op.swap()) { @@ -430,7 +430,7 @@ impl TreeNodeRewriter for Canonicalizer { (Expr::Column(left_col), Expr::Column(right_col), Some(swapped_op)) if right_col > left_col => { - Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { + Ok(Transformed::yes(Expr::binary_expr(BinaryExpr { left: right, op: swapped_op, right: left, @@ -438,13 +438,13 @@ impl TreeNodeRewriter for Canonicalizer { } // (Expr::Literal(_a), Expr::Column(_b), Some(swapped_op)) => { - Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { + Ok(Transformed::yes(Expr::binary_expr(BinaryExpr { left: right, op: swapped_op, right: left, }))) } - _ => Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr { + _ => Ok(Transformed::no(Expr::binary_expr(BinaryExpr { left, op, right, @@ -593,32 +593,32 @@ impl<'a> ConstEvaluator<'a> { | Expr::Column(_) | Expr::OuterReferenceColumn(_, _) | Expr::Exists { .. } - | Expr::InSubquery(_) - | Expr::ScalarSubquery(_) + | Expr::InSubquery(_, _) + | Expr::ScalarSubquery(_, _) | Expr::WindowFunction { .. } - | Expr::GroupingSet(_) + | Expr::GroupingSet(_, _) | Expr::Wildcard { .. } | Expr::Placeholder(_) => false, - Expr::ScalarFunction(ScalarFunction { func, .. }) => { + Expr::ScalarFunction(ScalarFunction { func, .. }, _) => { Self::volatility_ok(func.signature().volatility) } Expr::Literal(_) - | Expr::Unnest(_) + | Expr::Unnest(_, _) | Expr::BinaryExpr { .. } - | Expr::Not(_) - | Expr::IsNotNull(_) - | Expr::IsNull(_) - | Expr::IsTrue(_) - | Expr::IsFalse(_) - | Expr::IsUnknown(_) - | Expr::IsNotTrue(_) - | Expr::IsNotFalse(_) - | Expr::IsNotUnknown(_) - | Expr::Negative(_) + | Expr::Not(_, _) + | Expr::IsNotNull(_, _) + | Expr::IsNull(_, _) + | Expr::IsTrue(_, _) + | Expr::IsFalse(_, _) + | Expr::IsUnknown(_, _) + | Expr::IsNotTrue(_, _) + | Expr::IsNotFalse(_, _) + | Expr::IsNotUnknown(_, _) + | Expr::Negative(_, _) | Expr::Between { .. } | Expr::Like { .. } | Expr::SimilarTo { .. } - | Expr::Case(_) + | Expr::Case(_, _) | Expr::Cast { .. } | Expr::TryCast { .. } | Expr::InList { .. } => true, @@ -730,28 +730,34 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // true = A --> A // false = A --> !A // null = A --> null - Expr::BinaryExpr(BinaryExpr { - left, - op: Eq, - right, - }) if is_bool_lit(&left) && info.is_boolean_type(&right)? => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: Eq, + right, + }, + _, + ) if is_bool_lit(&left) && info.is_boolean_type(&right)? => { Transformed::yes(match as_bool_lit(&left)? { Some(true) => *right, - Some(false) => Expr::Not(right), + Some(false) => Expr::_not(right), None => lit_bool_null(), }) } // A = true --> A // A = false --> !A // A = null --> null - Expr::BinaryExpr(BinaryExpr { - left, - op: Eq, - right, - }) if is_bool_lit(&right) && info.is_boolean_type(&left)? => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: Eq, + right, + }, + _, + ) if is_bool_lit(&right) && info.is_boolean_type(&left)? => { Transformed::yes(match as_bool_lit(&right)? { Some(true) => *left, - Some(false) => Expr::Not(left), + Some(false) => Expr::_not(left), None => lit_bool_null(), }) } @@ -761,13 +767,16 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // true != A --> !A // false != A --> A // null != A --> null - Expr::BinaryExpr(BinaryExpr { - left, - op: NotEq, - right, - }) if is_bool_lit(&left) && info.is_boolean_type(&right)? => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: NotEq, + right, + }, + _, + ) if is_bool_lit(&left) && info.is_boolean_type(&right)? => { Transformed::yes(match as_bool_lit(&left)? { - Some(true) => Expr::Not(right), + Some(true) => Expr::_not(right), Some(false) => *right, None => lit_bool_null(), }) @@ -775,13 +784,16 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // A != true --> !A // A != false --> A // A != null --> null, - Expr::BinaryExpr(BinaryExpr { - left, - op: NotEq, - right, - }) if is_bool_lit(&right) && info.is_boolean_type(&left)? => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: NotEq, + right, + }, + _, + ) if is_bool_lit(&right) && info.is_boolean_type(&left)? => { Transformed::yes(match as_bool_lit(&right)? { - Some(true) => Expr::Not(left), + Some(true) => Expr::_not(left), Some(false) => *left, None => lit_bool_null(), }) @@ -792,76 +804,109 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // // true OR A --> true (even if A is null) - Expr::BinaryExpr(BinaryExpr { - left, - op: Or, - right: _, - }) if is_true(&left) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: Or, + right: _, + }, + _, + ) if is_true(&left) => Transformed::yes(*left), // false OR A --> A - Expr::BinaryExpr(BinaryExpr { - left, - op: Or, - right, - }) if is_false(&left) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left, + op: Or, + right, + }, + _, + ) if is_false(&left) => Transformed::yes(*right), // A OR true --> true (even if A is null) - Expr::BinaryExpr(BinaryExpr { - left: _, - op: Or, - right, - }) if is_true(&right) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left: _, + op: Or, + right, + }, + _, + ) if is_true(&right) => Transformed::yes(*right), // A OR false --> A - Expr::BinaryExpr(BinaryExpr { - left, - op: Or, - right, - }) if is_false(&right) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: Or, + right, + }, + _, + ) if is_false(&right) => Transformed::yes(*left), // A OR !A ---> true (if A not nullable) - Expr::BinaryExpr(BinaryExpr { - left, - op: Or, - right, - }) if is_not_of(&right, &left) && !info.nullable(&left)? => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: Or, + right, + }, + _, + ) if is_not_of(&right, &left) && !info.nullable(&left)? => { Transformed::yes(lit(true)) } // !A OR A ---> true (if A not nullable) - Expr::BinaryExpr(BinaryExpr { - left, - op: Or, - right, - }) if is_not_of(&left, &right) && !info.nullable(&right)? => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: Or, + right, + }, + _, + ) if is_not_of(&left, &right) && !info.nullable(&right)? => { Transformed::yes(lit(true)) } // (..A..) OR A --> (..A..) - Expr::BinaryExpr(BinaryExpr { - left, - op: Or, - right, - }) if expr_contains(&left, &right, Or) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: Or, + right, + }, + _, + ) if expr_contains(&left, &right, Or) => Transformed::yes(*left), // A OR (..A..) --> (..A..) - Expr::BinaryExpr(BinaryExpr { - left, - op: Or, - right, - }) if expr_contains(&right, &left, Or) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left, + op: Or, + right, + }, + _, + ) if expr_contains(&right, &left, Or) => Transformed::yes(*right), // A OR (A AND B) --> A - Expr::BinaryExpr(BinaryExpr { - left, - op: Or, - right, - }) if is_op_with(And, &right, &left) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: Or, + right, + }, + _, + ) if is_op_with(And, &right, &left) => Transformed::yes(*left), // (A AND B) OR A --> A - Expr::BinaryExpr(BinaryExpr { - left, - op: Or, - right, - }) if is_op_with(And, &left, &right) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left, + op: Or, + right, + }, + _, + ) if is_op_with(And, &left, &right) => Transformed::yes(*right), // Eliminate common factors in conjunctions e.g // (A AND B) OR (A AND C) -> A AND (B OR C) - Expr::BinaryExpr(BinaryExpr { - left, - op: Or, - right, - }) if has_common_conjunction(&left, &right) => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: Or, + right, + }, + _, + ) if has_common_conjunction(&left, &right) => { let lhs: IndexSet = iter_conjunction_owned(*left).collect(); let (common, rhs): (Vec<_>, Vec<_>) = iter_conjunction_owned(*right) .partition(|e| lhs.contains(e) && !e.is_volatile()); @@ -882,116 +927,164 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // // true AND A --> A - Expr::BinaryExpr(BinaryExpr { - left, - op: And, - right, - }) if is_true(&left) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left, + op: And, + right, + }, + _, + ) if is_true(&left) => Transformed::yes(*right), // false AND A --> false (even if A is null) - Expr::BinaryExpr(BinaryExpr { - left, - op: And, - right: _, - }) if is_false(&left) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: And, + right: _, + }, + _, + ) if is_false(&left) => Transformed::yes(*left), // A AND true --> A - Expr::BinaryExpr(BinaryExpr { - left, - op: And, - right, - }) if is_true(&right) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: And, + right, + }, + _, + ) if is_true(&right) => Transformed::yes(*left), // A AND false --> false (even if A is null) - Expr::BinaryExpr(BinaryExpr { - left: _, - op: And, - right, - }) if is_false(&right) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left: _, + op: And, + right, + }, + _, + ) if is_false(&right) => Transformed::yes(*right), // A AND !A ---> false (if A not nullable) - Expr::BinaryExpr(BinaryExpr { - left, - op: And, - right, - }) if is_not_of(&right, &left) && !info.nullable(&left)? => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: And, + right, + }, + _, + ) if is_not_of(&right, &left) && !info.nullable(&left)? => { Transformed::yes(lit(false)) } // !A AND A ---> false (if A not nullable) - Expr::BinaryExpr(BinaryExpr { - left, - op: And, - right, - }) if is_not_of(&left, &right) && !info.nullable(&right)? => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: And, + right, + }, + _, + ) if is_not_of(&left, &right) && !info.nullable(&right)? => { Transformed::yes(lit(false)) } // (..A..) AND A --> (..A..) - Expr::BinaryExpr(BinaryExpr { - left, - op: And, - right, - }) if expr_contains(&left, &right, And) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: And, + right, + }, + _, + ) if expr_contains(&left, &right, And) => Transformed::yes(*left), // A AND (..A..) --> (..A..) - Expr::BinaryExpr(BinaryExpr { - left, - op: And, - right, - }) if expr_contains(&right, &left, And) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left, + op: And, + right, + }, + _, + ) if expr_contains(&right, &left, And) => Transformed::yes(*right), // A AND (A OR B) --> A - Expr::BinaryExpr(BinaryExpr { - left, - op: And, - right, - }) if is_op_with(Or, &right, &left) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: And, + right, + }, + _, + ) if is_op_with(Or, &right, &left) => Transformed::yes(*left), // (A OR B) AND A --> A - Expr::BinaryExpr(BinaryExpr { - left, - op: And, - right, - }) if is_op_with(Or, &left, &right) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left, + op: And, + right, + }, + _, + ) if is_op_with(Or, &left, &right) => Transformed::yes(*right), // // Rules for Multiply // // A * 1 --> A - Expr::BinaryExpr(BinaryExpr { - left, - op: Multiply, - right, - }) if is_one(&right) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: Multiply, + right, + }, + _, + ) if is_one(&right) => Transformed::yes(*left), // 1 * A --> A - Expr::BinaryExpr(BinaryExpr { - left, - op: Multiply, - right, - }) if is_one(&left) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left, + op: Multiply, + right, + }, + _, + ) if is_one(&left) => Transformed::yes(*right), // A * null --> null - Expr::BinaryExpr(BinaryExpr { - left: _, - op: Multiply, - right, - }) if is_null(&right) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left: _, + op: Multiply, + right, + }, + _, + ) if is_null(&right) => Transformed::yes(*right), // null * A --> null - Expr::BinaryExpr(BinaryExpr { - left, - op: Multiply, - right: _, - }) if is_null(&left) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: Multiply, + right: _, + }, + _, + ) if is_null(&left) => Transformed::yes(*left), // A * 0 --> 0 (if A is not null and not floating, since NAN * 0 -> NAN) - Expr::BinaryExpr(BinaryExpr { - left, - op: Multiply, - right, - }) if !info.nullable(&left)? + Expr::BinaryExpr( + BinaryExpr { + left, + op: Multiply, + right, + }, + _, + ) if !info.nullable(&left)? && !info.get_data_type(&left)?.is_floating() && is_zero(&right) => { Transformed::yes(*right) } // 0 * A --> 0 (if A is not null and not floating, since 0 * NAN -> NAN) - Expr::BinaryExpr(BinaryExpr { - left, - op: Multiply, - right, - }) if !info.nullable(&right)? + Expr::BinaryExpr( + BinaryExpr { + left, + op: Multiply, + right, + }, + _, + ) if !info.nullable(&right)? && !info.get_data_type(&right)?.is_floating() && is_zero(&left) => { @@ -1003,46 +1096,64 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // // A / 1 --> A - Expr::BinaryExpr(BinaryExpr { - left, - op: Divide, - right, - }) if is_one(&right) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: Divide, + right, + }, + _, + ) if is_one(&right) => Transformed::yes(*left), // null / A --> null - Expr::BinaryExpr(BinaryExpr { - left, - op: Divide, - right: _, - }) if is_null(&left) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: Divide, + right: _, + }, + _, + ) if is_null(&left) => Transformed::yes(*left), // A / null --> null - Expr::BinaryExpr(BinaryExpr { - left: _, - op: Divide, - right, - }) if is_null(&right) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left: _, + op: Divide, + right, + }, + _, + ) if is_null(&right) => Transformed::yes(*right), // // Rules for Modulo // // A % null --> null - Expr::BinaryExpr(BinaryExpr { - left: _, - op: Modulo, - right, - }) if is_null(&right) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left: _, + op: Modulo, + right, + }, + _, + ) if is_null(&right) => Transformed::yes(*right), // null % A --> null - Expr::BinaryExpr(BinaryExpr { - left, - op: Modulo, - right: _, - }) if is_null(&left) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: Modulo, + right: _, + }, + _, + ) if is_null(&left) => Transformed::yes(*left), // A % 1 --> 0 (if A is not nullable and not floating, since NAN % 1 --> NAN) - Expr::BinaryExpr(BinaryExpr { - left, - op: Modulo, - right, - }) if !info.nullable(&left)? + Expr::BinaryExpr( + BinaryExpr { + left, + op: Modulo, + right, + }, + _, + ) if !info.nullable(&left)? && !info.get_data_type(&left)?.is_floating() && is_one(&right) => { @@ -1056,84 +1167,114 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // // A & null -> null - Expr::BinaryExpr(BinaryExpr { - left: _, - op: BitwiseAnd, - right, - }) if is_null(&right) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left: _, + op: BitwiseAnd, + right, + }, + _, + ) if is_null(&right) => Transformed::yes(*right), // null & A -> null - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseAnd, - right: _, - }) if is_null(&left) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseAnd, + right: _, + }, + _, + ) if is_null(&left) => Transformed::yes(*left), // A & 0 -> 0 (if A not nullable) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseAnd, - right, - }) if !info.nullable(&left)? && is_zero(&right) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseAnd, + right, + }, + _, + ) if !info.nullable(&left)? && is_zero(&right) => Transformed::yes(*right), // 0 & A -> 0 (if A not nullable) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseAnd, - right, - }) if !info.nullable(&right)? && is_zero(&left) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseAnd, + right, + }, + _, + ) if !info.nullable(&right)? && is_zero(&left) => Transformed::yes(*left), // !A & A -> 0 (if A not nullable) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseAnd, - right, - }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseAnd, + right, + }, + _, + ) if is_negative_of(&left, &right) && !info.nullable(&right)? => { Transformed::yes(Expr::Literal(ScalarValue::new_zero( &info.get_data_type(&left)?, )?)) } // A & !A -> 0 (if A not nullable) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseAnd, - right, - }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseAnd, + right, + }, + _, + ) if is_negative_of(&right, &left) && !info.nullable(&left)? => { Transformed::yes(Expr::Literal(ScalarValue::new_zero( &info.get_data_type(&left)?, )?)) } // (..A..) & A --> (..A..) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseAnd, - right, - }) if expr_contains(&left, &right, BitwiseAnd) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseAnd, + right, + }, + _, + ) if expr_contains(&left, &right, BitwiseAnd) => Transformed::yes(*left), // A & (..A..) --> (..A..) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseAnd, - right, - }) if expr_contains(&right, &left, BitwiseAnd) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseAnd, + right, + }, + _, + ) if expr_contains(&right, &left, BitwiseAnd) => Transformed::yes(*right), // A & (A | B) --> A (if B not null) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseAnd, - right, - }) if !info.nullable(&right)? && is_op_with(BitwiseOr, &right, &left) => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseAnd, + right, + }, + _, + ) if !info.nullable(&right)? && is_op_with(BitwiseOr, &right, &left) => { Transformed::yes(*left) } // (A | B) & A --> A (if B not null) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseAnd, - right, - }) if !info.nullable(&left)? && is_op_with(BitwiseOr, &left, &right) => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseAnd, + right, + }, + _, + ) if !info.nullable(&left)? && is_op_with(BitwiseOr, &left, &right) => { Transformed::yes(*right) } @@ -1142,84 +1283,114 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // // A | null -> null - Expr::BinaryExpr(BinaryExpr { - left: _, - op: BitwiseOr, - right, - }) if is_null(&right) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left: _, + op: BitwiseOr, + right, + }, + _, + ) if is_null(&right) => Transformed::yes(*right), // null | A -> null - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseOr, - right: _, - }) if is_null(&left) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseOr, + right: _, + }, + _, + ) if is_null(&left) => Transformed::yes(*left), // A | 0 -> A (even if A is null) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseOr, - right, - }) if is_zero(&right) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseOr, + right, + }, + _, + ) if is_zero(&right) => Transformed::yes(*left), // 0 | A -> A (even if A is null) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseOr, - right, - }) if is_zero(&left) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseOr, + right, + }, + _, + ) if is_zero(&left) => Transformed::yes(*right), // !A | A -> -1 (if A not nullable) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseOr, - right, - }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseOr, + right, + }, + _, + ) if is_negative_of(&left, &right) && !info.nullable(&right)? => { Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( &info.get_data_type(&left)?, )?)) } // A | !A -> -1 (if A not nullable) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseOr, - right, - }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseOr, + right, + }, + _, + ) if is_negative_of(&right, &left) && !info.nullable(&left)? => { Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( &info.get_data_type(&left)?, )?)) } // (..A..) | A --> (..A..) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseOr, - right, - }) if expr_contains(&left, &right, BitwiseOr) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseOr, + right, + }, + _, + ) if expr_contains(&left, &right, BitwiseOr) => Transformed::yes(*left), // A | (..A..) --> (..A..) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseOr, - right, - }) if expr_contains(&right, &left, BitwiseOr) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseOr, + right, + }, + _, + ) if expr_contains(&right, &left, BitwiseOr) => Transformed::yes(*right), // A | (A & B) --> A (if B not null) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseOr, - right, - }) if !info.nullable(&right)? && is_op_with(BitwiseAnd, &right, &left) => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseOr, + right, + }, + _, + ) if !info.nullable(&right)? && is_op_with(BitwiseAnd, &right, &left) => { Transformed::yes(*left) } // (A & B) | A --> A (if B not null) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseOr, - right, - }) if !info.nullable(&left)? && is_op_with(BitwiseAnd, &left, &right) => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseOr, + right, + }, + _, + ) if !info.nullable(&left)? && is_op_with(BitwiseAnd, &left, &right) => { Transformed::yes(*right) } @@ -1228,61 +1399,82 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // // A ^ null -> null - Expr::BinaryExpr(BinaryExpr { - left: _, - op: BitwiseXor, - right, - }) if is_null(&right) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left: _, + op: BitwiseXor, + right, + }, + _, + ) if is_null(&right) => Transformed::yes(*right), // null ^ A -> null - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseXor, - right: _, - }) if is_null(&left) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseXor, + right: _, + }, + _, + ) if is_null(&left) => Transformed::yes(*left), // A ^ 0 -> A (if A not nullable) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseXor, - right, - }) if !info.nullable(&left)? && is_zero(&right) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseXor, + right, + }, + _, + ) if !info.nullable(&left)? && is_zero(&right) => Transformed::yes(*left), // 0 ^ A -> A (if A not nullable) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseXor, - right, - }) if !info.nullable(&right)? && is_zero(&left) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseXor, + right, + }, + _, + ) if !info.nullable(&right)? && is_zero(&left) => Transformed::yes(*right), // !A ^ A -> -1 (if A not nullable) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseXor, - right, - }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseXor, + right, + }, + _, + ) if is_negative_of(&left, &right) && !info.nullable(&right)? => { Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( &info.get_data_type(&left)?, )?)) } // A ^ !A -> -1 (if A not nullable) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseXor, - right, - }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseXor, + right, + }, + _, + ) if is_negative_of(&right, &left) && !info.nullable(&left)? => { Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( &info.get_data_type(&left)?, )?)) } // (..A..) ^ A --> (the expression without A, if number of A is odd, otherwise one A) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseXor, - right, - }) if expr_contains(&left, &right, BitwiseXor) => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseXor, + right, + }, + _, + ) if expr_contains(&left, &right, BitwiseXor) => { let expr = delete_xor_in_complex_expr(&left, &right, false); Transformed::yes(if expr == *right { Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&right)?)?) @@ -1292,11 +1484,14 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { } // A ^ (..A..) --> (the expression without A, if number of A is odd, otherwise one A) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseXor, - right, - }) if expr_contains(&right, &left, BitwiseXor) => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseXor, + right, + }, + _, + ) if expr_contains(&right, &left, BitwiseXor) => { let expr = delete_xor_in_complex_expr(&right, &left, true); Transformed::yes(if expr == *left { Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?) @@ -1310,60 +1505,78 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // // A >> null -> null - Expr::BinaryExpr(BinaryExpr { - left: _, - op: BitwiseShiftRight, - right, - }) if is_null(&right) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left: _, + op: BitwiseShiftRight, + right, + }, + _, + ) if is_null(&right) => Transformed::yes(*right), // null >> A -> null - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseShiftRight, - right: _, - }) if is_null(&left) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseShiftRight, + right: _, + }, + _, + ) if is_null(&left) => Transformed::yes(*left), // A >> 0 -> A (even if A is null) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseShiftRight, - right, - }) if is_zero(&right) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseShiftRight, + right, + }, + _, + ) if is_zero(&right) => Transformed::yes(*left), // // Rules for BitwiseShiftRight // // A << null -> null - Expr::BinaryExpr(BinaryExpr { - left: _, - op: BitwiseShiftLeft, - right, - }) if is_null(&right) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left: _, + op: BitwiseShiftLeft, + right, + }, + _, + ) if is_null(&right) => Transformed::yes(*right), // null << A -> null - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseShiftLeft, - right: _, - }) if is_null(&left) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseShiftLeft, + right: _, + }, + _, + ) if is_null(&left) => Transformed::yes(*left), // A << 0 -> A (even if A is null) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseShiftLeft, - right, - }) if is_zero(&right) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseShiftLeft, + right, + }, + _, + ) if is_zero(&right) => Transformed::yes(*left), // // Rules for Not // - Expr::Not(inner) => Transformed::yes(negate_clause(*inner)), + Expr::Not(inner, _) => Transformed::yes(negate_clause(*inner)), // // Rules for Negative // - Expr::Negative(inner) => Transformed::yes(distribute_negation(*inner)), + Expr::Negative(inner, _) => Transformed::yes(distribute_negation(*inner)), // // Rules for Case @@ -1380,11 +1593,14 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // // Note: the rationale for this rewrite is that the expr can then be further // simplified using the existing rules for AND/OR - Expr::Case(Case { - expr: None, - when_then_expr, - else_expr, - }) if !when_then_expr.is_empty() + Expr::Case( + Case { + expr: None, + when_then_expr, + else_expr, + }, + _, + ) if !when_then_expr.is_empty() && when_then_expr.len() < 3 // The rewrite is O(n!) so limit to small number && info.is_boolean_type(&when_then_expr[0].1)? => { @@ -1412,10 +1628,10 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // Do a first pass at simplification out_expr.rewrite(self)? } - Expr::ScalarFunction(ScalarFunction { func: udf, args }) => { + Expr::ScalarFunction(ScalarFunction { func: udf, args }, _) => { match udf.simplify(args, info)? { ExprSimplifyResult::Original(args) => { - Transformed::no(Expr::ScalarFunction(ScalarFunction { + Transformed::no(Expr::scalar_function(ScalarFunction { func: udf, args, })) @@ -1424,21 +1640,24 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { } } - Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction { - ref func, - .. - }) => match (func.simplify(), expr) { - (Some(simplify_function), Expr::AggregateFunction(af)) => { + Expr::AggregateFunction( + datafusion_expr::expr::AggregateFunction { ref func, .. }, + _, + ) => match (func.simplify(), expr) { + (Some(simplify_function), Expr::AggregateFunction(af, _)) => { Transformed::yes(simplify_function(af, info)?) } (_, expr) => Transformed::no(expr), }, - Expr::WindowFunction(WindowFunction { - fun: WindowFunctionDefinition::WindowUDF(ref udwf), - .. - }) => match (udwf.simplify(), expr) { - (Some(simplify_function), Expr::WindowFunction(wf)) => { + Expr::WindowFunction( + WindowFunction { + fun: WindowFunctionDefinition::WindowUDF(ref udwf), + .. + }, + _, + ) => match (udwf.simplify(), expr) { + (Some(simplify_function), Expr::WindowFunction(wf, _)) => { Transformed::yes(simplify_function(wf, info)?) } (_, expr) => Transformed::no(expr), @@ -1450,7 +1669,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // a between 3 and 5 --> a >= 3 AND a <=5 // a not between 3 and 5 --> a < 3 OR a > 5 - Expr::Between(between) => Transformed::yes(if between.negated { + Expr::Between(between, _) => Transformed::yes(if between.negated { let l = *between.expr.clone(); let r = *between.expr; or(l.lt(*between.low), r.gt(*between.high)) @@ -1464,14 +1683,17 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // // Rules for regexes // - Expr::BinaryExpr(BinaryExpr { - left, - op: op @ (RegexMatch | RegexNotMatch | RegexIMatch | RegexNotIMatch), - right, - }) => Transformed::yes(simplify_regex_expr(left, op, right)?), + Expr::BinaryExpr( + BinaryExpr { + left, + op: op @ (RegexMatch | RegexNotMatch | RegexIMatch | RegexNotIMatch), + right, + }, + _, + ) => Transformed::yes(simplify_regex_expr(left, op, right)?), // Rules for Like - Expr::Like(like) => { + Expr::Like(like, _) => { // `\` is implicit escape, see https://github.com/apache/datafusion/issues/13291 let escape_char = like.escape_char.unwrap_or('\\'); match as_string_scalar(&like.pattern) { @@ -1489,8 +1711,10 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { Transformed::yes(if !info.nullable(&like.expr)? { result_for_non_null } else { - Expr::Case(Case { - expr: Some(Box::new(Expr::IsNotNull(like.expr))), + Expr::case(Case { + expr: Some(Box::new(Expr::_is_not_null( + like.expr, + ))), when_then_expr: vec![( Box::new(lit(true)), Box::new(result_for_non_null), @@ -1509,7 +1733,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { .unwrap() .replace_all(pattern_str, "%") .to_string(); - Transformed::yes(Expr::Like(Like { + Transformed::yes(Expr::_like(Like { pattern: Box::new(to_string_scalar( data_type, Some(simplified_pattern), @@ -1523,63 +1747,74 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { { // If the pattern does not contain any wildcards, we can simplify the like expression to an equality expression // TODO: handle escape characters - Transformed::yes(Expr::BinaryExpr(BinaryExpr { + Transformed::yes(Expr::binary_expr(BinaryExpr { left: like.expr.clone(), op: if like.negated { NotEq } else { Eq }, right: like.pattern.clone(), })) } - Some(_pattern_str) => Transformed::no(Expr::Like(like)), + Some(_pattern_str) => Transformed::no(Expr::_like(like)), } } - None => Transformed::no(Expr::Like(like)), + None => Transformed::no(Expr::_like(like)), } } // a is not null/unknown --> true (if a is not nullable) - Expr::IsNotNull(expr) | Expr::IsNotUnknown(expr) + Expr::IsNotNull(expr, _) | Expr::IsNotUnknown(expr, _) if !info.nullable(&expr)? => { Transformed::yes(lit(true)) } // a is null/unknown --> false (if a is not nullable) - Expr::IsNull(expr) | Expr::IsUnknown(expr) if !info.nullable(&expr)? => { + Expr::IsNull(expr, _) | Expr::IsUnknown(expr, _) + if !info.nullable(&expr)? => + { Transformed::yes(lit(false)) } // expr IN () --> false // expr NOT IN () --> true - Expr::InList(InList { - expr, - list, - negated, - }) if list.is_empty() && *expr != Expr::Literal(ScalarValue::Null) => { + Expr::InList( + InList { + expr, + list, + negated, + }, + _, + ) if list.is_empty() && *expr != Expr::Literal(ScalarValue::Null) => { Transformed::yes(lit(negated)) } // null in (x, y, z) --> null // null not in (x, y, z) --> null - Expr::InList(InList { - expr, - list: _, - negated: _, - }) if is_null(expr.as_ref()) => Transformed::yes(lit_bool_null()), + Expr::InList( + InList { + expr, + list: _, + negated: _, + }, + _, + ) if is_null(expr.as_ref()) => Transformed::yes(lit_bool_null()), // expr IN ((subquery)) -> expr IN (subquery), see ##5529 - Expr::InList(InList { - expr, - mut list, - negated, - }) if list.len() == 1 + Expr::InList( + InList { + expr, + mut list, + negated, + }, + _, + ) if list.len() == 1 && matches!(list.first(), Some(Expr::ScalarSubquery { .. })) => { - let Expr::ScalarSubquery(subquery) = list.remove(0) else { + let Expr::ScalarSubquery(subquery, _) = list.remove(0) else { unreachable!() }; - Transformed::yes(Expr::InSubquery(InSubquery::new( + Transformed::yes(Expr::in_subquery(InSubquery::new( expr, subquery, negated, ))) } @@ -1587,11 +1822,14 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // Combine multiple OR expressions into a single IN list expression if possible // // i.e. `a = 1 OR a = 2 OR a = 3` -> `a IN (1, 2, 3)` - Expr::BinaryExpr(BinaryExpr { - left, - op: Or, - right, - }) if are_inlist_and_eq(left.as_ref(), right.as_ref()) => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: Or, + right, + }, + _, + ) if are_inlist_and_eq(left.as_ref(), right.as_ref()) => { let lhs = to_inlist(*left).unwrap(); let rhs = to_inlist(*right).unwrap(); let mut seen: HashSet = HashSet::new(); @@ -1608,7 +1846,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { negated: false, }; - Transformed::yes(Expr::InList(merged_inlist)) + Transformed::yes(Expr::_in_list(merged_inlist)) } // Simplify expressions that is guaranteed to be true or false to a literal boolean expression @@ -1627,11 +1865,14 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // 6. `a in (1,2,3,4) AND a not in (1,2,3,4,5) -> a in (), which is false` // 7. `a not in (1,2,3,4) AND a in (1,2,3,4,5) -> a = 5` // 8. `a in (1,2,3,4) AND a not in (5,6,7,8) -> a in (1,2,3,4)` - Expr::BinaryExpr(BinaryExpr { - left, - op: And, - right, - }) if are_inlist_and_eq_and_match_neg( + Expr::BinaryExpr( + BinaryExpr { + left, + op: And, + right, + }, + _, + ) if are_inlist_and_eq_and_match_neg( left.as_ref(), right.as_ref(), false, @@ -1639,7 +1880,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { ) => { match (*left, *right) { - (Expr::InList(l1), Expr::InList(l2)) => { + (Expr::InList(l1, _), Expr::InList(l2, _)) => { return inlist_intersection(l1, &l2, false).map(Transformed::yes); } // Matched previously once @@ -1647,11 +1888,14 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { } } - Expr::BinaryExpr(BinaryExpr { - left, - op: And, - right, - }) if are_inlist_and_eq_and_match_neg( + Expr::BinaryExpr( + BinaryExpr { + left, + op: And, + right, + }, + _, + ) if are_inlist_and_eq_and_match_neg( left.as_ref(), right.as_ref(), true, @@ -1659,7 +1903,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { ) => { match (*left, *right) { - (Expr::InList(l1), Expr::InList(l2)) => { + (Expr::InList(l1, _), Expr::InList(l2, _)) => { return inlist_union(l1, l2, true).map(Transformed::yes); } // Matched previously once @@ -1667,11 +1911,14 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { } } - Expr::BinaryExpr(BinaryExpr { - left, - op: And, - right, - }) if are_inlist_and_eq_and_match_neg( + Expr::BinaryExpr( + BinaryExpr { + left, + op: And, + right, + }, + _, + ) if are_inlist_and_eq_and_match_neg( left.as_ref(), right.as_ref(), false, @@ -1679,7 +1926,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { ) => { match (*left, *right) { - (Expr::InList(l1), Expr::InList(l2)) => { + (Expr::InList(l1, _), Expr::InList(l2, _)) => { return inlist_except(l1, &l2).map(Transformed::yes); } // Matched previously once @@ -1687,11 +1934,14 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { } } - Expr::BinaryExpr(BinaryExpr { - left, - op: And, - right, - }) if are_inlist_and_eq_and_match_neg( + Expr::BinaryExpr( + BinaryExpr { + left, + op: And, + right, + }, + _, + ) if are_inlist_and_eq_and_match_neg( left.as_ref(), right.as_ref(), true, @@ -1699,7 +1949,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { ) => { match (*left, *right) { - (Expr::InList(l1), Expr::InList(l2)) => { + (Expr::InList(l1, _), Expr::InList(l2, _)) => { return inlist_except(l2, &l1).map(Transformed::yes); } // Matched previously once @@ -1707,11 +1957,14 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { } } - Expr::BinaryExpr(BinaryExpr { - left, - op: Or, - right, - }) if are_inlist_and_eq_and_match_neg( + Expr::BinaryExpr( + BinaryExpr { + left, + op: Or, + right, + }, + _, + ) if are_inlist_and_eq_and_match_neg( left.as_ref(), right.as_ref(), true, @@ -1719,7 +1972,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { ) => { match (*left, *right) { - (Expr::InList(l1), Expr::InList(l2)) => { + (Expr::InList(l1, _), Expr::InList(l2, _)) => { return inlist_intersection(l1, &l2, true).map(Transformed::yes); } // Matched previously once @@ -1764,7 +2017,7 @@ fn are_inlist_and_eq_and_match_neg( is_right_neg: bool, ) -> bool { match (left, right) { - (Expr::InList(l), Expr::InList(r)) => { + (Expr::InList(l, _), Expr::InList(r, _)) => { l.expr == r.expr && l.negated == is_left_neg && r.negated == is_right_neg } _ => false, @@ -1789,8 +2042,8 @@ fn are_inlist_and_eq(left: &Expr, right: &Expr) -> bool { /// Try to convert an expression to an in-list expression fn as_inlist(expr: &Expr) -> Option> { match expr { - Expr::InList(inlist) => Some(Cow::Borrowed(inlist)), - Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == Operator::Eq => { + Expr::InList(inlist, _) => Some(Cow::Borrowed(inlist)), + Expr::BinaryExpr(BinaryExpr { left, op, right }, _) if *op == Operator::Eq => { match (left.as_ref(), right.as_ref()) { (Expr::Column(_), Expr::Literal(_)) => Some(Cow::Owned(InList { expr: left.clone(), @@ -1811,12 +2064,15 @@ fn as_inlist(expr: &Expr) -> Option> { fn to_inlist(expr: Expr) -> Option { match expr { - Expr::InList(inlist) => Some(inlist), - Expr::BinaryExpr(BinaryExpr { - left, - op: Operator::Eq, - right, - }) => match (left.as_ref(), right.as_ref()) { + Expr::InList(inlist, _) => Some(inlist), + Expr::BinaryExpr( + BinaryExpr { + left, + op: Operator::Eq, + right, + }, + _, + ) => match (left.as_ref(), right.as_ref()) { (Expr::Column(_), Expr::Literal(_)) => Some(InList { expr: left, list: vec![*right], @@ -1848,7 +2104,7 @@ fn inlist_union(mut l1: InList, l2: InList, negated: bool) -> Result { l1.list.extend(keep_l2); l1.negated = negated; - Ok(Expr::InList(l1)) + Ok(Expr::_in_list(l1)) } /// Return the intersection of two inlist expressions @@ -1864,7 +2120,7 @@ fn inlist_intersection(mut l1: InList, l2: &InList, negated: bool) -> Result Result { if l1.list.is_empty() { return Ok(lit(false)); } - Ok(Expr::InList(l1)) + Ok(Expr::_in_list(l1)) } #[cfg(test)] @@ -3025,7 +3281,7 @@ mod tests { } fn regex_match(left: Expr, right: Expr) -> Expr { - Expr::BinaryExpr(BinaryExpr { + Expr::binary_expr(BinaryExpr { left: Box::new(left), op: Operator::RegexMatch, right: Box::new(right), @@ -3033,7 +3289,7 @@ mod tests { } fn regex_not_match(left: Expr, right: Expr) -> Expr { - Expr::BinaryExpr(BinaryExpr { + Expr::binary_expr(BinaryExpr { left: Box::new(left), op: Operator::RegexNotMatch, right: Box::new(right), @@ -3041,7 +3297,7 @@ mod tests { } fn regex_imatch(left: Expr, right: Expr) -> Expr { - Expr::BinaryExpr(BinaryExpr { + Expr::binary_expr(BinaryExpr { left: Box::new(left), op: Operator::RegexIMatch, right: Box::new(right), @@ -3049,7 +3305,7 @@ mod tests { } fn regex_not_imatch(left: Expr, right: Expr) -> Expr { - Expr::BinaryExpr(BinaryExpr { + Expr::binary_expr(BinaryExpr { left: Box::new(left), op: Operator::RegexNotIMatch, right: Box::new(right), @@ -3151,13 +3407,13 @@ mod tests { #[test] fn simplify_expr_is_not_null() { assert_eq!( - simplify(Expr::IsNotNull(Box::new(col("c1")))), - Expr::IsNotNull(Box::new(col("c1"))) + simplify(Expr::_is_not_null(Box::new(col("c1")))), + Expr::_is_not_null(Box::new(col("c1"))) ); // 'c1_non_null IS NOT NULL' is always true assert_eq!( - simplify(Expr::IsNotNull(Box::new(col("c1_non_null")))), + simplify(Expr::_is_not_null(Box::new(col("c1_non_null")))), lit(true) ); } @@ -3165,13 +3421,13 @@ mod tests { #[test] fn simplify_expr_is_null() { assert_eq!( - simplify(Expr::IsNull(Box::new(col("c1")))), - Expr::IsNull(Box::new(col("c1"))) + simplify(Expr::_is_null(Box::new(col("c1")))), + Expr::_is_null(Box::new(col("c1"))) ); // 'c1_non_null IS NULL' is always false assert_eq!( - simplify(Expr::IsNull(Box::new(col("c1_non_null")))), + simplify(Expr::_is_null(Box::new(col("c1_non_null")))), lit(false) ); } @@ -3269,7 +3525,7 @@ mod tests { // --> // false assert_eq!( - simplify(Expr::Case(Case::new( + simplify(Expr::case(Case::new( None, vec![( Box::new(col("c2").not_eq(lit(false))), @@ -3289,7 +3545,7 @@ mod tests { // Need to call simplify 2x due to // https://github.com/apache/datafusion/issues/1160 assert_eq!( - simplify(simplify(Expr::Case(Case::new( + simplify(simplify(Expr::case(Case::new( None, vec![( Box::new(col("c2").not_eq(lit(false))), @@ -3307,7 +3563,7 @@ mod tests { // Need to call simplify 2x due to // https://github.com/apache/datafusion/issues/1160 assert_eq!( - simplify(simplify(Expr::Case(Case::new( + simplify(simplify(Expr::case(Case::new( None, vec![(Box::new(col("c2").is_null()), Box::new(lit(true)),)], Some(Box::new(col("c2"))), @@ -3325,7 +3581,7 @@ mod tests { // Need to call simplify 2x due to // https://github.com/apache/datafusion/issues/1160 assert_eq!( - simplify(simplify(Expr::Case(Case::new( + simplify(simplify(Expr::case(Case::new( None, vec![ (Box::new(col("c1")), Box::new(lit(true)),), @@ -3344,7 +3600,7 @@ mod tests { // Need to call simplify 2x due to // https://github.com/apache/datafusion/issues/1160 assert_eq!( - simplify(simplify(Expr::Case(Case::new( + simplify(simplify(Expr::case(Case::new( None, vec![ (Box::new(col("c1")), Box::new(lit(true)),), @@ -3963,7 +4219,7 @@ mod tests { fn test_simplify_udaf() { let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_with_simplify()); let aggregate_function_expr = - Expr::AggregateFunction(expr::AggregateFunction::new_udf( + Expr::aggregate_function(expr::AggregateFunction::new_udf( udaf.into(), vec![], false, @@ -3977,7 +4233,7 @@ mod tests { let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_without_simplify()); let aggregate_function_expr = - Expr::AggregateFunction(expr::AggregateFunction::new_udf( + Expr::aggregate_function(expr::AggregateFunction::new_udf( udaf.into(), vec![], false, @@ -4058,7 +4314,7 @@ mod tests { WindowUDF::new_from_impl(SimplifyMockUdwf::new_with_simplify()).into(), ); let window_function_expr = - Expr::WindowFunction(WindowFunction::new(udwf, vec![])); + Expr::window_function(WindowFunction::new(udwf, vec![])); let expected = col("result_column"); assert_eq!(simplify(window_function_expr), expected); @@ -4067,7 +4323,7 @@ mod tests { WindowUDF::new_from_impl(SimplifyMockUdwf::new_without_simplify()).into(), ); let window_function_expr = - Expr::WindowFunction(WindowFunction::new(udwf, vec![])); + Expr::window_function(WindowFunction::new(udwf, vec![])); let expected = window_function_expr.clone(); assert_eq!(simplify(window_function_expr), expected); @@ -4156,7 +4412,7 @@ mod tests { #[test] fn test_optimize_volatile_conditions() { let fun = Arc::new(ScalarUDF::new_from_impl(VolatileUdf::new())); - let rand = Expr::ScalarFunction(ScalarFunction::new_udf(fun, vec![])); + let rand = Expr::scalar_function(ScalarFunction::new_udf(fun, vec![])); { let expr = rand .clone() @@ -4191,7 +4447,7 @@ mod tests { } fn if_not_null(expr: Expr, then: bool) -> Expr { - Expr::Case(Case { + Expr::case(Case { expr: Some(expr.is_not_null().into()), when_then_expr: vec![(lit(true).into(), lit(then).into())], else_expr: None, diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index afcbe528083b8..2c2e5686ebeeb 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -66,24 +66,27 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { } match &expr { - Expr::IsNull(inner) => match self.guarantees.get(inner.as_ref()) { + Expr::IsNull(inner, _) => match self.guarantees.get(inner.as_ref()) { Some(NullableInterval::Null { .. }) => Ok(Transformed::yes(lit(true))), Some(NullableInterval::NotNull { .. }) => { Ok(Transformed::yes(lit(false))) } _ => Ok(Transformed::no(expr)), }, - Expr::IsNotNull(inner) => match self.guarantees.get(inner.as_ref()) { + Expr::IsNotNull(inner, _) => match self.guarantees.get(inner.as_ref()) { Some(NullableInterval::Null { .. }) => Ok(Transformed::yes(lit(false))), Some(NullableInterval::NotNull { .. }) => Ok(Transformed::yes(lit(true))), _ => Ok(Transformed::no(expr)), }, - Expr::Between(Between { - expr: inner, - negated, - low, - high, - }) => { + Expr::Between( + Between { + expr: inner, + negated, + low, + high, + }, + _, + ) => { if let (Some(interval), Expr::Literal(low), Expr::Literal(high)) = ( self.guarantees.get(inner.as_ref()), low.as_ref(), @@ -107,7 +110,7 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { } } - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + Expr::BinaryExpr(BinaryExpr { left, op, right }, _) => { // The left or right side of expression might either have a guarantee // or be a literal. Either way, we can resolve them to a NullableInterval. let left_interval = self @@ -158,11 +161,14 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { } } - Expr::InList(InList { - expr: inner, - list, - negated, - }) => { + Expr::InList( + InList { + expr: inner, + list, + negated, + }, + _, + ) => { if let Some(interval) = self.guarantees.get(inner.as_ref()) { // Can remove items from the list that don't match the guarantee let new_list: Vec = list @@ -184,7 +190,7 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { }) .collect::>()?; - Ok(Transformed::yes(Expr::InList(InList { + Ok(Transformed::yes(Expr::_in_list(InList { expr: inner.clone(), list: new_list, negated: *negated, @@ -301,7 +307,7 @@ mod tests { true, ), ( - Expr::BinaryExpr(BinaryExpr { + Expr::binary_expr(BinaryExpr { left: Box::new(col("x")), op: Operator::IsDistinctFrom, right: Box::new(lit(ScalarValue::Null)), @@ -309,7 +315,7 @@ mod tests { true, ), ( - Expr::BinaryExpr(BinaryExpr { + Expr::binary_expr(BinaryExpr { left: Box::new(col("x")), op: Operator::IsDistinctFrom, right: Box::new(lit(ScalarValue::Date32(Some(17000)))), @@ -360,7 +366,7 @@ mod tests { // (original_expr, expected_simplification) let simplified_cases = &[ ( - Expr::BinaryExpr(BinaryExpr { + Expr::binary_expr(BinaryExpr { left: Box::new(col("x")), op: Operator::IsDistinctFrom, right: Box::new(lit("z")), @@ -368,7 +374,7 @@ mod tests { true, ), ( - Expr::BinaryExpr(BinaryExpr { + Expr::binary_expr(BinaryExpr { left: Box::new(col("x")), op: Operator::IsNotDistinctFrom, right: Box::new(lit("z")), @@ -388,7 +394,7 @@ mod tests { col("x").not_eq(lit("a")), col("x").between(lit("a"), lit("z")), col("x").not_between(lit("a"), lit("z")), - Expr::BinaryExpr(BinaryExpr { + Expr::binary_expr(BinaryExpr { left: Box::new(col("x")), op: Operator::IsDistinctFrom, right: Box::new(lit(ScalarValue::Null)), @@ -467,7 +473,7 @@ mod tests { .collect(); assert_eq!( output, - Expr::InList(InList { + Expr::_in_list(InList { expr: Box::new(col(*column_name)), list: expected_list, negated: *negated, diff --git a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs index c8638eb723955..4f6a7832532ce 100644 --- a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs @@ -38,11 +38,14 @@ impl TreeNodeRewriter for ShortenInListSimplifier { fn f_up(&mut self, expr: Expr) -> Result> { // if expr is a single column reference: // expr IN (A, B, ...) --> (expr = A) OR (expr = B) OR (expr = C) - if let Expr::InList(InList { - expr, - list, - negated, - }) = expr.clone() + if let Expr::InList( + InList { + expr, + list, + negated, + }, + _, + ) = expr.clone() { if !list.is_empty() && ( diff --git a/datafusion/optimizer/src/simplify_expressions/regex.rs b/datafusion/optimizer/src/simplify_expressions/regex.rs index 6c99f18ab0f64..52be76a1c972d 100644 --- a/datafusion/optimizer/src/simplify_expressions/regex.rs +++ b/datafusion/optimizer/src/simplify_expressions/regex.rs @@ -67,7 +67,7 @@ pub fn simplify_regex_expr( } // Leave untouched if optimization didn't work - Ok(Expr::BinaryExpr(BinaryExpr { left, op, right })) + Ok(Expr::binary_expr(BinaryExpr { left, op, right })) } #[derive(Debug)] @@ -105,7 +105,7 @@ impl OperatorMode { case_insensitive: self.i, }; - Expr::Like(like) + Expr::_like(like) } /// Creates an [`Expr::BinaryExpr`] of "`left` = `right`" or "`left` != `right`". @@ -115,7 +115,7 @@ impl OperatorMode { } else { Operator::Eq }; - Expr::BinaryExpr(BinaryExpr { left, op, right }) + Expr::binary_expr(BinaryExpr { left, op, right }) } } diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs index 200f1f159d813..c1bc086661f29 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs @@ -81,7 +81,7 @@ impl SimplifyExpressions { ) -> Result> { let schema = if !plan.inputs().is_empty() { DFSchemaRef::new(merge_schema(&plan.inputs())) - } else if let LogicalPlan::TableScan(scan) = &plan { + } else if let LogicalPlan::TableScan(scan, _) = &plan { // When predicates are pushed into a table scan, there is no input // schema to resolve predicates against, so it must be handled specially // @@ -114,7 +114,7 @@ impl SimplifyExpressions { // This is likely related to the fact that order of the columns must // match the order of the children. see // https://github.com/apache/datafusion/pull/8780 for more details - let simplifier = if let LogicalPlan::Join(_) = plan { + let simplifier = if let LogicalPlan::Join(_, _) = plan { simplifier.with_canonicalize(false) } else { simplifier @@ -406,12 +406,12 @@ mod tests { #[test] fn test_simplify_optimized_plan_support_values() -> Result<()> { - let expr1 = Expr::BinaryExpr(BinaryExpr::new( + let expr1 = Expr::binary_expr(BinaryExpr::new( Box::new(lit(1)), Operator::Plus, Box::new(lit(2)), )); - let expr2 = Expr::BinaryExpr(BinaryExpr::new( + let expr2 = Expr::binary_expr(BinaryExpr::new( Box::new(lit(2)), Operator::Minus, Box::new(lit(1)), @@ -439,7 +439,7 @@ mod tests { #[test] fn cast_expr() -> Result<()> { let table_scan = test_table_scan(); - let proj = vec![Expr::Cast(Cast::new(Box::new(lit("0")), DataType::Int32))]; + let proj = vec![Expr::cast(Cast::new(Box::new(lit("0")), DataType::Int32))]; let plan = LogicalPlanBuilder::from(table_scan) .project(proj)? .build()?; diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index c30c3631c193a..15927aaab468c 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -23,6 +23,7 @@ use datafusion_expr::{ expr_fn::{and, bitwise_and, bitwise_or, or}, Expr, Like, Operator, }; +use std::ops::Not; pub static POWS_OF_TEN: [i128; 38] = [ 1, @@ -69,7 +70,7 @@ pub static POWS_OF_TEN: [i128; 38] = [ /// expressions. Such as: (A AND B) AND C fn expr_contains_inner(expr: &Expr, needle: &Expr, search_op: Operator) -> bool { match expr { - Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == search_op => { + Expr::BinaryExpr(BinaryExpr { left, op, right }, _) if *op == search_op => { expr_contains_inner(left, needle, search_op) || expr_contains_inner(right, needle, search_op) } @@ -92,7 +93,7 @@ pub fn delete_xor_in_complex_expr(expr: &Expr, needle: &Expr, is_left: bool) -> xor_counter: &mut i32, ) -> Expr { match expr { - Expr::BinaryExpr(BinaryExpr { left, op, right }) + Expr::BinaryExpr(BinaryExpr { left, op, right }, _) if *op == Operator::BitwiseXor => { let left_expr = recursive_delete_xor_in_expr(left, needle, xor_counter); @@ -105,7 +106,7 @@ pub fn delete_xor_in_complex_expr(expr: &Expr, needle: &Expr, is_left: bool) -> return left_expr; } - Expr::BinaryExpr(BinaryExpr::new( + Expr::binary_expr(BinaryExpr::new( Box::new(left_expr), *op, Box::new(right_expr), @@ -121,13 +122,13 @@ pub fn delete_xor_in_complex_expr(expr: &Expr, needle: &Expr, is_left: bool) -> return needle.clone(); } else if xor_counter % 2 == 0 { if is_left { - return Expr::BinaryExpr(BinaryExpr::new( + return Expr::binary_expr(BinaryExpr::new( Box::new(needle.clone()), Operator::BitwiseXor, Box::new(result_expr), )); } else { - return Expr::BinaryExpr(BinaryExpr::new( + return Expr::binary_expr(BinaryExpr::new( Box::new(result_expr), Operator::BitwiseXor, Box::new(needle.clone()), @@ -211,17 +212,17 @@ pub fn is_false(expr: &Expr) -> bool { /// returns true if `haystack` looks like (needle OP X) or (X OP needle) pub fn is_op_with(target_op: Operator, haystack: &Expr, needle: &Expr) -> bool { - matches!(haystack, Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == &target_op && (needle == left.as_ref() || needle == right.as_ref()) && !needle.is_volatile()) + matches!(haystack, Expr::BinaryExpr(BinaryExpr { left, op, right }, _) if op == &target_op && (needle == left.as_ref() || needle == right.as_ref()) && !needle.is_volatile()) } /// returns true if `not_expr` is !`expr` (not) pub fn is_not_of(not_expr: &Expr, expr: &Expr) -> bool { - matches!(not_expr, Expr::Not(inner) if expr == inner.as_ref()) + matches!(not_expr, Expr::Not(inner, _) if expr == inner.as_ref()) } /// returns true if `not_expr` is !`expr` (bitwise not) pub fn is_negative_of(not_expr: &Expr, expr: &Expr) -> bool { - matches!(not_expr, Expr::Negative(inner) if expr == inner.as_ref()) + matches!(not_expr, Expr::Negative(inner, _) if expr == inner.as_ref()) } /// returns the contained boolean value in `expr` as @@ -249,9 +250,9 @@ pub fn as_bool_lit(expr: &Expr) -> Result> { /// For others, use Not clause pub fn negate_clause(expr: Expr) -> Expr { match expr { - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + Expr::BinaryExpr(BinaryExpr { left, op, right }, _) => { if let Some(negated_op) = op.negate() { - return Expr::BinaryExpr(BinaryExpr::new(left, negated_op, right)); + return Expr::binary_expr(BinaryExpr::new(left, negated_op, right)); } match op { // not (A and B) ===> (not A) or (not B) @@ -269,34 +270,35 @@ pub fn negate_clause(expr: Expr) -> Expr { and(left, right) } // use not clause - _ => Expr::Not(Box::new(Expr::BinaryExpr(BinaryExpr::new( - left, op, right, - )))), + _ => Expr::binary_expr(BinaryExpr::new(left, op, right)).not(), } } // not (not A) ===> A - Expr::Not(expr) => *expr, + Expr::Not(expr, _) => *expr, // not (A is not null) ===> A is null - Expr::IsNotNull(expr) => expr.is_null(), + Expr::IsNotNull(expr, _) => expr.is_null(), // not (A is null) ===> A is not null - Expr::IsNull(expr) => expr.is_not_null(), + Expr::IsNull(expr, _) => expr.is_not_null(), // not (A not in (..)) ===> A in (..) // not (A in (..)) ===> A not in (..) - Expr::InList(InList { - expr, - list, - negated, - }) => expr.in_list(list, !negated), + Expr::InList( + InList { + expr, + list, + negated, + }, + _, + ) => expr.in_list(list, !negated), // not (A between B and C) ===> (A not between B and C) // not (A not between B and C) ===> (A between B and C) - Expr::Between(between) => Expr::Between(Between::new( + Expr::Between(between, _) => Expr::_between(Between::new( between.expr, !between.negated, between.low, between.high, )), // not (A like B) ===> A not like B - Expr::Like(like) => Expr::Like(Like::new( + Expr::Like(like, _) => Expr::_like(Like::new( !like.negated, like.expr, like.pattern, @@ -304,7 +306,7 @@ pub fn negate_clause(expr: Expr) -> Expr { like.case_insensitive, )), // use not clause - _ => Expr::Not(Box::new(expr)), + _ => expr.not(), } } @@ -318,7 +320,7 @@ pub fn negate_clause(expr: Expr) -> Expr { /// For others, use Negative clause pub fn distribute_negation(expr: Expr) -> Expr { match expr { - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + Expr::BinaryExpr(BinaryExpr { left, op, right }, _) => { match op { // ~(A & B) ===> ~A | ~B Operator::BitwiseAnd => { @@ -335,14 +337,14 @@ pub fn distribute_negation(expr: Expr) -> Expr { bitwise_and(left, right) } // use negative clause - _ => Expr::Negative(Box::new(Expr::BinaryExpr(BinaryExpr::new( + _ => Expr::negative(Box::new(Expr::binary_expr(BinaryExpr::new( left, op, right, )))), } } // ~(~A) ===> A - Expr::Negative(expr) => *expr, + Expr::Negative(expr, _) => *expr, // use negative clause - _ => Expr::Negative(Box::new(expr)), + _ => Expr::negative(Box::new(expr)), } } diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index c8f3a4bc7859c..7303c77829033 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -66,14 +66,17 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { let mut fields_set = HashSet::new(); let mut aggregate_count = 0; for expr in aggr_expr { - if let Expr::AggregateFunction(AggregateFunction { - func, - distinct, - args, - filter, - order_by, - null_treatment: _, - }) = expr + if let Expr::AggregateFunction( + AggregateFunction { + func, + distinct, + args, + filter, + order_by, + null_treatment: _, + }, + _, + ) = expr { if filter.is_some() || order_by.is_some() { return Ok(false); @@ -98,7 +101,7 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { /// Check if the first expr is [Expr::GroupingSet]. fn contains_grouping_set(expr: &[Expr]) -> bool { - matches!(expr.first(), Some(Expr::GroupingSet(_))) + matches!(expr.first(), Some(Expr::GroupingSet(_, _))) } impl OptimizerRule for SingleDistinctToGroupBy { @@ -120,13 +123,16 @@ impl OptimizerRule for SingleDistinctToGroupBy { _config: &dyn OptimizerConfig, ) -> Result, DataFusionError> { match plan { - LogicalPlan::Aggregate(Aggregate { - input, - aggr_expr, - schema, - group_expr, - .. - }) if is_single_distinct_agg(&aggr_expr)? + LogicalPlan::Aggregate( + Aggregate { + input, + aggr_expr, + schema, + group_expr, + .. + }, + _, + ) if is_single_distinct_agg(&aggr_expr)? && !contains_grouping_set(&group_expr) => { let group_size = group_expr.len(); @@ -182,7 +188,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { mut args, distinct, .. - }) => { + }, _) => { if distinct { if args.len() != 1 { return internal_err!("DISTINCT aggregate should have exactly one argument"); @@ -193,7 +199,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { inner_group_exprs .push(arg.alias(SINGLE_DISTINCT_ALIAS)); } - Ok(Expr::AggregateFunction(AggregateFunction::new_udf( + Ok(Expr::aggregate_function(AggregateFunction::new_udf( func, vec![col(SINGLE_DISTINCT_ALIAS)], false, // intentional to remove distinct here @@ -206,7 +212,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { index += 1; let alias_str = format!("alias{}", index); inner_aggr_exprs.push( - Expr::AggregateFunction(AggregateFunction::new_udf( + Expr::aggregate_function(AggregateFunction::new_udf( Arc::clone(&func), args, false, @@ -216,7 +222,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { )) .alias(&alias_str), ); - Ok(Expr::AggregateFunction(AggregateFunction::new_udf( + Ok(Expr::aggregate_function(AggregateFunction::new_udf( func, vec![col(&alias_str)], false, @@ -231,7 +237,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { .collect::>>()?; // construct the inner AggrPlan - let inner_agg = LogicalPlan::Aggregate(Aggregate::try_new( + let inner_agg = LogicalPlan::aggregate(Aggregate::try_new( input, inner_group_exprs, inner_aggr_exprs, @@ -263,7 +269,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { )) .collect(); - let outer_aggr = LogicalPlan::Aggregate(Aggregate::try_new( + let outer_aggr = LogicalPlan::aggregate(Aggregate::try_new( Arc::new(inner_agg), outer_group_exprs, outer_aggr_exprs, @@ -288,7 +294,7 @@ mod tests { use datafusion_functions_aggregate::sum::sum_udaf; fn max_distinct(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new_udf( + Expr::aggregate_function(AggregateFunction::new_udf( max_udaf(), vec![expr], true, @@ -345,7 +351,7 @@ mod tests { fn single_distinct_and_grouping_set() -> Result<()> { let table_scan = test_table_scan()?; - let grouping_set = Expr::GroupingSet(GroupingSet::GroupingSets(vec![ + let grouping_set = Expr::grouping_set(GroupingSet::GroupingSets(vec![ vec![col("a")], vec![col("b")], ])); @@ -366,7 +372,8 @@ mod tests { fn single_distinct_and_cube() -> Result<()> { let table_scan = test_table_scan()?; - let grouping_set = Expr::GroupingSet(GroupingSet::Cube(vec![col("a"), col("b")])); + let grouping_set = + Expr::grouping_set(GroupingSet::Cube(vec![col("a"), col("b")])); let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![grouping_set], vec![count_distinct(col("c"))])? @@ -385,7 +392,7 @@ mod tests { let table_scan = test_table_scan()?; let grouping_set = - Expr::GroupingSet(GroupingSet::Rollup(vec![col("a"), col("b")])); + Expr::grouping_set(GroupingSet::Rollup(vec![col("a"), col("b")])); let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![grouping_set], vec![count_distinct(col("c"))])? @@ -569,7 +576,7 @@ mod tests { let table_scan = test_table_scan()?; // sum(a) FILTER (WHERE a > 5) - let expr = Expr::AggregateFunction(AggregateFunction::new_udf( + let expr = Expr::aggregate_function(AggregateFunction::new_udf( sum_udaf(), vec![col("a")], false, @@ -612,7 +619,7 @@ mod tests { let table_scan = test_table_scan()?; // SUM(a ORDER BY a) - let expr = Expr::AggregateFunction(AggregateFunction::new_udf( + let expr = Expr::aggregate_function(AggregateFunction::new_udf( sum_udaf(), vec![col("a")], false, diff --git a/datafusion/optimizer/src/test/user_defined.rs b/datafusion/optimizer/src/test/user_defined.rs index a39f90b5da5db..94875ae17fb92 100644 --- a/datafusion/optimizer/src/test/user_defined.rs +++ b/datafusion/optimizer/src/test/user_defined.rs @@ -30,7 +30,7 @@ use std::{ /// Create a new user defined plan node, for testing pub fn new(input: LogicalPlan) -> LogicalPlan { let node = Arc::new(TestUserDefinedPlanNode { input }); - LogicalPlan::Extension(Extension { node }) + LogicalPlan::extension(Extension { node }) } #[derive(PartialEq, Eq, PartialOrd, Hash)] diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 31e21d08b569a..09d80a7679f01 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -101,7 +101,7 @@ impl OptimizerRule for UnwrapCastInComparison { ) -> Result> { let mut schema = merge_schema(&plan.inputs()); - if let LogicalPlan::TableScan(ts) = &plan { + if let LogicalPlan::TableScan(ts, _) = &plan { let source_schema = DFSchema::try_from_qualified_schema( ts.table_name.clone(), &ts.source.schema(), @@ -136,7 +136,7 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { // For case: // try_cast/cast(expr as data_type) op literal // literal op try_cast/cast(expr as data_type) - Expr::BinaryExpr(BinaryExpr { left, op, right }) + Expr::BinaryExpr(BinaryExpr { left, op, right }, _) if { let Ok(left_type) = left.get_type(&self.schema) else { return Ok(Transformed::no(expr)); @@ -152,12 +152,18 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { match (left.as_mut(), right.as_mut()) { ( Expr::Literal(left_lit_value), - Expr::TryCast(TryCast { - expr: right_expr, .. - }) - | Expr::Cast(Cast { - expr: right_expr, .. - }), + Expr::TryCast( + TryCast { + expr: right_expr, .. + }, + _, + ) + | Expr::Cast( + Cast { + expr: right_expr, .. + }, + _, + ), ) => { // if the left_lit_value can be cast to the type of expr // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal @@ -181,12 +187,18 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { } } ( - Expr::TryCast(TryCast { - expr: left_expr, .. - }) - | Expr::Cast(Cast { - expr: left_expr, .. - }), + Expr::TryCast( + TryCast { + expr: left_expr, .. + }, + _, + ) + | Expr::Cast( + Cast { + expr: left_expr, .. + }, + _, + ), Expr::Literal(right_lit_value), ) => { // if the right_lit_value can be cast to the type of expr @@ -215,15 +227,24 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { } // For case: // try_cast/cast(expr as left_type) in (expr1,expr2,expr3) - Expr::InList(InList { - expr: left, list, .. - }) => { - let (Expr::TryCast(TryCast { - expr: left_expr, .. - }) - | Expr::Cast(Cast { - expr: left_expr, .. - })) = left.as_mut() + Expr::InList( + InList { + expr: left, list, .. + }, + _, + ) => { + let (Expr::TryCast( + TryCast { + expr: left_expr, .. + }, + _, + ) + | Expr::Cast( + Cast { + expr: left_expr, .. + }, + _, + )) = left.as_mut() else { return Ok(Transformed::no(expr)); }; diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 9f325bc01b1d0..71e87c7839a57 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -195,7 +195,7 @@ mod tests { // a IS NULL (is_null(col("a")), false), // a IS NOT NULL - (Expr::IsNotNull(Box::new(col("a"))), true), + (Expr::_is_not_null(Box::new(col("a"))), true), // a = NULL ( binary_expr(col("a"), Operator::Eq, Expr::Literal(ScalarValue::Null)), diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index bffc2c46fc1e1..82738f221f4d3 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -111,7 +111,7 @@ pub fn create_physical_expr( let input_schema: &Schema = &input_dfschema.into(); match e { - Expr::Alias(Alias { expr, .. }) => { + Expr::Alias(Alias { expr, .. }, _) => { Ok(create_physical_expr(expr, input_dfschema, execution_props)?) } Expr::Column(c) => { @@ -138,7 +138,7 @@ pub fn create_physical_expr( } } } - Expr::IsTrue(expr) => { + Expr::IsTrue(expr, _) => { let binary_op = binary_expr( expr.as_ref().clone(), Operator::IsNotDistinctFrom, @@ -146,12 +146,12 @@ pub fn create_physical_expr( ); create_physical_expr(&binary_op, input_dfschema, execution_props) } - Expr::IsNotTrue(expr) => { + Expr::IsNotTrue(expr, _) => { let binary_op = binary_expr(expr.as_ref().clone(), Operator::IsDistinctFrom, lit(true)); create_physical_expr(&binary_op, input_dfschema, execution_props) } - Expr::IsFalse(expr) => { + Expr::IsFalse(expr, _) => { let binary_op = binary_expr( expr.as_ref().clone(), Operator::IsNotDistinctFrom, @@ -159,12 +159,12 @@ pub fn create_physical_expr( ); create_physical_expr(&binary_op, input_dfschema, execution_props) } - Expr::IsNotFalse(expr) => { + Expr::IsNotFalse(expr, _) => { let binary_op = binary_expr(expr.as_ref().clone(), Operator::IsDistinctFrom, lit(false)); create_physical_expr(&binary_op, input_dfschema, execution_props) } - Expr::IsUnknown(expr) => { + Expr::IsUnknown(expr, _) => { let binary_op = binary_expr( expr.as_ref().clone(), Operator::IsNotDistinctFrom, @@ -172,7 +172,7 @@ pub fn create_physical_expr( ); create_physical_expr(&binary_op, input_dfschema, execution_props) } - Expr::IsNotUnknown(expr) => { + Expr::IsNotUnknown(expr, _) => { let binary_op = binary_expr( expr.as_ref().clone(), Operator::IsDistinctFrom, @@ -180,7 +180,7 @@ pub fn create_physical_expr( ); create_physical_expr(&binary_op, input_dfschema, execution_props) } - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + Expr::BinaryExpr(BinaryExpr { left, op, right }, _) => { // Create physical expressions for left and right operands let lhs = create_physical_expr(left, input_dfschema, execution_props)?; let rhs = create_physical_expr(right, input_dfschema, execution_props)?; @@ -193,13 +193,16 @@ pub fn create_physical_expr( // planning. binary(lhs, *op, rhs, input_schema) } - Expr::Like(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - }) => { + Expr::Like( + Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + }, + _, + ) => { if escape_char.is_some() { return exec_err!("LIKE does not support escape_char"); } @@ -215,13 +218,16 @@ pub fn create_physical_expr( input_schema, ) } - Expr::SimilarTo(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - }) => { + Expr::SimilarTo( + Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + }, + _, + ) => { if escape_char.is_some() { return exec_err!("SIMILAR TO does not support escape_char yet"); } @@ -231,7 +237,7 @@ pub fn create_physical_expr( create_physical_expr(pattern, input_dfschema, execution_props)?; similar_to(*negated, *case_insensitive, physical_expr, physical_pattern) } - Expr::Case(case) => { + Expr::Case(case, _) => { let expr: Option> = if let Some(e) = &case.expr { Some(create_physical_expr( e.as_ref(), @@ -268,34 +274,34 @@ pub fn create_physical_expr( }; Ok(expressions::case(expr, when_then_expr, else_expr)?) } - Expr::Cast(Cast { expr, data_type }) => expressions::cast( + Expr::Cast(Cast { expr, data_type }, _) => expressions::cast( create_physical_expr(expr, input_dfschema, execution_props)?, input_schema, data_type.clone(), ), - Expr::TryCast(TryCast { expr, data_type }) => expressions::try_cast( + Expr::TryCast(TryCast { expr, data_type }, _) => expressions::try_cast( create_physical_expr(expr, input_dfschema, execution_props)?, input_schema, data_type.clone(), ), - Expr::Not(expr) => { + Expr::Not(expr, _) => { expressions::not(create_physical_expr(expr, input_dfschema, execution_props)?) } - Expr::Negative(expr) => expressions::negative( + Expr::Negative(expr, _) => expressions::negative( create_physical_expr(expr, input_dfschema, execution_props)?, input_schema, ), - Expr::IsNull(expr) => expressions::is_null(create_physical_expr( + Expr::IsNull(expr, _) => expressions::is_null(create_physical_expr( expr, input_dfschema, execution_props, )?), - Expr::IsNotNull(expr) => expressions::is_not_null(create_physical_expr( + Expr::IsNotNull(expr, _) => expressions::is_not_null(create_physical_expr( expr, input_dfschema, execution_props, )?), - Expr::ScalarFunction(ScalarFunction { func, args }) => { + Expr::ScalarFunction(ScalarFunction { func, args }, _) => { let physical_args = create_physical_exprs(args, input_dfschema, execution_props)?; @@ -307,12 +313,15 @@ pub fn create_physical_expr( input_dfschema, ) } - Expr::Between(Between { - expr, - negated, - low, - high, - }) => { + Expr::Between( + Between { + expr, + negated, + low, + high, + }, + _, + ) => { let value_expr = create_physical_expr(expr, input_dfschema, execution_props)?; let low_expr = create_physical_expr(low, input_dfschema, execution_props)?; let high_expr = create_physical_expr(high, input_dfschema, execution_props)?; @@ -341,11 +350,14 @@ pub fn create_physical_expr( binary_expr } } - Expr::InList(InList { - expr, - list, - negated, - }) => match expr.as_ref() { + Expr::InList( + InList { + expr, + list, + negated, + }, + _, + ) => match expr.as_ref() { Expr::Literal(ScalarValue::Utf8(None)) => { Ok(expressions::lit(ScalarValue::Boolean(None))) } diff --git a/datafusion/physical-optimizer/src/pruning.rs b/datafusion/physical-optimizer/src/pruning.rs index 3cfb03b7205a5..52d7edff42ec0 100644 --- a/datafusion/physical-optimizer/src/pruning.rs +++ b/datafusion/physical-optimizer/src/pruning.rs @@ -2549,7 +2549,7 @@ mod tests { Field::new("c2", DataType::Int32, false), ]); // test c1 in(1, 2, 3) - let expr = Expr::InList(InList::new( + let expr = Expr::_in_list(InList::new( Box::new(col("c1")), vec![lit(1), lit(2), lit(3)], false, @@ -2580,7 +2580,7 @@ mod tests { Field::new("c2", DataType::Int32, false), ]); // test c1 in() - let expr = Expr::InList(InList::new(Box::new(col("c1")), vec![], false)); + let expr = Expr::_in_list(InList::new(Box::new(col("c1")), vec![], false)); let expected_expr = "true"; let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); @@ -2596,7 +2596,7 @@ mod tests { Field::new("c2", DataType::Int32, false), ]); // test c1 not in(1, 2, 3) - let expr = Expr::InList(InList::new( + let expr = Expr::_in_list(InList::new( Box::new(col("c1")), vec![lit(1), lit(2), lit(3)], true, @@ -2747,7 +2747,7 @@ mod tests { fn row_group_predicate_cast_list() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); // test cast(c1 as int64) in int64(1, 2, 3) - let expr = Expr::InList(InList::new( + let expr = Expr::_in_list(InList::new( Box::new(cast(col("c1"), DataType::Int64)), vec![ lit(ScalarValue::Int64(Some(1))), @@ -2772,7 +2772,7 @@ mod tests { test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); - let expr = Expr::InList(InList::new( + let expr = Expr::_in_list(InList::new( Box::new(cast(col("c1"), DataType::Int64)), vec![ lit(ScalarValue::Int64(Some(1))), @@ -3107,7 +3107,7 @@ mod tests { // -i < 0 prune_with_expr( - Expr::Negative(Box::new(col("i"))).lt(lit(0)), + Expr::negative(Box::new(col("i"))).lt(lit(0)), &schema, &statistics, expected_ret, @@ -3136,7 +3136,7 @@ mod tests { prune_with_expr( // -i >= 0 - Expr::Negative(Box::new(col("i"))).gt_eq(lit(0)), + Expr::negative(Box::new(col("i"))).gt_eq(lit(0)), &schema, &statistics, expected_ret, @@ -3173,7 +3173,7 @@ mod tests { prune_with_expr( // cast(-i as utf8) >= 0 - cast(Expr::Negative(Box::new(col("i"))), DataType::Utf8).gt_eq(lit("0")), + cast(Expr::negative(Box::new(col("i"))), DataType::Utf8).gt_eq(lit("0")), &schema, &statistics, expected_ret, @@ -3181,7 +3181,7 @@ mod tests { prune_with_expr( // try_cast(-i as utf8) >= 0 - try_cast(Expr::Negative(Box::new(col("i"))), DataType::Utf8).gt_eq(lit("0")), + try_cast(Expr::negative(Box::new(col("i"))), DataType::Utf8).gt_eq(lit("0")), &schema, &statistics, expected_ret, @@ -3281,7 +3281,7 @@ mod tests { prune_with_expr( // -i < 1 - Expr::Negative(Box::new(col("i"))).lt(lit(1)), + Expr::negative(Box::new(col("i"))).lt(lit(1)), &schema, &statistics, expected_ret, @@ -3431,7 +3431,7 @@ mod tests { prune_with_expr( // `-cast(i as int64) < 0` convert to `cast(i as int64) > -0` - Expr::Negative(Box::new(cast(col("i"), DataType::Int64))) + Expr::negative(Box::new(cast(col("i"), DataType::Int64))) .lt(lit(ScalarValue::Int64(Some(0)))), &schema, &statistics, diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 64d8e24ce5182..b5fe83d619b27 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -22,7 +22,7 @@ use datafusion_common::{ exec_datafusion_err, internal_err, plan_datafusion_err, RecursionUnnestOption, Result, ScalarValue, TableReference, UnnestOptions, }; -use datafusion_expr::expr::{Alias, Placeholder, Sort, Wildcard}; +use datafusion_expr::expr::{Placeholder, Sort, Wildcard}; use datafusion_expr::expr::{Unnest, WildcardOptions}; use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ @@ -245,7 +245,11 @@ pub fn parse_expr( Ok(operands .into_iter() .reduce(|left, right| { - Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right))) + Expr::binary_expr(BinaryExpr::new( + Box::new(left), + op, + Box::new(right), + )) }) .expect("Binary expression could not be reduced to a single expression.")) } @@ -284,7 +288,7 @@ pub fn parse_expr( }; let args = parse_exprs(&expr.exprs, registry, codec)?; - Expr::WindowFunction(WindowFunction::new( + Expr::window_function(WindowFunction::new( expr::WindowFunctionDefinition::AggregateUDF(udaf_function), args, )) @@ -301,7 +305,7 @@ pub fn parse_expr( }; let args = parse_exprs(&expr.exprs, registry, codec)?; - Expr::WindowFunction(WindowFunction::new( + Expr::window_function(WindowFunction::new( expr::WindowFunctionDefinition::WindowUDF(udwf_function), args, )) @@ -313,64 +317,59 @@ pub fn parse_expr( } } } - ExprType::Alias(alias) => Ok(Expr::Alias(Alias::new( - parse_required_expr(alias.expr.as_deref(), registry, "expr", codec)?, - alias - .relation - .first() - .map(|r| TableReference::try_from(r.clone())) - .transpose()?, - alias.alias.clone(), + ExprType::Alias(alias) => { + Ok( + parse_required_expr(alias.expr.as_deref(), registry, "expr", codec)? + .alias_qualified( + alias + .relation + .first() + .map(|r| TableReference::try_from(r.clone())) + .transpose()?, + alias.alias.clone(), + ), + ) + } + ExprType::IsNullExpr(is_null) => Ok(Expr::_is_null(Box::new( + parse_required_expr(is_null.expr.as_deref(), registry, "expr", codec)?, ))), - ExprType::IsNullExpr(is_null) => Ok(Expr::IsNull(Box::new(parse_required_expr( - is_null.expr.as_deref(), - registry, - "expr", - codec, - )?))), - ExprType::IsNotNullExpr(is_not_null) => Ok(Expr::IsNotNull(Box::new( + ExprType::IsNotNullExpr(is_not_null) => Ok(Expr::_is_not_null(Box::new( parse_required_expr(is_not_null.expr.as_deref(), registry, "expr", codec)?, ))), - ExprType::NotExpr(not) => Ok(Expr::Not(Box::new(parse_required_expr( + ExprType::NotExpr(not) => Ok(Expr::_not(Box::new(parse_required_expr( not.expr.as_deref(), registry, "expr", codec, )?))), - ExprType::IsTrue(msg) => Ok(Expr::IsTrue(Box::new(parse_required_expr( - msg.expr.as_deref(), - registry, - "expr", - codec, - )?))), - ExprType::IsFalse(msg) => Ok(Expr::IsFalse(Box::new(parse_required_expr( - msg.expr.as_deref(), - registry, - "expr", - codec, - )?))), - ExprType::IsUnknown(msg) => Ok(Expr::IsUnknown(Box::new(parse_required_expr( + ExprType::IsTrue(msg) => Ok(Expr::_is_true(Box::new(parse_required_expr( msg.expr.as_deref(), registry, "expr", codec, )?))), - ExprType::IsNotTrue(msg) => Ok(Expr::IsNotTrue(Box::new(parse_required_expr( + ExprType::IsFalse(msg) => Ok(Expr::_is_false(Box::new(parse_required_expr( msg.expr.as_deref(), registry, "expr", codec, )?))), - ExprType::IsNotFalse(msg) => Ok(Expr::IsNotFalse(Box::new(parse_required_expr( + ExprType::IsUnknown(msg) => Ok(Expr::_is_unknown(Box::new(parse_required_expr( msg.expr.as_deref(), registry, "expr", codec, )?))), - ExprType::IsNotUnknown(msg) => Ok(Expr::IsNotUnknown(Box::new( + ExprType::IsNotTrue(msg) => Ok(Expr::_is_not_true(Box::new( + parse_required_expr(msg.expr.as_deref(), registry, "expr", codec)?, + ))), + ExprType::IsNotFalse(msg) => Ok(Expr::_is_not_false(Box::new( + parse_required_expr(msg.expr.as_deref(), registry, "expr", codec)?, + ))), + ExprType::IsNotUnknown(msg) => Ok(Expr::_is_not_unknown(Box::new( parse_required_expr(msg.expr.as_deref(), registry, "expr", codec)?, ))), - ExprType::Between(between) => Ok(Expr::Between(Between::new( + ExprType::Between(between) => Ok(Expr::_between(Between::new( Box::new(parse_required_expr( between.expr.as_deref(), registry, @@ -391,7 +390,7 @@ pub fn parse_expr( codec, )?), ))), - ExprType::Like(like) => Ok(Expr::Like(Like::new( + ExprType::Like(like) => Ok(Expr::_like(Like::new( like.negated, Box::new(parse_required_expr( like.expr.as_deref(), @@ -408,7 +407,7 @@ pub fn parse_expr( parse_escape_char(&like.escape_char)?, false, ))), - ExprType::Ilike(like) => Ok(Expr::Like(Like::new( + ExprType::Ilike(like) => Ok(Expr::_like(Like::new( like.negated, Box::new(parse_required_expr( like.expr.as_deref(), @@ -425,7 +424,7 @@ pub fn parse_expr( parse_escape_char(&like.escape_char)?, true, ))), - ExprType::SimilarTo(like) => Ok(Expr::SimilarTo(Like::new( + ExprType::SimilarTo(like) => Ok(Expr::similar_to(Like::new( like.negated, Box::new(parse_required_expr( like.expr.as_deref(), @@ -462,7 +461,7 @@ pub fn parse_expr( Ok((Box::new(when_expr), Box::new(then_expr))) }) .collect::, Box)>, Error>>()?; - Ok(Expr::Case(Case::new( + Ok(Expr::case(Case::new( parse_optional_expr(case.expr.as_deref(), registry, codec)?.map(Box::new), when_then_expr, parse_optional_expr(case.else_expr.as_deref(), registry, codec)? @@ -477,7 +476,7 @@ pub fn parse_expr( codec, )?); let data_type = cast.arrow_type.as_ref().required("arrow_type")?; - Ok(Expr::Cast(Cast::new(expr, data_type))) + Ok(Expr::cast(Cast::new(expr, data_type))) } ExprType::TryCast(cast) => { let expr = Box::new(parse_required_expr( @@ -487,9 +486,9 @@ pub fn parse_expr( codec, )?); let data_type = cast.arrow_type.as_ref().required("arrow_type")?; - Ok(Expr::TryCast(TryCast::new(expr, data_type))) + Ok(Expr::try_cast(TryCast::new(expr, data_type))) } - ExprType::Negative(negative) => Ok(Expr::Negative(Box::new( + ExprType::Negative(negative) => Ok(Expr::negative(Box::new( parse_required_expr(negative.expr.as_deref(), registry, "expr", codec)?, ))), ExprType::Unnest(unnest) => { @@ -497,9 +496,9 @@ pub fn parse_expr( if exprs.len() != 1 { return Err(proto_error("Unnest must have exactly one expression")); } - Ok(Expr::Unnest(Unnest::new(exprs.swap_remove(0)))) + Ok(Expr::unnest(Unnest::new(exprs.swap_remove(0)))) } - ExprType::InList(in_list) => Ok(Expr::InList(InList::new( + ExprType::InList(in_list) => Ok(Expr::_in_list(InList::new( Box::new(parse_required_expr( in_list.expr.as_deref(), registry, @@ -511,7 +510,7 @@ pub fn parse_expr( ))), ExprType::Wildcard(protobuf::Wildcard { qualifier }) => { let qualifier = qualifier.to_owned().map(|x| x.try_into()).transpose()?; - Ok(Expr::Wildcard(Wildcard { + Ok(Expr::wildcard(Wildcard { qualifier, options: WildcardOptions::default(), })) @@ -525,7 +524,7 @@ pub fn parse_expr( Some(buf) => codec.try_decode_udf(fun_name, buf)?, None => registry.udf(fun_name.as_str())?, }; - Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( + Ok(Expr::scalar_function(expr::ScalarFunction::new_udf( scalar_fn, parse_exprs(args, registry, codec)?, ))) @@ -536,7 +535,7 @@ pub fn parse_expr( None => registry.udaf(&pb.fun_name)?, }; - Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( + Ok(Expr::aggregate_function(expr::AggregateFunction::new_udf( agg_fn, parse_exprs(&pb.args, registry, codec)?, pb.distinct, @@ -550,16 +549,16 @@ pub fn parse_expr( } ExprType::GroupingSet(GroupingSetNode { expr }) => { - Ok(Expr::GroupingSet(GroupingSets( + Ok(Expr::grouping_set(GroupingSets( expr.iter() .map(|expr_list| parse_exprs(&expr_list.expr, registry, codec)) .collect::, Error>>()?, ))) } - ExprType::Cube(CubeNode { expr }) => Ok(Expr::GroupingSet(GroupingSet::Cube( + ExprType::Cube(CubeNode { expr }) => Ok(Expr::grouping_set(GroupingSet::Cube( parse_exprs(expr, registry, codec)?, ))), - ExprType::Rollup(RollupNode { expr }) => Ok(Expr::GroupingSet( + ExprType::Rollup(RollupNode { expr }) => Ok(Expr::grouping_set( GroupingSet::Rollup(parse_exprs(expr, registry, codec)?), )), ExprType::Placeholder(PlaceholderNode { id, data_type }) => match data_type { diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 50636048ebc96..035857ede2c29 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -297,7 +297,7 @@ impl AsLogicalPlan for LogicalPlanNode { match projection.optional_alias.as_ref() { Some(a) => match a { protobuf::projection_node::OptionalAlias::Alias(alias) => { - Ok(LogicalPlan::SubqueryAlias(SubqueryAlias::try_new( + Ok(LogicalPlan::subquery_alias(SubqueryAlias::try_new( Arc::new(new_proj), alias.clone(), )?)) @@ -567,7 +567,7 @@ impl AsLogicalPlan for LogicalPlanNode { column_defaults.insert(col_name.clone(), expr); } - Ok(LogicalPlan::Ddl(DdlStatement::CreateExternalTable( + Ok(LogicalPlan::ddl(DdlStatement::CreateExternalTable( CreateExternalTable { schema: pb_schema.try_into()?, name: from_table_reference( @@ -602,7 +602,7 @@ impl AsLogicalPlan for LogicalPlanNode { None }; - Ok(LogicalPlan::Ddl(DdlStatement::CreateView(CreateView { + Ok(LogicalPlan::ddl(DdlStatement::CreateView(CreateView { name: from_table_reference(create_view.name.as_ref(), "CreateView")?, temporary: create_view.temporary, input: Arc::new(plan), @@ -617,7 +617,7 @@ impl AsLogicalPlan for LogicalPlanNode { )) })?; - Ok(LogicalPlan::Ddl(DdlStatement::CreateCatalogSchema( + Ok(LogicalPlan::ddl(DdlStatement::CreateCatalogSchema( CreateCatalogSchema { schema_name: create_catalog_schema.schema_name.clone(), if_not_exists: create_catalog_schema.if_not_exists, @@ -632,7 +632,7 @@ impl AsLogicalPlan for LogicalPlanNode { )) })?; - Ok(LogicalPlan::Ddl(DdlStatement::CreateCatalog( + Ok(LogicalPlan::ddl(DdlStatement::CreateCatalog( CreateCatalog { catalog_name: create_catalog.catalog_name.clone(), if_not_exists: create_catalog.if_not_exists, @@ -772,7 +772,7 @@ impl AsLogicalPlan for LogicalPlanNode { let extension_node = extension_codec.try_decode(node, &input_plans, ctx)?; - Ok(LogicalPlan::Extension(extension_node)) + Ok(LogicalPlan::extension(extension_node)) } LogicalPlanType::Distinct(distinct) => { let input: LogicalPlan = @@ -848,7 +848,7 @@ impl AsLogicalPlan for LogicalPlanNode { .build() } LogicalPlanType::DropView(dropview) => { - Ok(LogicalPlan::Ddl(DdlStatement::DropView(DropView { + Ok(LogicalPlan::ddl(DdlStatement::DropView(DropView { name: from_table_reference(dropview.name.as_ref(), "DropView")?, if_exists: dropview.if_exists, schema: Arc::new(convert_required!(dropview.schema)?), @@ -862,7 +862,7 @@ impl AsLogicalPlan for LogicalPlanNode { extension_codec.try_decode_file_format(©.file_type, ctx)?, ); - Ok(LogicalPlan::Copy(dml::CopyTo { + Ok(LogicalPlan::copy(dml::CopyTo { input: Arc::new(input), output_url: copy.output_url.clone(), partition_by: copy.partition_by.clone(), @@ -873,7 +873,7 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlanType::Unnest(unnest) => { let input: LogicalPlan = into_logical_plan!(unnest.input, ctx, extension_codec)?; - Ok(LogicalPlan::Unnest(Unnest { + Ok(LogicalPlan::unnest(Unnest { input: Arc::new(input), exec_columns: unnest.exec_columns.iter().map(|c| c.into()).collect(), list_type_columns: unnest @@ -925,7 +925,7 @@ impl AsLogicalPlan for LogicalPlanNode { )))? .try_into_logical_plan(ctx, extension_codec)?; - Ok(LogicalPlan::RecursiveQuery(RecursiveQuery { + Ok(LogicalPlan::recursive_query(RecursiveQuery { name: recursive_query_node.name.clone(), static_term: Arc::new(static_term), recursive_term: Arc::new(recursive_term), @@ -954,7 +954,7 @@ impl AsLogicalPlan for LogicalPlanNode { Self: Sized, { match plan { - LogicalPlan::Values(Values { values, .. }) => { + LogicalPlan::Values(Values { values, .. }, _) => { let n_cols = if values.is_empty() { 0 } else { @@ -971,13 +971,16 @@ impl AsLogicalPlan for LogicalPlanNode { )), }) } - LogicalPlan::TableScan(TableScan { - table_name, - source, - filters, - projection, - .. - }) => { + LogicalPlan::TableScan( + TableScan { + table_name, + source, + filters, + projection, + .. + }, + _, + ) => { let provider = source_as_provider(source)?; let schema = provider.schema(); let source = provider.as_any(); @@ -1131,7 +1134,7 @@ impl AsLogicalPlan for LogicalPlanNode { Ok(node) } } - LogicalPlan::Projection(Projection { expr, input, .. }) => { + LogicalPlan::Projection(Projection { expr, input, .. }, _) => { Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Projection(Box::new( protobuf::ProjectionNode { @@ -1147,7 +1150,7 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Filter(filter) => { + LogicalPlan::Filter(filter, _) => { let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( filter.input.as_ref(), extension_codec, @@ -1164,7 +1167,7 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Distinct(Distinct::All(input)) => { + LogicalPlan::Distinct(Distinct::All(input), _) => { let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( input.as_ref(), extension_codec, @@ -1177,13 +1180,16 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Distinct(Distinct::On(DistinctOn { - on_expr, - select_expr, - sort_expr, - input, - .. - })) => { + LogicalPlan::Distinct( + Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + input, + .. + }), + _, + ) => { let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( input.as_ref(), extension_codec, @@ -1203,9 +1209,12 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Window(Window { - input, window_expr, .. - }) => { + LogicalPlan::Window( + Window { + input, window_expr, .. + }, + _, + ) => { let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( input.as_ref(), extension_codec, @@ -1219,12 +1228,15 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Aggregate(Aggregate { - group_expr, - aggr_expr, - input, - .. - }) => { + LogicalPlan::Aggregate( + Aggregate { + group_expr, + aggr_expr, + input, + .. + }, + _, + ) => { let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( input.as_ref(), extension_codec, @@ -1239,16 +1251,19 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Join(Join { - left, - right, - on, - filter, - join_type, - join_constraint, - null_equals_null, - .. - }) => { + LogicalPlan::Join( + Join { + left, + right, + on, + filter, + join_type, + join_constraint, + null_equals_null, + .. + }, + _, + ) => { let left: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( left.as_ref(), extension_codec, @@ -1290,10 +1305,10 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Subquery(_) => { + LogicalPlan::Subquery(_, _) => { not_impl_err!("LogicalPlan serde is not yet implemented for subqueries") } - LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }) => { + LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }, _) => { let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( input.as_ref(), extension_codec, @@ -1307,7 +1322,7 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Limit(limit) => { + LogicalPlan::Limit(limit, _) => { let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( limit.input.as_ref(), extension_codec, @@ -1333,7 +1348,7 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Sort(Sort { input, expr, fetch }) => { + LogicalPlan::Sort(Sort { input, expr, fetch }, _) => { let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( input.as_ref(), extension_codec, @@ -1350,10 +1365,13 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Repartition(Repartition { - input, - partitioning_scheme, - }) => { + LogicalPlan::Repartition( + Repartition { + input, + partitioning_scheme, + }, + _, + ) => { use datafusion::logical_expr::Partitioning; let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( input.as_ref(), @@ -1397,8 +1415,8 @@ impl AsLogicalPlan for LogicalPlanNode { }, )), }), - LogicalPlan::Ddl(DdlStatement::CreateExternalTable( - CreateExternalTable { + LogicalPlan::Ddl( + DdlStatement::CreateExternalTable(CreateExternalTable { name, location, file_type, @@ -1412,8 +1430,9 @@ impl AsLogicalPlan for LogicalPlanNode { constraints, column_defaults, temporary, - }, - )) => { + }), + _, + ) => { let mut converted_order_exprs: Vec = vec![]; for order in order_exprs { let temp = SortExprNodeCollection { @@ -1449,13 +1468,16 @@ impl AsLogicalPlan for LogicalPlanNode { )), }) } - LogicalPlan::Ddl(DdlStatement::CreateView(CreateView { - name, - input, - or_replace, - definition, - temporary, - })) => Ok(LogicalPlanNode { + LogicalPlan::Ddl( + DdlStatement::CreateView(CreateView { + name, + input, + or_replace, + definition, + temporary, + }), + _, + ) => Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CreateView(Box::new( protobuf::CreateViewNode { name: Some(name.clone().into()), @@ -1469,13 +1491,14 @@ impl AsLogicalPlan for LogicalPlanNode { }, ))), }), - LogicalPlan::Ddl(DdlStatement::CreateCatalogSchema( - CreateCatalogSchema { + LogicalPlan::Ddl( + DdlStatement::CreateCatalogSchema(CreateCatalogSchema { schema_name, if_not_exists, schema: df_schema, - }, - )) => Ok(LogicalPlanNode { + }), + _, + ) => Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CreateCatalogSchema( protobuf::CreateCatalogSchemaNode { schema_name: schema_name.clone(), @@ -1484,11 +1507,14 @@ impl AsLogicalPlan for LogicalPlanNode { }, )), }), - LogicalPlan::Ddl(DdlStatement::CreateCatalog(CreateCatalog { - catalog_name, - if_not_exists, - schema: df_schema, - })) => Ok(LogicalPlanNode { + LogicalPlan::Ddl( + DdlStatement::CreateCatalog(CreateCatalog { + catalog_name, + if_not_exists, + schema: df_schema, + }), + _, + ) => Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CreateCatalog( protobuf::CreateCatalogNode { catalog_name: catalog_name.clone(), @@ -1497,7 +1523,7 @@ impl AsLogicalPlan for LogicalPlanNode { }, )), }), - LogicalPlan::Analyze(a) => { + LogicalPlan::Analyze(a, _) => { let input = LogicalPlanNode::try_from_logical_plan( a.input.as_ref(), extension_codec, @@ -1511,7 +1537,7 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Explain(a) => { + LogicalPlan::Explain(a, _) => { let input = LogicalPlanNode::try_from_logical_plan( a.plan.as_ref(), extension_codec, @@ -1525,7 +1551,7 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Union(union) => { + LogicalPlan::Union(union, _) => { let inputs: Vec = union .inputs .iter() @@ -1537,7 +1563,7 @@ impl AsLogicalPlan for LogicalPlanNode { )), }) } - LogicalPlan::Extension(extension) => { + LogicalPlan::Extension(extension, _) => { let mut buf: Vec = vec![]; extension_codec.try_encode(extension, &mut buf)?; @@ -1554,11 +1580,14 @@ impl AsLogicalPlan for LogicalPlanNode { )), }) } - LogicalPlan::Statement(Statement::Prepare(Prepare { - name, - data_types, - input, - })) => { + LogicalPlan::Statement( + Statement::Prepare(Prepare { + name, + data_types, + input, + }), + _, + ) => { let input = LogicalPlanNode::try_from_logical_plan(input, extension_codec)?; Ok(LogicalPlanNode { @@ -1574,15 +1603,18 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Unnest(Unnest { - input, - exec_columns, - list_type_columns, - struct_type_columns, - dependency_indices, - schema, - options, - }) => { + LogicalPlan::Unnest( + Unnest { + input, + exec_columns, + list_type_columns, + struct_type_columns, + dependency_indices, + schema, + options, + }, + _, + ) => { let input = LogicalPlanNode::try_from_logical_plan(input, extension_codec)?; let proto_unnest_list_items = list_type_columns @@ -1618,20 +1650,23 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Ddl(DdlStatement::CreateMemoryTable(_)) => Err(proto_error( + LogicalPlan::Ddl(DdlStatement::CreateMemoryTable(_), _) => Err(proto_error( "LogicalPlan serde is not yet implemented for CreateMemoryTable", )), - LogicalPlan::Ddl(DdlStatement::CreateIndex(_)) => Err(proto_error( + LogicalPlan::Ddl(DdlStatement::CreateIndex(_), _) => Err(proto_error( "LogicalPlan serde is not yet implemented for CreateIndex", )), - LogicalPlan::Ddl(DdlStatement::DropTable(_)) => Err(proto_error( + LogicalPlan::Ddl(DdlStatement::DropTable(_), _) => Err(proto_error( "LogicalPlan serde is not yet implemented for DropTable", )), - LogicalPlan::Ddl(DdlStatement::DropView(DropView { - name, - if_exists, - schema, - })) => Ok(LogicalPlanNode { + LogicalPlan::Ddl( + DdlStatement::DropView(DropView { + name, + if_exists, + schema, + }), + _, + ) => Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::DropView( protobuf::DropViewNode { name: Some(name.clone().into()), @@ -1640,28 +1675,31 @@ impl AsLogicalPlan for LogicalPlanNode { }, )), }), - LogicalPlan::Ddl(DdlStatement::DropCatalogSchema(_)) => Err(proto_error( + LogicalPlan::Ddl(DdlStatement::DropCatalogSchema(_), _) => Err(proto_error( "LogicalPlan serde is not yet implemented for DropCatalogSchema", )), - LogicalPlan::Ddl(DdlStatement::CreateFunction(_)) => Err(proto_error( + LogicalPlan::Ddl(DdlStatement::CreateFunction(_), _) => Err(proto_error( "LogicalPlan serde is not yet implemented for CreateFunction", )), - LogicalPlan::Ddl(DdlStatement::DropFunction(_)) => Err(proto_error( + LogicalPlan::Ddl(DdlStatement::DropFunction(_), _) => Err(proto_error( "LogicalPlan serde is not yet implemented for DropFunction", )), - LogicalPlan::Statement(_) => Err(proto_error( + LogicalPlan::Statement(_, _) => Err(proto_error( "LogicalPlan serde is not yet implemented for Statement", )), - LogicalPlan::Dml(_) => Err(proto_error( + LogicalPlan::Dml(_, _) => Err(proto_error( "LogicalPlan serde is not yet implemented for Dml", )), - LogicalPlan::Copy(dml::CopyTo { - input, - output_url, - file_type, - partition_by, - .. - }) => { + LogicalPlan::Copy( + dml::CopyTo { + input, + output_url, + file_type, + partition_by, + .. + }, + _, + ) => { let input = LogicalPlanNode::try_from_logical_plan(input, extension_codec)?; let mut buf = Vec::new(); @@ -1682,7 +1720,7 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlan::DescribeTable(_) => Err(proto_error( "LogicalPlan serde is not yet implemented for DescribeTable", )), - LogicalPlan::RecursiveQuery(recursive) => { + LogicalPlan::RecursiveQuery(recursive, _) => { let static_term = LogicalPlanNode::try_from_logical_plan( recursive.static_term.as_ref(), extension_codec, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index b3c91207ecf32..ef6ef0d62a2fd 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -194,11 +194,14 @@ pub fn serialize_expr( Expr::Column(c) => protobuf::LogicalExprNode { expr_type: Some(ExprType::Column(c.into())), }, - Expr::Alias(Alias { - expr, - relation, - name, - }) => { + Expr::Alias( + Alias { + expr, + relation, + name, + }, + _, + ) => { let alias = Box::new(protobuf::AliasNode { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), relation: relation @@ -217,16 +220,19 @@ pub fn serialize_expr( expr_type: Some(ExprType::Literal(pb_value)), } } - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + Expr::BinaryExpr(BinaryExpr { left, op, right }, _) => { // Try to linerize a nested binary expression tree of the same operator // into a flat vector of expressions. let mut exprs = vec![right.as_ref()]; let mut current_expr = left.as_ref(); - while let Expr::BinaryExpr(BinaryExpr { - left, - op: current_op, - right, - }) = current_expr + while let Expr::BinaryExpr( + BinaryExpr { + left, + op: current_op, + right, + }, + _, + ) = current_expr { if current_op == op { exprs.push(right.as_ref()); @@ -248,13 +254,16 @@ pub fn serialize_expr( expr_type: Some(ExprType::BinaryExpr(binary_expr)), } } - Expr::Like(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - }) => { + Expr::Like( + Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + }, + _, + ) => { if *case_insensitive { let pb = Box::new(protobuf::ILikeNode { negated: *negated, @@ -279,13 +288,16 @@ pub fn serialize_expr( } } } - Expr::SimilarTo(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive: _, - }) => { + Expr::SimilarTo( + Like { + negated, + expr, + pattern, + escape_char, + case_insensitive: _, + }, + _, + ) => { let pb = Box::new(protobuf::SimilarToNode { negated: *negated, expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), @@ -296,15 +308,18 @@ pub fn serialize_expr( expr_type: Some(ExprType::SimilarTo(pb)), } } - Expr::WindowFunction(expr::WindowFunction { - ref fun, - ref args, - ref partition_by, - ref order_by, - ref window_frame, - // TODO: support null treatment in proto - null_treatment: _, - }) => { + Expr::WindowFunction( + expr::WindowFunction { + ref fun, + ref args, + ref partition_by, + ref order_by, + ref window_frame, + // TODO: support null treatment in proto + null_treatment: _, + }, + _, + ) => { let (window_function, fun_definition) = match fun { WindowFunctionDefinition::AggregateUDF(aggr_udf) => { let mut buf = Vec::new(); @@ -344,14 +359,17 @@ pub fn serialize_expr( expr_type: Some(ExprType::WindowExpr(window_expr)), } } - Expr::AggregateFunction(expr::AggregateFunction { - ref func, - ref args, - ref distinct, - ref filter, - ref order_by, - null_treatment: _, - }) => { + Expr::AggregateFunction( + expr::AggregateFunction { + ref func, + ref args, + ref distinct, + ref filter, + ref order_by, + null_treatment: _, + }, + _, + ) => { let mut buf = Vec::new(); let _ = codec.try_encode_udaf(func, &mut buf); protobuf::LogicalExprNode { @@ -379,7 +397,7 @@ pub fn serialize_expr( "Proto serialization error: Scalar Variable not supported".to_string(), )) } - Expr::ScalarFunction(ScalarFunction { func, args }) => { + Expr::ScalarFunction(ScalarFunction { func, args }, _) => { let mut buf = Vec::new(); let _ = codec.try_encode_udf(func, &mut buf); protobuf::LogicalExprNode { @@ -390,7 +408,7 @@ pub fn serialize_expr( })), } } - Expr::Not(expr) => { + Expr::Not(expr, _) => { let expr = Box::new(protobuf::Not { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), }); @@ -398,7 +416,7 @@ pub fn serialize_expr( expr_type: Some(ExprType::NotExpr(expr)), } } - Expr::IsNull(expr) => { + Expr::IsNull(expr, _) => { let expr = Box::new(protobuf::IsNull { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), }); @@ -406,7 +424,7 @@ pub fn serialize_expr( expr_type: Some(ExprType::IsNullExpr(expr)), } } - Expr::IsNotNull(expr) => { + Expr::IsNotNull(expr, _) => { let expr = Box::new(protobuf::IsNotNull { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), }); @@ -414,7 +432,7 @@ pub fn serialize_expr( expr_type: Some(ExprType::IsNotNullExpr(expr)), } } - Expr::IsTrue(expr) => { + Expr::IsTrue(expr, _) => { let expr = Box::new(protobuf::IsTrue { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), }); @@ -422,7 +440,7 @@ pub fn serialize_expr( expr_type: Some(ExprType::IsTrue(expr)), } } - Expr::IsFalse(expr) => { + Expr::IsFalse(expr, _) => { let expr = Box::new(protobuf::IsFalse { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), }); @@ -430,7 +448,7 @@ pub fn serialize_expr( expr_type: Some(ExprType::IsFalse(expr)), } } - Expr::IsUnknown(expr) => { + Expr::IsUnknown(expr, _) => { let expr = Box::new(protobuf::IsUnknown { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), }); @@ -438,7 +456,7 @@ pub fn serialize_expr( expr_type: Some(ExprType::IsUnknown(expr)), } } - Expr::IsNotTrue(expr) => { + Expr::IsNotTrue(expr, _) => { let expr = Box::new(protobuf::IsNotTrue { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), }); @@ -446,7 +464,7 @@ pub fn serialize_expr( expr_type: Some(ExprType::IsNotTrue(expr)), } } - Expr::IsNotFalse(expr) => { + Expr::IsNotFalse(expr, _) => { let expr = Box::new(protobuf::IsNotFalse { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), }); @@ -454,7 +472,7 @@ pub fn serialize_expr( expr_type: Some(ExprType::IsNotFalse(expr)), } } - Expr::IsNotUnknown(expr) => { + Expr::IsNotUnknown(expr, _) => { let expr = Box::new(protobuf::IsNotUnknown { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), }); @@ -462,12 +480,15 @@ pub fn serialize_expr( expr_type: Some(ExprType::IsNotUnknown(expr)), } } - Expr::Between(Between { - expr, - negated, - low, - high, - }) => { + Expr::Between( + Between { + expr, + negated, + low, + high, + }, + _, + ) => { let expr = Box::new(protobuf::BetweenNode { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), negated: *negated, @@ -478,7 +499,7 @@ pub fn serialize_expr( expr_type: Some(ExprType::Between(expr)), } } - Expr::Case(case) => { + Expr::Case(case, _) => { let when_then_expr = case .when_then_expr .iter() @@ -504,7 +525,7 @@ pub fn serialize_expr( expr_type: Some(ExprType::Case(expr)), } } - Expr::Cast(Cast { expr, data_type }) => { + Expr::Cast(Cast { expr, data_type }, _) => { let expr = Box::new(protobuf::CastNode { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), arrow_type: Some(data_type.try_into()?), @@ -513,7 +534,7 @@ pub fn serialize_expr( expr_type: Some(ExprType::Cast(expr)), } } - Expr::TryCast(TryCast { expr, data_type }) => { + Expr::TryCast(TryCast { expr, data_type }, _) => { let expr = Box::new(protobuf::TryCastNode { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), arrow_type: Some(data_type.try_into()?), @@ -522,7 +543,7 @@ pub fn serialize_expr( expr_type: Some(ExprType::TryCast(expr)), } } - Expr::Negative(expr) => { + Expr::Negative(expr, _) => { let expr = Box::new(protobuf::NegativeNode { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), }); @@ -530,7 +551,7 @@ pub fn serialize_expr( expr_type: Some(ExprType::Negative(expr)), } } - Expr::Unnest(Unnest { expr }) => { + Expr::Unnest(Unnest { expr }, _) => { let expr = protobuf::Unnest { exprs: vec![serialize_expr(expr.as_ref(), codec)?], }; @@ -538,11 +559,14 @@ pub fn serialize_expr( expr_type: Some(ExprType::Unnest(expr)), } } - Expr::InList(InList { - expr, - list, - negated, - }) => { + Expr::InList( + InList { + expr, + list, + negated, + }, + _, + ) => { let expr = Box::new(protobuf::InListNode { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), list: serialize_exprs(list, codec)?, @@ -552,30 +576,30 @@ pub fn serialize_expr( expr_type: Some(ExprType::InList(expr)), } } - Expr::Wildcard(Wildcard { qualifier, .. }) => protobuf::LogicalExprNode { + Expr::Wildcard(Wildcard { qualifier, .. }, _) => protobuf::LogicalExprNode { expr_type: Some(ExprType::Wildcard(protobuf::Wildcard { qualifier: qualifier.to_owned().map(|x| x.into()), })), }, - Expr::ScalarSubquery(_) - | Expr::InSubquery(_) + Expr::ScalarSubquery(_, _) + | Expr::InSubquery(_, _) | Expr::Exists { .. } | Expr::OuterReferenceColumn { .. } => { // we would need to add logical plan operators to datafusion.proto to support this // see discussion in https://github.com/apache/datafusion/issues/2565 return Err(Error::General("Proto serialization error: Expr::ScalarSubquery(_) | Expr::InSubquery(_) | Expr::Exists { .. } | Exp:OuterReferenceColumn not supported".to_string())); } - Expr::GroupingSet(GroupingSet::Cube(exprs)) => protobuf::LogicalExprNode { + Expr::GroupingSet(GroupingSet::Cube(exprs), _) => protobuf::LogicalExprNode { expr_type: Some(ExprType::Cube(CubeNode { expr: serialize_exprs(exprs, codec)?, })), }, - Expr::GroupingSet(GroupingSet::Rollup(exprs)) => protobuf::LogicalExprNode { + Expr::GroupingSet(GroupingSet::Rollup(exprs), _) => protobuf::LogicalExprNode { expr_type: Some(ExprType::Rollup(RollupNode { expr: serialize_exprs(exprs, codec)?, })), }, - Expr::GroupingSet(GroupingSet::GroupingSets(exprs)) => { + Expr::GroupingSet(GroupingSet::GroupingSets(exprs), _) => { protobuf::LogicalExprNode { expr_type: Some(ExprType::GroupingSet(GroupingSetNode { expr: exprs diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index deece2e54a5a6..9bb674ae67017 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -133,7 +133,7 @@ async fn roundtrip_logical_plan() -> Result<()> { ctx.register_csv("t1", "tests/testdata/test.csv", CsvReadOptions::default()) .await?; let scan = ctx.table("t1").await?.into_optimized_plan()?; - let topk_plan = LogicalPlan::Extension(Extension { + let topk_plan = LogicalPlan::extension(Extension { node: Arc::new(TopKPlanNode::new(3, scan, col("revenue"))), }); let extension_codec = TopKExtensionCodec {}; @@ -380,7 +380,7 @@ async fn roundtrip_logical_plan_copy_to_sql_options() -> Result<()> { let input = create_csv_scan(&ctx).await?; let file_type = format_as_file_type(Arc::new(CsvFormatFactory::new())); - let plan = LogicalPlan::Copy(CopyTo { + let plan = LogicalPlan::copy(CopyTo { input: Arc::new(input), output_url: "test.csv".to_string(), partition_by: vec!["a".to_string(), "b".to_string(), "c".to_string()], @@ -420,7 +420,7 @@ async fn roundtrip_logical_plan_copy_to_writer_options() -> Result<()> { ParquetFormatFactory::new_with_options(parquet_format), )); - let plan = LogicalPlan::Copy(CopyTo { + let plan = LogicalPlan::copy(CopyTo { input: Arc::new(input), output_url: "test.parquet".to_string(), file_type, @@ -434,7 +434,7 @@ async fn roundtrip_logical_plan_copy_to_writer_options() -> Result<()> { logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &codec)?; assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); match logical_round_trip { - LogicalPlan::Copy(copy_to) => { + LogicalPlan::Copy(copy_to, _) => { assert_eq!("test.parquet", copy_to.output_url); assert_eq!(vec!["a", "b", "c"], copy_to.partition_by); assert_eq!(copy_to.file_type.get_ext(), "parquet".to_string()); @@ -452,7 +452,7 @@ async fn roundtrip_logical_plan_copy_to_arrow() -> Result<()> { let file_type = format_as_file_type(Arc::new(ArrowFormatFactory::new())); - let plan = LogicalPlan::Copy(CopyTo { + let plan = LogicalPlan::copy(CopyTo { input: Arc::new(input), output_url: "test.arrow".to_string(), partition_by: vec!["a".to_string(), "b".to_string(), "c".to_string()], @@ -467,7 +467,7 @@ async fn roundtrip_logical_plan_copy_to_arrow() -> Result<()> { assert_eq!(format!("{plan}"), format!("{logical_round_trip}")); match logical_round_trip { - LogicalPlan::Copy(copy_to) => { + LogicalPlan::Copy(copy_to, _) => { assert_eq!("test.arrow", copy_to.output_url); assert_eq!("arrow".to_string(), copy_to.file_type.get_ext()); assert_eq!(vec!["a", "b", "c"], copy_to.partition_by); @@ -499,7 +499,7 @@ async fn roundtrip_logical_plan_copy_to_csv() -> Result<()> { csv_format.clone(), ))); - let plan = LogicalPlan::Copy(CopyTo { + let plan = LogicalPlan::copy(CopyTo { input: Arc::new(input), output_url: "test.csv".to_string(), partition_by: vec!["a".to_string(), "b".to_string(), "c".to_string()], @@ -514,7 +514,7 @@ async fn roundtrip_logical_plan_copy_to_csv() -> Result<()> { assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); match logical_round_trip { - LogicalPlan::Copy(copy_to) => { + LogicalPlan::Copy(copy_to, _) => { assert_eq!("test.csv", copy_to.output_url); assert_eq!("csv".to_string(), copy_to.file_type.get_ext()); assert_eq!(vec!["a", "b", "c"], copy_to.partition_by); @@ -565,7 +565,7 @@ async fn roundtrip_logical_plan_copy_to_json() -> Result<()> { json_format.clone(), ))); - let plan = LogicalPlan::Copy(CopyTo { + let plan = LogicalPlan::copy(CopyTo { input: Arc::new(input), output_url: "test.json".to_string(), partition_by: vec!["a".to_string(), "b".to_string(), "c".to_string()], @@ -581,7 +581,7 @@ async fn roundtrip_logical_plan_copy_to_json() -> Result<()> { assert_eq!(format!("{plan}"), format!("{logical_round_trip}")); match logical_round_trip { - LogicalPlan::Copy(copy_to) => { + LogicalPlan::Copy(copy_to, _) => { assert_eq!("test.json", copy_to.output_url); assert_eq!("json".to_string(), copy_to.file_type.get_ext()); assert_eq!(vec!["a", "b", "c"], copy_to.partition_by); @@ -637,7 +637,7 @@ async fn roundtrip_logical_plan_copy_to_parquet() -> Result<()> { ParquetFormatFactory::new_with_options(parquet_format.clone()), )); - let plan = LogicalPlan::Copy(CopyTo { + let plan = LogicalPlan::copy(CopyTo { input: Arc::new(input), output_url: "test.parquet".to_string(), partition_by: vec!["a".to_string(), "b".to_string(), "c".to_string()], @@ -653,7 +653,7 @@ async fn roundtrip_logical_plan_copy_to_parquet() -> Result<()> { assert_eq!(format!("{plan}"), format!("{logical_round_trip}")); match logical_round_trip { - LogicalPlan::Copy(copy_to) => { + LogicalPlan::Copy(copy_to, _) => { assert_eq!("test.parquet", copy_to.output_url); assert_eq!("parquet".to_string(), copy_to.file_type.get_ext()); assert_eq!(vec!["a", "b", "c"], copy_to.partition_by); @@ -1902,7 +1902,7 @@ fn roundtrip_dfschema() { #[test] fn roundtrip_not() { - let test_expr = Expr::Not(Box::new(lit(1.0_f32))); + let test_expr = Expr::_not(Box::new(lit(1.0_f32))); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); @@ -1910,7 +1910,7 @@ fn roundtrip_not() { #[test] fn roundtrip_is_null() { - let test_expr = Expr::IsNull(Box::new(col("id"))); + let test_expr = Expr::_is_null(Box::new(col("id"))); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); @@ -1918,7 +1918,7 @@ fn roundtrip_is_null() { #[test] fn roundtrip_is_not_null() { - let test_expr = Expr::IsNotNull(Box::new(col("id"))); + let test_expr = Expr::_is_not_null(Box::new(col("id"))); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); @@ -1926,7 +1926,7 @@ fn roundtrip_is_not_null() { #[test] fn roundtrip_between() { - let test_expr = Expr::Between(Between::new( + let test_expr = Expr::_between(Between::new( Box::new(lit(1.0_f32)), true, Box::new(lit(2.0_f32)), @@ -1940,7 +1940,7 @@ fn roundtrip_between() { #[test] fn roundtrip_binary_op() { fn test(op: Operator) { - let test_expr = Expr::BinaryExpr(BinaryExpr::new( + let test_expr = Expr::binary_expr(BinaryExpr::new( Box::new(lit(1.0_f32)), op, Box::new(lit(2.0_f32)), @@ -1974,7 +1974,7 @@ fn roundtrip_binary_op() { #[test] fn roundtrip_case() { - let test_expr = Expr::Case(Case::new( + let test_expr = Expr::case(Case::new( Some(Box::new(lit(1.0_f32))), vec![(Box::new(lit(2.0_f32)), Box::new(lit(3.0_f32)))], Some(Box::new(lit(4.0_f32))), @@ -1986,7 +1986,7 @@ fn roundtrip_case() { #[test] fn roundtrip_case_with_null() { - let test_expr = Expr::Case(Case::new( + let test_expr = Expr::case(Case::new( Some(Box::new(lit(1.0_f32))), vec![(Box::new(lit(2.0_f32)), Box::new(lit(3.0_f32)))], Some(Box::new(Expr::Literal(ScalarValue::Null))), @@ -2006,7 +2006,7 @@ fn roundtrip_null_literal() { #[test] fn roundtrip_cast() { - let test_expr = Expr::Cast(Cast::new(Box::new(lit(1.0_f32)), DataType::Boolean)); + let test_expr = Expr::cast(Cast::new(Box::new(lit(1.0_f32)), DataType::Boolean)); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); @@ -2015,13 +2015,13 @@ fn roundtrip_cast() { #[test] fn roundtrip_try_cast() { let test_expr = - Expr::TryCast(TryCast::new(Box::new(lit(1.0_f32)), DataType::Boolean)); + Expr::try_cast(TryCast::new(Box::new(lit(1.0_f32)), DataType::Boolean)); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); let test_expr = - Expr::TryCast(TryCast::new(Box::new(lit("not a bool")), DataType::Boolean)); + Expr::try_cast(TryCast::new(Box::new(lit("not a bool")), DataType::Boolean)); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); @@ -2029,7 +2029,7 @@ fn roundtrip_try_cast() { #[test] fn roundtrip_negative() { - let test_expr = Expr::Negative(Box::new(lit(1.0_f32))); + let test_expr = Expr::negative(Box::new(lit(1.0_f32))); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); @@ -2037,7 +2037,7 @@ fn roundtrip_negative() { #[test] fn roundtrip_inlist() { - let test_expr = Expr::InList(InList::new( + let test_expr = Expr::_in_list(InList::new( Box::new(lit(1.0_f32)), vec![lit(2.0_f32)], true, @@ -2049,7 +2049,7 @@ fn roundtrip_inlist() { #[test] fn roundtrip_unnest() { - let test_expr = Expr::Unnest(Unnest { + let test_expr = Expr::unnest(Unnest { expr: Box::new(col("col")), }); @@ -2059,7 +2059,7 @@ fn roundtrip_unnest() { #[test] fn roundtrip_wildcard() { - let test_expr = Expr::Wildcard(Wildcard { + let test_expr = Expr::wildcard(Wildcard { qualifier: None, options: WildcardOptions::default(), }); @@ -2070,7 +2070,7 @@ fn roundtrip_wildcard() { #[test] fn roundtrip_qualified_wildcard() { - let test_expr = Expr::Wildcard(Wildcard { + let test_expr = Expr::wildcard(Wildcard { qualifier: Some("foo".into()), options: WildcardOptions::default(), }); @@ -2082,7 +2082,7 @@ fn roundtrip_qualified_wildcard() { #[test] fn roundtrip_like() { fn like(negated: bool, escape_char: Option) { - let test_expr = Expr::Like(Like::new( + let test_expr = Expr::_like(Like::new( negated, Box::new(col("col")), Box::new(lit("[0-9]+")), @@ -2101,7 +2101,7 @@ fn roundtrip_like() { #[test] fn roundtrip_ilike() { fn ilike(negated: bool, escape_char: Option) { - let test_expr = Expr::Like(Like::new( + let test_expr = Expr::_like(Like::new( negated, Box::new(col("col")), Box::new(lit("[0-9]+")), @@ -2120,7 +2120,7 @@ fn roundtrip_ilike() { #[test] fn roundtrip_similar_to() { fn similar_to(negated: bool, escape_char: Option) { - let test_expr = Expr::SimilarTo(Like::new( + let test_expr = Expr::similar_to(Like::new( negated, Box::new(col("col")), Box::new(lit("[0-9]+")), @@ -2195,7 +2195,7 @@ fn roundtrip_aggregate_udf() { Arc::new(vec![DataType::Float64, DataType::UInt32]), ); - let test_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( + let test_expr = Expr::aggregate_function(expr::AggregateFunction::new_udf( Arc::new(dummy_agg.clone()), vec![lit(1.0_f64)], false, @@ -2227,7 +2227,7 @@ fn roundtrip_scalar_udf() { scalar_fn, ); - let test_expr = Expr::ScalarFunction(ScalarFunction::new_udf( + let test_expr = Expr::scalar_function(ScalarFunction::new_udf( Arc::new(udf.clone()), vec![lit("")], )); @@ -2266,7 +2266,7 @@ fn roundtrip_aggregate_udf_extension_codec() { #[test] fn roundtrip_grouping_sets() { - let test_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![ + let test_expr = Expr::grouping_set(GroupingSet::GroupingSets(vec![ vec![col("a")], vec![col("b")], vec![col("a"), col("b")], @@ -2278,7 +2278,7 @@ fn roundtrip_grouping_sets() { #[test] fn roundtrip_rollup() { - let test_expr = Expr::GroupingSet(GroupingSet::Rollup(vec![col("a"), col("b")])); + let test_expr = Expr::grouping_set(GroupingSet::Rollup(vec![col("a"), col("b")])); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); @@ -2286,7 +2286,7 @@ fn roundtrip_rollup() { #[test] fn roundtrip_cube() { - let test_expr = Expr::GroupingSet(GroupingSet::Cube(vec![col("a"), col("b")])); + let test_expr = Expr::grouping_set(GroupingSet::Cube(vec![col("a"), col("b")])); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); @@ -2305,13 +2305,13 @@ fn roundtrip_substr() { .unwrap(); // substr(string, position) - let test_expr = Expr::ScalarFunction(ScalarFunction::new_udf( + let test_expr = Expr::scalar_function(ScalarFunction::new_udf( fun.clone(), vec![col("col"), lit(1_i64)], )); // substr(string, position, count) - let test_expr_with_count = Expr::ScalarFunction(ScalarFunction::new_udf( + let test_expr_with_count = Expr::scalar_function(ScalarFunction::new_udf( fun, vec![col("col"), lit(1_i64), lit(1_i64)], )); @@ -2324,7 +2324,7 @@ fn roundtrip_window() { let ctx = SessionContext::new(); // 1. without window_frame - let test_expr1 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr1 = Expr::window_function(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(rank_udwf()), vec![], )) @@ -2335,7 +2335,7 @@ fn roundtrip_window() { .unwrap(); // 2. with default window_frame - let test_expr2 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr2 = Expr::window_function(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(rank_udwf()), vec![], )) @@ -2352,7 +2352,7 @@ fn roundtrip_window() { WindowFrameBound::Following(ScalarValue::UInt64(Some(2))), ); - let test_expr3 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr3 = Expr::window_function(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(rank_udwf()), vec![], )) @@ -2369,7 +2369,7 @@ fn roundtrip_window() { WindowFrameBound::Following(ScalarValue::UInt64(Some(2))), ); - let test_expr4 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr4 = Expr::window_function(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("col1")], )) @@ -2419,7 +2419,7 @@ fn roundtrip_window() { Arc::new(vec![DataType::Float64, DataType::UInt32]), ); - let test_expr5 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr5 = Expr::window_function(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(Arc::new(dummy_agg.clone())), vec![col("col1")], )) @@ -2500,7 +2500,7 @@ fn roundtrip_window() { let dummy_window_udf = WindowUDF::from(SimpleWindowUDF::new()); - let test_expr6 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr6 = Expr::window_function(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(Arc::new(dummy_window_udf.clone())), vec![col("col1")], )) @@ -2510,7 +2510,7 @@ fn roundtrip_window() { .build() .unwrap(); - let text_expr7 = Expr::WindowFunction(expr::WindowFunction::new( + let text_expr7 = Expr::window_function(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(avg_udaf()), vec![col("col1")], )) diff --git a/datafusion/proto/tests/cases/serialize.rs b/datafusion/proto/tests/cases/serialize.rs index d1b50105d053d..dac8cff12e530 100644 --- a/datafusion/proto/tests/cases/serialize.rs +++ b/datafusion/proto/tests/cases/serialize.rs @@ -261,7 +261,7 @@ fn test_expression_serialization_roundtrip() { // default to 4 args (though some exprs like substr have error checking) let num_args = 4; let args: Vec<_> = std::iter::repeat(&lit).take(num_args).cloned().collect(); - let expr = Expr::ScalarFunction(ScalarFunction::new_udf(function, args)); + let expr = Expr::scalar_function(ScalarFunction::new_udf(function, args)); let extension_codec = DefaultLogicalExtensionCodec {}; let proto = serialize_expr(&expr, &extension_codec).unwrap(); diff --git a/datafusion/sql/src/cte.rs b/datafusion/sql/src/cte.rs index c288d6ca70674..71c94890b93c6 100644 --- a/datafusion/sql/src/cte.rs +++ b/datafusion/sql/src/cte.rs @@ -193,7 +193,7 @@ fn has_work_table_reference( ) -> bool { let mut has_reference = false; plan.apply(|node| { - if let LogicalPlan::TableScan(scan) = node { + if let LogicalPlan::TableScan(scan, _) = node { if Arc::ptr_eq(&scan.source, work_table_source) { has_reference = true; return Ok(TreeNodeRecursion::Stop); diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 0ce7c891a6085..90b3242ce1733 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -22,8 +22,8 @@ use datafusion_common::{ internal_datafusion_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema, Dependency, Result, }; -use datafusion_expr::expr::{Wildcard, WildcardOptions}; use datafusion_expr::expr::{ScalarFunction, Unnest}; +use datafusion_expr::expr::{Wildcard, WildcardOptions}; use datafusion_expr::planner::PlannerResult; use datafusion_expr::{ expr, Expr, ExprFunctionExt, ExprSchemable, WindowFrame, WindowFunctionDefinition, @@ -235,7 +235,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // User-defined function (UDF) should have precedence if let Some(fm) = self.context_provider.get_function_meta(&name) { let args = self.function_args_to_expr(args, schema, planner_context)?; - return Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fm, args))); + return Ok(Expr::scalar_function(ScalarFunction::new_udf(fm, args))); } // Build Unnest expression @@ -246,7 +246,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } let expr = exprs.swap_remove(0); Self::check_unnest_arg(&expr, schema)?; - return Ok(Expr::Unnest(Unnest::new(expr))); + return Ok(Expr::unnest(Unnest::new(expr))); } if !order_by.is_empty() && is_function_window { @@ -310,7 +310,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }; if let Ok(fun) = self.find_window_func(&name) { - return Expr::WindowFunction(expr::WindowFunction::new( + return Expr::window_function(expr::WindowFunction::new( fun, self.function_args_to_expr(args, schema, planner_context)?, )) @@ -336,7 +336,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .map(|e| self.sql_expr_to_logical_expr(*e, schema, planner_context)) .transpose()? .map(Box::new); - return Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( + return Ok(Expr::aggregate_function(expr::AggregateFunction::new_udf( fm, args, distinct, @@ -371,7 +371,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { internal_datafusion_err!("Unable to find expected '{fn_name}' function") })?; let args = vec![self.sql_expr_to_logical_expr(expr, schema, planner_context)?]; - Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, args))) + Ok(Expr::scalar_function(ScalarFunction::new_udf(fun, args))) } pub(super) fn find_window_func( @@ -413,7 +413,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { name: _, arg: FunctionArgExpr::Wildcard, operator: _, - } => Ok(Expr::Wildcard(Wildcard { + } => Ok(Expr::wildcard(Wildcard { qualifier: None, options: WildcardOptions::default(), })), @@ -421,7 +421,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { self.sql_expr_to_logical_expr(arg, schema, planner_context) } FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => { - Ok(Expr::Wildcard(Wildcard { + Ok(Expr::wildcard(Wildcard { qualifier: None, options: WildcardOptions::default(), })) @@ -433,7 +433,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if qualified_indices.is_empty() { return plan_err!("Invalid qualifier {qualifier}"); } - Ok(Expr::Wildcard(Wildcard { + Ok(Expr::wildcard(Wildcard { qualifier: Some(qualifier), options: WildcardOptions::default(), })) diff --git a/datafusion/sql/src/expr/grouping_set.rs b/datafusion/sql/src/expr/grouping_set.rs index a8b3ef7e20ec2..ad8613e4596cc 100644 --- a/datafusion/sql/src/expr/grouping_set.rs +++ b/datafusion/sql/src/expr/grouping_set.rs @@ -36,7 +36,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .collect() }) .collect(); - Ok(Expr::GroupingSet(GroupingSet::GroupingSets(args?))) + Ok(Expr::grouping_set(GroupingSet::GroupingSets(args?))) } pub(super) fn sql_rollup_to_expr( @@ -57,7 +57,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } }) .collect(); - Ok(Expr::GroupingSet(GroupingSet::Rollup(args?))) + Ok(Expr::grouping_set(GroupingSet::Rollup(args?))) } pub(super) fn sql_cube_to_expr( @@ -76,6 +76,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } }) .collect(); - Ok(Expr::GroupingSet(GroupingSet::Cube(args?))) + Ok(Expr::grouping_set(GroupingSet::Cube(args?))) } } diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index e103f68fc9275..f8a0e929d7bf4 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -223,7 +223,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { None }; - Ok(Expr::Case(Case::new( + Ok(Expr::case(Case::new( expr, when_expr .iter() diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 3e6049dc2f067..df2c2867814f7 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -124,7 +124,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } let RawBinaryExpr { op, left, right } = binary_expr; - Ok(Expr::BinaryExpr(BinaryExpr::new( + Ok(Expr::binary_expr(BinaryExpr::new( Box::new(left), self.parse_sql_binary_op(op)?, Box::new(right), @@ -253,7 +253,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { return not_impl_err!("CAST with format is not supported: {format}"); } - Ok(Expr::TryCast(TryCast::new( + Ok(Expr::try_cast(TryCast::new( Box::new(self.sql_expr_to_logical_expr( *expr, schema, @@ -263,21 +263,21 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ))) } - SQLExpr::TypedString { data_type, value } => Ok(Expr::Cast(Cast::new( + SQLExpr::TypedString { data_type, value } => Ok(Expr::cast(Cast::new( Box::new(lit(value)), self.convert_data_type(&data_type)?, ))), - SQLExpr::IsNull(expr) => Ok(Expr::IsNull(Box::new( + SQLExpr::IsNull(expr) => Ok(Expr::_is_null(Box::new( self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, ))), - SQLExpr::IsNotNull(expr) => Ok(Expr::IsNotNull(Box::new( + SQLExpr::IsNotNull(expr) => Ok(Expr::_is_not_null(Box::new( self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, ))), SQLExpr::IsDistinctFrom(left, right) => { - Ok(Expr::BinaryExpr(BinaryExpr::new( + Ok(Expr::binary_expr(BinaryExpr::new( Box::new(self.sql_expr_to_logical_expr( *left, schema, @@ -293,7 +293,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } SQLExpr::IsNotDistinctFrom(left, right) => { - Ok(Expr::BinaryExpr(BinaryExpr::new( + Ok(Expr::binary_expr(BinaryExpr::new( Box::new(self.sql_expr_to_logical_expr( *left, schema, @@ -308,27 +308,27 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ))) } - SQLExpr::IsTrue(expr) => Ok(Expr::IsTrue(Box::new( + SQLExpr::IsTrue(expr) => Ok(Expr::_is_true(Box::new( self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, ))), - SQLExpr::IsFalse(expr) => Ok(Expr::IsFalse(Box::new( + SQLExpr::IsFalse(expr) => Ok(Expr::_is_false(Box::new( self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, ))), - SQLExpr::IsNotTrue(expr) => Ok(Expr::IsNotTrue(Box::new( + SQLExpr::IsNotTrue(expr) => Ok(Expr::_is_not_true(Box::new( self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, ))), - SQLExpr::IsNotFalse(expr) => Ok(Expr::IsNotFalse(Box::new( + SQLExpr::IsNotFalse(expr) => Ok(Expr::_is_not_false(Box::new( self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, ))), - SQLExpr::IsUnknown(expr) => Ok(Expr::IsUnknown(Box::new( + SQLExpr::IsUnknown(expr) => Ok(Expr::_is_unknown(Box::new( self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, ))), - SQLExpr::IsNotUnknown(expr) => Ok(Expr::IsNotUnknown(Box::new( + SQLExpr::IsNotUnknown(expr) => Ok(Expr::_is_not_unknown(Box::new( self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, ))), @@ -341,7 +341,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { negated, low, high, - } => Ok(Expr::Between(Between::new( + } => Ok(Expr::_between(Between::new( Box::new(self.sql_expr_to_logical_expr( *expr, schema, @@ -509,7 +509,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::AtTimeZone { timestamp, time_zone, - } => Ok(Expr::Cast(Cast::new( + } => Ok(Expr::cast(Cast::new( Box::new(self.sql_expr_to_logical_expr_internal( *timestamp, schema, @@ -565,11 +565,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } not_impl_err!("AnyOp not supported by ExprPlanner: {binary_expr:?}") } - SQLExpr::Wildcard => Ok(Expr::Wildcard(Wildcard { + SQLExpr::Wildcard => Ok(Expr::wildcard(Wildcard { qualifier: None, options: WildcardOptions::default(), })), - SQLExpr::QualifiedWildcard(object_name) => Ok(Expr::Wildcard(Wildcard { + SQLExpr::QualifiedWildcard(object_name) => Ok(Expr::wildcard(Wildcard { qualifier: Some(self.object_name_to_table_reference(object_name)?), options: WildcardOptions::default(), })), @@ -769,7 +769,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .map(|e| self.sql_expr_to_logical_expr(e, schema, planner_context)) .collect::>>()?; - Ok(Expr::InList(InList::new( + Ok(Expr::_in_list(InList::new( Box::new(self.sql_expr_to_logical_expr(expr, schema, planner_context)?), list_expr, negated, @@ -804,7 +804,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } else { None }; - Ok(Expr::Like(Like::new( + Ok(Expr::_like(Like::new( negated, Box::new(self.sql_expr_to_logical_expr(expr, schema, planner_context)?), Box::new(pattern), @@ -835,7 +835,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } else { None }; - Ok(Expr::SimilarTo(Like::new( + Ok(Expr::similar_to(Like::new( negated, Box::new(self.sql_expr_to_logical_expr(expr, schema, planner_context)?), Box::new(pattern), @@ -891,7 +891,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { internal_datafusion_err!("Unable to find expected '{fun_name}' function") })?; - Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, args))) + Ok(Expr::scalar_function(ScalarFunction::new_udf(fun, args))) } fn sql_overlay_to_expr( @@ -946,7 +946,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { DataType::Timestamp(TimeUnit::Nanosecond, tz) if expr.get_type(schema)? == DataType::Int64 => { - Expr::Cast(Cast::new( + Expr::cast(Cast::new( Box::new(expr), DataType::Timestamp(TimeUnit::Second, tz.clone()), )) @@ -954,7 +954,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { _ => expr, }; - Ok(Expr::Cast(Cast::new(Box::new(expr), dt))) + Ok(Expr::cast(Cast::new(Box::new(expr), dt))) } fn sql_subscript_to_expr( diff --git a/datafusion/sql/src/expr/subquery.rs b/datafusion/sql/src/expr/subquery.rs index ff161c6ed644e..45df1e27acab1 100644 --- a/datafusion/sql/src/expr/subquery.rs +++ b/datafusion/sql/src/expr/subquery.rs @@ -37,7 +37,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let sub_plan = self.query_to_plan(subquery, planner_context)?; let outer_ref_columns = sub_plan.all_out_ref_exprs(); planner_context.set_outer_query_schema(old_outer_query_schema); - Ok(Expr::Exists(Exists { + Ok(Expr::exists(Exists { subquery: Subquery { subquery: Arc::new(sub_plan), outer_ref_columns, @@ -60,7 +60,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let outer_ref_columns = sub_plan.all_out_ref_exprs(); planner_context.set_outer_query_schema(old_outer_query_schema); let expr = Box::new(self.sql_to_expr(expr, input_schema, planner_context)?); - Ok(Expr::InSubquery(InSubquery::new( + Ok(Expr::in_subquery(InSubquery::new( expr, Subquery { subquery: Arc::new(sub_plan), @@ -81,7 +81,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let sub_plan = self.query_to_plan(subquery, planner_context)?; let outer_ref_columns = sub_plan.all_out_ref_exprs(); planner_context.set_outer_query_schema(old_outer_query_schema); - Ok(Expr::ScalarSubquery(Subquery { + Ok(Expr::scalar_subquery(Subquery { subquery: Arc::new(sub_plan), outer_ref_columns, })) diff --git a/datafusion/sql/src/expr/unary_op.rs b/datafusion/sql/src/expr/unary_op.rs index 06988eb03893b..5b9a3d3151951 100644 --- a/datafusion/sql/src/expr/unary_op.rs +++ b/datafusion/sql/src/expr/unary_op.rs @@ -32,7 +32,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { planner_context: &mut PlannerContext, ) -> Result { match op { - UnaryOperator::Not => Ok(Expr::Not(Box::new( + UnaryOperator::Not => Ok(Expr::_not(Box::new( self.sql_expr_to_logical_expr(expr, schema, planner_context)?, ))), UnaryOperator::Plus => { @@ -59,7 +59,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { self.sql_interval_to_expr(true, interval) } // Not a literal, apply negative operator on expression - _ => Ok(Expr::Negative(Box::new(self.sql_expr_to_logical_expr( + _ => Ok(Expr::negative(Box::new(self.sql_expr_to_logical_expr( expr, schema, planner_context, diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index 1cf090aa64aa6..34b7d7822ea3c 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -223,7 +223,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fractional_seconds_precision: None, }, )?; - return Ok(Expr::BinaryExpr(BinaryExpr::new( + return Ok(Expr::binary_expr(BinaryExpr::new( Box::new(left_expr), df_op, Box::new(right_expr), diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index 740f9ad3b42c3..f88096852d621 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -116,11 +116,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { return Ok(plan); } - if let LogicalPlan::Distinct(Distinct::On(ref distinct_on)) = plan { + if let LogicalPlan::Distinct(Distinct::On(ref distinct_on), _) = plan { // In case of `DISTINCT ON` we must capture the sort expressions since during the plan // optimization we're effectively doing a `first_value` aggregation according to them. let distinct_on = distinct_on.clone().with_sort_expr(order_by)?; - Ok(LogicalPlan::Distinct(Distinct::On(distinct_on))) + Ok(LogicalPlan::distinct(Distinct::On(distinct_on))) } else { LogicalPlanBuilder::from(plan).sort(order_by)?.build() } @@ -133,7 +133,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { select_into: Option, ) -> Result { match select_into { - Some(into) => Ok(LogicalPlan::Ddl(DdlStatement::CreateMemoryTable( + Some(into) => Ok(LogicalPlan::ddl(DdlStatement::CreateMemoryTable( CreateMemoryTable { name: self.object_name_to_table_reference(into.name)?, constraints: Constraints::empty(), diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index 256cc58e71dc4..54ba85bb3f1e6 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -128,7 +128,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { planner_context, )?; Self::check_unnest_arg(&expr, &schema)?; - Ok(Expr::Unnest(Unnest::new(expr))) + Ok(Expr::unnest(Unnest::new(expr))) }) .collect::>>()?; if unnest_exprs.is_empty() { @@ -189,16 +189,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { planner_context.set_outer_from_schema(Some(old_from_schema)); match plan { - LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }) => { + LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }, _) => { subquery_alias( - LogicalPlan::Subquery(Subquery { + LogicalPlan::subquery(Subquery { subquery: input, outer_ref_columns, }), alias, ) } - plan => Ok(LogicalPlan::Subquery(Subquery { + plan => Ok(LogicalPlan::subquery(Subquery { subquery: Arc::new(plan), outer_ref_columns, })), @@ -215,17 +215,17 @@ fn optimize_subquery_sort(plan: LogicalPlan) -> Result> // 3. LIMIT => Handled by a `Sort`, so we need to search for it. let mut has_limit = false; let new_plan = plan.transform_down(|c| { - if let LogicalPlan::Limit(_) = c { + if let LogicalPlan::Limit(_, _) = c { has_limit = true; return Ok(Transformed::no(c)); } match c { - LogicalPlan::Sort(s) => { + LogicalPlan::Sort(s, _) => { if !has_limit { has_limit = false; return Ok(Transformed::yes(s.input.as_ref().clone())); } - Ok(Transformed::no(LogicalPlan::Sort(s))) + Ok(Transformed::no(LogicalPlan::sort(s))) } _ => Ok(Transformed::no(c)), } diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 80a08da5e35d6..a12fcaa05f5fd 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -178,9 +178,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { select_exprs .iter() .filter(|select_expr| match select_expr { - Expr::AggregateFunction(_) => false, - Expr::Alias(Alias { expr, name: _, .. }) => { - !matches!(**expr, Expr::AggregateFunction(_)) + Expr::AggregateFunction(_, _) => false, + Expr::Alias(Alias { expr, name: _, .. }, _) => { + !matches!(**expr, Expr::AggregateFunction(_, _)) } _ => true, }) @@ -364,7 +364,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fn try_process_aggregate_unnest(&self, input: LogicalPlan) -> Result { match input { - LogicalPlan::Aggregate(agg) => { + LogicalPlan::Aggregate(agg, _) => { let agg_expr = agg.aggr_expr.clone(); let (new_input, new_group_by_exprs) = self.try_process_group_by_unnest(agg)?; @@ -372,12 +372,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .aggregate(new_group_by_exprs, agg_expr)? .build() } - LogicalPlan::Filter(mut filter) => { + LogicalPlan::Filter(mut filter, _) => { filter.input = Arc::new(self.try_process_aggregate_unnest(Arc::unwrap_or_clone( filter.input, ))?); - Ok(LogicalPlan::Filter(filter)) + Ok(LogicalPlan::filter(filter)) } _ => Ok(input), } @@ -519,7 +519,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &[using_columns], )?; - Ok(LogicalPlan::Filter(Filter::try_new( + Ok(LogicalPlan::filter(Filter::try_new( filter_expr, Arc::new(plan), )?)) @@ -748,7 +748,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let plan = LogicalPlanBuilder::from(input.clone()) .aggregate(group_by_exprs.to_vec(), aggr_exprs.to_vec())? .build()?; - let group_by_exprs = if let LogicalPlan::Aggregate(agg) = &plan { + let group_by_exprs = if let LogicalPlan::Aggregate(agg, _) = &plan { &agg.group_expr } else { unreachable!(); @@ -764,13 +764,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut aggr_projection_exprs = vec![]; for expr in group_by_exprs { match expr { - Expr::GroupingSet(GroupingSet::Rollup(exprs)) => { + Expr::GroupingSet(GroupingSet::Rollup(exprs), _) => { aggr_projection_exprs.extend_from_slice(exprs) } - Expr::GroupingSet(GroupingSet::Cube(exprs)) => { + Expr::GroupingSet(GroupingSet::Cube(exprs), _) => { aggr_projection_exprs.extend_from_slice(exprs) } - Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => { + Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs), _) => { for exprs in lists_of_exprs { aggr_projection_exprs.extend_from_slice(exprs) } diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 31b836f32b242..e6aad5272e9cd 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -443,7 +443,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { plan.schema(), )?; - Ok(LogicalPlan::Ddl(DdlStatement::CreateMemoryTable( + Ok(LogicalPlan::ddl(DdlStatement::CreateMemoryTable( CreateMemoryTable { name: self.object_name_to_table_reference(name)?, constraints, @@ -466,7 +466,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &all_constraints, plan.schema(), )?; - Ok(LogicalPlan::Ddl(DdlStatement::CreateMemoryTable( + Ok(LogicalPlan::ddl(DdlStatement::CreateMemoryTable( CreateMemoryTable { name: self.object_name_to_table_reference(name)?, constraints, @@ -530,7 +530,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut plan = self.query_to_plan(*query, &mut PlannerContext::new())?; plan = self.apply_expr_alias(plan, columns)?; - Ok(LogicalPlan::Ddl(DdlStatement::CreateView(CreateView { + Ok(LogicalPlan::ddl(DdlStatement::CreateView(CreateView { name: self.object_name_to_table_reference(name)?, input: Arc::new(plan), or_replace, @@ -547,7 +547,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Statement::CreateSchema { schema_name, if_not_exists, - } => Ok(LogicalPlan::Ddl(DdlStatement::CreateCatalogSchema( + } => Ok(LogicalPlan::ddl(DdlStatement::CreateCatalogSchema( CreateCatalogSchema { schema_name: get_schema_name(&schema_name), if_not_exists, @@ -558,7 +558,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { db_name, if_not_exists, .. - } => Ok(LogicalPlan::Ddl(DdlStatement::CreateCatalog( + } => Ok(LogicalPlan::ddl(DdlStatement::CreateCatalog( CreateCatalog { catalog_name: object_name_to_string(&db_name), if_not_exists, @@ -587,14 +587,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { match object_type { ObjectType::Table => { - Ok(LogicalPlan::Ddl(DdlStatement::DropTable(DropTable { + Ok(LogicalPlan::ddl(DdlStatement::DropTable(DropTable { name, if_exists, schema: DFSchemaRef::new(DFSchema::empty()), }))) } ObjectType::View => { - Ok(LogicalPlan::Ddl(DdlStatement::DropView(DropView { + Ok(LogicalPlan::ddl(DdlStatement::DropView(DropView { name, if_exists, schema: DFSchemaRef::new(DFSchema::empty()), @@ -608,7 +608,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Err(ParserError("Invalid schema specifier (has 3 parts)".to_string())) } }?; - Ok(LogicalPlan::Ddl(DdlStatement::DropCatalogSchema(DropCatalogSchema { + Ok(LogicalPlan::ddl(DdlStatement::DropCatalogSchema(DropCatalogSchema { name, if_exists, cascade, @@ -640,7 +640,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { *statement, &mut planner_context, )?; - Ok(LogicalPlan::Statement(PlanStatement::Prepare(Prepare { + Ok(LogicalPlan::statement(PlanStatement::Prepare(Prepare { name: ident_to_string(&name), data_types, input: Arc::new(plan), @@ -667,7 +667,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .map(|expr| self.sql_to_expr(expr, &empty_schema, planner_context)) .collect::>>()?; - Ok(LogicalPlan::Statement(PlanStatement::Execute(Execute { + Ok(LogicalPlan::statement(PlanStatement::Execute(Execute { name: object_name_to_string(&name), parameters, }))) @@ -676,7 +676,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { name, // Similar to PostgreSQL, the PREPARE keyword is ignored prepare: _, - } => Ok(LogicalPlan::Statement(PlanStatement::Deallocate( + } => Ok(LogicalPlan::statement(PlanStatement::Deallocate( Deallocate { name: ident_to_string(&name), }, @@ -860,14 +860,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { access_mode, isolation_level, }); - Ok(LogicalPlan::Statement(statement)) + Ok(LogicalPlan::statement(statement)) } Statement::Commit { chain } => { let statement = PlanStatement::TransactionEnd(TransactionEnd { conclusion: TransactionConclusion::Commit, chain, }); - Ok(LogicalPlan::Statement(statement)) + Ok(LogicalPlan::statement(statement)) } Statement::Rollback { chain, savepoint } => { if savepoint.is_some() { @@ -877,7 +877,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { conclusion: TransactionConclusion::Rollback, chain, }); - Ok(LogicalPlan::Statement(statement)) + Ok(LogicalPlan::statement(statement)) } Statement::CreateFunction { or_replace, @@ -971,7 +971,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema: DFSchemaRef::new(DFSchema::empty()), }); - Ok(LogicalPlan::Ddl(statement)) + Ok(LogicalPlan::ddl(statement)) } Statement::DropFunction { if_exists, @@ -992,7 +992,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { name, schema: DFSchemaRef::new(DFSchema::empty()), }); - Ok(LogicalPlan::Ddl(statement)) + Ok(LogicalPlan::ddl(statement)) } else { exec_err!("Function name not provided") } @@ -1021,7 +1021,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { false, None, )?; - Ok(LogicalPlan::Ddl(DdlStatement::CreateIndex( + Ok(LogicalPlan::ddl(DdlStatement::CreateIndex( PlanCreateIndex { name, table, @@ -1166,7 +1166,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .map(|f| f.name().to_owned()) .collect(); - Ok(LogicalPlan::Copy(CopyTo { + Ok(LogicalPlan::copy(CopyTo { input: Arc::new(input), output_url: statement.target, file_type, @@ -1288,7 +1288,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let name = self.object_name_to_table_reference(name)?; let constraints = Self::new_constraint_from_table_constraints(&all_constraints, &df_schema)?; - Ok(LogicalPlan::Ddl(DdlStatement::CreateExternalTable( + Ok(LogicalPlan::ddl(DdlStatement::CreateExternalTable( PlanCreateExternalTable { schema: df_schema, name, @@ -1416,7 +1416,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { statement: DFStatement, ) -> Result { let plan = self.statement_to_plan(statement)?; - if matches!(plan, LogicalPlan::Explain(_)) { + if matches!(plan, LogicalPlan::Explain(_, _)) { return plan_err!("Nested EXPLAINs are not supported"); } let plan = Arc::new(plan); @@ -1424,7 +1424,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let schema = schema.to_dfschema_ref()?; if analyze { - Ok(LogicalPlan::Analyze(Analyze { + Ok(LogicalPlan::analyze(Analyze { verbose, input: plan, schema, @@ -1432,7 +1432,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } else { let stringified_plans = vec![plan.to_stringified(PlanType::InitialLogicalPlan)]; - Ok(LogicalPlan::Explain(Explain { + Ok(LogicalPlan::explain(Explain { verbose, plan, stringified_plans, @@ -1552,7 +1552,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { value: value_string, }); - Ok(LogicalPlan::Statement(statement)) + Ok(LogicalPlan::statement(statement)) } fn delete_to_plan( @@ -1586,11 +1586,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &[&[&schema]], &[using_columns], )?; - LogicalPlan::Filter(Filter::try_new(filter_expr, Arc::new(scan))?) + LogicalPlan::filter(Filter::try_new(filter_expr, Arc::new(scan))?) } }; - let plan = LogicalPlan::Dml(DmlStatement::new( + let plan = LogicalPlan::dml(DmlStatement::new( table_ref, schema.into(), WriteOp::Delete, @@ -1660,7 +1660,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &[&[scan.schema()]], &[using_columns], )?; - LogicalPlan::Filter(Filter::try_new(filter_expr, Arc::new(scan))?) + LogicalPlan::filter(Filter::try_new(filter_expr, Arc::new(scan))?) } }; @@ -1703,7 +1703,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let source = project(source, exprs)?; - let plan = LogicalPlan::Dml(DmlStatement::new( + let plan = LogicalPlan::dml(DmlStatement::new( table_name, table_schema, WriteOp::Update, @@ -1828,7 +1828,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { (true, true) => plan_err!("Conflicting insert operations: `overwrite` and `replace_into` cannot both be true")?, }; - let plan = LogicalPlan::Dml(DmlStatement::new( + let plan = LogicalPlan::dml(DmlStatement::new( table_name, Arc::new(table_schema), WriteOp::Insert(insert_op), diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index da268a3ff7387..aa6619fbfa8b4 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -93,11 +93,14 @@ impl Unparser<'_> { fn expr_to_sql_inner(&self, expr: &Expr) -> Result { match expr { - Expr::InList(InList { - expr, - list, - negated, - }) => { + Expr::InList( + InList { + expr, + list, + negated, + }, + _, + ) => { let list_expr = list .iter() .map(|e| self.expr_to_sql_inner(e)) @@ -108,7 +111,7 @@ impl Unparser<'_> { negated: *negated, }) } - Expr::ScalarFunction(ScalarFunction { func, args }) => { + Expr::ScalarFunction(ScalarFunction { func, args }, _) => { let func_name = func.name(); if let Some(expr) = self @@ -120,12 +123,15 @@ impl Unparser<'_> { self.scalar_function_to_sql(func_name, args) } - Expr::Between(Between { - expr, - negated, - low, - high, - }) => { + Expr::Between( + Between { + expr, + negated, + low, + high, + }, + _, + ) => { let sql_parser_expr = self.expr_to_sql_inner(expr)?; let sql_low = self.expr_to_sql_inner(low)?; let sql_high = self.expr_to_sql_inner(high)?; @@ -137,18 +143,21 @@ impl Unparser<'_> { )))) } Expr::Column(col) => self.col_to_sql(col), - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + Expr::BinaryExpr(BinaryExpr { left, op, right }, _) => { let l = self.expr_to_sql_inner(left.as_ref())?; let r = self.expr_to_sql_inner(right.as_ref())?; let op = self.op_to_sql(op)?; Ok(ast::Expr::Nested(Box::new(self.binary_op_to_sql(l, r, op)))) } - Expr::Case(Case { - expr, - when_then_expr, - else_expr, - }) => { + Expr::Case( + Case { + expr, + when_then_expr, + else_expr, + }, + _, + ) => { let conditions = when_then_expr .iter() .map(|(w, _)| self.expr_to_sql_inner(w)) @@ -179,19 +188,22 @@ impl Unparser<'_> { else_result, }) } - Expr::Cast(Cast { expr, data_type }) => { + Expr::Cast(Cast { expr, data_type }, _) => { Ok(self.cast_to_sql(expr, data_type)?) } Expr::Literal(value) => Ok(self.scalar_to_sql(value)?), - Expr::Alias(Alias { expr, name: _, .. }) => self.expr_to_sql_inner(expr), - Expr::WindowFunction(WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - null_treatment: _, - }) => { + Expr::Alias(Alias { expr, name: _, .. }, _) => self.expr_to_sql_inner(expr), + Expr::WindowFunction( + WindowFunction { + fun, + args, + partition_by, + order_by, + window_frame, + null_treatment: _, + }, + _, + ) => { let func_name = fun.name(); let args = self.function_args_to_sql(args)?; @@ -246,27 +258,33 @@ impl Unparser<'_> { parameters: ast::FunctionArguments::None, })) } - Expr::SimilarTo(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive: _, - }) - | Expr::Like(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive: _, - }) => Ok(ast::Expr::Like { + Expr::SimilarTo( + Like { + negated, + expr, + pattern, + escape_char, + case_insensitive: _, + }, + _, + ) + | Expr::Like( + Like { + negated, + expr, + pattern, + escape_char, + case_insensitive: _, + }, + _, + ) => Ok(ast::Expr::Like { negated: *negated, expr: Box::new(self.expr_to_sql_inner(expr)?), pattern: Box::new(self.expr_to_sql_inner(pattern)?), escape_char: escape_char.map(|c| c.to_string()), any: false, }), - Expr::AggregateFunction(agg) => { + Expr::AggregateFunction(agg, _) => { let func_name = agg.func.name(); let args = self.function_args_to_sql(&agg.args)?; @@ -293,7 +311,7 @@ impl Unparser<'_> { parameters: ast::FunctionArguments::None, })) } - Expr::ScalarSubquery(subq) => { + Expr::ScalarSubquery(subq, _) => { let sub_statement = self.plan_to_sql(subq.subquery.as_ref())?; let sub_query = if let ast::Statement::Query(inner_query) = sub_statement { @@ -305,7 +323,7 @@ impl Unparser<'_> { }; Ok(ast::Expr::Subquery(sub_query)) } - Expr::InSubquery(insubq) => { + Expr::InSubquery(insubq, _) => { let inexpr = Box::new(self.expr_to_sql_inner(insubq.expr.as_ref())?); let sub_statement = self.plan_to_sql(insubq.subquery.subquery.as_ref())?; @@ -323,7 +341,7 @@ impl Unparser<'_> { negated: insubq.negated, }) } - Expr::Exists(Exists { subquery, negated }) => { + Expr::Exists(Exists { subquery, negated }, _) => { let sub_statement = self.plan_to_sql(subquery.subquery.as_ref())?; let sub_query = if let ast::Statement::Query(inner_query) = sub_statement { @@ -338,38 +356,38 @@ impl Unparser<'_> { negated: *negated, }) } - Expr::IsNull(expr) => { + Expr::IsNull(expr, _) => { Ok(ast::Expr::IsNull(Box::new(self.expr_to_sql_inner(expr)?))) } - Expr::IsNotNull(expr) => Ok(ast::Expr::IsNotNull(Box::new( + Expr::IsNotNull(expr, _) => Ok(ast::Expr::IsNotNull(Box::new( self.expr_to_sql_inner(expr)?, ))), - Expr::IsTrue(expr) => { + Expr::IsTrue(expr, _) => { Ok(ast::Expr::IsTrue(Box::new(self.expr_to_sql_inner(expr)?))) } - Expr::IsNotTrue(expr) => Ok(ast::Expr::IsNotTrue(Box::new( + Expr::IsNotTrue(expr, _) => Ok(ast::Expr::IsNotTrue(Box::new( self.expr_to_sql_inner(expr)?, ))), - Expr::IsFalse(expr) => { + Expr::IsFalse(expr, _) => { Ok(ast::Expr::IsFalse(Box::new(self.expr_to_sql_inner(expr)?))) } - Expr::IsNotFalse(expr) => Ok(ast::Expr::IsNotFalse(Box::new( + Expr::IsNotFalse(expr, _) => Ok(ast::Expr::IsNotFalse(Box::new( self.expr_to_sql_inner(expr)?, ))), - Expr::IsUnknown(expr) => Ok(ast::Expr::IsUnknown(Box::new( + Expr::IsUnknown(expr, _) => Ok(ast::Expr::IsUnknown(Box::new( self.expr_to_sql_inner(expr)?, ))), - Expr::IsNotUnknown(expr) => Ok(ast::Expr::IsNotUnknown(Box::new( + Expr::IsNotUnknown(expr, _) => Ok(ast::Expr::IsNotUnknown(Box::new( self.expr_to_sql_inner(expr)?, ))), - Expr::Not(expr) => { + Expr::Not(expr, _) => { let sql_parser_expr = self.expr_to_sql_inner(expr)?; Ok(AstExpr::UnaryOp { op: UnaryOperator::Not, expr: Box::new(sql_parser_expr), }) } - Expr::Negative(expr) => { + Expr::Negative(expr, _) => { let sql_parser_expr = self.expr_to_sql_inner(expr)?; Ok(AstExpr::UnaryOp { op: UnaryOperator::Minus, @@ -393,7 +411,7 @@ impl Unparser<'_> { ) }) } - Expr::TryCast(TryCast { expr, data_type }) => { + Expr::TryCast(TryCast { expr, data_type }, _) => { let inner_expr = self.expr_to_sql_inner(expr)?; Ok(ast::Expr::Cast { kind: ast::CastKind::TryCast, @@ -403,7 +421,7 @@ impl Unparser<'_> { }) } // TODO: unparsing wildcard addition options - Expr::Wildcard(Wildcard { qualifier, .. }) => { + Expr::Wildcard(Wildcard { qualifier, .. }, _) => { if let Some(qualifier) = qualifier { let idents: Vec = qualifier.to_vec().into_iter().map(Ident::new).collect(); @@ -412,7 +430,7 @@ impl Unparser<'_> { Ok(ast::Expr::Wildcard) } } - Expr::GroupingSet(grouping_set) => match grouping_set { + Expr::GroupingSet(grouping_set, _) => match grouping_set { GroupingSet::GroupingSets(grouping_sets) => { let expr_ast_sets = grouping_sets .iter() @@ -450,7 +468,7 @@ impl Unparser<'_> { Ok(ast::Expr::Value(ast::Value::Placeholder(p.id.to_string()))) } Expr::OuterReferenceColumn(_, col) => self.col_to_sql(col), - Expr::Unnest(unnest) => self.unnest_to_sql(unnest), + Expr::Unnest(unnest, _) => self.unnest_to_sql(unnest), } } @@ -602,10 +620,13 @@ impl Unparser<'_> { .map(|e| { if matches!( e, - Expr::Wildcard(Wildcard { - qualifier: None, - .. - }) + Expr::Wildcard( + Wildcard { + qualifier: None, + .. + }, + _ + ) ) { Ok(ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Wildcard)) } else { @@ -1577,7 +1598,7 @@ mod tests { fn expr_to_sql_ok() -> Result<()> { let dummy_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); let dummy_logical_plan = table_scan(Some("t"), &dummy_schema, None)? - .project(vec![Expr::Wildcard(Wildcard { + .project(vec![Expr::wildcard(Wildcard { qualifier: None, options: WildcardOptions::default(), })])? @@ -1610,14 +1631,14 @@ mod tests { r#"CASE WHEN a IS NOT NULL THEN true ELSE false END"#, ), ( - Expr::Cast(Cast { + Expr::cast(Cast { expr: Box::new(col("a")), data_type: DataType::Date64, }), r#"CAST(a AS DATETIME)"#, ), ( - Expr::Cast(Cast { + Expr::cast(Cast { expr: Box::new(col("a")), data_type: DataType::Timestamp( TimeUnit::Nanosecond, @@ -1627,14 +1648,14 @@ mod tests { r#"CAST(a AS TIMESTAMP WITH TIME ZONE)"#, ), ( - Expr::Cast(Cast { + Expr::cast(Cast { expr: Box::new(col("a")), data_type: DataType::Timestamp(TimeUnit::Millisecond, None), }), r#"CAST(a AS TIMESTAMP)"#, ), ( - Expr::Cast(Cast { + Expr::cast(Cast { expr: Box::new(col("a")), data_type: DataType::UInt32, }), @@ -1665,7 +1686,7 @@ mod tests { r#"dummy_udf(a, b) IS NOT NULL"#, ), ( - Expr::Like(Like { + Expr::_like(Like { negated: true, expr: Box::new(col("a")), pattern: Box::new(lit("foo")), @@ -1675,7 +1696,7 @@ mod tests { r#"a NOT LIKE 'foo' ESCAPE 'o'"#, ), ( - Expr::SimilarTo(Like { + Expr::similar_to(Like { negated: false, expr: Box::new(col("a")), pattern: Box::new(lit("foo")), @@ -1771,7 +1792,7 @@ mod tests { (sum(col("a")), r#"sum(a)"#), ( count_udaf() - .call(vec![Expr::Wildcard(Wildcard { + .call(vec![Expr::wildcard(Wildcard { qualifier: None, options: WildcardOptions::default(), })]) @@ -1782,7 +1803,7 @@ mod tests { ), ( count_udaf() - .call(vec![Expr::Wildcard(Wildcard { + .call(vec![Expr::wildcard(Wildcard { qualifier: None, options: WildcardOptions::default(), })]) @@ -1792,7 +1813,7 @@ mod tests { "count(*) FILTER (WHERE true)", ), ( - Expr::WindowFunction(WindowFunction { + Expr::window_function(WindowFunction { fun: WindowFunctionDefinition::WindowUDF(row_number_udwf()), args: vec![col("col")], partition_by: vec![], @@ -1803,7 +1824,7 @@ mod tests { r#"row_number(col) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)"#, ), ( - Expr::WindowFunction(WindowFunction { + Expr::window_function(WindowFunction { fun: WindowFunctionDefinition::AggregateUDF(count_udaf()), args: vec![wildcard()], partition_by: vec![], @@ -1852,7 +1873,7 @@ mod tests { Expr::between(col("a"), lit(1), lit(7)), r#"(a BETWEEN 1 AND 7)"#, ), - (Expr::Negative(Box::new(col("a"))), r#"-a"#), + (Expr::negative(Box::new(col("a"))), r#"-a"#), ( exists(Arc::new(dummy_logical_plan.clone())), r#"EXISTS (SELECT * FROM t WHERE (t.a = 1))"#, @@ -1916,14 +1937,14 @@ mod tests { r#"((a + b) > 100.123)"#, ), ( - Expr::Cast(Cast { + Expr::cast(Cast { expr: Box::new(col("a")), data_type: DataType::Decimal128(10, -2), }), r#"CAST(a AS DECIMAL(12,0))"#, ), ( - Expr::Unnest(Unnest { + Expr::unnest(Unnest { expr: Box::new(Expr::Column(Column { relation: Some(TableReference::partial("schema", "table")), name: "array_col".to_string(), @@ -1994,7 +2015,7 @@ mod tests { .build(); let unparser = Unparser::new(&dialect); - let expr = Expr::Cast(Cast { + let expr = Expr::cast(Cast { expr: Box::new(col("a")), data_type: DataType::Date64, }); @@ -2019,7 +2040,7 @@ mod tests { .build(); let unparser = Unparser::new(&dialect); - let expr = Expr::Cast(Cast { + let expr = Expr::cast(Cast { expr: Box::new(col("a")), data_type: DataType::Float64, }); @@ -2246,7 +2267,7 @@ mod tests { fn test_cast_value_to_binary_expr() { let tests = [ ( - Expr::Cast(Cast { + Expr::cast(Cast { expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some( "blah".to_string(), )))), @@ -2255,7 +2276,7 @@ mod tests { "'blah'", ), ( - Expr::Cast(Cast { + Expr::cast(Cast { expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some( "blah".to_string(), )))), @@ -2291,7 +2312,7 @@ mod tests { ] { let unparser = Unparser::new(dialect); - let expr = Expr::Cast(Cast { + let expr = Expr::cast(Cast { expr: Box::new(col("a")), data_type, }); @@ -2374,7 +2395,7 @@ mod tests { [(default_dialect, "BIGINT"), (mysql_dialect, "SIGNED")] { let unparser = Unparser::new(&dialect); - let expr = Expr::Cast(Cast { + let expr = Expr::cast(Cast { expr: Box::new(col("a")), data_type: DataType::Int64, }); @@ -2402,7 +2423,7 @@ mod tests { [(default_dialect, "INTEGER"), (mysql_dialect, "SIGNED")] { let unparser = Unparser::new(&dialect); - let expr = Expr::Cast(Cast { + let expr = Expr::cast(Cast { expr: Box::new(col("a")), data_type: DataType::Int32, }); @@ -2441,7 +2462,7 @@ mod tests { (&mysql_dialect, ×tamp_with_tz, "DATETIME"), ] { let unparser = Unparser::new(dialect); - let expr = Expr::Cast(Cast { + let expr = Expr::cast(Cast { expr: Box::new(col("a")), data_type: data_type.clone(), }); @@ -2468,7 +2489,7 @@ mod tests { ] { let unparser = Unparser::new(dialect); - let expr = Expr::Cast(Cast { + let expr = Expr::cast(Cast { expr: Box::new(col("a")), data_type, }); @@ -2485,7 +2506,7 @@ mod tests { #[test] fn test_cast_value_to_dict_expr() { let tests = [( - Expr::Cast(Cast { + Expr::cast(Cast { expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some( "variation".to_string(), )))), @@ -2517,12 +2538,12 @@ mod tests { [(default_dialect, "DOUBLE"), (postgres_dialect, "NUMERIC")] { let unparser = Unparser::new(dialect.as_ref()); - let expr = Expr::ScalarFunction(ScalarFunction { + let expr = Expr::scalar_function(ScalarFunction { func: Arc::new(ScalarUDF::from( datafusion_functions::math::round::RoundFunc::new(), )), args: vec![ - Expr::Cast(Cast { + Expr::cast(Cast { expr: Box::new(col("a")), data_type: DataType::Float64, }), diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 81e47ed939f22..343f03baf71e5 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -91,31 +91,31 @@ impl Unparser<'_> { let plan = normalize_union_schema(plan)?; match plan { - LogicalPlan::Projection(_) - | LogicalPlan::Filter(_) - | LogicalPlan::Window(_) - | LogicalPlan::Aggregate(_) - | LogicalPlan::Sort(_) - | LogicalPlan::Join(_) - | LogicalPlan::Repartition(_) - | LogicalPlan::Union(_) - | LogicalPlan::TableScan(_) + LogicalPlan::Projection(_, _) + | LogicalPlan::Filter(_, _) + | LogicalPlan::Window(_, _) + | LogicalPlan::Aggregate(_, _) + | LogicalPlan::Sort(_, _) + | LogicalPlan::Join(_, _) + | LogicalPlan::Repartition(_, _) + | LogicalPlan::Union(_, _) + | LogicalPlan::TableScan(_, _) | LogicalPlan::EmptyRelation(_) - | LogicalPlan::Subquery(_) - | LogicalPlan::SubqueryAlias(_) - | LogicalPlan::Limit(_) - | LogicalPlan::Statement(_) - | LogicalPlan::Values(_) - | LogicalPlan::Distinct(_) => self.select_to_sql_statement(&plan), - LogicalPlan::Dml(_) => self.dml_to_sql(&plan), - LogicalPlan::Explain(_) - | LogicalPlan::Analyze(_) - | LogicalPlan::Extension(_) - | LogicalPlan::Ddl(_) - | LogicalPlan::Copy(_) + | LogicalPlan::Subquery(_, _) + | LogicalPlan::SubqueryAlias(_, _) + | LogicalPlan::Limit(_, _) + | LogicalPlan::Statement(_, _) + | LogicalPlan::Values(_, _) + | LogicalPlan::Distinct(_, _) => self.select_to_sql_statement(&plan), + LogicalPlan::Dml(_, _) => self.dml_to_sql(&plan), + LogicalPlan::Explain(_, _) + | LogicalPlan::Analyze(_, _) + | LogicalPlan::Extension(_, _) + | LogicalPlan::Ddl(_, _) + | LogicalPlan::Copy(_, _) | LogicalPlan::DescribeTable(_) - | LogicalPlan::RecursiveQuery(_) - | LogicalPlan::Unnest(_) => not_impl_err!("Unsupported plan: {plan:?}"), + | LogicalPlan::RecursiveQuery(_, _) + | LogicalPlan::Unnest(_, _) => not_impl_err!("Unsupported plan: {plan:?}"), } } @@ -275,7 +275,7 @@ impl Unparser<'_> { relation: &mut RelationBuilder, ) -> Result<()> { match plan { - LogicalPlan::TableScan(scan) => { + LogicalPlan::TableScan(scan, _) => { if let Some(unparsed_table_scan) = Self::unparse_table_scan_pushdown(plan, None)? { @@ -304,7 +304,7 @@ impl Unparser<'_> { Ok(()) } - LogicalPlan::Projection(p) => { + LogicalPlan::Projection(p, _) => { if let Some(new_plan) = rewrite_plan_for_sort_on_non_projected_fields(p) { return self .select_to_sql_recursively(&new_plan, query, select, relation); @@ -321,7 +321,7 @@ impl Unparser<'_> { self.reconstruct_select_statement(plan, p, select)?; self.select_to_sql_recursively(p.input.as_ref(), query, select, relation) } - LogicalPlan::Filter(filter) => { + LogicalPlan::Filter(filter, _) => { if let Some(agg) = find_agg_node_within_select(plan, select.already_projected()) { @@ -341,7 +341,7 @@ impl Unparser<'_> { relation, ) } - LogicalPlan::Limit(limit) => { + LogicalPlan::Limit(limit, _) => { // Limit can be top-level plan for derived table if select.already_projected() { return self.derive_with_dialect_alias( @@ -378,7 +378,7 @@ impl Unparser<'_> { relation, ) } - LogicalPlan::Sort(sort) => { + LogicalPlan::Sort(sort, _) => { // Sort can be top-level plan for derived table if select.already_projected() { return self.derive_with_dialect_alias( @@ -419,7 +419,7 @@ impl Unparser<'_> { relation, ) } - LogicalPlan::Aggregate(agg) => { + LogicalPlan::Aggregate(agg, _) => { // Aggregation can be already handled in the projection case if !select.already_projected() { // The query returns aggregate and group expressions. If that weren't the case, @@ -448,7 +448,7 @@ impl Unparser<'_> { relation, ) } - LogicalPlan::Distinct(distinct) => { + LogicalPlan::Distinct(distinct, _) => { // Distinct can be top-level plan for derived table if select.already_projected() { return self.derive_with_dialect_alias( @@ -486,7 +486,7 @@ impl Unparser<'_> { select.distinct(Some(select_distinct)); self.select_to_sql_recursively(input, query, select, relation) } - LogicalPlan::Join(join) => { + LogicalPlan::Join(join, _) => { let mut table_scan_filters = vec![]; let left_plan = @@ -529,7 +529,7 @@ impl Unparser<'_> { // Combine `table_scan_filters` into a single filter using `AND` let Some(combined_filters) = table_scan_filters.into_iter().reduce(|acc, filter| { - Expr::BinaryExpr(BinaryExpr { + Expr::binary_expr(BinaryExpr { left: Box::new(acc), op: Operator::And, right: Box::new(filter), @@ -541,7 +541,7 @@ impl Unparser<'_> { // Combine `join.filter` with `combined_filters` using `AND` match &join.filter { - Some(filter) => Some(Expr::BinaryExpr(BinaryExpr { + Some(filter) => Some(Expr::binary_expr(BinaryExpr { left: Box::new(filter.clone()), op: Operator::And, right: Box::new(combined_filters), @@ -579,7 +579,7 @@ impl Unparser<'_> { Ok(()) } - LogicalPlan::SubqueryAlias(plan_alias) => { + LogicalPlan::SubqueryAlias(plan_alias, _) => { let (plan, mut columns) = subquery_alias_inner_query_and_columns(plan_alias); let unparsed_table_scan = Self::unparse_table_scan_pushdown( @@ -626,7 +626,7 @@ impl Unparser<'_> { Ok(()) } - LogicalPlan::Union(union) => { + LogicalPlan::Union(union, _) => { if union.inputs.len() != 2 { return not_impl_err!( "UNION ALL expected 2 inputs, but found {}", @@ -665,7 +665,7 @@ impl Unparser<'_> { Ok(()) } - LogicalPlan::Window(window) => { + LogicalPlan::Window(window, _) => { // Window nodes are handled simultaneously with Projection nodes self.select_to_sql_recursively( window.input.as_ref(), @@ -678,8 +678,10 @@ impl Unparser<'_> { relation.empty(); Ok(()) } - LogicalPlan::Extension(_) => not_impl_err!("Unsupported operator: {plan:?}"), - LogicalPlan::Unnest(unnest) => { + LogicalPlan::Extension(_, _) => { + not_impl_err!("Unsupported operator: {plan:?}") + } + LogicalPlan::Unnest(unnest, _) => { if !unnest.struct_type_columns.is_empty() { return internal_err!( "Struct type columns are not currently supported in UNNEST: {:?}", @@ -694,7 +696,7 @@ impl Unparser<'_> { // | Projection: table.col1, table.col2 AS UNNEST(table.col2) // | Filter: table.col3 = Int64(3) // | TableScan: table projection=None - if let LogicalPlan::Projection(p) = unnest.input.as_ref() { + if let LogicalPlan::Projection(p, _) = unnest.input.as_ref() { // continue with projection input self.select_to_sql_recursively(&p.input, query, select, relation) } else { @@ -716,7 +718,7 @@ impl Unparser<'_> { alias: Option, ) -> Result> { match plan { - LogicalPlan::TableScan(table_scan) => { + LogicalPlan::TableScan(table_scan, _) => { if !Self::is_scan_with_pushdown(table_scan) { return Ok(None); } @@ -801,7 +803,7 @@ impl Unparser<'_> { Ok(Some(builder.build()?)) } - LogicalPlan::SubqueryAlias(subquery_alias) => { + LogicalPlan::SubqueryAlias(subquery_alias, _) => { Self::unparse_table_scan_pushdown( &subquery_alias.input, Some(subquery_alias.alias.clone()), @@ -809,7 +811,7 @@ impl Unparser<'_> { } // SubqueryAlias could be rewritten to a plan with a projection as the top node by [rewrite::subquery_alias_inner_query_and_columns]. // The inner table scan could be a scan with pushdown operations. - LogicalPlan::Projection(projection) => { + LogicalPlan::Projection(projection, _) => { if let Some(plan) = Self::unparse_table_scan_pushdown(&projection.input, alias.clone())? { @@ -847,7 +849,7 @@ impl Unparser<'_> { fn select_item_to_sql(&self, expr: &Expr) -> Result { match expr { - Expr::Alias(Alias { expr, name, .. }) => { + Expr::Alias(Alias { expr, name, .. }, _) => { let inner = self.expr_to_sql(expr)?; Ok(ast::SelectItem::ExprWithAlias { diff --git a/datafusion/sql/src/unparser/rewrite.rs b/datafusion/sql/src/unparser/rewrite.rs index 68af121a41179..84e1960863f80 100644 --- a/datafusion/sql/src/unparser/rewrite.rs +++ b/datafusion/sql/src/unparser/rewrite.rs @@ -23,7 +23,6 @@ use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, Column, HashMap, Result, TableReference, }; -use datafusion_expr::expr::Alias; use datafusion_expr::{Expr, LogicalPlan, Projection, Sort, SortExpr}; use sqlparser::ast::Ident; @@ -58,20 +57,20 @@ pub(super) fn normalize_union_schema(plan: &LogicalPlan) -> Result let plan = plan.clone(); let transformed_plan = plan.transform_up(|plan| match plan { - LogicalPlan::Union(mut union) => { + LogicalPlan::Union(mut union, _) => { let schema = Arc::unwrap_or_clone(union.schema); let schema = schema.strip_qualifiers(); union.schema = Arc::new(schema); - Ok(Transformed::yes(LogicalPlan::Union(union))) + Ok(Transformed::yes(LogicalPlan::union(union))) } - LogicalPlan::Sort(sort) => { + LogicalPlan::Sort(sort, _) => { // Only rewrite Sort expressions that have a UNION as their input - if !matches!(&*sort.input, LogicalPlan::Union(_)) { - return Ok(Transformed::no(LogicalPlan::Sort(sort))); + if !matches!(&*sort.input, LogicalPlan::Union(_, _)) { + return Ok(Transformed::no(LogicalPlan::sort(sort))); } - Ok(Transformed::yes(LogicalPlan::Sort(Sort { + Ok(Transformed::yes(LogicalPlan::sort(Sort { expr: rewrite_sort_expr_for_union(sort.expr)?, input: sort.input, fetch: sort.fetch, @@ -122,11 +121,11 @@ fn rewrite_sort_expr_for_union(exprs: Vec) -> Result> { pub(super) fn rewrite_plan_for_sort_on_non_projected_fields( p: &Projection, ) -> Option { - let LogicalPlan::Sort(sort) = p.input.as_ref() else { + let LogicalPlan::Sort(sort, _) = p.input.as_ref() else { return None; }; - let LogicalPlan::Projection(inner_p) = sort.input.as_ref() else { + let LogicalPlan::Projection(inner_p, _) = sort.input.as_ref() else { return None; }; @@ -136,7 +135,7 @@ pub(super) fn rewrite_plan_for_sort_on_non_projected_fields( .iter() .enumerate() .map(|(i, f)| match f { - Expr::Alias(alias) => { + Expr::Alias(alias, _) => { let a = Expr::Column(alias.name.clone().into()); map.insert(a.clone(), f.clone()); a @@ -182,9 +181,9 @@ pub(super) fn rewrite_plan_for_sort_on_non_projected_fields( .collect::>(); inner_p.expr.clone_from(&new_exprs); - sort.input = Arc::new(LogicalPlan::Projection(inner_p)); + sort.input = Arc::new(LogicalPlan::projection(inner_p)); - Some(LogicalPlan::Sort(sort)) + Some(LogicalPlan::sort(sort)) } else { None } @@ -222,7 +221,7 @@ pub(super) fn subquery_alias_inner_query_and_columns( ) -> (&LogicalPlan, Vec) { let plan: &LogicalPlan = subquery_alias.input.as_ref(); - let LogicalPlan::Projection(outer_projections) = plan else { + let LogicalPlan::Projection(outer_projections, _) = plan else { return (plan, vec![]); }; @@ -236,7 +235,7 @@ pub(super) fn subquery_alias_inner_query_and_columns( // Projection: j1.j1_id AS id // Projection: j1.j1_id for (i, inner_expr) in inner_projection.expr.iter().enumerate() { - let Expr::Alias(ref outer_alias) = &outer_projections.expr[i] else { + let Expr::Alias(ref outer_alias, _) = &outer_projections.expr[i] else { return (plan, vec![]); }; @@ -270,11 +269,13 @@ pub(super) fn inject_column_aliases_into_subquery( aliases: Vec, ) -> Result { match &plan { - LogicalPlan::Projection(inner_p) => Ok(inject_column_aliases(inner_p, aliases)), + LogicalPlan::Projection(inner_p, _) => { + Ok(inject_column_aliases(inner_p, aliases)) + } _ => { // projection is wrapped by other operator (LIMIT, SORT, etc), iterate through the plan to find it plan.map_children(|child| { - if let LogicalPlan::Projection(p) = &child { + if let LogicalPlan::Projection(p, _) = &child { Ok(Transformed::yes(inject_column_aliases(p, aliases.clone()))) } else { Ok(Transformed::no(child)) @@ -307,25 +308,21 @@ pub(super) fn inject_column_aliases( _ => None, }; - Expr::Alias(Alias { - expr: Box::new(expr.clone()), - relation, - name: col_alias.value, - }) + expr.clone().alias_qualified(relation, col_alias.value) }) .collect::>(); updated_projection.expr = new_exprs; - LogicalPlan::Projection(updated_projection) + LogicalPlan::projection(updated_projection) } fn find_projection(logical_plan: &LogicalPlan) -> Option<&Projection> { match logical_plan { - LogicalPlan::Projection(p) => Some(p), - LogicalPlan::Limit(p) => find_projection(p.input.as_ref()), - LogicalPlan::Distinct(p) => find_projection(p.input().as_ref()), - LogicalPlan::Sort(p) => find_projection(p.input.as_ref()), + LogicalPlan::Projection(p, _) => Some(p), + LogicalPlan::Limit(p, _) => find_projection(p.input.as_ref()), + LogicalPlan::Distinct(p, _) => find_projection(p.input().as_ref()), + LogicalPlan::Sort(p, _) => find_projection(p.input.as_ref()), _ => None, } } diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index d0f80da83d63f..88167c9e30d8c 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -50,11 +50,11 @@ pub(crate) fn find_agg_node_within_select( input.first()? }; // Agg nodes explicitly return immediately with a single node - if let LogicalPlan::Aggregate(agg) = input { + if let LogicalPlan::Aggregate(agg, _) = input { Some(agg) - } else if let LogicalPlan::TableScan(_) = input { + } else if let LogicalPlan::TableScan(_, _) = input { None - } else if let LogicalPlan::Projection(_) = input { + } else if let LogicalPlan::Projection(_, _) = input { if already_projected { None } else { @@ -76,11 +76,11 @@ pub(crate) fn find_unnest_node_within_select(plan: &LogicalPlan) -> Option<&Unne input.first()? }; - if let LogicalPlan::Unnest(unnest) = input { + if let LogicalPlan::Unnest(unnest, _) = input { Some(unnest) - } else if let LogicalPlan::TableScan(_) = input { + } else if let LogicalPlan::TableScan(_, _) = input { None - } else if let LogicalPlan::Projection(_) = input { + } else if let LogicalPlan::Projection(_, _) = input { None } else { find_unnest_node_within_select(input) @@ -107,7 +107,7 @@ pub(crate) fn find_window_nodes_within_select<'a>( // Window nodes accumulate in a vec until encountering a TableScan or 2nd projection match input { - LogicalPlan::Window(window) => { + LogicalPlan::Window(window, _) => { prev_windows = match &mut prev_windows { Some(windows) => { windows.push(window); @@ -117,14 +117,14 @@ pub(crate) fn find_window_nodes_within_select<'a>( }; find_window_nodes_within_select(input, prev_windows, already_projected) } - LogicalPlan::Projection(_) => { + LogicalPlan::Projection(_, _) => { if already_projected { prev_windows } else { find_window_nodes_within_select(input, prev_windows, true) } } - LogicalPlan::TableScan(_) => prev_windows, + LogicalPlan::TableScan(_, _) => prev_windows, _ => find_window_nodes_within_select(input, prev_windows, already_projected), } } @@ -140,9 +140,9 @@ pub(crate) fn unproject_unnest_expr(expr: Expr, unnest: &Unnest) -> Result // Currently, only List/Array columns (defined in `list_type_columns`) are supported for unnesting. if unnest.list_type_columns.iter().any(|e| e.1.output_column.name == col_ref.name) { if let Ok(idx) = unnest.schema.index_of_column(col_ref) { - if let LogicalPlan::Projection(Projection { expr, .. }) = unnest.input.as_ref() { + if let LogicalPlan::Projection(Projection { expr, .. }, _) = unnest.input.as_ref() { if let Some(unprojected_expr) = expr.get(idx) { - let unnest_expr = Expr::Unnest(expr::Unnest::new(unprojected_expr.clone())); + let unnest_expr = Expr::unnest(expr::Unnest::new(unprojected_expr.clone())); return Ok(Transformed::yes(unnest_expr)); } } @@ -211,7 +211,7 @@ pub(crate) fn unproject_window_exprs(expr: Expr, windows: &[&Window]) -> Result< fn find_agg_expr<'a>(agg: &'a Aggregate, column: &Column) -> Result> { if let Ok(index) = agg.schema.index_of_column(column) { - if matches!(agg.group_expr.as_slice(), [Expr::GroupingSet(_)]) { + if matches!(agg.group_expr.as_slice(), [Expr::GroupingSet(_, _)]) { // For grouping set expr, we must operate by expression list from the grouping set let grouping_expr = grouping_set_to_exprlist(agg.group_expr.as_slice())?; match index.cmp(&grouping_expr.len()) { @@ -255,7 +255,7 @@ pub(crate) fn unproject_sort_expr( let mut sort_expr = sort_expr.clone(); // Remove alias if present, because ORDER BY cannot use aliases - if let Expr::Alias(alias) = &sort_expr.expr { + if let Expr::Alias(alias, _) = &sort_expr.expr { sort_expr.expr = *alias.expr.clone(); } @@ -279,10 +279,10 @@ pub(crate) fn unproject_sort_expr( // If SELECT and ORDER BY contain the same expression with a scalar function, the ORDER BY expression will // be replaced by a Column expression (e.g., "substr(customer.c_last_name, Int64(0), Int64(5))"), and we need // to transform it back to the actual expression. - if let LogicalPlan::Projection(Projection { expr, schema, .. }) = input { + if let LogicalPlan::Projection(Projection { expr, schema, .. }, _) = input { if let Ok(idx) = schema.index_of_column(col_ref) { - if let Some(Expr::ScalarFunction(scalar_fn)) = expr.get(idx) { - sort_expr.expr = Expr::ScalarFunction(scalar_fn.clone()); + if let Some(Expr::ScalarFunction(scalar_fn, _)) = expr.get(idx) { + sort_expr.expr = Expr::scalar_function(scalar_fn.clone()); } } return Ok(sort_expr); @@ -316,15 +316,15 @@ pub(crate) fn try_transform_to_simple_table_scan_with_filters( while let Some(current_plan) = plan_stack.pop() { match current_plan { - LogicalPlan::SubqueryAlias(alias) => { + LogicalPlan::SubqueryAlias(alias, _) => { table_alias = Some(alias.alias.clone()); plan_stack.push(alias.input.as_ref()); } - LogicalPlan::Filter(filter) => { + LogicalPlan::Filter(filter, _) => { filters.push(filter.predicate.clone()); plan_stack.push(filter.input.as_ref()); } - LogicalPlan::TableScan(table_scan) => { + LogicalPlan::TableScan(table_scan, _) => { let table_schema = table_scan.source.schema(); // optional rewriter if table has an alias let mut filter_alias_rewriter = diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index e479bdbacd839..060958afdcecd 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -103,17 +103,17 @@ pub(crate) fn check_columns_satisfy_exprs( let column_exprs = find_column_exprs(exprs); for e in &column_exprs { match e { - Expr::GroupingSet(GroupingSet::Rollup(exprs)) => { + Expr::GroupingSet(GroupingSet::Rollup(exprs), _) => { for e in exprs { check_column_satisfies_expr(columns, e, message_prefix)?; } } - Expr::GroupingSet(GroupingSet::Cube(exprs)) => { + Expr::GroupingSet(GroupingSet::Cube(exprs), _) => { for e in exprs { check_column_satisfies_expr(columns, e, message_prefix)?; } } - Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => { + Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs), _) => { for exprs in lists_of_exprs { for e in exprs { check_column_satisfies_expr(columns, e, message_prefix)?; @@ -148,7 +148,9 @@ pub(crate) fn extract_aliases(exprs: &[Expr]) -> HashMap { exprs .iter() .filter_map(|expr| match expr { - Expr::Alias(Alias { expr, name, .. }) => Some((name.clone(), *expr.clone())), + Expr::Alias(Alias { expr, name, .. }, _) => { + Some((name.clone(), *expr.clone())) + } _ => None, }) .collect::>() @@ -171,7 +173,7 @@ pub(crate) fn resolve_positions_to_exprs( let index = (position - 1) as usize; let select_expr = &select_exprs[index]; Ok(match select_expr { - Expr::Alias(Alias { expr, .. }) => *expr.clone(), + Expr::Alias(Alias { expr, .. }, _) => *expr.clone(), _ => select_expr.clone(), }) } @@ -208,9 +210,11 @@ pub fn window_expr_common_partition_keys(window_exprs: &[Expr]) -> Result<&[Expr let all_partition_keys = window_exprs .iter() .map(|expr| match expr { - Expr::WindowFunction(WindowFunction { partition_by, .. }) => Ok(partition_by), - Expr::Alias(Alias { expr, .. }) => match expr.as_ref() { - Expr::WindowFunction(WindowFunction { partition_by, .. }) => { + Expr::WindowFunction(WindowFunction { partition_by, .. }, _) => { + Ok(partition_by) + } + Expr::Alias(Alias { expr, .. }, _) => match expr.as_ref() { + Expr::WindowFunction(WindowFunction { partition_by, .. }, _) => { Ok(partition_by) } expr => exec_err!("Impossibly got non-window expr {expr:?}"), @@ -424,7 +428,7 @@ impl<'a> TreeNodeRewriter for RecursiveUnnestRewriter<'a> { /// - If some unnest expr has been visited, maintain a stack of such information, this /// is used to detect if some recursive unnest expr exists (e.g **unnest(unnest(unnest(3d column))))** fn f_down(&mut self, expr: Expr) -> Result> { - if let Expr::Unnest(ref unnest_expr) = expr { + if let Expr::Unnest(ref unnest_expr, _) = expr { let (data_type, _) = unnest_expr.expr.data_type_and_nullable(self.input_schema)?; self.consecutive_unnest.push(Some(unnest_expr.clone())); @@ -481,7 +485,7 @@ impl<'a> TreeNodeRewriter for RecursiveUnnestRewriter<'a> { /// ``` /// fn f_up(&mut self, expr: Expr) -> Result> { - if let Expr::Unnest(ref traversing_unnest) = expr { + if let Expr::Unnest(ref traversing_unnest, _) = expr { if traversing_unnest == self.top_most_unnest.as_ref().unwrap() { self.top_most_unnest = None; } diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index ab7e6c8d0bb73..914526a9f996b 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -206,10 +206,14 @@ fn test_parse_options_value_normalization() { assert_eq!(expected_plan, format!("{plan}")); match plan { - LogicalPlan::Ddl(DdlStatement::CreateExternalTable( - CreateExternalTable { options, .. }, - )) - | LogicalPlan::Copy(CopyTo { options, .. }) => { + LogicalPlan::Ddl( + DdlStatement::CreateExternalTable(CreateExternalTable { + options, + .. + }), + _, + ) + | LogicalPlan::Copy(CopyTo { options, .. }, _) => { expected_options.iter().for_each(|(k, v)| { assert_eq!(Some(&v.to_string()), options.get(*k)); }); @@ -2711,7 +2715,7 @@ fn prepare_stmt_quick_test( assert_eq!(format!("{assert_plan}"), expected_plan); // verify data types - if let LogicalPlan::Statement(Statement::Prepare(Prepare { data_types, .. })) = + if let LogicalPlan::Statement(Statement::Prepare(Prepare { data_types, .. }), _) = assert_plan { let dt = format!("{data_types:?}"); @@ -4436,15 +4440,18 @@ fn plan_create_index() { "CREATE UNIQUE INDEX IF NOT EXISTS idx_name ON test USING btree (name, age DESC)"; let plan = logical_plan_with_options(sql, ParserOptions::default()).unwrap(); match plan { - LogicalPlan::Ddl(DdlStatement::CreateIndex(CreateIndex { - name, - table, - using, - columns, - unique, - if_not_exists, - .. - })) => { + LogicalPlan::Ddl( + DdlStatement::CreateIndex(CreateIndex { + name, + table, + using, + columns, + unique, + if_not_exists, + .. + }), + _, + ) => { assert_eq!(name, Some("idx_name".to_string())); assert_eq!(format!("{table}"), "test"); assert_eq!(using, Some("btree".to_string())); diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 1cce228527ecf..728818ac40619 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -167,7 +167,7 @@ fn split_eq_and_noneq_join_predicate_with_nulls_equality( for expr in exprs { #[allow(clippy::collapsible_match)] match expr { - Expr::BinaryExpr(binary_expr) => match binary_expr { + Expr::BinaryExpr(binary_expr, _) => match binary_expr { x @ (BinaryExpr { left, op: Operator::Eq, @@ -293,16 +293,16 @@ pub async fn from_substrait_plan( match plan { // If the last node of the plan produces expressions, bake the renames into those expressions. // This isn't necessary for correctness, but helps with roundtrip tests. - LogicalPlan::Projection(p) => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(p.expr, p.input.schema(), renamed_schema.fields())?, p.input)?)), - LogicalPlan::Aggregate(a) => { + LogicalPlan::Projection(p, _) => Ok(LogicalPlan::projection(Projection::try_new(rename_expressions(p.expr, p.input.schema(), renamed_schema.fields())?, p.input)?)), + LogicalPlan::Aggregate(a, _) => { let (group_fields, expr_fields) = renamed_schema.fields().split_at(a.group_expr.len()); let new_group_exprs = rename_expressions(a.group_expr, a.input.schema(), group_fields)?; let new_aggr_exprs = rename_expressions(a.aggr_expr, a.input.schema(), expr_fields)?; - Ok(LogicalPlan::Aggregate(Aggregate::try_new(a.input, new_group_exprs, new_aggr_exprs)?)) + Ok(LogicalPlan::aggregate(Aggregate::try_new(a.input, new_group_exprs, new_aggr_exprs)?)) }, // There are probably more plans where we could bake things in, can add them later as needed. // Otherwise, add a new Project to handle the renaming. - _ => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(plan.schema().columns().iter().map(|c| col(c.to_owned())), plan.schema(), renamed_schema.fields())?, Arc::new(plan))?)) + _ => Ok(LogicalPlan::projection(Projection::try_new(rename_expressions(plan.schema().columns().iter().map(|c| col(c.to_owned())), plan.schema(), renamed_schema.fields())?, Arc::new(plan))?)) } } }, @@ -438,7 +438,7 @@ fn rename_expressions( .map(|(old_expr, new_field)| { // Check if type (i.e. nested struct field names) match, use Cast to rename if needed let new_expr = if &old_expr.get_type(input_schema)? != new_field.data_type() { - Expr::Cast(Cast::new( + Expr::cast(Cast::new( Box::new(old_expr), new_field.data_type().to_owned(), )) @@ -591,7 +591,7 @@ pub async fn from_substrait_rel( from_substrait_rex(ctx, expr, input.clone().schema(), extensions) .await?; // if the expression is WindowFunction, wrap in a Window relation - if let Expr::WindowFunction(_) = &e { + if let Expr::WindowFunction(_, _) = &e { // Adding the same expression here and in the project below // works because the project's builder uses columnize_expr(..) // to transform it into a column reference @@ -707,7 +707,7 @@ pub async fn from_substrait_rel( // Note that GroupingSet::Rollup would become GroupingSet::GroupingSets, when // parsed by the producer and consumer, since Substrait does not have a type dedicated // to ROLLUP. Only vector of Groupings (grouping sets) is available. - group_exprs.push(Expr::GroupingSet(GroupingSet::GroupingSets( + group_exprs.push(Expr::grouping_set(GroupingSet::GroupingSets( grouping_sets, ))); } @@ -923,7 +923,7 @@ pub async fn from_substrait_rel( }) .collect::>()?; - Ok(LogicalPlan::Values(Values { + Ok(LogicalPlan::values(Values { schema: DFSchemaRef::new(substrait_schema), values, })) @@ -1019,7 +1019,7 @@ pub async fn from_substrait_rel( .state() .serializer_registry() .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; - Ok(LogicalPlan::Extension(Extension { node: plan })) + Ok(LogicalPlan::extension(Extension { node: plan })) } Some(RelType::ExtensionSingle(extension)) => { let Some(ext_detail) = &extension.detail else { @@ -1037,7 +1037,7 @@ pub async fn from_substrait_rel( let input_plan = from_substrait_rel(ctx, input_rel, extensions).await?; let plan = plan.with_exprs_and_inputs(plan.expressions(), vec![input_plan])?; - Ok(LogicalPlan::Extension(Extension { node: plan })) + Ok(LogicalPlan::extension(Extension { node: plan })) } Some(RelType::ExtensionMulti(extension)) => { let Some(ext_detail) = &extension.detail else { @@ -1053,7 +1053,7 @@ pub async fn from_substrait_rel( inputs.push(input_plan); } let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?; - Ok(LogicalPlan::Extension(Extension { node: plan })) + Ok(LogicalPlan::extension(Extension { node: plan })) } Some(RelType::Exchange(exchange)) => { let Some(input) = exchange.input.as_ref() else { @@ -1089,7 +1089,7 @@ pub async fn from_substrait_rel( return not_impl_err!("Unsupported exchange kind: {exchange_kind:?}"); } }; - Ok(LogicalPlan::Repartition(Repartition { + Ok(LogicalPlan::repartition(Repartition { input, partitioning_scheme, })) @@ -1156,7 +1156,7 @@ fn apply_emit_kind( // expressions in the projection are volatile. This is to avoid issues like // converting a single call of the random() function into multiple calls due to // duplicate fields in the output_mapping. - LogicalPlan::Projection(proj) if !contains_volatile_expr(&proj) => { + LogicalPlan::Projection(proj, _) if !contains_volatile_expr(&proj) => { let mut exprs: Vec = vec![]; for field in output_mapping { let expr = proj.expr @@ -1263,7 +1263,7 @@ fn apply_projection(table: DataFrame, substrait_schema: DFSchema) -> Result { + LogicalPlan::TableScan(mut scan, _) => { let column_indices: Vec = substrait_schema .strip_qualifiers() .fields() @@ -1287,7 +1287,7 @@ fn apply_projection(table: DataFrame, substrait_schema: DFSchema) -> Result plan_err!("DataFrame passed to apply_projection must be a TableScan"), } @@ -1502,7 +1502,7 @@ pub async fn from_substrait_agg_func( args }; - Ok(Arc::new(Expr::AggregateFunction( + Ok(Arc::new(Expr::aggregate_function( expr::AggregateFunction::new_udf(fun, args, distinct, filter, order_by, None), ))) } else { @@ -1526,7 +1526,7 @@ pub async fn from_substrait_rex( Some(RexType::SingularOrList(s)) => { let substrait_expr = s.value.as_ref().unwrap(); let substrait_list = s.options.as_ref(); - Ok(Expr::InList(InList { + Ok(Expr::_in_list(InList { expr: Box::new( from_substrait_rex(ctx, substrait_expr, input_schema, extensions) .await?, @@ -1593,7 +1593,7 @@ pub async fn from_substrait_rex( )), None => None, }; - Ok(Expr::Case(Case { + Ok(Expr::case(Case { expr, when_then_expr, else_expr, @@ -1615,7 +1615,7 @@ pub async fn from_substrait_rex( // try to first match the requested function into registered udfs, then built-in ops // and finally built-in expressions if let Some(func) = ctx.state().scalar_functions().get(fn_name) { - Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( + Ok(Expr::scalar_function(expr::ScalarFunction::new_udf( func.to_owned(), args, ))) @@ -1632,7 +1632,7 @@ pub async fn from_substrait_rex( .into_iter() .fold(None, |combined_expr: Option, arg: Expr| { Some(match combined_expr { - Some(expr) => Expr::BinaryExpr(BinaryExpr { + Some(expr) => Expr::binary_expr(BinaryExpr { left: Box::new(expr), op, right: Box::new(arg), @@ -1654,7 +1654,7 @@ pub async fn from_substrait_rex( Ok(Expr::Literal(scalar_value)) } Some(RexType::Cast(cast)) => match cast.as_ref().r#type.as_ref() { - Some(output_type) => Ok(Expr::Cast(Cast::new( + Some(output_type) => Ok(Expr::cast(Cast::new( Box::new( from_substrait_rex( ctx, @@ -1712,7 +1712,7 @@ pub async fn from_substrait_rex( } } }; - Ok(Expr::WindowFunction(expr::WindowFunction { + Ok(Expr::window_function(expr::WindowFunction { fun, args: from_substrait_func_args( ctx, @@ -1750,7 +1750,7 @@ pub async fn from_substrait_rex( from_substrait_rel(ctx, haystack_expr, extensions) .await?; let outer_refs = haystack_expr.all_out_ref_exprs(); - Ok(Expr::InSubquery(InSubquery { + Ok(Expr::in_subquery(InSubquery { expr: Box::new( from_substrait_rex( ctx, @@ -1779,7 +1779,7 @@ pub async fn from_substrait_rex( ) .await?; let outer_ref_columns = plan.all_out_ref_exprs(); - Ok(Expr::ScalarSubquery(Subquery { + Ok(Expr::scalar_subquery(Subquery { subquery: Arc::new(plan), outer_ref_columns, })) @@ -1796,7 +1796,7 @@ pub async fn from_substrait_rex( ) .await?; let outer_ref_columns = plan.all_out_ref_exprs(); - Ok(Expr::Exists(Exists::new( + Ok(Expr::exists(Exists::new( Subquery { subquery: Arc::new(plan), outer_ref_columns, @@ -2876,16 +2876,16 @@ impl BuiltinExprBuilder { let arg = Box::new(arg); let expr = match fn_name { - "not" => Expr::Not(arg), - "negative" | "negate" => Expr::Negative(arg), - "is_null" => Expr::IsNull(arg), - "is_not_null" => Expr::IsNotNull(arg), - "is_true" => Expr::IsTrue(arg), - "is_false" => Expr::IsFalse(arg), - "is_not_true" => Expr::IsNotTrue(arg), - "is_not_false" => Expr::IsNotFalse(arg), - "is_unknown" => Expr::IsUnknown(arg), - "is_not_unknown" => Expr::IsNotUnknown(arg), + "not" => Expr::_not(arg), + "negative" | "negate" => Expr::negative(arg), + "is_null" => Expr::_is_null(arg), + "is_not_null" => Expr::_is_not_null(arg), + "is_true" => Expr::_is_true(arg), + "is_false" => Expr::_is_false(arg), + "is_not_true" => Expr::_is_not_true(arg), + "is_not_false" => Expr::_is_not_false(arg), + "is_unknown" => Expr::_is_unknown(arg), + "is_not_unknown" => Expr::_is_not_unknown(arg), _ => return not_impl_err!("Unsupported builtin expression: {}", fn_name), }; @@ -2941,7 +2941,7 @@ impl BuiltinExprBuilder { None }; - Ok(Expr::Like(Like { + Ok(Expr::_like(Like { negated: false, expr: Box::new(expr), pattern: Box::new(pattern), diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 4d864e4334ce6..15043add37bf8 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -187,7 +187,7 @@ pub fn to_substrait_rel( extensions: &mut Extensions, ) -> Result> { match plan { - LogicalPlan::TableScan(scan) => { + LogicalPlan::TableScan(scan, _) => { let projection = scan.projection.as_ref().map(|p| { p.iter() .map(|i| StructItem { @@ -241,7 +241,7 @@ pub fn to_substrait_rel( }))), })) } - LogicalPlan::Values(v) => { + LogicalPlan::Values(v, _) => { let values = v .values .iter() @@ -250,7 +250,7 @@ pub fn to_substrait_rel( .iter() .map(|v| match v { Expr::Literal(sv) => to_substrait_literal(sv, extensions), - Expr::Alias(alias) => match alias.expr.as_ref() { + Expr::Alias(alias, _) => match alias.expr.as_ref() { // The schema gives us the names, so we can skip aliases Expr::Literal(sv) => to_substrait_literal(sv, extensions), _ => Err(substrait_datafusion_err!( @@ -280,7 +280,7 @@ pub fn to_substrait_rel( }))), })) } - LogicalPlan::Projection(p) => { + LogicalPlan::Projection(p, _) => { let expressions = p .expr .iter() @@ -306,7 +306,7 @@ pub fn to_substrait_rel( }))), })) } - LogicalPlan::Filter(filter) => { + LogicalPlan::Filter(filter, _) => { let input = to_substrait_rel(filter.input.as_ref(), ctx, extensions)?; let filter_expr = to_substrait_rex( ctx, @@ -324,7 +324,7 @@ pub fn to_substrait_rel( }))), })) } - LogicalPlan::Limit(limit) => { + LogicalPlan::Limit(limit, _) => { let input = to_substrait_rel(limit.input.as_ref(), ctx, extensions)?; let FetchType::Literal(fetch) = limit.get_fetch_type()? else { return not_impl_err!("Non-literal limit fetch"); @@ -343,7 +343,7 @@ pub fn to_substrait_rel( }))), })) } - LogicalPlan::Sort(sort) => { + LogicalPlan::Sort(sort, _) => { let input = to_substrait_rel(sort.input.as_ref(), ctx, extensions)?; let sort_fields = sort .expr @@ -359,7 +359,7 @@ pub fn to_substrait_rel( }))), })) } - LogicalPlan::Aggregate(agg) => { + LogicalPlan::Aggregate(agg, _) => { let input = to_substrait_rel(agg.input.as_ref(), ctx, extensions)?; let (grouping_expressions, groupings) = to_substrait_groupings( ctx, @@ -384,7 +384,7 @@ pub fn to_substrait_rel( }))), })) } - LogicalPlan::Distinct(Distinct::All(plan)) => { + LogicalPlan::Distinct(Distinct::All(plan), _) => { // Use Substrait's AggregateRel with empty measures to represent `select distinct` let input = to_substrait_rel(plan.as_ref(), ctx, extensions)?; // Get grouping keys from the input relation's number of output fields @@ -406,7 +406,7 @@ pub fn to_substrait_rel( }))), })) } - LogicalPlan::Join(join) => { + LogicalPlan::Join(join, _) => { let left = to_substrait_rel(join.left.as_ref(), ctx, extensions)?; let right = to_substrait_rel(join.right.as_ref(), ctx, extensions)?; let join_type = to_substrait_jointype(join.join_type); @@ -476,12 +476,12 @@ pub fn to_substrait_rel( }))), })) } - LogicalPlan::SubqueryAlias(alias) => { + LogicalPlan::SubqueryAlias(alias, _) => { // Do nothing if encounters SubqueryAlias // since there is no corresponding relation type in Substrait to_substrait_rel(alias.input.as_ref(), ctx, extensions) } - LogicalPlan::Union(union) => { + LogicalPlan::Union(union, _) => { let input_rels = union .inputs .iter() @@ -499,7 +499,7 @@ pub fn to_substrait_rel( })), })) } - LogicalPlan::Window(window) => { + LogicalPlan::Window(window, _) => { let input = to_substrait_rel(window.input.as_ref(), ctx, extensions)?; // create a field reference for each input field @@ -538,7 +538,7 @@ pub fn to_substrait_rel( rel_type: Some(RelType::Project(project_rel)), })) } - LogicalPlan::Repartition(repartition) => { + LogicalPlan::Repartition(repartition, _) => { let input = to_substrait_rel(repartition.input.as_ref(), ctx, extensions)?; let partition_count = match repartition.partitioning_scheme { Partitioning::RoundRobinBatch(num) => num, @@ -584,7 +584,7 @@ pub fn to_substrait_rel( rel_type: Some(RelType::Exchange(Box::new(exchange_rel))), })) } - LogicalPlan::Extension(extension_plan) => { + LogicalPlan::Extension(extension_plan, _) => { let extension_bytes = ctx .state() .serializer_registry() @@ -800,7 +800,7 @@ pub fn to_substrait_groupings( let mut ref_group_exprs = vec![]; let groupings = match exprs.len() { 1 => match &exprs[0] { - Expr::GroupingSet(gs) => match gs { + Expr::GroupingSet(gs, _) => match gs { GroupingSet::Cube(_) => Err(DataFusionError::NotImplemented( "GroupingSet CUBE is not yet supported".to_string(), )), @@ -863,7 +863,7 @@ pub fn to_substrait_agg_measure( extensions: &mut Extensions, ) -> Result { match expr { - Expr::AggregateFunction(expr::AggregateFunction { func, args, distinct, filter, order_by, null_treatment: _, }) => { + Expr::AggregateFunction(expr::AggregateFunction { func, args, distinct, filter, order_by, null_treatment: _, }, _) => { let sorts = if let Some(order_by) = order_by { order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr, schema, extensions)).collect::>>()? } else { @@ -895,7 +895,7 @@ pub fn to_substrait_agg_measure( }) } - Expr::Alias(Alias{expr,..})=> { + Expr::Alias(Alias{expr,..}, _)=> { to_substrait_agg_measure(ctx, expr, schema, extensions) } _ => internal_err!( @@ -984,11 +984,14 @@ pub fn to_substrait_rex( extensions: &mut Extensions, ) -> Result { match expr { - Expr::InList(InList { - expr, - list, - negated, - }) => { + Expr::InList( + InList { + expr, + list, + negated, + }, + _, + ) => { let substrait_list = list .iter() .map(|x| to_substrait_rex(ctx, x, schema, col_ref_offset, extensions)) @@ -1021,7 +1024,7 @@ pub fn to_substrait_rex( Ok(substrait_or_list) } } - Expr::ScalarFunction(fun) => { + Expr::ScalarFunction(fun, _) => { let mut arguments: Vec = vec![]; for arg in &fun.args { arguments.push(FunctionArgument { @@ -1046,12 +1049,15 @@ pub fn to_substrait_rex( })), }) } - Expr::Between(Between { - expr, - negated, - low, - high, - }) => { + Expr::Between( + Between { + expr, + negated, + low, + high, + }, + _, + ) => { if *negated { // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) let substrait_expr = @@ -1114,17 +1120,20 @@ pub fn to_substrait_rex( let index = schema.index_of_column(col)?; substrait_field_ref(index + col_ref_offset) } - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + Expr::BinaryExpr(BinaryExpr { left, op, right }, _) => { let l = to_substrait_rex(ctx, left, schema, col_ref_offset, extensions)?; let r = to_substrait_rex(ctx, right, schema, col_ref_offset, extensions)?; Ok(make_binary_op_scalar_func(&l, &r, *op, extensions)) } - Expr::Case(Case { - expr, - when_then_expr, - else_expr, - }) => { + Expr::Case( + Case { + expr, + when_then_expr, + else_expr, + }, + _, + ) => { let mut ifs: Vec = vec![]; // Parse base if let Some(e) = expr { @@ -1176,7 +1185,7 @@ pub fn to_substrait_rex( rex_type: Some(RexType::IfThen(Box::new(IfThen { ifs, r#else }))), }) } - Expr::Cast(Cast { expr, data_type }) => { + Expr::Cast(Cast { expr, data_type }, _) => { Ok(Expression { rex_type: Some(RexType::Cast(Box::new( substrait::proto::expression::Cast { @@ -1194,17 +1203,20 @@ pub fn to_substrait_rex( }) } Expr::Literal(value) => to_substrait_literal_expr(value, extensions), - Expr::Alias(Alias { expr, .. }) => { + Expr::Alias(Alias { expr, .. }, _) => { to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions) } - Expr::WindowFunction(WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - null_treatment: _, - }) => { + Expr::WindowFunction( + WindowFunction { + fun, + args, + partition_by, + order_by, + window_frame, + null_treatment: _, + }, + _, + ) => { // function reference let function_anchor = extensions.register_function(fun.to_string()); // arguments @@ -1242,13 +1254,16 @@ pub fn to_substrait_rex( bound_type, )) } - Expr::Like(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - }) => make_substrait_like_expr( + Expr::Like( + Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + }, + _, + ) => make_substrait_like_expr( ctx, *case_insensitive, *negated, @@ -1259,11 +1274,14 @@ pub fn to_substrait_rex( col_ref_offset, extensions, ), - Expr::InSubquery(InSubquery { - expr, - subquery, - negated, - }) => { + Expr::InSubquery( + InSubquery { + expr, + subquery, + negated, + }, + _, + ) => { let substrait_expr = to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)?; @@ -1300,7 +1318,7 @@ pub fn to_substrait_rex( Ok(substrait_subquery) } } - Expr::Not(arg) => to_substrait_unary_scalar_fn( + Expr::Not(arg, _) => to_substrait_unary_scalar_fn( ctx, "not", arg, @@ -1308,7 +1326,7 @@ pub fn to_substrait_rex( col_ref_offset, extensions, ), - Expr::IsNull(arg) => to_substrait_unary_scalar_fn( + Expr::IsNull(arg, _) => to_substrait_unary_scalar_fn( ctx, "is_null", arg, @@ -1316,7 +1334,7 @@ pub fn to_substrait_rex( col_ref_offset, extensions, ), - Expr::IsNotNull(arg) => to_substrait_unary_scalar_fn( + Expr::IsNotNull(arg, _) => to_substrait_unary_scalar_fn( ctx, "is_not_null", arg, @@ -1324,7 +1342,7 @@ pub fn to_substrait_rex( col_ref_offset, extensions, ), - Expr::IsTrue(arg) => to_substrait_unary_scalar_fn( + Expr::IsTrue(arg, _) => to_substrait_unary_scalar_fn( ctx, "is_true", arg, @@ -1332,7 +1350,7 @@ pub fn to_substrait_rex( col_ref_offset, extensions, ), - Expr::IsFalse(arg) => to_substrait_unary_scalar_fn( + Expr::IsFalse(arg, _) => to_substrait_unary_scalar_fn( ctx, "is_false", arg, @@ -1340,7 +1358,7 @@ pub fn to_substrait_rex( col_ref_offset, extensions, ), - Expr::IsUnknown(arg) => to_substrait_unary_scalar_fn( + Expr::IsUnknown(arg, _) => to_substrait_unary_scalar_fn( ctx, "is_unknown", arg, @@ -1348,7 +1366,7 @@ pub fn to_substrait_rex( col_ref_offset, extensions, ), - Expr::IsNotTrue(arg) => to_substrait_unary_scalar_fn( + Expr::IsNotTrue(arg, _) => to_substrait_unary_scalar_fn( ctx, "is_not_true", arg, @@ -1356,7 +1374,7 @@ pub fn to_substrait_rex( col_ref_offset, extensions, ), - Expr::IsNotFalse(arg) => to_substrait_unary_scalar_fn( + Expr::IsNotFalse(arg, _) => to_substrait_unary_scalar_fn( ctx, "is_not_false", arg, @@ -1364,7 +1382,7 @@ pub fn to_substrait_rex( col_ref_offset, extensions, ), - Expr::IsNotUnknown(arg) => to_substrait_unary_scalar_fn( + Expr::IsNotUnknown(arg, _) => to_substrait_unary_scalar_fn( ctx, "is_not_unknown", arg, @@ -1372,7 +1390,7 @@ pub fn to_substrait_rex( col_ref_offset, extensions, ), - Expr::Negative(arg) => to_substrait_unary_scalar_fn( + Expr::Negative(arg, _) => to_substrait_unary_scalar_fn( ctx, "negate", arg, diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index d4e2d48885ae6..0026406f78478 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -898,7 +898,7 @@ async fn roundtrip_values() -> Result<()> { async fn roundtrip_values_no_columns() -> Result<()> { let ctx = create_context().await?; // "VALUES ()" is not yet supported by the SQL parser, so we construct the plan manually - let plan = LogicalPlan::Values(Values { + let plan = LogicalPlan::values(Values { values: vec![vec![], vec![]], // two rows, no columns schema: DFSchemaRef::new(DFSchema::empty()), }); @@ -971,7 +971,7 @@ async fn new_test_grammar() -> Result<()> { async fn extension_logical_plan() -> Result<()> { let ctx = create_context().await?; let validation_bytes = "MockUserDefinedLogicalPlan".as_bytes().to_vec(); - let ext_plan = LogicalPlan::Extension(Extension { + let ext_plan = LogicalPlan::extension(Extension { node: Arc::new(MockUserDefinedLogicalPlan { validation_bytes, inputs: vec![], @@ -1076,7 +1076,7 @@ async fn roundtrip_window_udf() -> Result<()> { async fn roundtrip_repartition_roundrobin() -> Result<()> { let ctx = create_context().await?; let scan_plan = ctx.sql("SELECT * FROM data").await?.into_optimized_plan()?; - let plan = LogicalPlan::Repartition(Repartition { + let plan = LogicalPlan::repartition(Repartition { input: Arc::new(scan_plan), partitioning_scheme: Partitioning::RoundRobinBatch(8), }); @@ -1093,7 +1093,7 @@ async fn roundtrip_repartition_roundrobin() -> Result<()> { async fn roundtrip_repartition_hash() -> Result<()> { let ctx = create_context().await?; let scan_plan = ctx.sql("SELECT * FROM data").await?.into_optimized_plan()?; - let plan = LogicalPlan::Repartition(Repartition { + let plan = LogicalPlan::repartition(Repartition { input: Arc::new(scan_plan), partitioning_scheme: Partitioning::Hash(vec![col("data.a")], 8), }); diff --git a/docs/source/library-user-guide/building-logical-plans.md b/docs/source/library-user-guide/building-logical-plans.md index 556deb02e9800..2c5770cbffecf 100644 --- a/docs/source/library-user-guide/building-logical-plans.md +++ b/docs/source/library-user-guide/building-logical-plans.md @@ -55,7 +55,7 @@ fn main() -> Result<(), DataFusionError> { let projection = None; // optional projection let filters = vec![]; // optional filters to push down let fetch = None; // optional LIMIT - let table_scan = LogicalPlan::TableScan(TableScan::try_new( + let table_scan = LogicalPlan::table_scan(TableScan::try_new( "person", Arc::new(table_source), projection,