Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve constant values across union operations #13805

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
22 changes: 20 additions & 2 deletions datafusion/physical-expr/src/equivalence/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use crate::{
};

use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::JoinType;
use datafusion_common::{JoinType, ScalarValue};
use datafusion_physical_expr_common::physical_expr::format_physical_expr_list;

/// A structure representing a expression known to be constant in a physical execution plan.
Expand Down Expand Up @@ -62,11 +62,15 @@ pub struct ConstExpr {
/// Does the constant have the same value across all partitions? See
/// struct docs for more details
across_partitions: bool,
/// The value of the constant expression
value: Option<ScalarValue>,
}

impl PartialEq for ConstExpr {
fn eq(&self, other: &Self) -> bool {
self.across_partitions == other.across_partitions && self.expr.eq(&other.expr)
self.across_partitions == other.across_partitions
&& self.expr.eq(&other.expr)
&& self.value == other.value
}
}

Expand All @@ -80,9 +84,15 @@ impl ConstExpr {
expr,
// By default, assume constant expressions are not same across partitions.
across_partitions: false,
value: None,
}
}

pub fn with_value(mut self, value: ScalarValue) -> Self {
self.value = Some(value);
self
}

/// Set the `across_partitions` flag
///
/// See struct docs for more details
Expand All @@ -106,6 +116,10 @@ impl ConstExpr {
self.expr
}

pub fn value(&self) -> Option<&ScalarValue> {
self.value.as_ref()
}

pub fn map<F>(&self, f: F) -> Option<Self>
where
F: Fn(&Arc<dyn PhysicalExpr>) -> Option<Arc<dyn PhysicalExpr>>,
Expand All @@ -114,6 +128,7 @@ impl ConstExpr {
maybe_expr.map(|expr| Self {
expr,
across_partitions: self.across_partitions,
value: self.value.clone(),
})
}

Expand Down Expand Up @@ -152,6 +167,9 @@ impl Display for ConstExpr {
if self.across_partitions {
write!(f, "(across_partitions)")?;
}
if let Some(value) = self.value.as_ref() {
write!(f, "({})", value)?;
}
Ok(())
}
}
Expand Down
127 changes: 88 additions & 39 deletions datafusion/physical-expr/src/equivalence/properties.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,30 +293,33 @@ impl EquivalenceProperties {
mut self,
constants: impl IntoIterator<Item = ConstExpr>,
) -> Self {
let (const_exprs, across_partition_flags): (
Vec<Arc<dyn PhysicalExpr>>,
Vec<bool>,
) = constants
let normalized_constants = constants
.into_iter()
.map(|const_expr| {
let across_partitions = const_expr.across_partitions();
let expr = const_expr.owned_expr();
(expr, across_partitions)
.filter_map(|c| {
let across_partitions = c.across_partitions();
let value = c.value().cloned();
let expr = c.owned_expr();
let normalized_expr = self.eq_group.normalize_expr(expr);

if const_exprs_contains(&self.constants, &normalized_expr) {
return None;
}

let mut const_expr = ConstExpr::from(normalized_expr)
.with_across_partitions(across_partitions);

if let Some(value) = value {
const_expr = const_expr.with_value(value);
}

Some(const_expr)
})
.unzip();
for (expr, across_partitions) in self
.eq_group
.normalize_exprs(const_exprs)
.into_iter()
.zip(across_partition_flags)
{
if !const_exprs_contains(&self.constants, &expr) {
let const_expr =
ConstExpr::from(expr).with_across_partitions(across_partitions);
self.constants.push(const_expr);
}
}
.collect::<Vec<_>>();

// Add all new normalized constants
self.constants.extend(normalized_constants);

// Discover any new orderings based on the constants
for ordering in self.normalized_oeq_class().iter() {
if let Err(e) = self.discover_new_orderings(&ordering[0].expr) {
log::debug!("error discovering new orderings: {e}");
Expand Down Expand Up @@ -1819,7 +1822,7 @@ impl Hash for ExprWrapper {
/// *all* output partitions, that is the same as being true for all *input*
/// partitions
fn calculate_union_binary(
mut lhs: EquivalenceProperties,
lhs: EquivalenceProperties,
mut rhs: EquivalenceProperties,
) -> Result<EquivalenceProperties> {
// Harmonize the schema of the rhs with the schema of the lhs (which is the accumulator schema):
Expand All @@ -1828,26 +1831,37 @@ fn calculate_union_binary(
}

// First, calculate valid constants for the union. An expression is constant
// at the output of the union if it is constant in both sides.
let constants: Vec<_> = lhs
// at the output of the union if it is constant in both sides with matching values.
let constants = lhs
.constants()
.iter()
.filter(|const_expr| const_exprs_contains(rhs.constants(), const_expr.expr()))
.map(|const_expr| {
// TODO: When both sides have a constant column, and the actual
// constant value is the same, then the output properties could
// reflect the constant is valid across all partitions. However we
// don't track the actual value that the ConstExpr takes on, so we
// can't determine that yet
ConstExpr::new(Arc::clone(const_expr.expr())).with_across_partitions(false)
})
.collect();
.filter_map(|lhs_const| {
// Find matching constant expression in RHS
rhs.constants()
.iter()
.find(|rhs_const| rhs_const.expr().eq(lhs_const.expr()))
.map(|rhs_const| {
let mut const_expr = ConstExpr::new(Arc::clone(lhs_const.expr()));

// remove any constants that are shared in both outputs (avoid double counting them)
for c in &constants {
lhs = lhs.remove_constant(c);
rhs = rhs.remove_constant(c);
}
// If both sides are constant across partitions, set across_partitions=true
if lhs_const.across_partitions() && rhs_const.across_partitions() {
const_expr = const_expr.with_across_partitions(true);
}

// If both sides have matching constant values, preserve the value and set across_partitions=true
if let (Some(lhs_val), Some(rhs_val)) =
(lhs_const.value(), rhs_const.value())
{
if lhs_val == rhs_val {
const_expr = const_expr
.with_across_partitions(true)
.with_value(lhs_val.clone());
}
}
const_expr
})
})
.collect::<Vec<_>>();

// Next, calculate valid orderings for the union by searching for prefixes
// in both sides.
Expand Down Expand Up @@ -2113,6 +2127,7 @@ mod tests {

use arrow::datatypes::{DataType, Field, Schema};
use arrow_schema::{Fields, TimeUnit};
use datafusion_common::ScalarValue;
use datafusion_expr::Operator;

#[test]
Expand Down Expand Up @@ -3651,4 +3666,38 @@ mod tests {

sort_expr
}

#[test]
fn test_union_constant_value_preservation() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
]));

let col_a = col("a", &schema)?;
let literal_10 = ScalarValue::Int32(Some(10));

// Create first input with a=10
let const_expr1 =
ConstExpr::new(Arc::clone(&col_a)).with_value(literal_10.clone());
let input1 = EquivalenceProperties::new(Arc::clone(&schema))
.with_constants(vec![const_expr1]);

// Create second input with a=10
let const_expr2 =
ConstExpr::new(Arc::clone(&col_a)).with_value(literal_10.clone());
let input2 = EquivalenceProperties::new(Arc::clone(&schema))
.with_constants(vec![const_expr2]);

// Calculate union properties
let union_props = calculate_union(vec![input1, input2], schema)?;

// Verify column 'a' remains constant with value 10
let const_a = &union_props.constants()[0];
assert!(const_a.expr().eq(&col_a));
assert!(const_a.across_partitions());
assert_eq!(const_a.value(), Some(&literal_10));

Ok(())
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to crate an end to end .slt test that shows this behavior?

For example, a EXPLAIN PLAN where a Sort is optimized away after the constant value is propagated through the union?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea! I have one in my mind. Let me add it

Copy link
Contributor

@berkaysynnada berkaysynnada Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @alamb, I tried it but after thinking more, we actually need one more step in planner to experience an end-to-end difference. Now we have the knowledge, but we are not using it. 2 possible optimizations are which come to my mind now:
Let's assume we have:

# Constant value tracking across union
query TT
explain
SELECT * FROM(
(
    SELECT * FROM aggregate_test_100 WHERE c1='a'
)
UNION ALL
(
    SELECT * FROM aggregate_test_100 WHERE c1='a'
))
ORDER BY c1
----
+   physical_plan
+   01)SortPreservingMergeExec: [c1@0 ASC NULLS LAST]
+   02)--UnionExec
+   03)----CoalesceBatchesExec: target_batch_size=2
+   04)------FilterExec: c1@0 = a
+   05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
+   06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], has_header=true
+   07)----CoalesceBatchesExec: target_batch_size=2
+   08)------FilterExec: c1@0 = a
+   09)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
+   10)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], has_header=true
  1. At the top of the plan, we see an SPM. However, it can have a CoalescePartitionsExec instead. That would improve the performance for sure.
  2. For the same query without an order by but with another outer filter, we will see another filter. However, we can actually remove that. This is another optimization, but can be observed pretty rarely rather than 1st one.

2nd one could be not really realistic, but the first one could be implemented without much effort with a few changes in replace_with_order_preserving_variants scope.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you take a look at the first check @gokselk? It should take a few line changes in plan_with_order_preserving_variants() function. It should first look the order requirements, and if they are matched, then it would try to convert CoalescePartitionExec to SortPreservingMergeExec. But before that conversion, you can check across_partitions flag of the input constants, and if it is true, you can left the CoalescePartitionsExec as is.

Copy link
Author

@gokselk gokselk Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you take a look at the first check @gokselk? It should take a few line changes in plan_with_order_preserving_variants() function. It should first look the order requirements, and if they are matched, then it would try to convert CoalescePartitionExec to SortPreservingMergeExec. But before that conversion, you can check across_partitions flag of the input constants, and if it is true, you can left the CoalescePartitionsExec as is.

I've made changes to FilterExec for value extraction and added an initial SLT file. The query now shows CoalescePartitionExec in the output, so I think your suggested changes to plan_with_order_preserving_variants() might not be needed anymore. However, I'd appreciate your review to confirm this.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It appears that I broke some ORDER BY queries in my recent commits. I will investigate this further.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To add more context, some tests are failing non-deterministically, which is why I didn't notice it beforehand.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what the actual situation is w.r.t. those tests, but I'd advise to take a look at whether they were underspecified in the first place (i.e. the query itself may not be specifying a concrete output ordering, which could make the test flaky).

Do failing queries have top level ORDER BY clauses? If so, it is probably a bug that was introduced. Otherwise, maybe they were flaky in the first place.

30 changes: 23 additions & 7 deletions datafusion/physical-plan/src/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ use datafusion_common::{
use datafusion_execution::TaskContext;
use datafusion_expr::Operator;
use datafusion_physical_expr::equivalence::ProjectionMapping;
use datafusion_physical_expr::expressions::BinaryExpr;
use datafusion_physical_expr::expressions::{BinaryExpr, Literal};
use datafusion_physical_expr::intervals::utils::check_support;
use datafusion_physical_expr::utils::collect_columns;
use datafusion_physical_expr::{
Expand Down Expand Up @@ -218,13 +218,29 @@ impl FilterExec {
if binary.op() == &Operator::Eq {
// Filter evaluates to single value for all partitions
if input_eqs.is_expr_constant(binary.left()) {
res_constants.push(
ConstExpr::from(binary.right()).with_across_partitions(true),
)
// When left side is constant, extract value from right side if it's a literal
let (expr, lit) = (
binary.right(),
binary.right().as_any().downcast_ref::<Literal>(),
);
let mut const_expr =
ConstExpr::from(expr).with_across_partitions(true);
if let Some(lit) = lit {
const_expr = const_expr.with_value(lit.value().clone());
}
res_constants.push(const_expr);
} else if input_eqs.is_expr_constant(binary.right()) {
res_constants.push(
ConstExpr::from(binary.left()).with_across_partitions(true),
)
// When right side is constant, extract value from left side if it's a literal
let (expr, lit) = (
binary.left(),
binary.left().as_any().downcast_ref::<Literal>(),
);
let mut const_expr =
ConstExpr::from(expr).with_across_partitions(true);
if let Some(lit) = lit {
const_expr = const_expr.with_value(lit.value().clone());
}
res_constants.push(const_expr);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
statement ok
CREATE EXTERNAL TABLE aggregate_test_100 (
c1 VARCHAR NOT NULL,
c2 TINYINT NOT NULL,
c3 SMALLINT NOT NULL,
c4 SMALLINT,
c5 INT,
c6 BIGINT NOT NULL,
c7 SMALLINT NOT NULL,
c8 INT NOT NULL,
c9 BIGINT UNSIGNED NOT NULL,
c10 VARCHAR NOT NULL,
c11 FLOAT NOT NULL,
c12 DOUBLE NOT NULL,
c13 VARCHAR NOT NULL
)
STORED AS CSV
LOCATION '../../testing/data/csv/aggregate_test_100.csv'
OPTIONS ('format.has_header' 'true');

statement ok
set datafusion.explain.physical_plan_only = true;

statement ok
set datafusion.execution.batch_size = 2;

# Constant value tracking across union
query TT
explain
SELECT * FROM(
(
SELECT * FROM aggregate_test_100 WHERE c1='a'
)
UNION ALL
(
SELECT * FROM aggregate_test_100 WHERE c1='a'
))
ORDER BY c1
----
physical_plan
01)CoalescePartitionsExec
02)--UnionExec
03)----CoalesceBatchesExec: target_batch_size=2
04)------FilterExec: c1@0 = a
05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], has_header=true
07)----CoalesceBatchesExec: target_batch_size=2
08)------FilterExec: c1@0 = a
09)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
10)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], has_header=true
Loading