From 0b2b4fb2763316e4c946e7bd12e06afc439f1480 Mon Sep 17 00:00:00 2001 From: Sergei Grebnov Date: Thu, 3 Oct 2024 07:52:40 -0700 Subject: [PATCH] Support unparsing plans with both Aggregation and Window functions (#12705) * Support unparsing plans with both Aggregation and Window functions (#35) * Fix unparsing for aggregation grouping sets * Add test for grouping set unparsing * Update datafusion/sql/src/unparser/utils.rs Co-authored-by: Jax Liu * Update datafusion/sql/src/unparser/utils.rs Co-authored-by: Jax Liu * Update * More tests --------- Co-authored-by: Jax Liu --- datafusion/sql/src/unparser/plan.rs | 25 ++-- datafusion/sql/src/unparser/utils.rs | 134 ++++++++++++++++------ datafusion/sql/tests/cases/plan_to_sql.rs | 21 ++++ 3 files changed, 134 insertions(+), 46 deletions(-) diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index a76e26aa7d98..c4fcbb2d6458 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -38,7 +38,10 @@ use super::{ rewrite_plan_for_sort_on_non_projected_fields, subquery_alias_inner_query_and_columns, TableAliasRewriter, }, - utils::{find_agg_node_within_select, unproject_window_exprs, AggVariant}, + utils::{ + find_agg_node_within_select, find_window_nodes_within_select, + unproject_window_exprs, + }, Unparser, }; @@ -172,13 +175,17 @@ impl Unparser<'_> { p: &Projection, select: &mut SelectBuilder, ) -> Result<()> { - match find_agg_node_within_select(plan, None, true) { - Some(AggVariant::Aggregate(agg)) => { + match ( + find_agg_node_within_select(plan, true), + find_window_nodes_within_select(plan, None, true), + ) { + (Some(agg), window) => { + let window_option = window.as_deref(); let items = p .expr .iter() .map(|proj_expr| { - let unproj = unproject_agg_exprs(proj_expr, agg)?; + let unproj = unproject_agg_exprs(proj_expr, agg, window_option)?; self.select_item_to_sql(&unproj) }) .collect::>>()?; @@ -192,7 +199,7 @@ impl Unparser<'_> { vec![], )); } - Some(AggVariant::Window(window)) => { + (None, Some(window)) => { let items = p .expr .iter() @@ -204,7 +211,7 @@ impl Unparser<'_> { select.projection(items); } - None => { + _ => { let items = p .expr .iter() @@ -287,10 +294,10 @@ impl Unparser<'_> { self.select_to_sql_recursively(p.input.as_ref(), query, select, relation) } LogicalPlan::Filter(filter) => { - if let Some(AggVariant::Aggregate(agg)) = - find_agg_node_within_select(plan, None, select.already_projected()) + if let Some(agg) = + find_agg_node_within_select(plan, select.already_projected()) { - let unprojected = unproject_agg_exprs(&filter.predicate, agg)?; + let unprojected = unproject_agg_exprs(&filter.predicate, agg, None)?; let filter_expr = self.expr_to_sql(&unprojected)?; select.having(Some(filter_expr)); } else { diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index c1b3fe18f7e7..0059aba25738 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -18,58 +18,83 @@ use datafusion_common::{ internal_err, tree_node::{Transformed, TreeNode}, - Result, + Column, DataFusionError, Result, +}; +use datafusion_expr::{ + utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Window, }; -use datafusion_expr::{Aggregate, Expr, LogicalPlan, Window}; -/// One of the possible aggregation plans which can be found within a single select query. -pub(crate) enum AggVariant<'a> { - Aggregate(&'a Aggregate), - Window(Vec<&'a Window>), +/// Recursively searches children of [LogicalPlan] to find an Aggregate node if exists +/// prior to encountering a Join, TableScan, or a nested subquery (derived table factor). +/// If an Aggregate or node is not found prior to this or at all before reaching the end +/// of the tree, None is returned. +pub(crate) fn find_agg_node_within_select( + plan: &LogicalPlan, + already_projected: bool, +) -> Option<&Aggregate> { + // Note that none of the nodes that have a corresponding node can have more + // than 1 input node. E.g. Projection / Filter always have 1 input node. + let input = plan.inputs(); + let input = if input.len() > 1 { + return None; + } else { + input.first()? + }; + // Agg nodes explicitly return immediately with a single node + if let LogicalPlan::Aggregate(agg) = input { + Some(agg) + } else if let LogicalPlan::TableScan(_) = input { + None + } else if let LogicalPlan::Projection(_) = input { + if already_projected { + None + } else { + find_agg_node_within_select(input, true) + } + } else { + find_agg_node_within_select(input, already_projected) + } } -/// Recursively searches children of [LogicalPlan] to find an Aggregate or window node if one exists +/// Recursively searches children of [LogicalPlan] to find Window nodes if exist /// prior to encountering a Join, TableScan, or a nested subquery (derived table factor). -/// If an Aggregate or window node is not found prior to this or at all before reaching the end -/// of the tree, None is returned. It is assumed that a Window and Aggregate node cannot both -/// be found in a single select query. -pub(crate) fn find_agg_node_within_select<'a>( +/// If Window node is not found prior to this or at all before reaching the end +/// of the tree, None is returned. +pub(crate) fn find_window_nodes_within_select<'a>( plan: &'a LogicalPlan, - mut prev_windows: Option>, + mut prev_windows: Option>, already_projected: bool, -) -> Option> { - // Note that none of the nodes that have a corresponding agg node can have more +) -> Option> { + // Note that none of the nodes that have a corresponding node can have more // than 1 input node. E.g. Projection / Filter always have 1 input node. let input = plan.inputs(); let input = if input.len() > 1 { - return None; + return prev_windows; } else { input.first()? }; - // Agg nodes explicitly return immediately with a single node // Window nodes accumulate in a vec until encountering a TableScan or 2nd projection match input { - LogicalPlan::Aggregate(agg) => Some(AggVariant::Aggregate(agg)), LogicalPlan::Window(window) => { prev_windows = match &mut prev_windows { - Some(AggVariant::Window(windows)) => { + Some(windows) => { windows.push(window); prev_windows } - _ => Some(AggVariant::Window(vec![window])), + _ => Some(vec![window]), }; - find_agg_node_within_select(input, prev_windows, already_projected) + find_window_nodes_within_select(input, prev_windows, already_projected) } LogicalPlan::Projection(_) => { if already_projected { prev_windows } else { - find_agg_node_within_select(input, prev_windows, true) + find_window_nodes_within_select(input, prev_windows, true) } } LogicalPlan::TableScan(_) => prev_windows, - _ => find_agg_node_within_select(input, prev_windows, already_projected), + _ => find_window_nodes_within_select(input, prev_windows, already_projected), } } @@ -78,22 +103,34 @@ pub(crate) fn find_agg_node_within_select<'a>( /// /// For example, if expr contains the column expr "COUNT(*)" it will be transformed /// into an actual aggregate expression COUNT(*) as identified in the aggregate node. -pub(crate) fn unproject_agg_exprs(expr: &Expr, agg: &Aggregate) -> Result { +pub(crate) fn unproject_agg_exprs( + expr: &Expr, + agg: &Aggregate, + windows: Option<&[&Window]>, +) -> Result { expr.clone() .transform(|sub_expr| { if let Expr::Column(c) = sub_expr { - // find the column in the agg schema - if let Ok(n) = agg.schema.index_of_column(&c) { - let unprojected_expr = agg - .group_expr - .iter() - .chain(agg.aggr_expr.iter()) - .nth(n) - .unwrap(); + if let Some(unprojected_expr) = find_agg_expr(agg, &c)? { Ok(Transformed::yes(unprojected_expr.clone())) + } else if let Some(mut unprojected_expr) = + windows.and_then(|w| find_window_expr(w, &c.name).cloned()) + { + if let Expr::WindowFunction(func) = &mut unprojected_expr { + // Window function can contain an aggregation column, e.g., 'avg(sum(ss_sales_price)) over ...' that needs to be unprojected + func.args.iter_mut().try_for_each(|arg| { + if let Expr::Column(c) = arg { + if let Some(expr) = find_agg_expr(agg, c)? { + *arg = expr.clone(); + } + } + Ok::<(), DataFusionError>(()) + })?; + } + Ok(Transformed::yes(unprojected_expr)) } else { internal_err!( - "Tried to unproject agg expr not found in provided Aggregate!" + "Tried to unproject agg expr for column '{}' that was not found in the provided Aggregate!", &c.name ) } } else { @@ -112,11 +149,7 @@ pub(crate) fn unproject_window_exprs(expr: &Expr, windows: &[&Window]) -> Result expr.clone() .transform(|sub_expr| { if let Expr::Column(c) = sub_expr { - if let Some(unproj) = windows - .iter() - .flat_map(|w| w.window_expr.iter()) - .find(|window_expr| window_expr.schema_name().to_string() == c.name) - { + if let Some(unproj) = find_window_expr(windows, &c.name) { Ok(Transformed::yes(unproj.clone())) } else { Ok(Transformed::no(Expr::Column(c))) @@ -127,3 +160,30 @@ pub(crate) fn unproject_window_exprs(expr: &Expr, windows: &[&Window]) -> Result }) .map(|e| e.data) } + +fn find_agg_expr<'a>(agg: &'a Aggregate, column: &Column) -> Result> { + if let Ok(index) = agg.schema.index_of_column(column) { + if matches!(agg.group_expr.as_slice(), [Expr::GroupingSet(_)]) { + // For grouping set expr, we must operate by expression list from the grouping set + let grouping_expr = grouping_set_to_exprlist(agg.group_expr.as_slice())?; + Ok(grouping_expr + .into_iter() + .chain(agg.aggr_expr.iter()) + .nth(index)) + } else { + Ok(agg.group_expr.iter().chain(agg.aggr_expr.iter()).nth(index)) + } + } else { + Ok(None) + } +} + +fn find_window_expr<'a>( + windows: &'a [&'a Window], + column_name: &'a str, +) -> Option<&'a Expr> { + windows + .iter() + .flat_map(|w| w.window_expr.iter()) + .find(|expr| expr.schema_name().to_string() == column_name) +} diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 49f4720ed137..903d4e28520b 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -149,6 +149,26 @@ fn roundtrip_statement() -> Result<()> { "WITH w1 AS (SELECT 'a' as col), w2 AS (SELECT 'b' as col), w3 as (SELECT 'c' as col) SELECT * FROM w1 UNION ALL SELECT * FROM w2 UNION ALL SELECT * FROM w3", "WITH w1 AS (SELECT 'a' as col), w2 AS (SELECT 'b' as col), w3 as (SELECT 'c' as col), w4 as (SELECT 'd' as col) SELECT * FROM w1 UNION ALL SELECT * FROM w2 UNION ALL SELECT * FROM w3 UNION ALL SELECT * FROM w4", "WITH w1 AS (SELECT 'a' as col), w2 AS (SELECT 'b' as col) SELECT * FROM w1 JOIN w2 ON w1.col = w2.col UNION ALL SELECT * FROM w1 JOIN w2 ON w1.col = w2.col UNION ALL SELECT * FROM w1 JOIN w2 ON w1.col = w2.col", + r#"SELECT id, first_name, + SUM(id) AS total_sum, + SUM(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS moving_sum, + MAX(SUM(id)) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total + FROM person JOIN orders ON person.id = orders.customer_id GROUP BY id, first_name"#, + r#"SELECT id, first_name, + SUM(id) AS total_sum, + SUM(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS moving_sum, + MAX(SUM(id)) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total + FROM (SELECT id, first_name from person) person JOIN (SELECT customer_id FROM orders) orders ON person.id = orders.customer_id GROUP BY id, first_name"#, + r#"SELECT id, first_name, last_name, customer_id, SUM(id) AS total_sum + FROM person + JOIN orders ON person.id = orders.customer_id + GROUP BY ROLLUP(id, first_name, last_name, customer_id)"#, + r#"SELECT id, first_name, last_name, + SUM(id) AS total_sum, + COUNT(*) AS total_count, + SUM(id) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS running_total + FROM person + GROUP BY GROUPING SETS ((id, first_name, last_name), (first_name, last_name), (last_name))"#, ]; // For each test sql string, we transform as follows: @@ -164,6 +184,7 @@ fn roundtrip_statement() -> Result<()> { let state = MockSessionState::default() .with_aggregate_function(sum_udaf()) .with_aggregate_function(count_udaf()) + .with_aggregate_function(max_udaf()) .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); let context = MockContextProvider { state }; let sql_to_rel = SqlToRel::new(&context);