diff --git a/cmd/auth.go b/cmd/auth.go index e161ab6da..2d368ba7d 100644 --- a/cmd/auth.go +++ b/cmd/auth.go @@ -1,6 +1,8 @@ package main import ( + "encoding/base64" + "encoding/json" "errors" "net/http" "net/mail" @@ -28,6 +30,11 @@ type loginTpl struct { Error string } +type oidcState struct { + Nonce string `json:"nonce"` + Next string `json:"next"` +} + var oidcProviders = map[string]bool{ "google.com": true, "microsoftonline.com": true, @@ -107,7 +114,18 @@ func handleOIDCLogin(c echo.Context) error { next = uriAdmin } - return c.Redirect(http.StatusFound, app.auth.GetOIDCAuthURL(next, nonce.Value)) + state := oidcState{ + Nonce: nonce.Value, + Next: next, + } + + stateJSON, err := json.Marshal(state) + if err != nil { + app.log.Printf("error marshalling OIDC state: %v", err) + return echo.NewHTTPError(http.StatusInternalServerError, app.i18n.T("globals.messages.internalError")) + } + + return c.Redirect(http.StatusFound, app.auth.GetOIDCAuthURL(base64.URLEncoding.EncodeToString(stateJSON), nonce.Value)) } // handleOIDCFinish receives the redirect callback from the OIDC provider and completes the handshake. @@ -125,6 +143,21 @@ func handleOIDCFinish(c echo.Context) error { return renderLoginPage(c, err) } + // Validate the state. + var state oidcState + stateB, err := base64.URLEncoding.DecodeString(c.QueryParam("state")) + if err != nil { + app.log.Printf("error decoding OIDC state: %v", err) + return echo.NewHTTPError(http.StatusInternalServerError, app.i18n.T("globals.messages.internalError")) + } + if err := json.Unmarshal(stateB, &state); err != nil { + app.log.Printf("error unmarshalling OIDC state: %v", err) + return echo.NewHTTPError(http.StatusInternalServerError, app.i18n.T("globals.messages.internalError")) + } + if state.Nonce != nonce.Value { + return renderLoginPage(c, echo.NewHTTPError(http.StatusUnauthorized, app.i18n.T("users.invalidRequest"))) + } + // Validate e-mail from the claim. email := strings.TrimSpace(claims.Email) if email == "" { @@ -153,7 +186,7 @@ func handleOIDCFinish(c echo.Context) error { return renderLoginPage(c, err) } - return c.Redirect(http.StatusFound, utils.SanitizeURI(c.QueryParam("state"))) + return c.Redirect(http.StatusFound, utils.SanitizeURI(state.Next)) } // renderLoginPage renders the login page and handles the login form. diff --git a/cmd/init.go b/cmd/init.go index 632f0ac16..8b5eb899c 100644 --- a/cmd/init.go +++ b/cmd/init.go @@ -434,9 +434,6 @@ func initConstants() *constants { c.BounceSendgridEnabled = ko.Bool("bounce.sendgrid_enabled") c.BouncePostmarkEnabled = ko.Bool("bounce.postmark.enabled") c.BounceForwardemailEnabled = ko.Bool("bounce.forwardemail.enabled") - - fmt.Println(c.BounceForwardemailEnabled) - c.HasLegacyUser = ko.Exists("app.admin_username") || ko.Exists("app.admin_password") b := md5.Sum([]byte(time.Now().String()))