diff --git a/dialect/pgdialect/alter_table.go b/dialect/pgdialect/alter_table.go index af103fe86..192d9138f 100644 --- a/dialect/pgdialect/alter_table.go +++ b/dialect/pgdialect/alter_table.go @@ -18,6 +18,8 @@ type Migrator struct { db *bun.DB } +var _ sqlschema.Migrator = (*Migrator)(nil) + func (m *Migrator) RenameTable(ctx context.Context, oldName, newName string) error { query := fmt.Sprintf("ALTER TABLE %s RENAME TO %s", oldName, newName) _, err := m.db.ExecContext(ctx, query) @@ -26,3 +28,28 @@ func (m *Migrator) RenameTable(ctx context.Context, oldName, newName string) err } return nil } + +func (m *Migrator) AddContraint(ctx context.Context, fk sqlschema.FK, name string) error { + q := m.db.NewRaw( + "ALTER TABLE ?.? ADD CONSTRAINT ? FOREIGN KEY (?) REFERENCES ?.? (?)", + bun.Safe(fk.From.Schema), bun.Safe(fk.From.Table), bun.Safe(name), + bun.Safe(fk.From.Column.String()), + bun.Safe(fk.To.Schema), bun.Safe(fk.To.Table), + bun.Safe(fk.To.Column.String()), + ) + if _, err := q.Exec(ctx); err != nil { + return err + } + return nil +} + +func (m *Migrator) DropContraint(ctx context.Context, schema, table, name string) error { + q := m.db.NewRaw( + "ALTER TABLE ?.? DROP CONSTRAINT ?", + bun.Safe(schema), bun.Safe(table), bun.Safe(name), + ) + if _, err := q.Exec(ctx); err != nil { + return err + } + return nil +} diff --git a/internal/dbtest/db_test.go b/internal/dbtest/db_test.go index c3ad08565..ddc9d70a5 100644 --- a/internal/dbtest/db_test.go +++ b/internal/dbtest/db_test.go @@ -1784,7 +1784,7 @@ func mustResetModel(tb testing.TB, ctx context.Context, db *bun.DB, models ...in func mustDropTableOnCleanup(tb testing.TB, ctx context.Context, db *bun.DB, models ...interface{}) { tb.Cleanup(func() { for _, model := range models { - drop := db.NewDropTable().IfExists().Model(model) + drop := db.NewDropTable().IfExists().Cascade().Model(model) _, err := drop.Exec(ctx) require.NoError(tb, err, "must drop table: %q", drop.GetTableName()) } diff --git a/internal/dbtest/migrate_test.go b/internal/dbtest/migrate_test.go index d1da1ea34..b721ab644 100644 --- a/internal/dbtest/migrate_test.go +++ b/internal/dbtest/migrate_test.go @@ -170,6 +170,7 @@ func TestAutoMigrator_Run(t *testing.T) { }{ {testRenameTable}, {testCreateDropTable}, + {testAlterForeignKeys}, } testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) { @@ -260,6 +261,87 @@ func testCreateDropTable(t *testing.T, db *bun.DB) { require.Equal(t, "createme", tables[0].Name) } +func testAlterForeignKeys(t *testing.T, db *bun.DB) { + // Initial state -- each thing has one owner + type OwnerExclusive struct { + bun.BaseModel `bun:"owners"` + ID int64 `bun:",pk"` + } + + type ThingExclusive struct { + bun.BaseModel `bun:"things"` + ID int64 `bun:",pk"` + OwnerID int64 `bun:",notnull"` + + Owner *OwnerExclusive `bun:"rel:belongs-to,join:owner_id=id"` + } + + // Change -- each thing has multiple owners + + type ThingCommon struct { + bun.BaseModel `bun:"things"` + ID int64 `bun:",pk"` + } + + type OwnerCommon struct { + bun.BaseModel `bun:"owners"` + ID int64 `bun:",pk"` + Things []*ThingCommon `bun:"m2m:things_to_owners,join:Owner=Thing"` + } + + type ThingsToOwner struct { + OwnerID int64 `bun:",notnull"` + Owner *OwnerCommon `bun:"rel:belongs-to,join:owner_id=id"` + ThingID int64 `bun:",notnull"` + Thing *ThingCommon `bun:"rel:belongs-to,join:thing_id=id"` + } + + // Arrange + ctx := context.Background() + dbInspector, err := sqlschema.NewInspector(db) + if err != nil { + t.Skip(err) + } + db.RegisterModel((*ThingsToOwner)(nil)) + + mustCreateTableWithFKs(t, ctx, db, + (*OwnerExclusive)(nil), + (*ThingExclusive)(nil), + ) + mustDropTableOnCleanup(t, ctx, db, (*ThingsToOwner)(nil)) + + m, err := migrate.NewAutoMigrator(db, + migrate.WithTableNameAuto(migrationsTable), + migrate.WithLocksTableNameAuto(migrationLocksTable), + migrate.WithModel((*ThingCommon)(nil)), + migrate.WithModel((*OwnerCommon)(nil)), + migrate.WithModel((*ThingsToOwner)(nil)), + ) + require.NoError(t, err) + + // Act + err = m.Run(ctx) + require.NoError(t, err) + + // Assert + state, err := dbInspector.Inspect(ctx) + require.NoError(t, err) + + defaultSchema := db.Dialect().DefaultSchema() + require.Contains(t, state.FKs, sqlschema.FK{ + From: sqlschema.C(defaultSchema, "things_to_owners", "owner_id"), + To: sqlschema.C(defaultSchema, "owners", "id"), + }) + require.Contains(t, state.FKs, sqlschema.FK{ + From: sqlschema.C(defaultSchema, "things_to_owners", "thing_id"), + To: sqlschema.C(defaultSchema, "things", "id"), + }) + require.NotContains(t, state.FKs, sqlschema.FK{ + From: sqlschema.C(defaultSchema, "things", "owner_id"), + To: sqlschema.C(defaultSchema, "owners", "id"), + }) +} + func TestDetector_Diff(t *testing.T) { type Journal struct { ISBN string `bun:"isbn,pk"` @@ -419,16 +501,20 @@ func TestDetector_Diff(t *testing.T) { }, want: []migrate.Operation{ &migrate.AddForeignKey{ + SourceSchema: dialect.DefaultSchema(), SourceTable: "users", SourceColumns: []string{"pet_kind", "pet_name"}, + TargetSchema: dialect.DefaultSchema(), TargetTable: "pets", - TargetColums: []string{"kind", "nickname"}, + TargetColumns: []string{"kind", "nickname"}, }, &migrate.AddForeignKey{ + SourceSchema: dialect.DefaultSchema(), SourceTable: "users", SourceColumns: []string{"friend"}, + TargetSchema: dialect.DefaultSchema(), TargetTable: "users", - TargetColums: []string{"username"}, + TargetColumns: []string{"username"}, }, }, }, @@ -447,10 +533,12 @@ func TestDetector_Diff(t *testing.T) { Model: &Owner{}, }, &migrate.AddForeignKey{ + SourceSchema: dialect.DefaultSchema(), SourceTable: "things", SourceColumns: []string{"owner_id"}, + TargetSchema: dialect.DefaultSchema(), TargetTable: "owners", - TargetColums: []string{"id"}, + TargetColumns: []string{"id"}, }, }, }, diff --git a/migrate/auto.go b/migrate/auto.go index 677ce6787..37b295382 100644 --- a/migrate/auto.go +++ b/migrate/auto.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "strings" + "time" "github.com/uptrace/bun" "github.com/uptrace/bun/migrate/sqlschema" @@ -172,7 +173,7 @@ func Diff(got, want sqlschema.State) Changeset { } type detector struct { - changes Changeset + changes Changeset } func newDetector() *detector { @@ -225,10 +226,12 @@ AddedLoop: for fk /*, fkName */ := range want.FKs { if _, ok := got.FKs[fk]; !ok { d.changes.Add(&AddForeignKey{ + SourceSchema: fk.From.Schema, SourceTable: fk.From.Table, SourceColumns: fk.From.Column.Split(), + TargetSchema: fk.To.Schema, TargetTable: fk.To.Table, - TargetColums: fk.To.Column.Split(), + TargetColumns: fk.To.Column.Split(), }) } } @@ -420,27 +423,34 @@ func trimSchema(name string) string { } type AddForeignKey struct { + SourceSchema string SourceTable string SourceColumns []string + TargetSchema string TargetTable string - TargetColums []string + TargetColumns []string } var _ Operation = (*AddForeignKey)(nil) func (op AddForeignKey) String() string { - return fmt.Sprintf("AddForeignKey %s(%s) references %s(%s)", - op.SourceTable, strings.Join(op.SourceColumns, ","), - op.TargetTable, strings.Join(op.TargetColums, ","), + return fmt.Sprintf("AddForeignKey %s.%s(%s) references %s.%s(%s)", + op.SourceSchema, op.SourceTable, strings.Join(op.SourceColumns, ","), + op.SourceTable, op.TargetTable, strings.Join(op.TargetColumns, ","), ) } func (op *AddForeignKey) Func(m sqlschema.Migrator) MigrationFunc { - return nil + return func(ctx context.Context, db *bun.DB) error { + return m.AddContraint(ctx, sqlschema.FK{ + From: sqlschema.C(op.SourceSchema, op.SourceTable, op.SourceColumns...), + To: sqlschema.C(op.TargetSchema, op.TargetTable, op.TargetColumns...), + }, "dummy_name_"+fmt.Sprint(time.Now().UnixNano())) + } } func (op *AddForeignKey) GetReverse() Operation { - return nil + return &noop{} // TODO: unless the WithFKNameFunc is specified, we cannot know what the constraint is called } type DropForeignKey struct { @@ -456,11 +466,13 @@ func (op *DropForeignKey) String() string { } func (op *DropForeignKey) Func(m sqlschema.Migrator) MigrationFunc { - return nil + return func(ctx context.Context, db *bun.DB) error { + return m.DropContraint(ctx, op.Schema, op.Table, op.ConstraintName) + } } func (op *DropForeignKey) GetReverse() Operation { - return nil + return &noop{} // TODO: store "OldFK" to recreate it } // sqlschema utils ------------------------------------------------------------ diff --git a/migrate/sqlschema/migrator.go b/migrate/sqlschema/migrator.go index 41b481f77..564e42a96 100644 --- a/migrate/sqlschema/migrator.go +++ b/migrate/sqlschema/migrator.go @@ -17,6 +17,8 @@ type Migrator interface { RenameTable(ctx context.Context, oldName, newName string) error CreateTable(ctx context.Context, model interface{}) error DropTable(ctx context.Context, schema, table string) error + AddContraint(ctx context.Context, fk FK, name string) error + DropContraint(ctx context.Context, schema, table, name string) error } // Migrator is a dialect-agnostic wrapper for sqlschema.Dialect diff --git a/migrate/sqlschema/state.go b/migrate/sqlschema/state.go index c57ea36d0..f48190c0f 100644 --- a/migrate/sqlschema/state.go +++ b/migrate/sqlschema/state.go @@ -92,9 +92,13 @@ func newComposite(columns ...string) composite { return composite(strings.Join(columns, ",")) } +func (c composite) String() string { + return string(c) +} + // Split returns a slice of column names that make up the composite. func (c composite) Split() []string { - return strings.Split(string(c), ",") + return strings.Split(c.String(), ",") } // Contains checks that a composite column contains every part of another composite. @@ -146,12 +150,6 @@ func (c cFQN) T() tFQN { // - depends on C{"A", "B", "C"} // - depends on C{"X", "Y", "Z"} // - depends on T{"A", "B"} and T{"X", "Y"} -// -// FIXME: current design does not allow for one column referencing multiple columns. Or does it? Think again. -// Consider: -// -// CONSTRAINT fk_customers FOREIGN KEY (customer_id) REFERENCES customers(id) -// CONSTRAINT fk_orders FOREIGN KEY (customer_id) REFERENCES orders(customer_id) type FK struct { From cFQN // From is the referencing column. To cFQN // To is the referenced column. @@ -188,6 +186,11 @@ func (fk *FK) DependsC(c cFQN) (bool, *cFQN) { // RefMap helps detecting modified FK relations. // It starts with an initial state and provides methods to update and delete // foreign key relations based on the column or table they depend on. +// +// Note: this is only important/necessary if we want to rename FKs instead of re-creating them. +// Most of the time it wouldn't make a difference, but there may be cases in which re-creating FKs could be costly +// and renaming them would be preferred. For that we could provided an options like WithRenameFKs(true) and +// WithRenameFKFunc(func(sqlschema.FK) string) to allow customizing the FK naming convention. type RefMap map[FK]*FK // deleted is a special value that RefMap uses to denote a deleted FK constraint.