Skip to content

Commit

Permalink
Preserve ordering equivalencies on with_reorder (#13770)
Browse files Browse the repository at this point in the history
* Preserve ordering equivalencies on `with_reorder`

* Add assertions

* Return early if filtered_exprs is empty

* Add clarify comment

* Refactor

* Add comprehensive test case

* Add comment for exprs_equal

* Cargo fmt

* Clippy fix

* Update properties.rs

* Update exprs_equal and add tests

* Update properties.rs

---------

Co-authored-by: berkaysynnada <[email protected]>
  • Loading branch information
gokselk and berkaysynnada authored Dec 20, 2024
1 parent f3b1141 commit b0d7cd0
Show file tree
Hide file tree
Showing 2 changed files with 494 additions and 4 deletions.
211 changes: 210 additions & 1 deletion datafusion/physical-expr/src/equivalence/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,59 @@ impl EquivalenceGroup {
JoinType::RightSemi | JoinType::RightAnti => right_equivalences.clone(),
}
}

/// Checks if two expressions are equal either directly or through equivalence classes.
/// For complex expressions (e.g. a + b), checks that the expression trees are structurally
/// identical and their leaf nodes are equivalent either directly or through equivalence classes.
pub fn exprs_equal(
&self,
left: &Arc<dyn PhysicalExpr>,
right: &Arc<dyn PhysicalExpr>,
) -> bool {
// Direct equality check
if left.eq(right) {
return true;
}

// Check if expressions are equivalent through equivalence classes
// We need to check both directions since expressions might be in different classes
if let Some(left_class) = self.get_equivalence_class(left) {
if left_class.contains(right) {
return true;
}
}
if let Some(right_class) = self.get_equivalence_class(right) {
if right_class.contains(left) {
return true;
}
}

// For non-leaf nodes, check structural equality
let left_children = left.children();
let right_children = right.children();

// If either expression is a leaf node and we haven't found equality yet,
// they must be different
if left_children.is_empty() || right_children.is_empty() {
return false;
}

// Type equality check through reflection
if left.as_any().type_id() != right.as_any().type_id() {
return false;
}

// Check if the number of children is the same
if left_children.len() != right_children.len() {
return false;
}

// Check if all children are equal
left_children
.into_iter()
.zip(right_children)
.all(|(left_child, right_child)| self.exprs_equal(left_child, right_child))
}
}

impl Display for EquivalenceGroup {
Expand All @@ -647,9 +700,10 @@ mod tests {

use super::*;
use crate::equivalence::tests::create_test_params;
use crate::expressions::{lit, Literal};
use crate::expressions::{lit, BinaryExpr, Literal};

use datafusion_common::{Result, ScalarValue};
use datafusion_expr::Operator;

#[test]
fn test_bridge_groups() -> Result<()> {
Expand Down Expand Up @@ -777,4 +831,159 @@ mod tests {
assert!(!cls1.contains_any(&cls3));
assert!(!cls2.contains_any(&cls3));
}

#[test]
fn test_exprs_equal() -> Result<()> {
struct TestCase {
left: Arc<dyn PhysicalExpr>,
right: Arc<dyn PhysicalExpr>,
expected: bool,
description: &'static str,
}

// Create test columns
let col_a = Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>;
let col_b = Arc::new(Column::new("b", 1)) as Arc<dyn PhysicalExpr>;
let col_x = Arc::new(Column::new("x", 2)) as Arc<dyn PhysicalExpr>;
let col_y = Arc::new(Column::new("y", 3)) as Arc<dyn PhysicalExpr>;

// Create test literals
let lit_1 =
Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
let lit_2 =
Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc<dyn PhysicalExpr>;

// Create equivalence group with classes (a = x) and (b = y)
let eq_group = EquivalenceGroup::new(vec![
EquivalenceClass::new(vec![Arc::clone(&col_a), Arc::clone(&col_x)]),
EquivalenceClass::new(vec![Arc::clone(&col_b), Arc::clone(&col_y)]),
]);

let test_cases = vec![
// Basic equality tests
TestCase {
left: Arc::clone(&col_a),
right: Arc::clone(&col_a),
expected: true,
description: "Same column should be equal",
},
// Equivalence class tests
TestCase {
left: Arc::clone(&col_a),
right: Arc::clone(&col_x),
expected: true,
description: "Columns in same equivalence class should be equal",
},
TestCase {
left: Arc::clone(&col_b),
right: Arc::clone(&col_y),
expected: true,
description: "Columns in same equivalence class should be equal",
},
TestCase {
left: Arc::clone(&col_a),
right: Arc::clone(&col_b),
expected: false,
description:
"Columns in different equivalence classes should not be equal",
},
// Literal tests
TestCase {
left: Arc::clone(&lit_1),
right: Arc::clone(&lit_1),
expected: true,
description: "Same literal should be equal",
},
TestCase {
left: Arc::clone(&lit_1),
right: Arc::clone(&lit_2),
expected: false,
description: "Different literals should not be equal",
},
// Complex expression tests
TestCase {
left: Arc::new(BinaryExpr::new(
Arc::clone(&col_a),
Operator::Plus,
Arc::clone(&col_b),
)) as Arc<dyn PhysicalExpr>,
right: Arc::new(BinaryExpr::new(
Arc::clone(&col_x),
Operator::Plus,
Arc::clone(&col_y),
)) as Arc<dyn PhysicalExpr>,
expected: true,
description:
"Binary expressions with equivalent operands should be equal",
},
TestCase {
left: Arc::new(BinaryExpr::new(
Arc::clone(&col_a),
Operator::Plus,
Arc::clone(&col_b),
)) as Arc<dyn PhysicalExpr>,
right: Arc::new(BinaryExpr::new(
Arc::clone(&col_x),
Operator::Plus,
Arc::clone(&col_a),
)) as Arc<dyn PhysicalExpr>,
expected: false,
description:
"Binary expressions with non-equivalent operands should not be equal",
},
TestCase {
left: Arc::new(BinaryExpr::new(
Arc::clone(&col_a),
Operator::Plus,
Arc::clone(&lit_1),
)) as Arc<dyn PhysicalExpr>,
right: Arc::new(BinaryExpr::new(
Arc::clone(&col_x),
Operator::Plus,
Arc::clone(&lit_1),
)) as Arc<dyn PhysicalExpr>,
expected: true,
description: "Binary expressions with equivalent column and same literal should be equal",
},
TestCase {
left: Arc::new(BinaryExpr::new(
Arc::new(BinaryExpr::new(
Arc::clone(&col_a),
Operator::Plus,
Arc::clone(&col_b),
)),
Operator::Multiply,
Arc::clone(&lit_1),
)) as Arc<dyn PhysicalExpr>,
right: Arc::new(BinaryExpr::new(
Arc::new(BinaryExpr::new(
Arc::clone(&col_x),
Operator::Plus,
Arc::clone(&col_y),
)),
Operator::Multiply,
Arc::clone(&lit_1),
)) as Arc<dyn PhysicalExpr>,
expected: true,
description: "Nested binary expressions with equivalent operands should be equal",
},
];

for TestCase {
left,
right,
expected,
description,
} in test_cases
{
let actual = eq_group.exprs_equal(&left, &right);
assert_eq!(
actual, expected,
"{}: Failed comparing {:?} and {:?}, expected {}, got {}",
description, left, right, expected, actual
);
}

Ok(())
}
}
Loading

0 comments on commit b0d7cd0

Please sign in to comment.