From bd9a272a9f4dd572a256809bb0615f454b19e913 Mon Sep 17 00:00:00 2001 From: Chih Wang Date: Tue, 12 Mar 2024 21:37:39 +0800 Subject: [PATCH] Port tan, tanh to datafusion-functions (#9535) * Port tan function * Port tanh function * Remove length checking of arg_types in return_type --------- Co-authored-by: Andrew Lamb --- datafusion/expr/src/built_in_function.rs | 13 --- datafusion/expr/src/expr_fn.rs | 4 - datafusion/functions/src/math/mod.rs | 8 +- datafusion/functions/src/math/tan.rs | 98 +++++++++++++++++++ datafusion/functions/src/math/tanh.rs | 98 +++++++++++++++++++ datafusion/physical-expr/src/functions.rs | 2 - .../physical-expr/src/math_expressions.rs | 2 - datafusion/proto/proto/datafusion.proto | 4 +- datafusion/proto/src/generated/pbjson.rs | 6 -- datafusion/proto/src/generated/prost.rs | 8 +- .../proto/src/logical_plan/from_proto.rs | 6 +- datafusion/proto/src/logical_plan/to_proto.rs | 2 - 12 files changed, 208 insertions(+), 43 deletions(-) create mode 100644 datafusion/functions/src/math/tan.rs create mode 100644 datafusion/functions/src/math/tanh.rs diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 41a71711e202..6435eaee160a 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -100,10 +100,6 @@ pub enum BuiltinScalarFunction { Sinh, /// sqrt Sqrt, - /// tan - Tan, - /// tanh - Tanh, /// trunc Trunc, /// cot @@ -320,8 +316,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Sqrt => Volatility::Immutable, BuiltinScalarFunction::Cbrt => Volatility::Immutable, BuiltinScalarFunction::Cot => Volatility::Immutable, - BuiltinScalarFunction::Tan => Volatility::Immutable, - BuiltinScalarFunction::Tanh => Volatility::Immutable, BuiltinScalarFunction::Trunc => Volatility::Immutable, BuiltinScalarFunction::ArrayElement => Volatility::Immutable, BuiltinScalarFunction::ArrayExcept => Volatility::Immutable, @@ -629,8 +623,6 @@ impl BuiltinScalarFunction { | BuiltinScalarFunction::Sinh | BuiltinScalarFunction::Sqrt | BuiltinScalarFunction::Cbrt - | BuiltinScalarFunction::Tan - | BuiltinScalarFunction::Tanh | BuiltinScalarFunction::Trunc | BuiltinScalarFunction::Cot => match input_expr_types[0] { Float32 => Ok(Float32), @@ -913,8 +905,6 @@ impl BuiltinScalarFunction { | BuiltinScalarFunction::Sin | BuiltinScalarFunction::Sinh | BuiltinScalarFunction::Sqrt - | BuiltinScalarFunction::Tan - | BuiltinScalarFunction::Tanh | BuiltinScalarFunction::Cot => { // math expressions expect 1 argument of type f64 or f32 // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we @@ -964,7 +954,6 @@ impl BuiltinScalarFunction { | BuiltinScalarFunction::Sinh | BuiltinScalarFunction::Sqrt | BuiltinScalarFunction::Cbrt - | BuiltinScalarFunction::Tanh | BuiltinScalarFunction::Trunc | BuiltinScalarFunction::Pi ) { @@ -1010,8 +999,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Sin => &["sin"], BuiltinScalarFunction::Sinh => &["sinh"], BuiltinScalarFunction::Sqrt => &["sqrt"], - BuiltinScalarFunction::Tan => &["tan"], - BuiltinScalarFunction::Tanh => &["tanh"], BuiltinScalarFunction::Trunc => &["trunc"], // conditional functions diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index fc094ffaa0fc..12eafa6ccdbc 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -538,11 +538,9 @@ scalar_expr!(Sqrt, sqrt, num, "square root of a number"); scalar_expr!(Cbrt, cbrt, num, "cube root of a number"); scalar_expr!(Sin, sin, num, "sine"); scalar_expr!(Cos, cos, num, "cosine"); -scalar_expr!(Tan, tan, num, "tangent"); scalar_expr!(Cot, cot, num, "cotangent"); scalar_expr!(Sinh, sinh, num, "hyperbolic sine"); scalar_expr!(Cosh, cosh, num, "hyperbolic cosine"); -scalar_expr!(Tanh, tanh, num, "hyperbolic tangent"); scalar_expr!(Atan, atan, num, "inverse tangent"); scalar_expr!(Asinh, asinh, num, "inverse hyperbolic sine"); scalar_expr!(Acosh, acosh, num, "inverse hyperbolic cosine"); @@ -1202,11 +1200,9 @@ mod test { test_unary_scalar_expr!(Cbrt, cbrt); test_unary_scalar_expr!(Sin, sin); test_unary_scalar_expr!(Cos, cos); - test_unary_scalar_expr!(Tan, tan); test_unary_scalar_expr!(Cot, cot); test_unary_scalar_expr!(Sinh, sinh); test_unary_scalar_expr!(Cosh, cosh); - test_unary_scalar_expr!(Tanh, tanh); test_unary_scalar_expr!(Atan, atan); test_unary_scalar_expr!(Asinh, asinh); test_unary_scalar_expr!(Acosh, acosh); diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs index 3741cc2802bb..e7ede6043a59 100644 --- a/datafusion/functions/src/math/mod.rs +++ b/datafusion/functions/src/math/mod.rs @@ -21,12 +21,16 @@ mod abs; mod acos; mod asin; mod nans; +mod tan; +mod tanh; // create UDFs make_udf_function!(nans::IsNanFunc, ISNAN, isnan); make_udf_function!(abs::AbsFunc, ABS, abs); make_udf_function!(acos::AcosFunc, ACOS, acos); make_udf_function!(asin::AsinFunc, ASIN, asin); +make_udf_function!(tan::TanFunc, TAN, tan); +make_udf_function!(tanh::TanhFunc, TANH, tanh); // Export the functions out of this package, both as expr_fn as well as a list of functions export_functions!( @@ -45,5 +49,7 @@ export_functions!( asin, num, "returns the arc sine or inverse sine of a number" - ) + ), + (tan, num, "returns the tangent of a number"), + (tanh, num, "returns the hyperbolic tangent of a number") ); diff --git a/datafusion/functions/src/math/tan.rs b/datafusion/functions/src/math/tan.rs new file mode 100644 index 000000000000..ea3e002f8489 --- /dev/null +++ b/datafusion/functions/src/math/tan.rs @@ -0,0 +1,98 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Math function: `tan()`. + +use std::any::Any; +use std::sync::Arc; + +use arrow::datatypes::DataType; +use arrow_array::{ArrayRef, Float32Array, Float64Array}; +use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_expr::Volatility; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature}; + +#[derive(Debug)] +pub struct TanFunc { + signature: Signature, +} + +impl TanFunc { + pub fn new() -> Self { + Self { + signature: Signature::uniform( + 1, + vec![DataType::Float64, DataType::Float32], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for TanFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "tan" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + let arg_type = &arg_types[0]; + + match arg_type { + DataType::Float64 => Ok(DataType::Float64), + DataType::Float32 => Ok(DataType::Float32), + + // For other types (possible values null/int), use Float 64 + _ => Ok(DataType::Float64), + } + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let args = ColumnarValue::values_to_arrays(args)?; + + let arr: ArrayRef = match args[0].data_type() { + DataType::Float64 => Arc::new(make_function_scalar_inputs_return_type!( + &args[0], + self.name(), + Float64Array, + Float64Array, + { f64::tan } + )), + DataType::Float32 => Arc::new(make_function_scalar_inputs_return_type!( + &args[0], + self.name(), + Float32Array, + Float32Array, + { f32::tan } + )), + other => { + return exec_err!( + "Unsupported data type {other:?} for function {}", + self.name() + ) + } + }; + Ok(ColumnarValue::Array(arr)) + } +} diff --git a/datafusion/functions/src/math/tanh.rs b/datafusion/functions/src/math/tanh.rs new file mode 100644 index 000000000000..af34681919ab --- /dev/null +++ b/datafusion/functions/src/math/tanh.rs @@ -0,0 +1,98 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Math function: `tanh()`. + +use std::any::Any; +use std::sync::Arc; + +use arrow::datatypes::DataType; +use arrow_array::{ArrayRef, Float32Array, Float64Array}; +use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_expr::Volatility; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature}; + +#[derive(Debug)] +pub struct TanhFunc { + signature: Signature, +} + +impl TanhFunc { + pub fn new() -> Self { + Self { + signature: Signature::uniform( + 1, + vec![DataType::Float64, DataType::Float32], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for TanhFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "tanh" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + let arg_type = &arg_types[0]; + + match arg_type { + DataType::Float64 => Ok(DataType::Float64), + DataType::Float32 => Ok(DataType::Float32), + + // For other types (possible values null/int), use Float 64 + _ => Ok(DataType::Float64), + } + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let args = ColumnarValue::values_to_arrays(args)?; + + let arr: ArrayRef = match args[0].data_type() { + DataType::Float64 => Arc::new(make_function_scalar_inputs_return_type!( + &args[0], + self.name(), + Float64Array, + Float64Array, + { f64::tanh } + )), + DataType::Float32 => Arc::new(make_function_scalar_inputs_return_type!( + &args[0], + self.name(), + Float32Array, + Float32Array, + { f32::tanh } + )), + other => { + return exec_err!( + "Unsupported data type {other:?} for function {}", + self.name() + ) + } + }; + Ok(ColumnarValue::Array(arr)) + } +} diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 2bce52bf7862..5d13f945692a 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -282,8 +282,6 @@ pub fn create_physical_fun( BuiltinScalarFunction::Sinh => Arc::new(math_expressions::sinh), BuiltinScalarFunction::Sqrt => Arc::new(math_expressions::sqrt), BuiltinScalarFunction::Cbrt => Arc::new(math_expressions::cbrt), - BuiltinScalarFunction::Tan => Arc::new(math_expressions::tan), - BuiltinScalarFunction::Tanh => Arc::new(math_expressions::tanh), BuiltinScalarFunction::Trunc => { Arc::new(|args| make_scalar_function_inner(math_expressions::trunc)(args)) } diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs index a8c115ba3a82..db8855cb5400 100644 --- a/datafusion/physical-expr/src/math_expressions.rs +++ b/datafusion/physical-expr/src/math_expressions.rs @@ -159,10 +159,8 @@ math_unary_function!("sqrt", sqrt); math_unary_function!("cbrt", cbrt); math_unary_function!("sin", sin); math_unary_function!("cos", cos); -math_unary_function!("tan", tan); math_unary_function!("sinh", sinh); math_unary_function!("cosh", cosh); -math_unary_function!("tanh", tanh); math_unary_function!("asin", asin); math_unary_function!("acos", acos); math_unary_function!("atan", atan); diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 6899fcf7c707..bd4d8b45b152 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -559,7 +559,7 @@ enum ScalarFunction { Signum = 15; Sin = 16; Sqrt = 17; - Tan = 18; + // Tan = 18; Trunc = 19; // 20 was Array // RegexpMatch = 21; @@ -620,7 +620,7 @@ enum ScalarFunction { Atanh = 76; Sinh = 77; Cosh = 78; - Tanh = 79; + // Tanh = 79; Pi = 80; Degrees = 81; Radians = 82; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index b915dbffbb6f..aaa0764b1e83 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -23127,7 +23127,6 @@ impl serde::Serialize for ScalarFunction { Self::Signum => "Signum", Self::Sin => "Sin", Self::Sqrt => "Sqrt", - Self::Tan => "Tan", Self::Trunc => "Trunc", Self::BitLength => "BitLength", Self::Btrim => "Btrim", @@ -23175,7 +23174,6 @@ impl serde::Serialize for ScalarFunction { Self::Atanh => "Atanh", Self::Sinh => "Sinh", Self::Cosh => "Cosh", - Self::Tanh => "Tanh", Self::Pi => "Pi", Self::Degrees => "Degrees", Self::Radians => "Radians", @@ -23237,7 +23235,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Signum", "Sin", "Sqrt", - "Tan", "Trunc", "BitLength", "Btrim", @@ -23285,7 +23282,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Atanh", "Sinh", "Cosh", - "Tanh", "Pi", "Degrees", "Radians", @@ -23376,7 +23372,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Signum" => Ok(ScalarFunction::Signum), "Sin" => Ok(ScalarFunction::Sin), "Sqrt" => Ok(ScalarFunction::Sqrt), - "Tan" => Ok(ScalarFunction::Tan), "Trunc" => Ok(ScalarFunction::Trunc), "BitLength" => Ok(ScalarFunction::BitLength), "Btrim" => Ok(ScalarFunction::Btrim), @@ -23424,7 +23419,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Atanh" => Ok(ScalarFunction::Atanh), "Sinh" => Ok(ScalarFunction::Sinh), "Cosh" => Ok(ScalarFunction::Cosh), - "Tanh" => Ok(ScalarFunction::Tanh), "Pi" => Ok(ScalarFunction::Pi), "Degrees" => Ok(ScalarFunction::Degrees), "Radians" => Ok(ScalarFunction::Radians), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 6149e02fb7d1..07a0f30a2f68 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2861,7 +2861,7 @@ pub enum ScalarFunction { Signum = 15, Sin = 16, Sqrt = 17, - Tan = 18, + /// Tan = 18; Trunc = 19, /// 20 was Array /// RegexpMatch = 21; @@ -2922,7 +2922,7 @@ pub enum ScalarFunction { Atanh = 76, Sinh = 77, Cosh = 78, - Tanh = 79, + /// Tanh = 79; Pi = 80, Degrees = 81, Radians = 82, @@ -3005,7 +3005,6 @@ impl ScalarFunction { ScalarFunction::Signum => "Signum", ScalarFunction::Sin => "Sin", ScalarFunction::Sqrt => "Sqrt", - ScalarFunction::Tan => "Tan", ScalarFunction::Trunc => "Trunc", ScalarFunction::BitLength => "BitLength", ScalarFunction::Btrim => "Btrim", @@ -3053,7 +3052,6 @@ impl ScalarFunction { ScalarFunction::Atanh => "Atanh", ScalarFunction::Sinh => "Sinh", ScalarFunction::Cosh => "Cosh", - ScalarFunction::Tanh => "Tanh", ScalarFunction::Pi => "Pi", ScalarFunction::Degrees => "Degrees", ScalarFunction::Radians => "Radians", @@ -3109,7 +3107,6 @@ impl ScalarFunction { "Signum" => Some(Self::Signum), "Sin" => Some(Self::Sin), "Sqrt" => Some(Self::Sqrt), - "Tan" => Some(Self::Tan), "Trunc" => Some(Self::Trunc), "BitLength" => Some(Self::BitLength), "Btrim" => Some(Self::Btrim), @@ -3157,7 +3154,6 @@ impl ScalarFunction { "Atanh" => Some(Self::Atanh), "Sinh" => Some(Self::Sinh), "Cosh" => Some(Self::Cosh), - "Tanh" => Some(Self::Tanh), "Pi" => Some(Self::Pi), "Degrees" => Some(Self::Degrees), "Radians" => Some(Self::Radians), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 131ac78847bc..e8c84ec12879 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -60,7 +60,7 @@ use datafusion_expr::{ lower, lpad, ltrim, md5, nanvl, now, octet_length, overlay, pi, power, radians, random, repeat, replace, reverse, right, round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, sqrt, starts_with, strpos, substr, - substr_index, substring, tan, tanh, to_hex, translate, trim, trunc, upper, uuid, + substr_index, substring, to_hex, translate, trim, trunc, upper, uuid, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, GroupingSet::GroupingSets, @@ -445,12 +445,10 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Cbrt => Self::Cbrt, ScalarFunction::Sin => Self::Sin, ScalarFunction::Cos => Self::Cos, - ScalarFunction::Tan => Self::Tan, ScalarFunction::Cot => Self::Cot, ScalarFunction::Atan => Self::Atan, ScalarFunction::Sinh => Self::Sinh, ScalarFunction::Cosh => Self::Cosh, - ScalarFunction::Tanh => Self::Tanh, ScalarFunction::Asinh => Self::Asinh, ScalarFunction::Acosh => Self::Acosh, ScalarFunction::Atanh => Self::Atanh, @@ -1479,11 +1477,9 @@ pub fn parse_expr( ScalarFunction::Cbrt => Ok(cbrt(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Sin => Ok(sin(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Cos => Ok(cos(parse_expr(&args[0], registry, codec)?)), - ScalarFunction::Tan => Ok(tan(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Atan => Ok(atan(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Sinh => Ok(sinh(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Cosh => Ok(cosh(parse_expr(&args[0], registry, codec)?)), - ScalarFunction::Tanh => Ok(tanh(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Atanh => { Ok(atanh(parse_expr(&args[0], registry, codec)?)) } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 6d0d81f61b07..c25e2e1ecd22 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1426,11 +1426,9 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Cbrt => Self::Cbrt, BuiltinScalarFunction::Sin => Self::Sin, BuiltinScalarFunction::Cos => Self::Cos, - BuiltinScalarFunction::Tan => Self::Tan, BuiltinScalarFunction::Cot => Self::Cot, BuiltinScalarFunction::Sinh => Self::Sinh, BuiltinScalarFunction::Cosh => Self::Cosh, - BuiltinScalarFunction::Tanh => Self::Tanh, BuiltinScalarFunction::Atan => Self::Atan, BuiltinScalarFunction::Asinh => Self::Asinh, BuiltinScalarFunction::Acosh => Self::Acosh,