Skip to content

Commit

Permalink
planner: add parent GE refs of group and prepare the global GE map fo…
Browse files Browse the repository at this point in the history
…r detecting group merge case. (#58382)

ref #51664
  • Loading branch information
AilinKid authored Dec 23, 2024
1 parent 985609a commit 078f5ee
Show file tree
Hide file tree
Showing 6 changed files with 266 additions and 26 deletions.
2 changes: 1 addition & 1 deletion pkg/planner/cascades/memo/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ go_test(
],
embed = [":memo"],
flaky = True,
shard_count = 7,
shard_count = 9,
deps = [
"//pkg/expression",
"//pkg/planner/cascades/base",
Expand Down
61 changes: 59 additions & 2 deletions pkg/planner/cascades/memo/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,16 @@ type Group struct {
// logicalExpressions indicates the logical equiv classes for this group.
logicalExpressions *list.List

// Operand2FirstExpr is used to locate to the first same type logical expression
// in list above instead of traverse them all.
// Operand2FirstExpr is used to locate to the first same type logicalExpression list.
Operand2FirstExpr map[pattern.Operand]*list.Element

// hash2GroupExpr is used to de-duplication in the list.
hash2GroupExpr *hashmap.Map[*GroupExpression, *list.Element]

// hash2ParentGroupExpr is reverted pointer back from current Group to parent referred GEs.
// uint64 means *list.Element's addr, list.Element means the pos in global memo group expression's list.
hash2ParentGroupExpr *hashmap.Map[*GroupExpression, struct{}]

// logicalProp indicates the logical property.
logicalProp *property.LogicalProperty

Expand Down Expand Up @@ -80,10 +83,12 @@ func (g *Group) Insert(e *GroupExpression) bool {
if e == nil {
return false
}
// first: judge the e's existence from the hash map.
// GroupExpressions hash should be initialized within Init(xxx) method.
if _, ok := g.hash2GroupExpr.Get(e); ok {
return false
}
// second: insert it into the logicalExpressions list and maintain the Operand2FirstExpr
operand := pattern.GetOperand(e.LogicalPlan)
var newEquiv *list.Element
mark, ok := g.Operand2FirstExpr[operand]
Expand All @@ -95,11 +100,40 @@ func (g *Group) Insert(e *GroupExpression) bool {
newEquiv = g.logicalExpressions.PushBack(e)
g.Operand2FirstExpr[operand] = newEquiv
}
// third: insert the list.element into the map.
g.hash2GroupExpr.Put(e, newEquiv)
e.group = g
return true
}

// Delete an existing Group expression
func (g *Group) Delete(e *GroupExpression) {
// first: del it from map and get its list element if any.
existElem, ok := g.hash2GroupExpr.Get(e)
if !ok {
// not exist at all.
return
}
g.hash2GroupExpr.Remove(e)
// second: maintain the Operand2FirstExpr
operand := pattern.GetOperand(existElem.Value.(*GroupExpression).LogicalPlan)
if g.Operand2FirstExpr[operand] == existElem {
// The target GroupExpr is the first Element of the same Operand.
// We need to change the FirstExpr to the next Expr, or delete the FirstExpr.
nextElem := existElem.Next()
if nextElem != nil && pattern.GetOperand(nextElem.Value.(*GroupExpression).LogicalPlan) == operand {
// next elem still the same operand, just move it forward.
g.Operand2FirstExpr[operand] = nextElem
} else {
// There is no more same GE of the Operand, so we should delete the FirstExpr of this Operand.
delete(g.Operand2FirstExpr, operand)
}
}
// third: just remove this element from logicalExpression list.
g.logicalExpressions.Remove(existElem)
e.group = nil
}

// GetGroupID gets the group id.
func (g *Group) GetGroupID() GroupID {
return g.groupID
Expand Down Expand Up @@ -162,6 +196,20 @@ func (g *Group) ForEachGE(f func(ge *GroupExpression) bool) {
}
}

// removeParentGEs remove the current Group's parent GE ref which is pointed to parent.
func (g *Group) removeParentGEs(parent *GroupExpression) {
_, ok := g.hash2ParentGroupExpr.Get(parent)
intest.Assert(ok)
g.hash2ParentGroupExpr.Remove(parent)
}

// addParentGEs is used to maintain the reverted parent pointer from Group to parent GEs.
func (g *Group) addParentGEs(parent *GroupExpression) {
_, ok := g.hash2ParentGroupExpr.Get(parent)
intest.Assert(!ok)
g.hash2ParentGroupExpr.Put(parent, struct{}{})
}

// NewGroup creates a new Group with given logical prop.
func NewGroup(prop *property.LogicalProperty) *Group {
g := &Group{
Expand All @@ -177,6 +225,15 @@ func NewGroup(prop *property.LogicalProperty) *Group {
return t.GetHash64()
},
),
hash2ParentGroupExpr: hashmap.New[*GroupExpression, struct{}](
4,
func(a, b *GroupExpression) bool {
return a.Equals(b)
},
func(t *GroupExpression) uint64 {
return t.GetHash64()
},
),
}
return g
}
141 changes: 141 additions & 0 deletions pkg/planner/cascades/memo/group_and_expr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
package memo

import (
"container/list"
"testing"

"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/planner/cascades/base"
"github.com/pingcap/tidb/pkg/planner/core/operator/logicalop"
"github.com/pingcap/tidb/pkg/util/mock"
"github.com/stretchr/testify/require"
"github.com/zyedidia/generic/hashmap"
)
Expand Down Expand Up @@ -90,6 +93,50 @@ func TestGroupExpressionHashCollision(t *testing.T) {
require.Equal(t, res.Value.(*GroupExpression).Inputs[1].groupID, GroupID(1))
}

func TestGroupExpressionDelete(t *testing.T) {
hasher1 := base.NewHashEqualer()
hasher2 := base.NewHashEqualer()
child1 := &Group{groupID: 1}
child2 := &Group{groupID: 2}
a := &GroupExpression{
Inputs: []*Group{child1, child2},
LogicalPlan: &logicalop.LogicalProjection{Exprs: []expression.Expression{expression.NewOne()}},
}
b := &GroupExpression{
// root group should change the hash.
Inputs: []*Group{child2, child1},
LogicalPlan: &logicalop.LogicalProjection{Exprs: []expression.Expression{expression.NewOne()}},
}
a.Hash64(hasher1)
a.hash64 = hasher1.Sum64()
b.Hash64(hasher2)
b.hash64 = hasher2.Sum64()
root := NewGroup(nil)
root.groupID = 3
require.True(t, root.Insert(a))
require.True(t, root.Insert(b))
require.Equal(t, root.logicalExpressions.Len(), 2)

mock := &GroupExpression{
Inputs: []*Group{child1},
LogicalPlan: &logicalop.LogicalProjection{Exprs: []expression.Expression{expression.NewOne()}},
}
hasher1.Reset()
mock.Hash64(hasher1)
mock.hash64 = hasher1.Sum64()

root.Delete(mock)
require.Equal(t, root.logicalExpressions.Len(), 2)

root.Delete(a)
require.Equal(t, root.logicalExpressions.Len(), 1)
require.Equal(t, root.GetLogicalExpressions().Front().Value.(*GroupExpression), b)

root.Delete(b)
require.Equal(t, root.logicalExpressions.Len(), 0)
require.Equal(t, root.GetLogicalExpressions().Len(), 0)
}

func TestGroupHashEquals(t *testing.T) {
hasher1 := base.NewHashEqualer()
hasher2 := base.NewHashEqualer()
Expand Down Expand Up @@ -141,3 +188,97 @@ func TestGroupExpressionHashEquals(t *testing.T) {
require.False(t, a.Equals(b))
require.False(t, a.Equals(&b))
}

func TestGroupParentGERefs(t *testing.T) {
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/planner/cascades/memo/MockPlanSkipMemoDeriveStats", `return(true)`))
defer func() {
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/planner/cascades/memo/MockPlanSkipMemoDeriveStats"))
}()
col1 := &expression.Column{
UniqueID: 1,
}
col2 := &expression.Column{
UniqueID: 2,
}
ctx := mock.NewContext()
t1 := logicalop.DataSource{}.Init(ctx, 0)
t1.SetSchema(expression.NewSchema(col1))
t2 := logicalop.DataSource{}.Init(ctx, 0)
t2.SetSchema(expression.NewSchema(col2))
join := logicalop.LogicalJoin{}.Init(ctx, 0)
join.SetSchema(expression.NewSchema(col1, col2))
join.SetChildren(t1, t2)

mm := NewMemo()
mm.Init(join)
require.Equal(t, 3, mm.GetGroups().Len())
require.Equal(t, 3, len(mm.GetGroupID2Group()))

require.Equal(t, mm.rootGroup.hash2ParentGroupExpr.Size(), 0)
require.Equal(t, mm.rootGroup.hash2GroupExpr.Size(), 1)
var (
j, j1, j2 *GroupExpression
elem *list.Element
)
mm.rootGroup.hash2GroupExpr.Each(func(key *GroupExpression, val *list.Element) {
j = key
elem = val
})
require.NotNil(t, elem)
require.NotNil(t, j)
require.Equal(t, elem.Value.(*GroupExpression), j)
require.Equal(t, mm.rootGroup.logicalExpressions.Len(), 1)
require.Equal(t, mm.rootGroup.logicalExpressions.Front(), elem)
require.True(t, j.LogicalPlan.Equals(join))

// left child group
leftGroup := j.Inputs[0]
require.Equal(t, leftGroup.hash2ParentGroupExpr.Size(), 1)
ge, ok := leftGroup.hash2ParentGroupExpr.Get(j)
require.True(t, ok)
require.NotNil(t, ge)
require.Equal(t, leftGroup.hash2GroupExpr.Size(), 1)
leftGroup.hash2GroupExpr.Each(func(key *GroupExpression, val *list.Element) {
j1 = key
elem = val
})
require.NotNil(t, elem)
require.NotNil(t, j1)
require.Equal(t, elem.Value.(*GroupExpression), j1)
require.True(t, j1.LogicalPlan.Equals(t1))

// right child group
rightGroup := j.Inputs[1]
require.Equal(t, rightGroup.hash2ParentGroupExpr.Size(), 1)
ge, ok = rightGroup.hash2ParentGroupExpr.Get(j)
require.True(t, ok)
require.NotNil(t, ge)
require.Equal(t, rightGroup.hash2GroupExpr.Size(), 1)
rightGroup.hash2GroupExpr.Each(func(key *GroupExpression, val *list.Element) {
j2 = key
elem = val
})
require.NotNil(t, elem)
require.NotNil(t, j2)
require.Equal(t, elem.Value.(*GroupExpression), j2)
require.True(t, j2.LogicalPlan.Equals(t2))

// assert global memo
require.Equal(t, mm.groups.Len(), 3)
require.Equal(t, mm.hash2GlobalGroupExpr.Size(), 3)
found := [3]bool{}
mm.hash2GlobalGroupExpr.Each(func(key *GroupExpression, val *GroupExpression) {
if key.Equals(j) {
found[0] = true
}
if key.Equals(j1) {
found[1] = true
}
if key.Equals(j2) {
found[2] = true
}
})
require.True(t, found[0])
require.True(t, found[1])
require.True(t, found[2])
}
12 changes: 0 additions & 12 deletions pkg/planner/cascades/memo/group_expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,18 +120,6 @@ func (e *GroupExpression) Equals(other any) bool {
return true
}

// NewGroupExpression creates a new GroupExpression with the given logical plan and children.
func NewGroupExpression(lp base.LogicalPlan, inputs []*Group) *GroupExpression {
return &GroupExpression{
group: nil,
Inputs: inputs,
LogicalPlan: lp,
hash64: 0,
// todo: add rule set length
mask: bitset.New(1),
}
}

// Init initializes the GroupExpression with the given group and hasher.
func (e *GroupExpression) Init(h base2.Hasher) {
e.Hash64(h)
Expand Down
Loading

0 comments on commit 078f5ee

Please sign in to comment.