diff --git a/docs/content/docs/getting-started/googlecloud/main.go b/docs/content/docs/getting-started/googlecloud/main.go index 65c0dec3d..6c6222e29 100644 --- a/docs/content/docs/getting-started/googlecloud/main.go +++ b/docs/content/docs/getting-started/googlecloud/main.go @@ -5,11 +5,9 @@ import ( "context" "log" - "github.com/ThreeDotsLabs/watermill/message/infrastructure/googlecloud" - "github.com/ThreeDotsLabs/watermill" - "github.com/ThreeDotsLabs/watermill/message" + "github.com/ThreeDotsLabs/watermill/message/infrastructure/googlecloud" ) func main() { @@ -37,7 +35,7 @@ func main() { go process(messages) - publisher, err := googlecloud.NewPublisher(context.Background(), googlecloud.PublisherConfig{ + publisher, err := googlecloud.NewPublisher(googlecloud.PublisherConfig{ ProjectID: "test-project", }) if err != nil { diff --git a/message/infrastructure/googlecloud/publisher.go b/message/infrastructure/googlecloud/publisher.go index 29e1f61e6..ba978d7b0 100644 --- a/message/infrastructure/googlecloud/publisher.go +++ b/message/infrastructure/googlecloud/publisher.go @@ -3,13 +3,13 @@ package googlecloud import ( "context" "sync" - - "github.com/ThreeDotsLabs/watermill" + "time" "cloud.google.com/go/pubsub" "github.com/pkg/errors" "google.golang.org/api/option" + "github.com/ThreeDotsLabs/watermill" "github.com/ThreeDotsLabs/watermill/message" ) @@ -21,8 +21,6 @@ var ( ) type Publisher struct { - ctx context.Context - topics map[string]*pubsub.Topic topicsLock sync.RWMutex closed bool @@ -39,6 +37,11 @@ type PublisherConfig struct { // Otherwise, trying to subscribe to non-existent subscription results in `ErrTopicDoesNotExist`. DoNotCreateTopicIfMissing bool + // ConnectTimeout defines the timeout for connecting to Pub/Sub + ConnectTimeout time.Duration + // PublishTimeout defines the timeout for publishing messages. + PublishTimeout time.Duration + // Settings for cloud.google.com/go/pubsub client library. PublishSettings *pubsub.PublishSettings ClientOptions []option.ClientOption @@ -52,21 +55,28 @@ func (c *PublisherConfig) setDefaults() { if c.Marshaler == nil { c.Marshaler = DefaultMarshalerUnmarshaler{} } - + if c.ConnectTimeout == 0 { + c.ConnectTimeout = time.Second * 10 + } + if c.PublishTimeout == 0 { + c.PublishTimeout = time.Second * 5 + } if c.Logger == nil { c.Logger = watermill.NopLogger{} } } -func NewPublisher(ctx context.Context, config PublisherConfig) (*Publisher, error) { +func NewPublisher(config PublisherConfig) (*Publisher, error) { config.setDefaults() pub := &Publisher{ - ctx: ctx, topics: map[string]*pubsub.Topic{}, config: config, } + ctx, cancel := context.WithTimeout(context.Background(), config.ConnectTimeout) + defer cancel() + var err error pub.client, err = pubsub.NewClient(ctx, config.ProjectID, config.ClientOptions...) if err != nil { @@ -88,7 +98,8 @@ func (p *Publisher) Publish(topic string, messages ...*message.Message) error { return ErrPublisherClosed } - ctx := p.ctx + ctx, cancel := context.WithTimeout(context.Background(), p.config.PublishTimeout) + defer cancel() t, err := p.topic(ctx, topic) if err != nil { diff --git a/message/infrastructure/googlecloud/pubsub_bench_test.go b/message/infrastructure/googlecloud/pubsub_bench_test.go index 32d1aab06..e64ac543c 100644 --- a/message/infrastructure/googlecloud/pubsub_bench_test.go +++ b/message/infrastructure/googlecloud/pubsub_bench_test.go @@ -3,6 +3,7 @@ package googlecloud_test import ( "context" "testing" + "time" "github.com/ThreeDotsLabs/watermill" "github.com/ThreeDotsLabs/watermill/message" @@ -14,10 +15,15 @@ import ( func BenchmarkSubscriber(b *testing.B) { infrastructure.BenchSubscriber(b, func(n int) message.PubSub { - ctx := context.Background() logger := watermill.NopLogger{} - publisher, err := googlecloud.NewPublisher(ctx, googlecloud.PublisherConfig{}) + publisher, err := googlecloud.NewPublisher(googlecloud.PublisherConfig{}) + if err != nil { + panic(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() subscriber, err := googlecloud.NewSubscriber( ctx, diff --git a/message/infrastructure/googlecloud/pubsub_test.go b/message/infrastructure/googlecloud/pubsub_test.go index 4a72a50fa..e1c154b36 100644 --- a/message/infrastructure/googlecloud/pubsub_test.go +++ b/message/infrastructure/googlecloud/pubsub_test.go @@ -22,9 +22,7 @@ import ( func newPubSub(t *testing.T, marshaler googlecloud.MarshalerUnmarshaler, subscriptionName googlecloud.SubscriptionNameFn) message.PubSub { logger := watermill.NewStdLogger(true, true) - ctx := context.Background() publisher, err := googlecloud.NewPublisher( - ctx, googlecloud.PublisherConfig{ Marshaler: marshaler, Logger: logger, @@ -32,6 +30,9 @@ func newPubSub(t *testing.T, marshaler googlecloud.MarshalerUnmarshaler, subscri ) require.NoError(t, err) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + subscriber, err := googlecloud.NewSubscriber( ctx, googlecloud.SubscriberConfig{ @@ -73,7 +74,9 @@ func TestPublishSubscribe(t *testing.T) { } func TestSubscriberUnexpectedTopicForSubscription(t *testing.T) { - ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + rand.Seed(time.Now().Unix()) testNumber := rand.Int() logger := watermill.NewStdLogger(true, true) @@ -97,7 +100,7 @@ func TestSubscriberUnexpectedTopicForSubscription(t *testing.T) { howManyMessages := 100 - messagesTopic1, err := sub1.Subscribe(context.Background(), topic1) + messagesTopic1, err := sub1.Subscribe(ctx, topic1) require.NoError(t, err) allMessagesReceived := make(chan struct{}) @@ -112,7 +115,7 @@ func TestSubscriberUnexpectedTopicForSubscription(t *testing.T) { } }() - produceMessages(t, ctx, topic1, howManyMessages) + produceMessages(t, topic1, howManyMessages) select { case <-allMessagesReceived: @@ -121,12 +124,12 @@ func TestSubscriberUnexpectedTopicForSubscription(t *testing.T) { t.Fatal("Test timed out") } - _, err = sub2.Subscribe(context.Background(), topic2) + _, err = sub2.Subscribe(ctx, topic2) require.Equal(t, googlecloud.ErrUnexpectedTopic, errors.Cause(err)) } -func produceMessages(t *testing.T, ctx context.Context, topic string, howMany int) { - pub, err := googlecloud.NewPublisher(ctx, googlecloud.PublisherConfig{}) +func produceMessages(t *testing.T, topic string, howMany int) { + pub, err := googlecloud.NewPublisher(googlecloud.PublisherConfig{}) require.NoError(t, err) defer pub.Close() diff --git a/message/infrastructure/googlecloud/subscriber.go b/message/infrastructure/googlecloud/subscriber.go index d516126d6..90310d8a7 100644 --- a/message/infrastructure/googlecloud/subscriber.go +++ b/message/infrastructure/googlecloud/subscriber.go @@ -4,13 +4,13 @@ import ( "context" "fmt" "sync" - - "google.golang.org/grpc" - "google.golang.org/grpc/codes" + "time" "cloud.google.com/go/pubsub" "github.com/pkg/errors" "google.golang.org/api/option" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" "github.com/ThreeDotsLabs/watermill" "github.com/ThreeDotsLabs/watermill/message" @@ -65,6 +65,9 @@ type SubscriberConfig struct { // Otherwise, trying to create a subscription on non-existent topic results in `ErrTopicDoesNotExist`. DoNotCreateTopicIfMissing bool + // InitializeTimeout defines the timeout for initializing topics. + InitializeTimeout time.Duration + // Settings for cloud.google.com/go/pubsub client library. ReceiveSettings pubsub.ReceiveSettings SubscriptionConfig pubsub.SubscriptionConfig @@ -93,6 +96,9 @@ func (c *SubscriberConfig) setDefaults() { if c.GenerateSubscriptionName == nil { c.GenerateSubscriptionName = TopicSubscriptionName } + if c.InitializeTimeout == 0 { + c.InitializeTimeout = time.Second * 10 + } if c.Unmarshaler == nil { c.Unmarshaler = DefaultMarshalerUnmarshaler{} } @@ -184,7 +190,7 @@ func (s *Subscriber) Subscribe(ctx context.Context, topic string) (<-chan *messa } func (s *Subscriber) SubscribeInitialize(topic string) (err error) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithTimeout(context.Background(), s.config.InitializeTimeout) defer cancel() subscriptionName := s.config.GenerateSubscriptionName(topic)