Skip to content

Commit

Permalink
Fix: handle NULL input in lead/lag window function (#12811)
Browse files Browse the repository at this point in the history
  • Loading branch information
HuSen8891 authored Oct 16, 2024
1 parent db85d07 commit 44127ec
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 7 deletions.
51 changes: 45 additions & 6 deletions datafusion/physical-plan/src/windows/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,41 @@ fn get_casted_value(
}
}

/// Rewrites the NULL expression (1st argument) with an expression
/// which is the same data type as the default value (3rd argument).
/// Also rewrites the return type with the same data type as the
/// default value.
///
/// If a default value is not provided, or it is NULL the original
/// expression (1st argument) and return type is returned without
/// any modifications.
fn rewrite_null_expr_and_data_type(
args: &[Arc<dyn PhysicalExpr>],
expr_type: &DataType,
) -> Result<(Arc<dyn PhysicalExpr>, DataType)> {
assert!(!args.is_empty());
let expr = Arc::clone(&args[0]);

// The input expression and the return is type is unchanged
// when the input expression is not NULL.
if !expr_type.is_null() {
return Ok((expr, expr_type.clone()));
}

get_scalar_value_from_args(args, 2)?
.and_then(|value| {
ScalarValue::try_from(value.data_type().clone())
.map(|sv| {
Ok((
Arc::new(Literal::new(sv)) as Arc<dyn PhysicalExpr>,
value.data_type().clone(),
))
})
.ok()
})
.unwrap_or(Ok((expr, expr_type.clone())))
}

fn create_built_in_window_expr(
fun: &BuiltInWindowFunction,
args: &[Arc<dyn PhysicalExpr>],
Expand Down Expand Up @@ -252,31 +287,35 @@ fn create_built_in_window_expr(
}
}
BuiltInWindowFunction::Lag => {
let arg = Arc::clone(&args[0]);
// rewrite NULL expression and the return datatype
let (arg, out_data_type) =
rewrite_null_expr_and_data_type(args, out_data_type)?;
let shift_offset = get_scalar_value_from_args(args, 1)?
.map(get_signed_integer)
.map_or(Ok(None), |v| v.map(Some))?;
let default_value =
get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?;
get_casted_value(get_scalar_value_from_args(args, 2)?, &out_data_type)?;
Arc::new(lag(
name,
out_data_type.clone(),
default_value.data_type().clone(),
arg,
shift_offset,
default_value,
ignore_nulls,
))
}
BuiltInWindowFunction::Lead => {
let arg = Arc::clone(&args[0]);
// rewrite NULL expression and the return datatype
let (arg, out_data_type) =
rewrite_null_expr_and_data_type(args, out_data_type)?;
let shift_offset = get_scalar_value_from_args(args, 1)?
.map(get_signed_integer)
.map_or(Ok(None), |v| v.map(Some))?;
let default_value =
get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?;
get_casted_value(get_scalar_value_from_args(args, 2)?, &out_data_type)?;
Arc::new(lead(
name,
out_data_type.clone(),
default_value.data_type().clone(),
arg,
shift_offset,
default_value,
Expand Down
50 changes: 49 additions & 1 deletion datafusion/sqllogictest/test_files/window.slt
Original file line number Diff line number Diff line change
Expand Up @@ -4941,4 +4941,52 @@ NULL
statement ok
DROP TABLE t;

## end test handle NULL and 0 of NTH_VALUE
## end test handle NULL and 0 of NTH_VALUE

## test handle NULL of lead

statement ok
create table t1(v1 int);

statement ok
insert into t1 values (1);

query B
SELECT LEAD(NULL, 0, false) OVER () FROM t1;
----
NULL

query B
SELECT LAG(NULL, 0, false) OVER () FROM t1;
----
NULL

query B
SELECT LEAD(NULL, 1, false) OVER () FROM t1;
----
false

query B
SELECT LAG(NULL, 1, false) OVER () FROM t1;
----
false

statement ok
insert into t1 values (2);

query B
SELECT LEAD(NULL, 1, false) OVER () FROM t1;
----
NULL
false

query B
SELECT LAG(NULL, 1, false) OVER () FROM t1;
----
false
NULL

statement ok
DROP TABLE t1;

## end test handle NULL of lead

0 comments on commit 44127ec

Please sign in to comment.