Skip to content

Commit

Permalink
feat(function): add least function (#13786)
Browse files Browse the repository at this point in the history
* start adding least fn

* feat(function): add least function

* update function name

* fix scalar smaller function

* add tests

* run Clippy and Fmt

* Generated docs using `./dev/update_function_docs.sh`

* add comment why `descending: false`

* update comment

* Update least.rs

Co-authored-by: Bruce Ritchie <[email protected]>

* Update scalar_functions.md

* run ./dev/update_function_docs.sh to update docs

* merge greatest and least implementation to one

* add header

---------

Co-authored-by: Bruce Ritchie <[email protected]>
Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
3 people authored Dec 20, 2024
1 parent 9b19d36 commit 667c77a
Show file tree
Hide file tree
Showing 6 changed files with 612 additions and 134 deletions.
183 changes: 49 additions & 134 deletions datafusion/functions/src/core/greatest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,19 @@
// specific language governing permissions and limitations
// under the License.

use arrow::array::{make_comparator, Array, ArrayRef, BooleanArray};
use crate::core::greatest_least_utils::GreatestLeastOperator;
use arrow::array::{make_comparator, Array, BooleanArray};
use arrow::compute::kernels::cmp;
use arrow::compute::kernels::zip::zip;
use arrow::compute::SortOptions;
use arrow::datatypes::DataType;
use arrow_buffer::BooleanBuffer;
use datafusion_common::{exec_err, plan_err, Result, ScalarValue};
use datafusion_common::{internal_err, Result, ScalarValue};
use datafusion_doc::Documentation;
use datafusion_expr::binary::type_union_resolution;
use datafusion_expr::scalar_doc_sections::DOC_SECTION_CONDITIONAL;
use datafusion_expr::ColumnarValue;
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use std::any::Any;
use std::sync::{Arc, OnceLock};
use std::sync::OnceLock;

const SORT_OPTIONS: SortOptions = SortOptions {
// We want greatest first
Expand Down Expand Up @@ -57,79 +56,57 @@ impl GreatestFunc {
}
}

fn get_logical_null_count(arr: &dyn Array) -> usize {
arr.logical_nulls()
.map(|n| n.null_count())
.unwrap_or_default()
}
impl GreatestLeastOperator for GreatestFunc {
const NAME: &'static str = "greatest";

/// Return boolean array where `arr[i] = lhs[i] >= rhs[i]` for all i, where `arr` is the result array
/// Nulls are always considered smaller than any other value
fn get_larger(lhs: &dyn Array, rhs: &dyn Array) -> Result<BooleanArray> {
// Fast path:
// If both arrays are not nested, have the same length and no nulls, we can use the faster vectorised kernel
// - If both arrays are not nested: Nested types, such as lists, are not supported as the null semantics are not well-defined.
// - both array does not have any nulls: cmp::gt_eq will return null if any of the input is null while we want to return false in that case
if !lhs.data_type().is_nested()
&& get_logical_null_count(lhs) == 0
&& get_logical_null_count(rhs) == 0
{
return cmp::gt_eq(&lhs, &rhs).map_err(|e| e.into());
}
fn keep_scalar<'a>(
lhs: &'a ScalarValue,
rhs: &'a ScalarValue,
) -> Result<&'a ScalarValue> {
if !lhs.data_type().is_nested() {
return if lhs >= rhs { Ok(lhs) } else { Ok(rhs) };
}

let cmp = make_comparator(lhs, rhs, SORT_OPTIONS)?;
// If complex type we can't compare directly as we want null values to be smaller
let cmp = make_comparator(
lhs.to_array()?.as_ref(),
rhs.to_array()?.as_ref(),
SORT_OPTIONS,
)?;

if lhs.len() != rhs.len() {
return exec_err!(
"All arrays should have the same length for greatest comparison"
);
if cmp(0, 0).is_ge() {
Ok(lhs)
} else {
Ok(rhs)
}
}

let values = BooleanBuffer::collect_bool(lhs.len(), |i| cmp(i, i).is_ge());

// No nulls as we only want to keep the values that are larger, its either true or false
Ok(BooleanArray::new(values, None))
}

/// Return array where the largest value at each index is kept
fn keep_larger(lhs: ArrayRef, rhs: ArrayRef) -> Result<ArrayRef> {
// True for values that we should keep from the left array
let keep_lhs = get_larger(lhs.as_ref(), rhs.as_ref())?;

let larger = zip(&keep_lhs, &lhs, &rhs)?;
/// Return boolean array where `arr[i] = lhs[i] >= rhs[i]` for all i, where `arr` is the result array
/// Nulls are always considered smaller than any other value
fn get_indexes_to_keep(lhs: &dyn Array, rhs: &dyn Array) -> Result<BooleanArray> {
// Fast path:
// If both arrays are not nested, have the same length and no nulls, we can use the faster vectorised kernel
// - If both arrays are not nested: Nested types, such as lists, are not supported as the null semantics are not well-defined.
// - both array does not have any nulls: cmp::gt_eq will return null if any of the input is null while we want to return false in that case
if !lhs.data_type().is_nested()
&& lhs.logical_null_count() == 0
&& rhs.logical_null_count() == 0
{
return cmp::gt_eq(&lhs, &rhs).map_err(|e| e.into());
}

Ok(larger)
}
let cmp = make_comparator(lhs, rhs, SORT_OPTIONS)?;

fn keep_larger_scalar<'a>(
lhs: &'a ScalarValue,
rhs: &'a ScalarValue,
) -> Result<&'a ScalarValue> {
if !lhs.data_type().is_nested() {
return if lhs >= rhs { Ok(lhs) } else { Ok(rhs) };
}

// If complex type we can't compare directly as we want null values to be smaller
let cmp = make_comparator(
lhs.to_array()?.as_ref(),
rhs.to_array()?.as_ref(),
SORT_OPTIONS,
)?;
if lhs.len() != rhs.len() {
return internal_err!(
"All arrays should have the same length for greatest comparison"
);
}

if cmp(0, 0).is_ge() {
Ok(lhs)
} else {
Ok(rhs)
}
}
let values = BooleanBuffer::collect_bool(lhs.len(), |i| cmp(i, i).is_ge());

fn find_coerced_type(data_types: &[DataType]) -> Result<DataType> {
if data_types.is_empty() {
plan_err!("greatest was called without any arguments. It requires at least 1.")
} else if let Some(coerced_type) = type_union_resolution(data_types) {
Ok(coerced_type)
} else {
plan_err!("Cannot find a common type for arguments")
// No nulls as we only want to keep the values that are larger, its either true or false
Ok(BooleanArray::new(values, None))
}
}

Expand All @@ -151,74 +128,12 @@ impl ScalarUDFImpl for GreatestFunc {
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
if args.is_empty() {
return exec_err!(
"greatest was called with no arguments. It requires at least 1."
);
}

// Some engines (e.g. SQL Server) allow greatest with single arg, it's a noop
if args.len() == 1 {
return Ok(args[0].clone());
}

// Split to scalars and arrays for later optimization
let (scalars, arrays): (Vec<_>, Vec<_>) = args.iter().partition(|x| match x {
ColumnarValue::Scalar(_) => true,
ColumnarValue::Array(_) => false,
});

let mut arrays_iter = arrays.iter().map(|x| match x {
ColumnarValue::Array(a) => a,
_ => unreachable!(),
});

let first_array = arrays_iter.next();

let mut largest: ArrayRef;

// Optimization: merge all scalars into one to avoid recomputing
if !scalars.is_empty() {
let mut scalars_iter = scalars.iter().map(|x| match x {
ColumnarValue::Scalar(s) => s,
_ => unreachable!(),
});

// We have at least one scalar
let mut largest_scalar = scalars_iter.next().unwrap();

for scalar in scalars_iter {
largest_scalar = keep_larger_scalar(largest_scalar, scalar)?;
}

// If we only have scalars, return the largest one
if arrays.is_empty() {
return Ok(ColumnarValue::Scalar(largest_scalar.clone()));
}

// We have at least one array
let first_array = first_array.unwrap();

// Start with the largest value
largest = keep_larger(
Arc::clone(first_array),
largest_scalar.to_array_of_size(first_array.len())?,
)?;
} else {
// If we only have arrays, start with the first array
// (We must have at least one array)
largest = Arc::clone(first_array.unwrap());
}

for array in arrays_iter {
largest = keep_larger(Arc::clone(array), largest)?;
}

Ok(ColumnarValue::Array(largest))
super::greatest_least_utils::execute_conditional::<Self>(args)
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
let coerced_type = find_coerced_type(arg_types)?;
let coerced_type =
super::greatest_least_utils::find_coerced_type::<Self>(arg_types)?;

Ok(vec![coerced_type; arg_types.len()])
}
Expand Down
133 changes: 133 additions & 0 deletions datafusion/functions/src/core/greatest_least_utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::array::{Array, ArrayRef, BooleanArray};
use arrow::compute::kernels::zip::zip;
use arrow::datatypes::DataType;
use datafusion_common::{internal_err, plan_err, Result, ScalarValue};
use datafusion_expr_common::columnar_value::ColumnarValue;
use datafusion_expr_common::type_coercion::binary::type_union_resolution;
use std::sync::Arc;

pub(super) trait GreatestLeastOperator {
const NAME: &'static str;

fn keep_scalar<'a>(
lhs: &'a ScalarValue,
rhs: &'a ScalarValue,
) -> Result<&'a ScalarValue>;

/// Return array with true for values that we should keep from the lhs array
fn get_indexes_to_keep(lhs: &dyn Array, rhs: &dyn Array) -> Result<BooleanArray>;
}

fn keep_array<Op: GreatestLeastOperator>(
lhs: ArrayRef,
rhs: ArrayRef,
) -> Result<ArrayRef> {
// True for values that we should keep from the left array
let keep_lhs = Op::get_indexes_to_keep(lhs.as_ref(), rhs.as_ref())?;

let result = zip(&keep_lhs, &lhs, &rhs)?;

Ok(result)
}

pub(super) fn execute_conditional<Op: GreatestLeastOperator>(
args: &[ColumnarValue],
) -> Result<ColumnarValue> {
if args.is_empty() {
return internal_err!(
"{} was called with no arguments. It requires at least 1.",
Op::NAME
);
}

// Some engines (e.g. SQL Server) allow greatest/least with single arg, it's a noop
if args.len() == 1 {
return Ok(args[0].clone());
}

// Split to scalars and arrays for later optimization
let (scalars, arrays): (Vec<_>, Vec<_>) = args.iter().partition(|x| match x {
ColumnarValue::Scalar(_) => true,
ColumnarValue::Array(_) => false,
});

let mut arrays_iter = arrays.iter().map(|x| match x {
ColumnarValue::Array(a) => a,
_ => unreachable!(),
});

let first_array = arrays_iter.next();

let mut result: ArrayRef;

// Optimization: merge all scalars into one to avoid recomputing (constant folding)
if !scalars.is_empty() {
let mut scalars_iter = scalars.iter().map(|x| match x {
ColumnarValue::Scalar(s) => s,
_ => unreachable!(),
});

// We have at least one scalar
let mut result_scalar = scalars_iter.next().unwrap();

for scalar in scalars_iter {
result_scalar = Op::keep_scalar(result_scalar, scalar)?;
}

// If we only have scalars, return the one that we should keep (largest/least)
if arrays.is_empty() {
return Ok(ColumnarValue::Scalar(result_scalar.clone()));
}

// We have at least one array
let first_array = first_array.unwrap();

// Start with the result value
result = keep_array::<Op>(
Arc::clone(first_array),
result_scalar.to_array_of_size(first_array.len())?,
)?;
} else {
// If we only have arrays, start with the first array
// (We must have at least one array)
result = Arc::clone(first_array.unwrap());
}

for array in arrays_iter {
result = keep_array::<Op>(Arc::clone(array), result)?;
}

Ok(ColumnarValue::Array(result))
}

pub(super) fn find_coerced_type<Op: GreatestLeastOperator>(
data_types: &[DataType],
) -> Result<DataType> {
if data_types.is_empty() {
plan_err!(
"{} was called without any arguments. It requires at least 1.",
Op::NAME
)
} else if let Some(coerced_type) = type_union_resolution(data_types) {
Ok(coerced_type)
} else {
plan_err!("Cannot find a common type for arguments")
}
}
Loading

0 comments on commit 667c77a

Please sign in to comment.