diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 21bef3c2c98e..01a854ffbdf2 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -42,7 +42,7 @@ use crate::variation_const::{ DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, INTERVAL_MONTH_DAY_NANO_TYPE_NAME, LARGE_CONTAINER_TYPE_VARIATION_REF, - UNSIGNED_INTEGER_TYPE_VARIATION_REF, + UNSIGNED_INTEGER_TYPE_VARIATION_REF, VIEW_CONTAINER_TYPE_VARIATION_REF, }; #[allow(deprecated)] use crate::variation_const::{ @@ -1432,6 +1432,7 @@ fn from_substrait_type( r#type::Kind::Binary(binary) => match binary.type_variation_reference { DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::Binary), LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::LargeBinary), + VIEW_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::BinaryView), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" ), @@ -1442,6 +1443,7 @@ fn from_substrait_type( r#type::Kind::String(string) => match string.type_variation_reference { DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::Utf8), LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::LargeUtf8), + VIEW_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::Utf8View), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" ), @@ -1759,6 +1761,7 @@ fn from_substrait_literal( Some(LiteralType::String(s)) => match lit.type_variation_reference { DEFAULT_CONTAINER_TYPE_VARIATION_REF => ScalarValue::Utf8(Some(s.clone())), LARGE_CONTAINER_TYPE_VARIATION_REF => ScalarValue::LargeUtf8(Some(s.clone())), + VIEW_CONTAINER_TYPE_VARIATION_REF => ScalarValue::Utf8View(Some(s.clone())), others => { return substrait_err!("Unknown type variation reference {others}"); } @@ -1768,6 +1771,7 @@ fn from_substrait_literal( LARGE_CONTAINER_TYPE_VARIATION_REF => { ScalarValue::LargeBinary(Some(b.clone())) } + VIEW_CONTAINER_TYPE_VARIATION_REF => ScalarValue::BinaryView(Some(b.clone())), others => { return substrait_err!("Unknown type variation reference {others}"); } diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index e71cf04cd341..f323ae146600 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -37,7 +37,7 @@ use crate::variation_const::{ DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, INTERVAL_MONTH_DAY_NANO_TYPE_NAME, LARGE_CONTAINER_TYPE_VARIATION_REF, - UNSIGNED_INTEGER_TYPE_VARIATION_REF, + UNSIGNED_INTEGER_TYPE_VARIATION_REF, VIEW_CONTAINER_TYPE_VARIATION_REF, }; use datafusion::arrow::array::{Array, GenericListArray, OffsetSizeTrait}; use datafusion::common::{ @@ -1450,6 +1450,12 @@ fn to_substrait_type( nullability, })), }), + DataType::BinaryView => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Binary(r#type::Binary { + type_variation_reference: VIEW_CONTAINER_TYPE_VARIATION_REF, + nullability, + })), + }), DataType::Utf8 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::String(r#type::String { type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, @@ -1462,6 +1468,12 @@ fn to_substrait_type( nullability, })), }), + DataType::Utf8View => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::String(r#type::String { + type_variation_reference: VIEW_CONTAINER_TYPE_VARIATION_REF, + nullability, + })), + }), DataType::List(inner) => { let inner_type = to_substrait_type(inner.data_type(), inner.is_nullable(), extensions)?; @@ -1902,6 +1914,10 @@ fn to_substrait_literal( LiteralType::Binary(b.clone()), LARGE_CONTAINER_TYPE_VARIATION_REF, ), + ScalarValue::BinaryView(Some(b)) => ( + LiteralType::Binary(b.clone()), + VIEW_CONTAINER_TYPE_VARIATION_REF, + ), ScalarValue::FixedSizeBinary(_, Some(b)) => ( LiteralType::FixedBinary(b.clone()), DEFAULT_TYPE_VARIATION_REF, @@ -1914,6 +1930,10 @@ fn to_substrait_literal( LiteralType::String(s.clone()), LARGE_CONTAINER_TYPE_VARIATION_REF, ), + ScalarValue::Utf8View(Some(s)) => ( + LiteralType::String(s.clone()), + VIEW_CONTAINER_TYPE_VARIATION_REF, + ), ScalarValue::Decimal128(v, p, s) if v.is_some() => ( LiteralType::Decimal(Decimal { value: v.unwrap().to_le_bytes().to_vec(), @@ -2335,8 +2355,10 @@ mod test { round_trip_type(DataType::Binary)?; round_trip_type(DataType::FixedSizeBinary(10))?; round_trip_type(DataType::LargeBinary)?; + round_trip_type(DataType::BinaryView)?; round_trip_type(DataType::Utf8)?; round_trip_type(DataType::LargeUtf8)?; + round_trip_type(DataType::Utf8View)?; round_trip_type(DataType::Decimal128(10, 2))?; round_trip_type(DataType::Decimal256(30, 2))?; diff --git a/datafusion/substrait/src/physical_plan/consumer.rs b/datafusion/substrait/src/physical_plan/consumer.rs index 5a8b888ef1cc..a8f8ce048e0f 100644 --- a/datafusion/substrait/src/physical_plan/consumer.rs +++ b/datafusion/substrait/src/physical_plan/consumer.rs @@ -37,6 +37,11 @@ use substrait::proto::{ expression::MaskExpression, read_rel::ReadType, rel::RelType, Rel, }; +use crate::variation_const::{ + DEFAULT_CONTAINER_TYPE_VARIATION_REF, LARGE_CONTAINER_TYPE_VARIATION_REF, + VIEW_CONTAINER_TYPE_VARIATION_REF, +}; + /// Convert Substrait Rel to DataFusion ExecutionPlan #[async_recursion] pub async fn from_substrait_rel( @@ -177,7 +182,27 @@ fn to_field(name: &String, r#type: &Type) -> Result { } Kind::String(string) => { nullable = is_nullable(string.nullability); - Ok(DataType::Utf8) + match string.type_variation_reference { + DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::Utf8), + LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::LargeUtf8), + VIEW_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::Utf8View), + _ => substrait_err!( + "Invalid type variation found for substrait string type class: {}", + string.type_variation_reference + ), + } + } + Kind::Binary(binary) => { + nullable = is_nullable(binary.nullability); + match binary.type_variation_reference { + DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::Binary), + LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::LargeBinary), + VIEW_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::BinaryView), + _ => substrait_err!( + "Invalid type variation found for substrait binary type class: {}", + binary.type_variation_reference + ), + } } _ => substrait_err!( "Unsupported kind: {:?} in the type with name {}", diff --git a/datafusion/substrait/src/physical_plan/producer.rs b/datafusion/substrait/src/physical_plan/producer.rs index 57fe68c4a780..7279785ae873 100644 --- a/datafusion/substrait/src/physical_plan/producer.rs +++ b/datafusion/substrait/src/physical_plan/producer.rs @@ -23,7 +23,7 @@ use std::collections::HashMap; use substrait::proto::expression::mask_expression::{StructItem, StructSelect}; use substrait::proto::expression::MaskExpression; use substrait::proto::r#type::{ - Boolean, Fp64, Kind, Nullability, String as SubstraitString, Struct, I64, + Binary, Boolean, Fp64, Kind, Nullability, String as SubstraitString, Struct, I64, }; use substrait::proto::read_rel::local_files::file_or_files::ParquetReadOptions; use substrait::proto::read_rel::local_files::file_or_files::{FileFormat, PathType}; @@ -35,6 +35,11 @@ use substrait::proto::ReadRel; use substrait::proto::Rel; use substrait::proto::{extensions, NamedStruct, Type}; +use crate::variation_const::{ + DEFAULT_CONTAINER_TYPE_VARIATION_REF, LARGE_CONTAINER_TYPE_VARIATION_REF, + VIEW_CONTAINER_TYPE_VARIATION_REF, +}; + /// Convert DataFusion ExecutionPlan to Substrait Rel pub fn to_substrait_rel( plan: &dyn ExecutionPlan, @@ -155,7 +160,37 @@ fn to_substrait_type(data_type: &DataType, nullable: bool) -> Result { }), DataType::Utf8 => Ok(Type { kind: Some(Kind::String(SubstraitString { - type_variation_reference: 0, + type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::LargeUtf8 => Ok(Type { + kind: Some(Kind::String(SubstraitString { + type_variation_reference: LARGE_CONTAINER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::Utf8View => Ok(Type { + kind: Some(Kind::String(SubstraitString { + type_variation_reference: VIEW_CONTAINER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::Binary => Ok(Type { + kind: Some(Kind::Binary(Binary { + type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::LargeBinary => Ok(Type { + kind: Some(Kind::Binary(Binary { + type_variation_reference: LARGE_CONTAINER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::BinaryView => Ok(Type { + kind: Some(Kind::Binary(Binary { + type_variation_reference: VIEW_CONTAINER_TYPE_VARIATION_REF, nullability, })), }), diff --git a/datafusion/substrait/src/variation_const.rs b/datafusion/substrait/src/variation_const.rs index 1525da764509..a3e76389d510 100644 --- a/datafusion/substrait/src/variation_const.rs +++ b/datafusion/substrait/src/variation_const.rs @@ -52,6 +52,7 @@ pub const DATE_32_TYPE_VARIATION_REF: u32 = 0; pub const DATE_64_TYPE_VARIATION_REF: u32 = 1; pub const DEFAULT_CONTAINER_TYPE_VARIATION_REF: u32 = 0; pub const LARGE_CONTAINER_TYPE_VARIATION_REF: u32 = 1; +pub const VIEW_CONTAINER_TYPE_VARIATION_REF: u32 = 2; pub const DECIMAL_128_TYPE_VARIATION_REF: u32 = 0; pub const DECIMAL_256_TYPE_VARIATION_REF: u32 = 1; diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 083a589fce26..98daac65e1cf 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -716,8 +716,10 @@ async fn all_type_literal() -> Result<()> { date32_col = arrow_cast('2020-01-01', 'Date32') AND binary_col = arrow_cast('binary', 'Binary') AND large_binary_col = arrow_cast('large_binary', 'LargeBinary') AND + view_binary_col = arrow_cast('binary_view', 'BinaryView') AND utf8_col = arrow_cast('utf8', 'Utf8') AND - large_utf8_col = arrow_cast('large_utf8', 'LargeUtf8');", + large_utf8_col = arrow_cast('large_utf8', 'LargeUtf8') AND + view_utf8_col = arrow_cast('utf8_view', 'Utf8View');", ) .await } @@ -1231,9 +1233,11 @@ async fn create_all_type_context() -> Result { Field::new("date64_col", DataType::Date64, true), Field::new("binary_col", DataType::Binary, true), Field::new("large_binary_col", DataType::LargeBinary, true), + Field::new("view_binary_col", DataType::BinaryView, true), Field::new("fixed_size_binary_col", DataType::FixedSizeBinary(42), true), Field::new("utf8_col", DataType::Utf8, true), Field::new("large_utf8_col", DataType::LargeUtf8, true), + Field::new("view_utf8_col", DataType::Utf8View, true), Field::new_list("list_col", Field::new("item", DataType::Int64, true), true), Field::new_list( "large_list_col",