Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add ability to pass a custom allowed hosts function #82

Merged
9 changes: 9 additions & 0 deletions secure.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ const (
// SSLHostFunc a type whose pointer is the type of field `SSLHostFunc` of `Options` struct
type SSLHostFunc func(host string) (newHost string)

// AllowedHostsFunc a custom function type that returns a list of strings used in place of AllowedHosts list
type AllowedHostsFunc func() []string

func defaultBadHostHandler(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Bad Host", http.StatusInternalServerError)
}
Expand Down Expand Up @@ -90,6 +93,8 @@ type Options struct {
CrossOriginOpenerPolicy string
// SSLHost is the host name that is used to redirect http requests to https. Default is "", which indicates to use the same host.
SSLHost string
// AllowedHostsFunc is a custom function that returns a list of fully qualified domain names that are allowed. If set, AllowedHosts will be ignored
AllowedHostsFunc AllowedHostsFunc
// AllowedHosts is a list of fully qualified domain names that are allowed. Default is empty list, which allows any and all host names.
AllowedHosts []string
// AllowedHostsAreRegex determines, if the provided slice contains valid regular expressions. If this flag is set to true, every request's
Expand Down Expand Up @@ -290,6 +295,10 @@ func (s *Secure) processRequest(w http.ResponseWriter, r *http.Request) (http.He
}

// Allowed hosts check.
if s.opt.AllowedHostsFunc != nil {
s.opt.AllowedHosts = s.opt.AllowedHostsFunc()
}

if len(s.opt.AllowedHosts) > 0 && !s.opt.IsDevelopment {
isGoodHost := false
if s.opt.AllowedHostsAreRegex {
Expand Down
31 changes: 31 additions & 0 deletions secure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1448,6 +1448,37 @@ func TestMultipleCustomSecureContextKeys(t *testing.T) {
expect(t, s2Headers.Get(featurePolicyHeader), s2.opt.FeaturePolicy)
}

func TestAllowHostsFunc(t *testing.T) {
s := New(Options{
AllowedHostsFunc: func() []string { return []string{"www.example.com"} },
})

res := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/foo", nil)
req.Host = "www.example.com"

s.Handler(myHandler).ServeHTTP(res, req)

expect(t, res.Code, http.StatusOK)
expect(t, res.Body.String(), `bar`)
}

func TestAllowHostsFuncIgnoreAllowedHostsList(t *testing.T) {
s := New(Options{
AllowedHosts: []string{"www.blocked.com"},
AllowedHostsFunc: func() []string { return []string{"www.allow.com"} },
})

res := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/foo", nil)
req.Host = "www.allow.com"

s.Handler(myHandler).ServeHTTP(res, req)

expect(t, res.Code, http.StatusOK)
expect(t, res.Body.String(), `bar`)
}

/* Test Helpers */
func expect(t *testing.T, a interface{}, b interface{}) {
if a != b {
Expand Down