Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

feat(auth/oidc): Use BadRequest response for state mismatches, enhance logging with context #403

Merged
merged 2 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions core/auth/oauth/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ var (

return ok
}

errNoStateInRequest = errors.New("no state in request")
errStateMismatch = errors.New("state mismatch")
errNoIDTokenClaim = errors.New("claim id_token missing")
errGeneric = errors.New("OpenID Connect error")
IvanMaidurov marked this conversation as resolved.
Show resolved Hide resolved
)

func oidcFactory(cfg config.Map) (auth.RequestIdentifier, error) {
Expand Down Expand Up @@ -443,12 +448,12 @@ func (i *openIDIdentifier) Authenticate(ctx context.Context, request *web.Reques

authConfig, err := i.config(request)
if err != nil {
return i.responder.ServerError(err)
return i.responder.ServerErrorWithContext(ctx, err)
}

u, err := url.Parse(authConfig.AuthCodeURL(state, options...))
if err != nil {
return i.responder.ServerError(err)
return i.responder.ServerErrorWithContext(ctx, err)
}

return i.responder.URLRedirect(u)
Expand All @@ -464,25 +469,26 @@ func (i *openIDIdentifier) Callback(ctx context.Context, request *web.Request, r
}
}

return i.responder.ServerError(fmt.Errorf("OpenID Connect error: %q (%q)", errString, errDetails))
return i.responder.ServerErrorWithContext(ctx, fmt.Errorf("%w: %q (%q)", errGeneric, errString, errDetails))
}

queryState, err := request.Query1("state")
if err != nil {
return i.responder.ServerError(errors.New("no state in request"))
return i.responder.BadRequestWithContext(ctx, errNoStateInRequest)
}

if !i.validateSessionCode(request, queryState) {
return i.responder.ServerError(errors.New("state mismatch"))
return i.responder.BadRequestWithContext(ctx, errStateMismatch)
}

code, err := request.Query1("code")
if err != nil {
return i.responder.ServerError(err)
return i.responder.BadRequestWithContext(ctx, fmt.Errorf("%w: code", err))
}

oauthConfig, err := i.config(request)
if err != nil {
return i.responder.ServerError(err)
return i.responder.ServerErrorWithContext(ctx, err)
}

options := make([]oauth2.AuthCodeOption, 0)
Expand All @@ -495,13 +501,13 @@ func (i *openIDIdentifier) Callback(ctx context.Context, request *web.Request, r

oauth2Token, err := oauthConfig.Exchange(ctx, code, options...)
if err != nil {
return i.responder.ServerError(err)
return i.responder.ServerErrorWithContext(ctx, err)
}

// Extract the ID Token from OAuth2 token.
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
if !ok {
return i.responder.ServerError(errors.New("claim id_token missing"))
return i.responder.ServerErrorWithContext(ctx, errNoIDTokenClaim)
}

verifierConfig := &oidc.Config{ClientID: i.oauth2Config.ClientID}
Expand All @@ -513,7 +519,7 @@ func (i *openIDIdentifier) Callback(ctx context.Context, request *web.Request, r
// Parse and verify ID Token payload.
idToken, err := verifier.Verify(ctx, rawIDToken)
if err != nil {
return i.responder.ServerError(err)
return i.responder.ServerErrorWithContext(ctx, err)
}

var (
Expand All @@ -523,7 +529,7 @@ func (i *openIDIdentifier) Callback(ctx context.Context, request *web.Request, r
)

if err := idToken.Claims(&tempIDTokenClaims); err != nil {
return i.responder.ServerError(err)
return i.responder.ServerErrorWithContext(ctx, err)
}
for k, v := range i.oidcConfig.Claims.IDToken {
idTokenClaims[k] = tempIDTokenClaims[v]
Expand Down Expand Up @@ -561,7 +567,7 @@ func (i *openIDIdentifier) Callback(ctx context.Context, request *web.Request, r
identity, err := i.Identify(ctx, request)
if err != nil {
i.Logout(ctx, request)
return i.responder.ServerError(err)
return i.responder.ServerErrorWithContext(ctx, err)
}

i.eventRouter.Dispatch(ctx, &auth.WebLoginEvent{Broker: i.broker, Request: request, Identity: identity})
Expand Down
4 changes: 2 additions & 2 deletions core/auth/oauth/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,12 @@ func TestParallelStateRaceConditions(t *testing.T) {
request.URL.RawQuery = url.Values{"state": []string{state2}}.Encode()
resp = identifier.Callback(context.Background(), web.CreateRequest(request, session), nil)
errResp = resp.(*web.ServerErrorResponse)
assert.EqualError(t, errResp.Error, "query value not found")
assert.EqualError(t, errResp.Error, "query value not found: code")

request.URL.RawQuery = url.Values{"state": []string{state1}}.Encode()
resp = identifier.Callback(context.Background(), web.CreateRequest(request, session), nil)
errResp = resp.(*web.ServerErrorResponse)
assert.EqualError(t, errResp.Error, "query value not found")
assert.EqualError(t, errResp.Error, "query value not found: code")

request.URL.RawQuery = url.Values{"state": []string{state1}}.Encode()
resp = identifier.Callback(context.Background(), web.CreateRequest(request, session), nil)
Expand Down
13 changes: 7 additions & 6 deletions core/requestlogger/logger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@ func TestLogger(t *testing.T) {
request.Request().Header.Set("Referer", "https://example.com/")

responder := new(web.Responder).Inject(&web.Router{}, flamingo.NullLogger{}, &struct {
Engine flamingo.TemplateEngine "inject:\",optional\""
Debug bool "inject:\"config:flamingo.debug.mode\""
TemplateForbidden string "inject:\"config:flamingo.template.err403\""
TemplateNotFound string "inject:\"config:flamingo.template.err404\""
TemplateUnavailable string "inject:\"config:flamingo.template.err503\""
TemplateErrorWithCode string "inject:\"config:flamingo.template.errWithCode\""
Engine flamingo.TemplateEngine `inject:",optional"`
Debug bool `inject:"config:flamingo.debug.mode"`
TemplateBadRequest string `inject:"config:flamingo.template.err400"`
TemplateForbidden string `inject:"config:flamingo.template.err403"`
TemplateNotFound string `inject:"config:flamingo.template.err404"`
TemplateUnavailable string `inject:"config:flamingo.template.err503"`
TemplateErrorWithCode string `inject:"config:flamingo.template.errWithCode"`
}{})

tests := []struct {
Expand Down
1 change: 1 addition & 0 deletions framework/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ flamingo: {
path?: string
}
template: {
err400: string | *"error/400"
err403: string | *"error/403"
err404: string | *"error/404"
errWithCode: string | *"error/withCode"
Expand Down
40 changes: 35 additions & 5 deletions framework/web/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ type (
debug bool

templateForbidden string
templateBadRequest string
templateNotFound string
templateUnavailable string
templateErrorWithCode string
Expand Down Expand Up @@ -127,6 +128,7 @@ const (
func (r *Responder) Inject(router *Router, logger flamingo.Logger, cfg *struct {
Engine flamingo.TemplateEngine `inject:",optional"`
Debug bool `inject:"config:flamingo.debug.mode"`
TemplateBadRequest string `inject:"config:flamingo.template.err400"`
TemplateForbidden string `inject:"config:flamingo.template.err403"`
TemplateNotFound string `inject:"config:flamingo.template.err404"`
TemplateUnavailable string `inject:"config:flamingo.template.err503"`
Expand All @@ -135,6 +137,7 @@ func (r *Responder) Inject(router *Router, logger flamingo.Logger, cfg *struct {
r.engine = cfg.Engine
r.router = router
r.templateForbidden = cfg.TemplateForbidden
r.templateBadRequest = cfg.TemplateBadRequest
r.templateNotFound = cfg.TemplateNotFound
r.templateUnavailable = cfg.TemplateUnavailable
r.templateErrorWithCode = cfg.TemplateErrorWithCode
Expand Down Expand Up @@ -429,36 +432,63 @@ func (r *Responder) ServerErrorWithCodeAndTemplate(err error, tpl string, status

// ServerError creates a 500 error response
func (r *Responder) ServerError(err error) *ServerErrorResponse {
return r.ServerErrorWithContext(context.Background(), err)
}

// ServerErrorWithContext creates a 500 error response and uses the provided context for enhanced logging
func (r *Responder) ServerErrorWithContext(ctx context.Context, err error) *ServerErrorResponse {
if errors.Is(err, context.Canceled) {
r.getLogger().Debug(fmt.Sprintf("%+v\n", err))
r.getLogger().WithContext(ctx).Debug(fmt.Sprintf("%+v\n", err))
} else {
r.getLogger().Error(fmt.Sprintf("%+v\n", err))
r.getLogger().WithContext(ctx).Error(fmt.Sprintf("%+v\n", err))
}

return r.ServerErrorWithCodeAndTemplate(err, r.templateErrorWithCode, http.StatusInternalServerError)
}

// Unavailable creates a 503 error response
func (r *Responder) Unavailable(err error) *ServerErrorResponse {
r.getLogger().Error(fmt.Sprintf("%+v\n", err))
return r.UnavailableWithContext(context.Background(), err)
}

// UnavailableWithContext creates a 503 error response and uses the provided context for enhanced logging
func (r *Responder) UnavailableWithContext(ctx context.Context, err error) *ServerErrorResponse {
r.getLogger().WithContext(ctx).Error(fmt.Sprintf("%+v\n", err))

return r.ServerErrorWithCodeAndTemplate(err, r.templateUnavailable, http.StatusServiceUnavailable)
}

// NotFound creates a 404 error response
func (r *Responder) NotFound(err error) *ServerErrorResponse {
r.getLogger().Warn(err)
return r.NotFoundWithContext(context.Background(), err)
}

// NotFoundWithContext creates a 404 error response and uses the provided context for enhanced logging
func (r *Responder) NotFoundWithContext(ctx context.Context, err error) *ServerErrorResponse {
r.getLogger().WithContext(ctx).Warn(err)

return r.ServerErrorWithCodeAndTemplate(err, r.templateNotFound, http.StatusNotFound)
}

// Forbidden creates a 403 error response
func (r *Responder) Forbidden(err error) *ServerErrorResponse {
r.getLogger().Warn(err)
return r.ForbiddenWithContext(context.Background(), err)
}

// ForbiddenWithContext creates a 403 error response and uses the provided context for enhanced logging
func (r *Responder) ForbiddenWithContext(ctx context.Context, err error) *ServerErrorResponse {
r.getLogger().WithContext(ctx).Warn(err)

return r.ServerErrorWithCodeAndTemplate(err, r.templateForbidden, http.StatusForbidden)
}

// BadRequestWithContext creates a 400 error response and uses the provided context for enhanced logging
func (r *Responder) BadRequestWithContext(ctx context.Context, err error) *ServerErrorResponse {
r.getLogger().WithContext(ctx).Info(err)

return r.ServerErrorWithCodeAndTemplate(err, r.templateForbidden, http.StatusBadRequest)
}

// SetNoCache helper
func (r *ServerErrorResponse) SetNoCache() *ServerErrorResponse {
r.Response.SetNoCache()
Expand Down