Skip to content

Commit 043dd32

Browse files
kliukovkinabhinav
andauthored
add interfaces flag with unit test (#200)
This is a proposal for `-interfaces` flag using source mode. Using this flag it is possible to list only required interfaces to be mocked in source mode. Currently, source mode mocks all interfaces which -source file contains. Thanks @KastenMike for raising this! Fixes golang/mock#660 --------- Co-authored-by: Abhinav Gupta <mail@abhinavg.net>
1 parent dcbbb16 commit 043dd32

File tree

4 files changed

+118
-4
lines changed

4 files changed

+118
-4
lines changed

mockgen/internal/tests/typed/bugreport.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package typed
22

3-
//go:generate mockgen -typed -aux_files faux=faux/faux.go -destination bugreport_mock.go -package typed -source=bugreport.go Example
3+
//go:generate mockgen -typed -aux_files faux=faux/faux.go -destination bugreport_mock.go -package typed -source=bugreport.go Source
44

55
import (
66
"log"

mockgen/internal/tests/typed/bugreport_mock.go

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

mockgen/parse.go

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package main
1818

1919
import (
2020
"errors"
21+
"flag"
2122
"fmt"
2223
"go/ast"
2324
"go/build"
@@ -61,6 +62,18 @@ func sourceMode(source string) (*model.Package, error) {
6162
srcDir: srcDir,
6263
}
6364

65+
// positional interface names -> include set
66+
if flag.NArg() > 1 {
67+
return nil, errors.New("-source mode accepts at most one argument")
68+
}
69+
if flag.NArg() == 1 {
70+
ifaces := strings.Split(flag.Arg(0), ",")
71+
p.includeNamesSet = make(map[string]struct{}, len(ifaces))
72+
for _, name := range ifaces {
73+
p.includeNamesSet[name] = struct{}{}
74+
}
75+
}
76+
6477
// Handle -imports.
6578
dotImports := make(map[string]bool)
6679
if *imports != "" {
@@ -92,6 +105,7 @@ func sourceMode(source string) (*model.Package, error) {
92105
for pkgPath := range dotImports {
93106
pkg.DotImports = append(pkg.DotImports, pkgPath)
94107
}
108+
95109
return pkg, nil
96110
}
97111

@@ -168,6 +182,7 @@ type fileParser struct {
168182
auxInterfaces *interfaceCache
169183
srcDir string
170184
excludeNamesSet map[string]struct{}
185+
includeNamesSet map[string]struct{} // empty to include all
171186
}
172187

173188
func (p *fileParser) errorf(pos token.Pos, format string, args ...any) error {
@@ -228,18 +243,31 @@ func (p *fileParser) parseFile(importPath string, file *ast.File) (*model.Packag
228243

229244
var is []*model.Interface
230245
for ni := range iterInterfaces(file) {
231-
if _, ok := p.excludeNamesSet[ni.name.String()]; ok {
246+
name := ni.name.String()
247+
248+
if _, ok := p.excludeNamesSet[name]; ok {
232249
continue
233250
}
234-
i, err := p.parseInterface(ni.name.String(), importPath, ni)
251+
252+
// All interfaces are included if no filter was specified.
253+
if len(p.includeNamesSet) > 0 {
254+
if _, ok := p.includeNamesSet[name]; !ok {
255+
continue
256+
}
257+
}
258+
259+
i, err := p.parseInterface(name, importPath, ni)
235260
if errors.Is(err, errConstraintInterface) {
236261
continue
237262
}
238263
if err != nil {
239264
return nil, err
240265
}
241266
is = append(is, i)
267+
268+
delete(p.includeNamesSet, name)
242269
}
270+
243271
return &model.Package{
244272
Name: file.Name.String(),
245273
PkgPath: importPath,

mockgen/parse_test.go

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package main
33
import (
44
"go/parser"
55
"go/token"
6+
"strings"
67
"testing"
78
)
89

@@ -143,3 +144,88 @@ func TestParseArrayWithConstLength(t *testing.T) {
143144
}
144145
}
145146
}
147+
148+
func TestParseFile_IncludeOnlyRequested(t *testing.T) {
149+
fs := token.NewFileSet()
150+
file, err := parser.ParseFile(fs, "internal/tests/custom_package_name/greeter/greeter.go", nil, 0)
151+
if err != nil {
152+
t.Fatalf("Unexpected error: %v", err)
153+
}
154+
155+
p := fileParser{
156+
fileSet: fs,
157+
imports: make(map[string]importedPackage),
158+
importedInterfaces: newInterfaceCache(),
159+
// include только один интерфейс
160+
includeNamesSet: map[string]struct{}{"InputMaker": {}},
161+
}
162+
163+
pkg, err := p.parseFile("", file)
164+
if err != nil {
165+
t.Fatalf("Unexpected error: %v", err)
166+
}
167+
168+
if len(pkg.Interfaces) != 1 || pkg.Interfaces[0].Name != "InputMaker" {
169+
t.Fatalf("Expected only InputMaker, got %v", pkg.Interfaces)
170+
}
171+
}
172+
173+
// When requested interface is missing, parser should ignore it (no error, no interfaces).
174+
func TestParseFile_IncludeMissing_Ignored(t *testing.T) {
175+
fs := token.NewFileSet()
176+
file, err := parser.ParseFile(fs, "internal/tests/custom_package_name/greeter/greeter.go", nil, 0)
177+
if err != nil {
178+
t.Fatalf("Unexpected error: %v", err)
179+
}
180+
181+
p := fileParser{
182+
fileSet: fs,
183+
imports: make(map[string]importedPackage),
184+
importedInterfaces: newInterfaceCache(),
185+
includeNamesSet: map[string]struct{}{"DoesNotExist": {}},
186+
}
187+
188+
pkg, err := p.parseFile("", file)
189+
if err != nil {
190+
t.Fatalf("Unexpected error: %v", err)
191+
}
192+
if len(pkg.Interfaces) != 0 {
193+
t.Fatalf("Expected no interfaces, got %v", pkg.Interfaces)
194+
}
195+
}
196+
197+
func TestParseFile_IncludeWithDuplicates_Dedupes(t *testing.T) {
198+
fs := token.NewFileSet()
199+
file, err := parser.ParseFile(fs, "internal/tests/custom_package_name/greeter/greeter.go", nil, 0)
200+
if err != nil {
201+
t.Fatalf("Unexpected error: %v", err)
202+
}
203+
204+
// Эмулируем «случайно указали дубликаты» как это делает sourceMode (через позиционные аргументы)
205+
args := []string{"InputMaker", "InputMaker"} // дубликаты
206+
include := make(map[string]struct{})
207+
for _, a := range args {
208+
for _, name := range strings.Split(a, ",") {
209+
name = strings.TrimSpace(name)
210+
if name != "" {
211+
include[name] = struct{}{}
212+
}
213+
}
214+
}
215+
216+
p := fileParser{
217+
fileSet: fs,
218+
imports: make(map[string]importedPackage),
219+
importedInterfaces: newInterfaceCache(),
220+
includeNamesSet: include,
221+
}
222+
223+
pkg, err := p.parseFile("", file)
224+
if err != nil {
225+
t.Fatalf("Unexpected error: %v", err)
226+
}
227+
228+
if len(pkg.Interfaces) != 1 || pkg.Interfaces[0].Name != "InputMaker" {
229+
t.Fatalf("Expected only InputMaker once, got %v", pkg.Interfaces)
230+
}
231+
}

0 commit comments

Comments
 (0)