Skip to content

Commit

Permalink
Fix subquery alias table definition unparsing for SQLite
Browse files Browse the repository at this point in the history
  • Loading branch information
sgrebnov committed Sep 5, 2024
1 parent 91b1d2b commit ab364d3
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 8 deletions.
29 changes: 29 additions & 0 deletions datafusion/sql/src/unparser/dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ pub trait Dialect: Send + Sync {
fn date32_cast_dtype(&self) -> sqlparser::ast::DataType {
sqlparser::ast::DataType::Date
}

/// Does the dialect support specifying column aliases as part of alias table definition?
/// (SELECT col1, col2 from my_table) AS my_table_alias(col1_alias, col2_alias)
fn supports_column_alias_in_table_alias(&self) -> bool {
true
}
}

/// `IntervalStyle` to use for unparsing
Expand Down Expand Up @@ -221,6 +227,10 @@ impl Dialect for SqliteDialect {
fn date32_cast_dtype(&self) -> sqlparser::ast::DataType {
sqlparser::ast::DataType::Text
}

fn supports_column_alias_in_table_alias(&self) -> bool {
false
}
}

pub struct CustomDialect {
Expand All @@ -236,6 +246,7 @@ pub struct CustomDialect {
timestamp_cast_dtype: ast::DataType,
timestamp_tz_cast_dtype: ast::DataType,
date32_cast_dtype: sqlparser::ast::DataType,
supports_column_alias_in_table_alias: bool,
}

impl Default for CustomDialect {
Expand All @@ -256,6 +267,7 @@ impl Default for CustomDialect {
TimezoneInfo::WithTimeZone,
),
date32_cast_dtype: sqlparser::ast::DataType::Date,
supports_column_alias_in_table_alias: true,
}
}
}
Expand Down Expand Up @@ -323,6 +335,10 @@ impl Dialect for CustomDialect {
fn date32_cast_dtype(&self) -> sqlparser::ast::DataType {
self.date32_cast_dtype.clone()
}

fn supports_column_alias_in_table_alias(&self) -> bool {
self.supports_column_alias_in_table_alias
}
}

/// `CustomDialectBuilder` to build `CustomDialect` using builder pattern
Expand Down Expand Up @@ -352,6 +368,7 @@ pub struct CustomDialectBuilder {
timestamp_cast_dtype: ast::DataType,
timestamp_tz_cast_dtype: ast::DataType,
date32_cast_dtype: ast::DataType,
supports_column_alias_in_table_alias: bool,
}

impl Default for CustomDialectBuilder {
Expand All @@ -378,6 +395,7 @@ impl CustomDialectBuilder {
TimezoneInfo::WithTimeZone,
),
date32_cast_dtype: sqlparser::ast::DataType::Date,
supports_column_alias_in_table_alias: true,
}
}

Expand All @@ -395,6 +413,8 @@ impl CustomDialectBuilder {
timestamp_cast_dtype: self.timestamp_cast_dtype,
timestamp_tz_cast_dtype: self.timestamp_tz_cast_dtype,
date32_cast_dtype: self.date32_cast_dtype,
supports_column_alias_in_table_alias: self
.supports_column_alias_in_table_alias,
}
}

Expand Down Expand Up @@ -482,4 +502,13 @@ impl CustomDialectBuilder {
self.date32_cast_dtype = date32_cast_dtype;
self
}

/// Customize the dialect to supports column aliases as part of alias table definition
pub fn with_supports_column_alias_in_table_alias(
mut self,
supports_column_alias_in_table_alias: bool,
) -> Self {
self.supports_column_alias_in_table_alias = supports_column_alias_in_table_alias;
self
}
}
32 changes: 27 additions & 5 deletions datafusion/sql/src/unparser/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use datafusion_common::{internal_err, not_impl_err, Column, DataFusionError, Result};
use datafusion_common::{internal_err, not_impl_err, plan_err, Column, DataFusionError, Result};
use datafusion_expr::{
expr::Alias, Distinct, Expr, JoinConstraint, JoinType, LogicalPlan, Projection,
SortExpr,
Expand All @@ -30,7 +30,8 @@ use super::{
SelectBuilder, TableRelationBuilder, TableWithJoinsBuilder,
},
rewrite::{
normalize_union_schema, rewrite_plan_for_sort_on_non_projected_fields,
inject_column_aliases, normalize_union_schema,
rewrite_plan_for_sort_on_non_projected_fields,
subquery_alias_inner_query_and_columns,
},
utils::{find_agg_node_within_select, unproject_window_exprs, AggVariant},
Expand Down Expand Up @@ -450,10 +451,31 @@ impl Unparser<'_> {
Ok(())
}
LogicalPlan::SubqueryAlias(plan_alias) => {
// Handle bottom-up to allocate relation
let (plan, columns) = subquery_alias_inner_query_and_columns(plan_alias);
let (plan, mut columns) =
subquery_alias_inner_query_and_columns(&plan_alias);

if !columns.is_empty() && !self.dialect.supports_column_alias_in_table_alias() {
// if columns are returned than plan corresponds to a projection
let LogicalPlan::Projection(inner_p) = plan else {
return plan_err!(
"Inner projection for subquery alias is expected"
);
};

// Instead of specifying column aliases as part of the outer table inject them directly into the inner projection
let rewritten_plan = inject_column_aliases(&inner_p, &columns);
columns.clear();

self.select_to_sql_recursively(
&rewritten_plan,
query,
select,
relation,
)?;
} else {
self.select_to_sql_recursively(&plan, query, select, relation)?;
}

self.select_to_sql_recursively(plan, query, select, relation)?;
relation.alias(Some(
self.new_table_alias(plan_alias.alias.table().to_string(), columns),
));
Expand Down
37 changes: 36 additions & 1 deletion datafusion/sql/src/unparser/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use datafusion_common::{
tree_node::{Transformed, TransformedResult, TreeNode},
Result,
};
use datafusion_expr::tree_node::transform_sort_vec;
use datafusion_expr::{expr::Alias, tree_node::transform_sort_vec};
use datafusion_expr::{Expr, LogicalPlan, Projection, Sort, SortExpr};
use sqlparser::ast::Ident;

Expand Down Expand Up @@ -257,6 +257,41 @@ pub(super) fn subquery_alias_inner_query_and_columns(
(outer_projections.input.as_ref(), columns)
}

/// Injects column aliases into the projection of a logical plan by wrapping `Expr::Column` expressions
/// with `Expr::Alias` using the provided list of aliases. Non-column expressions are left unchanged.
///
/// Example:
/// - `SELECT col1, col2 FROM table` with aliases `["alias_1", "some_alias_2"]` will be transformed to
/// - `SELECT col1 AS alias_1, col2 AS some_alias_2 FROM table`
pub(super) fn inject_column_aliases(
projection: &datafusion_expr::Projection,
aliases: &Vec<Ident>,
) -> LogicalPlan {
let mut updated_projection = projection.clone();

let new_exprs = projection
.expr
.iter()
.zip(aliases)
.map(|(expr, col_alias)| match expr {
Expr::Column(col) => {
let new_expr = Expr::Alias(Alias {
expr: Box::new(expr.clone()),
relation: col.relation.clone(),
name: col_alias.to_string(),
});

new_expr
}
_ => expr.clone(),
})
.collect::<Vec<_>>();

updated_projection.expr.clone_from(&new_exprs);

LogicalPlan::Projection(updated_projection)
}

fn find_projection(logical_plan: &LogicalPlan) -> Option<&Projection> {
match logical_plan {
LogicalPlan::Projection(p) => Some(p),
Expand Down
16 changes: 14 additions & 2 deletions datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use datafusion_expr::{col, table_scan};
use datafusion_sql::planner::{ContextProvider, PlannerContext, SqlToRel};
use datafusion_sql::unparser::dialect::{
DefaultDialect as UnparserDefaultDialect, Dialect as UnparserDialect,
MySqlDialect as UnparserMySqlDialect,
MySqlDialect as UnparserMySqlDialect, SqliteDialect,
};
use datafusion_sql::unparser::{expr_to_sql, plan_to_sql, Unparser};

Expand Down Expand Up @@ -406,7 +406,19 @@ fn roundtrip_statement_with_dialect() -> Result<()> {
expected: r#"SELECT c.id FROM (SELECT (CAST(j1.j1_id AS BIGINT) + 1) FROM j1 ORDER BY j1.j1_id ASC NULLS LAST LIMIT 1) AS c (id)"#,
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(UnparserDefaultDialect {}),
}
},
TestStatementWithDialect {
sql: "SELECT temp_j.id2 FROM (SELECT j1_id, j1_string FROM j1) AS temp_j(id2, string2)",
expected: r#"SELECT temp_j.id2 FROM (SELECT j1.j1_id, j1.j1_string FROM j1) AS temp_j (id2, string2)"#,
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(UnparserDefaultDialect {}),
},
TestStatementWithDialect {
sql: "SELECT temp_j.id2 FROM (SELECT j1_id, j1_string FROM j1) AS temp_j(id2, string2)",
expected: r#"SELECT `temp_j`.`id2` FROM (SELECT `j1`.`j1_id` AS `id2`, `j1`.`j1_string` AS `string2` FROM `j1`) AS `temp_j`"#,
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(SqliteDialect {}),
},
];

for query in tests {
Expand Down

0 comments on commit ab364d3

Please sign in to comment.