diff --git a/packages/schema/src/cli/actions/generate.ts b/packages/schema/src/cli/actions/generate.ts index 4da84c05d..44ad55612 100644 --- a/packages/schema/src/cli/actions/generate.ts +++ b/packages/schema/src/cli/actions/generate.ts @@ -1,4 +1,5 @@ import { PluginError } from '@zenstackhq/sdk'; +import { isPlugin } from '@zenstackhq/sdk/ast'; import colors from 'colors'; import path from 'path'; import { CliError } from '../cli-error'; @@ -18,6 +19,8 @@ type Options = { dependencyCheck: boolean; versionCheck: boolean; compile: boolean; + withPlugins?: string[]; + withoutPlugins?: string[]; defaultPlugins: boolean; }; @@ -57,9 +60,19 @@ async function runPlugins(options: Options) { const model = await loadDocument(schema); + for (const name of [...(options.withPlugins ?? []), ...(options.withoutPlugins ?? [])]) { + const pluginDecl = model.declarations.find((d) => isPlugin(d) && d.name === name); + if (!pluginDecl) { + console.error(colors.red(`Plugin "${name}" not found in schema.`)); + throw new CliError(`Plugin "${name}" not found in schema.`); + } + } + const runnerOpts: PluginRunnerOptions = { schema: model, schemaPath: path.resolve(schema), + withPlugins: options.withPlugins, + withoutPlugins: options.withoutPlugins, defaultPlugins: options.defaultPlugins, output: options.output, compile: options.compile, diff --git a/packages/schema/src/cli/index.ts b/packages/schema/src/cli/index.ts index 600d14b47..20207772a 100644 --- a/packages/schema/src/cli/index.ts +++ b/packages/schema/src/cli/index.ts @@ -120,6 +120,8 @@ export function createProgram() { .description('Run code generation.') .addOption(schemaOption) .addOption(new Option('-o, --output ', 'default output directory for core plugins')) + .addOption(new Option('--with-plugins ', 'only run specific plugins')) + .addOption(new Option('--without-plugins ', 'exclude specific plugins')) .addOption(new Option('--no-default-plugins', 'do not run default plugins')) .addOption(new Option('--no-compile', 'do not compile the output of core plugins')) .addOption(noVersionCheckOption) diff --git a/packages/schema/src/cli/plugin-runner.ts b/packages/schema/src/cli/plugin-runner.ts index ea00809fe..2912bfb60 100644 --- a/packages/schema/src/cli/plugin-runner.ts +++ b/packages/schema/src/cli/plugin-runner.ts @@ -40,6 +40,8 @@ export type PluginRunnerOptions = { schema: Model; schemaPath: string; output?: string; + withPlugins?: string[]; + withoutPlugins?: string[]; defaultPlugins: boolean; compile: boolean; }; @@ -137,7 +139,17 @@ export class PluginRunner { const project = createProject(); for (const { name, description, run, options: pluginOptions } of corePlugins) { const options = { ...pluginOptions, prismaClientPath }; - const r = await this.runPlugin(name, description, run, runnerOptions, options, dmmf, shortNameMap, project); + const r = await this.runPlugin( + name, + description, + run, + runnerOptions, + options, + dmmf, + shortNameMap, + project, + true + ); warnings.push(...(r?.warnings ?? [])); // the null-check is for backward compatibility if (r.dmmf) { @@ -162,7 +174,17 @@ export class PluginRunner { // run user plugins for (const { name, description, run, options: pluginOptions } of userPlugins) { const options = { ...pluginOptions, prismaClientPath }; - const r = await this.runPlugin(name, description, run, runnerOptions, options, dmmf, shortNameMap, project); + const r = await this.runPlugin( + name, + description, + run, + runnerOptions, + options, + dmmf, + shortNameMap, + project, + false + ); warnings.push(...(r?.warnings ?? [])); // the null-check is for backward compatibility } @@ -180,8 +202,7 @@ export class PluginRunner { if (existingPrisma) { corePlugins.push(existingPrisma); plugins.splice(plugins.indexOf(existingPrisma), 1); - } else if (options.defaultPlugins || plugins.some((p) => p.provider !== CorePlugins.Prisma)) { - // "@core/prisma" is enabled as default or if any other plugin is configured + } else if (options.defaultPlugins) { corePlugins.push(this.makeCorePlugin(CorePlugins.Prisma, options.schemaPath, {})); } @@ -215,7 +236,8 @@ export class PluginRunner { if ( !corePlugins.some((p) => p.provider === CorePlugins.Zod) && - (options.defaultPlugins || corePlugins.some((p) => p.provider === CorePlugins.Enhancer)) && + options.defaultPlugins && + corePlugins.some((p) => p.provider === CorePlugins.Enhancer) && hasValidation ) { // ensure "@core/zod" is enabled if "@core/enhancer" is enabled and there're validation rules @@ -319,10 +341,17 @@ export class PluginRunner { options: PluginDeclaredOptions, dmmf: DMMF.Document | undefined, shortNameMap: Map | undefined, - project: Project + project: Project, + isCorePlugin: boolean ) { + if (!isCorePlugin && !this.isPluginEnabled(name, runnerOptions)) { + ora(`Plugin "${name}" is skipped`).start().warn(); + return { warnings: [] }; + } + const title = description ?? `Running plugin ${colors.cyan(name)}`; const spinner = ora(title).start(); + try { const r = await telemetry.trackSpan( 'cli:plugin:start', @@ -358,6 +387,18 @@ export class PluginRunner { } } + private isPluginEnabled(name: string, runnerOptions: PluginRunnerOptions) { + if (runnerOptions.withPlugins && !runnerOptions.withPlugins.includes(name)) { + return false; + } + + if (runnerOptions.withoutPlugins && runnerOptions.withoutPlugins.includes(name)) { + return false; + } + + return true; + } + private getPluginModulePath(provider: string, schemaPath: string) { if (process.env.ZENSTACK_TEST === '1' && provider.startsWith('@zenstackhq/')) { // test code runs with its own sandbox of node_modules, make sure we don't diff --git a/tests/integration/tests/cli/generate.test.ts b/tests/integration/tests/cli/generate.test.ts index d90ce14cc..c4aed8e51 100644 --- a/tests/integration/tests/cli/generate.test.ts +++ b/tests/integration/tests/cli/generate.test.ts @@ -112,42 +112,28 @@ model Post { expect(fs.existsSync('./prisma/schema.prisma')).toBeTruthy(); }); - it('generate no default plugins with access-policy with zod', async () => { + it('generate no default plugins with enhancer and zod', async () => { fs.appendFileSync( 'schema.zmodel', ` - plugin enhancer { - provider = '@core/enhancer' + plugin prisma { + provider = '@core/prisma' + } + + plugin zod { + provider = '@core/zod' } - ` - ); - const program = createProgram(); - await program.parseAsync(['generate', '--no-dependency-check', '--no-default-plugins'], { from: 'user' }); - expect(fs.existsSync('./node_modules/.zenstack/policy.js')).toBeTruthy(); - expect(fs.existsSync('./node_modules/.zenstack/model-meta.js')).toBeTruthy(); - expect(fs.existsSync('./prisma/schema.prisma')).toBeTruthy(); - }); - it('generate no default plugins with access-policy without zod', async () => { - fs.appendFileSync( - 'schema.zmodel', - ` plugin enhancer { provider = '@core/enhancer' } ` ); - let content = fs.readFileSync('schema.zmodel', 'utf-8'); - content = content.replace('@email', ''); - fs.writeFileSync('schema.zmodel', content, 'utf-8'); - const program = createProgram(); await program.parseAsync(['generate', '--no-dependency-check', '--no-default-plugins'], { from: 'user' }); expect(fs.existsSync('./node_modules/.zenstack/policy.js')).toBeTruthy(); expect(fs.existsSync('./node_modules/.zenstack/model-meta.js')).toBeTruthy(); expect(fs.existsSync('./prisma/schema.prisma')).toBeTruthy(); - const z = require(path.join(process.cwd(), './node_modules/.zenstack/zod/models')); - expect(z).toEqual({}); }); it('generate no compile', async () => {