From 9179aa090ac5fca42964371430653d4d41149740 Mon Sep 17 00:00:00 2001 From: HuSen8891 Date: Tue, 8 Oct 2024 20:28:08 +0800 Subject: [PATCH] Fix: handle NULL input in lead/lag window function --- datafusion/physical-plan/src/windows/mod.rs | 51 ++++++++++++++++--- datafusion/sqllogictest/test_files/window.slt | 50 +++++++++++++++++- 2 files changed, 94 insertions(+), 7 deletions(-) diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 6f7d95bf95f5..e6a773f6b1ea 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -217,6 +217,41 @@ fn get_casted_value( } } +/// Rewrites the NULL expression (1st argument) with an expression +/// which is the same data type as the default value (3rd argument). +/// Also rewrites the return type with the same data type as the +/// default value. +/// +/// If a default value is not provided, or it is NULL the original +/// expression (1st argument) and return type is returned without +/// any modifications. +fn rewrite_null_expr_and_data_type( + args: &[Arc], + expr_type: &DataType, +) -> Result<(Arc, DataType)> { + assert!(!args.is_empty()); + let expr = Arc::clone(&args[0]); + + // The input expression and the return is type is unchanged + // when the input expression is not NULL. + if !expr_type.is_null() { + return Ok((expr, expr_type.clone())); + } + + get_scalar_value_from_args(args, 2)? + .and_then(|value| { + ScalarValue::try_from(value.data_type().clone()) + .map(|sv| { + Ok(( + Arc::new(Literal::new(sv)) as Arc, + value.data_type().clone(), + )) + }) + .ok() + }) + .unwrap_or(Ok((expr, expr_type.clone()))) +} + fn create_built_in_window_expr( fun: &BuiltInWindowFunction, args: &[Arc], @@ -252,15 +287,17 @@ fn create_built_in_window_expr( } } BuiltInWindowFunction::Lag => { - let arg = Arc::clone(&args[0]); + // rewrite NULL expression and the return datatype + let (arg, out_data_type) = + rewrite_null_expr_and_data_type(args, out_data_type)?; let shift_offset = get_scalar_value_from_args(args, 1)? .map(get_signed_integer) .map_or(Ok(None), |v| v.map(Some))?; let default_value = - get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?; + get_casted_value(get_scalar_value_from_args(args, 2)?, &out_data_type)?; Arc::new(lag( name, - out_data_type.clone(), + default_value.data_type().clone(), arg, shift_offset, default_value, @@ -268,15 +305,17 @@ fn create_built_in_window_expr( )) } BuiltInWindowFunction::Lead => { - let arg = Arc::clone(&args[0]); + // rewrite NULL expression and the return datatype + let (arg, out_data_type) = + rewrite_null_expr_and_data_type(args, out_data_type)?; let shift_offset = get_scalar_value_from_args(args, 1)? .map(get_signed_integer) .map_or(Ok(None), |v| v.map(Some))?; let default_value = - get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?; + get_casted_value(get_scalar_value_from_args(args, 2)?, &out_data_type)?; Arc::new(lead( name, - out_data_type.clone(), + default_value.data_type().clone(), arg, shift_offset, default_value, diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 79cb91e183db..1b612f921262 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -4941,4 +4941,52 @@ NULL statement ok DROP TABLE t; -## end test handle NULL and 0 of NTH_VALUE \ No newline at end of file +## end test handle NULL and 0 of NTH_VALUE + +## test handle NULL of lead + +statement ok +create table t1(v1 int); + +statement ok +insert into t1 values (1); + +query B +SELECT LEAD(NULL, 0, false) OVER () FROM t1; +---- +NULL + +query B +SELECT LAG(NULL, 0, false) OVER () FROM t1; +---- +NULL + +query B +SELECT LEAD(NULL, 1, false) OVER () FROM t1; +---- +false + +query B +SELECT LAG(NULL, 1, false) OVER () FROM t1; +---- +false + +statement ok +insert into t1 values (2); + +query B +SELECT LEAD(NULL, 1, false) OVER () FROM t1; +---- +NULL +false + +query B +SELECT LAG(NULL, 1, false) OVER () FROM t1; +---- +false +NULL + +statement ok +DROP TABLE t1; + +## end test handle NULL of lead \ No newline at end of file