diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index cc26d12fb029..03b3c7761ac6 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -626,6 +626,59 @@ impl EquivalenceGroup { JoinType::RightSemi | JoinType::RightAnti => right_equivalences.clone(), } } + + /// Checks if two expressions are equal either directly or through equivalence classes. + /// For complex expressions (e.g. a + b), checks that the expression trees are structurally + /// identical and their leaf nodes are equivalent either directly or through equivalence classes. + pub fn exprs_equal( + &self, + left: &Arc, + right: &Arc, + ) -> bool { + // Direct equality check + if left.eq(right) { + return true; + } + + // Check if expressions are equivalent through equivalence classes + // We need to check both directions since expressions might be in different classes + if let Some(left_class) = self.get_equivalence_class(left) { + if left_class.contains(right) { + return true; + } + } + if let Some(right_class) = self.get_equivalence_class(right) { + if right_class.contains(left) { + return true; + } + } + + // For non-leaf nodes, check structural equality + let left_children = left.children(); + let right_children = right.children(); + + // If either expression is a leaf node and we haven't found equality yet, + // they must be different + if left_children.is_empty() || right_children.is_empty() { + return false; + } + + // Type equality check through reflection + if left.as_any().type_id() != right.as_any().type_id() { + return false; + } + + // Check if the number of children is the same + if left_children.len() != right_children.len() { + return false; + } + + // Check if all children are equal + left_children + .into_iter() + .zip(right_children) + .all(|(left_child, right_child)| self.exprs_equal(left_child, right_child)) + } } impl Display for EquivalenceGroup { @@ -647,9 +700,10 @@ mod tests { use super::*; use crate::equivalence::tests::create_test_params; - use crate::expressions::{lit, Literal}; + use crate::expressions::{lit, BinaryExpr, Literal}; use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::Operator; #[test] fn test_bridge_groups() -> Result<()> { @@ -777,4 +831,159 @@ mod tests { assert!(!cls1.contains_any(&cls3)); assert!(!cls2.contains_any(&cls3)); } + + #[test] + fn test_exprs_equal() -> Result<()> { + struct TestCase { + left: Arc, + right: Arc, + expected: bool, + description: &'static str, + } + + // Create test columns + let col_a = Arc::new(Column::new("a", 0)) as Arc; + let col_b = Arc::new(Column::new("b", 1)) as Arc; + let col_x = Arc::new(Column::new("x", 2)) as Arc; + let col_y = Arc::new(Column::new("y", 3)) as Arc; + + // Create test literals + let lit_1 = + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc; + let lit_2 = + Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc; + + // Create equivalence group with classes (a = x) and (b = y) + let eq_group = EquivalenceGroup::new(vec![ + EquivalenceClass::new(vec![Arc::clone(&col_a), Arc::clone(&col_x)]), + EquivalenceClass::new(vec![Arc::clone(&col_b), Arc::clone(&col_y)]), + ]); + + let test_cases = vec![ + // Basic equality tests + TestCase { + left: Arc::clone(&col_a), + right: Arc::clone(&col_a), + expected: true, + description: "Same column should be equal", + }, + // Equivalence class tests + TestCase { + left: Arc::clone(&col_a), + right: Arc::clone(&col_x), + expected: true, + description: "Columns in same equivalence class should be equal", + }, + TestCase { + left: Arc::clone(&col_b), + right: Arc::clone(&col_y), + expected: true, + description: "Columns in same equivalence class should be equal", + }, + TestCase { + left: Arc::clone(&col_a), + right: Arc::clone(&col_b), + expected: false, + description: + "Columns in different equivalence classes should not be equal", + }, + // Literal tests + TestCase { + left: Arc::clone(&lit_1), + right: Arc::clone(&lit_1), + expected: true, + description: "Same literal should be equal", + }, + TestCase { + left: Arc::clone(&lit_1), + right: Arc::clone(&lit_2), + expected: false, + description: "Different literals should not be equal", + }, + // Complex expression tests + TestCase { + left: Arc::new(BinaryExpr::new( + Arc::clone(&col_a), + Operator::Plus, + Arc::clone(&col_b), + )) as Arc, + right: Arc::new(BinaryExpr::new( + Arc::clone(&col_x), + Operator::Plus, + Arc::clone(&col_y), + )) as Arc, + expected: true, + description: + "Binary expressions with equivalent operands should be equal", + }, + TestCase { + left: Arc::new(BinaryExpr::new( + Arc::clone(&col_a), + Operator::Plus, + Arc::clone(&col_b), + )) as Arc, + right: Arc::new(BinaryExpr::new( + Arc::clone(&col_x), + Operator::Plus, + Arc::clone(&col_a), + )) as Arc, + expected: false, + description: + "Binary expressions with non-equivalent operands should not be equal", + }, + TestCase { + left: Arc::new(BinaryExpr::new( + Arc::clone(&col_a), + Operator::Plus, + Arc::clone(&lit_1), + )) as Arc, + right: Arc::new(BinaryExpr::new( + Arc::clone(&col_x), + Operator::Plus, + Arc::clone(&lit_1), + )) as Arc, + expected: true, + description: "Binary expressions with equivalent column and same literal should be equal", + }, + TestCase { + left: Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::clone(&col_a), + Operator::Plus, + Arc::clone(&col_b), + )), + Operator::Multiply, + Arc::clone(&lit_1), + )) as Arc, + right: Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::clone(&col_x), + Operator::Plus, + Arc::clone(&col_y), + )), + Operator::Multiply, + Arc::clone(&lit_1), + )) as Arc, + expected: true, + description: "Nested binary expressions with equivalent operands should be equal", + }, + ]; + + for TestCase { + left, + right, + expected, + description, + } in test_cases + { + let actual = eq_group.exprs_equal(&left, &right); + assert_eq!( + actual, expected, + "{}: Failed comparing {:?} and {:?}, expected {}, got {}", + description, left, right, expected, actual + ); + } + + Ok(()) + } } diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index d4814fb4d780..f019b2e570ff 100755 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -15,12 +15,12 @@ // specific language governing permissions and limitations // under the License. -use std::fmt; use std::fmt::Display; use std::hash::{Hash, Hasher}; use std::iter::Peekable; use std::slice::Iter; use std::sync::Arc; +use std::{fmt, mem}; use super::ordering::collapse_lex_ordering; use crate::equivalence::class::const_exprs_contains; @@ -412,12 +412,51 @@ impl EquivalenceProperties { /// Updates the ordering equivalence group within assuming that the table /// is re-sorted according to the argument `sort_exprs`. Note that constants /// and equivalence classes are unchanged as they are unaffected by a re-sort. + /// If the given ordering is already satisfied, the function does nothing. pub fn with_reorder(mut self, sort_exprs: LexOrdering) -> Self { - // TODO: In some cases, existing ordering equivalences may still be valid add this analysis. - self.oeq_class = OrderingEquivalenceClass::new(vec![sort_exprs]); + // Filter out constant expressions as they don't affect ordering + let filtered_exprs = LexOrdering::new( + sort_exprs + .into_iter() + .filter(|expr| !self.is_expr_constant(&expr.expr)) + .collect(), + ); + + if filtered_exprs.is_empty() { + return self; + } + + let mut new_orderings = vec![filtered_exprs.clone()]; + + // Preserve valid suffixes from existing orderings + let orderings = mem::take(&mut self.oeq_class.orderings); + for existing in orderings { + if self.is_prefix_of(&filtered_exprs, &existing) { + let mut extended = filtered_exprs.clone(); + extended.extend(existing.into_iter().skip(filtered_exprs.len())); + new_orderings.push(extended); + } + } + + self.oeq_class = OrderingEquivalenceClass::new(new_orderings); self } + /// Checks if the new ordering matches a prefix of the existing ordering + /// (considering expression equivalences) + fn is_prefix_of(&self, new_order: &LexOrdering, existing: &LexOrdering) -> bool { + // Check if new order is longer than existing - can't be a prefix + if new_order.len() > existing.len() { + return false; + } + + // Check if new order matches existing prefix (considering equivalences) + new_order.iter().zip(existing).all(|(new, existing)| { + self.eq_group.exprs_equal(&new.expr, &existing.expr) + && new.options == existing.options + }) + } + /// Normalizes the given sort expressions (i.e. `sort_exprs`) using the /// equivalence group and the ordering equivalence class within. /// @@ -3852,4 +3891,246 @@ mod tests { Ok(()) } + + #[test] + fn test_with_reorder_constant_filtering() -> Result<()> { + let schema = create_test_schema()?; + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); + + // Setup constant columns + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + eq_properties = eq_properties.with_constants([ConstExpr::from(&col_a)]); + + let sort_exprs = LexOrdering::new(vec![ + PhysicalSortExpr { + expr: Arc::clone(&col_a), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: Arc::clone(&col_b), + options: SortOptions::default(), + }, + ]); + + let result = eq_properties.with_reorder(sort_exprs); + + // Should only contain b since a is constant + assert_eq!(result.oeq_class().len(), 1); + assert_eq!(result.oeq_class().orderings[0].len(), 1); + assert!(result.oeq_class().orderings[0][0].expr.eq(&col_b)); + + Ok(()) + } + + #[test] + fn test_with_reorder_preserve_suffix() -> Result<()> { + let schema = create_test_schema()?; + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); + + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let col_c = col("c", &schema)?; + + let asc = SortOptions::default(); + let desc = SortOptions { + descending: true, + nulls_first: true, + }; + + // Initial ordering: [a ASC, b DESC, c ASC] + eq_properties.add_new_orderings([LexOrdering::new(vec![ + PhysicalSortExpr { + expr: Arc::clone(&col_a), + options: asc, + }, + PhysicalSortExpr { + expr: Arc::clone(&col_b), + options: desc, + }, + PhysicalSortExpr { + expr: Arc::clone(&col_c), + options: asc, + }, + ])]); + + // New ordering: [a ASC] + let new_order = LexOrdering::new(vec![PhysicalSortExpr { + expr: Arc::clone(&col_a), + options: asc, + }]); + + let result = eq_properties.with_reorder(new_order); + + // Should only contain [a ASC, b DESC, c ASC] + assert_eq!(result.oeq_class().len(), 1); + assert_eq!(result.oeq_class().orderings[0].len(), 3); + assert!(result.oeq_class().orderings[0][0].expr.eq(&col_a)); + assert!(result.oeq_class().orderings[0][0].options.eq(&asc)); + assert!(result.oeq_class().orderings[0][1].expr.eq(&col_b)); + assert!(result.oeq_class().orderings[0][1].options.eq(&desc)); + assert!(result.oeq_class().orderings[0][2].expr.eq(&col_c)); + assert!(result.oeq_class().orderings[0][2].options.eq(&asc)); + + Ok(()) + } + + #[test] + fn test_with_reorder_equivalent_expressions() -> Result<()> { + let schema = create_test_schema()?; + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); + + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let col_c = col("c", &schema)?; + + // Make a and b equivalent + eq_properties.add_equal_conditions(&col_a, &col_b)?; + + let asc = SortOptions::default(); + + // Initial ordering: [a ASC, c ASC] + eq_properties.add_new_orderings([LexOrdering::new(vec![ + PhysicalSortExpr { + expr: Arc::clone(&col_a), + options: asc, + }, + PhysicalSortExpr { + expr: Arc::clone(&col_c), + options: asc, + }, + ])]); + + // New ordering: [b ASC] + let new_order = LexOrdering::new(vec![PhysicalSortExpr { + expr: Arc::clone(&col_b), + options: asc, + }]); + + let result = eq_properties.with_reorder(new_order); + + // Should only contain [b ASC, c ASC] + assert_eq!(result.oeq_class().len(), 1); + + // Verify orderings + assert_eq!(result.oeq_class().orderings[0].len(), 2); + assert!(result.oeq_class().orderings[0][0].expr.eq(&col_b)); + assert!(result.oeq_class().orderings[0][0].options.eq(&asc)); + assert!(result.oeq_class().orderings[0][1].expr.eq(&col_c)); + assert!(result.oeq_class().orderings[0][1].options.eq(&asc)); + + Ok(()) + } + + #[test] + fn test_with_reorder_incompatible_prefix() -> Result<()> { + let schema = create_test_schema()?; + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); + + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + + let asc = SortOptions::default(); + let desc = SortOptions { + descending: true, + nulls_first: true, + }; + + // Initial ordering: [a ASC, b DESC] + eq_properties.add_new_orderings([LexOrdering::new(vec![ + PhysicalSortExpr { + expr: Arc::clone(&col_a), + options: asc, + }, + PhysicalSortExpr { + expr: Arc::clone(&col_b), + options: desc, + }, + ])]); + + // New ordering: [a DESC] + let new_order = LexOrdering::new(vec![PhysicalSortExpr { + expr: Arc::clone(&col_a), + options: desc, + }]); + + let result = eq_properties.with_reorder(new_order.clone()); + + // Should only contain the new ordering since options don't match + assert_eq!(result.oeq_class().len(), 1); + assert_eq!(result.oeq_class().orderings[0], new_order); + + Ok(()) + } + + #[test] + fn test_with_reorder_comprehensive() -> Result<()> { + let schema = create_test_schema()?; + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); + + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let col_c = col("c", &schema)?; + let col_d = col("d", &schema)?; + let col_e = col("e", &schema)?; + + let asc = SortOptions::default(); + + // Constants: c is constant + eq_properties = eq_properties.with_constants([ConstExpr::from(&col_c)]); + + // Equality: b = d + eq_properties.add_equal_conditions(&col_b, &col_d)?; + + // Orderings: [d ASC, a ASC], [e ASC] + eq_properties.add_new_orderings([ + LexOrdering::new(vec![ + PhysicalSortExpr { + expr: Arc::clone(&col_d), + options: asc, + }, + PhysicalSortExpr { + expr: Arc::clone(&col_a), + options: asc, + }, + ]), + LexOrdering::new(vec![PhysicalSortExpr { + expr: Arc::clone(&col_e), + options: asc, + }]), + ]); + + // Initial ordering: [b ASC, c ASC] + let new_order = LexOrdering::new(vec![ + PhysicalSortExpr { + expr: Arc::clone(&col_b), + options: asc, + }, + PhysicalSortExpr { + expr: Arc::clone(&col_c), + options: asc, + }, + ]); + + let result = eq_properties.with_reorder(new_order); + + // Should preserve the original [d ASC, a ASC] ordering + assert_eq!(result.oeq_class().len(), 1); + let ordering = &result.oeq_class().orderings[0]; + assert_eq!(ordering.len(), 2); + + // First expression should be either b or d (they're equivalent) + assert!( + ordering[0].expr.eq(&col_b) || ordering[0].expr.eq(&col_d), + "Expected b or d as first expression, got {:?}", + ordering[0].expr + ); + assert!(ordering[0].options.eq(&asc)); + + // Second expression should be a + assert!(ordering[1].expr.eq(&col_a)); + assert!(ordering[1].options.eq(&asc)); + + Ok(()) + } }