Skip to content

Commit

Permalink
Simplify streaming_merge function parameters (#12719)
Browse files Browse the repository at this point in the history
* simplify streaming_merge function parameters

* revert test change

* change StreamingMergeConfig into builder pattern
  • Loading branch information
mertak-synnada authored Oct 4, 2024
1 parent 642a812 commit 31cbc43
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 110 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ mod sp_repartition_fuzz_tests {
metrics::{BaselineMetrics, ExecutionPlanMetricsSet},
repartition::RepartitionExec,
sorts::sort_preserving_merge::SortPreservingMergeExec,
sorts::streaming_merge::streaming_merge,
sorts::streaming_merge::StreamingMergeBuilder,
stream::RecordBatchStreamAdapter,
ExecutionPlan, Partitioning,
};
Expand Down Expand Up @@ -246,15 +246,14 @@ mod sp_repartition_fuzz_tests {
MemoryConsumer::new("test".to_string()).register(context.memory_pool());

// Internally SortPreservingMergeExec uses this function for merging.
let res = streaming_merge(
streams,
schema,
&exprs,
BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0),
1,
None,
mem_reservation,
)?;
let res = StreamingMergeBuilder::new()
.with_streams(streams)
.with_schema(schema)
.with_expressions(&exprs)
.with_metrics(BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0))
.with_batch_size(1)
.with_reservation(mem_reservation)
.build()?;
let res = collect(res).await?;
// Contains the merged result.
let res = concat_batches(&res[0].schema(), &res)?;
Expand Down
19 changes: 9 additions & 10 deletions datafusion/physical-plan/src/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use crate::aggregates::{
};
use crate::metrics::{BaselineMetrics, MetricBuilder, RecordOutput};
use crate::sorts::sort::sort_batch;
use crate::sorts::streaming_merge;
use crate::sorts::streaming_merge::StreamingMergeBuilder;
use crate::spill::{read_spill_as_stream, spill_record_batch_by_size};
use crate::stream::RecordBatchStreamAdapter;
use crate::{aggregates, metrics, ExecutionPlan, PhysicalExpr};
Expand Down Expand Up @@ -1001,15 +1001,14 @@ impl GroupedHashAggregateStream {
streams.push(stream);
}
self.spill_state.is_stream_merging = true;
self.input = streaming_merge(
streams,
schema,
&self.spill_state.spill_expr,
self.baseline_metrics.clone(),
self.batch_size,
None,
self.reservation.new_empty(),
)?;
self.input = StreamingMergeBuilder::new()
.with_streams(streams)
.with_schema(schema)
.with_expressions(&self.spill_state.spill_expr)
.with_metrics(self.baseline_metrics.clone())
.with_batch_size(self.batch_size)
.with_reservation(self.reservation.new_empty())
.build()?;
self.input_done = false;
self.group_ordering = GroupOrdering::Full(GroupOrderingFull::new());
Ok(())
Expand Down
20 changes: 10 additions & 10 deletions datafusion/physical-plan/src/repartition/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use crate::metrics::BaselineMetrics;
use crate::repartition::distributor_channels::{
channels, partition_aware_channels, DistributionReceiver, DistributionSender,
};
use crate::sorts::streaming_merge;
use crate::sorts::streaming_merge::StreamingMergeBuilder;
use crate::stream::RecordBatchStreamAdapter;
use crate::{DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, Statistics};

Expand Down Expand Up @@ -637,15 +637,15 @@ impl ExecutionPlan for RepartitionExec {
let merge_reservation =
MemoryConsumer::new(format!("{}[Merge {partition}]", name))
.register(context.memory_pool());
streaming_merge(
input_streams,
schema_captured,
&sort_exprs,
BaselineMetrics::new(&metrics, partition),
context.session_config().batch_size(),
fetch,
merge_reservation,
)
StreamingMergeBuilder::new()
.with_streams(input_streams)
.with_schema(schema_captured)
.with_expressions(&sort_exprs)
.with_metrics(BaselineMetrics::new(&metrics, partition))
.with_batch_size(context.session_config().batch_size())
.with_fetch(fetch)
.with_reservation(merge_reservation)
.build()
} else {
Ok(Box::pin(RepartitionStream {
num_input_partitions,
Expand Down
1 change: 0 additions & 1 deletion datafusion/physical-plan/src/sorts/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,3 @@ mod stream;
pub mod streaming_merge;

pub use index::RowIndex;
pub(crate) use streaming_merge::streaming_merge;
38 changes: 19 additions & 19 deletions datafusion/physical-plan/src/sorts/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use crate::limit::LimitStream;
use crate::metrics::{
BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet,
};
use crate::sorts::streaming_merge::streaming_merge;
use crate::sorts::streaming_merge::StreamingMergeBuilder;
use crate::spill::{read_spill_as_stream, spill_record_batches};
use crate::stream::RecordBatchStreamAdapter;
use crate::topk::TopK;
Expand Down Expand Up @@ -342,15 +342,15 @@ impl ExternalSorter {
streams.push(stream);
}

streaming_merge(
streams,
Arc::clone(&self.schema),
&self.expr,
self.metrics.baseline.clone(),
self.batch_size,
self.fetch,
self.reservation.new_empty(),
)
StreamingMergeBuilder::new()
.with_streams(streams)
.with_schema(Arc::clone(&self.schema))
.with_expressions(&self.expr)
.with_metrics(self.metrics.baseline.clone())
.with_batch_size(self.batch_size)
.with_fetch(self.fetch)
.with_reservation(self.reservation.new_empty())
.build()
} else {
self.in_mem_sort_stream(self.metrics.baseline.clone())
}
Expand Down Expand Up @@ -534,15 +534,15 @@ impl ExternalSorter {
})
.collect::<Result<_>>()?;

streaming_merge(
streams,
Arc::clone(&self.schema),
&self.expr,
metrics,
self.batch_size,
self.fetch,
self.merge_reservation.new_empty(),
)
StreamingMergeBuilder::new()
.with_streams(streams)
.with_schema(Arc::clone(&self.schema))
.with_expressions(&self.expr)
.with_metrics(metrics)
.with_batch_size(self.batch_size)
.with_fetch(self.fetch)
.with_reservation(self.merge_reservation.new_empty())
.build()
}

/// Sorts a single `RecordBatch` into a single stream.
Expand Down
39 changes: 19 additions & 20 deletions datafusion/physical-plan/src/sorts/sort_preserving_merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use crate::common::spawn_buffered;
use crate::expressions::PhysicalSortExpr;
use crate::limit::LimitStream;
use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
use crate::sorts::streaming_merge;
use crate::sorts::streaming_merge::StreamingMergeBuilder;
use crate::{
DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
Partitioning, PlanProperties, SendableRecordBatchStream, Statistics,
Expand Down Expand Up @@ -273,15 +273,15 @@ impl ExecutionPlan for SortPreservingMergeExec {

debug!("Done setting up sender-receiver for SortPreservingMergeExec::execute");

let result = streaming_merge(
receivers,
schema,
&self.expr,
BaselineMetrics::new(&self.metrics, partition),
context.session_config().batch_size(),
self.fetch,
reservation,
)?;
let result = StreamingMergeBuilder::new()
.with_streams(receivers)
.with_schema(schema)
.with_expressions(&self.expr)
.with_metrics(BaselineMetrics::new(&self.metrics, partition))
.with_batch_size(context.session_config().batch_size())
.with_fetch(self.fetch)
.with_reservation(reservation)
.build()?;

debug!("Got stream result from SortPreservingMergeStream::new_from_receivers");

Expand Down Expand Up @@ -960,16 +960,15 @@ mod tests {
MemoryConsumer::new("test").register(&task_ctx.runtime_env().memory_pool);

let fetch = None;
let merge_stream = streaming_merge(
streams,
batches.schema(),
sort.as_slice(),
BaselineMetrics::new(&metrics, 0),
task_ctx.session_config().batch_size(),
fetch,
reservation,
)
.unwrap();
let merge_stream = StreamingMergeBuilder::new()
.with_streams(streams)
.with_schema(batches.schema())
.with_expressions(sort.as_slice())
.with_metrics(BaselineMetrics::new(&metrics, 0))
.with_batch_size(task_ctx.session_config().batch_size())
.with_fetch(fetch)
.with_reservation(reservation)
.build()?;

let mut merged = common::collect(merge_stream).await.unwrap();

Expand Down
151 changes: 111 additions & 40 deletions datafusion/physical-plan/src/sorts/streaming_merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,49 +49,120 @@ macro_rules! merge_helper {
}};
}

/// Perform a streaming merge of [`SendableRecordBatchStream`] based on provided sort expressions
/// while preserving order.
pub fn streaming_merge(
#[derive(Default)]
pub struct StreamingMergeBuilder<'a> {
streams: Vec<SendableRecordBatchStream>,
schema: SchemaRef,
expressions: &[PhysicalSortExpr],
metrics: BaselineMetrics,
batch_size: usize,
schema: Option<SchemaRef>,
expressions: &'a [PhysicalSortExpr],
metrics: Option<BaselineMetrics>,
batch_size: Option<usize>,
fetch: Option<usize>,
reservation: MemoryReservation,
) -> Result<SendableRecordBatchStream> {
// If there are no sort expressions, preserving the order
// doesn't mean anything (and result in infinite loops)
if expressions.is_empty() {
return internal_err!("Sort expressions cannot be empty for streaming merge");
reservation: Option<MemoryReservation>,
}

impl<'a> StreamingMergeBuilder<'a> {
pub fn new() -> Self {
Self::default()
}
// Special case single column comparisons with optimized cursor implementations
if expressions.len() == 1 {
let sort = expressions[0].clone();
let data_type = sort.expr.data_type(schema.as_ref())?;
downcast_primitive! {
data_type => (primitive_merge_helper, sort, streams, schema, metrics, batch_size, fetch, reservation),
DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, metrics, batch_size, fetch, reservation)
DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, metrics, batch_size, fetch, reservation)
DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation)
DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation)
_ => {}
}

pub fn with_streams(mut self, streams: Vec<SendableRecordBatchStream>) -> Self {
self.streams = streams;
self
}

let streams = RowCursorStream::try_new(
schema.as_ref(),
expressions,
streams,
reservation.new_empty(),
)?;

Ok(Box::pin(SortPreservingMergeStream::new(
Box::new(streams),
schema,
metrics,
batch_size,
fetch,
reservation,
)))
pub fn with_schema(mut self, schema: SchemaRef) -> Self {
self.schema = Some(schema);
self
}

pub fn with_expressions(mut self, expressions: &'a [PhysicalSortExpr]) -> Self {
self.expressions = expressions;
self
}

pub fn with_metrics(mut self, metrics: BaselineMetrics) -> Self {
self.metrics = Some(metrics);
self
}

pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = Some(batch_size);
self
}

pub fn with_fetch(mut self, fetch: Option<usize>) -> Self {
self.fetch = fetch;
self
}

pub fn with_reservation(mut self, reservation: MemoryReservation) -> Self {
self.reservation = Some(reservation);
self
}

pub fn build(self) -> Result<SendableRecordBatchStream> {
let Self {
streams,
schema,
metrics,
batch_size,
reservation,
fetch,
expressions,
} = self;

// Early return if streams or expressions are empty
let checks = [
(
streams.is_empty(),
"Streams cannot be empty for streaming merge",
),
(
expressions.is_empty(),
"Sort expressions cannot be empty for streaming merge",
),
];

if let Some((_, error_message)) = checks.iter().find(|(condition, _)| *condition)
{
return internal_err!("{}", error_message);
}

// Unwrapping mandatory fields
let schema = schema.expect("Schema cannot be empty for streaming merge");
let metrics = metrics.expect("Metrics cannot be empty for streaming merge");
let batch_size =
batch_size.expect("Batch size cannot be empty for streaming merge");
let reservation =
reservation.expect("Reservation cannot be empty for streaming merge");

// Special case single column comparisons with optimized cursor implementations
if expressions.len() == 1 {
let sort = expressions[0].clone();
let data_type = sort.expr.data_type(schema.as_ref())?;
downcast_primitive! {
data_type => (primitive_merge_helper, sort, streams, schema, metrics, batch_size, fetch, reservation),
DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, metrics, batch_size, fetch, reservation)
DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, metrics, batch_size, fetch, reservation)
DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation)
DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation)
_ => {}
}
}

let streams = RowCursorStream::try_new(
schema.as_ref(),
expressions,
streams,
reservation.new_empty(),
)?;
Ok(Box::pin(SortPreservingMergeStream::new(
Box::new(streams),
schema,
metrics,
batch_size,
fetch,
reservation,
)))
}
}

0 comments on commit 31cbc43

Please sign in to comment.