diff --git a/packages/runtime/src/client/errors.ts b/packages/runtime/src/client/errors.ts index 1d6134e9..15961811 100644 --- a/packages/runtime/src/client/errors.ts +++ b/packages/runtime/src/client/errors.ts @@ -1,7 +1,12 @@ +/** + * Base for all ZenStack runtime errors. + */ +export class ZenStackError extends Error {} + /** * Error thrown when input validation fails. */ -export class InputValidationError extends Error { +export class InputValidationError extends ZenStackError { constructor(message: string, cause?: unknown) { super(message, { cause }); } @@ -10,7 +15,7 @@ export class InputValidationError extends Error { /** * Error thrown when a query fails. */ -export class QueryError extends Error { +export class QueryError extends ZenStackError { constructor(message: string, cause?: unknown) { super(message, { cause }); } @@ -19,12 +24,12 @@ export class QueryError extends Error { /** * Error thrown when an internal error occurs. */ -export class InternalError extends Error {} +export class InternalError extends ZenStackError {} /** * Error thrown when an entity is not found. */ -export class NotFoundError extends Error { +export class NotFoundError extends ZenStackError { constructor(model: string, details?: string) { super(`Entity not found for model "${model}"${details ? `: ${details}` : ''}`); } diff --git a/packages/runtime/src/client/executor/zenstack-query-executor.ts b/packages/runtime/src/client/executor/zenstack-query-executor.ts index 2d2395cb..c307bc4e 100644 --- a/packages/runtime/src/client/executor/zenstack-query-executor.ts +++ b/packages/runtime/src/client/executor/zenstack-query-executor.ts @@ -25,7 +25,7 @@ import { match } from 'ts-pattern'; import type { GetModels, SchemaDef } from '../../schema'; import { type ClientImpl } from '../client-impl'; import { TransactionIsolationLevel, type ClientContract } from '../contract'; -import { InternalError, QueryError } from '../errors'; +import { InternalError, QueryError, ZenStackError } from '../errors'; import { stripAlias } from '../kysely-utils'; import type { AfterEntityMutationCallback, OnKyselyQueryCallback } from '../plugin'; import { QueryNameMapper } from './name-mapper'; @@ -65,21 +65,53 @@ export class ZenStackQueryExecutor extends DefaultQuer return this.client.$options; } - override async executeQuery(compiledQuery: CompiledQuery, queryId: QueryId) { + override executeQuery(compiledQuery: CompiledQuery, queryId: QueryId) { // proceed with the query with kysely interceptors // if the query is a raw query, we need to carry over the parameters const queryParams = (compiledQuery as any).$raw ? compiledQuery.parameters : undefined; - const result = await this.proceedQueryWithKyselyInterceptors(compiledQuery.query, queryParams, queryId.queryId); - return result.result; + return this.provideConnection(async (connection) => { + let startedTx = false; + try { + // mutations are wrapped in tx if not already in one + if (this.isMutationNode(compiledQuery.query) && !this.driver.isTransactionConnection(connection)) { + await this.driver.beginTransaction(connection, { + isolationLevel: TransactionIsolationLevel.RepeatableRead, + }); + startedTx = true; + } + const result = await this.proceedQueryWithKyselyInterceptors( + connection, + compiledQuery.query, + queryParams, + queryId.queryId, + ); + if (startedTx) { + await this.driver.commitTransaction(connection); + } + return result; + } catch (err) { + if (startedTx) { + await this.driver.rollbackTransaction(connection); + } + if (err instanceof ZenStackError) { + throw err; + } else { + // wrap error + const message = `Failed to execute query: ${err}, sql: ${compiledQuery?.sql}`; + throw new QueryError(message, err); + } + } + }); } private async proceedQueryWithKyselyInterceptors( + connection: DatabaseConnection, queryNode: RootOperationNode, parameters: readonly unknown[] | undefined, queryId: string, ) { - let proceed = (q: RootOperationNode) => this.proceedQuery(q, parameters, queryId); + let proceed = (q: RootOperationNode) => this.proceedQuery(connection, q, parameters, queryId); const hooks: OnKyselyQueryCallback[] = []; // tsc perf @@ -92,18 +124,14 @@ export class ZenStackQueryExecutor extends DefaultQuer for (const hook of hooks) { const _proceed = proceed; proceed = async (query: RootOperationNode) => { - const _p = async (q: RootOperationNode) => { - const r = await _proceed(q); - return r.result; - }; - + const _p = (q: RootOperationNode) => _proceed(q); const hookResult = await hook!({ client: this.client as ClientContract, schema: this.client.$schema, query, proceed: _p, }); - return { result: hookResult }; + return hookResult; }; } @@ -132,161 +160,83 @@ export class ZenStackQueryExecutor extends DefaultQuer return { model, action, where }; } - private async proceedQuery(query: RootOperationNode, parameters: readonly unknown[] | undefined, queryId: string) { + private async proceedQuery( + connection: DatabaseConnection, + query: RootOperationNode, + parameters: readonly unknown[] | undefined, + queryId: string, + ) { let compiled: CompiledQuery | undefined; - try { - return await this.provideConnection(async (connection) => { - if (this.suppressMutationHooks || !this.isMutationNode(query) || !this.hasEntityMutationPlugins) { - // no need to handle mutation hooks, just proceed - const finalQuery = this.nameMapper.transformNode(query); - compiled = this.compileQuery(finalQuery); - if (parameters) { - compiled = { ...compiled, parameters }; - } - const result = await connection.executeQuery(compiled); - return { result }; - } + if (this.suppressMutationHooks || !this.isMutationNode(query) || !this.hasEntityMutationPlugins) { + // no need to handle mutation hooks, just proceed + const finalQuery = this.nameMapper.transformNode(query); + compiled = this.compileQuery(finalQuery); + if (parameters) { + compiled = { ...compiled, parameters }; + } + return connection.executeQuery(compiled); + } - if ( - (InsertQueryNode.is(query) || UpdateQueryNode.is(query)) && - this.hasEntityMutationPluginsWithAfterMutationHooks - ) { - // need to make sure the query node has "returnAll" for insert and update queries - // so that after-mutation hooks can get the mutated entities with all fields - query = { - ...query, - returning: ReturningNode.create([SelectionNode.createSelectAll()]), - }; - } - const finalQuery = this.nameMapper.transformNode(query); - compiled = this.compileQuery(finalQuery); - if (parameters) { - compiled = { ...compiled, parameters }; - } + if ( + (InsertQueryNode.is(query) || UpdateQueryNode.is(query)) && + this.hasEntityMutationPluginsWithAfterMutationHooks + ) { + // need to make sure the query node has "returnAll" for insert and update queries + // so that after-mutation hooks can get the mutated entities with all fields + query = { + ...query, + returning: ReturningNode.create([SelectionNode.createSelectAll()]), + }; + } + const finalQuery = this.nameMapper.transformNode(query); + compiled = this.compileQuery(finalQuery); + if (parameters) { + compiled = { ...compiled, parameters }; + } - // the client passed to hooks needs to be in sync with current in-transaction - // status so that it doesn't try to create a nested one - const currentlyInTx = this.driver.isTransactionConnection(connection); - - const connectionClient = this.createClientForConnection(connection, currentlyInTx); - - const mutationInfo = this.getMutationInfo(finalQuery); - - // cache already loaded before-mutation entities - let beforeMutationEntities: Record[] | undefined; - const loadBeforeMutationEntities = async () => { - if ( - beforeMutationEntities === undefined && - (UpdateQueryNode.is(query) || DeleteQueryNode.is(query)) - ) { - beforeMutationEntities = await this.loadEntities( - mutationInfo.model, - mutationInfo.where, - connection, - ); - } - return beforeMutationEntities; - }; - - // call before mutation hooks - await this.callBeforeMutationHooks( - finalQuery, - mutationInfo, - loadBeforeMutationEntities, - connectionClient, - queryId, - ); + // the client passed to hooks needs to be in sync with current in-transaction + // status so that it doesn't try to create a nested one + const currentlyInTx = this.driver.isTransactionConnection(connection); - // if mutation interceptor demands to run afterMutation hook in the transaction but we're not already - // inside one, we need to create one on the fly - const shouldCreateTx = - this.hasPluginRequestingAfterMutationWithinTransaction && - !this.driver.isTransactionConnection(connection); - - if (!shouldCreateTx) { - // if no on-the-fly tx is needed, just proceed with the query as is - const result = await connection.executeQuery(compiled); - - if (!this.driver.isTransactionConnection(connection)) { - // not in a transaction, just call all after-mutation hooks - await this.callAfterMutationHooks( - result, - finalQuery, - mutationInfo, - connectionClient, - 'all', - queryId, - ); - } else { - // run after-mutation hooks that are requested to be run inside tx - await this.callAfterMutationHooks( - result, - finalQuery, - mutationInfo, - connectionClient, - 'inTx', - queryId, - ); - - // register other after-mutation hooks to be run after the tx is committed - this.driver.registerTransactionCommitCallback(connection, () => - this.callAfterMutationHooks( - result, - finalQuery, - mutationInfo, - connectionClient, - 'outTx', - queryId, - ), - ); - } - - return { result }; - } else { - // if an on-the-fly tx is created, create one and wrap the query execution inside - await this.driver.beginTransaction(connection, { - isolationLevel: TransactionIsolationLevel.ReadCommitted, - }); - try { - // execute the query inside the on-the-fly transaction - const result = await connection.executeQuery(compiled); - - // run after-mutation hooks that are requested to be run inside tx - await this.callAfterMutationHooks( - result, - finalQuery, - mutationInfo, - connectionClient, - 'inTx', - queryId, - ); - - // commit the transaction - await this.driver.commitTransaction(connection); - - // run other after-mutation hooks after the tx is committed - await this.callAfterMutationHooks( - result, - finalQuery, - mutationInfo, - connectionClient, - 'outTx', - queryId, - ); - - return { result }; - } catch (err) { - // rollback the transaction - await this.driver.rollbackTransaction(connection); - throw err; - } - } - }); - } catch (err) { - const message = `Failed to execute query: ${err}, sql: ${compiled?.sql}`; - throw new QueryError(message, err); + const connectionClient = this.createClientForConnection(connection, currentlyInTx); + + const mutationInfo = this.getMutationInfo(finalQuery); + + // cache already loaded before-mutation entities + let beforeMutationEntities: Record[] | undefined; + const loadBeforeMutationEntities = async () => { + if (beforeMutationEntities === undefined && (UpdateQueryNode.is(query) || DeleteQueryNode.is(query))) { + beforeMutationEntities = await this.loadEntities(mutationInfo.model, mutationInfo.where, connection); + } + return beforeMutationEntities; + }; + + // call before mutation hooks + await this.callBeforeMutationHooks( + finalQuery, + mutationInfo, + loadBeforeMutationEntities, + connectionClient, + queryId, + ); + + const result = await connection.executeQuery(compiled); + + if (!this.driver.isTransactionConnection(connection)) { + // not in a transaction, just call all after-mutation hooks + await this.callAfterMutationHooks(result, finalQuery, mutationInfo, connectionClient, 'all', queryId); + } else { + // run after-mutation hooks that are requested to be run inside tx + await this.callAfterMutationHooks(result, finalQuery, mutationInfo, connectionClient, 'inTx', queryId); + + // register other after-mutation hooks to be run after the tx is committed + this.driver.registerTransactionCommitCallback(connection, () => + this.callAfterMutationHooks(result, finalQuery, mutationInfo, connectionClient, 'outTx', queryId), + ); } + + return result; } private createClientForConnection(connection: DatabaseConnection, inTx: boolean) { @@ -307,12 +257,6 @@ export class ZenStackQueryExecutor extends DefaultQuer return (this.client.$options.plugins ?? []).some((plugin) => plugin.onEntityMutation?.afterEntityMutation); } - private get hasPluginRequestingAfterMutationWithinTransaction() { - return (this.client.$options.plugins ?? []).some( - (plugin) => plugin.onEntityMutation?.runAfterMutationWithinTransaction, - ); - } - private isMutationNode(queryNode: RootOperationNode): queryNode is MutationQueryNode { return InsertQueryNode.is(queryNode) || UpdateQueryNode.is(queryNode) || DeleteQueryNode.is(queryNode); } diff --git a/packages/runtime/src/plugins/policy/errors.ts b/packages/runtime/src/plugins/policy/errors.ts index 675506d6..42d57b18 100644 --- a/packages/runtime/src/plugins/policy/errors.ts +++ b/packages/runtime/src/plugins/policy/errors.ts @@ -1,3 +1,5 @@ +import { ZenStackError } from '../../client'; + /** * Reason code for policy rejection. */ @@ -21,7 +23,7 @@ export enum RejectedByPolicyReason { /** * Error thrown when an operation is rejected by access policy. */ -export class RejectedByPolicyError extends Error { +export class RejectedByPolicyError extends ZenStackError { constructor( public readonly model: string | undefined, public readonly reason: RejectedByPolicyReason = RejectedByPolicyReason.NO_ACCESS, diff --git a/packages/runtime/test/policy/crud/post-update.test.ts b/packages/runtime/test/policy/crud/post-update.test.ts index 585ee180..98c2b5b5 100644 --- a/packages/runtime/test/policy/crud/post-update.test.ts +++ b/packages/runtime/test/policy/crud/post-update.test.ts @@ -126,8 +126,7 @@ describe('Policy post-update tests', () => { await expect(db.foo.findUnique({ where: { id: 2 } })).resolves.toMatchObject({ x: 3 }); }); - // TODO: fix transaction issue - it.skip('works with query builder API', async () => { + it('works with query builder API', async () => { const db = await createPolicyTestClient( ` model Foo { @@ -153,14 +152,16 @@ describe('Policy post-update tests', () => { await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 1 }); await expect(db.foo.findUnique({ where: { id: 2 } })).resolves.toMatchObject({ x: 2 }); - await expect(db.$qb.updateTable('Foo').set({ x: 2 }).where('id', '=', 1).execute()).resolves.toMatchObject({ - numAffectedRows: 1n, + await expect( + db.$qb.updateTable('Foo').set({ x: 2 }).where('id', '=', 1).executeTakeFirst(), + ).resolves.toMatchObject({ + numUpdatedRows: 1n, }); // check updated await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 2 }); - await expect(db.$qb.updateTable('Foo').set({ x: 3 }).execute()).resolves.toMatchObject({ - numAffectedRows: 2n, + await expect(db.$qb.updateTable('Foo').set({ x: 3 }).executeTakeFirst()).resolves.toMatchObject({ + numUpdatedRows: 2n, }); // check updated await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 3 }); diff --git a/packages/runtime/test/policy/migrated/multi-field-unique.test.ts b/packages/runtime/test/policy/migrated/multi-field-unique.test.ts index 7edbe019..fba22b09 100644 --- a/packages/runtime/test/policy/migrated/multi-field-unique.test.ts +++ b/packages/runtime/test/policy/migrated/multi-field-unique.test.ts @@ -1,19 +1,8 @@ -import path from 'path'; -import { afterEach, beforeAll, describe, expect, it } from 'vitest'; -import { createPolicyTestClient } from '../utils'; +import { describe, expect, it } from 'vitest'; import { QueryError } from '../../../src'; +import { createPolicyTestClient } from '../utils'; describe('Policy tests multi-field unique', () => { - let origDir: string; - - beforeAll(async () => { - origDir = path.resolve('.'); - }); - - afterEach(() => { - process.chdir(origDir); - }); - it('toplevel crud test unnamed constraint', async () => { const db = await createPolicyTestClient( `