diff --git a/internal/server/server.go b/internal/server/server.go index 5db6150..ec0bb7d 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -15,12 +15,12 @@ import ( const OAUTHPath = "/_oauth" type Server struct { - http.Handler + Config + oauthHandlers map[string]OAuthHandler sessionCookieHandler stateHandler whitelist.Whitelist - oauthHandlers map[string]OAuthHandler - config Config + http.Handler } type OAuthHandler interface { @@ -54,7 +54,7 @@ func New(config Config, l *slog.Logger) *Server { } } s := Server{ - config: config, + Config: config, oauthHandlers: oauthHandlers, sessionCookieHandler: sessionCookieHandler{ SecureCookie: !config.InsecureCookie, @@ -108,7 +108,7 @@ func (s *Server) authHandler(l *slog.Logger) http.HandlerFunc { } - if _, ok := s.config.Domains.getDomain(r.URL); !ok { + if _, ok := s.Domains.getDomain(r.URL); !ok { l.Warn("host doesn't match any configured domains", "host", r.URL.Host) http.Error(w, "Not authorized", http.StatusUnauthorized) return @@ -129,25 +129,15 @@ func (s *Server) redirectToAuth(w http.ResponseWriter, r *http.Request, l *slog. http.Error(w, "Internal server error", http.StatusInternalServerError) } - domain, ok := s.config.Domains.getDomain(r.URL) + domain, ok := s.Domains.getDomain(r.URL) if !ok { l.Error("invalid target host", "host", r.URL.Host) http.Error(w, "Invalid target host", http.StatusUnauthorized) return } - h, ok := s.oauthHandlers[domain] - if !ok { - l.Error("invalid target domain", "domain", domain) - http.Error(w, "Invalid target domain", http.StatusUnauthorized) - return - } - // Redirect user to Google to select the account to be used to authenticate the request - authCodeURL := h.AuthCodeURL(encodedState, - oauth2.SetAuthURLParam("prompt", "select_account"), - ) - + authCodeURL := s.oauthHandlers[domain].AuthCodeURL(encodedState, oauth2.SetAuthURLParam("prompt", "select_account")) l.Debug("redirecting ...", "authCodeURL", authCodeURL) http.Redirect(w, r, authCodeURL, http.StatusTemporaryRedirect) } @@ -168,12 +158,12 @@ func (s *Server) authCallbackHandler(l *slog.Logger) http.HandlerFunc { } // we already validated the host vs the domain during the redirect + // since the state matches, we can trust the request to be valid u, _ := url.Parse(redirectURL) - domain, _ := s.config.Domains.getDomain(u) - h, _ := s.oauthHandlers[domain] + domain, _ := s.Domains.getDomain(u) // Use the "code" in the response to determine the user's email address. - user, err := h.GetUserEmailAddress(r.FormValue("code")) + user, err := s.oauthHandlers[domain].GetUserEmailAddress(r.FormValue("code")) if err != nil { l.Error("failed to log in to google", "err", err) http.Error(w, "oauth2 failed", http.StatusBadGateway) @@ -191,11 +181,11 @@ func (s *Server) authCallbackHandler(l *slog.Logger) http.HandlerFunc { // GetUserEmailAddress successful. Add session cookie and redirect the user to the final destination. sc := sessionCookie{ Email: user, - Expiry: time.Now().Add(s.config.Expiry), + Expiry: time.Now().Add(s.Config.Expiry), } s.sessionCookieHandler.saveCookie(sc) - http.SetCookie(w, s.makeCookie(sc.encode(s.config.Secret), domain)) + http.SetCookie(w, s.makeCookie(sc.encode(s.Config.Secret), domain)) l.Info("user logged in. redirecting ...", "user", user, "url", redirectURL) http.Redirect(w, r, redirectURL, http.StatusTemporaryRedirect) } @@ -213,7 +203,7 @@ func (s *Server) logoutHandler(l *slog.Logger) http.HandlerFunc { } // get the domain for the target - domain, _ := s.config.Domains.getDomain(r.URL) + domain, _ := s.Domains.getDomain(r.URL) // Write a blank session cookie to override the current valid one. http.SetCookie(w, s.makeCookie("", domain)) @@ -229,7 +219,7 @@ func (s *Server) makeCookie(value, domain string) *http.Cookie { Value: value, Domain: domain, Path: "/", - Secure: !s.config.InsecureCookie, + Secure: !s.InsecureCookie, HttpOnly: true, } } diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 36152b3..0e113a5 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -62,7 +62,7 @@ func TestServer_authHandler(t *testing.T) { name: "valid domain with user info", args: args{ host: "user:password@www.example.com", - cookie: s.makeSessionCookie("foo@example.com", config.Secret), + cookie: s.makeSessionCookie("foo@example.com", Config.Secret), }, want: http.StatusOK, user: "foo@example.com",