Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 87 additions & 25 deletions packages/runtime/src/enhancements/policy/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -447,44 +447,106 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr

// go through create items, statically check input to determine if post-create
// check is needed, and also validate zod schema
let needPostCreateCheck = false;
for (const item of enumerate(args.data)) {
const validationResult = this.validateCreateInputSchema(this.model, item);
if (validationResult !== item) {
this.policyUtils.replace(item, validationResult);
}

const inputCheck = this.policyUtils.checkInputGuard(this.model, item, 'create');
if (inputCheck === false) {
// unconditionally deny
throw this.policyUtils.deniedByPolicy(
this.model,
'create',
undefined,
CrudFailureReason.ACCESS_POLICY_VIOLATION
);
} else if (inputCheck === true) {
// unconditionally allow
} else if (inputCheck === undefined) {
// static policy check is not possible, need to do post-create check
needPostCreateCheck = true;
}
}
const needPostCreateCheck = this.validateCreateInput(args);

if (!needPostCreateCheck) {
// direct create
return this.modelClient.createMany(args);
} else {
// create entities in a transaction with post-create checks
return this.queryUtils.transaction(this.prisma, async (tx) => {
const { result, postWriteChecks } = await this.doCreateMany(this.model, args, tx);
// post-create check
await this.runPostWriteChecks(postWriteChecks, tx);
return result;
return { count: result.length };
});
}
});
}

createManyAndReturn(args: { select: any; include: any; data: any; skipDuplicates?: boolean }) {
if (!args) {
throw prismaClientValidationError(this.prisma, this.prismaModule, 'query argument is required');
}
if (!args.data) {
throw prismaClientValidationError(
this.prisma,
this.prismaModule,
'data field is required in query argument'
);
}

return createDeferredPromise(async () => {
this.policyUtils.tryReject(this.prisma, this.model, 'create');

const origArgs = args;
args = clone(args);

// go through create items, statically check input to determine if post-create
// check is needed, and also validate zod schema
const needPostCreateCheck = this.validateCreateInput(args);

let result: { result: unknown; error?: Error }[];

if (!needPostCreateCheck) {
// direct create
const created = await this.modelClient.createManyAndReturn(args);

// process read-back
result = await Promise.all(
created.map((item) => this.policyUtils.readBack(this.prisma, this.model, 'create', origArgs, item))
);
} else {
// create entities in a transaction with post-create checks
result = await this.queryUtils.transaction(this.prisma, async (tx) => {
const { result: created, postWriteChecks } = await this.doCreateMany(this.model, args, tx);
// post-create check
await this.runPostWriteChecks(postWriteChecks, tx);

// process read-back
return Promise.all(
created.map((item) => this.policyUtils.readBack(tx, this.model, 'create', origArgs, item))
);
});
}

// throw read-back error if any of create result read-back fails
const error = result.find((r) => !!r.error)?.error;
if (error) {
throw error;
} else {
return result.map((r) => r.result);
}
});
}

private validateCreateInput(args: { data: any; skipDuplicates?: boolean | undefined }) {
let needPostCreateCheck = false;
for (const item of enumerate(args.data)) {
const validationResult = this.validateCreateInputSchema(this.model, item);
if (validationResult !== item) {
this.policyUtils.replace(item, validationResult);
}

const inputCheck = this.policyUtils.checkInputGuard(this.model, item, 'create');
if (inputCheck === false) {
// unconditionally deny
throw this.policyUtils.deniedByPolicy(
this.model,
'create',
undefined,
CrudFailureReason.ACCESS_POLICY_VIOLATION
);
} else if (inputCheck === true) {
// unconditionally allow
} else if (inputCheck === undefined) {
// static policy check is not possible, need to do post-create check
needPostCreateCheck = true;
}
}
return needPostCreateCheck;
}

private async doCreateMany(model: string, args: { data: any; skipDuplicates?: boolean }, db: CrudContract) {
// We can't call the native "createMany" because we can't get back what was created
// for post-create checks. Instead, do a "create" for each item and collect the results.
Expand All @@ -511,7 +573,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
createResult = createResult.filter((p) => !!p);

return {
result: { count: createResult.length },
result: createResult,
postWriteChecks: createResult.map((item) => ({
model,
operation: 'create' as PolicyOperationKind,
Expand Down
6 changes: 6 additions & 0 deletions packages/runtime/src/enhancements/proxy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ export interface PrismaProxyHandler {

createMany(args: { data: any; skipDuplicates?: boolean }): Promise<BatchResult>;

createManyAndReturn(args: { data: any; select: any; include: any; skipDuplicates?: boolean }): Promise<unknown[]>;

update(args: any): Promise<unknown>;

updateMany(args: any): Promise<BatchResult>;
Expand Down Expand Up @@ -122,6 +124,10 @@ export class DefaultPrismaProxyHandler implements PrismaProxyHandler {
return this.deferred<{ count: number }>('createMany', args, false);
}

createManyAndReturn(args: { data: any; select: any; include: any; skipDuplicates?: boolean }) {
return this.deferred<unknown[]>('createManyAndReturn', args);
}

update(args: any) {
return this.deferred('update', args);
}
Expand Down
3 changes: 2 additions & 1 deletion packages/runtime/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ export interface DbOperations {
findUnique(args: unknown): PrismaPromise<any>;
findUniqueOrThrow(args: unknown): PrismaPromise<any>;
create(args: unknown): Promise<any>;
createMany(args: unknown, skipDuplicates?: boolean): Promise<{ count: number }>;
createMany(args: unknown): Promise<{ count: number }>;
createManyAndReturn(args: unknown): Promise<unknown[]>;
update(args: unknown): Promise<any>;
updateMany(args: unknown): Promise<{ count: number }>;
upsert(args: unknown): Promise<any>;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import { loadSchema } from '@zenstackhq/testtools';

describe('Test API createManyAndReturn', () => {
it('model-level policies', async () => {
const { prisma, enhance } = await loadSchema(
`
model User {
id Int @id @default(autoincrement())
posts Post[]
level Int

@@allow('read', level > 0)
}

model Post {
id Int @id @default(autoincrement())
title String
published Boolean @default(false)
userId Int
user User @relation(fields: [userId], references: [id])

@@allow('read', published)
@@allow('create', contains(title, 'hello'))
}
`
);

await prisma.user.createMany({
data: [
{ id: 1, level: 1 },
{ id: 2, level: 0 },
],
});

const db = enhance();

// create rule violation
await expect(
db.post.createManyAndReturn({
data: [{ title: 'foo', userId: 1 }],
})
).toBeRejectedByPolicy();

// success
let r = await db.post.createManyAndReturn({
data: [{ id: 1, title: 'hello1', userId: 1, published: true }],
});
expect(r.length).toBe(1);

// read-back check
await expect(
db.post.createManyAndReturn({
data: [
{ id: 2, title: 'hello2', userId: 1, published: true },
{ id: 3, title: 'hello3', userId: 1, published: false },
],
})
).toBeRejectedByPolicy(['result is not allowed to be read back']);
await expect(prisma.post.findMany()).resolves.toHaveLength(3);

// return relation
await prisma.post.deleteMany();
r = await db.post.createManyAndReturn({
include: { user: true },
data: [{ id: 1, title: 'hello1', userId: 1, published: true }],
});
expect(r[0]).toMatchObject({ user: { id: 1 } });

// relation filtered
await prisma.post.deleteMany();
await expect(
db.post.createManyAndReturn({
include: { user: true },
data: [{ id: 1, title: 'hello1', userId: 2, published: true }],
})
).toBeRejectedByPolicy(['result is not allowed to be read back']);
await expect(prisma.post.findMany()).resolves.toHaveLength(1);
});

it('field-level policies', async () => {
const { prisma, enhance } = await loadSchema(
`
model Post {
id Int @id @default(autoincrement())
title String @allow('read', published)
published Boolean @default(false)

@@allow('all', true)
}
`
);

const db = enhance();

const r = await db.post.createManyAndReturn({
data: [
{ title: 'post1', published: true },
{ title: 'post2', published: false },
],
});
expect(r).toHaveLength(2);
expect(r[0].title).toBe('post1');
expect(r[1].title).toBeUndefined();
});
});