Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"skipFiles": ["<node_internals>/**"],
"type": "node",
"args": ["generate"],
"cwd": "${workspaceFolder}/samples/blog/zenstack"
"cwd": "${workspaceFolder}/samples/blog"
},
{
"name": "Debug with TSX",
Expand Down
68 changes: 65 additions & 3 deletions packages/cli/src/actions/action-utils.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import { loadDocument } from '@zenstackhq/language';
import { isDataSource } from '@zenstackhq/language/ast';
import { createZModelServices, loadDocument, type ZModelServices } from '@zenstackhq/language';
import { isDataSource, isPlugin, Model } from '@zenstackhq/language/ast';
import { getLiteral } from '@zenstackhq/language/utils';
import { PrismaSchemaGenerator } from '@zenstackhq/sdk';
import colors from 'colors';
import fs from 'node:fs';
import path from 'node:path';
import { fileURLToPath } from 'node:url';
import { CliError } from '../cli-error';
import { PLUGIN_MODULE_NAME } from '../constants';

export function getSchemaFile(file?: string) {
if (file) {
Expand Down Expand Up @@ -34,7 +37,9 @@ export function getSchemaFile(file?: string) {
}

export async function loadSchemaDocument(schemaFile: string) {
const loadResult = await loadDocument(schemaFile);
const { ZModelLanguage: services } = createZModelServices();
const pluginDocs = await getPluginDocuments(services, schemaFile);
const loadResult = await loadDocument(schemaFile, pluginDocs);
if (!loadResult.success) {
loadResult.errors.forEach((err) => {
console.error(colors.red(err));
Expand All @@ -47,6 +52,63 @@ export async function loadSchemaDocument(schemaFile: string) {
return loadResult.model;
}

export async function getPluginDocuments(services: ZModelServices, fileName: string): Promise<string[]> {
// parse the user document (without validation)
const parseResult = services.parser.LangiumParser.parse(fs.readFileSync(fileName, { encoding: 'utf-8' }));
const parsed = parseResult.value as Model;

// balk if there are syntax errors
if (parseResult.lexerErrors.length > 0 || parseResult.parserErrors.length > 0) {
return [];
}

// traverse plugins and collect "plugin.zmodel" documents
const result: string[] = [];
for (const decl of parsed.declarations.filter(isPlugin)) {
const providerField = decl.fields.find((f) => f.name === 'provider');
if (!providerField) {
continue;
}

const provider = getLiteral<string>(providerField.value);
if (!provider) {
continue;
}

let pluginModelFile: string | undefined;

// first try to treat provider as a path
let providerPath = path.resolve(path.dirname(fileName), provider);
if (fs.existsSync(providerPath)) {
if (fs.statSync(providerPath).isDirectory()) {
providerPath = path.join(providerPath, 'index.js');
}

// try plugin.zmodel next to the provider file
pluginModelFile = path.resolve(path.dirname(providerPath), PLUGIN_MODULE_NAME);
if (!fs.existsSync(pluginModelFile)) {
// try to find upwards
pluginModelFile = findUp([PLUGIN_MODULE_NAME], path.dirname(providerPath));
}
}

if (!pluginModelFile) {
// try loading it as a ESM module
try {
const resolvedUrl = import.meta.resolve(`${provider}/${PLUGIN_MODULE_NAME}`);
pluginModelFile = fileURLToPath(resolvedUrl);
} catch {
// noop
}
}

if (pluginModelFile && fs.existsSync(pluginModelFile)) {
result.push(pluginModelFile);
}
}
return result;
}

export function handleSubProcessError(err: unknown) {
if (err instanceof Error && 'status' in err && typeof err.status === 'number') {
process.exit(err.status);
Expand Down
10 changes: 6 additions & 4 deletions packages/cli/src/actions/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ async function runPlugins(schemaFile: string, model: Model, outputPath: string,
for (const plugin of plugins) {
const provider = getPluginProvider(plugin);

let cliPlugin: CliPlugin;
let cliPlugin: CliPlugin | undefined;
if (provider.startsWith('@core/')) {
cliPlugin = (corePlugins as any)[provider.slice('@core/'.length)];
if (!cliPlugin) {
Expand All @@ -78,12 +78,14 @@ async function runPlugins(schemaFile: string, model: Model, outputPath: string,
}
try {
cliPlugin = (await import(moduleSpec)).default as CliPlugin;
} catch (error) {
throw new CliError(`Failed to load plugin ${provider}: ${error}`);
} catch {
// plugin may not export a generator so we simply ignore the error here
}
}

processedPlugins.push({ cliPlugin, pluginOptions: getPluginOptions(plugin) });
if (cliPlugin) {
processedPlugins.push({ cliPlugin, pluginOptions: getPluginOptions(plugin) });
}
}

const defaultPlugins = [corePlugins['typescript']].reverse();
Expand Down
3 changes: 3 additions & 0 deletions packages/cli/src/constants.ts
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
// replaced at build time
export const TELEMETRY_TRACKING_TOKEN = '<TELEMETRY_TRACKING_TOKEN>';

// plugin-contributed model file name
export const PLUGIN_MODULE_NAME = 'plugin.zmodel';
76 changes: 0 additions & 76 deletions packages/language/res/stdlib.zmodel
Original file line number Diff line number Diff line change
Expand Up @@ -174,29 +174,6 @@ function hasSome(field: Any[], search: Any[]): Boolean {
function isEmpty(field: Any[]): Boolean {
} @@@expressionContext([AccessPolicy, ValidationRule])

/**
* The name of the model for which the policy rule is defined. If the rule is
* inherited to a sub model, this function returns the name of the sub model.
*
* @param optional parameter to control the casing of the returned value. Valid
* values are "original", "upper", "lower", "capitalize", "uncapitalize". Defaults
* to "original".
*/
function currentModel(casing: String?): String {
} @@@expressionContext([AccessPolicy])

/**
* The operation for which the policy rule is defined for. Note that a rule with
* "all" operation is expanded to "create", "read", "update", and "delete" rules,
* and the function returns corresponding value for each expanded version.
*
* @param optional parameter to control the casing of the returned value. Valid
* values are "original", "upper", "lower", "capitalize", "uncapitalize". Defaults
* to "original".
*/
function currentOperation(casing: String?): String {
} @@@expressionContext([AccessPolicy])

/**
* Marks an attribute to be only applicable to certain field types.
*/
Expand Down Expand Up @@ -658,56 +635,3 @@ attribute @meta(_ name: String, _ value: Any)
* Marks an attribute as deprecated.
*/
attribute @@@deprecated(_ message: String)

/* --- Policy Plugin --- */

/**
* Defines an access policy that allows a set of operations when the given condition is true.
*
* @param operation: comma-separated list of "create", "read", "update", "delete". Use "all" to denote all operations.
* @param condition: a boolean expression that controls if the operation should be allowed.
*/
attribute @@allow(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'post-update'","'delete'", "'all'"]), _ condition: Boolean)

/**
* Defines an access policy that allows the annotated field to be read or updated.
* You can pass a third argument as `true` to make it override the model-level policies.
*
* @param operation: comma-separated list of "create", "read", "update", "delete". Use "all" to denote all operations.
* @param condition: a boolean expression that controls if the operation should be allowed.
* @param override: a boolean value that controls if the field-level policy should override the model-level policy.
*/
// attribute @allow(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'delete'", "'all'"]), _ condition: Boolean, _ override: Boolean?)

/**
* Defines an access policy that denies a set of operations when the given condition is true.
*
* @param operation: comma-separated list of "create", "read", "update", "delete". Use "all" to denote all operations.
* @param condition: a boolean expression that controls if the operation should be denied.
*/
attribute @@deny(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'post-update'","'delete'", "'all'"]), _ condition: Boolean)

/**
* Defines an access policy that denies the annotated field to be read or updated.
*
* @param operation: comma-separated list of "create", "read", "update", "delete". Use "all" to denote all operations.
* @param condition: a boolean expression that controls if the operation should be denied.
*/
// attribute @deny(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'delete'", "'all'"]), _ condition: Boolean)

/**
* Checks if the current user can perform the given operation on the given field.
*
* @param field: The field to check access for
* @param operation: The operation to check access for. Can be "read", "create", "update", or "delete". If the operation is not provided,
* it defaults the operation of the containing policy rule.
*/
function check(field: Any, operation: String?): Boolean {
} @@@expressionContext([AccessPolicy])

/**
* Gets entity's value before an update. Only valid when used in a "post-update" policy rule.
*/
function before(): Any {
} @@@expressionContext([AccessPolicy])

4 changes: 2 additions & 2 deletions packages/language/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ export async function loadDocument(
);

// load additional model files
const pluginDocs = await Promise.all(
const additionalDocs = await Promise.all(
additionalModelFiles.map((file) =>
services.shared.workspace.LangiumDocuments.getOrCreateDocument(URI.file(path.resolve(file))),
),
Expand All @@ -69,7 +69,7 @@ export async function loadDocument(
}

// build the document together with standard library, plugin modules, and imported documents
await services.shared.workspace.DocumentBuilder.build([stdLib, ...pluginDocs, document, ...importedDocuments], {
await services.shared.workspace.DocumentBuilder.build([stdLib, ...additionalDocs, document, ...importedDocuments], {
validation: {
stopAfterLexingErrors: true,
stopAfterParsingErrors: true,
Expand Down
3 changes: 2 additions & 1 deletion packages/language/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -447,8 +447,9 @@ export function getAuthDecl(decls: (DataModel | TypeDef)[]) {
return authModel;
}

// TODO: move to policy plugin
export function isBeforeInvocation(node: AstNode) {
return isInvocationExpr(node) && node.function.ref?.name === 'before' && isFromStdlib(node.function.ref);
return isInvocationExpr(node) && node.function.ref?.name === 'before';
}

export function isCollectionPredicate(node: AstNode): node is BinaryExpr {
Expand Down
59 changes: 27 additions & 32 deletions packages/language/src/validators/function-invocation-validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import {
getLiteral,
isCheckInvocation,
isDataFieldReference,
isFromStdlib,
typeAssignable,
} from '../utils';
import type { AstValidator } from './common';
Expand Down Expand Up @@ -52,43 +51,39 @@ export default class FunctionInvocationValidator implements AstValidator<Express
return;
}

if (isFromStdlib(funcDecl)) {
// validate standard library functions

// find the containing attribute context for the invocation
let curr: AstNode | undefined = expr.$container;
let containerAttribute: DataModelAttribute | DataFieldAttribute | undefined;
while (curr) {
if (isDataModelAttribute(curr) || isDataFieldAttribute(curr)) {
containerAttribute = curr;
break;
}
curr = curr.$container;
// find the containing attribute context for the invocation
let curr: AstNode | undefined = expr.$container;
let containerAttribute: DataModelAttribute | DataFieldAttribute | undefined;
while (curr) {
if (isDataModelAttribute(curr) || isDataFieldAttribute(curr)) {
containerAttribute = curr;
break;
}
curr = curr.$container;
}

// validate the context allowed for the function
const exprContext = this.getExpressionContext(containerAttribute);
// validate the context allowed for the function
const exprContext = this.getExpressionContext(containerAttribute);

// get the context allowed for the function
const funcAllowedContext = getFunctionExpressionContext(funcDecl);
// get the context allowed for the function
const funcAllowedContext = getFunctionExpressionContext(funcDecl);

if (exprContext && !funcAllowedContext.includes(exprContext)) {
accept('error', `function "${funcDecl.name}" is not allowed in the current context: ${exprContext}`, {
node: expr,
});
return;
}
if (exprContext && !funcAllowedContext.includes(exprContext)) {
accept('error', `function "${funcDecl.name}" is not allowed in the current context: ${exprContext}`, {
node: expr,
});
return;
}

// TODO: express function validation rules declaratively in ZModel
// TODO: express function validation rules declaratively in ZModel

const allCasing = ['original', 'upper', 'lower', 'capitalize', 'uncapitalize'];
if (['currentModel', 'currentOperation'].includes(funcDecl.name)) {
const arg = getLiteral<string>(expr.args[0]?.value);
if (arg && !allCasing.includes(arg)) {
accept('error', `argument must be one of: ${allCasing.map((c) => '"' + c + '"').join(', ')}`, {
node: expr.args[0]!,
});
}
const allCasing = ['original', 'upper', 'lower', 'capitalize', 'uncapitalize'];
if (['currentModel', 'currentOperation'].includes(funcDecl.name)) {
const arg = getLiteral<string>(expr.args[0]?.value);
if (arg && !allCasing.includes(arg)) {
accept('error', `argument must be one of: ${allCasing.map((c) => '"' + c + '"').join(', ')}`, {
node: expr.args[0]!,
});
}
}

Expand Down
8 changes: 6 additions & 2 deletions packages/plugins/policy/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
"author": "ZenStack Team",
"license": "MIT",
"files": [
"dist"
"dist",
"plugin.zmodel"
],
"exports": {
".": {
Expand All @@ -26,14 +27,17 @@
"default": "./dist/index.cjs"
}
},
"./plugin.zmodel": {
"import": "./plugin.zmodel",
"require": "./plugin.zmodel"
},
"./package.json": {
"import": "./package.json",
"require": "./package.json"
}
},
"dependencies": {
"@zenstackhq/common-helpers": "workspace:*",
"@zenstackhq/sdk": "workspace:*",
"@zenstackhq/runtime": "workspace:*",
"ts-pattern": "catalog:"
},
Expand Down
Loading
Loading