diff --git a/ddl/db_integration_test.go b/ddl/db_integration_test.go index 442b797a54dbc..cc9cc657fdc6f 100644 --- a/ddl/db_integration_test.go +++ b/ddl/db_integration_test.go @@ -2373,6 +2373,43 @@ func TestSqlFunctionsInGeneratedColumns(t *testing.T) { tk.MustExec("create table t (a int, b int as ((a)))") } +func TestSchemaNameAndTableNameInGeneratedExpr(t *testing.T) { + store := testkit.CreateMockStore(t, mockstore.WithDDLChecker()) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("create database if not exists test") + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + + tk.MustExec("create table t(a int, b int as (lower(test.t.a)))") + tk.MustQuery("show create table t").Check(testkit.Rows("t CREATE TABLE `t` (\n" + + " `a` int(11) DEFAULT NULL,\n" + + " `b` int(11) GENERATED ALWAYS AS (lower(`a`)) VIRTUAL\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + + tk.MustExec("drop table t") + tk.MustExec("create table t(a int)") + tk.MustExec("alter table t add column b int as (lower(test.t.a))") + tk.MustQuery("show create table t").Check(testkit.Rows("t CREATE TABLE `t` (\n" + + " `a` int(11) DEFAULT NULL,\n" + + " `b` int(11) GENERATED ALWAYS AS (lower(`a`)) VIRTUAL\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + + tk.MustGetErrCode("alter table t add index idx((lower(test.t1.a)))", errno.ErrBadField) + + tk.MustExec("drop table t") + tk.MustGetErrCode("create table t(a int, b int as (lower(test1.t.a)))", errno.ErrWrongDBName) + + tk.MustExec("create table t(a int)") + tk.MustGetErrCode("alter table t add column b int as (lower(test.t1.a))", errno.ErrWrongTableName) + + tk.MustExec("alter table t add column c int") + tk.MustGetErrCode("alter table t modify column c int as (test.t1.a + 1) stored", errno.ErrWrongTableName) + + tk.MustExec("alter table t add column d int as (lower(test.T.a))") + tk.MustExec("alter table t add column e int as (lower(Test.t.a))") +} + func TestParserIssue284(t *testing.T) { store := testkit.CreateMockStore(t, mockstore.WithDDLChecker()) diff --git a/ddl/ddl_api.go b/ddl/ddl_api.go index 2ae8a46134d4b..7cd65b47b170a 100644 --- a/ddl/ddl_api.go +++ b/ddl/ddl_api.go @@ -1082,7 +1082,7 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef, o var sb strings.Builder restoreFlags := format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameBackQuotes | - format.RestoreSpacesAroundBinaryOperation + format.RestoreSpacesAroundBinaryOperation | format.RestoreWithoutSchemaName | format.RestoreWithoutTableName restoreCtx := format.NewRestoreCtx(restoreFlags, &sb) for _, v := range colDef.Options { @@ -1142,7 +1142,10 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef, o } col.GeneratedExprString = sb.String() col.GeneratedStored = v.Stored - _, dependColNames := findDependedColumnNames(colDef) + _, dependColNames, err := findDependedColumnNames(model.NewCIStr(""), model.NewCIStr(""), colDef) + if err != nil { + return nil, nil, errors.Trace(err) + } col.Dependences = dependColNames case ast.ColumnOptionCollate: if field_types.HasCharset(colDef.Tp) { @@ -1561,7 +1564,7 @@ func IsAutoRandomColumnID(tblInfo *model.TableInfo, colID int64) bool { return false } -func checkGeneratedColumn(ctx sessionctx.Context, colDefs []*ast.ColumnDef) error { +func checkGeneratedColumn(ctx sessionctx.Context, schemaName model.CIStr, tableName model.CIStr, colDefs []*ast.ColumnDef) error { var colName2Generation = make(map[string]columnGenerationInDDL, len(colDefs)) var exists bool var autoIncrementColumn string @@ -1576,7 +1579,10 @@ func checkGeneratedColumn(ctx sessionctx.Context, colDefs []*ast.ColumnDef) erro if containsColumnOption(colDef, ast.ColumnOptionAutoIncrement) { exists, autoIncrementColumn = true, colDef.Name.Name.L } - generated, depCols := findDependedColumnNames(colDef) + generated, depCols, err := findDependedColumnNames(schemaName, tableName, colDef) + if err != nil { + return errors.Trace(err) + } if !generated { colName2Generation[colDef.Name.Name.L] = columnGenerationInDDL{ position: i, @@ -2094,7 +2100,7 @@ func CheckTableInfoValidWithStmt(ctx sessionctx.Context, tbInfo *model.TableInfo func checkTableInfoValidWithStmt(ctx sessionctx.Context, tbInfo *model.TableInfo, s *ast.CreateTableStmt) (err error) { // All of these rely on the AST structure of expressions, which were // lost in the model (got serialized into strings). - if err := checkGeneratedColumn(ctx, s.Cols); err != nil { + if err := checkGeneratedColumn(ctx, s.Table.Schema, tbInfo.Name, s.Cols); err != nil { return errors.Trace(err) } @@ -3672,7 +3678,10 @@ func CreateNewColumn(ctx sessionctx.Context, ti ast.Ident, schema *model.DBInfo, return nil, dbterror.ErrUnsupportedOnGeneratedColumn.GenWithStackByArgs("Adding generated stored column through ALTER TABLE") } - _, dependColNames := findDependedColumnNames(specNewColumn) + _, dependColNames, err := findDependedColumnNames(schema.Name, t.Meta().Name, specNewColumn) + if err != nil { + return nil, errors.Trace(err) + } if !ctx.GetSessionVars().EnableAutoIncrementInGenerated { if err := checkAutoIncrementRef(specNewColumn.Name.Name.L, dependColNames, t.Meta()); err != nil { return nil, errors.Trace(err) @@ -4485,7 +4494,7 @@ func setColumnComment(ctx sessionctx.Context, col *table.Column, option *ast.Col func processColumnOptions(ctx sessionctx.Context, col *table.Column, options []*ast.ColumnOption) error { var sb strings.Builder restoreFlags := format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameBackQuotes | - format.RestoreSpacesAroundBinaryOperation + format.RestoreSpacesAroundBinaryOperation | format.RestoreWithoutSchemaName | format.RestoreWithoutSchemaName restoreCtx := format.NewRestoreCtx(restoreFlags, &sb) var hasDefaultValue, setOnUpdateNow bool @@ -4827,7 +4836,7 @@ func GetModifiableColumnJob( } // As same with MySQL, we don't support modifying the stored status for generated columns. - if err = checkModifyGeneratedColumn(sctx, t, col, newCol, specNewColumn, spec.Position); err != nil { + if err = checkModifyGeneratedColumn(sctx, schema.Name, t, col, newCol, specNewColumn, spec.Position); err != nil { return nil, errors.Trace(err) } @@ -6306,7 +6315,7 @@ func BuildHiddenColumnInfo(ctx sessionctx.Context, indexPartSpecifications []*as var sb strings.Builder restoreFlags := format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameBackQuotes | - format.RestoreSpacesAroundBinaryOperation + format.RestoreSpacesAroundBinaryOperation | format.RestoreWithoutSchemaName | format.RestoreWithoutTableName restoreCtx := format.NewRestoreCtx(restoreFlags, &sb) sb.Reset() err := idxPart.Expr.Restore(restoreCtx) diff --git a/ddl/generated_column.go b/ddl/generated_column.go index 678d803edf521..6c5e99dab2cd6 100644 --- a/ddl/generated_column.go +++ b/ddl/generated_column.go @@ -122,13 +122,19 @@ func findPositionRelativeColumn(cols []*table.Column, pos *ast.ColumnPosition) ( // findDependedColumnNames returns a set of string, which indicates // the names of the columns that are depended by colDef. -func findDependedColumnNames(colDef *ast.ColumnDef) (generated bool, colsMap map[string]struct{}) { +func findDependedColumnNames(schemaName model.CIStr, tableName model.CIStr, colDef *ast.ColumnDef) (generated bool, colsMap map[string]struct{}, err error) { colsMap = make(map[string]struct{}) for _, option := range colDef.Options { if option.Tp == ast.ColumnOptionGenerated { generated = true colNames := FindColumnNamesInExpr(option.Expr) for _, depCol := range colNames { + if depCol.Schema.L != "" && schemaName.L != "" && depCol.Schema.L != schemaName.L { + return false, nil, dbterror.ErrWrongDBName.GenWithStackByArgs(depCol.Schema.O) + } + if depCol.Table.L != "" && tableName.L != "" && depCol.Table.L != tableName.L { + return false, nil, dbterror.ErrWrongTableName.GenWithStackByArgs(depCol.Table.O) + } colsMap[depCol.Name.L] = struct{}{} } break @@ -192,7 +198,7 @@ func (c *generatedColumnChecker) Leave(inNode ast.Node) (node ast.Node, ok bool) // 3. check if the modified expr contains non-deterministic functions // 4. check whether new column refers to any auto-increment columns. // 5. check if the new column is indexed or stored -func checkModifyGeneratedColumn(sctx sessionctx.Context, tbl table.Table, oldCol, newCol *table.Column, newColDef *ast.ColumnDef, pos *ast.ColumnPosition) error { +func checkModifyGeneratedColumn(sctx sessionctx.Context, schemaName model.CIStr, tbl table.Table, oldCol, newCol *table.Column, newColDef *ast.ColumnDef, pos *ast.ColumnPosition) error { // rule 1. oldColIsStored := !oldCol.IsGenerated() || oldCol.GeneratedStored newColIsStored := !newCol.IsGenerated() || newCol.GeneratedStored @@ -252,7 +258,10 @@ func checkModifyGeneratedColumn(sctx sessionctx.Context, tbl table.Table, oldCol } // rule 4. - _, dependColNames := findDependedColumnNames(newColDef) + _, dependColNames, err := findDependedColumnNames(schemaName, tbl.Meta().Name, newColDef) + if err != nil { + return errors.Trace(err) + } if !sctx.GetSessionVars().EnableAutoIncrementInGenerated { if err := checkAutoIncrementRef(newColDef.Name.Name.L, dependColNames, tbl.Meta()); err != nil { return errors.Trace(err) diff --git a/parser/ast/ddl_test.go b/parser/ast/ddl_test.go index 156a66398426f..dbbb212037db8 100644 --- a/parser/ast/ddl_test.go +++ b/parser/ast/ddl_test.go @@ -248,6 +248,21 @@ func TestDDLColumnOptionRestore(t *testing.T) { runNodeRestoreTest(t, testCases, "CREATE TABLE child (id INT %s)", extractNodeFunc) } +func TestGeneratedRestore(t *testing.T) { + testCases := []NodeRestoreTestCase{ + {"generated always as(id + 1)", "GENERATED ALWAYS AS(`id`+1) VIRTUAL"}, + {"generated always as(id + 1) virtual", "GENERATED ALWAYS AS(`id`+1) VIRTUAL"}, + {"generated always as(id + 1) stored", "GENERATED ALWAYS AS(`id`+1) STORED"}, + {"generated always as(lower(id)) stored", "GENERATED ALWAYS AS(LOWER(`id`)) STORED"}, + {"generated always as(lower(child.id)) stored", "GENERATED ALWAYS AS(LOWER(`id`)) STORED"}, + } + extractNodeFunc := func(node Node) Node { + return node.(*CreateTableStmt).Cols[0].Options[0] + } + runNodeRestoreTestWithFlagsStmtChange(t, testCases, "CREATE TABLE child (id INT %s)", extractNodeFunc, + format.DefaultRestoreFlags|format.RestoreWithoutSchemaName|format.RestoreWithoutTableName) +} + func TestDDLColumnDefRestore(t *testing.T) { testCases := []NodeRestoreTestCase{ // for type diff --git a/parser/ast/dml.go b/parser/ast/dml.go index c711da90d123f..4e97ae8d95882 100644 --- a/parser/ast/dml.go +++ b/parser/ast/dml.go @@ -288,15 +288,17 @@ func (*TableName) resultSet() {} // Restore implements Node interface. func (n *TableName) restoreName(ctx *format.RestoreCtx) { - // restore db name - if n.Schema.String() != "" { - ctx.WriteName(n.Schema.String()) - ctx.WritePlain(".") - } else if ctx.DefaultDB != "" { - // Try CTE, for a CTE table name, we shouldn't write the database name. - if !ctx.IsCTETableName(n.Name.L) { - ctx.WriteName(ctx.DefaultDB) + if !ctx.Flags.HasWithoutSchemaNameFlag() { + // restore db name + if n.Schema.String() != "" { + ctx.WriteName(n.Schema.String()) ctx.WritePlain(".") + } else if ctx.DefaultDB != "" { + // Try CTE, for a CTE table name, we shouldn't write the database name. + if !ctx.IsCTETableName(n.Name.L) { + ctx.WriteName(ctx.DefaultDB) + ctx.WritePlain(".") + } } } // restore table name diff --git a/parser/ast/expressions.go b/parser/ast/expressions.go index 270c46218af61..6bca49f4d2a7d 100644 --- a/parser/ast/expressions.go +++ b/parser/ast/expressions.go @@ -512,11 +512,11 @@ type ColumnName struct { // Restore implements Node interface. func (n *ColumnName) Restore(ctx *format.RestoreCtx) error { - if n.Schema.O != "" && !ctx.IsCTETableName(n.Table.L) { + if n.Schema.O != "" && !ctx.IsCTETableName(n.Table.L) && !ctx.Flags.HasWithoutSchemaNameFlag() { ctx.WriteName(n.Schema.O) ctx.WritePlain(".") } - if n.Table.O != "" { + if n.Table.O != "" && !ctx.Flags.HasWithoutTableNameFlag() { ctx.WriteName(n.Table.O) ctx.WritePlain(".") } diff --git a/parser/format/format.go b/parser/format/format.go index 284d4dff4e9df..a60c8d7b6589d 100644 --- a/parser/format/format.go +++ b/parser/format/format.go @@ -236,6 +236,8 @@ const ( RestoreTiDBSpecialComment SkipPlacementRuleForRestore RestoreWithTTLEnableOff + RestoreWithoutSchemaName + RestoreWithoutTableName ) const ( @@ -247,6 +249,16 @@ func (rf RestoreFlags) has(flag RestoreFlags) bool { return rf&flag != 0 } +// HasWithoutSchemaNameFlag returns a boolean indicating when `rf` has `RestoreWithoutSchemaName` flag. +func (rf RestoreFlags) HasWithoutSchemaNameFlag() bool { + return rf.has(RestoreWithoutSchemaName) +} + +// HasWithoutTableNameFlag returns a boolean indicating when `rf` has `RestoreWithoutTableName` flag. +func (rf RestoreFlags) HasWithoutTableNameFlag() bool { + return rf.has(RestoreWithoutTableName) +} + // HasStringSingleQuotesFlag returns a boolean indicating when `rf` has `RestoreStringSingleQuotes` flag. func (rf RestoreFlags) HasStringSingleQuotesFlag() bool { return rf.has(RestoreStringSingleQuotes)