Skip to content

Commit

Permalink
Window UDF signature check (apache#12045)
Browse files Browse the repository at this point in the history
* udwf sig

Signed-off-by: jayzhan211 <[email protected]>

* add coerce_types

Signed-off-by: jayzhan211 <[email protected]>

* add doc

Signed-off-by: jayzhan211 <[email protected]>

---------

Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 authored Aug 18, 2024
1 parent 950dc73 commit a91be04
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 9 deletions.
17 changes: 16 additions & 1 deletion datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use crate::expr::{
};
use crate::type_coercion::binary::get_result_type;
use crate::type_coercion::functions::{
data_types_with_aggregate_udf, data_types_with_scalar_udf,
data_types_with_aggregate_udf, data_types_with_scalar_udf, data_types_with_window_udf,
};
use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition};
use arrow::compute::can_cast_types;
Expand Down Expand Up @@ -191,6 +191,21 @@ impl ExprSchemable for Expr {
})?;
Ok(fun.return_type(&new_types, &nullability)?)
}
WindowFunctionDefinition::WindowUDF(udwf) => {
let new_types = data_types_with_window_udf(&data_types, udwf)
.map_err(|err| {
plan_datafusion_err!(
"{} {}",
err,
utils::generate_signature_error_msg(
fun.name(),
fun.signature().clone(),
&data_types
)
)
})?;
Ok(fun.return_type(&new_types, &nullability)?)
}
_ => fun.return_type(&data_types, &nullability),
}
}
Expand Down
74 changes: 67 additions & 7 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,21 @@
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature};
use super::binary::{binary_numeric_coercion, comparison_coercion};
use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF};
use arrow::{
compute::can_cast_types,
datatypes::{DataType, TimeUnit},
};
use datafusion_common::utils::{coerced_fixed_size_list_to_list, list_ndims};
use datafusion_common::{
exec_err, internal_datafusion_err, internal_err, plan_err, Result,
exec_err, internal_datafusion_err, internal_err, plan_err,
utils::{coerced_fixed_size_list_to_list, list_ndims},
Result,
};
use datafusion_expr_common::signature::{
ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD,
};

use super::binary::{binary_numeric_coercion, comparison_coercion};
use std::sync::Arc;

/// Performs type coercion for scalar function arguments.
///
Expand Down Expand Up @@ -66,6 +65,13 @@ pub fn data_types_with_scalar_udf(
try_coerce_types(valid_types, current_types, &signature.type_signature)
}

/// Performs type coercion for aggregate function arguments.
///
/// Returns the data types to which each argument must be coerced to
/// match `signature`.
///
/// For more details on coercion in general, please see the
/// [`type_coercion`](crate::type_coercion) module.
pub fn data_types_with_aggregate_udf(
current_types: &[DataType],
func: &AggregateUDF,
Expand Down Expand Up @@ -95,6 +101,39 @@ pub fn data_types_with_aggregate_udf(
try_coerce_types(valid_types, current_types, &signature.type_signature)
}

/// Performs type coercion for window function arguments.
///
/// Returns the data types to which each argument must be coerced to
/// match `signature`.
///
/// For more details on coercion in general, please see the
/// [`type_coercion`](crate::type_coercion) module.
pub fn data_types_with_window_udf(
current_types: &[DataType],
func: &WindowUDF,
) -> Result<Vec<DataType>> {
let signature = func.signature();

if current_types.is_empty() {
if signature.type_signature.supports_zero_argument() {
return Ok(vec![]);
} else {
return plan_err!("{} does not support zero arguments.", func.name());
}
}

let valid_types =
get_valid_types_with_window_udf(&signature.type_signature, current_types, func)?;
if valid_types
.iter()
.any(|data_type| data_type == current_types)
{
return Ok(current_types.to_vec());
}

try_coerce_types(valid_types, current_types, &signature.type_signature)
}

/// Performs type coercion for function arguments.
///
/// Returns the data types to which each argument must be coerced to
Expand Down Expand Up @@ -205,6 +244,27 @@ fn get_valid_types_with_aggregate_udf(
Ok(valid_types)
}

fn get_valid_types_with_window_udf(
signature: &TypeSignature,
current_types: &[DataType],
func: &WindowUDF,
) -> Result<Vec<Vec<DataType>>> {
let valid_types = match signature {
TypeSignature::UserDefined => match func.coerce_types(current_types) {
Ok(coerced_types) => vec![coerced_types],
Err(e) => return exec_err!("User-defined coercion failed with {:?}", e),
},
TypeSignature::OneOf(signatures) => signatures
.iter()
.filter_map(|t| get_valid_types_with_window_udf(t, current_types, func).ok())
.flatten()
.collect::<Vec<_>>(),
_ => get_valid_types(signature, current_types)?,
};

Ok(valid_types)
}

/// Returns a Vec of all possible valid argument types for the given signature.
fn get_valid_types(
signature: &TypeSignature,
Expand Down
30 changes: 29 additions & 1 deletion datafusion/expr/src/udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use std::{

use arrow::datatypes::DataType;

use datafusion_common::Result;
use datafusion_common::{not_impl_err, Result};

use crate::expr::WindowFunction;
use crate::{
Expand Down Expand Up @@ -192,6 +192,11 @@ impl WindowUDF {
pub fn sort_options(&self) -> Option<SortOptions> {
self.inner.sort_options()
}

/// See [`WindowUDFImpl::coerce_types`] for more details.
pub fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
self.inner.coerce_types(arg_types)
}
}

impl<F> From<F> for WindowUDF
Expand Down Expand Up @@ -353,6 +358,29 @@ pub trait WindowUDFImpl: Debug + Send + Sync {
fn sort_options(&self) -> Option<SortOptions> {
None
}

/// Coerce arguments of a function call to types that the function can evaluate.
///
/// This function is only called if [`WindowUDFImpl::signature`] returns [`crate::TypeSignature::UserDefined`]. Most
/// UDWFs should return one of the other variants of `TypeSignature` which handle common
/// cases
///
/// See the [type coercion module](crate::type_coercion)
/// documentation for more details on type coercion
///
/// For example, if your function requires a floating point arguments, but the user calls
/// it like `my_func(1::int)` (aka with `1` as an integer), coerce_types could return `[DataType::Float64]`
/// to ensure the argument was cast to `1::double`
///
/// # Parameters
/// * `arg_types`: The argument types of the arguments this function with
///
/// # Return value
/// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call
/// arguments to these specific types.
fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
not_impl_err!("Function {} does not implement coerce_types", self.name())
}
}

/// WindowUDF that adds an alias to the underlying function. It is better to
Expand Down
18 changes: 18 additions & 0 deletions datafusion/sqllogictest/test_files/window.slt
Original file line number Diff line number Diff line change
Expand Up @@ -4879,3 +4879,21 @@ SELECT lead(column2, 1.1) OVER (order by column1) FROM t;

query error DataFusion error: Execution error: Expected an integer value
SELECT nth_value(column2, 1.1) OVER (order by column1) FROM t;

statement ok
drop table t;

statement ok
create table t(a int, b int) as values (1, 2)

query II
select a, row_number() over (order by b) as rn from t;
----
1 1

# RowNumber expect 0 args.
query error
select a, row_number(a) over (order by b) as rn from t;

statement ok
drop table t;

0 comments on commit a91be04

Please sign in to comment.