diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index c4c92be1525d3..b78f32e0ac486 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -21,7 +21,6 @@ use std::collections::HashMap; use std::fs::File; use std::io::prelude::*; use std::io::BufReader; -use std::str::FromStr; use crate::cli_context::CliSessionContext; use crate::helper::split_from_semicolon; @@ -35,6 +34,7 @@ use crate::{ use datafusion::common::instant::Instant; use datafusion::common::plan_datafusion_err; +use datafusion::config::ConfigFileType; use datafusion::datasource::listing::ListingTableUrl; use datafusion::error::{DataFusionError, Result}; use datafusion::logical_expr::{DdlStatement, LogicalPlan}; @@ -42,7 +42,6 @@ use datafusion::physical_plan::{collect, execute_stream, ExecutionPlanProperties use datafusion::sql::parser::{DFParser, Statement}; use datafusion::sql::sqlparser::dialect::dialect_from_str; -use datafusion::common::FileType; use datafusion::sql::sqlparser; use rustyline::error::ReadlineError; use rustyline::Editor; @@ -291,6 +290,15 @@ impl AdjustedPrintOptions { } } +fn config_file_type_from_str(ext: &str) -> Option { + match ext.to_lowercase().as_str() { + "csv" => Some(ConfigFileType::CSV), + "json" => Some(ConfigFileType::JSON), + "parquet" => Some(ConfigFileType::PARQUET), + _ => None, + } +} + async fn create_plan( ctx: &mut dyn CliSessionContext, statement: Statement, @@ -302,7 +310,7 @@ async fn create_plan( // will raise Configuration errors. if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan { // To support custom formats, treat error as None - let format = FileType::from_str(&cmd.file_type).ok(); + let format = config_file_type_from_str(&cmd.file_type); register_object_store_and_config_extensions( ctx, &cmd.location, @@ -313,13 +321,13 @@ async fn create_plan( } if let LogicalPlan::Copy(copy_to) = &mut plan { - let format: FileType = (©_to.format_options).into(); + let format = config_file_type_from_str(©_to.file_type.get_ext()); register_object_store_and_config_extensions( ctx, ©_to.output_url, ©_to.options, - Some(format), + format, ) .await?; } @@ -357,7 +365,7 @@ pub(crate) async fn register_object_store_and_config_extensions( ctx: &dyn CliSessionContext, location: &String, options: &HashMap, - format: Option, + format: Option, ) -> Result<()> { // Parse the location URL to extract the scheme and other components let table_path = ListingTableUrl::parse(location)?; @@ -374,7 +382,7 @@ pub(crate) async fn register_object_store_and_config_extensions( // Clone and modify the default table options based on the provided options let mut table_options = ctx.session_state().default_table_options().clone(); if let Some(format) = format { - table_options.set_file_format(format); + table_options.set_config_format(format); } table_options.alter_with_string_hash_map(options)?; @@ -392,7 +400,6 @@ pub(crate) async fn register_object_store_and_config_extensions( mod tests { use super::*; - use datafusion::common::config::FormatOptions; use datafusion::common::plan_err; use datafusion::prelude::SessionContext; @@ -403,7 +410,7 @@ mod tests { let plan = ctx.state().create_logical_plan(sql).await?; if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan { - let format = FileType::from_str(&cmd.file_type).ok(); + let format = config_file_type_from_str(&cmd.file_type); register_object_store_and_config_extensions( &ctx, &cmd.location, @@ -429,12 +436,12 @@ mod tests { let plan = ctx.state().create_logical_plan(sql).await?; if let LogicalPlan::Copy(cmd) = &plan { - let format: FileType = (&cmd.format_options).into(); + let format = config_file_type_from_str(&cmd.file_type.get_ext()); register_object_store_and_config_extensions( &ctx, &cmd.output_url, &cmd.options, - Some(format), + format, ) .await?; } else { @@ -484,7 +491,7 @@ mod tests { let mut plan = create_plan(&mut ctx, statement).await?; if let LogicalPlan::Copy(copy_to) = &mut plan { assert_eq!(copy_to.output_url, location); - assert!(matches!(copy_to.format_options, FormatOptions::PARQUET(_))); + assert_eq!(copy_to.file_type.get_ext(), "parquet".to_string()); ctx.runtime_env() .object_store_registry .get_store(&Url::parse(©_to.output_url).unwrap())?; diff --git a/datafusion-examples/examples/external_dependency/dataframe-to-s3.rs b/datafusion-examples/examples/external_dependency/dataframe-to-s3.rs index 4d71ed7589121..e75ba5dd5328a 100644 --- a/datafusion-examples/examples/external_dependency/dataframe-to-s3.rs +++ b/datafusion-examples/examples/external_dependency/dataframe-to-s3.rs @@ -20,10 +20,10 @@ use std::sync::Arc; use datafusion::dataframe::DataFrameWriteOptions; use datafusion::datasource::file_format::parquet::ParquetFormat; +use datafusion::datasource::file_format::FileFormat; use datafusion::datasource::listing::ListingOptions; use datafusion::error::Result; use datafusion::prelude::*; -use datafusion_common::{FileType, GetExt}; use object_store::aws::AmazonS3Builder; use url::Url; @@ -54,7 +54,7 @@ async fn main() -> Result<()> { let path = format!("s3://{bucket_name}/test_data/"); let file_format = ParquetFormat::default().with_enable_pruning(true); let listing_options = ListingOptions::new(Arc::new(file_format)) - .with_file_extension(FileType::PARQUET.get_ext()); + .with_file_extension(ParquetFormat::default().get_ext()); ctx.register_listing_table("test", &path, listing_options, None, None) .await?; @@ -79,7 +79,7 @@ async fn main() -> Result<()> { let file_format = ParquetFormat::default().with_enable_pruning(true); let listing_options = ListingOptions::new(Arc::new(file_format)) - .with_file_extension(FileType::PARQUET.get_ext()); + .with_file_extension(ParquetFormat::default().get_ext()); ctx.register_listing_table("test2", &out_path, listing_options, None, None) .await?; diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 47da14574c5db..b90aeffb07695 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -24,7 +24,7 @@ use std::str::FromStr; use crate::error::_config_err; use crate::parsers::CompressionTypeVariant; -use crate::{DataFusionError, FileType, Result}; +use crate::{DataFusionError, Result}; /// A macro that wraps a configuration struct and automatically derives /// [`Default`] and [`ConfigField`] for it, allowing it to be used @@ -1116,6 +1116,16 @@ macro_rules! extensions_options { } } +/// These file types have special built in behavior for configuration. +/// Use TableOptions::Extensions for configuring other file types. +#[derive(Debug, Clone)] +pub enum ConfigFileType { + CSV, + #[cfg(feature = "parquet")] + PARQUET, + JSON, +} + /// Represents the configuration options available for handling different table formats within a data processing application. /// This struct encompasses options for various file formats including CSV, Parquet, and JSON, allowing for flexible configuration /// of parsing and writing behaviors specific to each format. Additionally, it supports extending functionality through custom extensions. @@ -1134,7 +1144,7 @@ pub struct TableOptions { /// The current file format that the table operations should assume. This option allows /// for dynamic switching between the supported file types (e.g., CSV, Parquet, JSON). - pub current_format: Option, + pub current_format: Option, /// Optional extensions that can be used to extend or customize the behavior of the table /// options. Extensions can be registered using `Extensions::insert` and might include @@ -1152,10 +1162,9 @@ impl ConfigField for TableOptions { if let Some(file_type) = &self.current_format { match file_type { #[cfg(feature = "parquet")] - FileType::PARQUET => self.parquet.visit(v, "format", ""), - FileType::CSV => self.csv.visit(v, "format", ""), - FileType::JSON => self.json.visit(v, "format", ""), - _ => {} + ConfigFileType::PARQUET => self.parquet.visit(v, "format", ""), + ConfigFileType::CSV => self.csv.visit(v, "format", ""), + ConfigFileType::JSON => self.json.visit(v, "format", ""), } } else { self.csv.visit(v, "csv", ""); @@ -1188,12 +1197,9 @@ impl ConfigField for TableOptions { match key { "format" => match format { #[cfg(feature = "parquet")] - FileType::PARQUET => self.parquet.set(rem, value), - FileType::CSV => self.csv.set(rem, value), - FileType::JSON => self.json.set(rem, value), - _ => { - _config_err!("Config value \"{key}\" is not supported on {}", format) - } + ConfigFileType::PARQUET => self.parquet.set(rem, value), + ConfigFileType::CSV => self.csv.set(rem, value), + ConfigFileType::JSON => self.json.set(rem, value), }, _ => _config_err!("Config value \"{key}\" not found on TableOptions"), } @@ -1210,15 +1216,6 @@ impl TableOptions { Self::default() } - /// Sets the file format for the table. - /// - /// # Parameters - /// - /// * `format`: The file format to use (e.g., CSV, Parquet). - pub fn set_file_format(&mut self, format: FileType) { - self.current_format = Some(format); - } - /// Creates a new `TableOptions` instance initialized with settings from a given session config. /// /// # Parameters @@ -1249,6 +1246,15 @@ impl TableOptions { clone } + /// Sets the file format for the table. + /// + /// # Parameters + /// + /// * `format`: The file format to use (e.g., CSV, Parquet). + pub fn set_config_format(&mut self, format: ConfigFileType) { + self.current_format = Some(format); + } + /// Sets the extensions for this `TableOptions` instance. /// /// # Parameters @@ -1673,6 +1679,8 @@ config_namespace! { } } +pub trait FormatOptionsExt: Display {} + #[derive(Debug, Clone, PartialEq)] #[allow(clippy::large_enum_variant)] pub enum FormatOptions { @@ -1698,28 +1706,15 @@ impl Display for FormatOptions { } } -impl From for FormatOptions { - fn from(value: FileType) -> Self { - match value { - FileType::ARROW => FormatOptions::ARROW, - FileType::AVRO => FormatOptions::AVRO, - #[cfg(feature = "parquet")] - FileType::PARQUET => FormatOptions::PARQUET(TableParquetOptions::default()), - FileType::CSV => FormatOptions::CSV(CsvOptions::default()), - FileType::JSON => FormatOptions::JSON(JsonOptions::default()), - } - } -} - #[cfg(test)] mod tests { use std::any::Any; use std::collections::HashMap; use crate::config::{ - ConfigEntry, ConfigExtension, ExtensionOptions, Extensions, TableOptions, + ConfigEntry, ConfigExtension, ConfigFileType, ExtensionOptions, Extensions, + TableOptions, }; - use crate::FileType; #[derive(Default, Debug, Clone)] pub struct TestExtensionConfig { @@ -1777,7 +1772,7 @@ mod tests { let mut extension = Extensions::new(); extension.insert(TestExtensionConfig::default()); let mut table_config = TableOptions::new().with_extensions(extension); - table_config.set_file_format(FileType::CSV); + table_config.set_config_format(ConfigFileType::CSV); table_config.set("format.delimiter", ";").unwrap(); assert_eq!(table_config.csv.delimiter, b';'); table_config.set("test.bootstrap.servers", "asd").unwrap(); @@ -1794,7 +1789,7 @@ mod tests { #[test] fn csv_u8_table_options() { let mut table_config = TableOptions::new(); - table_config.set_file_format(FileType::CSV); + table_config.set_config_format(ConfigFileType::CSV); table_config.set("format.delimiter", ";").unwrap(); assert_eq!(table_config.csv.delimiter as char, ';'); table_config.set("format.escape", "\"").unwrap(); @@ -1807,7 +1802,7 @@ mod tests { #[test] fn parquet_table_options() { let mut table_config = TableOptions::new(); - table_config.set_file_format(FileType::PARQUET); + table_config.set_config_format(ConfigFileType::PARQUET); table_config .set("format.bloom_filter_enabled::col1", "true") .unwrap(); @@ -1821,7 +1816,7 @@ mod tests { #[test] fn parquet_table_options_config_entry() { let mut table_config = TableOptions::new(); - table_config.set_file_format(FileType::PARQUET); + table_config.set_config_format(ConfigFileType::PARQUET); table_config .set("format.bloom_filter_enabled::col1", "true") .unwrap(); @@ -1835,7 +1830,7 @@ mod tests { #[test] fn parquet_table_options_config_metadata_entry() { let mut table_config = TableOptions::new(); - table_config.set_file_format(FileType::PARQUET); + table_config.set_config_format(ConfigFileType::PARQUET); table_config.set("format.metadata::key1", "").unwrap(); table_config.set("format.metadata::key2", "value2").unwrap(); table_config diff --git a/datafusion/common/src/file_options/file_type.rs b/datafusion/common/src/file_options/file_type.rs index fc0bb74456450..2648f72897982 100644 --- a/datafusion/common/src/file_options/file_type.rs +++ b/datafusion/common/src/file_options/file_type.rs @@ -17,11 +17,8 @@ //! File type abstraction -use std::fmt::{self, Display}; -use std::str::FromStr; - -use crate::config::FormatOptions; -use crate::error::{DataFusionError, Result}; +use std::any::Any; +use std::fmt::Display; /// The default file extension of arrow files pub const DEFAULT_ARROW_EXTENSION: &str = ".arrow"; @@ -40,107 +37,10 @@ pub trait GetExt { fn get_ext(&self) -> String; } -/// Readable file type -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum FileType { - /// Apache Arrow file - ARROW, - /// Apache Avro file - AVRO, - /// Apache Parquet file - #[cfg(feature = "parquet")] - PARQUET, - /// CSV file - CSV, - /// JSON file - JSON, -} - -impl From<&FormatOptions> for FileType { - fn from(value: &FormatOptions) -> Self { - match value { - FormatOptions::CSV(_) => FileType::CSV, - FormatOptions::JSON(_) => FileType::JSON, - #[cfg(feature = "parquet")] - FormatOptions::PARQUET(_) => FileType::PARQUET, - FormatOptions::AVRO => FileType::AVRO, - FormatOptions::ARROW => FileType::ARROW, - } - } -} - -impl GetExt for FileType { - fn get_ext(&self) -> String { - match self { - FileType::ARROW => DEFAULT_ARROW_EXTENSION.to_owned(), - FileType::AVRO => DEFAULT_AVRO_EXTENSION.to_owned(), - #[cfg(feature = "parquet")] - FileType::PARQUET => DEFAULT_PARQUET_EXTENSION.to_owned(), - FileType::CSV => DEFAULT_CSV_EXTENSION.to_owned(), - FileType::JSON => DEFAULT_JSON_EXTENSION.to_owned(), - } - } -} - -impl Display for FileType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let out = match self { - FileType::CSV => "csv", - FileType::JSON => "json", - #[cfg(feature = "parquet")] - FileType::PARQUET => "parquet", - FileType::AVRO => "avro", - FileType::ARROW => "arrow", - }; - write!(f, "{}", out) - } -} - -impl FromStr for FileType { - type Err = DataFusionError; - - fn from_str(s: &str) -> Result { - let s = s.to_uppercase(); - match s.as_str() { - "ARROW" => Ok(FileType::ARROW), - "AVRO" => Ok(FileType::AVRO), - #[cfg(feature = "parquet")] - "PARQUET" => Ok(FileType::PARQUET), - "CSV" => Ok(FileType::CSV), - "JSON" | "NDJSON" => Ok(FileType::JSON), - _ => Err(DataFusionError::NotImplemented(format!( - "Unknown FileType: {s}" - ))), - } - } -} - -#[cfg(test)] -#[cfg(feature = "parquet")] -mod tests { - use std::str::FromStr; - - use crate::error::DataFusionError; - use crate::FileType; - - #[test] - fn from_str() { - for (ext, file_type) in [ - ("csv", FileType::CSV), - ("CSV", FileType::CSV), - ("json", FileType::JSON), - ("JSON", FileType::JSON), - ("avro", FileType::AVRO), - ("AVRO", FileType::AVRO), - ("parquet", FileType::PARQUET), - ("PARQUET", FileType::PARQUET), - ] { - assert_eq!(FileType::from_str(ext).unwrap(), file_type); - } - - assert!(matches!( - FileType::from_str("Unknown"), - Err(DataFusionError::NotImplemented(_)) - )); - } +/// Defines the functionality needed for logical planning for +/// a type of file which will be read or written to storage. +pub trait FileType: GetExt + Display + Send + Sync { + /// Returns the table source as [`Any`] so that it can be + /// downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; } diff --git a/datafusion/common/src/file_options/mod.rs b/datafusion/common/src/file_options/mod.rs index 59040b4290b07..77781457d0d2d 100644 --- a/datafusion/common/src/file_options/mod.rs +++ b/datafusion/common/src/file_options/mod.rs @@ -32,10 +32,10 @@ mod tests { use super::parquet_writer::ParquetWriterOptions; use crate::{ - config::TableOptions, + config::{ConfigFileType, TableOptions}, file_options::{csv_writer::CsvWriterOptions, json_writer::JsonWriterOptions}, parsers::CompressionTypeVariant, - FileType, Result, + Result, }; use parquet::{ @@ -76,7 +76,7 @@ mod tests { option_map.insert("format.bloom_filter_ndv".to_owned(), "123".to_owned()); let mut table_config = TableOptions::new(); - table_config.set_file_format(FileType::PARQUET); + table_config.set_config_format(ConfigFileType::PARQUET); table_config.alter_with_string_hash_map(&option_map)?; let parquet_options = ParquetWriterOptions::try_from(&table_config.parquet)?; @@ -181,7 +181,7 @@ mod tests { ); let mut table_config = TableOptions::new(); - table_config.set_file_format(FileType::PARQUET); + table_config.set_config_format(ConfigFileType::PARQUET); table_config.alter_with_string_hash_map(&option_map)?; let parquet_options = ParquetWriterOptions::try_from(&table_config.parquet)?; @@ -284,7 +284,7 @@ mod tests { option_map.insert("format.delimiter".to_owned(), ";".to_owned()); let mut table_config = TableOptions::new(); - table_config.set_file_format(FileType::CSV); + table_config.set_config_format(ConfigFileType::CSV); table_config.alter_with_string_hash_map(&option_map)?; let csv_options = CsvWriterOptions::try_from(&table_config.csv)?; @@ -306,7 +306,7 @@ mod tests { option_map.insert("format.compression".to_owned(), "gzip".to_owned()); let mut table_config = TableOptions::new(); - table_config.set_file_format(FileType::JSON); + table_config.set_config_format(ConfigFileType::JSON); table_config.alter_with_string_hash_map(&option_map)?; let json_options = JsonWriterOptions::try_from(&table_config.json)?; diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index e64acd0bfefe7..c275152642f0e 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -54,8 +54,8 @@ pub use error::{ SharedResult, }; pub use file_options::file_type::{ - FileType, GetExt, DEFAULT_ARROW_EXTENSION, DEFAULT_AVRO_EXTENSION, - DEFAULT_CSV_EXTENSION, DEFAULT_JSON_EXTENSION, DEFAULT_PARQUET_EXTENSION, + GetExt, DEFAULT_ARROW_EXTENSION, DEFAULT_AVRO_EXTENSION, DEFAULT_CSV_EXTENSION, + DEFAULT_JSON_EXTENSION, DEFAULT_PARQUET_EXTENSION, }; pub use functional_dependencies::{ aggregate_functional_dependencies, get_required_group_by_exprs_indices, diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 86e510969b33a..8e55da8c3ad07 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -26,6 +26,9 @@ use std::sync::Arc; use crate::arrow::record_batch::RecordBatch; use crate::arrow::util::pretty; +use crate::datasource::file_format::csv::CsvFormatFactory; +use crate::datasource::file_format::format_as_file_type; +use crate::datasource::file_format::json::JsonFormatFactory; use crate::datasource::{provider_as_source, MemTable, TableProvider}; use crate::error::Result; use crate::execution::context::{SessionState, TaskContext}; @@ -44,7 +47,7 @@ use arrow::array::{Array, ArrayRef, Int64Array, StringArray}; use arrow::compute::{cast, concat}; use arrow::datatypes::{DataType, Field}; use arrow_schema::{Schema, SchemaRef}; -use datafusion_common::config::{CsvOptions, FormatOptions, JsonOptions}; +use datafusion_common::config::{CsvOptions, JsonOptions}; use datafusion_common::{ plan_err, Column, DFSchema, DataFusionError, ParamValues, SchemaError, UnnestOptions, }; @@ -1329,13 +1332,19 @@ impl DataFrame { "Overwrites are not implemented for DataFrame::write_csv.".to_owned(), )); } - let props = writer_options - .unwrap_or_else(|| self.session_state.default_table_options().csv); + + let format = if let Some(csv_opts) = writer_options { + Arc::new(CsvFormatFactory::new_with_options(csv_opts)) + } else { + Arc::new(CsvFormatFactory::new()) + }; + + let file_type = format_as_file_type(format); let plan = LogicalPlanBuilder::copy_to( self.plan, path.into(), - FormatOptions::CSV(props), + file_type, HashMap::new(), options.partition_by, )? @@ -1384,13 +1393,18 @@ impl DataFrame { )); } - let props = writer_options - .unwrap_or_else(|| self.session_state.default_table_options().json); + let format = if let Some(json_opts) = writer_options { + Arc::new(JsonFormatFactory::new_with_options(json_opts)) + } else { + Arc::new(JsonFormatFactory::new()) + }; + + let file_type = format_as_file_type(format); let plan = LogicalPlanBuilder::copy_to( self.plan, path.into(), - FormatOptions::JSON(props), + file_type, Default::default(), options.partition_by, )? diff --git a/datafusion/core/src/dataframe/parquet.rs b/datafusion/core/src/dataframe/parquet.rs index 0ec46df0ae5d3..1abb550f5c98c 100644 --- a/datafusion/core/src/dataframe/parquet.rs +++ b/datafusion/core/src/dataframe/parquet.rs @@ -15,11 +15,17 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + +use crate::datasource::file_format::{ + format_as_file_type, parquet::ParquetFormatFactory, +}; + use super::{ DataFrame, DataFrameWriteOptions, DataFusionError, LogicalPlanBuilder, RecordBatch, }; -use datafusion_common::config::{FormatOptions, TableParquetOptions}; +use datafusion_common::config::TableParquetOptions; impl DataFrame { /// Execute the `DataFrame` and write the results to Parquet file(s). @@ -57,13 +63,18 @@ impl DataFrame { )); } - let props = writer_options - .unwrap_or_else(|| self.session_state.default_table_options().parquet); + let format = if let Some(parquet_opts) = writer_options { + Arc::new(ParquetFormatFactory::new_with_options(parquet_opts)) + } else { + Arc::new(ParquetFormatFactory::new()) + }; + + let file_type = format_as_file_type(format); let plan = LogicalPlanBuilder::copy_to( self.plan, path.into(), - FormatOptions::PARQUET(props), + file_type, Default::default(), options.partition_by, )? diff --git a/datafusion/core/src/datasource/file_format/arrow.rs b/datafusion/core/src/datasource/file_format/arrow.rs index 8c6790541597a..478a11d7e76e9 100644 --- a/datafusion/core/src/datasource/file_format/arrow.rs +++ b/datafusion/core/src/datasource/file_format/arrow.rs @@ -21,12 +21,14 @@ use std::any::Any; use std::borrow::Cow; +use std::collections::HashMap; use std::fmt::{self, Debug}; use std::sync::Arc; use super::file_compression_type::FileCompressionType; use super::write::demux::start_demuxer_task; use super::write::{create_writer, SharedBuffer}; +use super::FileFormatFactory; use crate::datasource::file_format::FileFormat; use crate::datasource::physical_plan::{ ArrowExec, FileGroupDisplay, FileScanConfig, FileSinkConfig, @@ -40,7 +42,10 @@ use arrow::ipc::reader::FileReader; use arrow::ipc::writer::IpcWriteOptions; use arrow::ipc::{root_as_message, CompressionType}; use arrow_schema::{ArrowError, Schema, SchemaRef}; -use datafusion_common::{not_impl_err, DataFusionError, Statistics}; +use datafusion_common::parsers::CompressionTypeVariant; +use datafusion_common::{ + not_impl_err, DataFusionError, GetExt, Statistics, DEFAULT_ARROW_EXTENSION, +}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; use datafusion_physical_plan::insert::{DataSink, DataSinkExec}; @@ -61,6 +66,38 @@ const INITIAL_BUFFER_BYTES: usize = 1048576; /// If the buffered Arrow data exceeds this size, it is flushed to object store const BUFFER_FLUSH_BYTES: usize = 1024000; +#[derive(Default)] +/// Factory struct used to create [ArrowFormat] +pub struct ArrowFormatFactory; + +impl ArrowFormatFactory { + /// Creates an instance of [ArrowFormatFactory] + pub fn new() -> Self { + Self {} + } +} + +impl FileFormatFactory for ArrowFormatFactory { + fn create( + &self, + _state: &SessionState, + _format_options: &HashMap, + ) -> Result> { + Ok(Arc::new(ArrowFormat)) + } + + fn default(&self) -> Arc { + Arc::new(ArrowFormat) + } +} + +impl GetExt for ArrowFormatFactory { + fn get_ext(&self) -> String { + // Removes the dot, i.e. ".parquet" -> "parquet" + DEFAULT_ARROW_EXTENSION[1..].to_string() + } +} + /// Arrow `FileFormat` implementation. #[derive(Default, Debug)] pub struct ArrowFormat; @@ -71,6 +108,23 @@ impl FileFormat for ArrowFormat { self } + fn get_ext(&self) -> String { + ArrowFormatFactory::new().get_ext() + } + + fn get_ext_with_compression( + &self, + file_compression_type: &FileCompressionType, + ) -> Result { + let ext = self.get_ext(); + match file_compression_type.get_variant() { + CompressionTypeVariant::UNCOMPRESSED => Ok(ext), + _ => Err(DataFusionError::Internal( + "Arrow FileFormat does not support compression.".into(), + )), + } + } + async fn infer_schema( &self, _state: &SessionState, diff --git a/datafusion/core/src/datasource/file_format/avro.rs b/datafusion/core/src/datasource/file_format/avro.rs index 7b2c26a2c4f9b..f4f9adcba7ed8 100644 --- a/datafusion/core/src/datasource/file_format/avro.rs +++ b/datafusion/core/src/datasource/file_format/avro.rs @@ -18,15 +18,22 @@ //! [`AvroFormat`] Apache Avro [`FileFormat`] abstractions use std::any::Any; +use std::collections::HashMap; use std::sync::Arc; use arrow::datatypes::Schema; use arrow::datatypes::SchemaRef; use async_trait::async_trait; +use datafusion_common::parsers::CompressionTypeVariant; +use datafusion_common::DataFusionError; +use datafusion_common::GetExt; +use datafusion_common::DEFAULT_AVRO_EXTENSION; use datafusion_physical_expr::PhysicalExpr; use object_store::{GetResultPayload, ObjectMeta, ObjectStore}; +use super::file_compression_type::FileCompressionType; use super::FileFormat; +use super::FileFormatFactory; use crate::datasource::avro_to_arrow::read_avro_schema_from_reader; use crate::datasource::physical_plan::{AvroExec, FileScanConfig}; use crate::error::Result; @@ -34,6 +41,38 @@ use crate::execution::context::SessionState; use crate::physical_plan::ExecutionPlan; use crate::physical_plan::Statistics; +#[derive(Default)] +/// Factory struct used to create [AvroFormat] +pub struct AvroFormatFactory; + +impl AvroFormatFactory { + /// Creates an instance of [AvroFormatFactory] + pub fn new() -> Self { + Self {} + } +} + +impl FileFormatFactory for AvroFormatFactory { + fn create( + &self, + _state: &SessionState, + _format_options: &HashMap, + ) -> Result> { + Ok(Arc::new(AvroFormat)) + } + + fn default(&self) -> Arc { + Arc::new(AvroFormat) + } +} + +impl GetExt for AvroFormatFactory { + fn get_ext(&self) -> String { + // Removes the dot, i.e. ".parquet" -> "parquet" + DEFAULT_AVRO_EXTENSION[1..].to_string() + } +} + /// Avro `FileFormat` implementation. #[derive(Default, Debug)] pub struct AvroFormat; @@ -44,6 +83,23 @@ impl FileFormat for AvroFormat { self } + fn get_ext(&self) -> String { + AvroFormatFactory::new().get_ext() + } + + fn get_ext_with_compression( + &self, + file_compression_type: &FileCompressionType, + ) -> Result { + let ext = self.get_ext(); + match file_compression_type.get_variant() { + CompressionTypeVariant::UNCOMPRESSED => Ok(ext), + _ => Err(DataFusionError::Internal( + "Avro FileFormat does not support compression.".into(), + )), + } + } + async fn infer_schema( &self, _state: &SessionState, diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index 2139b35621f22..92cb11e2b47a4 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -18,12 +18,12 @@ //! [`CsvFormat`], Comma Separated Value (CSV) [`FileFormat`] abstractions use std::any::Any; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::fmt::{self, Debug}; use std::sync::Arc; use super::write::orchestration::stateless_multipart_put; -use super::FileFormat; +use super::{FileFormat, FileFormatFactory}; use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::datasource::file_format::write::BatchSerializer; use crate::datasource::physical_plan::{ @@ -40,9 +40,11 @@ use arrow::array::RecordBatch; use arrow::csv::WriterBuilder; use arrow::datatypes::SchemaRef; use arrow::datatypes::{DataType, Field, Fields, Schema}; -use datafusion_common::config::CsvOptions; +use datafusion_common::config::{ConfigField, ConfigFileType, CsvOptions}; use datafusion_common::file_options::csv_writer::CsvWriterOptions; -use datafusion_common::{exec_err, not_impl_err, DataFusionError}; +use datafusion_common::{ + exec_err, not_impl_err, DataFusionError, GetExt, DEFAULT_CSV_EXTENSION, +}; use datafusion_execution::TaskContext; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; use datafusion_physical_plan::metrics::MetricsSet; @@ -53,6 +55,63 @@ use futures::stream::BoxStream; use futures::{pin_mut, Stream, StreamExt, TryStreamExt}; use object_store::{delimited::newline_delimited_stream, ObjectMeta, ObjectStore}; +#[derive(Default)] +/// Factory struct used to create [CsvFormatFactory] +pub struct CsvFormatFactory { + options: Option, +} + +impl CsvFormatFactory { + /// Creates an instance of [CsvFormatFactory] + pub fn new() -> Self { + Self { options: None } + } + + /// Creates an instance of [CsvFormatFactory] with customized default options + pub fn new_with_options(options: CsvOptions) -> Self { + Self { + options: Some(options), + } + } +} + +impl FileFormatFactory for CsvFormatFactory { + fn create( + &self, + state: &SessionState, + format_options: &HashMap, + ) -> Result> { + let csv_options = match &self.options { + None => { + let mut table_options = state.default_table_options(); + table_options.set_config_format(ConfigFileType::CSV); + table_options.alter_with_string_hash_map(format_options)?; + table_options.csv + } + Some(csv_options) => { + let mut csv_options = csv_options.clone(); + for (k, v) in format_options { + csv_options.set(k, v)?; + } + csv_options + } + }; + + Ok(Arc::new(CsvFormat::default().with_options(csv_options))) + } + + fn default(&self) -> Arc { + Arc::new(CsvFormat::default()) + } +} + +impl GetExt for CsvFormatFactory { + fn get_ext(&self) -> String { + // Removes the dot, i.e. ".parquet" -> "parquet" + DEFAULT_CSV_EXTENSION[1..].to_string() + } +} + /// Character Separated Value `FileFormat` implementation. #[derive(Debug, Default)] pub struct CsvFormat { @@ -206,6 +265,18 @@ impl FileFormat for CsvFormat { self } + fn get_ext(&self) -> String { + CsvFormatFactory::new().get_ext() + } + + fn get_ext_with_compression( + &self, + file_compression_type: &FileCompressionType, + ) -> Result { + let ext = self.get_ext(); + Ok(format!("{}{}", ext, file_compression_type.get_ext())) + } + async fn infer_schema( &self, state: &SessionState, @@ -558,7 +629,6 @@ mod tests { use datafusion_common::cast::as_string_array; use datafusion_common::internal_err; use datafusion_common::stats::Precision; - use datafusion_common::{FileType, GetExt}; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_expr::{col, lit}; @@ -1060,9 +1130,9 @@ mod tests { .with_repartition_file_min_size(0) .with_target_partitions(n_partitions); let ctx = SessionContext::new_with_config(config); - let file_format = CsvFormat::default().with_has_header(false); - let listing_options = ListingOptions::new(Arc::new(file_format)) - .with_file_extension(FileType::CSV.get_ext()); + let file_format = Arc::new(CsvFormat::default().with_has_header(false)); + let listing_options = ListingOptions::new(file_format.clone()) + .with_file_extension(file_format.get_ext()); ctx.register_listing_table( "empty", "tests/data/empty_files/all_empty/", @@ -1113,9 +1183,9 @@ mod tests { .with_repartition_file_min_size(0) .with_target_partitions(n_partitions); let ctx = SessionContext::new_with_config(config); - let file_format = CsvFormat::default().with_has_header(false); - let listing_options = ListingOptions::new(Arc::new(file_format)) - .with_file_extension(FileType::CSV.get_ext()); + let file_format = Arc::new(CsvFormat::default().with_has_header(false)); + let listing_options = ListingOptions::new(file_format.clone()) + .with_file_extension(file_format.get_ext()); ctx.register_listing_table( "empty", "tests/data/empty_files/some_empty", diff --git a/datafusion/core/src/datasource/file_format/file_compression_type.rs b/datafusion/core/src/datasource/file_format/file_compression_type.rs index c1fbe352d37bf..a054094822d01 100644 --- a/datafusion/core/src/datasource/file_format/file_compression_type.rs +++ b/datafusion/core/src/datasource/file_format/file_compression_type.rs @@ -22,7 +22,7 @@ use std::str::FromStr; use crate::error::{DataFusionError, Result}; use datafusion_common::parsers::CompressionTypeVariant::{self, *}; -use datafusion_common::{FileType, GetExt}; +use datafusion_common::GetExt; #[cfg(feature = "compression")] use async_compression::tokio::bufread::{ @@ -112,6 +112,11 @@ impl FileCompressionType { variant: UNCOMPRESSED, }; + /// Read only access to self.variant + pub fn get_variant(&self) -> &CompressionTypeVariant { + &self.variant + } + /// The file is compressed or not pub const fn is_compressed(&self) -> bool { self.variant.is_compressed() @@ -245,90 +250,16 @@ pub trait FileTypeExt { fn get_ext_with_compression(&self, c: FileCompressionType) -> Result; } -impl FileTypeExt for FileType { - fn get_ext_with_compression(&self, c: FileCompressionType) -> Result { - let ext = self.get_ext(); - - match self { - FileType::JSON | FileType::CSV => Ok(format!("{}{}", ext, c.get_ext())), - FileType::AVRO | FileType::ARROW => match c.variant { - UNCOMPRESSED => Ok(ext), - _ => Err(DataFusionError::Internal( - "FileCompressionType can be specified for CSV/JSON FileType.".into(), - )), - }, - #[cfg(feature = "parquet")] - FileType::PARQUET => match c.variant { - UNCOMPRESSED => Ok(ext), - _ => Err(DataFusionError::Internal( - "FileCompressionType can be specified for CSV/JSON FileType.".into(), - )), - }, - } - } -} - #[cfg(test)] mod tests { use std::str::FromStr; - use crate::datasource::file_format::file_compression_type::{ - FileCompressionType, FileTypeExt, - }; + use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::error::DataFusionError; - use datafusion_common::file_options::file_type::FileType; - use bytes::Bytes; use futures::StreamExt; - #[test] - fn get_ext_with_compression() { - for (file_type, compression, extension) in [ - (FileType::CSV, FileCompressionType::UNCOMPRESSED, ".csv"), - (FileType::CSV, FileCompressionType::GZIP, ".csv.gz"), - (FileType::CSV, FileCompressionType::XZ, ".csv.xz"), - (FileType::CSV, FileCompressionType::BZIP2, ".csv.bz2"), - (FileType::CSV, FileCompressionType::ZSTD, ".csv.zst"), - (FileType::JSON, FileCompressionType::UNCOMPRESSED, ".json"), - (FileType::JSON, FileCompressionType::GZIP, ".json.gz"), - (FileType::JSON, FileCompressionType::XZ, ".json.xz"), - (FileType::JSON, FileCompressionType::BZIP2, ".json.bz2"), - (FileType::JSON, FileCompressionType::ZSTD, ".json.zst"), - ] { - assert_eq!( - file_type.get_ext_with_compression(compression).unwrap(), - extension - ); - } - - let mut ty_ext_tuple = vec![]; - ty_ext_tuple.push((FileType::AVRO, ".avro")); - #[cfg(feature = "parquet")] - ty_ext_tuple.push((FileType::PARQUET, ".parquet")); - - // Cannot specify compression for these file types - for (file_type, extension) in ty_ext_tuple { - assert_eq!( - file_type - .get_ext_with_compression(FileCompressionType::UNCOMPRESSED) - .unwrap(), - extension - ); - for compression in [ - FileCompressionType::GZIP, - FileCompressionType::XZ, - FileCompressionType::BZIP2, - FileCompressionType::ZSTD, - ] { - assert!(matches!( - file_type.get_ext_with_compression(compression), - Err(DataFusionError::Internal(_)) - )); - } - } - } - #[test] fn from_str() { for (ext, compression_type) in [ diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index d5347c498c71f..007b084f504dd 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -18,13 +18,14 @@ //! [`JsonFormat`]: Line delimited JSON [`FileFormat`] abstractions use std::any::Any; +use std::collections::HashMap; use std::fmt; use std::fmt::Debug; use std::io::BufReader; use std::sync::Arc; use super::write::orchestration::stateless_multipart_put; -use super::{FileFormat, FileScanConfig}; +use super::{FileFormat, FileFormatFactory, FileScanConfig}; use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::datasource::file_format::write::BatchSerializer; use crate::datasource::physical_plan::FileGroupDisplay; @@ -41,9 +42,9 @@ use arrow::datatypes::SchemaRef; use arrow::json; use arrow::json::reader::{infer_json_schema_from_iterator, ValueIter}; use arrow_array::RecordBatch; -use datafusion_common::config::JsonOptions; +use datafusion_common::config::{ConfigField, ConfigFileType, JsonOptions}; use datafusion_common::file_options::json_writer::JsonWriterOptions; -use datafusion_common::not_impl_err; +use datafusion_common::{not_impl_err, GetExt, DEFAULT_JSON_EXTENSION}; use datafusion_execution::TaskContext; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; use datafusion_physical_plan::metrics::MetricsSet; @@ -53,6 +54,63 @@ use async_trait::async_trait; use bytes::{Buf, Bytes}; use object_store::{GetResultPayload, ObjectMeta, ObjectStore}; +#[derive(Default)] +/// Factory struct used to create [JsonFormat] +pub struct JsonFormatFactory { + options: Option, +} + +impl JsonFormatFactory { + /// Creates an instance of [JsonFormatFactory] + pub fn new() -> Self { + Self { options: None } + } + + /// Creates an instance of [JsonFormatFactory] with customized default options + pub fn new_with_options(options: JsonOptions) -> Self { + Self { + options: Some(options), + } + } +} + +impl FileFormatFactory for JsonFormatFactory { + fn create( + &self, + state: &SessionState, + format_options: &HashMap, + ) -> Result> { + let json_options = match &self.options { + None => { + let mut table_options = state.default_table_options(); + table_options.set_config_format(ConfigFileType::JSON); + table_options.alter_with_string_hash_map(format_options)?; + table_options.json + } + Some(json_options) => { + let mut json_options = json_options.clone(); + for (k, v) in format_options { + json_options.set(k, v)?; + } + json_options + } + }; + + Ok(Arc::new(JsonFormat::default().with_options(json_options))) + } + + fn default(&self) -> Arc { + Arc::new(JsonFormat::default()) + } +} + +impl GetExt for JsonFormatFactory { + fn get_ext(&self) -> String { + // Removes the dot, i.e. ".parquet" -> "parquet" + DEFAULT_JSON_EXTENSION[1..].to_string() + } +} + /// New line delimited JSON `FileFormat` implementation. #[derive(Debug, Default)] pub struct JsonFormat { @@ -95,6 +153,18 @@ impl FileFormat for JsonFormat { self } + fn get_ext(&self) -> String { + JsonFormatFactory::new().get_ext() + } + + fn get_ext_with_compression( + &self, + file_compression_type: &FileCompressionType, + ) -> Result { + let ext = self.get_ext(); + Ok(format!("{}{}", ext, file_compression_type.get_ext())) + } + async fn infer_schema( &self, _state: &SessionState, diff --git a/datafusion/core/src/datasource/file_format/mod.rs b/datafusion/core/src/datasource/file_format/mod.rs index 9462cde436103..1aa93a106aff0 100644 --- a/datafusion/core/src/datasource/file_format/mod.rs +++ b/datafusion/core/src/datasource/file_format/mod.rs @@ -32,7 +32,8 @@ pub mod parquet; pub mod write; use std::any::Any; -use std::fmt; +use std::collections::HashMap; +use std::fmt::{self, Display}; use std::sync::Arc; use crate::arrow::datatypes::SchemaRef; @@ -41,12 +42,29 @@ use crate::error::Result; use crate::execution::context::SessionState; use crate::physical_plan::{ExecutionPlan, Statistics}; -use datafusion_common::not_impl_err; +use datafusion_common::file_options::file_type::FileType; +use datafusion_common::{internal_err, not_impl_err, GetExt}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; use async_trait::async_trait; +use file_compression_type::FileCompressionType; use object_store::{ObjectMeta, ObjectStore}; +/// Factory for creating [`FileFormat`] instances based on session and command level options +/// +/// Users can provide their own `FileFormatFactory` to support arbitrary file formats +pub trait FileFormatFactory: Sync + Send + GetExt { + /// Initialize a [FileFormat] and configure based on session and command level options + fn create( + &self, + state: &SessionState, + format_options: &HashMap, + ) -> Result>; + + /// Initialize a [FileFormat] with all options set to default values + fn default(&self) -> Arc; +} + /// This trait abstracts all the file format specific implementations /// from the [`TableProvider`]. This helps code re-utilization across /// providers that support the same file formats. @@ -58,6 +76,15 @@ pub trait FileFormat: Send + Sync + fmt::Debug { /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; + /// Returns the extension for this FileFormat, e.g. "file.csv" -> csv + fn get_ext(&self) -> String; + + /// Returns the extension for this FileFormat when compressed, e.g. "file.csv.gz" -> csv + fn get_ext_with_compression( + &self, + _file_compression_type: &FileCompressionType, + ) -> Result; + /// Infer the common schema of the provided objects. The objects will usually /// be analysed up to a given number of records or files (as specified in the /// format config) then give the estimated common schema. This might fail if @@ -106,6 +133,67 @@ pub trait FileFormat: Send + Sync + fmt::Debug { } } +/// A container of [FileFormatFactory] which also implements [FileType]. +/// This enables converting a dyn FileFormat to a dyn FileType. +/// The former trait is a superset of the latter trait, which includes execution time +/// relevant methods. [FileType] is only used in logical planning and only implements +/// the subset of methods required during logical planning. +pub struct DefaultFileType { + file_format_factory: Arc, +} + +impl DefaultFileType { + /// Constructs a [DefaultFileType] wrapper from a [FileFormatFactory] + pub fn new(file_format_factory: Arc) -> Self { + Self { + file_format_factory, + } + } +} + +impl FileType for DefaultFileType { + fn as_any(&self) -> &dyn Any { + self + } +} + +impl Display for DefaultFileType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.file_format_factory.default().fmt(f) + } +} + +impl GetExt for DefaultFileType { + fn get_ext(&self) -> String { + self.file_format_factory.get_ext() + } +} + +/// Converts a [FileFormatFactory] to a [FileType] +pub fn format_as_file_type( + file_format_factory: Arc, +) -> Arc { + Arc::new(DefaultFileType { + file_format_factory, + }) +} + +/// Converts a [FileType] to a [FileFormatFactory]. +/// Returns an error if the [FileType] cannot be +/// downcasted to a [DefaultFileType]. +pub fn file_type_to_format( + file_type: &Arc, +) -> datafusion_common::Result> { + match file_type + .as_ref() + .as_any() + .downcast_ref::() + { + Some(source) => Ok(source.file_format_factory.clone()), + _ => internal_err!("FileType was not DefaultFileType"), + } +} + #[cfg(test)] pub(crate) mod test_util { use std::ops::Range; diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 4204593eba96d..44c9cc4ec4a9d 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -24,7 +24,7 @@ use std::sync::Arc; use super::write::demux::start_demuxer_task; use super::write::{create_writer, SharedBuffer}; -use super::{FileFormat, FileScanConfig}; +use super::{FileFormat, FileFormatFactory, FileScanConfig}; use crate::arrow::array::RecordBatch; use crate::arrow::datatypes::{Fields, Schema, SchemaRef}; use crate::datasource::file_format::file_compression_type::FileCompressionType; @@ -39,11 +39,13 @@ use crate::physical_plan::{ }; use arrow::compute::sum; -use datafusion_common::config::TableParquetOptions; +use datafusion_common::config::{ConfigField, ConfigFileType, TableParquetOptions}; use datafusion_common::file_options::parquet_writer::ParquetWriterOptions; +use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; use datafusion_common::{ - exec_err, internal_datafusion_err, not_impl_err, DataFusionError, + exec_err, internal_datafusion_err, not_impl_err, DataFusionError, GetExt, + DEFAULT_PARQUET_EXTENSION, }; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::TaskContext; @@ -53,6 +55,7 @@ use datafusion_physical_plan::metrics::MetricsSet; use async_trait::async_trait; use bytes::{BufMut, BytesMut}; +use hashbrown::HashMap; use log::debug; use object_store::buffered::BufWriter; use parquet::arrow::arrow_writer::{ @@ -75,7 +78,6 @@ use crate::datasource::physical_plan::parquet::{ ParquetExecBuilder, StatisticsConverter, }; use futures::{StreamExt, TryStreamExt}; -use hashbrown::HashMap; use object_store::path::Path; use object_store::{ObjectMeta, ObjectStore}; @@ -87,6 +89,65 @@ const INITIAL_BUFFER_BYTES: usize = 1048576; /// this size, it is flushed to object store const BUFFER_FLUSH_BYTES: usize = 1024000; +#[derive(Default)] +/// Factory struct used to create [ParquetFormat] +pub struct ParquetFormatFactory { + options: Option, +} + +impl ParquetFormatFactory { + /// Creates an instance of [ParquetFormatFactory] + pub fn new() -> Self { + Self { options: None } + } + + /// Creates an instance of [ParquetFormatFactory] with customized default options + pub fn new_with_options(options: TableParquetOptions) -> Self { + Self { + options: Some(options), + } + } +} + +impl FileFormatFactory for ParquetFormatFactory { + fn create( + &self, + state: &SessionState, + format_options: &std::collections::HashMap, + ) -> Result> { + let parquet_options = match &self.options { + None => { + let mut table_options = state.default_table_options(); + table_options.set_config_format(ConfigFileType::PARQUET); + table_options.alter_with_string_hash_map(format_options)?; + table_options.parquet + } + Some(parquet_options) => { + let mut parquet_options = parquet_options.clone(); + for (k, v) in format_options { + parquet_options.set(k, v)?; + } + parquet_options + } + }; + + Ok(Arc::new( + ParquetFormat::default().with_options(parquet_options), + )) + } + + fn default(&self) -> Arc { + Arc::new(ParquetFormat::default()) + } +} + +impl GetExt for ParquetFormatFactory { + fn get_ext(&self) -> String { + // Removes the dot, i.e. ".parquet" -> "parquet" + DEFAULT_PARQUET_EXTENSION[1..].to_string() + } +} + /// The Apache Parquet `FileFormat` implementation #[derive(Debug, Default)] pub struct ParquetFormat { @@ -188,6 +249,23 @@ impl FileFormat for ParquetFormat { self } + fn get_ext(&self) -> String { + ParquetFormatFactory::new().get_ext() + } + + fn get_ext_with_compression( + &self, + file_compression_type: &FileCompressionType, + ) -> Result { + let ext = self.get_ext(); + match file_compression_type.get_variant() { + CompressionTypeVariant::UNCOMPRESSED => Ok(ext), + _ => Err(DataFusionError::Internal( + "Parquet FileFormat does not support compression.".into(), + )), + } + } + async fn infer_schema( &self, state: &SessionState, diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 7f5e80c4988a5..74aca82b3ee6b 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -24,20 +24,11 @@ use std::{any::Any, sync::Arc}; use super::helpers::{expr_applicable_for_cols, pruned_partition_list, split_files}; use super::PartitionedFile; -#[cfg(feature = "parquet")] -use crate::datasource::file_format::parquet::ParquetFormat; use crate::datasource::{ create_ordering, get_statistics_with_limit, TableProvider, TableType, }; use crate::datasource::{ - file_format::{ - arrow::ArrowFormat, - avro::AvroFormat, - csv::CsvFormat, - file_compression_type::{FileCompressionType, FileTypeExt}, - json::JsonFormat, - FileFormat, - }, + file_format::{file_compression_type::FileCompressionType, FileFormat}, listing::ListingTableUrl, physical_plan::{FileScanConfig, FileSinkConfig}, }; @@ -51,7 +42,8 @@ use crate::{ use arrow::datatypes::{DataType, Field, SchemaBuilder, SchemaRef}; use arrow_schema::Schema; use datafusion_common::{ - internal_err, plan_err, project_schema, Constraints, FileType, SchemaExt, ToDFSchema, + config_datafusion_err, internal_err, plan_err, project_schema, Constraints, + SchemaExt, ToDFSchema, }; use datafusion_execution::cache::cache_manager::FileStatisticsCache; use datafusion_execution::cache::cache_unit::DefaultFileStatisticsCache; @@ -119,9 +111,7 @@ impl ListingTableConfig { } } - fn infer_file_type(path: &str) -> Result<(FileType, String)> { - let err_msg = format!("Unable to infer file type from path: {path}"); - + fn infer_file_extension(path: &str) -> Result { let mut exts = path.rsplit('.'); let mut splitted = exts.next().unwrap_or(""); @@ -133,14 +123,7 @@ impl ListingTableConfig { splitted = exts.next().unwrap_or(""); } - let file_type = FileType::from_str(splitted) - .map_err(|_| DataFusionError::Internal(err_msg.to_owned()))?; - - let ext = file_type - .get_ext_with_compression(file_compression_type.to_owned()) - .map_err(|_| DataFusionError::Internal(err_msg))?; - - Ok((file_type, ext)) + Ok(splitted.to_string()) } /// Infer `ListingOptions` based on `table_path` suffix. @@ -161,25 +144,15 @@ impl ListingTableConfig { .await .ok_or_else(|| DataFusionError::Internal("No files for table".into()))??; - let (file_type, file_extension) = - ListingTableConfig::infer_file_type(file.location.as_ref())?; + let file_extension = + ListingTableConfig::infer_file_extension(file.location.as_ref())?; - let mut table_options = state.default_table_options(); - table_options.set_file_format(file_type.clone()); - let file_format: Arc = match file_type { - FileType::CSV => { - Arc::new(CsvFormat::default().with_options(table_options.csv)) - } - #[cfg(feature = "parquet")] - FileType::PARQUET => { - Arc::new(ParquetFormat::default().with_options(table_options.parquet)) - } - FileType::AVRO => Arc::new(AvroFormat), - FileType::JSON => { - Arc::new(JsonFormat::default().with_options(table_options.json)) - } - FileType::ARROW => Arc::new(ArrowFormat), - }; + let file_format = state + .get_file_format_factory(&file_extension) + .ok_or(config_datafusion_err!( + "No file_format found with extension {file_extension}" + ))? + .create(state, &HashMap::new())?; let listing_options = ListingOptions::new(file_format) .with_file_extension(file_extension) @@ -1060,6 +1033,10 @@ impl ListingTable { #[cfg(test)] mod tests { use super::*; + use crate::datasource::file_format::avro::AvroFormat; + use crate::datasource::file_format::csv::CsvFormat; + use crate::datasource::file_format::json::JsonFormat; + use crate::datasource::file_format::parquet::ParquetFormat; #[cfg(feature = "parquet")] use crate::datasource::{provider_as_source, MemTable}; use crate::execution::options::ArrowReadOptions; @@ -1073,7 +1050,7 @@ mod tests { use arrow::record_batch::RecordBatch; use arrow_schema::SortOptions; use datafusion_common::stats::Precision; - use datafusion_common::{assert_contains, GetExt, ScalarValue}; + use datafusion_common::{assert_contains, ScalarValue}; use datafusion_expr::{BinaryExpr, LogicalPlanBuilder, Operator}; use datafusion_physical_expr::PhysicalSortExpr; use datafusion_physical_plan::ExecutionPlanProperties; @@ -1104,6 +1081,8 @@ mod tests { #[cfg(feature = "parquet")] #[tokio::test] async fn load_table_stats_by_default() -> Result<()> { + use crate::datasource::file_format::parquet::ParquetFormat; + let testdata = crate::test_util::parquet_test_data(); let filename = format!("{}/{}", testdata, "alltypes_plain.parquet"); let table_path = ListingTableUrl::parse(filename).unwrap(); @@ -1128,6 +1107,8 @@ mod tests { #[cfg(feature = "parquet")] #[tokio::test] async fn load_table_stats_when_no_stats() -> Result<()> { + use crate::datasource::file_format::parquet::ParquetFormat; + let testdata = crate::test_util::parquet_test_data(); let filename = format!("{}/{}", testdata, "alltypes_plain.parquet"); let table_path = ListingTableUrl::parse(filename).unwrap(); @@ -1162,7 +1143,10 @@ mod tests { let options = ListingOptions::new(Arc::new(ParquetFormat::default())); let schema = options.infer_schema(&state, &table_path).await.unwrap(); - use crate::physical_plan::expressions::col as physical_col; + use crate::{ + datasource::file_format::parquet::ParquetFormat, + physical_plan::expressions::col as physical_col, + }; use std::ops::Add; // (file_sort_order, expected_result) @@ -1253,7 +1237,7 @@ mod tests { register_test_store(&ctx, &[(&path, 100)]); let opt = ListingOptions::new(Arc::new(AvroFormat {})) - .with_file_extension(FileType::AVRO.get_ext()) + .with_file_extension(AvroFormat.get_ext()) .with_table_partition_cols(vec![(String::from("p1"), DataType::Utf8)]) .with_target_partitions(4); @@ -1516,7 +1500,7 @@ mod tests { "10".into(), ); helper_test_append_new_files_to_table( - FileType::JSON, + JsonFormat::default().get_ext(), FileCompressionType::UNCOMPRESSED, Some(config_map), 2, @@ -1534,7 +1518,7 @@ mod tests { "10".into(), ); helper_test_append_new_files_to_table( - FileType::CSV, + CsvFormat::default().get_ext(), FileCompressionType::UNCOMPRESSED, Some(config_map), 2, @@ -1552,7 +1536,7 @@ mod tests { "10".into(), ); helper_test_append_new_files_to_table( - FileType::PARQUET, + ParquetFormat::default().get_ext(), FileCompressionType::UNCOMPRESSED, Some(config_map), 2, @@ -1570,7 +1554,7 @@ mod tests { "20".into(), ); helper_test_append_new_files_to_table( - FileType::PARQUET, + ParquetFormat::default().get_ext(), FileCompressionType::UNCOMPRESSED, Some(config_map), 1, @@ -1756,7 +1740,7 @@ mod tests { ); config_map.insert("datafusion.execution.batch_size".into(), "1".into()); helper_test_append_new_files_to_table( - FileType::PARQUET, + ParquetFormat::default().get_ext(), FileCompressionType::UNCOMPRESSED, Some(config_map), 2, @@ -1774,7 +1758,7 @@ mod tests { "zstd".into(), ); let e = helper_test_append_new_files_to_table( - FileType::PARQUET, + ParquetFormat::default().get_ext(), FileCompressionType::UNCOMPRESSED, Some(config_map), 2, @@ -1787,7 +1771,7 @@ mod tests { } async fn helper_test_append_new_files_to_table( - file_type: FileType, + file_type_ext: String, file_compression_type: FileCompressionType, session_config_map: Option>, expected_n_files_per_insert: usize, @@ -1824,8 +1808,8 @@ mod tests { // Register appropriate table depending on file_type we want to test let tmp_dir = TempDir::new()?; - match file_type { - FileType::CSV => { + match file_type_ext.as_str() { + "csv" => { session_ctx .register_csv( "t", @@ -1836,7 +1820,7 @@ mod tests { ) .await?; } - FileType::JSON => { + "json" => { session_ctx .register_json( "t", @@ -1847,7 +1831,7 @@ mod tests { ) .await?; } - FileType::PARQUET => { + "parquet" => { session_ctx .register_parquet( "t", @@ -1856,7 +1840,7 @@ mod tests { ) .await?; } - FileType::AVRO => { + "avro" => { session_ctx .register_avro( "t", @@ -1865,7 +1849,7 @@ mod tests { ) .await?; } - FileType::ARROW => { + "arrow" => { session_ctx .register_arrow( "t", @@ -1874,6 +1858,7 @@ mod tests { ) .await?; } + _ => panic!("Unrecognized file extension {file_type_ext}"), } // Create and register the source table with the provided schema and inserted data diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index 6e47498243955..1d4d08481895b 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -18,14 +18,8 @@ //! Factory for creating ListingTables with default options use std::path::Path; -use std::str::FromStr; use std::sync::Arc; -#[cfg(feature = "parquet")] -use crate::datasource::file_format::parquet::ParquetFormat; -use crate::datasource::file_format::{ - arrow::ArrowFormat, avro::AvroFormat, csv::CsvFormat, json::JsonFormat, FileFormat, -}; use crate::datasource::listing::{ ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, }; @@ -34,8 +28,8 @@ use crate::datasource::TableProvider; use crate::execution::context::SessionState; use arrow::datatypes::{DataType, SchemaRef}; -use datafusion_common::Result; -use datafusion_common::{arrow_datafusion_err, DataFusionError, FileType}; +use datafusion_common::{arrow_datafusion_err, DataFusionError}; +use datafusion_common::{config_datafusion_err, Result}; use datafusion_expr::CreateExternalTable; use async_trait::async_trait; @@ -58,28 +52,15 @@ impl TableProviderFactory for ListingTableFactory { state: &SessionState, cmd: &CreateExternalTable, ) -> Result> { - let file_type = FileType::from_str(cmd.file_type.as_str()).map_err(|_| { - DataFusionError::Execution(format!("Unknown FileType {}", cmd.file_type)) - })?; - let mut table_options = state.default_table_options(); - table_options.set_file_format(file_type.clone()); - table_options.alter_with_string_hash_map(&cmd.options)?; + let file_format = state + .get_file_format_factory(cmd.file_type.as_str()) + .ok_or(config_datafusion_err!( + "Unable to create table with format {}! Could not find FileFormat.", + cmd.file_type + ))? + .create(state, &cmd.options)?; let file_extension = get_extension(cmd.location.as_str()); - let file_format: Arc = match file_type { - FileType::CSV => { - Arc::new(CsvFormat::default().with_options(table_options.csv)) - } - #[cfg(feature = "parquet")] - FileType::PARQUET => { - Arc::new(ParquetFormat::default().with_options(table_options.parquet)) - } - FileType::AVRO => Arc::new(AvroFormat), - FileType::JSON => { - Arc::new(JsonFormat::default().with_options(table_options.json)) - } - FileType::ARROW => Arc::new(ArrowFormat), - }; let (provided_schema, table_partition_cols) = if cmd.schema.fields().is_empty() { ( @@ -166,7 +147,9 @@ mod tests { use std::collections::HashMap; use super::*; - use crate::execution::context::SessionContext; + use crate::{ + datasource::file_format::csv::CsvFormat, execution::context::SessionContext, + }; use datafusion_common::{Constraints, DFSchema, TableReference}; diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index c06c630c45d14..327fbd976e877 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -534,13 +534,13 @@ mod tests { use super::*; use crate::dataframe::DataFrameWriteOptions; + use crate::datasource::file_format::csv::CsvFormat; use crate::prelude::*; use crate::test::{partitioned_csv_config, partitioned_file_groups}; use crate::{scalar::ScalarValue, test_util::aggr_test_schema}; use arrow::datatypes::*; use datafusion_common::test_util::arrow_test_data; - use datafusion_common::FileType; use object_store::chunked::ChunkedStore; use object_store::local::LocalFileSystem; @@ -561,6 +561,8 @@ mod tests { async fn csv_exec_with_projection( file_compression_type: FileCompressionType, ) -> Result<()> { + use crate::datasource::file_format::csv::CsvFormat; + let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); let file_schema = aggr_test_schema(); @@ -572,7 +574,7 @@ mod tests { path.as_str(), filename, 1, - FileType::CSV, + Arc::new(CsvFormat::default()), file_compression_type.to_owned(), tmp_dir.path(), )?; @@ -627,6 +629,8 @@ mod tests { async fn csv_exec_with_mixed_order_projection( file_compression_type: FileCompressionType, ) -> Result<()> { + use crate::datasource::file_format::csv::CsvFormat; + let cfg = SessionConfig::new().set_str("datafusion.catalog.has_header", "true"); let session_ctx = SessionContext::new_with_config(cfg); let task_ctx = session_ctx.task_ctx(); @@ -639,7 +643,7 @@ mod tests { path.as_str(), filename, 1, - FileType::CSV, + Arc::new(CsvFormat::default()), file_compression_type.to_owned(), tmp_dir.path(), )?; @@ -694,6 +698,8 @@ mod tests { async fn csv_exec_with_limit( file_compression_type: FileCompressionType, ) -> Result<()> { + use crate::datasource::file_format::csv::CsvFormat; + let cfg = SessionConfig::new().set_str("datafusion.catalog.has_header", "true"); let session_ctx = SessionContext::new_with_config(cfg); let task_ctx = session_ctx.task_ctx(); @@ -706,7 +712,7 @@ mod tests { path.as_str(), filename, 1, - FileType::CSV, + Arc::new(CsvFormat::default()), file_compression_type.to_owned(), tmp_dir.path(), )?; @@ -759,6 +765,8 @@ mod tests { async fn csv_exec_with_missing_column( file_compression_type: FileCompressionType, ) -> Result<()> { + use crate::datasource::file_format::csv::CsvFormat; + let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); let file_schema = aggr_test_schema_with_missing_col(); @@ -770,7 +778,7 @@ mod tests { path.as_str(), filename, 1, - FileType::CSV, + Arc::new(CsvFormat::default()), file_compression_type.to_owned(), tmp_dir.path(), )?; @@ -813,6 +821,8 @@ mod tests { async fn csv_exec_with_partition( file_compression_type: FileCompressionType, ) -> Result<()> { + use crate::datasource::file_format::csv::CsvFormat; + let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); let file_schema = aggr_test_schema(); @@ -824,7 +834,7 @@ mod tests { path.as_str(), filename, 1, - FileType::CSV, + Arc::new(CsvFormat::default()), file_compression_type.to_owned(), tmp_dir.path(), )?; @@ -929,7 +939,7 @@ mod tests { path.as_str(), filename, 1, - FileType::CSV, + Arc::new(CsvFormat::default()), file_compression_type.to_owned(), tmp_dir.path(), ) diff --git a/datafusion/core/src/datasource/physical_plan/json.rs b/datafusion/core/src/datasource/physical_plan/json.rs index e97554a791bda..c051b5d9b57d9 100644 --- a/datafusion/core/src/datasource/physical_plan/json.rs +++ b/datafusion/core/src/datasource/physical_plan/json.rs @@ -384,7 +384,6 @@ mod tests { use super::*; use crate::dataframe::DataFrameWriteOptions; - use crate::datasource::file_format::file_compression_type::FileTypeExt; use crate::datasource::file_format::{json::JsonFormat, FileFormat}; use crate::datasource::object_store::ObjectStoreUrl; use crate::execution::context::SessionState; @@ -397,7 +396,6 @@ mod tests { use arrow::array::Array; use arrow::datatypes::{Field, SchemaBuilder}; use datafusion_common::cast::{as_int32_array, as_int64_array, as_string_array}; - use datafusion_common::FileType; use object_store::chunked::ChunkedStore; use object_store::local::LocalFileSystem; use rstest::*; @@ -419,7 +417,7 @@ mod tests { TEST_DATA_BASE, filename, 1, - FileType::JSON, + Arc::new(JsonFormat::default()), file_compression_type.to_owned(), work_dir, ) @@ -453,7 +451,7 @@ mod tests { TEST_DATA_BASE, filename, 1, - FileType::JSON, + Arc::new(JsonFormat::default()), file_compression_type.to_owned(), tmp_dir.path(), ) @@ -472,8 +470,8 @@ mod tests { let path_buf = Path::new(url.path()).join(path); let path = path_buf.to_str().unwrap(); - let ext = FileType::JSON - .get_ext_with_compression(file_compression_type.to_owned()) + let ext = JsonFormat::default() + .get_ext_with_compression(&file_compression_type) .unwrap(); let read_options = NdJsonReadOptions::default() @@ -904,8 +902,8 @@ mod tests { let url: &Url = store_url.as_ref(); let path_buf = Path::new(url.path()).join(path); let path = path_buf.to_str().unwrap(); - let ext = FileType::JSON - .get_ext_with_compression(file_compression_type.to_owned()) + let ext = JsonFormat::default() + .get_ext_with_compression(&file_compression_type) .unwrap(); let read_option = NdJsonReadOptions::default() diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index ea7faac08c1cd..9d5c64719e759 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -799,7 +799,7 @@ mod tests { use arrow::datatypes::{Field, Schema, SchemaBuilder}; use arrow::record_batch::RecordBatch; use arrow_schema::Fields; - use datafusion_common::{assert_contains, FileType, GetExt, ScalarValue}; + use datafusion_common::{assert_contains, ScalarValue}; use datafusion_expr::{col, lit, when, Expr}; use datafusion_physical_expr::planner::logical2physical; use datafusion_physical_plan::ExecutionPlanProperties; @@ -1994,7 +1994,7 @@ mod tests { // Configure listing options let file_format = ParquetFormat::default().with_enable_pruning(true); let listing_options = ListingOptions::new(Arc::new(file_format)) - .with_file_extension(FileType::PARQUET.get_ext()); + .with_file_extension(ParquetFormat::default().get_ext()); // execute a simple query and write the results to parquet let out_dir = tmp_dir.as_ref().to_str().unwrap().to_string() + "/out"; diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index d2bac134b54aa..2b7867e72046c 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -25,6 +25,13 @@ use crate::catalog::{ MemoryCatalogProviderList, }; use crate::datasource::cte_worktable::CteWorkTable; +use crate::datasource::file_format::arrow::ArrowFormatFactory; +use crate::datasource::file_format::avro::AvroFormatFactory; +use crate::datasource::file_format::csv::CsvFormatFactory; +use crate::datasource::file_format::json::JsonFormatFactory; +#[cfg(feature = "parquet")] +use crate::datasource::file_format::parquet::ParquetFormatFactory; +use crate::datasource::file_format::{format_as_file_type, FileFormatFactory}; use crate::datasource::function::{TableFunction, TableFunctionImpl}; use crate::datasource::provider::{DefaultTableFactory, TableProviderFactory}; use crate::datasource::provider_as_source; @@ -41,10 +48,11 @@ use chrono::{DateTime, Utc}; use datafusion_common::alias::AliasGenerator; use datafusion_common::config::{ConfigExtension, ConfigOptions, TableOptions}; use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; +use datafusion_common::file_options::file_type::FileType; use datafusion_common::tree_node::TreeNode; use datafusion_common::{ - not_impl_err, plan_datafusion_err, DFSchema, DataFusionError, ResolvedTableReference, - TableReference, + config_err, not_impl_err, plan_datafusion_err, DFSchema, DataFusionError, + ResolvedTableReference, TableReference, }; use datafusion_execution::config::SessionConfig; use datafusion_execution::object_store::ObjectStoreUrl; @@ -109,6 +117,8 @@ pub struct SessionState { window_functions: HashMap>, /// Deserializer registry for extensions. serializer_registry: Arc, + /// Holds registered external FileFormat implementations + file_formats: HashMap>, /// Session configuration config: SessionConfig, /// Table options @@ -230,6 +240,7 @@ impl SessionState { aggregate_functions: HashMap::new(), window_functions: HashMap::new(), serializer_registry: Arc::new(EmptySerializerRegistry), + file_formats: HashMap::new(), table_options: TableOptions::default_from_session_config(config.options()), config, execution_props: ExecutionProps::new(), @@ -238,6 +249,37 @@ impl SessionState { function_factory: None, }; + #[cfg(feature = "parquet")] + if let Err(e) = + new_self.register_file_format(Arc::new(ParquetFormatFactory::new()), false) + { + log::info!("Unable to register default ParquetFormat: {e}") + }; + + if let Err(e) = + new_self.register_file_format(Arc::new(JsonFormatFactory::new()), false) + { + log::info!("Unable to register default JsonFormat: {e}") + }; + + if let Err(e) = + new_self.register_file_format(Arc::new(CsvFormatFactory::new()), false) + { + log::info!("Unable to register default CsvFormat: {e}") + }; + + if let Err(e) = + new_self.register_file_format(Arc::new(ArrowFormatFactory::new()), false) + { + log::info!("Unable to register default ArrowFormat: {e}") + }; + + if let Err(e) = + new_self.register_file_format(Arc::new(AvroFormatFactory::new()), false) + { + log::info!("Unable to register default AvroFormat: {e}") + }; + // register built in functions functions::register_all(&mut new_self) .expect("can not register built in functions"); @@ -811,6 +853,31 @@ impl SessionState { self.table_options.extensions.insert(extension) } + /// Adds or updates a [FileFormatFactory] which can be used with COPY TO or CREATE EXTERNAL TABLE statements for reading + /// and writing files of custom formats. + pub fn register_file_format( + &mut self, + file_format: Arc, + overwrite: bool, + ) -> Result<(), DataFusionError> { + let ext = file_format.get_ext().to_lowercase(); + match (self.file_formats.entry(ext.clone()), overwrite){ + (Entry::Vacant(e), _) => {e.insert(file_format);}, + (Entry::Occupied(mut e), true) => {e.insert(file_format);}, + (Entry::Occupied(_), false) => return config_err!("File type already registered for extension {ext}. Set overwrite to true to replace this extension."), + }; + Ok(()) + } + + /// Retrieves a [FileFormatFactory] based on file extension which has been registered + /// via SessionContext::register_file_format. Extensions are not case sensitive. + pub fn get_file_format_factory( + &self, + ext: &str, + ) -> Option> { + self.file_formats.get(&ext.to_lowercase()).cloned() + } + /// Get a new TaskContext to run in this session pub fn task_ctx(&self) -> Arc { Arc::new(TaskContext::from(self)) @@ -967,6 +1034,16 @@ impl<'a> ContextProvider for SessionContextProvider<'a> { fn udwf_names(&self) -> Vec { self.state.window_functions().keys().cloned().collect() } + + fn get_file_type(&self, ext: &str) -> datafusion_common::Result> { + self.state + .file_formats + .get(&ext.to_lowercase()) + .ok_or(plan_datafusion_err!( + "There is no registered file format with ext {ext}" + )) + .map(|file_type| format_as_file_type(file_type.clone())) + } } impl FunctionRegistry for SessionState { diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 15f7555575e8c..5b8501baaad8b 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -22,13 +22,7 @@ use std::collections::HashMap; use std::fmt::Write; use std::sync::Arc; -use crate::datasource::file_format::arrow::ArrowFormat; -use crate::datasource::file_format::avro::AvroFormat; -use crate::datasource::file_format::csv::CsvFormat; -use crate::datasource::file_format::json::JsonFormat; -#[cfg(feature = "parquet")] -use crate::datasource::file_format::parquet::ParquetFormat; -use crate::datasource::file_format::FileFormat; +use crate::datasource::file_format::file_type_to_format; use crate::datasource::listing::ListingTableUrl; use crate::datasource::physical_plan::FileSinkConfig; use crate::datasource::source_as_provider; @@ -74,11 +68,10 @@ use arrow::compute::SortOptions; use arrow::datatypes::{Schema, SchemaRef}; use arrow_array::builder::StringBuilder; use arrow_array::RecordBatch; -use datafusion_common::config::FormatOptions; use datafusion_common::display::ToStringifiedPlan; use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema, - FileType, ScalarValue, + ScalarValue, }; use datafusion_expr::dml::CopyTo; use datafusion_expr::expr::{ @@ -764,7 +757,7 @@ impl DefaultPhysicalPlanner { LogicalPlan::Copy(CopyTo { input, output_url, - format_options, + file_type, partition_by, options: source_option_tuples, }) => { @@ -791,32 +784,9 @@ impl DefaultPhysicalPlanner { table_partition_cols, overwrite: false, }; - let mut table_options = session_state.default_table_options(); - let sink_format: Arc = match format_options { - FormatOptions::CSV(options) => { - table_options.csv = options.clone(); - table_options.set_file_format(FileType::CSV); - table_options.alter_with_string_hash_map(source_option_tuples)?; - Arc::new(CsvFormat::default().with_options(table_options.csv)) - } - FormatOptions::JSON(options) => { - table_options.json = options.clone(); - table_options.set_file_format(FileType::JSON); - table_options.alter_with_string_hash_map(source_option_tuples)?; - Arc::new(JsonFormat::default().with_options(table_options.json)) - } - #[cfg(feature = "parquet")] - FormatOptions::PARQUET(options) => { - table_options.parquet = options.clone(); - table_options.set_file_format(FileType::PARQUET); - table_options.alter_with_string_hash_map(source_option_tuples)?; - Arc::new( - ParquetFormat::default().with_options(table_options.parquet), - ) - } - FormatOptions::AVRO => Arc::new(AvroFormat {}), - FormatOptions::ARROW => Arc::new(ArrowFormat {}), - }; + + let sink_format = file_type_to_format(file_type)? + .create(session_state, source_option_tuples)?; sink_format .create_writer_physical_plan(input_exec, session_state, config, None) diff --git a/datafusion/core/src/test/mod.rs b/datafusion/core/src/test/mod.rs index 2515b8a4e0164..e8550a79cb0e0 100644 --- a/datafusion/core/src/test/mod.rs +++ b/datafusion/core/src/test/mod.rs @@ -24,9 +24,9 @@ use std::io::{BufReader, BufWriter}; use std::path::Path; use std::sync::Arc; -use crate::datasource::file_format::file_compression_type::{ - FileCompressionType, FileTypeExt, -}; +use crate::datasource::file_format::csv::CsvFormat; +use crate::datasource::file_format::file_compression_type::FileCompressionType; +use crate::datasource::file_format::FileFormat; use crate::datasource::listing::PartitionedFile; use crate::datasource::object_store::ObjectStoreUrl; use crate::datasource::physical_plan::{CsvExec, FileScanConfig}; @@ -40,7 +40,7 @@ use crate::test_util::{aggr_test_schema, arrow_test_data}; use arrow::array::{self, Array, ArrayRef, Decimal128Builder, Int32Array}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; -use datafusion_common::{DataFusionError, FileType, Statistics}; +use datafusion_common::{DataFusionError, Statistics}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_physical_expr::{EquivalenceProperties, Partitioning, PhysicalSortExpr}; use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; @@ -87,7 +87,7 @@ pub fn scan_partitioned_csv(partitions: usize, work_dir: &Path) -> Result, file_compression_type: FileCompressionType, work_dir: &Path, ) -> Result>> { @@ -120,9 +120,8 @@ pub fn partitioned_file_groups( let filename = format!( "partition-{}{}", i, - file_type - .to_owned() - .get_ext_with_compression(file_compression_type.to_owned()) + file_format + .get_ext_with_compression(&file_compression_type) .unwrap() ); let filename = work_dir.join(filename); @@ -167,7 +166,7 @@ pub fn partitioned_file_groups( for (i, line) in f.lines().enumerate() { let line = line.unwrap(); - if i == 0 && file_type == FileType::CSV { + if i == 0 && file_format.get_ext() == CsvFormat::default().get_ext() { // write header to all partitions for w in writers.iter_mut() { w.write_all(line.as_bytes()).unwrap(); diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 4564a8c71fcd7..f87151efd88b5 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -49,8 +49,8 @@ use crate::{ }; use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; -use datafusion_common::config::FormatOptions; use datafusion_common::display::ToStringifiedPlan; +use datafusion_common::file_options::file_type::FileType; use datafusion_common::{ get_target_functional_dependencies, internal_err, not_impl_err, plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, @@ -272,14 +272,14 @@ impl LogicalPlanBuilder { pub fn copy_to( input: LogicalPlan, output_url: String, - format_options: FormatOptions, + file_type: Arc, options: HashMap, partition_by: Vec, ) -> Result { Ok(Self::from(LogicalPlan::Copy(CopyTo { input: Arc::new(input), output_url, - format_options, + file_type, options, partition_by, }))) diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index 707cff8ab5f12..81fd03555abb7 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -425,7 +425,7 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { LogicalPlan::Copy(CopyTo { input: _, output_url, - format_options, + file_type, partition_by: _, options, }) => { @@ -437,7 +437,7 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { json!({ "Node Type": "CopyTo", "Output URL": output_url, - "Format Options": format!("{}", format_options), + "File Type": format!("{}", file_type.get_ext()), "Options": op_str }) } diff --git a/datafusion/expr/src/logical_plan/dml.rs b/datafusion/expr/src/logical_plan/dml.rs index 13f3759ab8c06..c9eef9bd34cc0 100644 --- a/datafusion/expr/src/logical_plan/dml.rs +++ b/datafusion/expr/src/logical_plan/dml.rs @@ -21,7 +21,7 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; use arrow::datatypes::{DataType, Field, Schema}; -use datafusion_common::config::FormatOptions; +use datafusion_common::file_options::file_type::FileType; use datafusion_common::{DFSchemaRef, TableReference}; use crate::LogicalPlan; @@ -35,8 +35,8 @@ pub struct CopyTo { pub output_url: String, /// Determines which, if any, columns should be used for hive-style partitioned writes pub partition_by: Vec, - /// File format options. - pub format_options: FormatOptions, + /// File type trait + pub file_type: Arc, /// SQL Options that can affect the formats pub options: HashMap, } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 6e7efaf39e3e2..31f830a6a13df 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -857,13 +857,13 @@ impl LogicalPlan { LogicalPlan::Copy(CopyTo { input: _, output_url, - format_options, + file_type, options, partition_by, }) => Ok(LogicalPlan::Copy(CopyTo { input: Arc::new(inputs.swap_remove(0)), output_url: output_url.clone(), - format_options: format_options.clone(), + file_type: file_type.clone(), options: options.clone(), partition_by: partition_by.clone(), })), @@ -1729,7 +1729,7 @@ impl LogicalPlan { LogicalPlan::Copy(CopyTo { input: _, output_url, - format_options, + file_type, options, .. }) => { @@ -1739,7 +1739,7 @@ impl LogicalPlan { .collect::>() .join(", "); - write!(f, "CopyTo: format={format_options} output_url={output_url} options: ({op_str})") + write!(f, "CopyTo: format={} output_url={output_url} options: ({op_str})", file_type.get_ext()) } LogicalPlan::Ddl(ddl) => { write!(f, "{}", ddl.display()) diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 86c0cffd80a16..a47906f203221 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -256,14 +256,14 @@ impl TreeNode for LogicalPlan { input, output_url, partition_by, - format_options, + file_type, options, }) => rewrite_arc(input, f)?.update_data(|input| { LogicalPlan::Copy(CopyTo { input, output_url, partition_by, - format_options, + file_type, options, }) }), diff --git a/datafusion/proto/gen/src/main.rs b/datafusion/proto/gen/src/main.rs index 22c16eacb0938..d38a41a01ac23 100644 --- a/datafusion/proto/gen/src/main.rs +++ b/datafusion/proto/gen/src/main.rs @@ -29,6 +29,7 @@ fn main() -> Result<(), String> { let descriptor_path = proto_dir.join("proto/proto_descriptor.bin"); prost_build::Config::new() + .protoc_arg("--experimental_allow_proto3_optional") .file_descriptor_set_path(&descriptor_path) .out_dir(out_dir) .compile_well_known_types() diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 2e7005a4cb137..f2594ba103404 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -251,13 +251,7 @@ message DistinctOnNode { message CopyToNode { LogicalPlanNode input = 1; string output_url = 2; - oneof format_options { - datafusion_common.CsvOptions csv = 8; - datafusion_common.JsonOptions json = 9; - datafusion_common.TableParquetOptions parquet = 10; - datafusion_common.AvroOptions avro = 11; - datafusion_common.ArrowOptions arrow = 12; - } + bytes file_type = 3; repeated string partition_by = 7; } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 8fdc5d2e4db28..e8fbe954428a0 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -2536,10 +2536,10 @@ impl serde::Serialize for CopyToNode { if !self.output_url.is_empty() { len += 1; } - if !self.partition_by.is_empty() { + if !self.file_type.is_empty() { len += 1; } - if self.format_options.is_some() { + if !self.partition_by.is_empty() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.CopyToNode", len)?; @@ -2549,28 +2549,13 @@ impl serde::Serialize for CopyToNode { if !self.output_url.is_empty() { struct_ser.serialize_field("outputUrl", &self.output_url)?; } + if !self.file_type.is_empty() { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("fileType", pbjson::private::base64::encode(&self.file_type).as_str())?; + } if !self.partition_by.is_empty() { struct_ser.serialize_field("partitionBy", &self.partition_by)?; } - if let Some(v) = self.format_options.as_ref() { - match v { - copy_to_node::FormatOptions::Csv(v) => { - struct_ser.serialize_field("csv", v)?; - } - copy_to_node::FormatOptions::Json(v) => { - struct_ser.serialize_field("json", v)?; - } - copy_to_node::FormatOptions::Parquet(v) => { - struct_ser.serialize_field("parquet", v)?; - } - copy_to_node::FormatOptions::Avro(v) => { - struct_ser.serialize_field("avro", v)?; - } - copy_to_node::FormatOptions::Arrow(v) => { - struct_ser.serialize_field("arrow", v)?; - } - } - } struct_ser.end() } } @@ -2584,25 +2569,18 @@ impl<'de> serde::Deserialize<'de> for CopyToNode { "input", "output_url", "outputUrl", + "file_type", + "fileType", "partition_by", "partitionBy", - "csv", - "json", - "parquet", - "avro", - "arrow", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Input, OutputUrl, + FileType, PartitionBy, - Csv, - Json, - Parquet, - Avro, - Arrow, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -2626,12 +2604,8 @@ impl<'de> serde::Deserialize<'de> for CopyToNode { match value { "input" => Ok(GeneratedField::Input), "outputUrl" | "output_url" => Ok(GeneratedField::OutputUrl), + "fileType" | "file_type" => Ok(GeneratedField::FileType), "partitionBy" | "partition_by" => Ok(GeneratedField::PartitionBy), - "csv" => Ok(GeneratedField::Csv), - "json" => Ok(GeneratedField::Json), - "parquet" => Ok(GeneratedField::Parquet), - "avro" => Ok(GeneratedField::Avro), - "arrow" => Ok(GeneratedField::Arrow), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -2653,8 +2627,8 @@ impl<'de> serde::Deserialize<'de> for CopyToNode { { let mut input__ = None; let mut output_url__ = None; + let mut file_type__ = None; let mut partition_by__ = None; - let mut format_options__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { @@ -2669,54 +2643,27 @@ impl<'de> serde::Deserialize<'de> for CopyToNode { } output_url__ = Some(map_.next_value()?); } + GeneratedField::FileType => { + if file_type__.is_some() { + return Err(serde::de::Error::duplicate_field("fileType")); + } + file_type__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } GeneratedField::PartitionBy => { if partition_by__.is_some() { return Err(serde::de::Error::duplicate_field("partitionBy")); } partition_by__ = Some(map_.next_value()?); } - GeneratedField::Csv => { - if format_options__.is_some() { - return Err(serde::de::Error::duplicate_field("csv")); - } - format_options__ = map_.next_value::<::std::option::Option<_>>()?.map(copy_to_node::FormatOptions::Csv) -; - } - GeneratedField::Json => { - if format_options__.is_some() { - return Err(serde::de::Error::duplicate_field("json")); - } - format_options__ = map_.next_value::<::std::option::Option<_>>()?.map(copy_to_node::FormatOptions::Json) -; - } - GeneratedField::Parquet => { - if format_options__.is_some() { - return Err(serde::de::Error::duplicate_field("parquet")); - } - format_options__ = map_.next_value::<::std::option::Option<_>>()?.map(copy_to_node::FormatOptions::Parquet) -; - } - GeneratedField::Avro => { - if format_options__.is_some() { - return Err(serde::de::Error::duplicate_field("avro")); - } - format_options__ = map_.next_value::<::std::option::Option<_>>()?.map(copy_to_node::FormatOptions::Avro) -; - } - GeneratedField::Arrow => { - if format_options__.is_some() { - return Err(serde::de::Error::duplicate_field("arrow")); - } - format_options__ = map_.next_value::<::std::option::Option<_>>()?.map(copy_to_node::FormatOptions::Arrow) -; - } } } Ok(CopyToNode { input: input__, output_url: output_url__.unwrap_or_default(), + file_type: file_type__.unwrap_or_default(), partition_by: partition_by__.unwrap_or_default(), - format_options: format_options__, }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 036b7cff9b03c..93bf6c0602276 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -411,27 +411,10 @@ pub struct CopyToNode { pub input: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(string, tag = "2")] pub output_url: ::prost::alloc::string::String, + #[prost(bytes = "vec", tag = "3")] + pub file_type: ::prost::alloc::vec::Vec, #[prost(string, repeated, tag = "7")] pub partition_by: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, - #[prost(oneof = "copy_to_node::FormatOptions", tags = "8, 9, 10, 11, 12")] - pub format_options: ::core::option::Option, -} -/// Nested message and enum types in `CopyToNode`. -pub mod copy_to_node { - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum FormatOptions { - #[prost(message, tag = "8")] - Csv(super::super::datafusion_common::CsvOptions), - #[prost(message, tag = "9")] - Json(super::super::datafusion_common::JsonOptions), - #[prost(message, tag = "10")] - Parquet(super::super::datafusion_common::TableParquetOptions), - #[prost(message, tag = "11")] - Avro(super::super::datafusion_common::AvroOptions), - #[prost(message, tag = "12")] - Arrow(super::super::datafusion_common::ArrowOptions), - } } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/datafusion/proto/src/logical_plan/file_formats.rs b/datafusion/proto/src/logical_plan/file_formats.rs new file mode 100644 index 0000000000000..31102b728ec95 --- /dev/null +++ b/datafusion/proto/src/logical_plan/file_formats.rs @@ -0,0 +1,399 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use datafusion::{ + datasource::file_format::{ + arrow::ArrowFormatFactory, csv::CsvFormatFactory, json::JsonFormatFactory, + parquet::ParquetFormatFactory, FileFormatFactory, + }, + prelude::SessionContext, +}; +use datafusion_common::not_impl_err; + +use super::LogicalExtensionCodec; + +#[derive(Debug)] +pub struct CsvLogicalExtensionCodec; + +// TODO! This is a placeholder for now and needs to be implemented for real. +impl LogicalExtensionCodec for CsvLogicalExtensionCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[datafusion_expr::LogicalPlan], + _ctx: &datafusion::prelude::SessionContext, + ) -> datafusion_common::Result { + not_impl_err!("Method not implemented") + } + + fn try_encode( + &self, + _node: &datafusion_expr::Extension, + _buf: &mut Vec, + ) -> datafusion_common::Result<()> { + not_impl_err!("Method not implemented") + } + + fn try_decode_table_provider( + &self, + _buf: &[u8], + _schema: arrow::datatypes::SchemaRef, + _ctx: &datafusion::prelude::SessionContext, + ) -> datafusion_common::Result< + std::sync::Arc, + > { + not_impl_err!("Method not implemented") + } + + fn try_encode_table_provider( + &self, + _node: std::sync::Arc, + _buf: &mut Vec, + ) -> datafusion_common::Result<()> { + not_impl_err!("Method not implemented") + } + + fn try_decode_file_format( + &self, + __buf: &[u8], + __ctx: &SessionContext, + ) -> datafusion_common::Result> { + Ok(Arc::new(CsvFormatFactory::new())) + } + + fn try_encode_file_format( + &self, + __buf: &[u8], + __node: Arc, + ) -> datafusion_common::Result<()> { + Ok(()) + } + + fn try_decode_udf( + &self, + name: &str, + __buf: &[u8], + ) -> datafusion_common::Result> { + not_impl_err!("LogicalExtensionCodec is not provided for scalar function {name}") + } + + fn try_encode_udf( + &self, + __node: &datafusion_expr::ScalarUDF, + __buf: &mut Vec, + ) -> datafusion_common::Result<()> { + Ok(()) + } +} + +#[derive(Debug)] +pub struct JsonLogicalExtensionCodec; + +// TODO! This is a placeholder for now and needs to be implemented for real. +impl LogicalExtensionCodec for JsonLogicalExtensionCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[datafusion_expr::LogicalPlan], + _ctx: &datafusion::prelude::SessionContext, + ) -> datafusion_common::Result { + not_impl_err!("Method not implemented") + } + + fn try_encode( + &self, + _node: &datafusion_expr::Extension, + _buf: &mut Vec, + ) -> datafusion_common::Result<()> { + not_impl_err!("Method not implemented") + } + + fn try_decode_table_provider( + &self, + _buf: &[u8], + _schema: arrow::datatypes::SchemaRef, + _ctx: &datafusion::prelude::SessionContext, + ) -> datafusion_common::Result< + std::sync::Arc, + > { + not_impl_err!("Method not implemented") + } + + fn try_encode_table_provider( + &self, + _node: std::sync::Arc, + _buf: &mut Vec, + ) -> datafusion_common::Result<()> { + not_impl_err!("Method not implemented") + } + + fn try_decode_file_format( + &self, + __buf: &[u8], + __ctx: &SessionContext, + ) -> datafusion_common::Result> { + Ok(Arc::new(JsonFormatFactory::new())) + } + + fn try_encode_file_format( + &self, + __buf: &[u8], + __node: Arc, + ) -> datafusion_common::Result<()> { + Ok(()) + } + + fn try_decode_udf( + &self, + name: &str, + __buf: &[u8], + ) -> datafusion_common::Result> { + not_impl_err!("LogicalExtensionCodec is not provided for scalar function {name}") + } + + fn try_encode_udf( + &self, + __node: &datafusion_expr::ScalarUDF, + __buf: &mut Vec, + ) -> datafusion_common::Result<()> { + Ok(()) + } +} + +#[derive(Debug)] +pub struct ParquetLogicalExtensionCodec; + +// TODO! This is a placeholder for now and needs to be implemented for real. +impl LogicalExtensionCodec for ParquetLogicalExtensionCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[datafusion_expr::LogicalPlan], + _ctx: &datafusion::prelude::SessionContext, + ) -> datafusion_common::Result { + not_impl_err!("Method not implemented") + } + + fn try_encode( + &self, + _node: &datafusion_expr::Extension, + _buf: &mut Vec, + ) -> datafusion_common::Result<()> { + not_impl_err!("Method not implemented") + } + + fn try_decode_table_provider( + &self, + _buf: &[u8], + _schema: arrow::datatypes::SchemaRef, + _ctx: &datafusion::prelude::SessionContext, + ) -> datafusion_common::Result< + std::sync::Arc, + > { + not_impl_err!("Method not implemented") + } + + fn try_encode_table_provider( + &self, + _node: std::sync::Arc, + _buf: &mut Vec, + ) -> datafusion_common::Result<()> { + not_impl_err!("Method not implemented") + } + + fn try_decode_file_format( + &self, + __buf: &[u8], + __ctx: &SessionContext, + ) -> datafusion_common::Result> { + Ok(Arc::new(ParquetFormatFactory::new())) + } + + fn try_encode_file_format( + &self, + __buf: &[u8], + __node: Arc, + ) -> datafusion_common::Result<()> { + Ok(()) + } + + fn try_decode_udf( + &self, + name: &str, + __buf: &[u8], + ) -> datafusion_common::Result> { + not_impl_err!("LogicalExtensionCodec is not provided for scalar function {name}") + } + + fn try_encode_udf( + &self, + __node: &datafusion_expr::ScalarUDF, + __buf: &mut Vec, + ) -> datafusion_common::Result<()> { + Ok(()) + } +} + +#[derive(Debug)] +pub struct ArrowLogicalExtensionCodec; + +// TODO! This is a placeholder for now and needs to be implemented for real. +impl LogicalExtensionCodec for ArrowLogicalExtensionCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[datafusion_expr::LogicalPlan], + _ctx: &datafusion::prelude::SessionContext, + ) -> datafusion_common::Result { + not_impl_err!("Method not implemented") + } + + fn try_encode( + &self, + _node: &datafusion_expr::Extension, + _buf: &mut Vec, + ) -> datafusion_common::Result<()> { + not_impl_err!("Method not implemented") + } + + fn try_decode_table_provider( + &self, + _buf: &[u8], + _schema: arrow::datatypes::SchemaRef, + _ctx: &datafusion::prelude::SessionContext, + ) -> datafusion_common::Result< + std::sync::Arc, + > { + not_impl_err!("Method not implemented") + } + + fn try_encode_table_provider( + &self, + _node: std::sync::Arc, + _buf: &mut Vec, + ) -> datafusion_common::Result<()> { + not_impl_err!("Method not implemented") + } + + fn try_decode_file_format( + &self, + __buf: &[u8], + __ctx: &SessionContext, + ) -> datafusion_common::Result> { + Ok(Arc::new(ArrowFormatFactory::new())) + } + + fn try_encode_file_format( + &self, + __buf: &[u8], + __node: Arc, + ) -> datafusion_common::Result<()> { + Ok(()) + } + + fn try_decode_udf( + &self, + name: &str, + __buf: &[u8], + ) -> datafusion_common::Result> { + not_impl_err!("LogicalExtensionCodec is not provided for scalar function {name}") + } + + fn try_encode_udf( + &self, + __node: &datafusion_expr::ScalarUDF, + __buf: &mut Vec, + ) -> datafusion_common::Result<()> { + Ok(()) + } +} + +#[derive(Debug)] +pub struct AvroLogicalExtensionCodec; + +// TODO! This is a placeholder for now and needs to be implemented for real. +impl LogicalExtensionCodec for AvroLogicalExtensionCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[datafusion_expr::LogicalPlan], + _ctx: &datafusion::prelude::SessionContext, + ) -> datafusion_common::Result { + not_impl_err!("Method not implemented") + } + + fn try_encode( + &self, + _node: &datafusion_expr::Extension, + _buf: &mut Vec, + ) -> datafusion_common::Result<()> { + not_impl_err!("Method not implemented") + } + + fn try_decode_table_provider( + &self, + _buf: &[u8], + _schema: arrow::datatypes::SchemaRef, + _cts: &datafusion::prelude::SessionContext, + ) -> datafusion_common::Result< + std::sync::Arc, + > { + not_impl_err!("Method not implemented") + } + + fn try_encode_table_provider( + &self, + _node: std::sync::Arc, + _buf: &mut Vec, + ) -> datafusion_common::Result<()> { + not_impl_err!("Method not implemented") + } + + fn try_decode_file_format( + &self, + __buf: &[u8], + __ctx: &SessionContext, + ) -> datafusion_common::Result> { + Ok(Arc::new(ArrowFormatFactory::new())) + } + + fn try_encode_file_format( + &self, + __buf: &[u8], + __node: Arc, + ) -> datafusion_common::Result<()> { + Ok(()) + } + + fn try_decode_udf( + &self, + name: &str, + __buf: &[u8], + ) -> datafusion_common::Result> { + not_impl_err!("LogicalExtensionCodec is not provided for scalar function {name}") + } + + fn try_encode_udf( + &self, + __node: &datafusion_expr::ScalarUDF, + __buf: &mut Vec, + ) -> datafusion_common::Result<()> { + Ok(()) + } +} diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index ef37150a35db5..cdb9d5260a0f2 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -33,6 +33,9 @@ use crate::protobuf::{proto_error, FromProtoError, ToProtoError}; use arrow::datatypes::{DataType, Schema, SchemaRef}; #[cfg(feature = "parquet")] use datafusion::datasource::file_format::parquet::ParquetFormat; +use datafusion::datasource::file_format::{ + file_type_to_format, format_as_file_type, FileFormatFactory, +}; use datafusion::{ datasource::{ file_format::{avro::AvroFormat, csv::CsvFormat, FileFormat}, @@ -43,6 +46,7 @@ use datafusion::{ datasource::{provider_as_source, source_as_provider}, prelude::SessionContext, }; +use datafusion_common::file_options::file_type::FileType; use datafusion_common::{ context, internal_datafusion_err, internal_err, not_impl_err, DataFusionError, Result, TableReference, @@ -64,6 +68,7 @@ use prost::Message; use self::to_proto::serialize_expr; +pub mod file_formats; pub mod from_proto; pub mod to_proto; @@ -114,6 +119,22 @@ pub trait LogicalExtensionCodec: Debug + Send + Sync { buf: &mut Vec, ) -> Result<()>; + fn try_decode_file_format( + &self, + _buf: &[u8], + _ctx: &SessionContext, + ) -> Result> { + not_impl_err!("LogicalExtensionCodec is not provided for file format") + } + + fn try_encode_file_format( + &self, + _buf: &[u8], + _node: Arc, + ) -> Result<()> { + Ok(()) + } + fn try_decode_udf(&self, name: &str, _buf: &[u8]) -> Result> { not_impl_err!("LogicalExtensionCodec is not provided for scalar function {name}") } @@ -829,12 +850,16 @@ impl AsLogicalPlan for LogicalPlanNode { let input: LogicalPlan = into_logical_plan!(copy.input, ctx, extension_codec)?; + let file_type: Arc = format_as_file_type( + extension_codec.try_decode_file_format(©.file_type, ctx)?, + ); + Ok(datafusion_expr::LogicalPlan::Copy( datafusion_expr::dml::CopyTo { input: Arc::new(input), output_url: copy.output_url.clone(), partition_by: copy.partition_by.clone(), - format_options: convert_required!(copy.format_options)?, + file_type, options: Default::default(), }, )) @@ -1609,7 +1634,7 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlan::Copy(dml::CopyTo { input, output_url, - format_options, + file_type, partition_by, .. }) => { @@ -1618,12 +1643,16 @@ impl AsLogicalPlan for LogicalPlanNode { extension_codec, )?; + let buf = Vec::new(); + extension_codec + .try_encode_file_format(&buf, file_type_to_format(file_type)?)?; + Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CopyTo(Box::new( protobuf::CopyToNode { input: Some(Box::new(input)), output_url: output_url.to_string(), - format_options: Some(format_options.try_into()?), + file_type: buf, partition_by: partition_by.clone(), }, ))), diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index b636c77641c7c..7783c15611854 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -41,14 +41,13 @@ use datafusion::physical_plan::expressions::{ }; use datafusion::physical_plan::windows::{create_window_expr, schema_add_window_field}; use datafusion::physical_plan::{Partitioning, PhysicalExpr, WindowExpr}; -use datafusion_common::config::FormatOptions; use datafusion_common::{not_impl_err, DataFusionError, Result}; use datafusion_proto_common::common::proto_error; use crate::convert_required; use crate::logical_plan::{self}; +use crate::protobuf; use crate::protobuf::physical_expr_node::ExprType; -use crate::protobuf::{self, copy_to_node}; use super::PhysicalExtensionCodec; @@ -653,22 +652,3 @@ impl TryFrom<&protobuf::FileSinkConfig> for FileSinkConfig { }) } } - -impl TryFrom<©_to_node::FormatOptions> for FormatOptions { - type Error = DataFusionError; - fn try_from(value: ©_to_node::FormatOptions) -> Result { - Ok(match value { - copy_to_node::FormatOptions::Csv(options) => { - FormatOptions::CSV(options.try_into()?) - } - copy_to_node::FormatOptions::Json(options) => { - FormatOptions::JSON(options.try_into()?) - } - copy_to_node::FormatOptions::Parquet(options) => { - FormatOptions::PARQUET(options.try_into()?) - } - copy_to_node::FormatOptions::Avro(_) => FormatOptions::AVRO, - copy_to_node::FormatOptions::Arrow(_) => FormatOptions::ARROW, - }) - } -} diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index c02b59d06230f..8583900e9fa70 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -39,12 +39,11 @@ use datafusion::{ }, physical_plan::expressions::LikeExpr, }; -use datafusion_common::config::FormatOptions; use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; use crate::protobuf::{ - self, copy_to_node, physical_aggregate_expr_node, physical_window_expr_node, - PhysicalSortExprNode, PhysicalSortExprNodeCollection, + self, physical_aggregate_expr_node, physical_window_expr_node, PhysicalSortExprNode, + PhysicalSortExprNodeCollection, }; use super::PhysicalExtensionCodec; @@ -728,26 +727,3 @@ impl TryFrom<&FileSinkConfig> for protobuf::FileSinkConfig { }) } } - -impl TryFrom<&FormatOptions> for copy_to_node::FormatOptions { - type Error = DataFusionError; - fn try_from(value: &FormatOptions) -> std::result::Result { - Ok(match value { - FormatOptions::CSV(options) => { - copy_to_node::FormatOptions::Csv(options.try_into()?) - } - FormatOptions::JSON(options) => { - copy_to_node::FormatOptions::Json(options.try_into()?) - } - FormatOptions::PARQUET(options) => { - copy_to_node::FormatOptions::Parquet(options.try_into()?) - } - FormatOptions::AVRO => { - copy_to_node::FormatOptions::Avro(protobuf::AvroOptions {}) - } - FormatOptions::ARROW => { - copy_to_node::FormatOptions::Arrow(protobuf::ArrowOptions {}) - } - }) - } -} diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 510ebe9a98019..d54078b72bb72 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -26,6 +26,13 @@ use arrow::datatypes::{ DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, }; +use datafusion::datasource::file_format::arrow::ArrowFormatFactory; +use datafusion::datasource::file_format::csv::CsvFormatFactory; +use datafusion::datasource::file_format::format_as_file_type; +use datafusion::datasource::file_format::parquet::ParquetFormatFactory; +use datafusion_proto::logical_plan::file_formats::{ + ArrowLogicalExtensionCodec, CsvLogicalExtensionCodec, ParquetLogicalExtensionCodec, +}; use prost::Message; use datafusion::datasource::provider::TableProviderFactory; @@ -41,11 +48,11 @@ use datafusion::functions_aggregate::expr_fn::{ }; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; -use datafusion_common::config::{FormatOptions, TableOptions}; +use datafusion_common::config::TableOptions; use datafusion_common::scalar::ScalarStructBuilder; use datafusion_common::{ internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema, DFSchemaRef, - DataFusionError, FileType, Result, ScalarValue, + DataFusionError, Result, ScalarValue, }; use datafusion_expr::dml::CopyTo; use datafusion_expr::expr::{ @@ -326,20 +333,20 @@ async fn roundtrip_logical_plan_copy_to_sql_options() -> Result<()> { let ctx = SessionContext::new(); let input = create_csv_scan(&ctx).await?; - let mut table_options = ctx.copied_table_options(); - table_options.set_file_format(FileType::CSV); - table_options.set("format.delimiter", ";")?; + let file_type = format_as_file_type(Arc::new(CsvFormatFactory::new())); let plan = LogicalPlan::Copy(CopyTo { input: Arc::new(input), output_url: "test.csv".to_string(), partition_by: vec!["a".to_string(), "b".to_string(), "c".to_string()], - format_options: FormatOptions::CSV(table_options.csv.clone()), + file_type, options: Default::default(), }); - let bytes = logical_plan_to_bytes(&plan)?; - let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + let codec = CsvLogicalExtensionCodec {}; + let bytes = logical_plan_to_bytes_with_extension_codec(&plan, &codec)?; + let logical_round_trip = + logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &codec)?; assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); Ok(()) @@ -364,26 +371,27 @@ async fn roundtrip_logical_plan_copy_to_writer_options() -> Result<()> { parquet_format.global.dictionary_page_size_limit = 444; parquet_format.global.max_row_group_size = 555; + let file_type = format_as_file_type(Arc::new(ParquetFormatFactory::new())); + let plan = LogicalPlan::Copy(CopyTo { input: Arc::new(input), output_url: "test.parquet".to_string(), - format_options: FormatOptions::PARQUET(parquet_format.clone()), + file_type, partition_by: vec!["a".to_string(), "b".to_string(), "c".to_string()], options: Default::default(), }); - let bytes = logical_plan_to_bytes(&plan)?; - let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + let codec = ParquetLogicalExtensionCodec {}; + let bytes = logical_plan_to_bytes_with_extension_codec(&plan, &codec)?; + let logical_round_trip = + logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &codec)?; assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); match logical_round_trip { LogicalPlan::Copy(copy_to) => { assert_eq!("test.parquet", copy_to.output_url); assert_eq!(vec!["a", "b", "c"], copy_to.partition_by); - assert_eq!( - copy_to.format_options, - FormatOptions::PARQUET(parquet_format) - ); + assert_eq!(copy_to.file_type.get_ext(), "parquet".to_string()); } _ => panic!(), } @@ -396,22 +404,26 @@ async fn roundtrip_logical_plan_copy_to_arrow() -> Result<()> { let input = create_csv_scan(&ctx).await?; + let file_type = format_as_file_type(Arc::new(ArrowFormatFactory::new())); + let plan = LogicalPlan::Copy(CopyTo { input: Arc::new(input), output_url: "test.arrow".to_string(), partition_by: vec!["a".to_string(), "b".to_string(), "c".to_string()], - format_options: FormatOptions::ARROW, + file_type, options: Default::default(), }); - let bytes = logical_plan_to_bytes(&plan)?; - let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + let codec = ArrowLogicalExtensionCodec {}; + let bytes = logical_plan_to_bytes_with_extension_codec(&plan, &codec)?; + let logical_round_trip = + logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &codec)?; assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); match logical_round_trip { LogicalPlan::Copy(copy_to) => { assert_eq!("test.arrow", copy_to.output_url); - assert_eq!(FormatOptions::ARROW, copy_to.format_options); + assert_eq!("arrow".to_string(), copy_to.file_type.get_ext()); assert_eq!(vec!["a", "b", "c"], copy_to.partition_by); } _ => panic!(), @@ -437,22 +449,26 @@ async fn roundtrip_logical_plan_copy_to_csv() -> Result<()> { csv_format.time_format = Some("HH:mm:ss".to_string()); csv_format.null_value = Some("NIL".to_string()); + let file_type = format_as_file_type(Arc::new(CsvFormatFactory::new())); + let plan = LogicalPlan::Copy(CopyTo { input: Arc::new(input), output_url: "test.csv".to_string(), partition_by: vec!["a".to_string(), "b".to_string(), "c".to_string()], - format_options: FormatOptions::CSV(csv_format.clone()), + file_type, options: Default::default(), }); - let bytes = logical_plan_to_bytes(&plan)?; - let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + let codec = CsvLogicalExtensionCodec {}; + let bytes = logical_plan_to_bytes_with_extension_codec(&plan, &codec)?; + let logical_round_trip = + logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &codec)?; assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); match logical_round_trip { LogicalPlan::Copy(copy_to) => { assert_eq!("test.csv", copy_to.output_url); - assert_eq!(FormatOptions::CSV(csv_format), copy_to.format_options); + assert_eq!("csv".to_string(), copy_to.file_type.get_ext()); assert_eq!(vec!["a", "b", "c"], copy_to.partition_by); } _ => panic!(), diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 30f95170a34fd..63ef86446aaf4 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -21,6 +21,7 @@ use std::sync::Arc; use std::vec; use arrow_schema::*; +use datafusion_common::file_options::file_type::FileType; use datafusion_common::{ field_not_found, internal_err, plan_datafusion_err, DFSchemaRef, SchemaError, }; @@ -48,6 +49,11 @@ use crate::utils::make_decimal_type; pub trait ContextProvider { /// Getter for a datasource fn get_table_source(&self, name: TableReference) -> Result>; + + fn get_file_type(&self, _ext: &str) -> Result> { + not_impl_err!("Registered file types are not supported") + } + /// Getter for a table function fn get_table_function_source( &self, diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index cb492b390c764..518972545a484 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -34,8 +34,8 @@ use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{ exec_err, not_impl_err, plan_datafusion_err, plan_err, schema_err, unqualified_field_not_found, Column, Constraints, DFSchema, DFSchemaRef, - DataFusionError, FileType, Result, ScalarValue, SchemaError, SchemaReference, - TableReference, ToDFSchema, + DataFusionError, Result, ScalarValue, SchemaError, SchemaReference, TableReference, + ToDFSchema, }; use datafusion_expr::dml::CopyTo; use datafusion_expr::expr_rewriter::normalize_col_with_schemas_and_ambiguity_check; @@ -899,31 +899,35 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - let file_type = if let Some(file_type) = statement.stored_as { - FileType::from_str(&file_type).map_err(|_| { - DataFusionError::Configuration(format!("Unknown FileType {}", file_type)) - })? + let maybe_file_type = if let Some(stored_as) = &statement.stored_as { + if let Ok(ext_file_type) = self.context_provider.get_file_type(stored_as) { + Some(ext_file_type) + } else { + None + } } else { - let e = || { - DataFusionError::Configuration( - "Format not explicitly set and unable to get file extension! Use STORED AS to define file format." - .to_string(), - ) - }; - // try to infer file format from file extension - let extension: &str = &Path::new(&statement.target) - .extension() - .ok_or_else(e)? - .to_str() - .ok_or_else(e)? - .to_lowercase(); - - FileType::from_str(extension).map_err(|e| { - DataFusionError::Configuration(format!( - "{}. Use STORED AS to define file format.", - e - )) - })? + None + }; + + let file_type = match maybe_file_type { + Some(ft) => ft, + None => { + let e = || { + DataFusionError::Configuration( + "Format not explicitly set and unable to get file extension! Use STORED AS to define file format." + .to_string(), + ) + }; + // try to infer file format from file extension + let extension: &str = &Path::new(&statement.target) + .extension() + .ok_or_else(e)? + .to_str() + .ok_or_else(e)? + .to_lowercase(); + + self.context_provider.get_file_type(extension)? + } }; let partition_by = statement @@ -938,7 +942,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(LogicalPlan::Copy(CopyTo { input: Arc::new(input), output_url: statement.target, - format_options: file_type.into(), + file_type, partition_by, options, })) diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs index 893678d6b3742..f5caaefb3ea08 100644 --- a/datafusion/sql/tests/common/mod.rs +++ b/datafusion/sql/tests/common/mod.rs @@ -15,16 +15,39 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; #[cfg(test)] use std::collections::HashMap; +use std::fmt::Display; use std::{sync::Arc, vec}; use arrow_schema::*; use datafusion_common::config::ConfigOptions; -use datafusion_common::{plan_err, Result, TableReference}; +use datafusion_common::file_options::file_type::FileType; +use datafusion_common::{plan_err, GetExt, Result, TableReference}; use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF}; use datafusion_sql::planner::ContextProvider; +struct MockCsvType {} + +impl GetExt for MockCsvType { + fn get_ext(&self) -> String { + "csv".to_string() + } +} + +impl FileType for MockCsvType { + fn as_any(&self) -> &dyn Any { + self + } +} + +impl Display for MockCsvType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.get_ext()) + } +} + #[derive(Default)] pub(crate) struct MockContextProvider { options: ConfigOptions, @@ -191,6 +214,13 @@ impl ContextProvider for MockContextProvider { &self.options } + fn get_file_type( + &self, + _ext: &str, + ) -> Result> { + Ok(Arc::new(MockCsvType {})) + } + fn create_cte_work_table( &self, _name: &str,