Skip to content

Commit

Permalink
Propagate CSV options (quote, double quote, and escape) through protos
Browse files Browse the repository at this point in the history
  • Loading branch information
svranesevic committed Jun 14, 2024
1 parent 65ad9b6 commit 75a9b4c
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 1 deletion.
1 change: 1 addition & 0 deletions datafusion/common/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1559,6 +1559,7 @@ config_namespace! {
pub delimiter: u8, default = b','
pub quote: u8, default = b'"'
pub escape: Option<u8>, default = None
pub double_quote: Option<bool>, default = None
pub compression: CompressionTypeVariant, default = CompressionTypeVariant::UNCOMPRESSED
pub schema_infer_max_rec: usize, default = 100
pub date_format: Option<String>, default = None
Expand Down
6 changes: 6 additions & 0 deletions datafusion/common/src/file_options/csv_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ impl TryFrom<&CsvOptions> for CsvWriterOptions {
if let Some(v) = &value.null_value {
builder = builder.with_null(v.into())
}
if let Some(v) = &value.escape {
builder = builder.with_escape(*v)
}
if let Some(v) = &value.double_quote {
builder = builder.with_double_quote(*v)
}
Ok(CsvWriterOptions {
writer_options: builder,
compression: value.compression,
Expand Down
7 changes: 7 additions & 0 deletions datafusion/proto-common/proto/datafusion_common.proto
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,12 @@ message CsvWriterOptions {
string time_format = 7;
// Optional value to represent null
string null_value = 8;
// Optional quote. Defaults to `b'"'`
string quote = 9;
// Optional escape. Defaults to `'\\'`
string escape = 10;
// Optional flag whether to double quotes, instead of escaping. Defaults to `true`
bool double_quote = 11;
}

// Options controlling CSV format
Expand All @@ -398,6 +404,7 @@ message CsvOptions {
string time_format = 11; // Optional time format
string null_value = 12; // Optional representation of null value
bytes comment = 13; // Optional comment character as a byte
bytes double_quote = 14; // Indicates if quotes are doubled
}

// Options controlling CSV format
Expand Down
26 changes: 25 additions & 1 deletion datafusion/proto-common/src/from_proto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -853,6 +853,7 @@ impl TryFrom<&protobuf::CsvOptions> for CsvOptions {
delimiter: proto_opts.delimiter[0],
quote: proto_opts.quote[0],
escape: proto_opts.escape.first().copied(),
double_quote: proto_opts.has_header.first().map(|h| *h != 0),
compression: proto_opts.compression().into(),
schema_infer_max_rec: proto_opts.schema_infer_max_rec as usize,
date_format: (!proto_opts.date_format.is_empty())
Expand Down Expand Up @@ -1087,11 +1088,34 @@ pub(crate) fn csv_writer_options_from_proto(
return Err(proto_error("Error parsing CSV Delimiter"));
}
}
if !writer_options.quote.is_empty() {
if let Some(quote) = writer_options.quote.chars().next() {
if quote.is_ascii() {
builder = builder.with_quote(quote as u8);
} else {
return Err(proto_error("CSV Quote is not ASCII"));
}
} else {
return Err(proto_error("Error parsing CSV Quote"));
}
}
if !writer_options.escape.is_empty() {
if let Some(escape) = writer_options.escape.chars().next() {
if escape.is_ascii() {
builder = builder.with_escape(escape as u8);
} else {
return Err(proto_error("CSV Escape is not ASCII"));
}
} else {
return Err(proto_error("Error parsing CSV Escape"));
}
}
Ok(builder
.with_header(writer_options.has_header)
.with_date_format(writer_options.date_format.clone())
.with_datetime_format(writer_options.datetime_format.clone())
.with_timestamp_format(writer_options.timestamp_format.clone())
.with_time_format(writer_options.time_format.clone())
.with_null(writer_options.null_value.clone()))
.with_null(writer_options.null_value.clone())
.with_double_quote(writer_options.double_quote))
}
73 changes: 73 additions & 0 deletions datafusion/proto-common/src/generated/pbjson.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1853,6 +1853,9 @@ impl serde::Serialize for CsvOptions {
if !self.comment.is_empty() {
len += 1;
}
if !self.double_quote.is_empty() {
len += 1;
}
let mut struct_ser = serializer.serialize_struct("datafusion_common.CsvOptions", len)?;
if !self.has_header.is_empty() {
#[allow(clippy::needless_borrow)]
Expand Down Expand Up @@ -1901,6 +1904,10 @@ impl serde::Serialize for CsvOptions {
#[allow(clippy::needless_borrow)]
struct_ser.serialize_field("comment", pbjson::private::base64::encode(&self.comment).as_str())?;
}
if !self.double_quote.is_empty() {
#[allow(clippy::needless_borrow)]
struct_ser.serialize_field("doubleQuote", pbjson::private::base64::encode(&self.double_quote).as_str())?;
}
struct_ser.end()
}
}
Expand Down Expand Up @@ -1932,6 +1939,8 @@ impl<'de> serde::Deserialize<'de> for CsvOptions {
"null_value",
"nullValue",
"comment",
"double_quote",
"doubleQuote",
];

#[allow(clippy::enum_variant_names)]
Expand All @@ -1949,6 +1958,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions {
TimeFormat,
NullValue,
Comment,
DoubleQuote,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
fn deserialize<D>(deserializer: D) -> std::result::Result<GeneratedField, D::Error>
Expand Down Expand Up @@ -1983,6 +1993,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions {
"timeFormat" | "time_format" => Ok(GeneratedField::TimeFormat),
"nullValue" | "null_value" => Ok(GeneratedField::NullValue),
"comment" => Ok(GeneratedField::Comment),
"doubleQuote" | "double_quote" => Ok(GeneratedField::DoubleQuote),
_ => Err(serde::de::Error::unknown_field(value, FIELDS)),
}
}
Expand Down Expand Up @@ -2015,6 +2026,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions {
let mut time_format__ = None;
let mut null_value__ = None;
let mut comment__ = None;
let mut double_quote__ = None;
while let Some(k) = map_.next_key()? {
match k {
GeneratedField::HasHeader => {
Expand Down Expand Up @@ -2107,6 +2119,14 @@ impl<'de> serde::Deserialize<'de> for CsvOptions {
Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0)
;
}
GeneratedField::DoubleQuote => {
if double_quote__.is_some() {
return Err(serde::de::Error::duplicate_field("doubleQuote"));
}
double_quote__ =
Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0)
;
}
}
}
Ok(CsvOptions {
Expand All @@ -2123,6 +2143,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions {
time_format: time_format__.unwrap_or_default(),
null_value: null_value__.unwrap_or_default(),
comment: comment__.unwrap_or_default(),
double_quote: double_quote__.unwrap_or_default(),
})
}
}
Expand Down Expand Up @@ -2161,6 +2182,15 @@ impl serde::Serialize for CsvWriterOptions {
if !self.null_value.is_empty() {
len += 1;
}
if !self.quote.is_empty() {
len += 1;
}
if !self.escape.is_empty() {
len += 1;
}
if self.double_quote {
len += 1;
}
let mut struct_ser = serializer.serialize_struct("datafusion_common.CsvWriterOptions", len)?;
if self.compression != 0 {
let v = CompressionTypeVariant::try_from(self.compression)
Expand Down Expand Up @@ -2188,6 +2218,15 @@ impl serde::Serialize for CsvWriterOptions {
if !self.null_value.is_empty() {
struct_ser.serialize_field("nullValue", &self.null_value)?;
}
if !self.quote.is_empty() {
struct_ser.serialize_field("quote", &self.quote)?;
}
if !self.escape.is_empty() {
struct_ser.serialize_field("escape", &self.escape)?;
}
if self.double_quote {
struct_ser.serialize_field("doubleQuote", &self.double_quote)?;
}
struct_ser.end()
}
}
Expand All @@ -2212,6 +2251,10 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions {
"timeFormat",
"null_value",
"nullValue",
"quote",
"escape",
"double_quote",
"doubleQuote",
];

#[allow(clippy::enum_variant_names)]
Expand All @@ -2224,6 +2267,9 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions {
TimestampFormat,
TimeFormat,
NullValue,
Quote,
Escape,
DoubleQuote,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
fn deserialize<D>(deserializer: D) -> std::result::Result<GeneratedField, D::Error>
Expand Down Expand Up @@ -2253,6 +2299,9 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions {
"timestampFormat" | "timestamp_format" => Ok(GeneratedField::TimestampFormat),
"timeFormat" | "time_format" => Ok(GeneratedField::TimeFormat),
"nullValue" | "null_value" => Ok(GeneratedField::NullValue),
"quote" => Ok(GeneratedField::Quote),
"escape" => Ok(GeneratedField::Escape),
"doubleQuote" | "double_quote" => Ok(GeneratedField::DoubleQuote),
_ => Err(serde::de::Error::unknown_field(value, FIELDS)),
}
}
Expand Down Expand Up @@ -2280,6 +2329,9 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions {
let mut timestamp_format__ = None;
let mut time_format__ = None;
let mut null_value__ = None;
let mut quote__ = None;
let mut escape__ = None;
let mut double_quote__ = None;
while let Some(k) = map_.next_key()? {
match k {
GeneratedField::Compression => {
Expand Down Expand Up @@ -2330,6 +2382,24 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions {
}
null_value__ = Some(map_.next_value()?);
}
GeneratedField::Quote => {
if quote__.is_some() {
return Err(serde::de::Error::duplicate_field("quote"));
}
quote__ = Some(map_.next_value()?);
}
GeneratedField::Escape => {
if escape__.is_some() {
return Err(serde::de::Error::duplicate_field("escape"));
}
escape__ = Some(map_.next_value()?);
}
GeneratedField::DoubleQuote => {
if double_quote__.is_some() {
return Err(serde::de::Error::duplicate_field("doubleQuote"));
}
double_quote__ = Some(map_.next_value()?);
}
}
}
Ok(CsvWriterOptions {
Expand All @@ -2341,6 +2411,9 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions {
timestamp_format: timestamp_format__.unwrap_or_default(),
time_format: time_format__.unwrap_or_default(),
null_value: null_value__.unwrap_or_default(),
quote: quote__.unwrap_or_default(),
escape: escape__.unwrap_or_default(),
double_quote: double_quote__.unwrap_or_default(),
})
}
}
Expand Down
12 changes: 12 additions & 0 deletions datafusion/proto-common/src/generated/prost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,15 @@ pub struct CsvWriterOptions {
/// Optional value to represent null
#[prost(string, tag = "8")]
pub null_value: ::prost::alloc::string::String,
/// Optional quote. Defaults to `b'"'`
#[prost(string, tag = "9")]
pub quote: ::prost::alloc::string::String,
/// Optional escape. Defaults to `'\\'`
#[prost(string, tag = "10")]
pub escape: ::prost::alloc::string::String,
/// Optional flag whether to double quote instead of escaping. Defaults to `true`
#[prost(bool, tag = "11")]
pub double_quote: bool,
}
/// Options controlling CSV format
#[allow(clippy::derive_partial_eq_without_eq)]
Expand Down Expand Up @@ -611,6 +620,9 @@ pub struct CsvOptions {
/// Optional comment character as a byte
#[prost(bytes = "vec", tag = "13")]
pub comment: ::prost::alloc::vec::Vec<u8>,
/// Indicates if quotes are doubled
#[prost(bytes = "vec", tag = "14")]
pub double_quote: ::prost::alloc::vec::Vec<u8>,
}
/// Options controlling CSV format
#[allow(clippy::derive_partial_eq_without_eq)]
Expand Down
4 changes: 4 additions & 0 deletions datafusion/proto-common/src/to_proto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,7 @@ impl TryFrom<&CsvOptions> for protobuf::CsvOptions {
delimiter: vec![opts.delimiter],
quote: vec![opts.quote],
escape: opts.escape.map_or_else(Vec::new, |e| vec![e]),
double_quote: opts.double_quote.map_or_else(Vec::new, |h| vec![h as u8]),
compression: compression.into(),
schema_infer_max_rec: opts.schema_infer_max_rec as u64,
date_format: opts.date_format.clone().unwrap_or_default(),
Expand Down Expand Up @@ -1012,5 +1013,8 @@ pub(crate) fn csv_writer_options_to_proto(
timestamp_format: csv_options.timestamp_format().unwrap_or("").to_owned(),
time_format: csv_options.time_format().unwrap_or("").to_owned(),
null_value: csv_options.null().to_owned(),
quote: (csv_options.quote() as char).to_string(),
escape: (csv_options.escape() as char).to_string(),
double_quote: csv_options.double_quote(),
}
}

0 comments on commit 75a9b4c

Please sign in to comment.