From ab364d3ab2f7fa1d275c82b468869f5573432f86 Mon Sep 17 00:00:00 2001 From: Sergei Grebnov Date: Thu, 5 Sep 2024 07:04:04 +0300 Subject: [PATCH] Fix subquery alias table definition unparsing for SQLite --- datafusion/sql/src/unparser/dialect.rs | 29 ++++++++++++++++++ datafusion/sql/src/unparser/plan.rs | 32 +++++++++++++++++--- datafusion/sql/src/unparser/rewrite.rs | 37 ++++++++++++++++++++++- datafusion/sql/tests/cases/plan_to_sql.rs | 16 ++++++++-- 4 files changed, 106 insertions(+), 8 deletions(-) diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index 2a8e61add1d0..d8a4fb254264 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -102,6 +102,12 @@ pub trait Dialect: Send + Sync { fn date32_cast_dtype(&self) -> sqlparser::ast::DataType { sqlparser::ast::DataType::Date } + + /// Does the dialect support specifying column aliases as part of alias table definition? + /// (SELECT col1, col2 from my_table) AS my_table_alias(col1_alias, col2_alias) + fn supports_column_alias_in_table_alias(&self) -> bool { + true + } } /// `IntervalStyle` to use for unparsing @@ -221,6 +227,10 @@ impl Dialect for SqliteDialect { fn date32_cast_dtype(&self) -> sqlparser::ast::DataType { sqlparser::ast::DataType::Text } + + fn supports_column_alias_in_table_alias(&self) -> bool { + false + } } pub struct CustomDialect { @@ -236,6 +246,7 @@ pub struct CustomDialect { timestamp_cast_dtype: ast::DataType, timestamp_tz_cast_dtype: ast::DataType, date32_cast_dtype: sqlparser::ast::DataType, + supports_column_alias_in_table_alias: bool, } impl Default for CustomDialect { @@ -256,6 +267,7 @@ impl Default for CustomDialect { TimezoneInfo::WithTimeZone, ), date32_cast_dtype: sqlparser::ast::DataType::Date, + supports_column_alias_in_table_alias: true, } } } @@ -323,6 +335,10 @@ impl Dialect for CustomDialect { fn date32_cast_dtype(&self) -> sqlparser::ast::DataType { self.date32_cast_dtype.clone() } + + fn supports_column_alias_in_table_alias(&self) -> bool { + self.supports_column_alias_in_table_alias + } } /// `CustomDialectBuilder` to build `CustomDialect` using builder pattern @@ -352,6 +368,7 @@ pub struct CustomDialectBuilder { timestamp_cast_dtype: ast::DataType, timestamp_tz_cast_dtype: ast::DataType, date32_cast_dtype: ast::DataType, + supports_column_alias_in_table_alias: bool, } impl Default for CustomDialectBuilder { @@ -378,6 +395,7 @@ impl CustomDialectBuilder { TimezoneInfo::WithTimeZone, ), date32_cast_dtype: sqlparser::ast::DataType::Date, + supports_column_alias_in_table_alias: true, } } @@ -395,6 +413,8 @@ impl CustomDialectBuilder { timestamp_cast_dtype: self.timestamp_cast_dtype, timestamp_tz_cast_dtype: self.timestamp_tz_cast_dtype, date32_cast_dtype: self.date32_cast_dtype, + supports_column_alias_in_table_alias: self + .supports_column_alias_in_table_alias, } } @@ -482,4 +502,13 @@ impl CustomDialectBuilder { self.date32_cast_dtype = date32_cast_dtype; self } + + /// Customize the dialect to supports column aliases as part of alias table definition + pub fn with_supports_column_alias_in_table_alias( + mut self, + supports_column_alias_in_table_alias: bool, + ) -> Self { + self.supports_column_alias_in_table_alias = supports_column_alias_in_table_alias; + self + } } diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 509c5dd52cd4..d736cae86c6e 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::{internal_err, not_impl_err, Column, DataFusionError, Result}; +use datafusion_common::{internal_err, not_impl_err, plan_err, Column, DataFusionError, Result}; use datafusion_expr::{ expr::Alias, Distinct, Expr, JoinConstraint, JoinType, LogicalPlan, Projection, SortExpr, @@ -30,7 +30,8 @@ use super::{ SelectBuilder, TableRelationBuilder, TableWithJoinsBuilder, }, rewrite::{ - normalize_union_schema, rewrite_plan_for_sort_on_non_projected_fields, + inject_column_aliases, normalize_union_schema, + rewrite_plan_for_sort_on_non_projected_fields, subquery_alias_inner_query_and_columns, }, utils::{find_agg_node_within_select, unproject_window_exprs, AggVariant}, @@ -450,10 +451,31 @@ impl Unparser<'_> { Ok(()) } LogicalPlan::SubqueryAlias(plan_alias) => { - // Handle bottom-up to allocate relation - let (plan, columns) = subquery_alias_inner_query_and_columns(plan_alias); + let (plan, mut columns) = + subquery_alias_inner_query_and_columns(&plan_alias); + + if !columns.is_empty() && !self.dialect.supports_column_alias_in_table_alias() { + // if columns are returned than plan corresponds to a projection + let LogicalPlan::Projection(inner_p) = plan else { + return plan_err!( + "Inner projection for subquery alias is expected" + ); + }; + + // Instead of specifying column aliases as part of the outer table inject them directly into the inner projection + let rewritten_plan = inject_column_aliases(&inner_p, &columns); + columns.clear(); + + self.select_to_sql_recursively( + &rewritten_plan, + query, + select, + relation, + )?; + } else { + self.select_to_sql_recursively(&plan, query, select, relation)?; + } - self.select_to_sql_recursively(plan, query, select, relation)?; relation.alias(Some( self.new_table_alias(plan_alias.alias.table().to_string(), columns), )); diff --git a/datafusion/sql/src/unparser/rewrite.rs b/datafusion/sql/src/unparser/rewrite.rs index 2529385849e0..632a729af259 100644 --- a/datafusion/sql/src/unparser/rewrite.rs +++ b/datafusion/sql/src/unparser/rewrite.rs @@ -24,7 +24,7 @@ use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode}, Result, }; -use datafusion_expr::tree_node::transform_sort_vec; +use datafusion_expr::{expr::Alias, tree_node::transform_sort_vec}; use datafusion_expr::{Expr, LogicalPlan, Projection, Sort, SortExpr}; use sqlparser::ast::Ident; @@ -257,6 +257,41 @@ pub(super) fn subquery_alias_inner_query_and_columns( (outer_projections.input.as_ref(), columns) } +/// Injects column aliases into the projection of a logical plan by wrapping `Expr::Column` expressions +/// with `Expr::Alias` using the provided list of aliases. Non-column expressions are left unchanged. +/// +/// Example: +/// - `SELECT col1, col2 FROM table` with aliases `["alias_1", "some_alias_2"]` will be transformed to +/// - `SELECT col1 AS alias_1, col2 AS some_alias_2 FROM table` +pub(super) fn inject_column_aliases( + projection: &datafusion_expr::Projection, + aliases: &Vec, +) -> LogicalPlan { + let mut updated_projection = projection.clone(); + + let new_exprs = projection + .expr + .iter() + .zip(aliases) + .map(|(expr, col_alias)| match expr { + Expr::Column(col) => { + let new_expr = Expr::Alias(Alias { + expr: Box::new(expr.clone()), + relation: col.relation.clone(), + name: col_alias.to_string(), + }); + + new_expr + } + _ => expr.clone(), + }) + .collect::>(); + + updated_projection.expr.clone_from(&new_exprs); + + LogicalPlan::Projection(updated_projection) +} + fn find_projection(logical_plan: &LogicalPlan) -> Option<&Projection> { match logical_plan { LogicalPlan::Projection(p) => Some(p), diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index cdc7bef06afd..78bab92acac5 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -25,7 +25,7 @@ use datafusion_expr::{col, table_scan}; use datafusion_sql::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_sql::unparser::dialect::{ DefaultDialect as UnparserDefaultDialect, Dialect as UnparserDialect, - MySqlDialect as UnparserMySqlDialect, + MySqlDialect as UnparserMySqlDialect, SqliteDialect, }; use datafusion_sql::unparser::{expr_to_sql, plan_to_sql, Unparser}; @@ -406,7 +406,19 @@ fn roundtrip_statement_with_dialect() -> Result<()> { expected: r#"SELECT c.id FROM (SELECT (CAST(j1.j1_id AS BIGINT) + 1) FROM j1 ORDER BY j1.j1_id ASC NULLS LAST LIMIT 1) AS c (id)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), - } + }, + TestStatementWithDialect { + sql: "SELECT temp_j.id2 FROM (SELECT j1_id, j1_string FROM j1) AS temp_j(id2, string2)", + expected: r#"SELECT temp_j.id2 FROM (SELECT j1.j1_id, j1.j1_string FROM j1) AS temp_j (id2, string2)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT temp_j.id2 FROM (SELECT j1_id, j1_string FROM j1) AS temp_j(id2, string2)", + expected: r#"SELECT `temp_j`.`id2` FROM (SELECT `j1`.`j1_id` AS `id2`, `j1`.`j1_string` AS `string2` FROM `j1`) AS `temp_j`"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(SqliteDialect {}), + }, ]; for query in tests {