From 5f0cb49c8b1a47830d80a7add1d3c96d7d5a0025 Mon Sep 17 00:00:00 2001 From: wiedld Date: Thu, 21 Mar 2024 12:40:51 -0700 Subject: [PATCH] fix(9678): short circuiting prevented population of visited stack, for common subexpr elimination optimization (#9685) * test(9678): reproducer of short-circuiting causing expr elimination to error * fix(9678): populate visited stack for short-circuited expressions, during the common-expr elimination optimization * test(9678): reproducer for optimizer error (in common_subexpr_eliminate), as seen in other test case * chore: extract id_array into abstraction, to make it more clear the relationship between the two visitors * refactor: tweak the fix and make code more explicit (JumpMark, node_to_identifier) * fix: get the series_number and curr_id with the correct self.current_idx, before the various incr/decr * chore: remove unneeded conditional check (already done earlier), and add code comments * Refine documentation in common_subexpr_eliminate.rs * chore: cleanup -- fix 1 doc comment and consolidate common-expr-elimination test with other expr test --------- Co-authored-by: Andrew Lamb --- .../optimizer/src/common_subexpr_eliminate.rs | 130 +++++++++++------- datafusion/sqllogictest/test_files/expr.slt | 36 +++++ 2 files changed, 115 insertions(+), 51 deletions(-) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index e73885c6aaef..0c9064d0641f 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -29,8 +29,7 @@ use datafusion_common::tree_node::{ TreeNodeVisitor, }; use datafusion_common::{ - internal_datafusion_err, internal_err, Column, DFField, DFSchema, DFSchemaRef, - DataFusionError, Result, + internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, }; use datafusion_expr::expr::Alias; use datafusion_expr::logical_plan::{Aggregate, LogicalPlan, Projection, Window}; @@ -42,8 +41,36 @@ use datafusion_expr::{col, Expr, ExprSchemable}; /// - DataType of this expression. type ExprSet = HashMap; -/// Identifier type. Current implementation use describe of an expression (type String) as -/// Identifier. +/// An ordered map of Identifiers assigned by `ExprIdentifierVisitor` in an +/// initial expression walk. +/// +/// Used by `CommonSubexprRewriter`, which rewrites the expressions to remove +/// common subexpressions. +/// +/// Elements in this array are created on the walk down the expression tree +/// during `f_down`. Thus element 0 is the root of the expression tree. The +/// tuple contains: +/// - series_number. +/// - Incremented during `f_up`, start from 1. +/// - Thus, items with higher idx have the lower series_number. +/// - [`Identifier`] +/// - Identifier of the expression. If empty (`""`), expr should not be considered for common elimination. +/// +/// # Example +/// An expression like `(a + b)` would have the following `IdArray`: +/// ```text +/// [ +/// (3, "a + b"), +/// (2, "a"), +/// (1, "b") +/// ] +/// ``` +type IdArray = Vec<(usize, Identifier)>; + +/// Identifier for each subexpression. +/// +/// Note that the current implementation uses the `Display` of an expression +/// (a `String`) as `Identifier`. /// /// An identifier should (ideally) be able to "hash", "accumulate", "equal" and "have no /// collision (as low as possible)" @@ -293,8 +320,9 @@ impl CommonSubexprEliminate { agg_exprs.push(expr.alias(&name)); proj_exprs.push(Expr::Column(Column::from_name(name))); } else { - let id = - ExprIdentifierVisitor::<'static>::desc_expr(&expr_rewritten); + let id = ExprIdentifierVisitor::<'static>::expr_identifier( + &expr_rewritten, + ); let out_name = expr_rewritten.to_field(&new_input_schema)?.qualified_name(); agg_exprs.push(expr_rewritten.alias(&id)); @@ -557,15 +585,15 @@ impl ExprMask { /// This visitor implementation use a stack `visit_stack` to track traversal, which /// lets us know when a sub-tree's visiting is finished. When `pre_visit` is called /// (traversing to a new node), an `EnterMark` and an `ExprItem` will be pushed into stack. -/// And try to pop out a `EnterMark` on leaving a node (`post_visit()`). All `ExprItem` +/// And try to pop out a `EnterMark` on leaving a node (`f_up()`). All `ExprItem` /// before the first `EnterMark` is considered to be sub-tree of the leaving node. /// /// This visitor also records identifier in `id_array`. Makes the following traverse /// pass can get the identifier of a node without recalculate it. We assign each node /// in the expr tree a series number, start from 1, maintained by `series_number`. -/// Series number represents the order we left (`post_visit`) a node. Has the property +/// Series number represents the order we left (`f_up()`) a node. Has the property /// that child node's series number always smaller than parent's. While `id_array` is -/// organized in the order we enter (`pre_visit`) a node. `node_count` helps us to +/// organized in the order we enter (`f_down()`) a node. `node_count` helps us to /// get the index of `id_array` for each node. /// /// `Expr` without sub-expr (column, literal etc.) will not have identifier @@ -574,15 +602,15 @@ struct ExprIdentifierVisitor<'a> { // param expr_set: &'a mut ExprSet, /// series number (usize) and identifier. - id_array: &'a mut Vec<(usize, Identifier)>, + id_array: &'a mut IdArray, /// input schema for the node that we're optimizing, so we can determine the correct datatype /// for each subexpression input_schema: DFSchemaRef, // inner states visit_stack: Vec, - /// increased in pre_visit, start from 0. + /// increased in fn_down, start from 0. node_count: usize, - /// increased in post_visit, start from 1. + /// increased in fn_up, start from 1. series_number: usize, /// which expression should be skipped? expr_mask: ExprMask, @@ -593,31 +621,33 @@ enum VisitRecord { /// `usize` is the monotone increasing series number assigned in pre_visit(). /// Starts from 0. Is used to index the identifier array `id_array` in post_visit(). EnterMark(usize), + /// the node's children were skipped => jump to f_up on same node + JumpMark(usize), /// Accumulated identifier of sub expression. ExprItem(Identifier), } impl ExprIdentifierVisitor<'_> { - fn desc_expr(expr: &Expr) -> String { + fn expr_identifier(expr: &Expr) -> Identifier { format!("{expr}") } /// Find the first `EnterMark` in the stack, and accumulates every `ExprItem` /// before it. - fn pop_enter_mark(&mut self) -> Option<(usize, Identifier)> { + fn pop_enter_mark(&mut self) -> (usize, Identifier) { let mut desc = String::new(); while let Some(item) = self.visit_stack.pop() { match item { - VisitRecord::EnterMark(idx) => { - return Some((idx, desc)); + VisitRecord::EnterMark(idx) | VisitRecord::JumpMark(idx) => { + return (idx, desc); } - VisitRecord::ExprItem(s) => { - desc.push_str(&s); + VisitRecord::ExprItem(id) => { + desc.push_str(&id); } } } - None + unreachable!("Enter mark should paired with node number"); } } @@ -625,34 +655,39 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { type Node = Expr; fn f_down(&mut self, expr: &Expr) -> Result { + // put placeholder, sets the proper array length + self.id_array.push((0, "".to_string())); + // related to https://github.com/apache/arrow-datafusion/issues/8814 // If the expr contain volatile expression or is a short-circuit expression, skip it. if expr.short_circuits() || is_volatile_expression(expr)? { - return Ok(TreeNodeRecursion::Jump); + self.visit_stack + .push(VisitRecord::JumpMark(self.node_count)); + return Ok(TreeNodeRecursion::Jump); // go to f_up } + self.visit_stack .push(VisitRecord::EnterMark(self.node_count)); self.node_count += 1; - // put placeholder - self.id_array.push((0, "".to_string())); + Ok(TreeNodeRecursion::Continue) } fn f_up(&mut self, expr: &Expr) -> Result { self.series_number += 1; - let Some((idx, sub_expr_desc)) = self.pop_enter_mark() else { - return Ok(TreeNodeRecursion::Continue); - }; + let (idx, sub_expr_identifier) = self.pop_enter_mark(); + // skip exprs should not be recognize. if self.expr_mask.ignores(expr) { - self.id_array[idx].0 = self.series_number; - let desc = Self::desc_expr(expr); - self.visit_stack.push(VisitRecord::ExprItem(desc)); + let curr_expr_identifier = Self::expr_identifier(expr); + self.visit_stack + .push(VisitRecord::ExprItem(curr_expr_identifier)); + self.id_array[idx].0 = self.series_number; // leave Identifer as empty "", since will not use as common expr return Ok(TreeNodeRecursion::Continue); } - let mut desc = Self::desc_expr(expr); - desc.push_str(&sub_expr_desc); + let mut desc = Self::expr_identifier(expr); + desc.push_str(&sub_expr_identifier); self.id_array[idx] = (self.series_number, desc.clone()); self.visit_stack.push(VisitRecord::ExprItem(desc.clone())); @@ -693,7 +728,7 @@ fn expr_to_identifier( /// evaluate result of replaced expression. struct CommonSubexprRewriter<'a> { expr_set: &'a ExprSet, - id_array: &'a [(usize, Identifier)], + id_array: &'a IdArray, /// Which identifier is replaced. affected_id: &'a mut BTreeSet, @@ -715,20 +750,26 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { if expr.short_circuits() || is_volatile_expression(&expr)? { return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)); } + + let (series_number, curr_id) = &self.id_array[self.curr_index]; + + // halting conditions if self.curr_index >= self.id_array.len() - || self.max_series_number > self.id_array[self.curr_index].0 + || self.max_series_number > *series_number { return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)); } - let curr_id = &self.id_array[self.curr_index].1; // skip `Expr`s without identifier (empty identifier). if curr_id.is_empty() { - self.curr_index += 1; + self.curr_index += 1; // incr idx for id_array, when not jumping return Ok(Transformed::no(expr)); } + + // lookup previously visited expression match self.expr_set.get(curr_id) { Some((_, counter, _)) => { + // if has a commonly used (a.k.a. 1+ use) expr if *counter > 1 { self.affected_id.insert(curr_id.clone()); @@ -741,23 +782,10 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { )); } - let (series_number, id) = &self.id_array[self.curr_index]; + // incr idx for id_array, when not jumping self.curr_index += 1; - // Skip sub-node of a replaced tree, or without identifier, or is not repeated expr. - let expr_set_item = self.expr_set.get(id).ok_or_else(|| { - internal_datafusion_err!("expr_set invalid state") - })?; - if *series_number < self.max_series_number - || id.is_empty() - || expr_set_item.1 <= 1 - { - return Ok(Transformed::new( - expr, - false, - TreeNodeRecursion::Jump, - )); - } + // series_number was the inverse number ordering (when doing f_up) self.max_series_number = *series_number; // step index to skip all sub-node (which has smaller series number). while self.curr_index < self.id_array.len() @@ -771,7 +799,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { // `projection_push_down` optimizer use "expr name" to eliminate useless // projections. Ok(Transformed::new( - col(id).alias(expr_name), + col(curr_id).alias(expr_name), true, TreeNodeRecursion::Jump, )) @@ -787,7 +815,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { fn replace_common_expr( expr: Expr, - id_array: &[(usize, Identifier)], + id_array: &IdArray, expr_set: &ExprSet, affected_id: &mut BTreeSet, ) -> Result { diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index d6343f9a3fe8..69f3e439eac9 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -2205,3 +2205,39 @@ false true false true NULL NULL NULL NULL false false true true false false true false + + +############# +## Common Subexpr Eliminate Tests +############# + +statement ok +CREATE TABLE doubles ( + f64 DOUBLE +) as VALUES + (10.1) +; + +# common subexpr with alias +query RRR rowsort +select f64, round(1.0 / f64) as i64_1, acos(round(1.0 / f64)) from doubles; +---- +10.1 0 1.570796326795 + +# common subexpr with coalesce (short-circuited) +query RRR rowsort +select f64, coalesce(1.0 / f64, 0.0), acos(coalesce(1.0 / f64, 0.0)) from doubles; +---- +10.1 0.09900990099 1.471623942989 + +# common subexpr with coalesce (short-circuited) and alias +query RRR rowsort +select f64, coalesce(1.0 / f64, 0.0) as f64_1, acos(coalesce(1.0 / f64, 0.0)) from doubles; +---- +10.1 0.09900990099 1.471623942989 + +# common subexpr with case (short-circuited) +query RRR rowsort +select f64, case when f64 > 0 then 1.0 / f64 else null end, acos(case when f64 > 0 then 1.0 / f64 else null end) from doubles; +---- +10.1 0.09900990099 1.471623942989