diff --git a/docs/content/docs/getting-started/googlecloud/main.go b/docs/content/docs/getting-started/googlecloud/main.go index 70bdd8373..1d07da5a6 100644 --- a/docs/content/docs/getting-started/googlecloud/main.go +++ b/docs/content/docs/getting-started/googlecloud/main.go @@ -31,7 +31,7 @@ func main() { panic(err) } - // Subscribe will create the subscription. Only messages that are sent after the subscription is created may be received. + // Subscribe will create the subscription. Only messages that are sent after the subscription is created may be received. messages, err := subscriber.Subscribe("example.topic") if err != nil { panic(err) diff --git a/docs/content/docs/message.md b/docs/content/docs/message.md index 4ebbae313..87576b58f 100644 --- a/docs/content/docs/message.md +++ b/docs/content/docs/message.md @@ -16,7 +16,7 @@ When message is processed, you should send [`Ack()`]({{< ref "#ack" >}}) or [`Na `Acks` and `Nacks` are processed by Subscribers (in default implementations subscribers are waiting for `Ack` or `Nack`). {{% render-md %}} -{{% load-snippet-partial file="content/src-link/message/message.go" first_line_contains="type Message struct {" last_line_contains="ackSentType ackType" padding_after="2" %}} +{{% load-snippet-partial file="content/src-link/message/message.go" first_line_contains="type Message struct {" last_line_contains="ctx context.Context" padding_after="2" %}} {{% /render-md %}} ### Ack @@ -40,3 +40,12 @@ When message is processed, you should send [`Ack()`]({{< ref "#ack" >}}) or [`Na {{% load-snippet-partial file="content/docs/message/receiving-ack.go" first_line_contains="select {" last_line_contains="}" padding_after="0" %}} {{% /render-md %}} + +### Context + +Message contains the standard library context, just like an HTTP request. + +{{% render-md %}} +{{% load-snippet-partial file="content/src-link/message/message.go" first_line_contains="// Context" last_line_contains="func (m *Message) SetContext" padding_after="2" %}} +{{% /render-md %}} + diff --git a/message/infrastructure/gochannel/pubsub.go b/message/infrastructure/gochannel/pubsub.go index 9bf9948c7..86d02a7a0 100644 --- a/message/infrastructure/gochannel/pubsub.go +++ b/message/infrastructure/gochannel/pubsub.go @@ -1,10 +1,12 @@ package gochannel import ( + "context" "sync" "time" "github.com/pkg/errors" + "github.com/satori/go.uuid" "github.com/ThreeDotsLabs/watermill" @@ -30,7 +32,8 @@ type GoChannel struct { logger watermill.LoggerAdapter - closed bool + closed bool + closing chan struct{} } func NewGoChannel(buffer int64, logger watermill.LoggerAdapter, sendTimeout time.Duration) message.PubSub { @@ -41,6 +44,8 @@ func NewGoChannel(buffer int64, logger watermill.LoggerAdapter, sendTimeout time subscribers: make(map[string][]*subscriber), subscribersLock: &sync.RWMutex{}, logger: logger, + + closing: make(chan struct{}), } } @@ -72,29 +77,50 @@ func (g *GoChannel) sendMessage(topic string, message *message.Message) error { } for _, s := range subscribers { - subscriberLogFields := messageLogFields.Add(watermill.LogFields{ - "subscriber_uuid": s.uuid, - }) - - SendToSubscriber: - for { - select { - case s.outputChannel <- message: - select { - case <-message.Acked(): - g.logger.Trace("message sent", subscriberLogFields) - break SendToSubscriber - case <-message.Nacked(): - g.logger.Trace("nack received, resending message", subscriberLogFields) - - // message have nack already sent, we need fresh message - message = message.Copy() - - continue SendToSubscriber - } - case <-time.After(g.sendTimeout): - return errors.Errorf("sending message %s timeouted after %s", message.UUID, g.sendTimeout) - } + if err := g.sendMessageToSubscriber(message, s, messageLogFields); err != nil { + return err + } + } + + return nil +} + +func (g *GoChannel) sendMessageToSubscriber(msg *message.Message, s *subscriber, messageLogFields watermill.LogFields) error { + subscriberLogFields := messageLogFields.Add(watermill.LogFields{ + "subscriber_uuid": s.uuid, + }) + +SendToSubscriber: + for { + // copy the message to prevent ack/nack propagation to other consumers + // also allows to make retries on a fresh copy of the original message + msgToSend := msg.Copy() + + ctx, cancelCtx := context.WithCancel(context.Background()) + msgToSend.SetContext(ctx) + defer cancelCtx() + + select { + case s.outputChannel <- msgToSend: + g.logger.Trace("Sent message to subscriber", subscriberLogFields) + case <-time.After(g.sendTimeout): + return errors.Errorf("Sending message %s timeouted after %s", msgToSend.UUID, g.sendTimeout) + case <-g.closing: + g.logger.Trace("Closing, message discarded", subscriberLogFields) + return nil + } + + select { + case <-msgToSend.Acked(): + g.logger.Trace("Message acked", subscriberLogFields) + break SendToSubscriber + case <-msgToSend.Nacked(): + g.logger.Trace("Nack received, resending message", subscriberLogFields) + + continue SendToSubscriber + case <-g.closing: + g.logger.Trace("Closing, message discarded", subscriberLogFields) + return nil } } @@ -123,13 +149,14 @@ func (g *GoChannel) Subscribe(topic string) (chan *message.Message, error) { } func (g *GoChannel) Close() error { - g.subscribersLock.Lock() - defer g.subscribersLock.Unlock() - if g.closed { return nil } g.closed = true + close(g.closing) + + g.subscribersLock.Lock() + defer g.subscribersLock.Unlock() for _, topicSubscribers := range g.subscribers { for _, subscriber := range topicSubscribers { diff --git a/message/infrastructure/kafka/subscriber.go b/message/infrastructure/kafka/subscriber.go index 57844bff6..8a4be32e7 100644 --- a/message/infrastructure/kafka/subscriber.go +++ b/message/infrastructure/kafka/subscriber.go @@ -395,6 +395,10 @@ func (h messageHandler) processMessage( return errors.Wrap(err, "message unmarshal failed") } + ctx, cancelCtx := context.WithCancel(context.Background()) + msg.SetContext(ctx) + defer cancelCtx() + receivedMsgLogFields = receivedMsgLogFields.Add(watermill.LogFields{ "message_uuid": msg.UUID, }) diff --git a/message/infrastructure/nats/subscriber.go b/message/infrastructure/nats/subscriber.go index a23e2ede8..231adbbc8 100644 --- a/message/infrastructure/nats/subscriber.go +++ b/message/infrastructure/nats/subscriber.go @@ -1,6 +1,7 @@ package nats import ( + "context" "sync" "time" @@ -229,6 +230,10 @@ func (s *StreamingSubscriber) processMessage(m *stan.Msg, output chan *message.M return } + ctx, cancelCtx := context.WithCancel(context.Background()) + msg.SetContext(ctx) + defer cancelCtx() + messageLogFields := logFields.Add(watermill.LogFields{"message_uuid": msg.UUID}) s.logger.Trace("Unmarshaled message", messageLogFields) diff --git a/message/infrastructure/test_pubsub.go b/message/infrastructure/test_pubsub.go index 6edcb456f..a2d98ad11 100644 --- a/message/infrastructure/test_pubsub.go +++ b/message/infrastructure/test_pubsub.go @@ -1,6 +1,7 @@ package infrastructure import ( + "context" "fmt" "log" "sync" @@ -113,6 +114,11 @@ func TestPubSub( t.Parallel() topicTest(t, pubSubConstructor(t)) }) + + t.Run("messageCtx", func(t *testing.T) { + t.Parallel() + testMessageCtx(t, pubSubConstructor(t)) + }) } var stressTestTestsCount = 20 @@ -559,6 +565,73 @@ func topicTest(t *testing.T, pubSub message.PubSub) { assert.Equal(t, messagesConsumedTopic2.IDs()[0], topic2Msg.UUID) } +func testMessageCtx(t *testing.T, pubSub message.PubSub) { + defer pubSub.Close() + + topic := testTopicName() + + messages, err := pubSub.Subscribe(topic) + require.NoError(t, err) + + go func() { + msg := message.NewMessage(uuid.NewV4().String(), nil) + + // ensuring that context is not propagated via pub/sub + ctx, ctxCancel := context.WithCancel(context.Background()) + ctxCancel() + msg.SetContext(ctx) + + require.NoError(t, pubSub.Publish(topic, msg)) + require.NoError(t, pubSub.Publish(topic, msg)) + }() + + select { + case msg := <-messages: + ctx := msg.Context() + + select { + case <-ctx.Done(): + t.Fatal("context should not be canceled") + default: + // ok + } + + require.NoError(t, msg.Ack()) + + select { + case <-ctx.Done(): + // ok + case <-time.After(defaultTimeout): + t.Fatal("context should be canceled after Ack") + } + case <-time.After(defaultTimeout): + t.Fatal("no message received") + } + + select { + case msg := <-messages: + ctx := msg.Context() + + select { + case <-ctx.Done(): + t.Fatal("context should not be canceled") + default: + // ok + } + + go require.NoError(t, pubSub.Close()) + + select { + case <-ctx.Done(): + // ok + case <-time.After(defaultTimeout): + t.Fatal("context should be canceled after pubSub.Close()") + } + case <-time.After(defaultTimeout): + t.Fatal("no message received") + } +} + func assertConsumerGroupReceivedMessages( t *testing.T, pubSubConstructor ConsumerGroupPubSubConstructor, diff --git a/message/message.go b/message/message.go index baba34323..a5b98eafb 100644 --- a/message/message.go +++ b/message/message.go @@ -2,6 +2,7 @@ package message import ( "bytes" + "context" "sync" "github.com/pkg/errors" @@ -43,6 +44,8 @@ type Message struct { ackMutex sync.Mutex ackSentType ackType + + ctx context.Context } func NewMessage(uuid string, payload Payload) *Message { @@ -158,6 +161,23 @@ func (m *Message) Nacked() <-chan struct{} { return m.noAck } +// Context returns the message's context. To change the context, use +// SetContext. +// +// The returned context is always non-nil; it defaults to the +// background context. +func (m Message) Context() context.Context { + if m.ctx != nil { + return m.ctx + } + return context.Background() +} + +// SetContext sets provided context to the message. +func (m *Message) SetContext(ctx context.Context) { + m.ctx = ctx +} + // Copy copies all message without Acks/Nacks. func (m Message) Copy() *Message { msg := NewMessage(m.UUID, m.Payload)