Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 96 additions & 28 deletions utils/src/ast-grep/import-statement.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import assert from "node:assert/strict";
import { describe, it } from "node:test";
import astGrep from '@ast-grep/napi';
import dedent from 'dedent';
import astGrep from "@ast-grep/napi";
import dedent from "dedent";
import { getNodeImportStatements, getNodeImportCalls } from "./import-statement.ts";

describe("import-statement", () => {
Expand All @@ -21,21 +21,21 @@ describe("import-statement", () => {
`;
const ast = astGrep.parse(astGrep.Lang.JavaScript, code);

const fsImports = getNodeImportStatements(ast, 'fs');
const fsImports = getNodeImportStatements(ast, "fs");
assert.strictEqual(fsImports.length, 1);
assert.strictEqual(fsImports[0].field('source')?.text(), "'fs'");
assert.strictEqual(fsImports[0].field("source")?.text(), "'fs'");

const pathImports = getNodeImportStatements(ast, 'path');
const pathImports = getNodeImportStatements(ast, "path");
assert.strictEqual(pathImports.length, 1);
assert.strictEqual(pathImports[0].field('source')?.text(), "'node:path'");
assert.strictEqual(pathImports[0].field("source")?.text(), "'node:path'");

const childProcessImports = getNodeImportStatements(ast, 'child_process');
const childProcessImports = getNodeImportStatements(ast, "child_process");
assert.strictEqual(childProcessImports.length, 1);
assert.strictEqual(childProcessImports[0].field('source')?.text(), '"child_process"');
assert.strictEqual(childProcessImports[0].field("source")?.text(), '"child_process"');

const utilImports = getNodeImportStatements(ast, 'util');
const utilImports = getNodeImportStatements(ast, "util");
assert.strictEqual(utilImports.length, 1);
assert.strictEqual(utilImports[0].field('source')?.text(), '"node:util"');
assert.strictEqual(utilImports[0].field("source")?.text(), '"node:util"');
});

it("should return import calls", () => {
Expand All @@ -53,29 +53,53 @@ describe("import-statement", () => {
`;
const ast = astGrep.parse(astGrep.Lang.JavaScript, code);

const fsCalls = getNodeImportCalls(ast, 'fs');
const fsCalls = getNodeImportCalls(ast, "fs");
assert.strictEqual(fsCalls.length, 1);
const fsCallExpr = fsCalls[0].field('value')?.children()[1]; // await_expression -> call_expression
assert.strictEqual(fsCallExpr?.field('function')?.text(), 'import');
assert.strictEqual(fsCallExpr?.field('arguments')?.find({ rule: { kind: "string" } })?.text(), "'fs'");
const fsCallExpr = fsCalls[0].field("value")?.children()[1]; // await_expression -> call_expression
assert.strictEqual(fsCallExpr?.field("function")?.text(), "import");
assert.strictEqual(
fsCallExpr
?.field("arguments")
?.find({ rule: { kind: "string" } })
?.text(),
"'fs'",
);

const pathCalls = getNodeImportCalls(ast, 'path');
const pathCalls = getNodeImportCalls(ast, "path");
assert.strictEqual(pathCalls.length, 1);
const pathCallExpr = pathCalls[0].field('value')?.children()[1]; // await_expression -> call_expression
assert.strictEqual(pathCallExpr?.field('function')?.text(), 'import');
assert.strictEqual(pathCallExpr?.field('arguments')?.find({ rule: { kind: "string" } })?.text(), "'node:path'");
const pathCallExpr = pathCalls[0].field("value")?.children()[1]; // await_expression -> call_expression
assert.strictEqual(pathCallExpr?.field("function")?.text(), "import");
assert.strictEqual(
pathCallExpr
?.field("arguments")
?.find({ rule: { kind: "string" } })
?.text(),
"'node:path'",
);

const childProcessCalls = getNodeImportCalls(ast, 'child_process');
const childProcessCalls = getNodeImportCalls(ast, "child_process");
assert.strictEqual(childProcessCalls.length, 1);
const childProcessCallExpr = childProcessCalls[0].field('value')?.children()[1]; // await_expression -> call_expression
assert.strictEqual(childProcessCallExpr?.field('function')?.text(), 'import');
assert.strictEqual(childProcessCallExpr?.field('arguments')?.find({ rule: { kind: "string" } })?.text(), '"child_process"');
const childProcessCallExpr = childProcessCalls[0].field("value")?.children()[1]; // await_expression -> call_expression
assert.strictEqual(childProcessCallExpr?.field("function")?.text(), "import");
assert.strictEqual(
childProcessCallExpr
?.field("arguments")
?.find({ rule: { kind: "string" } })
?.text(),
'"child_process"',
);

const utilCalls = getNodeImportCalls(ast, 'util');
const utilCalls = getNodeImportCalls(ast, "util");
assert.strictEqual(utilCalls.length, 1);
const utilCallExpr = utilCalls[0].field('value')?.children()[1]; // await_expression -> call_expression
assert.strictEqual(utilCallExpr?.field('function')?.text(), 'import');
assert.strictEqual(utilCallExpr?.field('arguments')?.find({ rule: { kind: "string" } })?.text(), '"node:util"');
const utilCallExpr = utilCalls[0].field("value")?.children()[1]; // await_expression -> call_expression
assert.strictEqual(utilCallExpr?.field("function")?.text(), "import");
assert.strictEqual(
utilCallExpr
?.field("arguments")
?.find({ rule: { kind: "string" } })
?.text(),
'"node:util"',
);
});

it("shouldn't catch pending promises during import calls", () => {
Expand All @@ -84,7 +108,51 @@ describe("import-statement", () => {
`;
const ast = astGrep.parse(astGrep.Lang.JavaScript, code);

const moduleCalls = getNodeImportCalls(ast, 'module');
const moduleCalls = getNodeImportCalls(ast, "module");
assert.strictEqual(moduleCalls.length, 0, "Pending import calls should not be caught");
});
})

it("should catch thenable during import calls", () => {
const code = dedent`
import("node:fs").then((mdl) => {
const readFile = mdl.readFile;

readFile("package.json", "utf8", (err, data) => {
if (err) throw err;
console.log({ data });
});
});
`;
const ast = astGrep.parse(astGrep.Lang.JavaScript, code);

const moduleCalls = getNodeImportCalls(ast, "fs");
assert.strictEqual(moduleCalls.length, 1, "thenable import calls should be caught");
});

it("should catch thenable during import calls with catch block", () => {
const code = dedent`
import("node:fs").then((mdl) => {
const readFile = mdl.readFile;

readFile("package.json", "utf8", (err, data) => {
if (err) throw err;
console.log({ data });
});
}).catch(console.log);
`;
const ast = astGrep.parse(astGrep.Lang.JavaScript, code);

const moduleCalls = getNodeImportCalls(ast, "fs");
assert.strictEqual(moduleCalls.length, 1, "thenable import calls should be caught");
});

it("shouldn't catch when there is no 'then'", () => {
const code = dedent`
import("node:fs").catch(console.log);
`;
const ast = astGrep.parse(astGrep.Lang.JavaScript, code);

const moduleCalls = getNodeImportCalls(ast, "fs");
assert.strictEqual(moduleCalls.length, 0, "dynamic import without then shouldn't be caught");
});
});
114 changes: 80 additions & 34 deletions utils/src/ast-grep/import-statement.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
import type { SgRoot, SgNode } from '@codemod.com/jssg-types/main';
import type { SgRoot, SgNode } from "@codemod.com/jssg-types/main";

export const getNodeImportStatements = (rootNode: SgRoot, nodeModuleName: string): SgNode[] =>
rootNode
.root()
.findAll({
rule: {
kind: "import_statement",
has: {
field: "source",
kind: "string",
regex: `^['"](node:)?${nodeModuleName}['"]$`
}
}
});
rootNode.root().findAll({
rule: {
kind: "import_statement",
has: {
field: "source",
kind: "string",
regex: `^['"](node:)?${nodeModuleName}['"]$`,
},
},
});

/**
* We just catch `variable_declarator` nodes that use `import` to import a module
Expand All @@ -22,21 +20,16 @@ export const getNodeImportStatements = (rootNode: SgRoot, nodeModuleName: string
* We also don't catch pending promises, like `const pending = import("node:module");`
* because it's will became to complex to handle in codemod context. (storing var name, checking is method is used, etc.)
*/
export const getNodeImportCalls = (rootNode: SgRoot, nodeModuleName: string): SgNode[] =>
rootNode
.root()
.findAll({
export const getNodeImportCalls = (rootNode: SgRoot, nodeModuleName: string): SgNode[] => {
const nodes = rootNode.root().findAll({
rule: {
kind: "variable_declarator",
all: [
{
has: {
field: "name",
any: [
{ kind: "object_pattern" },
{ kind: "identifier" }
]
}
any: [{ kind: "object_pattern" }, { kind: "identifier" }],
},
},
{
has: {
Expand All @@ -48,23 +41,76 @@ export const getNodeImportCalls = (rootNode: SgRoot, nodeModuleName: string): Sg
{
has: {
field: "function",
kind: "import"
}
kind: "import",
},
},
{
has: {
field: "arguments",
kind: "arguments",
has: {
kind: "string",
regex: `^['"](node:)?${nodeModuleName}['"]$`
}
}
}
]
}
}
}
]
}
regex: `^['"](node:)?${nodeModuleName}['"]$`,
},
},
},
],
},
},
},
],
},
});

const dynamicImports = rootNode.root().findAll({
rule: {
kind: "call_expression",
all: [
{
has: {
field: "function",
kind: "import",
},
},
{
has: {
field: "arguments",
kind: "arguments",
has: {
kind: "string",
regex: `^['"](node:)?${nodeModuleName}['"]$`,
},
},
},
],
},
});

for (const node of dynamicImports) {
let parentNode = node.parent();
// iterate through all chained methods until reaching the expression_statement
// that marks the beginning of the import line
while (parentNode !== null && parentNode.kind() !== "expression_statement") {
parentNode = parentNode.parent();
}

// if it is a valid import add to list of nodes that will be retuned
if (parentNode?.kind() === "expression_statement") {
const thenBlock = parentNode.find({
rule: {
kind: "member_expression",
has: {
kind: "property_identifier",
regex: "then",
},
},
});

if (thenBlock !== null) {
nodes.push(parentNode);
}
}
}

return nodes;
};