diff --git a/Makefile b/Makefile index 1d1a0e4..b7eb690 100644 --- a/Makefile +++ b/Makefile @@ -158,7 +158,6 @@ ifneq ($(shell test -d $(MOCKS_DESTINATION); echo $$?), 0) $(MOCKGEN) --source pkg/cloudmap/client.go --destination $(MOCKS_DESTINATION)/pkg/cloudmap/client_mock.go --package cloudmap_mock $(MOCKGEN) --source pkg/cloudmap/cache.go --destination $(MOCKS_DESTINATION)/pkg/cloudmap/cache_mock.go --package cloudmap_mock $(MOCKGEN) --source pkg/cloudmap/operation_poller.go --destination $(MOCKS_DESTINATION)/pkg/cloudmap/operation_poller_mock.go --package cloudmap_mock - $(MOCKGEN) --source pkg/cloudmap/operation_collector.go --destination $(MOCKS_DESTINATION)/pkg/cloudmap/operation_collector_mock.go --package cloudmap_mock $(MOCKGEN) --source pkg/cloudmap/api.go --destination $(MOCKS_DESTINATION)/pkg/cloudmap/api_mock.go --package cloudmap_mock $(MOCKGEN) --source pkg/cloudmap/aws_facade.go --destination $(MOCKS_DESTINATION)/pkg/cloudmap/aws_facade_mock.go --package cloudmap_mock $(MOCKGEN) --source integration/janitor/api.go --destination $(MOCKS_DESTINATION)/integration/janitor/api_mock.go --package janitor_mock diff --git a/integration/janitor/janitor.go b/integration/janitor/janitor.go index 0a05fcf..260fc1e 100644 --- a/integration/janitor/janitor.go +++ b/integration/janitor/janitor.go @@ -72,7 +72,7 @@ func (j *cloudMapJanitor) Cleanup(ctx context.Context, nsName string) { opId, err := j.sdApi.DeleteNamespace(ctx, ns.Id) if err == nil { fmt.Println("namespace delete in progress") - _, err = j.sdApi.PollNamespaceOperation(ctx, opId) + _, err = cloudmap.NewOperationPoller(j.sdApi).Poll(ctx, opId) } j.checkOrFail(err, "clean up successful", "could not cleanup namespace") } @@ -87,17 +87,17 @@ func (j *cloudMapJanitor) deregisterInstances(ctx context.Context, nsName string fmt.Sprintf("service has %d instances to clean", len(insts)), "could not list instances to cleanup") - opColl := cloudmap.NewOperationCollector() + opPoller := cloudmap.NewOperationPoller(j.sdApi) for _, inst := range insts { instId := aws.ToString(inst.InstanceId) fmt.Printf("found instance to clean: %s\n", instId) - opColl.Add(func() (opId string, err error) { + opPoller.Submit(ctx, func() (opId string, err error) { return j.sdApi.DeregisterInstance(ctx, svcId, instId) }) } - opErr := cloudmap.NewDeregisterInstancePoller(j.sdApi, svcId, opColl.Collect(), opColl.GetStartTime()).Poll(ctx) - j.checkOrFail(opErr, "instances de-registered", "could not cleanup instances") + err = opPoller.Await() + j.checkOrFail(err, "instances de-registered", "could not cleanup instances") } func (j *cloudMapJanitor) checkOrFail(err error, successMsg string, failMsg string) { diff --git a/integration/janitor/janitor_test.go b/integration/janitor/janitor_test.go index 6f1d8c9..95e25c7 100644 --- a/integration/janitor/janitor_test.go +++ b/integration/janitor/janitor_test.go @@ -39,14 +39,15 @@ func TestCleanupHappyCase(t *testing.T) { tj.mockApi.EXPECT().DeregisterInstance(context.TODO(), test.SvcId, test.EndptId1). Return(test.OpId1, nil) - tj.mockApi.EXPECT().ListOperations(context.TODO(), gomock.Any()). - Return(map[string]types.OperationStatus{test.OpId1: types.OperationStatusSuccess}, nil) + tj.mockApi.EXPECT().GetOperation(context.TODO(), test.OpId1). + Return(&types.Operation{Status: types.OperationStatusSuccess}, nil) tj.mockApi.EXPECT().DeleteService(context.TODO(), test.SvcId). Return(nil) tj.mockApi.EXPECT().DeleteNamespace(context.TODO(), test.HttpNsId). Return(test.OpId2, nil) - tj.mockApi.EXPECT().PollNamespaceOperation(context.TODO(), test.OpId2). - Return(test.HttpNsId, nil) + tj.mockApi.EXPECT().GetOperation(context.TODO(), test.OpId2). + Return(&types.Operation{Status: types.OperationStatusSuccess, + Targets: map[string]string{string(types.OperationTargetTypeNamespace): test.HttpNsId}}, nil) tj.janitor.Cleanup(context.TODO(), test.HttpNsName) assert.False(t, *tj.failed) diff --git a/pkg/cloudmap/api.go b/pkg/cloudmap/api.go index 0342b5b..cb91f45 100644 --- a/pkg/cloudmap/api.go +++ b/pkg/cloudmap/api.go @@ -2,18 +2,12 @@ package cloudmap import ( "context" - "errors" - "fmt" - "time" - - "golang.org/x/time/rate" "github.com/aws/aws-cloud-map-mcs-controller-for-k8s/pkg/common" "github.com/aws/aws-cloud-map-mcs-controller-for-k8s/pkg/model" "github.com/aws/aws-sdk-go-v2/aws" sd "github.com/aws/aws-sdk-go-v2/service/servicediscovery" "github.com/aws/aws-sdk-go-v2/service/servicediscovery/types" - "k8s.io/apimachinery/pkg/util/wait" ) const ( @@ -32,9 +26,6 @@ type ServiceDiscoveryApi interface { // DiscoverInstances returns a list of service instances registered to a given service. DiscoverInstances(ctx context.Context, nsName string, svcName string, queryParameters map[string]string) (insts []types.HttpInstanceSummary, err error) - // ListOperations returns a map of operations to their status matching a list of filters. - ListOperations(ctx context.Context, opFilters []types.OperationFilter) (operationStatusMap map[string]types.OperationStatus, err error) - // GetOperation returns an operation. GetOperation(ctx context.Context, operationId string) (operation *types.Operation, err error) @@ -49,32 +40,25 @@ type ServiceDiscoveryApi interface { // DeregisterInstance de-registers a service instance in Cloud Map. DeregisterInstance(ctx context.Context, serviceId string, instanceId string) (operationId string, err error) - - // PollNamespaceOperation polls a namespace operation, and returns the namespace ID. - PollNamespaceOperation(ctx context.Context, operationId string) (namespaceId string, err error) } type serviceDiscoveryApi struct { - log common.Logger - awsFacade AwsFacade - nsRateLimiter *rate.Limiter - svcRateLimiter *rate.Limiter - opRateLimiter *rate.Limiter + log common.Logger + awsFacade AwsFacade + rateLimiter common.RateLimiter } // NewServiceDiscoveryApiFromConfig creates a new AWS Cloud Map API connection manager from an AWS client config. func NewServiceDiscoveryApiFromConfig(cfg *aws.Config) ServiceDiscoveryApi { return &serviceDiscoveryApi{ - log: common.NewLogger("cloudmap", "api"), - awsFacade: NewAwsFacadeFromConfig(cfg), - nsRateLimiter: rate.NewLimiter(rate.Every(1*time.Second), 5), // 1 per second - svcRateLimiter: rate.NewLimiter(rate.Every(2*time.Second), 10), // 2 per second - opRateLimiter: rate.NewLimiter(rate.Every(100*time.Second), 200), // 100 per second + log: common.NewLogger("cloudmap", "api"), + awsFacade: NewAwsFacadeFromConfig(cfg), + rateLimiter: common.NewDefaultRateLimiter(), } } func (sdApi *serviceDiscoveryApi) GetNamespaceMap(ctx context.Context) (map[string]*model.Namespace, error) { - err := sdApi.nsRateLimiter.Wait(ctx) + err := sdApi.rateLimiter.Wait(ctx, common.ListNamespaces) if err != nil { return nil, err } @@ -105,7 +89,7 @@ func (sdApi *serviceDiscoveryApi) GetNamespaceMap(ctx context.Context) (map[stri } func (sdApi *serviceDiscoveryApi) GetServiceIdMap(ctx context.Context, nsId string) (map[string]string, error) { - err := sdApi.svcRateLimiter.Wait(ctx) + err := sdApi.rateLimiter.Wait(ctx, common.ListServices) if err != nil { return nil, err } @@ -151,30 +135,8 @@ func (sdApi *serviceDiscoveryApi) DiscoverInstances(ctx context.Context, nsName return out.Instances, nil } -func (sdApi *serviceDiscoveryApi) ListOperations(ctx context.Context, opFilters []types.OperationFilter) (map[string]types.OperationStatus, error) { - opStatusMap := make(map[string]types.OperationStatus) - - pages := sd.NewListOperationsPaginator(sdApi.awsFacade, &sd.ListOperationsInput{ - Filters: opFilters, - }) - - for pages.HasMorePages() { - output, err := pages.NextPage(ctx) - - if err != nil { - return opStatusMap, err - } - - for _, sdOp := range output.Operations { - opStatusMap[aws.ToString(sdOp.Id)] = sdOp.Status - } - } - - return opStatusMap, nil -} - func (sdApi *serviceDiscoveryApi) GetOperation(ctx context.Context, opId string) (operation *types.Operation, err error) { - err = sdApi.opRateLimiter.Wait(ctx) + err = sdApi.rateLimiter.Wait(ctx, common.GetOperation) if err != nil { return nil, err } @@ -236,6 +198,11 @@ func (sdApi *serviceDiscoveryApi) getDnsConfig() types.DnsConfig { } func (sdApi *serviceDiscoveryApi) RegisterInstance(ctx context.Context, svcId string, instId string, instAttrs map[string]string) (opId string, err error) { + err = sdApi.rateLimiter.Wait(ctx, common.RegisterInstance) + if err != nil { + return "", err + } + regResp, err := sdApi.awsFacade.RegisterInstance(ctx, &sd.RegisterInstanceInput{ Attributes: instAttrs, InstanceId: &instId, @@ -250,6 +217,11 @@ func (sdApi *serviceDiscoveryApi) RegisterInstance(ctx context.Context, svcId st } func (sdApi *serviceDiscoveryApi) DeregisterInstance(ctx context.Context, svcId string, instId string) (opId string, err error) { + err = sdApi.rateLimiter.Wait(ctx, common.DeregisterInstance) + if err != nil { + return "", err + } + deregResp, err := sdApi.awsFacade.DeregisterInstance(ctx, &sd.DeregisterInstanceInput{ InstanceId: &instId, ServiceId: &svcId, @@ -261,31 +233,3 @@ func (sdApi *serviceDiscoveryApi) DeregisterInstance(ctx context.Context, svcId return aws.ToString(deregResp.OperationId), err } - -func (sdApi *serviceDiscoveryApi) PollNamespaceOperation(ctx context.Context, opId string) (nsId string, err error) { - err = wait.Poll(defaultOperationPollInterval, defaultOperationPollTimeout, func() (done bool, err error) { - sdApi.log.Info("polling operation", "opId", opId) - op, err := sdApi.GetOperation(ctx, opId) - - if err != nil { - return true, err - } - - if op.Status == types.OperationStatusFail { - return true, fmt.Errorf("failed to create namespace: %s", aws.ToString(op.ErrorMessage)) - } - - if op.Status == types.OperationStatusSuccess { - nsId = op.Targets[string(types.OperationTargetTypeNamespace)] - return true, nil - } - - return false, nil - }) - - if err == wait.ErrWaitTimeout { - err = errors.New(operationPollTimoutErrorMessage) - } - - return nsId, err -} diff --git a/pkg/cloudmap/api_test.go b/pkg/cloudmap/api_test.go index 70143ce..8990d51 100644 --- a/pkg/cloudmap/api_test.go +++ b/pkg/cloudmap/api_test.go @@ -5,9 +5,6 @@ import ( "errors" "fmt" "testing" - "time" - - "golang.org/x/time/rate" aboutv1alpha1 "github.com/aws/aws-cloud-map-mcs-controller-for-k8s/pkg/apis/about/v1alpha1" @@ -131,26 +128,6 @@ func TestServiceDiscoveryApi_DiscoverInstances_HappyCase(t *testing.T) { assert.Equal(t, test.EndptId2, *insts[1].InstanceId) } -func TestServiceDiscoveryApi_ListOperations_HappyCase(t *testing.T) { - mockController := gomock.NewController(t) - defer mockController.Finish() - - awsFacade := cloudmapMock.NewMockAwsFacade(mockController) - sdApi := getServiceDiscoveryApi(t, awsFacade) - - filters := make([]types.OperationFilter, 0) - awsFacade.EXPECT().ListOperations(context.TODO(), &sd.ListOperationsInput{Filters: filters}). - Return(&sd.ListOperationsOutput{ - Operations: []types.OperationSummary{ - {Id: aws.String(test.OpId1), Status: types.OperationStatusSuccess}, - }}, nil) - - ops, err := sdApi.ListOperations(context.TODO(), filters) - assert.Nil(t, err, "No error for happy case") - assert.True(t, len(ops) == 1) - assert.Equal(t, ops[test.OpId1], types.OperationStatusSuccess) -} - func TestServiceDiscoveryApi_GetOperation_HappyCase(t *testing.T) { mockController := gomock.NewController(t) defer mockController.Finish() @@ -316,33 +293,12 @@ func TestServiceDiscoveryApi_DeregisterInstance_Error(t *testing.T) { assert.Equal(t, sdkErr, err) } -func TestServiceDiscoveryApi_PollNamespaceOperation_HappyCase(t *testing.T) { - mockController := gomock.NewController(t) - defer mockController.Finish() - - awsFacade := cloudmapMock.NewMockAwsFacade(mockController) - awsFacade.EXPECT().GetOperation(context.TODO(), &sd.GetOperationInput{OperationId: aws.String(test.OpId1)}). - Return(&sd.GetOperationOutput{Operation: &types.Operation{Status: types.OperationStatusPending}}, nil) - - awsFacade.EXPECT().GetOperation(context.TODO(), &sd.GetOperationInput{OperationId: aws.String(test.OpId1)}). - Return(&sd.GetOperationOutput{Operation: &types.Operation{Status: types.OperationStatusSuccess, - Targets: map[string]string{string(types.OperationTargetTypeNamespace): test.HttpNsId}}}, nil) - - sdApi := getServiceDiscoveryApi(t, awsFacade) - - nsId, err := sdApi.PollNamespaceOperation(context.TODO(), test.OpId1) - assert.Nil(t, err) - assert.Equal(t, test.HttpNsId, nsId) -} - func getServiceDiscoveryApi(t *testing.T, awsFacade *cloudmapMock.MockAwsFacade) ServiceDiscoveryApi { scheme := runtime.NewScheme() scheme.AddKnownTypes(aboutv1alpha1.GroupVersion, &aboutv1alpha1.ClusterProperty{}) return &serviceDiscoveryApi{ - log: common.NewLoggerWithLogr(testr.New(t)), - awsFacade: awsFacade, - nsRateLimiter: rate.NewLimiter(rate.Every(1*time.Second), 2), // 1 per second - svcRateLimiter: rate.NewLimiter(rate.Every(2*time.Second), 4), // 2 per second - opRateLimiter: rate.NewLimiter(rate.Every(10*time.Second), 100), // 10 per second + log: common.NewLoggerWithLogr(testr.New(t)), + awsFacade: awsFacade, + rateLimiter: common.NewDefaultRateLimiter(), } } diff --git a/pkg/cloudmap/client.go b/pkg/cloudmap/client.go index afe1cde..6cd4ddc 100644 --- a/pkg/cloudmap/client.go +++ b/pkg/cloudmap/client.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-cloud-map-mcs-controller-for-k8s/pkg/common" "github.com/aws/aws-cloud-map-mcs-controller-for-k8s/pkg/model" "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/servicediscovery/types" ) // ServiceDiscoveryClient provides the service endpoint management functionality required by the AWS Cloud Map @@ -149,27 +150,21 @@ func (sdc *serviceDiscoveryClient) RegisterEndpoints(ctx context.Context, nsName return err } - opCollector := NewOperationCollector() - + operationPoller := NewOperationPoller(sdc.sdApi) for _, endpt := range endpts { endptId := endpt.Id endptAttrs := endpt.GetCloudMapAttributes() - opCollector.Add(func() (opId string, err error) { + operationPoller.Submit(ctx, func() (opId string, err error) { return sdc.sdApi.RegisterInstance(ctx, svcId, endptId, endptAttrs) }) } - err = NewRegisterInstancePoller(sdc.sdApi, svcId, opCollector.Collect(), opCollector.GetStartTime()).Poll(ctx) - // Evict cache entry so next list call reflects changes sdc.cache.EvictEndpoints(nsName, svcName) + err = operationPoller.Await() if err != nil { - return err - } - - if !opCollector.IsAllOperationsCreated() { - return errors.New("failure while registering endpoints") + return common.Wrap(err, errors.New("failure while registering endpoints")) } return nil @@ -188,29 +183,23 @@ func (sdc *serviceDiscoveryClient) DeleteEndpoints(ctx context.Context, nsName s return err } - opCollector := NewOperationCollector() - + operationPoller := NewOperationPoller(sdc.sdApi) for _, endpt := range endpts { endptId := endpt.Id - // add operation to delete endpoint - opCollector.Add(func() (opId string, err error) { + operationPoller.Submit(ctx, func() (opId string, err error) { return sdc.sdApi.DeregisterInstance(ctx, svcId, endptId) }) } - err = NewDeregisterInstancePoller(sdc.sdApi, svcId, opCollector.Collect(), opCollector.GetStartTime()).Poll(ctx) - // Evict cache entry so next list call reflects changes sdc.cache.EvictEndpoints(nsName, svcName) - if err != nil { - return err - } - if !opCollector.IsAllOperationsCreated() { - return errors.New("failure while de-registering endpoints") + err = operationPoller.Await() + if err != nil { + return common.Wrap(err, errors.New("failure while de-registering endpoints")) } - return nil + return err } func (sdc *serviceDiscoveryClient) getEndpoints(ctx context.Context, nsName string, svcName string) (endpts []*model.Endpoint, err error) { @@ -315,12 +304,13 @@ func (sdc *serviceDiscoveryClient) createNamespace(ctx context.Context, nsName s return nil, err } - nsId, err := sdc.sdApi.PollNamespaceOperation(ctx, opId) + op, err := NewOperationPoller(sdc.sdApi).Poll(ctx, opId) if err != nil { return nil, err } + nsId := op.Targets[string(types.OperationTargetTypeNamespace)] - sdc.log.Info("namespace created", "nsId", nsId) + sdc.log.Info("namespace created", "nsId", nsId, "namespace", nsName) // Default namespace type HTTP namespace = &model.Namespace{ diff --git a/pkg/cloudmap/client_test.go b/pkg/cloudmap/client_test.go index 84da49f..3de7f56 100644 --- a/pkg/cloudmap/client_test.go +++ b/pkg/cloudmap/client_test.go @@ -2,6 +2,7 @@ package cloudmap import ( "context" + "errors" "strconv" "testing" @@ -17,7 +18,6 @@ import ( "github.com/aws/aws-sdk-go-v2/service/servicediscovery/types" "github.com/go-logr/logr/testr" "github.com/golang/mock/gomock" - "github.com/pkg/errors" "github.com/stretchr/testify/assert" ) @@ -186,7 +186,9 @@ func TestServiceDiscoveryClient_CreateService_NamespaceNotFound(t *testing.T) { tc.mockCache.EXPECT().GetNamespaceMap().Return(map[string]*model.Namespace{}, true) tc.mockApi.EXPECT().CreateHttpNamespace(context.TODO(), test.HttpNsName).Return(test.OpId1, nil) - tc.mockApi.EXPECT().PollNamespaceOperation(context.TODO(), test.OpId1).Return(test.HttpNsId, nil) + tc.mockApi.EXPECT().GetOperation(context.TODO(), test.OpId1). + Return(&types.Operation{Status: types.OperationStatusSuccess, + Targets: map[string]string{string(types.OperationTargetTypeNamespace): test.HttpNsId}}, nil) tc.mockCache.EXPECT().EvictNamespaceMap() tc.mockApi.EXPECT().CreateService(context.TODO(), *test.GetTestHttpNamespace(), test.SvcName). @@ -221,8 +223,9 @@ func TestServiceDiscoveryClient_CreateService_CreatesNamespace_HappyCase(t *test tc.mockApi.EXPECT().CreateHttpNamespace(context.TODO(), test.HttpNsName). Return(test.OpId1, nil) - tc.mockApi.EXPECT().PollNamespaceOperation(context.TODO(), test.OpId1). - Return(test.HttpNsId, nil) + tc.mockApi.EXPECT().GetOperation(context.TODO(), test.OpId1). + Return(&types.Operation{Status: types.OperationStatusSuccess, + Targets: map[string]string{string(types.OperationTargetTypeNamespace): test.HttpNsId}}, nil) tc.mockCache.EXPECT().EvictNamespaceMap() tc.mockApi.EXPECT().CreateService(context.TODO(), *test.GetTestHttpNamespace(), test.SvcName). @@ -242,8 +245,8 @@ func TestServiceDiscoveryClient_CreateService_CreatesNamespace_PollError(t *test pollErr := errors.New("polling error") tc.mockApi.EXPECT().CreateHttpNamespace(context.TODO(), test.HttpNsName). Return(test.OpId1, nil) - tc.mockApi.EXPECT().PollNamespaceOperation(context.TODO(), test.OpId1). - Return("", pollErr) + tc.mockApi.EXPECT().GetOperation(context.TODO(), test.OpId1). + Return(nil, pollErr) err := tc.client.CreateService(context.TODO(), test.HttpNsName, test.SvcName) assert.Equal(t, pollErr, err) @@ -283,8 +286,7 @@ func TestServiceDiscoveryClient_GetService_HappyCase(t *testing.T) { tc.mockCache.EXPECT().GetEndpoints(test.HttpNsName, test.SvcName).Return([]*model.Endpoint{}, false) tc.mockApi.EXPECT().DiscoverInstances(context.TODO(), test.HttpNsName, test.SvcName, map[string]string{ model.ClusterSetIdAttr: test.ClusterSet, - }). - Return(getHttpInstanceSummaryForTest(), nil) + }).Return(getHttpInstanceSummaryForTest(), nil) tc.mockCache.EXPECT().CacheEndpoints(test.HttpNsName, test.SvcName, []*model.Endpoint{test.GetTestEndpoint1(), test.GetTestEndpoint2()}) @@ -335,51 +337,14 @@ func TestServiceDiscoveryClient_RegisterEndpoints(t *testing.T) { tc.mockCache.EXPECT().GetServiceIdMap(test.HttpNsName).Return(getServiceIdMapForTest(), true) - attrs1 := map[string]string{ - model.ClusterIdAttr: test.ClusterId1, - model.ClusterSetIdAttr: test.ClusterSet, - model.EndpointIpv4Attr: test.EndptIp1, - model.EndpointPortAttr: test.PortStr1, - model.EndpointPortNameAttr: test.PortName1, - model.EndpointProtocolAttr: test.Protocol1, - model.EndpointReadyAttr: test.EndptReadyTrue, - model.ServicePortNameAttr: test.PortName1, - model.ServicePortAttr: test.ServicePortStr1, - model.ServiceProtocolAttr: test.Protocol1, - model.ServiceTargetPortAttr: test.PortStr1, - model.ServiceTypeAttr: test.SvcType, - model.EndpointHostnameAttr: test.Hostname, - model.EndpointNodeNameAttr: test.Nodename, - model.ServiceExportCreationAttr: strconv.FormatInt(test.SvcExportCreationTimestamp, 10), - model.K8sVersionAttr: test.PackageVersion, - } - attrs2 := map[string]string{ - model.ClusterIdAttr: test.ClusterId1, - model.ClusterSetIdAttr: test.ClusterSet, - model.EndpointIpv4Attr: test.EndptIp2, - model.EndpointPortAttr: test.PortStr2, - model.EndpointPortNameAttr: test.PortName2, - model.EndpointProtocolAttr: test.Protocol2, - model.EndpointReadyAttr: test.EndptReadyTrue, - model.ServicePortNameAttr: test.PortName2, - model.ServicePortAttr: test.ServicePortStr2, - model.ServiceProtocolAttr: test.Protocol2, - model.ServiceTargetPortAttr: test.PortStr2, - model.ServiceTypeAttr: test.SvcType, - model.EndpointHostnameAttr: test.Hostname, - model.EndpointNodeNameAttr: test.Nodename, - model.ServiceExportCreationAttr: strconv.FormatInt(test.SvcExportCreationTimestamp, 10), - model.K8sVersionAttr: test.PackageVersion, - } - - tc.mockApi.EXPECT().RegisterInstance(context.TODO(), test.SvcId, test.EndptId1, attrs1). + tc.mockApi.EXPECT().RegisterInstance(context.TODO(), test.SvcId, test.EndptId1, getAttrs1()). Return(test.OpId1, nil) - tc.mockApi.EXPECT().RegisterInstance(context.TODO(), test.SvcId, test.EndptId2, attrs2). + tc.mockApi.EXPECT().RegisterInstance(context.TODO(), test.SvcId, test.EndptId2, getAttrs2()). Return(test.OpId2, nil) - tc.mockApi.EXPECT().ListOperations(context.TODO(), gomock.Any()). - Return(map[string]types.OperationStatus{ - test.OpId1: types.OperationStatusSuccess, - test.OpId2: types.OperationStatusSuccess}, nil) + tc.mockApi.EXPECT().GetOperation(context.TODO(), test.OpId1). + Return(&types.Operation{Status: types.OperationStatusSuccess}, nil) + tc.mockApi.EXPECT().GetOperation(context.TODO(), test.OpId2). + Return(&types.Operation{Status: types.OperationStatusSuccess}, nil) tc.mockCache.EXPECT().EvictEndpoints(test.HttpNsName, test.SvcName) @@ -389,18 +354,44 @@ func TestServiceDiscoveryClient_RegisterEndpoints(t *testing.T) { assert.Nil(t, err) } +func TestServiceDiscoveryClient_RegisterEndpoints_PollFailure(t *testing.T) { + tc := getTestSdClient(t) + defer tc.close() + + tc.mockCache.EXPECT().GetServiceIdMap(test.HttpNsName).Return(getServiceIdMapForTest(), true) + + tc.mockApi.EXPECT().RegisterInstance(context.TODO(), test.SvcId, test.EndptId1, getAttrs1()). + Return(test.OpId1, nil) + tc.mockApi.EXPECT().RegisterInstance(context.TODO(), test.SvcId, test.EndptId2, getAttrs2()). + Return(test.OpId2, nil) + tc.mockApi.EXPECT().GetOperation(context.TODO(), test.OpId1). + Return(&types.Operation{Status: types.OperationStatusFail}, nil) + tc.mockApi.EXPECT().GetOperation(context.TODO(), test.OpId2). + Return(&types.Operation{Status: types.OperationStatusSuccess}, nil) + + tc.mockCache.EXPECT().EvictEndpoints(test.HttpNsName, test.SvcName) + + err := tc.client.RegisterEndpoints(context.TODO(), test.HttpNsName, test.SvcName, + []*model.Endpoint{test.GetTestEndpoint1(), test.GetTestEndpoint2()}) + + assert.NotNil(t, err) + assert.Contains(t, err.Error(), test.OpId1) +} + func TestServiceDiscoveryClient_DeleteEndpoints(t *testing.T) { tc := getTestSdClient(t) defer tc.close() tc.mockCache.EXPECT().GetServiceIdMap(test.HttpNsName).Return(getServiceIdMapForTest(), true) - tc.mockApi.EXPECT().DeregisterInstance(context.TODO(), test.SvcId, test.EndptId1).Return(test.OpId1, nil) - tc.mockApi.EXPECT().DeregisterInstance(context.TODO(), test.SvcId, test.EndptId2).Return(test.OpId2, nil) - tc.mockApi.EXPECT().ListOperations(context.TODO(), gomock.Any()). - Return(map[string]types.OperationStatus{ - test.OpId1: types.OperationStatusSuccess, - test.OpId2: types.OperationStatusSuccess}, nil) + tc.mockApi.EXPECT().DeregisterInstance(context.TODO(), test.SvcId, test.EndptId1). + Return(test.OpId1, nil) + tc.mockApi.EXPECT().DeregisterInstance(context.TODO(), test.SvcId, test.EndptId2). + Return(test.OpId2, nil) + tc.mockApi.EXPECT().GetOperation(context.TODO(), test.OpId1). + Return(&types.Operation{Status: types.OperationStatusSuccess}, nil) + tc.mockApi.EXPECT().GetOperation(context.TODO(), test.OpId2). + Return(&types.Operation{Status: types.OperationStatusSuccess}, nil) tc.mockCache.EXPECT().EvictEndpoints(test.HttpNsName, test.SvcName) @@ -412,6 +403,33 @@ func TestServiceDiscoveryClient_DeleteEndpoints(t *testing.T) { assert.Nil(t, err) } +func TestServiceDiscoveryClient_DeleteEndpoints_PollFailure(t *testing.T) { + tc := getTestSdClient(t) + defer tc.close() + + tc.mockCache.EXPECT().GetServiceIdMap(test.HttpNsName).Return(getServiceIdMapForTest(), true) + + tc.mockApi.EXPECT().DeregisterInstance(context.TODO(), test.SvcId, test.EndptId1). + Return(test.OpId1, nil) + tc.mockApi.EXPECT().DeregisterInstance(context.TODO(), test.SvcId, test.EndptId2). + Return(test.OpId2, nil) + tc.mockApi.EXPECT().GetOperation(context.TODO(), test.OpId1). + Return(&types.Operation{Status: types.OperationStatusFail}, nil) + tc.mockApi.EXPECT().GetOperation(context.TODO(), test.OpId2). + Return(&types.Operation{Status: types.OperationStatusSuccess}, nil) + + tc.mockCache.EXPECT().EvictEndpoints(test.HttpNsName, test.SvcName) + + err := tc.client.DeleteEndpoints(context.TODO(), test.HttpNsName, test.SvcName, + []*model.Endpoint{ + {Id: test.EndptId1, ClusterId: test.ClusterId1, ClusterSetId: test.ClusterSet}, + {Id: test.EndptId2, ClusterId: test.ClusterId1, ClusterSetId: test.ClusterSet}, + }) + + assert.NotNil(t, err) + assert.Contains(t, err.Error(), test.OpId1) +} + func getTestSdClient(t *testing.T) *testSdClient { test.SetTestVersion() mockController := gomock.NewController(t) @@ -490,3 +508,45 @@ func getNamespaceMapForTest() map[string]*model.Namespace { func getServiceIdMapForTest() map[string]string { return map[string]string{test.SvcName: test.SvcId} } + +func getAttrs2() map[string]string { + return map[string]string{ + model.ClusterIdAttr: test.ClusterId1, + model.ClusterSetIdAttr: test.ClusterSet, + model.EndpointIpv4Attr: test.EndptIp2, + model.EndpointPortAttr: test.PortStr2, + model.EndpointPortNameAttr: test.PortName2, + model.EndpointProtocolAttr: test.Protocol2, + model.EndpointReadyAttr: test.EndptReadyTrue, + model.ServicePortNameAttr: test.PortName2, + model.ServicePortAttr: test.ServicePortStr2, + model.ServiceProtocolAttr: test.Protocol2, + model.ServiceTargetPortAttr: test.PortStr2, + model.ServiceTypeAttr: test.SvcType, + model.EndpointHostnameAttr: test.Hostname, + model.EndpointNodeNameAttr: test.Nodename, + model.ServiceExportCreationAttr: strconv.FormatInt(test.SvcExportCreationTimestamp, 10), + model.K8sVersionAttr: test.PackageVersion, + } +} + +func getAttrs1() map[string]string { + return map[string]string{ + model.ClusterIdAttr: test.ClusterId1, + model.ClusterSetIdAttr: test.ClusterSet, + model.EndpointIpv4Attr: test.EndptIp1, + model.EndpointPortAttr: test.PortStr1, + model.EndpointPortNameAttr: test.PortName1, + model.EndpointProtocolAttr: test.Protocol1, + model.EndpointReadyAttr: test.EndptReadyTrue, + model.ServicePortNameAttr: test.PortName1, + model.ServicePortAttr: test.ServicePortStr1, + model.ServiceProtocolAttr: test.Protocol1, + model.ServiceTargetPortAttr: test.PortStr1, + model.ServiceTypeAttr: test.SvcType, + model.EndpointHostnameAttr: test.Hostname, + model.EndpointNodeNameAttr: test.Nodename, + model.ServiceExportCreationAttr: strconv.FormatInt(test.SvcExportCreationTimestamp, 10), + model.K8sVersionAttr: test.PackageVersion, + } +} diff --git a/pkg/cloudmap/operation_collector.go b/pkg/cloudmap/operation_collector.go deleted file mode 100644 index 93293be..0000000 --- a/pkg/cloudmap/operation_collector.go +++ /dev/null @@ -1,84 +0,0 @@ -package cloudmap - -import ( - "sync" - - "github.com/aws/aws-cloud-map-mcs-controller-for-k8s/pkg/common" -) - -// OperationCollector collects a list of operation IDs asynchronously with thread safety. -type OperationCollector interface { - // Add calls an operation provider function to asynchronously collect operations to poll. - Add(operationProvider func() (operationId string, err error)) - - // Collect waits for all create operation results to be provided and returns a list of the successfully created operation IDs. - Collect() []string - - // GetStartTime returns the start time range to poll the collected operations. - GetStartTime() int64 - - // IsAllOperationsCreated returns true if all operations were created successfully. - IsAllOperationsCreated() bool -} - -type opCollector struct { - log common.Logger - opChan chan opResult - wg sync.WaitGroup - startTime int64 - createOpsSuccess bool -} - -type opResult struct { - opId string - err error -} - -func NewOperationCollector() OperationCollector { - return &opCollector{ - log: common.NewLogger("cloudmap"), - opChan: make(chan opResult), - startTime: Now(), - createOpsSuccess: true, - } -} - -func (opColl *opCollector) Add(opProvider func() (opId string, err error)) { - opColl.wg.Add(1) - go func() { - defer opColl.wg.Done() - - opId, opErr := opProvider() - opColl.opChan <- opResult{opId, opErr} - }() -} - -func (opColl *opCollector) Collect() []string { - opIds := make([]string, 0) - - // Run wait in separate go routine to unblock reading from the channel. - go func() { - opColl.wg.Wait() - close(opColl.opChan) - }() - - for op := range opColl.opChan { - if op.err != nil { - opColl.log.Info("could not create operation", "error", op.err) - opColl.createOpsSuccess = false - continue - } - - opIds = append(opIds, op.opId) - } - - return opIds -} - -func (opColl *opCollector) GetStartTime() int64 { - return opColl.startTime -} - -func (opColl *opCollector) IsAllOperationsCreated() bool { - return opColl.createOpsSuccess -} diff --git a/pkg/cloudmap/operation_collector_test.go b/pkg/cloudmap/operation_collector_test.go deleted file mode 100644 index cb6e6f5..0000000 --- a/pkg/cloudmap/operation_collector_test.go +++ /dev/null @@ -1,57 +0,0 @@ -package cloudmap - -import ( - "errors" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -const ( - op1 = "one" - op2 = "two" -) - -func TestOpCollector_HappyCase(t *testing.T) { - oc := NewOperationCollector() - oc.Add(func() (opId string, err error) { return op1, nil }) - oc.Add(func() (opId string, err error) { return op2, nil }) - - result := oc.Collect() - assert.True(t, oc.IsAllOperationsCreated()) - assert.Equal(t, 2, len(result)) - assert.Contains(t, result, op1) - assert.Contains(t, result, op2) -} - -func TestOpCollector_AllFail(t *testing.T) { - oc := NewOperationCollector() - oc.Add(func() (opId string, err error) { return op1, errors.New("fail one") }) - oc.Add(func() (opId string, err error) { return op2, errors.New("fail two") }) - - result := oc.Collect() - assert.False(t, oc.IsAllOperationsCreated()) - assert.Equal(t, 0, len(result)) -} - -func TestOpCollector_MixedSuccess(t *testing.T) { - oc := NewOperationCollector() - oc.Add(func() (opId string, err error) { return op1, errors.New("fail one") }) - oc.Add(func() (opId string, err error) { return op2, nil }) - - result := oc.Collect() - assert.False(t, oc.IsAllOperationsCreated()) - assert.Equal(t, []string{op2}, result) -} - -func TestOpCollector_GetStartTime(t *testing.T) { - oc1 := NewOperationCollector() - time.Sleep(time.Second) - oc2 := NewOperationCollector() - - assert.Equal(t, oc1.GetStartTime(), oc1.GetStartTime(), "Start time should not change") - assert.NotEqual(t, oc1.GetStartTime(), oc2.GetStartTime(), "Start time should reflect instantiation") - assert.Less(t, oc1.GetStartTime(), oc2.GetStartTime(), - "Start time should increase for later instantiations") -} diff --git a/pkg/cloudmap/operation_poller.go b/pkg/cloudmap/operation_poller.go index aacf4a3..d62278d 100644 --- a/pkg/cloudmap/operation_poller.go +++ b/pkg/cloudmap/operation_poller.go @@ -2,8 +2,8 @@ package cloudmap import ( "context" - "errors" - "strconv" + "fmt" + "sync" "time" "github.com/aws/aws-cloud-map-mcs-controller-for-k8s/pkg/common" @@ -14,150 +14,116 @@ import ( const ( // Interval between each getOperation call. - defaultOperationPollInterval = 3 * time.Second + defaultOperationPollInterval = 2 * time.Second // Time until we stop polling the operation - defaultOperationPollTimeout = 5 * time.Minute + defaultOperationPollTimeout = 1 * time.Minute operationPollTimoutErrorMessage = "timed out while polling operations" ) -// OperationPoller polls a list operations for a terminal status. +// OperationPoller polls a list operations for a terminal status type OperationPoller interface { - // Poll monitors operations until they reach terminal state. - Poll(ctx context.Context) error -} + // Submit operations to async poll + Submit(ctx context.Context, opProvider func() (opId string, err error)) -type operationPoller struct { - log common.Logger - sdApi ServiceDiscoveryApi - timeout time.Duration - - opIds []string - svcId string - opType types.OperationType - start int64 -} + // Poll operations for a terminal state + Poll(ctx context.Context, opId string) (*types.Operation, error) -func newOperationPoller(sdApi ServiceDiscoveryApi, svcId string, opIds []string, startTime int64) operationPoller { - return operationPoller{ - log: common.NewLogger("cloudmap"), - sdApi: sdApi, - timeout: defaultOperationPollTimeout, + // Await waits for all operation results from async poll + Await() (err error) +} - opIds: opIds, - svcId: svcId, - start: startTime, - } +type operationPoller struct { + log common.Logger + sdApi ServiceDiscoveryApi + opChan chan opResult + waitGroup sync.WaitGroup + pollInterval time.Duration + pollTimeout time.Duration } -// NewRegisterInstancePoller creates a new operation poller for register instance operations. -func NewRegisterInstancePoller(sdApi ServiceDiscoveryApi, serviceId string, opIds []string, startTime int64) OperationPoller { - poller := newOperationPoller(sdApi, serviceId, opIds, startTime) - poller.opType = types.OperationTypeRegisterInstance - return &poller +type opResult struct { + opId string + err error } -// NewDeregisterInstancePoller creates a new operation poller for de-register instance operations. -func NewDeregisterInstancePoller(sdApi ServiceDiscoveryApi, serviceId string, opIds []string, startTime int64) OperationPoller { - poller := newOperationPoller(sdApi, serviceId, opIds, startTime) - poller.opType = types.OperationTypeDeregisterInstance - return &poller +// NewOperationPoller creates a new operation poller +func NewOperationPoller(sdApi ServiceDiscoveryApi) OperationPoller { + return NewOperationPollerWithConfig(defaultOperationPollInterval, defaultOperationPollTimeout, sdApi) } -func (opPoller *operationPoller) Poll(ctx context.Context) (err error) { - if len(opPoller.opIds) == 0 { - opPoller.log.Info("no operations to poll") - return nil +// NewOperationPollerWithConfig creates a new operation poller +func NewOperationPollerWithConfig(pollInterval, pollTimeout time.Duration, sdApi ServiceDiscoveryApi) OperationPoller { + return &operationPoller{ + log: common.NewLogger("cloudmap", "OperationPoller"), + sdApi: sdApi, + opChan: make(chan opResult), + pollInterval: pollInterval, + pollTimeout: pollTimeout, } +} - err = wait.Poll(defaultOperationPollInterval, opPoller.timeout, func() (done bool, err error) { - opPoller.log.Info("polling operations", "operations", opPoller.opIds) +func (p *operationPoller) Submit(ctx context.Context, opProvider func() (opId string, err error)) { + p.waitGroup.Add(1) - sdOps, err := opPoller.sdApi.ListOperations(ctx, opPoller.buildFilters()) + // Poll for the operation in a separate go routine + go func() { + // Indicate the polling done i.e. decrement the WaitGroup counter when the goroutine returns + defer p.waitGroup.Done() - if err != nil { - return true, err + opId, err := opProvider() + // Poll for the operationId if the provider doesn't throw error + if err == nil { + _, err = p.Poll(ctx, opId) } - failedOps := make([]string, 0) + p.opChan <- opResult{opId: opId, err: err} + }() +} - for _, pollOp := range opPoller.opIds { - status, hasVal := sdOps[pollOp] - if !hasVal { - // polled operation not terminal - return false, nil - } +func (p *operationPoller) Poll(ctx context.Context, opId string) (op *types.Operation, err error) { + // poll tries a condition func until it returns true, an error, or the timeout is reached. + err = wait.Poll(p.pollInterval, p.pollTimeout, func() (done bool, err error) { + p.log.Info("polling operation", "opId", opId) - if status == types.OperationStatusFail { - failedOps = append(failedOps, pollOp) - } + op, err = p.sdApi.GetOperation(ctx, opId) + if err != nil { + return true, err } - if len(failedOps) != 0 { - for _, failedOp := range failedOps { - opPoller.log.Info("operation failed", "failedOp", failedOp, "reason", opPoller.getFailedOpReason(ctx, failedOp)) - } - return true, errors.New("operation failure") + switch op.Status { + case types.OperationStatusSuccess: + return true, nil + case types.OperationStatusFail: + return true, fmt.Errorf("operation failed, opId: %s, reason: %s", opId, aws.ToString(op.ErrorMessage)) + default: + return false, nil } - - opPoller.log.Info("operations completed successfully") - return true, nil }) - if err == wait.ErrWaitTimeout { - return errors.New(operationPollTimoutErrorMessage) - } - - return err -} - -func (opPoller *operationPoller) buildFilters() []types.OperationFilter { - svcFilter := types.OperationFilter{ - Name: types.OperationFilterNameServiceId, - Values: []string{opPoller.svcId}, + err = fmt.Errorf("%s, opId: %s", operationPollTimoutErrorMessage, opId) } - statusFilter := types.OperationFilter{ - Name: types.OperationFilterNameStatus, - Condition: types.FilterConditionIn, - Values: []string{ - string(types.OperationStatusFail), - string(types.OperationStatusSuccess)}, - } - typeFilter := types.OperationFilter{ - Name: types.OperationFilterNameType, - Values: []string{string(opPoller.opType)}, - } - - timeFilter := types.OperationFilter{ - Name: types.OperationFilterNameUpdateDate, - Condition: types.FilterConditionBetween, - Values: []string{ - Itoa(opPoller.start), - // Add one minute to end range in case op updates while list request is in flight - Itoa(Now() + 60000), - }, - } - - return []types.OperationFilter{svcFilter, statusFilter, typeFilter, timeFilter} + return op, err } -// getFailedOpReason returns operation error message, which is not available in ListOperations response -func (opPoller *operationPoller) getFailedOpReason(ctx context.Context, opId string) string { - op, err := opPoller.sdApi.GetOperation(ctx, opId) - - if err != nil { - return "failed to retrieve operation failure reason" +func (p *operationPoller) Await() (err error) { + // Run wait in separate go routine to unblock reading from the channel. + go func() { + // Block till the polling done i.e. WaitGroup counter is zero, and then close the channel + p.waitGroup.Wait() + close(p.opChan) + }() + + for res := range p.opChan { + if res.err != nil { + p.log.Error(res.err, "operation failed", "opId", res.opId) + err = common.Wrap(err, res.err) + } else { + p.log.Info("operations completed successfully", "opId", res.opId) + } } - return aws.ToString(op.ErrorMessage) -} -func Itoa(i int64) string { - return strconv.FormatInt(i, 10) -} - -// Now returns current time with milliseconds, as used by operation filter UPDATE_DATE field -func Now() int64 { - return time.Now().UnixNano() / 1000000 + return err } diff --git a/pkg/cloudmap/operation_poller_test.go b/pkg/cloudmap/operation_poller_test.go index 08bcc7c..87ec236 100644 --- a/pkg/cloudmap/operation_poller_test.go +++ b/pkg/cloudmap/operation_poller_test.go @@ -2,224 +2,165 @@ package cloudmap import ( "context" - "errors" - "strconv" + "fmt" "testing" "time" cloudmapMock "github.com/aws/aws-cloud-map-mcs-controller-for-k8s/mocks/pkg/cloudmap" - "github.com/aws/aws-cloud-map-mcs-controller-for-k8s/pkg/common" - "github.com/aws/aws-cloud-map-mcs-controller-for-k8s/test" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/servicediscovery/types" - "github.com/go-logr/logr/testr" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" ) -func TestOperationPoller_HappyCases(t *testing.T) { +const ( + op1 = "one" + op2 = "two" + op3 = "three" + interval = 100 * time.Millisecond + timeout = 500 * time.Millisecond +) + +func TestOperationPoller_HappyCase(t *testing.T) { mockController := gomock.NewController(t) defer mockController.Finish() sdApi := cloudmapMock.NewMockServiceDiscoveryApi(mockController) - pollerTypes := []struct { - constructor func() OperationPoller - expectedOpType types.OperationType - }{ - { - constructor: func() OperationPoller { - return NewRegisterInstancePoller(sdApi, test.SvcId, []string{test.OpId1, test.OpId2}, test.OpStart) - }, - expectedOpType: types.OperationTypeRegisterInstance, - }, - { - constructor: func() OperationPoller { - return NewDeregisterInstancePoller(sdApi, test.SvcId, []string{test.OpId1, test.OpId2}, test.OpStart) - }, - expectedOpType: types.OperationTypeDeregisterInstance, - }, - } + op1First := sdApi.EXPECT().GetOperation(gomock.Any(), op1).Return(opSubmitted(), nil) + op1Second := sdApi.EXPECT().GetOperation(gomock.Any(), op1).Return(opPending(), nil) + op1Third := sdApi.EXPECT().GetOperation(gomock.Any(), op1).Return(opSuccess(), nil) + gomock.InOrder(op1First, op1Second, op1Third) - for _, pollerType := range pollerTypes { - p := pollerType.constructor() - - var firstEnd int - - sdApi.EXPECT(). - ListOperations(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, filters []types.OperationFilter) (map[string]types.OperationStatus, error) { - assert.Contains(t, filters, - types.OperationFilter{ - Name: types.OperationFilterNameServiceId, - Values: []string{test.SvcId}, - }) - assert.Contains(t, filters, - types.OperationFilter{ - Name: types.OperationFilterNameStatus, - Condition: types.FilterConditionIn, - - Values: []string{ - string(types.OperationStatusFail), - string(types.OperationStatusSuccess)}, - }) - assert.Contains(t, filters, - types.OperationFilter{ - Name: types.OperationFilterNameType, - Values: []string{string(pollerType.expectedOpType)}, - }) - - timeFilter := findUpdateDateFilter(t, filters) - assert.NotNil(t, timeFilter) - assert.Equal(t, types.FilterConditionBetween, timeFilter.Condition) - assert.Equal(t, 2, len(timeFilter.Values)) - - filterStart, _ := strconv.Atoi(timeFilter.Values[0]) - assert.Equal(t, test.OpStart, filterStart) - - firstEnd, _ = strconv.Atoi(timeFilter.Values[1]) - - return map[string]types.OperationStatus{}, nil - }) - - sdApi.EXPECT(). - ListOperations(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, filters []types.OperationFilter) (map[string]types.OperationStatus, error) { - timeFilter := findUpdateDateFilter(t, filters) - secondEnd, _ := strconv.Atoi(timeFilter.Values[1]) - assert.Greater(t, secondEnd, firstEnd, - "Filter time frame for operations must increase between invocations of ListOperations") - - return map[string]types.OperationStatus{ - test.OpId1: types.OperationStatusSuccess, - test.OpId2: types.OperationStatusSuccess, - }, nil - }) - - err := p.Poll(context.TODO()) - - assert.Nil(t, err) - } -} + op2First := sdApi.EXPECT().GetOperation(gomock.Any(), op2).Return(opPending(), nil) + op2Second := sdApi.EXPECT().GetOperation(gomock.Any(), op2).Return(opSuccess(), nil) + gomock.InOrder(op2First, op2Second) -func TestOperationPoller_PollEmpty(t *testing.T) { - mockController := gomock.NewController(t) - defer mockController.Finish() + sdApi.EXPECT().GetOperation(gomock.Any(), op3).Return(opSuccess(), nil) - sdApi := cloudmapMock.NewMockServiceDiscoveryApi(mockController) + op := NewOperationPollerWithConfig(interval, timeout, sdApi) + op.Submit(context.TODO(), func() (opId string, err error) { return op1, nil }) + op.Submit(context.TODO(), func() (opId string, err error) { return op2, nil }) + op.Submit(context.TODO(), func() (opId string, err error) { return op3, nil }) - p := NewRegisterInstancePoller(sdApi, test.SvcId, []string{}, test.OpStart) - err := p.Poll(context.TODO()) - assert.Nil(t, err) + result := op.Await() + assert.Nil(t, result) } -func TestOperationPoller_PollFailure(t *testing.T) { +func TestOperationPoller_AllFail(t *testing.T) { mockController := gomock.NewController(t) defer mockController.Finish() sdApi := cloudmapMock.NewMockServiceDiscoveryApi(mockController) - p := NewRegisterInstancePoller(sdApi, test.SvcId, []string{test.OpId1, test.OpId2}, test.OpStart) - - pollErr := errors.New("error polling operations") - - sdApi.EXPECT(). - ListOperations(gomock.Any(), gomock.Any()). - Return(map[string]types.OperationStatus{}, pollErr) - - err := p.Poll(context.TODO()) - assert.Equal(t, pollErr, err) + op1First := sdApi.EXPECT().GetOperation(gomock.Any(), op1).Return(opSubmitted(), nil) + op1Second := sdApi.EXPECT().GetOperation(gomock.Any(), op1).Return(opPending(), nil) + op1Third := sdApi.EXPECT().GetOperation(gomock.Any(), op1).Return(opFailed(), nil) + gomock.InOrder(op1First, op1Second, op1Third) + + op2First := sdApi.EXPECT().GetOperation(gomock.Any(), op2).Return(opSubmitted(), nil) + op2Second := sdApi.EXPECT().GetOperation(gomock.Any(), op2).Return(opFailed(), nil) + gomock.InOrder(op2First, op2Second) + + op := NewOperationPollerWithConfig(interval, timeout, sdApi) + op.Submit(context.TODO(), func() (opId string, err error) { return op1, nil }) + op.Submit(context.TODO(), func() (opId string, err error) { return op2, nil }) + unknown := "failed to reg error" + op.Submit(context.TODO(), func() (opId string, err error) { + return "", fmt.Errorf(unknown) + }) + + err := op.Await() + assert.NotNil(t, err) + assert.Contains(t, err.Error(), op1) + assert.Contains(t, err.Error(), op2) + assert.Contains(t, err.Error(), unknown) } -func TestOperationPoller_PollOpFailure(t *testing.T) { +func TestOperationPoller_Mixed(t *testing.T) { mockController := gomock.NewController(t) defer mockController.Finish() sdApi := cloudmapMock.NewMockServiceDiscoveryApi(mockController) - p := NewRegisterInstancePoller(sdApi, test.SvcId, []string{test.OpId1, test.OpId2}, test.OpStart) - - sdApi.EXPECT(). - ListOperations(gomock.Any(), gomock.Any()). - Return( - map[string]types.OperationStatus{ - test.OpId1: types.OperationStatusSuccess, - test.OpId2: types.OperationStatusFail, - }, nil) + op1First := sdApi.EXPECT().GetOperation(gomock.Any(), op1).Return(opSubmitted(), nil) + op1Second := sdApi.EXPECT().GetOperation(gomock.Any(), op1).Return(opPending(), nil) + op1Third := sdApi.EXPECT().GetOperation(gomock.Any(), op1).Return(opFailed(), nil) + gomock.InOrder(op1First, op1Second, op1Third) - opErr := "operation failure message" + op2First := sdApi.EXPECT().GetOperation(gomock.Any(), op2).Return(opSubmitted(), nil) + op2Second := sdApi.EXPECT().GetOperation(gomock.Any(), op2).Return(opPending(), nil) + op2Third := sdApi.EXPECT().GetOperation(gomock.Any(), op2).Return(opSuccess(), nil) + gomock.InOrder(op2First, op2Second, op2Third) - sdApi.EXPECT(). - GetOperation(gomock.Any(), test.OpId2). - Return(&types.Operation{ErrorMessage: &opErr}, nil) + op := NewOperationPollerWithConfig(interval, timeout, sdApi) + op.Submit(context.TODO(), func() (opId string, err error) { return op1, nil }) + op.Submit(context.TODO(), func() (opId string, err error) { return op2, nil }) - err := p.Poll(context.TODO()) - assert.Equal(t, "operation failure", err.Error()) + err := op.Await() + assert.NotNil(t, err) + assert.Contains(t, err.Error(), op1) + assert.NotContains(t, err.Error(), op2) } -func TestOperationPoller_PollOpFailureAndMessageFailure(t *testing.T) { +func TestOperationPoller_Timeout(t *testing.T) { mockController := gomock.NewController(t) defer mockController.Finish() sdApi := cloudmapMock.NewMockServiceDiscoveryApi(mockController) - p := NewRegisterInstancePoller(sdApi, test.SvcId, []string{test.OpId1, test.OpId2}, test.OpStart) + sdApi.EXPECT().GetOperation(gomock.Any(), op1).Return(opPending(), nil).AnyTimes() - sdApi.EXPECT(). - ListOperations(gomock.Any(), gomock.Any()). - Return( - map[string]types.OperationStatus{ - test.OpId1: types.OperationStatusFail, - test.OpId2: types.OperationStatusSuccess, - }, nil) + op2First := sdApi.EXPECT().GetOperation(gomock.Any(), op2).Return(opPending(), nil) + op2Second := sdApi.EXPECT().GetOperation(gomock.Any(), op2).Return(opSuccess(), nil) + gomock.InOrder(op2First, op2Second) - sdApi.EXPECT(). - GetOperation(gomock.Any(), test.OpId1). - Return(nil, errors.New("failed to retrieve operation failure reason")) + op := NewOperationPollerWithConfig(interval, timeout, sdApi) + op.Submit(context.TODO(), func() (opId string, err error) { return op1, nil }) + op.Submit(context.TODO(), func() (opId string, err error) { return op2, nil }) - err := p.Poll(context.TODO()) - assert.Equal(t, "operation failure", err.Error()) + err := op.Await() + assert.NotNil(t, err) + assert.Contains(t, err.Error(), op1) + assert.Contains(t, err.Error(), operationPollTimoutErrorMessage) + assert.NotContains(t, err.Error(), op2) } -func TestOperationPoller_PollTimeout(t *testing.T) { +func TestOperationPoller_Poll_HappyCase(t *testing.T) { mockController := gomock.NewController(t) defer mockController.Finish() sdApi := cloudmapMock.NewMockServiceDiscoveryApi(mockController) - p := operationPoller{ - log: common.NewLoggerWithLogr(testr.New(t)), - sdApi: sdApi, - timeout: 2 * time.Millisecond, - opIds: []string{test.OpId1, test.OpId2}, - } - - sdApi.EXPECT(). - ListOperations(gomock.Any(), gomock.Any()). - Return( - map[string]types.OperationStatus{}, nil) + sdApi.EXPECT().GetOperation(context.TODO(), op1).Return(opPending(), nil) + sdApi.EXPECT().GetOperation(context.TODO(), op1).Return(opSuccess(), nil) - err := p.Poll(context.TODO()) - assert.Equal(t, operationPollTimoutErrorMessage, err.Error()) + op := NewOperationPollerWithConfig(interval, timeout, sdApi) + _, err := op.Poll(context.TODO(), op1) + assert.Nil(t, err) } -func TestItoa(t *testing.T) { - assert.Equal(t, "7", Itoa(7)) +func opPending() *types.Operation { + return &types.Operation{ + Status: types.OperationStatusPending, + } } -func TestNow(t *testing.T) { - now1 := Now() - time.Sleep(time.Millisecond * 5) - now2 := Now() - assert.Greater(t, now2, now1) +func opFailed() *types.Operation { + return &types.Operation{ + Status: types.OperationStatusFail, + ErrorMessage: aws.String("fail"), + } } -func findUpdateDateFilter(t *testing.T, filters []types.OperationFilter) *types.OperationFilter { - for _, filter := range filters { - if filter.Name == types.OperationFilterNameUpdateDate { - return &filter - } +func opSubmitted() *types.Operation { + return &types.Operation{ + Status: types.OperationStatusSubmitted, } +} - t.Errorf("Missing update date filter") - return nil +func opSuccess() *types.Operation { + return &types.Operation{ + Status: types.OperationStatusSuccess, + } } diff --git a/pkg/common/ratelimiter.go b/pkg/common/ratelimiter.go new file mode 100644 index 0000000..bae8346 --- /dev/null +++ b/pkg/common/ratelimiter.go @@ -0,0 +1,43 @@ +package common + +import ( + "context" + "fmt" + + "golang.org/x/time/rate" +) + +const ( + ListNamespaces Event = "ListNamespaces" + ListServices Event = "ListServices" + GetOperation Event = "GetOperation" + RegisterInstance Event = "RegisterInstance" + DeregisterInstance Event = "DeregisterInstance" +) + +type Event string + +type RateLimiter struct { + rateLimiters map[Event]*rate.Limiter +} + +// NewDefaultRateLimiter returns the rate limiters with the default limits for the AWS CloudMap's API calls +func NewDefaultRateLimiter() RateLimiter { + return RateLimiter{rateLimiters: map[Event]*rate.Limiter{ + // Below are the default limits for the AWS CloudMap's APIs + // TODO: make it customizable in the future + ListNamespaces: rate.NewLimiter(rate.Limit(1), 5), // 1 ListNamespaces API calls per second + ListServices: rate.NewLimiter(rate.Limit(2), 10), // 2 ListServices API calls per second + GetOperation: rate.NewLimiter(rate.Limit(100), 200), // 100 GetOperation API calls per second + RegisterInstance: rate.NewLimiter(rate.Limit(50), 100), // 50 RegisterInstance API calls per second + DeregisterInstance: rate.NewLimiter(rate.Limit(50), 100), // 50 DeregisterInstance API calls per second + }} +} + +// Wait blocks until limit permits an event to happen. It returns an error if the Context is canceled, or the expected wait time exceeds the Context's Deadline. +func (r RateLimiter) Wait(ctx context.Context, event Event) error { + if limiter, ok := r.rateLimiters[event]; ok { + return limiter.Wait(ctx) + } + return fmt.Errorf("event %s not found in the list of limiters", event) +} diff --git a/pkg/common/ratelimiter_test.go b/pkg/common/ratelimiter_test.go new file mode 100644 index 0000000..f1b6afb --- /dev/null +++ b/pkg/common/ratelimiter_test.go @@ -0,0 +1,64 @@ +package common + +import ( + "context" + "testing" +) + +func TestRateLimiter_Wait(t *testing.T) { + type fields struct { + RateLimiter RateLimiter + } + type args struct { + ctx context.Context + event Event + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "happy", + fields: fields{RateLimiter: NewDefaultRateLimiter()}, + args: args{ + ctx: context.TODO(), + event: ListServices, + }, + wantErr: false, + }, + { + name: "not_found", + fields: fields{RateLimiter: NewDefaultRateLimiter()}, + args: args{ + ctx: context.TODO(), + event: "test", + }, + wantErr: true, + }, + { + name: "error_ctx_canceled", + fields: fields{RateLimiter: NewDefaultRateLimiter()}, + args: args{ + ctx: ctxCanceled(context.TODO()), + event: ListNamespaces, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := tt.fields.RateLimiter + if err := r.Wait(tt.args.ctx, tt.args.event); (err != nil) != tt.wantErr { + t.Errorf("Wait() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func ctxCanceled(ctx context.Context) context.Context { + ret, cancel := context.WithCancel(ctx) + defer cancel() // cancel after function call + return ret +}