@@ -6,23 +6,23 @@ import type {
66 Expression ,
77 Identifier ,
88 ImportDeclaration ,
9- ImportExpression ,
109 VariableDeclaration ,
1110} from 'estree'
1211import type { SourceMap } from 'magic-string'
12+ import type { RollupAstNode } from 'rollup'
1313import type { Plugin , Rollup } from 'vite'
1414import type { Node , Positioned } from './esmWalker'
1515import { findNodeAround } from 'acorn-walk'
1616import MagicString from 'magic-string'
1717import { createFilter } from 'vite'
18- import { esmWalker , getArbitraryModuleIdentifier } from './esmWalker'
18+ import { esmWalker } from './esmWalker'
1919
2020interface HoistMocksOptions {
2121 /**
2222 * List of modules that should always be imported before compiler hints.
23- * @default [ 'vitest']
23+ * @default 'vitest'
2424 */
25- hoistedModules ?: string [ ]
25+ hoistedModule ?: string
2626 /**
2727 * @default ["vi", "vitest"]
2828 */
@@ -106,11 +106,14 @@ function isIdentifier(node: any): node is Positioned<Identifier> {
106106 return node . type === 'Identifier'
107107}
108108
109- function getBetterEnd ( code : string , node : Node ) {
109+ function getNodeTail ( code : string , node : Node ) {
110110 let end = node . end
111111 if ( code [ node . end ] === ';' ) {
112112 end += 1
113113 }
114+ if ( code [ node . end ] === '\n' ) {
115+ return end + 1
116+ }
114117 if ( code [ node . end + 1 ] === '\n' ) {
115118 end += 1
116119 }
@@ -160,48 +163,43 @@ export function hoistMocks(
160163 dynamicImportMockMethodNames = [ 'mock' , 'unmock' , 'doMock' , 'doUnmock' ] ,
161164 hoistedMethodNames = [ 'hoisted' ] ,
162165 utilsObjectNames = [ 'vi' , 'vitest' ] ,
163- hoistedModules = [ 'vitest' ] ,
166+ hoistedModule = 'vitest' ,
164167 } = options
165168
166- const hoistIndex = code . match ( hashbangRE ) ?. [ 0 ] . length ?? 0
169+ // hoist at the start of the file, after the hashbang
170+ let hoistIndex = hashbangRE . exec ( code ) ?. [ 0 ] . length ?? 0
167171
168172 let hoistedModuleImported = false
169173
170174 let uid = 0
171175 const idToImportMap = new Map < string , string > ( )
172176
177+ const imports : {
178+ node : RollupAstNode < ImportDeclaration >
179+ id : string
180+ } [ ] = [ ]
181+
173182 // this will transform import statements into dynamic ones, if there are imports
174183 // it will keep the import as is, if we don't need to mock anything
175184 // in browser environment it will wrap the module value with "vitest_wrap_module" function
176185 // that returns a proxy to the module so that named exports can be mocked
177- const transformImportDeclaration = ( node : ImportDeclaration ) => {
178- const source = node . source . value as string
179-
180- const importId = `__vi_import_${ uid ++ } __`
181- const hasSpecifiers = node . specifiers . length > 0
182- const code = hasSpecifiers
183- ? `const ${ importId } = await import('${ source } ')\n`
184- : `await import('${ source } ')\n`
185- return {
186- code,
187- id : importId ,
188- }
189- }
190-
191- function defineImport ( node : Positioned < ImportDeclaration > ) {
186+ function defineImport (
187+ importNode : ImportDeclaration & {
188+ start : number
189+ end : number
190+ } ,
191+ ) {
192+ const source = importNode . source . value as string
192193 // always hoist vitest import to top of the file, so
193194 // "vi" helpers can access it
194- if ( hoistedModules . includes ( node . source . value as string ) ) {
195+ if ( hoistedModule === source ) {
195196 hoistedModuleImported = true
196197 return
197198 }
199+ const importId = `__vi_import_${ uid ++ } __`
200+ imports . push ( { id : importId , node : importNode } )
198201
199- const declaration = transformImportDeclaration ( node )
200- if ( ! declaration ) {
201- return null
202- }
203- s . appendLeft ( hoistIndex , declaration . code )
204- return declaration . id
202+ return importId
205203 }
206204
207205 // 1. check all import statements and record id -> importName map
@@ -214,13 +212,20 @@ export function hoistMocks(
214212 if ( ! importId ) {
215213 continue
216214 }
217- s . remove ( node . start , getBetterEnd ( code , node ) )
218215 for ( const spec of node . specifiers ) {
219216 if ( spec . type === 'ImportSpecifier' ) {
220- idToImportMap . set (
221- spec . local . name ,
222- `${ importId } .${ getArbitraryModuleIdentifier ( spec . imported ) } ` ,
223- )
217+ if ( spec . imported . type === 'Identifier' ) {
218+ idToImportMap . set (
219+ spec . local . name ,
220+ `${ importId } .${ spec . imported . name } ` ,
221+ )
222+ }
223+ else {
224+ idToImportMap . set (
225+ spec . local . name ,
226+ `${ importId } [${ JSON . stringify ( spec . imported . value as string ) } ]` ,
227+ )
228+ }
224229 }
225230 else if ( spec . type === 'ImportDefaultSpecifier' ) {
226231 idToImportMap . set ( spec . local . name , `${ importId } .default` )
@@ -235,7 +240,7 @@ export function hoistMocks(
235240
236241 const declaredConst = new Set < string > ( )
237242 const hoistedNodes : Positioned <
238- CallExpression | VariableDeclaration | AwaitExpression
243+ CallExpression | VariableDeclaration | AwaitExpression
239244 > [ ] = [ ]
240245
241246 function createSyntaxError ( node : Positioned < Node > , message : string ) {
@@ -300,6 +305,8 @@ export function hoistMocks(
300305 }
301306 }
302307
308+ const usedUtilityExports = new Set < string > ( )
309+
303310 esmWalker ( ast , {
304311 onIdentifier ( id , info , parentStack ) {
305312 const binding = idToImportMap . get ( id . name )
@@ -333,6 +340,7 @@ export function hoistMocks(
333340 && isIdentifier ( node . callee . property )
334341 ) {
335342 const methodName = node . callee . property . name
343+ usedUtilityExports . add ( node . callee . object . name )
336344
337345 if ( hoistableMockMethodNames . includes ( methodName ) ) {
338346 const method = `${ node . callee . object . name } .${ methodName } `
@@ -347,6 +355,35 @@ export function hoistMocks(
347355 `Cannot export the result of "${ method } ". Remove export declaration because "${ method } " doesn\'t return anything.` ,
348356 )
349357 }
358+ // rewrite vi.mock(import('..')) into vi.mock('..')
359+ if (
360+ node . type === 'CallExpression'
361+ && node . callee . type === 'MemberExpression'
362+ && dynamicImportMockMethodNames . includes ( ( node . callee . property as Identifier ) . name )
363+ ) {
364+ const moduleInfo = node . arguments [ 0 ] as Positioned < Expression >
365+ // vi.mock(import('./path')) -> vi.mock('./path')
366+ if ( moduleInfo . type === 'ImportExpression' ) {
367+ const source = moduleInfo . source as Positioned < Expression >
368+ s . overwrite (
369+ moduleInfo . start ,
370+ moduleInfo . end ,
371+ s . slice ( source . start , source . end ) ,
372+ )
373+ }
374+ // vi.mock(await import('./path')) -> vi.mock('./path')
375+ if (
376+ moduleInfo . type === 'AwaitExpression'
377+ && moduleInfo . argument . type === 'ImportExpression'
378+ ) {
379+ const source = moduleInfo . argument . source as Positioned < Expression >
380+ s . overwrite (
381+ moduleInfo . start ,
382+ moduleInfo . end ,
383+ s . slice ( source . start , source . end ) ,
384+ )
385+ }
386+ }
350387 hoistedNodes . push ( node )
351388 }
352389 // vi.doMock(import('./path')) -> vi.doMock('./path')
@@ -394,9 +431,8 @@ export function hoistMocks(
394431 'AwaitExpression' ,
395432 ) ?. node as Positioned < AwaitExpression > | undefined
396433 // hoist "await vi.hoisted(async () => {})" or "vi.hoisted(() => {})"
397- hoistedNodes . push (
398- awaitedExpression ?. argument === node ? awaitedExpression : node ,
399- )
434+ const moveNode = awaitedExpression ?. argument === node ? awaitedExpression : node
435+ hoistedNodes . push ( moveNode )
400436 }
401437 }
402438 }
@@ -446,24 +482,6 @@ export function hoistMocks(
446482 )
447483 }
448484
449- function rewriteMockDynamicImport (
450- nodeCode : string ,
451- moduleInfo : Positioned < ImportExpression > ,
452- expressionStart : number ,
453- expressionEnd : number ,
454- mockStart : number ,
455- ) {
456- const source = moduleInfo . source as Positioned < Expression >
457- const importPath = s . slice ( source . start , source . end )
458- const nodeCodeStart = expressionStart - mockStart
459- const nodeCodeEnd = expressionEnd - mockStart
460- return (
461- nodeCode . slice ( 0 , nodeCodeStart )
462- + importPath
463- + nodeCode . slice ( nodeCodeEnd )
464- )
465- }
466-
467485 // validate hoistedNodes doesn't have nodes inside other nodes
468486 for ( let i = 0 ; i < hoistedNodes . length ; i ++ ) {
469487 const node = hoistedNodes [ i ]
@@ -479,61 +497,55 @@ export function hoistMocks(
479497 }
480498 }
481499
482- // Wait for imports to be hoisted and then hoist the mocks
483- const hoistedCode = hoistedNodes
484- . map ( ( node ) => {
485- const end = getBetterEnd ( code , node )
486- /**
487- * In the following case, we need to change the `user` to user: __vi_import_x__.user
488- * So we should get the latest code from `s`.
489- *
490- * import user from './user'
491- * vi.mock('./mock.js', () => ({ getSession: vi.fn().mockImplementation(() => ({ user })) }))
492- */
493- let nodeCode = s . slice ( node . start , end )
494-
495- // rewrite vi.mock(import('..')) into vi.mock('..')
496- if (
497- node . type === 'CallExpression'
498- && node . callee . type === 'MemberExpression'
499- && dynamicImportMockMethodNames . includes ( ( node . callee . property as Identifier ) . name )
500- ) {
501- const moduleInfo = node . arguments [ 0 ] as Positioned < Expression >
502- // vi.mock(import('./path')) -> vi.mock('./path')
503- if ( moduleInfo . type === 'ImportExpression' ) {
504- nodeCode = rewriteMockDynamicImport (
505- nodeCode ,
506- moduleInfo ,
507- moduleInfo . start ,
508- moduleInfo . end ,
509- node . start ,
510- )
511- }
512- // vi.mock(await import('./path')) -> vi.mock('./path')
513- if (
514- moduleInfo . type === 'AwaitExpression'
515- && moduleInfo . argument . type === 'ImportExpression'
516- ) {
517- nodeCode = rewriteMockDynamicImport (
518- nodeCode ,
519- moduleInfo . argument as Positioned < ImportExpression > ,
520- moduleInfo . start ,
521- moduleInfo . end ,
522- node . start ,
523- )
524- }
525- }
500+ // hoist vi.mock/vi.hoisted
501+ for ( const node of hoistedNodes ) {
502+ const end = getNodeTail ( code , node )
503+ if ( hoistIndex === end ) {
504+ hoistIndex = end
505+ }
506+ // don't hoist into itself if it's already at the top
507+ else if ( hoistIndex !== node . start ) {
508+ s . move ( node . start , end , hoistIndex )
509+ }
510+ }
526511
527- s . remove ( node . start , end )
528- return `${ nodeCode } ${ nodeCode . endsWith ( '\n' ) ? '' : '\n' } `
529- } )
530- . join ( '' )
512+ // hoist actual dynamic imports last so they are inserted after all hoisted mocks
513+ for ( const { node : importNode , id : importId } of imports ) {
514+ const source = importNode . source . value as string
531515
532- if ( hoistedCode || hoistedModuleImported ) {
533- s . prepend (
534- ( ! hoistedModuleImported && hoistedCode ? API_NOT_FOUND_CHECK ( utilsObjectNames ) : '' )
535- + hoistedCode ,
516+ s . update (
517+ importNode . start ,
518+ importNode . end ,
519+ `const ${ importId } = await import(${ JSON . stringify (
520+ source ,
521+ ) } );\n`,
536522 )
523+
524+ if ( importNode . start === hoistIndex ) {
525+ // no need to hoist, but update hoistIndex to keep the order
526+ hoistIndex = importNode . end
527+ }
528+ else {
529+ // There will be an error if the module is called before it is imported,
530+ // so the module import statement is hoisted to the top
531+ s . move ( importNode . start , importNode . end , hoistIndex )
532+ }
533+ }
534+
535+ if ( ! hoistedModuleImported && hoistedNodes . length ) {
536+ const utilityImports = [ ...usedUtilityExports ]
537+ // "vi" or "vitest" is imported from a module other than "vitest"
538+ if ( utilityImports . some ( name => idToImportMap . has ( name ) ) ) {
539+ s . prepend ( API_NOT_FOUND_CHECK ( utilityImports ) )
540+ }
541+ // if "vi" or "vitest" are not imported at all, import them
542+ else if ( utilityImports . length ) {
543+ s . prepend (
544+ `import { ${ [ ...usedUtilityExports ] . join ( ', ' ) } } from ${ JSON . stringify (
545+ hoistedModule ,
546+ ) } \n`,
547+ )
548+ }
537549 }
538550
539551 return {
0 commit comments