diff --git a/AUTHORS b/AUTHORS index 4e84c37..e8d6043 100644 --- a/AUTHORS +++ b/AUTHORS @@ -2,6 +2,7 @@ # Please keep the list sorted. adiabatic +Florian D. Loch Google LLC (https://opensource.google.com) jamesgroat Joshua Carp diff --git a/csrf.go b/csrf.go index f21e0a2..21223d4 100644 --- a/csrf.go +++ b/csrf.go @@ -274,16 +274,22 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } - // If the token returned from the session store is nil for non-idempotent - // ("unsafe") methods, call the error handler. - if realToken == nil { + // Retrieve the combined token (pad + masked) token... + maskedToken, err := cs.requestToken(r) + if err != nil { + r = envError(r, ErrBadToken) + cs.opts.ErrorHandler.ServeHTTP(w, r) + return + } + + if maskedToken == nil { r = envError(r, ErrNoToken) cs.opts.ErrorHandler.ServeHTTP(w, r) return } - // Retrieve the combined token (pad + masked) token and unmask it. - requestToken := unmask(cs.requestToken(r)) + // ... and unmask it. + requestToken := unmask(maskedToken) // Compare the request token against the real token if !compareTokens(requestToken, realToken) { diff --git a/csrf_test.go b/csrf_test.go index 0841d5f..a97fb10 100644 --- a/csrf_test.go +++ b/csrf_test.go @@ -374,6 +374,48 @@ func TestWithReferer(t *testing.T) { } } +// Requests without a token should fail with ErrNoToken. +func TestNoTokenProvided(t *testing.T) { + var finalErr error + + s := http.NewServeMux() + p := Protect(testKey, ErrorHandler(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + finalErr = FailureReason(r) + })))(s) + + var token string + s.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + token = Token(r) + })) + + // Obtain a CSRF cookie via a GET request. + r, err := http.NewRequest("GET", "http://www.gorillatoolkit.org/", nil) + if err != nil { + t.Fatal(err) + } + + rr := httptest.NewRecorder() + p.ServeHTTP(rr, r) + + // POST the token back in the header. + r, err = http.NewRequest("POST", "http://www.gorillatoolkit.org/", nil) + if err != nil { + t.Fatal(err) + } + + setCookie(rr, r) + // By accident we use the wrong header name for the token... + r.Header.Set("X-CSRF-nekot", token) + r.Header.Set("Referer", "http://www.gorillatoolkit.org/") + + rr = httptest.NewRecorder() + p.ServeHTTP(rr, r) + + if finalErr != nil && finalErr != ErrNoToken { + t.Fatalf("middleware failed to return correct error: got '%v' want '%v'", finalErr, ErrNoToken) + } +} + func setCookie(rr *httptest.ResponseRecorder, r *http.Request) { r.Header.Set("Cookie", rr.Header().Get("Set-Cookie")) } diff --git a/helpers.go b/helpers.go index 3dacfd2..c19dc47 100644 --- a/helpers.go +++ b/helpers.go @@ -105,7 +105,7 @@ func unmask(issued []byte) []byte { // requestToken returns the issued token (pad + masked token) from the HTTP POST // body or HTTP header. It will return nil if the token fails to decode. -func (cs *csrf) requestToken(r *http.Request) []byte { +func (cs *csrf) requestToken(r *http.Request) ([]byte, error) { // 1. Check the HTTP header first. issued := r.Header.Get(cs.opts.RequestHeader) @@ -123,14 +123,19 @@ func (cs *csrf) requestToken(r *http.Request) []byte { } } + // Return nil (equivalent to empty byte slice) if no token was found + if issued == "" { + return nil, nil + } + // Decode the "issued" (pad + masked) token sent in the request. Return a // nil byte slice on a decoding error (this will fail upstream). decoded, err := base64.StdEncoding.DecodeString(issued) if err != nil { - return nil + return nil, err } - return decoded + return decoded, nil } // generateRandomBytes returns securely generated random bytes.