Skip to content
This repository has been archived by the owner on Jun 27, 2023. It is now read-only.

mockgen: handle more cases of "duplicate" imports #405

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions mockgen/internal/tests/import_embedded_interface/bugreport.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//go:generate mockgen -destination bugreport_mock.go -package bugreport -source=bugreport.go
stevendanna marked this conversation as resolved.
Show resolved Hide resolved

package bugreport

import (
"log"

"github.com/golang/mock/mockgen/internal/tests/import_embedded_interface/ersatz"
"github.com/golang/mock/mockgen/internal/tests/import_embedded_interface/faux"
)

// Source is an interface w/ an embedded foreign interface
type Source interface {
ersatz.Embedded
faux.Foreign
}

func CallForeignMethod(s Source) {
log.Println(s.Ersatz())
log.Println(s.OtherErsatz())
}
63 changes: 63 additions & 0 deletions mockgen/internal/tests/import_embedded_interface/bugreport_mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 18 additions & 0 deletions mockgen/internal/tests/import_embedded_interface/bugreport_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package bugreport

import (
"testing"

"github.com/golang/mock/gomock"
)

// TestValidInterface assesses whether or not the generated mock is valid
func TestValidInterface(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

s := NewMockSource(ctrl)
s.EXPECT().Ersatz().Return("")
s.EXPECT().OtherErsatz().Return("")
CallForeignMethod(s)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package ersatz

type Embedded interface {
Ersatz() Return
}

type Return interface{}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package faux

import "github.com/golang/mock/mockgen/internal/tests/import_embedded_interface/other/log"

func Conflict1() {
log.Foo()
}
15 changes: 15 additions & 0 deletions mockgen/internal/tests/import_embedded_interface/faux/faux.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package faux

import (
"log"

"github.com/golang/mock/mockgen/internal/tests/import_embedded_interface/other/ersatz"
)

type Foreign interface {
ersatz.Embedded
}

func Conflict0() {
log.Println()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package ersatz

type Embedded interface {
OtherErsatz() Return
}

type Return interface{}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package log

func Foo() {}
143 changes: 103 additions & 40 deletions mockgen/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func sourceMode(source string) (*model.Package, error) {

p := &fileParser{
fileSet: fs,
imports: make(map[string]string),
imports: make(map[string]importedPackage),
importedInterfaces: make(map[string]map[string]*ast.InterfaceType),
auxInterfaces: make(map[string]map[string]*ast.InterfaceType),
srcDir: srcDir,
Expand All @@ -79,7 +79,7 @@ func sourceMode(source string) (*model.Package, error) {
dotImports[v] = true
} else {
// TODO: Catch dupes?
p.imports[k] = v
p.imports[k] = importedPkg{path: v}
}
}
}
Expand All @@ -100,9 +100,39 @@ func sourceMode(source string) (*model.Package, error) {
return pkg, nil
}

type importedPackage interface {
Path() string
Parser() *fileParser
}

type importedPkg struct {
path string
parser *fileParser
}

func (i importedPkg) Path() string { return i.path }
func (i importedPkg) Parser() *fileParser { return i.parser }

// duplicateImport is a bit of a misnomer. Currently the parser can't
// handle cases of multi-file packages importing different packages
// under the same name. Often these imports would not be problematic,
// so this type lets us defer raising an error unless the package name
// is actually used.
type duplicateImport struct {
name string
duplicates []string
}

func (d duplicateImport) Error() string {
return fmt.Sprintf("%q is ambigous because of duplicate imports: %v", d.name, d.duplicates)
}

func (d duplicateImport) Path() string { log.Fatal(d.Error()); return "" }
func (d duplicateImport) Parser() *fileParser { log.Fatal(d.Error()); return nil }

type fileParser struct {
fileSet *token.FileSet
imports map[string]string // package name => import path
imports map[string]importedPackage // package name => imported package
importedInterfaces map[string]map[string]*ast.InterfaceType // package (or "") => name => interface

auxFiles []*ast.File
Expand Down Expand Up @@ -154,18 +184,18 @@ func (p *fileParser) addAuxInterfacesFromFile(pkg string, file *ast.File) {
func (p *fileParser) parseFile(importPath string, file *ast.File) (*model.Package, error) {
allImports, dotImports := importsOfFile(file)
// Don't stomp imports provided by -imports. Those should take precedence.
for pkg, pkgPath := range allImports {
for pkg, pkgI := range allImports {
if _, ok := p.imports[pkg]; !ok {
p.imports[pkg] = pkgPath
p.imports[pkg] = pkgI
}
}
// Add imports from auxiliary files, which might be needed for embedded interfaces.
// Don't stomp any other imports.
for _, f := range p.auxFiles {
auxImports, _ := importsOfFile(f)
for pkg, pkgPath := range auxImports {
for pkg, pkgI := range auxImports {
if _, ok := p.imports[pkg]; !ok {
p.imports[pkg] = pkgPath
p.imports[pkg] = pkgI
}
}
}
Expand All @@ -186,31 +216,38 @@ func (p *fileParser) parseFile(importPath string, file *ast.File) (*model.Packag
}, nil
}

// parsePackage loads package specified by path, parses it and populates
// corresponding imports and importedInterfaces into the fileParser.
func (p *fileParser) parsePackage(path string) error {
// parsePackage loads package specified by path, parses it and returns
// a new fileParser with the parsed imports and interfaces.
func (p *fileParser) parsePackage(path string) (*fileParser, error) {
newP := &fileParser{
fileSet: token.NewFileSet(),
imports: make(map[string]importedPackage),
importedInterfaces: make(map[string]map[string]*ast.InterfaceType),
auxInterfaces: make(map[string]map[string]*ast.InterfaceType),
srcDir: p.srcDir,
}

var pkgs map[string]*ast.Package
if imp, err := build.Import(path, p.srcDir, build.FindOnly); err != nil {
return err
} else if pkgs, err = parser.ParseDir(p.fileSet, imp.Dir, nil, 0); err != nil {
return err
if imp, err := build.Import(path, newP.srcDir, build.FindOnly); err != nil {
return nil, err
} else if pkgs, err = parser.ParseDir(newP.fileSet, imp.Dir, nil, 0); err != nil {
return nil, err
}

for _, pkg := range pkgs {
file := ast.MergePackageFiles(pkg, ast.FilterFuncDuplicates|ast.FilterUnassociatedComments|ast.FilterImportDuplicates)
if _, ok := p.importedInterfaces[path]; !ok {
p.importedInterfaces[path] = make(map[string]*ast.InterfaceType)
newP.importedInterfaces[path] = make(map[string]*ast.InterfaceType)
}
for ni := range iterInterfaces(file) {
p.importedInterfaces[path][ni.name.Name] = ni.it
newP.importedInterfaces[path][ni.name.Name] = ni.it
}
imports, _ := importsOfFile(file)
for pkgName, pkgPath := range imports {
if _, ok := p.imports[pkgName]; !ok {
p.imports[pkgName] = pkgPath
}
for pkgName, pkgI := range imports {
newP.imports[pkgName] = pkgI
}
}
return nil
return newP, nil
}

func (p *fileParser) parseInterface(name, pkg string, it *ast.InterfaceType) (*model.Interface, error) {
Expand Down Expand Up @@ -252,21 +289,36 @@ func (p *fileParser) parseInterface(name, pkg string, it *ast.InterfaceType) (*m
if !ok {
return nil, p.errorf(v.X.Pos(), "unknown package %s", fpkg)
}

var eintf *model.Interface
var err error
ei := p.auxInterfaces[fpkg][sel]
if ei == nil {
fpkg = epkg
if _, ok = p.importedInterfaces[epkg]; !ok {
if err := p.parsePackage(epkg); err != nil {
return nil, p.errorf(v.Pos(), "could not parse package %s: %v", fpkg, err)
if ei != nil {
eintf, err = p.parseInterface(sel, fpkg, ei)
if err != nil {
return nil, err
}
} else {
path := epkg.Path()
parser := epkg.Parser()
if parser == nil {
ip, err := p.parsePackage(path)
if err != nil {
return nil, p.errorf(v.Pos(), "could not parse package %s: %v", path, err)
}
parser = ip
p.imports[fpkg] = importedPkg{
path: epkg.Path(),
parser: parser,
}
}
if ei = p.importedInterfaces[epkg][sel]; ei == nil {
return nil, p.errorf(v.Pos(), "unknown embedded interface %s.%s", fpkg, sel)
if ei = parser.importedInterfaces[path][sel]; ei == nil {
return nil, p.errorf(v.Pos(), "unknown embedded interface %s.%s", path, sel)
}
eintf, err = parser.parseInterface(sel, path, ei)
if err != nil {
return nil, err
}
}
eintf, err := p.parseInterface(sel, fpkg, ei)
if err != nil {
return nil, err
}
// Copy the methods.
// TODO: apply shadowing rules.
Expand Down Expand Up @@ -383,7 +435,7 @@ func (p *fileParser) parseType(pkg string, typ ast.Expr) (model.Type, error) {
// if so, patch the import w/ the fully qualified import
maybeImportedPkg, ok := p.imports[pkg]
if ok {
pkg = maybeImportedPkg
pkg = maybeImportedPkg.Path()
}
// assume type in this package
return &model.NamedType{Package: pkg, Type: v.Name}, nil
Expand Down Expand Up @@ -412,7 +464,7 @@ func (p *fileParser) parseType(pkg string, typ ast.Expr) (model.Type, error) {
if !ok {
return nil, p.errorf(v.Pos(), "unknown package %q", pkgName)
}
return &model.NamedType{Package: pkg, Type: v.Sel.String()}, nil
return &model.NamedType{Package: pkg.Path(), Type: v.Sel.String()}, nil
case *ast.StarExpr:
t, err := p.parseType(pkg, v.X)
if err != nil {
Expand All @@ -431,7 +483,7 @@ func (p *fileParser) parseType(pkg string, typ ast.Expr) (model.Type, error) {

// importsOfFile returns a map of package name to import path
// of the imports in file.
func importsOfFile(file *ast.File) (normalImports map[string]string, dotImports []string) {
func importsOfFile(file *ast.File) (normalImports map[string]importedPackage, dotImports []string) {
var importPaths []string
for _, is := range file.Imports {
if is.Name != nil {
Expand All @@ -441,7 +493,7 @@ func importsOfFile(file *ast.File) (normalImports map[string]string, dotImports
importPaths = append(importPaths, importPath)
}
packagesName := createPackageMap(importPaths)
normalImports = make(map[string]string)
normalImports = make(map[string]importedPackage)
dotImports = make([]string, 0)
for _, is := range file.Imports {
var pkgName string
Expand Down Expand Up @@ -469,11 +521,22 @@ func importsOfFile(file *ast.File) (normalImports map[string]string, dotImports
if pkgName == "." {
dotImports = append(dotImports, importPath)
} else {

if _, ok := normalImports[pkgName]; ok {
log.Fatalf("imported package collision: %q imported twice", pkgName)
if pkg, ok := normalImports[pkgName]; ok {
switch p := pkg.(type) {
case duplicateImport:
normalImports[pkgName] = duplicateImport{
name: p.name,
duplicates: append([]string{importPath}, p.duplicates...),
}
case importedPkg:
normalImports[pkgName] = duplicateImport{
name: pkgName,
duplicates: []string{p.path, importPath},
}
}
} else {
normalImports[pkgName] = importedPkg{path: importPath}
}
normalImports[pkgName] = importPath
}
}
return
Expand Down
Loading