From 1c847b4d28013830df877450a8999f40b73b94d7 Mon Sep 17 00:00:00 2001 From: Vaibhav Rabber Date: Fri, 22 Sep 2023 21:29:35 +0530 Subject: [PATCH] Implement proto serialization for `(Bounded)WindowAggExec`. (#7557) Signed-off-by: Vaibhav --- .../physical-expr/src/aggregate/regr.rs | 6 + .../physical-expr/src/expressions/mod.rs | 2 +- .../physical-expr/src/window/lead_lag.rs | 5 + datafusion/physical-expr/src/window/mod.rs | 1 + datafusion/physical-expr/src/window/ntile.rs | 4 + datafusion/proto/proto/datafusion.proto | 20 +- datafusion/proto/src/generated/pbjson.rs | 272 ++++++++++-- datafusion/proto/src/generated/prost.rs | 45 +- .../proto/src/physical_plan/from_proto.rs | 57 +++ datafusion/proto/src/physical_plan/mod.rs | 267 +++++++++--- .../proto/src/physical_plan/to_proto.rs | 386 +++++++++++++----- 11 files changed, 870 insertions(+), 195 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/regr.rs b/datafusion/physical-expr/src/aggregate/regr.rs index 1b8a5c6f76de..6922cb131cac 100644 --- a/datafusion/physical-expr/src/aggregate/regr.rs +++ b/datafusion/physical-expr/src/aggregate/regr.rs @@ -43,6 +43,12 @@ pub struct Regr { expr_x: Arc, } +impl Regr { + pub fn get_regr_type(&self) -> RegrType { + self.regr_type.clone() + } +} + #[derive(Debug, Clone)] #[allow(clippy::upper_case_acronyms)] pub enum RegrType { diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index cfc8ec2f0728..422aff20edda 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -59,7 +59,7 @@ pub use crate::aggregate::grouping::Grouping; pub use crate::aggregate::median::Median; pub use crate::aggregate::min_max::{Max, Min}; pub use crate::aggregate::min_max::{MaxAccumulator, MinAccumulator}; -pub use crate::aggregate::regr::Regr; +pub use crate::aggregate::regr::{Regr, RegrType}; pub use crate::aggregate::stats::StatsType; pub use crate::aggregate::stddev::{Stddev, StddevPop}; pub use crate::aggregate::sum::Sum; diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs index 16700d1eda8b..f55f1600b9ca 100644 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ b/datafusion/physical-expr/src/window/lead_lag.rs @@ -46,6 +46,11 @@ impl WindowShift { pub fn get_shift_offset(&self) -> i64 { self.shift_offset } + + /// Get the default_value for window shift expression. + pub fn get_default_value(&self) -> Option { + self.default_value.clone() + } } /// lead() window function diff --git a/datafusion/physical-expr/src/window/mod.rs b/datafusion/physical-expr/src/window/mod.rs index b1234d599a06..644edae36c9c 100644 --- a/datafusion/physical-expr/src/window/mod.rs +++ b/datafusion/physical-expr/src/window/mod.rs @@ -31,6 +31,7 @@ pub use aggregate::PlainAggregateWindowExpr; pub use built_in::BuiltInWindowExpr; pub use built_in_window_function_expr::BuiltInWindowFunctionExpr; pub use sliding_aggregate::SlidingAggregateWindowExpr; +pub use window_expr::NthValueKind; pub use window_expr::PartitionBatches; pub use window_expr::PartitionKey; pub use window_expr::PartitionWindowAggStates; diff --git a/datafusion/physical-expr/src/window/ntile.rs b/datafusion/physical-expr/src/window/ntile.rs index 008da422da12..49aac0877ab3 100644 --- a/datafusion/physical-expr/src/window/ntile.rs +++ b/datafusion/physical-expr/src/window/ntile.rs @@ -41,6 +41,10 @@ impl Ntile { pub fn new(name: String, n: u64) -> Self { Self { name, n } } + + pub fn get_n(&self) -> u64 { + self.n + } } impl BuiltInWindowFunctionExpr for Ntile { diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 89e307a2299f..0ebcf2537dda 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1183,7 +1183,11 @@ message PhysicalWindowExprNode { BuiltInWindowFunction built_in_function = 2; // udaf = 3 } - PhysicalExprNode expr = 4; + repeated PhysicalExprNode args = 4; + repeated PhysicalExprNode partition_by = 5; + repeated PhysicalSortExprNode order_by = 6; + WindowFrame window_frame = 7; + string name = 8; } message PhysicalIsNull { @@ -1388,11 +1392,21 @@ enum AggregateMode { SINGLE_PARTITIONED = 4; } +message PartiallySortedPartitionSearchMode { + repeated uint64 columns = 6; +} + message WindowAggExecNode { PhysicalPlanNode input = 1; - repeated PhysicalExprNode window_expr = 2; - repeated string window_expr_name = 3; + repeated PhysicalWindowExprNode window_expr = 2; Schema input_schema = 4; + repeated PhysicalExprNode partition_keys = 5; + // Set optional to `None` for `BoundedWindowAggExec`. + oneof partition_search_mode { + EmptyMessage linear = 7; + PartiallySortedPartitionSearchMode partially_sorted = 8; + EmptyMessage sorted = 9; + } } message MaybeFilter { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 5ae817d1783f..458c1ea6e5eb 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -13834,6 +13834,100 @@ impl<'de> serde::Deserialize<'de> for PartialTableReference { deserializer.deserialize_struct("datafusion.PartialTableReference", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for PartiallySortedPartitionSearchMode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.columns.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PartiallySortedPartitionSearchMode", len)?; + if !self.columns.is_empty() { + struct_ser.serialize_field("columns", &self.columns.iter().map(ToString::to_string).collect::>())?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for PartiallySortedPartitionSearchMode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "columns", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Columns, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "columns" => Ok(GeneratedField::Columns), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = PartiallySortedPartitionSearchMode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.PartiallySortedPartitionSearchMode") + } + + fn visit_map(self, mut map: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut columns__ = None; + while let Some(k) = map.next_key()? { + match k { + GeneratedField::Columns => { + if columns__.is_some() { + return Err(serde::de::Error::duplicate_field("columns")); + } + columns__ = + Some(map.next_value::>>()? + .into_iter().map(|x| x.0).collect()) + ; + } + } + } + Ok(PartiallySortedPartitionSearchMode { + columns: columns__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.PartiallySortedPartitionSearchMode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for PartitionMode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -17515,15 +17609,39 @@ impl serde::Serialize for PhysicalWindowExprNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { + if !self.args.is_empty() { + len += 1; + } + if !self.partition_by.is_empty() { + len += 1; + } + if !self.order_by.is_empty() { + len += 1; + } + if self.window_frame.is_some() { + len += 1; + } + if !self.name.is_empty() { len += 1; } if self.window_function.is_some() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalWindowExprNode", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + if !self.args.is_empty() { + struct_ser.serialize_field("args", &self.args)?; + } + if !self.partition_by.is_empty() { + struct_ser.serialize_field("partitionBy", &self.partition_by)?; + } + if !self.order_by.is_empty() { + struct_ser.serialize_field("orderBy", &self.order_by)?; + } + if let Some(v) = self.window_frame.as_ref() { + struct_ser.serialize_field("windowFrame", v)?; + } + if !self.name.is_empty() { + struct_ser.serialize_field("name", &self.name)?; } if let Some(v) = self.window_function.as_ref() { match v { @@ -17549,7 +17667,14 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", + "args", + "partition_by", + "partitionBy", + "order_by", + "orderBy", + "window_frame", + "windowFrame", + "name", "aggr_function", "aggrFunction", "built_in_function", @@ -17558,7 +17683,11 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, + Args, + PartitionBy, + OrderBy, + WindowFrame, + Name, AggrFunction, BuiltInFunction, } @@ -17582,7 +17711,11 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), + "args" => Ok(GeneratedField::Args), + "partitionBy" | "partition_by" => Ok(GeneratedField::PartitionBy), + "orderBy" | "order_by" => Ok(GeneratedField::OrderBy), + "windowFrame" | "window_frame" => Ok(GeneratedField::WindowFrame), + "name" => Ok(GeneratedField::Name), "aggrFunction" | "aggr_function" => Ok(GeneratedField::AggrFunction), "builtInFunction" | "built_in_function" => Ok(GeneratedField::BuiltInFunction), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), @@ -17604,15 +17737,43 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; + let mut args__ = None; + let mut partition_by__ = None; + let mut order_by__ = None; + let mut window_frame__ = None; + let mut name__ = None; let mut window_function__ = None; while let Some(k) = map.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::Args => { + if args__.is_some() { + return Err(serde::de::Error::duplicate_field("args")); } - expr__ = map.next_value()?; + args__ = Some(map.next_value()?); + } + GeneratedField::PartitionBy => { + if partition_by__.is_some() { + return Err(serde::de::Error::duplicate_field("partitionBy")); + } + partition_by__ = Some(map.next_value()?); + } + GeneratedField::OrderBy => { + if order_by__.is_some() { + return Err(serde::de::Error::duplicate_field("orderBy")); + } + order_by__ = Some(map.next_value()?); + } + GeneratedField::WindowFrame => { + if window_frame__.is_some() { + return Err(serde::de::Error::duplicate_field("windowFrame")); + } + window_frame__ = map.next_value()?; + } + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); + } + name__ = Some(map.next_value()?); } GeneratedField::AggrFunction => { if window_function__.is_some() { @@ -17629,7 +17790,11 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { } } Ok(PhysicalWindowExprNode { - expr: expr__, + args: args__.unwrap_or_default(), + partition_by: partition_by__.unwrap_or_default(), + order_by: order_by__.unwrap_or_default(), + window_frame: window_frame__, + name: name__.unwrap_or_default(), window_function: window_function__, }) } @@ -23374,10 +23539,13 @@ impl serde::Serialize for WindowAggExecNode { if !self.window_expr.is_empty() { len += 1; } - if !self.window_expr_name.is_empty() { + if self.input_schema.is_some() { len += 1; } - if self.input_schema.is_some() { + if !self.partition_keys.is_empty() { + len += 1; + } + if self.partition_search_mode.is_some() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.WindowAggExecNode", len)?; @@ -23387,12 +23555,25 @@ impl serde::Serialize for WindowAggExecNode { if !self.window_expr.is_empty() { struct_ser.serialize_field("windowExpr", &self.window_expr)?; } - if !self.window_expr_name.is_empty() { - struct_ser.serialize_field("windowExprName", &self.window_expr_name)?; - } if let Some(v) = self.input_schema.as_ref() { struct_ser.serialize_field("inputSchema", v)?; } + if !self.partition_keys.is_empty() { + struct_ser.serialize_field("partitionKeys", &self.partition_keys)?; + } + if let Some(v) = self.partition_search_mode.as_ref() { + match v { + window_agg_exec_node::PartitionSearchMode::Linear(v) => { + struct_ser.serialize_field("linear", v)?; + } + window_agg_exec_node::PartitionSearchMode::PartiallySorted(v) => { + struct_ser.serialize_field("partiallySorted", v)?; + } + window_agg_exec_node::PartitionSearchMode::Sorted(v) => { + struct_ser.serialize_field("sorted", v)?; + } + } + } struct_ser.end() } } @@ -23406,18 +23587,25 @@ impl<'de> serde::Deserialize<'de> for WindowAggExecNode { "input", "window_expr", "windowExpr", - "window_expr_name", - "windowExprName", "input_schema", "inputSchema", + "partition_keys", + "partitionKeys", + "linear", + "partially_sorted", + "partiallySorted", + "sorted", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Input, WindowExpr, - WindowExprName, InputSchema, + PartitionKeys, + Linear, + PartiallySorted, + Sorted, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -23441,8 +23629,11 @@ impl<'de> serde::Deserialize<'de> for WindowAggExecNode { match value { "input" => Ok(GeneratedField::Input), "windowExpr" | "window_expr" => Ok(GeneratedField::WindowExpr), - "windowExprName" | "window_expr_name" => Ok(GeneratedField::WindowExprName), "inputSchema" | "input_schema" => Ok(GeneratedField::InputSchema), + "partitionKeys" | "partition_keys" => Ok(GeneratedField::PartitionKeys), + "linear" => Ok(GeneratedField::Linear), + "partiallySorted" | "partially_sorted" => Ok(GeneratedField::PartiallySorted), + "sorted" => Ok(GeneratedField::Sorted), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -23464,8 +23655,9 @@ impl<'de> serde::Deserialize<'de> for WindowAggExecNode { { let mut input__ = None; let mut window_expr__ = None; - let mut window_expr_name__ = None; let mut input_schema__ = None; + let mut partition_keys__ = None; + let mut partition_search_mode__ = None; while let Some(k) = map.next_key()? { match k { GeneratedField::Input => { @@ -23480,25 +23672,47 @@ impl<'de> serde::Deserialize<'de> for WindowAggExecNode { } window_expr__ = Some(map.next_value()?); } - GeneratedField::WindowExprName => { - if window_expr_name__.is_some() { - return Err(serde::de::Error::duplicate_field("windowExprName")); - } - window_expr_name__ = Some(map.next_value()?); - } GeneratedField::InputSchema => { if input_schema__.is_some() { return Err(serde::de::Error::duplicate_field("inputSchema")); } input_schema__ = map.next_value()?; } + GeneratedField::PartitionKeys => { + if partition_keys__.is_some() { + return Err(serde::de::Error::duplicate_field("partitionKeys")); + } + partition_keys__ = Some(map.next_value()?); + } + GeneratedField::Linear => { + if partition_search_mode__.is_some() { + return Err(serde::de::Error::duplicate_field("linear")); + } + partition_search_mode__ = map.next_value::<::std::option::Option<_>>()?.map(window_agg_exec_node::PartitionSearchMode::Linear) +; + } + GeneratedField::PartiallySorted => { + if partition_search_mode__.is_some() { + return Err(serde::de::Error::duplicate_field("partiallySorted")); + } + partition_search_mode__ = map.next_value::<::std::option::Option<_>>()?.map(window_agg_exec_node::PartitionSearchMode::PartiallySorted) +; + } + GeneratedField::Sorted => { + if partition_search_mode__.is_some() { + return Err(serde::de::Error::duplicate_field("sorted")); + } + partition_search_mode__ = map.next_value::<::std::option::Option<_>>()?.map(window_agg_exec_node::PartitionSearchMode::Sorted) +; + } } } Ok(WindowAggExecNode { input: input__, window_expr: window_expr__.unwrap_or_default(), - window_expr_name: window_expr_name__.unwrap_or_default(), input_schema: input_schema__, + partition_keys: partition_keys__.unwrap_or_default(), + partition_search_mode: partition_search_mode__, }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 2fbf4d282aec..3382fa17fe58 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1566,7 +1566,7 @@ pub mod physical_expr_node { TryCast(::prost::alloc::boxed::Box), /// window expressions #[prost(message, tag = "15")] - WindowExpr(::prost::alloc::boxed::Box), + WindowExpr(super::PhysicalWindowExprNode), #[prost(message, tag = "16")] ScalarUdf(super::PhysicalScalarUdfNode), #[prost(message, tag = "18")] @@ -1615,8 +1615,16 @@ pub mod physical_aggregate_expr_node { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalWindowExprNode { - #[prost(message, optional, boxed, tag = "4")] - pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, repeated, tag = "4")] + pub args: ::prost::alloc::vec::Vec, + #[prost(message, repeated, tag = "5")] + pub partition_by: ::prost::alloc::vec::Vec, + #[prost(message, repeated, tag = "6")] + pub order_by: ::prost::alloc::vec::Vec, + #[prost(message, optional, tag = "7")] + pub window_frame: ::core::option::Option, + #[prost(string, tag = "8")] + pub name: ::prost::alloc::string::String, #[prost(oneof = "physical_window_expr_node::WindowFunction", tags = "1, 2")] pub window_function: ::core::option::Option< physical_window_expr_node::WindowFunction, @@ -1938,15 +1946,40 @@ pub struct ProjectionExecNode { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct PartiallySortedPartitionSearchMode { + #[prost(uint64, repeated, tag = "6")] + pub columns: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct WindowAggExecNode { #[prost(message, optional, boxed, tag = "1")] pub input: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, repeated, tag = "2")] - pub window_expr: ::prost::alloc::vec::Vec, - #[prost(string, repeated, tag = "3")] - pub window_expr_name: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + pub window_expr: ::prost::alloc::vec::Vec, #[prost(message, optional, tag = "4")] pub input_schema: ::core::option::Option, + #[prost(message, repeated, tag = "5")] + pub partition_keys: ::prost::alloc::vec::Vec, + /// Set optional to `None` for `BoundedWindowAggExec`. + #[prost(oneof = "window_agg_exec_node::PartitionSearchMode", tags = "7, 8, 9")] + pub partition_search_mode: ::core::option::Option< + window_agg_exec_node::PartitionSearchMode, + >, +} +/// Nested message and enum types in `WindowAggExecNode`. +pub mod window_agg_exec_node { + /// Set optional to `None` for `BoundedWindowAggExec`. + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum PartitionSearchMode { + #[prost(message, tag = "7")] + Linear(super::EmptyMessage), + #[prost(message, tag = "8")] + PartiallySorted(super::PartiallySortedPartitionSearchMode), + #[prost(message, tag = "9")] + Sorted(super::EmptyMessage), + } } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index e1d70634d884..6665635d3422 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -31,6 +31,8 @@ use datafusion::logical_expr::window_function::WindowFunction; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{in_list, LikeExpr}; use datafusion::physical_plan::expressions::{GetFieldAccessExpr, GetIndexedFieldExpr}; +use datafusion::physical_plan::windows::create_window_expr; +use datafusion::physical_plan::WindowExpr; use datafusion::physical_plan::{ expressions::{ BinaryExpr, CaseExpr, CastExpr, Column, IsNotNullExpr, IsNullExpr, Literal, @@ -84,6 +86,61 @@ pub fn parse_physical_sort_expr( } } +/// Parses a physical window expr from a protobuf. +/// +/// # Arguments +/// +/// * `proto` - Input proto with physical window exprression node. +/// * `name` - Name of the window expression. +/// * `registry` - A registry knows how to build logical expressions out of user-defined function' names +/// * `input_schema` - The Arrow schema for the input, used for determining expression data types +/// when performing type coercion. +pub fn parse_physical_window_expr( + proto: &protobuf::PhysicalWindowExprNode, + registry: &dyn FunctionRegistry, + input_schema: &Schema, +) -> Result> { + let window_node_expr = proto + .args + .iter() + .map(|e| parse_physical_expr(e, registry, input_schema)) + .collect::>>()?; + + let partition_by = proto + .partition_by + .iter() + .map(|p| parse_physical_expr(p, registry, input_schema)) + .collect::>>()?; + + let order_by = proto + .order_by + .iter() + .map(|o| parse_physical_sort_expr(o, registry, input_schema)) + .collect::>>()?; + + let window_frame = proto + .window_frame + .as_ref() + .map(|wf| wf.clone().try_into()) + .transpose() + .map_err(|e| DataFusionError::Internal(format!("{e}")))? + .ok_or_else(|| { + DataFusionError::Internal( + "Missing required field 'window_frame' in protobuf".to_string(), + ) + })?; + + create_window_expr( + &convert_required!(proto.window_function)?, + proto.name.clone(), + &window_node_expr, + &partition_by, + &order_by, + Arc::new(window_frame), + input_schema, + ) +} + /// Parses a physical expression from a protobuf. /// /// # Arguments diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 249abf8e7073..417b2af31499 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -25,7 +25,6 @@ use datafusion::datasource::file_format::file_compression_type::FileCompressionT use datafusion::datasource::physical_plan::{AvroExec, CsvExec, ParquetExec}; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::execution::FunctionRegistry; -use datafusion::logical_expr::WindowFrame; use datafusion::physical_plan::aggregates::{create_aggregate_expr, AggregateMode}; use datafusion::physical_plan::aggregates::{AggregateExec, PhysicalGroupBy}; use datafusion::physical_plan::analyze::AnalyzeExec; @@ -44,7 +43,9 @@ use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion::physical_plan::union::UnionExec; -use datafusion::physical_plan::windows::{create_window_expr, WindowAggExec}; +use datafusion::physical_plan::windows::{ + BoundedWindowAggExec, PartitionSearchMode, WindowAggExec, +}; use datafusion::physical_plan::{ udaf, AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, WindowExpr, }; @@ -61,9 +62,11 @@ use crate::protobuf::physical_aggregate_expr_node::AggregateFunction; use crate::protobuf::physical_expr_node::ExprType; use crate::protobuf::physical_plan_node::PhysicalPlanType; use crate::protobuf::repartition_exec_node::PartitionMethod; -use crate::protobuf::{self, PhysicalPlanNode}; +use crate::protobuf::{self, window_agg_exec_node, PhysicalPlanNode}; use crate::{convert_required, into_required}; +use self::from_proto::parse_physical_window_expr; + pub mod from_proto; pub mod to_proto; @@ -288,59 +291,60 @@ impl AsExecutionPlan for PhysicalPlanNode { ) })? .clone(); - let physical_schema: SchemaRef = - SchemaRef::new((&input_schema).try_into()?); + let input_schema: SchemaRef = SchemaRef::new((&input_schema).try_into()?); let physical_window_expr: Vec> = window_agg .window_expr .iter() - .zip(window_agg.window_expr_name.iter()) - .map(|(expr, name)| { - let expr_type = expr.expr_type.as_ref().ok_or_else(|| { - proto_error("Unexpected empty window physical expression") - })?; - - match expr_type { - ExprType::WindowExpr(window_node) => { - let window_node_expr = window_node - .expr - .as_ref() - .map(|e| { - parse_physical_expr( - e.as_ref(), - registry, - &physical_schema, - ) - }) - .transpose()? - .ok_or_else(|| { - proto_error( - "missing window_node expr expression" - .to_string(), - ) - })?; - - Ok(create_window_expr( - &convert_required!(window_node.window_function)?, - name.to_owned(), - &[window_node_expr], - &[], - &[], - Arc::new(WindowFrame::new(false)), - &physical_schema, - )?) - } - _ => internal_err!("Invalid expression for WindowAggrExec"), - } + .map(|window_expr| { + parse_physical_window_expr( + window_expr, + registry, + input_schema.as_ref(), + ) }) .collect::, _>>()?; - //todo fill partition keys and sort keys - Ok(Arc::new(WindowAggExec::try_new( - physical_window_expr, - input, - Arc::new((&input_schema).try_into()?), - vec![], - )?)) + + let partition_keys = window_agg + .partition_keys + .iter() + .map(|expr| { + parse_physical_expr(expr, registry, input.schema().as_ref()) + }) + .collect::>>>()?; + + if let Some(partition_search_mode) = + window_agg.partition_search_mode.as_ref() + { + let partition_search_mode = match partition_search_mode { + window_agg_exec_node::PartitionSearchMode::Linear(_) => { + PartitionSearchMode::Linear + } + window_agg_exec_node::PartitionSearchMode::PartiallySorted( + protobuf::PartiallySortedPartitionSearchMode { columns }, + ) => PartitionSearchMode::PartiallySorted( + columns.iter().map(|c| *c as usize).collect(), + ), + window_agg_exec_node::PartitionSearchMode::Sorted(_) => { + PartitionSearchMode::Sorted + } + }; + + Ok(Arc::new(BoundedWindowAggExec::try_new( + physical_window_expr, + input, + input_schema, + partition_keys, + partition_search_mode, + )?)) + } else { + Ok(Arc::new(WindowAggExec::try_new( + physical_window_expr, + input, + input_schema, + partition_keys, + )?)) + } } PhysicalPlanType::Aggregate(hash_agg) => { let input: Arc = into_physical_plan( @@ -1304,6 +1308,88 @@ impl AsExecutionPlan for PhysicalPlanNode { }, ))), }) + } else if let Some(exec) = plan.downcast_ref::() { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.input().to_owned(), + extension_codec, + )?; + + let input_schema = protobuf::Schema::try_from(exec.input_schema().as_ref())?; + + let window_expr = + exec.window_expr() + .iter() + .map(|e| e.clone().try_into()) + .collect::>>()?; + + let partition_keys = exec + .partition_keys + .iter() + .map(|e| e.clone().try_into()) + .collect::>>()?; + + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Window(Box::new( + protobuf::WindowAggExecNode { + input: Some(Box::new(input)), + window_expr, + input_schema: Some(input_schema), + partition_keys, + partition_search_mode: None, + }, + ))), + }) + } else if let Some(exec) = plan.downcast_ref::() { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.input().to_owned(), + extension_codec, + )?; + + let input_schema = protobuf::Schema::try_from(exec.input_schema().as_ref())?; + + let window_expr = + exec.window_expr() + .iter() + .map(|e| e.clone().try_into()) + .collect::>>()?; + + let partition_keys = exec + .partition_keys + .iter() + .map(|e| e.clone().try_into()) + .collect::>>()?; + + let partition_search_mode = match &exec.partition_search_mode { + PartitionSearchMode::Linear => { + window_agg_exec_node::PartitionSearchMode::Linear( + protobuf::EmptyMessage {}, + ) + } + PartitionSearchMode::PartiallySorted(columns) => { + window_agg_exec_node::PartitionSearchMode::PartiallySorted( + protobuf::PartiallySortedPartitionSearchMode { + columns: columns.iter().map(|c| *c as u64).collect(), + }, + ) + } + PartitionSearchMode::Sorted => { + window_agg_exec_node::PartitionSearchMode::Sorted( + protobuf::EmptyMessage {}, + ) + } + }; + + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Window(Box::new( + protobuf::WindowAggExecNode { + input: Some(Box::new(input)), + window_expr, + input_schema: Some(input_schema), + partition_keys, + partition_search_mode: Some(partition_search_mode), + }, + ))), + }) } else { let mut buf: Vec = vec![]; match extension_codec.try_encode(plan_clone.clone(), &mut buf) { @@ -1419,12 +1505,18 @@ mod roundtrip_tests { use datafusion::logical_expr::{BuiltinScalarFunction, Volatility}; use datafusion::physical_expr::expressions::GetFieldAccessExpr; use datafusion::physical_expr::expressions::{cast, in_list}; + use datafusion::physical_expr::window::SlidingAggregateWindowExpr; use datafusion::physical_expr::ScalarFunctionExpr; use datafusion::physical_plan::aggregates::PhysicalGroupBy; use datafusion::physical_plan::analyze::AnalyzeExec; - use datafusion::physical_plan::expressions::{like, BinaryExpr, GetIndexedFieldExpr}; + use datafusion::physical_plan::expressions::{ + like, BinaryExpr, GetIndexedFieldExpr, NthValue, Sum, + }; use datafusion::physical_plan::functions::make_scalar_function; use datafusion::physical_plan::projection::ProjectionExec; + use datafusion::physical_plan::windows::{ + BuiltInWindowExpr, PlainAggregateWindowExpr, WindowAggExec, + }; use datafusion::physical_plan::{functions, udaf}; use datafusion::{ arrow::{ @@ -1453,7 +1545,7 @@ mod roundtrip_tests { use datafusion_common::Result; use datafusion_expr::{ Accumulator, AccumulatorFactoryFunction, AggregateUDF, ReturnTypeFunction, - Signature, StateTypeFunction, + Signature, StateTypeFunction, WindowFrame, WindowFrameBound, }; fn roundtrip_test(exec_plan: Arc) -> Result<()> { @@ -1606,6 +1698,77 @@ mod roundtrip_tests { Ok(()) } + #[test] + fn roundtrip_window() -> Result<()> { + let field_a = Field::new("a", DataType::Int64, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + + let window_frame = WindowFrame { + units: datafusion_expr::WindowFrameUnits::Range, + start_bound: WindowFrameBound::Preceding(ScalarValue::Int64(None)), + end_bound: WindowFrameBound::CurrentRow, + }; + + let builtin_window_expr = Arc::new(BuiltInWindowExpr::new( + Arc::new(NthValue::first( + "FIRST_VALUE(a) PARTITION BY [b] ORDER BY [a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", + col("a", &schema)?, + DataType::Int64, + )), + &[col("b", &schema)?], + &[PhysicalSortExpr { + expr: col("a", &schema)?, + options: SortOptions { + descending: false, + nulls_first: false, + }, + }], + Arc::new(window_frame), + )); + + let plain_aggr_window_expr = Arc::new(PlainAggregateWindowExpr::new( + Arc::new(Avg::new( + cast(col("b", &schema)?, &schema, DataType::Float64)?, + "AVG(b)".to_string(), + DataType::Float64, + )), + &[], + &[], + Arc::new(WindowFrame::new(false)), + )); + + let window_frame = WindowFrame { + units: datafusion_expr::WindowFrameUnits::Range, + start_bound: WindowFrameBound::CurrentRow, + end_bound: WindowFrameBound::Preceding(ScalarValue::Int64(None)), + }; + + let sliding_aggr_window_expr = Arc::new(SlidingAggregateWindowExpr::new( + Arc::new(Sum::new( + cast(col("a", &schema)?, &schema, DataType::Float64)?, + "SUM(a) RANGE BETWEEN CURRENT ROW AND UNBOUNDED PRECEEDING", + DataType::Float64, + )), + &[], + &[], + Arc::new(window_frame), + )); + + let input = Arc::new(EmptyExec::new(false, schema.clone())); + + roundtrip_test(Arc::new(WindowAggExec::try_new( + vec![ + builtin_window_expr, + plain_aggr_window_expr, + sliding_aggr_window_expr, + ], + input, + schema.clone(), + vec![col("b", &schema)?], + )?)) + } + #[test] fn rountrip_aggregate() -> Result<()> { let field_a = Field::new("a", DataType::Int64, false); diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index a5b1300360fe..cf3dbe26190a 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -22,13 +22,34 @@ use std::{ sync::Arc, }; -use datafusion::physical_plan::expressions::{CastExpr, TryCastExpr}; -use datafusion::physical_plan::ColumnStatistics; use datafusion::physical_plan::{ expressions::{ - CaseExpr, InListExpr, IsNotNullExpr, IsNullExpr, NegativeExpr, NotExpr, + ApproxDistinct, ApproxMedian, ApproxPercentileCont, + ApproxPercentileContWithWeight, ArrayAgg, Correlation, Covariance, CovariancePop, + DistinctArrayAgg, DistinctBitXor, DistinctSum, FirstValue, Grouping, LastValue, + Median, OrderSensitiveArrayAgg, Regr, RegrType, Stddev, StddevPop, Variance, + VariancePop, + }, + windows::BuiltInWindowExpr, + ColumnStatistics, +}; +use datafusion::{ + physical_expr::window::NthValueKind, + physical_plan::{ + expressions::{ + CaseExpr, CumeDist, InListExpr, IsNotNullExpr, IsNullExpr, NegativeExpr, + NotExpr, NthValue, Ntile, Rank, RankType, RowNumber, WindowShift, + }, + Statistics, + }, +}; +use datafusion::{ + physical_expr::window::SlidingAggregateWindowExpr, + physical_plan::{ + expressions::{CastExpr, TryCastExpr}, + windows::PlainAggregateWindowExpr, + WindowExpr, }, - Statistics, }; use datafusion::datasource::listing::{FileRange, PartitionedFile}; @@ -42,7 +63,7 @@ use datafusion::physical_plan::expressions::{ }; use datafusion::physical_plan::{AggregateExpr, PhysicalExpr}; -use crate::protobuf; +use crate::protobuf::{self, physical_window_expr_node}; use crate::protobuf::{ physical_aggregate_expr_node, PhysicalSortExprNode, PhysicalSortExprNodeCollection, ScalarValue, @@ -58,9 +79,6 @@ impl TryFrom> for protobuf::PhysicalExprNode { type Error = DataFusionError; fn try_from(a: Arc) -> Result { - use datafusion::physical_plan::expressions; - use protobuf::AggregateFunction; - let expressions: Vec = a .expressions() .iter() @@ -74,117 +92,30 @@ impl TryFrom> for protobuf::PhysicalExprNode { .map(|e| e.clone().try_into()) .collect::>>()?; - let mut distinct = false; - let aggr_function = if a.as_any().downcast_ref::().is_some() { - Ok(AggregateFunction::Avg.into()) - } else if a.as_any().downcast_ref::().is_some() { - Ok(AggregateFunction::Sum.into()) - } else if a.as_any().downcast_ref::().is_some() { - Ok(AggregateFunction::Count.into()) - } else if a.as_any().downcast_ref::().is_some() { - Ok(AggregateFunction::BitAnd.into()) - } else if a.as_any().downcast_ref::().is_some() { - Ok(AggregateFunction::BitOr.into()) - } else if a.as_any().downcast_ref::().is_some() { - Ok(AggregateFunction::BitXor.into()) - } else if a.as_any().downcast_ref::().is_some() { - Ok(AggregateFunction::BoolAnd.into()) - } else if a.as_any().downcast_ref::().is_some() { - Ok(AggregateFunction::BoolOr.into()) - } else if a.as_any().downcast_ref::().is_some() { - distinct = true; - Ok(AggregateFunction::Count.into()) - } else if a.as_any().downcast_ref::().is_some() { - Ok(AggregateFunction::Min.into()) - } else if a.as_any().downcast_ref::().is_some() { - Ok(AggregateFunction::Max.into()) - } else if a - .as_any() - .downcast_ref::() - .is_some() - { - Ok(AggregateFunction::ApproxDistinct.into()) - } else if a.as_any().downcast_ref::().is_some() { - Ok(AggregateFunction::ArrayAgg.into()) - } else if a.as_any().downcast_ref::().is_some() { - Ok(AggregateFunction::Variance.into()) - } else if a - .as_any() - .downcast_ref::() - .is_some() - { - Ok(AggregateFunction::VariancePop.into()) - } else if a - .as_any() - .downcast_ref::() - .is_some() - { - Ok(AggregateFunction::Covariance.into()) - } else if a - .as_any() - .downcast_ref::() - .is_some() - { - Ok(AggregateFunction::CovariancePop.into()) - } else if a.as_any().downcast_ref::().is_some() { - Ok(AggregateFunction::Stddev.into()) - } else if a - .as_any() - .downcast_ref::() - .is_some() - { - Ok(AggregateFunction::StddevPop.into()) - } else if a - .as_any() - .downcast_ref::() - .is_some() - { - Ok(AggregateFunction::Correlation.into()) - } else if a - .as_any() - .downcast_ref::() - .is_some() - { - Ok(AggregateFunction::ApproxPercentileCont.into()) - } else if a - .as_any() - .downcast_ref::() - .is_some() - { - Ok(AggregateFunction::ApproxPercentileContWithWeight.into()) - } else if a - .as_any() - .downcast_ref::() - .is_some() - { - Ok(AggregateFunction::ApproxMedian.into()) - } else if a.as_any().is::() { - Ok(AggregateFunction::FirstValueAgg.into()) - } else if a.as_any().is::() { - Ok(AggregateFunction::LastValueAgg.into()) - } else { - if let Some(a) = a.as_any().downcast_ref::() { - return Ok(protobuf::PhysicalExprNode { + if let Some(a) = a.as_any().downcast_ref::() { + return Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr( protobuf::PhysicalAggregateExprNode { aggregate_function: Some(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(a.fun().name.clone())), expr: expressions, ordering_req, - distinct, + distinct: false, }, )), }); - } + } - not_impl_err!("Aggregate function not supported: {a:?}") - }?; + let AggrFn { + inner: aggr_function, + distinct, + } = aggr_expr_to_aggr_fn(a.as_ref())?; Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr( protobuf::PhysicalAggregateExprNode { aggregate_function: Some( physical_aggregate_expr_node::AggregateFunction::AggrFunction( - aggr_function, + aggr_function as i32, ), ), expr: expressions, @@ -196,6 +127,253 @@ impl TryFrom> for protobuf::PhysicalExprNode { } } +impl TryFrom> for protobuf::PhysicalWindowExprNode { + type Error = DataFusionError; + + fn try_from( + window_expr: Arc, + ) -> std::result::Result { + let expr = window_expr.as_any(); + + let mut args = window_expr.expressions().to_vec(); + let window_frame = window_expr.get_window_frame(); + + let window_function = if let Some(built_in_window_expr) = + expr.downcast_ref::() + { + let expr = built_in_window_expr.get_built_in_func_expr(); + let built_in_fn_expr = expr.as_any(); + + let builtin_fn = if built_in_fn_expr.downcast_ref::().is_some() { + protobuf::BuiltInWindowFunction::RowNumber + } else if let Some(rank_expr) = built_in_fn_expr.downcast_ref::() { + match rank_expr.get_type() { + RankType::Basic => protobuf::BuiltInWindowFunction::Rank, + RankType::Dense => protobuf::BuiltInWindowFunction::DenseRank, + RankType::Percent => protobuf::BuiltInWindowFunction::PercentRank, + } + } else if built_in_fn_expr.downcast_ref::().is_some() { + protobuf::BuiltInWindowFunction::CumeDist + } else if let Some(ntile_expr) = built_in_fn_expr.downcast_ref::() { + args.insert( + 0, + Arc::new(Literal::new(datafusion_common::ScalarValue::Int64(Some( + ntile_expr.get_n() as i64, + )))), + ); + protobuf::BuiltInWindowFunction::Ntile + } else if let Some(window_shift_expr) = + built_in_fn_expr.downcast_ref::() + { + args.insert( + 1, + Arc::new(Literal::new(datafusion_common::ScalarValue::Int64(Some( + window_shift_expr.get_shift_offset(), + )))), + ); + if let Some(default_value) = window_shift_expr.get_default_value() { + args.insert(2, Arc::new(Literal::new(default_value))); + } + if window_shift_expr.get_shift_offset() >= 0 { + protobuf::BuiltInWindowFunction::Lag + } else { + protobuf::BuiltInWindowFunction::Lead + } + } else if let Some(nth_value_expr) = + built_in_fn_expr.downcast_ref::() + { + match nth_value_expr.get_kind() { + NthValueKind::First => protobuf::BuiltInWindowFunction::FirstValue, + NthValueKind::Last => protobuf::BuiltInWindowFunction::LastValue, + NthValueKind::Nth(n) => { + args.insert( + 1, + Arc::new(Literal::new( + datafusion_common::ScalarValue::Int64(Some(n as i64)), + )), + ); + protobuf::BuiltInWindowFunction::NthValue + } + } + } else { + return not_impl_err!("BuiltIn function not supported: {expr:?}"); + }; + + physical_window_expr_node::WindowFunction::BuiltInFunction(builtin_fn as i32) + } else if let Some(plain_aggr_window_expr) = + expr.downcast_ref::() + { + let AggrFn { inner, distinct } = aggr_expr_to_aggr_fn( + plain_aggr_window_expr.get_aggregate_expr().as_ref(), + )?; + + if distinct { + // TODO + return not_impl_err!( + "Distinct aggregate functions not supported in window expressions" + ); + } + + if !window_frame.start_bound.is_unbounded() { + return Err(DataFusionError::Internal(format!("Invalid PlainAggregateWindowExpr = {window_expr:?} with WindowFrame = {window_frame:?}"))); + } + + physical_window_expr_node::WindowFunction::AggrFunction(inner as i32) + } else if let Some(sliding_aggr_window_expr) = + expr.downcast_ref::() + { + let AggrFn { inner, distinct } = aggr_expr_to_aggr_fn( + sliding_aggr_window_expr.get_aggregate_expr().as_ref(), + )?; + + if distinct { + // TODO + return not_impl_err!( + "Distinct aggregate functions not supported in window expressions" + ); + } + + if window_frame.start_bound.is_unbounded() { + return Err(DataFusionError::Internal(format!("Invalid SlidingAggregateWindowExpr = {window_expr:?} with WindowFrame = {window_frame:?}"))); + } + + physical_window_expr_node::WindowFunction::AggrFunction(inner as i32) + } else { + return not_impl_err!("WindowExpr not supported: {window_expr:?}"); + }; + + let args = args + .into_iter() + .map(|e| e.try_into()) + .collect::>>()?; + + let partition_by = window_expr + .partition_by() + .iter() + .map(|p| p.clone().try_into()) + .collect::>>()?; + + let order_by = window_expr + .order_by() + .iter() + .map(|o| o.clone().try_into()) + .collect::>>()?; + + let window_frame: protobuf::WindowFrame = window_frame + .as_ref() + .try_into() + .map_err(|e| DataFusionError::Internal(format!("{e}")))?; + + let name = window_expr.name().to_string(); + + Ok(protobuf::PhysicalWindowExprNode { + args, + partition_by, + order_by, + window_frame: Some(window_frame), + window_function: Some(window_function), + name, + }) + } +} + +struct AggrFn { + inner: protobuf::AggregateFunction, + distinct: bool, +} + +fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { + let aggr_expr = expr.as_any(); + let mut distinct = false; + + let inner = if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::Count + } else if aggr_expr.downcast_ref::().is_some() { + distinct = true; + protobuf::AggregateFunction::Count + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::Grouping + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::BitAnd + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::BitOr + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::BitXor + } else if aggr_expr.downcast_ref::().is_some() { + distinct = true; + protobuf::AggregateFunction::BitXor + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::BoolAnd + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::BoolOr + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::Sum + } else if aggr_expr.downcast_ref::().is_some() { + distinct = true; + protobuf::AggregateFunction::Sum + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::ApproxDistinct + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::ArrayAgg + } else if aggr_expr.downcast_ref::().is_some() { + distinct = true; + protobuf::AggregateFunction::ArrayAgg + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::ArrayAgg + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::Min + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::Max + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::Avg + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::Variance + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::VariancePop + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::Covariance + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::CovariancePop + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::Stddev + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::StddevPop + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::Correlation + } else if let Some(regr_expr) = aggr_expr.downcast_ref::() { + match regr_expr.get_regr_type() { + RegrType::Slope => protobuf::AggregateFunction::RegrSlope, + RegrType::Intercept => protobuf::AggregateFunction::RegrIntercept, + RegrType::Count => protobuf::AggregateFunction::RegrCount, + RegrType::R2 => protobuf::AggregateFunction::RegrR2, + RegrType::AvgX => protobuf::AggregateFunction::RegrAvgx, + RegrType::AvgY => protobuf::AggregateFunction::RegrAvgy, + RegrType::SXX => protobuf::AggregateFunction::RegrSxx, + RegrType::SYY => protobuf::AggregateFunction::RegrSyy, + RegrType::SXY => protobuf::AggregateFunction::RegrSxy, + } + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::ApproxPercentileCont + } else if aggr_expr + .downcast_ref::() + .is_some() + { + protobuf::AggregateFunction::ApproxPercentileContWithWeight + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::ApproxMedian + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::Median + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::FirstValueAgg + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::LastValueAgg + } else { + return not_impl_err!("Aggregate function not supported: {expr:?}"); + }; + + Ok(AggrFn { inner, distinct }) +} + impl TryFrom> for protobuf::PhysicalExprNode { type Error = DataFusionError;