Skip to content

Commit

Permalink
Rework timeout middleware to use http.TimeoutHandler implementation (fix
Browse files Browse the repository at this point in the history
  • Loading branch information
aldas authored Mar 8, 2021
1 parent 5622ecc commit d6127fe
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 97 deletions.
94 changes: 56 additions & 38 deletions middleware/timeout.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ package middleware

import (
"context"
"fmt"
"github.com/labstack/echo/v4"
"net/http"
"time"
)

Expand All @@ -14,24 +14,31 @@ type (
TimeoutConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// ErrorHandler defines a function which is executed for a timeout
// It can be used to define a custom timeout error
ErrorHandler TimeoutErrorHandlerWithContext

// ErrorMessage is written to response on timeout in addition to http.StatusServiceUnavailable (503) status code
// It can be used to define a custom timeout error message
ErrorMessage string

// OnTimeoutRouteErrorHandler is an error handler that is executed for error that was returned from wrapped route after
// request timeouted and we already had sent the error code (503) and message response to the client.
// NB: do not write headers/body inside this handler. The response has already been sent to the client and response writer
// will not accept anything no more. If you want to know what actual route middleware timeouted use `c.Path()`
OnTimeoutRouteErrorHandler func(err error, c echo.Context)

// Timeout configures a timeout for the middleware, defaults to 0 for no timeout
// NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds)
// the result of timeout does not seem to be reliable - could respond timeout, could respond handler output
// difference over 500microseconds (0.5millisecond) response seems to be reliable
Timeout time.Duration
}

// TimeoutErrorHandlerWithContext is an error handler that is used with the timeout middleware so we can
// handle the error as we see fit
TimeoutErrorHandlerWithContext func(error, echo.Context) error
)

var (
// DefaultTimeoutConfig is the default Timeout middleware config.
DefaultTimeoutConfig = TimeoutConfig{
Skipper: DefaultSkipper,
Timeout: 0,
ErrorHandler: nil,
ErrorMessage: "",
}
)

Expand All @@ -55,39 +62,50 @@ func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc {
return next(c)
}

ctx, cancel := context.WithTimeout(c.Request().Context(), config.Timeout)
defer cancel()

// this does a deep clone of the context, wondering if there is a better way to do this?
c.SetRequest(c.Request().Clone(ctx))

done := make(chan error, 1)
go func() {
defer func() {
if r := recover(); r != nil {
err, ok := r.(error)
if !ok {
err = fmt.Errorf("panic recovered in timeout middleware: %v", r)
}
c.Logger().Error(err)
done <- err
}
}()

// This goroutine will keep running even if this middleware times out and
// will be stopped when ctx.Done() is called down the next(c) call chain
done <- next(c)
}()
handlerWrapper := echoHandlerFuncWrapper{
ctx: c,
handler: next,
errChan: make(chan error, 1),
errHandler: config.OnTimeoutRouteErrorHandler,
}
handler := http.TimeoutHandler(handlerWrapper, config.Timeout, config.ErrorMessage)
handler.ServeHTTP(c.Response().Writer, c.Request())

select {
case <-ctx.Done():
if config.ErrorHandler != nil {
return config.ErrorHandler(ctx.Err(), c)
}
return ctx.Err()
case err := <-done:
case err := <-handlerWrapper.errChan:
return err
default:
return nil
}
}
}
}

type echoHandlerFuncWrapper struct {
ctx echo.Context
handler echo.HandlerFunc
errHandler func(err error, c echo.Context)
errChan chan error
}

func (t echoHandlerFuncWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
// replace writer with TimeoutHandler custom one. This will guarantee that
// `writes by h to its ResponseWriter will return ErrHandlerTimeout.`
originalWriter := t.ctx.Response().Writer
t.ctx.Response().Writer = rw

err := t.handler(t.ctx)
if ctxErr := r.Context().Err(); ctxErr == context.DeadlineExceeded {
if err != nil && t.errHandler != nil {
t.errHandler(err, t.ctx)
}
return // on timeout we can not send handler error to client because `http.TimeoutHandler` has already sent headers
}
// we restore original writer only for cases we did not timeout. On timeout we have already sent response to client
// and should not anymore send additional headers/data
// so on timeout writer stays what http.TimeoutHandler uses and prevents writing headers/body
t.ctx.Response().Writer = originalWriter
if err != nil {
t.errChan <- err
}
}
175 changes: 116 additions & 59 deletions middleware/timeout_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
package middleware

import (
"context"
"errors"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
Expand All @@ -22,6 +21,7 @@ func TestTimeoutSkipper(t *testing.T) {
Skipper: func(context echo.Context) bool {
return true
},
Timeout: 1 * time.Nanosecond,
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
Expand All @@ -31,18 +31,17 @@ func TestTimeoutSkipper(t *testing.T) {
c := e.NewContext(req, rec)

err := m(func(c echo.Context) error {
assert.NotEqual(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String())
return nil
time.Sleep(25 * time.Microsecond)
return errors.New("response from handler")
})(c)

assert.NoError(t, err)
// if not skipped we would have not returned error due context timeout logic
assert.EqualError(t, err, "response from handler")
}

func TestTimeoutWithTimeout0(t *testing.T) {
t.Parallel()
m := TimeoutWithConfig(TimeoutConfig{
Timeout: 0,
})
m := Timeout()

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
Expand All @@ -58,10 +57,11 @@ func TestTimeoutWithTimeout0(t *testing.T) {
assert.NoError(t, err)
}

func TestTimeoutIsCancelable(t *testing.T) {
func TestTimeoutErrorOutInHandler(t *testing.T) {
t.Parallel()
m := TimeoutWithConfig(TimeoutConfig{
Timeout: time.Minute,
// Timeout has to be defined or the whole flow for timeout middleware will be skipped
Timeout: 50 * time.Millisecond,
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
Expand All @@ -70,59 +70,22 @@ func TestTimeoutIsCancelable(t *testing.T) {
e := echo.New()
c := e.NewContext(req, rec)

err := m(func(c echo.Context) error {
assert.EqualValues(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String())
return nil
})(c)

assert.NoError(t, err)
}

func TestTimeoutErrorOutInHandler(t *testing.T) {
t.Parallel()
m := Timeout()

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

e := echo.New()
c := e.NewContext(req, rec)

err := m(func(c echo.Context) error {
return errors.New("err")
})(c)

assert.Error(t, err)
}

func TestTimeoutTimesOutAfterPredefinedTimeoutWithErrorHandler(t *testing.T) {
func TestTimeoutOnTimeoutRouteErrorHandler(t *testing.T) {
t.Parallel()
m := TimeoutWithConfig(TimeoutConfig{
Timeout: time.Second,
ErrorHandler: func(err error, e echo.Context) error {
assert.EqualError(t, err, context.DeadlineExceeded.Error())
return errors.New("err")
},
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

e := echo.New()
c := e.NewContext(req, rec)

err := m(func(c echo.Context) error {
time.Sleep(time.Minute)
return nil
})(c)

assert.EqualError(t, err, errors.New("err").Error())
}

func TestTimeoutTimesOutAfterPredefinedTimeout(t *testing.T) {
t.Parallel()
actualErrChan := make(chan error, 1)
m := TimeoutWithConfig(TimeoutConfig{
Timeout: time.Second,
Timeout: 1 * time.Millisecond,
OnTimeoutRouteErrorHandler: func(err error, c echo.Context) {
actualErrChan <- err
},
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
Expand All @@ -131,12 +94,16 @@ func TestTimeoutTimesOutAfterPredefinedTimeout(t *testing.T) {
e := echo.New()
c := e.NewContext(req, rec)

stopChan := make(chan struct{}, 0)
err := m(func(c echo.Context) error {
time.Sleep(time.Minute)
return nil
<-stopChan
return errors.New("error in route after timeout")
})(c)
stopChan <- struct{}{}
assert.NoError(t, err)

assert.EqualError(t, err, context.DeadlineExceeded.Error())
actualErr := <-actualErrChan
assert.EqualError(t, actualErr, "error in route after timeout")
}

func TestTimeoutTestRequestClone(t *testing.T) {
Expand All @@ -148,7 +115,7 @@ func TestTimeoutTestRequestClone(t *testing.T) {

m := TimeoutWithConfig(TimeoutConfig{
// Timeout has to be defined or the whole flow for timeout middleware will be skipped
Timeout: time.Second,
Timeout: 1 * time.Second,
})

e := echo.New()
Expand Down Expand Up @@ -178,8 +145,63 @@ func TestTimeoutTestRequestClone(t *testing.T) {

func TestTimeoutRecoversPanic(t *testing.T) {
t.Parallel()
e := echo.New()
e.Use(Recover()) // recover middleware will handler our panic
e.Use(TimeoutWithConfig(TimeoutConfig{
Timeout: 50 * time.Millisecond,
}))

e.GET("/", func(c echo.Context) error {
panic("panic!!!")
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

assert.NotPanics(t, func() {
e.ServeHTTP(rec, req)
})
}

func TestTimeoutDataRace(t *testing.T) {
t.Parallel()

timeout := 1 * time.Millisecond
m := TimeoutWithConfig(TimeoutConfig{
Timeout: timeout,
ErrorMessage: "Timeout! change me",
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

e := echo.New()
c := e.NewContext(req, rec)

err := m(func(c echo.Context) error {
// NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds)
// the result of timeout does not seem to be reliable - could respond timeout, could respond handler output
// difference over 500microseconds (0.5millisecond) response seems to be reliable
time.Sleep(timeout) // timeout and handler execution time difference is close to zero
return c.String(http.StatusOK, "Hello, World!")
})(c)

assert.NoError(t, err)

if rec.Code == http.StatusServiceUnavailable {
assert.Equal(t, "Timeout! change me", rec.Body.String())
} else {
assert.Equal(t, "Hello, World!", rec.Body.String())
}
}

func TestTimeoutWithErrorMessage(t *testing.T) {
t.Parallel()

timeout := 1 * time.Millisecond
m := TimeoutWithConfig(TimeoutConfig{
Timeout: 25 * time.Millisecond,
Timeout: timeout,
ErrorMessage: "Timeout! change me",
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
Expand All @@ -188,9 +210,44 @@ func TestTimeoutRecoversPanic(t *testing.T) {
e := echo.New()
c := e.NewContext(req, rec)

stopChan := make(chan struct{}, 0)
err := m(func(c echo.Context) error {
panic("panic in handler")
// NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds)
// the result of timeout does not seem to be reliable - could respond timeout, could respond handler output
// difference over 500microseconds (0.5millisecond) response seems to be reliable
<-stopChan
return c.String(http.StatusOK, "Hello, World!")
})(c)
stopChan <- struct{}{}

assert.NoError(t, err)
assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
assert.Equal(t, "Timeout! change me", rec.Body.String())
}

func TestTimeoutWithDefaultErrorMessage(t *testing.T) {
t.Parallel()

assert.Error(t, err, "panic recovered in timeout middleware: panic in handler")
timeout := 1 * time.Millisecond
m := TimeoutWithConfig(TimeoutConfig{
Timeout: timeout,
ErrorMessage: "",
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

e := echo.New()
c := e.NewContext(req, rec)

stopChan := make(chan struct{}, 0)
err := m(func(c echo.Context) error {
<-stopChan
return c.String(http.StatusOK, "Hello, World!")
})(c)
stopChan <- struct{}{}

assert.NoError(t, err)
assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
assert.Equal(t, `<html><head><title>Timeout</title></head><body><h1>Timeout</h1></body></html>`, rec.Body.String())
}

0 comments on commit d6127fe

Please # to comment.