Skip to content

Commit

Permalink
internal/grpcsync: Provide an internal-only pub-sub type API (#6167)
Browse files Browse the repository at this point in the history
Co-authored-by: Easwar Swaminathan <easwars@google.com>
  • Loading branch information
my4-dev and easwars authored Jun 30, 2023
1 parent 620a118 commit 51042db
Show file tree
Hide file tree
Showing 2 changed files with 347 additions and 0 deletions.
136 changes: 136 additions & 0 deletions internal/grpcsync/pubsub.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
*
* Copyright 2023 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package grpcsync

import (
"context"
"sync"
)

// Subscriber represents an entity that is subscribed to messages published on
// a PubSub. It wraps the callback to be invoked by the PubSub when a new
// message is published.
type Subscriber interface {
// OnMessage is invoked when a new message is published. Implementations
// must not block in this method.
OnMessage(msg interface{})
}

// PubSub is a simple one-to-many publish-subscribe system that supports
// messages of arbitrary type. It guarantees that messages are delivered in
// the same order in which they were published.
//
// Publisher invokes the Publish() method to publish new messages, while
// subscribers interested in receiving these messages register a callback
// via the Subscribe() method.
//
// Once a PubSub is stopped, no more messages can be published, and
// it is guaranteed that no more subscriber callback will be invoked.
type PubSub struct {
cs *CallbackSerializer
cancel context.CancelFunc

// Access to the below fields are guarded by this mutex.
mu sync.Mutex
msg interface{}
subscribers map[Subscriber]bool
stopped bool
}

// NewPubSub returns a new PubSub instance.
func NewPubSub() *PubSub {
ctx, cancel := context.WithCancel(context.Background())
return &PubSub{
cs: NewCallbackSerializer(ctx),
cancel: cancel,
subscribers: map[Subscriber]bool{},
}
}

// Subscribe registers the provided Subscriber to the PubSub.
//
// If the PubSub contains a previously published message, the Subscriber's
// OnMessage() callback will be invoked asynchronously with the existing
// message to begin with, and subsequently for every newly published message.
//
// The caller is responsible for invoking the returned cancel function to
// unsubscribe itself from the PubSub.
func (ps *PubSub) Subscribe(sub Subscriber) (cancel func()) {
ps.mu.Lock()
defer ps.mu.Unlock()

if ps.stopped {
return func() {}
}

ps.subscribers[sub] = true

if ps.msg != nil {
msg := ps.msg
ps.cs.Schedule(func(context.Context) {
ps.mu.Lock()
defer ps.mu.Unlock()
if !ps.subscribers[sub] {
return
}
sub.OnMessage(msg)
})
}

return func() {
ps.mu.Lock()
defer ps.mu.Unlock()
delete(ps.subscribers, sub)
}
}

// Publish publishes the provided message to the PubSub, and invokes
// callbacks registered by subscribers asynchronously.
func (ps *PubSub) Publish(msg interface{}) {
ps.mu.Lock()
defer ps.mu.Unlock()

if ps.stopped {
return
}

ps.msg = msg
for sub := range ps.subscribers {
s := sub
ps.cs.Schedule(func(context.Context) {
ps.mu.Lock()
defer ps.mu.Unlock()
if !ps.subscribers[s] {
return
}
s.OnMessage(msg)
})
}
}

// Stop shuts down the PubSub and releases any resources allocated by it.
// It is guaranteed that no subscriber callbacks would be invoked once this
// method returns.
func (ps *PubSub) Stop() {
ps.mu.Lock()
defer ps.mu.Unlock()
ps.stopped = true

ps.cancel()
}
211 changes: 211 additions & 0 deletions internal/grpcsync/pubsub_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
/*
*
* Copyright 2023 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package grpcsync

import (
"fmt"
"sync"
"testing"
"time"

"github.com/google/go-cmp/cmp"
)

type testSubscriber struct {
mu sync.Mutex
msgs []int
onMsgCh chan struct{}
}

func newTestSubscriber(chSize int) *testSubscriber {
return &testSubscriber{onMsgCh: make(chan struct{}, chSize)}
}

func (ts *testSubscriber) OnMessage(msg interface{}) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.msgs = append(ts.msgs, msg.(int))
select {
case ts.onMsgCh <- struct{}{}:
default:
}
}

func (ts *testSubscriber) receivedMsgs() []int {
ts.mu.Lock()
defer ts.mu.Unlock()

msgs := make([]int, len(ts.msgs))
copy(msgs, ts.msgs)

return msgs
}

func (s) TestPubSub_PublishNoMsg(t *testing.T) {
pubsub := NewPubSub()
defer pubsub.Stop()

ts := newTestSubscriber(1)
pubsub.Subscribe(ts)

select {
case <-ts.onMsgCh:
t.Fatalf("Subscriber callback invoked when no message was published")
case <-time.After(defaultTestShortTimeout):
}
}

func (s) TestPubSub_PublishMsgs_RegisterSubs_And_Stop(t *testing.T) {
pubsub := NewPubSub()

const numPublished = 10

ts1 := newTestSubscriber(numPublished)
pubsub.Subscribe(ts1)
wantMsgs1 := []int{}

var wg sync.WaitGroup
wg.Add(2)
// Publish ten messages on the pubsub and ensure that they are received in order by the subscriber.
go func() {
for i := 0; i < numPublished; i++ {
pubsub.Publish(i)
wantMsgs1 = append(wantMsgs1, i)
}
wg.Done()
}()

isTimeout := false
go func() {
for i := 0; i < numPublished; i++ {
select {
case <-ts1.onMsgCh:
case <-time.After(defaultTestTimeout):
isTimeout = true
}
}
wg.Done()
}()

wg.Wait()
if isTimeout {
t.Fatalf("Timeout when expecting the onMessage() callback to be invoked")
}
if gotMsgs1 := ts1.receivedMsgs(); !cmp.Equal(gotMsgs1, wantMsgs1) {
t.Fatalf("Received messages is %v, want %v", gotMsgs1, wantMsgs1)
}

// Register another subscriber and ensure that it receives the last published message.
ts2 := newTestSubscriber(numPublished)
pubsub.Subscribe(ts2)
wantMsgs2 := wantMsgs1[len(wantMsgs1)-1:]

select {
case <-ts2.onMsgCh:
case <-time.After(defaultTestShortTimeout):
t.Fatalf("Timeout when expecting the onMessage() callback to be invoked")
}
if gotMsgs2 := ts2.receivedMsgs(); !cmp.Equal(gotMsgs2, wantMsgs2) {
t.Fatalf("Received messages is %v, want %v", gotMsgs2, wantMsgs2)
}

wg.Add(3)
// Publish ten messages on the pubsub and ensure that they are received in order by the subscribers.
go func() {
for i := 0; i < numPublished; i++ {
pubsub.Publish(i)
wantMsgs1 = append(wantMsgs1, i)
wantMsgs2 = append(wantMsgs2, i)
}
wg.Done()
}()
errCh := make(chan error, 1)
go func() {
for i := 0; i < numPublished; i++ {
select {
case <-ts1.onMsgCh:
case <-time.After(defaultTestTimeout):
errCh <- fmt.Errorf("")
}
}
wg.Done()
}()
go func() {
for i := 0; i < numPublished; i++ {
select {
case <-ts2.onMsgCh:
case <-time.After(defaultTestTimeout):
errCh <- fmt.Errorf("")
}
}
wg.Done()
}()
wg.Wait()
select {
case <-errCh:
t.Fatalf("Timeout when expecting the onMessage() callback to be invoked")
default:
}
if gotMsgs1 := ts1.receivedMsgs(); !cmp.Equal(gotMsgs1, wantMsgs1) {
t.Fatalf("Received messages is %v, want %v", gotMsgs1, wantMsgs1)
}
if gotMsgs2 := ts2.receivedMsgs(); !cmp.Equal(gotMsgs2, wantMsgs2) {
t.Fatalf("Received messages is %v, want %v", gotMsgs2, wantMsgs2)
}

pubsub.Stop()

go func() {
pubsub.Publish(99)
}()
// Ensure that the subscriber callback is not invoked as instantiated
// pubsub has already closed.
select {
case <-ts1.onMsgCh:
t.Fatalf("The callback was invoked after pubsub being stopped")
case <-ts2.onMsgCh:
t.Fatalf("The callback was invoked after pubsub being stopped")
case <-time.After(defaultTestShortTimeout):
}
}

func (s) TestPubSub_PublishMsgs_BeforeRegisterSub(t *testing.T) {
pubsub := NewPubSub()
defer pubsub.Stop()

const numPublished = 3
for i := 0; i < numPublished; i++ {
pubsub.Publish(i)
}

ts := newTestSubscriber(numPublished)
pubsub.Subscribe(ts)

wantMsgs := []int{numPublished - 1}
// Ensure that the subscriber callback is invoked with a previously
// published message.
select {
case <-ts.onMsgCh:
if gotMsgs := ts.receivedMsgs(); !cmp.Equal(gotMsgs, wantMsgs) {
t.Fatalf("Received messages is %v, want %v", gotMsgs, wantMsgs)
}
case <-time.After(defaultTestShortTimeout):
t.Fatalf("Timeout when expecting the onMessage() callback to be invoked")
}
}

0 comments on commit 51042db

Please # to comment.