Skip to content

Commit 6610bd0

Browse files
authored
feat: allow comparing fields from different models in mutation policies (#1476)
1 parent fe97241 commit 6610bd0

File tree

13 files changed

+1618
-329
lines changed

13 files changed

+1618
-329
lines changed

packages/runtime/src/enhancements/policy/constraint-solver.ts

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import Logic from 'logic-solver';
22
import { match } from 'ts-pattern';
33
import type {
4-
CheckerConstraint,
54
ComparisonConstraint,
65
ComparisonTerm,
76
LogicalConstraint,
7+
PermissionCheckerConstraint,
88
ValueConstraint,
99
VariableConstraint,
1010
} from '../types';
@@ -22,7 +22,7 @@ export class ConstraintSolver {
2222
/**
2323
* Check the satisfiability of the given constraint.
2424
*/
25-
checkSat(constraint: CheckerConstraint): boolean {
25+
checkSat(constraint: PermissionCheckerConstraint): boolean {
2626
// reset state
2727
this.stringTable = [];
2828
this.variables = new Map<string, Logic.Formula>();
@@ -46,7 +46,7 @@ export class ConstraintSolver {
4646
return !!solver.solve();
4747
}
4848

49-
private buildFormula(constraint: CheckerConstraint): Logic.Formula {
49+
private buildFormula(constraint: PermissionCheckerConstraint): Logic.Formula {
5050
return match(constraint)
5151
.when(
5252
(c): c is ValueConstraint => c.kind === 'value',
@@ -100,11 +100,11 @@ export class ConstraintSolver {
100100
return Logic.not(this.buildFormula(constraint.children[0]));
101101
}
102102

103-
private isTrue(constraint: CheckerConstraint): unknown {
103+
private isTrue(constraint: PermissionCheckerConstraint): unknown {
104104
return constraint.kind === 'value' && constraint.value === true;
105105
}
106106

107-
private isFalse(constraint: CheckerConstraint): unknown {
107+
private isFalse(constraint: PermissionCheckerConstraint): unknown {
108108
return constraint.kind === 'value' && constraint.value === false;
109109
}
110110

packages/runtime/src/enhancements/policy/handler.ts

Lines changed: 151 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
/* eslint-disable @typescript-eslint/no-explicit-any */
22

3+
import deepmerge from 'deepmerge';
34
import { lowerCaseFirst } from 'lower-case-first';
45
import invariant from 'tiny-invariant';
56
import { P, match } from 'ts-pattern';
@@ -23,7 +24,7 @@ import { Logger } from '../logger';
2324
import { createDeferredPromise, createFluentPromise } from '../promise';
2425
import { PrismaProxyHandler } from '../proxy';
2526
import { QueryUtils } from '../query-utils';
26-
import type { CheckerConstraint } from '../types';
27+
import type { EntityCheckerFunc, PermissionCheckerConstraint } from '../types';
2728
import { clone, formatObject, isUnsafeMutate, prismaClientValidationError } from '../utils';
2829
import { ConstraintSolver } from './constraint-solver';
2930
import { PolicyUtil } from './policy-utils';
@@ -152,8 +153,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
152153
}
153154

154155
const result = await this.modelClient[actionName](_args);
155-
this.policyUtils.postProcessForRead(result, this.model, origArgs);
156-
return result;
156+
return this.policyUtils.postProcessForRead(result, this.model, origArgs);
157157
}
158158

159159
//#endregion
@@ -779,10 +779,27 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
779779
}
780780
};
781781

782-
const _connectDisconnect = async (model: string, args: any, context: NestedWriteVisitorContext) => {
782+
const _connectDisconnect = async (
783+
model: string,
784+
args: any,
785+
context: NestedWriteVisitorContext,
786+
operation: 'connect' | 'disconnect'
787+
) => {
783788
if (context.field?.backLink) {
784789
const backLinkField = this.policyUtils.getModelField(model, context.field.backLink);
785790
if (backLinkField?.isRelationOwner) {
791+
let uniqueFilter = args;
792+
if (operation === 'disconnect') {
793+
// disconnect filter is not unique, need to build a reversed query to
794+
// locate the entity and use its id fields as unique filter
795+
const reversedQuery = this.policyUtils.buildReversedQuery(context);
796+
const found = await db[model].findUnique({
797+
where: reversedQuery,
798+
select: this.policyUtils.makeIdSelection(model),
799+
});
800+
uniqueFilter = found && this.policyUtils.getIdFieldValues(model, found);
801+
}
802+
786803
// update happens on the related model, require updatable,
787804
// translate args to foreign keys so field-level policies can be checked
788805
const checkArgs: any = {};
@@ -794,10 +811,15 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
794811
}
795812
}
796813
}
797-
await this.policyUtils.checkPolicyForUnique(model, args, 'update', db, checkArgs);
798814

799-
// register post-update check
800-
await _registerPostUpdateCheck(model, args, args);
815+
// `uniqueFilter` can be undefined if the entity to be disconnected doesn't exist
816+
if (uniqueFilter) {
817+
// check for update
818+
await this.policyUtils.checkPolicyForUnique(model, uniqueFilter, 'update', db, checkArgs);
819+
820+
// register post-update check
821+
await _registerPostUpdateCheck(model, uniqueFilter, uniqueFilter);
822+
}
801823
}
802824
}
803825
};
@@ -970,14 +992,14 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
970992
}
971993
},
972994

973-
connect: async (model, args, context) => _connectDisconnect(model, args, context),
995+
connect: async (model, args, context) => _connectDisconnect(model, args, context, 'connect'),
974996

975997
connectOrCreate: async (model, args, context) => {
976998
// the where condition is already unique, so we can use it to check if the target exists
977999
const existing = await this.policyUtils.checkExistence(db, model, args.where);
9781000
if (existing) {
9791001
// connect
980-
await _connectDisconnect(model, args.where, context);
1002+
await _connectDisconnect(model, args.where, context, 'connect');
9811003
return true;
9821004
} else {
9831005
// create
@@ -997,7 +1019,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
9971019
}
9981020
},
9991021

1000-
disconnect: async (model, args, context) => _connectDisconnect(model, args, context),
1022+
disconnect: async (model, args, context) => _connectDisconnect(model, args, context, 'disconnect'),
10011023

10021024
set: async (model, args, context) => {
10031025
// find the set of items to be replaced
@@ -1012,10 +1034,10 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
10121034
const currentSet = await db[model].findMany(findCurrSetArgs);
10131035

10141036
// register current set for update (foreign key)
1015-
await Promise.all(currentSet.map((item) => _connectDisconnect(model, item, context)));
1037+
await Promise.all(currentSet.map((item) => _connectDisconnect(model, item, context, 'disconnect')));
10161038

10171039
// proceed with connecting the new set
1018-
await Promise.all(enumerate(args).map((item) => _connectDisconnect(model, item, context)));
1040+
await Promise.all(enumerate(args).map((item) => _connectDisconnect(model, item, context, 'connect')));
10191041
},
10201042

10211043
delete: async (model, args, context) => {
@@ -1160,48 +1182,78 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
11601182

11611183
args.data = this.validateUpdateInputSchema(this.model, args.data);
11621184

1163-
if (this.policyUtils.hasAuthGuard(this.model, 'postUpdate') || this.policyUtils.getZodSchema(this.model)) {
1164-
// use a transaction to do post-update checks
1165-
const postWriteChecks: PostWriteCheckRecord[] = [];
1166-
return this.queryUtils.transaction(this.prisma, async (tx) => {
1167-
// collect pre-update values
1168-
let select = this.policyUtils.makeIdSelection(this.model);
1169-
const preValueSelect = this.policyUtils.getPreValueSelect(this.model);
1170-
if (preValueSelect) {
1171-
select = { ...select, ...preValueSelect };
1172-
}
1173-
const currentSetQuery = { select, where: args.where };
1174-
this.policyUtils.injectAuthGuardAsWhere(tx, currentSetQuery, this.model, 'read');
1175-
1176-
if (this.shouldLogQuery) {
1177-
this.logger.info(`[policy] \`findMany\` ${this.model}: ${formatObject(currentSetQuery)}`);
1178-
}
1179-
const currentSet = await tx[this.model].findMany(currentSetQuery);
1180-
1181-
postWriteChecks.push(
1182-
...currentSet.map((preValue) => ({
1183-
model: this.model,
1184-
operation: 'postUpdate' as PolicyOperationKind,
1185-
uniqueFilter: this.policyUtils.getEntityIds(this.model, preValue),
1186-
preValue: preValueSelect ? preValue : undefined,
1187-
}))
1188-
);
1189-
1190-
// proceed with the update
1191-
const result = await tx[this.model].updateMany(args);
1185+
const entityChecker = this.policyUtils.getEntityChecker(this.model, 'update');
11921186

1193-
// run post-write checks
1194-
await this.runPostWriteChecks(postWriteChecks, tx);
1187+
const canProceedWithoutTransaction =
1188+
// no post-update rules
1189+
!this.policyUtils.hasAuthGuard(this.model, 'postUpdate') &&
1190+
// no Zod schema
1191+
!this.policyUtils.getZodSchema(this.model) &&
1192+
// no entity checker
1193+
!entityChecker;
11951194

1196-
return result;
1197-
});
1198-
} else {
1195+
if (canProceedWithoutTransaction) {
11991196
// proceed without a transaction
12001197
if (this.shouldLogQuery) {
12011198
this.logger.info(`[policy] \`updateMany\` ${this.model}: ${formatObject(args)}`);
12021199
}
12031200
return this.modelClient.updateMany(args);
12041201
}
1202+
1203+
// collect post-update checks
1204+
const postWriteChecks: PostWriteCheckRecord[] = [];
1205+
1206+
return this.queryUtils.transaction(this.prisma, async (tx) => {
1207+
// collect pre-update values
1208+
let select = this.policyUtils.makeIdSelection(this.model);
1209+
const preValueSelect = this.policyUtils.getPreValueSelect(this.model);
1210+
if (preValueSelect) {
1211+
select = { ...select, ...preValueSelect };
1212+
}
1213+
1214+
// merge selection required for running additional checker
1215+
const entityChecker = this.policyUtils.getEntityChecker(this.model, 'update');
1216+
if (entityChecker?.selector) {
1217+
select = deepmerge(select, entityChecker.selector);
1218+
}
1219+
1220+
const currentSetQuery = { select, where: args.where };
1221+
this.policyUtils.injectAuthGuardAsWhere(tx, currentSetQuery, this.model, 'update');
1222+
1223+
if (this.shouldLogQuery) {
1224+
this.logger.info(`[policy] \`findMany\` ${this.model}: ${formatObject(currentSetQuery)}`);
1225+
}
1226+
let candidates = await tx[this.model].findMany(currentSetQuery);
1227+
1228+
if (entityChecker) {
1229+
// filter candidates with additional checker and build an id filter
1230+
const r = this.buildIdFilterWithEntityChecker(candidates, entityChecker.func);
1231+
candidates = r.filteredCandidates;
1232+
1233+
// merge id filter into update's where clause
1234+
args.where = args.where ? { AND: [args.where, r.idFilter] } : r.idFilter;
1235+
}
1236+
1237+
postWriteChecks.push(
1238+
...candidates.map((preValue) => ({
1239+
model: this.model,
1240+
operation: 'postUpdate' as PolicyOperationKind,
1241+
uniqueFilter: this.policyUtils.getEntityIds(this.model, preValue),
1242+
preValue: preValueSelect ? preValue : undefined,
1243+
}))
1244+
);
1245+
1246+
// proceed with the update
1247+
if (this.shouldLogQuery) {
1248+
this.logger.info(`[policy] \`updateMany\` in tx for ${this.model}: ${formatObject(args)}`);
1249+
}
1250+
const result = await tx[this.model].updateMany(args);
1251+
1252+
// run post-write checks
1253+
await this.runPostWriteChecks(postWriteChecks, tx);
1254+
1255+
return result;
1256+
});
12051257
});
12061258
}
12071259

@@ -1328,14 +1380,49 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
13281380
this.policyUtils.tryReject(this.prisma, this.model, 'delete');
13291381

13301382
// inject policy conditions
1331-
args = args ?? {};
1383+
args = clone(args);
13321384
this.policyUtils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'delete');
13331385

1334-
// conduct the deletion
1335-
if (this.shouldLogQuery) {
1336-
this.logger.info(`[policy] \`deleteMany\` ${this.model}:\n${formatObject(args)}`);
1386+
const entityChecker = this.policyUtils.getEntityChecker(this.model, 'delete');
1387+
if (entityChecker) {
1388+
// additional checker exists, need to run deletion inside a transaction
1389+
return this.queryUtils.transaction(this.prisma, async (tx) => {
1390+
// find the delete candidates, selecting id fields and fields needed for
1391+
// running the additional checker
1392+
let candidateSelect = this.policyUtils.makeIdSelection(this.model);
1393+
if (entityChecker.selector) {
1394+
candidateSelect = deepmerge(candidateSelect, entityChecker.selector);
1395+
}
1396+
1397+
if (this.shouldLogQuery) {
1398+
this.logger.info(
1399+
`[policy] \`findMany\` ${this.model}: ${formatObject({
1400+
where: args.where,
1401+
select: candidateSelect,
1402+
})}`
1403+
);
1404+
}
1405+
const candidates = await tx[this.model].findMany({ where: args.where, select: candidateSelect });
1406+
1407+
// build a ID filter based on id values filtered by the additional checker
1408+
const { idFilter } = this.buildIdFilterWithEntityChecker(candidates, entityChecker.func);
1409+
1410+
// merge the ID filter into the where clause
1411+
args.where = args.where ? { AND: [args.where, idFilter] } : idFilter;
1412+
1413+
// finally, conduct the deletion with the combined where clause
1414+
if (this.shouldLogQuery) {
1415+
this.logger.info(`[policy] \`deleteMany\` in tx for ${this.model}:\n${formatObject(args)}`);
1416+
}
1417+
return tx[this.model].deleteMany(args);
1418+
});
1419+
} else {
1420+
// conduct the deletion directly
1421+
if (this.shouldLogQuery) {
1422+
this.logger.info(`[policy] \`deleteMany\` ${this.model}:\n${formatObject(args)}`);
1423+
}
1424+
return this.modelClient.deleteMany(args);
13371425
}
1338-
return this.modelClient.deleteMany(args);
13391426
});
13401427
}
13411428

@@ -1469,7 +1556,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
14691556
if (args.where) {
14701557
// combine runtime filters with generated constraints
14711558

1472-
const extraConstraints: CheckerConstraint[] = [];
1559+
const extraConstraints: PermissionCheckerConstraint[] = [];
14731560
for (const [field, value] of Object.entries(args.where)) {
14741561
if (value === undefined) {
14751562
continue;
@@ -1599,5 +1686,17 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
15991686
}
16001687
}
16011688

1689+
private buildIdFilterWithEntityChecker(candidates: any[], entityChecker: EntityCheckerFunc) {
1690+
const filteredCandidates = candidates.filter((value) => entityChecker(value, { user: this.context?.user }));
1691+
const idFields = this.policyUtils.getIdFields(this.model);
1692+
let idFilter: any;
1693+
if (idFields.length === 1) {
1694+
idFilter = { [idFields[0].name]: { in: filteredCandidates.map((x) => x[idFields[0].name]) } };
1695+
} else {
1696+
idFilter = { AND: filteredCandidates.map((x) => this.policyUtils.getIdFieldValues(this.model, x)) };
1697+
}
1698+
return { filteredCandidates, idFilter };
1699+
}
1700+
16021701
//#endregion
16031702
}

0 commit comments

Comments
 (0)