Skip to content

Commit

Permalink
Convert approx_median to UDAF (#10840)
Browse files Browse the repository at this point in the history
* move tdigest to physical-expr-common

* move approx_percentile_cont_accumulator to function-aggregate

* implement approx_meidan udaf

* remove approx_median aggregation function

* fix sqllogictests

* add removed type tests

* cargo fmt and clippy

* add logical roundtrip test

* fix dataframe test

* fix test and proto gen

* update lock in datafusion-cli

* fix typo

* fix test and doc

* fix sql_integration

* cargo fmt

* follow the checking style like other udaf

* add comment and modified dependency

* update lock and fmt

* add missing test annotation
  • Loading branch information
goldmedal authored Jun 10, 2024
1 parent e8fdc09 commit 3773fb7
Show file tree
Hide file tree
Showing 26 changed files with 471 additions and 497 deletions.
20 changes: 10 additions & 10 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion datafusion/core/tests/dataframe/dataframe_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use datafusion::assert_batches_eq;
use datafusion_common::{DFSchema, ScalarValue};
use datafusion_expr::expr::Alias;
use datafusion_expr::ExprSchemable;
use datafusion_functions_aggregate::expr_fn::approx_median;

fn test_schema() -> SchemaRef {
Arc::new(Schema::new(vec![
Expand Down Expand Up @@ -342,7 +343,7 @@ async fn test_fn_approx_median() -> Result<()> {

let expected = [
"+-----------------------+",
"| APPROX_MEDIAN(test.b) |",
"| approx_median(test.b) |",
"+-----------------------+",
"| 10 |",
"+-----------------------+",
Expand Down
8 changes: 2 additions & 6 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ pub enum AggregateFunction {
ApproxPercentileCont,
/// Approximate continuous percentile function with weight
ApproxPercentileContWithWeight,
/// ApproxMedian
ApproxMedian,
/// Grouping
Grouping,
/// Bit And
Expand Down Expand Up @@ -112,7 +110,6 @@ impl AggregateFunction {
RegrSXY => "REGR_SXY",
ApproxPercentileCont => "APPROX_PERCENTILE_CONT",
ApproxPercentileContWithWeight => "APPROX_PERCENTILE_CONT_WITH_WEIGHT",
ApproxMedian => "APPROX_MEDIAN",
Grouping => "GROUPING",
BitAnd => "BIT_AND",
BitOr => "BIT_OR",
Expand Down Expand Up @@ -161,7 +158,6 @@ impl FromStr for AggregateFunction {
"regr_sxy" => AggregateFunction::RegrSXY,
// approximate
"approx_distinct" => AggregateFunction::ApproxDistinct,
"approx_median" => AggregateFunction::ApproxMedian,
"approx_percentile_cont" => AggregateFunction::ApproxPercentileCont,
"approx_percentile_cont_with_weight" => {
AggregateFunction::ApproxPercentileContWithWeight
Expand Down Expand Up @@ -234,7 +230,6 @@ impl AggregateFunction {
AggregateFunction::ApproxPercentileContWithWeight => {
Ok(coerced_data_types[0].clone())
}
AggregateFunction::ApproxMedian => Ok(coerced_data_types[0].clone()),
AggregateFunction::Grouping => Ok(DataType::Int32),
AggregateFunction::NthValue => Ok(coerced_data_types[0].clone()),
AggregateFunction::StringAgg => Ok(DataType::LargeUtf8),
Expand Down Expand Up @@ -284,7 +279,8 @@ impl AggregateFunction {
AggregateFunction::BoolAnd | AggregateFunction::BoolOr => {
Signature::uniform(1, vec![DataType::Boolean], Volatility::Immutable)
}
AggregateFunction::Avg | AggregateFunction::ApproxMedian => {

AggregateFunction::Avg => {
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
}
AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable),
Expand Down
12 changes: 0 additions & 12 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,18 +284,6 @@ pub fn approx_distinct(expr: Expr) -> Expr {
))
}

/// Calculate an approximation of the median for `expr`.
pub fn approx_median(expr: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new(
aggregate_function::AggregateFunction::ApproxMedian,
vec![expr],
false,
None,
None,
None,
))
}

/// Calculate an approximation of the specified `percentile` for `expr`.
pub fn approx_percentile_cont(expr: Expr, percentile: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new(
Expand Down
10 changes: 0 additions & 10 deletions datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,16 +231,6 @@ pub fn coerce_types(
}
Ok(input_types.to_vec())
}
AggregateFunction::ApproxMedian => {
if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) {
return plan_err!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun,
input_types[0]
);
}
Ok(input_types.to_vec())
}
AggregateFunction::NthValue => Ok(input_types.to_vec()),
AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]),
AggregateFunction::StringAgg => {
Expand Down
129 changes: 129 additions & 0 deletions datafusion/functions-aggregate/src/approx_median.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Defines physical expressions for APPROX_MEDIAN that can be evaluated MEDIAN at runtime during query execution
use std::any::Any;
use std::fmt::Debug;

use arrow::{datatypes::DataType, datatypes::Field};
use arrow_schema::DataType::{Float64, UInt64};

use datafusion_common::{not_impl_err, plan_err, Result};
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::type_coercion::aggregates::NUMERICS;
use datafusion_expr::utils::format_state_name;
use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility};
use datafusion_physical_expr_common::aggregate::utils::down_cast_any_ref;

use crate::approx_percentile_cont::ApproxPercentileAccumulator;

make_udaf_expr_and_func!(
ApproxMedian,
approx_median,
expression,
"Computes the approximate median of a set of numbers",
approx_median_udaf
);

/// APPROX_MEDIAN aggregate expression
pub struct ApproxMedian {
signature: Signature,
}

impl Debug for ApproxMedian {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("ApproxMedian")
.field("name", &self.name())
.field("signature", &self.signature)
.finish()
}
}

impl Default for ApproxMedian {
fn default() -> Self {
Self::new()
}
}

impl ApproxMedian {
/// Create a new APPROX_MEDIAN aggregate function
pub fn new() -> Self {
Self {
signature: Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable),
}
}
}

impl AggregateUDFImpl for ApproxMedian {
/// Return a reference to Any that can be used for downcasting
fn as_any(&self) -> &dyn Any {
self
}

fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
Ok(vec![
Field::new(format_state_name(args.name, "max_size"), UInt64, false),
Field::new(format_state_name(args.name, "sum"), Float64, false),
Field::new(format_state_name(args.name, "count"), Float64, false),
Field::new(format_state_name(args.name, "max"), Float64, false),
Field::new(format_state_name(args.name, "min"), Float64, false),
Field::new_list(
format_state_name(args.name, "centroids"),
Field::new("item", Float64, true),
false,
),
])
}

fn name(&self) -> &str {
"approx_median"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if !arg_types[0].is_numeric() {
return plan_err!("ApproxMedian requires numeric input types");
}
Ok(arg_types[0].clone())
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
if acc_args.is_distinct {
return not_impl_err!(
"APPROX_MEDIAN(DISTINCT) aggregations are not available"
);
}

Ok(Box::new(ApproxPercentileAccumulator::new(
0.5_f64,
acc_args.input_type.clone(),
)))
}
}

impl PartialEq<dyn Any> for ApproxMedian {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| self.signature == x.signature)
.unwrap_or(false)
}
}
Loading

0 comments on commit 3773fb7

Please sign in to comment.