Skip to content

Commit 52aec66

Browse files
authored
feat: got policies to work with todo sample (#1)
* WIP * got todo sample running with policies
1 parent aa7630c commit 52aec66

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+2993
-925
lines changed

packages/cli/package.json

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,10 @@
2727
"pack": "pnpm pack"
2828
},
2929
"dependencies": {
30-
"@types/node": "^20.12.7",
30+
"@types/node": "^18.0.0",
3131
"@zenstackhq/language": "workspace:*",
3232
"@zenstackhq/runtime": "workspace:*",
33+
"@zenstackhq/sdk": "workspace:*",
3334
"async-exit-hook": "^2.0.1",
3435
"colors": "1.4.0",
3536
"commander": "^8.3.0",
@@ -43,6 +44,7 @@
4344
"typescript": "^5.0.0"
4445
},
4546
"devDependencies": {
47+
"@zenstackhq/testtools": "workspace:*",
4648
"@types/async-exit-hook": "^2.0.0",
4749
"@types/better-sqlite3": "^7.6.13",
4850
"@types/semver": "^7.3.13",

packages/cli/src/actions/generate.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import { isPlugin, LiteralExpr, type Model } from '@zenstackhq/language/ast';
22
import type { CliGenerator } from '@zenstackhq/runtime/client';
3+
import { TsSchemaGenerator } from '@zenstackhq/sdk';
34
import colors from 'colors';
45
import fs from 'node:fs';
56
import path from 'node:path';
67
import invariant from 'tiny-invariant';
78
import { PrismaSchemaGenerator } from '../prisma/prisma-schema-generator';
8-
import { TsSchemaGenerator } from '../zmodel/ts-schema-generator';
99
import { getSchemaFile, loadSchemaDocument } from './action-utils';
1010

1111
type Options = {
@@ -25,7 +25,7 @@ export async function run(options: Options) {
2525

2626
// generate TS schema
2727
const tsSchemaFile = path.join(outputPath, 'schema.ts');
28-
await new TsSchemaGenerator().generate(schemaFile, tsSchemaFile);
28+
await new TsSchemaGenerator().generate(schemaFile, [], tsSchemaFile);
2929

3030
await runPlugins(model, outputPath, tsSchemaFile);
3131

packages/cli/src/prisma/prisma-schema-generator.ts

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,7 @@ import {
3232
import { AstUtils } from 'langium';
3333
import { match, P } from 'ts-pattern';
3434

35-
import {
36-
hasAttribute,
37-
isDelegateModel,
38-
isIdField,
39-
} from '../zmodel/model-utils';
40-
import { ZModelCodeGenerator } from '../zmodel/zmodel-code-generator';
35+
import { ModelUtils, ZModelCodeGenerator } from '@zenstackhq/sdk';
4136
import {
4237
AttributeArgValue,
4338
ModelField,
@@ -165,7 +160,7 @@ export class PrismaSchemaGenerator {
165160
? prisma.addView(decl.name)
166161
: prisma.addModel(decl.name);
167162
for (const field of decl.fields) {
168-
if (hasAttribute(field, '@computed')) {
163+
if (ModelUtils.hasAttribute(field, '@computed')) {
169164
continue; // skip computed fields
170165
}
171166
// TODO: exclude fields inherited from delegate
@@ -274,7 +269,7 @@ export class PrismaSchemaGenerator {
274269
(attr) =>
275270
// when building physical schema, exclude `@default` for id fields inherited from delegate base
276271
!(
277-
isIdField(field) &&
272+
ModelUtils.isIdField(field) &&
278273
this.isInheritedFromDelegate(field) &&
279274
attr.decl.$refText === '@default'
280275
)
@@ -360,7 +355,10 @@ export class PrismaSchemaGenerator {
360355
}
361356

362357
private isInheritedFromDelegate(field: DataModelField) {
363-
return field.$inheritedFrom && isDelegateModel(field.$inheritedFrom);
358+
return (
359+
field.$inheritedFrom &&
360+
ModelUtils.isDelegateModel(field.$inheritedFrom)
361+
);
364362
}
365363

366364
private makeFieldAttribute(attr: DataModelFieldAttribute) {

packages/cli/test/ts-schema-gen.test.ts

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { describe, expect, it } from 'vitest';
2-
import { generateTsSchema } from './utils';
2+
import { generateTsSchema } from '@zenstackhq/testtools';
33

44
describe('TypeScript schema generation tests', () => {
55
it('generates correct data models', async () => {
@@ -144,8 +144,7 @@ model Post {
144144
kind: 'array',
145145
items: [
146146
{
147-
kind: 'ref',
148-
model: 'Post',
147+
kind: 'field',
149148
field: 'authorId',
150149
},
151150
],
@@ -157,8 +156,7 @@ model Post {
157156
kind: 'array',
158157
items: [
159158
{
160-
kind: 'ref',
161-
model: 'User',
159+
kind: 'field',
162160
field: 'id',
163161
},
164162
],
@@ -167,9 +165,8 @@ model Post {
167165
{
168166
name: 'onDelete',
169167
value: {
170-
kind: 'ref',
171-
model: 'ReferentialAction',
172-
field: 'Cascade',
168+
kind: 'literal',
169+
value: 'Cascade',
173170
},
174171
},
175172
],

packages/cli/tsup.config.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@ export default defineConfig({
99
sourcemap: true,
1010
clean: true,
1111
dts: true,
12-
format: ['esm'],
12+
format: ['esm', 'cjs'],
1313
});

packages/language/res/stdlib.zmodel

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -707,3 +707,15 @@ attribute @json() @@@targetField([TypeDefField])
707707
* Marks a field to be computed.
708708
*/
709709
attribute @computed()
710+
711+
/**
712+
* Gets the current login user.
713+
*/
714+
function auth(): Any {
715+
} @@@expressionContext([DefaultValue, AccessPolicy])
716+
717+
/**
718+
* Used to specify the model for resolving `auth()` function call in access policies. A Zmodel
719+
* can have at most one model with this attribute. By default, the model named "User" is used.
720+
*/
721+
attribute @@auth() @@@supportTypeDef

packages/language/src/index.ts

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ export class DocumentLoadError extends Error {
1818
}
1919

2020
export async function loadDocument(
21-
fileName: string
21+
fileName: string,
22+
pluginModelFiles: string[] = []
2223
): Promise<
2324
| { success: true; model: Model; warnings: string[] }
2425
| { success: false; errors: string[]; warnings: string[] }
@@ -55,16 +56,28 @@ export async function loadDocument(
5556
)
5657
);
5758

58-
const langiumDocuments = services.shared.workspace.LangiumDocuments;
59+
// load plugin model files
60+
const pluginDocs = await Promise.all(
61+
pluginModelFiles.map((file) =>
62+
services.shared.workspace.LangiumDocuments.getOrCreateDocument(
63+
URI.file(path.resolve(file))
64+
)
65+
)
66+
);
67+
5968
// load the document
69+
const langiumDocuments = services.shared.workspace.LangiumDocuments;
6070
const document = await langiumDocuments.getOrCreateDocument(
6171
URI.file(path.resolve(fileName))
6272
);
6373

6474
// build the document together with standard library, plugin modules, and imported documents
65-
await services.shared.workspace.DocumentBuilder.build([stdLib, document], {
66-
validation: true,
67-
});
75+
await services.shared.workspace.DocumentBuilder.build(
76+
[stdLib, ...pluginDocs, document],
77+
{
78+
validation: true,
79+
}
80+
);
6881

6982
const diagnostics = langiumDocuments.all
7083
.flatMap((doc) =>

packages/language/src/utils.ts

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,13 @@ export function isFromStdlib(node: AstNode) {
7676
);
7777
}
7878

79-
// export function isAuthInvocation(node: AstNode) {
80-
// return (
81-
// isInvocationExpr(node) &&
82-
// node.function.ref?.name === 'auth' &&
83-
// isFromStdlib(node.function.ref)
84-
// );
85-
// }
79+
export function isAuthInvocation(node: AstNode) {
80+
return (
81+
isInvocationExpr(node) &&
82+
node.function.ref?.name === 'auth' &&
83+
isFromStdlib(node.function.ref)
84+
);
85+
}
8686

8787
/**
8888
* Try getting string value from a potential string literal expression
@@ -161,12 +161,12 @@ export function mapBuiltinTypeToExpressionType(
161161
}
162162
}
163163

164-
// export function isAuthOrAuthMemberAccess(expr: Expression): boolean {
165-
// return (
166-
// isAuthInvocation(expr) ||
167-
// (isMemberAccessExpr(expr) && isAuthOrAuthMemberAccess(expr.operand))
168-
// );
169-
// }
164+
export function isAuthOrAuthMemberAccess(expr: Expression): boolean {
165+
return (
166+
isAuthInvocation(expr) ||
167+
(isMemberAccessExpr(expr) && isAuthOrAuthMemberAccess(expr.operand))
168+
);
169+
}
170170

171171
export function isEnumFieldReference(node: AstNode): node is ReferenceExpr {
172172
return isReferenceExpr(node) && isEnumField(node.target.ref);
@@ -598,13 +598,13 @@ export function getAllDeclarationsIncludingImports(
598598
return model.declarations.concat(...imports.map((imp) => imp.declarations));
599599
}
600600

601-
// export function getAuthDecl(decls: (DataModel | TypeDef)[]) {
602-
// let authModel = decls.find((m) => hasAttribute(m, '@@auth'));
603-
// if (!authModel) {
604-
// authModel = decls.find((m) => m.name === 'User');
605-
// }
606-
// return authModel;
607-
// }
601+
export function getAuthDecl(decls: (DataModel | TypeDef)[]) {
602+
let authModel = decls.find((m) => hasAttribute(m, '@@auth'));
603+
if (!authModel) {
604+
authModel = decls.find((m) => m.name === 'User');
605+
}
606+
return authModel;
607+
}
608608

609609
export function isFutureInvocation(node: AstNode) {
610610
return (

packages/language/src/validators/expression-validator.ts

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ import {
1919
import {
2020
findUpAst,
2121
getAttributeArgLiteral,
22+
isAuthInvocation,
23+
isAuthOrAuthMemberAccess,
2224
isDataModelFieldReference,
2325
isEnumFieldReference,
2426
typeAssignable,
@@ -32,34 +34,32 @@ export default class ExpressionValidator implements AstValidator<Expression> {
3234
validate(expr: Expression, accept: ValidationAcceptor): void {
3335
// deal with a few cases where reference resolution fail silently
3436
if (!expr.$resolvedType) {
35-
// TODO: revisit this
36-
// if (isAuthInvocation(expr)) {
37-
// // check was done at link time
38-
// accept(
39-
// 'error',
40-
// 'auth() cannot be resolved because no model marked with "@@auth()" or named "User" is found',
41-
// { node: expr }
42-
// );
43-
// } else {
44-
45-
const hasReferenceResolutionError = AstUtils.streamAst(expr).some(
46-
(node) => {
37+
if (isAuthInvocation(expr)) {
38+
// check was done at link time
39+
accept(
40+
'error',
41+
'auth() cannot be resolved because no model marked with "@@auth()" or named "User" is found',
42+
{ node: expr }
43+
);
44+
} else {
45+
const hasReferenceResolutionError = AstUtils.streamAst(
46+
expr
47+
).some((node) => {
4748
if (isMemberAccessExpr(node)) {
4849
return !!node.member.error;
4950
}
5051
if (isReferenceExpr(node)) {
5152
return !!node.target.error;
5253
}
5354
return false;
54-
}
55-
);
56-
if (!hasReferenceResolutionError) {
57-
// report silent errors not involving linker errors
58-
accept('error', 'Expression cannot be resolved', {
59-
node: expr,
6055
});
56+
if (!hasReferenceResolutionError) {
57+
// report silent errors not involving linker errors
58+
accept('error', 'Expression cannot be resolved', {
59+
node: expr,
60+
});
61+
}
6162
}
62-
// }
6363
}
6464

6565
// extra validations by expression type
@@ -379,9 +379,8 @@ export default class ExpressionValidator implements AstValidator<Expression> {
379379
isEnumFieldReference(expr) ||
380380
// null
381381
isNullExpr(expr) ||
382-
// TODO: revise cross-model field comparison
383-
// // `auth()` access
384-
// isAuthOrAuthMemberAccess(expr) ||
382+
// `auth()` access
383+
isAuthOrAuthMemberAccess(expr) ||
385384
// array
386385
(isArrayExpr(expr) &&
387386
expr.items.every((item) => this.isNotModelFieldExpr(item)))

packages/language/src/zmodel-linker.ts

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,10 @@ import {
5555
} from './ast';
5656
import {
5757
getAllLoadedAndReachableDataModelsAndTypeDefs,
58+
getAuthDecl,
5859
getContainingDataModel,
5960
getModelFieldsWithBases,
61+
isAuthInvocation,
6062
isFutureExpr,
6163
isMemberContainer,
6264
mapBuiltinTypeToExpressionType,
@@ -360,23 +362,20 @@ export class ZModelLinker extends DefaultLinker {
360362
// eslint-disable-next-line @typescript-eslint/ban-types
361363
const funcDecl = node.function.ref as FunctionDecl;
362364

363-
// TODO: revisit this
364-
// if (isAuthInvocation(node)) {
365-
// // auth() function is resolved against all loaded and reachable documents
365+
if (isAuthInvocation(node)) {
366+
// auth() function is resolved against all loaded and reachable documents
366367

367-
// // get all data models from loaded and reachable documents
368-
// const allDecls = getAllLoadedAndReachableDataModelsAndTypeDefs(
369-
// this.langiumDocuments(),
370-
// AstUtils.getContainerOfType(node, isDataModel)
371-
// );
372-
373-
// const authDecl = getAuthDecl(allDecls);
374-
// if (authDecl) {
375-
// node.$resolvedType = { decl: authDecl, nullable: true };
376-
// }
377-
// } else
368+
// get all data models from loaded and reachable documents
369+
const allDecls = getAllLoadedAndReachableDataModelsAndTypeDefs(
370+
this.langiumDocuments(),
371+
AstUtils.getContainerOfType(node, isDataModel)
372+
);
378373

379-
if (isFutureExpr(node)) {
374+
const authDecl = getAuthDecl(allDecls);
375+
if (authDecl) {
376+
node.$resolvedType = { decl: authDecl, nullable: true };
377+
}
378+
} else if (isFutureExpr(node)) {
380379
// future() function is resolved to current model
381380
node.$resolvedType = { decl: getContainingDataModel(node) };
382381
} else {
@@ -413,13 +412,11 @@ export class ZModelLinker extends DefaultLinker {
413412
// member access is resolved only in the context of the operand type
414413
if (node.member.ref) {
415414
this.resolveToDeclaredType(node, node.member.ref.type);
416-
417-
// TODO: revisit this
418-
// if (node.$resolvedType && isAuthInvocation(node.operand)) {
419-
// // member access on auth() function is nullable
420-
// // because user may not have provided all fields
421-
// node.$resolvedType.nullable = true;
422-
// }
415+
if (node.$resolvedType && isAuthInvocation(node.operand)) {
416+
// member access on auth() function is nullable
417+
// because user may not have provided all fields
418+
node.$resolvedType.nullable = true;
419+
}
423420
}
424421
}
425422
}

0 commit comments

Comments
 (0)