Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: handle NULL input in lead/lag window function #12811

Merged
merged 1 commit into from
Oct 16, 2024

Conversation

HuSen8891
Copy link
Contributor

Which issue does this PR close?

Closes #12717

Rationale for this change

What changes are included in this PR?

Are these changes tested?

Are there any user-facing changes?

@github-actions github-actions bot added physical-expr Physical Expressions sqllogictest SQL Logic Tests (.slt) labels Oct 8, 2024
Copy link
Contributor

@comphead comphead left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @HuSen8891
I have some feeling that we should fix the out_data_type earlier so it will serve both LAG/LEAD

@HuSen8891
Copy link
Contributor Author

Thanks @HuSen8891 I have some feeling that we should fix the out_data_type earlier so it will serve both LAG/LEAD

Thanks! handle NULL input in lag window function

@HuSen8891 HuSen8891 changed the title Fix: handle NULL input in lead window function Fix: handle NULL input in lead/lag window function Oct 9, 2024
@HuSen8891 HuSen8891 requested a review from comphead October 12, 2024 01:35
Comment on lines 264 to 277
// If value is NULL, we use default data type as output data type, no need to cast data type
let default_value = match out_data_type {
DataType::Null => match get_scalar_value_from_args(args, 2)? {
Some(value) => {
let null_value = ScalarValue::try_from(value.data_type())?;
arg = Arc::new(Literal::new(null_value));
value
}
None => ScalarValue::try_from(DataType::Null)?,
},
_ => {
get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?
}
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think concerns got intertwined here making this code a bit hard to read. Tbh this bit of code was hard to understand even in the absence of this fix 😅.

I know because I had some difficult in correctly moving this code for parsing the window function arguments when converting it to a WindowUDF. See #12857.

From my understanding there are two things happening here,

  1. When NULL is passed as the expression (1st argument) we need to coerce it to be the type of the default value(if one is provided).
  2. Like you rightly pointed out in the comments, there is no need to cast the default value to type of the 1st argument when it is NULL.

There is a dependency on the default value(3rd argument) in both of the cases.

We can compute them separately like this,

Suggested change
// If value is NULL, we use default data type as output data type, no need to cast data type
let default_value = match out_data_type {
DataType::Null => match get_scalar_value_from_args(args, 2)? {
Some(value) => {
let null_value = ScalarValue::try_from(value.data_type())?;
arg = Arc::new(Literal::new(null_value));
value
}
None => ScalarValue::try_from(DataType::Null)?,
},
_ => {
get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?
}
};
let arg = coerce_default_type_if_null(args, out_data_type);
let default_value =
get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?;

In the first part, we coerce NULL to DataType of default value when one is provided,

/// For lead/lag window functions, when the expression (first argument)
/// is NULL, interpret it as a missing value which has the same type as
/// the default value.
///
/// See https://github.com/apache/datafusion/issues/12717
fn coerce_default_type_if_null(
    args: &[Arc<dyn PhysicalExpr>],
    expr_type: &DataType,
) -> Arc<dyn PhysicalExpr> {
    let expr = Arc::clone(&args[0]);
    let default_value = get_scalar_value_from_args(args, 2);

    if !expr_type.is_null() {
        expr
    } else {
        default_value
            .unwrap()
            .and_then(|value| {
                ScalarValue::try_from(value.data_type().clone())
                    .map(|sv| Arc::new(Literal::new(sv)) as Arc<dyn PhysicalExpr>)
                    .ok()
            })
            .unwrap_or(expr)
    }
}

And in the second part we push the null check into get_casted_value function. This benefits both lead/lag as they are the only users of get_casted_value. It looks like this,

/// For lead/lag window functions, the types of the 1st and 3rd
/// argument should match. So here we attempt to cast the default
/// value to the same type as the 1st argument.
///
/// This also handles the case where the 1st argument could be
/// a NULL value in which case avoid casting the 3rd argument.
///
/// See https://github.com/apache/datafusion/issues/12717
fn get_casted_value(
    default_value: Option<ScalarValue>,
    dtype: &DataType,
) -> Result<ScalarValue> {
    match (dtype.is_null(), &default_value) {
        (true, _) => Ok(default_value.unwrap_or(ScalarValue::Null)),
        (false, None) => ScalarValue::try_from(dtype.clone()),
        (false, Some(v)) => {
            if !v.data_type().is_null() {
                v.cast_to(dtype)
            } else {
                ScalarValue::try_from(dtype.clone())
            }
        }
    }
}

Copy link
Contributor

@jcsherin jcsherin Oct 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After thinking about this for some more time, I found a better way which requires no changes to get_casted_value.

The key here is to also rewrite the return data type when NULL is passed, along with the NULL expression. This makes it possible to leave the old code for parsing default value unchanged.

1. Parsing window function arguments

        BuiltInWindowFunction::Lag => {
            let (arg, out_data_type) =
                rewrite_null_expr_and_data_type(args, out_data_type); // <- rewrite NULL expression and the return datatype
            let shift_offset = get_scalar_value_from_args(args, 1)?
                .map(get_signed_integer)
                .map_or(Ok(None), |v| v.map(Some))?;
            let default_value =
                get_casted_value(get_scalar_value_from_args(args, 2)?, &out_data_type)?; // <- No changes needed here. Same as earlier. See diff (deleted code).

            Arc::new(lag(
                name,
                default_value.data_type().clone(),
                arg,
                shift_offset,
                default_value,
                ignore_nulls,
            ))
        }

2. Rewrite NULL and return data type

fn rewrite_null_expr_and_data_type(
    args: &[Arc<dyn PhysicalExpr>],
    expr_type: &DataType,
) -> (Arc<dyn PhysicalExpr>, DataType) {
    let expr = Arc::clone(&args[0]);

    // The input expression and the return is type is unchanged
    // when the input expression is not NULL.
    if !expr_type.is_null() {
        return (expr, expr_type.clone());
    }

    // Rewrites the NULL expression (1st argument) with an expression
    // which is the same data type as the default value (3rd argument).
    // Also rewrites the return type with the same data type as the
    // default value.
    //
    // If a default value is not provided, or it is NULL the original
    // expression (1st argument) and return type is returned without
    // any modifications.
    get_scalar_value_from_args(args, 2)
        .unwrap()
        .and_then(|value| {
            ScalarValue::try_from(value.data_type().clone())
                .map(|sv| {
                    (
                        Arc::new(Literal::new(sv)) as Arc<dyn PhysicalExpr>,
                        value.data_type().clone(),
                    )
                })
                .ok()
        })
        .unwrap_or((expr, expr_type.clone()))
}

I've verified locally and these code changes will pass all the newly added sqllogictests in this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!Refactor the code as described above.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @HuSen8891. This refactor makes it easier to port this fix to #12857.

## end test handle NULL and 0 of NTH_VALUE

## test handle NULL of lead
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

@jcsherin
Copy link
Contributor

Thanks @HuSen8891

FYI this is the active PR for converting BuiltInWindowFunction::{Lead,Lag} to WindowUDF.

Comment on lines 240 to 241
get_scalar_value_from_args(args, 2)
.unwrap()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I should have avoided the use of unwrap() here which will panic 😅. Instead we should propagate the error up the call stack.

Suggested change
get_scalar_value_from_args(args, 2)
.unwrap()
get_scalar_value_from_args(args, 2)?

Inside get_scalar_value_from_args an error is returned if a Literal is not found at the specified index,

.downcast_ref::<Literal>()
.ok_or_else(|| DataFusionError::NotImplemented(
format!("There is only support Literal types for field at idx: {index} in Window Function"),
))?

For that we'll need to change the function signature to return a Result<(Arc<dyn PhysicalExpr>, DataType)> type.

fn rewrite_null_expr_and_data_type(
    args: &[Arc<dyn PhysicalExpr>],
    expr_type: &DataType,
) -> Result<(Arc<dyn PhysicalExpr>, DataType)> // <- wrapped in `Result()`

And the callers changed to do an early return in case of any errors.

            // rewrite NULL expression and the return datatype
            let (arg, out_data_type) =
                rewrite_null_expr_and_data_type(args, out_data_type)?; // <- adds `?`

args: &[Arc<dyn PhysicalExpr>],
expr_type: &DataType,
) -> (Arc<dyn PhysicalExpr>, DataType) {
let expr = Arc::clone(&args[0]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets assert args are non empty before access the first element

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Fixed.

@@ -217,6 +217,41 @@ fn get_casted_value(
}
}

fn rewrite_null_expr_and_data_type(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the naming its little bit abstract however its still good name but to help contributors understand what this method exactly doing lets add comments.

Having comments method increases chances that method can be reusable in future if the similar case shows up

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the naming its little bit abstract however its still good name but to help contributors understand what this method exactly doing lets add comments.

Having comments method increases chances that method can be reusable in future if the similar case shows up

💯

Arc::new(lag(
name,
out_data_type.clone(),
default_value.data_type().clone(),
Copy link
Contributor

@jcsherin jcsherin Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: This works because 1st and 3rd arguments ends up having the same datatype after parsing the arguments above (for all cases). So passing out_data_type.clone() (line removed per diff) as an argument here is also correct.

👍

return Ok((expr, expr_type.clone()));
}

get_scalar_value_from_args(args, 2)?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Copy link
Contributor

@comphead comphead left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm thanks @HuSen8891

@jcsherin
Copy link
Contributor

Thanks @HuSen8891 and @comphead 🙌

@alamb
Copy link
Contributor

alamb commented Oct 16, 2024

Thank you @HuSen8891 and @comphead and @jcsherin for the review

@alamb alamb merged commit 44127ec into apache:main Oct 16, 2024
24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
physical-expr Physical Expressions sqllogictest SQL Logic Tests (.slt)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Incorrect NULL handling in lead window function (SQLancer)
4 participants