Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Track parquet writer encoding memory usage on MemoryPool #11345

Merged
merged 10 commits into from
Jul 10, 2024
161 changes: 153 additions & 8 deletions datafusion/core/src/datasource/file_format/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ use datafusion_common::{
DEFAULT_PARQUET_EXTENSION,
};
use datafusion_common_runtime::SpawnedTask;
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryPool, MemoryReservation};
use datafusion_execution::TaskContext;
use datafusion_physical_expr::expressions::{MaxAccumulator, MinAccumulator};
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement};
Expand Down Expand Up @@ -749,14 +750,19 @@ impl DataSink for ParquetSink {
parquet_props.writer_options().clone(),
)
.await?;
let mut reservation =
MemoryConsumer::new(format!("ParquetSink[{}]", path))
.register(context.memory_pool());
file_write_tasks.spawn(async move {
while let Some(batch) = rx.recv().await {
writer.write(&batch).await?;
reservation.try_resize(writer.memory_size())?;
}
let file_metadata = writer
.close()
.await
.map_err(DataFusionError::ParquetError)?;
drop(reservation);
wiedld marked this conversation as resolved.
Show resolved Hide resolved
Ok((path, file_metadata))
});
} else {
Expand All @@ -771,13 +777,15 @@ impl DataSink for ParquetSink {
let schema = self.get_writer_schema();
let props = parquet_props.clone();
let parallel_options_clone = parallel_options.clone();
let pool = Arc::clone(context.memory_pool());
file_write_tasks.spawn(async move {
let file_metadata = output_single_parquet_file_parallelized(
writer,
rx,
schema,
props.writer_options(),
parallel_options_clone,
pool,
)
.await?;
Ok((path, file_metadata))
Expand Down Expand Up @@ -818,14 +826,16 @@ impl DataSink for ParquetSink {
async fn column_serializer_task(
mut rx: Receiver<ArrowLeafColumn>,
mut writer: ArrowColumnWriter,
) -> Result<ArrowColumnWriter> {
mut reservation: MemoryReservation,
) -> Result<(ArrowColumnWriter, MemoryReservation)> {
while let Some(col) = rx.recv().await {
writer.write(&col)?;
reservation.try_resize(writer.memory_size())?;
}
Ok(writer)
Ok((writer, reservation))
wiedld marked this conversation as resolved.
Show resolved Hide resolved
}

type ColumnWriterTask = SpawnedTask<Result<ArrowColumnWriter>>;
type ColumnWriterTask = SpawnedTask<Result<(ArrowColumnWriter, MemoryReservation)>>;
type ColSender = Sender<ArrowLeafColumn>;

/// Spawns a parallel serialization task for each column
Expand All @@ -835,6 +845,7 @@ fn spawn_column_parallel_row_group_writer(
schema: Arc<Schema>,
parquet_props: Arc<WriterProperties>,
max_buffer_size: usize,
pool: &Arc<dyn MemoryPool>,
) -> Result<(Vec<ColumnWriterTask>, Vec<ColSender>)> {
let schema_desc = arrow_to_parquet_schema(&schema)?;
let col_writers = get_column_writers(&schema_desc, &parquet_props, &schema)?;
Expand All @@ -848,7 +859,13 @@ fn spawn_column_parallel_row_group_writer(
mpsc::channel::<ArrowLeafColumn>(max_buffer_size);
col_array_channels.push(send_array);

let task = SpawnedTask::spawn(column_serializer_task(recieve_array, writer));
let reservation =
MemoryConsumer::new("ParquetSink(ArrowColumnWriter)").register(pool);
wiedld marked this conversation as resolved.
Show resolved Hide resolved
let task = SpawnedTask::spawn(column_serializer_task(
recieve_array,
writer,
reservation,
));
col_writer_tasks.push(task);
}

Expand All @@ -864,7 +881,8 @@ struct ParallelParquetWriterOptions {

/// This is the return type of calling [ArrowColumnWriter].close() on each column
/// i.e. the Vec of encoded columns which can be appended to a row group
type RBStreamSerializeResult = Result<(Vec<ArrowColumnChunk>, usize)>;
type RBStreamSerializeResult =
Result<(Vec<(ArrowColumnChunk, MemoryReservation)>, usize)>;

/// Sends the ArrowArrays in passed [RecordBatch] through the channels to their respective
/// parallel column serializers.
Expand Down Expand Up @@ -900,8 +918,11 @@ fn spawn_rg_join_and_finalize_task(
let num_cols = column_writer_tasks.len();
let mut finalized_rg = Vec::with_capacity(num_cols);
for task in column_writer_tasks.into_iter() {
let writer = task.join_unwind().await?;
finalized_rg.push(writer.close()?);
let (writer, mut reservation) = task.join_unwind().await?;
let encoded_size = writer.get_estimated_total_bytes();
let data = writer.close()?;
reservation.try_resize(encoded_size)?;
finalized_rg.push((data, reservation));
}

Ok((finalized_rg, rg_rows))
Expand All @@ -922,6 +943,7 @@ fn spawn_parquet_parallel_serialization_task(
schema: Arc<Schema>,
writer_props: Arc<WriterProperties>,
parallel_options: ParallelParquetWriterOptions,
pool: Arc<dyn MemoryPool>,
) -> SpawnedTask<Result<(), DataFusionError>> {
SpawnedTask::spawn(async move {
let max_buffer_rb = parallel_options.max_buffered_record_batches_per_stream;
Expand All @@ -931,6 +953,7 @@ fn spawn_parquet_parallel_serialization_task(
schema.clone(),
writer_props.clone(),
max_buffer_rb,
&pool,
)?;
let mut current_rg_rows = 0;

Expand Down Expand Up @@ -973,6 +996,7 @@ fn spawn_parquet_parallel_serialization_task(
schema.clone(),
writer_props.clone(),
max_buffer_rb,
&pool,
)?;
}
}
Expand Down Expand Up @@ -1002,9 +1026,13 @@ async fn concatenate_parallel_row_groups(
schema: Arc<Schema>,
writer_props: Arc<WriterProperties>,
mut object_store_writer: Box<dyn AsyncWrite + Send + Unpin>,
pool: Arc<dyn MemoryPool>,
) -> Result<FileMetaData> {
let merged_buff = SharedBuffer::new(INITIAL_BUFFER_BYTES);

let mut file_reservation =
MemoryConsumer::new("ParquetSink(SerializedFileWriter)").register(&pool);
wiedld marked this conversation as resolved.
Show resolved Hide resolved

let schema_desc = arrow_to_parquet_schema(schema.as_ref())?;
let mut parquet_writer = SerializedFileWriter::new(
merged_buff.clone(),
Expand All @@ -1013,27 +1041,38 @@ async fn concatenate_parallel_row_groups(
)?;

while let Some(task) = serialize_rx.recv().await {
let mut rg_reservation =
MemoryConsumer::new("ParquetSink(SerializedRowGroupWriter)").register(&pool);
wiedld marked this conversation as resolved.
Show resolved Hide resolved

let result = task.join_unwind().await;
let mut rg_out = parquet_writer.next_row_group()?;
let (serialized_columns, _cnt) = result?;
for chunk in serialized_columns {
for (chunk, col_reservation) in serialized_columns {
wiedld marked this conversation as resolved.
Show resolved Hide resolved
chunk.append_to_row_group(&mut rg_out)?;
let mut buff_to_flush = merged_buff.buffer.try_lock().unwrap();
if buff_to_flush.len() > BUFFER_FLUSH_BYTES {
object_store_writer
.write_all(buff_to_flush.as_slice())
.await?;
rg_reservation.shrink(buff_to_flush.len());
wiedld marked this conversation as resolved.
Show resolved Hide resolved
wiedld marked this conversation as resolved.
Show resolved Hide resolved
buff_to_flush.clear();
}
rg_reservation.grow(col_reservation.size());
drop(col_reservation);
}
rg_out.close()?;

// bytes remaining. Could be unflushed data, or rg metadata passed to the file writer on rg_out.close()
let remaining_bytes = rg_reservation.free();
file_reservation.grow(remaining_bytes);
}

let file_metadata = parquet_writer.close()?;
let final_buff = merged_buff.buffer.try_lock().unwrap();

object_store_writer.write_all(final_buff.as_slice()).await?;
object_store_writer.shutdown().await?;
file_reservation.free();

Ok(file_metadata)
}
Expand All @@ -1048,6 +1087,7 @@ async fn output_single_parquet_file_parallelized(
output_schema: Arc<Schema>,
parquet_props: &WriterProperties,
parallel_options: ParallelParquetWriterOptions,
pool: Arc<dyn MemoryPool>,
) -> Result<FileMetaData> {
let max_rowgroups = parallel_options.max_parallel_row_groups;
// Buffer size of this channel limits maximum number of RowGroups being worked on in parallel
Expand All @@ -1061,12 +1101,14 @@ async fn output_single_parquet_file_parallelized(
output_schema.clone(),
arc_props.clone(),
parallel_options,
Arc::clone(&pool),
);
let file_metadata = concatenate_parallel_row_groups(
serialize_rx,
output_schema.clone(),
arc_props.clone(),
object_store_writer,
pool,
)
.await?;

Expand Down Expand Up @@ -1158,8 +1200,10 @@ mod tests {
use super::super::test_util::scan_format;
use crate::datasource::listing::{ListingTableUrl, PartitionedFile};
use crate::physical_plan::collect;
use crate::test_util::bounded_stream;
use std::fmt::{Display, Formatter};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;

use super::*;

Expand Down Expand Up @@ -2177,4 +2221,105 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn parquet_sink_write_memory_reservation() -> Result<()> {
async fn test_memory_reservation(global: ParquetOptions) -> Result<()> {
let field_a = Field::new("a", DataType::Utf8, false);
let field_b = Field::new("b", DataType::Utf8, false);
let schema = Arc::new(Schema::new(vec![field_a, field_b]));
let object_store_url = ObjectStoreUrl::local_filesystem();

let file_sink_config = FileSinkConfig {
object_store_url: object_store_url.clone(),
file_groups: vec![PartitionedFile::new("/tmp".to_string(), 1)],
table_paths: vec![ListingTableUrl::parse("file:///")?],
output_schema: schema.clone(),
table_partition_cols: vec![],
overwrite: true,
keep_partition_by_columns: false,
};
let parquet_sink = Arc::new(ParquetSink::new(
file_sink_config,
TableParquetOptions {
key_value_metadata: std::collections::HashMap::from([
("my-data".to_string(), Some("stuff".to_string())),
("my-data-bool-key".to_string(), None),
]),
global,
..Default::default()
},
));

// create data
let col_a: ArrayRef = Arc::new(StringArray::from(vec!["foo", "bar"]));
let col_b: ArrayRef = Arc::new(StringArray::from(vec!["baz", "baz"]));
let batch =
RecordBatch::try_from_iter(vec![("a", col_a), ("b", col_b)]).unwrap();

// create task context
let task_context = build_ctx(object_store_url.as_ref());
assert_eq!(
task_context.memory_pool().reserved(),
0,
"no bytes are reserved yet"
);

let mut write_task = parquet_sink.write_all(
Box::pin(RecordBatchStreamAdapter::new(
schema,
bounded_stream(batch, 1000),
)),
&task_context,
);

// incrementally poll and check for memory reservation
wiedld marked this conversation as resolved.
Show resolved Hide resolved
let mut reserved_bytes = 0;
while futures::poll!(&mut write_task).is_pending() {
reserved_bytes += task_context.memory_pool().reserved();
tokio::time::sleep(Duration::from_micros(1)).await;
}
assert!(
reserved_bytes > 0,
"should have bytes reserved during write"
);
assert_eq!(
task_context.memory_pool().reserved(),
0,
"no leaking byte reservation"
);

Ok(())
}

let write_opts = ParquetOptions {
allow_single_file_parallelism: false,
..Default::default()
};
test_memory_reservation(write_opts)
.await
.expect("should track for non-parallel writes");

let row_parallel_write_opts = ParquetOptions {
allow_single_file_parallelism: true,
maximum_parallel_row_group_writers: 10,
maximum_buffered_record_batches_per_stream: 1,
..Default::default()
};
test_memory_reservation(row_parallel_write_opts)
.await
.expect("should track for row-parallel writes");

let col_parallel_write_opts = ParquetOptions {
allow_single_file_parallelism: true,
maximum_parallel_row_group_writers: 1,
maximum_buffered_record_batches_per_stream: 2,
..Default::default()
};
test_memory_reservation(col_parallel_write_opts)
.await
.expect("should track for column-parallel writes");

Ok(())
}
}
36 changes: 36 additions & 0 deletions datafusion/core/src/test_util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,3 +366,39 @@ pub fn register_unbounded_file_with_ordering(
ctx.register_table(table_name, Arc::new(StreamTable::new(Arc::new(config))))?;
Ok(())
}

struct BoundedStream {
limit: usize,
count: usize,
batch: RecordBatch,
}

impl Stream for BoundedStream {
type Item = Result<RecordBatch>;

fn poll_next(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
if self.count >= self.limit {
return Poll::Ready(None);
}
self.count += 1;
Poll::Ready(Some(Ok(self.batch.clone())))
}
}

impl RecordBatchStream for BoundedStream {
fn schema(&self) -> SchemaRef {
self.batch.schema()
}
}

/// Creates an bounded stream for testing purposes.
pub fn bounded_stream(batch: RecordBatch, limit: usize) -> SendableRecordBatchStream {
Box::pin(BoundedStream {
count: 0,
limit,
batch,
})
}
Loading