Skip to content

Commit

Permalink
feat: Add TLS Handshake timeout support (#530)
Browse files Browse the repository at this point in the history
* feat: Add TLS Handshake timeout support

Add support for configuring a timeout for TLS Handshake call via
DialTLSHandshakeTimeout DialOption. If no option is specified then the
default timeout is 10 seconds.

Also:
* Add a default connect timeout of 30 seconds matching that of net/http.

Fixes #509
  • Loading branch information
stevenh authored Nov 19, 2020
1 parent 0b0ad3d commit c4a82d6
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 13 deletions.
57 changes: 44 additions & 13 deletions redis/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,27 @@ type DialOption struct {
}

type dialOptions struct {
readTimeout time.Duration
writeTimeout time.Duration
dialer *net.Dialer
dialContext func(ctx context.Context, network, addr string) (net.Conn, error)
db int
username string
password string
clientName string
useTLS bool
skipVerify bool
tlsConfig *tls.Config
readTimeout time.Duration
writeTimeout time.Duration
tlsHandshakeTimeout time.Duration
dialer *net.Dialer
dialContext func(ctx context.Context, network, addr string) (net.Conn, error)
db int
username string
password string
clientName string
useTLS bool
skipVerify bool
tlsConfig *tls.Config
}

// DialTLSHandshakeTimeout specifies the maximum amount of time waiting to
// wait for a TLS handshake. Zero means no timeout.
// If no DialTLSHandshakeTimeout option is specified then the default is 30 seconds.
func DialTLSHandshakeTimeout(d time.Duration) DialOption {
return DialOption{func(do *dialOptions) {
do.tlsHandshakeTimeout = d
}}
}

// DialReadTimeout specifies the timeout for reading a single command reply.
Expand All @@ -104,6 +114,7 @@ func DialWriteTimeout(d time.Duration) DialOption {

// DialConnectTimeout specifies the timeout for connecting to the Redis server when
// no DialNetDial option is specified.
// If no DialConnectTimeout option is specified then the default is 30 seconds.
func DialConnectTimeout(d time.Duration) DialOption {
return DialOption{func(do *dialOptions) {
do.dialer.Timeout = d
Expand Down Expand Up @@ -201,13 +212,21 @@ func Dial(network, address string, options ...DialOption) (Conn, error) {
return DialContext(context.Background(), network, address, options...)
}

type tlsHandshakeTimeoutError struct{}

func (tlsHandshakeTimeoutError) Timeout() bool { return true }
func (tlsHandshakeTimeoutError) Temporary() bool { return true }
func (tlsHandshakeTimeoutError) Error() string { return "TLS handshake timeout" }

// DialContext connects to the Redis server at the given network and
// address using the specified options and context.
func DialContext(ctx context.Context, network, address string, options ...DialOption) (Conn, error) {
do := dialOptions{
dialer: &net.Dialer{
Timeout: time.Second * 30,
KeepAlive: time.Minute * 5,
},
tlsHandshakeTimeout: time.Second * 10,
}
for _, option := range options {
option.f(&do)
Expand Down Expand Up @@ -238,10 +257,22 @@ func DialContext(ctx context.Context, network, address string, options ...DialOp
}

tlsConn := tls.Client(netConn, tlsConfig)
if err := tlsConn.Handshake(); err != nil {
netConn.Close()
errc := make(chan error, 2) // buffered so we don't block timeout or Handshake
if d := do.tlsHandshakeTimeout; d != 0 {
timer := time.AfterFunc(d, func() {
errc <- tlsHandshakeTimeoutError{}
})
defer timer.Stop()
}
go func() {
errc <- tlsConn.Handshake()
}()
if err := <-errc; err != nil {
// Timeout or Handshake error.
netConn.Close() // nolint: errcheck
return nil, err
}

netConn = tlsConn
}

Expand Down
39 changes: 39 additions & 0 deletions redis/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,45 @@ func TestDialUseTLS(t *testing.T) {
checkPingPong(t, &buf, c)
}

type blockedReader struct {
ch chan struct{}
}

func (b blockedReader) Read(p []byte) (n int, err error) {
<-b.ch
return 0, nil
}

func dialTestBlockedConn(ch chan struct{}, w io.Writer) redis.DialOption {
return redis.DialNetDial(func(network, addr string) (net.Conn, error) {
return &testConn{Reader: blockedReader{ch: ch}, Writer: w}, nil
})
}

func TestDialTLSHandshakeTimeout(t *testing.T) {
var buf bytes.Buffer
ch := make(chan struct{})
var err error
go func() {
_, err = redis.Dial("tcp", "example.com:6379",
redis.DialTLSConfig(&clientTLSConfig),
redis.DialTLSHandshakeTimeout(time.Millisecond),
dialTestBlockedConn(ch, &buf),
redis.DialUseTLS(true))
close(ch)
}()
select {
case <-time.After(time.Second):
t.Fatal("dial didn't timeout")
case <-ch:
if err == nil {
t.Fatal("dial didn't error")
} else if err.Error() != "TLS handshake timeout" {
t.Fatal("dial unexpected error:", err)
}
}
}

func TestDialTLSSKipVerify(t *testing.T) {
var buf bytes.Buffer
c, err := redis.Dial("tcp", "example.com:6379",
Expand Down

0 comments on commit c4a82d6

Please # to comment.