From 2e53ba0c329f9fb2b723725a4b6dafe18a4b9d61 Mon Sep 17 00:00:00 2001 From: Christophe Lambin Date: Sat, 27 Apr 2024 00:36:38 +0200 Subject: [PATCH] feat: empty user list matches any user --- pkg/oauth/oauth.go | 2 +- pkg/state/state.go | 8 ++++---- pkg/whitelist/whitelist.go | 8 ++++++++ pkg/whitelist/whitelist_test.go | 5 +++-- 4 files changed, 16 insertions(+), 7 deletions(-) diff --git a/pkg/oauth/oauth.go b/pkg/oauth/oauth.go index 154960d..e7a6fc3 100644 --- a/pkg/oauth/oauth.go +++ b/pkg/oauth/oauth.go @@ -26,7 +26,7 @@ func NewHandler(provider, clientID, clientSecret, authURL string, logger *slog.L } } -// BaseHandler implements the provider-agnostic part of a Handler. +// BaseHandler implements the generic part of a Handler. type BaseHandler struct { oauth2.Config HTTPClient *http.Client diff --git a/pkg/state/state.go b/pkg/state/state.go index 1d55dc0..5ca7a6c 100644 --- a/pkg/state/state.go +++ b/pkg/state/state.go @@ -13,13 +13,13 @@ const stateSize = 32 // before redirecting to the oauth provider, we generate a random state. During callback, we then check if the oauth provider // sent us back the same state. The state is maintained for a limited amount of time to prevent (very unlikely) replay attacks. type Store[T any] struct { - cache *cache.Cache[string, T] + values *cache.Cache[string, T] } // New creates a new state Store func New[T any](retention time.Duration) Store[T] { return Store[T]{ - cache: cache.New[string, T](retention, time.Minute), + values: cache.New[string, T](retention, time.Minute), } } @@ -29,11 +29,11 @@ func (s Store[T]) Add(value T) string { // theoretically this could fail, but in practice this will never happen. _, _ = rand.Read(state) encodedState := hex.EncodeToString(state) - s.cache.Add(encodedState, value) + s.values.Add(encodedState, value) return encodedState } // Get checks if the state exists and returns the associated value func (s Store[T]) Get(state string) (T, bool) { - return s.cache.Get(state) + return s.values.Get(state) } diff --git a/pkg/whitelist/whitelist.go b/pkg/whitelist/whitelist.go index d876a13..50f0f53 100644 --- a/pkg/whitelist/whitelist.go +++ b/pkg/whitelist/whitelist.go @@ -19,6 +19,14 @@ func (w Whitelist) Contains(email string) bool { return ok } +// Match returns true if the email address is on the whitelist, or if the whitelist is empty +func (w Whitelist) Match(email string) bool { + if len(w) == 0 { + return true + } + return w.Contains(email) +} + func (w Whitelist) list() []string { list := make([]string, 0, len(w)) for email := range w { diff --git a/pkg/whitelist/whitelist_test.go b/pkg/whitelist/whitelist_test.go index ec631e6..cefa3ed 100644 --- a/pkg/whitelist/whitelist_test.go +++ b/pkg/whitelist/whitelist_test.go @@ -28,7 +28,8 @@ func Test_whitelist(t *testing.T) { { name: "empty", emails: []string{}, - want: assert.False, + email: "foo@example.com", + want: assert.True, }, { name: "case-insensitive", @@ -43,7 +44,7 @@ func Test_whitelist(t *testing.T) { t.Parallel() list := New(tt.emails) - tt.want(t, list.Contains(tt.email)) + tt.want(t, list.Match(tt.email)) sortedList := list.list() slices.Sort(sortedList)