Skip to content

Commit

Permalink
executor: fix goroutine leak when exceed quota in hash agg (#58078)
Browse files Browse the repository at this point in the history
close #58004
  • Loading branch information
xzhangxian1008 authored Dec 23, 2024
1 parent d0ea9e5 commit 985609a
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
22 changes: 22 additions & 0 deletions pkg/executor/aggregate/agg_hash_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ import (
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/disk"
"github.com/pingcap/tidb/pkg/util/hack"
"github.com/pingcap/tidb/pkg/util/logutil"
"github.com/pingcap/tidb/pkg/util/memory"
"github.com/pingcap/tidb/pkg/util/set"
"go.uber.org/zap"
)

// HashAggInput indicates the input of hash agg exec.
Expand Down Expand Up @@ -154,6 +156,8 @@ type HashAggExec struct {
spillHelper *parallelHashAggSpillHelper
// isChildDrained indicates whether the all data from child has been taken out.
isChildDrained bool

invalidMemoryUsageForTrackingTest bool
}

// Close implements the Executor Close interface.
Expand Down Expand Up @@ -204,6 +208,10 @@ func (e *HashAggExec) Close() error {
channel.Clear(e.finalOutputCh)
e.executed.Store(false)
if e.memTracker != nil {
if e.memTracker.BytesConsumed() < 0 {
logutil.BgLogger().Warn("Memory tracker's counter is invalid", zap.Int64("counter", e.memTracker.BytesConsumed()))
e.invalidMemoryUsageForTrackingTest = true
}
e.memTracker.ReplaceBytesUsed(0)
}
e.parallelExecValid = false
Expand Down Expand Up @@ -289,6 +297,8 @@ func (e *HashAggExec) initForUnparallelExec() {
}

func (e *HashAggExec) initPartialWorkers(partialConcurrency int, finalConcurrency int, ctx sessionctx.Context) {
memUsage := int64(0)

for i := 0; i < partialConcurrency; i++ {
partialResultsMap := make([]aggfuncs.AggPartialResultMapper, finalConcurrency)
for i := 0; i < finalConcurrency; i++ {
Expand Down Expand Up @@ -316,6 +326,8 @@ func (e *HashAggExec) initPartialWorkers(partialConcurrency int, finalConcurrenc
inflightChunkSync: e.inflightChunkSync,
}

memUsage += e.partialWorkers[i].chk.MemoryUsage()

e.partialWorkers[i].partialResultNumInRow = e.partialWorkers[i].getPartialResultSliceLenConsiderByteAlign()
for j := 0; j < finalConcurrency; j++ {
e.partialWorkers[i].BInMaps[j] = 0
Expand All @@ -332,8 +344,11 @@ func (e *HashAggExec) initPartialWorkers(partialConcurrency int, finalConcurrenc
chk: chunk.New(e.Children(0).RetFieldTypes(), 0, e.MaxChunkSize()),
giveBackCh: e.partialWorkers[i].inputCh,
}
memUsage += input.chk.MemoryUsage()
e.inputCh <- input
}

e.memTracker.Consume(memUsage)
}

func (e *HashAggExec) initFinalWorkers(finalConcurrency int) {
Expand Down Expand Up @@ -442,6 +457,7 @@ func (e *HashAggExec) fetchChildData(ctx context.Context, waitGroup *sync.WaitGr
ok bool
err error
)

defer func() {
if r := recover(); r != nil {
recoveryHashAgg(e.finalOutputCh, r)
Expand Down Expand Up @@ -494,6 +510,7 @@ func (e *HashAggExec) fetchChildData(ctx context.Context, waitGroup *sync.WaitGr
input.giveBackCh <- chk

if hasError := e.spillIfNeed(); hasError {
e.memTracker.Consume(-mSize)
return
}
}
Expand Down Expand Up @@ -857,3 +874,8 @@ func (e *HashAggExec) IsSpillTriggeredForTest() bool {
}
return false
}

// IsInvalidMemoryUsageTrackingForTest is for test
func (e *HashAggExec) IsInvalidMemoryUsageTrackingForTest() bool {
return e.invalidMemoryUsageForTrackingTest
}
6 changes: 6 additions & 0 deletions pkg/executor/aggregate/agg_spill_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ func generateResult(t *testing.T, ctx *mock.Context, dataSource *testutil.MockDa
resultRows = append(resultRows, chk.GetRow(i))
}
}
require.False(t, aggExec.IsInvalidMemoryUsageTrackingForTest())
aggExec.Close()

require.False(t, aggExec.IsSpillTriggeredForTest())
Expand Down Expand Up @@ -315,6 +316,7 @@ func executeCorrecResultTest(t *testing.T, ctx *mock.Context, aggExec *aggregate
resultRows = append(resultRows, chk.GetRow(i))
}
}
require.False(t, aggExec.IsInvalidMemoryUsageTrackingForTest())
aggExec.Close()

require.True(t, aggExec.IsSpillTriggeredForTest())
Expand Down Expand Up @@ -351,6 +353,7 @@ func fallBackActionTest(t *testing.T) {
}
chk.Reset()
}
require.False(t, aggExec.IsInvalidMemoryUsageTrackingForTest())
aggExec.Close()
require.Less(t, 0, newRootExceedAction.GetTriggeredNum())
}
Expand All @@ -373,6 +376,7 @@ func randomFailTest(t *testing.T, ctx *mock.Context, aggExec *aggregate.HashAggE
go func() {
time.Sleep(time.Duration(rand.Int31n(300)) * time.Millisecond)
once.Do(func() {
require.False(t, aggExec.IsInvalidMemoryUsageTrackingForTest())
aggExec.Close()
})
goRoutineWaiter.Done()
Expand All @@ -382,6 +386,7 @@ func randomFailTest(t *testing.T, ctx *mock.Context, aggExec *aggregate.HashAggE
err := aggExec.Next(tmpCtx, chk)
if err != nil {
once.Do(func() {
require.False(t, aggExec.IsInvalidMemoryUsageTrackingForTest())
err = aggExec.Close()
require.Equal(t, nil, err)
})
Expand All @@ -393,6 +398,7 @@ func randomFailTest(t *testing.T, ctx *mock.Context, aggExec *aggregate.HashAggE
chk.Reset()
}
once.Do(func() {
require.False(t, aggExec.IsInvalidMemoryUsageTrackingForTest())
aggExec.Close()
})
}
Expand Down

0 comments on commit 985609a

Please sign in to comment.