From 4bc322819f28695b6b9593d5bbaea775e7d798b6 Mon Sep 17 00:00:00 2001 From: kamille Date: Tue, 2 Jul 2024 21:42:51 +0800 Subject: [PATCH 1/7] Covert grouping to udaf (#11147) * define Grouping udf and impl AggregateUDFImpl for it. * add `grouping` to default list. * remove the old grouping related codes. * continue to remove codes. * regen pbs in proto. * remove built-in grouping in proto codes. * fix sql it. * Add test + export fn --------- Co-authored-by: Andrew Lamb --- datafusion/expr/src/aggregate_function.rs | 11 +- .../expr/src/type_coercion/aggregates.rs | 1 - .../functions-aggregate/src/grouping.rs | 97 +++++++++++++++++ datafusion/functions-aggregate/src/lib.rs | 3 + .../physical-expr/src/aggregate/build_in.rs | 5 - .../physical-expr/src/aggregate/grouping.rs | 103 ------------------ datafusion/physical-expr/src/aggregate/mod.rs | 1 - .../physical-expr/src/expressions/mod.rs | 1 - datafusion/proto/proto/datafusion.proto | 2 +- datafusion/proto/src/generated/pbjson.rs | 3 - datafusion/proto/src/generated/prost.rs | 4 +- .../proto/src/logical_plan/from_proto.rs | 1 - datafusion/proto/src/logical_plan/to_proto.rs | 2 - .../proto/src/physical_plan/to_proto.rs | 10 +- .../tests/cases/roundtrip_logical_plan.rs | 5 +- datafusion/sql/tests/sql_integration.rs | 15 +-- 16 files changed, 118 insertions(+), 146 deletions(-) create mode 100644 datafusion/functions-aggregate/src/grouping.rs delete mode 100644 datafusion/physical-expr/src/aggregate/grouping.rs diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index b17e4294a1ef..760952d94815 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -41,8 +41,6 @@ pub enum AggregateFunction { ArrayAgg, /// N'th value in a group according to some ordering NthValue, - /// Grouping - Grouping, } impl AggregateFunction { @@ -53,7 +51,6 @@ impl AggregateFunction { Max => "MAX", ArrayAgg => "ARRAY_AGG", NthValue => "NTH_VALUE", - Grouping => "GROUPING", } } } @@ -73,8 +70,6 @@ impl FromStr for AggregateFunction { "min" => AggregateFunction::Min, "array_agg" => AggregateFunction::ArrayAgg, "nth_value" => AggregateFunction::NthValue, - // other - "grouping" => AggregateFunction::Grouping, _ => { return plan_err!("There is no built-in function named {name}"); } @@ -119,7 +114,6 @@ impl AggregateFunction { coerced_data_types[0].clone(), input_expr_nullable[0], )))), - AggregateFunction::Grouping => Ok(DataType::Int32), AggregateFunction::NthValue => Ok(coerced_data_types[0].clone()), } } @@ -130,7 +124,6 @@ impl AggregateFunction { match self { AggregateFunction::Max | AggregateFunction::Min => Ok(true), AggregateFunction::ArrayAgg => Ok(false), - AggregateFunction::Grouping => Ok(true), AggregateFunction::NthValue => Ok(true), } } @@ -141,9 +134,7 @@ impl AggregateFunction { pub fn signature(&self) -> Signature { // note: the physical expression must accept the type returned by this function or the execution panics. match self { - AggregateFunction::Grouping | AggregateFunction::ArrayAgg => { - Signature::any(1, Volatility::Immutable) - } + AggregateFunction::ArrayAgg => Signature::any(1, Volatility::Immutable), AggregateFunction::Min | AggregateFunction::Max => { let valid = STRINGS .iter() diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 36a789d5b0ee..0f7464b96b3e 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -102,7 +102,6 @@ pub fn coerce_types( get_min_max_result_type(input_types) } AggregateFunction::NthValue => Ok(input_types.to_vec()), - AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]), } } diff --git a/datafusion/functions-aggregate/src/grouping.rs b/datafusion/functions-aggregate/src/grouping.rs new file mode 100644 index 000000000000..6fb7c3800f4e --- /dev/null +++ b/datafusion/functions-aggregate/src/grouping.rs @@ -0,0 +1,97 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions that can evaluated at runtime during query execution + +use std::any::Any; +use std::fmt; + +use arrow::datatypes::DataType; +use arrow::datatypes::Field; +use datafusion_common::{not_impl_err, Result}; +use datafusion_expr::function::AccumulatorArgs; +use datafusion_expr::function::StateFieldsArgs; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; + +make_udaf_expr_and_func!( + Grouping, + grouping, + expression, + "Returns 1 if the data is aggregated across the specified column or 0 for not aggregated in the result set.", + grouping_udaf +); + +pub struct Grouping { + signature: Signature, +} + +impl fmt::Debug for Grouping { + fn fmt(&self, f: &mut std::fmt::Formatter) -> fmt::Result { + f.debug_struct("Grouping") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for Grouping { + fn default() -> Self { + Self::new() + } +} + +impl Grouping { + /// Create a new GROUPING aggregate function. + pub fn new() -> Self { + Self { + signature: Signature::any(1, Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for Grouping { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "grouping" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int32) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + Ok(vec![Field::new( + format_state_name(args.name, "grouping"), + DataType::Int32, + true, + )]) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + not_impl_err!( + "physical plan is not yet implemented for GROUPING aggregate function" + ) + } +} diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 063e6000b4c9..fc485a284ab4 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -73,6 +73,7 @@ pub mod approx_percentile_cont_with_weight; pub mod average; pub mod bit_and_or_xor; pub mod bool_and_or; +pub mod grouping; pub mod string_agg; use crate::approx_percentile_cont::approx_percentile_cont_udaf; @@ -102,6 +103,7 @@ pub mod expr_fn { pub use super::covariance::covar_samp; pub use super::first_last::first_value; pub use super::first_last::last_value; + pub use super::grouping::grouping; pub use super::median::median; pub use super::regr::regr_avgx; pub use super::regr::regr_avgy; @@ -154,6 +156,7 @@ pub fn all_default_aggregate_functions() -> Vec> { bool_and_or::bool_and_udaf(), bool_and_or::bool_or_udaf(), average::avg_udaf(), + grouping::grouping_udaf(), ] } diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 169418d2daa0..adbbbd3e631e 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -60,11 +60,6 @@ pub fn create_aggregate_expr( .collect::>>()?; let input_phy_exprs = input_phy_exprs.to_vec(); Ok(match (fun, distinct) { - (AggregateFunction::Grouping, _) => Arc::new(expressions::Grouping::new( - input_phy_exprs[0].clone(), - name, - data_type, - )), (AggregateFunction::ArrayAgg, false) => { let expr = input_phy_exprs[0].clone(); let nullable = expr.nullable(input_schema)?; diff --git a/datafusion/physical-expr/src/aggregate/grouping.rs b/datafusion/physical-expr/src/aggregate/grouping.rs deleted file mode 100644 index d43bcd5c7091..000000000000 --- a/datafusion/physical-expr/src/aggregate/grouping.rs +++ /dev/null @@ -1,103 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expressions that can evaluated at runtime during query execution - -use std::any::Any; -use std::sync::Arc; - -use crate::aggregate::utils::down_cast_any_ref; -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::datatypes::DataType; -use arrow::datatypes::Field; -use datafusion_common::{not_impl_err, Result}; -use datafusion_expr::Accumulator; - -use crate::expressions::format_state_name; - -/// GROUPING aggregate expression -/// Returns the amount of non-null values of the given expression. -#[derive(Debug)] -pub struct Grouping { - name: String, - data_type: DataType, - nullable: bool, - expr: Arc, -} - -impl Grouping { - /// Create a new GROUPING aggregate function. - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - expr, - data_type, - nullable: true, - } - } -} - -impl AggregateExpr for Grouping { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Int32, self.nullable)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new( - format_state_name(&self.name, "grouping"), - DataType::Int32, - true, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn create_accumulator(&self) -> Result> { - not_impl_err!( - "physical plan is not yet implemented for GROUPING aggregate function" - ) - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for Grouping { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.nullable == x.nullable - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index ca5bf3293442..f0de7446f6f1 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -20,7 +20,6 @@ pub use datafusion_physical_expr_common::aggregate::AggregateExpr; pub(crate) mod array_agg; pub(crate) mod array_agg_distinct; pub(crate) mod array_agg_ordered; -pub(crate) mod grouping; pub(crate) mod nth_value; #[macro_use] pub(crate) mod min_max; diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index b87b6daa64c7..1f2c955ad07e 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -38,7 +38,6 @@ pub use crate::aggregate::array_agg::ArrayAgg; pub use crate::aggregate::array_agg_distinct::DistinctArrayAgg; pub use crate::aggregate::array_agg_ordered::OrderSensitiveArrayAgg; pub use crate::aggregate::build_in::create_aggregate_expr; -pub use crate::aggregate::grouping::Grouping; pub use crate::aggregate::min_max::{Max, MaxAccumulator, Min, MinAccumulator}; pub use crate::aggregate::nth_value::NthValueAgg; pub use crate::aggregate::stats::StatsType; diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 7f4d6b9d927e..ce6c0c53c3fc 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -483,7 +483,7 @@ enum AggregateFunction { // APPROX_PERCENTILE_CONT = 14; // APPROX_MEDIAN = 15; // APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16; - GROUPING = 17; + // GROUPING = 17; // MEDIAN = 18; // BIT_AND = 19; // BIT_OR = 20; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 33cd634c4aad..347654e52b73 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -535,7 +535,6 @@ impl serde::Serialize for AggregateFunction { Self::Min => "MIN", Self::Max => "MAX", Self::ArrayAgg => "ARRAY_AGG", - Self::Grouping => "GROUPING", Self::NthValueAgg => "NTH_VALUE_AGG", }; serializer.serialize_str(variant) @@ -551,7 +550,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "MIN", "MAX", "ARRAY_AGG", - "GROUPING", "NTH_VALUE_AGG", ]; @@ -596,7 +594,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "MIN" => Ok(AggregateFunction::Min), "MAX" => Ok(AggregateFunction::Max), "ARRAY_AGG" => Ok(AggregateFunction::ArrayAgg), - "GROUPING" => Ok(AggregateFunction::Grouping), "NTH_VALUE_AGG" => Ok(AggregateFunction::NthValueAgg), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 83b8b738c4f4..c74f172482b7 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1935,7 +1935,7 @@ pub enum AggregateFunction { /// APPROX_PERCENTILE_CONT = 14; /// APPROX_MEDIAN = 15; /// APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16; - Grouping = 17, + /// GROUPING = 17; /// MEDIAN = 18; /// BIT_AND = 19; /// BIT_OR = 20; @@ -1964,7 +1964,6 @@ impl AggregateFunction { AggregateFunction::Min => "MIN", AggregateFunction::Max => "MAX", AggregateFunction::ArrayAgg => "ARRAY_AGG", - AggregateFunction::Grouping => "GROUPING", AggregateFunction::NthValueAgg => "NTH_VALUE_AGG", } } @@ -1974,7 +1973,6 @@ impl AggregateFunction { "MIN" => Some(Self::Min), "MAX" => Some(Self::Max), "ARRAY_AGG" => Some(Self::ArrayAgg), - "GROUPING" => Some(Self::Grouping), "NTH_VALUE_AGG" => Some(Self::NthValueAgg), _ => None, } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 609cbc1a286b..f4fb69280436 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -145,7 +145,6 @@ impl From for AggregateFunction { protobuf::AggregateFunction::Min => Self::Min, protobuf::AggregateFunction::Max => Self::Max, protobuf::AggregateFunction::ArrayAgg => Self::ArrayAgg, - protobuf::AggregateFunction::Grouping => Self::Grouping, protobuf::AggregateFunction::NthValueAgg => Self::NthValue, } } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index ccc64119c8a1..7570040a1d08 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -117,7 +117,6 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::Min => Self::Min, AggregateFunction::Max => Self::Max, AggregateFunction::ArrayAgg => Self::ArrayAgg, - AggregateFunction::Grouping => Self::Grouping, AggregateFunction::NthValue => Self::NthValueAgg, } } @@ -378,7 +377,6 @@ pub fn serialize_expr( AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, AggregateFunction::Min => protobuf::AggregateFunction::Min, AggregateFunction::Max => protobuf::AggregateFunction::Max, - AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, AggregateFunction::NthValue => { protobuf::AggregateFunction::NthValueAgg } diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 375361261952..23cdc666e701 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -24,9 +24,9 @@ use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ ArrayAgg, BinaryExpr, CaseExpr, CastExpr, Column, CumeDist, DistinctArrayAgg, - Grouping, InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, NegativeExpr, - NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, RankType, - RowNumber, TryCastExpr, WindowShift, + InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, + NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, RankType, RowNumber, + TryCastExpr, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -244,9 +244,7 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { let aggr_expr = expr.as_any(); let mut distinct = false; - let inner = if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::Grouping - } else if aggr_expr.downcast_ref::().is_some() { + let inner = if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::ArrayAgg } else if aggr_expr.downcast_ref::().is_some() { distinct = true; diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index fe3da3d05854..5fc3a9a8a197 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -43,8 +43,8 @@ use datafusion::execution::FunctionRegistry; use datafusion::functions_aggregate::count::count_udaf; use datafusion::functions_aggregate::expr_fn::{ approx_median, approx_percentile_cont, approx_percentile_cont_with_weight, count, - count_distinct, covar_pop, covar_samp, first_value, median, stddev, stddev_pop, sum, - var_pop, var_sample, + count_distinct, covar_pop, covar_samp, first_value, grouping, median, stddev, + stddev_pop, sum, var_pop, var_sample, }; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; @@ -695,6 +695,7 @@ async fn roundtrip_expr_api() -> Result<()> { approx_median(lit(2)), approx_percentile_cont(lit(2), lit(0.5)), approx_percentile_cont_with_weight(lit(2), lit(1), lit(0.5)), + grouping(lit(1)), bit_and(lit(2)), bit_or(lit(2)), bit_xor(lit(2)), diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index ec623a956186..aca0d040bb8d 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -37,10 +37,10 @@ use datafusion_sql::{ planner::{ParserOptions, SqlToRel}, }; -use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::{ approx_median::approx_median_udaf, count::count_udaf, }; +use datafusion_functions_aggregate::{average::avg_udaf, grouping::grouping_udaf}; use rstest::rstest; use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; @@ -2693,7 +2693,8 @@ fn logical_plan_with_dialect_and_options( .with_udaf(sum_udaf()) .with_udaf(approx_median_udaf()) .with_udaf(count_udaf()) - .with_udaf(avg_udaf()); + .with_udaf(avg_udaf()) + .with_udaf(grouping_udaf()); let planner = SqlToRel::new_with_options(&context, options); let result = DFParser::parse_sql_with_dialect(sql, dialect); @@ -3097,8 +3098,8 @@ fn aggregate_with_rollup() { fn aggregate_with_rollup_with_grouping() { let sql = "SELECT id, state, age, grouping(state), grouping(age), grouping(state) + grouping(age), count(*) \ FROM person GROUP BY id, ROLLUP (state, age)"; - let expected = "Projection: person.id, person.state, person.age, GROUPING(person.state), GROUPING(person.age), GROUPING(person.state) + GROUPING(person.age), count(*)\ - \n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.state, person.age))]], aggr=[[GROUPING(person.state), GROUPING(person.age), count(*)]]\ + let expected = "Projection: person.id, person.state, person.age, grouping(person.state), grouping(person.age), grouping(person.state) + grouping(person.age), count(*)\ + \n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.state, person.age))]], aggr=[[grouping(person.state), grouping(person.age), count(*)]]\ \n TableScan: person"; quick_test(sql, expected); } @@ -3118,9 +3119,9 @@ fn rank_partition_grouping() { from person group by rollup(state, last_name)"; - let expected = "Projection: sum(person.age) AS total_sum, person.state, person.last_name, GROUPING(person.state) + GROUPING(person.last_name) AS x, RANK() PARTITION BY [GROUPING(person.state) + GROUPING(person.last_name), CASE WHEN GROUPING(person.last_name) = Int64(0) THEN person.state END] ORDER BY [sum(person.age) DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS the_rank\ - \n WindowAggr: windowExpr=[[RANK() PARTITION BY [GROUPING(person.state) + GROUPING(person.last_name), CASE WHEN GROUPING(person.last_name) = Int64(0) THEN person.state END] ORDER BY [sum(person.age) DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n Aggregate: groupBy=[[ROLLUP (person.state, person.last_name)]], aggr=[[sum(person.age), GROUPING(person.state), GROUPING(person.last_name)]]\ + let expected = "Projection: sum(person.age) AS total_sum, person.state, person.last_name, grouping(person.state) + grouping(person.last_name) AS x, RANK() PARTITION BY [grouping(person.state) + grouping(person.last_name), CASE WHEN grouping(person.last_name) = Int64(0) THEN person.state END] ORDER BY [sum(person.age) DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS the_rank\ + \n WindowAggr: windowExpr=[[RANK() PARTITION BY [grouping(person.state) + grouping(person.last_name), CASE WHEN grouping(person.last_name) = Int64(0) THEN person.state END] ORDER BY [sum(person.age) DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + \n Aggregate: groupBy=[[ROLLUP (person.state, person.last_name)]], aggr=[[sum(person.age), grouping(person.state), grouping(person.last_name)]]\ \n TableScan: person"; quick_test(sql, expected); } From 75b9c9bea28757ec66d7e6a88997bd6f30c367f5 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Tue, 2 Jul 2024 08:47:47 -0500 Subject: [PATCH 2/7] Make statistics_from_parquet_meta a sync function (#11205) * Make statistics_from_parquet_meta a sync function * Fix version * improve comment * Clippy and fmt --------- Co-authored-by: Andrew Lamb --- .../src/datasource/file_format/parquet.rs | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 5921d8a797ac..27d783cd89b5 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -460,14 +460,14 @@ async fn fetch_statistics( metadata_size_hint: Option, ) -> Result { let metadata = fetch_parquet_metadata(store, file, metadata_size_hint).await?; - statistics_from_parquet_meta(&metadata, table_schema).await + statistics_from_parquet_meta_calc(&metadata, table_schema) } /// Convert statistics in [`ParquetMetaData`] into [`Statistics`] using ['StatisticsConverter`] /// /// The statistics are calculated for each column in the table schema /// using the row group statistics in the parquet metadata. -pub async fn statistics_from_parquet_meta( +pub fn statistics_from_parquet_meta_calc( metadata: &ParquetMetaData, table_schema: SchemaRef, ) -> Result { @@ -543,6 +543,21 @@ pub async fn statistics_from_parquet_meta( Ok(statistics) } +/// Deprecated +/// Use [`statistics_from_parquet_meta_calc`] instead. +/// This method was deprecated because it didn't need to be async so a new method was created +/// that exposes a synchronous API. +#[deprecated( + since = "40.0.0", + note = "please use `statistics_from_parquet_meta_calc` instead" +)] +pub async fn statistics_from_parquet_meta( + metadata: &ParquetMetaData, + table_schema: SchemaRef, +) -> Result { + statistics_from_parquet_meta_calc(metadata, table_schema) +} + fn summarize_min_max_null_counts( min_accs: &mut [Option], max_accs: &mut [Option], @@ -1467,7 +1482,7 @@ mod tests { // Fetch statistics for first file let pq_meta = fetch_parquet_metadata(store.as_ref(), &files[0], None).await?; - let stats = statistics_from_parquet_meta(&pq_meta, schema.clone()).await?; + let stats = statistics_from_parquet_meta_calc(&pq_meta, schema.clone())?; assert_eq!(stats.num_rows, Precision::Exact(4)); // column c_dic @@ -1514,7 +1529,7 @@ mod tests { // Fetch statistics for first file let pq_meta = fetch_parquet_metadata(store.as_ref(), &files[0], None).await?; - let stats = statistics_from_parquet_meta(&pq_meta, schema.clone()).await?; + let stats = statistics_from_parquet_meta_calc(&pq_meta, schema.clone())?; // assert_eq!(stats.num_rows, Precision::Exact(3)); // column c1 @@ -1536,7 +1551,7 @@ mod tests { // Fetch statistics for second file let pq_meta = fetch_parquet_metadata(store.as_ref(), &files[1], None).await?; - let stats = statistics_from_parquet_meta(&pq_meta, schema.clone()).await?; + let stats = statistics_from_parquet_meta_calc(&pq_meta, schema.clone())?; assert_eq!(stats.num_rows, Precision::Exact(3)); // column c1: missing from the file so the table treats all 3 rows as null let c1_stats = &stats.column_statistics[0]; From 43ea68208b0caa9f5d1164c592d5ef17a8c957d9 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 2 Jul 2024 14:57:19 +0100 Subject: [PATCH 3/7] Allow user defined SQL planners to be registered (#11208) * Allow user defined SQL planners to be registered * fix clippy, remove unused Default * format --- datafusion/core/src/execution/context/mod.rs | 10 +++ .../core/src/execution/session_state.rs | 24 ++++- datafusion/core/tests/user_defined/mod.rs | 3 + .../user_defined/user_defined_sql_planner.rs | 88 +++++++++++++++++++ datafusion/expr/src/planner.rs | 2 +- datafusion/expr/src/registry.rs | 9 ++ datafusion/functions-array/src/planner.rs | 6 +- 7 files changed, 134 insertions(+), 8 deletions(-) create mode 100644 datafusion/core/tests/user_defined/user_defined_sql_planner.rs diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 012717a007c2..4685f194fe29 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -60,6 +60,7 @@ use datafusion_execution::registry::SerializerRegistry; use datafusion_expr::{ expr_rewriter::FunctionRewrite, logical_plan::{DdlStatement, Statement}, + planner::UserDefinedSQLPlanner, Expr, UserDefinedLogicalNode, WindowUDF, }; @@ -1390,6 +1391,15 @@ impl FunctionRegistry for SessionContext { ) -> Result<()> { self.state.write().register_function_rewrite(rewrite) } + + fn register_user_defined_sql_planner( + &mut self, + user_defined_sql_planner: Arc, + ) -> Result<()> { + self.state + .write() + .register_user_defined_sql_planner(user_defined_sql_planner) + } } /// Create a new task context instance from SessionContext diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index ac94ee61fcb2..aa81d77cf682 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -60,6 +60,7 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr_rewriter::FunctionRewrite; +use datafusion_expr::planner::UserDefinedSQLPlanner; use datafusion_expr::registry::{FunctionRegistry, SerializerRegistry}; use datafusion_expr::simplify::SimplifyInfo; use datafusion_expr::var_provider::{is_system_variables, VarType}; @@ -99,6 +100,8 @@ pub struct SessionState { session_id: String, /// Responsible for analyzing and rewrite a logical plan before optimization analyzer: Analyzer, + /// Provides support for customising the SQL planner, e.g. to add support for custom operators like `->>` or `?` + user_defined_sql_planners: Vec>, /// Responsible for optimizing a logical plan optimizer: Optimizer, /// Responsible for optimizing a physical execution plan @@ -231,6 +234,7 @@ impl SessionState { let mut new_self = SessionState { session_id, analyzer: Analyzer::new(), + user_defined_sql_planners: vec![], optimizer: Optimizer::new(), physical_optimizers: PhysicalOptimizer::new(), query_planner: Arc::new(DefaultQueryPlanner {}), @@ -947,16 +951,21 @@ impl SessionState { where S: ContextProvider, { - let query = SqlToRel::new_with_options(provider, self.get_parser_options()); + let mut query = SqlToRel::new_with_options(provider, self.get_parser_options()); + + // custom planners are registered first, so they're run first and take precedence over built-in planners + for planner in self.user_defined_sql_planners.iter() { + query = query.with_user_defined_planner(planner.clone()); + } // register crate of array expressions (if enabled) #[cfg(feature = "array_expressions")] { let array_planner = - Arc::new(functions_array::planner::ArrayFunctionPlanner::default()) as _; + Arc::new(functions_array::planner::ArrayFunctionPlanner) as _; let field_access_planner = - Arc::new(functions_array::planner::FieldAccessPlanner::default()) as _; + Arc::new(functions_array::planner::FieldAccessPlanner) as _; query .with_user_defined_planner(array_planner) @@ -1176,6 +1185,15 @@ impl FunctionRegistry for SessionState { self.analyzer.add_function_rewrite(rewrite); Ok(()) } + + fn register_user_defined_sql_planner( + &mut self, + user_defined_sql_planner: Arc, + ) -> datafusion_common::Result<()> { + self.user_defined_sql_planners + .push(user_defined_sql_planner); + Ok(()) + } } impl OptimizerConfig for SessionState { diff --git a/datafusion/core/tests/user_defined/mod.rs b/datafusion/core/tests/user_defined/mod.rs index 6c6d966cc3aa..9b83a9fdd408 100644 --- a/datafusion/core/tests/user_defined/mod.rs +++ b/datafusion/core/tests/user_defined/mod.rs @@ -29,3 +29,6 @@ mod user_defined_window_functions; /// Tests for User Defined Table Functions mod user_defined_table_functions; + +/// Tests for User Defined SQL Planner +mod user_defined_sql_planner; diff --git a/datafusion/core/tests/user_defined/user_defined_sql_planner.rs b/datafusion/core/tests/user_defined/user_defined_sql_planner.rs new file mode 100644 index 000000000000..37df7e0900b4 --- /dev/null +++ b/datafusion/core/tests/user_defined/user_defined_sql_planner.rs @@ -0,0 +1,88 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::RecordBatch; +use std::sync::Arc; + +use datafusion::common::{assert_batches_eq, DFSchema}; +use datafusion::error::Result; +use datafusion::execution::FunctionRegistry; +use datafusion::logical_expr::Operator; +use datafusion::prelude::*; +use datafusion::sql::sqlparser::ast::BinaryOperator; +use datafusion_expr::planner::{PlannerResult, RawBinaryExpr, UserDefinedSQLPlanner}; +use datafusion_expr::BinaryExpr; + +struct MyCustomPlanner; + +impl UserDefinedSQLPlanner for MyCustomPlanner { + fn plan_binary_op( + &self, + expr: RawBinaryExpr, + _schema: &DFSchema, + ) -> Result> { + match &expr.op { + BinaryOperator::Arrow => { + Ok(PlannerResult::Planned(Expr::BinaryExpr(BinaryExpr { + left: Box::new(expr.left.clone()), + right: Box::new(expr.right.clone()), + op: Operator::StringConcat, + }))) + } + BinaryOperator::LongArrow => { + Ok(PlannerResult::Planned(Expr::BinaryExpr(BinaryExpr { + left: Box::new(expr.left.clone()), + right: Box::new(expr.right.clone()), + op: Operator::Plus, + }))) + } + _ => Ok(PlannerResult::Original(expr)), + } + } +} + +async fn plan_and_collect(sql: &str) -> Result> { + let mut ctx = SessionContext::new(); + ctx.register_user_defined_sql_planner(Arc::new(MyCustomPlanner))?; + ctx.sql(sql).await?.collect().await +} + +#[tokio::test] +async fn test_custom_operators_arrow() { + let actual = plan_and_collect("select 'foo'->'bar';").await.unwrap(); + let expected = [ + "+----------------------------+", + "| Utf8(\"foo\") || Utf8(\"bar\") |", + "+----------------------------+", + "| foobar |", + "+----------------------------+", + ]; + assert_batches_eq!(&expected, &actual); +} + +#[tokio::test] +async fn test_custom_operators_long_arrow() { + let actual = plan_and_collect("select 1->>2;").await.unwrap(); + let expected = [ + "+---------------------+", + "| Int64(1) + Int64(2) |", + "+---------------------+", + "| 3 |", + "+---------------------+", + ]; + assert_batches_eq!(&expected, &actual); +} diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index 1febfbec7ef0..c928ab39194d 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -83,7 +83,7 @@ pub trait ContextProvider { } /// This trait allows users to customize the behavior of the SQL planner -pub trait UserDefinedSQLPlanner { +pub trait UserDefinedSQLPlanner: Send + Sync { /// Plan the binary operation between two expressions, returns OriginalBinaryExpr if not possible fn plan_binary_op( &self, diff --git a/datafusion/expr/src/registry.rs b/datafusion/expr/src/registry.rs index 70d0a21a870e..c276fe30f897 100644 --- a/datafusion/expr/src/registry.rs +++ b/datafusion/expr/src/registry.rs @@ -18,6 +18,7 @@ //! FunctionRegistry trait use crate::expr_rewriter::FunctionRewrite; +use crate::planner::UserDefinedSQLPlanner; use crate::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF}; use datafusion_common::{not_impl_err, plan_datafusion_err, Result}; use std::collections::HashMap; @@ -108,6 +109,14 @@ pub trait FunctionRegistry { ) -> Result<()> { not_impl_err!("Registering FunctionRewrite") } + + /// Registers a new [`UserDefinedSQLPlanner`] with the registry. + fn register_user_defined_sql_planner( + &mut self, + _user_defined_sql_planner: Arc, + ) -> Result<()> { + not_impl_err!("Registering UserDefinedSQLPlanner") + } } /// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode]. diff --git a/datafusion/functions-array/src/planner.rs b/datafusion/functions-array/src/planner.rs index f33ee56582cf..cfb3e5ed0729 100644 --- a/datafusion/functions-array/src/planner.rs +++ b/datafusion/functions-array/src/planner.rs @@ -31,8 +31,7 @@ use crate::{ make_array::make_array, }; -#[derive(Default)] -pub struct ArrayFunctionPlanner {} +pub struct ArrayFunctionPlanner; impl UserDefinedSQLPlanner for ArrayFunctionPlanner { fn plan_binary_op( @@ -99,8 +98,7 @@ impl UserDefinedSQLPlanner for ArrayFunctionPlanner { } } -#[derive(Default)] -pub struct FieldAccessPlanner {} +pub struct FieldAccessPlanner; impl UserDefinedSQLPlanner for FieldAccessPlanner { fn plan_field_access( From 58f79e143e1a90e5caa59eecc9b36dbdd082a7eb Mon Sep 17 00:00:00 2001 From: Duong Cong Toai <35887761+duongcongtoai@users.noreply.github.com> Date: Tue, 2 Jul 2024 16:10:56 +0200 Subject: [PATCH 4/7] Recursive `unnest` (#11062) * chore: fix map children of unnest * adjust test * remove debug * chore: move test to unnest.slt * chore: rename * add some comment * compile err * more comment * chore: address comment * more coverage * one more scenario --- datafusion/expr/src/expr.rs | 5 + datafusion/expr/src/logical_plan/builder.rs | 1 - datafusion/expr/src/tree_node.rs | 3 +- datafusion/sql/src/select.rs | 98 +++++++----- datafusion/sql/src/utils.rs | 143 ++++++++++++++---- datafusion/sqllogictest/test_files/unnest.slt | 97 ++++++++++++ 6 files changed, 271 insertions(+), 76 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index e7fdc7f4db41..579f5fed578f 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -358,6 +358,11 @@ impl Unnest { expr: Box::new(expr), } } + + /// Create a new Unnest expression. + pub fn new_boxed(boxed: Box) -> Self { + Self { expr: boxed } + } } /// Alias expression diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 1f1175088227..cc4348d58c33 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -1455,7 +1455,6 @@ pub fn project( _ => projected_expr.push(columnize_expr(normalize_col(e, &plan)?, &plan)?), } } - validate_unique_names("Projections", projected_expr.iter())?; Projection::try_new(projected_expr, Arc::new(plan)).map(LogicalPlan::Projection) diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index e57d57188743..f1df8609f903 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -136,8 +136,9 @@ impl TreeNode for Expr { | Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::ScalarVariable(_, _) - | Expr::Unnest(_) | Expr::Literal(_) => Transformed::no(self), + Expr::Unnest(Unnest { expr, .. }) => transform_box(expr, &mut f)? + .update_data(|be| Expr::Unnest(Unnest::new_boxed(be))), Expr::Alias(Alias { expr, relation, diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 102b47216e7e..236403e83d74 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -22,9 +22,8 @@ use crate::planner::{ idents_to_table_reference, ContextProvider, PlannerContext, SqlToRel, }; use crate::utils::{ - check_columns_satisfy_exprs, extract_aliases, rebase_expr, - recursive_transform_unnest, resolve_aliases_to_exprs, resolve_columns, - resolve_positions_to_exprs, + check_columns_satisfy_exprs, extract_aliases, rebase_expr, resolve_aliases_to_exprs, + resolve_columns, resolve_positions_to_exprs, transform_bottom_unnest, }; use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result}; @@ -298,46 +297,61 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { input: LogicalPlan, select_exprs: Vec, ) -> Result { - let mut unnest_columns = vec![]; - // from which column used for projection, before the unnest happen - // including non unnest column and unnest column - let mut inner_projection_exprs = vec![]; - - // expr returned here maybe different from the originals in inner_projection_exprs - // for example: - // - unnest(struct_col) will be transformed into unnest(struct_col).field1, unnest(struct_col).field2 - // - unnest(array_col) will be transformed into unnest(array_col).element - // - unnest(array_col) + 1 will be transformed into unnest(array_col).element +1 - let outer_projection_exprs: Vec = select_exprs - .into_iter() - .map(|expr| { - recursive_transform_unnest( - &input, - &mut unnest_columns, - &mut inner_projection_exprs, - expr, - ) - }) - .collect::>>()? - .into_iter() - .flatten() - .collect(); - - // Do the final projection - if unnest_columns.is_empty() { - LogicalPlanBuilder::from(input) - .project(inner_projection_exprs)? - .build() - } else { - let columns = unnest_columns.into_iter().map(|col| col.into()).collect(); - // Set preserve_nulls to false to ensure compatibility with DuckDB and PostgreSQL - let unnest_options = UnnestOptions::new().with_preserve_nulls(false); - LogicalPlanBuilder::from(input) - .project(inner_projection_exprs)? - .unnest_columns_with_options(columns, unnest_options)? - .project(outer_projection_exprs)? - .build() + let mut intermediate_plan = input; + let mut intermediate_select_exprs = select_exprs; + // Each expr in select_exprs can contains multiple unnest stage + // The transformation happen bottom up, one at a time for each iteration + // Ony exaust the loop if no more unnest transformation is found + for i in 0.. { + let mut unnest_columns = vec![]; + // from which column used for projection, before the unnest happen + // including non unnest column and unnest column + let mut inner_projection_exprs = vec![]; + + // expr returned here maybe different from the originals in inner_projection_exprs + // for example: + // - unnest(struct_col) will be transformed into unnest(struct_col).field1, unnest(struct_col).field2 + // - unnest(array_col) will be transformed into unnest(array_col).element + // - unnest(array_col) + 1 will be transformed into unnest(array_col).element +1 + let outer_projection_exprs: Vec = intermediate_select_exprs + .iter() + .map(|expr| { + transform_bottom_unnest( + &intermediate_plan, + &mut unnest_columns, + &mut inner_projection_exprs, + expr, + ) + }) + .collect::>>()? + .into_iter() + .flatten() + .collect(); + + // No more unnest is possible + if unnest_columns.is_empty() { + // The original expr does not contain any unnest + if i == 0 { + return LogicalPlanBuilder::from(intermediate_plan) + .project(inner_projection_exprs)? + .build(); + } + break; + } else { + let columns = unnest_columns.into_iter().map(|col| col.into()).collect(); + // Set preserve_nulls to false to ensure compatibility with DuckDB and PostgreSQL + let unnest_options = UnnestOptions::new().with_preserve_nulls(false); + let plan = LogicalPlanBuilder::from(intermediate_plan) + .project(inner_projection_exprs)? + .unnest_columns_with_options(columns, unnest_options)? + .build()?; + intermediate_plan = plan; + intermediate_select_exprs = outer_projection_exprs; + } } + LogicalPlanBuilder::from(intermediate_plan) + .project(intermediate_select_exprs)? + .build() } fn plan_selection( diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index bc27d25cf216..2eacbd174fc2 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -22,7 +22,9 @@ use std::collections::HashMap; use arrow_schema::{ DataType, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE, }; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; use datafusion_common::{ exec_err, internal_err, plan_err, Column, DataFusionError, Result, ScalarValue, }; @@ -267,11 +269,13 @@ pub(crate) fn normalize_ident(id: Ident) -> String { /// - For list column: unnest(col) with type list -> unnest(col) with type list::item /// - For struct column: unnest(struct(field1, field2)) -> unnest(struct).field1, unnest(struct).field2 /// The transformed exprs will be used in the outer projection -pub(crate) fn recursive_transform_unnest( +/// If along the path from root to bottom, there are multiple unnest expressions, the transformation +/// is done only for the bottom expression +pub(crate) fn transform_bottom_unnest( input: &LogicalPlan, unnest_placeholder_columns: &mut Vec, inner_projection_exprs: &mut Vec, - original_expr: Expr, + original_expr: &Expr, ) -> Result> { let mut transform = |unnest_expr: &Expr, expr_in_unnest: &Expr| -> Result> { @@ -298,35 +302,53 @@ pub(crate) fn recursive_transform_unnest( .collect::>(); Ok(expr) }; - // expr transformed maybe either the same, or different from the originals exprs - // for example: - // - unnest(struct_col) will be transformed into unnest(struct_col).field1, unnest(struct_col).field2 + // This transformation is only done for list unnest + // struct unnest is done at the root level, and at the later stage + // because the syntax of TreeNode only support transform into 1 Expr, while + // Unnest struct will be transformed into multiple Exprs + // TODO: This can be resolved after this issue is resolved: https://github.com/apache/datafusion/issues/10102 + // + // The transformation looks like: // - unnest(array_col) will be transformed into unnest(array_col) // - unnest(array_col) + 1 will be transformed into unnest(array_col) + 1 - - // Specifically handle root level unnest expr, this is the only place - // unnest on struct can be handled - if let Expr::Unnest(Unnest { expr: ref arg }) = original_expr { - return transform(&original_expr, arg); - } let Transformed { - data: transformed_expr, - transformed, - tnr: _, - } = original_expr.transform_up(|expr: Expr| { - if let Expr::Unnest(Unnest { expr: ref arg }) = expr { - let (data_type, _) = arg.data_type_and_nullable(input.schema())?; - if let DataType::Struct(_) = data_type { - return internal_err!("unnest on struct can ony be applied at the root level of select expression"); - } - let transformed_exprs = transform(&expr, arg)?; - Ok(Transformed::yes(transformed_exprs[0].clone())) - } else { - Ok(Transformed::no(expr)) + data: transformed_expr, + transformed, + tnr: _, + } = original_expr.clone().transform_up(|expr: Expr| { + let is_root_expr = &expr == original_expr; + // Root expr is transformed separately + if is_root_expr { + return Ok(Transformed::no(expr)); + } + if let Expr::Unnest(Unnest { expr: ref arg }) = expr { + let (data_type, _) = arg.data_type_and_nullable(input.schema())?; + + if let DataType::Struct(_) = data_type { + return internal_err!("unnest on struct can ony be applied at the root level of select expression"); } - })?; + + let mut transformed_exprs = transform(&expr, arg)?; + // root_expr.push(transformed_exprs[0].clone()); + Ok(Transformed::new( + transformed_exprs.swap_remove(0), + true, + TreeNodeRecursion::Stop, + )) + } else { + Ok(Transformed::no(expr)) + } + })?; if !transformed { + // Because root expr need to transform separately + // unnest struct is only possible here + // The transformation looks like + // - unnest(struct_col) will be transformed into unnest(struct_col).field1, unnest(struct_col).field2 + if let Expr::Unnest(Unnest { expr: ref arg }) = transformed_expr { + return transform(&transformed_expr, arg); + } + if matches!(&transformed_expr, Expr::Column(_)) { inner_projection_exprs.push(transformed_expr.clone()); Ok(vec![transformed_expr]) @@ -351,12 +373,13 @@ mod tests { use arrow_schema::Fields; use datafusion_common::{DFSchema, Result}; use datafusion_expr::{col, lit, unnest, EmptyRelation, LogicalPlan}; + use datafusion_functions::core::expr_ext::FieldAccessor; use datafusion_functions_aggregate::expr_fn::count; - use crate::utils::{recursive_transform_unnest, resolve_positions_to_exprs}; + use crate::utils::{resolve_positions_to_exprs, transform_bottom_unnest}; #[test] - fn test_recursive_transform_unnest() -> Result<()> { + fn test_transform_bottom_unnest() -> Result<()> { let schema = Schema::new(vec![ Field::new( "struct_col", @@ -390,11 +413,11 @@ mod tests { // unnest(struct_col) let original_expr = unnest(col("struct_col")); - let transformed_exprs = recursive_transform_unnest( + let transformed_exprs = transform_bottom_unnest( &input, &mut unnest_placeholder_columns, &mut inner_projection_exprs, - original_expr, + &original_expr, )?; assert_eq!( transformed_exprs, @@ -413,11 +436,11 @@ mod tests { // unnest(array_col) + 1 let original_expr = unnest(col("array_col")).add(lit(1i64)); - let transformed_exprs = recursive_transform_unnest( + let transformed_exprs = transform_bottom_unnest( &input, &mut unnest_placeholder_columns, &mut inner_projection_exprs, - original_expr, + &original_expr, )?; assert_eq!( unnest_placeholder_columns, @@ -440,6 +463,62 @@ mod tests { ] ); + // a nested structure struct[[]] + let schema = Schema::new(vec![ + Field::new( + "struct_col", // {array_col: [1,2,3]} + ArrowDataType::Struct(Fields::from(vec![Field::new( + "matrix", + ArrowDataType::List(Arc::new(Field::new( + "matrix_row", + ArrowDataType::List(Arc::new(Field::new( + "item", + ArrowDataType::Int64, + true, + ))), + true, + ))), + true, + )])), + false, + ), + Field::new("int_col", ArrowDataType::Int32, false), + ]); + + let dfschema = DFSchema::try_from(schema)?; + + let input = LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: Arc::new(dfschema), + }); + + let mut unnest_placeholder_columns = vec![]; + let mut inner_projection_exprs = vec![]; + + // An expr with multiple unnest + let original_expr = unnest(unnest(col("struct_col").field("matrix"))); + let transformed_exprs = transform_bottom_unnest( + &input, + &mut unnest_placeholder_columns, + &mut inner_projection_exprs, + &original_expr, + )?; + // Only the inner most/ bottom most unnest is transformed + assert_eq!( + transformed_exprs, + vec![unnest(col("unnest(struct_col[matrix])"))] + ); + assert_eq!( + unnest_placeholder_columns, + vec!["unnest(struct_col[matrix])"] + ); + assert_eq!( + inner_projection_exprs, + vec![col("struct_col") + .field("matrix") + .alias("unnest(struct_col[matrix])"),] + ); + Ok(()) } diff --git a/datafusion/sqllogictest/test_files/unnest.slt b/datafusion/sqllogictest/test_files/unnest.slt index 11ad8e0bb843..06733f7b1e40 100644 --- a/datafusion/sqllogictest/test_files/unnest.slt +++ b/datafusion/sqllogictest/test_files/unnest.slt @@ -37,6 +37,13 @@ AS VALUES (struct('d', 'e', struct('f')), (struct('x', 'y', [30,40, 50])), null) ; +statement ok +CREATE TABLE recursive_unnest_table +AS VALUES + (struct([1], 'a'), [[[1],[2]],[[1,1]]], [struct([1],[[1,2]])]), + (struct([2], 'b'), [[[3,4],[5]],[[null,6],null,[7,8]]], [struct([2],[[3],[4]])]) +; + ## Basic unnest expression in select list query I select unnest([1,2,3]); @@ -158,6 +165,37 @@ select unnest(column1), column1 from unnest_table; 6 [6] 12 [12] +## unnest as children of other expr +query I? +select unnest(column1) + 1 , column1 from unnest_table; +---- +2 [1, 2, 3] +3 [1, 2, 3] +4 [1, 2, 3] +5 [4, 5] +6 [4, 5] +7 [6] +13 [12] + +## unnest on multiple columns +query II +select unnest(column1), unnest(column2) from unnest_table; +---- +1 7 +2 NULL +3 NULL +4 8 +5 9 +NULL 10 +6 11 +NULL 12 +12 NULL +NULL 42 +NULL NULL + +query error DataFusion error: Error during planning: unnest\(\) can only be applied to array, struct and null +select unnest('foo'); + query ?II select array_remove(column1, 4), unnest(column2), column3 * 10 from unnest_table; ---- @@ -458,5 +496,64 @@ select unnest(column1) from (select * from (values([1,2,3]), ([4,5,6])) limit 1 5 6 +## FIXME: https://github.com/apache/datafusion/issues/11198 +query error DataFusion error: Error during planning: Projections require unique expression names but the expression "UNNEST\(Column\(Column \{ relation: Some\(Bare \{ table: "unnest_table" \}\), name: "column1" \}\)\)" at position 0 and "UNNEST\(Column\(Column \{ relation: Some\(Bare \{ table: "unnest_table" \}\), name: "column1" \}\)\)" at position 1 have the same name. Consider aliasing \("AS"\) one of them. +select unnest(column1), unnest(column1) from unnest_table; + statement ok drop table unnest_table; + +## unnest list followed by unnest struct +query ??? +select unnest(unnest(column3)), column3 from recursive_unnest_table; +---- +[1] [[1, 2]] [{c0: [1], c1: [[1, 2]]}] +[2] [[3], [4]] [{c0: [2], c1: [[3], [4]]}] + +## unnest->field_access->unnest->unnest +query I? +select unnest(unnest(unnest(column3)['c1'])), column3 from recursive_unnest_table; +---- +1 [{c0: [1], c1: [[1, 2]]}] +2 [{c0: [1], c1: [[1, 2]]}] +3 [{c0: [2], c1: [[3], [4]]}] +4 [{c0: [2], c1: [[3], [4]]}] + +## tripple list unnest +query I? +select unnest(unnest(unnest(column2))), column2 from recursive_unnest_table; +---- +1 [[[1], [2]], [[1, 1]]] +2 [[[1], [2]], [[1, 1]]] +1 [[[1], [2]], [[1, 1]]] +1 [[[1], [2]], [[1, 1]]] +3 [[[3, 4], [5]], [[, 6], , [7, 8]]] +4 [[[3, 4], [5]], [[, 6], , [7, 8]]] +5 [[[3, 4], [5]], [[, 6], , [7, 8]]] +NULL [[[3, 4], [5]], [[, 6], , [7, 8]]] +6 [[[3, 4], [5]], [[, 6], , [7, 8]]] +7 [[[3, 4], [5]], [[, 6], , [7, 8]]] +8 [[[3, 4], [5]], [[, 6], , [7, 8]]] + + + +query TT +explain select unnest(unnest(unnest(column3)['c1'])), column3 from recursive_unnest_table; +---- +logical_plan +01)Unnest: lists[unnest(unnest(unnest(recursive_unnest_table.column3)[c1]))] structs[] +02)--Projection: unnest(unnest(recursive_unnest_table.column3)[c1]) AS unnest(unnest(unnest(recursive_unnest_table.column3)[c1])), recursive_unnest_table.column3 +03)----Unnest: lists[unnest(unnest(recursive_unnest_table.column3)[c1])] structs[] +04)------Projection: get_field(unnest(recursive_unnest_table.column3), Utf8("c1")) AS unnest(unnest(recursive_unnest_table.column3)[c1]), recursive_unnest_table.column3 +05)--------Unnest: lists[unnest(recursive_unnest_table.column3)] structs[] +06)----------Projection: recursive_unnest_table.column3 AS unnest(recursive_unnest_table.column3), recursive_unnest_table.column3 +07)------------TableScan: recursive_unnest_table projection=[column3] +physical_plan +01)UnnestExec +02)--ProjectionExec: expr=[unnest(unnest(recursive_unnest_table.column3)[c1])@0 as unnest(unnest(unnest(recursive_unnest_table.column3)[c1])), column3@1 as column3] +03)----UnnestExec +04)------ProjectionExec: expr=[get_field(unnest(recursive_unnest_table.column3)@0, c1) as unnest(unnest(recursive_unnest_table.column3)[c1]), column3@1 as column3] +05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +06)----------UnnestExec +07)------------ProjectionExec: expr=[column3@0 as unnest(recursive_unnest_table.column3), column3@0 as column3] +08)--------------MemoryExec: partitions=1, partition_sizes=[1] \ No newline at end of file From 1840ab5331391d278de55d2bfb9a019c93314427 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 2 Jul 2024 12:22:02 -0400 Subject: [PATCH 5/7] Document how to test examples in user guide, add some more coverage (#11178) * Document testing examples, add some more coverage * add docs --- datafusion/core/src/lib.rs | 41 +++++++++++++++++++++++- docs/source/contributor-guide/testing.md | 25 +++++++++++++++ docs/source/user-guide/expressions.md | 4 ++- 3 files changed, 68 insertions(+), 2 deletions(-) diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index d81efaf68ca3..5f0af2d0adb8 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -587,8 +587,47 @@ pub mod test_util; #[cfg(doctest)] doc_comment::doctest!("../../../README.md", readme_example_test); +// Instructions for Documentation Examples +// +// The following commands test the examples from the user guide as part of +// `cargo test --doc` +// +// # Adding new tests: +// +// Simply add code like this to your .md file and ensure your md file is +// included in the lists below. +// +// ```rust +// +// ``` +// +// Note that sometimes it helps to author the doctest as a standalone program +// first, and then copy it into the user guide. +// +// # Debugging Test Failures +// +// Unfortunately, the line numbers reported by doctest do not correspond to the +// line numbers of in the .md files. Thus, if a doctest fails, use the name of +// the test to find the relevant file in the list below, and then find the +// example in that file to fix. +// +// For example, if `user_guide_expressions(line 123)` fails, +// go to `docs/source/user-guide/expressions.md` to find the relevant problem. + #[cfg(doctest)] doc_comment::doctest!( "../../../docs/source/user-guide/example-usage.md", - user_guid_example_tests + user_guide_example_usage +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/user-guide/configs.md", + user_guide_configs +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/user-guide/expressions.md", + user_guide_expressions ); diff --git a/docs/source/contributor-guide/testing.md b/docs/source/contributor-guide/testing.md index 11f53bcb2a2d..018cc6233c46 100644 --- a/docs/source/contributor-guide/testing.md +++ b/docs/source/contributor-guide/testing.md @@ -49,6 +49,31 @@ You can run these tests individually using `cargo` as normal command such as cargo test -p datafusion --test parquet_exec ``` +## Documentation Examples + +We use Rust [doctest] to verify examples from the documentation are correct and +up-to-date. These tests are run as part of our CI and you can run them them +locally with the following command: + +```shell +cargo test --doc +``` + +### API Documentation Examples + +As with other Rust projects, examples in doc comments in `.rs` files are +automatically checked to ensure they work and evolve along with the code. + +### User Guide Documentation + +Rust example code from the user guide (anything marked with \`\`\`rust) is also +tested in the same way using the [doc_comment] crate. See the end of +[core/src/lib.rs] for more details. + +[doctest]: https://doc.rust-lang.org/rust-by-example/testing/doc_testing.html +[doc_comment]: https://docs.rs/doc-comment/latest/doc_comment +[core/src/lib.rs]: https://github.com/apache/datafusion/blob/main/datafusion/core/src/lib.rs#L583 + ## Benchmarks ### Criterion Benchmarks diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index cae9627210e5..6e693a0e7087 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -26,8 +26,10 @@ available for creating logical expressions. These are documented below. Most functions and methods may receive and return an `Expr`, which can be chained together using a fluent-style API: ```rust +use datafusion::prelude::*; // create the expression `(a > 6) AND (b < 7)` -col("a").gt(lit(6)).and(col("b").lt(lit(7))) +col("a").gt(lit(6)).and(col("b").lt(lit(7))); + ``` ::: From a4796fa07893cc4db6b6c0f92dab85d2720a52c4 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 2 Jul 2024 12:39:05 -0400 Subject: [PATCH 6/7] Minor: Move MemoryCatalog*Provider into a module, improve comments (#11183) * Minor: Move MemoryCatalog*Provider into a module, improve comments * fix docs --- .../core/src/catalog/information_schema.rs | 2 +- datafusion/core/src/catalog/listing_schema.rs | 2 +- datafusion/core/src/catalog/memory.rs | 352 ++++++++++++++++++ datafusion/core/src/catalog/mod.rs | 193 +--------- datafusion/core/src/catalog/schema.rs | 155 +------- datafusion/core/src/lib.rs | 2 +- 6 files changed, 375 insertions(+), 331 deletions(-) create mode 100644 datafusion/core/src/catalog/memory.rs diff --git a/datafusion/core/src/catalog/information_schema.rs b/datafusion/core/src/catalog/information_schema.rs index a9d4590a5e28..c953de6d16d3 100644 --- a/datafusion/core/src/catalog/information_schema.rs +++ b/datafusion/core/src/catalog/information_schema.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Implements the SQL [Information Schema] for DataFusion. +//! [`InformationSchemaProvider`] that implements the SQL [Information Schema] for DataFusion. //! //! [Information Schema]: https://en.wikipedia.org/wiki/Information_schema diff --git a/datafusion/core/src/catalog/listing_schema.rs b/datafusion/core/src/catalog/listing_schema.rs index 29f3e4ad8181..373fe788c721 100644 --- a/datafusion/core/src/catalog/listing_schema.rs +++ b/datafusion/core/src/catalog/listing_schema.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! listing_schema contains a SchemaProvider that scans ObjectStores for tables automatically +//! [`ListingSchemaProvider`]: [`SchemaProvider`] that scans ObjectStores for tables automatically use std::any::Any; use std::collections::{HashMap, HashSet}; diff --git a/datafusion/core/src/catalog/memory.rs b/datafusion/core/src/catalog/memory.rs new file mode 100644 index 000000000000..3af823913a29 --- /dev/null +++ b/datafusion/core/src/catalog/memory.rs @@ -0,0 +1,352 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`MemoryCatalogProvider`], [`MemoryCatalogProviderList`]: In-memory +//! implementations of [`CatalogProviderList`] and [`CatalogProvider`]. + +use crate::catalog::schema::SchemaProvider; +use crate::catalog::{CatalogProvider, CatalogProviderList}; +use crate::datasource::TableProvider; +use async_trait::async_trait; +use dashmap::DashMap; +use datafusion_common::{exec_err, DataFusionError}; +use std::any::Any; +use std::sync::Arc; + +/// Simple in-memory list of catalogs +pub struct MemoryCatalogProviderList { + /// Collection of catalogs containing schemas and ultimately TableProviders + pub catalogs: DashMap>, +} + +impl MemoryCatalogProviderList { + /// Instantiates a new `MemoryCatalogProviderList` with an empty collection of catalogs + pub fn new() -> Self { + Self { + catalogs: DashMap::new(), + } + } +} + +impl Default for MemoryCatalogProviderList { + fn default() -> Self { + Self::new() + } +} + +impl CatalogProviderList for MemoryCatalogProviderList { + fn as_any(&self) -> &dyn Any { + self + } + + fn register_catalog( + &self, + name: String, + catalog: Arc, + ) -> Option> { + self.catalogs.insert(name, catalog) + } + + fn catalog_names(&self) -> Vec { + self.catalogs.iter().map(|c| c.key().clone()).collect() + } + + fn catalog(&self, name: &str) -> Option> { + self.catalogs.get(name).map(|c| c.value().clone()) + } +} + +/// Simple in-memory implementation of a catalog. +pub struct MemoryCatalogProvider { + schemas: DashMap>, +} + +impl MemoryCatalogProvider { + /// Instantiates a new MemoryCatalogProvider with an empty collection of schemas. + pub fn new() -> Self { + Self { + schemas: DashMap::new(), + } + } +} + +impl Default for MemoryCatalogProvider { + fn default() -> Self { + Self::new() + } +} + +impl CatalogProvider for MemoryCatalogProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema_names(&self) -> Vec { + self.schemas.iter().map(|s| s.key().clone()).collect() + } + + fn schema(&self, name: &str) -> Option> { + self.schemas.get(name).map(|s| s.value().clone()) + } + + fn register_schema( + &self, + name: &str, + schema: Arc, + ) -> datafusion_common::Result>> { + Ok(self.schemas.insert(name.into(), schema)) + } + + fn deregister_schema( + &self, + name: &str, + cascade: bool, + ) -> datafusion_common::Result>> { + if let Some(schema) = self.schema(name) { + let table_names = schema.table_names(); + match (table_names.is_empty(), cascade) { + (true, _) | (false, true) => { + let (_, removed) = self.schemas.remove(name).unwrap(); + Ok(Some(removed)) + } + (false, false) => exec_err!( + "Cannot drop schema {} because other tables depend on it: {}", + name, + itertools::join(table_names.iter(), ", ") + ), + } + } else { + Ok(None) + } + } +} + +/// Simple in-memory implementation of a schema. +pub struct MemorySchemaProvider { + tables: DashMap>, +} + +impl MemorySchemaProvider { + /// Instantiates a new MemorySchemaProvider with an empty collection of tables. + pub fn new() -> Self { + Self { + tables: DashMap::new(), + } + } +} + +impl Default for MemorySchemaProvider { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl SchemaProvider for MemorySchemaProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn table_names(&self) -> Vec { + self.tables + .iter() + .map(|table| table.key().clone()) + .collect() + } + + async fn table( + &self, + name: &str, + ) -> datafusion_common::Result>, DataFusionError> { + Ok(self.tables.get(name).map(|table| table.value().clone())) + } + + fn register_table( + &self, + name: String, + table: Arc, + ) -> datafusion_common::Result>> { + if self.table_exist(name.as_str()) { + return exec_err!("The table {name} already exists"); + } + Ok(self.tables.insert(name, table)) + } + + fn deregister_table( + &self, + name: &str, + ) -> datafusion_common::Result>> { + Ok(self.tables.remove(name).map(|(_, table)| table)) + } + + fn table_exist(&self, name: &str) -> bool { + self.tables.contains_key(name) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::catalog::schema::{MemorySchemaProvider, SchemaProvider}; + use crate::catalog::CatalogProvider; + use crate::datasource::empty::EmptyTable; + use crate::datasource::listing::{ListingTable, ListingTableConfig, ListingTableUrl}; + use crate::datasource::TableProvider; + use crate::prelude::SessionContext; + use arrow_schema::Schema; + use datafusion_common::assert_batches_eq; + use std::any::Any; + use std::sync::Arc; + + #[test] + fn memory_catalog_dereg_nonempty_schema() { + let cat = Arc::new(MemoryCatalogProvider::new()) as Arc; + + let schema = Arc::new(MemorySchemaProvider::new()) as Arc; + let test_table = Arc::new(EmptyTable::new(Arc::new(Schema::empty()))) + as Arc; + schema.register_table("t".into(), test_table).unwrap(); + + cat.register_schema("foo", schema.clone()).unwrap(); + + assert!( + cat.deregister_schema("foo", false).is_err(), + "dropping empty schema without cascade should error" + ); + assert!(cat.deregister_schema("foo", true).unwrap().is_some()); + } + + #[test] + fn memory_catalog_dereg_empty_schema() { + let cat = Arc::new(MemoryCatalogProvider::new()) as Arc; + + let schema = Arc::new(MemorySchemaProvider::new()) as Arc; + cat.register_schema("foo", schema).unwrap(); + + assert!(cat.deregister_schema("foo", false).unwrap().is_some()); + } + + #[test] + fn memory_catalog_dereg_missing() { + let cat = Arc::new(MemoryCatalogProvider::new()) as Arc; + assert!(cat.deregister_schema("foo", false).unwrap().is_none()); + } + + #[test] + fn default_register_schema_not_supported() { + // mimic a new CatalogProvider and ensure it does not support registering schemas + struct TestProvider {} + impl CatalogProvider for TestProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema_names(&self) -> Vec { + unimplemented!() + } + + fn schema(&self, _name: &str) -> Option> { + unimplemented!() + } + } + + let schema = Arc::new(MemorySchemaProvider::new()) as Arc; + let catalog = Arc::new(TestProvider {}); + + match catalog.register_schema("foo", schema) { + Ok(_) => panic!("unexpected OK"), + Err(e) => assert_eq!(e.strip_backtrace(), "This feature is not implemented: Registering new schemas is not supported"), + }; + } + + #[tokio::test] + async fn test_mem_provider() { + let provider = MemorySchemaProvider::new(); + let table_name = "test_table_exist"; + assert!(!provider.table_exist(table_name)); + assert!(provider.deregister_table(table_name).unwrap().is_none()); + let test_table = EmptyTable::new(Arc::new(Schema::empty())); + // register table successfully + assert!(provider + .register_table(table_name.to_string(), Arc::new(test_table)) + .unwrap() + .is_none()); + assert!(provider.table_exist(table_name)); + let other_table = EmptyTable::new(Arc::new(Schema::empty())); + let result = + provider.register_table(table_name.to_string(), Arc::new(other_table)); + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_schema_register_listing_table() { + let testdata = crate::test_util::parquet_test_data(); + let testdir = if testdata.starts_with('/') { + format!("file://{testdata}") + } else { + format!("file:///{testdata}") + }; + let filename = if testdir.ends_with('/') { + format!("{}{}", testdir, "alltypes_plain.parquet") + } else { + format!("{}/{}", testdir, "alltypes_plain.parquet") + }; + + let table_path = ListingTableUrl::parse(filename).unwrap(); + + let catalog = MemoryCatalogProvider::new(); + let schema = MemorySchemaProvider::new(); + + let ctx = SessionContext::new(); + + let config = ListingTableConfig::new(table_path) + .infer(&ctx.state()) + .await + .unwrap(); + let table = ListingTable::try_new(config).unwrap(); + + schema + .register_table("alltypes_plain".to_string(), Arc::new(table)) + .unwrap(); + + catalog.register_schema("active", Arc::new(schema)).unwrap(); + ctx.register_catalog("cat", Arc::new(catalog)); + + let df = ctx + .sql("SELECT id, bool_col FROM cat.active.alltypes_plain") + .await + .unwrap(); + + let actual = df.collect().await.unwrap(); + + let expected = [ + "+----+----------+", + "| id | bool_col |", + "+----+----------+", + "| 4 | true |", + "| 5 | false |", + "| 6 | true |", + "| 7 | false |", + "| 2 | true |", + "| 3 | false |", + "| 0 | true |", + "| 1 | false |", + "+----+----------+", + ]; + assert_batches_eq!(expected, &actual); + } +} diff --git a/datafusion/core/src/catalog/mod.rs b/datafusion/core/src/catalog/mod.rs index 53b133339924..531adc4b210c 100644 --- a/datafusion/core/src/catalog/mod.rs +++ b/datafusion/core/src/catalog/mod.rs @@ -16,16 +16,30 @@ // under the License. //! Interfaces and default implementations of catalogs and schemas. +//! +//! Traits: +//! * [`CatalogProviderList`]: a collection of `CatalogProvider`s +//! * [`CatalogProvider`]: a collection of [`SchemaProvider`]s (sometimes called a "database" in other systems) +//! * [`SchemaProvider`]: a collection of `TableProvider`s (often called a "schema" in other systems) +//! +//! Implementations +//! * Simple memory based catalog: [`MemoryCatalogProviderList`], [`MemoryCatalogProvider`], [`MemorySchemaProvider`] +//! * Information schema: [`information_schema`] +//! * Listing schema: [`listing_schema`] pub mod information_schema; pub mod listing_schema; +mod memory; pub mod schema; +pub use memory::{ + MemoryCatalogProvider, MemoryCatalogProviderList, MemorySchemaProvider, +}; +pub use schema::SchemaProvider; + pub use datafusion_sql::{ResolvedTableReference, TableReference}; -use crate::catalog::schema::SchemaProvider; -use dashmap::DashMap; -use datafusion_common::{exec_err, not_impl_err, Result}; +use datafusion_common::{not_impl_err, Result}; use std::any::Any; use std::collections::BTreeSet; use std::ops::ControlFlow; @@ -59,49 +73,6 @@ pub trait CatalogProviderList: Sync + Send { #[deprecated(since = "35.0.0", note = "use [`CatalogProviderList`] instead")] pub trait CatalogList: CatalogProviderList {} -/// Simple in-memory list of catalogs -pub struct MemoryCatalogProviderList { - /// Collection of catalogs containing schemas and ultimately TableProviders - pub catalogs: DashMap>, -} - -impl MemoryCatalogProviderList { - /// Instantiates a new `MemoryCatalogProviderList` with an empty collection of catalogs - pub fn new() -> Self { - Self { - catalogs: DashMap::new(), - } - } -} - -impl Default for MemoryCatalogProviderList { - fn default() -> Self { - Self::new() - } -} - -impl CatalogProviderList for MemoryCatalogProviderList { - fn as_any(&self) -> &dyn Any { - self - } - - fn register_catalog( - &self, - name: String, - catalog: Arc, - ) -> Option> { - self.catalogs.insert(name, catalog) - } - - fn catalog_names(&self) -> Vec { - self.catalogs.iter().map(|c| c.key().clone()).collect() - } - - fn catalog(&self, name: &str) -> Option> { - self.catalogs.get(name).map(|c| c.value().clone()) - } -} - /// Represents a catalog, comprising a number of named schemas. /// /// # Catalog Overview @@ -232,71 +203,6 @@ pub trait CatalogProvider: Sync + Send { } } -/// Simple in-memory implementation of a catalog. -pub struct MemoryCatalogProvider { - schemas: DashMap>, -} - -impl MemoryCatalogProvider { - /// Instantiates a new MemoryCatalogProvider with an empty collection of schemas. - pub fn new() -> Self { - Self { - schemas: DashMap::new(), - } - } -} - -impl Default for MemoryCatalogProvider { - fn default() -> Self { - Self::new() - } -} - -impl CatalogProvider for MemoryCatalogProvider { - fn as_any(&self) -> &dyn Any { - self - } - - fn schema_names(&self) -> Vec { - self.schemas.iter().map(|s| s.key().clone()).collect() - } - - fn schema(&self, name: &str) -> Option> { - self.schemas.get(name).map(|s| s.value().clone()) - } - - fn register_schema( - &self, - name: &str, - schema: Arc, - ) -> Result>> { - Ok(self.schemas.insert(name.into(), schema)) - } - - fn deregister_schema( - &self, - name: &str, - cascade: bool, - ) -> Result>> { - if let Some(schema) = self.schema(name) { - let table_names = schema.table_names(); - match (table_names.is_empty(), cascade) { - (true, _) | (false, true) => { - let (_, removed) = self.schemas.remove(name).unwrap(); - Ok(Some(removed)) - } - (false, false) => exec_err!( - "Cannot drop schema {} because other tables depend on it: {}", - name, - itertools::join(table_names.iter(), ", ") - ), - } - } else { - Ok(None) - } - } -} - /// Collects all tables and views referenced in the SQL statement. CTEs are collected separately. /// This can be used to determine which tables need to be in the catalog for a query to be planned. /// @@ -476,71 +382,6 @@ pub fn resolve_table_references( #[cfg(test)] mod tests { use super::*; - use crate::catalog::schema::MemorySchemaProvider; - use crate::datasource::empty::EmptyTable; - use crate::datasource::TableProvider; - use arrow::datatypes::Schema; - - #[test] - fn default_register_schema_not_supported() { - // mimic a new CatalogProvider and ensure it does not support registering schemas - struct TestProvider {} - impl CatalogProvider for TestProvider { - fn as_any(&self) -> &dyn Any { - self - } - - fn schema_names(&self) -> Vec { - unimplemented!() - } - - fn schema(&self, _name: &str) -> Option> { - unimplemented!() - } - } - - let schema = Arc::new(MemorySchemaProvider::new()) as Arc; - let catalog = Arc::new(TestProvider {}); - - match catalog.register_schema("foo", schema) { - Ok(_) => panic!("unexpected OK"), - Err(e) => assert_eq!(e.strip_backtrace(), "This feature is not implemented: Registering new schemas is not supported"), - }; - } - - #[test] - fn memory_catalog_dereg_nonempty_schema() { - let cat = Arc::new(MemoryCatalogProvider::new()) as Arc; - - let schema = Arc::new(MemorySchemaProvider::new()) as Arc; - let test_table = Arc::new(EmptyTable::new(Arc::new(Schema::empty()))) - as Arc; - schema.register_table("t".into(), test_table).unwrap(); - - cat.register_schema("foo", schema.clone()).unwrap(); - - assert!( - cat.deregister_schema("foo", false).is_err(), - "dropping empty schema without cascade should error" - ); - assert!(cat.deregister_schema("foo", true).unwrap().is_some()); - } - - #[test] - fn memory_catalog_dereg_empty_schema() { - let cat = Arc::new(MemoryCatalogProvider::new()) as Arc; - - let schema = Arc::new(MemorySchemaProvider::new()) as Arc; - cat.register_schema("foo", schema).unwrap(); - - assert!(cat.deregister_schema("foo", false).unwrap().is_some()); - } - - #[test] - fn memory_catalog_dereg_missing() { - let cat = Arc::new(MemoryCatalogProvider::new()) as Arc; - assert!(cat.deregister_schema("foo", false).unwrap().is_none()); - } #[test] fn resolve_table_references_shadowed_cte() { diff --git a/datafusion/core/src/catalog/schema.rs b/datafusion/core/src/catalog/schema.rs index 8249c3a5330f..7d76b3fa4f19 100644 --- a/datafusion/core/src/catalog/schema.rs +++ b/datafusion/core/src/catalog/schema.rs @@ -19,7 +19,6 @@ //! representing collections of named tables. use async_trait::async_trait; -use dashmap::DashMap; use datafusion_common::{exec_err, DataFusionError}; use std::any::Any; use std::sync::Arc; @@ -27,6 +26,9 @@ use std::sync::Arc; use crate::datasource::TableProvider; use crate::error::Result; +// backwards compatibility +pub use super::MemorySchemaProvider; + /// Represents a schema, comprising a number of named tables. /// /// Please see [`CatalogProvider`] for details of implementing a custom catalog. @@ -80,154 +82,3 @@ pub trait SchemaProvider: Sync + Send { /// Returns true if table exist in the schema provider, false otherwise. fn table_exist(&self, name: &str) -> bool; } - -/// Simple in-memory implementation of a schema. -pub struct MemorySchemaProvider { - tables: DashMap>, -} - -impl MemorySchemaProvider { - /// Instantiates a new MemorySchemaProvider with an empty collection of tables. - pub fn new() -> Self { - Self { - tables: DashMap::new(), - } - } -} - -impl Default for MemorySchemaProvider { - fn default() -> Self { - Self::new() - } -} - -#[async_trait] -impl SchemaProvider for MemorySchemaProvider { - fn as_any(&self) -> &dyn Any { - self - } - - fn table_names(&self) -> Vec { - self.tables - .iter() - .map(|table| table.key().clone()) - .collect() - } - - async fn table( - &self, - name: &str, - ) -> Result>, DataFusionError> { - Ok(self.tables.get(name).map(|table| table.value().clone())) - } - - fn register_table( - &self, - name: String, - table: Arc, - ) -> Result>> { - if self.table_exist(name.as_str()) { - return exec_err!("The table {name} already exists"); - } - Ok(self.tables.insert(name, table)) - } - - fn deregister_table(&self, name: &str) -> Result>> { - Ok(self.tables.remove(name).map(|(_, table)| table)) - } - - fn table_exist(&self, name: &str) -> bool { - self.tables.contains_key(name) - } -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use arrow::datatypes::Schema; - - use crate::assert_batches_eq; - use crate::catalog::schema::{MemorySchemaProvider, SchemaProvider}; - use crate::catalog::{CatalogProvider, MemoryCatalogProvider}; - use crate::datasource::empty::EmptyTable; - use crate::datasource::listing::{ListingTable, ListingTableConfig, ListingTableUrl}; - use crate::prelude::SessionContext; - - #[tokio::test] - async fn test_mem_provider() { - let provider = MemorySchemaProvider::new(); - let table_name = "test_table_exist"; - assert!(!provider.table_exist(table_name)); - assert!(provider.deregister_table(table_name).unwrap().is_none()); - let test_table = EmptyTable::new(Arc::new(Schema::empty())); - // register table successfully - assert!(provider - .register_table(table_name.to_string(), Arc::new(test_table)) - .unwrap() - .is_none()); - assert!(provider.table_exist(table_name)); - let other_table = EmptyTable::new(Arc::new(Schema::empty())); - let result = - provider.register_table(table_name.to_string(), Arc::new(other_table)); - assert!(result.is_err()); - } - - #[tokio::test] - async fn test_schema_register_listing_table() { - let testdata = crate::test_util::parquet_test_data(); - let testdir = if testdata.starts_with('/') { - format!("file://{testdata}") - } else { - format!("file:///{testdata}") - }; - let filename = if testdir.ends_with('/') { - format!("{}{}", testdir, "alltypes_plain.parquet") - } else { - format!("{}/{}", testdir, "alltypes_plain.parquet") - }; - - let table_path = ListingTableUrl::parse(filename).unwrap(); - - let catalog = MemoryCatalogProvider::new(); - let schema = MemorySchemaProvider::new(); - - let ctx = SessionContext::new(); - - let config = ListingTableConfig::new(table_path) - .infer(&ctx.state()) - .await - .unwrap(); - let table = ListingTable::try_new(config).unwrap(); - - schema - .register_table("alltypes_plain".to_string(), Arc::new(table)) - .unwrap(); - - catalog.register_schema("active", Arc::new(schema)).unwrap(); - ctx.register_catalog("cat", Arc::new(catalog)); - - let df = ctx - .sql("SELECT id, bool_col FROM cat.active.alltypes_plain") - .await - .unwrap(); - - let actual = df.collect().await.unwrap(); - - let expected = [ - "+----+----------+", - "| id | bool_col |", - "+----+----------+", - "| 4 | true |", - "| 5 | false |", - "| 6 | true |", - "| 7 | false |", - "| 2 | true |", - "| 3 | false |", - "| 0 | true |", - "| 1 | false |", - "+----+----------+", - ]; - assert_batches_eq!(expected, &actual); - } -} diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index 5f0af2d0adb8..2810dca46365 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -182,7 +182,7 @@ //! In order to achieve this, DataFusion supports extension at many points: //! //! * read from any datasource ([`TableProvider`]) -//! * define your own catalogs, schemas, and table lists ([`CatalogProvider`]) +//! * define your own catalogs, schemas, and table lists ([`catalog`] and [`CatalogProvider`]) //! * build your own query language or plans ([`LogicalPlanBuilder`]) //! * declare and use user-defined functions ([`ScalarUDF`], and [`AggregateUDF`], [`WindowUDF`]) //! * add custom plan rewrite passes ([`AnalyzerRule`], [`OptimizerRule`] and [`PhysicalOptimizerRule`]) From 4f4cd81de72a858896ac37a51b0e354cb379307c Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Tue, 2 Jul 2024 20:55:30 +0200 Subject: [PATCH 7/7] Fix docs wordings (#11226) * Fix "part of Arrow" wording in docs DataFusion is a top level project. * Fix doc reference in into_optimized_plan --- datafusion/core/src/dataframe/mod.rs | 2 +- docs/source/user-guide/introduction.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 8e55da8c3ad0..d0f2852a6e53 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1157,7 +1157,7 @@ impl DataFrame { /// Return the optimized [`LogicalPlan`] represented by this DataFrame. /// /// Note: This method should not be used outside testing -- see - /// [`Self::into_optimized_plan`] for more details. + /// [`Self::into_unoptimized_plan`] for more details. pub fn into_optimized_plan(self) -> Result { // Optimize the plan first for better UX self.session_state.optimize(&self.plan) diff --git a/docs/source/user-guide/introduction.md b/docs/source/user-guide/introduction.md index 5a458146c1c0..3a39419236d8 100644 --- a/docs/source/user-guide/introduction.md +++ b/docs/source/user-guide/introduction.md @@ -22,7 +22,7 @@ DataFusion is a very fast, extensible query engine for building high-quality data-centric systems in [Rust](http://rustlang.org), using the [Apache Arrow](https://arrow.apache.org) in-memory format. -DataFusion is part of the [Apache Arrow](https://arrow.apache.org/) +DataFusion originated as part of the [Apache Arrow](https://arrow.apache.org/) project. DataFusion offers SQL and Dataframe APIs, excellent [performance](https://benchmark.clickhouse.com/), built-in support for CSV, Parquet, JSON, and Avro, [python bindings], extensive customization, a great community, and more.