From d70e6e6726c7b47d2c5fa249d9194b1aadc23e57 Mon Sep 17 00:00:00 2001 From: Charlie Egan Date: Wed, 29 May 2024 17:29:14 +0100 Subject: [PATCH] lsp: Rule head completion improvements These new completion providers will complete keywords and common rule names in a limited number of but common scenarios. Signed-off-by: Charlie Egan --- internal/lsp/completions/manager.go | 12 +- .../lsp/completions/providers/commonrule.go | 90 ++++++++++++ .../completions/providers/commonrule_test.go | 132 +++++++++++++++++ .../completions/providers/ruleheadkeyword.go | 80 ++++++++++ .../providers/ruleheadkeyword_test.go | 138 ++++++++++++++++++ internal/lsp/types/types.go | 5 + 6 files changed, 455 insertions(+), 2 deletions(-) create mode 100644 internal/lsp/completions/providers/commonrule.go create mode 100644 internal/lsp/completions/providers/commonrule_test.go create mode 100644 internal/lsp/completions/providers/ruleheadkeyword.go create mode 100644 internal/lsp/completions/providers/ruleheadkeyword_test.go diff --git a/internal/lsp/completions/manager.go b/internal/lsp/completions/manager.go index 12dec8009..2f9d71ab3 100644 --- a/internal/lsp/completions/manager.go +++ b/internal/lsp/completions/manager.go @@ -36,6 +36,8 @@ func NewDefaultManager(c *cache.Cache) *Manager { m.RegisterProvider(&providers.PackageRefs{}) m.RegisterProvider(&providers.RuleRefs{}) m.RegisterProvider(&providers.RuleHead{}) + m.RegisterProvider(&providers.RuleHeadKeyword{}) + m.RegisterProvider(&providers.CommonRule{}) return m } @@ -49,8 +51,14 @@ func (m *Manager) Run(params types.CompletionParams, opts *providers.Options) ([ return nil, fmt.Errorf("error running completion provider: %w", err) } - if len(providerCompletions) > 0 { - completions = append(completions, providerCompletions...) + for _, completion := range providerCompletions { + // if a provider returns a mandatory completion, return it immediately + // as it is the only completion that should be shown. + if completion.Mandatory { + return []types.CompletionItem{completion}, nil + } + + completions = append(completions, completion) } } diff --git a/internal/lsp/completions/providers/commonrule.go b/internal/lsp/completions/providers/commonrule.go new file mode 100644 index 000000000..da83063a2 --- /dev/null +++ b/internal/lsp/completions/providers/commonrule.go @@ -0,0 +1,90 @@ +//nolint:dupl +package providers + +import ( + "fmt" + "strings" + + "github.com/styrainc/regal/internal/lsp/cache" + "github.com/styrainc/regal/internal/lsp/types" + "github.com/styrainc/regal/internal/lsp/types/completion" +) + +// CommonRule will return completions for new rules based on common Rego rule names. +type CommonRule struct{} + +func (*CommonRule) Run(c *cache.Cache, params types.CompletionParams, _ *Options) ([]types.CompletionItem, error) { + fileURI := params.TextDocument.URI + + _, currentLine := completionLineHelper(c, fileURI, params.Position.Line) + + if patternRuleBody.MatchString(currentLine) { // if in rule body + return []types.CompletionItem{}, nil + } + + words := patternWhiteSpace.Split(currentLine, -1) + if len(words) != 1 { + return []types.CompletionItem{}, nil + } + + // if the file already contains a rule with the same name, we do not want to + // suggest it again. In order to be able to do this later, we need to record + // all the existing rules in the file. + existingRules := make(map[string]struct{}) + + for _, ref := range c.GetFileRefs(fileURI) { + if ref.Kind == types.Rule || ref.Kind == types.ConstantRule || ref.Kind == types.Function { + parts := strings.Split(ref.Label, ".") + existingRules[parts[len(parts)-1]] = struct{}{} + } + } + + lastWord := strings.TrimSpace(currentLine) + + var label string + + var newText string + + for _, word := range []string{"allow", "deny"} { + if strings.HasPrefix(word, lastWord) { + // if the rule is defined, we can skip it as it'll be suggested by + // another provider + if _, ok := existingRules[word]; ok { + return []types.CompletionItem{}, nil + } + + label = word + newText = word + " " + + break + } + } + + if label == "" { + return []types.CompletionItem{}, nil + } + + return []types.CompletionItem{ + { + Label: label, + Kind: completion.Snippet, + Documentation: &types.MarkupContent{ + Kind: "markdown", + Value: fmt.Sprintf("%q is a common rule name", label), + }, + TextEdit: &types.TextEdit{ + Range: types.Range{ + Start: types.Position{ + Line: params.Position.Line, + Character: params.Position.Character - uint(len(lastWord)), + }, + End: types.Position{ + Line: params.Position.Line, + Character: uint(len(currentLine)), + }, + }, + NewText: newText, + }, + }, + }, nil +} diff --git a/internal/lsp/completions/providers/commonrule_test.go b/internal/lsp/completions/providers/commonrule_test.go new file mode 100644 index 000000000..5caa5e0ef --- /dev/null +++ b/internal/lsp/completions/providers/commonrule_test.go @@ -0,0 +1,132 @@ +//nolint:dupl +package providers + +import ( + "testing" + + "github.com/styrainc/regal/internal/lsp/cache" + "github.com/styrainc/regal/internal/lsp/completions/refs" + "github.com/styrainc/regal/internal/lsp/types" + "github.com/styrainc/regal/internal/parse" +) + +func TestCommonRule_TypedA(t *testing.T) { + t.Parallel() + + c := cache.NewCache() + + fileContents := `package policy + +a +` + + c.SetFileContents(testCaseFileURI, fileContents) + + p := &CommonRule{} + + completionParams := types.CompletionParams{ + TextDocument: types.TextDocumentIdentifier{ + URI: testCaseFileURI, + }, + Position: types.Position{ + Line: 2, + Character: 1, + }, + } + + completions, err := p.Run(c, completionParams, nil) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if len(completions) != 1 { + t.Fatalf("Expected exactly one completion, got: %v", completions) + } + + comp := completions[0] + if comp.Label != "allow" { + t.Fatalf("Expected label to be 'allow', got: %v", comp.Label) + } +} + +func TestCommonRule_TypedD(t *testing.T) { + t.Parallel() + + c := cache.NewCache() + + fileContents := `package policy + +d +` + + c.SetFileContents(testCaseFileURI, fileContents) + + p := &CommonRule{} + + completionParams := types.CompletionParams{ + TextDocument: types.TextDocumentIdentifier{ + URI: testCaseFileURI, + }, + Position: types.Position{ + Line: 2, + Character: 1, + }, + } + + completions, err := p.Run(c, completionParams, nil) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if len(completions) != 1 { + t.Fatalf("Expected exactly one completion, got: %v", completions) + } + + comp := completions[0] + if comp.Label != "deny" { + t.Fatalf("Expected label to be 'deny', got: %v", comp.Label) + } +} + +func TestCommonRule_TypedDAlreadyDefined(t *testing.T) { + t.Parallel() + + c := cache.NewCache() + + fileContents := `package policy + +deny := false + +` + + c.SetFileContents(testCaseFileURI, fileContents+"d") + + mod, err := parse.Module(testCaseFileURI, fileContents) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + c.SetModule(testCaseFileURI, mod) + c.SetFileRefs(testCaseFileURI, refs.ForModule(mod)) + + p := &CommonRule{} + + completionParams := types.CompletionParams{ + TextDocument: types.TextDocumentIdentifier{ + URI: testCaseFileURI, + }, + Position: types.Position{ + Line: 4, + Character: 1, + }, + } + + completions, err := p.Run(c, completionParams, nil) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if len(completions) != 0 { + t.Fatalf("Expected no completions, got: %v", completions) + } +} diff --git a/internal/lsp/completions/providers/ruleheadkeyword.go b/internal/lsp/completions/providers/ruleheadkeyword.go new file mode 100644 index 000000000..2c371ad3d --- /dev/null +++ b/internal/lsp/completions/providers/ruleheadkeyword.go @@ -0,0 +1,80 @@ +//nolint:dupl +package providers + +import ( + "strings" + + "github.com/styrainc/regal/internal/lsp/cache" + "github.com/styrainc/regal/internal/lsp/types" + "github.com/styrainc/regal/internal/lsp/types/completion" +) + +// RuleHeadKeyword will return completions for the keywords when starting a new rule. +// The current cases are supported: +// - [rule-name] if +// - [rule-name] contains +// - [rule-name] contains if +// These completions are mandatory, that means they are the only ones to be shown. +type RuleHeadKeyword struct{} + +func (*RuleHeadKeyword) Run(c *cache.Cache, params types.CompletionParams, _ *Options) ([]types.CompletionItem, error) { + fileURI := params.TextDocument.URI + + _, currentLine := completionLineHelper(c, fileURI, params.Position.Line) + + if patternRuleBody.MatchString(currentLine) { // if in rule body + return []types.CompletionItem{}, nil + } + + words := patternWhiteSpace.Split(currentLine, -1) + if len(words) < 2 { + return []types.CompletionItem{}, nil + } + + lastWord := words[len(words)-1] + + const keyWdContains = "contains" + + const keyWdIf = "if" + + var label string + + switch { + // suggest contains after the name of the rule in the rule head + //nolint:gocritic + case len(words) == 2 && strings.HasPrefix(keyWdContains, lastWord): + label = "contains" + // suggest if at the end of the rule head + case len(words) == 4 && words[1] == keyWdContains: + label = keyWdIf + // suggest if after the rule name + //nolint:gocritic + case len(words) == 2 && strings.HasPrefix(keyWdIf, lastWord): + label = keyWdIf + } + + if label == "" { + return []types.CompletionItem{}, nil + } + + return []types.CompletionItem{ + { + Label: label, + Kind: completion.Keyword, + TextEdit: &types.TextEdit{ + Range: types.Range{ + Start: types.Position{ + Line: params.Position.Line, + Character: params.Position.Character - uint(len(lastWord)), + }, + End: types.Position{ + Line: params.Position.Line, + Character: uint(len(currentLine)), + }, + }, + NewText: label + " ", + }, + Mandatory: true, + }, + }, nil +} diff --git a/internal/lsp/completions/providers/ruleheadkeyword_test.go b/internal/lsp/completions/providers/ruleheadkeyword_test.go new file mode 100644 index 000000000..8c879f7c8 --- /dev/null +++ b/internal/lsp/completions/providers/ruleheadkeyword_test.go @@ -0,0 +1,138 @@ +//nolint:dupl +package providers + +import ( + "testing" + + "github.com/styrainc/regal/internal/lsp/cache" + "github.com/styrainc/regal/internal/lsp/types" +) + +func TestRuleHeadKeyword_TypedIAfterRuleName(t *testing.T) { + t.Parallel() + + c := cache.NewCache() + + fileContents := `package policy + +deny i +` + + c.SetFileContents(testCaseFileURI, fileContents) + + p := &RuleHeadKeyword{} + + completionParams := types.CompletionParams{ + TextDocument: types.TextDocumentIdentifier{ + URI: testCaseFileURI, + }, + Position: types.Position{ + Line: 2, + Character: 5, + }, + } + + completions, err := p.Run(c, completionParams, nil) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if len(completions) != 1 { + t.Fatalf("Expected exactly one completion, got: %v", completions) + } + + comp := completions[0] + if comp.Label != "if" { + t.Fatalf("Expected label to be 'if', got: %v", comp.Label) + } + + if comp.Mandatory != true { + t.Fatalf("Expected mandatory to be true, got: %v", comp.Mandatory) + } +} + +func TestRuleHeadKeyword_TypedC(t *testing.T) { + t.Parallel() + + c := cache.NewCache() + + fileContents := `package policy + +deny c +` + + c.SetFileContents(testCaseFileURI, fileContents) + + p := &RuleHeadKeyword{} + + completionParams := types.CompletionParams{ + TextDocument: types.TextDocumentIdentifier{ + URI: testCaseFileURI, + }, + Position: types.Position{ + Line: 2, + Character: 5, + }, + } + + completions, err := p.Run(c, completionParams, nil) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if len(completions) != 1 { + t.Fatalf("Expected exactly one completion, got: %v", completions) + } + + comp := completions[0] + if comp.Label != "contains" { + t.Fatalf("Expected label to be 'contains', got: %v", comp.Label) + } + + if comp.Mandatory != true { + t.Fatalf("Expected mandatory to be true, got: %v", comp.Mandatory) + } +} + +func TestRuleHeadKeyword_TypedI(t *testing.T) { + t.Parallel() + + c := cache.NewCache() + + fileContents := `package policy + +deny contains message i +` + + c.SetFileContents(testCaseFileURI, fileContents) + + p := &RuleHeadKeyword{} + + completionParams := types.CompletionParams{ + TextDocument: types.TextDocumentIdentifier{ + URI: testCaseFileURI, + }, + Position: types.Position{ + Line: 2, + Character: 23, + }, + } + + completions, err := p.Run(c, completionParams, nil) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if len(completions) != 1 { + t.Fatalf("Expected exactly one completion, got: %v", completions) + } + + comp := completions[0] + if comp.Label != "if" { + t.Fatalf("Expected label to be 'if' got: %v", comp.Label) + } + + if comp.Mandatory != true { + t.Fatalf("Expected mandatory to be true, got: %v", comp.Mandatory) + } +} diff --git a/internal/lsp/types/types.go b/internal/lsp/types/types.go index 57b9d2cbf..975df15f7 100644 --- a/internal/lsp/types/types.go +++ b/internal/lsp/types/types.go @@ -136,6 +136,11 @@ type CompletionItem struct { Documentation *MarkupContent `json:"documentation,omitempty"` Preselect bool `json:"preselect"` TextEdit *TextEdit `json:"textEdit,omitempty"` + + // Mandatory is used to indicate that the completion item is mandatory and should be offered + // as an exclusive completion. This is not part of the LSP spec, but used in regal providers + // to indicate that the completion item is the only valid completion. + Mandatory bool `json:"-"` } type CompletionItemLabelDetails struct {