44 "bytes"
55 "fmt"
66 "go/ast"
7- "go/build"
87 "go/format"
9- "go/parser"
108 "go/token"
119 "io/ioutil"
1210 "log"
@@ -17,7 +15,7 @@ import (
1715
1816 "github.com/pkg/errors"
1917 "github.com/spf13/pflag"
20- "golang.org/x/tools/go/loader "
18+ "golang.org/x/tools/go/packages "
2119 "golang.org/x/tools/imports"
2220)
2321
@@ -27,7 +25,8 @@ type options struct {
2725 debug bool
2826 cmpImportName string
2927 showLoaderErrors bool
30- useAllFiles bool
28+ buildFlags []string
29+ localImportPath string
3130}
3231
3332func main () {
@@ -62,12 +61,14 @@ func setupFlags(name string) (*pflag.FlagSet, *options) {
6261 "import alias to use for the assert/cmp package" )
6362 flags .BoolVar (& opts .showLoaderErrors , "print-loader-errors" , false ,
6463 "print errors from loading source" )
65- flags .BoolVar (& opts .useAllFiles , "ignore-build-tags" , false ,
66- "migrate all files ignoring build tags" )
64+ flags .StringSliceVar (& opts .buildFlags , "build-tags" , nil ,
65+ "build to pass to Go when loading source files" )
66+ flags .StringVar (& opts .localImportPath , "local-import-path" , "" ,
67+ "value to pass to 'goimports -local' flag for sorting local imports" )
6768 flags .Usage = func () {
6869 fmt .Fprintf (os .Stderr , `Usage: %s [OPTIONS] PACKAGE [PACKAGE...]
6970
70- Migrate calls from testify/{assert|require} to gotest.tools/assert.
71+ Migrate calls from testify/{assert|require} to gotest.tools/v3/ assert.
7172
7273%s` , name , flags .FlagUsages ())
7374 }
@@ -87,18 +88,19 @@ func handleExitError(name string, err error) {
8788}
8889
8990func run (opts options ) error {
90- program , err := loadProgram (opts )
91+ imports .LocalPrefix = opts .localImportPath
92+
93+ fset := token .NewFileSet ()
94+ pkgs , err := loadPackages (opts , fset )
9195 if err != nil {
9296 return errors .Wrapf (err , "failed to load program" )
9397 }
9498
95- pkgs := program .InitialPackages ()
9699 debugf ("package count: %d" , len (pkgs ))
97-
98- fileset := program .Fset
99100 for _ , pkg := range pkgs {
100- for _ , astFile := range pkg .Files {
101- absFilename := fileset .File (astFile .Pos ()).Name ()
101+ debugf ("file count for package %v: %d" , pkg .PkgPath , len (pkg .Syntax ))
102+ for _ , astFile := range pkg .Syntax {
103+ absFilename := fset .File (astFile .Pos ()).Name ()
102104 filename := relativePath (absFilename )
103105 importNames := newImportNames (astFile .Imports , opts )
104106 if ! importNames .hasTestifyImports () {
@@ -109,9 +111,9 @@ func run(opts options) error {
109111 debugf ("migrating %s with imports: %#v" , filename , importNames )
110112 m := migration {
111113 file : astFile ,
112- fileset : fileset ,
114+ fileset : fset ,
113115 importNames : importNames ,
114- pkgInfo : pkg ,
116+ pkgInfo : pkg . TypesInfo ,
115117 }
116118 migrateFile (m )
117119 if opts .dryRun {
@@ -132,47 +134,33 @@ func run(opts options) error {
132134 return nil
133135}
134136
135- func loadProgram (opts options ) (* loader.Program , error ) {
136- fakeImporter , err := newFakeImporter ()
137+ var loadMode = packages .NeedName |
138+ packages .NeedFiles |
139+ packages .NeedCompiledGoFiles |
140+ packages .NeedDeps |
141+ packages .NeedImports |
142+ packages .NeedTypes |
143+ packages .NeedTypesInfo |
144+ packages .NeedTypesSizes |
145+ packages .NeedSyntax
146+
147+ func loadPackages (opts options , fset * token.FileSet ) ([]* packages.Package , error ) {
148+ conf := & packages.Config {
149+ Mode : loadMode ,
150+ Fset : fset ,
151+ Tests : true ,
152+ Logf : debugf ,
153+ BuildFlags : opts .buildFlags ,
154+ }
155+
156+ pkgs , err := packages .Load (conf , opts .pkgs ... )
137157 if err != nil {
138158 return nil , err
139159 }
140- defer fakeImporter .Close ()
141-
142- conf := loader.Config {
143- Fset : token .NewFileSet (),
144- ParserMode : parser .ParseComments ,
145- Build : buildContext (opts ),
146- AllowErrors : true ,
147- FindPackage : fakeImporter .Import ,
148- }
149- for _ , pkg := range opts .pkgs {
150- conf .ImportWithTests (pkg )
151- }
152- if ! opts .showLoaderErrors {
153- conf .TypeChecker .Error = func (e error ) {}
154- }
155- program , err := conf .Load ()
156160 if opts .showLoaderErrors {
157- for p , pkg := range program .AllPackages {
158- if len (pkg .Errors ) > 0 {
159- fmt .Printf ("Package %s loaded with some errors:\n " , p .Name ())
160- for _ , err := range pkg .Errors {
161- fmt .Println (" " , err .Error ())
162- }
163- }
164- }
165- }
166- return program , err
167- }
168-
169- func buildContext (opts options ) * build.Context {
170- c := build .Default
171- c .UseAllFiles = opts .useAllFiles
172- if val , ok := os .LookupEnv ("GOPATH" ); ok {
173- c .GOPATH = val
161+ packages .PrintErrors (pkgs )
174162 }
175- return & c
163+ return pkgs , nil
176164}
177165
178166func relativePath (p string ) string {
@@ -214,8 +202,9 @@ func (p importNames) funcNameFromTestifyName(name string) string {
214202}
215203
216204func newImportNames (imports []* ast.ImportSpec , opt options ) importNames {
205+ defaultAssertAlias := path .Base (pkgAssert )
217206 importNames := importNames {
218- assert : path . Base ( pkgAssert ) ,
207+ assert : defaultAssertAlias ,
219208 cmp : path .Base (pkgCmp ),
220209 }
221210 for _ , spec := range imports {
@@ -225,7 +214,18 @@ func newImportNames(imports []*ast.ImportSpec, opt options) importNames {
225214 case pkgTestifyRequire , pkgGopkgTestifyRequire :
226215 importNames .testifyRequire = identOrDefault (spec .Name , "require" )
227216 default :
228- if importedAs (spec , path .Base (pkgAssert )) {
217+ pkgPath := strings .Trim (spec .Path .Value , `"` )
218+
219+ switch {
220+ // v3/assert is already imported and has an alias
221+ case pkgPath == pkgAssert :
222+ if spec .Name != nil && spec .Name .Name != "" {
223+ importNames .assert = spec .Name .Name
224+ }
225+ continue
226+
227+ // some other package is imported as assert
228+ case importedAs (spec , path .Base (pkgAssert )) && importNames .assert == defaultAssertAlias :
229229 importNames .assert = "gtyassert"
230230 }
231231 }
0 commit comments