diff --git a/conn.go b/conn.go index 2ac19e9a..091e47ce 100644 --- a/conn.go +++ b/conn.go @@ -446,23 +446,24 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er maskBytes(key, 0, buf[6:]) } - d := 1000 * time.Hour - if !deadline.IsZero() { - d = time.Until(deadline) + if deadline.IsZero() { + // No timeout for zero time. + <-c.mu + } else { + d := time.Until(deadline) if d < 0 { return errWriteTimeout } - } - - select { - case <-c.mu: - default: - timer := time.NewTimer(d) select { case <-c.mu: - timer.Stop() - case <-timer.C: - return errWriteTimeout + default: + timer := time.NewTimer(d) + select { + case <-c.mu: + timer.Stop() + case <-timer.C: + return errWriteTimeout + } } } diff --git a/conn_test.go b/conn_test.go index f0c29c39..1f2afcfb 100644 --- a/conn_test.go +++ b/conn_test.go @@ -148,6 +148,22 @@ func TestFraming(t *testing.T) { } } +func TestWriteControlDeadline(t *testing.T) { + t.Parallel() + message := []byte("hello") + var connBuf bytes.Buffer + c := newTestConn(nil, &connBuf, true) + if err := c.WriteControl(PongMessage, message, time.Time{}); err != nil { + t.Errorf("WriteControl(..., zero deadline) = %v, want nil", err) + } + if err := c.WriteControl(PongMessage, message, time.Now().Add(time.Second)); err != nil { + t.Errorf("WriteControl(..., future deadline) = %v, want nil", err) + } + if err := c.WriteControl(PongMessage, message, time.Now().Add(-time.Second)); err == nil { + t.Errorf("WriteControl(..., past deadline) = nil, want timeout error") + } +} + func TestConcurrencyWriteControl(t *testing.T) { const message = "this is a ping/pong messsage" loop := 10