diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs index 1bab2953e4f65..71ab7c1b43502 100644 --- a/datafusion/functions/src/math/round.rs +++ b/datafusion/functions/src/math/round.rs @@ -20,10 +20,13 @@ use std::sync::Arc; use crate::utils::make_scalar_function; -use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array}; +use arrow::array::{ArrayRef, Float32Array, Float64Array, Int32Array}; +use arrow::compute::{cast_with_options, CastOptions}; use arrow::datatypes::DataType; -use arrow::datatypes::DataType::{Float32, Float64}; -use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; +use arrow::datatypes::DataType::{Float32, Float64, Int32}; +use datafusion_common::{ + exec_datafusion_err, exec_err, DataFusionError, Result, ScalarValue, +}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; @@ -114,7 +117,11 @@ pub fn round(args: &[ArrayRef]) -> Result { match args[0].data_type() { DataType::Float64 => match decimal_places { ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => { - let decimal_places = decimal_places.try_into().unwrap(); + let decimal_places: i32 = decimal_places.try_into().map_err(|e| { + exec_datafusion_err!( + "Invalid value for decimal places: {decimal_places}: {e}" + ) + })?; Ok(Arc::new(make_function_scalar_inputs!( &args[0], @@ -128,21 +135,30 @@ pub fn round(args: &[ArrayRef]) -> Result { } )) as ArrayRef) } - ColumnarValue::Array(decimal_places) => Ok(Arc::new(make_function_inputs2!( - &args[0], - decimal_places, - "value", - "decimal_places", - Float64Array, - Int64Array, - { - |value: f64, decimal_places: i64| { - (value * 10.0_f64.powi(decimal_places.try_into().unwrap())) - .round() - / 10.0_f64.powi(decimal_places.try_into().unwrap()) + ColumnarValue::Array(decimal_places) => { + let options = CastOptions { + safe: false, // raise error if the cast is not possible + ..Default::default() + }; + let decimal_places = cast_with_options(&decimal_places, &Int32, &options) + .map_err(|e| { + exec_datafusion_err!("Invalid values for decimal places: {e}") + })?; + Ok(Arc::new(make_function_inputs2!( + &args[0], + decimal_places, + "value", + "decimal_places", + Float64Array, + Int32Array, + { + |value: f64, decimal_places: i32| { + (value * 10.0_f64.powi(decimal_places)).round() + / 10.0_f64.powi(decimal_places) + } } - } - )) as ArrayRef), + )) as ArrayRef) + } _ => { exec_err!("round function requires a scalar or array for decimal_places") } @@ -150,7 +166,11 @@ pub fn round(args: &[ArrayRef]) -> Result { DataType::Float32 => match decimal_places { ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => { - let decimal_places = decimal_places.try_into().unwrap(); + let decimal_places: i32 = decimal_places.try_into().map_err(|e| { + exec_datafusion_err!( + "Invalid value for decimal places: {decimal_places}: {e}" + ) + })?; Ok(Arc::new(make_function_scalar_inputs!( &args[0], @@ -164,21 +184,30 @@ pub fn round(args: &[ArrayRef]) -> Result { } )) as ArrayRef) } - ColumnarValue::Array(decimal_places) => Ok(Arc::new(make_function_inputs2!( - &args[0], - decimal_places, - "value", - "decimal_places", - Float32Array, - Int64Array, - { - |value: f32, decimal_places: i64| { - (value * 10.0_f32.powi(decimal_places.try_into().unwrap())) - .round() - / 10.0_f32.powi(decimal_places.try_into().unwrap()) + ColumnarValue::Array(_) => { + let ColumnarValue::Array(decimal_places) = + decimal_places.cast_to(&Int32, None).map_err(|e| { + exec_datafusion_err!("Invalid values for decimal places: {e}") + })? + else { + panic!("Unexpected result of ColumnarValue::Array.cast") + }; + + Ok(Arc::new(make_function_inputs2!( + &args[0], + decimal_places, + "value", + "decimal_places", + Float32Array, + Int32Array, + { + |value: f32, decimal_places: i32| { + (value * 10.0_f32.powi(decimal_places)).round() + / 10.0_f32.powi(decimal_places) + } } - } - )) as ArrayRef), + )) as ArrayRef) + } _ => { exec_err!("round function requires a scalar or array for decimal_places") } @@ -196,6 +225,7 @@ mod test { use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array}; use datafusion_common::cast::{as_float32_array, as_float64_array}; + use datafusion_common::DataFusionError; #[test] fn test_round_f32() { @@ -262,4 +292,17 @@ mod test { assert_eq!(floats, &expected); } + + #[test] + fn test_round_f32_cast_fail() { + let args: Vec = vec![ + Arc::new(Float64Array::from(vec![125.2345])), // input + Arc::new(Int64Array::from(vec![2147483648])), // decimal_places + ]; + + let result = round(&args); + + assert!(result.is_err()); + assert!(matches!(result, Err(DataFusionError::Execution { .. }))); + } } diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index a68d1cc7a7709..e152269812590 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -823,6 +823,16 @@ select round(a), round(b), round(c) from small_floats; 0 0 1 1 0 0 +# round with too large +# max Int32 is 2147483647 +query error DataFusion error: Execution error: Invalid values for decimal places: Cast error: Can't cast value 2147483648 to type Int32 +select round(3.14, 2147483648); + +# with array +query error DataFusion error: Execution error: Invalid values for decimal places: Cast error: Can't cast value 2147483649 to type Int32 +select round(column1, column2) from values (3.14, 2), (3.14, 3), (3.14, 2147483649); + + ## signum # signum scalar function