Skip to content

Commit

Permalink
sharding: Fix data race in shard rate limiter
Browse files Browse the repository at this point in the history
  • Loading branch information
apricotbucket28 committed Feb 25, 2025
1 parent 84c71ee commit c066ac3
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 44 deletions.
1 change: 1 addition & 0 deletions sharding/shard_rate_limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type RateLimiter interface {
WaitBucket(ctx context.Context, shardID int) error

// UnlockBucket unlocks the given shardID bucket.
// If WaitBucket fails, UnlockBucket should not be called.
UnlockBucket(shardID int)
}

Expand Down
82 changes: 38 additions & 44 deletions sharding/shard_rate_limiter_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import (
"github.com/sasha-s/go-csync"
)

// identifyWait is the duration to wait in between identifying shards.
var identifyWait = 5 * time.Second

var _ RateLimiter = (*rateLimiterImpl)(nil)

// NewRateLimiter creates a new default RateLimiter with the given RateLimiterConfigOpt(s).
Expand All @@ -18,7 +21,7 @@ func NewRateLimiter(opts ...RateLimiterConfigOpt) RateLimiter {
config.Logger = config.Logger.With(slog.String("name", "sharding_rate_limiter"))

return &rateLimiterImpl{
buckets: map[int]*bucket{},
buckets: make(map[int]*bucket),
config: *config,
}
}
Expand All @@ -31,6 +34,7 @@ type rateLimiterImpl struct {
}

func (r *rateLimiterImpl) Close(ctx context.Context) {
r.config.Logger.Debug("closing shard rate limiter")
var wg sync.WaitGroup
r.mu.Lock()

Expand All @@ -45,74 +49,64 @@ func (r *rateLimiterImpl) Close(ctx context.Context) {
b.mu.Unlock()
}()
}
wg.Wait()
}

func (r *rateLimiterImpl) getBucket(shardID int, create bool) *bucket {
func (r *rateLimiterImpl) getBucket(shardID int) *bucket {
r.config.Logger.Debug("locking shard rate limiter")
r.mu.Lock()
defer func() {
r.config.Logger.Debug("unlocking shard rate limiter")
r.mu.Unlock()
}()
defer r.mu.Unlock()

key := ShardMaxConcurrencyKey(shardID, r.config.MaxConcurrency)
b, ok := r.buckets[key]
if !ok {
if !create {
return nil
}

b = &bucket{
Key: key,
}
r.buckets[key] = b
if b, ok := r.buckets[key]; ok {
return b
}

b := &bucket{
key: key,
}
r.buckets[key] = b
return b
}

func (r *rateLimiterImpl) WaitBucket(ctx context.Context, shardID int) error {
b := r.getBucket(shardID, true)
r.config.Logger.Debug("locking shard bucket", slog.Int("key", b.Key), slog.Time("reset", b.Reset))
key := ShardMaxConcurrencyKey(shardID, r.config.MaxConcurrency)
r.config.Logger.Debug("locking shard bucket", slog.Int("key", key))

b := r.getBucket(shardID)
if err := b.mu.CLock(ctx); err != nil {
return err
}

var until time.Time
now := time.Now()

if b.Reset.After(now) {
until = b.Reset
if b.reset.Before(now) {
return nil
}

if until.After(now) {
if deadline, ok := ctx.Deadline(); ok && until.After(deadline) {
return context.DeadlineExceeded
}
if deadline, ok := ctx.Deadline(); ok && b.reset.After(deadline) {
return context.DeadlineExceeded
}

select {
case <-ctx.Done():
b.mu.Unlock()
return ctx.Err()
case <-time.After(until.Sub(now)):
}
select {
case <-ctx.Done():
b.mu.Unlock()
return ctx.Err()
case <-time.After(b.reset.Sub(now)):
return nil
}
return nil
}

func (r *rateLimiterImpl) UnlockBucket(shardID int) {
b := r.getBucket(shardID, false)
if b == nil {
return
}
defer func() {
r.config.Logger.Debug("unlocking shard bucket", slog.Int("key", b.Key), slog.Time("reset", b.Reset))
b.mu.Unlock()
}()
b := r.getBucket(shardID)

b.Reset = time.Now().Add(5 * time.Second)
b.reset = time.Now().Add(identifyWait)
r.config.Logger.Debug("unlocking shard bucket", slog.Int("key", b.key), slog.Time("reset", b.reset))
b.mu.Unlock()
}

// bucket represents a rate-limiting bucket for a shard group.
type bucket struct {
mu csync.Mutex
Key int
Reset time.Time
key int
reset time.Time
}
81 changes: 81 additions & 0 deletions sharding/shard_rate_limiter_impl_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package sharding

import (
"context"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func init() {
identifyWait = 100 * time.Millisecond
}

func TestShardRateLimiterImpl(t *testing.T) {
t.Parallel()

r := NewRateLimiter()

start := time.Now()

var wg sync.WaitGroup
for i := 0; i < 3; i++ {
shardID := i
wg.Add(1)
go func() {
defer wg.Done()
err := r.WaitBucket(context.Background(), shardID)
assert.NoError(t, err)
r.UnlockBucket(shardID)
}()
}
wg.Wait()

expected := start.Add(200 * time.Millisecond)
assert.WithinDuration(t, expected, time.Now(), 10*time.Millisecond)
}

func TestShardRateLimiterImpl_WithMaxConcurrency(t *testing.T) {
t.Parallel()

r := NewRateLimiter(WithMaxConcurrency(3))

start := time.Now()

var wg sync.WaitGroup
for i := 0; i < 6; i++ {
shardID := i
wg.Add(1)
go func() {
defer wg.Done()
err := r.WaitBucket(context.Background(), shardID)
assert.NoError(t, err)
r.UnlockBucket(shardID)
}()
}
wg.Wait()

expected := start.Add(100 * time.Millisecond)
assert.WithinDuration(t, expected, time.Now(), 10*time.Millisecond)
}

func TestShardRateLimiterImpl_WaitBucketWithTimeout(t *testing.T) {
t.Parallel()

r := NewRateLimiter()

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()

err := r.WaitBucket(ctx, 0)
assert.NoError(t, err)

err = r.WaitBucket(ctx, 0)
if assert.Error(t, err) {
assert.Equal(t, context.DeadlineExceeded, err)
}

r.UnlockBucket(0)
}

0 comments on commit c066ac3

Please # to comment.