Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix FormatOptions::CSV propagation #10912

Merged
merged 6 commits into from
Jun 22, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions datafusion/common/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1559,6 +1559,7 @@ config_namespace! {
pub delimiter: u8, default = b','
pub quote: u8, default = b'"'
pub escape: Option<u8>, default = None
pub double_quote: Option<bool>, default = None
pub compression: CompressionTypeVariant, default = CompressionTypeVariant::UNCOMPRESSED
pub schema_infer_max_rec: usize, default = 100
pub date_format: Option<String>, default = None
Expand Down Expand Up @@ -1624,6 +1625,13 @@ impl CsvOptions {
self
}

/// Set true to indicate that the CSV quotes should be doubled.
/// - default to true
pub fn with_double_quote(mut self, double_quote: bool) -> Self {
self.double_quote = Some(double_quote);
self
}

/// Set a `CompressionTypeVariant` of CSV
/// - defaults to `CompressionTypeVariant::UNCOMPRESSED`
pub fn with_file_compression_type(
Expand Down Expand Up @@ -1668,6 +1676,7 @@ pub enum FormatOptions {
AVRO,
ARROW,
}

impl Display for FormatOptions {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let out = match self {
Expand Down
6 changes: 6 additions & 0 deletions datafusion/common/src/file_options/csv_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ impl TryFrom<&CsvOptions> for CsvWriterOptions {
if let Some(v) = &value.null_value {
builder = builder.with_null(v.into())
}
if let Some(v) = &value.escape {
builder = builder.with_escape(*v)
}
if let Some(v) = &value.double_quote {
builder = builder.with_double_quote(*v)
}
Ok(CsvWriterOptions {
writer_options: builder,
compression: value.compression,
Expand Down
5 changes: 5 additions & 0 deletions datafusion/core/tests/data/double_quote.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
c1,c2
id0,"""value0"""
id1,"""value1"""
id2,"""value2"""
id3,"""value3"""
7 changes: 7 additions & 0 deletions datafusion/proto-common/proto/datafusion_common.proto
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,12 @@ message CsvWriterOptions {
string time_format = 7;
// Optional value to represent null
string null_value = 8;
// Optional quote. Defaults to `b'"'`
string quote = 9;
// Optional escape. Defaults to `'\\'`
string escape = 10;
// Optional flag whether to double quotes, instead of escaping. Defaults to `true`
bool double_quote = 11;
}

// Options controlling CSV format
Expand All @@ -398,6 +404,7 @@ message CsvOptions {
string time_format = 11; // Optional time format
string null_value = 12; // Optional representation of null value
bytes comment = 13; // Optional comment character as a byte
bytes double_quote = 14; // Indicates if quotes are doubled
}

// Options controlling CSV format
Expand Down
26 changes: 25 additions & 1 deletion datafusion/proto-common/src/from_proto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -853,6 +853,7 @@ impl TryFrom<&protobuf::CsvOptions> for CsvOptions {
delimiter: proto_opts.delimiter[0],
quote: proto_opts.quote[0],
escape: proto_opts.escape.first().copied(),
double_quote: proto_opts.has_header.first().map(|h| *h != 0),
compression: proto_opts.compression().into(),
schema_infer_max_rec: proto_opts.schema_infer_max_rec as usize,
date_format: (!proto_opts.date_format.is_empty())
Expand Down Expand Up @@ -1087,11 +1088,34 @@ pub(crate) fn csv_writer_options_from_proto(
return Err(proto_error("Error parsing CSV Delimiter"));
}
}
if !writer_options.quote.is_empty() {
if let Some(quote) = writer_options.quote.chars().next() {
if quote.is_ascii() {
builder = builder.with_quote(quote as u8);
} else {
return Err(proto_error("CSV Quote is not ASCII"));
}
} else {
return Err(proto_error("Error parsing CSV Quote"));
}
}
if !writer_options.escape.is_empty() {
if let Some(escape) = writer_options.escape.chars().next() {
if escape.is_ascii() {
builder = builder.with_escape(escape as u8);
} else {
return Err(proto_error("CSV Escape is not ASCII"));
}
} else {
return Err(proto_error("Error parsing CSV Escape"));
}
}
Ok(builder
.with_header(writer_options.has_header)
.with_date_format(writer_options.date_format.clone())
.with_datetime_format(writer_options.datetime_format.clone())
.with_timestamp_format(writer_options.timestamp_format.clone())
.with_time_format(writer_options.time_format.clone())
.with_null(writer_options.null_value.clone()))
.with_null(writer_options.null_value.clone())
.with_double_quote(writer_options.double_quote))
}
73 changes: 73 additions & 0 deletions datafusion/proto-common/src/generated/pbjson.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1853,6 +1853,9 @@ impl serde::Serialize for CsvOptions {
if !self.comment.is_empty() {
len += 1;
}
if !self.double_quote.is_empty() {
len += 1;
}
let mut struct_ser = serializer.serialize_struct("datafusion_common.CsvOptions", len)?;
if !self.has_header.is_empty() {
#[allow(clippy::needless_borrow)]
Expand Down Expand Up @@ -1901,6 +1904,10 @@ impl serde::Serialize for CsvOptions {
#[allow(clippy::needless_borrow)]
struct_ser.serialize_field("comment", pbjson::private::base64::encode(&self.comment).as_str())?;
}
if !self.double_quote.is_empty() {
#[allow(clippy::needless_borrow)]
struct_ser.serialize_field("doubleQuote", pbjson::private::base64::encode(&self.double_quote).as_str())?;
}
struct_ser.end()
}
}
Expand Down Expand Up @@ -1932,6 +1939,8 @@ impl<'de> serde::Deserialize<'de> for CsvOptions {
"null_value",
"nullValue",
"comment",
"double_quote",
"doubleQuote",
];

#[allow(clippy::enum_variant_names)]
Expand All @@ -1949,6 +1958,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions {
TimeFormat,
NullValue,
Comment,
DoubleQuote,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
fn deserialize<D>(deserializer: D) -> std::result::Result<GeneratedField, D::Error>
Expand Down Expand Up @@ -1983,6 +1993,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions {
"timeFormat" | "time_format" => Ok(GeneratedField::TimeFormat),
"nullValue" | "null_value" => Ok(GeneratedField::NullValue),
"comment" => Ok(GeneratedField::Comment),
"doubleQuote" | "double_quote" => Ok(GeneratedField::DoubleQuote),
_ => Err(serde::de::Error::unknown_field(value, FIELDS)),
}
}
Expand Down Expand Up @@ -2015,6 +2026,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions {
let mut time_format__ = None;
let mut null_value__ = None;
let mut comment__ = None;
let mut double_quote__ = None;
while let Some(k) = map_.next_key()? {
match k {
GeneratedField::HasHeader => {
Expand Down Expand Up @@ -2107,6 +2119,14 @@ impl<'de> serde::Deserialize<'de> for CsvOptions {
Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0)
;
}
GeneratedField::DoubleQuote => {
if double_quote__.is_some() {
return Err(serde::de::Error::duplicate_field("doubleQuote"));
}
double_quote__ =
Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0)
;
}
}
}
Ok(CsvOptions {
Expand All @@ -2123,6 +2143,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions {
time_format: time_format__.unwrap_or_default(),
null_value: null_value__.unwrap_or_default(),
comment: comment__.unwrap_or_default(),
double_quote: double_quote__.unwrap_or_default(),
})
}
}
Expand Down Expand Up @@ -2161,6 +2182,15 @@ impl serde::Serialize for CsvWriterOptions {
if !self.null_value.is_empty() {
len += 1;
}
if !self.quote.is_empty() {
len += 1;
}
if !self.escape.is_empty() {
len += 1;
}
if self.double_quote {
len += 1;
}
let mut struct_ser = serializer.serialize_struct("datafusion_common.CsvWriterOptions", len)?;
if self.compression != 0 {
let v = CompressionTypeVariant::try_from(self.compression)
Expand Down Expand Up @@ -2188,6 +2218,15 @@ impl serde::Serialize for CsvWriterOptions {
if !self.null_value.is_empty() {
struct_ser.serialize_field("nullValue", &self.null_value)?;
}
if !self.quote.is_empty() {
struct_ser.serialize_field("quote", &self.quote)?;
}
if !self.escape.is_empty() {
struct_ser.serialize_field("escape", &self.escape)?;
}
if self.double_quote {
struct_ser.serialize_field("doubleQuote", &self.double_quote)?;
}
struct_ser.end()
}
}
Expand All @@ -2212,6 +2251,10 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions {
"timeFormat",
"null_value",
"nullValue",
"quote",
"escape",
"double_quote",
"doubleQuote",
];

#[allow(clippy::enum_variant_names)]
Expand All @@ -2224,6 +2267,9 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions {
TimestampFormat,
TimeFormat,
NullValue,
Quote,
Escape,
DoubleQuote,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
fn deserialize<D>(deserializer: D) -> std::result::Result<GeneratedField, D::Error>
Expand Down Expand Up @@ -2253,6 +2299,9 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions {
"timestampFormat" | "timestamp_format" => Ok(GeneratedField::TimestampFormat),
"timeFormat" | "time_format" => Ok(GeneratedField::TimeFormat),
"nullValue" | "null_value" => Ok(GeneratedField::NullValue),
"quote" => Ok(GeneratedField::Quote),
"escape" => Ok(GeneratedField::Escape),
"doubleQuote" | "double_quote" => Ok(GeneratedField::DoubleQuote),
_ => Err(serde::de::Error::unknown_field(value, FIELDS)),
}
}
Expand Down Expand Up @@ -2280,6 +2329,9 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions {
let mut timestamp_format__ = None;
let mut time_format__ = None;
let mut null_value__ = None;
let mut quote__ = None;
let mut escape__ = None;
let mut double_quote__ = None;
while let Some(k) = map_.next_key()? {
match k {
GeneratedField::Compression => {
Expand Down Expand Up @@ -2330,6 +2382,24 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions {
}
null_value__ = Some(map_.next_value()?);
}
GeneratedField::Quote => {
if quote__.is_some() {
return Err(serde::de::Error::duplicate_field("quote"));
}
quote__ = Some(map_.next_value()?);
}
GeneratedField::Escape => {
if escape__.is_some() {
return Err(serde::de::Error::duplicate_field("escape"));
}
escape__ = Some(map_.next_value()?);
}
GeneratedField::DoubleQuote => {
if double_quote__.is_some() {
return Err(serde::de::Error::duplicate_field("doubleQuote"));
}
double_quote__ = Some(map_.next_value()?);
}
}
}
Ok(CsvWriterOptions {
Expand All @@ -2341,6 +2411,9 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions {
timestamp_format: timestamp_format__.unwrap_or_default(),
time_format: time_format__.unwrap_or_default(),
null_value: null_value__.unwrap_or_default(),
quote: quote__.unwrap_or_default(),
escape: escape__.unwrap_or_default(),
double_quote: double_quote__.unwrap_or_default(),
})
}
}
Expand Down
12 changes: 12 additions & 0 deletions datafusion/proto-common/src/generated/prost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,15 @@ pub struct CsvWriterOptions {
/// Optional value to represent null
#[prost(string, tag = "8")]
pub null_value: ::prost::alloc::string::String,
/// Optional quote. Defaults to `b'"'`
#[prost(string, tag = "9")]
pub quote: ::prost::alloc::string::String,
/// Optional escape. Defaults to `'\\'`
#[prost(string, tag = "10")]
pub escape: ::prost::alloc::string::String,
/// Optional flag whether to double quote instead of escaping. Defaults to `true`
#[prost(bool, tag = "11")]
pub double_quote: bool,
}
/// Options controlling CSV format
#[allow(clippy::derive_partial_eq_without_eq)]
Expand Down Expand Up @@ -611,6 +620,9 @@ pub struct CsvOptions {
/// Optional comment character as a byte
#[prost(bytes = "vec", tag = "13")]
pub comment: ::prost::alloc::vec::Vec<u8>,
/// Indicates if quotes are doubled
#[prost(bytes = "vec", tag = "14")]
pub double_quote: ::prost::alloc::vec::Vec<u8>,
}
/// Options controlling CSV format
#[allow(clippy::derive_partial_eq_without_eq)]
Expand Down
4 changes: 4 additions & 0 deletions datafusion/proto-common/src/to_proto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,7 @@ impl TryFrom<&CsvOptions> for protobuf::CsvOptions {
delimiter: vec![opts.delimiter],
quote: vec![opts.quote],
escape: opts.escape.map_or_else(Vec::new, |e| vec![e]),
double_quote: opts.double_quote.map_or_else(Vec::new, |h| vec![h as u8]),
compression: compression.into(),
schema_infer_max_rec: opts.schema_infer_max_rec as u64,
date_format: opts.date_format.clone().unwrap_or_default(),
Expand Down Expand Up @@ -1012,5 +1013,8 @@ pub(crate) fn csv_writer_options_to_proto(
timestamp_format: csv_options.timestamp_format().unwrap_or("").to_owned(),
time_format: csv_options.time_format().unwrap_or("").to_owned(),
null_value: csv_options.null().to_owned(),
quote: (csv_options.quote() as char).to_string(),
escape: (csv_options.escape() as char).to_string(),
double_quote: csv_options.double_quote(),
}
}
12 changes: 6 additions & 6 deletions datafusion/proto/src/physical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1007,7 +1007,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode {
.as_ref()
.ok_or_else(|| proto_error("Missing required field in protobuf"))?
.try_into()?;
let sink_schema = convert_required!(sink.sink_schema)?;
let sink_schema = input.schema();
let sort_order = sink
.sort_order
.as_ref()
Expand All @@ -1024,7 +1024,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode {
Ok(Arc::new(DataSinkExec::new(
input,
Arc::new(data_sink),
Arc::new(sink_schema),
sink_schema,
sort_order,
)))
}
Expand All @@ -1037,7 +1037,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode {
.as_ref()
.ok_or_else(|| proto_error("Missing required field in protobuf"))?
.try_into()?;
let sink_schema = convert_required!(sink.sink_schema)?;
let sink_schema = input.schema();
let sort_order = sink
.sort_order
.as_ref()
Expand All @@ -1054,7 +1054,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode {
Ok(Arc::new(DataSinkExec::new(
input,
Arc::new(data_sink),
Arc::new(sink_schema),
sink_schema,
sort_order,
)))
}
Expand All @@ -1067,7 +1067,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode {
.as_ref()
.ok_or_else(|| proto_error("Missing required field in protobuf"))?
.try_into()?;
let sink_schema = convert_required!(sink.sink_schema)?;
let sink_schema = input.schema();
let sort_order = sink
.sort_order
.as_ref()
Expand All @@ -1084,7 +1084,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode {
Ok(Arc::new(DataSinkExec::new(
input,
Arc::new(data_sink),
Arc::new(sink_schema),
sink_schema,
sort_order,
)))
}
Expand Down
Loading
Loading