Skip to content

Commit

Permalink
refactor: add domain as authenticator attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
clambin committed Feb 2, 2025
1 parent ca76229 commit e158a6c
Show file tree
Hide file tree
Showing 9 changed files with 37 additions and 31 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ on:
push:
branches:
- main
- pprof
- refactor
permissions:
contents: read
packages: write
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ on:
push:
branches-ignore:
- main
- pprof
- refactor
permissions:
contents: read
jobs:
Expand Down
12 changes: 7 additions & 5 deletions internal/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,27 @@ import (
// Authenticator creates and validate JWT tokens inside a http.Cookie.
type Authenticator struct {
CookieName string
Domain string
Secret []byte
Expiration time.Duration
parser *jwt.Parser
}

func New(cookieName string, secret []byte, expiration time.Duration) *Authenticator {
func New(cookieName string, domain string, secret []byte, expiration time.Duration) *Authenticator {
return &Authenticator{
CookieName: cookieName,
Domain: domain,
Secret: secret,
Expiration: expiration,
parser: jwt.NewParser(jwt.WithValidMethods([]string{"HS256"})),
}
}

// CookieWithSignedToken returns a http.Cookie with a signed token.
func (a *Authenticator) CookieWithSignedToken(userID string, domain string) (c *http.Cookie, err error) {
func (a *Authenticator) CookieWithSignedToken(userID string) (c *http.Cookie, err error) {
var token string
if token, err = a.makeSignedToken(userID); err == nil {
c = a.Cookie(token, a.Expiration, domain)
c = a.Cookie(token, a.Expiration)
}
return c, err
}
Expand All @@ -50,13 +52,13 @@ func (a *Authenticator) makeSignedToken(userID string) (string, error) {
}

// Cookie returns a new http.Cookie for the provided token, expiration time and domain.
func (a *Authenticator) Cookie(token string, expiration time.Duration, domain string) *http.Cookie {
func (a *Authenticator) Cookie(token string, expiration time.Duration) *http.Cookie {
return &http.Cookie{
Name: a.CookieName,
Value: token,
MaxAge: int(expiration.Seconds()),
Path: "/",
Domain: domain,
Domain: a.Domain,
HttpOnly: true,
Secure: true,
//SameSite: http.SameSiteStrictMode,
Expand Down
16 changes: 8 additions & 8 deletions internal/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func TestAuthenticator_Authenticate(t *testing.T) {
{
name: "valid cookie",
cookie: func(a *Authenticator) *http.Cookie {
c, _ := a.CookieWithSignedToken("foo@example.com", "example.com")
c, _ := a.CookieWithSignedToken("foo@example.com")
return c
},
err: assert.NoError,
Expand All @@ -29,7 +29,7 @@ func TestAuthenticator_Authenticate(t *testing.T) {
{
name: "empty cookie",
cookie: func(a *Authenticator) *http.Cookie {
c, _ := a.CookieWithSignedToken("", "example.com")
c, _ := a.CookieWithSignedToken("")
c.Value = ""
return c
},
Expand All @@ -43,16 +43,16 @@ func TestAuthenticator_Authenticate(t *testing.T) {
name: "expired cookie",
cookie: func(a *Authenticator) *http.Cookie {
a.Expiration = -time.Hour
c, _ := a.CookieWithSignedToken("foo@example.com", "example.com")
c, _ := a.CookieWithSignedToken("foo@example.com")
return c
},
err: assert.Error,
},
{
name: "invalid HMAC",
cookie: func(a *Authenticator) *http.Cookie {
b := New(a.CookieName, []byte("wrong-secret"), a.Expiration)
c, _ := b.CookieWithSignedToken("foo@example.com", "example.com")
b := New(a.CookieName, "example.com", []byte("wrong-secret"), a.Expiration)
c, _ := b.CookieWithSignedToken("foo@example.com")
return c
},
err: assert.Error,
Expand All @@ -68,7 +68,7 @@ func TestAuthenticator_Authenticate(t *testing.T) {
}
token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
unsignedToken, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
return a.Cookie(unsignedToken, a.Expiration, "example.com")
return a.Cookie(unsignedToken, a.Expiration)
},
err: assert.Error,
},
Expand All @@ -82,15 +82,15 @@ func TestAuthenticator_Authenticate(t *testing.T) {
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signedToken, _ := token.SignedString(a.Secret)
return a.Cookie(signedToken, a.Expiration, "example.com")
return a.Cookie(signedToken, a.Expiration)
},
err: assert.Error,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := New("_auth", []byte("secret"), time.Hour)
a := New("_auth", "example.com", []byte("secret"), time.Hour)
r := httptest.NewRequest(http.MethodGet, "/", nil)
if tt.cookie != nil {
r.AddCookie(tt.cookie(a))
Expand Down
2 changes: 1 addition & 1 deletion internal/cmd/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func run(ctx context.Context, cfg configuration.Configuration, r prometheus.Regi
// create the server
metrics := server.NewMetrics("traefik_simple_auth", "", prometheus.Labels{"provider": cfg.Provider})
r.MustRegister(metrics)
authenticator := auth.New(cfg.SessionCookieName, cfg.Secret, cfg.SessionExpiration)
authenticator := auth.New(cfg.SessionCookieName, string(cfg.Domain), cfg.Secret, cfg.SessionExpiration)
stateStore := state.New(cfg.StateConfiguration)
s := server.New(ctx, authenticator, stateStore, cfg, metrics, logger)

Expand Down
10 changes: 7 additions & 3 deletions internal/server/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,19 @@ func forwardAuthHandler(

// logoutHandler logs out the user: it removes the cookie from the cookie store and sends an empty Cookie to the user.
// This means that the user's next request has an invalid cookie, triggering a new oauth flow.
func logoutHandler(authenticator *auth.Authenticator, authorizer authorizer, logger *slog.Logger) http.Handler {
func logoutHandler(
authenticator *auth.Authenticator,
authorizer authorizer,
logger *slog.Logger,
) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logger.Debug("request received", "request", (*request)(r))

// verify that the request is authorized
user, err := authorizer.AuthorizeRequest(r)
if err == nil {
// Write a blank cookie to override/clear the current valid one.
http.SetCookie(w, authenticator.Cookie("", 0, string(authorizer.Domain)))
http.SetCookie(w, authenticator.Cookie("", 0))
logger.Info("user has been logged out", "user", user)
http.Error(w, "You have been logged out", http.StatusUnauthorized)
return
Expand Down Expand Up @@ -144,7 +148,7 @@ func oAuth2CallbackHandler(

// Valid user. Create a cookie and redirect the user to the final destination.
logger.Info("user logged in", "user", user, "url", targetURL)
c, _ := authenticator.CookieWithSignedToken(user, string(authorizer.Domain))
c, _ := authenticator.CookieWithSignedToken(user)
logger.Debug("sending cookie to user", "user", user, "cookie", c)
http.SetCookie(w, c)
http.Redirect(w, r, targetURL, http.StatusTemporaryRedirect)
Expand Down
6 changes: 3 additions & 3 deletions internal/server/metrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,23 @@ func TestServer_withMetrics(t *testing.T) {

// invalid domain: forbidden
r = testutils.ForwardAuthRequest(http.MethodGet, "https://example.org/foo")
c, _ := authenticator.CookieWithSignedToken("foo@example.com", "example.org")
c, _ := authenticator.CookieWithSignedToken("foo@example.com")
r.AddCookie(c)
w = httptest.NewRecorder()
s.ServeHTTP(w, r)
assert.Equal(t, http.StatusForbidden, w.Code)

// invalid user: forbidden
r = testutils.ForwardAuthRequest(http.MethodGet, "https://example.com/foo")
c, _ = authenticator.CookieWithSignedToken("bar@example.com", "example.com")
c, _ = authenticator.CookieWithSignedToken("bar@example.com")
r.AddCookie(c)
w = httptest.NewRecorder()
s.ServeHTTP(w, r)
assert.Equal(t, http.StatusForbidden, w.Code)

// valid user & domain, valid token: ok
r = testutils.ForwardAuthRequest(http.MethodGet, "https://example.com/foo")
c, _ = authenticator.CookieWithSignedToken("foo@example.com", "example.com")
c, _ = authenticator.CookieWithSignedToken("foo@example.com")
r.AddCookie(c)
w = httptest.NewRecorder()
s.ServeHTTP(w, r)
Expand Down
6 changes: 3 additions & 3 deletions internal/server/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ import (
)

func TestSessionExtractor(t *testing.T) {
a := auth.New("_auth", []byte("secret"), time.Hour)
a := auth.New("_auth", "example.com", []byte("secret"), time.Hour)
extractor := authenticate(a)
validCookie, _ := a.CookieWithSignedToken("foo@example.com", "example.com")
validCookie, _ := a.CookieWithSignedToken("foo@example.com")

tests := []struct {
name string
Expand All @@ -28,7 +28,7 @@ func TestSessionExtractor(t *testing.T) {
},
{
name: "bad cookie",
cookie: a.Cookie("invalid-token", time.Hour, "example.com"),
cookie: a.Cookie("invalid-token", time.Hour),
wantErr: require.Error,
},
{
Expand Down
12 changes: 6 additions & 6 deletions internal/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func TestServer_Panics(t *testing.T) {
Provider: "foobar",
Domain: domain.Domain("example.com"),
}
authenticator := auth.New("_traefik-simple-auth", []byte("secret"), time.Hour)
authenticator := auth.New("_traefik-simple-auth", "example.com", []byte("secret"), time.Hour)
stateStore := state.New(state.Configuration{CacheType: "memory", TTL: time.Minute})
assert.Panics(t, func() {
_ = New(context.Background(), authenticator, stateStore, cfg, nil, testutils.DiscardLogger)
Expand All @@ -39,7 +39,7 @@ func TestForwardAuthHandler(t *testing.T) {
t.Cleanup(cancel)

authenticator, _, _, handler := setupServer(ctx, t, nil)
validSession, _ := authenticator.CookieWithSignedToken("foo@example.com", "example.com")
validSession, _ := authenticator.CookieWithSignedToken("foo@example.com")

type args struct {
target string
Expand Down Expand Up @@ -115,7 +115,7 @@ func TestLogoutHandler(t *testing.T) {

t.Run("logging out clears the browser's cookie", func(t *testing.T) {
r := testutils.ForwardAuthRequest(http.MethodGet, "https://example.com/_oauth/logout")
c, _ := authenticator.CookieWithSignedToken("foo@example.com", "example.com")
c, _ := authenticator.CookieWithSignedToken("foo@example.com")
r.AddCookie(c)
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)
Expand Down Expand Up @@ -243,7 +243,7 @@ func setupServer(ctx context.Context, t *testing.T, metrics metrics.RequestMetri
Domain: domain.Domain("example.com"),
Whitelist: list,
}
authenticator := auth.New("_auth", []byte("secret"), time.Hour)
authenticator := auth.New("_auth", "example.com", []byte("secret"), time.Hour)
stateStore := state.New(state.Configuration{CacheType: "memory", TTL: time.Minute})
return authenticator, stateStore, oidcServer, New(ctx, authenticator, stateStore, cfg, metrics, testutils.DiscardLogger)
}
Expand Down Expand Up @@ -284,11 +284,11 @@ func BenchmarkForwardAuthHandler(b *testing.B) {
Whitelist: whiteList,
Provider: "google",
}
authenticator := auth.New("_traefik-simple-auth", []byte("secret"), time.Hour)
authenticator := auth.New("_traefik-simple-auth", "example.com", []byte("secret"), time.Hour)
states := state.New(state.Configuration{CacheType: "memory", TTL: time.Minute})
s := New(context.Background(), authenticator, states, config, nil, testutils.DiscardLogger)

c, _ := authenticator.CookieWithSignedToken("foo@example.com", "example.com")
c, _ := authenticator.CookieWithSignedToken("foo@example.com")
req := testutils.ForwardAuthRequest(http.MethodGet, "https://example.com/foo")
req.AddCookie(c)

Expand Down

0 comments on commit e158a6c

Please # to comment.