Skip to content

Commit

Permalink
Port tan, tanh to datafusion-functions (#9535)
Browse files Browse the repository at this point in the history
* Port tan function

* Port tanh function

* Remove length checking of arg_types in return_type

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
ongchi and alamb authored Mar 12, 2024
1 parent 263cce0 commit bd9a272
Show file tree
Hide file tree
Showing 12 changed files with 208 additions and 43 deletions.
13 changes: 0 additions & 13 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,6 @@ pub enum BuiltinScalarFunction {
Sinh,
/// sqrt
Sqrt,
/// tan
Tan,
/// tanh
Tanh,
/// trunc
Trunc,
/// cot
Expand Down Expand Up @@ -320,8 +316,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Sqrt => Volatility::Immutable,
BuiltinScalarFunction::Cbrt => Volatility::Immutable,
BuiltinScalarFunction::Cot => Volatility::Immutable,
BuiltinScalarFunction::Tan => Volatility::Immutable,
BuiltinScalarFunction::Tanh => Volatility::Immutable,
BuiltinScalarFunction::Trunc => Volatility::Immutable,
BuiltinScalarFunction::ArrayElement => Volatility::Immutable,
BuiltinScalarFunction::ArrayExcept => Volatility::Immutable,
Expand Down Expand Up @@ -629,8 +623,6 @@ impl BuiltinScalarFunction {
| BuiltinScalarFunction::Sinh
| BuiltinScalarFunction::Sqrt
| BuiltinScalarFunction::Cbrt
| BuiltinScalarFunction::Tan
| BuiltinScalarFunction::Tanh
| BuiltinScalarFunction::Trunc
| BuiltinScalarFunction::Cot => match input_expr_types[0] {
Float32 => Ok(Float32),
Expand Down Expand Up @@ -913,8 +905,6 @@ impl BuiltinScalarFunction {
| BuiltinScalarFunction::Sin
| BuiltinScalarFunction::Sinh
| BuiltinScalarFunction::Sqrt
| BuiltinScalarFunction::Tan
| BuiltinScalarFunction::Tanh
| BuiltinScalarFunction::Cot => {
// math expressions expect 1 argument of type f64 or f32
// priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we
Expand Down Expand Up @@ -964,7 +954,6 @@ impl BuiltinScalarFunction {
| BuiltinScalarFunction::Sinh
| BuiltinScalarFunction::Sqrt
| BuiltinScalarFunction::Cbrt
| BuiltinScalarFunction::Tanh
| BuiltinScalarFunction::Trunc
| BuiltinScalarFunction::Pi
) {
Expand Down Expand Up @@ -1010,8 +999,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Sin => &["sin"],
BuiltinScalarFunction::Sinh => &["sinh"],
BuiltinScalarFunction::Sqrt => &["sqrt"],
BuiltinScalarFunction::Tan => &["tan"],
BuiltinScalarFunction::Tanh => &["tanh"],
BuiltinScalarFunction::Trunc => &["trunc"],

// conditional functions
Expand Down
4 changes: 0 additions & 4 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -538,11 +538,9 @@ scalar_expr!(Sqrt, sqrt, num, "square root of a number");
scalar_expr!(Cbrt, cbrt, num, "cube root of a number");
scalar_expr!(Sin, sin, num, "sine");
scalar_expr!(Cos, cos, num, "cosine");
scalar_expr!(Tan, tan, num, "tangent");
scalar_expr!(Cot, cot, num, "cotangent");
scalar_expr!(Sinh, sinh, num, "hyperbolic sine");
scalar_expr!(Cosh, cosh, num, "hyperbolic cosine");
scalar_expr!(Tanh, tanh, num, "hyperbolic tangent");
scalar_expr!(Atan, atan, num, "inverse tangent");
scalar_expr!(Asinh, asinh, num, "inverse hyperbolic sine");
scalar_expr!(Acosh, acosh, num, "inverse hyperbolic cosine");
Expand Down Expand Up @@ -1202,11 +1200,9 @@ mod test {
test_unary_scalar_expr!(Cbrt, cbrt);
test_unary_scalar_expr!(Sin, sin);
test_unary_scalar_expr!(Cos, cos);
test_unary_scalar_expr!(Tan, tan);
test_unary_scalar_expr!(Cot, cot);
test_unary_scalar_expr!(Sinh, sinh);
test_unary_scalar_expr!(Cosh, cosh);
test_unary_scalar_expr!(Tanh, tanh);
test_unary_scalar_expr!(Atan, atan);
test_unary_scalar_expr!(Asinh, asinh);
test_unary_scalar_expr!(Acosh, acosh);
Expand Down
8 changes: 7 additions & 1 deletion datafusion/functions/src/math/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,16 @@ mod abs;
mod acos;
mod asin;
mod nans;
mod tan;
mod tanh;

// create UDFs
make_udf_function!(nans::IsNanFunc, ISNAN, isnan);
make_udf_function!(abs::AbsFunc, ABS, abs);
make_udf_function!(acos::AcosFunc, ACOS, acos);
make_udf_function!(asin::AsinFunc, ASIN, asin);
make_udf_function!(tan::TanFunc, TAN, tan);
make_udf_function!(tanh::TanhFunc, TANH, tanh);

// Export the functions out of this package, both as expr_fn as well as a list of functions
export_functions!(
Expand All @@ -45,5 +49,7 @@ export_functions!(
asin,
num,
"returns the arc sine or inverse sine of a number"
)
),
(tan, num, "returns the tangent of a number"),
(tanh, num, "returns the hyperbolic tangent of a number")
);
98 changes: 98 additions & 0 deletions datafusion/functions/src/math/tan.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Math function: `tan()`.
use std::any::Any;
use std::sync::Arc;

use arrow::datatypes::DataType;
use arrow_array::{ArrayRef, Float32Array, Float64Array};
use datafusion_common::{exec_err, DataFusionError, Result};
use datafusion_expr::Volatility;
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature};

#[derive(Debug)]
pub struct TanFunc {
signature: Signature,
}

impl TanFunc {
pub fn new() -> Self {
Self {
signature: Signature::uniform(
1,
vec![DataType::Float64, DataType::Float32],
Volatility::Immutable,
),
}
}
}

impl ScalarUDFImpl for TanFunc {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"tan"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
let arg_type = &arg_types[0];

match arg_type {
DataType::Float64 => Ok(DataType::Float64),
DataType::Float32 => Ok(DataType::Float32),

// For other types (possible values null/int), use Float 64
_ => Ok(DataType::Float64),
}
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(args)?;

let arr: ArrayRef = match args[0].data_type() {
DataType::Float64 => Arc::new(make_function_scalar_inputs_return_type!(
&args[0],
self.name(),
Float64Array,
Float64Array,
{ f64::tan }
)),
DataType::Float32 => Arc::new(make_function_scalar_inputs_return_type!(
&args[0],
self.name(),
Float32Array,
Float32Array,
{ f32::tan }
)),
other => {
return exec_err!(
"Unsupported data type {other:?} for function {}",
self.name()
)
}
};
Ok(ColumnarValue::Array(arr))
}
}
98 changes: 98 additions & 0 deletions datafusion/functions/src/math/tanh.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Math function: `tanh()`.
use std::any::Any;
use std::sync::Arc;

use arrow::datatypes::DataType;
use arrow_array::{ArrayRef, Float32Array, Float64Array};
use datafusion_common::{exec_err, DataFusionError, Result};
use datafusion_expr::Volatility;
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature};

#[derive(Debug)]
pub struct TanhFunc {
signature: Signature,
}

impl TanhFunc {
pub fn new() -> Self {
Self {
signature: Signature::uniform(
1,
vec![DataType::Float64, DataType::Float32],
Volatility::Immutable,
),
}
}
}

impl ScalarUDFImpl for TanhFunc {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"tanh"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
let arg_type = &arg_types[0];

match arg_type {
DataType::Float64 => Ok(DataType::Float64),
DataType::Float32 => Ok(DataType::Float32),

// For other types (possible values null/int), use Float 64
_ => Ok(DataType::Float64),
}
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(args)?;

let arr: ArrayRef = match args[0].data_type() {
DataType::Float64 => Arc::new(make_function_scalar_inputs_return_type!(
&args[0],
self.name(),
Float64Array,
Float64Array,
{ f64::tanh }
)),
DataType::Float32 => Arc::new(make_function_scalar_inputs_return_type!(
&args[0],
self.name(),
Float32Array,
Float32Array,
{ f32::tanh }
)),
other => {
return exec_err!(
"Unsupported data type {other:?} for function {}",
self.name()
)
}
};
Ok(ColumnarValue::Array(arr))
}
}
2 changes: 0 additions & 2 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,6 @@ pub fn create_physical_fun(
BuiltinScalarFunction::Sinh => Arc::new(math_expressions::sinh),
BuiltinScalarFunction::Sqrt => Arc::new(math_expressions::sqrt),
BuiltinScalarFunction::Cbrt => Arc::new(math_expressions::cbrt),
BuiltinScalarFunction::Tan => Arc::new(math_expressions::tan),
BuiltinScalarFunction::Tanh => Arc::new(math_expressions::tanh),
BuiltinScalarFunction::Trunc => {
Arc::new(|args| make_scalar_function_inner(math_expressions::trunc)(args))
}
Expand Down
2 changes: 0 additions & 2 deletions datafusion/physical-expr/src/math_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,8 @@ math_unary_function!("sqrt", sqrt);
math_unary_function!("cbrt", cbrt);
math_unary_function!("sin", sin);
math_unary_function!("cos", cos);
math_unary_function!("tan", tan);
math_unary_function!("sinh", sinh);
math_unary_function!("cosh", cosh);
math_unary_function!("tanh", tanh);
math_unary_function!("asin", asin);
math_unary_function!("acos", acos);
math_unary_function!("atan", atan);
Expand Down
4 changes: 2 additions & 2 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ enum ScalarFunction {
Signum = 15;
Sin = 16;
Sqrt = 17;
Tan = 18;
// Tan = 18;
Trunc = 19;
// 20 was Array
// RegexpMatch = 21;
Expand Down Expand Up @@ -620,7 +620,7 @@ enum ScalarFunction {
Atanh = 76;
Sinh = 77;
Cosh = 78;
Tanh = 79;
// Tanh = 79;
Pi = 80;
Degrees = 81;
Radians = 82;
Expand Down
6 changes: 0 additions & 6 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit bd9a272

Please sign in to comment.