Skip to content

Commit

Permalink
refactor: use the only reproducible key (expr_identifer) for expr_set…
Browse files Browse the repository at this point in the history
…, while keeping the (stack-popped) symbol used for alias.
  • Loading branch information
wiedld committed Mar 28, 2024
1 parent d59a8de commit 049bf09
Showing 1 changed file with 50 additions and 19 deletions.
69 changes: 50 additions & 19 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,15 @@ use datafusion_expr::logical_plan::{
use datafusion_expr::{col, Expr, ExprSchemable};

/// A map from expression's identifier to tuple including
///
/// key == Identifier created with only the current node (& subtree)
///
/// values:
/// - the expression itself (cloned)
/// - counter
/// - DataType of this expression.
type ExprSet = HashMap<Identifier, (Expr, usize, DataType)>;
/// - symbol used as the identifier in the alias
type ExprSet = HashMap<Identifier, (Expr, usize, DataType, Identifier)>;

/// Identifier for each subexpression.
///
Expand Down Expand Up @@ -278,9 +283,9 @@ impl CommonSubexprEliminate {

for id in affected_id {
match expr_set.get(&id) {
Some((expr, _, _)) => {
Some((expr, _, _, symbol)) => {
// todo: check `nullable`
agg_exprs.push(expr.clone().alias(&id));
agg_exprs.push(expr.clone().alias(symbol.as_str()));
}
_ => {
return internal_err!("expr_set invalid state");
Expand Down Expand Up @@ -455,11 +460,11 @@ fn build_common_expr_project_plan(

for id in affected_id {
match expr_set.get(&id) {
Some((expr, _, data_type)) => {
Some((expr, _, data_type, symbol)) => {
// todo: check `nullable`
let field = DFField::new_unqualified(&id, data_type.clone(), true);
fields_set.insert(field.name().to_owned());
project_exprs.push(expr.clone().alias(&id));
project_exprs.push(expr.clone().alias(symbol.as_str()));
}
_ => {
return internal_err!("expr_set invalid state");
Expand Down Expand Up @@ -650,16 +655,16 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> {
.push(VisitRecord::ExprItem(curr_expr_identifier));
return Ok(TreeNodeRecursion::Continue);
}
let mut desc = Self::expr_identifier(expr);
desc.push_str(&sub_expr_identifier);
let curr_expr_identifier = Self::expr_identifier(expr);
let desc = format!("{curr_expr_identifier}{sub_expr_identifier}");

self.visit_stack.push(VisitRecord::ExprItem(desc.clone()));

let data_type = expr.get_type(&self.input_schema)?;

self.expr_set
.entry(desc)
.or_insert_with(|| (expr.clone(), 0, data_type))
.entry(curr_expr_identifier)
.or_insert_with(|| (expr.clone(), 0, data_type, desc))
.1 += 1;
Ok(TreeNodeRecursion::Continue)
}
Expand Down Expand Up @@ -713,7 +718,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> {

// lookup previously visited expression
match self.expr_set.get(curr_id) {
Some((_, counter, _)) => {
Some((_, counter, _, symbol)) => {
// if has a commonly used (a.k.a. 1+ use) expr
if *counter > 1 {
self.affected_id.insert(curr_id.clone());
Expand All @@ -723,7 +728,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> {
// `projection_push_down` optimizer use "expr name" to eliminate useless
// projections.
Ok(Transformed::new(
col(curr_id).alias(expr_name),
col(symbol).alias(expr_name),
true,
TreeNodeRecursion::Jump,
))
Expand Down Expand Up @@ -1026,18 +1031,24 @@ mod test {
let expr_set_1 = [
(
"c+a".to_string(),
(col("c") + col("a"), 1, DataType::UInt32),
(col("c") + col("a"), 1, DataType::UInt32, "c+a".to_string()),
),
(
"b+a".to_string(),
(col("b") + col("a"), 1, DataType::UInt32),
(col("b") + col("a"), 1, DataType::UInt32, "b+a".to_string()),
),
]
.into_iter()
.collect();
let expr_set_2 = [
("c+a".to_string(), (col("c+a"), 1, DataType::UInt32)),
("b+a".to_string(), (col("b+a"), 1, DataType::UInt32)),
(
"c+a".to_string(),
(col("c+a"), 1, DataType::UInt32, "c+a".to_string()),
),
(
"b+a".to_string(),
(col("b+a"), 1, DataType::UInt32, "b+a".to_string()),
),
]
.into_iter()
.collect();
Expand Down Expand Up @@ -1069,23 +1080,43 @@ mod test {
let expr_set_1 = [
(
"test1.c+test1.a".to_string(),
(col("test1.c") + col("test1.a"), 1, DataType::UInt32),
(
col("test1.c") + col("test1.a"),
1,
DataType::UInt32,
"test1.c+test1.a".to_string(),
),
),
(
"test1.b+test1.a".to_string(),
(col("test1.b") + col("test1.a"), 1, DataType::UInt32),
(
col("test1.b") + col("test1.a"),
1,
DataType::UInt32,
"test1.b+test1.a".to_string(),
),
),
]
.into_iter()
.collect();
let expr_set_2 = [
(
"test1.c+test1.a".to_string(),
(col("test1.c+test1.a"), 1, DataType::UInt32),
(
col("test1.c+test1.a"),
1,
DataType::UInt32,
"test1.c+test1.a".to_string(),
),
),
(
"test1.b+test1.a".to_string(),
(col("test1.b+test1.a"), 1, DataType::UInt32),
(
col("test1.b+test1.a"),
1,
DataType::UInt32,
"test1.b+test1.a".to_string(),
),
),
]
.into_iter()
Expand Down

0 comments on commit 049bf09

Please sign in to comment.