From 161a0fd2e6710b0b8899ade0244e05614ddb8c91 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Sun, 2 Jun 2024 15:31:49 +0800 Subject: [PATCH] feat: support for Prisma 5.14's `createManyAndReturn` --- .../src/enhancements/policy/handler.ts | 112 ++++++++++++++---- packages/runtime/src/enhancements/proxy.ts | 6 + packages/runtime/src/types.ts | 3 +- .../create-many-and-return.test.ts | 105 ++++++++++++++++ 4 files changed, 200 insertions(+), 26 deletions(-) create mode 100644 tests/integration/tests/enhancements/with-policy/create-many-and-return.test.ts diff --git a/packages/runtime/src/enhancements/policy/handler.ts b/packages/runtime/src/enhancements/policy/handler.ts index 7ce3a8987..b6088ed25 100644 --- a/packages/runtime/src/enhancements/policy/handler.ts +++ b/packages/runtime/src/enhancements/policy/handler.ts @@ -447,31 +447,10 @@ export class PolicyProxyHandler implements Pr // go through create items, statically check input to determine if post-create // check is needed, and also validate zod schema - let needPostCreateCheck = false; - for (const item of enumerate(args.data)) { - const validationResult = this.validateCreateInputSchema(this.model, item); - if (validationResult !== item) { - this.policyUtils.replace(item, validationResult); - } - - const inputCheck = this.policyUtils.checkInputGuard(this.model, item, 'create'); - if (inputCheck === false) { - // unconditionally deny - throw this.policyUtils.deniedByPolicy( - this.model, - 'create', - undefined, - CrudFailureReason.ACCESS_POLICY_VIOLATION - ); - } else if (inputCheck === true) { - // unconditionally allow - } else if (inputCheck === undefined) { - // static policy check is not possible, need to do post-create check - needPostCreateCheck = true; - } - } + const needPostCreateCheck = this.validateCreateInput(args); if (!needPostCreateCheck) { + // direct create return this.modelClient.createMany(args); } else { // create entities in a transaction with post-create checks @@ -479,12 +458,95 @@ export class PolicyProxyHandler implements Pr const { result, postWriteChecks } = await this.doCreateMany(this.model, args, tx); // post-create check await this.runPostWriteChecks(postWriteChecks, tx); - return result; + return { count: result.length }; }); } }); } + createManyAndReturn(args: { select: any; include: any; data: any; skipDuplicates?: boolean }) { + if (!args) { + throw prismaClientValidationError(this.prisma, this.prismaModule, 'query argument is required'); + } + if (!args.data) { + throw prismaClientValidationError( + this.prisma, + this.prismaModule, + 'data field is required in query argument' + ); + } + + return createDeferredPromise(async () => { + this.policyUtils.tryReject(this.prisma, this.model, 'create'); + + const origArgs = args; + args = clone(args); + + // go through create items, statically check input to determine if post-create + // check is needed, and also validate zod schema + const needPostCreateCheck = this.validateCreateInput(args); + + let result: { result: unknown; error?: Error }[]; + + if (!needPostCreateCheck) { + // direct create + const created = await this.modelClient.createManyAndReturn(args); + + // process read-back + result = await Promise.all( + created.map((item) => this.policyUtils.readBack(this.prisma, this.model, 'create', origArgs, item)) + ); + } else { + // create entities in a transaction with post-create checks + result = await this.queryUtils.transaction(this.prisma, async (tx) => { + const { result: created, postWriteChecks } = await this.doCreateMany(this.model, args, tx); + // post-create check + await this.runPostWriteChecks(postWriteChecks, tx); + + // process read-back + return Promise.all( + created.map((item) => this.policyUtils.readBack(tx, this.model, 'create', origArgs, item)) + ); + }); + } + + // throw read-back error if any of create result read-back fails + const error = result.find((r) => !!r.error)?.error; + if (error) { + throw error; + } else { + return result.map((r) => r.result); + } + }); + } + + private validateCreateInput(args: { data: any; skipDuplicates?: boolean | undefined }) { + let needPostCreateCheck = false; + for (const item of enumerate(args.data)) { + const validationResult = this.validateCreateInputSchema(this.model, item); + if (validationResult !== item) { + this.policyUtils.replace(item, validationResult); + } + + const inputCheck = this.policyUtils.checkInputGuard(this.model, item, 'create'); + if (inputCheck === false) { + // unconditionally deny + throw this.policyUtils.deniedByPolicy( + this.model, + 'create', + undefined, + CrudFailureReason.ACCESS_POLICY_VIOLATION + ); + } else if (inputCheck === true) { + // unconditionally allow + } else if (inputCheck === undefined) { + // static policy check is not possible, need to do post-create check + needPostCreateCheck = true; + } + } + return needPostCreateCheck; + } + private async doCreateMany(model: string, args: { data: any; skipDuplicates?: boolean }, db: CrudContract) { // We can't call the native "createMany" because we can't get back what was created // for post-create checks. Instead, do a "create" for each item and collect the results. @@ -511,7 +573,7 @@ export class PolicyProxyHandler implements Pr createResult = createResult.filter((p) => !!p); return { - result: { count: createResult.length }, + result: createResult, postWriteChecks: createResult.map((item) => ({ model, operation: 'create' as PolicyOperationKind, diff --git a/packages/runtime/src/enhancements/proxy.ts b/packages/runtime/src/enhancements/proxy.ts index e7f55a88c..70d8f27e9 100644 --- a/packages/runtime/src/enhancements/proxy.ts +++ b/packages/runtime/src/enhancements/proxy.ts @@ -35,6 +35,8 @@ export interface PrismaProxyHandler { createMany(args: { data: any; skipDuplicates?: boolean }): Promise; + createManyAndReturn(args: { data: any; select: any; include: any; skipDuplicates?: boolean }): Promise; + update(args: any): Promise; updateMany(args: any): Promise; @@ -122,6 +124,10 @@ export class DefaultPrismaProxyHandler implements PrismaProxyHandler { return this.deferred<{ count: number }>('createMany', args, false); } + createManyAndReturn(args: { data: any; select: any; include: any; skipDuplicates?: boolean }) { + return this.deferred('createManyAndReturn', args); + } + update(args: any) { return this.deferred('update', args); } diff --git a/packages/runtime/src/types.ts b/packages/runtime/src/types.ts index b9497b7ee..bf1660090 100644 --- a/packages/runtime/src/types.ts +++ b/packages/runtime/src/types.ts @@ -12,7 +12,8 @@ export interface DbOperations { findUnique(args: unknown): PrismaPromise; findUniqueOrThrow(args: unknown): PrismaPromise; create(args: unknown): Promise; - createMany(args: unknown, skipDuplicates?: boolean): Promise<{ count: number }>; + createMany(args: unknown): Promise<{ count: number }>; + createManyAndReturn(args: unknown): Promise; update(args: unknown): Promise; updateMany(args: unknown): Promise<{ count: number }>; upsert(args: unknown): Promise; diff --git a/tests/integration/tests/enhancements/with-policy/create-many-and-return.test.ts b/tests/integration/tests/enhancements/with-policy/create-many-and-return.test.ts new file mode 100644 index 000000000..c96f16256 --- /dev/null +++ b/tests/integration/tests/enhancements/with-policy/create-many-and-return.test.ts @@ -0,0 +1,105 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('Test API createManyAndReturn', () => { + it('model-level policies', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + posts Post[] + level Int + + @@allow('read', level > 0) + } + + model Post { + id Int @id @default(autoincrement()) + title String + published Boolean @default(false) + userId Int + user User @relation(fields: [userId], references: [id]) + + @@allow('read', published) + @@allow('create', contains(title, 'hello')) + } + ` + ); + + await prisma.user.createMany({ + data: [ + { id: 1, level: 1 }, + { id: 2, level: 0 }, + ], + }); + + const db = enhance(); + + // create rule violation + await expect( + db.post.createManyAndReturn({ + data: [{ title: 'foo', userId: 1 }], + }) + ).toBeRejectedByPolicy(); + + // success + let r = await db.post.createManyAndReturn({ + data: [{ id: 1, title: 'hello1', userId: 1, published: true }], + }); + expect(r.length).toBe(1); + + // read-back check + await expect( + db.post.createManyAndReturn({ + data: [ + { id: 2, title: 'hello2', userId: 1, published: true }, + { id: 3, title: 'hello3', userId: 1, published: false }, + ], + }) + ).toBeRejectedByPolicy(['result is not allowed to be read back']); + await expect(prisma.post.findMany()).resolves.toHaveLength(3); + + // return relation + await prisma.post.deleteMany(); + r = await db.post.createManyAndReturn({ + include: { user: true }, + data: [{ id: 1, title: 'hello1', userId: 1, published: true }], + }); + expect(r[0]).toMatchObject({ user: { id: 1 } }); + + // relation filtered + await prisma.post.deleteMany(); + await expect( + db.post.createManyAndReturn({ + include: { user: true }, + data: [{ id: 1, title: 'hello1', userId: 2, published: true }], + }) + ).toBeRejectedByPolicy(['result is not allowed to be read back']); + await expect(prisma.post.findMany()).resolves.toHaveLength(1); + }); + + it('field-level policies', async () => { + const { prisma, enhance } = await loadSchema( + ` + model Post { + id Int @id @default(autoincrement()) + title String @allow('read', published) + published Boolean @default(false) + + @@allow('all', true) + } + ` + ); + + const db = enhance(); + + const r = await db.post.createManyAndReturn({ + data: [ + { title: 'post1', published: true }, + { title: 'post2', published: false }, + ], + }); + expect(r).toHaveLength(2); + expect(r[0].title).toBe('post1'); + expect(r[1].title).toBeUndefined(); + }); +});