From 74480ac58ee5658a275b1e8b0ebd9764d0e48844 Mon Sep 17 00:00:00 2001 From: zhuliquan Date: Fri, 20 Dec 2024 22:35:53 +0800 Subject: [PATCH] feat: support normalized expr in CSE (#13315) * feat: support normalized expr in CSE * feat: support normalize_eq in cse optimization * feat: support cumulative binary expr result in normalize_eq --------- Co-authored-by: Andrew Lamb --- datafusion/common/src/cse.rs | 150 +++++-- datafusion/expr/src/expr.rs | 389 +++++++++++++++++- datafusion/expr/src/logical_plan/plan.rs | 20 + .../optimizer/src/common_subexpr_eliminate.rs | 263 +++++++++++- 4 files changed, 790 insertions(+), 32 deletions(-) diff --git a/datafusion/common/src/cse.rs b/datafusion/common/src/cse.rs index ab02915858cd..f64571b8471e 100644 --- a/datafusion/common/src/cse.rs +++ b/datafusion/common/src/cse.rs @@ -50,12 +50,33 @@ impl HashNode for Arc { } } +/// The `Normalizeable` trait defines a method to determine whether a node can be normalized. +/// +/// Normalization is the process of converting a node into a canonical form that can be used +/// to compare nodes for equality. This is useful in optimizations like Common Subexpression Elimination (CSE), +/// where semantically equivalent nodes (e.g., `a + b` and `b + a`) should be treated as equal. +pub trait Normalizeable { + fn can_normalize(&self) -> bool; +} + +/// The `NormalizeEq` trait extends `Eq` and `Normalizeable` to provide a method for comparing +/// normlized nodes in optimizations like Common Subexpression Elimination (CSE). +/// +/// The `normalize_eq` method ensures that two nodes that are semantically equivalent (after normalization) +/// are considered equal in CSE optimization, even if their original forms differ. +/// +/// This trait allows for equality comparisons between nodes with equivalent semantics, regardless of their +/// internal representations. +pub trait NormalizeEq: Eq + Normalizeable { + fn normalize_eq(&self, other: &Self) -> bool; +} + /// Identifier that represents a [`TreeNode`] tree. /// /// This identifier is designed to be efficient and "hash", "accumulate", "equal" and /// "have no collision (as low as possible)" -#[derive(Debug, Eq, PartialEq)] -struct Identifier<'n, N> { +#[derive(Debug, Eq)] +struct Identifier<'n, N: NormalizeEq> { // Hash of `node` built up incrementally during the first, visiting traversal. // Its value is not necessarily equal to default hash of the node. E.g. it is not // equal to `expr.hash()` if the node is `Expr`. @@ -63,20 +84,29 @@ struct Identifier<'n, N> { node: &'n N, } -impl Clone for Identifier<'_, N> { +impl Clone for Identifier<'_, N> { fn clone(&self) -> Self { *self } } -impl Copy for Identifier<'_, N> {} +impl Copy for Identifier<'_, N> {} -impl Hash for Identifier<'_, N> { +impl Hash for Identifier<'_, N> { fn hash(&self, state: &mut H) { state.write_u64(self.hash); } } -impl<'n, N: HashNode> Identifier<'n, N> { +impl PartialEq for Identifier<'_, N> { + fn eq(&self, other: &Self) -> bool { + self.hash == other.hash && self.node.normalize_eq(other.node) + } +} + +impl<'n, N> Identifier<'n, N> +where + N: HashNode + NormalizeEq, +{ fn new(node: &'n N, random_state: &RandomState) -> Self { let mut hasher = random_state.build_hasher(); node.hash_node(&mut hasher); @@ -213,7 +243,11 @@ pub enum FoundCommonNodes { /// /// A [`TreeNode`] without any children (column, literal etc.) will not have identifier /// because they should not be recognized as common subtree. -struct CSEVisitor<'a, 'n, N, C: CSEController> { +struct CSEVisitor<'a, 'n, N, C> +where + N: NormalizeEq, + C: CSEController, +{ /// statistics of [`TreeNode`]s node_stats: &'a mut NodeStats<'n, N>, @@ -244,7 +278,10 @@ struct CSEVisitor<'a, 'n, N, C: CSEController> { } /// Record item that used when traversing a [`TreeNode`] tree. -enum VisitRecord<'n, N> { +enum VisitRecord<'n, N> +where + N: NormalizeEq, +{ /// Marks the beginning of [`TreeNode`]. It contains: /// - The post-order index assigned during the first, visiting traversal. EnterMark(usize), @@ -258,7 +295,11 @@ enum VisitRecord<'n, N> { NodeItem(Identifier<'n, N>, bool), } -impl<'n, N: TreeNode + HashNode, C: CSEController> CSEVisitor<'_, 'n, N, C> { +impl<'n, N, C> CSEVisitor<'_, 'n, N, C> +where + N: TreeNode + HashNode + NormalizeEq, + C: CSEController, +{ /// Find the first `EnterMark` in the stack, and accumulates every `NodeItem` before /// it. Returns a tuple that contains: /// - The pre-order index of the [`TreeNode`] we marked. @@ -271,17 +312,26 @@ impl<'n, N: TreeNode + HashNode, C: CSEController> CSEVisitor<'_, 'n, /// information up from children to parents via `visit_stack` during the first, /// visiting traversal and no need to test the expression's validity beforehand with /// an extra traversal). - fn pop_enter_mark(&mut self) -> (usize, Option>, bool) { - let mut node_id = None; + fn pop_enter_mark( + &mut self, + can_normalize: bool, + ) -> (usize, Option>, bool) { + let mut node_ids: Vec> = vec![]; let mut is_valid = true; while let Some(item) = self.visit_stack.pop() { match item { VisitRecord::EnterMark(down_index) => { + if can_normalize { + node_ids.sort_by_key(|i| i.hash); + } + let node_id = node_ids + .into_iter() + .fold(None, |accum, item| Some(item.combine(accum))); return (down_index, node_id, is_valid); } VisitRecord::NodeItem(sub_node_id, sub_node_is_valid) => { - node_id = Some(sub_node_id.combine(node_id)); + node_ids.push(sub_node_id); is_valid &= sub_node_is_valid; } } @@ -290,8 +340,10 @@ impl<'n, N: TreeNode + HashNode, C: CSEController> CSEVisitor<'_, 'n, } } -impl<'n, N: TreeNode + HashNode + Eq, C: CSEController> TreeNodeVisitor<'n> - for CSEVisitor<'_, 'n, N, C> +impl<'n, N, C> TreeNodeVisitor<'n> for CSEVisitor<'_, 'n, N, C> +where + N: TreeNode + HashNode + NormalizeEq, + C: CSEController, { type Node = N; @@ -331,7 +383,8 @@ impl<'n, N: TreeNode + HashNode + Eq, C: CSEController> TreeNodeVisito } fn f_up(&mut self, node: &'n Self::Node) -> Result { - let (down_index, sub_node_id, sub_node_is_valid) = self.pop_enter_mark(); + let (down_index, sub_node_id, sub_node_is_valid) = + self.pop_enter_mark(node.can_normalize()); let node_id = Identifier::new(node, self.random_state).combine(sub_node_id); let is_valid = C::is_valid(node) && sub_node_is_valid; @@ -369,7 +422,11 @@ impl<'n, N: TreeNode + HashNode + Eq, C: CSEController> TreeNodeVisito /// Rewrite a [`TreeNode`] tree by replacing detected common subtrees with the /// corresponding temporary [`TreeNode`], that column contains the evaluate result of /// replaced [`TreeNode`] tree. -struct CSERewriter<'a, 'n, N, C: CSEController> { +struct CSERewriter<'a, 'n, N, C> +where + N: NormalizeEq, + C: CSEController, +{ /// statistics of [`TreeNode`]s node_stats: &'a NodeStats<'n, N>, @@ -386,8 +443,10 @@ struct CSERewriter<'a, 'n, N, C: CSEController> { controller: &'a mut C, } -impl> TreeNodeRewriter - for CSERewriter<'_, '_, N, C> +impl TreeNodeRewriter for CSERewriter<'_, '_, N, C> +where + N: TreeNode + NormalizeEq, + C: CSEController, { type Node = N; @@ -408,13 +467,30 @@ impl> TreeNodeRewriter self.down_index += 1; } - let (node, alias) = - self.common_nodes.entry(node_id).or_insert_with(|| { - let node_alias = self.controller.generate_alias(); - (node, node_alias) - }); - - let rewritten = self.controller.rewrite(node, alias); + // We *must* replace all original nodes with same `node_id`, not just the first + // node which is inserted into the common_nodes. This is because nodes with the same + // `node_id` are semantically equivalent, but not exactly the same. + // + // For example, `a + 1` and `1 + a` are semantically equivalent but not identical. + // In this case, we should replace the common expression `1 + a` with a new variable + // (e.g., `__common_cse_1`). So, `a + 1` and `1 + a` would both be replaced by + // `__common_cse_1`. + // + // The final result would be: + // - `__common_cse_1 as a + 1` + // - `__common_cse_1 as 1 + a` + // + // This way, we can efficiently handle semantically equivalent expressions without + // incorrectly treating them as identical. + let rewritten = if let Some((_, alias)) = self.common_nodes.get(&node_id) + { + self.controller.rewrite(&node, alias) + } else { + let node_alias = self.controller.generate_alias(); + let rewritten = self.controller.rewrite(&node, &node_alias); + self.common_nodes.insert(node_id, (node, node_alias)); + rewritten + }; return Ok(Transformed::new(rewritten, true, TreeNodeRecursion::Jump)); } @@ -441,7 +517,11 @@ pub struct CSE> { controller: C, } -impl> CSE { +impl CSE +where + N: TreeNode + HashNode + Clone + NormalizeEq, + C: CSEController, +{ pub fn new(controller: C) -> Self { Self { random_state: RandomState::new(), @@ -557,6 +637,7 @@ impl> CSE ) -> Result> { let mut found_common = false; let mut node_stats = NodeStats::new(); + let id_arrays_list = nodes_list .iter() .map(|nodes| { @@ -596,7 +677,10 @@ impl> CSE #[cfg(test)] mod test { use crate::alias::AliasGenerator; - use crate::cse::{CSEController, HashNode, IdArray, Identifier, NodeStats, CSE}; + use crate::cse::{ + CSEController, HashNode, IdArray, Identifier, NodeStats, NormalizeEq, + Normalizeable, CSE, + }; use crate::tree_node::tests::TestTreeNode; use crate::Result; use std::collections::HashSet; @@ -662,6 +746,18 @@ mod test { } } + impl Normalizeable for TestTreeNode { + fn can_normalize(&self) -> bool { + false + } + } + + impl NormalizeEq for TestTreeNode { + fn normalize_eq(&self, other: &Self) -> bool { + self == other + } + } + #[test] fn id_array_visitor() -> Result<()> { let alias_generator = AliasGenerator::new(); diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index c495b5396f53..af54dad79d2e 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -30,7 +30,7 @@ use crate::Volatility; use crate::{udaf, ExprSchemable, Operator, Signature, WindowFrame, WindowUDF}; use arrow::datatypes::{DataType, FieldRef}; -use datafusion_common::cse::HashNode; +use datafusion_common::cse::{HashNode, NormalizeEq, Normalizeable}; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeContainer, TreeNodeRecursion, }; @@ -1665,6 +1665,393 @@ impl Expr { } } +impl Normalizeable for Expr { + fn can_normalize(&self) -> bool { + #[allow(clippy::match_like_matches_macro)] + match self { + Expr::BinaryExpr(BinaryExpr { + op: + _op @ (Operator::Plus + | Operator::Multiply + | Operator::BitwiseAnd + | Operator::BitwiseOr + | Operator::BitwiseXor + | Operator::Eq + | Operator::NotEq), + .. + }) => true, + _ => false, + } + } +} + +impl NormalizeEq for Expr { + fn normalize_eq(&self, other: &Self) -> bool { + match (self, other) { + ( + Expr::BinaryExpr(BinaryExpr { + left: self_left, + op: self_op, + right: self_right, + }), + Expr::BinaryExpr(BinaryExpr { + left: other_left, + op: other_op, + right: other_right, + }), + ) => { + if self_op != other_op { + return false; + } + + if matches!( + self_op, + Operator::Plus + | Operator::Multiply + | Operator::BitwiseAnd + | Operator::BitwiseOr + | Operator::BitwiseXor + | Operator::Eq + | Operator::NotEq + ) { + (self_left.normalize_eq(other_left) + && self_right.normalize_eq(other_right)) + || (self_left.normalize_eq(other_right) + && self_right.normalize_eq(other_left)) + } else { + self_left.normalize_eq(other_left) + && self_right.normalize_eq(other_right) + } + } + ( + Expr::Alias(Alias { + expr: self_expr, + relation: self_relation, + name: self_name, + }), + Expr::Alias(Alias { + expr: other_expr, + relation: other_relation, + name: other_name, + }), + ) => { + self_name == other_name + && self_relation == other_relation + && self_expr.normalize_eq(other_expr) + } + ( + Expr::Like(Like { + negated: self_negated, + expr: self_expr, + pattern: self_pattern, + escape_char: self_escape_char, + case_insensitive: self_case_insensitive, + }), + Expr::Like(Like { + negated: other_negated, + expr: other_expr, + pattern: other_pattern, + escape_char: other_escape_char, + case_insensitive: other_case_insensitive, + }), + ) + | ( + Expr::SimilarTo(Like { + negated: self_negated, + expr: self_expr, + pattern: self_pattern, + escape_char: self_escape_char, + case_insensitive: self_case_insensitive, + }), + Expr::SimilarTo(Like { + negated: other_negated, + expr: other_expr, + pattern: other_pattern, + escape_char: other_escape_char, + case_insensitive: other_case_insensitive, + }), + ) => { + self_negated == other_negated + && self_escape_char == other_escape_char + && self_case_insensitive == other_case_insensitive + && self_expr.normalize_eq(other_expr) + && self_pattern.normalize_eq(other_pattern) + } + (Expr::Not(self_expr), Expr::Not(other_expr)) + | (Expr::IsNull(self_expr), Expr::IsNull(other_expr)) + | (Expr::IsTrue(self_expr), Expr::IsTrue(other_expr)) + | (Expr::IsFalse(self_expr), Expr::IsFalse(other_expr)) + | (Expr::IsUnknown(self_expr), Expr::IsUnknown(other_expr)) + | (Expr::IsNotNull(self_expr), Expr::IsNotNull(other_expr)) + | (Expr::IsNotTrue(self_expr), Expr::IsNotTrue(other_expr)) + | (Expr::IsNotFalse(self_expr), Expr::IsNotFalse(other_expr)) + | (Expr::IsNotUnknown(self_expr), Expr::IsNotUnknown(other_expr)) + | (Expr::Negative(self_expr), Expr::Negative(other_expr)) + | ( + Expr::Unnest(Unnest { expr: self_expr }), + Expr::Unnest(Unnest { expr: other_expr }), + ) => self_expr.normalize_eq(other_expr), + ( + Expr::Between(Between { + expr: self_expr, + negated: self_negated, + low: self_low, + high: self_high, + }), + Expr::Between(Between { + expr: other_expr, + negated: other_negated, + low: other_low, + high: other_high, + }), + ) => { + self_negated == other_negated + && self_expr.normalize_eq(other_expr) + && self_low.normalize_eq(other_low) + && self_high.normalize_eq(other_high) + } + ( + Expr::Cast(Cast { + expr: self_expr, + data_type: self_data_type, + }), + Expr::Cast(Cast { + expr: other_expr, + data_type: other_data_type, + }), + ) + | ( + Expr::TryCast(TryCast { + expr: self_expr, + data_type: self_data_type, + }), + Expr::TryCast(TryCast { + expr: other_expr, + data_type: other_data_type, + }), + ) => self_data_type == other_data_type && self_expr.normalize_eq(other_expr), + ( + Expr::ScalarFunction(ScalarFunction { + func: self_func, + args: self_args, + }), + Expr::ScalarFunction(ScalarFunction { + func: other_func, + args: other_args, + }), + ) => { + self_func.name() == other_func.name() + && self_args.len() == other_args.len() + && self_args + .iter() + .zip(other_args.iter()) + .all(|(a, b)| a.normalize_eq(b)) + } + ( + Expr::AggregateFunction(AggregateFunction { + func: self_func, + args: self_args, + distinct: self_distinct, + filter: self_filter, + order_by: self_order_by, + null_treatment: self_null_treatment, + }), + Expr::AggregateFunction(AggregateFunction { + func: other_func, + args: other_args, + distinct: other_distinct, + filter: other_filter, + order_by: other_order_by, + null_treatment: other_null_treatment, + }), + ) => { + self_func.name() == other_func.name() + && self_distinct == other_distinct + && self_null_treatment == other_null_treatment + && self_args.len() == other_args.len() + && self_args + .iter() + .zip(other_args.iter()) + .all(|(a, b)| a.normalize_eq(b)) + && match (self_filter, other_filter) { + (Some(self_filter), Some(other_filter)) => { + self_filter.normalize_eq(other_filter) + } + (None, None) => true, + _ => false, + } + && match (self_order_by, other_order_by) { + (Some(self_order_by), Some(other_order_by)) => self_order_by + .iter() + .zip(other_order_by.iter()) + .all(|(a, b)| { + a.asc == b.asc + && a.nulls_first == b.nulls_first + && a.expr.normalize_eq(&b.expr) + }), + (None, None) => true, + _ => false, + } + } + ( + Expr::WindowFunction(WindowFunction { + fun: self_fun, + args: self_args, + partition_by: self_partition_by, + order_by: self_order_by, + window_frame: self_window_frame, + null_treatment: self_null_treatment, + }), + Expr::WindowFunction(WindowFunction { + fun: other_fun, + args: other_args, + partition_by: other_partition_by, + order_by: other_order_by, + window_frame: other_window_frame, + null_treatment: other_null_treatment, + }), + ) => { + self_fun.name() == other_fun.name() + && self_window_frame == other_window_frame + && self_null_treatment == other_null_treatment + && self_args.len() == other_args.len() + && self_args + .iter() + .zip(other_args.iter()) + .all(|(a, b)| a.normalize_eq(b)) + && self_partition_by + .iter() + .zip(other_partition_by.iter()) + .all(|(a, b)| a.normalize_eq(b)) + && self_order_by + .iter() + .zip(other_order_by.iter()) + .all(|(a, b)| { + a.asc == b.asc + && a.nulls_first == b.nulls_first + && a.expr.normalize_eq(&b.expr) + }) + } + ( + Expr::Exists(Exists { + subquery: self_subquery, + negated: self_negated, + }), + Expr::Exists(Exists { + subquery: other_subquery, + negated: other_negated, + }), + ) => { + self_negated == other_negated + && self_subquery.normalize_eq(other_subquery) + } + ( + Expr::InSubquery(InSubquery { + expr: self_expr, + subquery: self_subquery, + negated: self_negated, + }), + Expr::InSubquery(InSubquery { + expr: other_expr, + subquery: other_subquery, + negated: other_negated, + }), + ) => { + self_negated == other_negated + && self_expr.normalize_eq(other_expr) + && self_subquery.normalize_eq(other_subquery) + } + ( + Expr::ScalarSubquery(self_subquery), + Expr::ScalarSubquery(other_subquery), + ) => self_subquery.normalize_eq(other_subquery), + ( + Expr::GroupingSet(GroupingSet::Rollup(self_exprs)), + Expr::GroupingSet(GroupingSet::Rollup(other_exprs)), + ) + | ( + Expr::GroupingSet(GroupingSet::Cube(self_exprs)), + Expr::GroupingSet(GroupingSet::Cube(other_exprs)), + ) => { + self_exprs.len() == other_exprs.len() + && self_exprs + .iter() + .zip(other_exprs.iter()) + .all(|(a, b)| a.normalize_eq(b)) + } + ( + Expr::GroupingSet(GroupingSet::GroupingSets(self_exprs)), + Expr::GroupingSet(GroupingSet::GroupingSets(other_exprs)), + ) => { + self_exprs.len() == other_exprs.len() + && self_exprs.iter().zip(other_exprs.iter()).all(|(a, b)| { + a.len() == b.len() + && a.iter().zip(b.iter()).all(|(x, y)| x.normalize_eq(y)) + }) + } + ( + Expr::InList(InList { + expr: self_expr, + list: self_list, + negated: self_negated, + }), + Expr::InList(InList { + expr: other_expr, + list: other_list, + negated: other_negated, + }), + ) => { + // TODO: normalize_eq for lists, for example `a IN (c1 + c3, c3)` is equal to `a IN (c3, c1 + c3)` + self_negated == other_negated + && self_expr.normalize_eq(other_expr) + && self_list.len() == other_list.len() + && self_list + .iter() + .zip(other_list.iter()) + .all(|(a, b)| a.normalize_eq(b)) + } + ( + Expr::Case(Case { + expr: self_expr, + when_then_expr: self_when_then_expr, + else_expr: self_else_expr, + }), + Expr::Case(Case { + expr: other_expr, + when_then_expr: other_when_then_expr, + else_expr: other_else_expr, + }), + ) => { + // TODO: normalize_eq for when_then_expr + // for example `CASE a WHEN 1 THEN 2 WHEN 3 THEN 4 ELSE 5 END` is equal to `CASE a WHEN 3 THEN 4 WHEN 1 THEN 2 ELSE 5 END` + self_when_then_expr.len() == other_when_then_expr.len() + && self_when_then_expr + .iter() + .zip(other_when_then_expr.iter()) + .all(|((self_when, self_then), (other_when, other_then))| { + self_when.normalize_eq(other_when) + && self_then.normalize_eq(other_then) + }) + && match (self_expr, other_expr) { + (Some(self_expr), Some(other_expr)) => { + self_expr.normalize_eq(other_expr) + } + (None, None) => true, + (_, _) => false, + } + && match (self_else_expr, other_else_expr) { + (Some(self_else_expr), Some(other_else_expr)) => { + self_else_expr.normalize_eq(other_else_expr) + } + (None, None) => true, + (_, _) => false, + } + } + (_, _) => self == other, + } + } +} + impl HashNode for Expr { /// As it is pretty easy to forget changing this method when `Expr` changes the /// implementation doesn't use wildcard patterns (`..`, `_`) to catch changes diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 31bf4c573444..6c2b923cf6ad 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -45,6 +45,7 @@ use crate::{ }; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion_common::cse::{NormalizeEq, Normalizeable}; use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, }; @@ -3354,6 +3355,25 @@ pub struct Subquery { pub outer_ref_columns: Vec, } +impl Normalizeable for Subquery { + fn can_normalize(&self) -> bool { + false + } +} + +impl NormalizeEq for Subquery { + fn normalize_eq(&self, other: &Self) -> bool { + // TODO: may be implement NormalizeEq for LogicalPlan? + *self.subquery == *other.subquery + && self.outer_ref_columns.len() == other.outer_ref_columns.len() + && self + .outer_ref_columns + .iter() + .zip(other.outer_ref_columns.iter()) + .all(|(a, b)| a.normalize_eq(b)) + } +} + impl Subquery { pub fn try_from_expr(plan: &Expr) -> Result<&Subquery> { match plan { diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 0ea2d24effbb..e7c9a198f3ad 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -795,8 +795,9 @@ mod test { use arrow::datatypes::{DataType, Field, Schema}; use datafusion_expr::logical_plan::{table_scan, JoinType}; use datafusion_expr::{ - grouping_set, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarUDF, - ScalarUDFImpl, Signature, SimpleAggregateUDF, Volatility, + grouping_set, is_null, not, AccumulatorFactoryFunction, AggregateUDF, + ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, + Volatility, }; use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder}; @@ -1054,8 +1055,9 @@ mod test { .project(vec![lit(1) + col("a"), col("a") + lit(1)])? .build()?; - let expected = "Projection: Int32(1) + test.a, test.a + Int32(1)\ - \n TableScan: test"; + let expected = "Projection: __common_expr_1 AS Int32(1) + test.a, __common_expr_1 AS test.a + Int32(1)\ + \n Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; assert_optimized_plan_eq(expected, plan, None); @@ -1412,6 +1414,259 @@ mod test { Ok(()) } + #[test] + fn test_normalize_add_expression() -> Result<()> { + // a + b <=> b + a + let table_scan = test_table_scan()?; + let expr = ((col("a") + col("b")) * (col("b") + col("a"))).eq(lit(30)); + let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; + + let expected = "Projection: test.a, test.b, test.c\ + \n Filter: __common_expr_1 * __common_expr_1 = Int32(30)\ + \n Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + Ok(()) + } + + #[test] + fn test_normalize_multi_expression() -> Result<()> { + // a * b <=> b * a + let table_scan = test_table_scan()?; + let expr = ((col("a") * col("b")) + (col("b") * col("a"))).eq(lit(30)); + let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; + + let expected = "Projection: test.a, test.b, test.c\ + \n Filter: __common_expr_1 + __common_expr_1 = Int32(30)\ + \n Projection: test.a * test.b AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + Ok(()) + } + + #[test] + fn test_normalize_bitset_and_expression() -> Result<()> { + // a & b <=> b & a + let table_scan = test_table_scan()?; + let expr = ((col("a") & col("b")) + (col("b") & col("a"))).eq(lit(30)); + let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; + + let expected = "Projection: test.a, test.b, test.c\ + \n Filter: __common_expr_1 + __common_expr_1 = Int32(30)\ + \n Projection: test.a & test.b AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + Ok(()) + } + + #[test] + fn test_normalize_bitset_or_expression() -> Result<()> { + // a | b <=> b | a + let table_scan = test_table_scan()?; + let expr = ((col("a") | col("b")) + (col("b") | col("a"))).eq(lit(30)); + let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; + + let expected = "Projection: test.a, test.b, test.c\ + \n Filter: __common_expr_1 + __common_expr_1 = Int32(30)\ + \n Projection: test.a | test.b AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + Ok(()) + } + + #[test] + fn test_normalize_bitset_xor_expression() -> Result<()> { + // a # b <=> b # a + let table_scan = test_table_scan()?; + let expr = ((col("a") ^ col("b")) + (col("b") ^ col("a"))).eq(lit(30)); + let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; + + let expected = "Projection: test.a, test.b, test.c\ + \n Filter: __common_expr_1 + __common_expr_1 = Int32(30)\ + \n Projection: test.a BIT_XOR test.b AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + Ok(()) + } + + #[test] + fn test_normalize_eq_expression() -> Result<()> { + // a = b <=> b = a + let table_scan = test_table_scan()?; + let expr = (col("a").eq(col("b"))).and(col("b").eq(col("a"))); + let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; + + let expected = "Projection: test.a, test.b, test.c\ + \n Filter: __common_expr_1 AND __common_expr_1\ + \n Projection: test.a = test.b AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + Ok(()) + } + + #[test] + fn test_normalize_ne_expression() -> Result<()> { + // a != b <=> b != a + let table_scan = test_table_scan()?; + let expr = (col("a").not_eq(col("b"))).and(col("b").not_eq(col("a"))); + let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; + + let expected = "Projection: test.a, test.b, test.c\ + \n Filter: __common_expr_1 AND __common_expr_1\ + \n Projection: test.a != test.b AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + Ok(()) + } + + #[test] + fn test_normalize_complex_expression() -> Result<()> { + // case1: a + b * c <=> b * c + a + let table_scan = test_table_scan()?; + let expr = ((col("a") + col("b") * col("c")) - (col("b") * col("c") + col("a"))) + .eq(lit(30)); + let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; + + let expected = "Projection: test.a, test.b, test.c\ + \n Filter: __common_expr_1 - __common_expr_1 = Int32(30)\ + \n Projection: test.a + test.b * test.c AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + // ((c1 + c2 / c3) * c3 <=> c3 * (c2 / c3 + c1)) + let table_scan = test_table_scan()?; + let expr = (((col("a") + col("b") / col("c")) * col("c")) + / (col("c") * (col("b") / col("c") + col("a"))) + + col("a")) + .eq(lit(30)); + let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; + let expected = "Projection: test.a, test.b, test.c\ + \n Filter: __common_expr_1 / __common_expr_1 + test.a = Int32(30)\ + \n Projection: (test.a + test.b / test.c) * test.c AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + // c2 / (c1 + c3) <=> c2 / (c3 + c1) + let table_scan = test_table_scan()?; + let expr = ((col("b") / (col("a") + col("c"))) + * (col("b") / (col("c") + col("a")))) + .eq(lit(30)); + let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; + let expected = "Projection: test.a, test.b, test.c\ + \n Filter: __common_expr_1 * __common_expr_1 = Int32(30)\ + \n Projection: test.b / (test.a + test.c) AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + Ok(()) + } + + #[derive(Debug)] + pub struct TestUdf { + signature: Signature, + } + + impl TestUdf { + pub fn new() -> Self { + Self { + signature: Signature::numeric(1, Volatility::Immutable), + } + } + } + + impl ScalarUDFImpl for TestUdf { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "my_udf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _: &[DataType]) -> Result { + Ok(DataType::Int32) + } + + fn invoke(&self, _: &[ColumnarValue]) -> Result { + panic!("not implemented") + } + } + + #[test] + fn test_normalize_inner_binary_expression() -> Result<()> { + // Not(a == b) <=> Not(b == a) + let table_scan = test_table_scan()?; + let expr1 = not(col("a").eq(col("b"))); + let expr2 = not(col("b").eq(col("a"))); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![expr1, expr2])? + .build()?; + let expected = "Projection: __common_expr_1 AS NOT test.a = test.b, __common_expr_1 AS NOT test.b = test.a\ + \n Projection: NOT test.a = test.b AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + // is_null(a == b) <=> is_null(b == a) + let table_scan = test_table_scan()?; + let expr1 = is_null(col("a").eq(col("b"))); + let expr2 = is_null(col("b").eq(col("a"))); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![expr1, expr2])? + .build()?; + let expected = "Projection: __common_expr_1 AS test.a = test.b IS NULL, __common_expr_1 AS test.b = test.a IS NULL\ + \n Projection: test.a = test.b IS NULL AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + // a + b between 0 and 10 <=> b + a between 0 and 10 + let table_scan = test_table_scan()?; + let expr1 = (col("a") + col("b")).between(lit(0), lit(10)); + let expr2 = (col("b") + col("a")).between(lit(0), lit(10)); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![expr1, expr2])? + .build()?; + let expected = "Projection: __common_expr_1 AS test.a + test.b BETWEEN Int32(0) AND Int32(10), __common_expr_1 AS test.b + test.a BETWEEN Int32(0) AND Int32(10)\ + \n Projection: test.a + test.b BETWEEN Int32(0) AND Int32(10) AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + // c between a + b and 10 <=> c between b + a and 10 + let table_scan = test_table_scan()?; + let expr1 = col("c").between(col("a") + col("b"), lit(10)); + let expr2 = col("c").between(col("b") + col("a"), lit(10)); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![expr1, expr2])? + .build()?; + let expected = "Projection: __common_expr_1 AS test.c BETWEEN test.a + test.b AND Int32(10), __common_expr_1 AS test.c BETWEEN test.b + test.a AND Int32(10)\ + \n Projection: test.c BETWEEN test.a + test.b AND Int32(10) AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + // function call with argument <=> function call with argument + let udf = ScalarUDF::from(TestUdf::new()); + let table_scan = test_table_scan()?; + let expr1 = udf.call(vec![col("a") + col("b")]); + let expr2 = udf.call(vec![col("b") + col("a")]); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![expr1, expr2])? + .build()?; + let expected = "Projection: __common_expr_1 AS my_udf(test.a + test.b), __common_expr_1 AS my_udf(test.b + test.a)\ + \n Projection: my_udf(test.a + test.b) AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + Ok(()) + } + /// returns a "random" function that is marked volatile (aka each invocation /// returns a different value) ///