Skip to content

Commit

Permalink
refactor: better way to generate domain-specific auth URL
Browse files Browse the repository at this point in the history
  • Loading branch information
clambin committed Apr 20, 2024
1 parent 89fd13f commit 82f73d5
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 25 deletions.
38 changes: 14 additions & 24 deletions internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ import (
const OAUTHPath = "/_oauth"

type Server struct {
http.Handler
Config
oauthHandlers map[string]OAuthHandler
sessionCookieHandler
stateHandler
whitelist.Whitelist
oauthHandlers map[string]OAuthHandler
config Config
http.Handler
}

type OAuthHandler interface {
Expand Down Expand Up @@ -54,7 +54,7 @@ func New(config Config, l *slog.Logger) *Server {
}
}
s := Server{
config: config,
Config: config,
oauthHandlers: oauthHandlers,
sessionCookieHandler: sessionCookieHandler{
SecureCookie: !config.InsecureCookie,
Expand Down Expand Up @@ -108,7 +108,7 @@ func (s *Server) authHandler(l *slog.Logger) http.HandlerFunc {

}

if _, ok := s.config.Domains.getDomain(r.URL); !ok {
if _, ok := s.Domains.getDomain(r.URL); !ok {
l.Warn("host doesn't match any configured domains", "host", r.URL.Host)
http.Error(w, "Not authorized", http.StatusUnauthorized)
return
Expand All @@ -129,25 +129,15 @@ func (s *Server) redirectToAuth(w http.ResponseWriter, r *http.Request, l *slog.
http.Error(w, "Internal server error", http.StatusInternalServerError)
}

domain, ok := s.config.Domains.getDomain(r.URL)
domain, ok := s.Domains.getDomain(r.URL)
if !ok {
l.Error("invalid target host", "host", r.URL.Host)
http.Error(w, "Invalid target host", http.StatusUnauthorized)
return
}

h, ok := s.oauthHandlers[domain]
if !ok {
l.Error("invalid target domain", "domain", domain)
http.Error(w, "Invalid target domain", http.StatusUnauthorized)
return
}

// Redirect user to Google to select the account to be used to authenticate the request
authCodeURL := h.AuthCodeURL(encodedState,
oauth2.SetAuthURLParam("prompt", "select_account"),
)

authCodeURL := s.oauthHandlers[domain].AuthCodeURL(encodedState, oauth2.SetAuthURLParam("prompt", "select_account"))
l.Debug("redirecting ...", "authCodeURL", authCodeURL)
http.Redirect(w, r, authCodeURL, http.StatusTemporaryRedirect)
}
Expand All @@ -168,12 +158,12 @@ func (s *Server) authCallbackHandler(l *slog.Logger) http.HandlerFunc {
}

// we already validated the host vs the domain during the redirect
// since the state matches, we can trust the request to be valid
u, _ := url.Parse(redirectURL)
domain, _ := s.config.Domains.getDomain(u)
h, _ := s.oauthHandlers[domain]
domain, _ := s.Domains.getDomain(u)

// Use the "code" in the response to determine the user's email address.
user, err := h.GetUserEmailAddress(r.FormValue("code"))
user, err := s.oauthHandlers[domain].GetUserEmailAddress(r.FormValue("code"))
if err != nil {
l.Error("failed to log in to google", "err", err)
http.Error(w, "oauth2 failed", http.StatusBadGateway)
Expand All @@ -191,11 +181,11 @@ func (s *Server) authCallbackHandler(l *slog.Logger) http.HandlerFunc {
// GetUserEmailAddress successful. Add session cookie and redirect the user to the final destination.
sc := sessionCookie{
Email: user,
Expiry: time.Now().Add(s.config.Expiry),
Expiry: time.Now().Add(s.Config.Expiry),
}
s.sessionCookieHandler.saveCookie(sc)

http.SetCookie(w, s.makeCookie(sc.encode(s.config.Secret), domain))
http.SetCookie(w, s.makeCookie(sc.encode(s.Config.Secret), domain))
l.Info("user logged in. redirecting ...", "user", user, "url", redirectURL)
http.Redirect(w, r, redirectURL, http.StatusTemporaryRedirect)
}
Expand All @@ -213,7 +203,7 @@ func (s *Server) logoutHandler(l *slog.Logger) http.HandlerFunc {
}

// get the domain for the target
domain, _ := s.config.Domains.getDomain(r.URL)
domain, _ := s.Domains.getDomain(r.URL)

// Write a blank session cookie to override the current valid one.
http.SetCookie(w, s.makeCookie("", domain))
Expand All @@ -229,7 +219,7 @@ func (s *Server) makeCookie(value, domain string) *http.Cookie {
Value: value,
Domain: domain,
Path: "/",
Secure: !s.config.InsecureCookie,
Secure: !s.InsecureCookie,
HttpOnly: true,
}
}
Expand Down
2 changes: 1 addition & 1 deletion internal/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func TestServer_authHandler(t *testing.T) {
name: "valid domain with user info",
args: args{
host: "user:password@www.example.com",
cookie: s.makeSessionCookie("foo@example.com", config.Secret),
cookie: s.makeSessionCookie("foo@example.com", Config.Secret),
},
want: http.StatusOK,
user: "foo@example.com",
Expand Down

0 comments on commit 82f73d5

Please # to comment.