Skip to content

Commit

Permalink
Fix FormatOptions::CSV propagation (apache#10912)
Browse files Browse the repository at this point in the history
* Fix sink output schema being passed in to `FileSinkExec` where input schema was expected

* Propagate CSV options (quote, double quote, and escape) through protos

* Add test for double quotes

* Test quote escape when double quotes are disabled

* regen

---------

Co-authored-by: svranesevic <[email protected]>
Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
3 people authored and findepi committed Jul 16, 2024
1 parent b26d5c0 commit 814e5b9
Show file tree
Hide file tree
Showing 11 changed files with 226 additions and 7 deletions.
9 changes: 9 additions & 0 deletions datafusion/common/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1566,6 +1566,7 @@ config_namespace! {
pub delimiter: u8, default = b','
pub quote: u8, default = b'"'
pub escape: Option<u8>, default = None
pub double_quote: Option<bool>, default = None
pub compression: CompressionTypeVariant, default = CompressionTypeVariant::UNCOMPRESSED
pub schema_infer_max_rec: usize, default = 100
pub date_format: Option<String>, default = None
Expand Down Expand Up @@ -1631,6 +1632,13 @@ impl CsvOptions {
self
}

/// Set true to indicate that the CSV quotes should be doubled.
/// - default to true
pub fn with_double_quote(mut self, double_quote: bool) -> Self {
self.double_quote = Some(double_quote);
self
}

/// Set a `CompressionTypeVariant` of CSV
/// - defaults to `CompressionTypeVariant::UNCOMPRESSED`
pub fn with_file_compression_type(
Expand Down Expand Up @@ -1675,6 +1683,7 @@ pub enum FormatOptions {
AVRO,
ARROW,
}

impl Display for FormatOptions {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let out = match self {
Expand Down
6 changes: 6 additions & 0 deletions datafusion/common/src/file_options/csv_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ impl TryFrom<&CsvOptions> for CsvWriterOptions {
if let Some(v) = &value.null_value {
builder = builder.with_null(v.into())
}
if let Some(v) = &value.escape {
builder = builder.with_escape(*v)
}
if let Some(v) = &value.double_quote {
builder = builder.with_double_quote(*v)
}
Ok(CsvWriterOptions {
writer_options: builder,
compression: value.compression,
Expand Down
5 changes: 5 additions & 0 deletions datafusion/core/tests/data/double_quote.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
c1,c2
id0,"""value0"""
id1,"""value1"""
id2,"""value2"""
id3,"""value3"""
7 changes: 7 additions & 0 deletions datafusion/proto-common/proto/datafusion_common.proto
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,12 @@ message CsvWriterOptions {
string time_format = 7;
// Optional value to represent null
string null_value = 8;
// Optional quote. Defaults to `b'"'`
string quote = 9;
// Optional escape. Defaults to `'\\'`
string escape = 10;
// Optional flag whether to double quotes, instead of escaping. Defaults to `true`
bool double_quote = 11;
}

// Options controlling CSV format
Expand All @@ -402,6 +408,7 @@ message CsvOptions {
string time_format = 11; // Optional time format
string null_value = 12; // Optional representation of null value
bytes comment = 13; // Optional comment character as a byte
bytes double_quote = 14; // Indicates if quotes are doubled
}

// Options controlling CSV format
Expand Down
26 changes: 25 additions & 1 deletion datafusion/proto-common/src/from_proto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,7 @@ impl TryFrom<&protobuf::CsvOptions> for CsvOptions {
delimiter: proto_opts.delimiter[0],
quote: proto_opts.quote[0],
escape: proto_opts.escape.first().copied(),
double_quote: proto_opts.has_header.first().map(|h| *h != 0),
compression: proto_opts.compression().into(),
schema_infer_max_rec: proto_opts.schema_infer_max_rec as usize,
date_format: (!proto_opts.date_format.is_empty())
Expand Down Expand Up @@ -1091,11 +1092,34 @@ pub(crate) fn csv_writer_options_from_proto(
return Err(proto_error("Error parsing CSV Delimiter"));
}
}
if !writer_options.quote.is_empty() {
if let Some(quote) = writer_options.quote.chars().next() {
if quote.is_ascii() {
builder = builder.with_quote(quote as u8);
} else {
return Err(proto_error("CSV Quote is not ASCII"));
}
} else {
return Err(proto_error("Error parsing CSV Quote"));
}
}
if !writer_options.escape.is_empty() {
if let Some(escape) = writer_options.escape.chars().next() {
if escape.is_ascii() {
builder = builder.with_escape(escape as u8);
} else {
return Err(proto_error("CSV Escape is not ASCII"));
}
} else {
return Err(proto_error("Error parsing CSV Escape"));
}
}
Ok(builder
.with_header(writer_options.has_header)
.with_date_format(writer_options.date_format.clone())
.with_datetime_format(writer_options.datetime_format.clone())
.with_timestamp_format(writer_options.timestamp_format.clone())
.with_time_format(writer_options.time_format.clone())
.with_null(writer_options.null_value.clone()))
.with_null(writer_options.null_value.clone())
.with_double_quote(writer_options.double_quote))
}
73 changes: 73 additions & 0 deletions datafusion/proto-common/src/generated/pbjson.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1881,6 +1881,9 @@ impl serde::Serialize for CsvOptions {
if !self.comment.is_empty() {
len += 1;
}
if !self.double_quote.is_empty() {
len += 1;
}
let mut struct_ser = serializer.serialize_struct("datafusion_common.CsvOptions", len)?;
if !self.has_header.is_empty() {
#[allow(clippy::needless_borrow)]
Expand Down Expand Up @@ -1929,6 +1932,10 @@ impl serde::Serialize for CsvOptions {
#[allow(clippy::needless_borrow)]
struct_ser.serialize_field("comment", pbjson::private::base64::encode(&self.comment).as_str())?;
}
if !self.double_quote.is_empty() {
#[allow(clippy::needless_borrow)]
struct_ser.serialize_field("doubleQuote", pbjson::private::base64::encode(&self.double_quote).as_str())?;
}
struct_ser.end()
}
}
Expand Down Expand Up @@ -1960,6 +1967,8 @@ impl<'de> serde::Deserialize<'de> for CsvOptions {
"null_value",
"nullValue",
"comment",
"double_quote",
"doubleQuote",
];

#[allow(clippy::enum_variant_names)]
Expand All @@ -1977,6 +1986,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions {
TimeFormat,
NullValue,
Comment,
DoubleQuote,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
fn deserialize<D>(deserializer: D) -> std::result::Result<GeneratedField, D::Error>
Expand Down Expand Up @@ -2011,6 +2021,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions {
"timeFormat" | "time_format" => Ok(GeneratedField::TimeFormat),
"nullValue" | "null_value" => Ok(GeneratedField::NullValue),
"comment" => Ok(GeneratedField::Comment),
"doubleQuote" | "double_quote" => Ok(GeneratedField::DoubleQuote),
_ => Err(serde::de::Error::unknown_field(value, FIELDS)),
}
}
Expand Down Expand Up @@ -2043,6 +2054,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions {
let mut time_format__ = None;
let mut null_value__ = None;
let mut comment__ = None;
let mut double_quote__ = None;
while let Some(k) = map_.next_key()? {
match k {
GeneratedField::HasHeader => {
Expand Down Expand Up @@ -2135,6 +2147,14 @@ impl<'de> serde::Deserialize<'de> for CsvOptions {
Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0)
;
}
GeneratedField::DoubleQuote => {
if double_quote__.is_some() {
return Err(serde::de::Error::duplicate_field("doubleQuote"));
}
double_quote__ =
Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0)
;
}
}
}
Ok(CsvOptions {
Expand All @@ -2151,6 +2171,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions {
time_format: time_format__.unwrap_or_default(),
null_value: null_value__.unwrap_or_default(),
comment: comment__.unwrap_or_default(),
double_quote: double_quote__.unwrap_or_default(),
})
}
}
Expand Down Expand Up @@ -2189,6 +2210,15 @@ impl serde::Serialize for CsvWriterOptions {
if !self.null_value.is_empty() {
len += 1;
}
if !self.quote.is_empty() {
len += 1;
}
if !self.escape.is_empty() {
len += 1;
}
if self.double_quote {
len += 1;
}
let mut struct_ser = serializer.serialize_struct("datafusion_common.CsvWriterOptions", len)?;
if self.compression != 0 {
let v = CompressionTypeVariant::try_from(self.compression)
Expand Down Expand Up @@ -2216,6 +2246,15 @@ impl serde::Serialize for CsvWriterOptions {
if !self.null_value.is_empty() {
struct_ser.serialize_field("nullValue", &self.null_value)?;
}
if !self.quote.is_empty() {
struct_ser.serialize_field("quote", &self.quote)?;
}
if !self.escape.is_empty() {
struct_ser.serialize_field("escape", &self.escape)?;
}
if self.double_quote {
struct_ser.serialize_field("doubleQuote", &self.double_quote)?;
}
struct_ser.end()
}
}
Expand All @@ -2240,6 +2279,10 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions {
"timeFormat",
"null_value",
"nullValue",
"quote",
"escape",
"double_quote",
"doubleQuote",
];

#[allow(clippy::enum_variant_names)]
Expand All @@ -2252,6 +2295,9 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions {
TimestampFormat,
TimeFormat,
NullValue,
Quote,
Escape,
DoubleQuote,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
fn deserialize<D>(deserializer: D) -> std::result::Result<GeneratedField, D::Error>
Expand Down Expand Up @@ -2281,6 +2327,9 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions {
"timestampFormat" | "timestamp_format" => Ok(GeneratedField::TimestampFormat),
"timeFormat" | "time_format" => Ok(GeneratedField::TimeFormat),
"nullValue" | "null_value" => Ok(GeneratedField::NullValue),
"quote" => Ok(GeneratedField::Quote),
"escape" => Ok(GeneratedField::Escape),
"doubleQuote" | "double_quote" => Ok(GeneratedField::DoubleQuote),
_ => Err(serde::de::Error::unknown_field(value, FIELDS)),
}
}
Expand Down Expand Up @@ -2308,6 +2357,9 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions {
let mut timestamp_format__ = None;
let mut time_format__ = None;
let mut null_value__ = None;
let mut quote__ = None;
let mut escape__ = None;
let mut double_quote__ = None;
while let Some(k) = map_.next_key()? {
match k {
GeneratedField::Compression => {
Expand Down Expand Up @@ -2358,6 +2410,24 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions {
}
null_value__ = Some(map_.next_value()?);
}
GeneratedField::Quote => {
if quote__.is_some() {
return Err(serde::de::Error::duplicate_field("quote"));
}
quote__ = Some(map_.next_value()?);
}
GeneratedField::Escape => {
if escape__.is_some() {
return Err(serde::de::Error::duplicate_field("escape"));
}
escape__ = Some(map_.next_value()?);
}
GeneratedField::DoubleQuote => {
if double_quote__.is_some() {
return Err(serde::de::Error::duplicate_field("doubleQuote"));
}
double_quote__ = Some(map_.next_value()?);
}
}
}
Ok(CsvWriterOptions {
Expand All @@ -2369,6 +2439,9 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions {
timestamp_format: timestamp_format__.unwrap_or_default(),
time_format: time_format__.unwrap_or_default(),
null_value: null_value__.unwrap_or_default(),
quote: quote__.unwrap_or_default(),
escape: escape__.unwrap_or_default(),
double_quote: double_quote__.unwrap_or_default(),
})
}
}
Expand Down
12 changes: 12 additions & 0 deletions datafusion/proto-common/src/generated/prost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,15 @@ pub struct CsvWriterOptions {
/// Optional value to represent null
#[prost(string, tag = "8")]
pub null_value: ::prost::alloc::string::String,
/// Optional quote. Defaults to `b'"'`
#[prost(string, tag = "9")]
pub quote: ::prost::alloc::string::String,
/// Optional escape. Defaults to `'\\'`
#[prost(string, tag = "10")]
pub escape: ::prost::alloc::string::String,
/// Optional flag whether to double quote instead of escaping. Defaults to `true`
#[prost(bool, tag = "11")]
pub double_quote: bool,
}
/// Options controlling CSV format
#[allow(clippy::derive_partial_eq_without_eq)]
Expand Down Expand Up @@ -619,6 +628,9 @@ pub struct CsvOptions {
/// Optional comment character as a byte
#[prost(bytes = "vec", tag = "13")]
pub comment: ::prost::alloc::vec::Vec<u8>,
/// Indicates if quotes are doubled
#[prost(bytes = "vec", tag = "14")]
pub double_quote: ::prost::alloc::vec::Vec<u8>,
}
/// Options controlling CSV format
#[allow(clippy::derive_partial_eq_without_eq)]
Expand Down
4 changes: 4 additions & 0 deletions datafusion/proto-common/src/to_proto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,7 @@ impl TryFrom<&CsvOptions> for protobuf::CsvOptions {
delimiter: vec![opts.delimiter],
quote: vec![opts.quote],
escape: opts.escape.map_or_else(Vec::new, |e| vec![e]),
double_quote: opts.double_quote.map_or_else(Vec::new, |h| vec![h as u8]),
compression: compression.into(),
schema_infer_max_rec: opts.schema_infer_max_rec as u64,
date_format: opts.date_format.clone().unwrap_or_default(),
Expand Down Expand Up @@ -1022,5 +1023,8 @@ pub(crate) fn csv_writer_options_to_proto(
timestamp_format: csv_options.timestamp_format().unwrap_or("").to_owned(),
time_format: csv_options.time_format().unwrap_or("").to_owned(),
null_value: csv_options.null().to_owned(),
quote: (csv_options.quote() as char).to_string(),
escape: (csv_options.escape() as char).to_string(),
double_quote: csv_options.double_quote(),
}
}
12 changes: 12 additions & 0 deletions datafusion/proto/src/generated/datafusion_proto_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,15 @@ pub struct CsvWriterOptions {
/// Optional value to represent null
#[prost(string, tag = "8")]
pub null_value: ::prost::alloc::string::String,
/// Optional quote. Defaults to `b'"'`
#[prost(string, tag = "9")]
pub quote: ::prost::alloc::string::String,
/// Optional escape. Defaults to `'\\'`
#[prost(string, tag = "10")]
pub escape: ::prost::alloc::string::String,
/// Optional flag whether to double quotes, instead of escaping. Defaults to `true`
#[prost(bool, tag = "11")]
pub double_quote: bool,
}
/// Options controlling CSV format
#[allow(clippy::derive_partial_eq_without_eq)]
Expand Down Expand Up @@ -619,6 +628,9 @@ pub struct CsvOptions {
/// Optional comment character as a byte
#[prost(bytes = "vec", tag = "13")]
pub comment: ::prost::alloc::vec::Vec<u8>,
/// Indicates if quotes are doubled
#[prost(bytes = "vec", tag = "14")]
pub double_quote: ::prost::alloc::vec::Vec<u8>,
}
/// Options controlling CSV format
#[allow(clippy::derive_partial_eq_without_eq)]
Expand Down
Loading

0 comments on commit 814e5b9

Please sign in to comment.