Skip to content

Commit

Permalink
refactor basic_auth_test to utilize table driven tests
Browse files Browse the repository at this point in the history
  • Loading branch information
eolson authored and aldas committed Oct 26, 2024
1 parent 822d11a commit 03c0236
Showing 1 changed file with 95 additions and 62 deletions.
157 changes: 95 additions & 62 deletions middleware/basic_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package middleware

import (
"encoding/base64"
"errors"
"net/http"
"net/http/httptest"
"strings"
Expand All @@ -16,78 +17,110 @@ import (

func TestBasicAuth(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()
c := e.NewContext(req, res)
f := func(u, p string, c echo.Context) (bool, error) {

mockValidator := func(u, p string, c echo.Context) (bool, error) {
if u == "joe" && p == "secret" {
return true, nil
}
return false, nil
}
h := BasicAuth(f)(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})

// Valid credentials
auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(t, h(c))

h = BasicAuthWithConfig(BasicAuthConfig{
Validator: f,
Realm: "someRealm",
})(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})

// Valid credentials
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(t, h(c))

// Case-insensitive header scheme
auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(t, h(c))

// Invalid credentials
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password"))
req.Header.Set(echo.HeaderAuthorization, auth)
he := h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code)
assert.Equal(t, basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate))
// Define the test cases
tests := []struct {
name string
authHeader string
expectedCode int
expectedAuth string
skipperResult bool
expectedErr bool
expectedErrMsg string
}{
{
name: "Valid credentials",
authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")),
expectedCode: http.StatusOK,
skipperResult: false,
},
{
name: "Case-insensitive header scheme",
authHeader: strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")),
expectedCode: http.StatusOK,
skipperResult: false,
},
{
name: "Invalid credentials",
authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password")),
expectedCode: http.StatusUnauthorized,
expectedAuth: basic + ` realm="someRealm"`,
skipperResult: false,
expectedErr: true,
expectedErrMsg: "Unauthorized",
},
{
name: "Invalid base64 string",
authHeader: basic + " invalidString",
expectedCode: http.StatusBadRequest,
skipperResult: false,
expectedErr: true,
expectedErrMsg: "Bad Request",
},
{
name: "Missing Authorization header",
expectedCode: http.StatusUnauthorized,
skipperResult: false,
expectedErr: true,
expectedErrMsg: "Unauthorized",
},
{
name: "Invalid Authorization header",
authHeader: base64.StdEncoding.EncodeToString([]byte("invalid")),
expectedCode: http.StatusUnauthorized,
skipperResult: false,
expectedErr: true,
expectedErrMsg: "Unauthorized",
},
{
name: "Skipped Request",
authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:skip")),
expectedCode: http.StatusOK,
skipperResult: true,
},
}

// Invalid base64 string
auth = basic + " invalidString"
req.Header.Set(echo.HeaderAuthorization, auth)
he = h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusBadRequest, he.Code)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

// Missing Authorization header
req.Header.Del(echo.HeaderAuthorization)
he = h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code)
req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()
c := e.NewContext(req, res)

// Invalid Authorization header
auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
req.Header.Set(echo.HeaderAuthorization, auth)
he = h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code)
if tt.authHeader != "" {
req.Header.Set(echo.HeaderAuthorization, tt.authHeader)
}

h = BasicAuthWithConfig(BasicAuthConfig{
Validator: f,
Realm: "someRealm",
Skipper: func(c echo.Context) bool {
return true
},
})(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
h := BasicAuthWithConfig(BasicAuthConfig{
Validator: mockValidator,
Realm: "someRealm",
Skipper: func(c echo.Context) bool {
return tt.skipperResult
},
})(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})

// Skipped Request
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:skip"))
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(t, h(c))
err := h(c)

if tt.expectedErr {
var he *echo.HTTPError
errors.As(err, &he)
assert.Equal(t, tt.expectedCode, he.Code)
if tt.expectedAuth != "" {
assert.Equal(t, tt.expectedAuth, res.Header().Get(echo.HeaderWWWAuthenticate))
}
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expectedCode, res.Code)
}
})
}
}

0 comments on commit 03c0236

Please # to comment.