Skip to content

Commit

Permalink
Context support (#24)
Browse files Browse the repository at this point in the history
* [docs] added Google Cloud Pub/Sub docs

* update getting started go.mod's

* fixed gochannel test

* [docs] added arrow for collapse toggle

* small tests improvment

* added context to the message

* added extra check for message context test

* added context docs

* Language fixes etc

* Fix typos and go fmt
  • Loading branch information
roblaszczak authored and maclav3 committed Dec 18, 2018
1 parent a0a3acd commit ecdc76c
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 29 deletions.
2 changes: 1 addition & 1 deletion docs/content/docs/getting-started/googlecloud/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func main() {
panic(err)
}

// Subscribe will create the subscription. Only messages that are sent after the subscription is created may be received.
// Subscribe will create the subscription. Only messages that are sent after the subscription is created may be received.
messages, err := subscriber.Subscribe("example.topic")
if err != nil {
panic(err)
Expand Down
11 changes: 10 additions & 1 deletion docs/content/docs/message.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ When message is processed, you should send [`Ack()`]({{< ref "#ack" >}}) or [`Na
`Acks` and `Nacks` are processed by Subscribers (in default implementations subscribers are waiting for `Ack` or `Nack`).

{{% render-md %}}
{{% load-snippet-partial file="content/src-link/message/message.go" first_line_contains="type Message struct {" last_line_contains="ackSentType ackType" padding_after="2" %}}
{{% load-snippet-partial file="content/src-link/message/message.go" first_line_contains="type Message struct {" last_line_contains="ctx context.Context" padding_after="2" %}}
{{% /render-md %}}

### Ack
Expand All @@ -40,3 +40,12 @@ When message is processed, you should send [`Ack()`]({{< ref "#ack" >}}) or [`Na
{{% load-snippet-partial file="content/docs/message/receiving-ack.go" first_line_contains="select {" last_line_contains="}" padding_after="0" %}}
{{% /render-md %}}


### Context

Message contains the standard library context, just like an HTTP request.

{{% render-md %}}
{{% load-snippet-partial file="content/src-link/message/message.go" first_line_contains="// Context" last_line_contains="func (m *Message) SetContext" padding_after="2" %}}
{{% /render-md %}}

81 changes: 54 additions & 27 deletions message/infrastructure/gochannel/pubsub.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package gochannel

import (
"context"
"sync"
"time"

"github.com/pkg/errors"

"github.com/satori/go.uuid"

"github.com/ThreeDotsLabs/watermill"
Expand All @@ -30,7 +32,8 @@ type GoChannel struct {

logger watermill.LoggerAdapter

closed bool
closed bool
closing chan struct{}
}

func NewGoChannel(buffer int64, logger watermill.LoggerAdapter, sendTimeout time.Duration) message.PubSub {
Expand All @@ -41,6 +44,8 @@ func NewGoChannel(buffer int64, logger watermill.LoggerAdapter, sendTimeout time
subscribers: make(map[string][]*subscriber),
subscribersLock: &sync.RWMutex{},
logger: logger,

closing: make(chan struct{}),
}
}

Expand Down Expand Up @@ -72,29 +77,50 @@ func (g *GoChannel) sendMessage(topic string, message *message.Message) error {
}

for _, s := range subscribers {
subscriberLogFields := messageLogFields.Add(watermill.LogFields{
"subscriber_uuid": s.uuid,
})

SendToSubscriber:
for {
select {
case s.outputChannel <- message:
select {
case <-message.Acked():
g.logger.Trace("message sent", subscriberLogFields)
break SendToSubscriber
case <-message.Nacked():
g.logger.Trace("nack received, resending message", subscriberLogFields)

// message have nack already sent, we need fresh message
message = message.Copy()

continue SendToSubscriber
}
case <-time.After(g.sendTimeout):
return errors.Errorf("sending message %s timeouted after %s", message.UUID, g.sendTimeout)
}
if err := g.sendMessageToSubscriber(message, s, messageLogFields); err != nil {
return err
}
}

return nil
}

func (g *GoChannel) sendMessageToSubscriber(msg *message.Message, s *subscriber, messageLogFields watermill.LogFields) error {
subscriberLogFields := messageLogFields.Add(watermill.LogFields{
"subscriber_uuid": s.uuid,
})

SendToSubscriber:
for {
// copy the message to prevent ack/nack propagation to other consumers
// also allows to make retries on a fresh copy of the original message
msgToSend := msg.Copy()

ctx, cancelCtx := context.WithCancel(context.Background())
msgToSend.SetContext(ctx)
defer cancelCtx()

select {
case s.outputChannel <- msgToSend:
g.logger.Trace("Sent message to subscriber", subscriberLogFields)
case <-time.After(g.sendTimeout):
return errors.Errorf("Sending message %s timeouted after %s", msgToSend.UUID, g.sendTimeout)
case <-g.closing:
g.logger.Trace("Closing, message discarded", subscriberLogFields)
return nil
}

select {
case <-msgToSend.Acked():
g.logger.Trace("Message acked", subscriberLogFields)
break SendToSubscriber
case <-msgToSend.Nacked():
g.logger.Trace("Nack received, resending message", subscriberLogFields)

continue SendToSubscriber
case <-g.closing:
g.logger.Trace("Closing, message discarded", subscriberLogFields)
return nil
}
}

Expand Down Expand Up @@ -123,13 +149,14 @@ func (g *GoChannel) Subscribe(topic string) (chan *message.Message, error) {
}

func (g *GoChannel) Close() error {
g.subscribersLock.Lock()
defer g.subscribersLock.Unlock()

if g.closed {
return nil
}
g.closed = true
close(g.closing)

g.subscribersLock.Lock()
defer g.subscribersLock.Unlock()

for _, topicSubscribers := range g.subscribers {
for _, subscriber := range topicSubscribers {
Expand Down
4 changes: 4 additions & 0 deletions message/infrastructure/kafka/subscriber.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,10 @@ func (h messageHandler) processMessage(
return errors.Wrap(err, "message unmarshal failed")
}

ctx, cancelCtx := context.WithCancel(context.Background())
msg.SetContext(ctx)
defer cancelCtx()

receivedMsgLogFields = receivedMsgLogFields.Add(watermill.LogFields{
"message_uuid": msg.UUID,
})
Expand Down
5 changes: 5 additions & 0 deletions message/infrastructure/nats/subscriber.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package nats

import (
"context"
"sync"
"time"

Expand Down Expand Up @@ -229,6 +230,10 @@ func (s *StreamingSubscriber) processMessage(m *stan.Msg, output chan *message.M
return
}

ctx, cancelCtx := context.WithCancel(context.Background())
msg.SetContext(ctx)
defer cancelCtx()

messageLogFields := logFields.Add(watermill.LogFields{"message_uuid": msg.UUID})
s.logger.Trace("Unmarshaled message", messageLogFields)

Expand Down
73 changes: 73 additions & 0 deletions message/infrastructure/test_pubsub.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package infrastructure

import (
"context"
"fmt"
"log"
"sync"
Expand Down Expand Up @@ -113,6 +114,11 @@ func TestPubSub(
t.Parallel()
topicTest(t, pubSubConstructor(t))
})

t.Run("messageCtx", func(t *testing.T) {
t.Parallel()
testMessageCtx(t, pubSubConstructor(t))
})
}

var stressTestTestsCount = 20
Expand Down Expand Up @@ -559,6 +565,73 @@ func topicTest(t *testing.T, pubSub message.PubSub) {
assert.Equal(t, messagesConsumedTopic2.IDs()[0], topic2Msg.UUID)
}

func testMessageCtx(t *testing.T, pubSub message.PubSub) {
defer pubSub.Close()

topic := testTopicName()

messages, err := pubSub.Subscribe(topic)
require.NoError(t, err)

go func() {
msg := message.NewMessage(uuid.NewV4().String(), nil)

// ensuring that context is not propagated via pub/sub
ctx, ctxCancel := context.WithCancel(context.Background())
ctxCancel()
msg.SetContext(ctx)

require.NoError(t, pubSub.Publish(topic, msg))
require.NoError(t, pubSub.Publish(topic, msg))
}()

select {
case msg := <-messages:
ctx := msg.Context()

select {
case <-ctx.Done():
t.Fatal("context should not be canceled")
default:
// ok
}

require.NoError(t, msg.Ack())

select {
case <-ctx.Done():
// ok
case <-time.After(defaultTimeout):
t.Fatal("context should be canceled after Ack")
}
case <-time.After(defaultTimeout):
t.Fatal("no message received")
}

select {
case msg := <-messages:
ctx := msg.Context()

select {
case <-ctx.Done():
t.Fatal("context should not be canceled")
default:
// ok
}

go require.NoError(t, pubSub.Close())

select {
case <-ctx.Done():
// ok
case <-time.After(defaultTimeout):
t.Fatal("context should be canceled after pubSub.Close()")
}
case <-time.After(defaultTimeout):
t.Fatal("no message received")
}
}

func assertConsumerGroupReceivedMessages(
t *testing.T,
pubSubConstructor ConsumerGroupPubSubConstructor,
Expand Down
20 changes: 20 additions & 0 deletions message/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package message

import (
"bytes"
"context"
"sync"

"github.com/pkg/errors"
Expand Down Expand Up @@ -43,6 +44,8 @@ type Message struct {

ackMutex sync.Mutex
ackSentType ackType

ctx context.Context
}

func NewMessage(uuid string, payload Payload) *Message {
Expand Down Expand Up @@ -158,6 +161,23 @@ func (m *Message) Nacked() <-chan struct{} {
return m.noAck
}

// Context returns the message's context. To change the context, use
// SetContext.
//
// The returned context is always non-nil; it defaults to the
// background context.
func (m Message) Context() context.Context {
if m.ctx != nil {
return m.ctx
}
return context.Background()
}

// SetContext sets provided context to the message.
func (m *Message) SetContext(ctx context.Context) {
m.ctx = ctx
}

// Copy copies all message without Acks/Nacks.
func (m Message) Copy() *Message {
msg := NewMessage(m.UUID, m.Payload)
Expand Down

0 comments on commit ecdc76c

Please # to comment.