Skip to content

Commit

Permalink
stmtctx, *: change TypeCtx field to a private field (pingcap#47742)
Browse files Browse the repository at this point in the history
  • Loading branch information
YangKeao authored Oct 19, 2023
1 parent 6166594 commit 90bd2dd
Show file tree
Hide file tree
Showing 38 changed files with 144 additions and 127 deletions.
2 changes: 1 addition & 1 deletion pkg/ddl/ddl_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -1352,7 +1352,7 @@ func getDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.Colu
return str, false, err
}
// For other kind of fields (e.g. INT), we supply its integer as string value.
value, err := v.GetBinaryLiteral().ToInt(ctx.GetSessionVars().StmtCtx.TypeCtx)
value, err := v.GetBinaryLiteral().ToInt(ctx.GetSessionVars().StmtCtx.TypeCtx())
if err != nil {
return nil, false, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/executor/insert_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ func (e *InsertValues) fillRow(ctx context.Context, row []types.Datum, hasValue
if err != nil && gCol.FieldType.IsArray() {
return nil, completeError(tbl, gCol.Offset, rowIdx, err)
}
if e.Ctx().GetSessionVars().StmtCtx.TypeCtx.HandleTruncate(err) != nil {
if e.Ctx().GetSessionVars().StmtCtx.HandleTruncate(err) != nil {
return nil, err
}
row[colIdx], err = table.CastValue(e.Ctx(), val, gCol.ToInfo(), false, false)
Expand Down
4 changes: 2 additions & 2 deletions pkg/expression/aggregation/avg.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ func (af *avgFunction) ResetContext(sc *stmtctx.StatementContext, evalCtx *AggEv
func (af *avgFunction) Update(evalCtx *AggEvaluateContext, sc *stmtctx.StatementContext, row chunk.Row) (err error) {
switch af.Mode {
case Partial1Mode, CompleteMode:
err = af.updateSum(sc.TypeCtx, evalCtx, row)
err = af.updateSum(sc.TypeCtx(), evalCtx, row)
case Partial2Mode, FinalMode:
err = af.updateAvg(sc.TypeCtx, evalCtx, row)
err = af.updateAvg(sc.TypeCtx(), evalCtx, row)
case DedupMode:
panic("DedupMode is not supported now.")
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/expression/aggregation/sum.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type sumFunction struct {

// Update implements Aggregation interface.
func (sf *sumFunction) Update(evalCtx *AggEvaluateContext, sc *stmtctx.StatementContext, row chunk.Row) error {
return sf.updateSum(sc.TypeCtx, evalCtx, row)
return sf.updateSum(sc.TypeCtx(), evalCtx, row)
}

// GetResult implements Aggregation interface.
Expand Down
4 changes: 2 additions & 2 deletions pkg/expression/builtin_arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ func (s *builtinArithmeticDivideDecimalSig) evalDecimal(row chunk.Row) (*types.M
return c, true, handleDivisionByZeroError(s.ctx)
} else if err == types.ErrTruncated {
sc := s.ctx.GetSessionVars().StmtCtx
err = sc.TypeCtx.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c))
err = sc.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c))
} else if err == nil {
_, frac := c.PrecisionAndFrac()
if frac < s.baseBuiltinFunc.tp.GetDecimal() {
Expand Down Expand Up @@ -846,7 +846,7 @@ func (s *builtinArithmeticIntDivideDecimalSig) evalInt(row chunk.Row) (ret int64
return 0, true, handleDivisionByZeroError(s.ctx)
}
if err == types.ErrTruncated {
err = sc.TypeCtx.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c))
err = sc.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c))
}
if err == types.ErrOverflow {
newErr := errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c)
Expand Down
4 changes: 2 additions & 2 deletions pkg/expression/builtin_arithmetic_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func (b *builtinArithmeticDivideDecimalSig) vecEvalDecimal(input *chunk.Chunk, r
result.SetNull(i, true)
continue
} else if err == types.ErrTruncated {
if err = sc.TypeCtx.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", to)); err != nil {
if err = sc.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", to)); err != nil {
return err
}
} else if err == nil {
Expand Down Expand Up @@ -617,7 +617,7 @@ func (b *builtinArithmeticIntDivideDecimalSig) vecEvalInt(input *chunk.Chunk, re
continue
}
if err == types.ErrTruncated {
err = sc.TypeCtx.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c))
err = sc.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c))
} else if err == types.ErrOverflow {
newErr := errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c)
err = sc.HandleOverflow(newErr, newErr)
Expand Down
44 changes: 22 additions & 22 deletions pkg/expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ func convertJSON2Tp(evalType types.EvalType) func(*stmtctx.StatementContext, typ
if item.TypeCode != types.JSONTypeCodeString {
return nil, ErrInvalidJSONForFuncIndex
}
return types.ProduceStrWithSpecifiedTp(string(item.GetString()), tp, sc.TypeCtx, false)
return types.ProduceStrWithSpecifiedTp(string(item.GetString()), tp, sc.TypeCtx(), false)
}
case types.ETInt:
return func(sc *stmtctx.StatementContext, item types.BinaryJSON, tp *types.FieldType) (any, error) {
Expand All @@ -552,7 +552,7 @@ func convertJSON2Tp(evalType types.EvalType) func(*stmtctx.StatementContext, typ
if item.TypeCode != types.JSONTypeCodeFloat64 && item.TypeCode != types.JSONTypeCodeInt64 && item.TypeCode != types.JSONTypeCodeUint64 {
return nil, ErrInvalidJSONForFuncIndex
}
return types.ConvertJSONToFloat(sc.TypeCtx, item)
return types.ConvertJSONToFloat(sc.TypeCtx(), item)
}
case types.ETDatetime:
return func(sc *stmtctx.StatementContext, item types.BinaryJSON, tp *types.FieldType) (any, error) {
Expand Down Expand Up @@ -730,7 +730,7 @@ func (b *builtinCastIntAsStringSig) evalString(row chunk.Row) (res string, isNul
if tp.GetType() == mysql.TypeYear && res == "0" {
res = "0000"
}
res, err = types.ProduceStrWithSpecifiedTp(res, b.tp, b.ctx.GetSessionVars().StmtCtx.TypeCtx, false)
res, err = types.ProduceStrWithSpecifiedTp(res, b.tp, b.ctx.GetSessionVars().StmtCtx.TypeCtx(), false)
if err != nil {
return res, false, err
}
Expand Down Expand Up @@ -790,7 +790,7 @@ func (b *builtinCastIntAsDurationSig) evalDuration(row chunk.Row) (res types.Dur
err = b.ctx.GetSessionVars().StmtCtx.HandleOverflow(err, err)
}
if types.ErrTruncatedWrongVal.Equal(err) {
err = b.ctx.GetSessionVars().StmtCtx.TypeCtx.HandleTruncate(err)
err = b.ctx.GetSessionVars().StmtCtx.HandleTruncate(err)
}
return res, true, err
}
Expand Down Expand Up @@ -1045,7 +1045,7 @@ func (b *builtinCastRealAsStringSig) evalString(row chunk.Row) (res string, isNu
// If we strconv.FormatFloat the value with 64bits, the result is incorrect!
bits = 32
}
res, err = types.ProduceStrWithSpecifiedTp(strconv.FormatFloat(val, 'f', -1, bits), b.tp, b.ctx.GetSessionVars().StmtCtx.TypeCtx, false)
res, err = types.ProduceStrWithSpecifiedTp(strconv.FormatFloat(val, 'f', -1, bits), b.tp, b.ctx.GetSessionVars().StmtCtx.TypeCtx(), false)
if err != nil {
return res, false, err
}
Expand Down Expand Up @@ -1102,7 +1102,7 @@ func (b *builtinCastRealAsDurationSig) evalDuration(row chunk.Row) (res types.Du
res, _, err = types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, strconv.FormatFloat(val, 'f', -1, 64), b.tp.GetDecimal())
if err != nil {
if types.ErrTruncatedWrongVal.Equal(err) {
err = b.ctx.GetSessionVars().StmtCtx.TypeCtx.HandleTruncate(err)
err = b.ctx.GetSessionVars().StmtCtx.HandleTruncate(err)
// ErrTruncatedWrongVal needs to be considered NULL.
return res, true, err
}
Expand Down Expand Up @@ -1191,7 +1191,7 @@ func (b *builtinCastDecimalAsStringSig) evalString(row chunk.Row) (res string, i
return res, isNull, err
}
sc := b.ctx.GetSessionVars().StmtCtx
res, err = types.ProduceStrWithSpecifiedTp(string(val.ToString()), b.tp, sc.TypeCtx, false)
res, err = types.ProduceStrWithSpecifiedTp(string(val.ToString()), b.tp, sc.TypeCtx(), false)
if err != nil {
return res, false, err
}
Expand Down Expand Up @@ -1279,7 +1279,7 @@ func (b *builtinCastDecimalAsDurationSig) evalDuration(row chunk.Row) (res types
}
res, _, err = types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, string(val.ToString()), b.tp.GetDecimal())
if types.ErrTruncatedWrongVal.Equal(err) {
err = b.ctx.GetSessionVars().StmtCtx.TypeCtx.HandleTruncate(err)
err = b.ctx.GetSessionVars().StmtCtx.HandleTruncate(err)
// ErrTruncatedWrongVal needs to be considered NULL.
return res, true, err
}
Expand All @@ -1301,7 +1301,7 @@ func (b *builtinCastStringAsStringSig) evalString(row chunk.Row) (res string, is
if isNull || err != nil {
return res, isNull, err
}
res, err = types.ProduceStrWithSpecifiedTp(res, b.tp, b.ctx.GetSessionVars().StmtCtx.TypeCtx, false)
res, err = types.ProduceStrWithSpecifiedTp(res, b.tp, b.ctx.GetSessionVars().StmtCtx.TypeCtx(), false)
if err != nil {
return res, false, err
}
Expand Down Expand Up @@ -1366,7 +1366,7 @@ func (b *builtinCastStringAsIntSig) evalInt(row chunk.Row) (res int64, isNull bo
var ures uint64
sc := b.ctx.GetSessionVars().StmtCtx
if !isNegative {
ures, err = types.StrToUint(sc.TypeCtx, val, true)
ures, err = types.StrToUint(sc.TypeCtx(), val, true)
res = int64(ures)

if err == nil && !mysql.HasUnsignedFlag(b.tp.GetFlag()) && ures > uint64(math.MaxInt64) {
Expand All @@ -1375,7 +1375,7 @@ func (b *builtinCastStringAsIntSig) evalInt(row chunk.Row) (res int64, isNull bo
} else if b.inUnion && mysql.HasUnsignedFlag(b.tp.GetFlag()) {
res = 0
} else {
res, err = types.StrToInt(sc.TypeCtx, val, true)
res, err = types.StrToInt(sc.TypeCtx(), val, true)
if err == nil && mysql.HasUnsignedFlag(b.tp.GetFlag()) {
// If overflow, don't append this warnings
sc.AppendWarning(types.ErrCastNegIntAsUnsigned)
Expand Down Expand Up @@ -1411,7 +1411,7 @@ func (b *builtinCastStringAsRealSig) evalReal(row chunk.Row) (res float64, isNul
return res, isNull, err
}
sc := b.ctx.GetSessionVars().StmtCtx
res, err = types.StrToFloat(sc.TypeCtx, val, true)
res, err = types.StrToFloat(sc.TypeCtx(), val, true)
if err != nil {
return 0, false, err
}
Expand Down Expand Up @@ -1449,7 +1449,7 @@ func (b *builtinCastStringAsDecimalSig) evalDecimal(row chunk.Row) (res *types.M
if err == types.ErrTruncated {
err = types.ErrTruncatedWrongVal.GenWithStackByArgs("DECIMAL", []byte(val))
}
err = sc.TypeCtx.HandleTruncate(err)
err = sc.HandleTruncate(err)
if err != nil {
return res, false, err
}
Expand Down Expand Up @@ -1506,7 +1506,7 @@ func (b *builtinCastStringAsDurationSig) evalDuration(row chunk.Row) (res types.
res, isNull, err = types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, val, b.tp.GetDecimal())
if types.ErrTruncatedWrongVal.Equal(err) {
sc := b.ctx.GetSessionVars().StmtCtx
err = sc.TypeCtx.HandleTruncate(err)
err = sc.HandleTruncate(err)
}
return res, isNull, err
}
Expand Down Expand Up @@ -1619,7 +1619,7 @@ func (b *builtinCastTimeAsStringSig) evalString(row chunk.Row) (res string, isNu
return res, isNull, err
}
sc := b.ctx.GetSessionVars().StmtCtx
res, err = types.ProduceStrWithSpecifiedTp(val.String(), b.tp, sc.TypeCtx, false)
res, err = types.ProduceStrWithSpecifiedTp(val.String(), b.tp, sc.TypeCtx(), false)
if err != nil {
return res, false, err
}
Expand Down Expand Up @@ -1752,7 +1752,7 @@ func (b *builtinCastDurationAsStringSig) evalString(row chunk.Row) (res string,
return res, isNull, err
}
sc := b.ctx.GetSessionVars().StmtCtx
res, err = types.ProduceStrWithSpecifiedTp(val.String(), b.tp, sc.TypeCtx, false)
res, err = types.ProduceStrWithSpecifiedTp(val.String(), b.tp, sc.TypeCtx(), false)
if err != nil {
return res, false, err
}
Expand Down Expand Up @@ -1854,7 +1854,7 @@ func (b *builtinCastJSONAsRealSig) evalReal(row chunk.Row) (res float64, isNull
return res, isNull, err
}
sc := b.ctx.GetSessionVars().StmtCtx
res, err = types.ConvertJSONToFloat(sc.TypeCtx, val)
res, err = types.ConvertJSONToFloat(sc.TypeCtx(), val)
return
}

Expand All @@ -1874,7 +1874,7 @@ func (b *builtinCastJSONAsDecimalSig) evalDecimal(row chunk.Row) (res *types.MyD
return res, isNull, err
}
sc := b.ctx.GetSessionVars().StmtCtx
res, err = types.ConvertJSONToDecimal(sc.TypeCtx, val)
res, err = types.ConvertJSONToDecimal(sc.TypeCtx(), val)
if err != nil {
return res, false, err
}
Expand All @@ -1897,7 +1897,7 @@ func (b *builtinCastJSONAsStringSig) evalString(row chunk.Row) (res string, isNu
if isNull || err != nil {
return res, isNull, err
}
s, err := types.ProduceStrWithSpecifiedTp(val.String(), b.tp, b.ctx.GetSessionVars().StmtCtx.TypeCtx, false)
s, err := types.ProduceStrWithSpecifiedTp(val.String(), b.tp, b.ctx.GetSessionVars().StmtCtx.TypeCtx(), false)
if err != nil {
return res, false, err
}
Expand Down Expand Up @@ -1960,7 +1960,7 @@ func (b *builtinCastJSONAsTimeSig) evalTime(row chunk.Row) (res types.Time, isNu
return res, isNull, err
default:
err = types.ErrTruncatedWrongVal.GenWithStackByArgs(types.TypeStr(b.tp.GetType()), val.String())
return res, true, b.ctx.GetSessionVars().StmtCtx.TypeCtx.HandleTruncate(err)
return res, true, b.ctx.GetSessionVars().StmtCtx.HandleTruncate(err)
}
}

Expand Down Expand Up @@ -2002,12 +2002,12 @@ func (b *builtinCastJSONAsDurationSig) evalDuration(row chunk.Row) (res types.Du
res, _, err = types.ParseDuration(stmtCtx, s, b.tp.GetDecimal())
if types.ErrTruncatedWrongVal.Equal(err) {
sc := b.ctx.GetSessionVars().StmtCtx
err = sc.TypeCtx.HandleTruncate(err)
err = sc.HandleTruncate(err)
}
return res, isNull, err
default:
err = types.ErrTruncatedWrongVal.GenWithStackByArgs("TIME", val.String())
return res, true, stmtCtx.TypeCtx.HandleTruncate(err)
return res, true, stmtCtx.HandleTruncate(err)
}
}

Expand Down
Loading

0 comments on commit 90bd2dd

Please sign in to comment.