Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 101 additions & 154 deletions packages/runtime/src/client/executor/zenstack-query-executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,21 +65,47 @@ export class ZenStackQueryExecutor<Schema extends SchemaDef> extends DefaultQuer
return this.client.$options;
}

override async executeQuery(compiledQuery: CompiledQuery, queryId: QueryId) {
override executeQuery(compiledQuery: CompiledQuery, queryId: QueryId) {
// proceed with the query with kysely interceptors
// if the query is a raw query, we need to carry over the parameters
const queryParams = (compiledQuery as any).$raw ? compiledQuery.parameters : undefined;
const result = await this.proceedQueryWithKyselyInterceptors(compiledQuery.query, queryParams, queryId.queryId);

return result.result;
return this.provideConnection(async (connection) => {
let startedTx = false;
try {
// mutations are wrapped in tx if not already in one
if (this.isMutationNode(compiledQuery.query) && !this.driver.isTransactionConnection(connection)) {
await this.driver.beginTransaction(connection, {
isolationLevel: TransactionIsolationLevel.RepeatableRead,
});
startedTx = true;
}
const result = await this.proceedQueryWithKyselyInterceptors(
connection,
compiledQuery.query,
queryParams,
queryId.queryId,
);
if (startedTx) {
await this.driver.commitTransaction(connection);
}
return result;
} catch (err) {
if (startedTx) {
await this.driver.rollbackTransaction(connection);
}
throw err;
}
});
}

private async proceedQueryWithKyselyInterceptors(
connection: DatabaseConnection,
queryNode: RootOperationNode,
parameters: readonly unknown[] | undefined,
queryId: string,
) {
let proceed = (q: RootOperationNode) => this.proceedQuery(q, parameters, queryId);
let proceed = (q: RootOperationNode) => this.proceedQuery(connection, q, parameters, queryId);

const hooks: OnKyselyQueryCallback<Schema>[] = [];
// tsc perf
Expand All @@ -92,18 +118,14 @@ export class ZenStackQueryExecutor<Schema extends SchemaDef> extends DefaultQuer
for (const hook of hooks) {
const _proceed = proceed;
proceed = async (query: RootOperationNode) => {
const _p = async (q: RootOperationNode) => {
const r = await _proceed(q);
return r.result;
};

const _p = (q: RootOperationNode) => _proceed(q);
const hookResult = await hook!({
client: this.client as ClientContract<Schema>,
schema: this.client.$schema,
query,
proceed: _p,
});
return { result: hookResult };
return hookResult;
};
}

Expand Down Expand Up @@ -132,157 +154,88 @@ export class ZenStackQueryExecutor<Schema extends SchemaDef> extends DefaultQuer
return { model, action, where };
}

private async proceedQuery(query: RootOperationNode, parameters: readonly unknown[] | undefined, queryId: string) {
private async proceedQuery(
connection: DatabaseConnection,
query: RootOperationNode,
parameters: readonly unknown[] | undefined,
queryId: string,
) {
let compiled: CompiledQuery | undefined;

try {
return await this.provideConnection(async (connection) => {
if (this.suppressMutationHooks || !this.isMutationNode(query) || !this.hasEntityMutationPlugins) {
// no need to handle mutation hooks, just proceed
const finalQuery = this.nameMapper.transformNode(query);
compiled = this.compileQuery(finalQuery);
if (parameters) {
compiled = { ...compiled, parameters };
}
const result = await connection.executeQuery<any>(compiled);
return { result };
}

if (
(InsertQueryNode.is(query) || UpdateQueryNode.is(query)) &&
this.hasEntityMutationPluginsWithAfterMutationHooks
) {
// need to make sure the query node has "returnAll" for insert and update queries
// so that after-mutation hooks can get the mutated entities with all fields
query = {
...query,
returning: ReturningNode.create([SelectionNode.createSelectAll()]),
};
}
if (this.suppressMutationHooks || !this.isMutationNode(query) || !this.hasEntityMutationPlugins) {
// no need to handle mutation hooks, just proceed
const finalQuery = this.nameMapper.transformNode(query);
compiled = this.compileQuery(finalQuery);
if (parameters) {
compiled = { ...compiled, parameters };
}
return connection.executeQuery<any>(compiled);
}

// the client passed to hooks needs to be in sync with current in-transaction
// status so that it doesn't try to create a nested one
const currentlyInTx = this.driver.isTransactionConnection(connection);

const connectionClient = this.createClientForConnection(connection, currentlyInTx);

const mutationInfo = this.getMutationInfo(finalQuery);

// cache already loaded before-mutation entities
let beforeMutationEntities: Record<string, unknown>[] | undefined;
const loadBeforeMutationEntities = async () => {
if (
beforeMutationEntities === undefined &&
(UpdateQueryNode.is(query) || DeleteQueryNode.is(query))
) {
beforeMutationEntities = await this.loadEntities(
mutationInfo.model,
mutationInfo.where,
connection,
);
}
return beforeMutationEntities;
if (
(InsertQueryNode.is(query) || UpdateQueryNode.is(query)) &&
this.hasEntityMutationPluginsWithAfterMutationHooks
) {
// need to make sure the query node has "returnAll" for insert and update queries
// so that after-mutation hooks can get the mutated entities with all fields
query = {
...query,
returning: ReturningNode.create([SelectionNode.createSelectAll()]),
};
}
const finalQuery = this.nameMapper.transformNode(query);
compiled = this.compileQuery(finalQuery);
if (parameters) {
compiled = { ...compiled, parameters };
}

// call before mutation hooks
await this.callBeforeMutationHooks(
finalQuery,
mutationInfo,
loadBeforeMutationEntities,
connectionClient,
queryId,
);
// the client passed to hooks needs to be in sync with current in-transaction
// status so that it doesn't try to create a nested one
const currentlyInTx = this.driver.isTransactionConnection(connection);

// if mutation interceptor demands to run afterMutation hook in the transaction but we're not already
// inside one, we need to create one on the fly
const shouldCreateTx =
this.hasPluginRequestingAfterMutationWithinTransaction &&
!this.driver.isTransactionConnection(connection);

if (!shouldCreateTx) {
// if no on-the-fly tx is needed, just proceed with the query as is
const result = await connection.executeQuery<any>(compiled);

if (!this.driver.isTransactionConnection(connection)) {
// not in a transaction, just call all after-mutation hooks
await this.callAfterMutationHooks(
result,
finalQuery,
mutationInfo,
connectionClient,
'all',
queryId,
);
} else {
// run after-mutation hooks that are requested to be run inside tx
await this.callAfterMutationHooks(
result,
finalQuery,
mutationInfo,
connectionClient,
'inTx',
queryId,
);

// register other after-mutation hooks to be run after the tx is committed
this.driver.registerTransactionCommitCallback(connection, () =>
this.callAfterMutationHooks(
result,
finalQuery,
mutationInfo,
connectionClient,
'outTx',
queryId,
),
);
}

return { result };
} else {
// if an on-the-fly tx is created, create one and wrap the query execution inside
await this.driver.beginTransaction(connection, {
isolationLevel: TransactionIsolationLevel.ReadCommitted,
});
try {
// execute the query inside the on-the-fly transaction
const result = await connection.executeQuery<any>(compiled);

// run after-mutation hooks that are requested to be run inside tx
await this.callAfterMutationHooks(
result,
finalQuery,
mutationInfo,
connectionClient,
'inTx',
queryId,
);

// commit the transaction
await this.driver.commitTransaction(connection);

// run other after-mutation hooks after the tx is committed
await this.callAfterMutationHooks(
result,
finalQuery,
mutationInfo,
connectionClient,
'outTx',
queryId,
);

return { result };
} catch (err) {
// rollback the transaction
await this.driver.rollbackTransaction(connection);
throw err;
}
const connectionClient = this.createClientForConnection(connection, currentlyInTx);

const mutationInfo = this.getMutationInfo(finalQuery);

// cache already loaded before-mutation entities
let beforeMutationEntities: Record<string, unknown>[] | undefined;
const loadBeforeMutationEntities = async () => {
if (beforeMutationEntities === undefined && (UpdateQueryNode.is(query) || DeleteQueryNode.is(query))) {
beforeMutationEntities = await this.loadEntities(
mutationInfo.model,
mutationInfo.where,
connection,
);
}
});
return beforeMutationEntities;
};

// call before mutation hooks
await this.callBeforeMutationHooks(
finalQuery,
mutationInfo,
loadBeforeMutationEntities,
connectionClient,
queryId,
);

const result = await connection.executeQuery<any>(compiled);

if (!this.driver.isTransactionConnection(connection)) {
// not in a transaction, just call all after-mutation hooks
await this.callAfterMutationHooks(result, finalQuery, mutationInfo, connectionClient, 'all', queryId);
} else {
// run after-mutation hooks that are requested to be run inside tx
await this.callAfterMutationHooks(result, finalQuery, mutationInfo, connectionClient, 'inTx', queryId);

// register other after-mutation hooks to be run after the tx is committed
this.driver.registerTransactionCommitCallback(connection, () =>
this.callAfterMutationHooks(result, finalQuery, mutationInfo, connectionClient, 'outTx', queryId),
);
}

return result;
} catch (err) {
const message = `Failed to execute query: ${err}, sql: ${compiled?.sql}`;
throw new QueryError(message, err);
Expand All @@ -307,12 +260,6 @@ export class ZenStackQueryExecutor<Schema extends SchemaDef> extends DefaultQuer
return (this.client.$options.plugins ?? []).some((plugin) => plugin.onEntityMutation?.afterEntityMutation);
}

private get hasPluginRequestingAfterMutationWithinTransaction() {
return (this.client.$options.plugins ?? []).some(
(plugin) => plugin.onEntityMutation?.runAfterMutationWithinTransaction,
);
}

private isMutationNode(queryNode: RootOperationNode): queryNode is MutationQueryNode {
return InsertQueryNode.is(queryNode) || UpdateQueryNode.is(queryNode) || DeleteQueryNode.is(queryNode);
}
Expand Down
13 changes: 7 additions & 6 deletions packages/runtime/test/policy/crud/post-update.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,7 @@ describe('Policy post-update tests', () => {
await expect(db.foo.findUnique({ where: { id: 2 } })).resolves.toMatchObject({ x: 3 });
});

// TODO: fix transaction issue
it.skip('works with query builder API', async () => {
it('works with query builder API', async () => {
const db = await createPolicyTestClient(
`
model Foo {
Expand All @@ -153,14 +152,16 @@ describe('Policy post-update tests', () => {
await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 1 });
await expect(db.foo.findUnique({ where: { id: 2 } })).resolves.toMatchObject({ x: 2 });

await expect(db.$qb.updateTable('Foo').set({ x: 2 }).where('id', '=', 1).execute()).resolves.toMatchObject({
numAffectedRows: 1n,
await expect(
db.$qb.updateTable('Foo').set({ x: 2 }).where('id', '=', 1).executeTakeFirst(),
).resolves.toMatchObject({
numUpdatedRows: 1n,
});
// check updated
await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 2 });

await expect(db.$qb.updateTable('Foo').set({ x: 3 }).execute()).resolves.toMatchObject({
numAffectedRows: 2n,
await expect(db.$qb.updateTable('Foo').set({ x: 3 }).executeTakeFirst()).resolves.toMatchObject({
numUpdatedRows: 2n,
});
// check updated
await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 3 });
Expand Down
Loading