diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 64c5b56a40802..79fa3d5ff4ed0 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -2245,7 +2245,7 @@ impl DistinctOn { /// Aggregates its input based on a set of grouping and aggregate /// expressions (e.g. SUM). -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash)] // mark non_exhaustive to encourage use of try_new/new() #[non_exhaustive] pub struct Aggregate { diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index cb3b4accf35d0..36b3140320b86 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -17,7 +17,6 @@ //! [`CommonSubexprEliminate`] to avoid redundant computation of common sub-expressions -use std::collections::hash_map::Entry; use std::collections::{BTreeSet, HashMap}; use std::sync::Arc; @@ -35,75 +34,37 @@ use datafusion_expr::expr::Alias; use datafusion_expr::logical_plan::{Aggregate, LogicalPlan, Projection, Window}; use datafusion_expr::{col, Expr, ExprSchemable}; -/// Set of expressions generated by the [`ExprIdentifierVisitor`] -/// and consumed by the [`CommonSubexprRewriter`]. -#[derive(Default)] -struct ExprSet { - /// A map from expression's identifier (stringified expr) to tuple including: - /// - the expression itself (cloned) - /// - counter - /// - DataType of this expression. - /// - symbol used as the identifier in the alias. - map: HashMap, -} - -impl ExprSet { - fn expr_identifier(expr: &Expr) -> Identifier { - format!("{expr}") - } - - fn get(&self, key: &Identifier) -> Option<&(Expr, usize, DataType, Identifier)> { - self.map.get(key) - } - - fn entry( - &mut self, - key: Identifier, - ) -> Entry<'_, Identifier, (Expr, usize, DataType, Identifier)> { - self.map.entry(key) - } - - fn populate_expr_set( - &mut self, - expr: &[Expr], - input_schema: DFSchemaRef, - expr_mask: ExprMask, - ) -> Result<()> { - expr.iter().try_for_each(|e| { - self.expr_to_identifier(e, Arc::clone(&input_schema), expr_mask)?; - - Ok(()) - }) - } - - /// Go through an expression tree and generate identifier for every node in this tree. - fn expr_to_identifier( - &mut self, - expr: &Expr, - input_schema: DFSchemaRef, - expr_mask: ExprMask, - ) -> Result<()> { - expr.visit(&mut ExprIdentifierVisitor { - expr_set: self, - input_schema, - visit_stack: vec![], - node_count: 0, - expr_mask, - })?; - - Ok(()) - } -} +/// A map from expression's identifier to tuple including +/// - the expression itself (cloned) +/// - counter +/// - DataType of this expression. +type ExprSet = HashMap; -impl From> for ExprSet { - fn from(entries: Vec<(Identifier, (Expr, usize, DataType, Identifier))>) -> Self { - let mut expr_set = Self::default(); - entries.into_iter().for_each(|(k, v)| { - expr_set.map.insert(k, v); - }); - expr_set - } -} +/// An ordered map of Identifiers assigned by `ExprIdentifierVisitor` in an +/// initial expression walk. +/// +/// Used by `CommonSubexprRewriter`, which rewrites the expressions to remove +/// common subexpressions. +/// +/// Elements in this array are created on the walk down the expression tree +/// during `f_down`. Thus element 0 is the root of the expression tree. The +/// tuple contains: +/// - series_number. +/// - Incremented during `f_up`, start from 1. +/// - Thus, items with higher idx have the lower series_number. +/// - [`Identifier`] +/// - Identifier of the expression. If empty (`""`), expr should not be considered for common elimination. +/// +/// # Example +/// An expression like `(a + b)` would have the following `IdArray`: +/// ```text +/// [ +/// (3, "a + b"), +/// (2, "a"), +/// (1, "b") +/// ] +/// ``` +type IdArray = Vec<(usize, Identifier)>; /// Identifier for each subexpression. /// @@ -156,16 +117,21 @@ impl CommonSubexprEliminate { fn rewrite_exprs_list( &self, exprs_list: &[&[Expr]], + arrays_list: &[&[Vec<(usize, String)>]], expr_set: &ExprSet, affected_id: &mut BTreeSet, ) -> Result>> { exprs_list .iter() - .map(|exprs| { + .zip(arrays_list.iter()) + .map(|(exprs, arrays)| { exprs .iter() .cloned() - .map(|expr| replace_common_expr(expr, expr_set, affected_id)) + .zip(arrays.iter()) + .map(|(expr, id_array)| { + replace_common_expr(expr, id_array, expr_set, affected_id) + }) .collect::>>() }) .collect::>>() @@ -182,6 +148,7 @@ impl CommonSubexprEliminate { fn rewrite_expr( &self, exprs_list: &[&[Expr]], + arrays_list: &[&[Vec<(usize, String)>]], input: &LogicalPlan, expr_set: &ExprSet, config: &dyn OptimizerConfig, @@ -189,7 +156,7 @@ impl CommonSubexprEliminate { let mut affected_id = BTreeSet::::new(); let rewrite_exprs = - self.rewrite_exprs_list(exprs_list, expr_set, &mut affected_id)?; + self.rewrite_exprs_list(exprs_list, arrays_list, expr_set, &mut affected_id)?; let mut new_input = self .try_optimize(input, config)? @@ -207,7 +174,8 @@ impl CommonSubexprEliminate { config: &dyn OptimizerConfig, ) -> Result { let mut window_exprs = vec![]; - let mut expr_set = ExprSet::default(); + let mut arrays_per_window = vec![]; + let mut expr_set = ExprSet::new(); // Get all window expressions inside the consecutive window operators. // Consecutive window expressions may refer to same complex expression. @@ -226,18 +194,30 @@ impl CommonSubexprEliminate { plan = input.as_ref().clone(); let input_schema = Arc::clone(input.schema()); - expr_set.populate_expr_set(&window_expr, input_schema, ExprMask::Normal)?; + let arrays = + to_arrays(&window_expr, input_schema, &mut expr_set, ExprMask::Normal)?; window_exprs.push(window_expr); + arrays_per_window.push(arrays); } let mut window_exprs = window_exprs .iter() .map(|expr| expr.as_slice()) .collect::>(); + let arrays_per_window = arrays_per_window + .iter() + .map(|arrays| arrays.as_slice()) + .collect::>(); - let (mut new_expr, new_input) = - self.rewrite_expr(&window_exprs, &plan, &expr_set, config)?; + assert_eq!(window_exprs.len(), arrays_per_window.len()); + let (mut new_expr, new_input) = self.rewrite_expr( + &window_exprs, + &arrays_per_window, + &plan, + &expr_set, + config, + )?; assert_eq!(window_exprs.len(), new_expr.len()); // Construct consecutive window operator, with their corresponding new window expressions. @@ -274,36 +254,46 @@ impl CommonSubexprEliminate { input, .. } = aggregate; - let mut expr_set = ExprSet::default(); + let mut expr_set = ExprSet::new(); - // build expr_set, with groupby and aggr + // rewrite inputs let input_schema = Arc::clone(input.schema()); - expr_set.populate_expr_set( + let group_arrays = to_arrays( group_expr, Arc::clone(&input_schema), + &mut expr_set, ExprMask::Normal, )?; - expr_set.populate_expr_set(aggr_expr, input_schema, ExprMask::Normal)?; + let aggr_arrays = + to_arrays(aggr_expr, input_schema, &mut expr_set, ExprMask::Normal)?; - // rewrite inputs - let (mut new_expr, new_input) = - self.rewrite_expr(&[group_expr, aggr_expr], input, &expr_set, config)?; + let (mut new_expr, new_input) = self.rewrite_expr( + &[group_expr, aggr_expr], + &[&group_arrays, &aggr_arrays], + input, + &expr_set, + config, + )?; // note the reversed pop order. let new_aggr_expr = pop_expr(&mut new_expr)?; let new_group_expr = pop_expr(&mut new_expr)?; // create potential projection on top - let mut expr_set = ExprSet::default(); + let mut expr_set = ExprSet::new(); let new_input_schema = Arc::clone(new_input.schema()); - expr_set.populate_expr_set( + let aggr_arrays = to_arrays( &new_aggr_expr, new_input_schema.clone(), + &mut expr_set, ExprMask::NormalAndAggregates, )?; - let mut affected_id = BTreeSet::::new(); - let mut rewritten = - self.rewrite_exprs_list(&[&new_aggr_expr], &expr_set, &mut affected_id)?; + let mut rewritten = self.rewrite_exprs_list( + &[&new_aggr_expr], + &[&aggr_arrays], + &expr_set, + &mut affected_id, + )?; let rewritten = pop_expr(&mut rewritten)?; if affected_id.is_empty() { @@ -323,9 +313,9 @@ impl CommonSubexprEliminate { for id in affected_id { match expr_set.get(&id) { - Some((expr, _, _, symbol)) => { + Some((expr, _, _)) => { // todo: check `nullable` - agg_exprs.push(expr.clone().alias(symbol.as_str())); + agg_exprs.push(expr.clone().alias(&id)); } _ => { return internal_err!("expr_set invalid state"); @@ -343,7 +333,9 @@ impl CommonSubexprEliminate { agg_exprs.push(expr.alias(&name)); proj_exprs.push(Expr::Column(Column::from_name(name))); } else { - let id = ExprSet::expr_identifier(&expr_rewritten); + let id = ExprIdentifierVisitor::<'static>::expr_identifier( + &expr_rewritten, + ); let (qualifier, field) = expr_rewritten.to_field(&new_input_schema)?; let out_name = qualified_name(qualifier.as_ref(), field.name()); @@ -379,13 +371,13 @@ impl CommonSubexprEliminate { let inputs = plan.inputs(); let input = inputs[0]; let input_schema = Arc::clone(input.schema()); - let mut expr_set = ExprSet::default(); + let mut expr_set = ExprSet::new(); // Visit expr list and build expr identifier to occuring count map (`expr_set`). - expr_set.populate_expr_set(&expr, input_schema, ExprMask::Normal)?; + let arrays = to_arrays(&expr, input_schema, &mut expr_set, ExprMask::Normal)?; let (mut new_expr, new_input) = - self.rewrite_expr(&[&expr], input, &expr_set, config)?; + self.rewrite_expr(&[&expr], &[&arrays], input, &expr_set, config)?; plan.with_new_exprs(pop_expr(&mut new_expr)?, vec![new_input]) } @@ -471,6 +463,28 @@ fn pop_expr(new_expr: &mut Vec>) -> Result> { .ok_or_else(|| DataFusionError::Internal("Failed to pop expression".to_string())) } +fn to_arrays( + expr: &[Expr], + input_schema: DFSchemaRef, + expr_set: &mut ExprSet, + expr_mask: ExprMask, +) -> Result>> { + expr.iter() + .map(|e| { + let mut id_array = vec![]; + expr_to_identifier( + e, + expr_set, + &mut id_array, + Arc::clone(&input_schema), + expr_mask, + )?; + + Ok(id_array) + }) + .collect::>>() +} + /// Build the "intermediate" projection plan that evaluates the extracted common /// expressions. /// @@ -491,11 +505,11 @@ fn build_common_expr_project_plan( for id in affected_id { match expr_set.get(&id) { - Some((expr, _, data_type, symbol)) => { + Some((expr, _, data_type)) => { // todo: check `nullable` let field = Field::new(&id, data_type.clone(), true); fields_set.insert(field.name().to_owned()); - project_exprs.push(expr.clone().alias(symbol.as_str())); + project_exprs.push(expr.clone().alias(&id)); } _ => { return internal_err!("expr_set invalid state"); @@ -612,6 +626,8 @@ impl ExprMask { struct ExprIdentifierVisitor<'a> { // param expr_set: &'a mut ExprSet, + /// series number (usize) and identifier. + id_array: &'a mut IdArray, /// input schema for the node that we're optimizing, so we can determine the correct datatype /// for each subexpression input_schema: DFSchemaRef, @@ -619,6 +635,8 @@ struct ExprIdentifierVisitor<'a> { visit_stack: Vec, /// increased in fn_down, start from 0. node_count: usize, + /// increased in fn_up, start from 1. + series_number: usize, /// which expression should be skipped? expr_mask: ExprMask, } @@ -635,6 +653,10 @@ enum VisitRecord { } impl ExprIdentifierVisitor<'_> { + fn expr_identifier(expr: &Expr) -> Identifier { + format!("{expr}") + } + /// Find the first `EnterMark` in the stack, and accumulates every `ExprItem` /// before it. fn pop_enter_mark(&mut self) -> (usize, Identifier) { @@ -658,7 +680,10 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { type Node = Expr; fn f_down(&mut self, expr: &Expr) -> Result { - // related to https://github.com/apache/datafusion/issues/8814 + // put placeholder, sets the proper array length + self.id_array.push((0, "".to_string())); + + // related to https://github.com/apache/arrow-datafusion/issues/8814 // If the expr contain volatile expression or is a short-circuit expression, skip it. if expr.short_circuits() || expr.is_volatile()? { self.visit_stack @@ -674,39 +699,70 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { } fn f_up(&mut self, expr: &Expr) -> Result { - let (_idx, sub_expr_identifier) = self.pop_enter_mark(); + self.series_number += 1; + + let (idx, sub_expr_identifier) = self.pop_enter_mark(); // skip exprs should not be recognize. if self.expr_mask.ignores(expr) { - let curr_expr_identifier = ExprSet::expr_identifier(expr); + let curr_expr_identifier = Self::expr_identifier(expr); self.visit_stack .push(VisitRecord::ExprItem(curr_expr_identifier)); + self.id_array[idx].0 = self.series_number; // leave Identifer as empty "", since will not use as common expr return Ok(TreeNodeRecursion::Continue); } - let curr_expr_identifier = ExprSet::expr_identifier(expr); - let alias_symbol = format!("{curr_expr_identifier}{sub_expr_identifier}"); + let mut desc = Self::expr_identifier(expr); + desc.push_str(&sub_expr_identifier); - self.visit_stack - .push(VisitRecord::ExprItem(alias_symbol.clone())); + self.id_array[idx] = (self.series_number, desc.clone()); + self.visit_stack.push(VisitRecord::ExprItem(desc.clone())); let data_type = expr.get_type(&self.input_schema)?; self.expr_set - .entry(curr_expr_identifier) - .or_insert_with(|| (expr.clone(), 0, data_type, alias_symbol)) + .entry(desc) + .or_insert_with(|| (expr.clone(), 0, data_type)) .1 += 1; Ok(TreeNodeRecursion::Continue) } } -/// Rewrite expression by common sub-expression with a corresponding temporary -/// column name that will compute the subexpression. -/// -/// `affected_id` is updated with any sub expressions that were replaced +/// Go through an expression tree and generate identifier for every node in this tree. +fn expr_to_identifier( + expr: &Expr, + expr_set: &mut ExprSet, + id_array: &mut Vec<(usize, Identifier)>, + input_schema: DFSchemaRef, + expr_mask: ExprMask, +) -> Result<()> { + expr.visit(&mut ExprIdentifierVisitor { + expr_set, + id_array, + input_schema, + visit_stack: vec![], + node_count: 0, + series_number: 0, + expr_mask, + })?; + + Ok(()) +} + +/// Rewrite expression by replacing detected common sub-expression with +/// the corresponding temporary column name. That column contains the +/// evaluate result of replaced expression. struct CommonSubexprRewriter<'a> { expr_set: &'a ExprSet, + id_array: &'a IdArray, /// Which identifier is replaced. affected_id: &'a mut BTreeSet, + + /// the max series number we have rewritten. Other expression nodes + /// with smaller series number is already replaced and shouldn't + /// do anything with them. + max_series_number: usize, + /// current node's information's index in `id_array`. + curr_index: usize, } impl TreeNodeRewriter for CommonSubexprRewriter<'_> { @@ -720,29 +776,64 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)); } - let curr_id = &ExprSet::expr_identifier(&expr); + let (series_number, curr_id) = &self.id_array[self.curr_index]; + + // halting conditions + if self.curr_index >= self.id_array.len() + || self.max_series_number > *series_number + { + return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)); + } + + // skip `Expr`s without identifier (empty identifier). + if curr_id.is_empty() { + self.curr_index += 1; // incr idx for id_array, when not jumping + return Ok(Transformed::no(expr)); + } // lookup previously visited expression match self.expr_set.get(curr_id) { - Some((_, counter, _, symbol)) => { + Some((_, counter, _)) => { // if has a commonly used (a.k.a. 1+ use) expr if *counter > 1 { self.affected_id.insert(curr_id.clone()); + // This expr tree is finished. + if self.curr_index >= self.id_array.len() { + return Ok(Transformed::new( + expr, + false, + TreeNodeRecursion::Jump, + )); + } + + // incr idx for id_array, when not jumping + self.curr_index += 1; + + // series_number was the inverse number ordering (when doing f_up) + self.max_series_number = *series_number; + // step index to skip all sub-node (which has smaller series number). + while self.curr_index < self.id_array.len() + && *series_number > self.id_array[self.curr_index].0 + { + self.curr_index += 1; + } + let expr_name = expr.display_name()?; // Alias this `Column` expr to it original "expr name", // `projection_push_down` optimizer use "expr name" to eliminate useless // projections. Ok(Transformed::new( - col(symbol).alias(expr_name), + col(curr_id).alias(expr_name), true, TreeNodeRecursion::Jump, )) } else { + self.curr_index += 1; Ok(Transformed::no(expr)) } } - None => Ok(Transformed::no(expr)), + _ => internal_err!("expr_set invalid state"), } } } @@ -751,12 +842,16 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { /// column name, updating `affected_id` with any replaced expressions fn replace_common_expr( expr: Expr, + id_array: &IdArray, expr_set: &ExprSet, affected_id: &mut BTreeSet, ) -> Result { expr.rewrite(&mut CommonSubexprRewriter { expr_set, + id_array, affected_id, + max_series_number: 0, + curr_index: 0, }) .data() } @@ -789,6 +884,74 @@ mod test { assert_eq!(expected, formatted_plan); } + #[test] + fn id_array_visitor() -> Result<()> { + let expr = ((sum(col("a") + lit(1))) - avg(col("c"))) * lit(2); + + let schema = Arc::new(DFSchema::from_unqualifed_fields( + vec![ + Field::new("a", DataType::Int64, false), + Field::new("c", DataType::Int64, false), + ] + .into(), + Default::default(), + )?); + + // skip aggregates + let mut id_array = vec![]; + expr_to_identifier( + &expr, + &mut HashMap::new(), + &mut id_array, + Arc::clone(&schema), + ExprMask::Normal, + )?; + + let expected = vec![ + (9, "(SUM(a + Int32(1)) - AVG(c)) * Int32(2)Int32(2)SUM(a + Int32(1)) - AVG(c)AVG(c)SUM(a + Int32(1))"), + (7, "SUM(a + Int32(1)) - AVG(c)AVG(c)SUM(a + Int32(1))"), + (4, ""), + (3, "a + Int32(1)Int32(1)a"), + (1, ""), + (2, ""), + (6, ""), + (5, ""), + (8, "") + ] + .into_iter() + .map(|(number, id)| (number, id.into())) + .collect::>(); + assert_eq!(expected, id_array); + + // include aggregates + let mut id_array = vec![]; + expr_to_identifier( + &expr, + &mut HashMap::new(), + &mut id_array, + Arc::clone(&schema), + ExprMask::NormalAndAggregates, + )?; + + let expected = vec![ + (9, "(SUM(a + Int32(1)) - AVG(c)) * Int32(2)Int32(2)SUM(a + Int32(1)) - AVG(c)AVG(c)cSUM(a + Int32(1))a + Int32(1)Int32(1)a"), + (7, "SUM(a + Int32(1)) - AVG(c)AVG(c)cSUM(a + Int32(1))a + Int32(1)Int32(1)a"), + (4, "SUM(a + Int32(1))a + Int32(1)Int32(1)a"), + (3, "a + Int32(1)Int32(1)a"), + (1, ""), + (2, ""), + (6, "AVG(c)c"), + (5, ""), + (8, "") + ] + .into_iter() + .map(|(number, id)| (number, id.into())) + .collect::>(); + assert_eq!(expected, id_array); + + Ok(()) + } + #[test] fn tpch_q1_simplified() -> Result<()> { // SQL: @@ -1033,28 +1196,24 @@ mod test { let table_scan = test_table_scan().unwrap(); let affected_id: BTreeSet = ["c+a".to_string(), "b+a".to_string()].into_iter().collect(); - let expr_set_1 = vec![ + let expr_set_1 = [ ( "c+a".to_string(), - (col("c") + col("a"), 1, DataType::UInt32, "c+a".to_string()), + (col("c") + col("a"), 1, DataType::UInt32), ), ( "b+a".to_string(), - (col("b") + col("a"), 1, DataType::UInt32, "b+a".to_string()), + (col("b") + col("a"), 1, DataType::UInt32), ), ] - .into(); - let expr_set_2 = vec![ - ( - "c+a".to_string(), - (col("c+a"), 1, DataType::UInt32, "c+a".to_string()), - ), - ( - "b+a".to_string(), - (col("b+a"), 1, DataType::UInt32, "b+a".to_string()), - ), + .into_iter() + .collect(); + let expr_set_2 = [ + ("c+a".to_string(), (col("c+a"), 1, DataType::UInt32)), + ("b+a".to_string(), (col("b+a"), 1, DataType::UInt32)), ] - .into(); + .into_iter() + .collect(); let project = build_common_expr_project_plan(table_scan, affected_id.clone(), &expr_set_1) .unwrap(); @@ -1080,48 +1239,30 @@ mod test { ["test1.c+test1.a".to_string(), "test1.b+test1.a".to_string()] .into_iter() .collect(); - let expr_set_1 = vec![ + let expr_set_1 = [ ( "test1.c+test1.a".to_string(), - ( - col("test1.c") + col("test1.a"), - 1, - DataType::UInt32, - "test1.c+test1.a".to_string(), - ), + (col("test1.c") + col("test1.a"), 1, DataType::UInt32), ), ( "test1.b+test1.a".to_string(), - ( - col("test1.b") + col("test1.a"), - 1, - DataType::UInt32, - "test1.b+test1.a".to_string(), - ), + (col("test1.b") + col("test1.a"), 1, DataType::UInt32), ), ] - .into(); - let expr_set_2 = vec![ + .into_iter() + .collect(); + let expr_set_2 = [ ( "test1.c+test1.a".to_string(), - ( - col("test1.c+test1.a"), - 1, - DataType::UInt32, - "test1.c+test1.a".to_string(), - ), + (col("test1.c+test1.a"), 1, DataType::UInt32), ), ( "test1.b+test1.a".to_string(), - ( - col("test1.b+test1.a"), - 1, - DataType::UInt32, - "test1.b+test1.a".to_string(), - ), + (col("test1.b+test1.a"), 1, DataType::UInt32), ), ] - .into(); + .into_iter() + .collect(); let project = build_common_expr_project_plan(join, affected_id.clone(), &expr_set_1) .unwrap();