From 1e2cab4893445078e9f005628654ffe48c660ef2 Mon Sep 17 00:00:00 2001 From: Charlie Egan Date: Thu, 16 Jan 2025 17:15:53 +0000 Subject: [PATCH] fix: add error based on expected regal dir Fixes https://github.com/StyraInc/regal/issues/1341 Signed-off-by: Charlie Egan --- cmd/fix.go | 15 +++++---- pkg/config/config.go | 4 +++ pkg/config/config_test.go | 70 ++++++++++++++++++++++++++------------- pkg/fixer/fixer.go | 7 ++-- 4 files changed, 64 insertions(+), 32 deletions(-) diff --git a/cmd/fix.go b/cmd/fix.go index 457fcf47..d08f43ac 100644 --- a/cmd/fix.go +++ b/cmd/fix.go @@ -270,15 +270,18 @@ func fix(args []string, params *fixCommandParams) error { return fmt.Errorf("could not find potential roots: %w", err) } - versionsMap, err := config.AllRegoVersions(regalDir.Name(), &userConfig) - if err != nil { - return fmt.Errorf("failed to get all Rego versions: %w", err) - } - f := fixer.NewFixer() f.RegisterRoots(roots...) f.RegisterFixes(fixes.NewDefaultFixes()...) - f.SetRegoVersionsMap(versionsMap) + + if userConfigFile != nil { + versionsMap, err := config.AllRegoVersions(filepath.Dir(userConfigFile.Name()), &userConfig) + if err != nil { + return fmt.Errorf("failed to get all Rego versions: %w", err) + } + + f.SetRegoVersionsMap(versionsMap) + } if !slices.Contains([]string{"error", "rename"}, params.conflictMode) { return fmt.Errorf("invalid conflict mode: %s, expected 'error' or 'rename'", params.conflictMode) diff --git a/pkg/config/config.go b/pkg/config/config.go index 50eef99c..3be6ae3d 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -355,6 +355,10 @@ func FromMap(confMap map[string]any) (Config, error) { func AllRegoVersions(root string, conf *Config) (map[string]ast.RegoVersion, error) { versionsMap := make(map[string]ast.RegoVersion) + if conf == nil { + return versionsMap, nil + } + manifestLocations, err := rio.FindManifestLocations(root) if err != nil { return nil, fmt.Errorf("failed to find manifest locations: %w", err) diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 0596d68e..cb133c5a 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -542,37 +542,61 @@ func TestUnmarshalProjectRootsAsStringOrObject(t *testing.T) { func TestAllRegoVersions(t *testing.T) { t.Parallel() - bs := []byte(`project: + testCases := map[string]struct { + Config string + FS map[string]string + Expected map[string]ast.RegoVersion + }{ + "values from config": { + Config: `project: rego-version: 0 roots: - path: foo rego-version: 1 -`) +`, + FS: map[string]string{ + "bar/baz/.manifest": `{"rego_version": 1}`, + }, + Expected: map[string]ast.RegoVersion{ + "": ast.RegoV0, + "bar/baz": ast.RegoV1, + "foo": ast.RegoV1, + }, + }, + "no config": { + Config: "", + FS: map[string]string{ + "bar/baz/.manifest": `{"rego_version": 1}`, + }, + Expected: map[string]ast.RegoVersion{}, + }, + } - var conf Config + for testName, testData := range testCases { + t.Run(testName, func(t *testing.T) { + t.Parallel() - if err := yaml.Unmarshal(bs, &conf); err != nil { - t.Fatal(err) - } + var conf *Config - fs := map[string]string{ - "bar/baz/.manifest": `{"rego_version": 1}`, - } + if testData.Config != "" { + var loadedConf Config + if err := yaml.Unmarshal([]byte(testData.Config), &loadedConf); err != nil { + t.Fatal(err) + } - test.WithTempFS(fs, func(root string) { - versions, err := AllRegoVersions(root, &conf) - if err != nil { - t.Fatal(err) - } + conf = &loadedConf + } - expected := map[string]ast.RegoVersion{ - "": ast.RegoV0, - "foo": ast.RegoV1, - "bar/baz": ast.RegoV1, - } + test.WithTempFS(testData.FS, func(root string) { + versions, err := AllRegoVersions(root, conf) + if err != nil { + t.Fatal(err) + } - if !maps.Equal(versions, expected) { - t.Errorf("expected %v, got %v", expected, versions) - } - }) + if !maps.Equal(versions, testData.Expected) { + t.Errorf("expected %v, got %v", testData.Expected, versions) + } + }) + }) + } } diff --git a/pkg/fixer/fixer.go b/pkg/fixer/fixer.go index b47e9a63..6568b1f7 100644 --- a/pkg/fixer/fixer.go +++ b/pkg/fixer/fixer.go @@ -295,14 +295,15 @@ func (f *Fixer) applyLinterFixes( return fmt.Errorf("failed to list files: %w", err) } - if f.versionsMap == nil { - return errors.New("rego versions map not set") + var versionsMap map[string]ast.RegoVersion + if f.versionsMap != nil { + versionsMap = f.versionsMap } for { fixMadeInIteration := false - in, err := fp.ToInput(f.versionsMap) + in, err := fp.ToInput(versionsMap) if err != nil { return fmt.Errorf("failed to generate linter input: %w", err) }