Skip to content

Commit

Permalink
Support unparsing plans with both Aggregation and Window functions (a…
Browse files Browse the repository at this point in the history
…pache#12705)

* Support unparsing plans with both Aggregation and Window functions (#35)

* Fix unparsing for aggregation grouping sets

* Add test for grouping set unparsing

* Update datafusion/sql/src/unparser/utils.rs

Co-authored-by: Jax Liu <[email protected]>

* Update datafusion/sql/src/unparser/utils.rs

Co-authored-by: Jax Liu <[email protected]>

* Update

* More tests

---------

Co-authored-by: Jax Liu <[email protected]>
  • Loading branch information
sgrebnov and goldmedal authored Oct 3, 2024
1 parent cfd861c commit 0b2b4fb
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 46 deletions.
25 changes: 16 additions & 9 deletions datafusion/sql/src/unparser/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ use super::{
rewrite_plan_for_sort_on_non_projected_fields,
subquery_alias_inner_query_and_columns, TableAliasRewriter,
},
utils::{find_agg_node_within_select, unproject_window_exprs, AggVariant},
utils::{
find_agg_node_within_select, find_window_nodes_within_select,
unproject_window_exprs,
},
Unparser,
};

Expand Down Expand Up @@ -172,13 +175,17 @@ impl Unparser<'_> {
p: &Projection,
select: &mut SelectBuilder,
) -> Result<()> {
match find_agg_node_within_select(plan, None, true) {
Some(AggVariant::Aggregate(agg)) => {
match (
find_agg_node_within_select(plan, true),
find_window_nodes_within_select(plan, None, true),
) {
(Some(agg), window) => {
let window_option = window.as_deref();
let items = p
.expr
.iter()
.map(|proj_expr| {
let unproj = unproject_agg_exprs(proj_expr, agg)?;
let unproj = unproject_agg_exprs(proj_expr, agg, window_option)?;
self.select_item_to_sql(&unproj)
})
.collect::<Result<Vec<_>>>()?;
Expand All @@ -192,7 +199,7 @@ impl Unparser<'_> {
vec![],
));
}
Some(AggVariant::Window(window)) => {
(None, Some(window)) => {
let items = p
.expr
.iter()
Expand All @@ -204,7 +211,7 @@ impl Unparser<'_> {

select.projection(items);
}
None => {
_ => {
let items = p
.expr
.iter()
Expand Down Expand Up @@ -287,10 +294,10 @@ impl Unparser<'_> {
self.select_to_sql_recursively(p.input.as_ref(), query, select, relation)
}
LogicalPlan::Filter(filter) => {
if let Some(AggVariant::Aggregate(agg)) =
find_agg_node_within_select(plan, None, select.already_projected())
if let Some(agg) =
find_agg_node_within_select(plan, select.already_projected())
{
let unprojected = unproject_agg_exprs(&filter.predicate, agg)?;
let unprojected = unproject_agg_exprs(&filter.predicate, agg, None)?;
let filter_expr = self.expr_to_sql(&unprojected)?;
select.having(Some(filter_expr));
} else {
Expand Down
134 changes: 97 additions & 37 deletions datafusion/sql/src/unparser/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,58 +18,83 @@
use datafusion_common::{
internal_err,
tree_node::{Transformed, TreeNode},
Result,
Column, DataFusionError, Result,
};
use datafusion_expr::{
utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Window,
};
use datafusion_expr::{Aggregate, Expr, LogicalPlan, Window};

/// One of the possible aggregation plans which can be found within a single select query.
pub(crate) enum AggVariant<'a> {
Aggregate(&'a Aggregate),
Window(Vec<&'a Window>),
/// Recursively searches children of [LogicalPlan] to find an Aggregate node if exists
/// prior to encountering a Join, TableScan, or a nested subquery (derived table factor).
/// If an Aggregate or node is not found prior to this or at all before reaching the end
/// of the tree, None is returned.
pub(crate) fn find_agg_node_within_select(
plan: &LogicalPlan,
already_projected: bool,
) -> Option<&Aggregate> {
// Note that none of the nodes that have a corresponding node can have more
// than 1 input node. E.g. Projection / Filter always have 1 input node.
let input = plan.inputs();
let input = if input.len() > 1 {
return None;
} else {
input.first()?
};
// Agg nodes explicitly return immediately with a single node
if let LogicalPlan::Aggregate(agg) = input {
Some(agg)
} else if let LogicalPlan::TableScan(_) = input {
None
} else if let LogicalPlan::Projection(_) = input {
if already_projected {
None
} else {
find_agg_node_within_select(input, true)
}
} else {
find_agg_node_within_select(input, already_projected)
}
}

/// Recursively searches children of [LogicalPlan] to find an Aggregate or window node if one exists
/// Recursively searches children of [LogicalPlan] to find Window nodes if exist
/// prior to encountering a Join, TableScan, or a nested subquery (derived table factor).
/// If an Aggregate or window node is not found prior to this or at all before reaching the end
/// of the tree, None is returned. It is assumed that a Window and Aggregate node cannot both
/// be found in a single select query.
pub(crate) fn find_agg_node_within_select<'a>(
/// If Window node is not found prior to this or at all before reaching the end
/// of the tree, None is returned.
pub(crate) fn find_window_nodes_within_select<'a>(
plan: &'a LogicalPlan,
mut prev_windows: Option<AggVariant<'a>>,
mut prev_windows: Option<Vec<&'a Window>>,
already_projected: bool,
) -> Option<AggVariant<'a>> {
// Note that none of the nodes that have a corresponding agg node can have more
) -> Option<Vec<&'a Window>> {
// Note that none of the nodes that have a corresponding node can have more
// than 1 input node. E.g. Projection / Filter always have 1 input node.
let input = plan.inputs();
let input = if input.len() > 1 {
return None;
return prev_windows;
} else {
input.first()?
};

// Agg nodes explicitly return immediately with a single node
// Window nodes accumulate in a vec until encountering a TableScan or 2nd projection
match input {
LogicalPlan::Aggregate(agg) => Some(AggVariant::Aggregate(agg)),
LogicalPlan::Window(window) => {
prev_windows = match &mut prev_windows {
Some(AggVariant::Window(windows)) => {
Some(windows) => {
windows.push(window);
prev_windows
}
_ => Some(AggVariant::Window(vec![window])),
_ => Some(vec![window]),
};
find_agg_node_within_select(input, prev_windows, already_projected)
find_window_nodes_within_select(input, prev_windows, already_projected)
}
LogicalPlan::Projection(_) => {
if already_projected {
prev_windows
} else {
find_agg_node_within_select(input, prev_windows, true)
find_window_nodes_within_select(input, prev_windows, true)
}
}
LogicalPlan::TableScan(_) => prev_windows,
_ => find_agg_node_within_select(input, prev_windows, already_projected),
_ => find_window_nodes_within_select(input, prev_windows, already_projected),
}
}

Expand All @@ -78,22 +103,34 @@ pub(crate) fn find_agg_node_within_select<'a>(
///
/// For example, if expr contains the column expr "COUNT(*)" it will be transformed
/// into an actual aggregate expression COUNT(*) as identified in the aggregate node.
pub(crate) fn unproject_agg_exprs(expr: &Expr, agg: &Aggregate) -> Result<Expr> {
pub(crate) fn unproject_agg_exprs(
expr: &Expr,
agg: &Aggregate,
windows: Option<&[&Window]>,
) -> Result<Expr> {
expr.clone()
.transform(|sub_expr| {
if let Expr::Column(c) = sub_expr {
// find the column in the agg schema
if let Ok(n) = agg.schema.index_of_column(&c) {
let unprojected_expr = agg
.group_expr
.iter()
.chain(agg.aggr_expr.iter())
.nth(n)
.unwrap();
if let Some(unprojected_expr) = find_agg_expr(agg, &c)? {
Ok(Transformed::yes(unprojected_expr.clone()))
} else if let Some(mut unprojected_expr) =
windows.and_then(|w| find_window_expr(w, &c.name).cloned())
{
if let Expr::WindowFunction(func) = &mut unprojected_expr {
// Window function can contain an aggregation column, e.g., 'avg(sum(ss_sales_price)) over ...' that needs to be unprojected
func.args.iter_mut().try_for_each(|arg| {
if let Expr::Column(c) = arg {
if let Some(expr) = find_agg_expr(agg, c)? {
*arg = expr.clone();
}
}
Ok::<(), DataFusionError>(())
})?;
}
Ok(Transformed::yes(unprojected_expr))
} else {
internal_err!(
"Tried to unproject agg expr not found in provided Aggregate!"
"Tried to unproject agg expr for column '{}' that was not found in the provided Aggregate!", &c.name
)
}
} else {
Expand All @@ -112,11 +149,7 @@ pub(crate) fn unproject_window_exprs(expr: &Expr, windows: &[&Window]) -> Result
expr.clone()
.transform(|sub_expr| {
if let Expr::Column(c) = sub_expr {
if let Some(unproj) = windows
.iter()
.flat_map(|w| w.window_expr.iter())
.find(|window_expr| window_expr.schema_name().to_string() == c.name)
{
if let Some(unproj) = find_window_expr(windows, &c.name) {
Ok(Transformed::yes(unproj.clone()))
} else {
Ok(Transformed::no(Expr::Column(c)))
Expand All @@ -127,3 +160,30 @@ pub(crate) fn unproject_window_exprs(expr: &Expr, windows: &[&Window]) -> Result
})
.map(|e| e.data)
}

fn find_agg_expr<'a>(agg: &'a Aggregate, column: &Column) -> Result<Option<&'a Expr>> {
if let Ok(index) = agg.schema.index_of_column(column) {
if matches!(agg.group_expr.as_slice(), [Expr::GroupingSet(_)]) {
// For grouping set expr, we must operate by expression list from the grouping set
let grouping_expr = grouping_set_to_exprlist(agg.group_expr.as_slice())?;
Ok(grouping_expr
.into_iter()
.chain(agg.aggr_expr.iter())
.nth(index))
} else {
Ok(agg.group_expr.iter().chain(agg.aggr_expr.iter()).nth(index))
}
} else {
Ok(None)
}
}

fn find_window_expr<'a>(
windows: &'a [&'a Window],
column_name: &'a str,
) -> Option<&'a Expr> {
windows
.iter()
.flat_map(|w| w.window_expr.iter())
.find(|expr| expr.schema_name().to_string() == column_name)
}
21 changes: 21 additions & 0 deletions datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,26 @@ fn roundtrip_statement() -> Result<()> {
"WITH w1 AS (SELECT 'a' as col), w2 AS (SELECT 'b' as col), w3 as (SELECT 'c' as col) SELECT * FROM w1 UNION ALL SELECT * FROM w2 UNION ALL SELECT * FROM w3",
"WITH w1 AS (SELECT 'a' as col), w2 AS (SELECT 'b' as col), w3 as (SELECT 'c' as col), w4 as (SELECT 'd' as col) SELECT * FROM w1 UNION ALL SELECT * FROM w2 UNION ALL SELECT * FROM w3 UNION ALL SELECT * FROM w4",
"WITH w1 AS (SELECT 'a' as col), w2 AS (SELECT 'b' as col) SELECT * FROM w1 JOIN w2 ON w1.col = w2.col UNION ALL SELECT * FROM w1 JOIN w2 ON w1.col = w2.col UNION ALL SELECT * FROM w1 JOIN w2 ON w1.col = w2.col",
r#"SELECT id, first_name,
SUM(id) AS total_sum,
SUM(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS moving_sum,
MAX(SUM(id)) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total
FROM person JOIN orders ON person.id = orders.customer_id GROUP BY id, first_name"#,
r#"SELECT id, first_name,
SUM(id) AS total_sum,
SUM(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS moving_sum,
MAX(SUM(id)) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total
FROM (SELECT id, first_name from person) person JOIN (SELECT customer_id FROM orders) orders ON person.id = orders.customer_id GROUP BY id, first_name"#,
r#"SELECT id, first_name, last_name, customer_id, SUM(id) AS total_sum
FROM person
JOIN orders ON person.id = orders.customer_id
GROUP BY ROLLUP(id, first_name, last_name, customer_id)"#,
r#"SELECT id, first_name, last_name,
SUM(id) AS total_sum,
COUNT(*) AS total_count,
SUM(id) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS running_total
FROM person
GROUP BY GROUPING SETS ((id, first_name, last_name), (first_name, last_name), (last_name))"#,
];

// For each test sql string, we transform as follows:
Expand All @@ -164,6 +184,7 @@ fn roundtrip_statement() -> Result<()> {
let state = MockSessionState::default()
.with_aggregate_function(sum_udaf())
.with_aggregate_function(count_udaf())
.with_aggregate_function(max_udaf())
.with_expr_planner(Arc::new(CoreFunctionPlanner::default()));
let context = MockContextProvider { state };
let sql_to_rel = SqlToRel::new(&context);
Expand Down

0 comments on commit 0b2b4fb

Please sign in to comment.