diff --git a/roast/tester.go b/roast/tester.go index 44149b7..e27fd3d 100644 --- a/roast/tester.go +++ b/roast/tester.go @@ -23,6 +23,7 @@ import ( "github.com/256dpi/fire" "github.com/256dpi/fire/axe" "github.com/256dpi/fire/coal" + "github.com/256dpi/fire/stick" ) // TODO: Support filters and sorters. @@ -160,13 +161,18 @@ func (t *Tester) Invalidate() { } // List will list the provided model and validate the response if requested. -func (t *Tester) List(tt *testing.T, model coal.Model, response []coal.Model) Result { +func (t *Tester) List(tt *testing.T, model coal.Model, response []coal.Model, hooks ...func(res, cmp coal.Model)) Result { // list resources res, doc, err := t.DataClient.List(model) assert.NoError(tt, err) // validate response if err == nil && response != nil { + for i := range response { + for _, cb := range hooks { + cb(res[i], response[i]) + } + } assert.Equal(tt, response, res) } @@ -195,13 +201,16 @@ func (t *Tester) ListError(tt *testing.T, model coal.Model, expected error) Resu } // Find will find the provided model and validate the response if requested. -func (t *Tester) Find(tt *testing.T, model coal.Model, response coal.Model) Result { +func (t *Tester) Find(tt *testing.T, model coal.Model, response coal.Model, hooks ...func(res, cmp coal.Model)) Result { // find resource res, doc, err := t.DataClient.Find(model) assert.NoError(tt, err) // validate response if err == nil && response != nil { + for _, cb := range hooks { + cb(res, response) + } assert.Equal(tt, response, res) } @@ -231,7 +240,7 @@ func (t *Tester) FindError(tt *testing.T, model coal.Model, expected error) Resu // Create will create the provided model and validate the response and result // if requested. -func (t *Tester) Create(tt *testing.T, model, response, result coal.Model) Result { +func (t *Tester) Create(tt *testing.T, model, response, result coal.Model, hooks ...func(res, cmp coal.Model)) Result { // create resource res, doc, err := t.DataClient.Create(model) assert.NoError(tt, err) @@ -239,6 +248,9 @@ func (t *Tester) Create(tt *testing.T, model, response, result coal.Model) Resul // validate response if err == nil && response != nil { response.GetBase().DocID = res.ID() + for _, cb := range hooks { + cb(res, response) + } assert.Equal(tt, response, res) } @@ -247,6 +259,9 @@ func (t *Tester) Create(tt *testing.T, model, response, result coal.Model) Resul result.GetBase().DocID = res.ID() actual := coal.GetMeta(model).Make() t.Tester.Fetch(actual, res.ID()) + for _, cb := range hooks { + cb(actual, result) + } assert.Equal(tt, result, actual) } @@ -276,13 +291,16 @@ func (t *Tester) CreateError(tt *testing.T, model coal.Model, expected error) Re // Update will update the provided model and validate the response and result // if requested. -func (t *Tester) Update(tt *testing.T, model, response, result coal.Model) Result { +func (t *Tester) Update(tt *testing.T, model, response, result coal.Model, hooks ...func(res, cmp coal.Model)) Result { // update resource res, doc, err := t.DataClient.Update(model) assert.NoError(tt, err) // validate response if err == nil && response != nil { + for _, cb := range hooks { + cb(res, response) + } assert.Equal(tt, response, res) } @@ -290,6 +308,9 @@ func (t *Tester) Update(tt *testing.T, model, response, result coal.Model) Resul if err == nil && result != nil { actual := coal.GetMeta(model).Make() t.Tester.Fetch(actual, model.ID()) + for _, cb := range hooks { + cb(actual, result) + } assert.Equal(tt, result, actual) } @@ -318,7 +339,7 @@ func (t *Tester) UpdateError(tt *testing.T, model coal.Model, expected error) Re } // Delete will delete the provided model and validate the result. -func (t *Tester) Delete(tt *testing.T, model, result coal.Model) Result { +func (t *Tester) Delete(tt *testing.T, model, result coal.Model, hooks ...func(res, cmp coal.Model)) Result { // delete resource err := t.DataClient.Delete(model) assert.NoError(tt, err) @@ -327,6 +348,9 @@ func (t *Tester) Delete(tt *testing.T, model, result coal.Model) Result { if err == nil && result != nil { actual := coal.GetMeta(model).Make() t.Tester.Fetch(actual, model.ID()) + for _, fn := range hooks { + fn(actual, result) + } assert.Equal(tt, result, actual) } @@ -478,3 +502,21 @@ func (t *Tester) Await(tt *testing.T, timeout time.Duration, fns ...func()) int return n } + +// CopyHook will copy fields from the response to the result. +func CopyHook(fields ...string) func(res, cmp coal.Model) { + return func(res, cmp coal.Model) { + for _, field := range fields { + stick.MustSet(cmp, field, stick.MustGet(res, field)) + } + } +} + +// SkipHook will copy fields from the result to the response. +func SkipHook(fields ...string) func(res, cmp coal.Model) { + return func(res, cmp coal.Model) { + for _, field := range fields { + stick.MustSet(res, field, stick.MustGet(cmp, field)) + } + } +} diff --git a/roast/tester_test.go b/roast/tester_test.go index 930fc0d..d183c7e 100644 --- a/roast/tester_test.go +++ b/roast/tester_test.go @@ -44,6 +44,22 @@ func TestTester(t *testing.T) { tt.Delete(t, post, nil) } +func TestTesterHooks(t *testing.T) { + post1 := &fooModel{String: "String"} + post2 := &fooModel{} + + CopyHook("String")(post1, post2) + assert.Equal(t, "String", post1.String) + assert.Equal(t, "String", post2.String) + + post1 = &fooModel{String: "String"} + post2 = &fooModel{} + + SkipHook("String")(post1, post2) + assert.Zero(t, post1.String) + assert.Zero(t, post2.String) +} + func TestTesterErrors(t *testing.T) { tt := NewTester(Config{ Models: models.All(),