Skip to content

Commit fb0aa0f

Browse files
authored
fix(zod): avoid importing Prisma enum, recognize enum fields with default (#2307)
1 parent dfa6402 commit fb0aa0f

File tree

4 files changed

+35
-12
lines changed

4 files changed

+35
-12
lines changed

packages/schema/src/plugins/zod/generator.ts

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ export class ZodSchemaGenerator {
340340
});
341341
this.sourceFiles.push(sf);
342342
sf.replaceWithText((writer) => {
343-
this.addPreludeAndImports(typeDef, writer, output);
343+
this.addPreludeAndImports(typeDef, writer);
344344

345345
writer.write(`const baseSchema = z.object(`);
346346
writer.inlineBlock(() => {
@@ -383,7 +383,7 @@ export const ${typeDef.name}Schema = ${refineFuncName}(${noRefineSchema});
383383
return schemaName;
384384
}
385385

386-
private addPreludeAndImports(decl: DataModel | TypeDef, writer: CodeBlockWriter, output: string) {
386+
private addPreludeAndImports(decl: DataModel | TypeDef, writer: CodeBlockWriter) {
387387
writer.writeLine(`import { z } from 'zod/${this.zodVersion}';`);
388388

389389
// import user-defined enums from Prisma as they might be referenced in the expressions
@@ -396,10 +396,6 @@ export const ${typeDef.name}Schema = ${refineFuncName}(${noRefineSchema});
396396
}
397397
}
398398
}
399-
if (importEnums.size > 0) {
400-
const prismaImport = computePrismaClientImport(path.join(output, 'models'), this.options);
401-
writer.writeLine(`import { ${[...importEnums].join(', ')} } from '${prismaImport}';`);
402-
}
403399

404400
// import enum schemas
405401
const importedEnumSchemas = new Set<string>();
@@ -448,7 +444,7 @@ export const ${typeDef.name}Schema = ${refineFuncName}(${noRefineSchema});
448444
const relations = model.fields.filter((field) => isDataModel(field.type.reference?.ref));
449445
const fkFields = model.fields.filter((field) => isForeignKeyField(field));
450446

451-
this.addPreludeAndImports(model, writer, output);
447+
this.addPreludeAndImports(model, writer);
452448

453449
// base schema - including all scalar fields, with optionality following the schema
454450
this.createModelBaseSchema('baseSchema', writer, scalarFields, true);
@@ -730,9 +726,7 @@ export const ${upperCaseFirst(model.name)}UpdateSchema = ${updateSchema};
730726
/**
731727
* Schema refinement function for applying \`@@validate\` rules.
732728
*/
733-
export function ${refineFuncName}<T>(schema: z.ZodType<T>) { return schema${refinements.join(
734-
'\n'
735-
)};
729+
export function ${refineFuncName}<T>(schema: z.ZodType<T>) { return schema${refinements.join('\n')};
736730
}
737731
`
738732
);
@@ -766,6 +760,7 @@ export const ${upperCaseFirst(model.name)}UpdateSchema = ${updateSchema};
766760
let expr = new TypeScriptExpressionTransformer({
767761
context: ExpressionContext.ValidationRule,
768762
fieldReferenceContext: 'value',
763+
useLiteralEnum: true,
769764
}).transform(valueArg);
770765

771766
if (isDataModelFieldReference(valueArg)) {

packages/schema/src/plugins/zod/utils/schema-gen.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { upperCaseFirst } from '@zenstackhq/runtime/local-helpers';
2-
import { getLiteral, hasAttribute, isFromStdlib } from '@zenstackhq/sdk';
2+
import { getLiteral, hasAttribute, isEnumFieldReference, isFromStdlib } from '@zenstackhq/sdk';
33
import {
44
DataModelField,
55
DataModelFieldAttribute,
@@ -246,6 +246,8 @@ export function getFieldSchemaDefault(field: DataModelField | TypeDefField) {
246246
return arg.value.value;
247247
} else if (isBooleanLiteral(arg.value)) {
248248
return arg.value.value;
249+
} else if (isEnumFieldReference(arg.value) && arg.value.target.ref) {
250+
return JSON.stringify(arg.value.target.ref.name);
249251
} else if (
250252
isInvocationExpr(arg.value) &&
251253
isFromStdlib(arg.value.function.ref!) &&

packages/sdk/src/typescript-expression-transformer.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ type Options = {
4040
futureRefContext?: string;
4141
context: ExpressionContext;
4242
operationContext?: 'read' | 'create' | 'update' | 'postUpdate' | 'delete';
43+
useLiteralEnum?: boolean;
4344
};
4445

4546
type Casing = 'original' | 'upper' | 'lower' | 'capitalize' | 'uncapitalize';
@@ -392,7 +393,9 @@ export class TypeScriptExpressionTransformer {
392393
}
393394

394395
if (isEnumField(expr.target.ref)) {
395-
return `${expr.target.ref.$container.name}.${expr.target.ref.name}`;
396+
return this.options.useLiteralEnum
397+
? JSON.stringify(expr.target.ref.name)
398+
: `${expr.target.ref.$container.name}.${expr.target.ref.name}`;
396399
} else {
397400
if (this.options?.isPostGuard) {
398401
// if we're processing post-update, any direct field access should be
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import { loadSchema } from '@zenstackhq/testtools';
2+
3+
describe('Issue 2291', () => {
4+
it('should work', async () => {
5+
const { zodSchemas } = await loadSchema(
6+
`
7+
enum SomeEnum {
8+
Ex1
9+
Ex2
10+
}
11+
12+
/// Post model
13+
model Post {
14+
id String @id @default(cuid())
15+
e SomeEnum @default(Ex1)
16+
}
17+
`,
18+
{ fullZod: true }
19+
);
20+
21+
expect(zodSchemas.models.PostSchema.parse({ id: '1' })).toEqual({ id: '1', e: 'Ex1' });
22+
});
23+
});

0 commit comments

Comments
 (0)