Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: prevent cursor read from being cancelled by GC #39950

Merged
merged 8 commits into from
Dec 16, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions ddl/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1584,13 +1584,11 @@ func TestLogAndShowSlowLog(t *testing.T) {
}

func TestReportingMinStartTimestamp(t *testing.T) {
_, dom := testkit.CreateMockStoreAndDomainWithSchemaLease(t, dbTestLease)
store, dom := testkit.CreateMockStoreAndDomainWithSchemaLease(t, dbTestLease)
tk := testkit.NewTestKit(t, store)
se := tk.Session()

infoSyncer := dom.InfoSyncer()
sm := &testkit.MockSessionManager{
PS: make([]*util.ProcessInfo, 0),
}
infoSyncer.SetSessionManager(sm)
beforeTS := oracle.GoTimeToTS(time.Now())
infoSyncer.ReportMinStartTS(dom.Store())
afterTS := oracle.GoTimeToTS(time.Now())
Expand All @@ -1599,13 +1597,21 @@ func TestReportingMinStartTimestamp(t *testing.T) {
now := time.Now()
validTS := oracle.GoTimeToLowerLimitStartTS(now.Add(time.Minute), tikv.MaxTxnTimeUse)
lowerLimit := oracle.GoTimeToLowerLimitStartTS(now, tikv.MaxTxnTimeUse)
sm := se.GetSessionManager().(*testkit.MockSessionManager)
sm.PS = []*util.ProcessInfo{
{CurTxnStartTS: 0},
{CurTxnStartTS: math.MaxUint64},
{CurTxnStartTS: lowerLimit},
{CurTxnStartTS: validTS},
}
infoSyncer.SetSessionManager(sm)
infoSyncer.ReportMinStartTS(dom.Store())
require.Equal(t, validTS, infoSyncer.GetMinStartTS())

unhold := se.GetSessionVars().HoldTS(validTS - 1)
infoSyncer.ReportMinStartTS(dom.Store())
require.Equal(t, validTS-1, infoSyncer.GetMinStartTS())

unhold()
infoSyncer.ReportMinStartTS(dom.Store())
require.Equal(t, validTS, infoSyncer.GetMinStartTS())
}
Expand Down
16 changes: 2 additions & 14 deletions domain/infosync/info.go
Original file line number Diff line number Diff line change
Expand Up @@ -689,8 +689,6 @@ func (is *InfoSyncer) ReportMinStartTS(store kv.Storage) {
if sm == nil {
return
}
pl := sm.ShowProcessList()
innerSessionStartTSList := sm.GetInternalSessionStartTSList()

// Calculate the lower limit of the start timestamp to avoid extremely old transaction delaying GC.
currentVer, err := store.CurrentVersion(kv.GlobalTxnScope)
Expand All @@ -704,18 +702,8 @@ func (is *InfoSyncer) ReportMinStartTS(store kv.Storage) {
minStartTS := oracle.GoTimeToTS(now)
logutil.BgLogger().Debug("ReportMinStartTS", zap.Uint64("initial minStartTS", minStartTS),
zap.Uint64("StartTSLowerLimit", startTSLowerLimit))
for _, info := range pl {
if info.CurTxnStartTS > startTSLowerLimit && info.CurTxnStartTS < minStartTS {
minStartTS = info.CurTxnStartTS
}
}

for _, innerTS := range innerSessionStartTSList {
logutil.BgLogger().Debug("ReportMinStartTS", zap.Uint64("Internal Session Transaction StartTS", innerTS))
kv.PrintLongTimeInternalTxn(now, innerTS, false)
if innerTS > startTSLowerLimit && innerTS < minStartTS {
minStartTS = innerTS
}
if ts := sm.GetMinStartTS(startTSLowerLimit); ts > startTSLowerLimit && ts < minStartTS {
minStartTS = ts
}

is.minStartTS = kv.GetMinInnerTxnStartTS(now, startTSLowerLimit, minStartTS)
Expand Down
1 change: 1 addition & 0 deletions server/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ go_test(
"//util/plancodec",
"//util/resourcegrouptag",
"//util/rowcodec",
"//util/sqlexec",
"//util/topsql",
"//util/topsql/collector",
"//util/topsql/collector/mock",
Expand Down
10 changes: 9 additions & 1 deletion server/conn_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,9 @@ func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stm
if rs == nil {
return false, cc.writeOK(ctx)
}
if result, ok := rs.(*tidbResultSet); ok {
// since there are multiple implementations of ResultSet (the rs might be wrapped), we have to unwrap the rs before
// casting it to *tidbResultSet.
if result, ok := unwrapResultSet(rs).(*tidbResultSet); ok {
if planCacheStmt, ok := prepStmt.(*plannercore.PlanCacheStmt); ok {
result.preparedStmt = planCacheStmt
}
Expand All @@ -278,6 +280,12 @@ func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stm
if useCursor {
cc.initResultEncoder(ctx)
defer cc.rsEncoder.clean()
// fix https://github.com/pingcap/tidb/issues/39447. we need to hold the start-ts here because the process info
// will be set to sleep after fetch returned.
if pi := cc.ctx.ShowProcess(); pi != nil && pi.CurTxnStartTS > 0 {
unhold := cc.ctx.GetSessionVars().HoldTS(pi.CurTxnStartTS)
rs = &rsWithHooks{ResultSet: rs, onClosed: unhold}
}
stmt.StoreResultSet(rs)
if err = cc.writeColumnInfo(rs.Columns()); err != nil {
return false, err
Expand Down
89 changes: 89 additions & 0 deletions server/conn_stmt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@
package server

import (
"context"
"encoding/binary"
"testing"

"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/parser/terror"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/testkit"
"github.com/pingcap/tidb/types"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -251,3 +254,89 @@ func TestParseStmtFetchCmd(t *testing.T) {
require.Equal(t, tc.err, err)
}
}

func TestCursorReadHoldTS(t *testing.T) {
store, dom := testkit.CreateMockStoreAndDomain(t)
srv := CreateMockServer(t, store)
srv.SetDomain(dom)
defer srv.Close()

appendUint32 := binary.LittleEndian.AppendUint32
ctx := context.Background()
c := CreateMockConn(t, srv)
tk := testkit.NewTestKitWithSession(t, store, c.Context().Session)
tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a int primary key)")
tk.MustExec("insert into t values (1), (2), (3), (4), (5), (6), (7), (8)")
tk.MustQuery("select count(*) from t").Check(testkit.Rows("8"))

stmt, _, _, err := c.Context().Prepare("select * from t")
require.NoError(t, err)
require.Zero(t, tk.Session().GetSessionVars().GetMinProtectedTS(0))

// should hold ts after executing stmt with cursor
require.NoError(t, c.Dispatch(ctx, append(
appendUint32([]byte{mysql.ComStmtExecute}, uint32(stmt.ID())),
mysql.CursorTypeReadOnly, 0x1, 0x0, 0x0, 0x0,
)))
ts := tk.Session().GetSessionVars().GetMinProtectedTS(0)
require.Positive(t, ts)
// should unhold ts when result set exhausted
require.NoError(t, c.Dispatch(ctx, appendUint32(appendUint32([]byte{mysql.ComStmtFetch}, uint32(stmt.ID())), 5)))
require.Equal(t, ts, tk.Session().GetSessionVars().GetMinProtectedTS(0))
require.Equal(t, ts, srv.GetMinStartTS(0))
require.NoError(t, c.Dispatch(ctx, appendUint32(appendUint32([]byte{mysql.ComStmtFetch}, uint32(stmt.ID())), 5)))
require.Equal(t, ts, tk.Session().GetSessionVars().GetMinProtectedTS(0))
require.Equal(t, ts, srv.GetMinStartTS(0))
require.NoError(t, c.Dispatch(ctx, appendUint32(appendUint32([]byte{mysql.ComStmtFetch}, uint32(stmt.ID())), 5)))
require.Zero(t, tk.Session().GetSessionVars().GetMinProtectedTS(0))

// should hold ts after executing stmt with cursor
require.NoError(t, c.Dispatch(ctx, append(
appendUint32([]byte{mysql.ComStmtExecute}, uint32(stmt.ID())),
mysql.CursorTypeReadOnly, 0x1, 0x0, 0x0, 0x0,
)))
require.Positive(t, tk.Session().GetSessionVars().GetMinProtectedTS(0))
// should unhold ts when stmt reset
require.NoError(t, c.Dispatch(ctx, appendUint32([]byte{mysql.ComStmtReset}, uint32(stmt.ID()))))
require.Zero(t, tk.Session().GetSessionVars().GetMinProtectedTS(0))

// should hold ts after executing stmt with cursor
require.NoError(t, c.Dispatch(ctx, append(
appendUint32([]byte{mysql.ComStmtExecute}, uint32(stmt.ID())),
mysql.CursorTypeReadOnly, 0x1, 0x0, 0x0, 0x0,
)))
require.Positive(t, tk.Session().GetSessionVars().GetMinProtectedTS(0))
// should unhold ts when stmt closed
require.NoError(t, c.Dispatch(ctx, appendUint32([]byte{mysql.ComStmtClose}, uint32(stmt.ID()))))
require.Zero(t, tk.Session().GetSessionVars().GetMinProtectedTS(0))

// create another 2 stmts and execute them
stmt1, _, _, err := c.Context().Prepare("select * from t")
require.NoError(t, err)
require.NoError(t, c.Dispatch(ctx, append(
appendUint32([]byte{mysql.ComStmtExecute}, uint32(stmt1.ID())),
mysql.CursorTypeReadOnly, 0x1, 0x0, 0x0, 0x0,
)))
ts1 := tk.Session().GetSessionVars().GetMinProtectedTS(0)
require.Positive(t, ts1)
stmt2, _, _, err := c.Context().Prepare("select * from t")
require.NoError(t, err)
require.NoError(t, c.Dispatch(ctx, append(
appendUint32([]byte{mysql.ComStmtExecute}, uint32(stmt2.ID())),
mysql.CursorTypeReadOnly, 0x1, 0x0, 0x0, 0x0,
)))
ts2 := tk.Session().GetSessionVars().GetMinProtectedTS(ts1)
require.Positive(t, ts2)

require.Less(t, ts1, ts2)
require.Equal(t, ts1, srv.GetMinStartTS(0))
require.Equal(t, ts2, srv.GetMinStartTS(ts1))
require.Zero(t, srv.GetMinStartTS(ts2))

// should unhold all when session closed
c.Close()
require.Zero(t, tk.Session().GetSessionVars().GetMinProtectedTS(0))
require.Zero(t, srv.GetMinStartTS(0))
}
40 changes: 40 additions & 0 deletions server/driver_tidb.go
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,46 @@ func (trs *tidbResultSet) Columns() []*ColumnInfo {
return trs.columns
}

// rsWithHooks wraps a ResultSet with some hooks (currently only onClosed).
type rsWithHooks struct {
ResultSet
onClosed func()
}

// Close implements ResultSet#Close
func (rs *rsWithHooks) Close() error {
closed := rs.IsClosed()
err := rs.ResultSet.Close()
if !closed && rs.onClosed != nil {
rs.onClosed()
}
return err
}

// OnFetchReturned implements fetchNotifier#OnFetchReturned
func (rs *rsWithHooks) OnFetchReturned() {
if impl, ok := rs.ResultSet.(fetchNotifier); ok {
impl.OnFetchReturned()
}
}

// Unwrap returns the underlying result set
func (rs *rsWithHooks) Unwrap() ResultSet {
return rs.ResultSet
}

// unwrapResultSet likes errors.Cause but for ResultSet
func unwrapResultSet(rs ResultSet) ResultSet {
var unRS ResultSet
if u, ok := rs.(interface{ Unwrap() ResultSet }); ok {
unRS = u.Unwrap()
}
if unRS == nil {
return rs
}
return unwrapResultSet(unRS)
}

func convertColumnInfo(fld *ast.ResultField) (ci *ColumnInfo) {
ci = &ColumnInfo{
Name: fld.ColumnAsName.O,
Expand Down
25 changes: 25 additions & 0 deletions server/driver_tidb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/sqlexec"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -95,3 +96,27 @@ func TestConvertColumnInfo(t *testing.T) {
colInfo = convertColumnInfo(&resultField)
require.Equal(t, uint32(4), colInfo.ColumnLength)
}

func TestRSWithHooks(t *testing.T) {
closeCount := 0
rs := &rsWithHooks{
ResultSet: &tidbResultSet{recordSet: new(sqlexec.SimpleRecordSet)},
onClosed: func() { closeCount++ },
}
require.Equal(t, 0, closeCount)
rs.Close()
require.Equal(t, 1, closeCount)
rs.Close()
require.Equal(t, 1, closeCount)
}

func TestUnwrapRS(t *testing.T) {
var nilRS ResultSet
require.Nil(t, unwrapResultSet(nilRS))
rs0 := new(tidbResultSet)
rs1 := &rsWithHooks{ResultSet: rs0}
rs2 := &rsWithHooks{ResultSet: rs1}
for _, rs := range []ResultSet{rs0, rs1, rs2} {
require.Equal(t, rs0, unwrapResultSet(rs))
}
}
40 changes: 40 additions & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -970,3 +970,43 @@ func (s *Server) KillNonFlashbackClusterConn() {
s.Kill(id, false)
}
}

// GetMinStartTS implements SessionManager interface.
func (s *Server) GetMinStartTS(lowerBound uint64) (ts uint64) {
// sys processes
zyguan marked this conversation as resolved.
Show resolved Hide resolved
if s.dom != nil {
for _, pi := range s.dom.SysProcTracker().GetSysProcessList() {
if pi != nil && pi.CurTxnStartTS > lowerBound && (pi.CurTxnStartTS < ts || ts == 0) {
ts = pi.CurTxnStartTS
}
}
}
// user sessions
func() {
s.rwlock.RLock()
defer s.rwlock.RUnlock()
for _, client := range s.clients {
pi := client.ctx.ShowProcess()
// start ts of current process
if pi != nil && pi.CurTxnStartTS > lowerBound && (pi.CurTxnStartTS < ts || ts == 0) {
ts = pi.CurTxnStartTS
}
// min protected timestamp of current session
if minTS := client.ctx.GetSessionVars().GetMinProtectedTS(lowerBound); minTS > lowerBound && (minTS < ts || ts == 0) {
ts = minTS
}
}
}()
// internal sessions
func() {
s.sessionMapMutex.Lock()
defer s.sessionMapMutex.Unlock()
analyzeProcID := util.GetAutoAnalyzeProcID(s.ServerID)
jackysp marked this conversation as resolved.
Show resolved Hide resolved
for se := range s.internalSessions {
if thisTS, processInfoID := session.GetStartTSFromSession(se); processInfoID != analyzeProcID && thisTS > lowerBound && (thisTS < ts || ts == 0) {
ts = thisTS
}
}
}()
return
}
50 changes: 50 additions & 0 deletions sessionctx/variable/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -1328,6 +1328,12 @@ type SessionVars struct {

// Resource group name
ResourceGroupName string

// protectedTSList holds a list of timestamps that should delay GC.
protectedTSList struct {
zyguan marked this conversation as resolved.
Show resolved Hide resolved
sync.Mutex
items map[uint64]int
}
}

// GetNewChunkWithCapacity Attempt to request memory from the chunk pool
Expand Down Expand Up @@ -3157,3 +3163,47 @@ func (s *SessionVars) GetRelatedTableForMDL() *sync.Map {
func (s *SessionVars) EnableForceInlineCTE() bool {
return s.enableForceInlineCTE
}

// HoldTS holds the timestamp to prevent its data from being GCed.
func (s *SessionVars) HoldTS(ts uint64) (unhold func()) {
s.protectedTSList.Lock()
if s.protectedTSList.items == nil {
s.protectedTSList.items = map[uint64]int{}
}
s.protectedTSList.items[ts] += 1
s.protectedTSList.Unlock()
var once sync.Once
return func() {
once.Do(func() {
s.protectedTSList.Lock()
if s.protectedTSList.items != nil {
if s.protectedTSList.items[ts] > 1 {
s.protectedTSList.items[ts] -= 1
} else {
delete(s.protectedTSList.items, ts)
}
}
s.protectedTSList.Unlock()
})
}
}

// GetMinProtectedTS returns the minimum protected timestamps.
func (s *SessionVars) GetMinProtectedTS(lowerBound uint64) (ts uint64) {
s.protectedTSList.Lock()
for k, v := range s.protectedTSList.items {
if v > 0 && k > lowerBound && (k < ts || ts == 0) {
ts = k
}
}
s.protectedTSList.Unlock()
return
}

// GetProtectedTSCount returns the number of protected timestamps (mainly used for test).
func (s *SessionVars) GetProtectedTSCount() (count int) {
s.protectedTSList.Lock()
count = len(s.protectedTSList.items)
s.protectedTSList.Unlock()
return
}
Loading