-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix: handle NULL input in lead/lag window function #12811
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @HuSen8891
I have some feeling that we should fix the out_data_type earlier so it will serve both LAG/LEAD
Thanks! handle NULL input in lag window function |
// If value is NULL, we use default data type as output data type, no need to cast data type | ||
let default_value = match out_data_type { | ||
DataType::Null => match get_scalar_value_from_args(args, 2)? { | ||
Some(value) => { | ||
let null_value = ScalarValue::try_from(value.data_type())?; | ||
arg = Arc::new(Literal::new(null_value)); | ||
value | ||
} | ||
None => ScalarValue::try_from(DataType::Null)?, | ||
}, | ||
_ => { | ||
get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)? | ||
} | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think concerns got intertwined here making this code a bit hard to read. Tbh this bit of code was hard to understand even in the absence of this fix 😅.
I know because I had some difficult in correctly moving this code for parsing the window function arguments when converting it to a WindowUDF
. See #12857.
From my understanding there are two things happening here,
- When
NULL
is passed as the expression (1st argument) we need to coerce it to be the type of the default value(if one is provided). - Like you rightly pointed out in the comments, there is no need to cast the default value to type of the 1st argument when it is
NULL
.
There is a dependency on the default value(3rd argument) in both of the cases.
We can compute them separately like this,
// If value is NULL, we use default data type as output data type, no need to cast data type | |
let default_value = match out_data_type { | |
DataType::Null => match get_scalar_value_from_args(args, 2)? { | |
Some(value) => { | |
let null_value = ScalarValue::try_from(value.data_type())?; | |
arg = Arc::new(Literal::new(null_value)); | |
value | |
} | |
None => ScalarValue::try_from(DataType::Null)?, | |
}, | |
_ => { | |
get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)? | |
} | |
}; | |
let arg = coerce_default_type_if_null(args, out_data_type); | |
let default_value = | |
get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?; | |
In the first part, we coerce NULL
to DataType
of default value when one is provided,
/// For lead/lag window functions, when the expression (first argument)
/// is NULL, interpret it as a missing value which has the same type as
/// the default value.
///
/// See https://github.com/apache/datafusion/issues/12717
fn coerce_default_type_if_null(
args: &[Arc<dyn PhysicalExpr>],
expr_type: &DataType,
) -> Arc<dyn PhysicalExpr> {
let expr = Arc::clone(&args[0]);
let default_value = get_scalar_value_from_args(args, 2);
if !expr_type.is_null() {
expr
} else {
default_value
.unwrap()
.and_then(|value| {
ScalarValue::try_from(value.data_type().clone())
.map(|sv| Arc::new(Literal::new(sv)) as Arc<dyn PhysicalExpr>)
.ok()
})
.unwrap_or(expr)
}
}
And in the second part we push the null check into get_casted_value
function. This benefits both lead/lag as they are the only users of get_casted_value
. It looks like this,
/// For lead/lag window functions, the types of the 1st and 3rd
/// argument should match. So here we attempt to cast the default
/// value to the same type as the 1st argument.
///
/// This also handles the case where the 1st argument could be
/// a NULL value in which case avoid casting the 3rd argument.
///
/// See https://github.com/apache/datafusion/issues/12717
fn get_casted_value(
default_value: Option<ScalarValue>,
dtype: &DataType,
) -> Result<ScalarValue> {
match (dtype.is_null(), &default_value) {
(true, _) => Ok(default_value.unwrap_or(ScalarValue::Null)),
(false, None) => ScalarValue::try_from(dtype.clone()),
(false, Some(v)) => {
if !v.data_type().is_null() {
v.cast_to(dtype)
} else {
ScalarValue::try_from(dtype.clone())
}
}
}
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After thinking about this for some more time, I found a better way which requires no changes to get_casted_value
.
The key here is to also rewrite the return data type when NULL is passed, along with the NULL expression. This makes it possible to leave the old code for parsing default value unchanged.
1. Parsing window function arguments
BuiltInWindowFunction::Lag => {
let (arg, out_data_type) =
rewrite_null_expr_and_data_type(args, out_data_type); // <- rewrite NULL expression and the return datatype
let shift_offset = get_scalar_value_from_args(args, 1)?
.map(get_signed_integer)
.map_or(Ok(None), |v| v.map(Some))?;
let default_value =
get_casted_value(get_scalar_value_from_args(args, 2)?, &out_data_type)?; // <- No changes needed here. Same as earlier. See diff (deleted code).
Arc::new(lag(
name,
default_value.data_type().clone(),
arg,
shift_offset,
default_value,
ignore_nulls,
))
}
2. Rewrite NULL and return data type
fn rewrite_null_expr_and_data_type(
args: &[Arc<dyn PhysicalExpr>],
expr_type: &DataType,
) -> (Arc<dyn PhysicalExpr>, DataType) {
let expr = Arc::clone(&args[0]);
// The input expression and the return is type is unchanged
// when the input expression is not NULL.
if !expr_type.is_null() {
return (expr, expr_type.clone());
}
// Rewrites the NULL expression (1st argument) with an expression
// which is the same data type as the default value (3rd argument).
// Also rewrites the return type with the same data type as the
// default value.
//
// If a default value is not provided, or it is NULL the original
// expression (1st argument) and return type is returned without
// any modifications.
get_scalar_value_from_args(args, 2)
.unwrap()
.and_then(|value| {
ScalarValue::try_from(value.data_type().clone())
.map(|sv| {
(
Arc::new(Literal::new(sv)) as Arc<dyn PhysicalExpr>,
value.data_type().clone(),
)
})
.ok()
})
.unwrap_or((expr, expr_type.clone()))
}
I've verified locally and these code changes will pass all the newly added sqllogictests in this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!Refactor the code as described above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @HuSen8891. This refactor makes it easier to port this fix to #12857.
## end test handle NULL and 0 of NTH_VALUE | ||
|
||
## test handle NULL of lead |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❤️
Thanks @HuSen8891 FYI this is the active PR for converting |
get_scalar_value_from_args(args, 2) | ||
.unwrap() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I should have avoided the use of unwrap()
here which will panic 😅. Instead we should propagate the error up the call stack.
get_scalar_value_from_args(args, 2) | |
.unwrap() | |
get_scalar_value_from_args(args, 2)? |
Inside get_scalar_value_from_args
an error is returned if a Literal
is not found at the specified index,
datafusion/datafusion/physical-plan/src/windows/mod.rs
Lines 173 to 176 in f9e8e07
.downcast_ref::<Literal>() | |
.ok_or_else(|| DataFusionError::NotImplemented( | |
format!("There is only support Literal types for field at idx: {index} in Window Function"), | |
))? |
For that we'll need to change the function signature to return a Result<(Arc<dyn PhysicalExpr>, DataType)>
type.
fn rewrite_null_expr_and_data_type(
args: &[Arc<dyn PhysicalExpr>],
expr_type: &DataType,
) -> Result<(Arc<dyn PhysicalExpr>, DataType)> // <- wrapped in `Result()`
And the callers changed to do an early return in case of any errors.
// rewrite NULL expression and the return datatype
let (arg, out_data_type) =
rewrite_null_expr_and_data_type(args, out_data_type)?; // <- adds `?`
args: &[Arc<dyn PhysicalExpr>], | ||
expr_type: &DataType, | ||
) -> (Arc<dyn PhysicalExpr>, DataType) { | ||
let expr = Arc::clone(&args[0]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets assert args are non empty before access the first element
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Fixed.
@@ -217,6 +217,41 @@ fn get_casted_value( | |||
} | |||
} | |||
|
|||
fn rewrite_null_expr_and_data_type( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for the naming its little bit abstract however its still good name but to help contributors understand what this method exactly doing lets add comments.
Having comments method increases chances that method can be reusable in future if the similar case shows up
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for the naming its little bit abstract however its still good name but to help contributors understand what this method exactly doing lets add comments.
Having comments method increases chances that method can be reusable in future if the similar case shows up
💯
Arc::new(lag( | ||
name, | ||
out_data_type.clone(), | ||
default_value.data_type().clone(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: This works because 1st and 3rd arguments ends up having the same datatype after parsing the arguments above (for all cases). So passing out_data_type.clone()
(line removed per diff) as an argument here is also correct.
👍
return Ok((expr, expr_type.clone())); | ||
} | ||
|
||
get_scalar_value_from_args(args, 2)? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm thanks @HuSen8891
Thanks @HuSen8891 and @comphead 🙌 |
Thank you @HuSen8891 and @comphead and @jcsherin for the review |
Which issue does this PR close?
Closes #12717
Rationale for this change
What changes are included in this PR?
Are these changes tested?
Are there any user-facing changes?