From a91be04ced3746c673788d5da124c6d30009d9ff Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sun, 18 Aug 2024 20:47:21 +0800 Subject: [PATCH] Window UDF signature check (#12045) * udwf sig Signed-off-by: jayzhan211 * add coerce_types Signed-off-by: jayzhan211 * add doc Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- datafusion/expr/src/expr_schema.rs | 17 ++++- .../expr/src/type_coercion/functions.rs | 74 +++++++++++++++++-- datafusion/expr/src/udwf.rs | 30 +++++++- datafusion/sqllogictest/test_files/window.slt | 18 +++++ 4 files changed, 130 insertions(+), 9 deletions(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index af35b9a9910d..f6489fef14a1 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -22,7 +22,7 @@ use crate::expr::{ }; use crate::type_coercion::binary::get_result_type; use crate::type_coercion::functions::{ - data_types_with_aggregate_udf, data_types_with_scalar_udf, + data_types_with_aggregate_udf, data_types_with_scalar_udf, data_types_with_window_udf, }; use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition}; use arrow::compute::can_cast_types; @@ -191,6 +191,21 @@ impl ExprSchemable for Expr { })?; Ok(fun.return_type(&new_types, &nullability)?) } + WindowFunctionDefinition::WindowUDF(udwf) => { + let new_types = data_types_with_window_udf(&data_types, udwf) + .map_err(|err| { + plan_datafusion_err!( + "{} {}", + err, + utils::generate_signature_error_msg( + fun.name(), + fun.signature().clone(), + &data_types + ) + ) + })?; + Ok(fun.return_type(&new_types, &nullability)?) + } _ => fun.return_type(&data_types, &nullability), } } diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 190374b01dd2..b0b14a1a4e6e 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -15,22 +15,21 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; - -use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature}; +use super::binary::{binary_numeric_coercion, comparison_coercion}; +use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF}; use arrow::{ compute::can_cast_types, datatypes::{DataType, TimeUnit}, }; -use datafusion_common::utils::{coerced_fixed_size_list_to_list, list_ndims}; use datafusion_common::{ - exec_err, internal_datafusion_err, internal_err, plan_err, Result, + exec_err, internal_datafusion_err, internal_err, plan_err, + utils::{coerced_fixed_size_list_to_list, list_ndims}, + Result, }; use datafusion_expr_common::signature::{ ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD, }; - -use super::binary::{binary_numeric_coercion, comparison_coercion}; +use std::sync::Arc; /// Performs type coercion for scalar function arguments. /// @@ -66,6 +65,13 @@ pub fn data_types_with_scalar_udf( try_coerce_types(valid_types, current_types, &signature.type_signature) } +/// Performs type coercion for aggregate function arguments. +/// +/// Returns the data types to which each argument must be coerced to +/// match `signature`. +/// +/// For more details on coercion in general, please see the +/// [`type_coercion`](crate::type_coercion) module. pub fn data_types_with_aggregate_udf( current_types: &[DataType], func: &AggregateUDF, @@ -95,6 +101,39 @@ pub fn data_types_with_aggregate_udf( try_coerce_types(valid_types, current_types, &signature.type_signature) } +/// Performs type coercion for window function arguments. +/// +/// Returns the data types to which each argument must be coerced to +/// match `signature`. +/// +/// For more details on coercion in general, please see the +/// [`type_coercion`](crate::type_coercion) module. +pub fn data_types_with_window_udf( + current_types: &[DataType], + func: &WindowUDF, +) -> Result> { + let signature = func.signature(); + + if current_types.is_empty() { + if signature.type_signature.supports_zero_argument() { + return Ok(vec![]); + } else { + return plan_err!("{} does not support zero arguments.", func.name()); + } + } + + let valid_types = + get_valid_types_with_window_udf(&signature.type_signature, current_types, func)?; + if valid_types + .iter() + .any(|data_type| data_type == current_types) + { + return Ok(current_types.to_vec()); + } + + try_coerce_types(valid_types, current_types, &signature.type_signature) +} + /// Performs type coercion for function arguments. /// /// Returns the data types to which each argument must be coerced to @@ -205,6 +244,27 @@ fn get_valid_types_with_aggregate_udf( Ok(valid_types) } +fn get_valid_types_with_window_udf( + signature: &TypeSignature, + current_types: &[DataType], + func: &WindowUDF, +) -> Result>> { + let valid_types = match signature { + TypeSignature::UserDefined => match func.coerce_types(current_types) { + Ok(coerced_types) => vec![coerced_types], + Err(e) => return exec_err!("User-defined coercion failed with {:?}", e), + }, + TypeSignature::OneOf(signatures) => signatures + .iter() + .filter_map(|t| get_valid_types_with_window_udf(t, current_types, func).ok()) + .flatten() + .collect::>(), + _ => get_valid_types(signature, current_types)?, + }; + + Ok(valid_types) +} + /// Returns a Vec of all possible valid argument types for the given signature. fn get_valid_types( signature: &TypeSignature, diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index aa754a57086f..88b3d613cb43 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -27,7 +27,7 @@ use std::{ use arrow::datatypes::DataType; -use datafusion_common::Result; +use datafusion_common::{not_impl_err, Result}; use crate::expr::WindowFunction; use crate::{ @@ -192,6 +192,11 @@ impl WindowUDF { pub fn sort_options(&self) -> Option { self.inner.sort_options() } + + /// See [`WindowUDFImpl::coerce_types`] for more details. + pub fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + self.inner.coerce_types(arg_types) + } } impl From for WindowUDF @@ -353,6 +358,29 @@ pub trait WindowUDFImpl: Debug + Send + Sync { fn sort_options(&self) -> Option { None } + + /// Coerce arguments of a function call to types that the function can evaluate. + /// + /// This function is only called if [`WindowUDFImpl::signature`] returns [`crate::TypeSignature::UserDefined`]. Most + /// UDWFs should return one of the other variants of `TypeSignature` which handle common + /// cases + /// + /// See the [type coercion module](crate::type_coercion) + /// documentation for more details on type coercion + /// + /// For example, if your function requires a floating point arguments, but the user calls + /// it like `my_func(1::int)` (aka with `1` as an integer), coerce_types could return `[DataType::Float64]` + /// to ensure the argument was cast to `1::double` + /// + /// # Parameters + /// * `arg_types`: The argument types of the arguments this function with + /// + /// # Return value + /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call + /// arguments to these specific types. + fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { + not_impl_err!("Function {} does not implement coerce_types", self.name()) + } } /// WindowUDF that adds an alias to the underlying function. It is better to diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 0bf7a8a1eb1b..ef6746730eb6 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -4879,3 +4879,21 @@ SELECT lead(column2, 1.1) OVER (order by column1) FROM t; query error DataFusion error: Execution error: Expected an integer value SELECT nth_value(column2, 1.1) OVER (order by column1) FROM t; + +statement ok +drop table t; + +statement ok +create table t(a int, b int) as values (1, 2) + +query II +select a, row_number() over (order by b) as rn from t; +---- +1 1 + +# RowNumber expect 0 args. +query error +select a, row_number(a) over (order by b) as rn from t; + +statement ok +drop table t;