diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 5b5bca75ddb0..61e002ece98b 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -21,7 +21,7 @@ use std::borrow::Cow; use std::collections::HashSet; use std::ops::Not; -use super::inlist_simplifier::{InListSimplifier, ShortenInListSimplifier}; +use super::inlist_simplifier::ShortenInListSimplifier; use super::utils::*; use crate::analyzer::type_coercion::TypeCoercionRewriter; use crate::simplify_expressions::guarantees::GuaranteeRewriter; @@ -175,7 +175,6 @@ impl ExprSimplifier { let mut simplifier = Simplifier::new(&self.info); let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?; let mut shorten_in_list_simplifier = ShortenInListSimplifier::new(); - let mut inlist_simplifier = InListSimplifier::new(); let mut guarantee_rewriter = GuaranteeRewriter::new(&self.guarantees); if self.canonicalize { @@ -190,8 +189,6 @@ impl ExprSimplifier { .data()? .rewrite(&mut simplifier) .data()? - .rewrite(&mut inlist_simplifier) - .data()? .rewrite(&mut guarantee_rewriter) .data()? // run both passes twice to try an minimize simplifications that we missed @@ -1452,13 +1449,8 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: Operator::Or, right, }) if are_inlist_and_eq(left.as_ref(), right.as_ref()) => { - let left = as_inlist(left.as_ref()); - let right = as_inlist(right.as_ref()); - - let lhs = left.unwrap(); - let rhs = right.unwrap(); - let lhs = lhs.into_owned(); - let rhs = rhs.into_owned(); + let lhs = to_inlist(*left).unwrap(); + let rhs = to_inlist(*right).unwrap(); let mut seen: HashSet = HashSet::new(); let list = lhs .list @@ -1473,7 +1465,123 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { negated: false, }; - return Ok(Transformed::yes(Expr::InList(merged_inlist))); + Transformed::yes(Expr::InList(merged_inlist)) + } + + // Simplify expressions that is guaranteed to be true or false to a literal boolean expression + // + // Rules: + // If both expressions are `IN` or `NOT IN`, then we can apply intersection or union on both lists + // Intersection: + // 1. `a in (1,2,3) AND a in (4,5) -> a in (), which is false` + // 2. `a in (1,2,3) AND a in (2,3,4) -> a in (2,3)` + // 3. `a not in (1,2,3) OR a not in (3,4,5,6) -> a not in (3)` + // Union: + // 4. `a not int (1,2,3) AND a not in (4,5,6) -> a not in (1,2,3,4,5,6)` + // # This rule is handled by `or_in_list_simplifier.rs` + // 5. `a in (1,2,3) OR a in (4,5,6) -> a in (1,2,3,4,5,6)` + // If one of the expressions is `IN` and another one is `NOT IN`, then we apply exception on `In` expression + // 6. `a in (1,2,3,4) AND a not in (1,2,3,4,5) -> a in (), which is false` + // 7. `a not in (1,2,3,4) AND a in (1,2,3,4,5) -> a = 5` + // 8. `a in (1,2,3,4) AND a not in (5,6,7,8) -> a in (1,2,3,4)` + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::And, + right, + }) if are_inlist_and_eq_and_match_neg( + left.as_ref(), + right.as_ref(), + false, + false, + ) => + { + match (*left, *right) { + (Expr::InList(l1), Expr::InList(l2)) => { + return inlist_intersection(l1, l2, false).map(Transformed::yes); + } + // Matched previously once + _ => unreachable!(), + } + } + + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::And, + right, + }) if are_inlist_and_eq_and_match_neg( + left.as_ref(), + right.as_ref(), + true, + true, + ) => + { + match (*left, *right) { + (Expr::InList(l1), Expr::InList(l2)) => { + return inlist_union(l1, l2, true).map(Transformed::yes); + } + // Matched previously once + _ => unreachable!(), + } + } + + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::And, + right, + }) if are_inlist_and_eq_and_match_neg( + left.as_ref(), + right.as_ref(), + false, + true, + ) => + { + match (*left, *right) { + (Expr::InList(l1), Expr::InList(l2)) => { + return inlist_except(l1, l2).map(Transformed::yes); + } + // Matched previously once + _ => unreachable!(), + } + } + + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::And, + right, + }) if are_inlist_and_eq_and_match_neg( + left.as_ref(), + right.as_ref(), + true, + false, + ) => + { + match (*left, *right) { + (Expr::InList(l1), Expr::InList(l2)) => { + return inlist_except(l2, l1).map(Transformed::yes); + } + // Matched previously once + _ => unreachable!(), + } + } + + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Or, + right, + }) if are_inlist_and_eq_and_match_neg( + left.as_ref(), + right.as_ref(), + true, + true, + ) => + { + match (*left, *right) { + (Expr::InList(l1), Expr::InList(l2)) => { + return inlist_intersection(l1, l2, true).map(Transformed::yes); + } + // Matched previously once + _ => unreachable!(), + } } // no additional rewrites possible @@ -1482,6 +1590,22 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { } } +// TODO: We might not need this after defer pattern for Box is stabilized. https://github.com/rust-lang/rust/issues/87121 +fn are_inlist_and_eq_and_match_neg( + left: &Expr, + right: &Expr, + is_left_neg: bool, + is_right_neg: bool, +) -> bool { + match (left, right) { + (Expr::InList(l), Expr::InList(r)) => { + l.expr == r.expr && l.negated == is_left_neg && r.negated == is_right_neg + } + _ => false, + } +} + +// TODO: We might not need this after defer pattern for Box is stabilized. https://github.com/rust-lang/rust/issues/87121 fn are_inlist_and_eq(left: &Expr, right: &Expr) -> bool { let left = as_inlist(left); let right = as_inlist(right); @@ -1519,6 +1643,78 @@ fn as_inlist(expr: &Expr) -> Option> { } } +fn to_inlist(expr: Expr) -> Option { + match expr { + Expr::InList(inlist) => Some(inlist), + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Eq, + right, + }) => match (left.as_ref(), right.as_ref()) { + (Expr::Column(_), Expr::Literal(_)) => Some(InList { + expr: left, + list: vec![*right], + negated: false, + }), + (Expr::Literal(_), Expr::Column(_)) => Some(InList { + expr: right, + list: vec![*left], + negated: false, + }), + _ => None, + }, + _ => None, + } +} + +/// Return the union of two inlist expressions +/// maintaining the order of the elements in the two lists +fn inlist_union(mut l1: InList, l2: InList, negated: bool) -> Result { + // extend the list in l1 with the elements in l2 that are not already in l1 + let l1_items: HashSet<_> = l1.list.iter().collect(); + + // keep all l2 items that do not also appear in l1 + let keep_l2: Vec<_> = l2 + .list + .into_iter() + .filter_map(|e| if l1_items.contains(&e) { None } else { Some(e) }) + .collect(); + + l1.list.extend(keep_l2); + l1.negated = negated; + Ok(Expr::InList(l1)) +} + +/// Return the intersection of two inlist expressions +/// maintaining the order of the elements in the two lists +fn inlist_intersection(mut l1: InList, l2: InList, negated: bool) -> Result { + let l2_items = l2.list.iter().collect::>(); + + // remove all items from l1 that are not in l2 + l1.list.retain(|e| l2_items.contains(e)); + + // e in () is always false + // e not in () is always true + if l1.list.is_empty() { + return Ok(lit(negated)); + } + Ok(Expr::InList(l1)) +} + +/// Return the all items in l1 that are not in l2 +/// maintaining the order of the elements in the two lists +fn inlist_except(mut l1: InList, l2: InList) -> Result { + let l2_items = l2.list.iter().collect::>(); + + // keep only items from l1 that are not in l2 + l1.list.retain(|e| !l2_items.contains(e)); + + if l1.list.is_empty() { + return Ok(lit(false)); + } + Ok(Expr::InList(l1)) +} + #[cfg(test)] mod tests { use std::{ diff --git a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs index 5d1cf27827a9..9dcb8ed15563 100644 --- a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs @@ -19,12 +19,10 @@ use super::THRESHOLD_INLINE_INLIST; -use std::collections::HashSet; - use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_common::Result; use datafusion_expr::expr::InList; -use datafusion_expr::{lit, BinaryExpr, Expr, Operator}; +use datafusion_expr::Expr; pub(super) struct ShortenInListSimplifier {} @@ -97,121 +95,3 @@ impl TreeNodeRewriter for ShortenInListSimplifier { Ok(Transformed::no(expr)) } } - -pub(super) struct InListSimplifier {} - -impl InListSimplifier { - pub(super) fn new() -> Self { - Self {} - } -} - -impl TreeNodeRewriter for InListSimplifier { - type Node = Expr; - - fn f_up(&mut self, expr: Expr) -> Result> { - // Simplify expressions that is guaranteed to be true or false to a literal boolean expression - // - // Rules: - // If both expressions are `IN` or `NOT IN`, then we can apply intersection or union on both lists - // Intersection: - // 1. `a in (1,2,3) AND a in (4,5) -> a in (), which is false` - // 2. `a in (1,2,3) AND a in (2,3,4) -> a in (2,3)` - // 3. `a not in (1,2,3) OR a not in (3,4,5,6) -> a not in (3)` - // Union: - // 4. `a not int (1,2,3) AND a not in (4,5,6) -> a not in (1,2,3,4,5,6)` - // # This rule is handled by `or_in_list_simplifier.rs` - // 5. `a in (1,2,3) OR a in (4,5,6) -> a in (1,2,3,4,5,6)` - // If one of the expressions is `IN` and another one is `NOT IN`, then we apply exception on `In` expression - // 6. `a in (1,2,3,4) AND a not in (1,2,3,4,5) -> a in (), which is false` - // 7. `a not in (1,2,3,4) AND a in (1,2,3,4,5) -> a = 5` - // 8. `a in (1,2,3,4) AND a not in (5,6,7,8) -> a in (1,2,3,4)` - if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr.clone() { - match (*left, op, *right) { - (Expr::InList(l1), Operator::And, Expr::InList(l2)) - if l1.expr == l2.expr && !l1.negated && !l2.negated => - { - return inlist_intersection(l1, l2, false).map(Transformed::yes); - } - (Expr::InList(l1), Operator::And, Expr::InList(l2)) - if l1.expr == l2.expr && l1.negated && l2.negated => - { - return inlist_union(l1, l2, true).map(Transformed::yes); - } - (Expr::InList(l1), Operator::And, Expr::InList(l2)) - if l1.expr == l2.expr && !l1.negated && l2.negated => - { - return inlist_except(l1, l2).map(Transformed::yes); - } - (Expr::InList(l1), Operator::And, Expr::InList(l2)) - if l1.expr == l2.expr && l1.negated && !l2.negated => - { - return inlist_except(l2, l1).map(Transformed::yes); - } - (Expr::InList(l1), Operator::Or, Expr::InList(l2)) - if l1.expr == l2.expr && l1.negated && l2.negated => - { - return inlist_intersection(l1, l2, true).map(Transformed::yes); - } - (left, op, right) => { - // put the expression back together - return Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr { - left: Box::new(left), - op, - right: Box::new(right), - }))); - } - } - } - - Ok(Transformed::no(expr)) - } -} - -/// Return the union of two inlist expressions -/// maintaining the order of the elements in the two lists -fn inlist_union(mut l1: InList, l2: InList, negated: bool) -> Result { - // extend the list in l1 with the elements in l2 that are not already in l1 - let l1_items: HashSet<_> = l1.list.iter().collect(); - - // keep all l2 items that do not also appear in l1 - let keep_l2: Vec<_> = l2 - .list - .into_iter() - .filter_map(|e| if l1_items.contains(&e) { None } else { Some(e) }) - .collect(); - - l1.list.extend(keep_l2); - l1.negated = negated; - Ok(Expr::InList(l1)) -} - -/// Return the intersection of two inlist expressions -/// maintaining the order of the elements in the two lists -fn inlist_intersection(mut l1: InList, l2: InList, negated: bool) -> Result { - let l2_items = l2.list.iter().collect::>(); - - // remove all items from l1 that are not in l2 - l1.list.retain(|e| l2_items.contains(e)); - - // e in () is always false - // e not in () is always true - if l1.list.is_empty() { - return Ok(lit(negated)); - } - Ok(Expr::InList(l1)) -} - -/// Return the all items in l1 that are not in l2 -/// maintaining the order of the elements in the two lists -fn inlist_except(mut l1: InList, l2: InList) -> Result { - let l2_items = l2.list.iter().collect::>(); - - // keep only items from l1 that are not in l2 - l1.list.retain(|e| !l2_items.contains(e)); - - if l1.list.is_empty() { - return Ok(lit(false)); - } - Ok(Expr::InList(l1)) -} diff --git a/datafusion/sqllogictest/test_files/predicates.slt b/datafusion/sqllogictest/test_files/predicates.slt index 4c9254beef6b..33c9ff7c3eed 100644 --- a/datafusion/sqllogictest/test_files/predicates.slt +++ b/datafusion/sqllogictest/test_files/predicates.slt @@ -781,4 +781,4 @@ logical_plan EmptyRelation physical_plan EmptyExec statement ok -drop table t; +drop table t; \ No newline at end of file