From c066ac3c1369ccb0d8af48235de25a35d7085546 Mon Sep 17 00:00:00 2001 From: apricotbucket28 Date: Tue, 25 Feb 2025 14:03:21 -0300 Subject: [PATCH] sharding: Fix data race in shard rate limiter --- sharding/shard_rate_limiter.go | 1 + sharding/shard_rate_limiter_impl.go | 82 +++++++++++------------- sharding/shard_rate_limiter_impl_test.go | 81 +++++++++++++++++++++++ 3 files changed, 120 insertions(+), 44 deletions(-) create mode 100644 sharding/shard_rate_limiter_impl_test.go diff --git a/sharding/shard_rate_limiter.go b/sharding/shard_rate_limiter.go index da21a281..ae81e2a7 100644 --- a/sharding/shard_rate_limiter.go +++ b/sharding/shard_rate_limiter.go @@ -18,6 +18,7 @@ type RateLimiter interface { WaitBucket(ctx context.Context, shardID int) error // UnlockBucket unlocks the given shardID bucket. + // If WaitBucket fails, UnlockBucket should not be called. UnlockBucket(shardID int) } diff --git a/sharding/shard_rate_limiter_impl.go b/sharding/shard_rate_limiter_impl.go index afc0c3be..fa343329 100644 --- a/sharding/shard_rate_limiter_impl.go +++ b/sharding/shard_rate_limiter_impl.go @@ -9,6 +9,9 @@ import ( "github.com/sasha-s/go-csync" ) +// identifyWait is the duration to wait in between identifying shards. +var identifyWait = 5 * time.Second + var _ RateLimiter = (*rateLimiterImpl)(nil) // NewRateLimiter creates a new default RateLimiter with the given RateLimiterConfigOpt(s). @@ -18,7 +21,7 @@ func NewRateLimiter(opts ...RateLimiterConfigOpt) RateLimiter { config.Logger = config.Logger.With(slog.String("name", "sharding_rate_limiter")) return &rateLimiterImpl{ - buckets: map[int]*bucket{}, + buckets: make(map[int]*bucket), config: *config, } } @@ -31,6 +34,7 @@ type rateLimiterImpl struct { } func (r *rateLimiterImpl) Close(ctx context.Context) { + r.config.Logger.Debug("closing shard rate limiter") var wg sync.WaitGroup r.mu.Lock() @@ -45,74 +49,64 @@ func (r *rateLimiterImpl) Close(ctx context.Context) { b.mu.Unlock() }() } + wg.Wait() } -func (r *rateLimiterImpl) getBucket(shardID int, create bool) *bucket { +func (r *rateLimiterImpl) getBucket(shardID int) *bucket { r.config.Logger.Debug("locking shard rate limiter") r.mu.Lock() - defer func() { - r.config.Logger.Debug("unlocking shard rate limiter") - r.mu.Unlock() - }() + defer r.mu.Unlock() + key := ShardMaxConcurrencyKey(shardID, r.config.MaxConcurrency) - b, ok := r.buckets[key] - if !ok { - if !create { - return nil - } - - b = &bucket{ - Key: key, - } - r.buckets[key] = b + if b, ok := r.buckets[key]; ok { + return b + } + + b := &bucket{ + key: key, } + r.buckets[key] = b return b } func (r *rateLimiterImpl) WaitBucket(ctx context.Context, shardID int) error { - b := r.getBucket(shardID, true) - r.config.Logger.Debug("locking shard bucket", slog.Int("key", b.Key), slog.Time("reset", b.Reset)) + key := ShardMaxConcurrencyKey(shardID, r.config.MaxConcurrency) + r.config.Logger.Debug("locking shard bucket", slog.Int("key", key)) + + b := r.getBucket(shardID) if err := b.mu.CLock(ctx); err != nil { return err } - var until time.Time now := time.Now() - - if b.Reset.After(now) { - until = b.Reset + if b.reset.Before(now) { + return nil } - if until.After(now) { - if deadline, ok := ctx.Deadline(); ok && until.After(deadline) { - return context.DeadlineExceeded - } + if deadline, ok := ctx.Deadline(); ok && b.reset.After(deadline) { + return context.DeadlineExceeded + } - select { - case <-ctx.Done(): - b.mu.Unlock() - return ctx.Err() - case <-time.After(until.Sub(now)): - } + select { + case <-ctx.Done(): + b.mu.Unlock() + return ctx.Err() + case <-time.After(b.reset.Sub(now)): + return nil } - return nil } func (r *rateLimiterImpl) UnlockBucket(shardID int) { - b := r.getBucket(shardID, false) - if b == nil { - return - } - defer func() { - r.config.Logger.Debug("unlocking shard bucket", slog.Int("key", b.Key), slog.Time("reset", b.Reset)) - b.mu.Unlock() - }() + b := r.getBucket(shardID) - b.Reset = time.Now().Add(5 * time.Second) + b.reset = time.Now().Add(identifyWait) + r.config.Logger.Debug("unlocking shard bucket", slog.Int("key", b.key), slog.Time("reset", b.reset)) + b.mu.Unlock() } +// bucket represents a rate-limiting bucket for a shard group. type bucket struct { mu csync.Mutex - Key int - Reset time.Time + key int + reset time.Time } diff --git a/sharding/shard_rate_limiter_impl_test.go b/sharding/shard_rate_limiter_impl_test.go new file mode 100644 index 00000000..b5eefa37 --- /dev/null +++ b/sharding/shard_rate_limiter_impl_test.go @@ -0,0 +1,81 @@ +package sharding + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func init() { + identifyWait = 100 * time.Millisecond +} + +func TestShardRateLimiterImpl(t *testing.T) { + t.Parallel() + + r := NewRateLimiter() + + start := time.Now() + + var wg sync.WaitGroup + for i := 0; i < 3; i++ { + shardID := i + wg.Add(1) + go func() { + defer wg.Done() + err := r.WaitBucket(context.Background(), shardID) + assert.NoError(t, err) + r.UnlockBucket(shardID) + }() + } + wg.Wait() + + expected := start.Add(200 * time.Millisecond) + assert.WithinDuration(t, expected, time.Now(), 10*time.Millisecond) +} + +func TestShardRateLimiterImpl_WithMaxConcurrency(t *testing.T) { + t.Parallel() + + r := NewRateLimiter(WithMaxConcurrency(3)) + + start := time.Now() + + var wg sync.WaitGroup + for i := 0; i < 6; i++ { + shardID := i + wg.Add(1) + go func() { + defer wg.Done() + err := r.WaitBucket(context.Background(), shardID) + assert.NoError(t, err) + r.UnlockBucket(shardID) + }() + } + wg.Wait() + + expected := start.Add(100 * time.Millisecond) + assert.WithinDuration(t, expected, time.Now(), 10*time.Millisecond) +} + +func TestShardRateLimiterImpl_WaitBucketWithTimeout(t *testing.T) { + t.Parallel() + + r := NewRateLimiter() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + err := r.WaitBucket(ctx, 0) + assert.NoError(t, err) + + err = r.WaitBucket(ctx, 0) + if assert.Error(t, err) { + assert.Equal(t, context.DeadlineExceeded, err) + } + + r.UnlockBucket(0) +}