Skip to content

Commit

Permalink
lsp: Rule head completion improvements
Browse files Browse the repository at this point in the history
These new completion providers will complete keywords and common rule
names in a limited number of but common scenarios.

Signed-off-by: Charlie Egan <charlie@styra.com>
  • Loading branch information
charlieegan3 committed May 29, 2024
1 parent 51bd977 commit d70e6e6
Show file tree
Hide file tree
Showing 6 changed files with 455 additions and 2 deletions.
12 changes: 10 additions & 2 deletions internal/lsp/completions/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ func NewDefaultManager(c *cache.Cache) *Manager {
m.RegisterProvider(&providers.PackageRefs{})
m.RegisterProvider(&providers.RuleRefs{})
m.RegisterProvider(&providers.RuleHead{})
m.RegisterProvider(&providers.RuleHeadKeyword{})
m.RegisterProvider(&providers.CommonRule{})

return m
}
Expand All @@ -49,8 +51,14 @@ func (m *Manager) Run(params types.CompletionParams, opts *providers.Options) ([
return nil, fmt.Errorf("error running completion provider: %w", err)
}

if len(providerCompletions) > 0 {
completions = append(completions, providerCompletions...)
for _, completion := range providerCompletions {
// if a provider returns a mandatory completion, return it immediately
// as it is the only completion that should be shown.
if completion.Mandatory {
return []types.CompletionItem{completion}, nil
}

completions = append(completions, completion)
}
}

Expand Down
90 changes: 90 additions & 0 deletions internal/lsp/completions/providers/commonrule.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
//nolint:dupl
package providers

import (
"fmt"
"strings"

"github.com/styrainc/regal/internal/lsp/cache"
"github.com/styrainc/regal/internal/lsp/types"
"github.com/styrainc/regal/internal/lsp/types/completion"
)

// CommonRule will return completions for new rules based on common Rego rule names.
type CommonRule struct{}

func (*CommonRule) Run(c *cache.Cache, params types.CompletionParams, _ *Options) ([]types.CompletionItem, error) {
fileURI := params.TextDocument.URI

_, currentLine := completionLineHelper(c, fileURI, params.Position.Line)

if patternRuleBody.MatchString(currentLine) { // if in rule body
return []types.CompletionItem{}, nil
}

words := patternWhiteSpace.Split(currentLine, -1)
if len(words) != 1 {
return []types.CompletionItem{}, nil
}

// if the file already contains a rule with the same name, we do not want to
// suggest it again. In order to be able to do this later, we need to record
// all the existing rules in the file.
existingRules := make(map[string]struct{})

for _, ref := range c.GetFileRefs(fileURI) {
if ref.Kind == types.Rule || ref.Kind == types.ConstantRule || ref.Kind == types.Function {
parts := strings.Split(ref.Label, ".")
existingRules[parts[len(parts)-1]] = struct{}{}
}
}

lastWord := strings.TrimSpace(currentLine)

var label string

var newText string

for _, word := range []string{"allow", "deny"} {
if strings.HasPrefix(word, lastWord) {
// if the rule is defined, we can skip it as it'll be suggested by
// another provider
if _, ok := existingRules[word]; ok {
return []types.CompletionItem{}, nil
}

label = word
newText = word + " "

break
}
}

if label == "" {
return []types.CompletionItem{}, nil
}

return []types.CompletionItem{
{
Label: label,
Kind: completion.Snippet,
Documentation: &types.MarkupContent{
Kind: "markdown",
Value: fmt.Sprintf("%q is a common rule name", label),
},
TextEdit: &types.TextEdit{
Range: types.Range{
Start: types.Position{
Line: params.Position.Line,
Character: params.Position.Character - uint(len(lastWord)),
},
End: types.Position{
Line: params.Position.Line,
Character: uint(len(currentLine)),
},
},
NewText: newText,
},
},
}, nil
}
132 changes: 132 additions & 0 deletions internal/lsp/completions/providers/commonrule_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
//nolint:dupl
package providers

import (
"testing"

"github.com/styrainc/regal/internal/lsp/cache"
"github.com/styrainc/regal/internal/lsp/completions/refs"
"github.com/styrainc/regal/internal/lsp/types"
"github.com/styrainc/regal/internal/parse"
)

func TestCommonRule_TypedA(t *testing.T) {
t.Parallel()

c := cache.NewCache()

fileContents := `package policy
a
`

c.SetFileContents(testCaseFileURI, fileContents)

p := &CommonRule{}

completionParams := types.CompletionParams{
TextDocument: types.TextDocumentIdentifier{
URI: testCaseFileURI,
},
Position: types.Position{
Line: 2,
Character: 1,
},
}

completions, err := p.Run(c, completionParams, nil)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

if len(completions) != 1 {
t.Fatalf("Expected exactly one completion, got: %v", completions)
}

comp := completions[0]
if comp.Label != "allow" {
t.Fatalf("Expected label to be 'allow', got: %v", comp.Label)
}
}

func TestCommonRule_TypedD(t *testing.T) {
t.Parallel()

c := cache.NewCache()

fileContents := `package policy
d
`

c.SetFileContents(testCaseFileURI, fileContents)

p := &CommonRule{}

completionParams := types.CompletionParams{
TextDocument: types.TextDocumentIdentifier{
URI: testCaseFileURI,
},
Position: types.Position{
Line: 2,
Character: 1,
},
}

completions, err := p.Run(c, completionParams, nil)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

if len(completions) != 1 {
t.Fatalf("Expected exactly one completion, got: %v", completions)
}

comp := completions[0]
if comp.Label != "deny" {
t.Fatalf("Expected label to be 'deny', got: %v", comp.Label)
}
}

func TestCommonRule_TypedDAlreadyDefined(t *testing.T) {
t.Parallel()

c := cache.NewCache()

fileContents := `package policy
deny := false
`

c.SetFileContents(testCaseFileURI, fileContents+"d")

mod, err := parse.Module(testCaseFileURI, fileContents)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

c.SetModule(testCaseFileURI, mod)
c.SetFileRefs(testCaseFileURI, refs.ForModule(mod))

p := &CommonRule{}

completionParams := types.CompletionParams{
TextDocument: types.TextDocumentIdentifier{
URI: testCaseFileURI,
},
Position: types.Position{
Line: 4,
Character: 1,
},
}

completions, err := p.Run(c, completionParams, nil)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

if len(completions) != 0 {
t.Fatalf("Expected no completions, got: %v", completions)
}
}
80 changes: 80 additions & 0 deletions internal/lsp/completions/providers/ruleheadkeyword.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
//nolint:dupl
package providers

import (
"strings"

"github.com/styrainc/regal/internal/lsp/cache"
"github.com/styrainc/regal/internal/lsp/types"
"github.com/styrainc/regal/internal/lsp/types/completion"
)

// RuleHeadKeyword will return completions for the keywords when starting a new rule.
// The current cases are supported:
// - [rule-name] if
// - [rule-name] contains
// - [rule-name] contains if
// These completions are mandatory, that means they are the only ones to be shown.
type RuleHeadKeyword struct{}

func (*RuleHeadKeyword) Run(c *cache.Cache, params types.CompletionParams, _ *Options) ([]types.CompletionItem, error) {
fileURI := params.TextDocument.URI

_, currentLine := completionLineHelper(c, fileURI, params.Position.Line)

if patternRuleBody.MatchString(currentLine) { // if in rule body
return []types.CompletionItem{}, nil
}

words := patternWhiteSpace.Split(currentLine, -1)
if len(words) < 2 {
return []types.CompletionItem{}, nil
}

lastWord := words[len(words)-1]

const keyWdContains = "contains"

const keyWdIf = "if"

var label string

switch {
// suggest contains after the name of the rule in the rule head
//nolint:gocritic
case len(words) == 2 && strings.HasPrefix(keyWdContains, lastWord):
label = "contains"
// suggest if at the end of the rule head
case len(words) == 4 && words[1] == keyWdContains:
label = keyWdIf
// suggest if after the rule name
//nolint:gocritic
case len(words) == 2 && strings.HasPrefix(keyWdIf, lastWord):
label = keyWdIf
}

if label == "" {
return []types.CompletionItem{}, nil
}

return []types.CompletionItem{
{
Label: label,
Kind: completion.Keyword,
TextEdit: &types.TextEdit{
Range: types.Range{
Start: types.Position{
Line: params.Position.Line,
Character: params.Position.Character - uint(len(lastWord)),
},
End: types.Position{
Line: params.Position.Line,
Character: uint(len(currentLine)),
},
},
NewText: label + " ",
},
Mandatory: true,
},
}, nil
}
Loading

0 comments on commit d70e6e6

Please # to comment.