diff --git a/migrator.go b/migrator.go index 30f5e97..f64efc9 100644 --- a/migrator.go +++ b/migrator.go @@ -9,6 +9,7 @@ import ( "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/migrator" + "gorm.io/gorm/schema" ) type Migrator struct { @@ -19,12 +20,46 @@ func (m Migrator) GetTables() (tableList []string, err error) { return tableList, m.DB.Raw("SELECT table_name FROM INFORMATION_SCHEMA.tables WHERE table_catalog = ?", m.CurrentDatabase()).Scan(&tableList).Error } +func getTableSchemaName(schema *schema.Schema) string { + //return the schema name if it is explicitly provided in the table name + //otherwise return a sql wildcard -> use any table_schema + if schema == nil || !strings.Contains(schema.Table, ".") { + return "" + } + _, schemaName, _ := splitFullQualifiedName(schema.Table) + return schemaName +} + +func splitFullQualifiedName(name string) (string, string, string) { + nameParts := strings.Split(name, ".") + if len(nameParts) == 1 { //[table_name] + return "", "", nameParts[0] + } else if len(nameParts) == 2 { //[table_schema].[table_name] + return "", nameParts[0], nameParts[1] + } else if len(nameParts) == 3 { //[table_catalog].[table_schema].[table_name] + return nameParts[0], nameParts[1], nameParts[2] + } + return "", "", "" +} + +func getFullQualifiedTableName(stmt *gorm.Statement) string { + fullQualifiedTableName := stmt.Table + if schemaName := getTableSchemaName(stmt.Schema); schemaName != "" { + fullQualifiedTableName = schemaName + "." + fullQualifiedTableName + } + return fullQualifiedTableName +} + func (m Migrator) HasTable(value interface{}) bool { var count int m.RunWithValue(value, func(stmt *gorm.Statement) error { + schemaName := getTableSchemaName(stmt.Schema) + if schemaName == "" { + schemaName = "%" + } return m.DB.Raw( - "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", - stmt.Table, m.CurrentDatabase(), + "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ? and table_schema like ? AND table_type = ?", + stmt.Table, m.CurrentDatabase(), schemaName, "BASE TABLE", ).Row().Scan(&count) }) return count > 0 @@ -40,7 +75,7 @@ func (m Migrator) DropTable(values ...interface{}) error { Parent string } var constraints []constraint - err := tx.Raw("SELECT name, OBJECT_NAME(parent_object_id) as parent FROM sys.foreign_keys WHERE referenced_object_id = object_id(?)", stmt.Table).Scan(&constraints).Error + err := tx.Raw("SELECT name, OBJECT_NAME(parent_object_id) as parent FROM sys.foreign_keys WHERE referenced_object_id = object_id(?)", getFullQualifiedTableName(stmt)).Scan(&constraints).Error for _, c := range constraints { if err == nil { @@ -150,7 +185,7 @@ var defaultValueTrimRegexp = regexp.MustCompile("^\\('?([^']*)'?\\)$") func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { columnTypes := make([]gorm.ColumnType, 0) execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { - rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows() + rows, err := m.DB.Session(&gorm.Session{}).Table(getFullQualifiedTableName(stmt)).Limit(1).Rows() if err != nil { return err } @@ -259,7 +294,7 @@ func (m Migrator) HasIndex(value interface{}, name string) bool { return m.DB.Raw( "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", - name, stmt.Table, + name, getFullQualifiedTableName(stmt), ).Row().Scan(&count) }) return count > 0 @@ -285,9 +320,17 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool { name = chk.Name } + tableCatalog, schema, tableName := splitFullQualifiedName(table) + if tableCatalog == "" { + tableCatalog = m.CurrentDatabase() + } + if schema == "" { + schema = "%" + } + return m.DB.Raw( - `SELECT count(*) FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id inner join information_schema.tables as I on I.TABLE_NAME = T.name WHERE F.name = ? AND T.Name = ? AND I.TABLE_CATALOG = ?;`, - name, table, m.CurrentDatabase(), + `SELECT count(*) FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id inner join information_schema.tables as I on I.TABLE_NAME = T.name WHERE F.name = ? AND I.TABLE_NAME = ? AND I.TABLE_SCHEMA like ? AND I.TABLE_CATALOG = ?;`, + name, tableName, schema, tableCatalog, ).Row().Scan(&count) }) return count > 0 @@ -297,3 +340,8 @@ func (m Migrator) CurrentDatabase() (name string) { m.DB.Raw("SELECT DB_NAME() AS [Current Database]").Row().Scan(&name) return } + +func (m Migrator) DefaultSchema() (name string) { + m.DB.Raw("SELECT SCHEMA_NAME() AS [Default Schema]").Row().Scan(&name) + return +} diff --git a/migrator_test.go b/migrator_test.go new file mode 100644 index 0000000..3daeedb --- /dev/null +++ b/migrator_test.go @@ -0,0 +1,98 @@ +package sqlserver_test + +import ( + "os" + "testing" + + "gorm.io/driver/sqlserver" + "gorm.io/gorm" +) + +var sqlserverDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" + +func init() { + if dbDSN := os.Getenv("GORM_DSN"); dbDSN != "" { + sqlserverDSN = dbDSN + } +} + +type Testtable struct { + Test uint64 `gorm:"index"` +} + +type Testtable2 struct { + Test uint64 `gorm:"index"` + Test2 uint64 +} + +func (*Testtable2) TableName() string { return "testtables" } + +type Testtable3 struct { + Test3 uint64 +} + +func (*Testtable3) TableName() string { return "testschema1.Testtables" } + +type Testtable4 struct { + Test4 uint64 +} + +func (*Testtable4) TableName() string { return "testschema2.Testtables" } + +type Testtable5 struct { + Test4 uint64 + Test5 uint64 `gorm:"index"` +} + +func (*Testtable5) TableName() string { return "testschema2.Testtables" } + +func TestAutomigrateTablesWithoutDefaultSchema(t *testing.T) { + db, err := gorm.Open(sqlserver.Open(sqlserverDSN)) + if err != nil { + t.Error(err) + } + + if tx := db.Exec("create schema testschema1"); tx.Error != nil { + t.Error("couldn't create schema testschema1", tx.Error) + } + if tx := db.Exec("create schema testschema2"); tx.Error != nil { + t.Error("couldn't create schema testschema2", tx.Error) + } + + if err = db.AutoMigrate(&Testtable{}); err != nil { + t.Error("couldn't create a table at user default schema", err) + } + if err = db.AutoMigrate(&Testtable2{}); err != nil { + t.Error("couldn't update a table at user default schema", err) + } + if err = db.AutoMigrate(&Testtable3{}); err != nil { + t.Error("couldn't create a table at schema testschema1", err) + } + if err = db.AutoMigrate(&Testtable4{}); err != nil { + t.Error("couldn't create a table at schema testschema2", err) + } + if err = db.AutoMigrate(&Testtable5{}); err != nil { + t.Error("couldn't update a table at schema testschema2", err) + } + + if tx := db.Exec("drop table testtables"); tx.Error != nil { + t.Error("couldn't drop table testtable at user default schema", tx.Error) + } + + if tx := db.Exec("drop table testschema1.testtables"); tx.Error != nil { + t.Error("couldn't drop table testschema1.testtable", tx.Error) + } + + if tx := db.Exec("drop table testschema2.testtables"); tx.Error != nil { + t.Error("couldn't drop table testschema2.testtable", tx.Error) + } + + if tx := db.Exec("drop schema testschema1"); tx.Error != nil { + t.Error("couldn't drop schema testschema1", tx.Error) + } + + if tx := db.Exec("drop schema testschema2"); tx.Error != nil { + t.Error("couldn't drop schema testschema2", tx.Error) + } + +}