From dbdd898f16a873d785f95d78a2d2a1a999f9882f Mon Sep 17 00:00:00 2001 From: Drew MacInnis Date: Thu, 17 Nov 2016 22:43:06 -0500 Subject: [PATCH] Handle vault redirects --- vault/app-id_strategy.go | 22 ++++++++++--- vault/client.go | 49 ++++++++++++++++++----------- vault/http.go | 47 +++++++++++++++++++++++++++ vault/http_test.go | 68 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 163 insertions(+), 23 deletions(-) create mode 100644 vault/http.go create mode 100644 vault/http_test.go diff --git a/vault/app-id_strategy.go b/vault/app-id_strategy.go index 2d8b04b5c..06603f0b7 100644 --- a/vault/app-id_strategy.go +++ b/vault/app-id_strategy.go @@ -28,20 +28,34 @@ func NewAppIDAuthStrategy() *AppIDAuthStrategy { return nil } -// GetToken - log in to the app-id auth backend and return the client token -func (a *AppIDAuthStrategy) GetToken(addr *url.URL) (string, error) { +// GetHTTPClient configures the HTTP client with a timeout +func (a *AppIDAuthStrategy) GetHTTPClient() *http.Client { if a.hc == nil { a.hc = &http.Client{Timeout: time.Second * 5} } - client := a.hc + return a.hc +} + +// SetToken is a no-op for AppIDAuthStrategy as a token hasn't been acquired yet +func (a *AppIDAuthStrategy) SetToken(req *http.Request) { + // no-op +} +// Do wraps http.Client.Do +func (a *AppIDAuthStrategy) Do(req *http.Request) (*http.Response, error) { + hc := a.GetHTTPClient() + return hc.Do(req) +} + +// GetToken - log in to the app-id auth backend and return the client token +func (a *AppIDAuthStrategy) GetToken(addr *url.URL) (string, error) { buf := new(bytes.Buffer) json.NewEncoder(buf).Encode(&a) u := &url.URL{} *u = *addr u.Path = "/v1/auth/app-id/login" - res, err := client.Post(u.String(), "application/json; charset=utf-8", buf) + res, err := requestAndFollow(a, "POST", u, buf.Bytes()) if err != nil { return "", err } diff --git a/vault/client.go b/vault/client.go index 3a8cad658..7e80fa276 100644 --- a/vault/client.go +++ b/vault/client.go @@ -54,6 +54,31 @@ func getAuthStrategy() AuthStrategy { return nil } +// GetHTTPClient returns a client configured w/X-Vault-Token header +func (c *Client) GetHTTPClient() *http.Client { + if c.hc == nil { + c.hc = &http.Client{ + Timeout: time.Second * 5, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + c.SetToken(req) + return nil + }, + } + } + return c.hc +} + +// SetToken adds an X-Vault-Token header to the request +func (c *Client) SetToken(req *http.Request) { + req.Header.Set("X-Vault-Token", c.token) +} + +// Do wraps http.Client.Do +func (c *Client) Do(req *http.Request) (*http.Response, error) { + hc := c.GetHTTPClient() + return hc.Do(req) +} + // Login - log in to Vault with the discovered auth backend and save the token func (c *Client) Login() error { token, err := c.Auth.GetToken(c.Addr) @@ -72,17 +97,12 @@ func (c *Client) RevokeToken() { return } - if c.hc == nil { - c.hc = &http.Client{Timeout: time.Second * 5} - } - u := &url.URL{} *u = *c.Addr u.Path = "/v1/auth/token/revoke-self" - req, _ := http.NewRequest("POST", u.String(), nil) - req.Header.Set("X-Vault-Token", c.token) - res, err := c.hc.Do(req) + res, err := requestAndFollow(c, "POST", u, nil) + if err != nil { log.Println("Error while revoking Vault Token", err) } @@ -94,24 +114,15 @@ func (c *Client) RevokeToken() { func (c *Client) Read(path string) ([]byte, error) { path = normalizeURLPath(path) - if c.hc == nil { - c.hc = &http.Client{Timeout: time.Second * 5} - } u := &url.URL{} *u = *c.Addr u.Path = "/v1" + path - req, err := http.NewRequest("GET", u.String(), nil) - if err != nil { - return nil, err - } - req.Header.Set("X-Vault-Token", c.token) - res, err := c.hc.Do(req) + res, err := requestAndFollow(c, "GET", u, nil) if err != nil { return nil, err } - body, err := ioutil.ReadAll(res.Body) res.Body.Close() if err != nil { @@ -119,7 +130,7 @@ func (c *Client) Read(path string) ([]byte, error) { } if res.StatusCode != 200 { - err = fmt.Errorf("Unexpected HTTP status %d on Read from %s: %s", res.StatusCode, u, body) + err = fmt.Errorf("Unexpected HTTP status %d on Read from %s: %s", res.StatusCode, path, body) return nil, err } @@ -131,7 +142,7 @@ func (c *Client) Read(path string) ([]byte, error) { } if _, ok := response["data"]; !ok { - return nil, fmt.Errorf("Unexpected HTTP body on Read for %s: %s", u, body) + return nil, fmt.Errorf("Unexpected HTTP body on Read for %s: %s", path, body) } return json.Marshal(response["data"]) diff --git a/vault/http.go b/vault/http.go new file mode 100644 index 000000000..f8335796a --- /dev/null +++ b/vault/http.go @@ -0,0 +1,47 @@ +package vault + +import ( + "bytes" + "net/http" + "net/url" +) + +// httpClient +type httpClient interface { + GetHTTPClient() *http.Client + SetToken(req *http.Request) + Do(req *http.Request) (*http.Response, error) +} + +func requestAndFollow(hc httpClient, method string, u *url.URL, body []byte) (*http.Response, error) { + var res *http.Response + var err error + for attempts := 0; attempts < 2; attempts++ { + reader := bytes.NewReader(body) + req, err := http.NewRequest(method, u.String(), reader) + + if err != nil { + return nil, err + } + hc.SetToken(req) + if method == "POST" { + req.Header.Set("Content-Type", "application/json; charset=utf-8") + } + + res, err = hc.Do(req) + if err != nil { + return nil, err + } + if res.StatusCode == http.StatusTemporaryRedirect { + res.Body.Close() + location, errLocation := res.Location() + if errLocation != nil { + return nil, errLocation + } + u.Host = location.Host + } else { + break + } + } + return res, err +} diff --git a/vault/http_test.go b/vault/http_test.go new file mode 100644 index 000000000..e2ee4429b --- /dev/null +++ b/vault/http_test.go @@ -0,0 +1,68 @@ +package vault + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" +) + +type testClient struct{} + +func (tc *testClient) GetHTTPClient() *http.Client { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reqStr := fmt.Sprintf("%s %s", r.Method, r.URL) + switch reqStr { + case "POST http://vaultA:8500/v1/foo": + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Location", "http://vaultB:8500/v1/foo") + w.WriteHeader(http.StatusTemporaryRedirect) + fmt.Fprintln(w, "{}") + case "POST http://vaultB:8500/v1/foo": + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + fmt.Fprintln(w, "{}") + default: + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintf(w, "{ 'message': 'Unexpected request: %s'}", reqStr) + } + })) + return &http.Client{ + Transport: &http.Transport{ + Proxy: func(req *http.Request) (*url.URL, error) { + return url.Parse(server.URL) + }, + }, + } +} + +func (tc *testClient) SetToken(req *http.Request) { + req.Header.Set("X-Vault-Token", "dead-beef-cafe-babe") +} + +func (tc *testClient) Do(req *http.Request) (*http.Response, error) { + hc := tc.GetHTTPClient() + return hc.Do(req) +} + +func TestRequestAndFollow_GetWithRedirect(t *testing.T) { + tc := &testClient{} + u, _ := url.Parse("http://vaultA:8500/v1/foo") + + res, err := requestAndFollow(tc, "POST", u, nil) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, res.StatusCode) + +} + +func TestRequestAndFollow_GetNoRedirect(t *testing.T) { + tc := &testClient{} + u, _ := url.Parse("http://vaultB:8500/v1/foo") + + res, err := requestAndFollow(tc, "POST", u, nil) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, res.StatusCode) +}