From d6127fe316464d9a2ae0245682a84fc189f6a676 Mon Sep 17 00:00:00 2001 From: Martti T Date: Mon, 8 Mar 2021 03:13:22 +0200 Subject: [PATCH] Rework timeout middleware to use http.TimeoutHandler implementation (fix #1761) (#1801) --- middleware/timeout.go | 94 ++++++++++++-------- middleware/timeout_test.go | 175 ++++++++++++++++++++++++------------- 2 files changed, 172 insertions(+), 97 deletions(-) diff --git a/middleware/timeout.go b/middleware/timeout.go index 4be557f76..68f464e40 100644 --- a/middleware/timeout.go +++ b/middleware/timeout.go @@ -4,8 +4,8 @@ package middleware import ( "context" - "fmt" "github.com/labstack/echo/v4" + "net/http" "time" ) @@ -14,16 +14,23 @@ type ( TimeoutConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper - // ErrorHandler defines a function which is executed for a timeout - // It can be used to define a custom timeout error - ErrorHandler TimeoutErrorHandlerWithContext + + // ErrorMessage is written to response on timeout in addition to http.StatusServiceUnavailable (503) status code + // It can be used to define a custom timeout error message + ErrorMessage string + + // OnTimeoutRouteErrorHandler is an error handler that is executed for error that was returned from wrapped route after + // request timeouted and we already had sent the error code (503) and message response to the client. + // NB: do not write headers/body inside this handler. The response has already been sent to the client and response writer + // will not accept anything no more. If you want to know what actual route middleware timeouted use `c.Path()` + OnTimeoutRouteErrorHandler func(err error, c echo.Context) + // Timeout configures a timeout for the middleware, defaults to 0 for no timeout + // NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds) + // the result of timeout does not seem to be reliable - could respond timeout, could respond handler output + // difference over 500microseconds (0.5millisecond) response seems to be reliable Timeout time.Duration } - - // TimeoutErrorHandlerWithContext is an error handler that is used with the timeout middleware so we can - // handle the error as we see fit - TimeoutErrorHandlerWithContext func(error, echo.Context) error ) var ( @@ -31,7 +38,7 @@ var ( DefaultTimeoutConfig = TimeoutConfig{ Skipper: DefaultSkipper, Timeout: 0, - ErrorHandler: nil, + ErrorMessage: "", } ) @@ -55,39 +62,50 @@ func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc { return next(c) } - ctx, cancel := context.WithTimeout(c.Request().Context(), config.Timeout) - defer cancel() - - // this does a deep clone of the context, wondering if there is a better way to do this? - c.SetRequest(c.Request().Clone(ctx)) - - done := make(chan error, 1) - go func() { - defer func() { - if r := recover(); r != nil { - err, ok := r.(error) - if !ok { - err = fmt.Errorf("panic recovered in timeout middleware: %v", r) - } - c.Logger().Error(err) - done <- err - } - }() - - // This goroutine will keep running even if this middleware times out and - // will be stopped when ctx.Done() is called down the next(c) call chain - done <- next(c) - }() + handlerWrapper := echoHandlerFuncWrapper{ + ctx: c, + handler: next, + errChan: make(chan error, 1), + errHandler: config.OnTimeoutRouteErrorHandler, + } + handler := http.TimeoutHandler(handlerWrapper, config.Timeout, config.ErrorMessage) + handler.ServeHTTP(c.Response().Writer, c.Request()) select { - case <-ctx.Done(): - if config.ErrorHandler != nil { - return config.ErrorHandler(ctx.Err(), c) - } - return ctx.Err() - case err := <-done: + case err := <-handlerWrapper.errChan: return err + default: + return nil } } } } + +type echoHandlerFuncWrapper struct { + ctx echo.Context + handler echo.HandlerFunc + errHandler func(err error, c echo.Context) + errChan chan error +} + +func (t echoHandlerFuncWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + // replace writer with TimeoutHandler custom one. This will guarantee that + // `writes by h to its ResponseWriter will return ErrHandlerTimeout.` + originalWriter := t.ctx.Response().Writer + t.ctx.Response().Writer = rw + + err := t.handler(t.ctx) + if ctxErr := r.Context().Err(); ctxErr == context.DeadlineExceeded { + if err != nil && t.errHandler != nil { + t.errHandler(err, t.ctx) + } + return // on timeout we can not send handler error to client because `http.TimeoutHandler` has already sent headers + } + // we restore original writer only for cases we did not timeout. On timeout we have already sent response to client + // and should not anymore send additional headers/data + // so on timeout writer stays what http.TimeoutHandler uses and prevents writing headers/body + t.ctx.Response().Writer = originalWriter + if err != nil { + t.errChan <- err + } +} diff --git a/middleware/timeout_test.go b/middleware/timeout_test.go index faecc4c53..af4c62647 100644 --- a/middleware/timeout_test.go +++ b/middleware/timeout_test.go @@ -3,7 +3,6 @@ package middleware import ( - "context" "errors" "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" @@ -22,6 +21,7 @@ func TestTimeoutSkipper(t *testing.T) { Skipper: func(context echo.Context) bool { return true }, + Timeout: 1 * time.Nanosecond, }) req := httptest.NewRequest(http.MethodGet, "/", nil) @@ -31,18 +31,17 @@ func TestTimeoutSkipper(t *testing.T) { c := e.NewContext(req, rec) err := m(func(c echo.Context) error { - assert.NotEqual(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String()) - return nil + time.Sleep(25 * time.Microsecond) + return errors.New("response from handler") })(c) - assert.NoError(t, err) + // if not skipped we would have not returned error due context timeout logic + assert.EqualError(t, err, "response from handler") } func TestTimeoutWithTimeout0(t *testing.T) { t.Parallel() - m := TimeoutWithConfig(TimeoutConfig{ - Timeout: 0, - }) + m := Timeout() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() @@ -58,10 +57,11 @@ func TestTimeoutWithTimeout0(t *testing.T) { assert.NoError(t, err) } -func TestTimeoutIsCancelable(t *testing.T) { +func TestTimeoutErrorOutInHandler(t *testing.T) { t.Parallel() m := TimeoutWithConfig(TimeoutConfig{ - Timeout: time.Minute, + // Timeout has to be defined or the whole flow for timeout middleware will be skipped + Timeout: 50 * time.Millisecond, }) req := httptest.NewRequest(http.MethodGet, "/", nil) @@ -70,24 +70,6 @@ func TestTimeoutIsCancelable(t *testing.T) { e := echo.New() c := e.NewContext(req, rec) - err := m(func(c echo.Context) error { - assert.EqualValues(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String()) - return nil - })(c) - - assert.NoError(t, err) -} - -func TestTimeoutErrorOutInHandler(t *testing.T) { - t.Parallel() - m := Timeout() - - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - - e := echo.New() - c := e.NewContext(req, rec) - err := m(func(c echo.Context) error { return errors.New("err") })(c) @@ -95,34 +77,15 @@ func TestTimeoutErrorOutInHandler(t *testing.T) { assert.Error(t, err) } -func TestTimeoutTimesOutAfterPredefinedTimeoutWithErrorHandler(t *testing.T) { +func TestTimeoutOnTimeoutRouteErrorHandler(t *testing.T) { t.Parallel() - m := TimeoutWithConfig(TimeoutConfig{ - Timeout: time.Second, - ErrorHandler: func(err error, e echo.Context) error { - assert.EqualError(t, err, context.DeadlineExceeded.Error()) - return errors.New("err") - }, - }) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - - e := echo.New() - c := e.NewContext(req, rec) - err := m(func(c echo.Context) error { - time.Sleep(time.Minute) - return nil - })(c) - - assert.EqualError(t, err, errors.New("err").Error()) -} - -func TestTimeoutTimesOutAfterPredefinedTimeout(t *testing.T) { - t.Parallel() + actualErrChan := make(chan error, 1) m := TimeoutWithConfig(TimeoutConfig{ - Timeout: time.Second, + Timeout: 1 * time.Millisecond, + OnTimeoutRouteErrorHandler: func(err error, c echo.Context) { + actualErrChan <- err + }, }) req := httptest.NewRequest(http.MethodGet, "/", nil) @@ -131,12 +94,16 @@ func TestTimeoutTimesOutAfterPredefinedTimeout(t *testing.T) { e := echo.New() c := e.NewContext(req, rec) + stopChan := make(chan struct{}, 0) err := m(func(c echo.Context) error { - time.Sleep(time.Minute) - return nil + <-stopChan + return errors.New("error in route after timeout") })(c) + stopChan <- struct{}{} + assert.NoError(t, err) - assert.EqualError(t, err, context.DeadlineExceeded.Error()) + actualErr := <-actualErrChan + assert.EqualError(t, actualErr, "error in route after timeout") } func TestTimeoutTestRequestClone(t *testing.T) { @@ -148,7 +115,7 @@ func TestTimeoutTestRequestClone(t *testing.T) { m := TimeoutWithConfig(TimeoutConfig{ // Timeout has to be defined or the whole flow for timeout middleware will be skipped - Timeout: time.Second, + Timeout: 1 * time.Second, }) e := echo.New() @@ -178,8 +145,63 @@ func TestTimeoutTestRequestClone(t *testing.T) { func TestTimeoutRecoversPanic(t *testing.T) { t.Parallel() + e := echo.New() + e.Use(Recover()) // recover middleware will handler our panic + e.Use(TimeoutWithConfig(TimeoutConfig{ + Timeout: 50 * time.Millisecond, + })) + + e.GET("/", func(c echo.Context) error { + panic("panic!!!") + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + assert.NotPanics(t, func() { + e.ServeHTTP(rec, req) + }) +} + +func TestTimeoutDataRace(t *testing.T) { + t.Parallel() + + timeout := 1 * time.Millisecond + m := TimeoutWithConfig(TimeoutConfig{ + Timeout: timeout, + ErrorMessage: "Timeout! change me", + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e := echo.New() + c := e.NewContext(req, rec) + + err := m(func(c echo.Context) error { + // NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds) + // the result of timeout does not seem to be reliable - could respond timeout, could respond handler output + // difference over 500microseconds (0.5millisecond) response seems to be reliable + time.Sleep(timeout) // timeout and handler execution time difference is close to zero + return c.String(http.StatusOK, "Hello, World!") + })(c) + + assert.NoError(t, err) + + if rec.Code == http.StatusServiceUnavailable { + assert.Equal(t, "Timeout! change me", rec.Body.String()) + } else { + assert.Equal(t, "Hello, World!", rec.Body.String()) + } +} + +func TestTimeoutWithErrorMessage(t *testing.T) { + t.Parallel() + + timeout := 1 * time.Millisecond m := TimeoutWithConfig(TimeoutConfig{ - Timeout: 25 * time.Millisecond, + Timeout: timeout, + ErrorMessage: "Timeout! change me", }) req := httptest.NewRequest(http.MethodGet, "/", nil) @@ -188,9 +210,44 @@ func TestTimeoutRecoversPanic(t *testing.T) { e := echo.New() c := e.NewContext(req, rec) + stopChan := make(chan struct{}, 0) err := m(func(c echo.Context) error { - panic("panic in handler") + // NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds) + // the result of timeout does not seem to be reliable - could respond timeout, could respond handler output + // difference over 500microseconds (0.5millisecond) response seems to be reliable + <-stopChan + return c.String(http.StatusOK, "Hello, World!") })(c) + stopChan <- struct{}{} + + assert.NoError(t, err) + assert.Equal(t, http.StatusServiceUnavailable, rec.Code) + assert.Equal(t, "Timeout! change me", rec.Body.String()) +} + +func TestTimeoutWithDefaultErrorMessage(t *testing.T) { + t.Parallel() - assert.Error(t, err, "panic recovered in timeout middleware: panic in handler") + timeout := 1 * time.Millisecond + m := TimeoutWithConfig(TimeoutConfig{ + Timeout: timeout, + ErrorMessage: "", + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e := echo.New() + c := e.NewContext(req, rec) + + stopChan := make(chan struct{}, 0) + err := m(func(c echo.Context) error { + <-stopChan + return c.String(http.StatusOK, "Hello, World!") + })(c) + stopChan <- struct{}{} + + assert.NoError(t, err) + assert.Equal(t, http.StatusServiceUnavailable, rec.Code) + assert.Equal(t, `Timeout

Timeout

`, rec.Body.String()) }