Skip to content

Commit

Permalink
Minor: dont panic with bad arguments to round (apache#10899)
Browse files Browse the repository at this point in the history
* Minor: dont panic with bad arguments to round

* Minor: add more safety casts to round func

* Remove panic + add sqllogictest

* clippy

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
2 people authored and findepi committed Jul 16, 2024
1 parent 2262b16 commit 4ea85c2
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 33 deletions.
109 changes: 76 additions & 33 deletions datafusion/functions/src/math/round.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@ use std::sync::Arc;

use crate::utils::make_scalar_function;

use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array};
use arrow::array::{ArrayRef, Float32Array, Float64Array, Int32Array};
use arrow::compute::{cast_with_options, CastOptions};
use arrow::datatypes::DataType;
use arrow::datatypes::DataType::{Float32, Float64};
use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue};
use arrow::datatypes::DataType::{Float32, Float64, Int32};
use datafusion_common::{
exec_datafusion_err, exec_err, DataFusionError, Result, ScalarValue,
};
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
use datafusion_expr::TypeSignature::Exact;
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
Expand Down Expand Up @@ -114,7 +117,11 @@ pub fn round(args: &[ArrayRef]) -> Result<ArrayRef> {
match args[0].data_type() {
DataType::Float64 => match decimal_places {
ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => {
let decimal_places = decimal_places.try_into().unwrap();
let decimal_places: i32 = decimal_places.try_into().map_err(|e| {
exec_datafusion_err!(
"Invalid value for decimal places: {decimal_places}: {e}"
)
})?;

Ok(Arc::new(make_function_scalar_inputs!(
&args[0],
Expand All @@ -128,29 +135,42 @@ pub fn round(args: &[ArrayRef]) -> Result<ArrayRef> {
}
)) as ArrayRef)
}
ColumnarValue::Array(decimal_places) => Ok(Arc::new(make_function_inputs2!(
&args[0],
decimal_places,
"value",
"decimal_places",
Float64Array,
Int64Array,
{
|value: f64, decimal_places: i64| {
(value * 10.0_f64.powi(decimal_places.try_into().unwrap()))
.round()
/ 10.0_f64.powi(decimal_places.try_into().unwrap())
ColumnarValue::Array(decimal_places) => {
let options = CastOptions {
safe: false, // raise error if the cast is not possible
..Default::default()
};
let decimal_places = cast_with_options(&decimal_places, &Int32, &options)
.map_err(|e| {
exec_datafusion_err!("Invalid values for decimal places: {e}")
})?;
Ok(Arc::new(make_function_inputs2!(
&args[0],
decimal_places,
"value",
"decimal_places",
Float64Array,
Int32Array,
{
|value: f64, decimal_places: i32| {
(value * 10.0_f64.powi(decimal_places)).round()
/ 10.0_f64.powi(decimal_places)
}
}
}
)) as ArrayRef),
)) as ArrayRef)
}
_ => {
exec_err!("round function requires a scalar or array for decimal_places")
}
},

DataType::Float32 => match decimal_places {
ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => {
let decimal_places = decimal_places.try_into().unwrap();
let decimal_places: i32 = decimal_places.try_into().map_err(|e| {
exec_datafusion_err!(
"Invalid value for decimal places: {decimal_places}: {e}"
)
})?;

Ok(Arc::new(make_function_scalar_inputs!(
&args[0],
Expand All @@ -164,21 +184,30 @@ pub fn round(args: &[ArrayRef]) -> Result<ArrayRef> {
}
)) as ArrayRef)
}
ColumnarValue::Array(decimal_places) => Ok(Arc::new(make_function_inputs2!(
&args[0],
decimal_places,
"value",
"decimal_places",
Float32Array,
Int64Array,
{
|value: f32, decimal_places: i64| {
(value * 10.0_f32.powi(decimal_places.try_into().unwrap()))
.round()
/ 10.0_f32.powi(decimal_places.try_into().unwrap())
ColumnarValue::Array(_) => {
let ColumnarValue::Array(decimal_places) =
decimal_places.cast_to(&Int32, None).map_err(|e| {
exec_datafusion_err!("Invalid values for decimal places: {e}")
})?
else {
panic!("Unexpected result of ColumnarValue::Array.cast")
};

Ok(Arc::new(make_function_inputs2!(
&args[0],
decimal_places,
"value",
"decimal_places",
Float32Array,
Int32Array,
{
|value: f32, decimal_places: i32| {
(value * 10.0_f32.powi(decimal_places)).round()
/ 10.0_f32.powi(decimal_places)
}
}
}
)) as ArrayRef),
)) as ArrayRef)
}
_ => {
exec_err!("round function requires a scalar or array for decimal_places")
}
Expand All @@ -196,6 +225,7 @@ mod test {

use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array};
use datafusion_common::cast::{as_float32_array, as_float64_array};
use datafusion_common::DataFusionError;

#[test]
fn test_round_f32() {
Expand Down Expand Up @@ -262,4 +292,17 @@ mod test {

assert_eq!(floats, &expected);
}

#[test]
fn test_round_f32_cast_fail() {
let args: Vec<ArrayRef> = vec![
Arc::new(Float64Array::from(vec![125.2345])), // input
Arc::new(Int64Array::from(vec![2147483648])), // decimal_places
];

let result = round(&args);

assert!(result.is_err());
assert!(matches!(result, Err(DataFusionError::Execution { .. })));
}
}
10 changes: 10 additions & 0 deletions datafusion/sqllogictest/test_files/scalar.slt
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,16 @@ select round(a), round(b), round(c) from small_floats;
0 0 1
1 0 0

# round with too large
# max Int32 is 2147483647
query error DataFusion error: Execution error: Invalid values for decimal places: Cast error: Can't cast value 2147483648 to type Int32
select round(3.14, 2147483648);

# with array
query error DataFusion error: Execution error: Invalid values for decimal places: Cast error: Can't cast value 2147483649 to type Int32
select round(column1, column2) from values (3.14, 2), (3.14, 3), (3.14, 2147483649);


## signum

# signum scalar function
Expand Down

0 comments on commit 4ea85c2

Please sign in to comment.