diff --git a/packages/runtime/src/plugins/policy/policy-handler.ts b/packages/runtime/src/plugins/policy/policy-handler.ts index f26c2038..fcbabe43 100644 --- a/packages/runtime/src/plugins/policy/policy-handler.ts +++ b/packages/runtime/src/plugins/policy/policy-handler.ts @@ -134,6 +134,102 @@ export class PolicyHandler extends OperationNodeTransf // return result; } + // #region overrides + + protected override transformSelectQuery(node: SelectQueryNode) { + let whereNode = node.where; + + node.from?.froms.forEach((from) => { + const extractResult = this.extractTableName(from); + if (extractResult) { + const { model, alias } = extractResult; + const filter = this.buildPolicyFilter(model, alias, 'read'); + whereNode = WhereNode.create( + whereNode?.where ? conjunction(this.dialect, [whereNode.where, filter]) : filter, + ); + } + }); + + const baseResult = super.transformSelectQuery({ + ...node, + where: undefined, + }); + + return { + ...baseResult, + where: whereNode, + }; + } + + protected override transformInsertQuery(node: InsertQueryNode) { + // pre-insert check is done in `handle()` + + let onConflict = node.onConflict; + + if (onConflict?.updates) { + // for "on conflict do update", we need to apply policy filter to the "where" clause + const mutationModel = this.getMutationModel(node); + const filter = this.buildPolicyFilter(mutationModel, undefined, 'update'); + if (onConflict.updateWhere) { + onConflict = { + ...onConflict, + updateWhere: WhereNode.create(conjunction(this.dialect, [onConflict.updateWhere.where, filter])), + }; + } else { + onConflict = { + ...onConflict, + updateWhere: WhereNode.create(filter), + }; + } + } + + // merge updated onConflict + const processedNode = onConflict ? { ...node, onConflict } : node; + + const result = super.transformInsertQuery(processedNode); + + if (!node.returning) { + return result; + } + + if (this.onlyReturningId(node)) { + return result; + } else { + // only return ID fields, that's enough for reading back the inserted row + const idFields = getIdFields(this.client.$schema, this.getMutationModel(node)); + return { + ...result, + returning: ReturningNode.create( + idFields.map((field) => SelectionNode.create(ColumnNode.create(field))), + ), + }; + } + } + + protected override transformUpdateQuery(node: UpdateQueryNode) { + const result = super.transformUpdateQuery(node); + const mutationModel = this.getMutationModel(node); + const filter = this.buildPolicyFilter(mutationModel, undefined, 'update'); + return { + ...result, + where: WhereNode.create(result.where ? conjunction(this.dialect, [result.where.where, filter]) : filter), + }; + } + + protected override transformDeleteQuery(node: DeleteQueryNode) { + const result = super.transformDeleteQuery(node); + const mutationModel = this.getMutationModel(node); + const filter = this.buildPolicyFilter(mutationModel, undefined, 'delete'); + return { + ...result, + where: WhereNode.create(result.where ? conjunction(this.dialect, [result.where.where, filter]) : filter), + }; + } + + // #endregion + + // #region helpers + private onlyReturningId(node: MutationQueryNode) { if (!node.returning) { return true; @@ -397,70 +493,6 @@ export class PolicyHandler extends OperationNodeTransf return combinedPolicy; } - protected override transformSelectQuery(node: SelectQueryNode) { - let whereNode = node.where; - - node.from?.froms.forEach((from) => { - const extractResult = this.extractTableName(from); - if (extractResult) { - const { model, alias } = extractResult; - const filter = this.buildPolicyFilter(model, alias, 'read'); - whereNode = WhereNode.create( - whereNode?.where ? conjunction(this.dialect, [whereNode.where, filter]) : filter, - ); - } - }); - - const baseResult = super.transformSelectQuery({ - ...node, - where: undefined, - }); - - return { - ...baseResult, - where: whereNode, - }; - } - - protected override transformInsertQuery(node: InsertQueryNode) { - const result = super.transformInsertQuery(node); - if (!node.returning) { - return result; - } - if (this.onlyReturningId(node)) { - return result; - } else { - // only return ID fields, that's enough for reading back the inserted row - const idFields = getIdFields(this.client.$schema, this.getMutationModel(node)); - return { - ...result, - returning: ReturningNode.create( - idFields.map((field) => SelectionNode.create(ColumnNode.create(field))), - ), - }; - } - } - - protected override transformUpdateQuery(node: UpdateQueryNode) { - const result = super.transformUpdateQuery(node); - const mutationModel = this.getMutationModel(node); - const filter = this.buildPolicyFilter(mutationModel, undefined, 'update'); - return { - ...result, - where: WhereNode.create(result.where ? conjunction(this.dialect, [result.where.where, filter]) : filter), - }; - } - - protected override transformDeleteQuery(node: DeleteQueryNode) { - const result = super.transformDeleteQuery(node); - const mutationModel = this.getMutationModel(node); - const filter = this.buildPolicyFilter(mutationModel, undefined, 'delete'); - return { - ...result, - where: WhereNode.create(result.where ? conjunction(this.dialect, [result.where.where, filter]) : filter), - }; - } - private extractTableName(from: OperationNode): { model: GetModels; alias?: string } | undefined { if (TableNode.is(from)) { return { model: from.table.identifier.name as GetModels }; @@ -528,4 +560,6 @@ export class PolicyHandler extends OperationNodeTransf } return result; } + + // #endregion } diff --git a/packages/runtime/test/policy/crud/delete.test.ts b/packages/runtime/test/policy/crud/delete.test.ts new file mode 100644 index 00000000..f515f0dc --- /dev/null +++ b/packages/runtime/test/policy/crud/delete.test.ts @@ -0,0 +1,51 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from '../utils'; + +describe('Delete policy tests', () => { + it('works with top-level delete/deleteMany', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + @@allow('create,read', true) + @@allow('delete', x > 0) +} +`, + ); + + await db.foo.create({ data: { id: 1, x: 0 } }); + await expect(db.foo.delete({ where: { id: 1 } })).toBeRejectedNotFound(); + + await db.foo.create({ data: { id: 2, x: 1 } }); + await expect(db.foo.delete({ where: { id: 2 } })).toResolveTruthy(); + await expect(db.foo.count()).resolves.toBe(1); + + await db.foo.create({ data: { id: 3, x: 1 } }); + await expect(db.foo.deleteMany()).resolves.toMatchObject({ count: 1 }); + await expect(db.foo.count()).resolves.toBe(1); + }); + + it('works with query builder delete', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + @@allow('create,read', true) + @@allow('delete', x > 0) +} +`, + ); + await db.foo.create({ data: { id: 1, x: 0 } }); + await db.foo.create({ data: { id: 2, x: 1 } }); + + await expect(db.$qb.deleteFrom('Foo').where('id', '=', 1).executeTakeFirst()).resolves.toMatchObject({ + numDeletedRows: 0n, + }); + await expect(db.foo.count()).resolves.toBe(2); + + await expect(db.$qb.deleteFrom('Foo').executeTakeFirst()).resolves.toMatchObject({ numDeletedRows: 1n }); + await expect(db.foo.count()).resolves.toBe(1); + }); +}); diff --git a/packages/runtime/test/policy/crud/update.test.ts b/packages/runtime/test/policy/crud/update.test.ts index ef56b40f..f7b2b820 100644 --- a/packages/runtime/test/policy/crud/update.test.ts +++ b/packages/runtime/test/policy/crud/update.test.ts @@ -953,4 +953,100 @@ model Foo { ); }); }); + + describe('Query builder tests', () => { + it('works with simple update', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + @@allow('create', true) + @@allow('update', x > 1) + @@allow('read', true) +} +`, + ); + + await db.foo.createMany({ + data: [ + { id: 1, x: 1 }, + { id: 2, x: 2 }, + { id: 3, x: 3 }, + ], + }); + + // not updatable + await expect( + db.$qb.updateTable('Foo').set({ x: 5 }).where('id', '=', 1).executeTakeFirst(), + ).resolves.toMatchObject({ numUpdatedRows: 0n }); + + // with where + await expect( + db.$qb.updateTable('Foo').set({ x: 5 }).where('id', '=', 2).executeTakeFirst(), + ).resolves.toMatchObject({ numUpdatedRows: 1n }); + await expect(db.foo.findUnique({ where: { id: 2 } })).resolves.toMatchObject({ x: 5 }); + + // without where + await expect(db.$qb.updateTable('Foo').set({ x: 6 }).executeTakeFirst()).resolves.toMatchObject({ + numUpdatedRows: 2n, + }); + await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 1 }); + }); + + it('works with insert on conflict do update', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + @@allow('create', true) + @@allow('update', x > 1) + @@allow('read', true) +} +`, + ); + + await db.foo.createMany({ + data: [ + { id: 1, x: 1 }, + { id: 2, x: 2 }, + { id: 3, x: 3 }, + ], + }); + + // #1 not updatable + await expect( + db.$qb + .insertInto('Foo') + .values({ id: 1, x: 5 }) + .onConflict((oc: any) => oc.column('id').doUpdateSet({ x: 5 })) + .executeTakeFirst(), + ).resolves.toMatchObject({ numInsertedOrUpdatedRows: 0n }); + await expect(db.foo.count()).resolves.toBe(3); + await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 1 }); + + // with where, #1 not updatable + await expect( + db.$qb + .insertInto('Foo') + .values({ id: 1, x: 5 }) + .onConflict((oc: any) => oc.column('id').doUpdateSet({ x: 5 }).where('id', '=', 1)) + .executeTakeFirst(), + ).resolves.toMatchObject({ numInsertedOrUpdatedRows: 0n }); + await expect(db.foo.count()).resolves.toBe(3); + await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 1 }); + + // with where, #2 updatable + await expect( + db.$qb + .insertInto('Foo') + .values({ id: 2, x: 5 }) + .onConflict((oc: any) => oc.column('id').doUpdateSet({ x: 6 }).where('id', '=', 2)) + .executeTakeFirst(), + ).resolves.toMatchObject({ numInsertedOrUpdatedRows: 1n }); + await expect(db.foo.count()).resolves.toBe(3); + await expect(db.foo.findUnique({ where: { id: 2 } })).resolves.toMatchObject({ x: 6 }); + }); + }); });