diff --git a/message/router/middleware/timeout.go b/message/router/middleware/timeout.go new file mode 100644 index 000000000..c30b5c30d --- /dev/null +++ b/message/router/middleware/timeout.go @@ -0,0 +1,22 @@ +package middleware + +import ( + "context" + "time" + + "github.com/ThreeDotsLabs/watermill/message" +) + +func Timeout(timeout time.Duration) func(message.HandlerFunc) message.HandlerFunc { + return func(h message.HandlerFunc) message.HandlerFunc { + return func(msg *message.Message) ([]*message.Message, error) { + ctx, cancel := context.WithTimeout(msg.Context(), timeout) + defer func() { + cancel() + }() + + msg.SetContext(ctx) + return h(msg) + } + } +} diff --git a/message/router/middleware/timeout_test.go b/message/router/middleware/timeout_test.go new file mode 100644 index 000000000..98172e520 --- /dev/null +++ b/message/router/middleware/timeout_test.go @@ -0,0 +1,30 @@ +package middleware_test + +import ( + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/ThreeDotsLabs/watermill/message" + "github.com/ThreeDotsLabs/watermill/message/router/middleware" +) + +func TestTimeout(t *testing.T) { + timeout := middleware.Timeout(time.Millisecond * 10) + + h := timeout(func(msg *message.Message) ([]*message.Message, error) { + delay := time.After(time.Millisecond * 100) + + select { + case <-msg.Context().Done(): + return nil, nil + case <-delay: + return nil, errors.New("timeout did not occur") + } + }) + + _, err := h(message.NewMessage("any-uuid", nil)) + require.NoError(t, err) +}