From 049bf0927d8653c91703e1ce02b38c806925ac16 Mon Sep 17 00:00:00 2001 From: wiedld Date: Thu, 28 Mar 2024 12:06:39 -0700 Subject: [PATCH] refactor: use the only reproducible key (expr_identifer) for expr_set, while keeping the (stack-popped) symbol used for alias. --- .../optimizer/src/common_subexpr_eliminate.rs | 69 ++++++++++++++----- 1 file changed, 50 insertions(+), 19 deletions(-) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 5846c695e4d9..aef703bd5b41 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -38,10 +38,15 @@ use datafusion_expr::logical_plan::{ use datafusion_expr::{col, Expr, ExprSchemable}; /// A map from expression's identifier to tuple including +/// +/// key == Identifier created with only the current node (& subtree) +/// +/// values: /// - the expression itself (cloned) /// - counter /// - DataType of this expression. -type ExprSet = HashMap; +/// - symbol used as the identifier in the alias +type ExprSet = HashMap; /// Identifier for each subexpression. /// @@ -278,9 +283,9 @@ impl CommonSubexprEliminate { for id in affected_id { match expr_set.get(&id) { - Some((expr, _, _)) => { + Some((expr, _, _, symbol)) => { // todo: check `nullable` - agg_exprs.push(expr.clone().alias(&id)); + agg_exprs.push(expr.clone().alias(symbol.as_str())); } _ => { return internal_err!("expr_set invalid state"); @@ -455,11 +460,11 @@ fn build_common_expr_project_plan( for id in affected_id { match expr_set.get(&id) { - Some((expr, _, data_type)) => { + Some((expr, _, data_type, symbol)) => { // todo: check `nullable` let field = DFField::new_unqualified(&id, data_type.clone(), true); fields_set.insert(field.name().to_owned()); - project_exprs.push(expr.clone().alias(&id)); + project_exprs.push(expr.clone().alias(symbol.as_str())); } _ => { return internal_err!("expr_set invalid state"); @@ -650,16 +655,16 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { .push(VisitRecord::ExprItem(curr_expr_identifier)); return Ok(TreeNodeRecursion::Continue); } - let mut desc = Self::expr_identifier(expr); - desc.push_str(&sub_expr_identifier); + let curr_expr_identifier = Self::expr_identifier(expr); + let desc = format!("{curr_expr_identifier}{sub_expr_identifier}"); self.visit_stack.push(VisitRecord::ExprItem(desc.clone())); let data_type = expr.get_type(&self.input_schema)?; self.expr_set - .entry(desc) - .or_insert_with(|| (expr.clone(), 0, data_type)) + .entry(curr_expr_identifier) + .or_insert_with(|| (expr.clone(), 0, data_type, desc)) .1 += 1; Ok(TreeNodeRecursion::Continue) } @@ -713,7 +718,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { // lookup previously visited expression match self.expr_set.get(curr_id) { - Some((_, counter, _)) => { + Some((_, counter, _, symbol)) => { // if has a commonly used (a.k.a. 1+ use) expr if *counter > 1 { self.affected_id.insert(curr_id.clone()); @@ -723,7 +728,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { // `projection_push_down` optimizer use "expr name" to eliminate useless // projections. Ok(Transformed::new( - col(curr_id).alias(expr_name), + col(symbol).alias(expr_name), true, TreeNodeRecursion::Jump, )) @@ -1026,18 +1031,24 @@ mod test { let expr_set_1 = [ ( "c+a".to_string(), - (col("c") + col("a"), 1, DataType::UInt32), + (col("c") + col("a"), 1, DataType::UInt32, "c+a".to_string()), ), ( "b+a".to_string(), - (col("b") + col("a"), 1, DataType::UInt32), + (col("b") + col("a"), 1, DataType::UInt32, "b+a".to_string()), ), ] .into_iter() .collect(); let expr_set_2 = [ - ("c+a".to_string(), (col("c+a"), 1, DataType::UInt32)), - ("b+a".to_string(), (col("b+a"), 1, DataType::UInt32)), + ( + "c+a".to_string(), + (col("c+a"), 1, DataType::UInt32, "c+a".to_string()), + ), + ( + "b+a".to_string(), + (col("b+a"), 1, DataType::UInt32, "b+a".to_string()), + ), ] .into_iter() .collect(); @@ -1069,11 +1080,21 @@ mod test { let expr_set_1 = [ ( "test1.c+test1.a".to_string(), - (col("test1.c") + col("test1.a"), 1, DataType::UInt32), + ( + col("test1.c") + col("test1.a"), + 1, + DataType::UInt32, + "test1.c+test1.a".to_string(), + ), ), ( "test1.b+test1.a".to_string(), - (col("test1.b") + col("test1.a"), 1, DataType::UInt32), + ( + col("test1.b") + col("test1.a"), + 1, + DataType::UInt32, + "test1.b+test1.a".to_string(), + ), ), ] .into_iter() @@ -1081,11 +1102,21 @@ mod test { let expr_set_2 = [ ( "test1.c+test1.a".to_string(), - (col("test1.c+test1.a"), 1, DataType::UInt32), + ( + col("test1.c+test1.a"), + 1, + DataType::UInt32, + "test1.c+test1.a".to_string(), + ), ), ( "test1.b+test1.a".to_string(), - (col("test1.b+test1.a"), 1, DataType::UInt32), + ( + col("test1.b+test1.a"), + 1, + DataType::UInt32, + "test1.b+test1.a".to_string(), + ), ), ] .into_iter()