From 9fd06535d591835ed71223c00eb1285ce75dcab6 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Sun, 14 Jul 2024 12:05:42 -0700 Subject: [PATCH] fix(cli): generated TS typing for `auth()` access is too strong --- .../enhancer/enhance/auth-type-generator.ts | 61 ++++++------------- .../enhancements/with-policy/auth.test.ts | 26 ++++++++ 2 files changed, 45 insertions(+), 42 deletions(-) diff --git a/packages/schema/src/plugins/enhancer/enhance/auth-type-generator.ts b/packages/schema/src/plugins/enhancer/enhance/auth-type-generator.ts index bf92d1c9d..a4e09fbb2 100644 --- a/packages/schema/src/plugins/enhancer/enhance/auth-type-generator.ts +++ b/packages/schema/src/plugins/enhancer/enhance/auth-type-generator.ts @@ -1,4 +1,4 @@ -import { getIdFields, hasAttribute, isAuthInvocation, isDataModelFieldReference } from '@zenstackhq/sdk'; +import { getIdFields, isAuthInvocation, isDataModelFieldReference } from '@zenstackhq/sdk'; import { DataModel, DataModelField, @@ -18,41 +18,27 @@ export function generateAuthType(model: Model, authModel: DataModel) { const types = new Map< string, { - // scalar fields to directly pick from Prisma-generated type - pickFields: string[]; - - // relation fields to include - addFields: { name: string; type: string }[]; + // relation fields to require + requiredRelations: { name: string; type: string }[]; } >(); - types.set(authModel.name, { pickFields: getIdFields(authModel).map((f) => f.name), addFields: [] }); + types.set(authModel.name, { requiredRelations: [] }); const ensureType = (model: string) => { if (!types.has(model)) { - types.set(model, { pickFields: [], addFields: [] }); - } - }; - - const addPickField = (model: string, field: string) => { - let fields = types.get(model); - if (!fields) { - fields = { pickFields: [], addFields: [] }; - types.set(model, fields); - } - if (!fields.pickFields.includes(field)) { - fields.pickFields.push(field); + types.set(model, { requiredRelations: [] }); } }; const addAddField = (model: string, name: string, type: string, array: boolean) => { let fields = types.get(model); if (!fields) { - fields = { pickFields: [], addFields: [] }; + fields = { requiredRelations: [] }; types.set(model, fields); } - if (!fields.addFields.find((f) => f.name === name)) { - fields.addFields.push({ name, type: array ? `${type}[]` : type }); + if (!fields.requiredRelations.find((f) => f.name === name)) { + fields.requiredRelations.push({ name, type: array ? `${type}[]` : type }); } }; @@ -71,11 +57,6 @@ export function generateAuthType(model: Model, authModel: DataModel) { const fieldType = memberDecl.type.reference.ref.name; ensureType(fieldType); addAddField(exprType.name, memberDecl.name, fieldType, memberDecl.type.array); - } else { - // member is a scalar - if (!isIgnoredField(node.member.ref)) { - addPickField(exprType.name, node.member.$refText); - } } } } @@ -88,11 +69,6 @@ export function generateAuthType(model: Model, authModel: DataModel) { // field is a relation ensureType(fieldType.name); addAddField(fieldDecl.$container.name, node.target.$refText, fieldType.name, fieldDecl.type.array); - } else { - if (!isIgnoredField(fieldDecl)) { - // field is a scalar - addPickField(fieldDecl.$container.name, node.target.$refText); - } } } }); @@ -112,16 +88,21 @@ ${Array.from(types.entries()) .map(([model, fields]) => { let result = `Partial<_P.${model}>`; - if (fields.pickFields.length > 0) { - result = `WithRequired<${result}, ${fields.pickFields - .map((f) => `'${f}'`) - .join('|')}> & Record`; + if (model === authModel.name) { + // auth model's id fields are always required + const idFields = getIdFields(authModel).map((f) => f.name); + if (idFields.length > 0) { + result = `WithRequired<${result}, ${idFields.map((f) => `'${f}'`).join('|')}>`; + } } - if (fields.addFields.length > 0) { - result = `${result} & { ${fields.addFields.map(({ name, type }) => `${name}: ${type}`).join('; ')} }`; + if (fields.requiredRelations.length > 0) { + // merge required relation fields + result = `${result} & { ${fields.requiredRelations.map((f) => `${f.name}: ${f.type}`).join('; ')} }`; } + result = `${result} & Record`; + return ` export type ${model} = ${result};`; }) .join('\n')} @@ -145,7 +126,3 @@ function isAuthAccess(node: AstNode): node is Expression { return false; } - -function isIgnoredField(field: DataModelField | undefined) { - return !!(field && hasAttribute(field, '@ignore')); -} diff --git a/tests/integration/tests/enhancements/with-policy/auth.test.ts b/tests/integration/tests/enhancements/with-policy/auth.test.ts index ab2e55d73..f397fa804 100644 --- a/tests/integration/tests/enhancements/with-policy/auth.test.ts +++ b/tests/integration/tests/enhancements/with-policy/auth.test.ts @@ -809,4 +809,30 @@ describe('auth() compile-time test', () => { } ); }); + + it('optional field stays optional', async () => { + await loadSchema( + ` + model User { + id Int @id + age Int? + + @@allow('all', auth().age > 0) + } + `, + { + compile: true, + extraSourceFiles: [ + { + name: 'main.ts', + content: ` + import { enhance } from ".zenstack/enhance"; + import { PrismaClient } from '@prisma/client'; + enhance(new PrismaClient(), { user: { id: 1 } }); + `, + }, + ], + } + ); + }); });