From e7858ff0ab1c282ab46bd93cabc3dc83db583165 Mon Sep 17 00:00:00 2001 From: Dan Harris <1327726+thinkharderdev@users.noreply.github.com> Date: Fri, 17 May 2024 17:09:55 -0400 Subject: [PATCH] Handle dictionary values in ScalarValue serde (#10563) * Handle dictionary values in ScalarValue serde * Do not panic on failed physical expr decoding (#241) * revert clippy change --- datafusion/proto/proto/datafusion.proto | 6 + datafusion/proto/src/generated/pbjson.rs | 133 ++++++++++++++++++ datafusion/proto/src/generated/prost.rs | 13 ++ .../proto/src/logical_plan/from_proto.rs | 49 ++++++- datafusion/proto/src/logical_plan/to_proto.rs | 9 +- datafusion/proto/src/physical_plan/mod.rs | 4 +- .../tests/cases/roundtrip_logical_plan.rs | 49 +++++++ 7 files changed, 259 insertions(+), 4 deletions(-) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index fd79345275ab..8d69b0bad5ed 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -797,9 +797,15 @@ message Union{ // Used for List/FixedSizeList/LargeList/Struct message ScalarNestedValue { + message Dictionary { + bytes ipc_message = 1; + bytes arrow_data = 2; + } + bytes ipc_message = 1; bytes arrow_data = 2; Schema schema = 3; + repeated Dictionary dictionaries = 4; } message ScalarTime32Value { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 01d9a6e0dde6..8df0aeb851df 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22605,6 +22605,9 @@ impl serde::Serialize for ScalarNestedValue { if self.schema.is_some() { len += 1; } + if !self.dictionaries.is_empty() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.ScalarNestedValue", len)?; if !self.ipc_message.is_empty() { #[allow(clippy::needless_borrow)] @@ -22617,6 +22620,9 @@ impl serde::Serialize for ScalarNestedValue { if let Some(v) = self.schema.as_ref() { struct_ser.serialize_field("schema", v)?; } + if !self.dictionaries.is_empty() { + struct_ser.serialize_field("dictionaries", &self.dictionaries)?; + } struct_ser.end() } } @@ -22632,6 +22638,7 @@ impl<'de> serde::Deserialize<'de> for ScalarNestedValue { "arrow_data", "arrowData", "schema", + "dictionaries", ]; #[allow(clippy::enum_variant_names)] @@ -22639,6 +22646,7 @@ impl<'de> serde::Deserialize<'de> for ScalarNestedValue { IpcMessage, ArrowData, Schema, + Dictionaries, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -22663,6 +22671,7 @@ impl<'de> serde::Deserialize<'de> for ScalarNestedValue { "ipcMessage" | "ipc_message" => Ok(GeneratedField::IpcMessage), "arrowData" | "arrow_data" => Ok(GeneratedField::ArrowData), "schema" => Ok(GeneratedField::Schema), + "dictionaries" => Ok(GeneratedField::Dictionaries), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -22685,6 +22694,7 @@ impl<'de> serde::Deserialize<'de> for ScalarNestedValue { let mut ipc_message__ = None; let mut arrow_data__ = None; let mut schema__ = None; + let mut dictionaries__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::IpcMessage => { @@ -22709,18 +22719,141 @@ impl<'de> serde::Deserialize<'de> for ScalarNestedValue { } schema__ = map_.next_value()?; } + GeneratedField::Dictionaries => { + if dictionaries__.is_some() { + return Err(serde::de::Error::duplicate_field("dictionaries")); + } + dictionaries__ = Some(map_.next_value()?); + } } } Ok(ScalarNestedValue { ipc_message: ipc_message__.unwrap_or_default(), arrow_data: arrow_data__.unwrap_or_default(), schema: schema__, + dictionaries: dictionaries__.unwrap_or_default(), }) } } deserializer.deserialize_struct("datafusion.ScalarNestedValue", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for scalar_nested_value::Dictionary { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.ipc_message.is_empty() { + len += 1; + } + if !self.arrow_data.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ScalarNestedValue.Dictionary", len)?; + if !self.ipc_message.is_empty() { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("ipcMessage", pbjson::private::base64::encode(&self.ipc_message).as_str())?; + } + if !self.arrow_data.is_empty() { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("arrowData", pbjson::private::base64::encode(&self.arrow_data).as_str())?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for scalar_nested_value::Dictionary { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "ipc_message", + "ipcMessage", + "arrow_data", + "arrowData", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + IpcMessage, + ArrowData, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "ipcMessage" | "ipc_message" => Ok(GeneratedField::IpcMessage), + "arrowData" | "arrow_data" => Ok(GeneratedField::ArrowData), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = scalar_nested_value::Dictionary; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ScalarNestedValue.Dictionary") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut ipc_message__ = None; + let mut arrow_data__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::IpcMessage => { + if ipc_message__.is_some() { + return Err(serde::de::Error::duplicate_field("ipcMessage")); + } + ipc_message__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } + GeneratedField::ArrowData => { + if arrow_data__.is_some() { + return Err(serde::de::Error::duplicate_field("arrowData")); + } + arrow_data__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } + } + } + Ok(scalar_nested_value::Dictionary { + ipc_message: ipc_message__.unwrap_or_default(), + arrow_data: arrow_data__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.ScalarNestedValue.Dictionary", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for ScalarTime32Value { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 64e72ba03878..b6b7687e6c00 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1133,6 +1133,19 @@ pub struct ScalarNestedValue { pub arrow_data: ::prost::alloc::vec::Vec, #[prost(message, optional, tag = "3")] pub schema: ::core::option::Option, + #[prost(message, repeated, tag = "4")] + pub dictionaries: ::prost::alloc::vec::Vec, +} +/// Nested message and enum types in `ScalarNestedValue`. +pub mod scalar_nested_value { + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct Dictionary { + #[prost(bytes = "vec", tag = "1")] + pub ipc_message: ::prost::alloc::vec::Vec, + #[prost(bytes = "vec", tag = "2")] + pub arrow_data: ::prost::alloc::vec::Vec, + } } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 585bcad7f38c..5df8eb59e173 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -15,8 +15,10 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashMap; use std::sync::Arc; +use arrow::array::ArrayRef; use arrow::{ array::AsArray, buffer::Buffer, @@ -522,6 +524,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { let protobuf::ScalarNestedValue { ipc_message, arrow_data, + dictionaries, schema, } = &v; @@ -548,11 +551,55 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { ) })?; + let dict_by_id: HashMap = dictionaries.iter().map(|protobuf::scalar_nested_value::Dictionary { ipc_message, arrow_data }| { + let message = root_as_message(ipc_message.as_slice()).map_err(|e| { + Error::General(format!( + "Error IPC message while deserializing ScalarValue::List dictionary message: {e}" + )) + })?; + let buffer = Buffer::from(arrow_data); + + let dict_batch = message.header_as_dictionary_batch().ok_or_else(|| { + Error::General( + "Unexpected message type deserializing ScalarValue::List dictionary message" + .to_string(), + ) + })?; + + let id = dict_batch.id(); + + let fields_using_this_dictionary = schema.fields_with_dict_id(id); + let first_field = fields_using_this_dictionary.first().ok_or_else(|| { + Error::General("dictionary id not found in schema while deserializing ScalarValue::List".to_string()) + })?; + + let values: ArrayRef = match first_field.data_type() { + DataType::Dictionary(_, ref value_type) => { + // Make a fake schema for the dictionary batch. + let value = value_type.as_ref().clone(); + let schema = Schema::new(vec![Field::new("", value, true)]); + // Read a single column + let record_batch = read_record_batch( + &buffer, + dict_batch.data().unwrap(), + Arc::new(schema), + &Default::default(), + None, + &message.version(), + )?; + Ok(record_batch.column(0).clone()) + } + _ => Err(Error::General("dictionary id not found in schema while deserializing ScalarValue::List".to_string())), + }?; + + Ok((id,values)) + }).collect::>>()?; + let record_batch = read_record_batch( &buffer, ipc_batch, Arc::new(schema), - &Default::default(), + &dict_by_id, None, &message.version(), ) diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index ecdbde6faf59..52482c890ac9 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1497,7 +1497,7 @@ fn encode_scalar_nested_value( let gen = IpcDataGenerator {}; let mut dict_tracker = DictionaryTracker::new(false); - let (_, encoded_message) = gen + let (encoded_dictionaries, encoded_message) = gen .encoded_batch(&batch, &mut dict_tracker, &Default::default()) .map_err(|e| { Error::General(format!("Error encoding ScalarValue::List as IPC: {e}")) @@ -1508,6 +1508,13 @@ fn encode_scalar_nested_value( let scalar_list_value = protobuf::ScalarNestedValue { ipc_message: encoded_message.ipc_message, arrow_data: encoded_message.arrow_data, + dictionaries: encoded_dictionaries + .into_iter() + .map(|data| protobuf::scalar_nested_value::Dictionary { + ipc_message: data.ipc_message, + arrow_data: data.arrow_data, + }) + .collect(), schema: Some(schema), }; diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 4de0b7c06d45..0515ed5006aa 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -494,9 +494,9 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { match expr_type { ExprType::AggregateExpr(agg_node) => { let input_phy_expr: Vec> = agg_node.expr.iter() - .map(|e| parse_physical_expr(e, registry, &physical_schema, extension_codec).unwrap()).collect(); + .map(|e| parse_physical_expr(e, registry, &physical_schema, extension_codec)).collect::>>()?; let ordering_req: Vec = agg_node.ordering_req.iter() - .map(|e| parse_physical_sort_expr(e, registry, &physical_schema, extension_codec).unwrap()).collect(); + .map(|e| parse_physical_sort_expr(e, registry, &physical_schema, extension_codec)).collect::>>()?; agg_node.aggregate_function.as_ref().map(|func| { match func { AggregateFunction::AggrFunction(i) => { diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index b5b0b4c2247a..472d64905b1b 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -1104,9 +1104,58 @@ fn round_trip_scalar_values() { ) .build() .unwrap(), + ScalarStructBuilder::new() + .with_scalar( + Field::new("a", DataType::Int32, true), + ScalarValue::from(23i32), + ) + .with_scalar( + Field::new("b", DataType::Boolean, false), + ScalarValue::from(false), + ) + .with_scalar( + Field::new( + "c", + DataType::Dictionary( + Box::new(DataType::UInt16), + Box::new(DataType::Utf8), + ), + false, + ), + ScalarValue::Dictionary( + Box::new(DataType::UInt16), + Box::new("value".into()), + ), + ) + .build() + .unwrap(), + ScalarValue::try_from(&DataType::Struct(Fields::from(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Boolean, false), + ]))) + .unwrap(), ScalarValue::try_from(&DataType::Struct(Fields::from(vec![ Field::new("a", DataType::Int32, true), Field::new("b", DataType::Boolean, false), + Field::new( + "c", + DataType::Dictionary( + Box::new(DataType::UInt16), + Box::new(DataType::Binary), + ), + false, + ), + Field::new( + "d", + DataType::new_list( + DataType::Dictionary( + Box::new(DataType::UInt16), + Box::new(DataType::Binary), + ), + false, + ), + false, + ), ]))) .unwrap(), ScalarValue::FixedSizeBinary(b"bar".to_vec().len() as i32, Some(b"bar".to_vec())),