diff --git a/redis/pool.go b/redis/pool.go index 4a468cf0..c6c56f43 100644 --- a/redis/pool.go +++ b/redis/pool.go @@ -387,6 +387,7 @@ func (p *Pool) waitVacantConn(ctx context.Context) (waited time.Duration, err er // because `select` picks a random `case` if several of them are "ready". select { case <-ctx.Done(): + p.ch <- struct{}{} return 0, ctx.Err() default: } diff --git a/redis/pool_test.go b/redis/pool_test.go index a8efa515..a8a8c2ff 100644 --- a/redis/pool_test.go +++ b/redis/pool_test.go @@ -825,6 +825,30 @@ func TestWaitPoolGetContext(t *testing.T) { defer c.Close() } +func TestWaitPoolGetContextIssue520(t *testing.T) { + d := poolDialer{t: t} + p := &redis.Pool{ + MaxIdle: 1, + MaxActive: 1, + Dial: d.dial, + Wait: true, + } + defer p.Close() + ctx1, _ := context.WithTimeout(context.Background(), 1*time.Nanosecond) + c, err := p.GetContext(ctx1) + if err != context.DeadlineExceeded { + t.Fatalf("GetContext returned %v", err) + } + defer c.Close() + + ctx2, _ := context.WithCancel(context.Background()) + c2, err := p.GetContext(ctx2) + if err != nil { + t.Fatalf("Get context returned %v", err) + } + defer c2.Close() +} + func TestWaitPoolGetContextWithDialContext(t *testing.T) { d := poolDialer{t: t} p := &redis.Pool{