diff --git a/packages/schema/src/language-server/validator/expression-validator.ts b/packages/schema/src/language-server/validator/expression-validator.ts index 478db5ff7..7f24f1f60 100644 --- a/packages/schema/src/language-server/validator/expression-validator.ts +++ b/packages/schema/src/language-server/validator/expression-validator.ts @@ -4,6 +4,7 @@ import { DataModelAttribute, Expression, ExpressionType, + isArrayExpr, isDataModel, isDataModelAttribute, isDataModelField, @@ -82,6 +83,8 @@ export default class ExpressionValidator implements AstValidator { node: expr.right, }); } + + this.validateCrossModelFieldComparison(expr, accept); break; } @@ -137,6 +140,7 @@ export default class ExpressionValidator implements AstValidator { accept('error', 'incompatible operand types', { node: expr }); } + this.validateCrossModelFieldComparison(expr, accept); break; } @@ -158,43 +162,8 @@ export default class ExpressionValidator implements AstValidator { break; } - // not supported: - // - foo.a == bar - // - foo.user.id == userId - // except: - // - future().userId == userId - if ( - (isMemberAccessExpr(expr.left) && - isDataModelField(expr.left.member.ref) && - expr.left.member.ref.$container != getContainingDataModel(expr)) || - (isMemberAccessExpr(expr.right) && - isDataModelField(expr.right.member.ref) && - expr.right.member.ref.$container != getContainingDataModel(expr)) - ) { - // foo.user.id == auth().id - // foo.user.id == "123" - // foo.user.id == null - // foo.user.id == EnumValue - if (!(this.isNotModelFieldExpr(expr.left) || this.isNotModelFieldExpr(expr.right))) { - const containingPolicyAttr = findUpAst( - expr, - (node) => isDataModelAttribute(node) && ['@@allow', '@@deny'].includes(node.decl.$refText) - ) as DataModelAttribute | undefined; - - if (containingPolicyAttr) { - const operation = getAttributeArgLiteral(containingPolicyAttr, 'operation'); - if (operation?.split(',').includes('all') || operation?.split(',').includes('read')) { - accept( - 'error', - 'comparison between fields of different models is not supported in model-level "read" rules', - { - node: expr, - } - ); - break; - } - } - } + if (!this.validateCrossModelFieldComparison(expr, accept)) { + break; } if ( @@ -262,6 +231,49 @@ export default class ExpressionValidator implements AstValidator { } } + private validateCrossModelFieldComparison(expr: BinaryExpr, accept: ValidationAcceptor) { + // not supported in "read" rules: + // - foo.a == bar + // - foo.user.id == userId + // except: + // - future().userId == userId + if ( + (isMemberAccessExpr(expr.left) && + isDataModelField(expr.left.member.ref) && + expr.left.member.ref.$container != getContainingDataModel(expr)) || + (isMemberAccessExpr(expr.right) && + isDataModelField(expr.right.member.ref) && + expr.right.member.ref.$container != getContainingDataModel(expr)) + ) { + // foo.user.id == auth().id + // foo.user.id == "123" + // foo.user.id == null + // foo.user.id == EnumValue + if (!(this.isNotModelFieldExpr(expr.left) || this.isNotModelFieldExpr(expr.right))) { + const containingPolicyAttr = findUpAst( + expr, + (node) => isDataModelAttribute(node) && ['@@allow', '@@deny'].includes(node.decl.$refText) + ) as DataModelAttribute | undefined; + + if (containingPolicyAttr) { + const operation = getAttributeArgLiteral(containingPolicyAttr, 'operation'); + if (operation?.split(',').includes('all') || operation?.split(',').includes('read')) { + accept( + 'error', + 'comparison between fields of different models is not supported in model-level "read" rules', + { + node: expr, + } + ); + return false; + } + } + } + } + + return true; + } + private validateCollectionPredicate(expr: BinaryExpr, accept: ValidationAcceptor) { if (!expr.$resolvedType) { accept('error', 'collection predicate can only be used on an array of model type', { node: expr }); @@ -273,9 +285,18 @@ export default class ExpressionValidator implements AstValidator { return findUpAst(node, (n) => isDataModelAttribute(n) && n.decl.$refText === '@@validate'); } - private isNotModelFieldExpr(expr: Expression) { + private isNotModelFieldExpr(expr: Expression): boolean { return ( - isLiteralExpr(expr) || isEnumFieldReference(expr) || isNullExpr(expr) || this.isAuthOrAuthMemberAccess(expr) + // literal + isLiteralExpr(expr) || + // enum field + isEnumFieldReference(expr) || + // null + isNullExpr(expr) || + // `auth()` access + this.isAuthOrAuthMemberAccess(expr) || + // array + (isArrayExpr(expr) && expr.items.every((item) => this.isNotModelFieldExpr(item))) ); } diff --git a/packages/schema/tests/schema/validation/attribute-validation.test.ts b/packages/schema/tests/schema/validation/attribute-validation.test.ts index b2ac1544b..778f37d7a 100644 --- a/packages/schema/tests/schema/validation/attribute-validation.test.ts +++ b/packages/schema/tests/schema/validation/attribute-validation.test.ts @@ -701,6 +701,37 @@ describe('Attribute tests', () => { `) ).toContain('comparison between fields of different models is not supported in model-level "read" rules'); + expect( + await loadModelWithError(` + ${prelude} + model User { + id Int @id + lists List[] + todos Todo[] + value Int + } + + model List { + id Int @id + user User @relation(fields: [userId], references: [id]) + userId Int + todos Todo[] + } + + model Todo { + id Int @id + user User @relation(fields: [userId], references: [id]) + userId Int + list List @relation(fields: [listId], references: [id]) + listId Int + value Int + + @@allow('all', list.user.value > value) + } + + `) + ).toContain('comparison between fields of different models is not supported in model-level "read" rules'); + expect( await loadModel(` ${prelude} diff --git a/tests/regression/tests/issue-1506.test.ts b/tests/regression/tests/issue-1506.test.ts index b9866a6f8..e8154c676 100644 --- a/tests/regression/tests/issue-1506.test.ts +++ b/tests/regression/tests/issue-1506.test.ts @@ -1,8 +1,9 @@ -import { loadSchema } from '@zenstackhq/testtools'; +import { loadModelWithError } from '@zenstackhq/testtools'; describe('issue 1506', () => { it('regression', async () => { - const { prisma, enhance } = await loadSchema( - ` + await expect( + loadModelWithError( + ` model A { id Int @id @default(autoincrement()) value Int @@ -29,29 +30,10 @@ describe('issue 1506', () => { @@allow('read', true) } - `, - { preserveTsFiles: true, logPrismaQuery: true } + ` + ) + ).resolves.toContain( + 'comparison between fields of different models is not supported in model-level "read" rules' ); - - await prisma.a.create({ - data: { - value: 3, - b: { - create: { - value: 2, - c: { - create: { - value: 1, - }, - }, - }, - }, - }, - }); - - const db = enhance(); - const read = await db.a.findMany({ include: { b: true } }); - expect(read).toHaveLength(1); - expect(read[0].b).toBeTruthy(); }); });