Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/ide/jetbrains/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"homepage": "https://zenstack.dev",
"private": true,
"scripts": {
"build": "./gradlew buildPlugin"
"build": "./gradlew buildPlugin"
},
"author": "ZenStack Team",
"license": "MIT",
Expand Down
7 changes: 6 additions & 1 deletion packages/runtime/src/cross/model-meta.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import { lowerCaseFirst } from 'lower-case-first';

/**
* An access key in the user context object (e.g. `profile.picture.url`)
*/
export type AuthContextSelector = string;

/**
* Runtime information of a data model or field attribute
*/
export type RuntimeAttribute = {
name: string;
args: Array<{ name?: string; value: unknown }>;
args: Array<{ name?: string; value: unknown } | { name: 'auth()'; value: AuthContextSelector }>;
};

/**
Expand Down
25 changes: 23 additions & 2 deletions packages/runtime/src/enhancements/create-enhancement.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { withPassword } from './password';
import { withPolicy } from './policy';
import type { ErrorTransformer } from './proxy';
import type { PolicyDef, ZodSchemas } from './types';
import { withDefaultAuth } from './default-auth';

/**
* Kinds of enhancements to `PrismaClient`
Expand All @@ -15,6 +16,7 @@ export enum EnhancementKind {
Password = 'password',
Omit = 'omit',
Policy = 'policy',
DefaultAuth = 'defaultAuth',
}

/**
Expand Down Expand Up @@ -92,6 +94,7 @@ export type EnhancementContext = {

let hasPassword: boolean | undefined = undefined;
let hasOmit: boolean | undefined = undefined;
let hasDefaultAuth: boolean | undefined = undefined;

/**
* Gets a Prisma client enhanced with all enhancement behaviors, including access
Expand Down Expand Up @@ -120,13 +123,26 @@ export function createEnhancement<DbClient extends object>(

let result = prisma;

if (hasPassword === undefined || hasOmit === undefined) {
if (
process.env.ZENSTACK_TEST === '1' || // avoid caching in tests
hasPassword === undefined ||
hasOmit === undefined ||
hasDefaultAuth === undefined
) {
const allFields = Object.values(options.modelMeta.fields).flatMap((modelInfo) => Object.values(modelInfo));
hasPassword = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@password'));
hasOmit = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@omit'));
hasDefaultAuth = allFields.some((field) =>
field.attributes?.some((attr) => attr.name === '@default' && attr.args[0]?.name === 'auth()')
);
}

const kinds = options.kinds ?? [EnhancementKind.Password, EnhancementKind.Omit, EnhancementKind.Policy];
const kinds = options.kinds ?? [
EnhancementKind.Password,
EnhancementKind.Omit,
EnhancementKind.Policy,
EnhancementKind.DefaultAuth,
];

if (hasPassword && kinds.includes(EnhancementKind.Password)) {
// @password proxy
Expand All @@ -138,6 +154,11 @@ export function createEnhancement<DbClient extends object>(
result = withOmit(result, options);
}

if (hasDefaultAuth && kinds.includes(EnhancementKind.DefaultAuth)) {
// @default(auth()) proxy
result = withDefaultAuth(result, options, context);
}

// policy proxy
if (kinds.includes(EnhancementKind.Policy)) {
result = withPolicy(result, options, context);
Expand Down
83 changes: 83 additions & 0 deletions packages/runtime/src/enhancements/default-auth.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/* eslint-disable @typescript-eslint/no-unused-vars */
/* eslint-disable @typescript-eslint/no-explicit-any */

import { NestedWriteVisitor, PrismaWriteActionType, FieldInfo } from '../cross';
import { DbClientContract } from '../types';
import { EnhancementContext, EnhancementOptions } from './create-enhancement';
import { DefaultPrismaProxyHandler, PrismaProxyActions, makeProxy } from './proxy';

/**
* Gets an enhanced Prisma client that supports `@default(auth())` attribute.
*
* @private
*/
export function withDefaultAuth<DbClient extends object>(
prisma: DbClient,
options: EnhancementOptions,
context?: EnhancementContext
): DbClient {
return makeProxy(
prisma,
options.modelMeta,
(_prisma, model) => new DefaultAuthHandler(_prisma as DbClientContract, model, options, context),
'defaultAuth'
);
}

class DefaultAuthHandler extends DefaultPrismaProxyHandler {
private readonly db: DbClientContract;
constructor(
prisma: DbClientContract,
model: string,
private readonly options: EnhancementOptions,
private readonly context?: EnhancementContext
) {
super(prisma, model);
this.db = prisma;
}

// base override
protected async preprocessArgs(action: PrismaProxyActions, args: any) {
const actionsOfInterest: PrismaProxyActions[] = ['create', 'createMany', 'update', 'updateMany', 'upsert'];
if (actionsOfInterest.includes(action)) {
await this.preprocessWritePayload(this.model, action as PrismaWriteActionType, args);
}
return args;
}

private async preprocessWritePayload(model: string, action: PrismaWriteActionType, args: any) {
const visitor = new NestedWriteVisitor(this.options.modelMeta, {
field: async (field, action, data, context) => {
const userContext = this.context?.user;
if (!userContext) {
throw new Error(`Invalid user context`);
}
const fields = this.options.modelMeta.fields[model];
const isDefaultAuthField = (fieldInfo: FieldInfo) =>
fieldInfo.attributes?.find((attr) => attr.name === '@default' && attr.args?.[0]?.name === 'auth()');
const defaultAuthSelectorFields = Object.fromEntries(
Object.entries(fields)
.filter(([_, fieldInfo]) => isDefaultAuthField(fieldInfo))
.map(([field, fieldInfo]) => [
field,
fieldInfo.attributes?.find((attr) => attr.name === '@default')?.args[0]?.value as
| string
| undefined,
])
);
const defaultAuthFields = Object.fromEntries(
Object.entries(defaultAuthSelectorFields).map(([field, selector]) => [
field,
selector ? userContext[selector] : userContext,
])
);
console.log('defaultAuthFields :', defaultAuthFields);
for (const [field, defaultValue] of Object.entries(defaultAuthFields)) {
context.parent[field] = defaultValue;
}
},
});

await visitor.visit(model, action, args);
}
}
1 change: 0 additions & 1 deletion packages/schema/src/plugins/prisma/prisma-builder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,6 @@ export class FunctionCallArg {
return this.name ? `${this.name}: ${this.value}` : this.value;
}
}

export class Enum extends ContainerDeclaration {
public fields: EnumField[] = [];

Expand Down
12 changes: 11 additions & 1 deletion packages/schema/src/plugins/prisma/schema-generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ import {

const MODEL_PASSTHROUGH_ATTR = '@@prisma.passthrough';
const FIELD_PASSTHROUGH_ATTR = '@prisma.passthrough';
const NO_FIELD_ATTRIBUTE = '';

/**
* Generates Prisma schema file
Expand Down Expand Up @@ -285,7 +286,12 @@ export default class PrismaSchemaGenerator {
!!attrDecl.attributes.find((a) => a.decl.ref?.name === '@@@prisma') ||
// the special pass-through attribute
attrDecl.name === MODEL_PASSTHROUGH_ATTR ||
attrDecl.name === FIELD_PASSTHROUGH_ATTR
attrDecl.name === FIELD_PASSTHROUGH_ATTR ||
// auth() in @default() is not supported by Prisma
// FIXME: condition is inverted to avoid error...
!!attrDecl.attributes.find(
(a) => a.decl.ref?.name === '@default' && a.args[0].value.$cstNode?.text.startsWith('auth()')
)
);
}

Expand Down Expand Up @@ -334,6 +340,10 @@ export default class PrismaSchemaGenerator {
} else {
throw new PluginError(name, `Invalid arguments for ${FIELD_PASSTHROUGH_ATTR} attribute`);
}
// do not write @default(auth()) field attribute as it is not supported by Prisma
// TODO: we should add a comment to the field
} else if (attrName === '@default' && attr.args[0].value.$cstNode?.text.startsWith('auth()')) {
return new PrismaPassThroughAttribute(NO_FIELD_ATTRIBUTE);
} else {
return new PrismaFieldAttribute(
attrName,
Expand Down
4 changes: 2 additions & 2 deletions packages/schema/src/res/stdlib.zmodel
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ function env(name: String): String {
* Gets the current login user.
*/
function auth(): Any {
} @@@expressionContext([AccessPolicy])
} @@@expressionContext([DefaultValue, AccessPolicy])

/**
* Gets current date-time (as DateTime type).
Expand Down Expand Up @@ -204,7 +204,7 @@ attribute @id(map: String?, length: Int?, sort: String?, clustered: Boolean?) @@

/**
* Defines a default value for a field.
* @param value: An expression (e.g. 5, true, now()).
* @param value: An expression (e.g. 5, true, now(), auth()).
*/
attribute @default(_ value: ContextType, map: String?) @@@prisma

Expand Down
2 changes: 2 additions & 0 deletions packages/schema/tests/generator/prisma-generator.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ describe('Prisma generator test', () => {
id String @id @default(nanoid(6))
x String @default(nanoid())
y String @default(dbgenerated("gen_random_uuid()"))
z String @default(auth().id)
}
`);

Expand All @@ -142,6 +143,7 @@ describe('Prisma generator test', () => {
expect(content).toContain('@default(nanoid(6))');
expect(content).toContain('@default(nanoid())');
expect(content).toContain('@default(dbgenerated("gen_random_uuid()"))');
expect(content).not.toContain('@default(auth().id)');
});

it('triple slash comments', async () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,35 @@ describe('Attribute tests', () => {
});

it('auth function check', async () => {
await loadModel(`
${prelude}

model User {
id String @id
name String
}
model B {
id String @id
userId String @default(auth().id)
userName String @default(auth().name)
}
`);

// expect(
// await loadModelWithError(`
// ${prelude}

// model User {
// id String @id
// name String
// }
// model B {
// id String @id
// userData String @default(auth())
// }
// `)
// ).toContain("Value is not assignable to parameter");

expect(
await loadModelWithError(`
${prelude}
Expand Down Expand Up @@ -1124,14 +1153,14 @@ describe('Attribute tests', () => {
});

it('incorrect function expression context', async () => {
expect(
await loadModelWithError(`
${prelude}
model M {
id String @id @default(auth())
}
`)
).toContain('function "auth" is not allowed in the current context: DefaultValue');
// expect(
// await loadModelWithError(`
// ${prelude}
// model M {
// id String @id @default(auth())
// }
// `)
// ).toContain('function "auth" is not allowed in the current context: DefaultValue');

expect(
await loadModelWithError(`
Expand Down
7 changes: 7 additions & 0 deletions packages/sdk/src/model-meta-generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,13 @@ function getFieldAttributes(field: DataModelField): RuntimeAttribute[] {
args.push({ name: arg.name, value: v });
} else if (isStringLiteral(arg.value) || isBooleanLiteral(arg.value)) {
args.push({ name: arg.name, value: arg.value.value });
} else if (
attr.decl.ref?.name === '@default' &&
attr.args[0].value.$cstNode?.text.startsWith('auth()')
) {
const authValue = attr.args[0].value.$cstNode?.text;
const authSelector = authValue === 'auth()' ? authValue : authValue.slice('auth().'.length);
args.push({ name: 'auth()', value: authSelector });
} else {
// non-literal args are ignored
}
Expand Down
55 changes: 55 additions & 0 deletions tests/integration/tests/enhancements/with-policy/auth.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -363,4 +363,59 @@ describe('With Policy: auth() test', () => {
enhance({ id: '1', posts: [{ id: '1', published: true, comments: [] }] }).post.create(createPayload)
).toResolveTruthy();
});

it('Default auth() on literal fields', async () => {
const { enhance } = await loadSchema(
`
model User {
id String @id
name String
score Int

}

model Post {
id String @id @default(uuid())
title String
score Int? @default(auth().score)
authorName String? @default(auth().name)

@@allow('all', true)
}
`
);

const userDb = enhance({ id: '1', name: 'user1', score: 10 });
await expect(userDb.post.create({ data: { title: 'abc' } })).toResolveTruthy();
await expect(userDb.post.findMany()).resolves.toHaveLength(1);
await expect(userDb.post.count({ where: { authorName: 'user1', score: 10 } })).resolves.toBe(1);
});

it('Default auth() with foreign key', async () => {
const { enhance } = await loadSchema(
`
model User {
id String @id
posts Post[]

@@allow('all', true)

}

model Post {
id String @id @default(uuid())
title String
author User @relation(fields: [authorId], references: [id])
authorId String @default(auth().id)

@@allow('all', true)
}
`
);

const db = enhance({ id: 'userId-1' });
await expect(db.user.create({ data: { id: 'userId-1' } })).toResolveTruthy();
await expect(db.post.create({ data: { title: 'abc' } })).toResolveTruthy();
await expect(db.post.count({ where: { authorId: 'userId-1' } })).resolves.toBe(1);
});
});