diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 198ab28..14ff526 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -4,7 +4,7 @@ on: push: branches: - main - - redirect + - provider jobs: test: diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 5bbcc2d..6208ac7 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -4,7 +4,7 @@ on: push: branches-ignore: - main - - redirect + - provider jobs: test: diff --git a/README.md b/README.md index c7173a6..7cff787 100644 --- a/README.md +++ b/README.md @@ -40,18 +40,18 @@ For traefik-simple-auth, a valid cookie: If an incoming request does not contain a valid session cookie, the user needs to be authenticated: -* We forward the user to Google's login page, so the user can be authenticated; -* When the user has logged in, Google sends the request back to traefik-simple-auth, specifically to the address `/_oauth`; +* We forward the user to the auth provider's login page, so the user can be authenticated; +* When the user has logged in, the provider sends the request back to traefik-simple-auth, specifically to the address `/_oauth`; * This routes the request to traefik-simple-auth's authCallback handler; * The handler uses the request to retrieve the authenticated user's email address and see if it is part of the `users` whitelist; * If so, it creates a new session cookie, and redirects the user to the original destination, with the session cookie; * This results in the request being sent back to traefik-simple-auth, with the session cookie, so it passes and the request is sent to the final destination. Given the asynchronous nature of the handshake during the authentication, traefik-simple-auth needs to validate the request -received from Google, to protect against cross-site request forgery (CSRF). The approach is as follows: +received from the auth provider, to protect against cross-site request forgery (CSRF). The approach is as follows: -* When the authCallback handler forwards the user to Google, it passes a random 'state', that it associates with the original request (i.e. the final destination) -* When Google sends the request back to traefik-simple-auth, it passes the same 'state' with the request. +* When the authCallback handler forwards the user to the auth provider, it passes a random 'state', that it associates with the original request (i.e. the final destination) +* When the auth provider sends the request back to traefik-simple-auth, it passes the same 'state' with the request. * traefik-simple-auth only keeps the state (with the final destination) for 5 minutes, which should be ample time for the user to log in. ## Installation @@ -59,7 +59,7 @@ received from Google, to protect against cross-site request forgery (CSRF). The Container images are available on [ghcr.io](https://ghcr.io/clambin/traefik-simple-auth). ## Configuration -### Google +### Using Google as auth provider Head to https://console.developers.google.com and create a new project. Create new Credentials and select OAuth Client ID with "web application" as its application type. @@ -78,7 +78,7 @@ Note the Client ID and Client Secret as you will need to configure these for tra ### Traefik #### Middleware -With your Google credentials defined, set up a `forward-auth` middleware. This causes Traefik to forward each incoming +With your auth credentials defined, set up a `forward-auth` middleware. This causes Traefik to forward each incoming request for a router configured with this middleware for authentication. In Kubernetes, this can be done with the following manifest: @@ -101,8 +101,8 @@ This created a new middleware `traefik-simple-auth` that forwards incoming reque #### Ingress -To authenticate a user, traefik-simple-auth redirects the user to their Google login page. Upon successful login, Google -forwards the request to the redirectURL (as configured in section Google). You will therefore need an ingress to route +To authenticate a user, traefik-simple-auth redirects the user to the auth provider's login page. Upon successful login, +the provider forwards the request to the redirectURL (as configured in section Google). You therefore need an ingress to route the request to traefik-simple-auth: ``` @@ -127,7 +127,7 @@ spec: number: 8080 ``` -This forwards the Google request back to traefik-simple-auth. +This forwards the request request back to traefik-simple-auth. ### Running traefik-simple-auth @@ -135,14 +135,14 @@ traefik-simple-auth supports the following command-line arguments: ``` Usage: - -addr string + -addr string The address to listen on for HTTP requests (default ":8080") -auth-prefix string prefix to construct the authRedirect URL from the domain (default "auth") -client-id string - Google OAuth Client ID + OAuth2 Client ID -client-secret string - Google OAuth Client Secret + OAuth2 Client Secret -debug Enable debug mode -domains string @@ -153,10 +153,13 @@ Usage: Use insecure cookies -prom string The address to listen on for Prometheus scrape requests (default ":9090") + -provider string + The OAuth2 provider to use (default "google") -secret string Secret to use for authentication -users string Comma-separated list of usernames to login + ``` #### Option details @@ -169,6 +172,10 @@ Usage: Prefix used to construct the authorization URL from the domain. +- `provider` + + The auth provider to use. Currently, only "google" and "github" are supported. + - `client-id` The (hex-encoded) Google Client ID, found in the Google Credentials configuration. diff --git a/cmd/traefik-simple-auth/traefik-simple-auth.go b/cmd/traefik-simple-auth/traefik-simple-auth.go index 126afa9..182f568 100644 --- a/cmd/traefik-simple-auth/traefik-simple-auth.go +++ b/cmd/traefik-simple-auth/traefik-simple-auth.go @@ -25,8 +25,9 @@ var ( authPrefix = flag.String("auth-prefix", "auth", "prefix to construct the authRedirect URL from the domain") domains = flag.String("domains", "", "Comma-separated list of domains to allow access") users = flag.String("users", "", "Comma-separated list of usernames to login") - clientId = flag.String("client-id", "", "Google OAuth Client ID") - clientSecret = flag.String("client-secret", "", "Google OAuth Client Secret") + provider = flag.String("provider", "google", "The OAuth2 provider to use") + clientId = flag.String("client-id", "", "OAuth2 Client ID") + clientSecret = flag.String("client-secret", "", "OAuth2 Client Secret") secret = flag.String("secret", "", "Secret to use for authentication") version string = "change-me" @@ -75,6 +76,7 @@ func getConfiguration(l *slog.Logger) server.Config { InsecureCookie: *insecure, Domains: strings.Split(*domains, ","), Users: strings.Split(*users, ","), + Provider: *provider, ClientID: *clientId, ClientSecret: *clientSecret, AuthPrefix: *authPrefix, diff --git a/go.mod b/go.mod index 5c9bd87..ed86f99 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,6 @@ toolchain go1.22.2 require ( github.com/clambin/go-common/cache v0.4.0 github.com/clambin/go-common/http v0.4.3 - github.com/clambin/go-common/set v0.4.3 github.com/clambin/go-common/testutils v0.1.0 github.com/prometheus/client_golang v1.19.0 github.com/stretchr/testify v1.9.0 diff --git a/go.sum b/go.sum index faaec0e..f8cd651 100644 --- a/go.sum +++ b/go.sum @@ -6,14 +6,10 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/clambin/go-common/cache v0.3.0 h1:nLG0wDWnLxCAP6ih5A8zVCInN7YGBWlonBtNTPvJUA4= -github.com/clambin/go-common/cache v0.3.0/go.mod h1:kk8t1usGhQjSOdjE5YKEwWBJigZThxKIl1wYP7hujeE= github.com/clambin/go-common/cache v0.4.0 h1:PjyWbQye8pHDIDomRfRWsaCeCD4gVjs7ITQJoopUe0E= github.com/clambin/go-common/cache v0.4.0/go.mod h1:kk8t1usGhQjSOdjE5YKEwWBJigZThxKIl1wYP7hujeE= github.com/clambin/go-common/http v0.4.3 h1:XRXi7rE4lPGpK4cALfM8ADcVDadgYfpzluM/4irn1E0= github.com/clambin/go-common/http v0.4.3/go.mod h1:g2LMIgauEx/3wAIYNxrjM2AiKWNbODNlUXUinrWkbPY= -github.com/clambin/go-common/set v0.4.3 h1:Sm9lkAJsh82j40RDpfQIziHyHjwr07+KsQF6vgCVXm4= -github.com/clambin/go-common/set v0.4.3/go.mod h1:Q5GpBoM7B7abNV2Wzys+wQMInBHMoHyh/h0Cn2OmY4A= github.com/clambin/go-common/testutils v0.1.0 h1:/nGWaOCIhW+Ew1c2NU7GLY/YPb8dp9SV8+MTgWksAgk= github.com/clambin/go-common/testutils v0.1.0/go.mod h1:bV0j8D4zhNkleCeluFKLBeLQ0L/dqkxbaR/joLn8kzg= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= diff --git a/internal/server/server.go b/internal/server/server.go index ec0bb7d..50b7dce 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -5,7 +5,6 @@ import ( "github.com/clambin/traefik-simple-auth/pkg/oauth" "github.com/clambin/traefik-simple-auth/pkg/whitelist" "golang.org/x/oauth2" - "golang.org/x/oauth2/google" "log/slog" "net/http" "net/url" @@ -16,22 +15,18 @@ const OAUTHPath = "/_oauth" type Server struct { Config - oauthHandlers map[string]OAuthHandler + oauthHandlers map[string]oauth.Handler sessionCookieHandler stateHandler whitelist.Whitelist http.Handler } -type OAuthHandler interface { - AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string - GetUserEmailAddress(code string) (string, error) -} - type Config struct { Expiry time.Duration Secret []byte InsecureCookie bool + Provider string Domains Domains Users []string ClientID string @@ -40,17 +35,11 @@ type Config struct { } func New(config Config, l *slog.Logger) *Server { - oauthHandlers := make(map[string]OAuthHandler) + oauthHandlers := make(map[string]oauth.Handler) for _, domain := range config.Domains { - oauthHandlers[domain] = oauth.Handler{ - HTTPClient: http.DefaultClient, - Config: oauth2.Config{ - ClientID: config.ClientID, - ClientSecret: config.ClientSecret, - Endpoint: google.Endpoint, - RedirectURL: makeAuthURL(config.AuthPrefix, domain), - Scopes: []string{"https://www.googleapis.com/auth/userinfo.email"}, - }, + var err error + if oauthHandlers[domain], err = oauth.NewHandler(config.Provider, config.ClientID, config.ClientSecret, makeAuthURL(config.AuthPrefix, domain, OAUTHPath), l.With("oauth", config.Provider)); err != nil { + panic("unknown provider: " + config.Provider) } } s := Server{ @@ -63,7 +52,7 @@ func New(config Config, l *slog.Logger) *Server { sessions: cache.New[string, sessionCookie](config.Expiry, time.Minute), }, stateHandler: stateHandler{ - // 5 minutes should be enough for the user to log in to Google + // 5 minutes should be enough for the user to log in cache: cache.New[string, string](5*time.Minute, time.Minute), }, Whitelist: whitelist.New(config.Users), @@ -85,7 +74,7 @@ func (s *Server) authHandler(l *slog.Logger) http.HandlerFunc { c, err := r.Cookie(sessionCookieName) if err != nil || c.Value == "" { - // Client doesn't have a valid cookie. Redirect to Google to authenticate the user. + // Client doesn't have a valid cookie. Redirect to oauth provider to authenticate the user. // When the user is authenticated, authCallbackHandler generates a new valid cookie. l.Debug("no cookie found, redirecting ...") s.redirectToAuth(w, r, l) @@ -93,14 +82,14 @@ func (s *Server) authHandler(l *slog.Logger) http.HandlerFunc { } session, err := s.getSessionCookie(c) if err != nil { - // Client has an invalid cookie. Redirect to Google to authenticate the user. + // Client has an invalid cookie. Redirect to oauth provider to authenticate the user. // When the user is authenticated, authCallbackHandler generates a new valid cookie. l.Warn("invalid cookie. redirecting ...", "err", err) s.redirectToAuth(w, r, l) return } if session.expired() { - // Client has an expired cookie. Redirect to Google to authenticate the user. + // Client has an expired cookie. Redirect to oauth provider to authenticate the user. // When the user is authenticated, authCallbackHandler generates a new valid cookie. l.Debug("expired cookie. redirecting ...") s.redirectToAuth(w, r, l) @@ -136,7 +125,7 @@ func (s *Server) redirectToAuth(w http.ResponseWriter, r *http.Request, l *slog. return } - // Redirect user to Google to select the account to be used to authenticate the request + // Redirect user to oauth provider to select the account to be used to authenticate the request authCodeURL := s.oauthHandlers[domain].AuthCodeURL(encodedState, oauth2.SetAuthURLParam("prompt", "select_account")) l.Debug("redirecting ...", "authCodeURL", authCodeURL) http.Redirect(w, r, authCodeURL, http.StatusTemporaryRedirect) @@ -165,7 +154,7 @@ func (s *Server) authCallbackHandler(l *slog.Logger) http.HandlerFunc { // Use the "code" in the response to determine the user's email address. user, err := s.oauthHandlers[domain].GetUserEmailAddress(r.FormValue("code")) if err != nil { - l.Error("failed to log in to google", "err", err) + l.Error("failed to log in", "err", err) http.Error(w, "oauth2 failed", http.StatusBadGateway) return } @@ -225,7 +214,7 @@ func (s *Server) makeCookie(value, domain string) *http.Cookie { } // makeAuthURL returns the auth URL for a given domain -func makeAuthURL(authPrefix, domain string) string { +func makeAuthURL(authPrefix, domain, OAUTHPath string) string { var dot string if domain != "" && domain[0] != '.' { dot = "." diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 0e113a5..1094bee 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -2,6 +2,7 @@ package server import ( "errors" + "github.com/clambin/traefik-simple-auth/pkg/oauth" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/oauth2" @@ -79,10 +80,11 @@ func TestServer_authHandler(t *testing.T) { } config := Config{ - Domains: Domains{"example.com"}, - Secret: []byte("secret"), - Expiry: time.Hour, - Users: []string{"foo@example.com"}, + Domains: Domains{"example.com"}, + Secret: []byte("secret"), + Expiry: time.Hour, + Users: []string{"foo@example.com"}, + Provider: "google", } s := New(config, slog.Default()) @@ -137,10 +139,11 @@ func Benchmark_authHandler(b *testing.B) { func TestServer_authHandler_expiry(t *testing.T) { config := Config{ - Expiry: 500 * time.Millisecond, - Secret: []byte("secret"), - Domains: []string{"example.com"}, - Users: []string{"foo@example.com"}, + Expiry: 500 * time.Millisecond, + Secret: []byte("secret"), + Domains: []string{"example.com"}, + Users: []string{"foo@example.com"}, + Provider: "google", } s := New(config, slog.Default()) sc := sessionCookie{Email: "foo@example.com", Expiry: time.Now().Add(config.Expiry)} @@ -156,7 +159,6 @@ func TestServer_authHandler_expiry(t *testing.T) { } func TestServer_redirectToAuth(t *testing.T) { - tests := []struct { name string target string @@ -187,6 +189,7 @@ func TestServer_redirectToAuth(t *testing.T) { ClientSecret: "secret", Domains: Domains{"example.com", ".example.org"}, AuthPrefix: "auth", + Provider: "google", } s := New(config, slog.Default()) @@ -218,9 +221,10 @@ func TestServer_redirectToAuth(t *testing.T) { func TestServer_LogoutHandler(t *testing.T) { config := Config{ - Secret: []byte("secret"), - Domains: Domains{"example.com"}, - Expiry: time.Hour, + Secret: []byte("secret"), + Domains: Domains{"example.com"}, + Expiry: time.Hour, + Provider: "google", } s := New(config, slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))) @@ -284,7 +288,7 @@ func TestServer_AuthCallbackHandler(t *testing.T) { t.Parallel() l := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug})) - s := New(Config{Users: []string{"foo@example.com"}, Domains: Domains{"example.com"}}, l) + s := New(Config{Users: []string{"foo@example.com"}, Domains: Domains{"example.com"}, Provider: "google"}, l) s.oauthHandlers["example.com"] = &fakeOauthHandler{email: tt.oauthUser, err: tt.oauthErr} state := tt.state @@ -318,7 +322,7 @@ func makeHTTPRequest(method, host, uri string) *http.Request { return req } -var _ OAuthHandler = fakeOauthHandler{} +var _ oauth.Handler = fakeOauthHandler{} type fakeOauthHandler struct { email string diff --git a/pkg/oauth/github.go b/pkg/oauth/github.go new file mode 100644 index 0000000..7359ec9 --- /dev/null +++ b/pkg/oauth/github.go @@ -0,0 +1,95 @@ +package oauth + +import ( + "encoding/json" + "fmt" + "golang.org/x/oauth2" + "golang.org/x/oauth2/github" + "log/slog" + "net/http" +) + +type GitHubHandler struct { + BaseHandler +} + +func NewGitHubHandler(clientID, clientSecret, authURL string, logger *slog.Logger) *GitHubHandler { + return &GitHubHandler{ + BaseHandler: BaseHandler{ + HTTPClient: http.DefaultClient, + Config: oauth2.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + Endpoint: github.Endpoint, + RedirectURL: authURL, + Scopes: []string{"user:email", "read:user"}, + }, + Logger: logger, + }, + } +} + +func (h GitHubHandler) GetUserEmailAddress(code string) (string, error) { + // Use code to get token and get user info from GitHub + token, err := h.getAccessToken(code) + if err != nil { + return "", err + } + + email, err := h.getAddress(token) + if email != "" && err == nil { + return email, nil + } + h.Logger.Debug("No email address found. Using user public profile instead", "err", err) + return h.getAddressFromProfile(token) +} + +func (h GitHubHandler) getAddress(token *oauth2.Token) (string, error) { + resp, err := h.do("https://api.github.com/user/emails", token) + if err != nil { + return "", fmt.Errorf("failed to get user info: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + var users []struct { + Email string `json:"email"` + Primary bool `json:"primary"` + } + + if err = json.NewDecoder(resp.Body).Decode(&users); err != nil { + return "", err + } + + if len(users) == 0 { + return "", fmt.Errorf("no email addresses found") + } + + for _, user := range users { + if user.Primary { + return user.Email, nil + } + } + // fallback in case no primary email: return the first one + h.Logger.Warn("No primary email address found. Defaulting to first email address instead.", "email", users[0].Email) + return users[0].Email, nil +} + +func (h GitHubHandler) getAddressFromProfile(token *oauth2.Token) (string, error) { + resp, err := h.do("https://api.github.com/user", token) + defer func() { _ = resp.Body.Close() }() + + var user struct { + Email string `json:"email"` + } + + err = json.NewDecoder(resp.Body).Decode(&user) + return user.Email, err +} + +func (h GitHubHandler) do(url string, token *oauth2.Token) (*http.Response, error) { + req, _ := http.NewRequest(http.MethodGet, url, nil) + token.SetAuthHeader(req) + req.Header.Set("Accept", "application/vnd.github+json") + + return h.HTTPClient.Do(req) +} diff --git a/pkg/oauth/github_test.go b/pkg/oauth/github_test.go new file mode 100644 index 0000000..9e9d5c3 --- /dev/null +++ b/pkg/oauth/github_test.go @@ -0,0 +1,68 @@ +package oauth + +import ( + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "io" + "log/slog" + "net/http" + "strings" + "testing" +) + +func TestGitHubHandler_GetUserEmailAddress(t *testing.T) { + tests := []struct { + name string + emailResponse string + userResponse string + }{ + { + name: "primary email", + emailResponse: `[ {"email":"bar@example.com","primary":false}, {"email":"foo@example.com","primary":true} ]`, + }, + { + name: "no primary email", + emailResponse: `[ {"email":"foo@example.com","primary":false}, {"email":"bar@example.com","primary":false} ]`, + }, + { + name: "no emails", + emailResponse: `[ ]`, + userResponse: `{"email":"foo@example.com"}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + s := oauthServer{ + roundTrip: func(r *http.Request) (*http.Response, error) { + var resp http.Response + switch r.URL.Path { + case "/login/oauth/access_token": + resp.StatusCode = http.StatusOK + resp.Body = io.NopCloser(strings.NewReader(`{"access_token":"123456789"}`)) + case "/user/emails": + resp.StatusCode = http.StatusOK + resp.Body = io.NopCloser(strings.NewReader(tt.emailResponse)) + case "/user": + resp.StatusCode = http.StatusOK + resp.Body = io.NopCloser(strings.NewReader(tt.userResponse)) + default: + resp.StatusCode = http.StatusNotFound + resp.Body = io.NopCloser(strings.NewReader(`{"path":"` + r.URL.Path + `"}`)) + } + return &resp, nil + }, + } + + h, _ := NewHandler("github", "1234", "1234567", "https://auth.example.com/_oauth", slog.Default()) + h.(*GitHubHandler).HTTPClient = &http.Client{Transport: s} + + user, err := h.GetUserEmailAddress("abcd1234") + require.NoError(t, err) + assert.Equal(t, "foo@example.com", user) + + }) + } +} diff --git a/pkg/oauth/google.go b/pkg/oauth/google.go new file mode 100644 index 0000000..f296a09 --- /dev/null +++ b/pkg/oauth/google.go @@ -0,0 +1,55 @@ +package oauth + +import ( + "encoding/json" + "fmt" + "golang.org/x/oauth2" + "golang.org/x/oauth2/google" + "log/slog" + "net/http" +) + +type GoogleHandler struct { + BaseHandler +} + +func NewGoogleHandler(clientID, clientSecret, authURL string, logger *slog.Logger) *GoogleHandler { + return &GoogleHandler{ + BaseHandler: BaseHandler{ + HTTPClient: http.DefaultClient, + Config: oauth2.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + Endpoint: google.Endpoint, + RedirectURL: authURL, + Scopes: []string{"https://www.googleapis.com/auth/userinfo.email"}, + }, + Logger: logger, + }, + } +} + +const googleUserURL = "https://openidconnect.googleapis.com/v1/userinfo" + +func (h GoogleHandler) GetUserEmailAddress(code string) (string, error) { + // Use code to get token and get user info from Google. + token, err := h.getAccessToken(code) + if err != nil { + return "", err + } + + req, _ := http.NewRequest(http.MethodGet, googleUserURL, nil) + token.SetAuthHeader(req) + + response, err := h.HTTPClient.Do(req) + if err != nil { + return "", fmt.Errorf("failed getting user info: %s", err.Error()) + } + defer func() { _ = response.Body.Close() }() + + var user struct { + Email string `json:"email"` + } + err = json.NewDecoder(response.Body).Decode(&user) + return user.Email, err +} diff --git a/pkg/oauth/google_test.go b/pkg/oauth/google_test.go new file mode 100644 index 0000000..3ae86f2 --- /dev/null +++ b/pkg/oauth/google_test.go @@ -0,0 +1,51 @@ +package oauth + +import ( + "encoding/json" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "io" + "log/slog" + "net/http" + "strings" + "testing" +) + +func TestGoogleHandler_GetUserEmailAddress(t *testing.T) { + s := oauthServer{ + roundTrip: func(r *http.Request) (*http.Response, error) { + var resp http.Response + switch r.URL.Path { + case "/token": + resp.StatusCode = http.StatusOK + resp.Body = io.NopCloser(strings.NewReader(`{"access_token":"123456789"}`)) + case "/v1/userinfo": + resp.StatusCode = http.StatusOK + resp.Body = io.NopCloser(strings.NewReader(`{"email":"foo@example.com"}`)) + default: + resp.StatusCode = http.StatusNotFound + resp.Body = io.NopCloser(strings.NewReader(`{"path":"` + r.URL.Path + `"}`)) + } + return &resp, nil + }, + } + + h, _ := NewHandler("google", "1234", "1234567", "https://auth.example.com/_oauth", slog.Default()) + h.(*GoogleHandler).HTTPClient = &http.Client{Transport: s} + + user, err := h.GetUserEmailAddress("abcd1234") + require.NoError(t, err) + assert.Equal(t, "foo@example.com", user) +} + +func TestHandler_userInfoEndpoint(t *testing.T) { + resp, err := http.Get("https://accounts.google.com/.well-known/openid-configuration") + require.NoError(t, err) + defer func() { _ = resp.Body.Close() }() + + var response struct { + UserInfoEndpoint string `json:"userinfo_endpoint"` + } + assert.NoError(t, json.NewDecoder(resp.Body).Decode(&response)) + assert.Equal(t, googleUserURL, response.UserInfoEndpoint, "google userinfo endpoint has changed") +} diff --git a/pkg/oauth/oauth.go b/pkg/oauth/oauth.go index 63514fe..3e87549 100644 --- a/pkg/oauth/oauth.go +++ b/pkg/oauth/oauth.go @@ -2,52 +2,35 @@ package oauth import ( "context" - "encoding/json" "fmt" "golang.org/x/oauth2" + "log/slog" "net/http" ) -type Handler struct { - oauth2.Config - HTTPClient *http.Client -} - -func (o Handler) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string { - return o.Config.AuthCodeURL(state, opts...) +type Handler interface { + AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string + GetUserEmailAddress(code string) (string, error) } -func (o Handler) GetUserEmailAddress(code string) (string, error) { - // Use code to get token and get user info from Google. - token, err := o.getAccessToken(code) - if err == nil { - return o.getUserEmailAddress(token) +func NewHandler(provider, clientID, clientSecret, authURL string, logger *slog.Logger) (Handler, error) { + switch provider { + case "google": + return NewGoogleHandler(clientID, clientSecret, authURL, logger), nil + case "github": + return NewGitHubHandler(clientID, clientSecret, authURL, logger), nil + default: + return nil, fmt.Errorf("unknown provider: %s", provider) } - return "", err } -func (o Handler) getAccessToken(code string) (string, error) { - ctx := context.WithValue(context.Background(), oauth2.HTTPClient, o.HTTPClient) - var accessToken string - token, err := o.Config.Exchange(ctx, code) - if err == nil { - accessToken = token.AccessToken - } - return accessToken, err +type BaseHandler struct { + oauth2.Config + HTTPClient *http.Client + Logger *slog.Logger } -const userInfoURL = "https://openidconnect.googleapis.com/v1/userinfo" - -func (o Handler) getUserEmailAddress(token string) (string, error) { - response, err := o.HTTPClient.Get(userInfoURL + "?access_token=" + token) - if err != nil { - return "", fmt.Errorf("failed getting user info: %s", err.Error()) - } - defer func() { _ = response.Body.Close() }() - - var user struct { - Email string `json:"email"` - } - err = json.NewDecoder(response.Body).Decode(&user) - return user.Email, err +func (h BaseHandler) getAccessToken(code string) (*oauth2.Token, error) { + ctx := context.WithValue(context.Background(), oauth2.HTTPClient, h.HTTPClient) + return h.Config.Exchange(ctx, code) } diff --git a/pkg/oauth/oauth_test.go b/pkg/oauth/oauth_test.go index 7aaba7c..6fde109 100644 --- a/pkg/oauth/oauth_test.go +++ b/pkg/oauth/oauth_test.go @@ -1,84 +1,31 @@ package oauth import ( - "encoding/json" - "fmt" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/oauth2" - "golang.org/x/oauth2/google" - "io" + "log/slog" "net/http" "net/url" - "strings" "testing" ) func TestHandler_AuthCodeURL(t *testing.T) { - o := Handler{ - Config: oauth2.Config{ - ClientID: "CLIENT_ID", - ClientSecret: "CLIENT_SECRET", - Endpoint: google.Endpoint, - RedirectURL: "http://localhost", - Scopes: []string{"https://www.googleapis.com/auth/userinfo.email"}, - }, - } + h, _ := NewHandler("google", "CLIENT_ID", "CLIENT_SECRET", "https://auth/example.com/_oauth", slog.Default()) - u, err := url.Parse(o.AuthCodeURL("state", oauth2.SetAuthURLParam("prompt", "select_profile"))) + u, err := url.Parse(h.AuthCodeURL("state", oauth2.SetAuthURLParam("prompt", "select_profile"))) require.NoError(t, err) q := u.Query() assert.Equal(t, "state", q.Get("state")) assert.Equal(t, "select_profile", q.Get("prompt")) } -func TestHandler_GetUserEmailAddress(t *testing.T) { - s := oauthServer{} - o := Handler{ - HTTPClient: &http.Client{Transport: s}, - Config: oauth2.Config{ - ClientID: "1234", - ClientSecret: "1234567", - Endpoint: google.Endpoint, - RedirectURL: "/", - Scopes: []string{"https://www.googleapis.com/auth/userinfo.email"}, - }, - } - - user, err := o.GetUserEmailAddress("abcd1234") - require.NoError(t, err) - assert.Equal(t, "foo@example.com", user) -} - -func TestHandler_userInfoEndpoint(t *testing.T) { - resp, err := http.Get("https://accounts.google.com/.well-known/openid-configuration") - require.NoError(t, err) - defer func() { _ = resp.Body.Close() }() - - var response struct { - UserInfoEndpoint string `json:"userinfo_endpoint"` - } - assert.NoError(t, json.NewDecoder(resp.Body).Decode(&response)) - assert.Equal(t, userInfoURL, response.UserInfoEndpoint, "google userinfo endpoint has changed") -} - var _ http.RoundTripper = &oauthServer{} -type oauthServer struct{} +type oauthServer struct { + roundTrip func(*http.Request) (*http.Response, error) +} func (o oauthServer) RoundTrip(r *http.Request) (*http.Response, error) { - var resp http.Response - switch r.URL.Path { - case "/token": - resp.StatusCode = http.StatusOK - resp.Body = io.NopCloser(strings.NewReader(`{"access_token":"123456789"}`)) - case "/v1/userinfo": - resp.StatusCode = http.StatusOK - resp.Body = io.NopCloser(strings.NewReader(`{"email":"foo@example.com"}`)) - default: - fmt.Printf("Unsupported path: %v\n", r.URL.Path) - resp.StatusCode = http.StatusNotFound - resp.Body = io.NopCloser(strings.NewReader(`{"path":"` + r.URL.Path + `"}`)) - } - return &resp, nil + return o.roundTrip(r) }