diff --git a/rule/modifies_value_receiver.go b/rule/modifies_value_receiver.go index 2f92991f5..d811c486b 100644 --- a/rule/modifies_value_receiver.go +++ b/rule/modifies_value_receiver.go @@ -12,99 +12,35 @@ import ( type ModifiesValRecRule struct{} // Apply applies the rule to given file. -func (*ModifiesValRecRule) Apply(file *lint.File, _ lint.Arguments) []lint.Failure { +func (r *ModifiesValRecRule) Apply(file *lint.File, _ lint.Arguments) []lint.Failure { var failures []lint.Failure - onFailure := func(failure lint.Failure) { - failures = append(failures, failure) - } - - w := lintModifiesValRecRule{file: file, onFailure: onFailure} file.Pkg.TypeCheck() - ast.Walk(w, file.AST) - - return failures -} - -// Name returns the rule name. -func (*ModifiesValRecRule) Name() string { - return "modifies-value-receiver" -} - -type lintModifiesValRecRule struct { - file *lint.File - onFailure func(lint.Failure) -} - -func (w lintModifiesValRecRule) Visit(node ast.Node) ast.Visitor { - switch n := node.(type) { - case *ast.FuncDecl: - if n.Recv == nil { - return nil // skip, not a method - } - - receiver := n.Recv.List[0] - if _, ok := receiver.Type.(*ast.StarExpr); ok { - return nil // skip, method with pointer receiver + for _, decl := range file.AST.Decls { + funcDecl, ok := decl.(*ast.FuncDecl) + isAMethod := ok && funcDecl.Recv != nil + if !isAMethod { + continue // skip, not a method } - if w.skipType(receiver.Type) { - return nil // skip, receiver is a map or array - } - - if len(receiver.Names) < 1 { - return nil // skip, anonymous receiver + receiver := funcDecl.Recv.List[0] + if r.mustSkip(receiver, file.Pkg) { + continue } receiverName := receiver.Names[0].Name - if receiverName == "_" { - return nil // skip, anonymous receiver - } - - receiverAssignmentFinder := func(n ast.Node) bool { - // look for assignments with the receiver in the right hand - assignment, ok := n.(*ast.AssignStmt) - if !ok { - return false - } - - for _, exp := range assignment.Lhs { - switch e := exp.(type) { - case *ast.IndexExpr: // receiver...[] = ... - continue - case *ast.StarExpr: // *receiver = ... - continue - case *ast.SelectorExpr: // receiver.field = ... - name := w.getNameFromExpr(e.X) - if name == "" || name != receiverName { - continue - } - case *ast.Ident: // receiver := ... - if e.Name != receiverName { - continue - } - default: - continue - } - - return true - } - - return false - } - - assignmentsToReceiver := pick(n.Body, receiverAssignmentFinder) + assignmentsToReceiver := r.getAssignmentsToReceiver(receiverName, funcDecl.Body) if len(assignmentsToReceiver) == 0 { - return nil // receiver is not modified + continue // receiver is not modified } - methodReturnsReceiver := len(w.findReturnReceiverStatements(receiverName, n.Body)) > 0 + methodReturnsReceiver := len(r.findReturnReceiverStatements(receiverName, funcDecl.Body)) > 0 if methodReturnsReceiver { - return nil // modification seems legit (see issue #1066) + continue // modification seems legit (see issue #1066) } for _, assignment := range assignmentsToReceiver { - w.onFailure(lint.Failure{ + failures = append(failures, lint.Failure{ Node: assignment, Confidence: 1, Failure: "suspicious assignment to a by-value method receiver", @@ -112,11 +48,16 @@ func (w lintModifiesValRecRule) Visit(node ast.Node) ast.Visitor { } } - return w + return failures +} + +// Name returns the rule name. +func (*ModifiesValRecRule) Name() string { + return "modifies-value-receiver" } -func (w lintModifiesValRecRule) skipType(t ast.Expr) bool { - rt := w.file.Pkg.TypeOf(t) +func (r *ModifiesValRecRule) skipType(t ast.Expr, pkg *lint.Package) bool { + rt := pkg.TypeOf(t) if rt == nil { return false } @@ -128,7 +69,7 @@ func (w lintModifiesValRecRule) skipType(t ast.Expr) bool { return strings.HasPrefix(rtName, "[]") || strings.HasPrefix(rtName, "map[") } -func (lintModifiesValRecRule) getNameFromExpr(ie ast.Expr) string { +func (*ModifiesValRecRule) getNameFromExpr(ie ast.Expr) string { ident, ok := ie.(*ast.Ident) if !ok { return "" @@ -137,7 +78,7 @@ func (lintModifiesValRecRule) getNameFromExpr(ie ast.Expr) string { return ident.Name } -func (w lintModifiesValRecRule) findReturnReceiverStatements(receiverName string, target ast.Node) []ast.Node { +func (r *ModifiesValRecRule) findReturnReceiverStatements(receiverName string, target ast.Node) []ast.Node { finder := func(n ast.Node) bool { // look for returns with the receiver as value returnStatement, ok := n.(*ast.ReturnStmt) @@ -148,7 +89,7 @@ func (w lintModifiesValRecRule) findReturnReceiverStatements(receiverName string for _, exp := range returnStatement.Results { switch e := exp.(type) { case *ast.SelectorExpr: // receiver.field = ... - name := w.getNameFromExpr(e.X) + name := r.getNameFromExpr(e.X) if name == "" || name != receiverName { continue } @@ -160,7 +101,7 @@ func (w lintModifiesValRecRule) findReturnReceiverStatements(receiverName string if e.Op != token.AND { continue } - name := w.getNameFromExpr(e.X) + name := r.getNameFromExpr(e.X) if name == "" || name != receiverName { continue } @@ -177,3 +118,60 @@ func (w lintModifiesValRecRule) findReturnReceiverStatements(receiverName string return pick(target, finder) } + +func (r *ModifiesValRecRule) mustSkip(receiver *ast.Field, pkg *lint.Package) bool { + if _, ok := receiver.Type.(*ast.StarExpr); ok { + return true // skip, method with pointer receiver + } + + if len(receiver.Names) < 1 { + return true // skip, anonymous receiver + } + + receiverName := receiver.Names[0].Name + if receiverName == "_" { + return true // skip, anonymous receiver + } + + if r.skipType(receiver.Type, pkg) { + return true // skip, receiver is a map or array + } + + return false +} + +func (r *ModifiesValRecRule) getAssignmentsToReceiver(receiverName string, funcBody *ast.BlockStmt) []ast.Node { + receiverAssignmentFinder := func(n ast.Node) bool { + // look for assignments with the receiver in the right hand + assignment, ok := n.(*ast.AssignStmt) + if !ok { + return false + } + + for _, exp := range assignment.Lhs { + switch e := exp.(type) { + case *ast.IndexExpr: // receiver...[] = ... + continue + case *ast.StarExpr: // *receiver = ... + continue + case *ast.SelectorExpr: // receiver.field = ... + name := r.getNameFromExpr(e.X) + if name == "" || name != receiverName { + continue + } + case *ast.Ident: // receiver := ... + if e.Name != receiverName { + continue + } + default: + continue + } + + return true + } + + return false + } + + return pick(funcBody, receiverAssignmentFinder) +}