From fa1efd11307f6d21903cf6a2064de44c249d2272 Mon Sep 17 00:00:00 2001 From: Denis Voytyuk <5462781+denisvmedia@users.noreply.github.com> Date: Sat, 30 Nov 2024 10:32:30 +0100 Subject: [PATCH 1/5] chore: Improve sortables detection --- lint/package.go | 69 ++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 57 insertions(+), 12 deletions(-) diff --git a/lint/package.go b/lint/package.go index 873f8a002..e212205ee 100644 --- a/lint/package.go +++ b/lint/package.go @@ -1,6 +1,7 @@ package lint import ( + "errors" "go/ast" "go/importer" "go/token" @@ -31,7 +32,6 @@ type Package struct { var ( trueValue = 1 falseValue = 2 - notSet = 3 go121 = goversion.Must(goversion.NewVersion("1.21")) go122 = goversion.Must(goversion.NewVersion("1.22")) @@ -111,6 +111,11 @@ func (p *Package) TypeCheck() error { astFiles = append(astFiles, f.AST) } + if anyFile == nil { + // this is unlikely to happen, but technically guarantees anyFile to not be nil + return errors.New("no ast.File found") + } + typesPkg, err := check(config, anyFile.AST.Name.Name, p.fset, astFiles, info) // Remember the typechecking info, even if config.Check failed, @@ -135,7 +140,7 @@ func check(config *types.Config, n string, fset *token.FileSet, astFiles []*ast. return config.Check(n, fset, astFiles, info) } -// TypeOf returns the type of an expression. +// TypeOf returns the type of expression. func (p *Package) TypeOf(expr ast.Expr) types.Type { if p.typesInfo == nil { return nil @@ -148,32 +153,72 @@ type walker struct { has map[string]int } +// bitfield for which methods exist on each type. +const ( + bfLen = 1 << iota + bfLess + bfSwap +) + func (w *walker) Visit(n ast.Node) ast.Visitor { fn, ok := n.(*ast.FuncDecl) if !ok || fn.Recv == nil || len(fn.Recv.List) == 0 { return w } - // TODO(dsymonds): We could check the signature to be more precise. + recv := typeparams.ReceiverType(fn) - if i, ok := w.nmap[fn.Name.Name]; ok { - w.has[recv] |= i + + // Ensure the method signature matches expectations. + switch fn.Name.Name { + case "Len": + if fn.Type.Params.NumFields() == 0 && fn.Type.Results.NumFields() == 1 { + resultType := fn.Type.Results.List[0].Type + if _, ok := resultType.(*ast.Ident); ok && resultType.(*ast.Ident).Name == "int" { + w.has[recv] |= bfLen + } + } + case "Less": + if fn.Type.Params.NumFields() == 2 && fn.Type.Results.NumFields() == 1 { + param1 := fn.Type.Params.List[0].Type + var param2 ast.Expr + if len(fn.Type.Params.List) == 2 { + param2 = fn.Type.Params.List[1].Type + } else { + param2 = param1 + } + resultType := fn.Type.Results.List[0].Type + + // Ensure parameters have the same type and the result is a bool. + if typesEqual(param1, param2) && isBool(resultType) { + w.has[recv] |= bfLess + } + } + case "Swap": + if fn.Type.Params.NumFields() == 2 && fn.Type.Results.NumFields() == 0 { + w.has[recv] |= bfSwap + } } return w } +func typesEqual(a, b ast.Expr) bool { + identA, okA := a.(*ast.Ident) + identB, okB := b.(*ast.Ident) + return okA && okB && identA.Name == identB.Name +} + +func isBool(t ast.Expr) bool { + ident, ok := t.(*ast.Ident) + return ok && ident.Name == "bool" +} + func (p *Package) scanSortable() { p.sortable = map[string]bool{} - // bitfield for which methods exist on each type. - const ( - bfLen = 1 << iota - bfLess - bfSwap - ) nmap := map[string]int{"Len": bfLen, "Less": bfLess, "Swap": bfSwap} has := map[string]int{} for _, f := range p.files { - ast.Walk(&walker{nmap, has}, f.AST) + ast.Walk(&walker{nmap: nmap, has: has}, f.AST) } for typ, ms := range has { if ms == bfLen|bfLess|bfSwap { From 4cd7820ae4ff52c8d54b8d698b9b36b7fabab612 Mon Sep 17 00:00:00 2001 From: chavacava Date: Sun, 1 Dec 2024 16:11:54 +0100 Subject: [PATCH 2/5] refactors detection of types implementing sortable interface --- lint/package.go | 113 +++++++++++++++++++++++++--------------- testdata/golint/sort.go | 4 ++ 2 files changed, 75 insertions(+), 42 deletions(-) diff --git a/lint/package.go b/lint/package.go index e212205ee..42eb946ce 100644 --- a/lint/package.go +++ b/lint/package.go @@ -166,50 +166,12 @@ func (w *walker) Visit(n ast.Node) ast.Visitor { return w } - recv := typeparams.ReceiverType(fn) - - // Ensure the method signature matches expectations. - switch fn.Name.Name { - case "Len": - if fn.Type.Params.NumFields() == 0 && fn.Type.Results.NumFields() == 1 { - resultType := fn.Type.Results.List[0].Type - if _, ok := resultType.(*ast.Ident); ok && resultType.(*ast.Ident).Name == "int" { - w.has[recv] |= bfLen - } - } - case "Less": - if fn.Type.Params.NumFields() == 2 && fn.Type.Results.NumFields() == 1 { - param1 := fn.Type.Params.List[0].Type - var param2 ast.Expr - if len(fn.Type.Params.List) == 2 { - param2 = fn.Type.Params.List[1].Type - } else { - param2 = param1 - } - resultType := fn.Type.Results.List[0].Type - - // Ensure parameters have the same type and the result is a bool. - if typesEqual(param1, param2) && isBool(resultType) { - w.has[recv] |= bfLess - } - } - case "Swap": - if fn.Type.Params.NumFields() == 2 && fn.Type.Results.NumFields() == 0 { - w.has[recv] |= bfSwap - } - } - return w -} + recvType := typeparams.ReceiverType(fn) + bf := getBitfieldForFunction(fn) -func typesEqual(a, b ast.Expr) bool { - identA, okA := a.(*ast.Ident) - identB, okB := b.(*ast.Ident) - return okA && okB && identA.Name == identB.Name -} + w.has[recvType] |= bf -func isBool(t ast.Expr) bool { - ident, ok := t.(*ast.Ident) - return ok && ident.Name == "bool" + return w } func (p *Package) scanSortable() { @@ -249,3 +211,70 @@ func (p *Package) IsAtLeastGo121() bool { func (p *Package) IsAtLeastGo122() bool { return p.goVersion.GreaterThanOrEqual(go122) } + +func getBitfieldForFunction(fn *ast.FuncDecl) int { + switch { + case funcSignatureIs(fn, "Len", []string{}, []string{"int"}): + return bfLen + case funcSignatureIs(fn, "Less", []string{"int", "int"}, []string{"bool"}): + return bfLess + case funcSignatureIs(fn, "Swap", []string{"int", "int"}, []string{}): + return bfSwap + default: + return 0 + } +} + +// funcSignatureIs returns true if the given func decl satisfies has a signature characterized +// by the given name, parameters types and return types; false otherwise +func funcSignatureIs(funcDecl *ast.FuncDecl, wantName string, wantParametersTypes, wantResultsTypes []string) bool { + if wantName != funcDecl.Name.String() { + return false // func name doesn't match expected one + } + + funcParametersTypes := getTypeNames(funcDecl.Type.Params) + if len(wantParametersTypes) != len(funcParametersTypes) { + return false // func has not the expected number of parameters + } + + funcResultsTypes := getTypeNames(funcDecl.Type.Results) + if len(wantResultsTypes) != len(funcResultsTypes) { + return false // func has not the expected number of return values + } + + for i, wantType := range wantParametersTypes { + if wantType != funcParametersTypes[i] { + return false // type of a func's parameter does not match the type of the corresponding expected parameter + } + } + + for i, wantType := range wantResultsTypes { + if wantType != funcResultsTypes[i] { + return false // type of a func's return value does not match the type of the corresponding expected return value + } + } + + return true +} + +func getTypeNames(fields *ast.FieldList) []string { + result := []string{} + + if fields == nil { + return result + } + + for _, field := range fields.List { + typeName := field.Type.(*ast.Ident).Name + if field.Names == nil { // unnamed field + result = append(result, typeName) + continue + } + + for range field.Names { // add one type name for each field name + result = append(result, typeName) + } + } + + return result +} diff --git a/testdata/golint/sort.go b/testdata/golint/sort.go index 331ce7167..953ab452a 100644 --- a/testdata/golint/sort.go +++ b/testdata/golint/sort.go @@ -17,4 +17,8 @@ func (u U) Len() int { return len(u) } func (u U) Less(i, j int) bool { return u[i] < u[j] } func (u U) Swap(i, j int) { u[i], u[j] = u[j], u[i] } +func (u U) Len() (result int) { return len(u) } +func (u U) Less(i int, j int) (result bool) { return u[i] < u[j] } +func (u U) Swap(i int, j int) { u[i], u[j] = u[j], u[i] } + func (u U) Other() {} // MATCH /exported method U.Other should have comment or be unexported/ From a6c01cceabb7f5a09ae63cd27a8c6b7bcf1d5427 Mon Sep 17 00:00:00 2001 From: chavacava Date: Sun, 1 Dec 2024 16:28:04 +0100 Subject: [PATCH 3/5] removes unused field from walker and adds type for sortable method flags --- lint/package.go | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/lint/package.go b/lint/package.go index 42eb946ce..6a8b23c43 100644 --- a/lint/package.go +++ b/lint/package.go @@ -148,14 +148,14 @@ func (p *Package) TypeOf(expr ast.Expr) types.Type { return p.typesInfo.TypeOf(expr) } +type sortableMethodsFlags int type walker struct { - nmap map[string]int - has map[string]int + sortableMethodFlagsByTypeName map[string]sortableMethodsFlags } -// bitfield for which methods exist on each type. +// flags for sortable interface methods. const ( - bfLen = 1 << iota + bfLen sortableMethodsFlags = 1 << iota bfLess bfSwap ) @@ -167,9 +167,7 @@ func (w *walker) Visit(n ast.Node) ast.Visitor { } recvType := typeparams.ReceiverType(fn) - bf := getBitfieldForFunction(fn) - - w.has[recvType] |= bf + w.sortableMethodFlagsByTypeName[recvType] |= getSortableMethodFlagForFunction(fn) return w } @@ -177,12 +175,11 @@ func (w *walker) Visit(n ast.Node) ast.Visitor { func (p *Package) scanSortable() { p.sortable = map[string]bool{} - nmap := map[string]int{"Len": bfLen, "Less": bfLess, "Swap": bfSwap} - has := map[string]int{} + sortableFlags := map[string]sortableMethodsFlags{} for _, f := range p.files { - ast.Walk(&walker{nmap: nmap, has: has}, f.AST) + ast.Walk(&walker{sortableMethodFlagsByTypeName: sortableFlags}, f.AST) } - for typ, ms := range has { + for typ, ms := range sortableFlags { if ms == bfLen|bfLess|bfSwap { p.sortable[typ] = true } @@ -212,7 +209,7 @@ func (p *Package) IsAtLeastGo122() bool { return p.goVersion.GreaterThanOrEqual(go122) } -func getBitfieldForFunction(fn *ast.FuncDecl) int { +func getSortableMethodFlagForFunction(fn *ast.FuncDecl) sortableMethodsFlags { switch { case funcSignatureIs(fn, "Len", []string{}, []string{"int"}): return bfLen From 8fc7733bf3bf3ffa05b979cd5c797421715af013 Mon Sep 17 00:00:00 2001 From: chavacava Date: Sun, 1 Dec 2024 16:40:18 +0100 Subject: [PATCH 4/5] moves auxiliary functions to a new internal/astutils package --- internal/astutils/ast_utils.go | 60 +++++++++++++++++++++++++++++++++ lint/package.go | 61 +++------------------------------- 2 files changed, 64 insertions(+), 57 deletions(-) create mode 100644 internal/astutils/ast_utils.go diff --git a/internal/astutils/ast_utils.go b/internal/astutils/ast_utils.go new file mode 100644 index 000000000..5a66f11fa --- /dev/null +++ b/internal/astutils/ast_utils.go @@ -0,0 +1,60 @@ +package astutils + +import "go/ast" + +// FuncSignatureIs returns true if the given func decl satisfies a signature characterized +// by the given name, parameters types and return types; false otherwise. +// +// Example: to check if a function declaration has the signature Foo(int, string) (bool,error) +// call to FuncSignatureIs(funcDecl,"Foo",[]string{"int","string"},[]string{"bool","error"}) +func FuncSignatureIs(funcDecl *ast.FuncDecl, wantName string, wantParametersTypes, wantResultsTypes []string) bool { + if wantName != funcDecl.Name.String() { + return false // func name doesn't match expected one + } + + funcParametersTypes := getTypeNames(funcDecl.Type.Params) + if len(wantParametersTypes) != len(funcParametersTypes) { + return false // func has not the expected number of parameters + } + + funcResultsTypes := getTypeNames(funcDecl.Type.Results) + if len(wantResultsTypes) != len(funcResultsTypes) { + return false // func has not the expected number of return values + } + + for i, wantType := range wantParametersTypes { + if wantType != funcParametersTypes[i] { + return false // type of a func's parameter does not match the type of the corresponding expected parameter + } + } + + for i, wantType := range wantResultsTypes { + if wantType != funcResultsTypes[i] { + return false // type of a func's return value does not match the type of the corresponding expected return value + } + } + + return true +} + +func getTypeNames(fields *ast.FieldList) []string { + result := []string{} + + if fields == nil { + return result + } + + for _, field := range fields.List { + typeName := field.Type.(*ast.Ident).Name + if field.Names == nil { // unnamed field + result = append(result, typeName) + continue + } + + for range field.Names { // add one type name for each field name + result = append(result, typeName) + } + } + + return result +} diff --git a/lint/package.go b/lint/package.go index 6a8b23c43..66cd3ba79 100644 --- a/lint/package.go +++ b/lint/package.go @@ -10,6 +10,7 @@ import ( goversion "github.com/hashicorp/go-version" + "github.com/mgechev/revive/internal/astutils" "github.com/mgechev/revive/internal/typeparams" ) @@ -211,67 +212,13 @@ func (p *Package) IsAtLeastGo122() bool { func getSortableMethodFlagForFunction(fn *ast.FuncDecl) sortableMethodsFlags { switch { - case funcSignatureIs(fn, "Len", []string{}, []string{"int"}): + case astutils.FuncSignatureIs(fn, "Len", []string{}, []string{"int"}): return bfLen - case funcSignatureIs(fn, "Less", []string{"int", "int"}, []string{"bool"}): + case astutils.FuncSignatureIs(fn, "Less", []string{"int", "int"}, []string{"bool"}): return bfLess - case funcSignatureIs(fn, "Swap", []string{"int", "int"}, []string{}): + case astutils.FuncSignatureIs(fn, "Swap", []string{"int", "int"}, []string{}): return bfSwap default: return 0 } } - -// funcSignatureIs returns true if the given func decl satisfies has a signature characterized -// by the given name, parameters types and return types; false otherwise -func funcSignatureIs(funcDecl *ast.FuncDecl, wantName string, wantParametersTypes, wantResultsTypes []string) bool { - if wantName != funcDecl.Name.String() { - return false // func name doesn't match expected one - } - - funcParametersTypes := getTypeNames(funcDecl.Type.Params) - if len(wantParametersTypes) != len(funcParametersTypes) { - return false // func has not the expected number of parameters - } - - funcResultsTypes := getTypeNames(funcDecl.Type.Results) - if len(wantResultsTypes) != len(funcResultsTypes) { - return false // func has not the expected number of return values - } - - for i, wantType := range wantParametersTypes { - if wantType != funcParametersTypes[i] { - return false // type of a func's parameter does not match the type of the corresponding expected parameter - } - } - - for i, wantType := range wantResultsTypes { - if wantType != funcResultsTypes[i] { - return false // type of a func's return value does not match the type of the corresponding expected return value - } - } - - return true -} - -func getTypeNames(fields *ast.FieldList) []string { - result := []string{} - - if fields == nil { - return result - } - - for _, field := range fields.List { - typeName := field.Type.(*ast.Ident).Name - if field.Names == nil { // unnamed field - result = append(result, typeName) - continue - } - - for range field.Names { // add one type name for each field name - result = append(result, typeName) - } - } - - return result -} From f468e69f2e7151a526309ecb30f95cb5fb206d74 Mon Sep 17 00:00:00 2001 From: chavacava Date: Mon, 2 Dec 2024 07:00:10 +0100 Subject: [PATCH 5/5] replaces unnecessary AST walker by a iteration on declarations --- lint/package.go | 30 ++++++++++++------------------ testdata/golint/sort.go | 22 ++++++++++++++++++---- 2 files changed, 30 insertions(+), 22 deletions(-) diff --git a/lint/package.go b/lint/package.go index 66cd3ba79..41575f480 100644 --- a/lint/package.go +++ b/lint/package.go @@ -150,9 +150,6 @@ func (p *Package) TypeOf(expr ast.Expr) types.Type { } type sortableMethodsFlags int -type walker struct { - sortableMethodFlagsByTypeName map[string]sortableMethodsFlags -} // flags for sortable interface methods. const ( @@ -161,25 +158,22 @@ const ( bfSwap ) -func (w *walker) Visit(n ast.Node) ast.Visitor { - fn, ok := n.(*ast.FuncDecl) - if !ok || fn.Recv == nil || len(fn.Recv.List) == 0 { - return w - } - - recvType := typeparams.ReceiverType(fn) - w.sortableMethodFlagsByTypeName[recvType] |= getSortableMethodFlagForFunction(fn) - - return w -} - func (p *Package) scanSortable() { - p.sortable = map[string]bool{} - sortableFlags := map[string]sortableMethodsFlags{} for _, f := range p.files { - ast.Walk(&walker{sortableMethodFlagsByTypeName: sortableFlags}, f.AST) + for _, decl := range f.AST.Decls { + fn, ok := decl.(*ast.FuncDecl) + isAMethodDeclaration := ok && fn.Recv != nil && len(fn.Recv.List) != 0 + if !isAMethodDeclaration { + continue + } + + recvType := typeparams.ReceiverType(fn) + sortableFlags[recvType] |= getSortableMethodFlagForFunction(fn) + } } + + p.sortable = make(map[string]bool, len(sortableFlags)) for typ, ms := range sortableFlags { if ms == bfLen|bfLess|bfSwap { p.sortable[typ] = true diff --git a/testdata/golint/sort.go b/testdata/golint/sort.go index 953ab452a..80901f48b 100644 --- a/testdata/golint/sort.go +++ b/testdata/golint/sort.go @@ -17,8 +17,22 @@ func (u U) Len() int { return len(u) } func (u U) Less(i, j int) bool { return u[i] < u[j] } func (u U) Swap(i, j int) { u[i], u[j] = u[j], u[i] } -func (u U) Len() (result int) { return len(u) } -func (u U) Less(i int, j int) (result bool) { return u[i] < u[j] } -func (u U) Swap(i int, j int) { u[i], u[j] = u[j], u[i] } - func (u U) Other() {} // MATCH /exported method U.Other should have comment or be unexported/ + +// V is ... +type V []int + +func (v V) Len() (result int) { return len(w) } +func (v V) Less(i int, j int) (result bool) { return w[i] < w[j] } +func (v V) Swap(i int, j int) { v[i], v[j] = v[j], v[i] } + +// W is ... +type W []int + +func (w W) Swap(i int, j int) {} // MATCH /exported method W.Swap should have comment or be unexported/ + +// Vv is ... +type Vv []int + +func (vv Vv) Len() (result int) { return len(w) } // MATCH /exported method Vv.Len should have comment or be unexported/ +func (vv Vv) Less(i int, j int) (result bool) { return w[i] < w[j] } // MATCH /exported method Vv.Less should have comment or be unexported/