diff --git a/components/cqrs/command_bus.go b/components/cqrs/command_bus.go index 7dd414a52..8de7123a7 100644 --- a/components/cqrs/command_bus.go +++ b/components/cqrs/command_bus.go @@ -1,6 +1,8 @@ package cqrs import ( + "context" + "github.com/ThreeDotsLabs/watermill/message" ) @@ -30,11 +32,13 @@ func NewCommandBus( } // Send sends command to the command bus. -func (c CommandBus) Send(cmd interface{}) error { +func (c CommandBus) Send(ctx context.Context, cmd interface{}) error { msg, err := c.marshaler.Marshal(cmd) if err != nil { return err } + msg.SetContext(ctx) + return c.publisher.Publish(c.topic, msg) } diff --git a/components/cqrs/command_bus_test.go b/components/cqrs/command_bus_test.go new file mode 100644 index 000000000..ee651eed9 --- /dev/null +++ b/components/cqrs/command_bus_test.go @@ -0,0 +1,50 @@ +package cqrs + +import ( + "context" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ThreeDotsLabs/watermill/message" +) + +type publisherStub struct { + messages map[string]message.Messages + + mu sync.Mutex +} + +func newPublisherStub() *publisherStub { + return &publisherStub{ + messages: make(map[string]message.Messages), + } +} + +func (*publisherStub) Close() error { + return nil +} + +func (p *publisherStub) Publish(topic string, messages ...*message.Message) error { + p.mu.Lock() + defer p.mu.Unlock() + + p.messages[topic] = append(p.messages[topic], messages...) + + return nil +} + +func TestCommandBus_Send_ContextPropagation(t *testing.T) { + publisher := newPublisherStub() + + commandBus := NewCommandBus(publisher, "whatever", JSONMarshaler{}) + + ctx := context.WithValue(context.Background(), "key", "value") + + err := commandBus.Send(ctx, "message") + require.NoError(t, err) + + assert.Equal(t, ctx, publisher.messages["whatever"][0].Context()) +} diff --git a/components/cqrs/command_processor.go b/components/cqrs/command_processor.go index 85aba0fd2..93e4fd1a7 100644 --- a/components/cqrs/command_processor.go +++ b/components/cqrs/command_processor.go @@ -1,6 +1,7 @@ package cqrs import ( + "context" "fmt" "github.com/pkg/errors" @@ -15,7 +16,7 @@ import ( // In contrast to EvenHandler, every Command must have only one CommandHandler. type CommandHandler interface { NewCommand() interface{} - Handle(cmd interface{}) error + Handle(ctx context.Context, cmd interface{}) error } // CommandProcessor determines which CommandHandler should handle the command received from the command bus. @@ -120,7 +121,7 @@ func (p CommandProcessor) RouterHandlerFunc(handler CommandHandler) (message.Han return nil, err } - if err := handler.Handle(cmd); err != nil { + if err := handler.Handle(msg.Context(), cmd); err != nil { return nil, err } diff --git a/components/cqrs/command_processor_test.go b/components/cqrs/command_processor_test.go index 081218eec..310b0fbce 100644 --- a/components/cqrs/command_processor_test.go +++ b/components/cqrs/command_processor_test.go @@ -1,6 +1,7 @@ package cqrs_test import ( + "context" "testing" "github.com/pkg/errors" @@ -19,7 +20,7 @@ func (nonPointerCommandHandler) NewCommand() interface{} { return TestCommand{} } -func (nonPointerCommandHandler) Handle(cmd interface{}) error { +func (nonPointerCommandHandler) Handle(ctx context.Context, cmd interface{}) error { panic("not implemented") } diff --git a/components/cqrs/cqrs_test.go b/components/cqrs/cqrs_test.go index a04771f78..b84b22ce1 100644 --- a/components/cqrs/cqrs_test.go +++ b/components/cqrs/cqrs_test.go @@ -1,15 +1,17 @@ package cqrs_test import ( + "context" "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/ThreeDotsLabs/watermill" "github.com/ThreeDotsLabs/watermill/components/cqrs" "github.com/ThreeDotsLabs/watermill/message" "github.com/ThreeDotsLabs/watermill/message/infrastructure/gochannel" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) // TestCQRS is functional test of CQRS command handler and event handler. @@ -22,23 +24,23 @@ func TestCQRS(t *testing.T) { router, cqrsFacade := createRouterAndFacade(ts, t, captureCommandHandler, captureEventHandler) pointerCmd := &TestCommand{ID: watermill.NewULID()} - require.NoError(t, cqrsFacade.CommandBus().Send(pointerCmd)) + require.NoError(t, cqrsFacade.CommandBus().Send(context.Background(), pointerCmd)) assert.EqualValues(t, []interface{}{pointerCmd}, captureCommandHandler.HandledCommands()) captureCommandHandler.Reset() nonPointerCmd := TestCommand{ID: watermill.NewULID()} - require.NoError(t, cqrsFacade.CommandBus().Send(nonPointerCmd)) + require.NoError(t, cqrsFacade.CommandBus().Send(context.Background(), nonPointerCmd)) // command is always unmarshaled to pointer value assert.EqualValues(t, []interface{}{&nonPointerCmd}, captureCommandHandler.HandledCommands()) captureCommandHandler.Reset() pointerEvent := &TestEvent{ID: watermill.NewULID()} - require.NoError(t, cqrsFacade.EventBus().Publish(pointerEvent)) + require.NoError(t, cqrsFacade.EventBus().Publish(context.Background(), pointerEvent)) assert.EqualValues(t, []interface{}{pointerEvent}, captureEventHandler.HandledEvents()) captureEventHandler.Reset() nonPointerEvent := TestEvent{ID: watermill.NewULID()} - require.NoError(t, cqrsFacade.EventBus().Publish(nonPointerEvent)) + require.NoError(t, cqrsFacade.EventBus().Publish(context.Background(), nonPointerEvent)) // event is always unmarshaled to pointer value assert.EqualValues(t, []interface{}{&nonPointerEvent}, captureEventHandler.HandledEvents()) captureEventHandler.Reset() @@ -126,7 +128,7 @@ func (CaptureCommandHandler) NewCommand() interface{} { return &TestCommand{} } -func (h *CaptureCommandHandler) Handle(cmd interface{}) error { +func (h *CaptureCommandHandler) Handle(ctx context.Context, cmd interface{}) error { h.handledCommands = append(h.handledCommands, cmd.(*TestCommand)) return nil } @@ -152,7 +154,7 @@ func (CaptureEventHandler) NewEvent() interface{} { return &TestEvent{} } -func (h *CaptureEventHandler) Handle(cmd interface{}) error { - h.handledEvents = append(h.handledEvents, cmd.(*TestEvent)) +func (h *CaptureEventHandler) Handle(ctx context.Context, event interface{}) error { + h.handledEvents = append(h.handledEvents, event.(*TestEvent)) return nil } diff --git a/components/cqrs/event_bus.go b/components/cqrs/event_bus.go index 10140de90..84c270c0b 100644 --- a/components/cqrs/event_bus.go +++ b/components/cqrs/event_bus.go @@ -1,6 +1,8 @@ package cqrs import ( + "context" + "github.com/ThreeDotsLabs/watermill/message" ) @@ -30,11 +32,13 @@ func NewEventBus( } // Send sends command to the event bus. -func (c EventBus) Publish(event interface{}) error { +func (c EventBus) Publish(ctx context.Context, event interface{}) error { msg, err := c.marshaler.Marshal(event) if err != nil { return err } + msg.SetContext(ctx) + return c.publisher.Publish(c.topic, msg) } diff --git a/components/cqrs/event_bus_test.go b/components/cqrs/event_bus_test.go new file mode 100644 index 000000000..405b93f04 --- /dev/null +++ b/components/cqrs/event_bus_test.go @@ -0,0 +1,22 @@ +package cqrs + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestEventBus_Send_ContextPropagation(t *testing.T) { + publisher := newPublisherStub() + + eventBus := NewEventBus(publisher, "whatever", JSONMarshaler{}) + + ctx := context.WithValue(context.Background(), "key", "value") + + err := eventBus.Publish(ctx, "message") + require.NoError(t, err) + + assert.Equal(t, ctx, publisher.messages["whatever"][0].Context()) +} diff --git a/components/cqrs/event_processor.go b/components/cqrs/event_processor.go index 540697c80..02a30c678 100644 --- a/components/cqrs/event_processor.go +++ b/components/cqrs/event_processor.go @@ -1,6 +1,7 @@ package cqrs import ( + "context" "fmt" "github.com/pkg/errors" @@ -16,7 +17,7 @@ import ( // In contrast to CommandHandler, every Event can have multiple EventHandlers. type EventHandler interface { NewEvent() interface{} - Handle(event interface{}) error + Handle(ctx context.Context, event interface{}) error } // EventProcessor determines which EventHandler should handle event received from event bus. @@ -115,7 +116,7 @@ func (p EventProcessor) RouterHandlerFunc(handler EventHandler) (message.Handler return nil, err } - if err := handler.Handle(event); err != nil { + if err := handler.Handle(msg.Context(), event); err != nil { return nil, err } diff --git a/components/cqrs/event_processor_test.go b/components/cqrs/event_processor_test.go index 4f983a569..4e4321f47 100644 --- a/components/cqrs/event_processor_test.go +++ b/components/cqrs/event_processor_test.go @@ -1,6 +1,7 @@ package cqrs_test import ( + "context" "testing" "github.com/pkg/errors" @@ -19,7 +20,7 @@ func (nonPointerEventProcessor) NewEvent() interface{} { return TestEvent{} } -func (nonPointerEventProcessor) Handle(cmd interface{}) error { +func (nonPointerEventProcessor) Handle(ctx context.Context, cmd interface{}) error { panic("not implemented") } @@ -47,7 +48,7 @@ func (duplicateTestEventHandler1) NewEvent() interface{} { return &TestEvent{} } -func (h *duplicateTestEventHandler1) Handle(cmd interface{}) error { return nil } +func (h *duplicateTestEventHandler1) Handle(ctx context.Context, event interface{}) error { return nil } type duplicateTestEventHandler2 struct{} @@ -55,7 +56,7 @@ func (duplicateTestEventHandler2) NewEvent() interface{} { return &TestEvent{} } -func (h *duplicateTestEventHandler2) Handle(cmd interface{}) error { return nil } +func (h *duplicateTestEventHandler2) Handle(ctx context.Context, event interface{}) error { return nil } func TestEventProcessor_multiple_same_event_handlers(t *testing.T) { ts := NewTestServices() diff --git a/go.mod b/go.mod index 02815bfd7..961e7f661 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/cenkalti/backoff v2.1.1+incompatible github.com/go-chi/chi v4.0.2+incompatible github.com/gogo/protobuf v1.2.1 + github.com/golang/protobuf v1.2.1-0.20190205222052-c823c79ea157 github.com/google/uuid v1.1.1 github.com/hashicorp/go-multierror v1.0.0 github.com/nats-io/go-nats v1.7.2 // indirect diff --git a/go.sum b/go.sum index b1a139804..35a37f064 100644 --- a/go.sum +++ b/go.sum @@ -46,6 +46,8 @@ github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfb github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.2.1-0.20190205222052-c823c79ea157 h1:SdQMHsZ18/XZCHuwt3IF+dvHgYTO2XMWZjv3XBKQqAI= +github.com/golang/protobuf v1.2.1-0.20190205222052-c823c79ea157/go.mod h1:Qd/q+1AKNOZr9uGQzbzCmRO6sUih6GTPZv6a1/R87v0= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pOkfl+p/TAqKOfFu+7KPlMVpok/w= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=