Skip to content

Commit

Permalink
AWSPrometheusRemoteWriteExporter - Add SDK and system information to …
Browse files Browse the repository at this point in the history
…User-Agent header (#3317)

* added aws-sdk-go version and exec-env to the aws prw exporter user-agent header

* Added comments

* refined adding user-agent header

* Updated adding to the User-Agent header

* adding header to cloned request

* added aws-sdk-go version and exec-env to the aws prw exporter user-agent header

* Added comments

* refined adding user-agent header

* Updated adding to the User-Agent header

* adding header to cloned request

* compute string as constant and store it in signingRoundTripper struct

* added tests for user-agent header and check to see if user-agent header exists

* changed check for exisiting UA Header

* retrigger checks

Co-authored-by: Anthony J Mirabella <a9@aneurysm9.com>
  • Loading branch information
dhruv-vora and Aneurysm9 authored May 10, 2021
1 parent 15f661b commit 40f1090
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 16 deletions.
35 changes: 24 additions & 11 deletions exporter/awsprometheusremotewriteexporter/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ const defaultAMPSigV4Service = "aps"

// signingRoundTripper is a Custom RoundTripper that performs AWS Sig V4.
type signingRoundTripper struct {
transport http.RoundTripper
signer *v4.Signer
region string
service string
transport http.RoundTripper
signer *v4.Signer
region string
service string
runtimeInfo string
}

func (si *signingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
Expand All @@ -57,6 +58,17 @@ func (si *signingRoundTripper) RoundTrip(req *http.Request) (*http.Response, err

// Clone request to ensure thread safety.
req2 := cloneRequest(req)

// Add the runtime information to the User-Agent header of the request
ua := req2.Header.Get("User-Agent")
if len(ua) > 0 {
ua = ua + " " + si.runtimeInfo
} else {
ua = si.runtimeInfo
}
req2.Header.Set("User-Agent", ua)

// Sign the request
_, err = si.signer.Sign(req2, body, si.service, si.region, time.Now())
if err != nil {
return nil, err
Expand All @@ -71,7 +83,7 @@ func (si *signingRoundTripper) RoundTrip(req *http.Request) (*http.Response, err
return resp, err
}

func newSigningRoundTripper(cfg *Config, next http.RoundTripper) (http.RoundTripper, error) {
func newSigningRoundTripper(cfg *Config, next http.RoundTripper, runtimeInfo string) (http.RoundTripper, error) {
auth := cfg.AuthConfig
if auth.Region == "" {
region, err := parseEndpointRegion(cfg.Config.HTTPClientSettings.Endpoint)
Expand All @@ -88,7 +100,7 @@ func newSigningRoundTripper(cfg *Config, next http.RoundTripper) (http.RoundTrip
if err != nil {
return next, err
}
return newSigningRoundTripperWithCredentials(auth, creds, next)
return newSigningRoundTripperWithCredentials(auth, creds, next, runtimeInfo)
}

func getCredsFromConfig(auth AuthConfig) (*credentials.Credentials, error) {
Expand Down Expand Up @@ -122,16 +134,17 @@ func parseEndpointRegion(endpoint string) (region string, err error) {
return p[1], nil
}

func newSigningRoundTripperWithCredentials(auth AuthConfig, creds *credentials.Credentials, next http.RoundTripper) (http.RoundTripper, error) {
func newSigningRoundTripperWithCredentials(auth AuthConfig, creds *credentials.Credentials, next http.RoundTripper, runtimeInfo string) (http.RoundTripper, error) {
if creds == nil {
return nil, errors.New("no AWS credentials exist")
}
signer := v4.NewSigner(creds)
rt := signingRoundTripper{
transport: next,
signer: signer,
region: auth.Region,
service: auth.Service,
transport: next,
signer: signer,
region: auth.Region,
service: auth.Service,
runtimeInfo: runtimeInfo,
}
return &rt, nil
}
Expand Down
16 changes: 12 additions & 4 deletions exporter/awsprometheusremotewriteexporter/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@ package awsprometheusremotewriteexporter

import (
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"runtime"
"strings"
"sync"
"testing"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
v4 "github.com/aws/aws-sdk-go/aws/signer/v4"
"github.com/stretchr/testify/assert"
Expand All @@ -32,6 +35,8 @@ import (
"go.opentelemetry.io/collector/config/configtls"
)

var sdkInformation string = fmt.Sprintf("%s/%s (%s; %s; %s)", aws.SDKName, aws.SDKVersion, runtime.Version(), runtime.GOOS, runtime.GOARCH)

func TestRequestSignature(t *testing.T) {
// Some form of AWS credentials must be set up for tests to succeed
awsCreds := fetchMockCredentials()
Expand All @@ -40,6 +45,7 @@ func TestRequestSignature(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := v4.GetSignedRequestSignature(r)
assert.NoError(t, err)
assert.Equal(t, sdkInformation, r.Header.Get("User-Agent"))
w.WriteHeader(200)
}))
defer server.Close()
Expand All @@ -54,7 +60,7 @@ func TestRequestSignature(t *testing.T) {
WriteBufferSize: 0,
Timeout: 0,
CustomRoundTripper: func(next http.RoundTripper) (http.RoundTripper, error) {
return newSigningRoundTripperWithCredentials(authConfig, awsCreds, next)
return newSigningRoundTripperWithCredentials(authConfig, awsCreds, next, sdkInformation)
},
}
client, _ := setting.ToClient()
Expand Down Expand Up @@ -85,6 +91,7 @@ func TestLeakingBody(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := v4.GetSignedRequestSignature(r)
assert.NoError(t, err)
assert.Equal(t, sdkInformation, r.Header.Get("User-Agent"))
w.WriteHeader(200)
}))
defer server.Close()
Expand All @@ -99,7 +106,7 @@ func TestLeakingBody(t *testing.T) {
WriteBufferSize: 0,
Timeout: 0,
CustomRoundTripper: func(next http.RoundTripper) (http.RoundTripper, error) {
return newSigningRoundTripperWithCredentials(authConfig, awsCreds, next)
return newSigningRoundTripperWithCredentials(authConfig, awsCreds, next, sdkInformation)
},
}
client, _ := setting.ToClient()
Expand Down Expand Up @@ -177,12 +184,13 @@ func TestRoundTrip(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := v4.GetSignedRequestSignature(r)
assert.NoError(t, err)
assert.Equal(t, sdkInformation, r.Header.Get("User-Agent"))
w.WriteHeader(200)
}))
defer server.Close()
serverURL, _ := url.Parse(server.URL)
authConfig := AuthConfig{Region: "region", Service: "service"}
rt, err := newSigningRoundTripperWithCredentials(authConfig, awsCreds, tt.rt)
rt, err := newSigningRoundTripperWithCredentials(authConfig, awsCreds, tt.rt, sdkInformation)
assert.NoError(t, err)
req, err := http.NewRequest("POST", serverURL.String(), strings.NewReader(""))
assert.NoError(t, err)
Expand Down Expand Up @@ -239,7 +247,7 @@ func TestCreateSigningRoundTripperWithCredentials(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rtp, err := newSigningRoundTripperWithCredentials(tt.authConfig, tt.creds, tt.roundTripper)
rtp, err := newSigningRoundTripperWithCredentials(tt.authConfig, tt.creds, tt.roundTripper, sdkInformation)
if tt.returnError {
assert.Error(t, err)
return
Expand Down
12 changes: 11 additions & 1 deletion exporter/awsprometheusremotewriteexporter/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@ package awsprometheusremotewriteexporter

import (
"context"
"fmt"
"net/http"
"os"
"runtime"
"strings"

"github.com/aws/aws-sdk-go/aws"
"go.opentelemetry.io/collector/component"
"go.opentelemetry.io/collector/config"
prw "go.opentelemetry.io/collector/exporter/prometheusremotewriteexporter"
Expand Down Expand Up @@ -56,7 +61,12 @@ func (af *awsFactory) CreateDefaultConfig() config.Exporter {

cfg.ExporterSettings = config.NewExporterSettings(config.NewID(typeStr))
cfg.HTTPClientSettings.CustomRoundTripper = func(next http.RoundTripper) (http.RoundTripper, error) {
return newSigningRoundTripper(cfg, next)
extras := []string{runtime.Version(), runtime.GOOS, runtime.GOARCH}
if v := os.Getenv("AWS_EXECUTION_ENV"); v != "" {
extras = append(extras, v)
}
runtimeInfo := fmt.Sprintf("%s/%s (%s)", aws.SDKName, aws.SDKVersion, strings.Join(extras, "; "))
return newSigningRoundTripper(cfg, next, runtimeInfo)
}

return cfg
Expand Down

0 comments on commit 40f1090

Please # to comment.