Skip to content

Commit

Permalink
fix(9678): short circuiting prevented population of visited stack, fo…
Browse files Browse the repository at this point in the history
…r common subexpr elimination optimization (#9685)

* test(9678): reproducer of short-circuiting causing expr elimination to error

* fix(9678): populate visited stack for short-circuited expressions, during the common-expr elimination optimization

* test(9678): reproducer for optimizer error (in common_subexpr_eliminate), as seen in other test case

* chore: extract id_array into abstraction, to make it more clear the relationship between the two visitors

* refactor: tweak the fix and make code more explicit (JumpMark, node_to_identifier)

* fix: get the series_number and curr_id with the correct self.current_idx, before the various incr/decr

* chore: remove unneeded conditional check (already done earlier), and add code comments

* Refine documentation in common_subexpr_eliminate.rs

* chore: cleanup -- fix 1 doc comment and consolidate common-expr-elimination test with other expr test

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
wiedld and alamb authored Mar 21, 2024
1 parent 6d74025 commit 5f0cb49
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 51 deletions.
130 changes: 79 additions & 51 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ use datafusion_common::tree_node::{
TreeNodeVisitor,
};
use datafusion_common::{
internal_datafusion_err, internal_err, Column, DFField, DFSchema, DFSchemaRef,
DataFusionError, Result,
internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result,
};
use datafusion_expr::expr::Alias;
use datafusion_expr::logical_plan::{Aggregate, LogicalPlan, Projection, Window};
Expand All @@ -42,8 +41,36 @@ use datafusion_expr::{col, Expr, ExprSchemable};
/// - DataType of this expression.
type ExprSet = HashMap<Identifier, (Expr, usize, DataType)>;

/// Identifier type. Current implementation use describe of an expression (type String) as
/// Identifier.
/// An ordered map of Identifiers assigned by `ExprIdentifierVisitor` in an
/// initial expression walk.
///
/// Used by `CommonSubexprRewriter`, which rewrites the expressions to remove
/// common subexpressions.
///
/// Elements in this array are created on the walk down the expression tree
/// during `f_down`. Thus element 0 is the root of the expression tree. The
/// tuple contains:
/// - series_number.
/// - Incremented during `f_up`, start from 1.
/// - Thus, items with higher idx have the lower series_number.
/// - [`Identifier`]
/// - Identifier of the expression. If empty (`""`), expr should not be considered for common elimination.
///
/// # Example
/// An expression like `(a + b)` would have the following `IdArray`:
/// ```text
/// [
/// (3, "a + b"),
/// (2, "a"),
/// (1, "b")
/// ]
/// ```
type IdArray = Vec<(usize, Identifier)>;

/// Identifier for each subexpression.
///
/// Note that the current implementation uses the `Display` of an expression
/// (a `String`) as `Identifier`.
///
/// An identifier should (ideally) be able to "hash", "accumulate", "equal" and "have no
/// collision (as low as possible)"
Expand Down Expand Up @@ -293,8 +320,9 @@ impl CommonSubexprEliminate {
agg_exprs.push(expr.alias(&name));
proj_exprs.push(Expr::Column(Column::from_name(name)));
} else {
let id =
ExprIdentifierVisitor::<'static>::desc_expr(&expr_rewritten);
let id = ExprIdentifierVisitor::<'static>::expr_identifier(
&expr_rewritten,
);
let out_name =
expr_rewritten.to_field(&new_input_schema)?.qualified_name();
agg_exprs.push(expr_rewritten.alias(&id));
Expand Down Expand Up @@ -557,15 +585,15 @@ impl ExprMask {
/// This visitor implementation use a stack `visit_stack` to track traversal, which
/// lets us know when a sub-tree's visiting is finished. When `pre_visit` is called
/// (traversing to a new node), an `EnterMark` and an `ExprItem` will be pushed into stack.
/// And try to pop out a `EnterMark` on leaving a node (`post_visit()`). All `ExprItem`
/// And try to pop out a `EnterMark` on leaving a node (`f_up()`). All `ExprItem`
/// before the first `EnterMark` is considered to be sub-tree of the leaving node.
///
/// This visitor also records identifier in `id_array`. Makes the following traverse
/// pass can get the identifier of a node without recalculate it. We assign each node
/// in the expr tree a series number, start from 1, maintained by `series_number`.
/// Series number represents the order we left (`post_visit`) a node. Has the property
/// Series number represents the order we left (`f_up()`) a node. Has the property
/// that child node's series number always smaller than parent's. While `id_array` is
/// organized in the order we enter (`pre_visit`) a node. `node_count` helps us to
/// organized in the order we enter (`f_down()`) a node. `node_count` helps us to
/// get the index of `id_array` for each node.
///
/// `Expr` without sub-expr (column, literal etc.) will not have identifier
Expand All @@ -574,15 +602,15 @@ struct ExprIdentifierVisitor<'a> {
// param
expr_set: &'a mut ExprSet,
/// series number (usize) and identifier.
id_array: &'a mut Vec<(usize, Identifier)>,
id_array: &'a mut IdArray,
/// input schema for the node that we're optimizing, so we can determine the correct datatype
/// for each subexpression
input_schema: DFSchemaRef,
// inner states
visit_stack: Vec<VisitRecord>,
/// increased in pre_visit, start from 0.
/// increased in fn_down, start from 0.
node_count: usize,
/// increased in post_visit, start from 1.
/// increased in fn_up, start from 1.
series_number: usize,
/// which expression should be skipped?
expr_mask: ExprMask,
Expand All @@ -593,66 +621,73 @@ enum VisitRecord {
/// `usize` is the monotone increasing series number assigned in pre_visit().
/// Starts from 0. Is used to index the identifier array `id_array` in post_visit().
EnterMark(usize),
/// the node's children were skipped => jump to f_up on same node
JumpMark(usize),
/// Accumulated identifier of sub expression.
ExprItem(Identifier),
}

impl ExprIdentifierVisitor<'_> {
fn desc_expr(expr: &Expr) -> String {
fn expr_identifier(expr: &Expr) -> Identifier {
format!("{expr}")
}

/// Find the first `EnterMark` in the stack, and accumulates every `ExprItem`
/// before it.
fn pop_enter_mark(&mut self) -> Option<(usize, Identifier)> {
fn pop_enter_mark(&mut self) -> (usize, Identifier) {
let mut desc = String::new();

while let Some(item) = self.visit_stack.pop() {
match item {
VisitRecord::EnterMark(idx) => {
return Some((idx, desc));
VisitRecord::EnterMark(idx) | VisitRecord::JumpMark(idx) => {
return (idx, desc);
}
VisitRecord::ExprItem(s) => {
desc.push_str(&s);
VisitRecord::ExprItem(id) => {
desc.push_str(&id);
}
}
}
None
unreachable!("Enter mark should paired with node number");
}
}

impl TreeNodeVisitor for ExprIdentifierVisitor<'_> {
type Node = Expr;

fn f_down(&mut self, expr: &Expr) -> Result<TreeNodeRecursion> {
// put placeholder, sets the proper array length
self.id_array.push((0, "".to_string()));

// related to https://github.com/apache/arrow-datafusion/issues/8814
// If the expr contain volatile expression or is a short-circuit expression, skip it.
if expr.short_circuits() || is_volatile_expression(expr)? {
return Ok(TreeNodeRecursion::Jump);
self.visit_stack
.push(VisitRecord::JumpMark(self.node_count));
return Ok(TreeNodeRecursion::Jump); // go to f_up
}

self.visit_stack
.push(VisitRecord::EnterMark(self.node_count));
self.node_count += 1;
// put placeholder
self.id_array.push((0, "".to_string()));

Ok(TreeNodeRecursion::Continue)
}

fn f_up(&mut self, expr: &Expr) -> Result<TreeNodeRecursion> {
self.series_number += 1;

let Some((idx, sub_expr_desc)) = self.pop_enter_mark() else {
return Ok(TreeNodeRecursion::Continue);
};
let (idx, sub_expr_identifier) = self.pop_enter_mark();

// skip exprs should not be recognize.
if self.expr_mask.ignores(expr) {
self.id_array[idx].0 = self.series_number;
let desc = Self::desc_expr(expr);
self.visit_stack.push(VisitRecord::ExprItem(desc));
let curr_expr_identifier = Self::expr_identifier(expr);
self.visit_stack
.push(VisitRecord::ExprItem(curr_expr_identifier));
self.id_array[idx].0 = self.series_number; // leave Identifer as empty "", since will not use as common expr
return Ok(TreeNodeRecursion::Continue);
}
let mut desc = Self::desc_expr(expr);
desc.push_str(&sub_expr_desc);
let mut desc = Self::expr_identifier(expr);
desc.push_str(&sub_expr_identifier);

self.id_array[idx] = (self.series_number, desc.clone());
self.visit_stack.push(VisitRecord::ExprItem(desc.clone()));
Expand Down Expand Up @@ -693,7 +728,7 @@ fn expr_to_identifier(
/// evaluate result of replaced expression.
struct CommonSubexprRewriter<'a> {
expr_set: &'a ExprSet,
id_array: &'a [(usize, Identifier)],
id_array: &'a IdArray,
/// Which identifier is replaced.
affected_id: &'a mut BTreeSet<Identifier>,

Expand All @@ -715,20 +750,26 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> {
if expr.short_circuits() || is_volatile_expression(&expr)? {
return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump));
}

let (series_number, curr_id) = &self.id_array[self.curr_index];

// halting conditions
if self.curr_index >= self.id_array.len()
|| self.max_series_number > self.id_array[self.curr_index].0
|| self.max_series_number > *series_number
{
return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump));
}

let curr_id = &self.id_array[self.curr_index].1;
// skip `Expr`s without identifier (empty identifier).
if curr_id.is_empty() {
self.curr_index += 1;
self.curr_index += 1; // incr idx for id_array, when not jumping
return Ok(Transformed::no(expr));
}

// lookup previously visited expression
match self.expr_set.get(curr_id) {
Some((_, counter, _)) => {
// if has a commonly used (a.k.a. 1+ use) expr
if *counter > 1 {
self.affected_id.insert(curr_id.clone());

Expand All @@ -741,23 +782,10 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> {
));
}

let (series_number, id) = &self.id_array[self.curr_index];
// incr idx for id_array, when not jumping
self.curr_index += 1;
// Skip sub-node of a replaced tree, or without identifier, or is not repeated expr.
let expr_set_item = self.expr_set.get(id).ok_or_else(|| {
internal_datafusion_err!("expr_set invalid state")
})?;
if *series_number < self.max_series_number
|| id.is_empty()
|| expr_set_item.1 <= 1
{
return Ok(Transformed::new(
expr,
false,
TreeNodeRecursion::Jump,
));
}

// series_number was the inverse number ordering (when doing f_up)
self.max_series_number = *series_number;
// step index to skip all sub-node (which has smaller series number).
while self.curr_index < self.id_array.len()
Expand All @@ -771,7 +799,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> {
// `projection_push_down` optimizer use "expr name" to eliminate useless
// projections.
Ok(Transformed::new(
col(id).alias(expr_name),
col(curr_id).alias(expr_name),
true,
TreeNodeRecursion::Jump,
))
Expand All @@ -787,7 +815,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> {

fn replace_common_expr(
expr: Expr,
id_array: &[(usize, Identifier)],
id_array: &IdArray,
expr_set: &ExprSet,
affected_id: &mut BTreeSet<Identifier>,
) -> Result<Expr> {
Expand Down
36 changes: 36 additions & 0 deletions datafusion/sqllogictest/test_files/expr.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2205,3 +2205,39 @@ false true false true
NULL NULL NULL NULL
false false true true
false false true false


#############
## Common Subexpr Eliminate Tests
#############

statement ok
CREATE TABLE doubles (
f64 DOUBLE
) as VALUES
(10.1)
;

# common subexpr with alias
query RRR rowsort
select f64, round(1.0 / f64) as i64_1, acos(round(1.0 / f64)) from doubles;
----
10.1 0 1.570796326795

# common subexpr with coalesce (short-circuited)
query RRR rowsort
select f64, coalesce(1.0 / f64, 0.0), acos(coalesce(1.0 / f64, 0.0)) from doubles;
----
10.1 0.09900990099 1.471623942989

# common subexpr with coalesce (short-circuited) and alias
query RRR rowsort
select f64, coalesce(1.0 / f64, 0.0) as f64_1, acos(coalesce(1.0 / f64, 0.0)) from doubles;
----
10.1 0.09900990099 1.471623942989

# common subexpr with case (short-circuited)
query RRR rowsort
select f64, case when f64 > 0 then 1.0 / f64 else null end, acos(case when f64 > 0 then 1.0 / f64 else null end) from doubles;
----
10.1 0.09900990099 1.471623942989

0 comments on commit 5f0cb49

Please sign in to comment.