From dee4261a0b9f8775b2eac6c61b0d5c2cd0a61e1a Mon Sep 17 00:00:00 2001 From: HuSen8891 Date: Tue, 8 Oct 2024 20:28:08 +0800 Subject: [PATCH] Fix: handle NULL input in lead/lag window function --- datafusion/physical-plan/src/windows/mod.rs | 40 ++++++++++++--- datafusion/sqllogictest/test_files/window.slt | 50 ++++++++++++++++++- 2 files changed, 81 insertions(+), 9 deletions(-) diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 6aafaad0ad779..9e25536c6598e 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -257,15 +257,27 @@ fn create_built_in_window_expr( } } BuiltInWindowFunction::Lag => { - let arg = Arc::clone(&args[0]); + let mut arg = Arc::clone(&args[0]); let shift_offset = get_scalar_value_from_args(args, 1)? .map(get_signed_integer) .map_or(Ok(None), |v| v.map(Some))?; - let default_value = - get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?; + // If value is NULL, we use default data type as output data type, no need to cast data type + let default_value = match out_data_type { + DataType::Null => match get_scalar_value_from_args(args, 2)? { + Some(value) => { + let null_value = ScalarValue::try_from(value.data_type())?; + arg = Arc::new(Literal::new(null_value)); + value + } + None => ScalarValue::try_from(DataType::Null)?, + }, + _ => { + get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)? + } + }; Arc::new(lag( name, - out_data_type.clone(), + default_value.data_type().clone(), arg, shift_offset, default_value, @@ -273,15 +285,27 @@ fn create_built_in_window_expr( )) } BuiltInWindowFunction::Lead => { - let arg = Arc::clone(&args[0]); + let mut arg = Arc::clone(&args[0]); let shift_offset = get_scalar_value_from_args(args, 1)? .map(get_signed_integer) .map_or(Ok(None), |v| v.map(Some))?; - let default_value = - get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?; + // If value is NULL, we use default data type as output data type, no need to cast data type + let default_value = match out_data_type { + DataType::Null => match get_scalar_value_from_args(args, 2)? { + Some(value) => { + let null_value = ScalarValue::try_from(value.data_type())?; + arg = Arc::new(Literal::new(null_value)); + value + } + None => ScalarValue::try_from(DataType::Null)?, + }, + _ => { + get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)? + } + }; Arc::new(lead( name, - out_data_type.clone(), + default_value.data_type().clone(), arg, shift_offset, default_value, diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index cb6c6a5ace76e..79df1b6a02f52 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -4932,4 +4932,52 @@ SELECT v1, NTH_VALUE(v2, 0) OVER (PARTITION BY v1 ORDER BY v2) FROM t; statement ok DROP TABLE t; -## end test handle NULL and 0 of NTH_VALUE \ No newline at end of file +## end test handle NULL and 0 of NTH_VALUE + +## test handle NULL of lead + +statement ok +create table t1(v1 int); + +statement ok +insert into t1 values (1); + +query B +SELECT LEAD(NULL, 0, false) OVER () FROM t1; +---- +NULL + +query B +SELECT LAG(NULL, 0, false) OVER () FROM t1; +---- +NULL + +query B +SELECT LEAD(NULL, 1, false) OVER () FROM t1; +---- +false + +query B +SELECT LAG(NULL, 1, false) OVER () FROM t1; +---- +false + +statement ok +insert into t1 values (2); + +query B +SELECT LEAD(NULL, 1, false) OVER () FROM t1; +---- +NULL +false + +query B +SELECT LAG(NULL, 1, false) OVER () FROM t1; +---- +false +NULL + +statement ok +DROP TABLE t1; + +## end test handle NULL of lead \ No newline at end of file