From 4198963eba078022c221c6c202d50721d6c8d851 Mon Sep 17 00:00:00 2001 From: augustin Date: Mon, 22 Jan 2024 22:03:14 +0100 Subject: [PATCH 01/12] WIP default auth --- package.json | 2 +- packages/ide/jetbrains/package.json | 2 +- packages/runtime/src/cross/mutator.ts | 1 + .../src/enhancements/create-enhancement.ts | 27 ++++++- .../runtime/src/enhancements/default-auth.ts | 77 +++++++++++++++++++ .../src/plugins/prisma/prisma-builder.ts | 18 ++++- .../src/plugins/prisma/schema-generator.ts | 10 +++ packages/schema/src/res/stdlib.zmodel | 4 +- .../tests/generator/prisma-generator.test.ts | 2 + .../validation/attribute-validation.test.ts | 45 +++++++++-- packages/sdk/src/model-meta-generator.ts | 15 +++- .../enhancements/with-policy/auth.test.ts | 28 +++++++ 12 files changed, 214 insertions(+), 17 deletions(-) create mode 100644 packages/runtime/src/enhancements/default-auth.ts diff --git a/package.json b/package.json index 83d9f7072..c7b2f6ade 100644 --- a/package.json +++ b/package.json @@ -5,7 +5,7 @@ "scripts": { "build": "pnpm -r build", "lint": "pnpm -r lint", - "test": "ZENSTACK_TEST=1 pnpm -r run test --silent --forceExit", + "test": "ZENSTACK_TEST=1 pnpm -r run test --silent=false --forceExit", "test-ci": "ZENSTACK_TEST=1 pnpm -r run test --silent --forceExit", "publish-all": "pnpm --filter \"./packages/**\" -r publish --access public", "publish-preview": "pnpm --filter \"./packages/**\" -r publish --force --registry https://preview.registry.zenstack.dev/", diff --git a/packages/ide/jetbrains/package.json b/packages/ide/jetbrains/package.json index 274e88c2a..ec05801e3 100644 --- a/packages/ide/jetbrains/package.json +++ b/packages/ide/jetbrains/package.json @@ -6,7 +6,7 @@ "homepage": "https://zenstack.dev", "private": true, "scripts": { - "build": "./gradlew buildPlugin" + "build": "echo './gradlew buildPlugin'" }, "author": "ZenStack Team", "license": "MIT", diff --git a/packages/runtime/src/cross/mutator.ts b/packages/runtime/src/cross/mutator.ts index 0dd66e6fb..0ed5761a4 100644 --- a/packages/runtime/src/cross/mutator.ts +++ b/packages/runtime/src/cross/mutator.ts @@ -124,6 +124,7 @@ function createMutate( insert[name] = newData[name]; } else { const defaultAttr = field.attributes?.find((attr) => attr.name === '@default'); + // TODO: handle default auth() attributes here ? if (field.type === 'DateTime') { // default value for DateTime field if (defaultAttr || field.attributes?.some((attr) => attr.name === '@updatedAt')) { diff --git a/packages/runtime/src/enhancements/create-enhancement.ts b/packages/runtime/src/enhancements/create-enhancement.ts index a82640905..c900d1828 100644 --- a/packages/runtime/src/enhancements/create-enhancement.ts +++ b/packages/runtime/src/enhancements/create-enhancement.ts @@ -7,6 +7,7 @@ import { withPassword } from './password'; import { withPolicy } from './policy'; import type { ErrorTransformer } from './proxy'; import type { PolicyDef, ZodSchemas } from './types'; +// import { withDefaultAuth } from './default-auth'; /** * Kinds of enhancements to `PrismaClient` @@ -15,6 +16,7 @@ export enum EnhancementKind { Password = 'password', Omit = 'omit', Policy = 'policy', + DefaultAuth = 'defaultAuth', } /** @@ -92,6 +94,7 @@ export type EnhancementContext = { let hasPassword: boolean | undefined = undefined; let hasOmit: boolean | undefined = undefined; +let hasDefaultAuth: boolean | undefined = undefined; /** * Gets a Prisma client enhanced with all enhancement behaviors, including access @@ -120,13 +123,28 @@ export function createEnhancement( let result = prisma; - if (hasPassword === undefined || hasOmit === undefined) { + if (hasPassword === undefined || hasOmit === undefined || hasDefaultAuth === undefined) { const allFields = Object.values(options.modelMeta.fields).flatMap((modelInfo) => Object.values(modelInfo)); hasPassword = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@password')); hasOmit = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@omit')); + // FIXME: check from modelMeta if @default(auth()) is present in some fields + // hasDefaultAuth = allFields.some((field) => + // field.attributes?.some( + // (attr) => + // attr.name === '@default' && + // typeof attr.args[0]?.value === 'string' && + // attr.args[0]?.value.startsWith('auth()') + // ) + // ); + hasDefaultAuth = true; } - const kinds = options.kinds ?? [EnhancementKind.Password, EnhancementKind.Omit, EnhancementKind.Policy]; + const kinds = options.kinds ?? [ + EnhancementKind.Password, + EnhancementKind.Omit, + EnhancementKind.Policy, + EnhancementKind.DefaultAuth, + ]; if (hasPassword && kinds.includes(EnhancementKind.Password)) { // @password proxy @@ -138,6 +156,11 @@ export function createEnhancement( result = withOmit(result, options); } + // if (hasDefaultAuth && kinds.includes(EnhancementKind.DefaultAuth)) { + // // @default(auth()) proxy + // result = withDefaultAuth(result, options, context); + // } + // policy proxy if (kinds.includes(EnhancementKind.Policy)) { result = withPolicy(result, options, context); diff --git a/packages/runtime/src/enhancements/default-auth.ts b/packages/runtime/src/enhancements/default-auth.ts new file mode 100644 index 000000000..1b2006576 --- /dev/null +++ b/packages/runtime/src/enhancements/default-auth.ts @@ -0,0 +1,77 @@ +/* eslint-disable @typescript-eslint/no-unused-vars */ +/* eslint-disable @typescript-eslint/no-explicit-any */ + +import { + enumerate, + getModelFields, + resolveField, + type ModelMeta, + NestedWriteVisitor, + PrismaWriteActionType, +} from '../cross'; +import { DbClientContract } from '../types'; +import { EnhancementContext, EnhancementOptions } from './create-enhancement'; +import { DefaultPrismaProxyHandler, PrismaProxyActions, makeProxy } from './proxy'; + +/** + * Gets an enhanced Prisma client that supports `@default(auth())` attribute. + * + * @private + */ +export function withDefaultAuth( + prisma: DbClient, + options: EnhancementOptions, + context?: EnhancementContext +): DbClient { + return makeProxy( + prisma, + options.modelMeta, + (_prisma, model) => new DefaultAuthHandler(_prisma as DbClientContract, model, options, context), + 'defaultAuth' + ); +} + +class DefaultAuthHandler extends DefaultPrismaProxyHandler { + constructor( + prisma: DbClientContract, + model: string, + private readonly options: EnhancementOptions, + private readonly context?: EnhancementContext + ) { + super(prisma, model); + } + + // base override + protected async preprocessArgs(action: PrismaProxyActions, args: any) { + const actionsOfInterest: PrismaProxyActions[] = ['create', 'createMany', 'update', 'updateMany', 'upsert']; + if (args && args.data && actionsOfInterest.includes(action)) { + await this.preprocessWritePayload(this.model, action as PrismaWriteActionType, args); + } + return args; + } + + private async preprocessWritePayload(model: string, action: PrismaWriteActionType, args: any) { + const visitor = new NestedWriteVisitor(this.options.modelMeta, { + field: async (field, _action, _data, context) => { + const defaultAuthAttr = field.attributes?.find( + (attr) => + attr.name === '@default' && + typeof attr.args[0]?.value === 'string' && + attr.args[0]?.value.startsWith('auth()') + ); + if (defaultAuthAttr && field.type === 'String') { + const authSelector = (defaultAuthAttr?.args[0]?.value as string).slice('auth()'.length); + // get auth selector and retrieve default value from context + const userContext = this.context?.user; + if (!userContext) { + throw new Error(`Invalid user context`); + } + const authValue = authSelector ? userContext[authSelector] : userContext; + context.parent[field.name] = authValue; + } + }, + }); + + await visitor.visit(model, action, args); + } +} diff --git a/packages/schema/src/plugins/prisma/prisma-builder.ts b/packages/schema/src/plugins/prisma/prisma-builder.ts index 64777b62e..3d9c68357 100644 --- a/packages/schema/src/plugins/prisma/prisma-builder.ts +++ b/packages/schema/src/plugins/prisma/prisma-builder.ts @@ -225,8 +225,8 @@ export class AttributeArg { export class AttributeArgValue { constructor( - public type: 'String' | 'FieldReference' | 'Number' | 'Boolean' | 'Array' | 'FunctionCall', - public value: string | number | boolean | FieldReference | FunctionCall | AttributeArgValue[] + public type: 'String' | 'FieldReference' | 'Number' | 'Boolean' | 'Array' | 'FunctionCall' | 'AuthAttribute', + public value: string | number | boolean | FieldReference | FunctionCall | AttributeArgValue[] | AuthAttribute ) { switch (type) { case 'String': @@ -249,6 +249,10 @@ export class AttributeArgValue { case 'FunctionCall': if (!(value instanceof FunctionCall)) throw new Error('Value must be FunctionCall'); break; + case 'AuthAttribute': + // TODO: implement validation + // if (!(value instanceof FunctionCall)) throw new Error('Value must be FunctionCall'); + break; } } @@ -271,6 +275,8 @@ export class AttributeArgValue { return r; } } + case 'AuthAttribute': + return this.value.toString(); case 'FunctionCall': return this.value.toString(); case 'Boolean': @@ -311,6 +317,14 @@ export class FunctionCallArg { } } +export class AuthAttribute { + constructor(public field: string) {} + + toString(): string { + return `"${this.field}"`; + } +} + export class Enum extends ContainerDeclaration { public fields: EnumField[] = []; diff --git a/packages/schema/src/plugins/prisma/schema-generator.ts b/packages/schema/src/plugins/prisma/schema-generator.ts index feee0f3d1..65046681a 100644 --- a/packages/schema/src/plugins/prisma/schema-generator.ts +++ b/packages/schema/src/plugins/prisma/schema-generator.ts @@ -50,6 +50,7 @@ import telemetry from '../../telemetry'; import { execSync } from '../../utils/exec-utils'; import { getPackageJson } from '../../utils/pkg-utils'; import { + AuthAttribute, ModelFieldType, AttributeArg as PrismaAttributeArg, AttributeArgValue as PrismaAttributeArgValue, @@ -66,9 +67,11 @@ import { PassThroughAttribute as PrismaPassThroughAttribute, SimpleField, } from './prisma-builder'; +import { isAuthInvocation } from '../../utils/ast-utils'; const MODEL_PASSTHROUGH_ATTR = '@@prisma.passthrough'; const FIELD_PASSTHROUGH_ATTR = '@prisma.passthrough'; +const NO_FIELD_ATTRIBUTE = ''; /** * Generates Prisma schema file @@ -334,6 +337,9 @@ export default class PrismaSchemaGenerator { } else { throw new PluginError(name, `Invalid arguments for ${FIELD_PASSTHROUGH_ATTR} attribute`); } + // remove @default(auth()) field attributes as they are not supported by Prisma + } else if (attrName.startsWith('@default') && isAuthInvocation(attr.args[0].value)) { + return new PrismaFieldAttribute(NO_FIELD_ATTRIBUTE); } else { return new PrismaFieldAttribute( attrName, @@ -370,6 +376,10 @@ export default class PrismaSchemaGenerator { } else if (isInvocationExpr(node)) { // invocation return new PrismaAttributeArgValue('FunctionCall', this.makeFunctionCall(node)); + // @ts-expect-error TODO: fix this + } else if (node.operand?.function?.$refText === 'auth') { + console.log('member access', node); + return new PrismaAttributeArgValue('AuthAttribute', new AuthAttribute('DEFAULT USER FROM PRISMA BUILDER')); } else { throw new PluginError(name, `Unsupported attribute argument expression type: ${node.$type}`); } diff --git a/packages/schema/src/res/stdlib.zmodel b/packages/schema/src/res/stdlib.zmodel index 1a9446d7b..f755bb3df 100644 --- a/packages/schema/src/res/stdlib.zmodel +++ b/packages/schema/src/res/stdlib.zmodel @@ -73,7 +73,7 @@ function env(name: String): String { * Gets the current login user. */ function auth(): Any { -} @@@expressionContext([AccessPolicy]) +} @@@expressionContext([DefaultValue, AccessPolicy]) /** * Gets current date-time (as DateTime type). @@ -204,7 +204,7 @@ attribute @id(map: String?, length: Int?, sort: String?, clustered: Boolean?) @@ /** * Defines a default value for a field. - * @param value: An expression (e.g. 5, true, now()). + * @param value: An expression (e.g. 5, true, now(), auth()). */ attribute @default(_ value: ContextType, map: String?) @@@prisma diff --git a/packages/schema/tests/generator/prisma-generator.test.ts b/packages/schema/tests/generator/prisma-generator.test.ts index 8d295d143..d2f425e53 100644 --- a/packages/schema/tests/generator/prisma-generator.test.ts +++ b/packages/schema/tests/generator/prisma-generator.test.ts @@ -123,6 +123,7 @@ describe('Prisma generator test', () => { id String @id @default(nanoid(6)) x String @default(nanoid()) y String @default(dbgenerated("gen_random_uuid()")) + z String @default(auth().id) } `); @@ -142,6 +143,7 @@ describe('Prisma generator test', () => { expect(content).toContain('@default(nanoid(6))'); expect(content).toContain('@default(nanoid())'); expect(content).toContain('@default(dbgenerated("gen_random_uuid()"))'); + expect(content).not.toContain('@default(auth().id)'); }); it('triple slash comments', async () => { diff --git a/packages/schema/tests/schema/validation/attribute-validation.test.ts b/packages/schema/tests/schema/validation/attribute-validation.test.ts index 8b7886334..cb2f788d4 100644 --- a/packages/schema/tests/schema/validation/attribute-validation.test.ts +++ b/packages/schema/tests/schema/validation/attribute-validation.test.ts @@ -1009,6 +1009,35 @@ describe('Attribute tests', () => { }); it('auth function check', async () => { + await loadModel(` + ${prelude} + + model User { + id String @id + name String + } + model B { + id String @id + userId String @default(auth().id) + userName String @default(auth().name) + } + `); + + // expect( + // await loadModelWithError(` + // ${prelude} + + // model User { + // id String @id + // name String + // } + // model B { + // id String @id + // userData String @default(auth()) + // } + // `) + // ).toContain("Value is not assignable to parameter"); + expect( await loadModelWithError(` ${prelude} @@ -1124,14 +1153,14 @@ describe('Attribute tests', () => { }); it('incorrect function expression context', async () => { - expect( - await loadModelWithError(` - ${prelude} - model M { - id String @id @default(auth()) - } - `) - ).toContain('function "auth" is not allowed in the current context: DefaultValue'); + // expect( + // await loadModelWithError(` + // ${prelude} + // model M { + // id String @id @default(auth()) + // } + // `) + // ).toContain('function "auth" is not allowed in the current context: DefaultValue'); expect( await loadModelWithError(` diff --git a/packages/sdk/src/model-meta-generator.ts b/packages/sdk/src/model-meta-generator.ts index da0ba96dd..772725584 100644 --- a/packages/sdk/src/model-meta-generator.ts +++ b/packages/sdk/src/model-meta-generator.ts @@ -1,10 +1,12 @@ import { ArrayExpr, + AttributeArg, DataModel, DataModelField, isArrayExpr, isBooleanLiteral, isDataModel, + isInvocationExpr, isNumberLiteral, isReferenceExpr, isStringLiteral, @@ -23,6 +25,7 @@ import { hasAttribute, isEnumFieldReference, isForeignKeyField, + isFromStdlib, isIdField, resolved, } from '.'; @@ -210,8 +213,10 @@ function getFieldAttributes(field: DataModelField): RuntimeAttribute[] { args.push({ name: arg.name, value: v }); } else if (isStringLiteral(arg.value) || isBooleanLiteral(arg.value)) { args.push({ name: arg.name, value: arg.value.value }); + } else if (isAuthDefaultValue(arg)) { + args.push({ name: arg.name, value: arg.value.toString() }); } else { - // non-literal args are ignored + // non-literal and auth args are ignored } } return { name: resolved(attr.decl).name, args }; @@ -327,3 +332,11 @@ function getDeleteCascades(model: DataModel): string[] { }) .map((m) => m.name); } + +function isAuthDefaultValue(arg: AttributeArg): boolean { + return ( + isInvocationExpr(arg.value) && + isFromStdlib(arg.value.function.ref!) && + arg.value.function.$refText.startsWith('auth') + ); +} diff --git a/tests/integration/tests/enhancements/with-policy/auth.test.ts b/tests/integration/tests/enhancements/with-policy/auth.test.ts index 942d2d579..dce5ab698 100644 --- a/tests/integration/tests/enhancements/with-policy/auth.test.ts +++ b/tests/integration/tests/enhancements/with-policy/auth.test.ts @@ -363,4 +363,32 @@ describe('With Policy: auth() test', () => { enhance({ id: '1', posts: [{ id: '1', published: true, comments: [] }] }).post.create(createPayload) ).toResolveTruthy(); }); + + it('Field with default auth() created correctly', async () => { + const { enhance } = await loadSchema( + ` + model User { + id String @id @default(uuid()) + name String + // posts Post[] + + } + + model Post { + id String @id @default(uuid()) + title String + authorName String @default(auth().name) + // author User @relation(fields: [authorId], references: [id]) + // authorId String @default(auth().id) + + @@allow('all', true) + } + ` + ); + + const userDb = enhance({ id: '1', name: 'user1' }); + await expect(userDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); + console.log(await userDb.post.findMany()); + await expect(userDb.post.count({ where: { authorName: 'user1' } })).resolves.toBe(1); + }); }); From a31a422fa387b57eb1ad0149cf90a1d534912e67 Mon Sep 17 00:00:00 2001 From: augustin Date: Tue, 23 Jan 2024 06:12:53 +0100 Subject: [PATCH 02/12] feat: handle default auth for literal fields --- packages/runtime/src/cross/model-meta.ts | 7 +- .../src/enhancements/create-enhancement.ts | 23 +- .../runtime/src/enhancements/default-auth.ts | 48 +- .../src/plugins/prisma/schema-generator.ts | 20 +- packages/sdk/src/model-meta-generator.ts | 22 +- .../enhancements/with-policy/auth.test.ts | 706 +++++++++--------- 6 files changed, 429 insertions(+), 397 deletions(-) diff --git a/packages/runtime/src/cross/model-meta.ts b/packages/runtime/src/cross/model-meta.ts index 817819b8c..89b1e11b7 100644 --- a/packages/runtime/src/cross/model-meta.ts +++ b/packages/runtime/src/cross/model-meta.ts @@ -1,11 +1,16 @@ import { lowerCaseFirst } from 'lower-case-first'; +/** + * An access key in the user context object (e.g. `profile.picture.url`) + */ +export type AuthContextSelector = string; + /** * Runtime information of a data model or field attribute */ export type RuntimeAttribute = { name: string; - args: Array<{ name?: string; value: unknown }>; + args: Array<{ name?: string; value: unknown } | { name: 'auth()'; value: AuthContextSelector }>; }; /** diff --git a/packages/runtime/src/enhancements/create-enhancement.ts b/packages/runtime/src/enhancements/create-enhancement.ts index c900d1828..bb8baef54 100644 --- a/packages/runtime/src/enhancements/create-enhancement.ts +++ b/packages/runtime/src/enhancements/create-enhancement.ts @@ -7,7 +7,7 @@ import { withPassword } from './password'; import { withPolicy } from './policy'; import type { ErrorTransformer } from './proxy'; import type { PolicyDef, ZodSchemas } from './types'; -// import { withDefaultAuth } from './default-auth'; +import { withDefaultAuth } from './default-auth'; /** * Kinds of enhancements to `PrismaClient` @@ -127,16 +127,9 @@ export function createEnhancement( const allFields = Object.values(options.modelMeta.fields).flatMap((modelInfo) => Object.values(modelInfo)); hasPassword = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@password')); hasOmit = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@omit')); - // FIXME: check from modelMeta if @default(auth()) is present in some fields - // hasDefaultAuth = allFields.some((field) => - // field.attributes?.some( - // (attr) => - // attr.name === '@default' && - // typeof attr.args[0]?.value === 'string' && - // attr.args[0]?.value.startsWith('auth()') - // ) - // ); - hasDefaultAuth = true; + hasDefaultAuth = allFields.some((field) => + field.attributes?.some((attr) => attr.name === '@default' && attr.args[0]?.name === 'auth()') + ); } const kinds = options.kinds ?? [ @@ -156,10 +149,10 @@ export function createEnhancement( result = withOmit(result, options); } - // if (hasDefaultAuth && kinds.includes(EnhancementKind.DefaultAuth)) { - // // @default(auth()) proxy - // result = withDefaultAuth(result, options, context); - // } + if (hasDefaultAuth && kinds.includes(EnhancementKind.DefaultAuth)) { + // @default(auth()) proxy + result = withDefaultAuth(result, options, context); + } // policy proxy if (kinds.includes(EnhancementKind.Policy)) { diff --git a/packages/runtime/src/enhancements/default-auth.ts b/packages/runtime/src/enhancements/default-auth.ts index 1b2006576..f88c38cbe 100644 --- a/packages/runtime/src/enhancements/default-auth.ts +++ b/packages/runtime/src/enhancements/default-auth.ts @@ -8,6 +8,8 @@ import { type ModelMeta, NestedWriteVisitor, PrismaWriteActionType, + NestedWriteVisitorContext, + FieldInfo, } from '../cross'; import { DbClientContract } from '../types'; import { EnhancementContext, EnhancementOptions } from './create-enhancement'; @@ -32,6 +34,7 @@ export function withDefaultAuth( } class DefaultAuthHandler extends DefaultPrismaProxyHandler { + private readonly db: DbClientContract; constructor( prisma: DbClientContract, model: string, @@ -39,12 +42,13 @@ class DefaultAuthHandler extends DefaultPrismaProxyHandler { private readonly context?: EnhancementContext ) { super(prisma, model); + this.db = prisma; } // base override protected async preprocessArgs(action: PrismaProxyActions, args: any) { const actionsOfInterest: PrismaProxyActions[] = ['create', 'createMany', 'update', 'updateMany', 'upsert']; - if (args && args.data && actionsOfInterest.includes(action)) { + if (actionsOfInterest.includes(action)) { await this.preprocessWritePayload(this.model, action as PrismaWriteActionType, args); } return args; @@ -52,23 +56,33 @@ class DefaultAuthHandler extends DefaultPrismaProxyHandler { private async preprocessWritePayload(model: string, action: PrismaWriteActionType, args: any) { const visitor = new NestedWriteVisitor(this.options.modelMeta, { - field: async (field, _action, _data, context) => { - const defaultAuthAttr = field.attributes?.find( - (attr) => - attr.name === '@default' && - typeof attr.args[0]?.value === 'string' && - attr.args[0]?.value.startsWith('auth()') - ); - if (defaultAuthAttr && field.type === 'String') { - const authSelector = (defaultAuthAttr?.args[0]?.value as string).slice('auth()'.length); - // get auth selector and retrieve default value from context - const userContext = this.context?.user; - if (!userContext) { - throw new Error(`Invalid user context`); - } - const authValue = authSelector ? userContext[authSelector] : userContext; - context.parent[field.name] = authValue; + create: async (model, args, _context) => { + const userContext = this.context?.user; + if (!userContext) { + throw new Error(`Invalid user context`); } + const fields = this.options.modelMeta.fields[model]; + const isDefaultAuthField = (fieldInfo: FieldInfo) => + fieldInfo.attributes?.find((attr) => attr.name === '@default' && attr.args?.[0]?.name === 'auth()'); + const defaultAuthSelectorFields = Object.fromEntries( + Object.entries(fields) + .filter(([_, fieldInfo]) => isDefaultAuthField(fieldInfo)) + .map(([field, fieldInfo]) => [ + field, + fieldInfo.attributes?.find((attr) => attr.name === '@default')?.args[0]?.value as + | string + | undefined, + ]) + ); + const defaultAuthFields = Object.fromEntries( + Object.entries(defaultAuthSelectorFields).map(([field, selector]) => [ + field, + selector ? userContext[selector] : userContext, + ]) + ); + console.log('defaultAuthFields :', defaultAuthFields); + const result = await this.db[model].create({ data: { ...defaultAuthFields, ...args } }); + return result; }, }); diff --git a/packages/schema/src/plugins/prisma/schema-generator.ts b/packages/schema/src/plugins/prisma/schema-generator.ts index 65046681a..5298aead0 100644 --- a/packages/schema/src/plugins/prisma/schema-generator.ts +++ b/packages/schema/src/plugins/prisma/schema-generator.ts @@ -50,7 +50,6 @@ import telemetry from '../../telemetry'; import { execSync } from '../../utils/exec-utils'; import { getPackageJson } from '../../utils/pkg-utils'; import { - AuthAttribute, ModelFieldType, AttributeArg as PrismaAttributeArg, AttributeArgValue as PrismaAttributeArgValue, @@ -67,7 +66,6 @@ import { PassThroughAttribute as PrismaPassThroughAttribute, SimpleField, } from './prisma-builder'; -import { isAuthInvocation } from '../../utils/ast-utils'; const MODEL_PASSTHROUGH_ATTR = '@@prisma.passthrough'; const FIELD_PASSTHROUGH_ATTR = '@prisma.passthrough'; @@ -288,7 +286,12 @@ export default class PrismaSchemaGenerator { !!attrDecl.attributes.find((a) => a.decl.ref?.name === '@@@prisma') || // the special pass-through attribute attrDecl.name === MODEL_PASSTHROUGH_ATTR || - attrDecl.name === FIELD_PASSTHROUGH_ATTR + attrDecl.name === FIELD_PASSTHROUGH_ATTR || + // auth() in @default() is not supported by Prisma + // FIXME: condition is inverted to avoid error... + !!attrDecl.attributes.find( + (a) => a.decl.ref?.name === '@default' && a.args[0].value.$cstNode?.text.startsWith('auth()') + ) ); } @@ -337,9 +340,10 @@ export default class PrismaSchemaGenerator { } else { throw new PluginError(name, `Invalid arguments for ${FIELD_PASSTHROUGH_ATTR} attribute`); } - // remove @default(auth()) field attributes as they are not supported by Prisma - } else if (attrName.startsWith('@default') && isAuthInvocation(attr.args[0].value)) { - return new PrismaFieldAttribute(NO_FIELD_ATTRIBUTE); + // do not write @default(auth()) field attribute as it is not supported by Prisma + // TODO: we should add a comment to the field + } else if (attrName === '@default' && attr.args[0].value.$cstNode?.text.startsWith('auth()')) { + return new PrismaPassThroughAttribute(NO_FIELD_ATTRIBUTE); } else { return new PrismaFieldAttribute( attrName, @@ -376,10 +380,6 @@ export default class PrismaSchemaGenerator { } else if (isInvocationExpr(node)) { // invocation return new PrismaAttributeArgValue('FunctionCall', this.makeFunctionCall(node)); - // @ts-expect-error TODO: fix this - } else if (node.operand?.function?.$refText === 'auth') { - console.log('member access', node); - return new PrismaAttributeArgValue('AuthAttribute', new AuthAttribute('DEFAULT USER FROM PRISMA BUILDER')); } else { throw new PluginError(name, `Unsupported attribute argument expression type: ${node.$type}`); } diff --git a/packages/sdk/src/model-meta-generator.ts b/packages/sdk/src/model-meta-generator.ts index 772725584..9f06932df 100644 --- a/packages/sdk/src/model-meta-generator.ts +++ b/packages/sdk/src/model-meta-generator.ts @@ -1,12 +1,10 @@ import { ArrayExpr, - AttributeArg, DataModel, DataModelField, isArrayExpr, isBooleanLiteral, isDataModel, - isInvocationExpr, isNumberLiteral, isReferenceExpr, isStringLiteral, @@ -25,7 +23,6 @@ import { hasAttribute, isEnumFieldReference, isForeignKeyField, - isFromStdlib, isIdField, resolved, } from '.'; @@ -213,10 +210,15 @@ function getFieldAttributes(field: DataModelField): RuntimeAttribute[] { args.push({ name: arg.name, value: v }); } else if (isStringLiteral(arg.value) || isBooleanLiteral(arg.value)) { args.push({ name: arg.name, value: arg.value.value }); - } else if (isAuthDefaultValue(arg)) { - args.push({ name: arg.name, value: arg.value.toString() }); + } else if ( + attr.decl.ref?.name === '@default' && + attr.args[0].value.$cstNode?.text.startsWith('auth()') + ) { + const authValue = attr.args[0].value.$cstNode?.text; + const authSelector = authValue === 'auth()' ? authValue : authValue.slice('auth().'.length); + args.push({ name: 'auth()', value: authSelector }); } else { - // non-literal and auth args are ignored + // non-literal args are ignored } } return { name: resolved(attr.decl).name, args }; @@ -332,11 +334,3 @@ function getDeleteCascades(model: DataModel): string[] { }) .map((m) => m.name); } - -function isAuthDefaultValue(arg: AttributeArg): boolean { - return ( - isInvocationExpr(arg.value) && - isFromStdlib(arg.value.function.ref!) && - arg.value.function.$refText.startsWith('auth') - ); -} diff --git a/tests/integration/tests/enhancements/with-policy/auth.test.ts b/tests/integration/tests/enhancements/with-policy/auth.test.ts index dce5ab698..973177c27 100644 --- a/tests/integration/tests/enhancements/with-policy/auth.test.ts +++ b/tests/integration/tests/enhancements/with-policy/auth.test.ts @@ -12,374 +12,400 @@ describe('With Policy: auth() test', () => { process.chdir(origDir); }); - it('undefined user with string id simple', async () => { + // it('undefined user with string id simple', async () => { + // const { enhance } = await loadSchema( + // ` + // model User { + // id String @id @default(uuid()) + // } + + // model Post { + // id String @id @default(uuid()) + // title String + + // @@allow('read', true) + // @@allow('create', auth() != null) + // } + // ` + // ); + + // const db = enhance(); + // await expect(db.post.create({ data: { title: 'abc' } })).toBeRejectedByPolicy(); + + // const authDb = enhance({ id: 'user1' }); + // await expect(authDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); + // }); + + // it('undefined user with string id more', async () => { + // const { enhance } = await loadSchema( + // ` + // model User { + // id String @id @default(uuid()) + // } + + // model Post { + // id String @id @default(uuid()) + // title String + + // @@allow('read', true) + // @@allow('create', auth().id != null) + // } + // ` + // ); + + // const db = enhance(); + // await expect(db.post.create({ data: { title: 'abc' } })).toBeRejectedByPolicy(); + + // const authDb = enhance({ id: 'user1' }); + // await expect(authDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); + // }); + + // it('undefined user with int id', async () => { + // const { enhance } = await loadSchema( + // ` + // model User { + // id Int @id @default(autoincrement()) + // } + + // model Post { + // id String @id @default(uuid()) + // title String + + // @@allow('read', true) + // @@allow('create', auth() != null) + // } + // ` + // ); + + // const db = enhance(); + // await expect(db.post.create({ data: { title: 'abc' } })).toBeRejectedByPolicy(); + + // const authDb = enhance({ id: 'user1' }); + // await expect(authDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); + // }); + + // it('undefined user compared with field', async () => { + // const { enhance } = await loadSchema( + // ` + // model User { + // id String @id @default(uuid()) + // posts Post[] + + // @@allow('all', true) + // } + + // model Post { + // id String @id @default(uuid()) + // title String + // author User @relation(fields: [authorId], references: [id]) + // authorId String + + // @@allow('create,read', true) + // @@allow('update', auth() == author) + // } + // ` + // ); + + // const db = enhance(); + // await expect(db.user.create({ data: { id: 'user1' } })).toResolveTruthy(); + // await expect(db.post.create({ data: { id: '1', title: 'abc', authorId: 'user1' } })).toResolveTruthy(); + + // const authDb = enhance(); + // await expect(authDb.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toBeRejectedByPolicy(); + + // expect(() => enhance({ id: null })).toThrow(/Invalid user context/); + + // const authDb2 = enhance({ id: 'user1' }); + // await expect(authDb2.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toResolveTruthy(); + // }); + + // it('undefined user compared with field more', async () => { + // const { enhance } = await loadSchema( + // ` + // model User { + // id String @id @default(uuid()) + // posts Post[] + + // @@allow('all', true) + // } + + // model Post { + // id String @id @default(uuid()) + // title String + // author User @relation(fields: [authorId], references: [id]) + // authorId String + + // @@allow('create,read', true) + // @@allow('update', auth().id == author.id) + // } + // ` + // ); + + // const db = enhance(); + // await expect(db.user.create({ data: { id: 'user1' } })).toResolveTruthy(); + // await expect(db.post.create({ data: { id: '1', title: 'abc', authorId: 'user1' } })).toResolveTruthy(); + + // await expect(db.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toBeRejectedByPolicy(); + + // const authDb2 = enhance({ id: 'user1' }); + // await expect(authDb2.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toResolveTruthy(); + // }); + + // it('undefined user non-id field', async () => { + // const { enhance } = await loadSchema( + // ` + // model User { + // id String @id @default(uuid()) + // posts Post[] + // role String + + // @@allow('all', true) + // } + + // model Post { + // id String @id @default(uuid()) + // title String + // author User @relation(fields: [authorId], references: [id]) + // authorId String + + // @@allow('create,read', true) + // @@allow('update', auth().role == 'ADMIN') + // } + // ` + // ); + + // const db = enhance(); + // await expect(db.user.create({ data: { id: 'user1', role: 'USER' } })).toResolveTruthy(); + // await expect(db.post.create({ data: { id: '1', title: 'abc', authorId: 'user1' } })).toResolveTruthy(); + // await expect(db.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toBeRejectedByPolicy(); + + // const authDb = enhance({ id: 'user1', role: 'USER' }); + // await expect(authDb.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toBeRejectedByPolicy(); + + // const authDb1 = enhance({ id: 'user2', role: 'ADMIN' }); + // await expect(authDb1.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toResolveTruthy(); + // }); + + // it('non User auth model', async () => { + // const { enhance } = await loadSchema( + // ` + // model Foo { + // id String @id @default(uuid()) + // role String + + // @@auth() + // } + + // model Post { + // id String @id @default(uuid()) + // title String + + // @@allow('read', true) + // @@allow('create', auth().role == 'ADMIN') + // } + // ` + // ); + + // const userDb = enhance({ id: 'user1', role: 'USER' }); + // await expect(userDb.post.create({ data: { title: 'abc' } })).toBeRejectedByPolicy(); + + // const adminDb = enhance({ id: 'user1', role: 'ADMIN' }); + // await expect(adminDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); + // }); + + // it('User model ignored', async () => { + // const { enhance } = await loadSchema( + // ` + // model User { + // id String @id @default(uuid()) + // role String + + // @@ignore + // } + + // model Post { + // id String @id @default(uuid()) + // title String + + // @@allow('read', true) + // @@allow('create', auth().role == 'ADMIN') + // } + // ` + // ); + + // const userDb = enhance({ id: 'user1', role: 'USER' }); + // await expect(userDb.post.create({ data: { title: 'abc' } })).toBeRejectedByPolicy(); + + // const adminDb = enhance({ id: 'user1', role: 'ADMIN' }); + // await expect(adminDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); + // }); + + // it('Auth model ignored', async () => { + // const { enhance } = await loadSchema( + // ` + // model Foo { + // id String @id @default(uuid()) + // role String + + // @@auth() + // @@ignore + // } + + // model Post { + // id String @id @default(uuid()) + // title String + + // @@allow('read', true) + // @@allow('create', auth().role == 'ADMIN') + // } + // ` + // ); + + // const userDb = enhance({ id: 'user1', role: 'USER' }); + // await expect(userDb.post.create({ data: { title: 'abc' } })).toBeRejectedByPolicy(); + + // const adminDb = enhance({ id: 'user1', role: 'ADMIN' }); + // await expect(adminDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); + // }); + + // it('collection predicate', async () => { + // const { enhance, prisma } = await loadSchema( + // ` + // model User { + // id String @id @default(uuid()) + // posts Post[] + + // @@allow('all', true) + // } + + // model Post { + // id String @id @default(uuid()) + // title String + // published Boolean @default(false) + // author User @relation(fields: [authorId], references: [id]) + // authorId String + // comments Comment[] + + // @@allow('read', true) + // @@allow('create', auth().posts?[published && comments![published]]) + // } + + // model Comment { + // id String @id @default(uuid()) + // published Boolean @default(false) + // post Post @relation(fields: [postId], references: [id]) + // postId String + + // @@allow('all', true) + // } + // ` + // ); + + // const user = await prisma.user.create({ data: {} }); + + // const createPayload = { + // data: { title: 'Post 1', author: { connect: { id: user.id } } }, + // }; + + // // no post + // await expect(enhance({ id: '1' }).post.create(createPayload)).toBeRejectedByPolicy(); + + // // post not published + // await expect( + // enhance({ id: '1', posts: [{ id: '1', published: false }] }).post.create(createPayload) + // ).toBeRejectedByPolicy(); + + // // no comments + // await expect( + // enhance({ id: '1', posts: [{ id: '1', published: true }] }).post.create(createPayload) + // ).toBeRejectedByPolicy(); + + // // not all comments published + // await expect( + // enhance({ + // id: '1', + // posts: [ + // { + // id: '1', + // published: true, + // comments: [ + // { id: '1', published: true }, + // { id: '2', published: false }, + // ], + // }, + // ], + // }).post.create(createPayload) + // ).toBeRejectedByPolicy(); + + // // comments published but parent post is not + // await expect( + // enhance({ + // id: '1', + // posts: [ + // { id: '1', published: false, comments: [{ id: '1', published: true }] }, + // { id: '2', published: true }, + // ], + // }).post.create(createPayload) + // ).toBeRejectedByPolicy(); + + // await expect( + // enhance({ + // id: '1', + // posts: [ + // { id: '1', published: true, comments: [{ id: '1', published: true }] }, + // { id: '2', published: false }, + // ], + // }).post.create(createPayload) + // ).toResolveTruthy(); + + // // no comments ("every" evaluates to tru in this case) + // await expect( + // enhance({ id: '1', posts: [{ id: '1', published: true, comments: [] }] }).post.create(createPayload) + // ).toResolveTruthy(); + // }); + + it('Default auth() on literal fields', async () => { const { enhance } = await loadSchema( ` model User { - id String @id @default(uuid()) - } - - model Post { - id String @id @default(uuid()) - title String - - @@allow('read', true) - @@allow('create', auth() != null) - } - ` - ); - - const db = enhance(); - await expect(db.post.create({ data: { title: 'abc' } })).toBeRejectedByPolicy(); - - const authDb = enhance({ id: 'user1' }); - await expect(authDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); - }); - - it('undefined user with string id more', async () => { - const { enhance } = await loadSchema( - ` - model User { - id String @id @default(uuid()) - } - - model Post { - id String @id @default(uuid()) - title String - - @@allow('read', true) - @@allow('create', auth().id != null) - } - ` - ); - - const db = enhance(); - await expect(db.post.create({ data: { title: 'abc' } })).toBeRejectedByPolicy(); - - const authDb = enhance({ id: 'user1' }); - await expect(authDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); - }); - - it('undefined user with int id', async () => { - const { enhance } = await loadSchema( - ` - model User { - id Int @id @default(autoincrement()) - } - - model Post { - id String @id @default(uuid()) - title String - - @@allow('read', true) - @@allow('create', auth() != null) - } - ` - ); - - const db = enhance(); - await expect(db.post.create({ data: { title: 'abc' } })).toBeRejectedByPolicy(); - - const authDb = enhance({ id: 'user1' }); - await expect(authDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); - }); - - it('undefined user compared with field', async () => { - const { enhance } = await loadSchema( - ` - model User { - id String @id @default(uuid()) - posts Post[] - - @@allow('all', true) - } - - model Post { - id String @id @default(uuid()) - title String - author User @relation(fields: [authorId], references: [id]) - authorId String - - @@allow('create,read', true) - @@allow('update', auth() == author) - } - ` - ); - - const db = enhance(); - await expect(db.user.create({ data: { id: 'user1' } })).toResolveTruthy(); - await expect(db.post.create({ data: { id: '1', title: 'abc', authorId: 'user1' } })).toResolveTruthy(); - - const authDb = enhance(); - await expect(authDb.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toBeRejectedByPolicy(); - - expect(() => enhance({ id: null })).toThrow(/Invalid user context/); - - const authDb2 = enhance({ id: 'user1' }); - await expect(authDb2.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toResolveTruthy(); - }); - - it('undefined user compared with field more', async () => { - const { enhance } = await loadSchema( - ` - model User { - id String @id @default(uuid()) - posts Post[] + id String @id @default(autoincrement()) + name String + score Int - @@allow('all', true) } model Post { id String @id @default(uuid()) title String - author User @relation(fields: [authorId], references: [id]) - authorId String - - @@allow('create,read', true) - @@allow('update', auth().id == author.id) - } - ` - ); - - const db = enhance(); - await expect(db.user.create({ data: { id: 'user1' } })).toResolveTruthy(); - await expect(db.post.create({ data: { id: '1', title: 'abc', authorId: 'user1' } })).toResolveTruthy(); - - await expect(db.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toBeRejectedByPolicy(); - - const authDb2 = enhance({ id: 'user1' }); - await expect(authDb2.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toResolveTruthy(); - }); - - it('undefined user non-id field', async () => { - const { enhance } = await loadSchema( - ` - model User { - id String @id @default(uuid()) - posts Post[] - role String + score Int? @default(auth().score) + authorName String? @default(auth().name) @@allow('all', true) } - - model Post { - id String @id @default(uuid()) - title String - author User @relation(fields: [authorId], references: [id]) - authorId String - - @@allow('create,read', true) - @@allow('update', auth().role == 'ADMIN') - } - ` - ); - - const db = enhance(); - await expect(db.user.create({ data: { id: 'user1', role: 'USER' } })).toResolveTruthy(); - await expect(db.post.create({ data: { id: '1', title: 'abc', authorId: 'user1' } })).toResolveTruthy(); - await expect(db.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toBeRejectedByPolicy(); - - const authDb = enhance({ id: 'user1', role: 'USER' }); - await expect(authDb.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toBeRejectedByPolicy(); - - const authDb1 = enhance({ id: 'user2', role: 'ADMIN' }); - await expect(authDb1.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toResolveTruthy(); - }); - - it('non User auth model', async () => { - const { enhance } = await loadSchema( - ` - model Foo { - id String @id @default(uuid()) - role String - - @@auth() - } - - model Post { - id String @id @default(uuid()) - title String - - @@allow('read', true) - @@allow('create', auth().role == 'ADMIN') - } - ` - ); - - const userDb = enhance({ id: 'user1', role: 'USER' }); - await expect(userDb.post.create({ data: { title: 'abc' } })).toBeRejectedByPolicy(); - - const adminDb = enhance({ id: 'user1', role: 'ADMIN' }); - await expect(adminDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); - }); - - it('User model ignored', async () => { - const { enhance } = await loadSchema( - ` - model User { - id String @id @default(uuid()) - role String - - @@ignore - } - - model Post { - id String @id @default(uuid()) - title String - - @@allow('read', true) - @@allow('create', auth().role == 'ADMIN') - } ` ); - const userDb = enhance({ id: 'user1', role: 'USER' }); - await expect(userDb.post.create({ data: { title: 'abc' } })).toBeRejectedByPolicy(); - - const adminDb = enhance({ id: 'user1', role: 'ADMIN' }); - await expect(adminDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); + const userDb = enhance({ id: '1', name: 'user1', score: 10 }); + await expect(userDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); + console.log(await userDb.post.findMany()); + await expect(userDb.post.count({ where: { authorName: 'user1', score: 10 } })).resolves.toBe(1); }); - it('Auth model ignored', async () => { + it('Default auth() with foreign key', async () => { const { enhance } = await loadSchema( ` - model Foo { - id String @id @default(uuid()) - role String - - @@auth() - @@ignore - } - - model Post { - id String @id @default(uuid()) - title String - - @@allow('read', true) - @@allow('create', auth().role == 'ADMIN') - } - ` - ); - - const userDb = enhance({ id: 'user1', role: 'USER' }); - await expect(userDb.post.create({ data: { title: 'abc' } })).toBeRejectedByPolicy(); - - const adminDb = enhance({ id: 'user1', role: 'ADMIN' }); - await expect(adminDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); - }); - - it('collection predicate', async () => { - const { enhance, prisma } = await loadSchema( - ` model User { - id String @id @default(uuid()) + id String @id @default(autoincrement()) + name String posts Post[] - @@allow('all', true) } model Post { id String @id @default(uuid()) title String - published Boolean @default(false) author User @relation(fields: [authorId], references: [id]) - authorId String - comments Comment[] - - @@allow('read', true) - @@allow('create', auth().posts?[published && comments![published]]) - } - - model Comment { - id String @id @default(uuid()) - published Boolean @default(false) - post Post @relation(fields: [postId], references: [id]) - postId String - - @@allow('all', true) - } - ` - ); - - const user = await prisma.user.create({ data: {} }); - - const createPayload = { - data: { title: 'Post 1', author: { connect: { id: user.id } } }, - }; - - // no post - await expect(enhance({ id: '1' }).post.create(createPayload)).toBeRejectedByPolicy(); - - // post not published - await expect( - enhance({ id: '1', posts: [{ id: '1', published: false }] }).post.create(createPayload) - ).toBeRejectedByPolicy(); - - // no comments - await expect( - enhance({ id: '1', posts: [{ id: '1', published: true }] }).post.create(createPayload) - ).toBeRejectedByPolicy(); - - // not all comments published - await expect( - enhance({ - id: '1', - posts: [ - { - id: '1', - published: true, - comments: [ - { id: '1', published: true }, - { id: '2', published: false }, - ], - }, - ], - }).post.create(createPayload) - ).toBeRejectedByPolicy(); - - // comments published but parent post is not - await expect( - enhance({ - id: '1', - posts: [ - { id: '1', published: false, comments: [{ id: '1', published: true }] }, - { id: '2', published: true }, - ], - }).post.create(createPayload) - ).toBeRejectedByPolicy(); - - await expect( - enhance({ - id: '1', - posts: [ - { id: '1', published: true, comments: [{ id: '1', published: true }] }, - { id: '2', published: false }, - ], - }).post.create(createPayload) - ).toResolveTruthy(); - - // no comments ("every" evaluates to tru in this case) - await expect( - enhance({ id: '1', posts: [{ id: '1', published: true, comments: [] }] }).post.create(createPayload) - ).toResolveTruthy(); - }); - - it('Field with default auth() created correctly', async () => { - const { enhance } = await loadSchema( - ` - model User { - id String @id @default(uuid()) - name String - // posts Post[] - - } - - model Post { - id String @id @default(uuid()) - title String - authorName String @default(auth().name) - // author User @relation(fields: [authorId], references: [id]) - // authorId String @default(auth().id) + authorId String @default(auth().id) @@allow('all', true) } From 0cc43bbf2e7c774217b780bc4c841f3d338ed799 Mon Sep 17 00:00:00 2001 From: augustin Date: Tue, 23 Jan 2024 10:12:35 +0100 Subject: [PATCH 03/12] fix: avoid create operation called twice --- .../runtime/src/enhancements/default-auth.ts | 77 +- .../enhancements/with-policy/auth.test.ts | 721 +++++++++--------- 2 files changed, 427 insertions(+), 371 deletions(-) diff --git a/packages/runtime/src/enhancements/default-auth.ts b/packages/runtime/src/enhancements/default-auth.ts index f88c38cbe..05718b172 100644 --- a/packages/runtime/src/enhancements/default-auth.ts +++ b/packages/runtime/src/enhancements/default-auth.ts @@ -1,16 +1,8 @@ /* eslint-disable @typescript-eslint/no-unused-vars */ /* eslint-disable @typescript-eslint/no-explicit-any */ -import { - enumerate, - getModelFields, - resolveField, - type ModelMeta, - NestedWriteVisitor, - PrismaWriteActionType, - NestedWriteVisitorContext, - FieldInfo, -} from '../cross'; +import deepcopy from 'deepcopy'; +import { NestedWriteVisitor, PrismaWriteActionType, FieldInfo } from '../cross'; import { DbClientContract } from '../types'; import { EnhancementContext, EnhancementOptions } from './create-enhancement'; import { DefaultPrismaProxyHandler, PrismaProxyActions, makeProxy } from './proxy'; @@ -56,7 +48,7 @@ class DefaultAuthHandler extends DefaultPrismaProxyHandler { private async preprocessWritePayload(model: string, action: PrismaWriteActionType, args: any) { const visitor = new NestedWriteVisitor(this.options.modelMeta, { - create: async (model, args, _context) => { + field: async (field, action, data, context) => { const userContext = this.context?.user; if (!userContext) { throw new Error(`Invalid user context`); @@ -81,11 +73,70 @@ class DefaultAuthHandler extends DefaultPrismaProxyHandler { ]) ); console.log('defaultAuthFields :', defaultAuthFields); - const result = await this.db[model].create({ data: { ...defaultAuthFields, ...args } }); - return result; + for (const [field, defaultValue] of Object.entries(defaultAuthFields)) { + // const fieldInfo = fields[field]; + // console.log('fieldInfo :', fieldInfo); + // console.log('isForeignKey :', fieldInfo.isForeignKey); + // if (fieldInfo.isForeignKey) { + // console.log('field :', field); + // console.log('defaultValue :', defaultValue); + // const data = getConnectDefaultValue(fields, field, defaultValue); + // const connectedField = Object.keys(data)[0]; + // console.log('data : ', data); + // context.parent[connectedField] = data[connectedField]; + // } else { + context.parent[field] = defaultValue; + // } + } }, }); await visitor.visit(model, action, args); } } + +// function hasForeignKeyMapping(fieldInfo: FieldInfo) { +// return fieldInfo.foreignKeyMapping !== undefined; +// } + +// function getConnectDefaultValue(fields: Record, field: string, defaultValue: unknown) { +// for (const key in fields) { +// const fieldInfo = fields[key]; +// if (hasForeignKeyMapping(fieldInfo)) { +// const connectedRawValue = { connect: fieldInfo.foreignKeyMapping! }; +// const connectedValue = replaceFirstValue(connectedRawValue, defaultValue); +// console.log('old data :', { [fieldInfo.name]: connectedRawValue }); +// return { [fieldInfo.name]: connectedValue }; +// } +// } +// return {}; +// } + +// function replaceFirstValue(obj: Record, newValue: any) { +// // Fonction récursive pour parcourir l'objet +// function replaceFirstValueRecursive(currentObj: Record) { +// for (const key in currentObj) { +// if (typeof currentObj[key] === 'object' && currentObj[key] !== null) { +// // Remplace la première valeur trouvée dans l'objet +// for (const nestedKey in currentObj[key]) { +// // eslint-disable-next-line no-prototype-builtins +// if (currentObj[key].hasOwnProperty(nestedKey)) { +// currentObj[key][nestedKey] = newValue; +// return; +// } +// } + +// // Continue la recherche récursive +// replaceFirstValueRecursive(currentObj[key]); +// } +// } +// } + +// // Clone l'objet pour ne pas modifier l'original +// const clonedObj = JSON.parse(JSON.stringify(obj)); + +// // Appelle la fonction récursive avec l'objet cloné +// replaceFirstValueRecursive(clonedObj); + +// return clonedObj; +// } diff --git a/tests/integration/tests/enhancements/with-policy/auth.test.ts b/tests/integration/tests/enhancements/with-policy/auth.test.ts index 973177c27..10b1c961b 100644 --- a/tests/integration/tests/enhancements/with-policy/auth.test.ts +++ b/tests/integration/tests/enhancements/with-policy/auth.test.ts @@ -12,363 +12,363 @@ describe('With Policy: auth() test', () => { process.chdir(origDir); }); - // it('undefined user with string id simple', async () => { - // const { enhance } = await loadSchema( - // ` - // model User { - // id String @id @default(uuid()) - // } - - // model Post { - // id String @id @default(uuid()) - // title String - - // @@allow('read', true) - // @@allow('create', auth() != null) - // } - // ` - // ); - - // const db = enhance(); - // await expect(db.post.create({ data: { title: 'abc' } })).toBeRejectedByPolicy(); - - // const authDb = enhance({ id: 'user1' }); - // await expect(authDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); - // }); - - // it('undefined user with string id more', async () => { - // const { enhance } = await loadSchema( - // ` - // model User { - // id String @id @default(uuid()) - // } - - // model Post { - // id String @id @default(uuid()) - // title String - - // @@allow('read', true) - // @@allow('create', auth().id != null) - // } - // ` - // ); - - // const db = enhance(); - // await expect(db.post.create({ data: { title: 'abc' } })).toBeRejectedByPolicy(); - - // const authDb = enhance({ id: 'user1' }); - // await expect(authDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); - // }); - - // it('undefined user with int id', async () => { - // const { enhance } = await loadSchema( - // ` - // model User { - // id Int @id @default(autoincrement()) - // } - - // model Post { - // id String @id @default(uuid()) - // title String - - // @@allow('read', true) - // @@allow('create', auth() != null) - // } - // ` - // ); - - // const db = enhance(); - // await expect(db.post.create({ data: { title: 'abc' } })).toBeRejectedByPolicy(); - - // const authDb = enhance({ id: 'user1' }); - // await expect(authDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); - // }); - - // it('undefined user compared with field', async () => { - // const { enhance } = await loadSchema( - // ` - // model User { - // id String @id @default(uuid()) - // posts Post[] - - // @@allow('all', true) - // } - - // model Post { - // id String @id @default(uuid()) - // title String - // author User @relation(fields: [authorId], references: [id]) - // authorId String - - // @@allow('create,read', true) - // @@allow('update', auth() == author) - // } - // ` - // ); - - // const db = enhance(); - // await expect(db.user.create({ data: { id: 'user1' } })).toResolveTruthy(); - // await expect(db.post.create({ data: { id: '1', title: 'abc', authorId: 'user1' } })).toResolveTruthy(); - - // const authDb = enhance(); - // await expect(authDb.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toBeRejectedByPolicy(); - - // expect(() => enhance({ id: null })).toThrow(/Invalid user context/); - - // const authDb2 = enhance({ id: 'user1' }); - // await expect(authDb2.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toResolveTruthy(); - // }); - - // it('undefined user compared with field more', async () => { - // const { enhance } = await loadSchema( - // ` - // model User { - // id String @id @default(uuid()) - // posts Post[] - - // @@allow('all', true) - // } - - // model Post { - // id String @id @default(uuid()) - // title String - // author User @relation(fields: [authorId], references: [id]) - // authorId String - - // @@allow('create,read', true) - // @@allow('update', auth().id == author.id) - // } - // ` - // ); - - // const db = enhance(); - // await expect(db.user.create({ data: { id: 'user1' } })).toResolveTruthy(); - // await expect(db.post.create({ data: { id: '1', title: 'abc', authorId: 'user1' } })).toResolveTruthy(); - - // await expect(db.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toBeRejectedByPolicy(); - - // const authDb2 = enhance({ id: 'user1' }); - // await expect(authDb2.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toResolveTruthy(); - // }); - - // it('undefined user non-id field', async () => { - // const { enhance } = await loadSchema( - // ` - // model User { - // id String @id @default(uuid()) - // posts Post[] - // role String - - // @@allow('all', true) - // } - - // model Post { - // id String @id @default(uuid()) - // title String - // author User @relation(fields: [authorId], references: [id]) - // authorId String - - // @@allow('create,read', true) - // @@allow('update', auth().role == 'ADMIN') - // } - // ` - // ); - - // const db = enhance(); - // await expect(db.user.create({ data: { id: 'user1', role: 'USER' } })).toResolveTruthy(); - // await expect(db.post.create({ data: { id: '1', title: 'abc', authorId: 'user1' } })).toResolveTruthy(); - // await expect(db.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toBeRejectedByPolicy(); - - // const authDb = enhance({ id: 'user1', role: 'USER' }); - // await expect(authDb.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toBeRejectedByPolicy(); - - // const authDb1 = enhance({ id: 'user2', role: 'ADMIN' }); - // await expect(authDb1.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toResolveTruthy(); - // }); - - // it('non User auth model', async () => { - // const { enhance } = await loadSchema( - // ` - // model Foo { - // id String @id @default(uuid()) - // role String - - // @@auth() - // } - - // model Post { - // id String @id @default(uuid()) - // title String - - // @@allow('read', true) - // @@allow('create', auth().role == 'ADMIN') - // } - // ` - // ); - - // const userDb = enhance({ id: 'user1', role: 'USER' }); - // await expect(userDb.post.create({ data: { title: 'abc' } })).toBeRejectedByPolicy(); - - // const adminDb = enhance({ id: 'user1', role: 'ADMIN' }); - // await expect(adminDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); - // }); - - // it('User model ignored', async () => { - // const { enhance } = await loadSchema( - // ` - // model User { - // id String @id @default(uuid()) - // role String - - // @@ignore - // } - - // model Post { - // id String @id @default(uuid()) - // title String - - // @@allow('read', true) - // @@allow('create', auth().role == 'ADMIN') - // } - // ` - // ); - - // const userDb = enhance({ id: 'user1', role: 'USER' }); - // await expect(userDb.post.create({ data: { title: 'abc' } })).toBeRejectedByPolicy(); - - // const adminDb = enhance({ id: 'user1', role: 'ADMIN' }); - // await expect(adminDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); - // }); - - // it('Auth model ignored', async () => { - // const { enhance } = await loadSchema( - // ` - // model Foo { - // id String @id @default(uuid()) - // role String - - // @@auth() - // @@ignore - // } - - // model Post { - // id String @id @default(uuid()) - // title String - - // @@allow('read', true) - // @@allow('create', auth().role == 'ADMIN') - // } - // ` - // ); - - // const userDb = enhance({ id: 'user1', role: 'USER' }); - // await expect(userDb.post.create({ data: { title: 'abc' } })).toBeRejectedByPolicy(); - - // const adminDb = enhance({ id: 'user1', role: 'ADMIN' }); - // await expect(adminDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); - // }); - - // it('collection predicate', async () => { - // const { enhance, prisma } = await loadSchema( - // ` - // model User { - // id String @id @default(uuid()) - // posts Post[] - - // @@allow('all', true) - // } - - // model Post { - // id String @id @default(uuid()) - // title String - // published Boolean @default(false) - // author User @relation(fields: [authorId], references: [id]) - // authorId String - // comments Comment[] - - // @@allow('read', true) - // @@allow('create', auth().posts?[published && comments![published]]) - // } - - // model Comment { - // id String @id @default(uuid()) - // published Boolean @default(false) - // post Post @relation(fields: [postId], references: [id]) - // postId String - - // @@allow('all', true) - // } - // ` - // ); - - // const user = await prisma.user.create({ data: {} }); - - // const createPayload = { - // data: { title: 'Post 1', author: { connect: { id: user.id } } }, - // }; - - // // no post - // await expect(enhance({ id: '1' }).post.create(createPayload)).toBeRejectedByPolicy(); - - // // post not published - // await expect( - // enhance({ id: '1', posts: [{ id: '1', published: false }] }).post.create(createPayload) - // ).toBeRejectedByPolicy(); - - // // no comments - // await expect( - // enhance({ id: '1', posts: [{ id: '1', published: true }] }).post.create(createPayload) - // ).toBeRejectedByPolicy(); - - // // not all comments published - // await expect( - // enhance({ - // id: '1', - // posts: [ - // { - // id: '1', - // published: true, - // comments: [ - // { id: '1', published: true }, - // { id: '2', published: false }, - // ], - // }, - // ], - // }).post.create(createPayload) - // ).toBeRejectedByPolicy(); - - // // comments published but parent post is not - // await expect( - // enhance({ - // id: '1', - // posts: [ - // { id: '1', published: false, comments: [{ id: '1', published: true }] }, - // { id: '2', published: true }, - // ], - // }).post.create(createPayload) - // ).toBeRejectedByPolicy(); - - // await expect( - // enhance({ - // id: '1', - // posts: [ - // { id: '1', published: true, comments: [{ id: '1', published: true }] }, - // { id: '2', published: false }, - // ], - // }).post.create(createPayload) - // ).toResolveTruthy(); - - // // no comments ("every" evaluates to tru in this case) - // await expect( - // enhance({ id: '1', posts: [{ id: '1', published: true, comments: [] }] }).post.create(createPayload) - // ).toResolveTruthy(); - // }); + it('undefined user with string id simple', async () => { + const { enhance } = await loadSchema( + ` + model User { + id String @id @default(uuid()) + } + + model Post { + id String @id @default(uuid()) + title String + + @@allow('read', true) + @@allow('create', auth() != null) + } + ` + ); + + const db = enhance(); + await expect(db.post.create({ data: { title: 'abc' } })).toBeRejectedByPolicy(); + + const authDb = enhance({ id: 'user1' }); + await expect(authDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); + }); + + it('undefined user with string id more', async () => { + const { enhance } = await loadSchema( + ` + model User { + id String @id @default(uuid()) + } + + model Post { + id String @id @default(uuid()) + title String + + @@allow('read', true) + @@allow('create', auth().id != null) + } + ` + ); + + const db = enhance(); + await expect(db.post.create({ data: { title: 'abc' } })).toBeRejectedByPolicy(); + + const authDb = enhance({ id: 'user1' }); + await expect(authDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); + }); + + it('undefined user with int id', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + } + + model Post { + id String @id @default(uuid()) + title String + + @@allow('read', true) + @@allow('create', auth() != null) + } + ` + ); + + const db = enhance(); + await expect(db.post.create({ data: { title: 'abc' } })).toBeRejectedByPolicy(); + + const authDb = enhance({ id: 'user1' }); + await expect(authDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); + }); + + it('undefined user compared with field', async () => { + const { enhance } = await loadSchema( + ` + model User { + id String @id @default(uuid()) + posts Post[] + + @@allow('all', true) + } + + model Post { + id String @id @default(uuid()) + title String + author User @relation(fields: [authorId], references: [id]) + authorId String + + @@allow('create,read', true) + @@allow('update', auth() == author) + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 'user1' } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: '1', title: 'abc', authorId: 'user1' } })).toResolveTruthy(); + + const authDb = enhance(); + await expect(authDb.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toBeRejectedByPolicy(); + + expect(() => enhance({ id: null })).toThrow(/Invalid user context/); + + const authDb2 = enhance({ id: 'user1' }); + await expect(authDb2.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toResolveTruthy(); + }); + + it('undefined user compared with field more', async () => { + const { enhance } = await loadSchema( + ` + model User { + id String @id @default(uuid()) + posts Post[] + + @@allow('all', true) + } + + model Post { + id String @id @default(uuid()) + title String + author User @relation(fields: [authorId], references: [id]) + authorId String + + @@allow('create,read', true) + @@allow('update', auth().id == author.id) + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 'user1' } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: '1', title: 'abc', authorId: 'user1' } })).toResolveTruthy(); + + await expect(db.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toBeRejectedByPolicy(); + + const authDb2 = enhance({ id: 'user1' }); + await expect(authDb2.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toResolveTruthy(); + }); + + it('undefined user non-id field', async () => { + const { enhance } = await loadSchema( + ` + model User { + id String @id @default(uuid()) + posts Post[] + role String + + @@allow('all', true) + } + + model Post { + id String @id @default(uuid()) + title String + author User @relation(fields: [authorId], references: [id]) + authorId String + + @@allow('create,read', true) + @@allow('update', auth().role == 'ADMIN') + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 'user1', role: 'USER' } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: '1', title: 'abc', authorId: 'user1' } })).toResolveTruthy(); + await expect(db.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toBeRejectedByPolicy(); + + const authDb = enhance({ id: 'user1', role: 'USER' }); + await expect(authDb.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toBeRejectedByPolicy(); + + const authDb1 = enhance({ id: 'user2', role: 'ADMIN' }); + await expect(authDb1.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toResolveTruthy(); + }); + + it('non User auth model', async () => { + const { enhance } = await loadSchema( + ` + model Foo { + id String @id @default(uuid()) + role String + + @@auth() + } + + model Post { + id String @id @default(uuid()) + title String + + @@allow('read', true) + @@allow('create', auth().role == 'ADMIN') + } + ` + ); + + const userDb = enhance({ id: 'user1', role: 'USER' }); + await expect(userDb.post.create({ data: { title: 'abc' } })).toBeRejectedByPolicy(); + + const adminDb = enhance({ id: 'user1', role: 'ADMIN' }); + await expect(adminDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); + }); + + it('User model ignored', async () => { + const { enhance } = await loadSchema( + ` + model User { + id String @id @default(uuid()) + role String + + @@ignore + } + + model Post { + id String @id @default(uuid()) + title String + + @@allow('read', true) + @@allow('create', auth().role == 'ADMIN') + } + ` + ); + + const userDb = enhance({ id: 'user1', role: 'USER' }); + await expect(userDb.post.create({ data: { title: 'abc' } })).toBeRejectedByPolicy(); + + const adminDb = enhance({ id: 'user1', role: 'ADMIN' }); + await expect(adminDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); + }); + + it('Auth model ignored', async () => { + const { enhance } = await loadSchema( + ` + model Foo { + id String @id @default(uuid()) + role String + + @@auth() + @@ignore + } + + model Post { + id String @id @default(uuid()) + title String + + @@allow('read', true) + @@allow('create', auth().role == 'ADMIN') + } + ` + ); + + const userDb = enhance({ id: 'user1', role: 'USER' }); + await expect(userDb.post.create({ data: { title: 'abc' } })).toBeRejectedByPolicy(); + + const adminDb = enhance({ id: 'user1', role: 'ADMIN' }); + await expect(adminDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); + }); + + it('collection predicate', async () => { + const { enhance, prisma } = await loadSchema( + ` + model User { + id String @id @default(uuid()) + posts Post[] + + @@allow('all', true) + } + + model Post { + id String @id @default(uuid()) + title String + published Boolean @default(false) + author User @relation(fields: [authorId], references: [id]) + authorId String + comments Comment[] + + @@allow('read', true) + @@allow('create', auth().posts?[published && comments![published]]) + } + + model Comment { + id String @id @default(uuid()) + published Boolean @default(false) + post Post @relation(fields: [postId], references: [id]) + postId String + + @@allow('all', true) + } + ` + ); + + const user = await prisma.user.create({ data: {} }); + + const createPayload = { + data: { title: 'Post 1', author: { connect: { id: user.id } } }, + }; + + // no post + await expect(enhance({ id: '1' }).post.create(createPayload)).toBeRejectedByPolicy(); + + // post not published + await expect( + enhance({ id: '1', posts: [{ id: '1', published: false }] }).post.create(createPayload) + ).toBeRejectedByPolicy(); + + // no comments + await expect( + enhance({ id: '1', posts: [{ id: '1', published: true }] }).post.create(createPayload) + ).toBeRejectedByPolicy(); + + // not all comments published + await expect( + enhance({ + id: '1', + posts: [ + { + id: '1', + published: true, + comments: [ + { id: '1', published: true }, + { id: '2', published: false }, + ], + }, + ], + }).post.create(createPayload) + ).toBeRejectedByPolicy(); + + // comments published but parent post is not + await expect( + enhance({ + id: '1', + posts: [ + { id: '1', published: false, comments: [{ id: '1', published: true }] }, + { id: '2', published: true }, + ], + }).post.create(createPayload) + ).toBeRejectedByPolicy(); + + await expect( + enhance({ + id: '1', + posts: [ + { id: '1', published: true, comments: [{ id: '1', published: true }] }, + { id: '2', published: false }, + ], + }).post.create(createPayload) + ).toResolveTruthy(); + + // no comments ("every" evaluates to tru in this case) + await expect( + enhance({ id: '1', posts: [{ id: '1', published: true, comments: [] }] }).post.create(createPayload) + ).toResolveTruthy(); + }); it('Default auth() on literal fields', async () => { const { enhance } = await loadSchema( ` model User { - id String @id @default(autoincrement()) + id String @id name String score Int @@ -387,6 +387,8 @@ describe('With Policy: auth() test', () => { const userDb = enhance({ id: '1', name: 'user1', score: 10 }); await expect(userDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); + await expect(userDb.post.findMany()).resolves.toHaveLength(1); + console.log(await userDb.post.findMany()); await expect(userDb.post.count({ where: { authorName: 'user1', score: 10 } })).resolves.toBe(1); }); @@ -395,10 +397,11 @@ describe('With Policy: auth() test', () => { const { enhance } = await loadSchema( ` model User { - id String @id @default(autoincrement()) - name String + id String @id posts Post[] + @@allow('all', true) + } model Post { @@ -412,9 +415,11 @@ describe('With Policy: auth() test', () => { ` ); - const userDb = enhance({ id: '1', name: 'user1' }); - await expect(userDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); - console.log(await userDb.post.findMany()); - await expect(userDb.post.count({ where: { authorName: 'user1' } })).resolves.toBe(1); + const db = enhance({ id: 'userId-1' }); + await expect(db.user.create({ data: { id: 'userId-1' } })).toResolveTruthy(); + console.log(await db.user.findMany()); + await expect(db.post.create({ data: { title: 'abc' } })).toResolveTruthy(); + console.log(await db.post.findMany()); + await expect(db.post.count({ where: { authorId: 'userId-1' } })).resolves.toBe(1); }); }); From a551f30458f65c8625927b360872814060a61be3 Mon Sep 17 00:00:00 2001 From: augustin Date: Tue, 23 Jan 2024 10:15:51 +0100 Subject: [PATCH 04/12] remove unused code --- .../runtime/src/enhancements/default-auth.ts | 59 ------------------- .../enhancements/with-policy/auth.test.ts | 4 -- 2 files changed, 63 deletions(-) diff --git a/packages/runtime/src/enhancements/default-auth.ts b/packages/runtime/src/enhancements/default-auth.ts index 05718b172..9027038f7 100644 --- a/packages/runtime/src/enhancements/default-auth.ts +++ b/packages/runtime/src/enhancements/default-auth.ts @@ -1,7 +1,6 @@ /* eslint-disable @typescript-eslint/no-unused-vars */ /* eslint-disable @typescript-eslint/no-explicit-any */ -import deepcopy from 'deepcopy'; import { NestedWriteVisitor, PrismaWriteActionType, FieldInfo } from '../cross'; import { DbClientContract } from '../types'; import { EnhancementContext, EnhancementOptions } from './create-enhancement'; @@ -74,19 +73,7 @@ class DefaultAuthHandler extends DefaultPrismaProxyHandler { ); console.log('defaultAuthFields :', defaultAuthFields); for (const [field, defaultValue] of Object.entries(defaultAuthFields)) { - // const fieldInfo = fields[field]; - // console.log('fieldInfo :', fieldInfo); - // console.log('isForeignKey :', fieldInfo.isForeignKey); - // if (fieldInfo.isForeignKey) { - // console.log('field :', field); - // console.log('defaultValue :', defaultValue); - // const data = getConnectDefaultValue(fields, field, defaultValue); - // const connectedField = Object.keys(data)[0]; - // console.log('data : ', data); - // context.parent[connectedField] = data[connectedField]; - // } else { context.parent[field] = defaultValue; - // } } }, }); @@ -94,49 +81,3 @@ class DefaultAuthHandler extends DefaultPrismaProxyHandler { await visitor.visit(model, action, args); } } - -// function hasForeignKeyMapping(fieldInfo: FieldInfo) { -// return fieldInfo.foreignKeyMapping !== undefined; -// } - -// function getConnectDefaultValue(fields: Record, field: string, defaultValue: unknown) { -// for (const key in fields) { -// const fieldInfo = fields[key]; -// if (hasForeignKeyMapping(fieldInfo)) { -// const connectedRawValue = { connect: fieldInfo.foreignKeyMapping! }; -// const connectedValue = replaceFirstValue(connectedRawValue, defaultValue); -// console.log('old data :', { [fieldInfo.name]: connectedRawValue }); -// return { [fieldInfo.name]: connectedValue }; -// } -// } -// return {}; -// } - -// function replaceFirstValue(obj: Record, newValue: any) { -// // Fonction récursive pour parcourir l'objet -// function replaceFirstValueRecursive(currentObj: Record) { -// for (const key in currentObj) { -// if (typeof currentObj[key] === 'object' && currentObj[key] !== null) { -// // Remplace la première valeur trouvée dans l'objet -// for (const nestedKey in currentObj[key]) { -// // eslint-disable-next-line no-prototype-builtins -// if (currentObj[key].hasOwnProperty(nestedKey)) { -// currentObj[key][nestedKey] = newValue; -// return; -// } -// } - -// // Continue la recherche récursive -// replaceFirstValueRecursive(currentObj[key]); -// } -// } -// } - -// // Clone l'objet pour ne pas modifier l'original -// const clonedObj = JSON.parse(JSON.stringify(obj)); - -// // Appelle la fonction récursive avec l'objet cloné -// replaceFirstValueRecursive(clonedObj); - -// return clonedObj; -// } diff --git a/tests/integration/tests/enhancements/with-policy/auth.test.ts b/tests/integration/tests/enhancements/with-policy/auth.test.ts index 10b1c961b..619d5aa47 100644 --- a/tests/integration/tests/enhancements/with-policy/auth.test.ts +++ b/tests/integration/tests/enhancements/with-policy/auth.test.ts @@ -388,8 +388,6 @@ describe('With Policy: auth() test', () => { const userDb = enhance({ id: '1', name: 'user1', score: 10 }); await expect(userDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); await expect(userDb.post.findMany()).resolves.toHaveLength(1); - - console.log(await userDb.post.findMany()); await expect(userDb.post.count({ where: { authorName: 'user1', score: 10 } })).resolves.toBe(1); }); @@ -417,9 +415,7 @@ describe('With Policy: auth() test', () => { const db = enhance({ id: 'userId-1' }); await expect(db.user.create({ data: { id: 'userId-1' } })).toResolveTruthy(); - console.log(await db.user.findMany()); await expect(db.post.create({ data: { title: 'abc' } })).toResolveTruthy(); - console.log(await db.post.findMany()); await expect(db.post.count({ where: { authorId: 'userId-1' } })).resolves.toBe(1); }); }); From c8f049c4e6e77a84d6e216b806e4b8c5e35e3f8e Mon Sep 17 00:00:00 2001 From: augustin Date: Tue, 23 Jan 2024 11:14:44 +0100 Subject: [PATCH 05/12] clean up --- package.json | 2 +- packages/ide/jetbrains/package.json | 2 +- packages/runtime/src/cross/mutator.ts | 1 - .../src/plugins/prisma/prisma-builder.ts | 19 ++----------------- 4 files changed, 4 insertions(+), 20 deletions(-) diff --git a/package.json b/package.json index c7b2f6ade..83d9f7072 100644 --- a/package.json +++ b/package.json @@ -5,7 +5,7 @@ "scripts": { "build": "pnpm -r build", "lint": "pnpm -r lint", - "test": "ZENSTACK_TEST=1 pnpm -r run test --silent=false --forceExit", + "test": "ZENSTACK_TEST=1 pnpm -r run test --silent --forceExit", "test-ci": "ZENSTACK_TEST=1 pnpm -r run test --silent --forceExit", "publish-all": "pnpm --filter \"./packages/**\" -r publish --access public", "publish-preview": "pnpm --filter \"./packages/**\" -r publish --force --registry https://preview.registry.zenstack.dev/", diff --git a/packages/ide/jetbrains/package.json b/packages/ide/jetbrains/package.json index ec05801e3..4e7fc26df 100644 --- a/packages/ide/jetbrains/package.json +++ b/packages/ide/jetbrains/package.json @@ -6,7 +6,7 @@ "homepage": "https://zenstack.dev", "private": true, "scripts": { - "build": "echo './gradlew buildPlugin'" + "build": "./gradlew buildPlugin" }, "author": "ZenStack Team", "license": "MIT", diff --git a/packages/runtime/src/cross/mutator.ts b/packages/runtime/src/cross/mutator.ts index 0ed5761a4..0dd66e6fb 100644 --- a/packages/runtime/src/cross/mutator.ts +++ b/packages/runtime/src/cross/mutator.ts @@ -124,7 +124,6 @@ function createMutate( insert[name] = newData[name]; } else { const defaultAttr = field.attributes?.find((attr) => attr.name === '@default'); - // TODO: handle default auth() attributes here ? if (field.type === 'DateTime') { // default value for DateTime field if (defaultAttr || field.attributes?.some((attr) => attr.name === '@updatedAt')) { diff --git a/packages/schema/src/plugins/prisma/prisma-builder.ts b/packages/schema/src/plugins/prisma/prisma-builder.ts index 3d9c68357..68336baeb 100644 --- a/packages/schema/src/plugins/prisma/prisma-builder.ts +++ b/packages/schema/src/plugins/prisma/prisma-builder.ts @@ -225,8 +225,8 @@ export class AttributeArg { export class AttributeArgValue { constructor( - public type: 'String' | 'FieldReference' | 'Number' | 'Boolean' | 'Array' | 'FunctionCall' | 'AuthAttribute', - public value: string | number | boolean | FieldReference | FunctionCall | AttributeArgValue[] | AuthAttribute + public type: 'String' | 'FieldReference' | 'Number' | 'Boolean' | 'Array' | 'FunctionCall', + public value: string | number | boolean | FieldReference | FunctionCall | AttributeArgValue[] ) { switch (type) { case 'String': @@ -249,10 +249,6 @@ export class AttributeArgValue { case 'FunctionCall': if (!(value instanceof FunctionCall)) throw new Error('Value must be FunctionCall'); break; - case 'AuthAttribute': - // TODO: implement validation - // if (!(value instanceof FunctionCall)) throw new Error('Value must be FunctionCall'); - break; } } @@ -275,8 +271,6 @@ export class AttributeArgValue { return r; } } - case 'AuthAttribute': - return this.value.toString(); case 'FunctionCall': return this.value.toString(); case 'Boolean': @@ -316,15 +310,6 @@ export class FunctionCallArg { return this.name ? `${this.name}: ${this.value}` : this.value; } } - -export class AuthAttribute { - constructor(public field: string) {} - - toString(): string { - return `"${this.field}"`; - } -} - export class Enum extends ContainerDeclaration { public fields: EnumField[] = []; From 08a7282bd23cb09c85c74437b495136f70faaba1 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Wed, 24 Jan 2024 11:39:55 +0800 Subject: [PATCH 06/12] fix: do not cache proxy detection flags in test environment --- packages/runtime/src/enhancements/create-enhancement.ts | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/packages/runtime/src/enhancements/create-enhancement.ts b/packages/runtime/src/enhancements/create-enhancement.ts index bb8baef54..0b128d82e 100644 --- a/packages/runtime/src/enhancements/create-enhancement.ts +++ b/packages/runtime/src/enhancements/create-enhancement.ts @@ -123,7 +123,12 @@ export function createEnhancement( let result = prisma; - if (hasPassword === undefined || hasOmit === undefined || hasDefaultAuth === undefined) { + if ( + process.env.ZENSTACK_TEST === '1' || // avoid caching in tests + hasPassword === undefined || + hasOmit === undefined || + hasDefaultAuth === undefined + ) { const allFields = Object.values(options.modelMeta.fields).flatMap((modelInfo) => Object.values(modelInfo)); hasPassword = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@password')); hasOmit = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@omit')); From 92de62e1b63575f206dba4d218087de0731ffc97 Mon Sep 17 00:00:00 2001 From: augustin Date: Wed, 24 Jan 2024 13:11:24 +0100 Subject: [PATCH 07/12] fix: use create callback to add default auth args --- packages/runtime/src/cross/model-meta.ts | 2 +- .../runtime/src/cross/nested-write-visitor.ts | 2 +- .../runtime/src/enhancements/default-auth.ts | 39 ++++++++++--------- packages/runtime/src/enhancements/utils.ts | 15 +++++++ .../src/plugins/prisma/schema-generator.ts | 26 ++++++------- packages/sdk/src/utils.ts | 7 ++++ .../enhancements/with-policy/auth.test.ts | 36 ++++++++++++++++- 7 files changed, 92 insertions(+), 35 deletions(-) diff --git a/packages/runtime/src/cross/model-meta.ts b/packages/runtime/src/cross/model-meta.ts index 89b1e11b7..fc6581684 100644 --- a/packages/runtime/src/cross/model-meta.ts +++ b/packages/runtime/src/cross/model-meta.ts @@ -3,7 +3,7 @@ import { lowerCaseFirst } from 'lower-case-first'; /** * An access key in the user context object (e.g. `profile.picture.url`) */ -export type AuthContextSelector = string; +export type AuthContextSelector = string | undefined; /** * Runtime information of a data model or field attribute diff --git a/packages/runtime/src/cross/nested-write-visitor.ts b/packages/runtime/src/cross/nested-write-visitor.ts index 7d67f6d9b..477117dbd 100644 --- a/packages/runtime/src/cross/nested-write-visitor.ts +++ b/packages/runtime/src/cross/nested-write-visitor.ts @@ -34,7 +34,7 @@ export type NestedWriteVisitorContext = { * to let the visitor traverse it instead of its original children. */ export type NestedWriterVisitorCallback = { - create?: (model: string, args: any[], context: NestedWriteVisitorContext) => MaybePromise; + create?: (model: string, data: any, context: NestedWriteVisitorContext) => MaybePromise; createMany?: ( model: string, diff --git a/packages/runtime/src/enhancements/default-auth.ts b/packages/runtime/src/enhancements/default-auth.ts index 9027038f7..108004a33 100644 --- a/packages/runtime/src/enhancements/default-auth.ts +++ b/packages/runtime/src/enhancements/default-auth.ts @@ -1,10 +1,11 @@ /* eslint-disable @typescript-eslint/no-unused-vars */ /* eslint-disable @typescript-eslint/no-explicit-any */ -import { NestedWriteVisitor, PrismaWriteActionType, FieldInfo } from '../cross'; +import { NestedWriteVisitor, PrismaWriteActionType, FieldInfo, AuthContextSelector } from '../cross'; import { DbClientContract } from '../types'; import { EnhancementContext, EnhancementOptions } from './create-enhancement'; import { DefaultPrismaProxyHandler, PrismaProxyActions, makeProxy } from './proxy'; +import { deepGet } from './utils'; /** * Gets an enhanced Prisma client that supports `@default(auth())` attribute. @@ -40,44 +41,46 @@ class DefaultAuthHandler extends DefaultPrismaProxyHandler { protected async preprocessArgs(action: PrismaProxyActions, args: any) { const actionsOfInterest: PrismaProxyActions[] = ['create', 'createMany', 'update', 'updateMany', 'upsert']; if (actionsOfInterest.includes(action)) { - await this.preprocessWritePayload(this.model, action as PrismaWriteActionType, args); + const newArgs = await this.preprocessWritePayload(this.model, action as PrismaWriteActionType, args); + return newArgs; } return args; } private async preprocessWritePayload(model: string, action: PrismaWriteActionType, args: any) { + let newArgs = {}; const visitor = new NestedWriteVisitor(this.options.modelMeta, { - field: async (field, action, data, context) => { + create: (model, _data, _context) => { const userContext = this.context?.user; if (!userContext) { throw new Error(`Invalid user context`); } const fields = this.options.modelMeta.fields[model]; - const isDefaultAuthField = (fieldInfo: FieldInfo) => - fieldInfo.attributes?.find((attr) => attr.name === '@default' && attr.args?.[0]?.name === 'auth()'); - const defaultAuthSelectorFields = Object.fromEntries( + const defaultAuthSelectorFields: Record = Object.fromEntries( Object.entries(fields) - .filter(([_, fieldInfo]) => isDefaultAuthField(fieldInfo)) - .map(([field, fieldInfo]) => [ - field, - fieldInfo.attributes?.find((attr) => attr.name === '@default')?.args[0]?.value as - | string - | undefined, - ]) + .filter(([_, fieldInfo]) => this.isDefaultAuthField(fieldInfo)) + .map(([field, fieldInfo]) => [field, this.getAuthSelector(fieldInfo)]) ); const defaultAuthFields = Object.fromEntries( - Object.entries(defaultAuthSelectorFields).map(([field, selector]) => [ + Object.entries(defaultAuthSelectorFields).map(([field, authSelector]) => [ field, - selector ? userContext[selector] : userContext, + deepGet(userContext, authSelector, userContext), ]) ); console.log('defaultAuthFields :', defaultAuthFields); - for (const [field, defaultValue] of Object.entries(defaultAuthFields)) { - context.parent[field] = defaultValue; - } + newArgs = { ...args, data: { ...args.data, ...defaultAuthFields } }; }, }); await visitor.visit(model, action, args); + return newArgs; + } + + private isDefaultAuthField(field: FieldInfo): boolean { + return !!field.attributes?.find((attr) => attr.name === '@default' && attr.args?.[0]?.name === 'auth()'); + } + + private getAuthSelector(fieldInfo: FieldInfo): AuthContextSelector { + return fieldInfo.attributes?.find((attr) => attr.name === '@default')?.args[0]?.value as AuthContextSelector; } } diff --git a/packages/runtime/src/enhancements/utils.ts b/packages/runtime/src/enhancements/utils.ts index ba2f9a2d8..c68bfaf61 100644 --- a/packages/runtime/src/enhancements/utils.ts +++ b/packages/runtime/src/enhancements/utils.ts @@ -22,3 +22,18 @@ export function prismaClientKnownRequestError(prisma: DbClientContract, prismaMo export function prismaClientUnknownRequestError(prismaModule: any, ...args: unknown[]): Error { throw new prismaModule.PrismaClientUnknownRequestError(...args); } + +export function deepGet(object: object, path: string | string[] | undefined, defaultValue: unknown): unknown { + if (path === undefined) { + return defaultValue; + } + const keys = Array.isArray(path) ? path : path.split('.'); + for (const key of keys) { + if (object && typeof object === 'object' && key in object) { + object = object[key as keyof typeof object]; + } else { + return defaultValue; + } + } + return object !== undefined ? object : defaultValue; +} diff --git a/packages/schema/src/plugins/prisma/schema-generator.ts b/packages/schema/src/plugins/prisma/schema-generator.ts index 5298aead0..0f25ab1b8 100644 --- a/packages/schema/src/plugins/prisma/schema-generator.ts +++ b/packages/schema/src/plugins/prisma/schema-generator.ts @@ -33,6 +33,7 @@ import { getDMMF, getLiteral, getPrismaVersion, + isDefaultAuthField, PluginError, PluginOptions, resolved, @@ -69,7 +70,6 @@ import { const MODEL_PASSTHROUGH_ATTR = '@@prisma.passthrough'; const FIELD_PASSTHROUGH_ATTR = '@prisma.passthrough'; -const NO_FIELD_ATTRIBUTE = ''; /** * Generates Prisma schema file @@ -286,12 +286,7 @@ export default class PrismaSchemaGenerator { !!attrDecl.attributes.find((a) => a.decl.ref?.name === '@@@prisma') || // the special pass-through attribute attrDecl.name === MODEL_PASSTHROUGH_ATTR || - attrDecl.name === FIELD_PASSTHROUGH_ATTR || - // auth() in @default() is not supported by Prisma - // FIXME: condition is inverted to avoid error... - !!attrDecl.attributes.find( - (a) => a.decl.ref?.name === '@default' && a.args[0].value.$cstNode?.text.startsWith('auth()') - ) + attrDecl.name === FIELD_PASSTHROUGH_ATTR ); } @@ -317,9 +312,7 @@ export default class PrismaSchemaGenerator { const type = new ModelFieldType(fieldType, field.type.array, field.type.optional); - const attributes = field.attributes - .filter((attr) => this.isPrismaAttribute(attr)) - .map((attr) => this.makeFieldAttribute(attr)); + const attributes = this.getAttributesToGenerate(field); const nonPrismaAttributes = field.attributes.filter((attr) => attr.decl.ref && !this.isPrismaAttribute(attr)); @@ -331,6 +324,15 @@ export default class PrismaSchemaGenerator { field.comments.forEach((c) => result.addComment(c)); } + private getAttributesToGenerate(field: DataModelField) { + if (isDefaultAuthField(field)) { + return []; + } + return field.attributes + .filter((attr) => this.isPrismaAttribute(attr)) + .map((attr) => this.makeFieldAttribute(attr)); + } + private makeFieldAttribute(attr: DataModelFieldAttribute) { const attrName = resolved(attr.decl).name; if (attrName === FIELD_PASSTHROUGH_ATTR) { @@ -340,10 +342,6 @@ export default class PrismaSchemaGenerator { } else { throw new PluginError(name, `Invalid arguments for ${FIELD_PASSTHROUGH_ATTR} attribute`); } - // do not write @default(auth()) field attribute as it is not supported by Prisma - // TODO: we should add a comment to the field - } else if (attrName === '@default' && attr.args[0].value.$cstNode?.text.startsWith('auth()')) { - return new PrismaPassThroughAttribute(NO_FIELD_ATTRIBUTE); } else { return new PrismaFieldAttribute( attrName, diff --git a/packages/sdk/src/utils.ts b/packages/sdk/src/utils.ts index afd043565..d16a42031 100644 --- a/packages/sdk/src/utils.ts +++ b/packages/sdk/src/utils.ts @@ -280,6 +280,13 @@ export function isForeignKeyField(field: DataModelField) { }); } +export function isDefaultAuthField(field: DataModelField) { + return ( + hasAttribute(field, '@default') && + !!field.attributes.find((attr) => attr.args?.[0]?.value.$cstNode?.text.startsWith('auth()')) + ); +} + export function resolvePath(_path: string, options: Pick) { if (path.isAbsolute(_path)) { return _path; diff --git a/tests/integration/tests/enhancements/with-policy/auth.test.ts b/tests/integration/tests/enhancements/with-policy/auth.test.ts index 619d5aa47..f8ced4e7e 100644 --- a/tests/integration/tests/enhancements/with-policy/auth.test.ts +++ b/tests/integration/tests/enhancements/with-policy/auth.test.ts @@ -392,7 +392,7 @@ describe('With Policy: auth() test', () => { }); it('Default auth() with foreign key', async () => { - const { enhance } = await loadSchema( + const { enhance, modelMeta } = await loadSchema( ` model User { id String @id @@ -414,8 +414,42 @@ describe('With Policy: auth() test', () => { ); const db = enhance({ id: 'userId-1' }); + const attributes = modelMeta.fields.post.authorId.attributes; + expect(attributes).toHaveProperty('0.name', '@default'); + expect(attributes).toHaveProperty('0.args.0.name', 'auth()'); + expect(attributes).toHaveProperty('0.args.0.value', 'id'); await expect(db.user.create({ data: { id: 'userId-1' } })).toResolveTruthy(); await expect(db.post.create({ data: { title: 'abc' } })).toResolveTruthy(); await expect(db.post.count({ where: { authorId: 'userId-1' } })).resolves.toBe(1); }); + + it('Default auth() with nested user context value', async () => { + const { enhance, modelMeta } = await loadSchema( + ` + model User { + id String @id + + @@allow('all', true) + + } + + model Post { + id String @id @default(uuid()) + title String + defaultImageUrl string @default(auth().profile.image.url) + + @@allow('all', true) + } + ` + ); + const url = 'https://zenstack.dev'; + const db = enhance({ id: 'userId-1', profile: { image: { url } } }); + const attributes = modelMeta.fields.post.authorId.attributes; + expect(attributes).toHaveProperty('0.name', '@default'); + expect(attributes).toHaveProperty('0.args.0.name', 'auth()'); + expect(attributes).toHaveProperty('0.args.0.value', 'profile.image.url'); + await expect(db.user.create({ data: { id: 'userId-1' } })).toResolveTruthy(); + await expect(db.post.create({ data: { title: 'abc' } })).toResolveTruthy(); + await expect(db.post.count({ where: { defaultImageUrl: url } })).resolves.toBe(1); + }); }); From 258b56b064f2ce3ac28714bd6e020bc9357c3b77 Mon Sep 17 00:00:00 2001 From: augustin Date: Wed, 24 Jan 2024 13:53:51 +0100 Subject: [PATCH 08/12] fix: default args should not override passed args --- .../runtime/src/enhancements/default-auth.ts | 2 +- .../enhancements/with-policy/auth.test.ts | 25 +++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/packages/runtime/src/enhancements/default-auth.ts b/packages/runtime/src/enhancements/default-auth.ts index 108004a33..ffd901e0c 100644 --- a/packages/runtime/src/enhancements/default-auth.ts +++ b/packages/runtime/src/enhancements/default-auth.ts @@ -68,7 +68,7 @@ class DefaultAuthHandler extends DefaultPrismaProxyHandler { ]) ); console.log('defaultAuthFields :', defaultAuthFields); - newArgs = { ...args, data: { ...args.data, ...defaultAuthFields } }; + newArgs = { ...args, data: { ...defaultAuthFields, ...args.data } }; }, }); diff --git a/tests/integration/tests/enhancements/with-policy/auth.test.ts b/tests/integration/tests/enhancements/with-policy/auth.test.ts index f8ced4e7e..b1c2bafcd 100644 --- a/tests/integration/tests/enhancements/with-policy/auth.test.ts +++ b/tests/integration/tests/enhancements/with-policy/auth.test.ts @@ -391,6 +391,31 @@ describe('With Policy: auth() test', () => { await expect(userDb.post.count({ where: { authorName: 'user1', score: 10 } })).resolves.toBe(1); }); + it('Default auth() data should not override passed args', async () => { + const { enhance } = await loadSchema( + ` + model User { + id String @id + name String + + } + + model Post { + id String @id @default(uuid()) + authorName String? @default(auth().name) + + @@allow('all', true) + } + ` + ); + + const userContextName = 'user1'; + const overrideName = 'no-default-auth-name'; + const userDb = enhance({ id: '1', name: userContextName }); + await expect(userDb.post.create({ data: { authorName: overrideName } })).toResolveTruthy(); + await expect(userDb.post.count({ where: { authorName: overrideName } })).resolves.toBe(1); + }); + it('Default auth() with foreign key', async () => { const { enhance, modelMeta } = await loadSchema( ` From 33115303ba8bfb587aeb59803e7f37ea35808494 Mon Sep 17 00:00:00 2001 From: augustin Date: Wed, 24 Jan 2024 17:53:21 +0100 Subject: [PATCH 09/12] considering auth() without selector --- packages/runtime/src/cross/model-meta.ts | 2 +- .../runtime/src/enhancements/default-auth.ts | 23 +++++++----- packages/runtime/src/enhancements/utils.ts | 2 +- packages/sdk/src/model-meta-generator.ts | 4 +- .../enhancements/with-policy/auth.test.ts | 37 +++++++++++++++++++ 5 files changed, 55 insertions(+), 13 deletions(-) diff --git a/packages/runtime/src/cross/model-meta.ts b/packages/runtime/src/cross/model-meta.ts index fc6581684..89b1e11b7 100644 --- a/packages/runtime/src/cross/model-meta.ts +++ b/packages/runtime/src/cross/model-meta.ts @@ -3,7 +3,7 @@ import { lowerCaseFirst } from 'lower-case-first'; /** * An access key in the user context object (e.g. `profile.picture.url`) */ -export type AuthContextSelector = string | undefined; +export type AuthContextSelector = string; /** * Runtime information of a data model or field attribute diff --git a/packages/runtime/src/enhancements/default-auth.ts b/packages/runtime/src/enhancements/default-auth.ts index ffd901e0c..8f3a0cdff 100644 --- a/packages/runtime/src/enhancements/default-auth.ts +++ b/packages/runtime/src/enhancements/default-auth.ts @@ -56,16 +56,21 @@ class DefaultAuthHandler extends DefaultPrismaProxyHandler { throw new Error(`Invalid user context`); } const fields = this.options.modelMeta.fields[model]; - const defaultAuthSelectorFields: Record = Object.fromEntries( - Object.entries(fields) - .filter(([_, fieldInfo]) => this.isDefaultAuthField(fieldInfo)) - .map(([field, fieldInfo]) => [field, this.getAuthSelector(fieldInfo)]) - ); + const defaultAuthSelectorFields: Record = + Object.fromEntries( + Object.entries(fields) + .filter(([_, fieldInfo]) => this.isDefaultAuthField(fieldInfo)) + .map(([field, fieldInfo]) => [ + field, + { fieldType: fieldInfo.type, selector: this.getAuthSelector(fieldInfo) }, + ]) + ); const defaultAuthFields = Object.fromEntries( - Object.entries(defaultAuthSelectorFields).map(([field, authSelector]) => [ - field, - deepGet(userContext, authSelector, userContext), - ]) + Object.entries(defaultAuthSelectorFields).map(([field, { fieldType, selector }]) => { + // if field type is String, we expect auth() to return the whole user context as string + const defaultValue = fieldType === 'String' ? JSON.stringify(userContext) : userContext; + return [field, deepGet(userContext, selector, defaultValue)]; + }) ); console.log('defaultAuthFields :', defaultAuthFields); newArgs = { ...args, data: { ...defaultAuthFields, ...args.data } }; diff --git a/packages/runtime/src/enhancements/utils.ts b/packages/runtime/src/enhancements/utils.ts index c68bfaf61..2879a3119 100644 --- a/packages/runtime/src/enhancements/utils.ts +++ b/packages/runtime/src/enhancements/utils.ts @@ -24,7 +24,7 @@ export function prismaClientUnknownRequestError(prismaModule: any, ...args: unkn } export function deepGet(object: object, path: string | string[] | undefined, defaultValue: unknown): unknown { - if (path === undefined) { + if (path === undefined || path === '') { return defaultValue; } const keys = Array.isArray(path) ? path : path.split('.'); diff --git a/packages/sdk/src/model-meta-generator.ts b/packages/sdk/src/model-meta-generator.ts index 9f06932df..928929bb6 100644 --- a/packages/sdk/src/model-meta-generator.ts +++ b/packages/sdk/src/model-meta-generator.ts @@ -10,7 +10,7 @@ import { isStringLiteral, ReferenceExpr, } from '@zenstackhq/language/ast'; -import type { RuntimeAttribute } from '@zenstackhq/runtime'; +import type { AuthContextSelector, RuntimeAttribute } from '@zenstackhq/runtime'; import { lowerCaseFirst } from 'lower-case-first'; import { CodeBlockWriter, Project, VariableDeclarationKind } from 'ts-morph'; import { @@ -215,7 +215,7 @@ function getFieldAttributes(field: DataModelField): RuntimeAttribute[] { attr.args[0].value.$cstNode?.text.startsWith('auth()') ) { const authValue = attr.args[0].value.$cstNode?.text; - const authSelector = authValue === 'auth()' ? authValue : authValue.slice('auth().'.length); + const authSelector: AuthContextSelector = authValue.slice('auth().'.length); args.push({ name: 'auth()', value: authSelector }); } else { // non-literal args are ignored diff --git a/tests/integration/tests/enhancements/with-policy/auth.test.ts b/tests/integration/tests/enhancements/with-policy/auth.test.ts index b1c2bafcd..95838128a 100644 --- a/tests/integration/tests/enhancements/with-policy/auth.test.ts +++ b/tests/integration/tests/enhancements/with-policy/auth.test.ts @@ -416,6 +416,43 @@ describe('With Policy: auth() test', () => { await expect(userDb.post.count({ where: { authorName: overrideName } })).resolves.toBe(1); }); + it('Default auth() return user context', async () => { + const { enhance, modelMeta } = await loadSchema( + ` + model User { + id String @id + name String + + } + + model Post { + id String @id @default(uuid()) + authorName String? @default(auth().name) + userDetailsAsString String? @default(auth()) + userDetailsAsJson Json? @default(auth()) + + @@allow('all', true) + } + `, + { + provider: 'postgresql', + compile: true, + } + ); + + const userContext = { id: '1', name: 'user1' }; + const db = enhance(userContext); + const userDetailsAttributes = modelMeta.fields.post.userDetails.attributes; + expect(userDetailsAttributes).toHaveProperty('0.name', '@default'); + expect(userDetailsAttributes).toHaveProperty('0.args.0.name', 'auth()'); + expect(userDetailsAttributes).toHaveProperty('0.args.0.value', ''); + await expect(db.post.create({ data: {} })).toResolveTruthy(); + await expect(db.post.count()).resolves.toBe(1); + const post = await db.post.findFirst(); + expect(post?.userDetailsAsString).toEqual(JSON.stringify(userContext)); + expect(post?.userDetailsAsJson).toEqual(userContext); + }); + it('Default auth() with foreign key', async () => { const { enhance, modelMeta } = await loadSchema( ` From 37c5cd174561c8d91a2e06a67e925212216d455d Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Fri, 26 Jan 2024 15:08:39 +0800 Subject: [PATCH 10/12] A few refactors - Generate a function to provide value for fields using `auth()` in `@default` so we don't need to evaluate at runtime - Correct the way of visiting nested create payload --- packages/runtime/src/cross/model-meta.ts | 17 ++-- .../src/enhancements/create-enhancement.ts | 4 +- .../runtime/src/enhancements/default-auth.ts | 77 ++++++++------- .../validator/expression-validator.ts | 37 ++++---- .../function-invocation-validator.ts | 9 +- .../src/language-server/zmodel-linker.ts | 13 +-- .../enhancer/policy/expression-writer.ts | 9 +- .../enhancer/policy/policy-guard-generator.ts | 10 +- .../src/plugins/zod/utils/schema-gen.ts | 6 +- packages/schema/src/utils/ast-utils.ts | 48 +--------- packages/sdk/package.json | 2 + packages/sdk/src/index.ts | 1 + packages/sdk/src/model-meta-generator.ts | 65 ++++++++++--- .../src}/typescript-expression-transformer.ts | 4 +- packages/sdk/src/utils.ts | 40 +++++++- pnpm-lock.yaml | 6 ++ .../enhancements/with-policy/auth.test.ts | 93 +++++++++---------- 17 files changed, 246 insertions(+), 195 deletions(-) rename packages/{schema/src/utils => sdk/src}/typescript-expression-transformer.ts (98%) diff --git a/packages/runtime/src/cross/model-meta.ts b/packages/runtime/src/cross/model-meta.ts index 89b1e11b7..a45aae3f1 100644 --- a/packages/runtime/src/cross/model-meta.ts +++ b/packages/runtime/src/cross/model-meta.ts @@ -1,18 +1,18 @@ import { lowerCaseFirst } from 'lower-case-first'; -/** - * An access key in the user context object (e.g. `profile.picture.url`) - */ -export type AuthContextSelector = string; - /** * Runtime information of a data model or field attribute */ export type RuntimeAttribute = { name: string; - args: Array<{ name?: string; value: unknown } | { name: 'auth()'; value: AuthContextSelector }>; + args: Array<{ name?: string; value: unknown }>; }; +/** + * Function for computing default value for a field + */ +export type FieldDefaultValueProvider = (userContext: unknown) => unknown; + /** * Runtime information of a data model field */ @@ -71,6 +71,11 @@ export type FieldInfo = { * Mapping from foreign key field names to relation field names */ foreignKeyMapping?: Record; + + /** + * A function that provides a default value for the field + */ + defaultValueProvider?: FieldDefaultValueProvider; }; /** diff --git a/packages/runtime/src/enhancements/create-enhancement.ts b/packages/runtime/src/enhancements/create-enhancement.ts index 0b128d82e..e3204cd52 100644 --- a/packages/runtime/src/enhancements/create-enhancement.ts +++ b/packages/runtime/src/enhancements/create-enhancement.ts @@ -132,9 +132,7 @@ export function createEnhancement( const allFields = Object.values(options.modelMeta.fields).flatMap((modelInfo) => Object.values(modelInfo)); hasPassword = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@password')); hasOmit = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@omit')); - hasDefaultAuth = allFields.some((field) => - field.attributes?.some((attr) => attr.name === '@default' && attr.args[0]?.name === 'auth()') - ); + hasDefaultAuth = allFields.some((field) => field.defaultValueProvider); } const kinds = options.kinds ?? [ diff --git a/packages/runtime/src/enhancements/default-auth.ts b/packages/runtime/src/enhancements/default-auth.ts index 8f3a0cdff..48af0ed73 100644 --- a/packages/runtime/src/enhancements/default-auth.ts +++ b/packages/runtime/src/enhancements/default-auth.ts @@ -1,11 +1,11 @@ /* eslint-disable @typescript-eslint/no-unused-vars */ /* eslint-disable @typescript-eslint/no-explicit-any */ -import { NestedWriteVisitor, PrismaWriteActionType, FieldInfo, AuthContextSelector } from '../cross'; +import deepcopy from 'deepcopy'; +import { FieldInfo, NestedWriteVisitor, PrismaWriteActionType, enumerate, getFields } from '../cross'; import { DbClientContract } from '../types'; import { EnhancementContext, EnhancementOptions } from './create-enhancement'; import { DefaultPrismaProxyHandler, PrismaProxyActions, makeProxy } from './proxy'; -import { deepGet } from './utils'; /** * Gets an enhanced Prisma client that supports `@default(auth())` attribute. @@ -27,6 +27,8 @@ export function withDefaultAuth( class DefaultAuthHandler extends DefaultPrismaProxyHandler { private readonly db: DbClientContract; + private readonly userContext: any; + constructor( prisma: DbClientContract, model: string, @@ -35,6 +37,12 @@ class DefaultAuthHandler extends DefaultPrismaProxyHandler { ) { super(prisma, model); this.db = prisma; + + if (!this.context?.user) { + throw new Error(`Using \`auth()\` in \`@default\` requires a user context`); + } + + this.userContext = this.context.user; } // base override @@ -48,44 +56,47 @@ class DefaultAuthHandler extends DefaultPrismaProxyHandler { } private async preprocessWritePayload(model: string, action: PrismaWriteActionType, args: any) { - let newArgs = {}; + const newArgs = deepcopy(args); + + const processCreatePayload = (model: string, data: any) => { + const fields = getFields(this.options.modelMeta, model); + for (const fieldInfo of Object.values(fields)) { + if (fieldInfo.name in data) { + // create payload already sets field value + continue; + } + + if (!fieldInfo.defaultValueProvider) { + // field doesn't have a runtime default value provider + continue; + } + + const authDefaultValue = this.getDefaultValueFromAuth(fieldInfo); + if (authDefaultValue !== undefined) { + // set field value extracted from `auth()` + data[fieldInfo.name] = authDefaultValue; + } + } + }; + + // visit create payload and set default value to fields using `auth()` in `@default()` const visitor = new NestedWriteVisitor(this.options.modelMeta, { - create: (model, _data, _context) => { - const userContext = this.context?.user; - if (!userContext) { - throw new Error(`Invalid user context`); + create: (model, data) => { + processCreatePayload(model, data); + }, + + createMany: (model, args) => { + for (const item of enumerate(args.data)) { + processCreatePayload(model, item); } - const fields = this.options.modelMeta.fields[model]; - const defaultAuthSelectorFields: Record = - Object.fromEntries( - Object.entries(fields) - .filter(([_, fieldInfo]) => this.isDefaultAuthField(fieldInfo)) - .map(([field, fieldInfo]) => [ - field, - { fieldType: fieldInfo.type, selector: this.getAuthSelector(fieldInfo) }, - ]) - ); - const defaultAuthFields = Object.fromEntries( - Object.entries(defaultAuthSelectorFields).map(([field, { fieldType, selector }]) => { - // if field type is String, we expect auth() to return the whole user context as string - const defaultValue = fieldType === 'String' ? JSON.stringify(userContext) : userContext; - return [field, deepGet(userContext, selector, defaultValue)]; - }) - ); - console.log('defaultAuthFields :', defaultAuthFields); - newArgs = { ...args, data: { ...defaultAuthFields, ...args.data } }; }, }); - await visitor.visit(model, action, args); + await visitor.visit(model, action, newArgs); return newArgs; } - private isDefaultAuthField(field: FieldInfo): boolean { - return !!field.attributes?.find((attr) => attr.name === '@default' && attr.args?.[0]?.name === 'auth()'); - } - - private getAuthSelector(fieldInfo: FieldInfo): AuthContextSelector { - return fieldInfo.attributes?.find((attr) => attr.name === '@default')?.args[0]?.value as AuthContextSelector; + private getDefaultValueFromAuth(fieldInfo: FieldInfo) { + return fieldInfo.defaultValueProvider?.(this.userContext); } } diff --git a/packages/schema/src/language-server/validator/expression-validator.ts b/packages/schema/src/language-server/validator/expression-validator.ts index 7644521b8..cfc8a39af 100644 --- a/packages/schema/src/language-server/validator/expression-validator.ts +++ b/packages/schema/src/language-server/validator/expression-validator.ts @@ -3,16 +3,16 @@ import { Expression, ExpressionType, isDataModel, + isDataModelField, isEnum, + isLiteralExpr, isMemberAccessExpr, isNullExpr, isThisExpr, - isDataModelField, - isLiteralExpr, } from '@zenstackhq/language/ast'; -import { isDataModelFieldReference, isEnumFieldReference } from '@zenstackhq/sdk'; +import { isAuthInvocation, isDataModelFieldReference, isEnumFieldReference } from '@zenstackhq/sdk'; import { ValidationAcceptor } from 'langium'; -import { getContainingDataModel, isAuthInvocation, isCollectionPredicate } from '../../utils/ast-utils'; +import { getContainingDataModel, isCollectionPredicate } from '../../utils/ast-utils'; import { AstValidator } from '../types'; import { typeAssignable } from './utils'; @@ -132,18 +132,24 @@ export default class ExpressionValidator implements AstValidator { // - foo.user.id == userId // except: // - future().userId == userId - if(isMemberAccessExpr(expr.left) && isDataModelField(expr.left.member.ref) && expr.left.member.ref.$container != getContainingDataModel(expr) - || isMemberAccessExpr(expr.right) && isDataModelField(expr.right.member.ref) && expr.right.member.ref.$container != getContainingDataModel(expr)) - { + if ( + (isMemberAccessExpr(expr.left) && + isDataModelField(expr.left.member.ref) && + expr.left.member.ref.$container != getContainingDataModel(expr)) || + (isMemberAccessExpr(expr.right) && + isDataModelField(expr.right.member.ref) && + expr.right.member.ref.$container != getContainingDataModel(expr)) + ) { // foo.user.id == auth().id // foo.user.id == "123" // foo.user.id == null // foo.user.id == EnumValue - if(!(this.isNotModelFieldExpr(expr.left) || this.isNotModelFieldExpr(expr.right))) - { - accept('error', 'comparison between fields of different models are not supported', { node: expr }); - break; - } + if (!(this.isNotModelFieldExpr(expr.left) || this.isNotModelFieldExpr(expr.right))) { + accept('error', 'comparison between fields of different models are not supported', { + node: expr, + }); + break; + } } if ( @@ -205,14 +211,13 @@ export default class ExpressionValidator implements AstValidator { } } - private isNotModelFieldExpr(expr: Expression) { - return isLiteralExpr(expr) || isEnumFieldReference(expr) || isNullExpr(expr) || this.isAuthOrAuthMemberAccess(expr) + return ( + isLiteralExpr(expr) || isEnumFieldReference(expr) || isNullExpr(expr) || this.isAuthOrAuthMemberAccess(expr) + ); } private isAuthOrAuthMemberAccess(expr: Expression) { return isAuthInvocation(expr) || (isMemberAccessExpr(expr) && isAuthInvocation(expr.operand)); } - } - diff --git a/packages/schema/src/language-server/validator/function-invocation-validator.ts b/packages/schema/src/language-server/validator/function-invocation-validator.ts index 3bc364bd2..50b974a53 100644 --- a/packages/schema/src/language-server/validator/function-invocation-validator.ts +++ b/packages/schema/src/language-server/validator/function-invocation-validator.ts @@ -11,10 +11,15 @@ import { isDataModelFieldAttribute, isLiteralExpr, } from '@zenstackhq/language/ast'; -import { ExpressionContext, getFunctionExpressionContext, isEnumFieldReference, isFromStdlib } from '@zenstackhq/sdk'; +import { + ExpressionContext, + getDataModelFieldReference, + getFunctionExpressionContext, + isEnumFieldReference, + isFromStdlib, +} from '@zenstackhq/sdk'; import { AstNode, ValidationAcceptor } from 'langium'; import { P, match } from 'ts-pattern'; -import { getDataModelFieldReference } from '../../utils/ast-utils'; import { AstValidator } from '../types'; import { typeAssignable } from './utils'; diff --git a/packages/schema/src/language-server/zmodel-linker.ts b/packages/schema/src/language-server/zmodel-linker.ts index ef97cf4b6..8c8fb2c98 100644 --- a/packages/schema/src/language-server/zmodel-linker.ts +++ b/packages/schema/src/language-server/zmodel-linker.ts @@ -35,7 +35,7 @@ import { isReferenceExpr, isStringLiteral, } from '@zenstackhq/language/ast'; -import { getContainingModel, hasAttribute, isFromStdlib } from '@zenstackhq/sdk'; +import { getContainingModel, hasAttribute, isAuthInvocation, isFutureExpr } from '@zenstackhq/sdk'; import { AstNode, AstNodeDescription, @@ -52,12 +52,7 @@ import { } from 'langium'; import { match } from 'ts-pattern'; import { CancellationToken } from 'vscode-jsonrpc'; -import { - getAllDeclarationsFromImports, - getContainingDataModel, - isAuthInvocation, - isCollectionPredicate, -} from '../utils/ast-utils'; +import { getAllDeclarationsFromImports, getContainingDataModel, isCollectionPredicate } from '../utils/ast-utils'; import { mapBuiltinTypeToExpressionType } from './validator/utils'; interface DefaultReference extends Reference { @@ -329,7 +324,7 @@ export class ZModelLinker extends DefaultLinker { if (node.function.ref) { // eslint-disable-next-line @typescript-eslint/ban-types const funcDecl = node.function.ref as FunctionDecl; - if (funcDecl.name === 'auth' && isFromStdlib(funcDecl)) { + if (isAuthInvocation(node)) { // auth() function is resolved to User model in the current document const model = getContainingModel(node); @@ -346,7 +341,7 @@ export class ZModelLinker extends DefaultLinker { node.$resolvedType = { decl: authModel, nullable: true }; } } - } else if (funcDecl.name === 'future' && isFromStdlib(funcDecl)) { + } else if (isFutureExpr(node)) { // future() function is resolved to current model node.$resolvedType = { decl: getContainingDataModel(node) }; } else { diff --git a/packages/schema/src/plugins/enhancer/policy/expression-writer.ts b/packages/schema/src/plugins/enhancer/policy/expression-writer.ts index 0cc80c7ea..e38a34c29 100644 --- a/packages/schema/src/plugins/enhancer/policy/expression-writer.ts +++ b/packages/schema/src/plugins/enhancer/policy/expression-writer.ts @@ -19,19 +19,18 @@ import { import { ExpressionContext, getFunctionExpressionContext, + getIdFields, getLiteral, + isAuthInvocation, isDataModelFieldReference, isFutureExpr, PluginError, + TypeScriptExpressionTransformer, + TypeScriptExpressionTransformerError, } from '@zenstackhq/sdk'; import { lowerCaseFirst } from 'lower-case-first'; import { CodeBlockWriter } from 'ts-morph'; import { name } from '..'; -import { getIdFields, isAuthInvocation } from '../../../utils/ast-utils'; -import { - TypeScriptExpressionTransformer, - TypeScriptExpressionTransformerError, -} from '../../../utils/typescript-expression-transformer'; type ComparisonOperator = '==' | '!=' | '>' | '>=' | '<' | '<='; diff --git a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts index e5017383d..149858cd6 100644 --- a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts @@ -33,14 +33,18 @@ import { PluginError, PluginOptions, RUNTIME_PACKAGE, + TypeScriptExpressionTransformer, + TypeScriptExpressionTransformerError, analyzePolicies, getAttributeArg, getAuthModel, getDataModels, + getIdFields, getLiteral, getPrismaClientImportSpec, hasAttribute, hasValidationAttributes, + isAuthInvocation, isEnumFieldReference, isForeignKeyField, isFromStdlib, @@ -52,11 +56,7 @@ import { lowerCaseFirst } from 'lower-case-first'; import path from 'path'; import { FunctionDeclaration, Project, SourceFile, VariableDeclarationKind, WriterFunction } from 'ts-morph'; import { name } from '..'; -import { getIdFields, isAuthInvocation, isCollectionPredicate } from '../../../utils/ast-utils'; -import { - TypeScriptExpressionTransformer, - TypeScriptExpressionTransformerError, -} from '../../../utils/typescript-expression-transformer'; +import { isCollectionPredicate } from '../../../utils/ast-utils'; import { ALL_OPERATION_KINDS } from '../../plugin-utils'; import { ExpressionWriter, FALSE, TRUE } from './expression-writer'; diff --git a/packages/schema/src/plugins/zod/utils/schema-gen.ts b/packages/schema/src/plugins/zod/utils/schema-gen.ts index 802127c58..02607d4c7 100644 --- a/packages/schema/src/plugins/zod/utils/schema-gen.ts +++ b/packages/schema/src/plugins/zod/utils/schema-gen.ts @@ -1,6 +1,8 @@ import { ExpressionContext, PluginError, + TypeScriptExpressionTransformer, + TypeScriptExpressionTransformerError, getAttributeArg, getAttributeArgLiteral, getLiteral, @@ -18,10 +20,6 @@ import { } from '@zenstackhq/sdk/ast'; import { upperCaseFirst } from 'upper-case-first'; import { name } from '..'; -import { - TypeScriptExpressionTransformer, - TypeScriptExpressionTransformerError, -} from '../../../utils/typescript-expression-transformer'; export function makeFieldSchema(field: DataModelField, respectDefault = false) { if (isDataModel(field.type.reference?.ref)) { diff --git a/packages/schema/src/utils/ast-utils.ts b/packages/schema/src/utils/ast-utils.ts index 661f14b26..80543d6a2 100644 --- a/packages/schema/src/utils/ast-utils.ts +++ b/packages/schema/src/utils/ast-utils.ts @@ -1,21 +1,13 @@ import { BinaryExpr, DataModel, - DataModelField, Expression, - isArrayExpr, isBinaryExpr, isDataModel, - isDataModelField, - isInvocationExpr, - isMemberAccessExpr, isModel, - isReferenceExpr, Model, ModelImport, - ReferenceExpr, } from '@zenstackhq/language/ast'; -import { isFromStdlib } from '@zenstackhq/sdk'; import { AstNode, getDocument, LangiumDocuments, Mutable } from 'langium'; import { URI, Utils } from 'vscode-uri'; @@ -56,43 +48,6 @@ function updateContainer(nodes: T[], container: AstNode): Mut }); } -export function getIdFields(dataModel: DataModel) { - const fieldLevelId = dataModel.$resolvedFields.find((f) => - f.attributes.some((attr) => attr.decl.$refText === '@id') - ); - if (fieldLevelId) { - return [fieldLevelId]; - } else { - // get model level @@id attribute - const modelIdAttr = dataModel.attributes.find((attr) => attr.decl?.ref?.name === '@@id'); - if (modelIdAttr) { - // get fields referenced in the attribute: @@id([field1, field2]]) - if (!isArrayExpr(modelIdAttr.args[0].value)) { - return []; - } - const argValue = modelIdAttr.args[0].value; - return argValue.items - .filter((expr): expr is ReferenceExpr => isReferenceExpr(expr) && !!getDataModelFieldReference(expr)) - .map((expr) => expr.target.ref as DataModelField); - } - } - return []; -} - -export function isAuthInvocation(node: AstNode) { - return isInvocationExpr(node) && node.function.ref?.name === 'auth' && isFromStdlib(node.function.ref); -} - -export function getDataModelFieldReference(expr: Expression): DataModelField | undefined { - if (isReferenceExpr(expr) && isDataModelField(expr.target.ref)) { - return expr.target.ref; - } else if (isMemberAccessExpr(expr) && isDataModelField(expr.member.ref)) { - return expr.member.ref; - } else { - return undefined; - } -} - export function resolveImportUri(imp: ModelImport): URI | undefined { if (imp.path === undefined || imp.path.length === 0) { return undefined; @@ -157,7 +112,6 @@ export function isCollectionPredicate(node: AstNode): node is BinaryExpr { return isBinaryExpr(node) && ['?', '!', '^'].includes(node.operator); } - export function getContainingDataModel(node: Expression): DataModel | undefined { let curr: AstNode | undefined = node.$container; while (curr) { @@ -167,4 +121,4 @@ export function getContainingDataModel(node: Expression): DataModel | undefined curr = curr.$container; } return undefined; -} \ No newline at end of file +} diff --git a/packages/sdk/package.json b/packages/sdk/package.json index beddaad70..ac8bcaf1d 100644 --- a/packages/sdk/package.json +++ b/packages/sdk/package.json @@ -23,10 +23,12 @@ "@prisma/internals-v5": "npm:@prisma/internals@^5.0.0", "@zenstackhq/language": "workspace:*", "@zenstackhq/runtime": "workspace:*", + "langium": "1.2.0", "lower-case-first": "^2.0.2", "prettier": "^2.8.3 || 3.x", "semver": "^7.5.2", "ts-morph": "^16.0.0", + "ts-pattern": "^4.3.0", "upper-case-first": "^2.0.2" }, "devDependencies": { diff --git a/packages/sdk/src/index.ts b/packages/sdk/src/index.ts index 64060390e..5013267e8 100644 --- a/packages/sdk/src/index.ts +++ b/packages/sdk/src/index.ts @@ -4,6 +4,7 @@ export { generate as generateModelMeta } from './model-meta-generator'; export * from './policy'; export * from './prisma'; export * from './types'; +export * from './typescript-expression-transformer'; export * from './utils'; export * from './validation'; export * from './zmodel-code-generator'; diff --git a/packages/sdk/src/model-meta-generator.ts b/packages/sdk/src/model-meta-generator.ts index 928929bb6..76290a57e 100644 --- a/packages/sdk/src/model-meta-generator.ts +++ b/packages/sdk/src/model-meta-generator.ts @@ -10,10 +10,12 @@ import { isStringLiteral, ReferenceExpr, } from '@zenstackhq/language/ast'; -import type { AuthContextSelector, RuntimeAttribute } from '@zenstackhq/runtime'; +import type { RuntimeAttribute } from '@zenstackhq/runtime'; +import { streamAst } from 'langium'; import { lowerCaseFirst } from 'lower-case-first'; -import { CodeBlockWriter, Project, VariableDeclarationKind } from 'ts-morph'; +import { CodeBlockWriter, Project, SourceFile, VariableDeclarationKind } from 'ts-morph'; import { + ExpressionContext, getAttribute, getAttributeArg, getAttributeArgs, @@ -21,10 +23,12 @@ import { getDataModels, getLiteral, hasAttribute, + isAuthInvocation, isEnumFieldReference, isForeignKeyField, isIdField, resolved, + TypeScriptExpressionTransformer, } from '.'; export type ModelMetaGeneratorOptions = { @@ -37,13 +41,20 @@ export async function generate(project: Project, models: DataModel[], options: M sf.addStatements('/* eslint-disable */'); sf.addVariableStatement({ declarationKind: VariableDeclarationKind.Const, - declarations: [{ name: 'metadata', initializer: (writer) => generateModelMetadata(models, writer, options) }], + declarations: [ + { name: 'metadata', initializer: (writer) => generateModelMetadata(models, sf, writer, options) }, + ], }); sf.addStatements('export default metadata;'); return sf; } -function generateModelMetadata(dataModels: DataModel[], writer: CodeBlockWriter, options: ModelMetaGeneratorOptions) { +function generateModelMetadata( + dataModels: DataModel[], + sourceFile: SourceFile, + writer: CodeBlockWriter, + options: ModelMetaGeneratorOptions +) { writer.block(() => { writer.write('fields:'); writer.block(() => { @@ -119,6 +130,12 @@ function generateModelMetadata(dataModels: DataModel[], writer: CodeBlockWriter, foreignKeyMapping: ${JSON.stringify(fkMapping)},`); } + const defaultValueProvider = generateDefaultValueProvider(f, sourceFile); + if (defaultValueProvider) { + writer.write(` + defaultValueProvider: ${defaultValueProvider},`); + } + writer.write(` },`); } @@ -210,13 +227,6 @@ function getFieldAttributes(field: DataModelField): RuntimeAttribute[] { args.push({ name: arg.name, value: v }); } else if (isStringLiteral(arg.value) || isBooleanLiteral(arg.value)) { args.push({ name: arg.name, value: arg.value.value }); - } else if ( - attr.decl.ref?.name === '@default' && - attr.args[0].value.$cstNode?.text.startsWith('auth()') - ) { - const authValue = attr.args[0].value.$cstNode?.text; - const authSelector: AuthContextSelector = authValue.slice('auth().'.length); - args.push({ name: 'auth()', value: authSelector }); } else { // non-literal args are ignored } @@ -334,3 +344,36 @@ function getDeleteCascades(model: DataModel): string[] { }) .map((m) => m.name); } + +function generateDefaultValueProvider(field: DataModelField, sourceFile: SourceFile) { + const defaultAttr = getAttribute(field, '@default'); + if (!defaultAttr) { + return undefined; + } + + const expr = defaultAttr.args[0]?.value; + if (!expr) { + return undefined; + } + + // find `auth()` in default value expression + const hasAuth = streamAst(expr).some(isAuthInvocation); + if (!hasAuth) { + return undefined; + } + + // generates a provider function like: + // function $default$Model$field(user: any) { ... } + const func = sourceFile.addFunction({ + name: `$default$${field.$container.name}$${field.name}`, + parameters: [{ name: 'user', type: 'any' }], + returnType: 'unknown', + statements: (writer) => { + const tsWriter = new TypeScriptExpressionTransformer({ context: ExpressionContext.DefaultValue }); + const code = tsWriter.transform(expr, false); + writer.write(`return ${code};`); + }, + }); + + return func.getName(); +} diff --git a/packages/schema/src/utils/typescript-expression-transformer.ts b/packages/sdk/src/typescript-expression-transformer.ts similarity index 98% rename from packages/schema/src/utils/typescript-expression-transformer.ts rename to packages/sdk/src/typescript-expression-transformer.ts index cd868d76c..20585118c 100644 --- a/packages/schema/src/utils/typescript-expression-transformer.ts +++ b/packages/sdk/src/typescript-expression-transformer.ts @@ -17,9 +17,9 @@ import { ThisExpr, UnaryExpr, } from '@zenstackhq/language/ast'; -import { ExpressionContext, getLiteral, isFromStdlib, isFutureExpr } from '@zenstackhq/sdk'; import { match, P } from 'ts-pattern'; -import { getIdFields } from './ast-utils'; +import { ExpressionContext } from './constants'; +import { getIdFields, getLiteral, isFromStdlib, isFutureExpr } from './utils'; export class TypeScriptExpressionTransformerError extends Error { constructor(message: string) { diff --git a/packages/sdk/src/utils.ts b/packages/sdk/src/utils.ts index d16a42031..2f046b692 100644 --- a/packages/sdk/src/utils.ts +++ b/packages/sdk/src/utils.ts @@ -22,6 +22,7 @@ import { isGeneratorDecl, isInvocationExpr, isLiteralExpr, + isMemberAccessExpr, isModel, isObjectExpr, isReferenceExpr, @@ -341,7 +342,11 @@ export function getFunctionExpressionContext(funcDecl: FunctionDecl) { } export function isFutureExpr(node: AstNode) { - return !!(isInvocationExpr(node) && node.function.ref?.name === 'future' && isFromStdlib(node.function.ref)); + return isInvocationExpr(node) && node.function.ref?.name === 'future' && isFromStdlib(node.function.ref); +} + +export function isAuthInvocation(node: AstNode) { + return isInvocationExpr(node) && node.function.ref?.name === 'auth' && isFromStdlib(node.function.ref); } export function isFromStdlib(node: AstNode) { @@ -380,3 +385,36 @@ export function getAuthModel(dataModels: DataModel[]) { } return authModel; } + +export function getIdFields(dataModel: DataModel) { + const fieldLevelId = dataModel.$resolvedFields.find((f) => + f.attributes.some((attr) => attr.decl.$refText === '@id') + ); + if (fieldLevelId) { + return [fieldLevelId]; + } else { + // get model level @@id attribute + const modelIdAttr = dataModel.attributes.find((attr) => attr.decl?.ref?.name === '@@id'); + if (modelIdAttr) { + // get fields referenced in the attribute: @@id([field1, field2]]) + if (!isArrayExpr(modelIdAttr.args[0].value)) { + return []; + } + const argValue = modelIdAttr.args[0].value; + return argValue.items + .filter((expr): expr is ReferenceExpr => isReferenceExpr(expr) && !!getDataModelFieldReference(expr)) + .map((expr) => expr.target.ref as DataModelField); + } + } + return []; +} + +export function getDataModelFieldReference(expr: Expression): DataModelField | undefined { + if (isReferenceExpr(expr) && isDataModelField(expr.target.ref)) { + return expr.target.ref; + } else if (isMemberAccessExpr(expr) && isDataModelField(expr.member.ref)) { + return expr.member.ref; + } else { + return undefined; + } +} diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 84ef3e88d..641633679 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -620,6 +620,9 @@ importers: '@zenstackhq/runtime': specifier: workspace:* version: link:../runtime/dist + langium: + specifier: 1.2.0 + version: 1.2.0 lower-case-first: specifier: ^2.0.2 version: 2.0.2 @@ -632,6 +635,9 @@ importers: ts-morph: specifier: ^16.0.0 version: 16.0.0 + ts-pattern: + specifier: ^4.3.0 + version: 4.3.0 upper-case-first: specifier: ^2.0.2 version: 2.0.2 diff --git a/tests/integration/tests/enhancements/with-policy/auth.test.ts b/tests/integration/tests/enhancements/with-policy/auth.test.ts index 95838128a..f5b4e2f4f 100644 --- a/tests/integration/tests/enhancements/with-policy/auth.test.ts +++ b/tests/integration/tests/enhancements/with-policy/auth.test.ts @@ -416,43 +416,6 @@ describe('With Policy: auth() test', () => { await expect(userDb.post.count({ where: { authorName: overrideName } })).resolves.toBe(1); }); - it('Default auth() return user context', async () => { - const { enhance, modelMeta } = await loadSchema( - ` - model User { - id String @id - name String - - } - - model Post { - id String @id @default(uuid()) - authorName String? @default(auth().name) - userDetailsAsString String? @default(auth()) - userDetailsAsJson Json? @default(auth()) - - @@allow('all', true) - } - `, - { - provider: 'postgresql', - compile: true, - } - ); - - const userContext = { id: '1', name: 'user1' }; - const db = enhance(userContext); - const userDetailsAttributes = modelMeta.fields.post.userDetails.attributes; - expect(userDetailsAttributes).toHaveProperty('0.name', '@default'); - expect(userDetailsAttributes).toHaveProperty('0.args.0.name', 'auth()'); - expect(userDetailsAttributes).toHaveProperty('0.args.0.value', ''); - await expect(db.post.create({ data: {} })).toResolveTruthy(); - await expect(db.post.count()).resolves.toBe(1); - const post = await db.post.findFirst(); - expect(post?.userDetailsAsString).toEqual(JSON.stringify(userContext)); - expect(post?.userDetailsAsJson).toEqual(userContext); - }); - it('Default auth() with foreign key', async () => { const { enhance, modelMeta } = await loadSchema( ` @@ -476,29 +439,41 @@ describe('With Policy: auth() test', () => { ); const db = enhance({ id: 'userId-1' }); - const attributes = modelMeta.fields.post.authorId.attributes; - expect(attributes).toHaveProperty('0.name', '@default'); - expect(attributes).toHaveProperty('0.args.0.name', 'auth()'); - expect(attributes).toHaveProperty('0.args.0.value', 'id'); await expect(db.user.create({ data: { id: 'userId-1' } })).toResolveTruthy(); - await expect(db.post.create({ data: { title: 'abc' } })).toResolveTruthy(); - await expect(db.post.count({ where: { authorId: 'userId-1' } })).resolves.toBe(1); + await expect(db.post.create({ data: { title: 'abc' } })).resolves.toMatchObject({ authorId: 'userId-1' }); }); it('Default auth() with nested user context value', async () => { - const { enhance, modelMeta } = await loadSchema( + const { enhance } = await loadSchema( ` model User { id String @id + profile Profile? + posts Post[] @@allow('all', true) + } + model Profile { + id String @id @default(uuid()) + image Image? + user User @relation(fields: [userId], references: [id]) + userId String @unique + } + + model Image { + id String @id @default(uuid()) + url String + profile Profile @relation(fields: [profileId], references: [id]) + profileId String @unique } model Post { id String @id @default(uuid()) title String - defaultImageUrl string @default(auth().profile.image.url) + defaultImageUrl String @default(auth().profile.image.url) + author User @relation(fields: [authorId], references: [id]) + authorId String @@allow('all', true) } @@ -506,12 +481,28 @@ describe('With Policy: auth() test', () => { ); const url = 'https://zenstack.dev'; const db = enhance({ id: 'userId-1', profile: { image: { url } } }); - const attributes = modelMeta.fields.post.authorId.attributes; - expect(attributes).toHaveProperty('0.name', '@default'); - expect(attributes).toHaveProperty('0.args.0.name', 'auth()'); - expect(attributes).toHaveProperty('0.args.0.value', 'profile.image.url'); + + // top-level create await expect(db.user.create({ data: { id: 'userId-1' } })).toResolveTruthy(); - await expect(db.post.create({ data: { title: 'abc' } })).toResolveTruthy(); - await expect(db.post.count({ where: { defaultImageUrl: url } })).resolves.toBe(1); + await expect( + db.post.create({ data: { title: 'abc', author: { connect: { id: 'userId-1' } } } }) + ).resolves.toMatchObject({ defaultImageUrl: url }); + + // nested create + let result = await db.user.create({ + data: { + id: 'userId-2', + posts: { + create: [{ title: 'p1' }, { title: 'p2' }], + }, + }, + include: { posts: true }, + }); + expect(result.posts).toEqual( + expect.arrayContaining([ + expect.objectContaining({ title: 'p1', defaultImageUrl: url }), + expect.objectContaining({ title: 'p2', defaultImageUrl: url }), + ]) + ); }); }); From 6bc36a5f4f7fba6c06e0b68c7c2cfaa6cc7e3795 Mon Sep 17 00:00:00 2001 From: augustin Date: Fri, 26 Jan 2024 12:06:39 +0100 Subject: [PATCH 11/12] remove useless newArgs in writer visitor --- packages/runtime/src/enhancements/default-auth.ts | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/packages/runtime/src/enhancements/default-auth.ts b/packages/runtime/src/enhancements/default-auth.ts index 48af0ed73..3673d7dfd 100644 --- a/packages/runtime/src/enhancements/default-auth.ts +++ b/packages/runtime/src/enhancements/default-auth.ts @@ -1,7 +1,6 @@ /* eslint-disable @typescript-eslint/no-unused-vars */ /* eslint-disable @typescript-eslint/no-explicit-any */ -import deepcopy from 'deepcopy'; import { FieldInfo, NestedWriteVisitor, PrismaWriteActionType, enumerate, getFields } from '../cross'; import { DbClientContract } from '../types'; import { EnhancementContext, EnhancementOptions } from './create-enhancement'; @@ -49,15 +48,12 @@ class DefaultAuthHandler extends DefaultPrismaProxyHandler { protected async preprocessArgs(action: PrismaProxyActions, args: any) { const actionsOfInterest: PrismaProxyActions[] = ['create', 'createMany', 'update', 'updateMany', 'upsert']; if (actionsOfInterest.includes(action)) { - const newArgs = await this.preprocessWritePayload(this.model, action as PrismaWriteActionType, args); - return newArgs; + await this.preprocessWritePayload(this.model, action as PrismaWriteActionType, args); } return args; } private async preprocessWritePayload(model: string, action: PrismaWriteActionType, args: any) { - const newArgs = deepcopy(args); - const processCreatePayload = (model: string, data: any) => { const fields = getFields(this.options.modelMeta, model); for (const fieldInfo of Object.values(fields)) { @@ -92,8 +88,7 @@ class DefaultAuthHandler extends DefaultPrismaProxyHandler { }, }); - await visitor.visit(model, action, newArgs); - return newArgs; + await visitor.visit(model, action, args); } private getDefaultValueFromAuth(fieldInfo: FieldInfo) { From a601e9f76922ae35e57e71695af44bd4410cb50a Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Fri, 26 Jan 2024 19:50:00 +0800 Subject: [PATCH 12/12] Revert "remove useless newArgs in writer visitor" This reverts commit 6bc36a5f4f7fba6c06e0b68c7c2cfaa6cc7e3795. --- packages/runtime/src/enhancements/default-auth.ts | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/packages/runtime/src/enhancements/default-auth.ts b/packages/runtime/src/enhancements/default-auth.ts index 3673d7dfd..48af0ed73 100644 --- a/packages/runtime/src/enhancements/default-auth.ts +++ b/packages/runtime/src/enhancements/default-auth.ts @@ -1,6 +1,7 @@ /* eslint-disable @typescript-eslint/no-unused-vars */ /* eslint-disable @typescript-eslint/no-explicit-any */ +import deepcopy from 'deepcopy'; import { FieldInfo, NestedWriteVisitor, PrismaWriteActionType, enumerate, getFields } from '../cross'; import { DbClientContract } from '../types'; import { EnhancementContext, EnhancementOptions } from './create-enhancement'; @@ -48,12 +49,15 @@ class DefaultAuthHandler extends DefaultPrismaProxyHandler { protected async preprocessArgs(action: PrismaProxyActions, args: any) { const actionsOfInterest: PrismaProxyActions[] = ['create', 'createMany', 'update', 'updateMany', 'upsert']; if (actionsOfInterest.includes(action)) { - await this.preprocessWritePayload(this.model, action as PrismaWriteActionType, args); + const newArgs = await this.preprocessWritePayload(this.model, action as PrismaWriteActionType, args); + return newArgs; } return args; } private async preprocessWritePayload(model: string, action: PrismaWriteActionType, args: any) { + const newArgs = deepcopy(args); + const processCreatePayload = (model: string, data: any) => { const fields = getFields(this.options.modelMeta, model); for (const fieldInfo of Object.values(fields)) { @@ -88,7 +92,8 @@ class DefaultAuthHandler extends DefaultPrismaProxyHandler { }, }); - await visitor.visit(model, action, args); + await visitor.visit(model, action, newArgs); + return newArgs; } private getDefaultValueFromAuth(fieldInfo: FieldInfo) {