Skip to content

Commit

Permalink
impl cast as array
Browse files Browse the repository at this point in the history
Signed-off-by: xiongjiwei <[email protected]>
  • Loading branch information
xiongjiwei committed Dec 23, 2022
1 parent 74a2864 commit 7e67a11
Show file tree
Hide file tree
Showing 9 changed files with 273 additions and 60 deletions.
15 changes: 15 additions & 0 deletions errors.toml
Original file line number Diff line number Diff line change
Expand Up @@ -1696,6 +1696,21 @@ error = '''
Invalid data type for JSON data in argument %d to function %s; a JSON string or JSON type is required.
'''

["expression:3752"]
error = '''
Value is out of range for expression index '%s' at row %d
'''

["expression:3903"]
error = '''
Invalid JSON value for CAST for expression index '%s'
'''

["expression:3907"]
error = '''
Data too long for expression index '%s'
'''

["expression:8128"]
error = '''
Invalid TABLESAMPLE: %s
Expand Down
34 changes: 30 additions & 4 deletions executor/insert_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ func (e *InsertValues) evalRow(ctx context.Context, list []expression.Expression
e.evalBuffer.SetDatum(offset, val1)
}
// Row may lack of generated column, autoIncrement column, empty column here.
return e.fillRow(ctx, row, hasValue)
return e.fillRow(ctx, row, hasValue, rowIdx)
}

var emptyRow chunk.Row
Expand Down Expand Up @@ -422,7 +422,7 @@ func (e *InsertValues) fastEvalRow(ctx context.Context, list []expression.Expres
offset := e.insertColumns[i].Offset
row[offset], hasValue[offset] = val1, true
}
return e.fillRow(ctx, row, hasValue)
return e.fillRow(ctx, row, hasValue, rowIdx)
}

// setValueForRefColumn set some default values for the row to eval the row value with other columns,
Expand Down Expand Up @@ -562,7 +562,7 @@ func (e *InsertValues) getRow(ctx context.Context, vals []types.Datum) ([]types.
hasValue[offset] = true
}

return e.fillRow(ctx, row, hasValue)
return e.fillRow(ctx, row, hasValue, 0)
}

// getColDefaultValue gets the column default value.
Expand Down Expand Up @@ -647,7 +647,7 @@ func (e *InsertValues) fillColValue(ctx context.Context, datum types.Datum, idx
// `insert|replace values` can guarantee consecutive autoID in a batch.
// Other statements like `insert select from` don't guarantee consecutive autoID.
// https://dev.mysql.com/doc/refman/8.0/en/innodb-auto-increment-handling.html
func (e *InsertValues) fillRow(ctx context.Context, row []types.Datum, hasValue []bool) ([]types.Datum, error) {
func (e *InsertValues) fillRow(ctx context.Context, row []types.Datum, hasValue []bool, rowIdx int) ([]types.Datum, error) {
gCols := make([]*table.Column, 0)
tCols := e.Table.Cols()
if e.hasExtraHandle {
Expand Down Expand Up @@ -693,6 +693,9 @@ func (e *InsertValues) fillRow(ctx context.Context, row []types.Datum, hasValue
for i, gCol := range gCols {
colIdx := gCol.ColumnInfo.Offset
val, err := e.GenExprs[i].Eval(chunk.MutRowFromDatums(row).ToRow())
if err != nil && gCol.FieldType.IsArray() {
return nil, completeError(tbl, gCol.Offset, rowIdx, err)
}
if e.ctx.GetSessionVars().StmtCtx.HandleTruncate(err) != nil {
return nil, err
}
Expand All @@ -708,6 +711,29 @@ func (e *InsertValues) fillRow(ctx context.Context, row []types.Datum, hasValue
return row, nil
}

func completeError(tbl *model.TableInfo, offset int, rowIdx int, err error) error {
name := "expression_index"
for _, idx := range tbl.Indices {
for _, column := range idx.Columns {
if column.Offset == offset {
name = idx.Name.O
break
}
}
}

if expression.ErrInvalidJSONForFuncIndex.Equal(err) {
return expression.ErrInvalidJSONForFuncIndex.GenWithStackByArgs(name)
}
if types.ErrOverflow.Equal(err) {
return expression.ErrDataOutOfRangeFuncIndex.GenWithStackByArgs(name, rowIdx+1)
}
if types.ErrDataTooLong.Equal(err) {
return expression.ErrFuncIndexDataIsTooLong.GenWithStackByArgs(name)
}
return err
}

// isAutoNull can help judge whether a datum is AutoIncrement Null quickly.
// This used to help lazyFillAutoIncrement to find consecutive N datum backwards for batch autoID alloc.
func (e *InsertValues) isAutoNull(ctx context.Context, d types.Datum, col *table.Column) bool {
Expand Down
87 changes: 53 additions & 34 deletions expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/parser/terror"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
Expand Down Expand Up @@ -469,13 +470,23 @@ func (b *castJSONAsArrayFunctionSig) evalJSON(row chunk.Row) (res types.BinaryJS

arrayVals := make([]any, 0, len(b.args))
f := convertJSON2Tp(b.tp.ArrayType())
originVal := b.ctx.GetSessionVars().StmtCtx.OverflowAsWarning
b.ctx.GetSessionVars().StmtCtx.OverflowAsWarning = false
if f == nil {
return types.BinaryJSON{}, false, ErrNotSupportedYet.GenWithStackByArgs("CAS-ing JSON to the target type")
}
sc := b.ctx.GetSessionVars().StmtCtx
originalOverflowAsWarning := sc.OverflowAsWarning
originIgnoreTruncate := sc.IgnoreTruncate
originTruncateAsWarning := sc.TruncateAsWarning
sc.OverflowAsWarning = false
sc.IgnoreTruncate = false
sc.TruncateAsWarning = false
defer func() {
b.ctx.GetSessionVars().StmtCtx.OverflowAsWarning = originVal
sc.OverflowAsWarning = originalOverflowAsWarning
sc.IgnoreTruncate = originIgnoreTruncate
sc.TruncateAsWarning = originTruncateAsWarning
}()
for i := 0; i < val.GetElemCount(); i++ {
item, err := f(b, val.ArrayGetElem(i))
item, err := f(sc, val.ArrayGetElem(i))
if err != nil {
return types.BinaryJSON{}, false, err
}
Expand All @@ -484,43 +495,51 @@ func (b *castJSONAsArrayFunctionSig) evalJSON(row chunk.Row) (res types.BinaryJS
return types.CreateBinaryJSON(arrayVals), false, nil
}

func convertJSON2Tp(tp *types.FieldType) func(*castJSONAsArrayFunctionSig, types.BinaryJSON) (any, error) {
func convertJSON2Tp(tp *types.FieldType) func(*stmtctx.StatementContext, types.BinaryJSON) (any, error) {
switch tp.EvalType() {
case types.ETString:
return func(b *castJSONAsArrayFunctionSig, item types.BinaryJSON) (any, error) {
return func(sc *stmtctx.StatementContext, item types.BinaryJSON) (any, error) {
if item.TypeCode != types.JSONTypeCodeString {
return nil, errIncorrectArgs
return nil, ErrInvalidJSONForFuncIndex
}
return types.ProduceStrWithSpecifiedTp(string(item.GetString()), tp, b.ctx.GetSessionVars().StmtCtx, false)
return types.ProduceStrWithSpecifiedTp(string(item.GetString()), tp, sc, false)
}
default:
return func(b *castJSONAsArrayFunctionSig, item types.BinaryJSON) (any, error) {
switch tp.EvalType() {
case types.ETInt:
if item.TypeCode != types.JSONTypeCodeInt64 && item.TypeCode != types.JSONTypeCodeUint64 {
return nil, errIncorrectArgs
}
case types.ETReal, types.ETDecimal:
if item.TypeCode != types.JSONTypeCodeInt64 && item.TypeCode != types.JSONTypeCodeUint64 && item.TypeCode != types.JSONTypeCodeFloat64 {
return nil, errIncorrectArgs
}
case types.ETDatetime:
if item.TypeCode != types.JSONTypeCodeDatetime {
return nil, errIncorrectArgs
}
case types.ETTimestamp:
if item.TypeCode != types.JSONTypeCodeTimestamp {
return nil, errIncorrectArgs
}
case types.ETDuration:
if item.TypeCode != types.JSONTypeCodeDate {
return nil, errIncorrectArgs
}
case types.ETInt:
return func(sc *stmtctx.StatementContext, item types.BinaryJSON) (any, error) {
if item.TypeCode != types.JSONTypeCodeInt64 && item.TypeCode != types.JSONTypeCodeUint64 {
return nil, ErrInvalidJSONForFuncIndex
}
d := types.NewJSONDatum(item)
to, err := d.ConvertTo(b.ctx.GetSessionVars().StmtCtx, tp)
return to.GetValue(), err
return types.ConvertJSONToInt(sc, item, mysql.HasUnsignedFlag(tp.GetFlag()), tp.GetType())
}
case types.ETReal, types.ETDecimal:
return func(sc *stmtctx.StatementContext, item types.BinaryJSON) (any, error) {
if item.TypeCode != types.JSONTypeCodeInt64 && item.TypeCode != types.JSONTypeCodeUint64 && item.TypeCode != types.JSONTypeCodeFloat64 {
return nil, ErrInvalidJSONForFuncIndex
}
return types.ConvertJSONToFloat(sc, item)
}
case types.ETDatetime:
return func(sc *stmtctx.StatementContext, item types.BinaryJSON) (any, error) {
if (tp.GetType() == mysql.TypeDatetime && item.TypeCode != types.JSONTypeCodeDatetime) || (tp.GetType() == mysql.TypeDate && item.TypeCode != types.JSONTypeCodeDate) {
return nil, ErrInvalidJSONForFuncIndex
}
res := item.GetTime()
res.SetType(tp.GetType())
if tp.GetType() == mysql.TypeDate {
// Truncate hh:mm:ss part if the type is Date.
res.SetCoreTime(types.FromDate(res.Year(), res.Month(), res.Day(), 0, 0, 0, 0))
}
return res, nil
}
case types.ETDuration:
return func(sc *stmtctx.StatementContext, item types.BinaryJSON) (any, error) {
if item.TypeCode != types.JSONTypeCodeDuration {
return nil, ErrInvalidJSONForFuncIndex
}
return item.GetDuration(), nil
}
default:
return nil
}
}

Expand Down
2 changes: 1 addition & 1 deletion expression/builtin_cast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1661,7 +1661,7 @@ func TestCastArrayFunc(t *testing.T) {
[]interface{}{int64(-1), 2.1, int64(3)},
[]interface{}{int64(-1), 2.1, int64(3)},
types.NewFieldTypeBuilder().SetType(mysql.TypeDouble).SetCharset(charset.CharsetBin).SetCollate(charset.CharsetBin).SetArray(true).BuildP(),
false,
true,
true,
},
}
Expand Down
3 changes: 3 additions & 0 deletions expression/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ var (
ErrInternal = dbterror.ClassOptimizer.NewStd(mysql.ErrInternal)
ErrNoDB = dbterror.ClassOptimizer.NewStd(mysql.ErrNoDB)
ErrNotSupportedYet = dbterror.ClassExpression.NewStd(mysql.ErrNotSupportedYet)
ErrInvalidJSONForFuncIndex = dbterror.ClassExpression.NewStd(mysql.ErrInvalidJSONValueForFuncIndex)
ErrDataOutOfRangeFuncIndex = dbterror.ClassExpression.NewStd(mysql.ErrDataOutOfRangeFunctionalIndex)
ErrFuncIndexDataIsTooLong = dbterror.ClassExpression.NewStd(mysql.ErrFunctionalIndexDataIsTooLong)

// All the un-exported errors are defined here:
errFunctionNotExists = dbterror.ClassExpression.NewStd(mysql.ErrSpDoesNotExist)
Expand Down
Loading

0 comments on commit 7e67a11

Please sign in to comment.