Skip to content

Commit 359202c

Browse files
JamyDevlinzhpsywhang
authored
Support for archive mode (#258)
This adds support for a 3rd option for creating mocks after reflect and source-mode: the archive mode. Archive mode lets you load archive files to create mocks. This can come in handy for writing Bazel rules that produce intermediary archive files and automatically codegen mocks in Bazel environments. Rebased version of #125 --------- Co-authored-by: Zhongpeng Lin <[email protected]> Co-authored-by: Sung Yoon Whang <[email protected]> Co-authored-by: Sung Yoon Whang <[email protected]>
1 parent 871d86b commit 359202c

File tree

4 files changed

+272
-14
lines changed

4 files changed

+272
-14
lines changed

README.md

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,23 @@ export PATH=$PATH:$(go env GOPATH)/bin
4040

4141
## Running mockgen
4242

43-
`mockgen` has two modes of operation: source and package.
43+
`mockgen` has three modes of operation: archive, source and package.
44+
45+
### Archive mode
46+
47+
Archive mode generates mock interfaces from a package archive
48+
file (.a). It is enabled by using the -archive flag. An import
49+
path and a comma-separated list of symbols should be provided
50+
as a non-flag argument to the command.
51+
52+
Example:
53+
54+
```bash
55+
# Build the package to a archive.
56+
go build -o pkg.a database/sql/driver
57+
58+
mockgen -archive=pkg.a database/sql/driver Conn,Driver
59+
```
4460

4561
### Source mode
4662

@@ -77,6 +93,8 @@ The `mockgen` command is used to generate source code for a mock
7793
class given a Go source file containing interfaces to be mocked.
7894
It supports the following flags:
7995

96+
- `-archive`: A package archive file containing interfaces to be mocked.
97+
8098
- `-source`: A file containing interfaces to be mocked.
8199

82100
- `-destination`: A file to which to write the resulting source code. If you

mockgen/archive.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package main
2+
3+
import (
4+
"fmt"
5+
"go/token"
6+
"go/types"
7+
"os"
8+
9+
"go.uber.org/mock/mockgen/model"
10+
11+
"golang.org/x/tools/go/gcexportdata"
12+
)
13+
14+
func archiveMode(importPath string, symbols []string, archive string) (*model.Package, error) {
15+
f, err := os.Open(archive)
16+
if err != nil {
17+
return nil, err
18+
}
19+
defer f.Close()
20+
r, err := gcexportdata.NewReader(f)
21+
if err != nil {
22+
return nil, fmt.Errorf("read export data %q: %v", archive, err)
23+
}
24+
25+
fset := token.NewFileSet()
26+
imports := make(map[string]*types.Package)
27+
tp, err := gcexportdata.Read(r, fset, imports, importPath)
28+
if err != nil {
29+
return nil, err
30+
}
31+
32+
pkg := &model.Package{
33+
Name: tp.Name(),
34+
PkgPath: tp.Path(),
35+
Interfaces: make([]*model.Interface, 0, len(symbols)),
36+
}
37+
for _, name := range symbols {
38+
m := tp.Scope().Lookup(name)
39+
tn, ok := m.(*types.TypeName)
40+
if !ok {
41+
continue
42+
}
43+
ti, ok := tn.Type().Underlying().(*types.Interface)
44+
if !ok {
45+
continue
46+
}
47+
it, err := model.InterfaceFromGoTypesType(ti)
48+
if err != nil {
49+
return nil, err
50+
}
51+
it.Name = m.Name()
52+
pkg.Interfaces = append(pkg.Interfaces, it)
53+
}
54+
return pkg, nil
55+
}

mockgen/mockgen.go

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ var (
5454
)
5555

5656
var (
57+
archive = flag.String("archive", "", "(archive mode) Input Go archive file; enables archive mode.")
5758
source = flag.String("source", "", "(source mode) Input Go source file; enables source mode.")
5859
destination = flag.String("destination", "", "Output file; defaults to stdout.")
5960
mockNames = flag.String("mock_names", "", "Comma-separated interfaceName=mockName pairs of explicit mock names to use. Mock names default to 'Mock'+ interfaceName suffix.")
@@ -68,11 +69,10 @@ var (
6869
typed = flag.Bool("typed", false, "Generate Type-safe 'Return', 'Do', 'DoAndReturn' function")
6970
imports = flag.String("imports", "", "(source mode) Comma-separated name=path pairs of explicit imports to use.")
7071
auxFiles = flag.String("aux_files", "", "(source mode) Comma-separated pkg=path pairs of auxiliary Go source files.")
71-
excludeInterfaces = flag.String("exclude_interfaces", "", "(source mode) Comma-separated names of interfaces to be excluded")
7272
modelGob = flag.String("model_gob", "", "Skip package/source loading entirely and use the gob encoded model.Package at the given path")
73-
74-
debugParser = flag.Bool("debug_parser", false, "Print out parser results only.")
75-
showVersion = flag.Bool("version", false, "Print version.")
73+
excludeInterfaces = flag.String("exclude_interfaces", "", "Comma-separated names of interfaces to be excluded")
74+
debugParser = flag.Bool("debug_parser", false, "Print out parser results only.")
75+
showVersion = flag.Bool("version", false, "Print version.")
7676
)
7777

7878
func main() {
@@ -89,17 +89,24 @@ func main() {
8989
var pkg *model.Package
9090
var err error
9191
var packageName string
92-
if *modelGob != "" {
92+
93+
// Switch between modes
94+
switch {
95+
case *modelGob != "": // gob mode
9396
pkg, err = gobMode(*modelGob)
94-
} else if *source != "" {
97+
case *source != "": // source mode
9598
pkg, err = sourceMode(*source)
96-
} else {
97-
if flag.NArg() != 2 {
98-
usage()
99-
log.Fatal("Expected exactly two arguments")
100-
}
99+
case *archive != "": // archive mode
100+
checkArgs()
101+
packageName = flag.Arg(0)
102+
interfaces := strings.Split(flag.Arg(1), ",")
103+
pkg, err = archiveMode(packageName, interfaces, *archive)
104+
105+
default: // package mode
106+
checkArgs()
101107
packageName = flag.Arg(0)
102108
interfaces := strings.Split(flag.Arg(1), ",")
109+
103110
if packageName == "." {
104111
dir, err := os.Getwd()
105112
if err != nil {
@@ -109,10 +116,12 @@ func main() {
109116
if err != nil {
110117
log.Fatalf("Parse package name failed: %v", err)
111118
}
119+
112120
}
113121
parser := packageModeParser{}
114122
pkg, err = parser.parsePackage(packageName, interfaces)
115123
}
124+
116125
if err != nil {
117126
log.Fatalf("Loading input failed: %v", err)
118127
}
@@ -155,6 +164,8 @@ func main() {
155164
}
156165
if *source != "" {
157166
g.filename = *source
167+
} else if *archive != "" {
168+
g.filename = *archive
158169
} else {
159170
g.srcPackage = packageName
160171
g.srcInterfaces = flag.Arg(1)
@@ -230,12 +241,19 @@ func parseExcludeInterfaces(names string) map[string]struct{} {
230241
return namesSet
231242
}
232243

244+
func checkArgs() {
245+
if flag.NArg() != 2 {
246+
usage()
247+
log.Fatal("Expected exactly two arguments")
248+
}
249+
}
250+
233251
func usage() {
234252
_, _ = io.WriteString(os.Stderr, usageText)
235253
flag.PrintDefaults()
236254
}
237255

238-
const usageText = `mockgen has two modes of operation: source and package.
256+
const usageText = `mockgen has three modes of operation: archive, source and package.
239257
240258
Source mode generates mock interfaces from a source file.
241259
It is enabled by using the -source flag. Other flags that
@@ -245,12 +263,19 @@ Example:
245263
246264
Package mode works by specifying the package and interface names.
247265
It is enabled by passing two non-flag arguments: an import path, and a
248-
comma-separated list of symbols.
266+
comma-separated list of symbols.
249267
You can use "." to refer to the current path's package.
250268
Example:
251269
mockgen database/sql/driver Conn,Driver
252270
mockgen . SomeInterface
253271
272+
Archive mode generates mock interfaces from a package archive
273+
file (.a). It is enabled by using the -archive flag and two
274+
non-flag arguments: an import path, and a comma-separated
275+
list of symbols.
276+
Example:
277+
mockgen -archive=pkg.a database/sql/driver Conn,Driver
278+
254279
`
255280

256281
type generator struct {

mockgen/model/model_gotypes.go

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
package model
2+
3+
import (
4+
"fmt"
5+
"go/types"
6+
)
7+
8+
// InterfaceFromGoTypesType returns a pointer to an interface for the
9+
// given interface type loaded from archive.
10+
func InterfaceFromGoTypesType(it *types.Interface) (*Interface, error) {
11+
intf := &Interface{}
12+
13+
for i := 0; i < it.NumMethods(); i++ {
14+
mt := it.Method(i)
15+
// Skip unexported methods.
16+
if !mt.Exported() {
17+
continue
18+
}
19+
m := &Method{
20+
Name: mt.Name(),
21+
}
22+
23+
var err error
24+
m.In, m.Variadic, m.Out, err = funcArgsFromGoTypesType(mt.Type().(*types.Signature))
25+
if err != nil {
26+
return nil, fmt.Errorf("method %q: %w", mt.Name(), err)
27+
}
28+
29+
intf.AddMethod(m)
30+
}
31+
32+
return intf, nil
33+
}
34+
35+
func funcArgsFromGoTypesType(t *types.Signature) (in []*Parameter, variadic *Parameter, out []*Parameter, err error) {
36+
nin := t.Params().Len()
37+
if t.Variadic() {
38+
nin--
39+
}
40+
for i := 0; i < nin; i++ {
41+
p, err := parameterFromGoTypesType(t.Params().At(i), false)
42+
if err != nil {
43+
return nil, nil, nil, err
44+
}
45+
in = append(in, p)
46+
}
47+
if t.Variadic() {
48+
p, err := parameterFromGoTypesType(t.Params().At(nin), true)
49+
if err != nil {
50+
return nil, nil, nil, err
51+
}
52+
variadic = p
53+
}
54+
for i := 0; i < t.Results().Len(); i++ {
55+
p, err := parameterFromGoTypesType(t.Results().At(i), false)
56+
if err != nil {
57+
return nil, nil, nil, err
58+
}
59+
out = append(out, p)
60+
}
61+
return
62+
}
63+
64+
func parameterFromGoTypesType(v *types.Var, variadic bool) (*Parameter, error) {
65+
t := v.Type()
66+
if variadic {
67+
t = t.(*types.Slice).Elem()
68+
}
69+
tt, err := typeFromGoTypesType(t)
70+
if err != nil {
71+
return nil, err
72+
}
73+
return &Parameter{Name: v.Name(), Type: tt}, nil
74+
}
75+
76+
func typeFromGoTypesType(t types.Type) (Type, error) {
77+
if t, ok := t.(*types.Named); ok {
78+
tn := t.Obj()
79+
if tn.Pkg() == nil {
80+
return PredeclaredType(tn.Name()), nil
81+
}
82+
return &NamedType{
83+
Package: tn.Pkg().Path(),
84+
Type: tn.Name(),
85+
}, nil
86+
}
87+
88+
// only unnamed or predeclared types after here
89+
90+
// Lots of types have element types. Let's do the parsing and error checking for all of them.
91+
var elemType Type
92+
if t, ok := t.(interface{ Elem() types.Type }); ok {
93+
var err error
94+
elemType, err = typeFromGoTypesType(t.Elem())
95+
if err != nil {
96+
return nil, err
97+
}
98+
}
99+
100+
switch t := t.(type) {
101+
case *types.Array:
102+
return &ArrayType{
103+
Len: int(t.Len()),
104+
Type: elemType,
105+
}, nil
106+
case *types.Basic:
107+
return PredeclaredType(t.String()), nil
108+
case *types.Chan:
109+
var dir ChanDir
110+
switch t.Dir() {
111+
case types.RecvOnly:
112+
dir = RecvDir
113+
case types.SendOnly:
114+
dir = SendDir
115+
}
116+
return &ChanType{
117+
Dir: dir,
118+
Type: elemType,
119+
}, nil
120+
case *types.Signature:
121+
in, variadic, out, err := funcArgsFromGoTypesType(t)
122+
if err != nil {
123+
return nil, err
124+
}
125+
return &FuncType{
126+
In: in,
127+
Out: out,
128+
Variadic: variadic,
129+
}, nil
130+
case *types.Interface:
131+
if t.NumMethods() == 0 {
132+
return PredeclaredType("interface{}"), nil
133+
}
134+
case *types.Map:
135+
kt, err := typeFromGoTypesType(t.Key())
136+
if err != nil {
137+
return nil, err
138+
}
139+
return &MapType{
140+
Key: kt,
141+
Value: elemType,
142+
}, nil
143+
case *types.Pointer:
144+
return &PointerType{
145+
Type: elemType,
146+
}, nil
147+
case *types.Slice:
148+
return &ArrayType{
149+
Len: -1,
150+
Type: elemType,
151+
}, nil
152+
case *types.Struct:
153+
if t.NumFields() == 0 {
154+
return PredeclaredType("struct{}"), nil
155+
}
156+
// TODO: UnsafePointer
157+
}
158+
159+
return nil, fmt.Errorf("can't yet turn %v (%T) into a model.Type", t.String(), t)
160+
}

0 commit comments

Comments
 (0)