From e6774ce05582811638ebbf39f054fd1f5cd97e16 Mon Sep 17 00:00:00 2001 From: Carsten Dietrich Date: Thu, 2 May 2024 08:31:03 +0200 Subject: [PATCH 1/2] feat(core/auth/oidc): Use BadRequest response for state mismatches, enhance logging with context --- core/auth/oauth/oidc.go | 25 +++++++++---------- core/auth/oauth/oidc_test.go | 4 ++-- core/requestlogger/logger_test.go | 13 +++++----- framework/module.go | 1 + framework/web/result.go | 40 +++++++++++++++++++++++++++---- 5 files changed, 58 insertions(+), 25 deletions(-) diff --git a/core/auth/oauth/oidc.go b/core/auth/oauth/oidc.go index 745669fc4..4d8d5b26c 100644 --- a/core/auth/oauth/oidc.go +++ b/core/auth/oauth/oidc.go @@ -443,12 +443,12 @@ func (i *openIDIdentifier) Authenticate(ctx context.Context, request *web.Reques authConfig, err := i.config(request) if err != nil { - return i.responder.ServerError(err) + return i.responder.ServerErrorWithContext(ctx, err) } u, err := url.Parse(authConfig.AuthCodeURL(state, options...)) if err != nil { - return i.responder.ServerError(err) + return i.responder.ServerErrorWithContext(ctx, err) } return i.responder.URLRedirect(u) @@ -464,25 +464,26 @@ func (i *openIDIdentifier) Callback(ctx context.Context, request *web.Request, r } } - return i.responder.ServerError(fmt.Errorf("OpenID Connect error: %q (%q)", errString, errDetails)) + return i.responder.ServerErrorWithContext(ctx, fmt.Errorf("OpenID Connect error: %q (%q)", errString, errDetails)) } queryState, err := request.Query1("state") if err != nil { - return i.responder.ServerError(errors.New("no state in request")) + return i.responder.BadRequestWithContext(ctx, errors.New("no state in request")) } + if !i.validateSessionCode(request, queryState) { - return i.responder.ServerError(errors.New("state mismatch")) + return i.responder.BadRequestWithContext(ctx, errors.New("state mismatch")) } code, err := request.Query1("code") if err != nil { - return i.responder.ServerError(err) + return i.responder.BadRequestWithContext(ctx, fmt.Errorf("%w: code", err)) } oauthConfig, err := i.config(request) if err != nil { - return i.responder.ServerError(err) + return i.responder.ServerErrorWithContext(ctx, err) } options := make([]oauth2.AuthCodeOption, 0) @@ -495,13 +496,13 @@ func (i *openIDIdentifier) Callback(ctx context.Context, request *web.Request, r oauth2Token, err := oauthConfig.Exchange(ctx, code, options...) if err != nil { - return i.responder.ServerError(err) + return i.responder.ServerErrorWithContext(ctx, err) } // Extract the ID Token from OAuth2 token. rawIDToken, ok := oauth2Token.Extra("id_token").(string) if !ok { - return i.responder.ServerError(errors.New("claim id_token missing")) + return i.responder.ServerErrorWithContext(ctx, errors.New("claim id_token missing")) } verifierConfig := &oidc.Config{ClientID: i.oauth2Config.ClientID} @@ -513,7 +514,7 @@ func (i *openIDIdentifier) Callback(ctx context.Context, request *web.Request, r // Parse and verify ID Token payload. idToken, err := verifier.Verify(ctx, rawIDToken) if err != nil { - return i.responder.ServerError(err) + return i.responder.ServerErrorWithContext(ctx, err) } var ( @@ -523,7 +524,7 @@ func (i *openIDIdentifier) Callback(ctx context.Context, request *web.Request, r ) if err := idToken.Claims(&tempIDTokenClaims); err != nil { - return i.responder.ServerError(err) + return i.responder.ServerErrorWithContext(ctx, err) } for k, v := range i.oidcConfig.Claims.IDToken { idTokenClaims[k] = tempIDTokenClaims[v] @@ -561,7 +562,7 @@ func (i *openIDIdentifier) Callback(ctx context.Context, request *web.Request, r identity, err := i.Identify(ctx, request) if err != nil { i.Logout(ctx, request) - return i.responder.ServerError(err) + return i.responder.ServerErrorWithContext(ctx, err) } i.eventRouter.Dispatch(ctx, &auth.WebLoginEvent{Broker: i.broker, Request: request, Identity: identity}) diff --git a/core/auth/oauth/oidc_test.go b/core/auth/oauth/oidc_test.go index 1074349cf..0cb8d9b8f 100644 --- a/core/auth/oauth/oidc_test.go +++ b/core/auth/oauth/oidc_test.go @@ -83,12 +83,12 @@ func TestParallelStateRaceConditions(t *testing.T) { request.URL.RawQuery = url.Values{"state": []string{state2}}.Encode() resp = identifier.Callback(context.Background(), web.CreateRequest(request, session), nil) errResp = resp.(*web.ServerErrorResponse) - assert.EqualError(t, errResp.Error, "query value not found") + assert.EqualError(t, errResp.Error, "query value not found: code") request.URL.RawQuery = url.Values{"state": []string{state1}}.Encode() resp = identifier.Callback(context.Background(), web.CreateRequest(request, session), nil) errResp = resp.(*web.ServerErrorResponse) - assert.EqualError(t, errResp.Error, "query value not found") + assert.EqualError(t, errResp.Error, "query value not found: code") request.URL.RawQuery = url.Values{"state": []string{state1}}.Encode() resp = identifier.Callback(context.Background(), web.CreateRequest(request, session), nil) diff --git a/core/requestlogger/logger_test.go b/core/requestlogger/logger_test.go index 290e6f370..bf0f022e1 100644 --- a/core/requestlogger/logger_test.go +++ b/core/requestlogger/logger_test.go @@ -30,12 +30,13 @@ func TestLogger(t *testing.T) { request.Request().Header.Set("Referer", "https://example.com/") responder := new(web.Responder).Inject(&web.Router{}, flamingo.NullLogger{}, &struct { - Engine flamingo.TemplateEngine "inject:\",optional\"" - Debug bool "inject:\"config:flamingo.debug.mode\"" - TemplateForbidden string "inject:\"config:flamingo.template.err403\"" - TemplateNotFound string "inject:\"config:flamingo.template.err404\"" - TemplateUnavailable string "inject:\"config:flamingo.template.err503\"" - TemplateErrorWithCode string "inject:\"config:flamingo.template.errWithCode\"" + Engine flamingo.TemplateEngine `inject:",optional"` + Debug bool `inject:"config:flamingo.debug.mode"` + TemplateBadRequest string `inject:"config:flamingo.template.err400"` + TemplateForbidden string `inject:"config:flamingo.template.err403"` + TemplateNotFound string `inject:"config:flamingo.template.err404"` + TemplateUnavailable string `inject:"config:flamingo.template.err503"` + TemplateErrorWithCode string `inject:"config:flamingo.template.errWithCode"` }{}) tests := []struct { diff --git a/framework/module.go b/framework/module.go index af10bddbd..7a28191e4 100644 --- a/framework/module.go +++ b/framework/module.go @@ -104,6 +104,7 @@ flamingo: { path?: string } template: { + err400: string | *"error/400" err403: string | *"error/403" err404: string | *"error/404" errWithCode: string | *"error/withCode" diff --git a/framework/web/result.go b/framework/web/result.go index 911a462b7..8affd3d50 100644 --- a/framework/web/result.go +++ b/framework/web/result.go @@ -32,6 +32,7 @@ type ( debug bool templateForbidden string + templateBadRequest string templateNotFound string templateUnavailable string templateErrorWithCode string @@ -127,6 +128,7 @@ const ( func (r *Responder) Inject(router *Router, logger flamingo.Logger, cfg *struct { Engine flamingo.TemplateEngine `inject:",optional"` Debug bool `inject:"config:flamingo.debug.mode"` + TemplateBadRequest string `inject:"config:flamingo.template.err400"` TemplateForbidden string `inject:"config:flamingo.template.err403"` TemplateNotFound string `inject:"config:flamingo.template.err404"` TemplateUnavailable string `inject:"config:flamingo.template.err503"` @@ -135,6 +137,7 @@ func (r *Responder) Inject(router *Router, logger flamingo.Logger, cfg *struct { r.engine = cfg.Engine r.router = router r.templateForbidden = cfg.TemplateForbidden + r.templateBadRequest = cfg.TemplateBadRequest r.templateNotFound = cfg.TemplateNotFound r.templateUnavailable = cfg.TemplateUnavailable r.templateErrorWithCode = cfg.TemplateErrorWithCode @@ -429,10 +432,15 @@ func (r *Responder) ServerErrorWithCodeAndTemplate(err error, tpl string, status // ServerError creates a 500 error response func (r *Responder) ServerError(err error) *ServerErrorResponse { + return r.ServerErrorWithContext(context.Background(), err) +} + +// ServerErrorWithContext creates a 500 error response and uses the provided context for enhanced logging +func (r *Responder) ServerErrorWithContext(ctx context.Context, err error) *ServerErrorResponse { if errors.Is(err, context.Canceled) { - r.getLogger().Debug(fmt.Sprintf("%+v\n", err)) + r.getLogger().WithContext(ctx).Debug(fmt.Sprintf("%+v\n", err)) } else { - r.getLogger().Error(fmt.Sprintf("%+v\n", err)) + r.getLogger().WithContext(ctx).Error(fmt.Sprintf("%+v\n", err)) } return r.ServerErrorWithCodeAndTemplate(err, r.templateErrorWithCode, http.StatusInternalServerError) @@ -440,25 +448,47 @@ func (r *Responder) ServerError(err error) *ServerErrorResponse { // Unavailable creates a 503 error response func (r *Responder) Unavailable(err error) *ServerErrorResponse { - r.getLogger().Error(fmt.Sprintf("%+v\n", err)) + return r.UnavailableWithContext(context.Background(), err) +} + +// UnavailableWithContext creates a 503 error response and uses the provided context for enhanced logging +func (r *Responder) UnavailableWithContext(ctx context.Context, err error) *ServerErrorResponse { + r.getLogger().WithContext(ctx).Error(fmt.Sprintf("%+v\n", err)) return r.ServerErrorWithCodeAndTemplate(err, r.templateUnavailable, http.StatusServiceUnavailable) } // NotFound creates a 404 error response func (r *Responder) NotFound(err error) *ServerErrorResponse { - r.getLogger().Warn(err) + return r.NotFoundWithContext(context.Background(), err) +} + +// NotFoundWithContext creates a 404 error response and uses the provided context for enhanced logging +func (r *Responder) NotFoundWithContext(ctx context.Context, err error) *ServerErrorResponse { + r.getLogger().WithContext(ctx).Warn(err) return r.ServerErrorWithCodeAndTemplate(err, r.templateNotFound, http.StatusNotFound) } // Forbidden creates a 403 error response func (r *Responder) Forbidden(err error) *ServerErrorResponse { - r.getLogger().Warn(err) + return r.ForbiddenWithContext(context.Background(), err) +} + +// ForbiddenWithContext creates a 403 error response and uses the provided context for enhanced logging +func (r *Responder) ForbiddenWithContext(ctx context.Context, err error) *ServerErrorResponse { + r.getLogger().WithContext(ctx).Warn(err) return r.ServerErrorWithCodeAndTemplate(err, r.templateForbidden, http.StatusForbidden) } +// BadRequestWithContext creates a 400 error response and uses the provided context for enhanced logging +func (r *Responder) BadRequestWithContext(ctx context.Context, err error) *ServerErrorResponse { + r.getLogger().WithContext(ctx).Info(err) + + return r.ServerErrorWithCodeAndTemplate(err, r.templateForbidden, http.StatusBadRequest) +} + // SetNoCache helper func (r *ServerErrorResponse) SetNoCache() *ServerErrorResponse { r.Response.SetNoCache() From b67fe075470b009bf3b882eb3290f64fa44d0679 Mon Sep 17 00:00:00 2001 From: Carsten Dietrich Date: Thu, 2 May 2024 08:44:06 +0200 Subject: [PATCH 2/2] lint: avoid dynamic errors --- core/auth/oauth/oidc.go | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/core/auth/oauth/oidc.go b/core/auth/oauth/oidc.go index 4d8d5b26c..c923a8d7c 100644 --- a/core/auth/oauth/oidc.go +++ b/core/auth/oauth/oidc.go @@ -109,6 +109,11 @@ var ( return ok } + + errNoStateInRequest = errors.New("no state in request") + errStateMismatch = errors.New("state mismatch") + errNoIDTokenClaim = errors.New("claim id_token missing") + errGeneric = errors.New("OpenID Connect error") ) func oidcFactory(cfg config.Map) (auth.RequestIdentifier, error) { @@ -464,16 +469,16 @@ func (i *openIDIdentifier) Callback(ctx context.Context, request *web.Request, r } } - return i.responder.ServerErrorWithContext(ctx, fmt.Errorf("OpenID Connect error: %q (%q)", errString, errDetails)) + return i.responder.ServerErrorWithContext(ctx, fmt.Errorf("%w: %q (%q)", errGeneric, errString, errDetails)) } queryState, err := request.Query1("state") if err != nil { - return i.responder.BadRequestWithContext(ctx, errors.New("no state in request")) + return i.responder.BadRequestWithContext(ctx, errNoStateInRequest) } if !i.validateSessionCode(request, queryState) { - return i.responder.BadRequestWithContext(ctx, errors.New("state mismatch")) + return i.responder.BadRequestWithContext(ctx, errStateMismatch) } code, err := request.Query1("code") @@ -502,7 +507,7 @@ func (i *openIDIdentifier) Callback(ctx context.Context, request *web.Request, r // Extract the ID Token from OAuth2 token. rawIDToken, ok := oauth2Token.Extra("id_token").(string) if !ok { - return i.responder.ServerErrorWithContext(ctx, errors.New("claim id_token missing")) + return i.responder.ServerErrorWithContext(ctx, errNoIDTokenClaim) } verifierConfig := &oidc.Config{ClientID: i.oauth2Config.ClientID}