Skip to content

Commit

Permalink
planner: extend predicate simplification for subquery and nested expr…
Browse files Browse the repository at this point in the history
…essions (#58261)

close #58171
  • Loading branch information
ghazalfamilyusa authored Dec 20, 2024
1 parent 3e28938 commit e53ec59
Show file tree
Hide file tree
Showing 18 changed files with 694 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@
],
"Plan": [
"TableReader 10.00 root partition:p0 data:Selection",
"└─Selection 10.00 cop[tikv] or(eq(test_partition.t1.a, 1), 0)",
"└─Selection 10.00 cop[tikv] eq(test_partition.t1.a, 1)",
" └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo"
],
"IndexPlan": [
Expand All @@ -737,9 +737,8 @@
],
"Plan": [
"Sort 10000.00 root test_partition.t1.id, test_partition.t1.a",
"└─TableReader 10000.00 root partition:all data:Selection",
" └─Selection 10000.00 cop[tikv] or(eq(test_partition.t1.a, 1), 1)",
" └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo"
"└─TableReader 10000.00 root partition:all data:TableFullScan",
" └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo"
],
"IndexPlan": [
"Sort 10000.00 root test_partition_1.t1.id, test_partition_1.t1.a",
Expand Down Expand Up @@ -1697,7 +1696,7 @@
],
"Plan": [
"TableReader 3323.33 root partition:p0 data:Selection",
"└─Selection 3323.33 cop[tikv] or(le(test_partition.t1.a, 3), 0)",
"└─Selection 3323.33 cop[tikv] le(test_partition.t1.a, 3)",
" └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo"
],
"IndexPlan": [
Expand All @@ -1722,9 +1721,8 @@
],
"Plan": [
"Sort 10000.00 root test_partition.t1.id, test_partition.t1.a",
"└─TableReader 10000.00 root partition:all data:Selection",
" └─Selection 10000.00 cop[tikv] or(eq(test_partition.t1.a, 3), 1)",
" └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo"
"└─TableReader 10000.00 root partition:all data:TableFullScan",
" └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo"
],
"IndexPlan": [
"Sort 10000.00 root test_partition_1.t1.id, test_partition_1.t1.a",
Expand Down Expand Up @@ -1895,12 +1893,12 @@
"Result": null,
"Plan": [
"TableReader 10.00 root partition:dual data:Selection",
"└─Selection 10.00 cop[tikv] eq(test_partition.t1.a, 100), or(le(test_partition.t1.b, 200), 1)",
"└─Selection 10.00 cop[tikv] eq(test_partition.t1.a, 100)",
" └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo"
],
"IndexPlan": [
"IndexReader 100.00 root partition:dual index:IndexRangeScan",
"└─IndexRangeScan 100.00 cop[tikv] table:t1, index:a(a, b, id) range:[100 NULL,100 +inf], keep order:false, stats:pseudo"
"IndexReader 10.00 root partition:dual index:IndexRangeScan",
"└─IndexRangeScan 10.00 cop[tikv] table:t1, index:a(a, b, id) range:[100,100], keep order:false, stats:pseudo"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@
"Projection 12.49 root test.t2.a2, test.t2.b2, test.t2.c2, test.t1.a1, test.t1.b1, test.t1.c1",
"└─HashJoin 12.49 root inner join, equal:[eq(test.t1.a1, test.t2.a2)]",
" ├─TableReader(Build) 9.99 root data:Selection",
" │ └─Selection 9.99 cop[tikv] not(isnull(test.t1.a1)), or(0, eq(test.t1.b1, 5))",
" │ └─Selection 9.99 cop[tikv] eq(test.t1.b1, 5), not(isnull(test.t1.a1))",
" │ └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo",
" └─TableReader(Probe) 9990.00 root data:Selection",
" └─Selection 9990.00 cop[tikv] not(isnull(test.t2.a2))",
Expand Down
63 changes: 63 additions & 0 deletions pkg/planner/core/rule_decorrelate.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,71 @@ func (*DecorrelateSolver) aggDefaultValueMap(agg *logicalop.LogicalAggregation)
return defaultValueMap
}

// pruneRedundantApply: Removes the Apply operator if the parent SELECT clause does not filter any rows from the source.
// Example: SELECT 1 FROM t1 AS tab WHERE 1 = 1 OR (EXISTS(SELECT 1 FROM t2 WHERE a2 = a1))
// In this case, the subquery can be removed entirely since the WHERE clause always evaluates to True.
// This results in a SELECT node with a True condition and an Apply operator as its child.
// If this pattern is detected, we remove both the SELECT and Apply nodes, returning the left child of the Apply operator as the result.
// For the example above, the result would be a table scan on t1.
func pruneRedundantApply(p base.LogicalPlan) (base.LogicalPlan, bool) {
// Check if the current plan is a LogicalSelection
logicalSelection, ok := p.(*logicalop.LogicalSelection)
if !ok {
return nil, false
}

// Retrieve the child of LogicalSelection
selectSource := logicalSelection.Children()[0]

// Check if the child is a LogicalApply
apply, ok := selectSource.(*logicalop.LogicalApply)
if !ok {
return nil, false
}

// Ensure the Apply operator is of a suitable join type to match the required pattern.
// Only LeftOuterJoin or LeftOuterSemiJoin are considered valid here.
if apply.JoinType != logicalop.LeftOuterJoin && apply.JoinType != logicalop.LeftOuterSemiJoin {
return nil, false
}

// Simplify predicates from the LogicalSelection
simplifiedPredicates := applyPredicateSimplification(p.SCtx(), logicalSelection.Conditions)

// Determine if this is a "true selection"
trueSelection := false
if len(simplifiedPredicates) == 0 {
trueSelection = true
} else if len(simplifiedPredicates) == 1 {
_, simplifiedPredicatesType := FindPredicateType(p.SCtx(), simplifiedPredicates[0])
if simplifiedPredicatesType == truePredicate {
trueSelection = true
}
}

if trueSelection {
finalResult := apply

// Traverse through LogicalApply nodes to find the last one
for {
child := finalResult.Children()[0]
nextApply, ok := child.(*logicalop.LogicalApply)
if !ok {
return child, true // Return the child of the last LogicalApply
}
finalResult = nextApply
}
}

return nil, false
}

// Optimize implements base.LogicalOptRule.<0th> interface.
func (s *DecorrelateSolver) Optimize(ctx context.Context, p base.LogicalPlan, opt *optimizetrace.LogicalOptimizeOp) (base.LogicalPlan, bool, error) {
if optimizedPlan, planChanged := pruneRedundantApply(p); planChanged {
return optimizedPlan, planChanged, nil
}

planChanged := false
if apply, ok := p.(*logicalop.LogicalApply); ok {
outerPlan := apply.Children()[0]
Expand Down
141 changes: 127 additions & 14 deletions pkg/planner/core/rule_predicate_simplification.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
// 1. in-list and not equal list intersection.
// 2. Drop OR predicates if they are empty for this pattern: P AND (P1 OR P2 ... OR Pn)
// Pi is removed if P & Pi is false/empty.
// 3. Simplify predicates with logical constants (True/False).
type PredicateSimplification struct {
}

Expand All @@ -44,16 +45,50 @@ const (
lessThanOrEqualPredicate
greaterThanOrEqualPredicate
orPredicate
andPredicate
scalarPredicate
falsePredicate
truePredicate
otherPredicate
)

func findPredicateType(expr expression.Expression) (*expression.Column, predicateType) {
func logicalConstant(bc base.PlanContext, cond expression.Expression) predicateType {
sc := bc.GetSessionVars().StmtCtx
con, ok := cond.(*expression.Constant)
if !ok {
return otherPredicate
}
if expression.MaybeOverOptimized4PlanCache(bc.GetExprCtx(), []expression.Expression{con}) {
return otherPredicate
}
isTrue, err := con.Value.ToBool(sc.TypeCtxOrDefault())
if err == nil {
if isTrue == 0 {
return falsePredicate
}
return truePredicate
}
return otherPredicate
}

// FindPredicateType determines the type of predicate represented by a given expression.
// It analyzes the provided expression and returns a column (if applicable) and a corresponding predicate type.
// The function handles different expression types, including constants, scalar functions, and their specific cases:
// - Logical operators (`OR` and `AND`).
// - Comparison operators (`EQ`, `NE`, `LT`, `GT`, `LE`, `GE`).
// - IN predicates with a list of constants.
// If the expression doesn't match any of these recognized patterns, it returns an `otherPredicate` type.
func FindPredicateType(bc base.PlanContext, expr expression.Expression) (*expression.Column, predicateType) {
switch v := expr.(type) {
case *expression.Constant:
return nil, logicalConstant(bc, expr)
case *expression.ScalarFunction:
if v.FuncName.L == ast.LogicOr {
return nil, orPredicate
}
if v.FuncName.L == ast.LogicAnd {
return nil, andPredicate
}
args := v.GetArgs()
if len(args) == 0 {
return nil, otherPredicate
Expand Down Expand Up @@ -103,8 +138,8 @@ func (*PredicateSimplification) Optimize(_ context.Context, p base.LogicalPlan,
// updateInPredicate applies intersection of an in list with <> value. It returns updated In list and a flag for
// a special case if an element in the inlist is not removed to keep the list not empty.
func updateInPredicate(ctx base.PlanContext, inPredicate expression.Expression, notEQPredicate expression.Expression) (expression.Expression, bool) {
_, inPredicateType := findPredicateType(inPredicate)
_, notEQPredicateType := findPredicateType(notEQPredicate)
_, inPredicateType := FindPredicateType(ctx, inPredicate)
_, notEQPredicateType := FindPredicateType(ctx, notEQPredicate)
if inPredicateType != inListPredicate || notEQPredicateType != notEqualPredicate {
return inPredicate, true
}
Expand Down Expand Up @@ -149,7 +184,8 @@ func splitCNF(conditions []expression.Expression) []expression.Expression {
}

func applyPredicateSimplification(sctx base.PlanContext, predicates []expression.Expression) []expression.Expression {
simplifiedPredicate := mergeInAndNotEQLists(sctx, predicates)
simplifiedPredicate := shortCircuitLogicalConstants(sctx, predicates)
simplifiedPredicate = mergeInAndNotEQLists(sctx, simplifiedPredicate)
pruneEmptyORBranches(sctx, simplifiedPredicate)
simplifiedPredicate = splitCNF(simplifiedPredicate)
return simplifiedPredicate
Expand All @@ -165,8 +201,8 @@ func mergeInAndNotEQLists(sctx base.PlanContext, predicates []expression.Express
for j := i + 1; j < len(predicates); j++ {
ithPredicate := predicates[i]
jthPredicate := predicates[j]
iCol, iType := findPredicateType(ithPredicate)
jCol, jType := findPredicateType(jthPredicate)
iCol, iType := FindPredicateType(sctx, ithPredicate)
jCol, jType := FindPredicateType(sctx, jthPredicate)
if iCol == jCol {
if iType == notEqualPredicate && jType == inListPredicate {
predicates[j], specialCase = updateInPredicate(sctx, jthPredicate, ithPredicate)
Expand Down Expand Up @@ -206,8 +242,8 @@ func unsatisfiableExpression(ctx base.PlanContext, p expression.Expression) bool
func unsatisfiable(ctx base.PlanContext, p1, p2 expression.Expression) bool {
var equalPred expression.Expression
var otherPred expression.Expression
col1, p1Type := findPredicateType(p1)
col2, p2Type := findPredicateType(p2)
col1, p1Type := FindPredicateType(ctx, p1)
col2, p2Type := FindPredicateType(ctx, p2)
if col1 != col2 || col1 == nil {
return false
}
Expand Down Expand Up @@ -247,17 +283,17 @@ func comparisonPred(predType predicateType) predicateType {
// It is applied for this pattern: P AND (P1 OR P2 ... OR Pn)
// Pi is removed if P & Pi is false/empty.
func updateOrPredicate(ctx base.PlanContext, orPredicateList expression.Expression, scalarPredicatePtr expression.Expression) expression.Expression {
_, orPredicateType := findPredicateType(orPredicateList)
_, scalarPredicateType := findPredicateType(scalarPredicatePtr)
_, orPredicateType := FindPredicateType(ctx, orPredicateList)
_, scalarPredicateType := FindPredicateType(ctx, scalarPredicatePtr)
scalarPredicateType = comparisonPred(scalarPredicateType)
if orPredicateType != orPredicate || scalarPredicateType != scalarPredicate {
return orPredicateList
}
v := orPredicateList.(*expression.ScalarFunction)
firstCondition := v.GetArgs()[0]
secondCondition := v.GetArgs()[1]
_, firstConditionType := findPredicateType(firstCondition)
_, secondConditionType := findPredicateType(secondCondition)
_, firstConditionType := FindPredicateType(ctx, firstCondition)
_, secondConditionType := FindPredicateType(ctx, secondCondition)
emptyFirst := false
emptySecond := false
if comparisonPred(firstConditionType) == scalarPredicate {
Expand Down Expand Up @@ -296,8 +332,8 @@ func pruneEmptyORBranches(sctx base.PlanContext, predicates []expression.Express
for j := i + 1; j < len(predicates); j++ {
ithPredicate := predicates[i]
jthPredicate := predicates[j]
_, iType := findPredicateType(ithPredicate)
_, jType := findPredicateType(jthPredicate)
_, iType := FindPredicateType(sctx, ithPredicate)
_, jType := FindPredicateType(sctx, jthPredicate)
iType = comparisonPred(iType)
jType = comparisonPred(jType)
if iType == scalarPredicate && jType == orPredicate {
Expand All @@ -311,6 +347,83 @@ func pruneEmptyORBranches(sctx base.PlanContext, predicates []expression.Express
}
}

// shortCircuitANDORLogicalConstants simplifies logical expressions by performing short-circuit evaluation
// based on the logical AND/OR nature of the predicate and constant truth/falsehood values.
func shortCircuitANDORLogicalConstants(sctx base.PlanContext, predicate expression.Expression, orCase bool) (expression.Expression, bool) {
con, _ := predicate.(*expression.ScalarFunction)
args := con.GetArgs()
firstCondition, secondCondition := args[0], args[1]

// Recursively process first and second conditions
firstCondition, firstType := processCondition(sctx, firstCondition)
secondCondition, secondType := processCondition(sctx, secondCondition)

switch {
case firstType == truePredicate && orCase:
return firstCondition, true
case secondType == truePredicate && orCase:
return secondCondition, true
case firstType == falsePredicate && orCase:
return secondCondition, true
case secondType == falsePredicate && orCase:
return firstCondition, true
case firstType == truePredicate && !orCase:
return secondCondition, true
case secondType == truePredicate && !orCase:
return firstCondition, true
case firstType == falsePredicate && !orCase:
return firstCondition, true
case secondType == falsePredicate && !orCase:
return secondCondition, true
default:
if firstCondition != args[0] || secondCondition != args[1] {
finalResult := expression.NewFunctionInternal(sctx.GetExprCtx(), con.FuncName.L, con.GetStaticType(), firstCondition, secondCondition)
return finalResult, true
}
return predicate, false
}
}

// processCondition handles individual predicate evaluation for logical AND/OR cases
// and returns the potentially simplified condition and its updated type.
func processCondition(sctx base.PlanContext, condition expression.Expression) (expression.Expression, predicateType) {
applied := false
_, conditionType := FindPredicateType(sctx, condition)

if conditionType == orPredicate {
condition, applied = shortCircuitANDORLogicalConstants(sctx, condition, true)
} else if conditionType == andPredicate {
condition, applied = shortCircuitANDORLogicalConstants(sctx, condition, false)
}

if applied {
sctx.GetSessionVars().StmtCtx.SetSkipPlanCache("True/False predicate simplification is triggered")
}

_, conditionType = FindPredicateType(sctx, condition)
return condition, conditionType
}

// shortCircuitLogicalConstants evaluates a list of predicates, applying short-circuit logic
// to simplify the list and eliminate redundant or trivially true/false predicates.
func shortCircuitLogicalConstants(sctx base.PlanContext, predicates []expression.Expression) []expression.Expression {
finalResult := make([]expression.Expression, 0, len(predicates))

for _, predicate := range predicates {
predicate, predicateType := processCondition(sctx, predicate)

if predicateType == falsePredicate {
return []expression.Expression{predicate}
}

if predicateType != truePredicate {
finalResult = append(finalResult, predicate)
}
}

return finalResult
}

// Name implements base.LogicalOptRule.<1st> interface.
func (*PredicateSimplification) Name() string {
return "predicate_simplification"
Expand Down
8 changes: 4 additions & 4 deletions tests/integrationtest/r/executor/merge_join.result
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ id estRows task access object operator info
MergeJoin 12.50 root inner join, left key:executor__merge_join.t.c1, right key:executor__merge_join.t1.c1
├─Sort(Build) 10.00 root executor__merge_join.t1.c1
│ └─TableReader 10.00 root data:Selection
│ └─Selection 10.00 cop[tikv] not(isnull(executor__merge_join.t1.c1)), or(eq(executor__merge_join.t1.c1, 3), 0)
│ └─Selection 10.00 cop[tikv] eq(executor__merge_join.t1.c1, 3), not(isnull(executor__merge_join.t1.c1))
│ └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo
└─Sort(Probe) 10.00 root executor__merge_join.t.c1
└─TableReader 10.00 root data:Selection
└─Selection 10.00 cop[tikv] not(isnull(executor__merge_join.t.c1)), or(eq(executor__merge_join.t.c1, 3), 0)
└─Selection 10.00 cop[tikv] eq(executor__merge_join.t.c1, 3), not(isnull(executor__merge_join.t.c1))
└─TableFullScan 10000.00 cop[tikv] table:t keep order:false, stats:pseudo
select /*+ TIDB_SMJ(t) */ * from t left outer join t1 on t.c1 = t1.c1 where t1.c1 = 3 or false;
c1 c2 c1 c2
Expand Down Expand Up @@ -325,12 +325,12 @@ Shuffle 12.50 root execution info: concurrency:4, data sources:[TableReader Tab
├─Sort(Build) 10.00 root executor__merge_join.t1.c1
│ └─ShuffleReceiver 10.00 root
│ └─TableReader 10.00 root data:Selection
│ └─Selection 10.00 cop[tikv] not(isnull(executor__merge_join.t1.c1)), or(eq(executor__merge_join.t1.c1, 3), 0)
│ └─Selection 10.00 cop[tikv] eq(executor__merge_join.t1.c1, 3), not(isnull(executor__merge_join.t1.c1))
│ └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo
└─Sort(Probe) 10.00 root executor__merge_join.t.c1
└─ShuffleReceiver 10.00 root
└─TableReader 10.00 root data:Selection
└─Selection 10.00 cop[tikv] not(isnull(executor__merge_join.t.c1)), or(eq(executor__merge_join.t.c1, 3), 0)
└─Selection 10.00 cop[tikv] eq(executor__merge_join.t.c1, 3), not(isnull(executor__merge_join.t.c1))
└─TableFullScan 10000.00 cop[tikv] table:t keep order:false, stats:pseudo
select /*+ TIDB_SMJ(t) */ * from t left outer join t1 on t.c1 = t1.c1 where t1.c1 = 3 or false;
c1 c2 c1 c2
Expand Down
Loading

0 comments on commit e53ec59

Please sign in to comment.