Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement prettier SQL unparsing (more human readable) #11186

Merged
merged 17 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 109 additions & 2 deletions datafusion/sql/src/unparser/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ use arrow_array::{Date32Array, Date64Array, PrimitiveArray};
use arrow_schema::DataType;
use sqlparser::ast::Value::SingleQuotedString;
use sqlparser::ast::{
self, Expr as AstExpr, Function, FunctionArg, Ident, Interval, TimezoneInfo,
UnaryOperator,
self, BinaryOperator, Expr as AstExpr, Function, FunctionArg, Ident, Interval,
TimezoneInfo, UnaryOperator,
};

use datafusion_common::{
Expand Down Expand Up @@ -101,7 +101,16 @@ pub fn expr_to_unparsed(expr: &Expr) -> Result<Unparsed> {
unparser.expr_to_unparsed(expr)
}

const LOWEST: &BinaryOperator = &BinaryOperator::BitwiseOr;

impl Unparser<'_> {
/// Try to unparse the expression into a more human-readable format
/// by removing unnecessary parentheses.
pub fn pretty_expr_to_sql(&self, expr: &Expr) -> Result<ast::Expr> {
let root_expr = self.expr_to_sql(expr)?;
Ok(self.pretty(root_expr, LOWEST, LOWEST))
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't have an extra method here and would combine it with expr_to_sql. The ast::Expr it produces is logically the same as the input one, just with unnecessary nesting removed. In fact, you could even think about this as serving the same purpose as an optimizer rewrite pass for LogicalPlan - it should produce logically the same thing as the input, just more efficient.


pub fn expr_to_sql(&self, expr: &Expr) -> Result<ast::Expr> {
match expr {
Expr::InList(InList {
Expand Down Expand Up @@ -603,6 +612,60 @@ impl Unparser<'_> {
}
}

/// Given an expression of the form `((a + b) * (c * d))`,
/// the parenthesing is redundant if the precedence of the nested expression is already higher
/// than the surrounding operators' precedence. The above expression would become
/// `(a + b) * c * d`.
///
/// Also note that when fetching the precedence of a nested expression, we ignore other nested
/// expressions, so precedence of expr `(a * (b + c))` equals `*` and not `+`.
fn pretty(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we just have a single expr_to_sql method, then it might make sense to rename this method as well to something like remove_unnecessary_nesting to more accurately describe what it does. Then if we added more "rewrite passes" to make it prettier or achieve some other functional goal, then they would just be added as separate functions after this one.

&self,
expr: ast::Expr,
left_op: &BinaryOperator,
right_op: &BinaryOperator,
) -> ast::Expr {
match expr {
ast::Expr::Nested(nested) => {
let surrounding_precedence = self
.sql_op_precedence(left_op)
.max(self.sql_op_precedence(right_op));

let inner_precedence = self.inner_precedence(&nested);

let not_associative =
matches!(left_op, BinaryOperator::Minus | BinaryOperator::Divide);

if inner_precedence == surrounding_precedence && not_associative {
ast::Expr::Nested(Box::new(self.pretty(*nested, LOWEST, LOWEST)))
} else if inner_precedence >= surrounding_precedence {
self.pretty(*nested, left_op, right_op)
} else {
ast::Expr::Nested(Box::new(self.pretty(*nested, LOWEST, LOWEST)))
}
}
ast::Expr::BinaryOp { left, op, right } => ast::Expr::BinaryOp {
left: Box::new(self.pretty(*left, left_op, &op)),
right: Box::new(self.pretty(*right, &op, right_op)),
op,
},
_ => expr,
}
}

fn inner_precedence(&self, expr: &ast::Expr) -> u8 {
match expr {
ast::Expr::Nested(_) | ast::Expr::Identifier(_) | ast::Expr::Value(_) => 100,
ast::Expr::BinaryOp { op, .. } => self.sql_op_precedence(op),
// closest precedence we currently have to Between is PGLikeMatch
// (https://www.postgresql.org/docs/7.2/sql-precedence.html)
ast::Expr::Between { .. } => {
self.sql_op_precedence(&ast::BinaryOperator::PGLikeMatch)
}
_ => 0,
}
}

pub(super) fn between_op_to_sql(
&self,
expr: ast::Expr,
Expand All @@ -618,6 +681,50 @@ impl Unparser<'_> {
}
}

// TODO: operator precedence should be defined in sqlparser
MohamedAbdeen21 marked this conversation as resolved.
Show resolved Hide resolved
// to avoid the need for sql_to_op and sql_op_precedence
fn sql_op_precedence(&self, op: &BinaryOperator) -> u8 {
match self.sql_to_op(op) {
Ok(op) => op.precedence(),
Err(_) => 0,
}
}

fn sql_to_op(&self, op: &BinaryOperator) -> Result<Operator> {
match op {
ast::BinaryOperator::Eq => Ok(Operator::Eq),
ast::BinaryOperator::NotEq => Ok(Operator::NotEq),
ast::BinaryOperator::Lt => Ok(Operator::Lt),
ast::BinaryOperator::LtEq => Ok(Operator::LtEq),
ast::BinaryOperator::Gt => Ok(Operator::Gt),
ast::BinaryOperator::GtEq => Ok(Operator::GtEq),
ast::BinaryOperator::Plus => Ok(Operator::Plus),
ast::BinaryOperator::Minus => Ok(Operator::Minus),
ast::BinaryOperator::Multiply => Ok(Operator::Multiply),
ast::BinaryOperator::Divide => Ok(Operator::Divide),
ast::BinaryOperator::Modulo => Ok(Operator::Modulo),
ast::BinaryOperator::And => Ok(Operator::And),
ast::BinaryOperator::Or => Ok(Operator::Or),
ast::BinaryOperator::PGRegexMatch => Ok(Operator::RegexMatch),
ast::BinaryOperator::PGRegexIMatch => Ok(Operator::RegexIMatch),
ast::BinaryOperator::PGRegexNotMatch => Ok(Operator::RegexNotMatch),
ast::BinaryOperator::PGRegexNotIMatch => Ok(Operator::RegexNotIMatch),
ast::BinaryOperator::PGILikeMatch => Ok(Operator::ILikeMatch),
ast::BinaryOperator::PGNotLikeMatch => Ok(Operator::NotLikeMatch),
ast::BinaryOperator::PGLikeMatch => Ok(Operator::LikeMatch),
ast::BinaryOperator::PGNotILikeMatch => Ok(Operator::NotILikeMatch),
ast::BinaryOperator::BitwiseAnd => Ok(Operator::BitwiseAnd),
ast::BinaryOperator::BitwiseOr => Ok(Operator::BitwiseOr),
ast::BinaryOperator::BitwiseXor => Ok(Operator::BitwiseXor),
ast::BinaryOperator::PGBitwiseShiftRight => Ok(Operator::BitwiseShiftRight),
ast::BinaryOperator::PGBitwiseShiftLeft => Ok(Operator::BitwiseShiftLeft),
ast::BinaryOperator::StringConcat => Ok(Operator::StringConcat),
ast::BinaryOperator::AtArrow => Ok(Operator::AtArrow),
ast::BinaryOperator::ArrowAt => Ok(Operator::ArrowAt),
_ => not_impl_err!("unsupported operation: {op:?}"),
MohamedAbdeen21 marked this conversation as resolved.
Show resolved Hide resolved
}
}

fn op_to_sql(&self, op: &Operator) -> Result<ast::BinaryOperator> {
match op {
Operator::Eq => Ok(ast::BinaryOperator::Eq),
Expand Down
74 changes: 62 additions & 12 deletions datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,26 +104,26 @@ fn roundtrip_statement() -> Result<()> {
"select id, count(*) as cnt from (select p1.id as id from person p1 inner join person p2 on p1.id=p2.id) group by id",
"select id, count(*), first_name from person group by first_name, id",
"select id, sum(age), first_name from person group by first_name, id",
"select id, count(*), first_name
from person
"select id, count(*), first_name
from person
where id!=3 and first_name=='test'
group by first_name, id
group by first_name, id
having count(*)>5 and count(*)<10
order by count(*)",
r#"select id, count("First Name") as count_first_name, "Last Name"
r#"select id, count("First Name") as count_first_name, "Last Name"
from person_quoted_cols
where id!=3 and "First Name"=='test'
group by "Last Name", id
group by "Last Name", id
having count_first_name>5 and count_first_name<10
order by count_first_name, "Last Name""#,
r#"select p.id, count("First Name") as count_first_name,
"Last Name", sum(qp.id/p.id - (select sum(id) from person_quoted_cols) ) / (select count(*) from person)
"Last Name", sum(qp.id/p.id - (select sum(id) from person_quoted_cols) ) / (select count(*) from person)
from (select id, "First Name", "Last Name" from person_quoted_cols) qp
inner join (select * from person) p
on p.id = qp.id
where p.id!=3 and "First Name"=='test' and qp.id in
where p.id!=3 and "First Name"=='test' and qp.id in
(select id from (select id, count(*) from person group by id having count(*) > 0))
group by "Last Name", p.id
group by "Last Name", p.id
having count_first_name>5 and count_first_name<10
order by count_first_name, "Last Name""#,
r#"SELECT j1_string as string FROM j1
Expand All @@ -134,12 +134,12 @@ fn roundtrip_statement() -> Result<()> {
SELECT j2_string as string FROM j2
ORDER BY string DESC
LIMIT 10"#,
"SELECT id, count(*) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
last_name, sum(id) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
"SELECT id, count(*) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
last_name, sum(id) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
first_name from person",
r#"SELECT id, count(distinct id) over (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
r#"SELECT id, count(distinct id) over (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
sum(id) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) from person"#,
"SELECT id, sum(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) from person",
"SELECT id, sum(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) from person",
];

// For each test sql string, we transform as follows:
Expand Down Expand Up @@ -314,3 +314,53 @@ fn test_table_references_in_plan_to_sql() {
"SELECT \"table\".id, \"table\".\"value\" FROM \"table\"",
);
}

#[test]
fn test_pretty_roundtrip() -> Result<()> {
let schema = Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new("age", DataType::Utf8, false),
]);

let df_schema = DFSchema::try_from(schema)?;

let context = MockContextProvider::default();
let sql_to_rel = SqlToRel::new(&context);

let unparser = Unparser::default();

let sql_to_pretty_unparse = vec![
("((id < 5) OR (age = 8))", "id < 5 OR age = 8"),
("((id + 5) * (age * 8))", "(id + 5) * age * 8"),
("(3 + (5 * 6) * 3)", "3 + 5 * 6 * 3"),
("((3 * (5 + 6)) * 3)", "3 * (5 + 6) * 3"),
("((3 AND (5 OR 6)) * 3)", "(3 AND (5 OR 6)) * 3"),
("((3 + (5 + 6)) * 3)", "(3 + 5 + 6) * 3"),
("((3 + (5 + 6)) + 3)", "3 + 5 + 6 + 3"),
MohamedAbdeen21 marked this conversation as resolved.
Show resolved Hide resolved
(
"((id > 10) || (age BETWEEN 10 AND 20))",
"id > 10 || age BETWEEN 10 AND 20",
),
(
"((id > 10) * (age BETWEEN 10 AND 20))",
"(id > 10) * (age BETWEEN 10 AND 20)",
),
("id - (age - 8)", "id - (age - 8)"),
("((id - age) - 8)", "id - age - 8"),
("(id OR (age - 8))", "id OR age - 8"),
("(id / (age - 8))", "id / (age - 8)"),
("((id / age) * 8)", "id / age * 8"),
];

for (sql, pretty) in sql_to_pretty_unparse.iter() {
let sql_expr = Parser::new(&GenericDialect {})
.try_with_sql(sql)?
.parse_expr()?;
let expr =
sql_to_rel.sql_to_expr(sql_expr, &df_schema, &mut PlannerContext::new())?;
let round_trip_sql = unparser.pretty_expr_to_sql(&expr)?.to_string();
assert_eq!(pretty.to_string(), round_trip_sql);
alamb marked this conversation as resolved.
Show resolved Hide resolved
}

Ok(())
}
Loading