Skip to content

Commit

Permalink
feat: "skipupdate" model field tag (#565)
Browse files Browse the repository at this point in the history
* feat: "skipupdate" model field tag
  • Loading branch information
funvit authored Jun 28, 2022
1 parent 31b2cc4 commit 9288294
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 0 deletions.
44 changes: 44 additions & 0 deletions internal/dbtest/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"runtime"
"strings"
"testing"
"time"

"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect"
Expand Down Expand Up @@ -266,6 +267,7 @@ func TestDB(t *testing.T) {
{testBinaryData},
{testUpsert},
{testMultiUpdate},
{testUpdateWithSkipupdateTag},
{testTxScanAndCount},
{testEmbedModelValue},
{testEmbedModelPointer},
Expand Down Expand Up @@ -1247,6 +1249,48 @@ func testMultiUpdate(t *testing.T, db *bun.DB) {
require.NoError(t, err)
}

func testUpdateWithSkipupdateTag(t *testing.T, db *bun.DB) {
type Model struct {
ID int64 `bun:",pk,autoincrement"`
Name string
CreatedAt time.Time `bun:",skipupdate"`
}

ctx := context.Background()

err := db.ResetModel(ctx, (*Model)(nil))
require.NoError(t, err)

createdAt := time.Now().Truncate(time.Minute).UTC()

model := &Model{ID: 1, Name: "foo", CreatedAt: createdAt}

_, err = db.NewInsert().Model(model).Exec(ctx)
require.NoError(t, err)
require.NotZero(t, model.CreatedAt)

//
// update field with tag "skipupdate"
//
model.CreatedAt = model.CreatedAt.Add(2 * time.Minute)
_, err = db.NewUpdate().Model(model).WherePK().Exec(ctx)
require.NoError(t, err)

//
// check
//
model_ := new(Model)
model_.ID = model.ID
err = db.NewSelect().Model(model_).WherePK().Scan(ctx)
require.NoError(t, err, "select")
require.NotEmpty(t, model_)
require.Equal(t, model.ID, model_.ID)
require.Equal(t, model.Name, model_.Name)
require.Equal(t, createdAt.UTC(), model_.CreatedAt.UTC())

require.NotEqual(t, model.CreatedAt.UTC(), model_.CreatedAt.UTC())
}

func testTxScanAndCount(t *testing.T, db *bun.DB) {
type Model struct {
ID int64 `bun:",pk,autoincrement"`
Expand Down
45 changes: 45 additions & 0 deletions internal/dbtest/pg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -646,3 +646,48 @@ func TestPostgresHStoreQuote(t *testing.T) {
require.NoError(t, err)
require.Equal(t, wanted, m)
}

func TestPostgresSkipupdateField(t *testing.T) {
type Model struct {
ID int64 `bun:",pk,autoincrement"`
Name string
CreatedAt time.Time `bun:",skipupdate"`
}

ctx := context.Background()

db := pg(t)
defer db.Close()

err := db.ResetModel(ctx, (*Model)(nil))
require.NoError(t, err)

createdAt := time.Now().Truncate(time.Minute).UTC()

model := &Model{ID: 1, Name: "foo", CreatedAt: createdAt}

_, err = db.NewInsert().Model(model).Exec(ctx)
require.NoError(t, err)
require.NotZero(t, model.CreatedAt)

//
// update field with tag "skipupdate"
//
model.CreatedAt = model.CreatedAt.Add(2 * time.Minute)
_, err = db.NewUpdate().Model(model).WherePK().Exec(ctx)
require.NoError(t, err)

//
// check
//
model_ := new(Model)
model_.ID = model.ID
err = db.NewSelect().Model(model_).WherePK().Scan(ctx)
require.NoError(t, err, "select")
require.NotEmpty(t, model_)
require.Equal(t, model.ID, model_.ID)
require.Equal(t, model.Name, model_.Name)
require.Equal(t, createdAt.UTC(), model_.CreatedAt.UTC())

require.NotEqual(t, model.CreatedAt.UTC(), model_.CreatedAt.UTC())
}
4 changes: 4 additions & 0 deletions query_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,10 @@ func (q *UpdateQuery) appendSetStruct(
isTemplate := fmter.IsNop()
pos := len(b)
for _, f := range fields {
if f.SkipUpdate() {
continue
}

app, hasValue := q.modelValues[f.Name]

if !hasValue && q.omitZero && f.HasZeroValue(model.strct) {
Expand Down
4 changes: 4 additions & 0 deletions schema/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ func (f *Field) ScanValue(strct reflect.Value, src interface{}) error {
return f.ScanWithCheck(fv, src)
}

func (f *Field) SkipUpdate() bool {
return f.Tag.HasOption("skipupdate")
}

func indexEqual(ind1, ind2 []int) bool {
if len(ind1) != len(ind2) {
return false
Expand Down
1 change: 1 addition & 0 deletions schema/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,7 @@ func isKnownFieldOption(name string) bool {
"unique",
"soft_delete",
"scanonly",
"skipupdate",

"pk",
"autoincrement",
Expand Down

0 comments on commit 9288294

Please # to comment.