Skip to content

feat: support connection lifetime for single client #727

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
15 changes: 15 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ func (c *singleClient) Do(ctx context.Context, cmd Completed) (resp RedisResult)
attempts := 1
retry:
resp = c.conn.Do(ctx, cmd)
if resp.Error() == errConnExpired {
goto retry
}
if c.retry && cmd.IsReadOnly() && c.isRetryable(resp.Error(), ctx) {
shouldRetry := c.retryHandler.WaitOrSkipRetry(
ctx, attempts, cmd, resp.Error(),
Expand Down Expand Up @@ -86,6 +89,9 @@ func (c *singleClient) DoMulti(ctx context.Context, multi ...Completed) (resps [
attempts := 1
retry:
resps = c.conn.DoMulti(ctx, multi...).s
if resps[0].Error() == errConnExpired {
goto retry
}
if c.retry && allReadOnly(multi) {
for i, resp := range resps {
if c.isRetryable(resp.Error(), ctx) {
Expand Down Expand Up @@ -114,6 +120,9 @@ func (c *singleClient) DoMultiCache(ctx context.Context, multi ...CacheableTTL)
attempts := 1
retry:
resps = c.conn.DoMultiCache(ctx, multi...).s
if resps[0].Error() == errConnExpired {
goto retry
}
if c.retry {
for i, resp := range resps {
if c.isRetryable(resp.Error(), ctx) {
Expand All @@ -139,6 +148,9 @@ func (c *singleClient) DoCache(ctx context.Context, cmd Cacheable, ttl time.Dura
attempts := 1
retry:
resp = c.conn.DoCache(ctx, cmd, ttl)
if resp.Error() == errConnExpired {
goto retry
}
if c.retry && c.isRetryable(resp.Error(), ctx) {
shouldRetry := c.retryHandler.WaitOrSkipRetry(ctx, attempts, Completed(cmd), resp.Error())
if shouldRetry {
Expand All @@ -156,6 +168,9 @@ func (c *singleClient) Receive(ctx context.Context, subscribe Completed, fn func
attempts := 1
retry:
err = c.conn.Receive(ctx, subscribe, fn)
if err == errConnExpired {
goto retry
}
if c.retry {
if _, ok := err.(*RedisError); !ok && c.isRetryable(err, ctx) {
shouldRetry := c.retryHandler.WaitOrSkipRetry(ctx, attempts, subscribe, err)
Expand Down
16 changes: 16 additions & 0 deletions mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1131,6 +1131,8 @@ type mockWire struct {
VersionFn func() int
ErrorFn func() error
CloseFn func()
StopTimerFn func() bool
ResetTimerFn func() bool

CleanSubscriptionsFn func()
SetPubSubHooksFn func(hooks PubSubHooks) <-chan error
Expand Down Expand Up @@ -1205,6 +1207,20 @@ func (m *mockWire) SetOnCloseHook(fn func(error)) {
}
}

func (m *mockWire) StopTimer() bool {
if m.StopTimerFn != nil {
return m.StopTimerFn()
}
return true
}

func (m *mockWire) ResetTimer() bool {
if m.ResetTimerFn != nil {
return m.ResetTimerFn()
}
return true
}

func (m *mockWire) Info() map[string]RedisMessage {
if m.InfoFn != nil {
return m.InfoFn()
Expand Down
33 changes: 32 additions & 1 deletion pipe.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ type wire interface {
CleanSubscriptions()
SetPubSubHooks(hooks PubSubHooks) <-chan error
SetOnCloseHook(fn func(error))
StopTimer() bool
ResetTimer() bool
}

var _ wire = (*pipe)(nil)
Expand Down Expand Up @@ -88,6 +90,8 @@ type pipe struct {
bgState int32
r2ps bool // identify this pipe is used for resp2 pubsub or not
noNoDelay bool
lftm time.Duration // lifetime
lftmTimer *time.Timer // lifetime timer
optIn bool
}

Expand Down Expand Up @@ -328,6 +332,10 @@ func _newPipe(ctx context.Context, connFn func(context.Context) (net.Conn, error
p.backgroundPing()
}
}
if option.ConnLifetime > 0 {
p.lftm = option.ConnLifetime
p.lftmTimer = time.AfterFunc(option.ConnLifetime, p.expired)
}
return p, nil
}

Expand All @@ -344,6 +352,7 @@ func (p *pipe) _exit(err error) {
p.error.CompareAndSwap(nil, &errs{error: err})
atomic.CompareAndSwapInt32(&p.state, 1, 2) // stop accepting new requests
_ = p.conn.Close() // force both read & write goroutine to exit
p.StopTimer()
p.clhks.Load().(func(error))(err)
}

Expand Down Expand Up @@ -1633,6 +1642,25 @@ func (p *pipe) Close() {
p.r2mu.Unlock()
}

func (p *pipe) StopTimer() bool {
if p.lftmTimer == nil {
return true
}
return p.lftmTimer.Stop()
}

func (p *pipe) ResetTimer() bool {
if p.lftmTimer == nil || p.Error() != nil {
return true
}
return p.lftmTimer.Reset(p.lftm)
}

func (p *pipe) expired() {
p.error.CompareAndSwap(nil, errExpired)
p.Close()
}

type pshks struct {
hooks PubSubHooks
close chan error
Expand Down Expand Up @@ -1672,6 +1700,9 @@ const (
)

var cacheMark = &(RedisMessage{})
var errClosing = &errs{error: ErrClosing}
var (
errClosing = &errs{error: ErrClosing}
errExpired = &errs{error: errConnExpired}
)

type errs struct{ error }
60 changes: 60 additions & 0 deletions pipe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2773,6 +2773,66 @@ func TestOnInvalidations(t *testing.T) {
}
}

func TestConnLifetime(t *testing.T) {
defer ShouldNotLeaked(SetupLeakDetection())

t.Run("Enabled ConnLifetime", func(t *testing.T) {
p, _, _, closeConn := setup(t, ClientOption{
ConnLifetime: 50 * time.Millisecond,
})
defer closeConn()

if p.Error() != nil {
t.Fatalf("unexpected error %v", p.Error())
}
time.Sleep(60 * time.Millisecond)
if p.Error() != errConnExpired {
t.Fatalf("unexpected error, expected: %v, got: %v", errConnExpired, p.Error())
}
})

t.Run("Disabled ConnLifetime", func(t *testing.T) {
p, _, _, closeConn := setup(t, ClientOption{})
defer closeConn()

time.Sleep(60 * time.Millisecond)
if p.Error() != nil {
t.Fatalf("unexpected error %v", p.Error())
}
})

t.Run("StopTimer", func(t *testing.T) {
p, _, _, closeConn := setup(t, ClientOption{
ConnLifetime: 50 * time.Millisecond,
})
defer closeConn()

p.StopTimer()
time.Sleep(60 * time.Millisecond)
if p.Error() != nil {
t.Fatalf("unexpected error %v", p.Error())
}
})

t.Run("ResetTimer", func(t *testing.T) {
p, _, _, closeConn := setup(t, ClientOption{
ConnLifetime: 50 * time.Millisecond,
})
defer closeConn()

time.Sleep(20 * time.Millisecond)
p.ResetTimer()
time.Sleep(40 * time.Millisecond)
if p.Error() != nil {
t.Fatalf("unexpected error %v", p.Error())
}
time.Sleep(20 * time.Millisecond)
if p.Error() != errConnExpired {
t.Fatalf("unexpected error, expected: %v, got: %v", errConnExpired, p.Error())
}
})
}

func TestMultiHalfErr(t *testing.T) {
defer ShouldNotLeaked(SetupLeakDetection())
p, mock, _, closeConn := setup(t, ClientOption{})
Expand Down
4 changes: 3 additions & 1 deletion pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,15 @@ retry:
// allowing others to make wires concurrently instead of waiting in line
p.cond.L.Unlock()
v = p.make(ctx)
v.StopTimer()
return v
}

i := len(p.list) - 1
v = p.list[i]
p.list[i] = nil
p.list = p.list[:i]
if v.Error() != nil {
if !v.StopTimer() || v.Error() != nil {
p.size--
v.Close()
goto retry
Expand All @@ -102,6 +103,7 @@ func (p *pool) Store(v wire) {
if !p.down && v.Error() == nil {
p.list = append(p.list, v)
p.startTimerIfNeeded()
v.ResetTimer()
} else {
p.size--
v.Close()
Expand Down
77 changes: 77 additions & 0 deletions pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,83 @@ func TestPoolWithIdleTTL(t *testing.T) {
})
}

func TestPoolWithConnLifetime(t *testing.T) {
defer ShouldNotLeaked(SetupLeakDetection())
setup := func(wires []wire) *pool {
var count int32
return newPool(len(wires), dead, 0, 0, func(ctx context.Context) wire {
idx := atomic.AddInt32(&count, 1) - 1
return wires[idx]
})
}

t.Run("Reuse without expired connections", func(t *testing.T) {
stopTimerCall := 0
wires := []wire{
&mockWire{},
&mockWire{
StopTimerFn: func() bool {
stopTimerCall++
return false
}, // connection lifetime timer is already fired
},
}
conn := make([]wire, 0, len(wires))
pool := setup(wires)
for i := 0; i < len(wires); i++ {
conn = append(conn, pool.Acquire(context.Background()))
}
for i := 0; i < len(conn); i++ {
pool.Store(conn[i])
}

if stopTimerCall != 1 {
t.Errorf("StopTimer must be called when making wire")
}

pool.cond.L.Lock()
if pool.size != 2 {
t.Errorf("size must be equal to 2, actual: %d", pool.size)
}
if len(pool.list) != 2 {
t.Errorf("list len must equal to 2, actual: %d", len(pool.list))
}
pool.cond.L.Unlock()

// stop timer failed, so drop the expired connection
pool.Store(pool.Acquire(context.Background()))

if stopTimerCall != 2 {
t.Errorf("StopTimer must be called when acquiring from pool")
}

pool.cond.L.Lock()
if pool.size != 1 {
t.Errorf("size must be equal to 1, actual: %d", pool.size)
}
if len(pool.list) != 1 {
t.Errorf("list len must equal to 1, actual: %d", len(pool.list))
}
pool.cond.L.Unlock()
})

t.Run("Reset timer when storing to pool", func(t *testing.T) {
call := false
w := &mockWire{
ResetTimerFn: func() bool {
call = true
return true
},
}
pool := setup([]wire{w})
pool.Store(pool.Acquire(context.Background()))

if !call {
t.Error("ResetTimer must be called when storing")
}
})
}

func TestPoolWithAcquireCtx(t *testing.T) {
defer ShouldNotLeaked(SetupLeakDetection())
setup := func(size int, delay time.Duration) *pool {
Expand Down
9 changes: 9 additions & 0 deletions rueidis.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,10 @@ type ClientOption struct {
// This default is ClientOption.Dialer.KeepAlive * (9+1), where 9 is the default of tcp_keepalive_probes on Linux.
ConnWriteTimeout time.Duration

// ConnLiftime is lifetime for each connection. If specified,
// connections will close after passing lifetime. Note that the connection which dedicated client and blocking use is not closed.
ConnLifetime time.Duration

// MaxFlushDelay when greater than zero pauses pipeline write loop for some time (not larger than MaxFlushDelay)
// after each flushing of data to the connection. This gives pipeline a chance to collect more commands to send
// to Redis. Adding this delay increases latency, reduces throughput – but in most cases may significantly reduce
Expand Down Expand Up @@ -505,3 +509,8 @@ func dial(ctx context.Context, dst string, opt *ClientOption) (conn net.Conn, er
}

const redisErrMsgCommandNotAllow = "command is not allowed"

var (
// errConnExpired means wrong connection that ClientOption.ConnLifetime had passed since connecting
errConnExpired = errors.New("connection is expired")
)