Skip to content

Commit

Permalink
Context fixes
Browse files Browse the repository at this point in the history
Go has strong convention around passing context.Context as the first call
to functions.

I intentionally did not fix auth/ and credential_provider/ misuses of
context, as that will be fixed in a later PR updating the AWS Go SDK.

Signed-off-by: Micah Hausler <mhausler@amazon.com>
  • Loading branch information
micahhausler committed Feb 12, 2025
1 parent e9d3aca commit b8204c9
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 13 deletions.
4 changes: 2 additions & 2 deletions provider/parameter_store_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func (p *ParameterStoreProvider) fetchParameterStoreValue(
) (values []*SecretValue, err error) {

for _, client := range p.clients {
batchValues, err := p.fetchParameterStoreBatch(client, ctx, batchDescriptors, curMap)
batchValues, err := p.fetchParameterStoreBatch(ctx, client, batchDescriptors, curMap)

if utils.IsFatalError(err) {
return nil, err
Expand All @@ -104,8 +104,8 @@ func (p *ParameterStoreProvider) fetchParameterStoreValue(
// if any parameter is failed to fetch, the parameter is returned as invalid parameter
// and the version information is updated in the current version map.
func (p *ParameterStoreProvider) fetchParameterStoreBatch(
client ParameterStoreClient,
ctx context.Context,
client ParameterStoreClient,
batchDescriptors []*SecretDescriptor,
curMap map[string]*v1alpha1.ObjectVersion,
) (v []*SecretValue, err error) {
Expand Down
8 changes: 4 additions & 4 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func (s *CSIDriverProviderServer) Mount(ctx context.Context, req *v1alpha1.Mount
return nil, fmt.Errorf("failed to unmarshal file permission, error: %+v", err)
}

regions, err := s.getAwsRegions(region, failoverRegion, nameSpace, podName, ctx)
regions, err := s.getAwsRegions(ctx, region, failoverRegion, nameSpace, podName)
if err != nil {
klog.ErrorS(err, "Failed to initialize AWS session")
return nil, err
Expand All @@ -135,7 +135,7 @@ func (s *CSIDriverProviderServer) Mount(ctx context.Context, req *v1alpha1.Mount
}
}

awsSessions, err := s.getAwsSessions(nameSpace, svcAcct, ctx, regions, usePodIdentity, podName, preferredAddressType)
awsSessions, err := s.getAwsSessions(ctx, nameSpace, svcAcct, regions, usePodIdentity, podName, preferredAddressType)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -193,7 +193,7 @@ func (s *CSIDriverProviderServer) Mount(ctx context.Context, req *v1alpha1.Mount
// If a region is not specified in the mount request, we must lookup the region from node label and add as primary region to the lookup region list
// If both the region and node label region are not available, error will be thrown
// If backupRegion is provided and is equal to region/node region, error will be thrown else backupRegion is added to the lookup region list
func (s *CSIDriverProviderServer) getAwsRegions(region, backupRegion, nameSpace, podName string, ctx context.Context) (response []string, err error) {
func (s *CSIDriverProviderServer) getAwsRegions(ctx context.Context, region, backupRegion, nameSpace, podName string) (response []string, err error) {
var lookupRegionList []string

// Find primary region. Fall back to region node if unavailable.
Expand All @@ -220,7 +220,7 @@ func (s *CSIDriverProviderServer) getAwsRegions(region, backupRegion, nameSpace,
// Gets the pod's AWS creds for each lookup region
// Establishes the connection using Aws cred for each lookup region
// If atleast one session is not created, error will be thrown
func (s *CSIDriverProviderServer) getAwsSessions(nameSpace, svcAcct string, ctx context.Context, lookupRegionList []string, usePodIdentity bool, podName string, preferredAddressType string) (response []*session.Session, err error) {
func (s *CSIDriverProviderServer) getAwsSessions(ctx context.Context, nameSpace, svcAcct string, lookupRegionList []string, usePodIdentity bool, podName string, preferredAddressType string) (response []*session.Session, err error) {
// Get the pod's AWS creds for each lookup region.
var awsSessionsList []*session.Session

Expand Down
14 changes: 7 additions & 7 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2004,7 +2004,7 @@ func TestMounts(t *testing.T) {

// Do the mount
req := buildMountReq(dir, tst, []*v1alpha1.ObjectVersion{})
rsp, err := svr.Mount(nil, req)
rsp, err := svr.Mount(context.Background(), req)
if len(tst.expErr) == 0 && err != nil {
t.Fatalf("%s: Got unexpected error: %s", tst.testName, err)
}
Expand Down Expand Up @@ -2045,7 +2045,7 @@ func TestMountsNoWrite(t *testing.T) {

// Do the mount
req := buildMountReq(dir, tst, []*v1alpha1.ObjectVersion{})
rsp, err := svr.Mount(nil, req)
rsp, err := svr.Mount(context.Background(), req)
if len(tst.expErr) == 0 && err != nil {
t.Fatalf("%s: Got unexpected error: %s", tst.testName, err)
}
Expand Down Expand Up @@ -2468,7 +2468,7 @@ func TestReMounts(t *testing.T) {

// Do the mount
req := buildMountReq(dir, tst, curState)
rsp, err := svr.Mount(nil, req)
rsp, err := svr.Mount(context.Background(), req)
if len(tst.expErr) == 0 && err != nil {
t.Fatalf("%s: Got unexpected error: %s", tst.testName, err)
}
Expand Down Expand Up @@ -2510,7 +2510,7 @@ func TestNoWriteReMounts(t *testing.T) {

// Do the mount
req := buildMountReq(dir, tst, curState)
rsp, err := svr.Mount(nil, req)
rsp, err := svr.Mount(context.Background(), req)
if len(tst.expErr) == 0 && err != nil {
t.Fatalf("%s: Got unexpected error: %s", tst.testName, err)
}
Expand Down Expand Up @@ -2547,7 +2547,7 @@ func TestEmptyAttributes(t *testing.T) {
Permission: "420",
CurrentObjectVersion: []*v1alpha1.ObjectVersion{},
}
rsp, err := svr.Mount(nil, req)
rsp, err := svr.Mount(context.Background(), req)

if rsp != nil {
t.Fatalf("TestEmptyAttributes: got unexpected response")
Expand All @@ -2567,7 +2567,7 @@ func TestNoPath(t *testing.T) {
Permission: "420",
CurrentObjectVersion: []*v1alpha1.ObjectVersion{},
}
rsp, err := svr.Mount(nil, req)
rsp, err := svr.Mount(context.Background(), req)

if rsp != nil {
t.Fatalf("TestNoPath: got unexpected response")
Expand All @@ -2590,7 +2590,7 @@ func TestDriverVersion(t *testing.T) {
t.Fatalf("TestDriverVersion: got empty server")
}

ver, err := svr.Version(nil, &v1alpha1.VersionRequest{})
ver, err := svr.Version(context.Background(), &v1alpha1.VersionRequest{})
if err != nil {
t.Fatalf("TestDriverVersion: got unexpected error %s", err.Error())
}
Expand Down

0 comments on commit b8204c9

Please # to comment.