Skip to content

Commit

Permalink
feat: migrate FKs
Browse files Browse the repository at this point in the history
  • Loading branch information
bevzzz committed Oct 27, 2024
1 parent a918dc4 commit 4c1dfdb
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 21 deletions.
27 changes: 27 additions & 0 deletions dialect/pgdialect/alter_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ type Migrator struct {
db *bun.DB
}

var _ sqlschema.Migrator = (*Migrator)(nil)

func (m *Migrator) RenameTable(ctx context.Context, oldName, newName string) error {
query := fmt.Sprintf("ALTER TABLE %s RENAME TO %s", oldName, newName)
_, err := m.db.ExecContext(ctx, query)
Expand All @@ -26,3 +28,28 @@ func (m *Migrator) RenameTable(ctx context.Context, oldName, newName string) err
}
return nil
}

func (m *Migrator) AddContraint(ctx context.Context, fk sqlschema.FK, name string) error {
q := m.db.NewRaw(
"ALTER TABLE ?.? ADD CONSTRAINT ? FOREIGN KEY (?) REFERENCES ?.? (?)",
bun.Safe(fk.From.Schema), bun.Safe(fk.From.Table), bun.Safe(name),
bun.Safe(fk.From.Column.String()),
bun.Safe(fk.To.Schema), bun.Safe(fk.To.Table),
bun.Safe(fk.To.Column.String()),
)
if _, err := q.Exec(ctx); err != nil {
return err
}
return nil
}

func (m *Migrator) DropContraint(ctx context.Context, schema, table, name string) error {
q := m.db.NewRaw(
"ALTER TABLE ?.? DROP CONSTRAINT ?",
bun.Safe(schema), bun.Safe(table), bun.Safe(name),
)
if _, err := q.Exec(ctx); err != nil {
return err
}
return nil
}
2 changes: 1 addition & 1 deletion internal/dbtest/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1784,7 +1784,7 @@ func mustResetModel(tb testing.TB, ctx context.Context, db *bun.DB, models ...in
func mustDropTableOnCleanup(tb testing.TB, ctx context.Context, db *bun.DB, models ...interface{}) {
tb.Cleanup(func() {
for _, model := range models {
drop := db.NewDropTable().IfExists().Model(model)
drop := db.NewDropTable().IfExists().Cascade().Model(model)
_, err := drop.Exec(ctx)
require.NoError(tb, err, "must drop table: %q", drop.GetTableName())
}
Expand Down
94 changes: 91 additions & 3 deletions internal/dbtest/migrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ func TestAutoMigrator_Run(t *testing.T) {
}{
{testRenameTable},
{testCreateDropTable},
{testAlterForeignKeys},
}

testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) {
Expand Down Expand Up @@ -260,6 +261,87 @@ func testCreateDropTable(t *testing.T, db *bun.DB) {
require.Equal(t, "createme", tables[0].Name)
}

func testAlterForeignKeys(t *testing.T, db *bun.DB) {
// Initial state -- each thing has one owner
type OwnerExclusive struct {
bun.BaseModel `bun:"owners"`
ID int64 `bun:",pk"`
}

type ThingExclusive struct {
bun.BaseModel `bun:"things"`
ID int64 `bun:",pk"`
OwnerID int64 `bun:",notnull"`

Owner *OwnerExclusive `bun:"rel:belongs-to,join:owner_id=id"`
}

// Change -- each thing has multiple owners

type ThingCommon struct {
bun.BaseModel `bun:"things"`
ID int64 `bun:",pk"`
}

type OwnerCommon struct {
bun.BaseModel `bun:"owners"`
ID int64 `bun:",pk"`
Things []*ThingCommon `bun:"m2m:things_to_owners,join:Owner=Thing"`
}

type ThingsToOwner struct {
OwnerID int64 `bun:",notnull"`
Owner *OwnerCommon `bun:"rel:belongs-to,join:owner_id=id"`
ThingID int64 `bun:",notnull"`
Thing *ThingCommon `bun:"rel:belongs-to,join:thing_id=id"`
}

// Arrange
ctx := context.Background()
dbInspector, err := sqlschema.NewInspector(db)
if err != nil {
t.Skip(err)
}
db.RegisterModel((*ThingsToOwner)(nil))

mustCreateTableWithFKs(t, ctx, db,
(*OwnerExclusive)(nil),
(*ThingExclusive)(nil),
)
mustDropTableOnCleanup(t, ctx, db, (*ThingsToOwner)(nil))

m, err := migrate.NewAutoMigrator(db,
migrate.WithTableNameAuto(migrationsTable),
migrate.WithLocksTableNameAuto(migrationLocksTable),
migrate.WithModel((*ThingCommon)(nil)),
migrate.WithModel((*OwnerCommon)(nil)),
migrate.WithModel((*ThingsToOwner)(nil)),
)
require.NoError(t, err)

// Act
err = m.Run(ctx)
require.NoError(t, err)

// Assert
state, err := dbInspector.Inspect(ctx)
require.NoError(t, err)

defaultSchema := db.Dialect().DefaultSchema()
require.Contains(t, state.FKs, sqlschema.FK{
From: sqlschema.C(defaultSchema, "things_to_owners", "owner_id"),
To: sqlschema.C(defaultSchema, "owners", "id"),
})
require.Contains(t, state.FKs, sqlschema.FK{
From: sqlschema.C(defaultSchema, "things_to_owners", "thing_id"),
To: sqlschema.C(defaultSchema, "things", "id"),
})
require.NotContains(t, state.FKs, sqlschema.FK{
From: sqlschema.C(defaultSchema, "things", "owner_id"),
To: sqlschema.C(defaultSchema, "owners", "id"),
})
}

func TestDetector_Diff(t *testing.T) {
type Journal struct {
ISBN string `bun:"isbn,pk"`
Expand Down Expand Up @@ -419,16 +501,20 @@ func TestDetector_Diff(t *testing.T) {
},
want: []migrate.Operation{
&migrate.AddForeignKey{
SourceSchema: dialect.DefaultSchema(),
SourceTable: "users",
SourceColumns: []string{"pet_kind", "pet_name"},
TargetSchema: dialect.DefaultSchema(),
TargetTable: "pets",
TargetColums: []string{"kind", "nickname"},
TargetColumns: []string{"kind", "nickname"},
},
&migrate.AddForeignKey{
SourceSchema: dialect.DefaultSchema(),
SourceTable: "users",
SourceColumns: []string{"friend"},
TargetSchema: dialect.DefaultSchema(),
TargetTable: "users",
TargetColums: []string{"username"},
TargetColumns: []string{"username"},
},
},
},
Expand All @@ -447,10 +533,12 @@ func TestDetector_Diff(t *testing.T) {
Model: &Owner{},
},
&migrate.AddForeignKey{
SourceSchema: dialect.DefaultSchema(),
SourceTable: "things",
SourceColumns: []string{"owner_id"},
TargetSchema: dialect.DefaultSchema(),
TargetTable: "owners",
TargetColums: []string{"id"},
TargetColumns: []string{"id"},
},
},
},
Expand Down
32 changes: 22 additions & 10 deletions migrate/auto.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"strings"
"time"

"github.com/uptrace/bun"
"github.com/uptrace/bun/migrate/sqlschema"
Expand Down Expand Up @@ -172,7 +173,7 @@ func Diff(got, want sqlschema.State) Changeset {
}

type detector struct {
changes Changeset
changes Changeset
}

func newDetector() *detector {
Expand Down Expand Up @@ -225,10 +226,12 @@ AddedLoop:
for fk /*, fkName */ := range want.FKs {
if _, ok := got.FKs[fk]; !ok {
d.changes.Add(&AddForeignKey{
SourceSchema: fk.From.Schema,
SourceTable: fk.From.Table,
SourceColumns: fk.From.Column.Split(),
TargetSchema: fk.To.Schema,
TargetTable: fk.To.Table,
TargetColums: fk.To.Column.Split(),
TargetColumns: fk.To.Column.Split(),
})
}
}
Expand Down Expand Up @@ -420,27 +423,34 @@ func trimSchema(name string) string {
}

type AddForeignKey struct {
SourceSchema string
SourceTable string
SourceColumns []string
TargetSchema string
TargetTable string
TargetColums []string
TargetColumns []string
}

var _ Operation = (*AddForeignKey)(nil)

func (op AddForeignKey) String() string {
return fmt.Sprintf("AddForeignKey %s(%s) references %s(%s)",
op.SourceTable, strings.Join(op.SourceColumns, ","),
op.TargetTable, strings.Join(op.TargetColums, ","),
return fmt.Sprintf("AddForeignKey %s.%s(%s) references %s.%s(%s)",
op.SourceSchema, op.SourceTable, strings.Join(op.SourceColumns, ","),
op.SourceTable, op.TargetTable, strings.Join(op.TargetColumns, ","),
)
}

func (op *AddForeignKey) Func(m sqlschema.Migrator) MigrationFunc {
return nil
return func(ctx context.Context, db *bun.DB) error {
return m.AddContraint(ctx, sqlschema.FK{
From: sqlschema.C(op.SourceSchema, op.SourceTable, op.SourceColumns...),
To: sqlschema.C(op.TargetSchema, op.TargetTable, op.TargetColumns...),
}, "dummy_name_"+fmt.Sprint(time.Now().UnixNano()))
}
}

func (op *AddForeignKey) GetReverse() Operation {
return nil
return &noop{} // TODO: unless the WithFKNameFunc is specified, we cannot know what the constraint is called
}

type DropForeignKey struct {
Expand All @@ -456,11 +466,13 @@ func (op *DropForeignKey) String() string {
}

func (op *DropForeignKey) Func(m sqlschema.Migrator) MigrationFunc {
return nil
return func(ctx context.Context, db *bun.DB) error {
return m.DropContraint(ctx, op.Schema, op.Table, op.ConstraintName)
}
}

func (op *DropForeignKey) GetReverse() Operation {
return nil
return &noop{} // TODO: store "OldFK" to recreate it
}

// sqlschema utils ------------------------------------------------------------
Expand Down
2 changes: 2 additions & 0 deletions migrate/sqlschema/migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ type Migrator interface {
RenameTable(ctx context.Context, oldName, newName string) error
CreateTable(ctx context.Context, model interface{}) error
DropTable(ctx context.Context, schema, table string) error
AddContraint(ctx context.Context, fk FK, name string) error
DropContraint(ctx context.Context, schema, table, name string) error
}

// Migrator is a dialect-agnostic wrapper for sqlschema.Dialect
Expand Down
17 changes: 10 additions & 7 deletions migrate/sqlschema/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,13 @@ func newComposite(columns ...string) composite {
return composite(strings.Join(columns, ","))
}

func (c composite) String() string {
return string(c)
}

// Split returns a slice of column names that make up the composite.
func (c composite) Split() []string {
return strings.Split(string(c), ",")
return strings.Split(c.String(), ",")
}

// Contains checks that a composite column contains every part of another composite.
Expand Down Expand Up @@ -146,12 +150,6 @@ func (c cFQN) T() tFQN {
// - depends on C{"A", "B", "C"}
// - depends on C{"X", "Y", "Z"}
// - depends on T{"A", "B"} and T{"X", "Y"}
//
// FIXME: current design does not allow for one column referencing multiple columns. Or does it? Think again.
// Consider:
//
// CONSTRAINT fk_customers FOREIGN KEY (customer_id) REFERENCES customers(id)
// CONSTRAINT fk_orders FOREIGN KEY (customer_id) REFERENCES orders(customer_id)
type FK struct {
From cFQN // From is the referencing column.
To cFQN // To is the referenced column.
Expand Down Expand Up @@ -188,6 +186,11 @@ func (fk *FK) DependsC(c cFQN) (bool, *cFQN) {
// RefMap helps detecting modified FK relations.
// It starts with an initial state and provides methods to update and delete
// foreign key relations based on the column or table they depend on.
//
// Note: this is only important/necessary if we want to rename FKs instead of re-creating them.
// Most of the time it wouldn't make a difference, but there may be cases in which re-creating FKs could be costly
// and renaming them would be preferred. For that we could provided an options like WithRenameFKs(true) and
// WithRenameFKFunc(func(sqlschema.FK) string) to allow customizing the FK naming convention.
type RefMap map[FK]*FK

// deleted is a special value that RefMap uses to denote a deleted FK constraint.
Expand Down

0 comments on commit 4c1dfdb

Please # to comment.