Skip to content

Commit

Permalink
add support for token file and eks container endpoint in general HTTP…
Browse files Browse the repository at this point in the history
… provider
  • Loading branch information
wty-Bryant authored Nov 13, 2023
1 parent f300f13 commit b3103f2
Show file tree
Hide file tree
Showing 4 changed files with 225 additions and 8 deletions.
10 changes: 10 additions & 0 deletions .changelog/0593bfc1f00841febcb73c7957839f01.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"id": "0593bfc1-f008-41fe-bcb7-3c7957839f01",
"type": "feature",
"collapse": true,
"description": "Add support for dynamic auth token from file and EKS container host in absolute/relative URIs in the HTTP credential provider.",
"modules": [
"config",
"credentials"
]
}
81 changes: 74 additions & 7 deletions config/resolve_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ package config
import (
"context"
"fmt"
"io/ioutil"
"net"
"net/url"
"os"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
Expand All @@ -21,11 +24,33 @@ import (

const (
// valid credential source values
credSourceEc2Metadata = "Ec2InstanceMetadata"
credSourceEnvironment = "Environment"
credSourceECSContainer = "EcsContainer"
credSourceEc2Metadata = "Ec2InstanceMetadata"
credSourceEnvironment = "Environment"
credSourceECSContainer = "EcsContainer"
httpProviderAuthFileEnvVar = "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE"
)

// direct representation of the IPv4 address for the ECS container
// "169.254.170.2"
var ecsContainerIPv4 net.IP = []byte{
169, 254, 170, 2,
}

// direct representation of the IPv4 address for the EKS container
// "169.254.170.23"
var eksContainerIPv4 net.IP = []byte{
169, 254, 170, 23,
}

// direct representation of the IPv6 address for the EKS container
// "fd00:ec2::23"
var eksContainerIPv6 net.IP = []byte{
0xFD, 0, 0xE, 0xC2,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0x23,
}

var (
ecsContainerEndpoint = "http://169.254.170.2" // not constant to allow for swapping during unit-testing
)
Expand Down Expand Up @@ -222,6 +247,36 @@ func processCredentials(ctx context.Context, cfg *aws.Config, sharedConfig *Shar
return nil
}

// isAllowedHost allows host to be loopback or known ECS/EKS container IPs
//
// host can either be an IP address OR an unresolved hostname - resolution will
// be automatically performed in the latter case
func isAllowedHost(host string) (bool, error) {
if ip := net.ParseIP(host); ip != nil {
return isIPAllowed(ip), nil
}

addrs, err := lookupHostFn(host)
if err != nil {
return false, err
}

for _, addr := range addrs {
if ip := net.ParseIP(addr); ip == nil || !isIPAllowed(ip) {
return false, nil
}
}

return true, nil
}

func isIPAllowed(ip net.IP) bool {
return ip.IsLoopback() ||
ip.Equal(ecsContainerIPv4) ||
ip.Equal(eksContainerIPv4) ||
ip.Equal(eksContainerIPv6)
}

func resolveLocalHTTPCredProvider(ctx context.Context, cfg *aws.Config, endpointURL, authToken string, configs configs) error {
var resolveErr error

Expand All @@ -232,10 +287,12 @@ func resolveLocalHTTPCredProvider(ctx context.Context, cfg *aws.Config, endpoint
host := parsed.Hostname()
if len(host) == 0 {
resolveErr = fmt.Errorf("unable to parse host from local HTTP cred provider URL")
} else if isLoopback, loopbackErr := isLoopbackHost(host); loopbackErr != nil {
resolveErr = fmt.Errorf("failed to resolve host %q, %v", host, loopbackErr)
} else if !isLoopback {
resolveErr = fmt.Errorf("invalid endpoint host, %q, only loopback hosts are allowed", host)
} else if parsed.Scheme == "http" {
if isAllowedHost, allowHostErr := isAllowedHost(host); allowHostErr != nil {
resolveErr = fmt.Errorf("failed to resolve host %q, %v", host, allowHostErr)
} else if !isAllowedHost {
resolveErr = fmt.Errorf("invalid endpoint host, %q, only loopback/ecs/eks hosts are allowed", host)
}
}
}

Expand All @@ -252,6 +309,16 @@ func resolveHTTPCredProvider(ctx context.Context, cfg *aws.Config, url, authToke
if len(authToken) != 0 {
options.AuthorizationToken = authToken
}
if authFilePath := os.Getenv(httpProviderAuthFileEnvVar); authFilePath != "" {
options.AuthorizationTokenProvider = endpointcreds.TokenProviderFunc(func() (string, error) {
var contents []byte
var err error
if contents, err = ioutil.ReadFile(authFilePath); err != nil {
return "", fmt.Errorf("failed to read authorization token from %v: %v", authFilePath, err)
}
return string(contents), nil
})
}
options.APIOptions = cfg.APIOptions
if cfg.Retryer != nil {
options.Retryer = cfg.Retryer()
Expand Down
58 changes: 57 additions & 1 deletion credentials/endpointcreds/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
"context"
"fmt"
"net/http"
"strings"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials/endpointcreds/internal/client"
Expand Down Expand Up @@ -81,7 +82,37 @@ type Options struct {

// Optional authorization token value if set will be used as the value of
// the Authorization header of the endpoint credential request.
//
// When constructed from environment, the provider will use the value of
// AWS_CONTAINER_AUTHORIZATION_TOKEN environment variable as the token
//
// Will be overridden if AuthorizationTokenProvider is configured
AuthorizationToken string

// Optional auth provider func to dynamically load the auth token from a file
// everytime a credential is retrieved
//
// When constructed from environment, the provider will read and use the content
// of the file pointed to by AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE environment variable
// as the auth token everytime credentials are retrieved
//
// Will override AuthorizationToken if configured
AuthorizationTokenProvider AuthTokenProvider
}

// AuthTokenProvider defines an interface to dynamically load a value to be passed
// for the Authorization header of a credentials request.
type AuthTokenProvider interface {
GetToken() (string, error)
}

// TokenProviderFunc is a func type implementing AuthTokenProvider interface
// and enables customizing token provider behavior
type TokenProviderFunc func() (string, error)

// GetToken func retrieves auth token according to TokenProviderFunc implementation
func (p TokenProviderFunc) GetToken() (string, error) {
return p()
}

// New returns a credentials Provider for retrieving AWS credentials
Expand Down Expand Up @@ -132,5 +163,30 @@ func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) {
}

func (p *Provider) getCredentials(ctx context.Context) (*client.GetCredentialsOutput, error) {
return p.client.GetCredentials(ctx, &client.GetCredentialsInput{AuthorizationToken: p.options.AuthorizationToken})
authToken, err := p.resolveAuthToken()
if err != nil {
return nil, fmt.Errorf("resolve auth token: %v", err)
}

return p.client.GetCredentials(ctx, &client.GetCredentialsInput{
AuthorizationToken: authToken,
})
}

func (p *Provider) resolveAuthToken() (string, error) {
authToken := p.options.AuthorizationToken

var err error
if p.options.AuthorizationTokenProvider != nil {
authToken, err = p.options.AuthorizationTokenProvider.GetToken()
if err != nil {
return "", err
}
}

if strings.ContainsAny(authToken, "\r\n") {
return "", fmt.Errorf("authorization token contains invalid newline sequence")
}

return authToken, nil
}
84 changes: 84 additions & 0 deletions credentials/endpointcreds/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,90 @@ func TestRetrieveStaticCredentials(t *testing.T) {
}
}

func TestAuthTokenProvider(t *testing.T) {
cases := map[string]struct {
AuthToken string
AuthTokenProvider endpointcreds.AuthTokenProvider
ExpectAuthToken string
ExpectError bool
}{
"AuthToken": {
AuthToken: "Basic abc123",
ExpectAuthToken: "Basic abc123",
},
"AuthFileToken": {
AuthToken: "Basic abc123",
AuthTokenProvider: endpointcreds.TokenProviderFunc(func() (string, error) {
return "Hello %20world", nil
}),
ExpectAuthToken: "Hello %20world",
},
"RetrieveFileTokenError": {
AuthToken: "Basic abc123",
AuthTokenProvider: endpointcreds.TokenProviderFunc(func() (string, error) {
return "", fmt.Errorf("test error")
}),
ExpectAuthToken: "Hello %20world",
ExpectError: true,
},
}

for name, c := range cases {
t.Run(name, func(t *testing.T) {
orig := sdk.NowTime
defer func() { sdk.NowTime = orig }()

var actualToken string
p := endpointcreds.New("http://127.0.0.1", func(o *endpointcreds.Options) {
o.HTTPClient = mockClient(func(r *http.Request) (*http.Response, error) {
actualToken = r.Header["Authorization"][0]
return &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader([]byte(`{
"AccessKeyID": "AKID",
"SecretAccessKey": "SECRET"
}`))),
}, nil
})
o.AuthorizationToken = c.AuthToken
o.AuthorizationTokenProvider = c.AuthTokenProvider
})
creds, err := p.Retrieve(context.Background())

if err != nil && !c.ExpectError {
t.Errorf("expect no error, got %v", err)
} else if err == nil && c.ExpectError {
t.Errorf("expect error, got nil")
}

if c.ExpectError {
return
}

if e, a := "AKID", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "SECRET", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("expect empty, got %v", v)
}
if e, a := c.ExpectAuthToken, actualToken; e != a {
t.Errorf("Expect %v, got %v", e, a)
}

sdk.NowTime = func() time.Time {
return time.Date(3000, 12, 16, 1, 30, 37, 0, time.UTC)
}

if creds.Expired() {
t.Errorf("expect not to be expired")
}
})
}
}

func TestFailedRetrieveCredentials(t *testing.T) {
p := endpointcreds.New("http://127.0.0.1", func(o *endpointcreds.Options) {
o.HTTPClient = mockClient(func(r *http.Request) (*http.Response, error) {
Expand Down

0 comments on commit b3103f2

Please # to comment.