Skip to content

Commit

Permalink
Track parquet writer encoding memory usage on MemoryPool (#11345)
Browse files Browse the repository at this point in the history
* feat(11344): track memory used for non-parallel writes

* feat(11344): track memory usage during parallel writes

* test(11344): create bounded stream for testing

* test(11344): test ParquetSink memory reservation

* feat(11344): track bytes in file writer

* refactor(11344): tweak the ordering to add col bytes to rg_reservation, before selecting shrinking for data bytes flushed

* refactor: move each col_reservation and rg_reservation to match the parallelized call stack for col vs rg

* test(11344): add memory_limit enforcement test for parquet sink

* chore: cleanup to remove unnecessary reservation management steps

* fix: fix CI test failure due to file extension rename
  • Loading branch information
wiedld authored Jul 10, 2024
1 parent 585504a commit 6038f4c
Show file tree
Hide file tree
Showing 3 changed files with 216 additions and 10 deletions.
165 changes: 155 additions & 10 deletions datafusion/core/src/datasource/file_format/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ use datafusion_common::{
DEFAULT_PARQUET_EXTENSION,
};
use datafusion_common_runtime::SpawnedTask;
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryPool, MemoryReservation};
use datafusion_execution::TaskContext;
use datafusion_physical_expr::expressions::{MaxAccumulator, MinAccumulator};
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement};
Expand Down Expand Up @@ -749,9 +750,13 @@ impl DataSink for ParquetSink {
parquet_props.writer_options().clone(),
)
.await?;
let mut reservation =
MemoryConsumer::new(format!("ParquetSink[{}]", path))
.register(context.memory_pool());
file_write_tasks.spawn(async move {
while let Some(batch) = rx.recv().await {
writer.write(&batch).await?;
reservation.try_resize(writer.memory_size())?;
}
let file_metadata = writer
.close()
Expand All @@ -771,13 +776,15 @@ impl DataSink for ParquetSink {
let schema = self.get_writer_schema();
let props = parquet_props.clone();
let parallel_options_clone = parallel_options.clone();
let pool = Arc::clone(context.memory_pool());
file_write_tasks.spawn(async move {
let file_metadata = output_single_parquet_file_parallelized(
writer,
rx,
schema,
props.writer_options(),
parallel_options_clone,
pool,
)
.await?;
Ok((path, file_metadata))
Expand Down Expand Up @@ -818,14 +825,16 @@ impl DataSink for ParquetSink {
async fn column_serializer_task(
mut rx: Receiver<ArrowLeafColumn>,
mut writer: ArrowColumnWriter,
) -> Result<ArrowColumnWriter> {
mut reservation: MemoryReservation,
) -> Result<(ArrowColumnWriter, MemoryReservation)> {
while let Some(col) = rx.recv().await {
writer.write(&col)?;
reservation.try_resize(writer.memory_size())?;
}
Ok(writer)
Ok((writer, reservation))
}

type ColumnWriterTask = SpawnedTask<Result<ArrowColumnWriter>>;
type ColumnWriterTask = SpawnedTask<Result<(ArrowColumnWriter, MemoryReservation)>>;
type ColSender = Sender<ArrowLeafColumn>;

/// Spawns a parallel serialization task for each column
Expand All @@ -835,6 +844,7 @@ fn spawn_column_parallel_row_group_writer(
schema: Arc<Schema>,
parquet_props: Arc<WriterProperties>,
max_buffer_size: usize,
pool: &Arc<dyn MemoryPool>,
) -> Result<(Vec<ColumnWriterTask>, Vec<ColSender>)> {
let schema_desc = arrow_to_parquet_schema(&schema)?;
let col_writers = get_column_writers(&schema_desc, &parquet_props, &schema)?;
Expand All @@ -848,7 +858,13 @@ fn spawn_column_parallel_row_group_writer(
mpsc::channel::<ArrowLeafColumn>(max_buffer_size);
col_array_channels.push(send_array);

let task = SpawnedTask::spawn(column_serializer_task(recieve_array, writer));
let reservation =
MemoryConsumer::new("ParquetSink(ArrowColumnWriter)").register(pool);
let task = SpawnedTask::spawn(column_serializer_task(
recieve_array,
writer,
reservation,
));
col_writer_tasks.push(task);
}

Expand All @@ -864,7 +880,7 @@ struct ParallelParquetWriterOptions {

/// This is the return type of calling [ArrowColumnWriter].close() on each column
/// i.e. the Vec of encoded columns which can be appended to a row group
type RBStreamSerializeResult = Result<(Vec<ArrowColumnChunk>, usize)>;
type RBStreamSerializeResult = Result<(Vec<ArrowColumnChunk>, MemoryReservation, usize)>;

/// Sends the ArrowArrays in passed [RecordBatch] through the channels to their respective
/// parallel column serializers.
Expand Down Expand Up @@ -895,16 +911,22 @@ async fn send_arrays_to_col_writers(
fn spawn_rg_join_and_finalize_task(
column_writer_tasks: Vec<ColumnWriterTask>,
rg_rows: usize,
pool: &Arc<dyn MemoryPool>,
) -> SpawnedTask<RBStreamSerializeResult> {
let mut rg_reservation =
MemoryConsumer::new("ParquetSink(SerializedRowGroupWriter)").register(pool);

SpawnedTask::spawn(async move {
let num_cols = column_writer_tasks.len();
let mut finalized_rg = Vec::with_capacity(num_cols);
for task in column_writer_tasks.into_iter() {
let writer = task.join_unwind().await?;
let (writer, _col_reservation) = task.join_unwind().await?;
let encoded_size = writer.get_estimated_total_bytes();
rg_reservation.grow(encoded_size);
finalized_rg.push(writer.close()?);
}

Ok((finalized_rg, rg_rows))
Ok((finalized_rg, rg_reservation, rg_rows))
})
}

Expand All @@ -922,6 +944,7 @@ fn spawn_parquet_parallel_serialization_task(
schema: Arc<Schema>,
writer_props: Arc<WriterProperties>,
parallel_options: ParallelParquetWriterOptions,
pool: Arc<dyn MemoryPool>,
) -> SpawnedTask<Result<(), DataFusionError>> {
SpawnedTask::spawn(async move {
let max_buffer_rb = parallel_options.max_buffered_record_batches_per_stream;
Expand All @@ -931,6 +954,7 @@ fn spawn_parquet_parallel_serialization_task(
schema.clone(),
writer_props.clone(),
max_buffer_rb,
&pool,
)?;
let mut current_rg_rows = 0;

Expand All @@ -957,6 +981,7 @@ fn spawn_parquet_parallel_serialization_task(
let finalize_rg_task = spawn_rg_join_and_finalize_task(
column_writer_handles,
max_row_group_rows,
&pool,
);

serialize_tx.send(finalize_rg_task).await.map_err(|_| {
Expand All @@ -973,6 +998,7 @@ fn spawn_parquet_parallel_serialization_task(
schema.clone(),
writer_props.clone(),
max_buffer_rb,
&pool,
)?;
}
}
Expand All @@ -981,8 +1007,11 @@ fn spawn_parquet_parallel_serialization_task(
drop(col_array_channels);
// Handle leftover rows as final rowgroup, which may be smaller than max_row_group_rows
if current_rg_rows > 0 {
let finalize_rg_task =
spawn_rg_join_and_finalize_task(column_writer_handles, current_rg_rows);
let finalize_rg_task = spawn_rg_join_and_finalize_task(
column_writer_handles,
current_rg_rows,
&pool,
);

serialize_tx.send(finalize_rg_task).await.map_err(|_| {
DataFusionError::Internal(
Expand All @@ -1002,9 +1031,13 @@ async fn concatenate_parallel_row_groups(
schema: Arc<Schema>,
writer_props: Arc<WriterProperties>,
mut object_store_writer: Box<dyn AsyncWrite + Send + Unpin>,
pool: Arc<dyn MemoryPool>,
) -> Result<FileMetaData> {
let merged_buff = SharedBuffer::new(INITIAL_BUFFER_BYTES);

let mut file_reservation =
MemoryConsumer::new("ParquetSink(SerializedFileWriter)").register(&pool);

let schema_desc = arrow_to_parquet_schema(schema.as_ref())?;
let mut parquet_writer = SerializedFileWriter::new(
merged_buff.clone(),
Expand All @@ -1015,15 +1048,20 @@ async fn concatenate_parallel_row_groups(
while let Some(task) = serialize_rx.recv().await {
let result = task.join_unwind().await;
let mut rg_out = parquet_writer.next_row_group()?;
let (serialized_columns, _cnt) = result?;
let (serialized_columns, mut rg_reservation, _cnt) = result?;
for chunk in serialized_columns {
chunk.append_to_row_group(&mut rg_out)?;
rg_reservation.free();

let mut buff_to_flush = merged_buff.buffer.try_lock().unwrap();
file_reservation.try_resize(buff_to_flush.len())?;

if buff_to_flush.len() > BUFFER_FLUSH_BYTES {
object_store_writer
.write_all(buff_to_flush.as_slice())
.await?;
buff_to_flush.clear();
file_reservation.try_resize(buff_to_flush.len())?; // will set to zero
}
}
rg_out.close()?;
Expand All @@ -1034,6 +1072,7 @@ async fn concatenate_parallel_row_groups(

object_store_writer.write_all(final_buff.as_slice()).await?;
object_store_writer.shutdown().await?;
file_reservation.free();

Ok(file_metadata)
}
Expand All @@ -1048,6 +1087,7 @@ async fn output_single_parquet_file_parallelized(
output_schema: Arc<Schema>,
parquet_props: &WriterProperties,
parallel_options: ParallelParquetWriterOptions,
pool: Arc<dyn MemoryPool>,
) -> Result<FileMetaData> {
let max_rowgroups = parallel_options.max_parallel_row_groups;
// Buffer size of this channel limits maximum number of RowGroups being worked on in parallel
Expand All @@ -1061,12 +1101,14 @@ async fn output_single_parquet_file_parallelized(
output_schema.clone(),
arc_props.clone(),
parallel_options,
Arc::clone(&pool),
);
let file_metadata = concatenate_parallel_row_groups(
serialize_rx,
output_schema.clone(),
arc_props.clone(),
object_store_writer,
pool,
)
.await?;

Expand Down Expand Up @@ -1158,8 +1200,10 @@ mod tests {
use super::super::test_util::scan_format;
use crate::datasource::listing::{ListingTableUrl, PartitionedFile};
use crate::physical_plan::collect;
use crate::test_util::bounded_stream;
use std::fmt::{Display, Formatter};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;

use super::*;

Expand Down Expand Up @@ -2177,4 +2221,105 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn parquet_sink_write_memory_reservation() -> Result<()> {
async fn test_memory_reservation(global: ParquetOptions) -> Result<()> {
let field_a = Field::new("a", DataType::Utf8, false);
let field_b = Field::new("b", DataType::Utf8, false);
let schema = Arc::new(Schema::new(vec![field_a, field_b]));
let object_store_url = ObjectStoreUrl::local_filesystem();

let file_sink_config = FileSinkConfig {
object_store_url: object_store_url.clone(),
file_groups: vec![PartitionedFile::new("/tmp".to_string(), 1)],
table_paths: vec![ListingTableUrl::parse("file:///")?],
output_schema: schema.clone(),
table_partition_cols: vec![],
overwrite: true,
keep_partition_by_columns: false,
};
let parquet_sink = Arc::new(ParquetSink::new(
file_sink_config,
TableParquetOptions {
key_value_metadata: std::collections::HashMap::from([
("my-data".to_string(), Some("stuff".to_string())),
("my-data-bool-key".to_string(), None),
]),
global,
..Default::default()
},
));

// create data
let col_a: ArrayRef = Arc::new(StringArray::from(vec!["foo", "bar"]));
let col_b: ArrayRef = Arc::new(StringArray::from(vec!["baz", "baz"]));
let batch =
RecordBatch::try_from_iter(vec![("a", col_a), ("b", col_b)]).unwrap();

// create task context
let task_context = build_ctx(object_store_url.as_ref());
assert_eq!(
task_context.memory_pool().reserved(),
0,
"no bytes are reserved yet"
);

let mut write_task = parquet_sink.write_all(
Box::pin(RecordBatchStreamAdapter::new(
schema,
bounded_stream(batch, 1000),
)),
&task_context,
);

// incrementally poll and check for memory reservation
let mut reserved_bytes = 0;
while futures::poll!(&mut write_task).is_pending() {
reserved_bytes += task_context.memory_pool().reserved();
tokio::time::sleep(Duration::from_micros(1)).await;
}
assert!(
reserved_bytes > 0,
"should have bytes reserved during write"
);
assert_eq!(
task_context.memory_pool().reserved(),
0,
"no leaking byte reservation"
);

Ok(())
}

let write_opts = ParquetOptions {
allow_single_file_parallelism: false,
..Default::default()
};
test_memory_reservation(write_opts)
.await
.expect("should track for non-parallel writes");

let row_parallel_write_opts = ParquetOptions {
allow_single_file_parallelism: true,
maximum_parallel_row_group_writers: 10,
maximum_buffered_record_batches_per_stream: 1,
..Default::default()
};
test_memory_reservation(row_parallel_write_opts)
.await
.expect("should track for row-parallel writes");

let col_parallel_write_opts = ParquetOptions {
allow_single_file_parallelism: true,
maximum_parallel_row_group_writers: 1,
maximum_buffered_record_batches_per_stream: 2,
..Default::default()
};
test_memory_reservation(col_parallel_write_opts)
.await
.expect("should track for column-parallel writes");

Ok(())
}
}
36 changes: 36 additions & 0 deletions datafusion/core/src/test_util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,3 +366,39 @@ pub fn register_unbounded_file_with_ordering(
ctx.register_table(table_name, Arc::new(StreamTable::new(Arc::new(config))))?;
Ok(())
}

struct BoundedStream {
limit: usize,
count: usize,
batch: RecordBatch,
}

impl Stream for BoundedStream {
type Item = Result<RecordBatch>;

fn poll_next(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
if self.count >= self.limit {
return Poll::Ready(None);
}
self.count += 1;
Poll::Ready(Some(Ok(self.batch.clone())))
}
}

impl RecordBatchStream for BoundedStream {
fn schema(&self) -> SchemaRef {
self.batch.schema()
}
}

/// Creates an bounded stream for testing purposes.
pub fn bounded_stream(batch: RecordBatch, limit: usize) -> SendableRecordBatchStream {
Box::pin(BoundedStream {
count: 0,
limit,
batch,
})
}
Loading

0 comments on commit 6038f4c

Please sign in to comment.