Skip to content

Commit

Permalink
Support parsing SQL strings to Exprs (#10995)
Browse files Browse the repository at this point in the history
* Draft parse_sql

* Allow stirng pass

* Complete sql to expr support

* Add examples

* Add unit tests

* Fix format

* Remove async for trival operation and add parquet demo

* Fix comments

* fix comments

* fix comments

* Fix doc link
  • Loading branch information
xinlifoobar authored Jun 23, 2024
1 parent 08e4e6a commit 6f10dbc
Show file tree
Hide file tree
Showing 9 changed files with 383 additions and 14 deletions.
1 change: 1 addition & 0 deletions datafusion-examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ cargo run --example csv_sql
- [`parquet_sql.rs`](examples/parquet_sql.rs): Build and run a query plan from a SQL statement against a local Parquet file
- [`parquet_sql_multiple_files.rs`](examples/parquet_sql_multiple_files.rs): Build and run a query plan from a SQL statement against multiple local Parquet files
- ['parquet_exec_visitor.rs'](examples/parquet_exec_visitor.rs): Extract statistics by visiting an ExecutionPlan after execution
- [`parse_sql_expr.rs`](examples/parse_sql_expr.rs): Parse SQL text into Datafusion `Expr`.
- [`plan_to_sql.rs`](examples/plan_to_sql.rs): Generate SQL from Datafusion `Expr` and `LogicalPlan`
- [`pruning.rs`](examples/parquet_sql.rs): Use pruning to rule out files based on statistics
- [`query-aws-s3.rs`](examples/external_dependency/query-aws-s3.rs): Configure `object_store` and run a query against files stored in AWS S3
Expand Down
157 changes: 157 additions & 0 deletions datafusion-examples/examples/parse_sql_expr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::datatypes::{DataType, Field, Schema};
use datafusion::{
assert_batches_eq,
error::Result,
prelude::{ParquetReadOptions, SessionContext},
};
use datafusion_common::DFSchema;
use datafusion_expr::{col, lit};
use datafusion_sql::unparser::Unparser;

/// This example demonstrates the programmatic parsing of SQL expressions using
/// the DataFusion [`SessionContext::parse_sql_expr`] API or the [`DataFrame::parse_sql_expr`] API.
///
///
/// The code in this example shows how to:
///
/// 1. [`simple_session_context_parse_sql_expr_demo`]: Parse a simple SQL text into a logical
/// expression using a schema at [`SessionContext`].
///
/// 2. [`simple_dataframe_parse_sql_expr_demo`]: Parse a simple SQL text into a logical expression
/// using a schema at [`DataFrame`].
///
/// 3. [`query_parquet_demo`]: Query a parquet file using the parsed_sql_expr from a DataFrame.
///
/// 4. [`round_trip_parse_sql_expr_demo`]: Parse a SQL text and convert it back to SQL using [`Unparser`].
#[tokio::main]
async fn main() -> Result<()> {
// See how to evaluate expressions
simple_session_context_parse_sql_expr_demo()?;
simple_dataframe_parse_sql_expr_demo().await?;
query_parquet_demo().await?;
round_trip_parse_sql_expr_demo().await?;
Ok(())
}

/// DataFusion can parse a SQL text to a logical expression against a schema at [`SessionContext`].
fn simple_session_context_parse_sql_expr_demo() -> Result<()> {
let sql = "a < 5 OR a = 8";
let expr = col("a").lt(lit(5_i64)).or(col("a").eq(lit(8_i64)));

// provide type information that `a` is an Int32
let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
let df_schema = DFSchema::try_from(schema).unwrap();
let ctx = SessionContext::new();

let parsed_expr = ctx.parse_sql_expr(sql, &df_schema)?;

assert_eq!(parsed_expr, expr);

Ok(())
}

/// DataFusion can parse a SQL text to an logical expression using schema at [`DataFrame`].
async fn simple_dataframe_parse_sql_expr_demo() -> Result<()> {
let sql = "int_col < 5 OR double_col = 8.0";
let expr = col("int_col")
.lt(lit(5_i64))
.or(col("double_col").eq(lit(8.0_f64)));

let ctx = SessionContext::new();
let testdata = datafusion::test_util::parquet_test_data();
let df = ctx
.read_parquet(
&format!("{testdata}/alltypes_plain.parquet"),
ParquetReadOptions::default(),
)
.await?;

let parsed_expr = df.parse_sql_expr(sql)?;

assert_eq!(parsed_expr, expr);

Ok(())
}

async fn query_parquet_demo() -> Result<()> {
let ctx = SessionContext::new();
let testdata = datafusion::test_util::parquet_test_data();
let df = ctx
.read_parquet(
&format!("{testdata}/alltypes_plain.parquet"),
ParquetReadOptions::default(),
)
.await?;

let df = df
.clone()
.select(vec![
df.parse_sql_expr("int_col")?,
df.parse_sql_expr("double_col")?,
])?
.filter(df.parse_sql_expr("int_col < 5 OR double_col = 8.0")?)?
.aggregate(
vec![df.parse_sql_expr("double_col")?],
vec![df.parse_sql_expr("SUM(int_col) as sum_int_col")?],
)?
// Directly parsing the SQL text into a sort expression is not supported yet, so
// construct it programatically
.sort(vec![col("double_col").sort(false, false)])?
.limit(0, Some(1))?;

let result = df.collect().await?;

assert_batches_eq!(
&[
"+------------+----------------------+",
"| double_col | sum(?table?.int_col) |",
"+------------+----------------------+",
"| 10.1 | 4 |",
"+------------+----------------------+",
],
&result
);

Ok(())
}

/// DataFusion can parse a SQL text and convert it back to SQL using [`Unparser`].
async fn round_trip_parse_sql_expr_demo() -> Result<()> {
let sql = "((int_col < 5) OR (double_col = 8))";

let ctx = SessionContext::new();
let testdata = datafusion::test_util::parquet_test_data();
let df = ctx
.read_parquet(
&format!("{testdata}/alltypes_plain.parquet"),
ParquetReadOptions::default(),
)
.await?;

let parsed_expr = df.parse_sql_expr(sql)?;

let unparser = Unparser::default();
let round_trip_sql = unparser.expr_to_sql(&parsed_expr)?.to_string();

assert_eq!(sql, round_trip_sql);

Ok(())
}
27 changes: 27 additions & 0 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,33 @@ impl DataFrame {
}
}

/// Creates logical expression from a SQL query text.
/// The expression is created and processed againt the current schema.
///
/// # Example: Parsing SQL queries
/// ```
/// # use arrow::datatypes::{DataType, Field, Schema};
/// # use datafusion::prelude::*;
/// # use datafusion_common::{DFSchema, Result};
/// # #[tokio::main]
/// # async fn main() -> Result<()> {
/// // datafusion will parse number as i64 first.
/// let sql = "a > 1 and b in (1, 10)";
/// let expected = col("a").gt(lit(1 as i64))
/// .and(col("b").in_list(vec![lit(1 as i64), lit(10 as i64)], false));
/// let ctx = SessionContext::new();
/// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?;
/// let expr = df.parse_sql_expr(sql)?;
/// assert_eq!(expected, expr);
/// # Ok(())
/// # }
/// ```
pub fn parse_sql_expr(&self, sql: &str) -> Result<Expr> {
let df_schema = self.schema();

self.session_state.create_logical_expr(sql, df_schema)
}

/// Consume the DataFrame and produce a physical plan
pub async fn create_physical_plan(self) -> Result<Arc<dyn ExecutionPlan>> {
self.session_state.create_physical_plan(&self.plan).await
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub(crate) fn from_bytes_to_f16(b: &[u8]) -> Option<f16> {
// Copy from arrow-rs
// https://github.com/apache/arrow-rs/blob/198af7a3f4aa20f9bd003209d9f04b0f37bb120e/parquet/src/arrow/buffer/bit_util.rs#L54
// Convert the byte slice to fixed length byte array with the length of N.
pub fn sign_extend_be<const N: usize>(b: &[u8]) -> [u8; N] {
fn sign_extend_be<const N: usize>(b: &[u8]) -> [u8; N] {
assert!(b.len() <= N, "Array too large, expected less than {N}");
let is_negative = (b[0] & 128u8) == 128u8;
let mut result = if is_negative { [255u8; N] } else { [0u8; N] };
Expand Down
26 changes: 26 additions & 0 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,32 @@ impl SessionContext {
self.execute_logical_plan(plan).await
}

/// Creates logical expresssions from SQL query text.
///
/// # Example: Parsing SQL queries
///
/// ```
/// # use arrow::datatypes::{DataType, Field, Schema};
/// # use datafusion::prelude::*;
/// # use datafusion_common::{DFSchema, Result};
/// # #[tokio::main]
/// # async fn main() -> Result<()> {
/// // datafusion will parse number as i64 first.
/// let sql = "a > 10";
/// let expected = col("a").gt(lit(10 as i64));
/// // provide type information that `a` is an Int32
/// let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
/// let df_schema = DFSchema::try_from(schema).unwrap();
/// let expr = SessionContext::new()
/// .parse_sql_expr(sql, &df_schema)?;
/// assert_eq!(expected, expr);
/// # Ok(())
/// # }
/// ```
pub fn parse_sql_expr(&self, sql: &str, df_schema: &DFSchema) -> Result<Expr> {
self.state.read().create_logical_expr(sql, df_schema)
}

/// Execute the [`LogicalPlan`], return a [`DataFrame`]. This API
/// is not featured limited (so all SQL such as `CREATE TABLE` and
/// `COPY` will be run).
Expand Down
67 changes: 55 additions & 12 deletions datafusion/core/src/execution/session_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ use datafusion_physical_expr::create_physical_expr;
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use datafusion_physical_plan::ExecutionPlan;
use datafusion_sql::parser::{DFParser, Statement};
use datafusion_sql::planner::{ContextProvider, ParserOptions, SqlToRel};
use datafusion_sql::planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel};
use sqlparser::ast::Expr as SQLExpr;
use sqlparser::dialect::dialect_from_str;
use std::collections::hash_map::Entry;
use std::collections::{HashMap, HashSet};
Expand Down Expand Up @@ -490,6 +491,27 @@ impl SessionState {
Ok(statement)
}

/// parse a sql string into a sqlparser-rs AST [`SQLExpr`].
///
/// See [`Self::create_logical_expr`] for parsing sql to [`Expr`].
pub fn sql_to_expr(
&self,
sql: &str,
dialect: &str,
) -> datafusion_common::Result<SQLExpr> {
let dialect = dialect_from_str(dialect).ok_or_else(|| {
plan_datafusion_err!(
"Unsupported SQL dialect: {dialect}. Available dialects: \
Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \
MsSQL, ClickHouse, BigQuery, Ansi."
)
})?;

let expr = DFParser::parse_sql_into_expr_with_dialect(sql, dialect.as_ref())?;

Ok(expr)
}

/// Resolve all table references in the SQL statement. Does not include CTE references.
///
/// See [`catalog::resolve_table_references`] for more information.
Expand Down Expand Up @@ -520,10 +542,6 @@ impl SessionState {
tables: HashMap::with_capacity(references.len()),
};

let enable_ident_normalization =
self.config.options().sql_parser.enable_ident_normalization;
let parse_float_as_decimal =
self.config.options().sql_parser.parse_float_as_decimal;
for reference in references {
let resolved = &self.resolve_table_ref(reference);
if let Entry::Vacant(v) = provider.tables.entry(resolved.to_string()) {
Expand All @@ -535,16 +553,19 @@ impl SessionState {
}
}

let query = SqlToRel::new_with_options(
&provider,
ParserOptions {
parse_float_as_decimal,
enable_ident_normalization,
},
);
let query = SqlToRel::new_with_options(&provider, self.get_parser_options());
query.statement_to_plan(statement)
}

fn get_parser_options(&self) -> ParserOptions {
let sql_parser_options = &self.config.options().sql_parser;

ParserOptions {
parse_float_as_decimal: sql_parser_options.parse_float_as_decimal,
enable_ident_normalization: sql_parser_options.enable_ident_normalization,
}
}

/// Creates a [`LogicalPlan`] from the provided SQL string. This
/// interface will plan any SQL DataFusion supports, including DML
/// like `CREATE TABLE`, and `COPY` (which can write to local
Expand All @@ -567,6 +588,28 @@ impl SessionState {
Ok(plan)
}

/// Creates a datafusion style AST [`Expr`] from a SQL string.
///
/// See example on [SessionContext::parse_sql_expr](crate::execution::context::SessionContext::parse_sql_expr)
pub fn create_logical_expr(
&self,
sql: &str,
df_schema: &DFSchema,
) -> datafusion_common::Result<Expr> {
let dialect = self.config.options().sql_parser.dialect.as_str();

let sql_expr = self.sql_to_expr(sql, dialect)?;

let provider = SessionContextProvider {
state: self,
tables: HashMap::new(),
};

let query = SqlToRel::new_with_options(&provider, self.get_parser_options());

query.sql_to_expr(sql_expr, df_schema, &mut PlannerContext::new())
}

/// Optimizes the logical plan by applying optimizer rules.
pub fn optimize(&self, plan: &LogicalPlan) -> datafusion_common::Result<LogicalPlan> {
if let LogicalPlan::Explain(e) = plan {
Expand Down
1 change: 1 addition & 0 deletions datafusion/core/tests/expr_api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use sqlparser::ast::NullTreatment;
/// Tests of using and evaluating `Expr`s outside the context of a LogicalPlan
use std::sync::{Arc, OnceLock};

mod parse_sql_expr;
mod simplification;

#[test]
Expand Down
Loading

0 comments on commit 6f10dbc

Please sign in to comment.