diff --git a/internal/grpcsync/pubsub.go b/internal/grpcsync/pubsub.go new file mode 100644 index 000000000000..f58b5ffa6b1e --- /dev/null +++ b/internal/grpcsync/pubsub.go @@ -0,0 +1,136 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpcsync + +import ( + "context" + "sync" +) + +// Subscriber represents an entity that is subscribed to messages published on +// a PubSub. It wraps the callback to be invoked by the PubSub when a new +// message is published. +type Subscriber interface { + // OnMessage is invoked when a new message is published. Implementations + // must not block in this method. + OnMessage(msg interface{}) +} + +// PubSub is a simple one-to-many publish-subscribe system that supports +// messages of arbitrary type. It guarantees that messages are delivered in +// the same order in which they were published. +// +// Publisher invokes the Publish() method to publish new messages, while +// subscribers interested in receiving these messages register a callback +// via the Subscribe() method. +// +// Once a PubSub is stopped, no more messages can be published, and +// it is guaranteed that no more subscriber callback will be invoked. +type PubSub struct { + cs *CallbackSerializer + cancel context.CancelFunc + + // Access to the below fields are guarded by this mutex. + mu sync.Mutex + msg interface{} + subscribers map[Subscriber]bool + stopped bool +} + +// NewPubSub returns a new PubSub instance. +func NewPubSub() *PubSub { + ctx, cancel := context.WithCancel(context.Background()) + return &PubSub{ + cs: NewCallbackSerializer(ctx), + cancel: cancel, + subscribers: map[Subscriber]bool{}, + } +} + +// Subscribe registers the provided Subscriber to the PubSub. +// +// If the PubSub contains a previously published message, the Subscriber's +// OnMessage() callback will be invoked asynchronously with the existing +// message to begin with, and subsequently for every newly published message. +// +// The caller is responsible for invoking the returned cancel function to +// unsubscribe itself from the PubSub. +func (ps *PubSub) Subscribe(sub Subscriber) (cancel func()) { + ps.mu.Lock() + defer ps.mu.Unlock() + + if ps.stopped { + return func() {} + } + + ps.subscribers[sub] = true + + if ps.msg != nil { + msg := ps.msg + ps.cs.Schedule(func(context.Context) { + ps.mu.Lock() + defer ps.mu.Unlock() + if !ps.subscribers[sub] { + return + } + sub.OnMessage(msg) + }) + } + + return func() { + ps.mu.Lock() + defer ps.mu.Unlock() + delete(ps.subscribers, sub) + } +} + +// Publish publishes the provided message to the PubSub, and invokes +// callbacks registered by subscribers asynchronously. +func (ps *PubSub) Publish(msg interface{}) { + ps.mu.Lock() + defer ps.mu.Unlock() + + if ps.stopped { + return + } + + ps.msg = msg + for sub := range ps.subscribers { + s := sub + ps.cs.Schedule(func(context.Context) { + ps.mu.Lock() + defer ps.mu.Unlock() + if !ps.subscribers[s] { + return + } + s.OnMessage(msg) + }) + } +} + +// Stop shuts down the PubSub and releases any resources allocated by it. +// It is guaranteed that no subscriber callbacks would be invoked once this +// method returns. +func (ps *PubSub) Stop() { + ps.mu.Lock() + defer ps.mu.Unlock() + ps.stopped = true + + ps.cancel() +} diff --git a/internal/grpcsync/pubsub_test.go b/internal/grpcsync/pubsub_test.go new file mode 100644 index 000000000000..9aebf3593a5b --- /dev/null +++ b/internal/grpcsync/pubsub_test.go @@ -0,0 +1,211 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpcsync + +import ( + "fmt" + "sync" + "testing" + "time" + + "github.com/google/go-cmp/cmp" +) + +type testSubscriber struct { + mu sync.Mutex + msgs []int + onMsgCh chan struct{} +} + +func newTestSubscriber(chSize int) *testSubscriber { + return &testSubscriber{onMsgCh: make(chan struct{}, chSize)} +} + +func (ts *testSubscriber) OnMessage(msg interface{}) { + ts.mu.Lock() + defer ts.mu.Unlock() + ts.msgs = append(ts.msgs, msg.(int)) + select { + case ts.onMsgCh <- struct{}{}: + default: + } +} + +func (ts *testSubscriber) receivedMsgs() []int { + ts.mu.Lock() + defer ts.mu.Unlock() + + msgs := make([]int, len(ts.msgs)) + copy(msgs, ts.msgs) + + return msgs +} + +func (s) TestPubSub_PublishNoMsg(t *testing.T) { + pubsub := NewPubSub() + defer pubsub.Stop() + + ts := newTestSubscriber(1) + pubsub.Subscribe(ts) + + select { + case <-ts.onMsgCh: + t.Fatalf("Subscriber callback invoked when no message was published") + case <-time.After(defaultTestShortTimeout): + } +} + +func (s) TestPubSub_PublishMsgs_RegisterSubs_And_Stop(t *testing.T) { + pubsub := NewPubSub() + + const numPublished = 10 + + ts1 := newTestSubscriber(numPublished) + pubsub.Subscribe(ts1) + wantMsgs1 := []int{} + + var wg sync.WaitGroup + wg.Add(2) + // Publish ten messages on the pubsub and ensure that they are received in order by the subscriber. + go func() { + for i := 0; i < numPublished; i++ { + pubsub.Publish(i) + wantMsgs1 = append(wantMsgs1, i) + } + wg.Done() + }() + + isTimeout := false + go func() { + for i := 0; i < numPublished; i++ { + select { + case <-ts1.onMsgCh: + case <-time.After(defaultTestTimeout): + isTimeout = true + } + } + wg.Done() + }() + + wg.Wait() + if isTimeout { + t.Fatalf("Timeout when expecting the onMessage() callback to be invoked") + } + if gotMsgs1 := ts1.receivedMsgs(); !cmp.Equal(gotMsgs1, wantMsgs1) { + t.Fatalf("Received messages is %v, want %v", gotMsgs1, wantMsgs1) + } + + // Register another subscriber and ensure that it receives the last published message. + ts2 := newTestSubscriber(numPublished) + pubsub.Subscribe(ts2) + wantMsgs2 := wantMsgs1[len(wantMsgs1)-1:] + + select { + case <-ts2.onMsgCh: + case <-time.After(defaultTestShortTimeout): + t.Fatalf("Timeout when expecting the onMessage() callback to be invoked") + } + if gotMsgs2 := ts2.receivedMsgs(); !cmp.Equal(gotMsgs2, wantMsgs2) { + t.Fatalf("Received messages is %v, want %v", gotMsgs2, wantMsgs2) + } + + wg.Add(3) + // Publish ten messages on the pubsub and ensure that they are received in order by the subscribers. + go func() { + for i := 0; i < numPublished; i++ { + pubsub.Publish(i) + wantMsgs1 = append(wantMsgs1, i) + wantMsgs2 = append(wantMsgs2, i) + } + wg.Done() + }() + errCh := make(chan error, 1) + go func() { + for i := 0; i < numPublished; i++ { + select { + case <-ts1.onMsgCh: + case <-time.After(defaultTestTimeout): + errCh <- fmt.Errorf("") + } + } + wg.Done() + }() + go func() { + for i := 0; i < numPublished; i++ { + select { + case <-ts2.onMsgCh: + case <-time.After(defaultTestTimeout): + errCh <- fmt.Errorf("") + } + } + wg.Done() + }() + wg.Wait() + select { + case <-errCh: + t.Fatalf("Timeout when expecting the onMessage() callback to be invoked") + default: + } + if gotMsgs1 := ts1.receivedMsgs(); !cmp.Equal(gotMsgs1, wantMsgs1) { + t.Fatalf("Received messages is %v, want %v", gotMsgs1, wantMsgs1) + } + if gotMsgs2 := ts2.receivedMsgs(); !cmp.Equal(gotMsgs2, wantMsgs2) { + t.Fatalf("Received messages is %v, want %v", gotMsgs2, wantMsgs2) + } + + pubsub.Stop() + + go func() { + pubsub.Publish(99) + }() + // Ensure that the subscriber callback is not invoked as instantiated + // pubsub has already closed. + select { + case <-ts1.onMsgCh: + t.Fatalf("The callback was invoked after pubsub being stopped") + case <-ts2.onMsgCh: + t.Fatalf("The callback was invoked after pubsub being stopped") + case <-time.After(defaultTestShortTimeout): + } +} + +func (s) TestPubSub_PublishMsgs_BeforeRegisterSub(t *testing.T) { + pubsub := NewPubSub() + defer pubsub.Stop() + + const numPublished = 3 + for i := 0; i < numPublished; i++ { + pubsub.Publish(i) + } + + ts := newTestSubscriber(numPublished) + pubsub.Subscribe(ts) + + wantMsgs := []int{numPublished - 1} + // Ensure that the subscriber callback is invoked with a previously + // published message. + select { + case <-ts.onMsgCh: + if gotMsgs := ts.receivedMsgs(); !cmp.Equal(gotMsgs, wantMsgs) { + t.Fatalf("Received messages is %v, want %v", gotMsgs, wantMsgs) + } + case <-time.After(defaultTestShortTimeout): + t.Fatalf("Timeout when expecting the onMessage() callback to be invoked") + } +}