Skip to content

Commit cb74ccb

Browse files
denisvmediachavacava
and
chavacava
authored
chore: Improve sortables detection (#1151)
Co-authored-by: chavacava <salvador.cavadini@gmail.com>
1 parent 72b91f0 commit cb74ccb

File tree

3 files changed

+120
-30
lines changed

3 files changed

+120
-30
lines changed

internal/astutils/ast_utils.go

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package astutils
2+
3+
import "go/ast"
4+
5+
// FuncSignatureIs returns true if the given func decl satisfies a signature characterized
6+
// by the given name, parameters types and return types; false otherwise.
7+
//
8+
// Example: to check if a function declaration has the signature Foo(int, string) (bool,error)
9+
// call to FuncSignatureIs(funcDecl,"Foo",[]string{"int","string"},[]string{"bool","error"})
10+
func FuncSignatureIs(funcDecl *ast.FuncDecl, wantName string, wantParametersTypes, wantResultsTypes []string) bool {
11+
if wantName != funcDecl.Name.String() {
12+
return false // func name doesn't match expected one
13+
}
14+
15+
funcParametersTypes := getTypeNames(funcDecl.Type.Params)
16+
if len(wantParametersTypes) != len(funcParametersTypes) {
17+
return false // func has not the expected number of parameters
18+
}
19+
20+
funcResultsTypes := getTypeNames(funcDecl.Type.Results)
21+
if len(wantResultsTypes) != len(funcResultsTypes) {
22+
return false // func has not the expected number of return values
23+
}
24+
25+
for i, wantType := range wantParametersTypes {
26+
if wantType != funcParametersTypes[i] {
27+
return false // type of a func's parameter does not match the type of the corresponding expected parameter
28+
}
29+
}
30+
31+
for i, wantType := range wantResultsTypes {
32+
if wantType != funcResultsTypes[i] {
33+
return false // type of a func's return value does not match the type of the corresponding expected return value
34+
}
35+
}
36+
37+
return true
38+
}
39+
40+
func getTypeNames(fields *ast.FieldList) []string {
41+
result := []string{}
42+
43+
if fields == nil {
44+
return result
45+
}
46+
47+
for _, field := range fields.List {
48+
typeName := field.Type.(*ast.Ident).Name
49+
if field.Names == nil { // unnamed field
50+
result = append(result, typeName)
51+
continue
52+
}
53+
54+
for range field.Names { // add one type name for each field name
55+
result = append(result, typeName)
56+
}
57+
}
58+
59+
return result
60+
}

lint/package.go

+42-30
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package lint
22

33
import (
4+
"errors"
45
"go/ast"
56
"go/importer"
67
"go/token"
@@ -9,6 +10,7 @@ import (
910

1011
goversion "github.com/hashicorp/go-version"
1112

13+
"github.com/mgechev/revive/internal/astutils"
1214
"github.com/mgechev/revive/internal/typeparams"
1315
)
1416

@@ -31,7 +33,6 @@ type Package struct {
3133
var (
3234
trueValue = 1
3335
falseValue = 2
34-
notSet = 3
3536

3637
go121 = goversion.Must(goversion.NewVersion("1.21"))
3738
go122 = goversion.Must(goversion.NewVersion("1.22"))
@@ -111,6 +112,11 @@ func (p *Package) TypeCheck() error {
111112
astFiles = append(astFiles, f.AST)
112113
}
113114

115+
if anyFile == nil {
116+
// this is unlikely to happen, but technically guarantees anyFile to not be nil
117+
return errors.New("no ast.File found")
118+
}
119+
114120
typesPkg, err := check(config, anyFile.AST.Name.Name, p.fset, astFiles, info)
115121

116122
// Remember the typechecking info, even if config.Check failed,
@@ -135,47 +141,40 @@ func check(config *types.Config, n string, fset *token.FileSet, astFiles []*ast.
135141
return config.Check(n, fset, astFiles, info)
136142
}
137143

138-
// TypeOf returns the type of an expression.
144+
// TypeOf returns the type of expression.
139145
func (p *Package) TypeOf(expr ast.Expr) types.Type {
140146
if p.typesInfo == nil {
141147
return nil
142148
}
143149
return p.typesInfo.TypeOf(expr)
144150
}
145151

146-
type walker struct {
147-
nmap map[string]int
148-
has map[string]int
149-
}
152+
type sortableMethodsFlags int
150153

151-
func (w *walker) Visit(n ast.Node) ast.Visitor {
152-
fn, ok := n.(*ast.FuncDecl)
153-
if !ok || fn.Recv == nil || len(fn.Recv.List) == 0 {
154-
return w
155-
}
156-
// TODO(dsymonds): We could check the signature to be more precise.
157-
recv := typeparams.ReceiverType(fn)
158-
if i, ok := w.nmap[fn.Name.Name]; ok {
159-
w.has[recv] |= i
160-
}
161-
return w
162-
}
154+
// flags for sortable interface methods.
155+
const (
156+
bfLen sortableMethodsFlags = 1 << iota
157+
bfLess
158+
bfSwap
159+
)
163160

164161
func (p *Package) scanSortable() {
165-
p.sortable = map[string]bool{}
166-
167-
// bitfield for which methods exist on each type.
168-
const (
169-
bfLen = 1 << iota
170-
bfLess
171-
bfSwap
172-
)
173-
nmap := map[string]int{"Len": bfLen, "Less": bfLess, "Swap": bfSwap}
174-
has := map[string]int{}
162+
sortableFlags := map[string]sortableMethodsFlags{}
175163
for _, f := range p.files {
176-
ast.Walk(&walker{nmap, has}, f.AST)
164+
for _, decl := range f.AST.Decls {
165+
fn, ok := decl.(*ast.FuncDecl)
166+
isAMethodDeclaration := ok && fn.Recv != nil && len(fn.Recv.List) != 0
167+
if !isAMethodDeclaration {
168+
continue
169+
}
170+
171+
recvType := typeparams.ReceiverType(fn)
172+
sortableFlags[recvType] |= getSortableMethodFlagForFunction(fn)
173+
}
177174
}
178-
for typ, ms := range has {
175+
176+
p.sortable = make(map[string]bool, len(sortableFlags))
177+
for typ, ms := range sortableFlags {
179178
if ms == bfLen|bfLess|bfSwap {
180179
p.sortable[typ] = true
181180
}
@@ -204,3 +203,16 @@ func (p *Package) IsAtLeastGo121() bool {
204203
func (p *Package) IsAtLeastGo122() bool {
205204
return p.goVersion.GreaterThanOrEqual(go122)
206205
}
206+
207+
func getSortableMethodFlagForFunction(fn *ast.FuncDecl) sortableMethodsFlags {
208+
switch {
209+
case astutils.FuncSignatureIs(fn, "Len", []string{}, []string{"int"}):
210+
return bfLen
211+
case astutils.FuncSignatureIs(fn, "Less", []string{"int", "int"}, []string{"bool"}):
212+
return bfLess
213+
case astutils.FuncSignatureIs(fn, "Swap", []string{"int", "int"}, []string{}):
214+
return bfSwap
215+
default:
216+
return 0
217+
}
218+
}

testdata/golint/sort.go

+18
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,21 @@ func (u U) Less(i, j int) bool { return u[i] < u[j] }
1818
func (u U) Swap(i, j int) { u[i], u[j] = u[j], u[i] }
1919

2020
func (u U) Other() {} // MATCH /exported method U.Other should have comment or be unexported/
21+
22+
// V is ...
23+
type V []int
24+
25+
func (v V) Len() (result int) { return len(w) }
26+
func (v V) Less(i int, j int) (result bool) { return w[i] < w[j] }
27+
func (v V) Swap(i int, j int) { v[i], v[j] = v[j], v[i] }
28+
29+
// W is ...
30+
type W []int
31+
32+
func (w W) Swap(i int, j int) {} // MATCH /exported method W.Swap should have comment or be unexported/
33+
34+
// Vv is ...
35+
type Vv []int
36+
37+
func (vv Vv) Len() (result int) { return len(w) } // MATCH /exported method Vv.Len should have comment or be unexported/
38+
func (vv Vv) Less(i int, j int) (result bool) { return w[i] < w[j] } // MATCH /exported method Vv.Less should have comment or be unexported/

0 commit comments

Comments
 (0)