Skip to content

Commit

Permalink
planner: support pushing down predicates to memory tables in prepared…
Browse files Browse the repository at this point in the history
… mode (#40262) (#42640)

close #39605
  • Loading branch information
ti-chi-bot authored Mar 29, 2023
1 parent 44d7fbe commit 8b20dfa
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 4 deletions.
30 changes: 30 additions & 0 deletions expression/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -1429,3 +1429,33 @@ func PropagateType(evalType types.EvalType, args ...Expression) {
}
}
}

// Args2Expressions4Test converts these values to an expression list.
// This conversion is incomplete, so only use for test.
func Args2Expressions4Test(args ...interface{}) []Expression {
exprs := make([]Expression, len(args))
for i, v := range args {
d := types.NewDatum(v)
var ft *types.FieldType
switch d.Kind() {
case types.KindNull:
ft = types.NewFieldType(mysql.TypeNull)
case types.KindInt64:
ft = types.NewFieldType(mysql.TypeLong)
case types.KindUint64:
ft = types.NewFieldType(mysql.TypeLong)
ft.AddFlag(mysql.UnsignedFlag)
case types.KindFloat64:
ft = types.NewFieldType(mysql.TypeDouble)
case types.KindString:
ft = types.NewFieldType(mysql.TypeVarString)
case types.KindMysqlTime:
ft = types.NewFieldType(mysql.TypeTimestamp)
default:
exprs[i] = nil
continue
}
exprs[i] = &Constant{Value: d, RetType: ft}
}
return exprs
}
16 changes: 12 additions & 4 deletions planner/core/memtable_predicate_extractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,14 @@ func (helper extractHelper) extractColInConsExpr(extractCols map[int64]*types.Fi
results := make([]types.Datum, 0, len(args[1:]))
for _, arg := range args[1:] {
constant, ok := arg.(*expression.Constant)
if !ok || constant.DeferredExpr != nil || constant.ParamMarker != nil {
if !ok || constant.DeferredExpr != nil {
return "", nil
}
results = append(results, constant.Value)
v := constant.Value
if constant.ParamMarker != nil {
v = constant.ParamMarker.GetUserVar()
}
results = append(results, v)
}
return name.ColName.L, results
}
Expand Down Expand Up @@ -117,10 +121,14 @@ func (helper extractHelper) extractColBinaryOpConsExpr(extractCols map[int64]*ty
// SELECT * FROM t1 WHERE c='rhs'
// SELECT * FROM t1 WHERE 'lhs'=c
constant, ok := args[1-colIdx].(*expression.Constant)
if !ok || constant.DeferredExpr != nil || constant.ParamMarker != nil {
if !ok || constant.DeferredExpr != nil {
return "", nil
}
return name.ColName.L, []types.Datum{constant.Value}
v := constant.Value
if constant.ParamMarker != nil {
v = constant.ParamMarker.GetUserVar()
}
return name.ColName.L, []types.Datum{v}
}

// extract the OR expression, e.g:
Expand Down
113 changes: 113 additions & 0 deletions planner/core/memtable_predicate_extractor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package core_test

import (
"context"
"fmt"
"regexp"
"sort"
"testing"
Expand All @@ -24,12 +25,16 @@ import (
"github.com/pingcap/tidb/domain"
"github.com/pingcap/tidb/errno"
"github.com/pingcap/tidb/parser"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/planner"
plannercore "github.com/pingcap/tidb/planner/core"
"github.com/pingcap/tidb/session"
"github.com/pingcap/tidb/testkit"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/hint"
"github.com/pingcap/tidb/util/set"
"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
)

func getLogicalMemTable(t *testing.T, dom *domain.Domain, se session.Session, parser *parser.Parser, sql string) *plannercore.LogicalMemTable {
Expand Down Expand Up @@ -1742,3 +1747,111 @@ func TestTikvRegionStatusExtractor(t *testing.T) {
require.Equal(t, ca.tableIDs, tableids)
}
}

func TestExtractorInPreparedStmt(t *testing.T) {
store, dom, clean := testkit.CreateMockStoreAndDomain(t)
defer clean()
tk := testkit.NewTestKit(t, store)

var cases = []struct {
prepared string
userVars []interface{}
params []interface{}
checker func(extractor plannercore.MemTablePredicateExtractor)
}{
{
prepared: "select * from information_schema.TIKV_REGION_STATUS where table_id = ?",
userVars: []interface{}{1},
params: []interface{}{1},
checker: func(extractor plannercore.MemTablePredicateExtractor) {
rse := extractor.(*plannercore.TiKVRegionStatusExtractor)
tableids := rse.GetTablesID()
slices.Sort(tableids)
require.Equal(t, []int64{1}, tableids)
},
},
{
prepared: "select * from information_schema.TIKV_REGION_STATUS where table_id = ? or table_id = ?",
userVars: []interface{}{1, 2},
params: []interface{}{1, 2},
checker: func(extractor plannercore.MemTablePredicateExtractor) {
rse := extractor.(*plannercore.TiKVRegionStatusExtractor)
tableids := rse.GetTablesID()
slices.Sort(tableids)
require.Equal(t, []int64{1, 2}, tableids)
},
},
{
prepared: "select * from information_schema.TIKV_REGION_STATUS where table_id in (?,?)",
userVars: []interface{}{1, 2},
params: []interface{}{1, 2},
checker: func(extractor plannercore.MemTablePredicateExtractor) {
rse := extractor.(*plannercore.TiKVRegionStatusExtractor)
tableids := rse.GetTablesID()
slices.Sort(tableids)
require.Equal(t, []int64{1, 2}, tableids)
},
},
{
prepared: "select * from information_schema.COLUMNS where table_name like ?",
userVars: []interface{}{`"a%"`},
params: []interface{}{"a%"},
checker: func(extractor plannercore.MemTablePredicateExtractor) {
rse := extractor.(*plannercore.ColumnsTableExtractor)
require.EqualValues(t, []string{"a%"}, rse.TableNamePatterns)
},
},
{
prepared: "select * from information_schema.tidb_hot_regions_history where update_time>=?",
userVars: []interface{}{"cast('2019-10-10 10:10:10' as datetime)"},
params: []interface{}{func() types.Time {
tk.Session().GetSessionVars().StmtCtx.TimeZone = time.Local
tt, err := types.ParseTimestamp(tk.Session().GetSessionVars().StmtCtx, "2019-10-10 10:10:10")
require.NoError(t, err)
return tt
}()},
checker: func(extractor plannercore.MemTablePredicateExtractor) {
rse := extractor.(*plannercore.HotRegionsHistoryTableExtractor)
require.Equal(t, timestamp(t, "2019-10-10 10:10:10"), rse.StartTime)
},
},
}

// text protocol
parser := parser.New()
for _, ca := range cases {
tk.MustExec(fmt.Sprintf("prepare stmt from '%s'", ca.prepared))
setStmt := "set "
exec := "execute stmt using "
for i, uv := range ca.userVars {
name := fmt.Sprintf("@a%d", i)
setStmt += fmt.Sprintf("%s=%v", name, uv)
exec += name
if i != len(ca.userVars)-1 {
setStmt += ","
exec += ","
}
}
tk.MustExec(setStmt)
stmt, err := parser.ParseOneStmt(exec, "", "")
require.NoError(t, err)
plan, err := planner.OptimizeExecStmt(context.Background(), tk.Session(), stmt.(*ast.ExecuteStmt), dom.InfoSchema())
require.NoError(t, err)
extractor := plan.(*plannercore.PhysicalMemTable).Extractor
ca.checker(extractor)
}

// binary protocol
for _, ca := range cases {
id, _, _, err := tk.Session().PrepareStmt(ca.prepared)
require.NoError(t, err)
execStmt := &ast.ExecuteStmt{
BinaryArgs: types.MakeDatums(ca.params...),
ExecID: id,
}
plan, err := planner.OptimizeExecStmt(context.Background(), tk.Session(), execStmt, dom.InfoSchema())
require.NoError(t, err)
extractor := plan.(*plannercore.PhysicalMemTable).Extractor
ca.checker(extractor)
}
}

0 comments on commit 8b20dfa

Please sign in to comment.