From 87f7ce7e0f22afa26bcf3d9aedac1b8afb68f97c Mon Sep 17 00:00:00 2001 From: Leonardo Di Giovanna Date: Wed, 25 Dec 2024 17:05:23 +0100 Subject: [PATCH] fix(decl/cmd/test): wait for goroutines termination upon failures Signed-off-by: Leonardo Di Giovanna --- cmd/declarative/test/test.go | 47 ++++++++++++++++++++++---------- pkg/test/tester/tester/tester.go | 16 +++-------- 2 files changed, 36 insertions(+), 27 deletions(-) diff --git a/cmd/declarative/test/test.go b/cmd/declarative/test/test.go index 9937471c..dc0a62b4 100644 --- a/cmd/declarative/test/test.go +++ b/cmd/declarative/test/test.go @@ -171,7 +171,7 @@ func (cw *CommandWrapper) run(cmd *cobra.Command, _ []string) { ctx, cancel := context.WithTimeout(ctx, cw.TestsTimeout) defer cancel() - exitAndCancel := func() { + cancelAndExit := func() { cancel() os.Exit(1) } @@ -179,7 +179,7 @@ func (cw *CommandWrapper) run(cmd *cobra.Command, _ []string) { labels, err := label.ParseSet(cw.Labels) if err != nil { logger.Error(err, "Error parsing labels") - exitAndCancel() + cancelAndExit() } logger = enrichLoggerWithLabels(logger, labels) @@ -187,13 +187,13 @@ func (cw *CommandWrapper) run(cmd *cobra.Command, _ []string) { description, err := loadTestsDescription(logger, cw.TestsDescriptionFile, cw.TestsDescription) if err != nil { logger.Error(err, "Error loading tests description") - exitAndCancel() + cancelAndExit() } runnerBuilder, err := cw.createRunnerBuilder() if err != nil { logger.Error(err, "Error creating runner builder") - exitAndCancel() + cancelAndExit() } // Retrieve the already populated test ID. The test ID absence is used to uniquely identify the root process in the @@ -201,22 +201,35 @@ func (cw *CommandWrapper) run(cmd *cobra.Command, _ []string) { testID := cw.TestID isRootProcess := testID == "" + // globalWaitGroup accounts for all spawned goroutines. + globalWaitGroup := sync.WaitGroup{} + waitAndExit := func() { + cancel() + globalWaitGroup.Wait() + os.Exit(1) + } + // Initialize tester and Falco alerts collection. var testr tester.Tester - testerWaitGroup := sync.WaitGroup{} if isRootProcess && !cw.skipOutcomeVerification { if testr, err = cw.createTester(logger); err != nil { logger.Error(err, "Error creating tester") - exitAndCancel() + cancelAndExit() } + + globalWaitGroup.Add(1) go func() { + defer globalWaitGroup.Done() + defer cancel() if err := testr.StartAlertsCollection(ctx); err != nil { logger.Error(err, "Error starting tester execution") - exitAndCancel() } }() } + // testerWaitGroup accounts for all goroutines producing reports. + testerWaitGroup := sync.WaitGroup{} + // Prepare parameters shared by all runners. runnerEnviron := cw.buildRunnerEnviron(cmd) var runnerLabels *label.Set @@ -248,7 +261,7 @@ func (cw *CommandWrapper) run(cmd *cobra.Command, _ []string) { testUID, err = extractTestUID(testID) if err != nil { logger.Error(err, "Error extracting test UID from test ID", "testId", testID) - exitAndCancel() + waitAndExit() } } @@ -267,23 +280,25 @@ func (cw *CommandWrapper) run(cmd *cobra.Command, _ []string) { runnerInstance, err := runnerBuilder.Build(testDesc.Runner, runnerLogger, runnerDescription) if err != nil { logger.Error(err, "Error creating runner") - exitAndCancel() + waitAndExit() } logger.Info("Starting test execution...") if err := runnerInstance.Run(ctx, testID, testIndex, testDesc); err != nil { logRunnerError(logger, err) - exitAndCancel() + waitAndExit() } logger.Info("Test execution completed") if testr != nil { - produceReport(&testerWaitGroup, testr, &testUID, testDesc, cw.reportFormat) + produceReport(&globalWaitGroup, &testerWaitGroup, testr, &testUID, testDesc, cw.reportFormat) } } testerWaitGroup.Wait() + cancel() + globalWaitGroup.Wait() } // enrichLoggerWithLabels creates a new logger, starting from the provided one, with the information extracted from the @@ -462,11 +477,13 @@ func logRunnerError(logger logr.Logger, err error) { } // produceReport produces a report for the given test by using the provided tester. -func produceReport(wg *sync.WaitGroup, testr tester.Tester, testUID *uuid.UUID, testDesc *loader.Test, - reportFmt reportFormat) { - wg.Add(1) +func produceReport(globalWaitGroup, testerWaitGroup *sync.WaitGroup, testr tester.Tester, testUID *uuid.UUID, + testDesc *loader.Test, reportFmt reportFormat) { + globalWaitGroup.Add(1) + testerWaitGroup.Add(1) go func() { - defer wg.Done() + defer globalWaitGroup.Done() + defer testerWaitGroup.Done() testName, ruleName := testDesc.Name, testDesc.Rule report := getReport(testr, testUID, ruleName, &testDesc.ExpectedOutcome) report.TestName, report.RuleName = testName, ruleName diff --git a/pkg/test/tester/tester/tester.go b/pkg/test/tester/tester/tester.go index f8feb5a2..8ae05074 100644 --- a/pkg/test/tester/tester/tester.go +++ b/pkg/test/tester/tester/tester.go @@ -65,7 +65,7 @@ func (t *testerImpl) StartAlertsCollection(ctx context.Context) error { } alertInfoCh := t.filterAlertsWithUID(ctx, alertCh) - t.startAlertsCaching(ctx, alertInfoCh) + t.startAlertsCaching(alertInfoCh) return nil } @@ -142,17 +142,9 @@ func (t *testerImpl) findUID(alrt *alert.Alert) *uuid.UUID { } // startAlertsCaching starts caching the alerts received through the provided channel. -func (t *testerImpl) startAlertsCaching(ctx context.Context, alertInfoCh <-chan *alertInfo) { - for { - select { - case <-ctx.Done(): - return - case info, ok := <-alertInfoCh: - if !ok { - return - } - t.cacheAlert(info.uid, info.alert) - } +func (t *testerImpl) startAlertsCaching(alertInfoCh <-chan *alertInfo) { + for info := range alertInfoCh { + t.cacheAlert(info.uid, info.alert) } }