Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Handle redirects from vault server versions earlier than v0.6.2 #76

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions vault/app-id_strategy.go
Original file line number Diff line number Diff line change
@@ -28,20 +28,34 @@ func NewAppIDAuthStrategy() *AppIDAuthStrategy {
return nil
}

// GetToken - log in to the app-id auth backend and return the client token
func (a *AppIDAuthStrategy) GetToken(addr *url.URL) (string, error) {
// GetHTTPClient configures the HTTP client with a timeout
func (a *AppIDAuthStrategy) GetHTTPClient() *http.Client {
if a.hc == nil {
a.hc = &http.Client{Timeout: time.Second * 5}
}
client := a.hc
return a.hc
}

// SetToken is a no-op for AppIDAuthStrategy as a token hasn't been acquired yet
func (a *AppIDAuthStrategy) SetToken(req *http.Request) {
// no-op
}

// Do wraps http.Client.Do
func (a *AppIDAuthStrategy) Do(req *http.Request) (*http.Response, error) {
hc := a.GetHTTPClient()
return hc.Do(req)
}

// GetToken - log in to the app-id auth backend and return the client token
func (a *AppIDAuthStrategy) GetToken(addr *url.URL) (string, error) {
buf := new(bytes.Buffer)
json.NewEncoder(buf).Encode(&a)

u := &url.URL{}
*u = *addr
u.Path = "/v1/auth/app-id/#"
res, err := client.Post(u.String(), "application/json; charset=utf-8", buf)
res, err := requestAndFollow(a, "POST", u, buf.Bytes())
if err != nil {
return "", err
}
49 changes: 30 additions & 19 deletions vault/client.go
Original file line number Diff line number Diff line change
@@ -54,6 +54,31 @@ func getAuthStrategy() AuthStrategy {
return nil
}

// GetHTTPClient returns a client configured w/X-Vault-Token header
func (c *Client) GetHTTPClient() *http.Client {
if c.hc == nil {
c.hc = &http.Client{
Timeout: time.Second * 5,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
c.SetToken(req)
return nil
},
}
}
return c.hc
}

// SetToken adds an X-Vault-Token header to the request
func (c *Client) SetToken(req *http.Request) {
req.Header.Set("X-Vault-Token", c.token)
}

// Do wraps http.Client.Do
func (c *Client) Do(req *http.Request) (*http.Response, error) {
hc := c.GetHTTPClient()
return hc.Do(req)
}

// Login - log in to Vault with the discovered auth backend and save the token
func (c *Client) Login() error {
token, err := c.Auth.GetToken(c.Addr)
@@ -72,17 +97,12 @@ func (c *Client) RevokeToken() {
return
}

if c.hc == nil {
c.hc = &http.Client{Timeout: time.Second * 5}
}

u := &url.URL{}
*u = *c.Addr
u.Path = "/v1/auth/token/revoke-self"
req, _ := http.NewRequest("POST", u.String(), nil)
req.Header.Set("X-Vault-Token", c.token)

res, err := c.hc.Do(req)
res, err := requestAndFollow(c, "POST", u, nil)

if err != nil {
log.Println("Error while revoking Vault Token", err)
}
@@ -94,32 +114,23 @@ func (c *Client) RevokeToken() {

func (c *Client) Read(path string) ([]byte, error) {
path = normalizeURLPath(path)
if c.hc == nil {
c.hc = &http.Client{Timeout: time.Second * 5}
}

u := &url.URL{}
*u = *c.Addr
u.Path = "/v1" + path
req, err := http.NewRequest("GET", u.String(), nil)
if err != nil {
return nil, err
}
req.Header.Set("X-Vault-Token", c.token)

res, err := c.hc.Do(req)
res, err := requestAndFollow(c, "GET", u, nil)
if err != nil {
return nil, err
}

body, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
return nil, err
}

if res.StatusCode != 200 {
err = fmt.Errorf("Unexpected HTTP status %d on Read from %s: %s", res.StatusCode, u, body)
err = fmt.Errorf("Unexpected HTTP status %d on Read from %s: %s", res.StatusCode, path, body)
return nil, err
}

@@ -131,7 +142,7 @@ func (c *Client) Read(path string) ([]byte, error) {
}

if _, ok := response["data"]; !ok {
return nil, fmt.Errorf("Unexpected HTTP body on Read for %s: %s", u, body)
return nil, fmt.Errorf("Unexpected HTTP body on Read for %s: %s", path, body)
}

return json.Marshal(response["data"])
47 changes: 47 additions & 0 deletions vault/http.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package vault

import (
"bytes"
"net/http"
"net/url"
)

// httpClient
type httpClient interface {
GetHTTPClient() *http.Client
SetToken(req *http.Request)
Do(req *http.Request) (*http.Response, error)
}

func requestAndFollow(hc httpClient, method string, u *url.URL, body []byte) (*http.Response, error) {
var res *http.Response
var err error
for attempts := 0; attempts < 2; attempts++ {
reader := bytes.NewReader(body)
req, err := http.NewRequest(method, u.String(), reader)

if err != nil {
return nil, err
}
hc.SetToken(req)
if method == "POST" {
req.Header.Set("Content-Type", "application/json; charset=utf-8")
}

res, err = hc.Do(req)
if err != nil {
return nil, err
}
if res.StatusCode == http.StatusTemporaryRedirect {
res.Body.Close()
location, errLocation := res.Location()
if errLocation != nil {
return nil, errLocation
}
u.Host = location.Host
} else {
break
}
}
return res, err
}
68 changes: 68 additions & 0 deletions vault/http_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package vault

import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing"

"github.com/stretchr/testify/assert"
)

type testClient struct{}

func (tc *testClient) GetHTTPClient() *http.Client {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqStr := fmt.Sprintf("%s %s", r.Method, r.URL)
switch reqStr {
case "POST http://vaultA:8500/v1/foo":
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Location", "http://vaultB:8500/v1/foo")
w.WriteHeader(http.StatusTemporaryRedirect)
fmt.Fprintln(w, "{}")
case "POST http://vaultB:8500/v1/foo":
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
fmt.Fprintln(w, "{}")
default:
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprintf(w, "{ 'message': 'Unexpected request: %s'}", reqStr)
}
}))
return &http.Client{
Transport: &http.Transport{
Proxy: func(req *http.Request) (*url.URL, error) {
return url.Parse(server.URL)
},
},
}
}

func (tc *testClient) SetToken(req *http.Request) {
req.Header.Set("X-Vault-Token", "dead-beef-cafe-babe")
}

func (tc *testClient) Do(req *http.Request) (*http.Response, error) {
hc := tc.GetHTTPClient()
return hc.Do(req)
}

func TestRequestAndFollow_GetWithRedirect(t *testing.T) {
tc := &testClient{}
u, _ := url.Parse("http://vaultA:8500/v1/foo")

res, err := requestAndFollow(tc, "POST", u, nil)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)

}

func TestRequestAndFollow_GetNoRedirect(t *testing.T) {
tc := &testClient{}
u, _ := url.Parse("http://vaultB:8500/v1/foo")

res, err := requestAndFollow(tc, "POST", u, nil)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
}