Skip to content

Commit 4069781

Browse files
committed
feat: implement sequential transaction
1 parent d6834d8 commit 4069781

File tree

8 files changed

+239
-79
lines changed

8 files changed

+239
-79
lines changed

TODO.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
- [x] Raw queries
5555
- [ ] Transactions
5656
- [x] Interactive transaction
57-
- [ ] Batch transaction
57+
- [x] Sequential transaction
5858
- [ ] Extensions
5959
- [x] Query builder API
6060
- [x] Computed fields

packages/cli/package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"data modeling"
1919
],
2020
"bin": {
21+
"zen": "bin/cli",
2122
"zenstack": "bin/cli"
2223
},
2324
"scripts": {

packages/runtime/src/client/client-impl.ts

Lines changed: 69 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { lowerCaseFirst } from '@zenstackhq/common-helpers';
1+
import { invariant, lowerCaseFirst } from '@zenstackhq/common-helpers';
22
import type { QueryExecutor, SqliteDialectConfig } from 'kysely';
33
import {
44
CompiledQuery,
@@ -15,7 +15,7 @@ import {
1515
import { match } from 'ts-pattern';
1616
import type { GetModels, ProcedureDef, SchemaDef } from '../schema';
1717
import type { AuthType } from '../schema/auth';
18-
import type { ClientConstructor, ClientContract, ModelOperations } from './contract';
18+
import type { ClientConstructor, ClientContract, ModelOperations, TransactionIsolationLevel } from './contract';
1919
import { AggregateOperationHandler } from './crud/operations/aggregate';
2020
import type { CrudOperation } from './crud/operations/base';
2121
import { BaseOperationHandler } from './crud/operations/base';
@@ -33,7 +33,7 @@ import * as BuiltinFunctions from './functions';
3333
import { SchemaDbPusher } from './helpers/schema-db-pusher';
3434
import type { ClientOptions, ProceduresOptions } from './options';
3535
import type { RuntimePlugin } from './plugin';
36-
import { createDeferredPromise } from './promise';
36+
import { createZenStackPromise, type ZenStackPromise } from './promise';
3737
import type { ToKysely } from './query-builder';
3838
import { ResultProcessor } from './result-processor';
3939

@@ -145,20 +145,75 @@ export class ClientImpl<Schema extends SchemaDef> {
145145
return new SqliteDialect(this.options.dialectConfig as SqliteDialectConfig);
146146
}
147147

148-
async $transaction<T>(callback: (tx: ClientContract<Schema>) => Promise<T>): Promise<T> {
148+
// overload for interactive transaction
149+
$transaction<T>(
150+
callback: (tx: ClientContract<Schema>) => Promise<T>,
151+
options?: { isolationLevel?: TransactionIsolationLevel },
152+
): Promise<T>;
153+
154+
// overload for sequential transaction
155+
$transaction<P extends Promise<any>[]>(arg: [...P], options?: { isolationLevel?: TransactionIsolationLevel }): P;
156+
157+
// implementation
158+
async $transaction(input: any, options?: { isolationLevel?: TransactionIsolationLevel }) {
159+
invariant(
160+
typeof input === 'function' || (Array.isArray(input) && input.every((p) => p.then)),
161+
'Invalid transaction input, expected a function or an array of ZenStackClient promises',
162+
);
163+
if (typeof input === 'function') {
164+
return this.interactiveTransaction(input, options);
165+
} else {
166+
return this.sequentialTransaction(input, options);
167+
}
168+
}
169+
170+
private async interactiveTransaction(
171+
callback: (tx: ClientContract<Schema>) => Promise<any>,
172+
options?: { isolationLevel?: TransactionIsolationLevel },
173+
): Promise<any> {
149174
if (this.kysely.isTransaction) {
150175
// proceed directly if already in a transaction
151176
return callback(this as unknown as ClientContract<Schema>);
152177
} else {
153178
// otherwise, create a new transaction, clone the client, and execute the callback
154-
return this.kysely.transaction().execute((tx) => {
155-
const txClient = new ClientImpl<Schema>(this.schema, this.$options);
179+
let txBuilder = this.kysely.transaction();
180+
if (options?.isolationLevel) {
181+
txBuilder = txBuilder.setIsolationLevel(options.isolationLevel);
182+
}
183+
return txBuilder.execute((tx) => {
184+
const txClient = new ClientImpl<Schema>(this.schema, this.$options, this);
156185
txClient.kysely = tx;
157186
return callback(txClient as unknown as ClientContract<Schema>);
158187
});
159188
}
160189
}
161190

191+
private async sequentialTransaction(
192+
arg: ZenStackPromise<Schema, any>[],
193+
options?: { isolationLevel?: TransactionIsolationLevel },
194+
) {
195+
const execute = async (tx: Kysely<any>) => {
196+
const txClient = new ClientImpl<Schema>(this.schema, this.$options, this);
197+
txClient.kysely = tx;
198+
const result: any[] = [];
199+
for (const promise of arg) {
200+
result.push(await promise.cb(txClient as unknown as ClientContract<Schema>));
201+
}
202+
return result;
203+
};
204+
if (this.kysely.isTransaction) {
205+
// proceed directly if already in a transaction
206+
return execute(this.kysely);
207+
} else {
208+
// otherwise, create a new transaction, clone the client, and execute the callback
209+
let txBuilder = this.kysely.transaction();
210+
if (options?.isolationLevel) {
211+
txBuilder = txBuilder.setIsolationLevel(options.isolationLevel);
212+
}
213+
return txBuilder.execute((tx) => execute(tx as Kysely<any>));
214+
}
215+
}
216+
162217
get $procedures() {
163218
return Object.keys(this.$schema.procedures ?? {}).reduce((acc, name) => {
164219
acc[name] = (...args: unknown[]) => this.handleProc(name, args);
@@ -229,29 +284,29 @@ export class ClientImpl<Schema extends SchemaDef> {
229284
}
230285

231286
$executeRaw(query: TemplateStringsArray, ...values: any[]) {
232-
return createDeferredPromise(async () => {
287+
return createZenStackPromise(async () => {
233288
const result = await sql(query, ...values).execute(this.kysely);
234289
return Number(result.numAffectedRows ?? 0);
235290
});
236291
}
237292

238293
$executeRawUnsafe(query: string, ...values: any[]) {
239-
return createDeferredPromise(async () => {
294+
return createZenStackPromise(async () => {
240295
const compiledQuery = this.createRawCompiledQuery(query, values);
241296
const result = await this.kysely.executeQuery(compiledQuery);
242297
return Number(result.numAffectedRows ?? 0);
243298
});
244299
}
245300

246301
$queryRaw<T = unknown>(query: TemplateStringsArray, ...values: any[]) {
247-
return createDeferredPromise(async () => {
302+
return createZenStackPromise(async () => {
248303
const result = await sql(query, ...values).execute(this.kysely);
249304
return result.rows as T;
250305
});
251306
}
252307

253308
$queryRawUnsafe<T = unknown>(query: string, ...values: any[]) {
254-
return createDeferredPromise(async () => {
309+
return createZenStackPromise(async () => {
255310
const compiledQuery = this.createRawCompiledQuery(query, values);
256311
const result = await this.kysely.executeQuery(compiledQuery);
257312
return result.rows as T;
@@ -278,7 +333,7 @@ function createClientProxy<Schema extends SchemaDef>(client: ClientImpl<Schema>)
278333
const model = Object.keys(client.$schema.models).find((m) => m.toLowerCase() === prop.toLowerCase());
279334
if (model) {
280335
return createModelCrudHandler(
281-
client as ClientContract<Schema>,
336+
client as unknown as ClientContract<Schema>,
282337
model as GetModels<Schema>,
283338
inputValidator,
284339
resultProcessor,
@@ -304,9 +359,9 @@ function createModelCrudHandler<Schema extends SchemaDef, Model extends GetModel
304359
postProcess = false,
305360
throwIfNoResult = false,
306361
) => {
307-
return createDeferredPromise(async () => {
308-
let proceed = async (_args?: unknown, tx?: ClientContract<Schema>) => {
309-
const _handler = tx ? handler.withClient(tx) : handler;
362+
return createZenStackPromise(async (txClient?: ClientContract<Schema>) => {
363+
let proceed = async (_args?: unknown) => {
364+
const _handler = txClient ? handler.withClient(txClient) : handler;
310365
const r = await _handler.handle(operation, _args ?? args);
311366
if (!r && throwIfNoResult) {
312367
throw new NotFoundError(model);

packages/runtime/src/client/constants.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,8 @@ export const CONTEXT_COMMENT_PREFIX = '-- $$context:';
77
* The types of fields that are numeric.
88
*/
99
export const NUMERIC_FIELD_TYPES = ['Int', 'Float', 'BigInt', 'Decimal'];
10+
11+
/**
12+
* Client API methods that are not supported in transactions.
13+
*/
14+
export const TRANSACTION_UNSUPPORTED_METHODS = ['$transaction', '$disconnect', '$use'] as const;

packages/runtime/src/client/contract.ts

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import type { Decimal } from 'decimal.js';
22
import { type GetModels, type ProcedureDef, type SchemaDef } from '../schema';
33
import type { AuthType } from '../schema/auth';
4-
import type { OrUndefinedIf } from '../utils/type-utils';
4+
import type { OrUndefinedIf, UnwrapTuplePromises } from '../utils/type-utils';
55
import type {
66
AggregateArgs,
77
AggregateResult,
@@ -28,6 +28,20 @@ import type {
2828
import type { ClientOptions } from './options';
2929
import type { RuntimePlugin } from './plugin';
3030
import type { ToKysely } from './query-builder';
31+
import type { TRANSACTION_UNSUPPORTED_METHODS } from './constants';
32+
33+
type TransactionUnsupportedMethods = (typeof TRANSACTION_UNSUPPORTED_METHODS)[number];
34+
35+
/**
36+
* Transaction isolation levels.
37+
*/
38+
export enum TransactionIsolationLevel {
39+
ReadUncommitted = 'read uncommitted',
40+
ReadCommitted = 'read committed',
41+
RepeatableRead = 'repeatable read',
42+
Serializable = 'serializable',
43+
Snapshot = 'snapshot',
44+
}
3145

3246
/**
3347
* ZenStack client interface.
@@ -99,9 +113,20 @@ export type ClientContract<Schema extends SchemaDef> = {
99113
readonly $qbRaw: ToKysely<any>;
100114

101115
/**
102-
* Starts a transaction.
116+
* Starts an interactive transaction.
117+
*/
118+
$transaction<T>(
119+
callback: (tx: Omit<ClientContract<Schema>, TransactionUnsupportedMethods>) => Promise<T>,
120+
options?: { isolationLevel?: TransactionIsolationLevel },
121+
): Promise<T>;
122+
123+
/**
124+
* Starts a sequential transaction.
103125
*/
104-
$transaction<T>(callback: (tx: ClientContract<Schema>) => Promise<T>): Promise<T>;
126+
$transaction<P extends Promise<any>[]>(
127+
arg: [...P],
128+
options?: { isolationLevel?: TransactionIsolationLevel },
129+
): Promise<UnwrapTuplePromises<P>>;
105130

106131
/**
107132
* Returns a new client with the specified plugin installed.

packages/runtime/src/client/promise.ts

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,28 @@
1+
import type { SchemaDef } from '../schema';
2+
import type { ClientContract } from './contract';
3+
4+
/**
5+
* A promise that only executes when it's awaited or .then() is called.
6+
*/
7+
export type ZenStackPromise<Schema extends SchemaDef, T> = Promise<T> & {
8+
/**
9+
* @private
10+
* Callable to get a plain promise.
11+
*/
12+
cb: (txClient?: ClientContract<Schema>) => Promise<T>;
13+
};
14+
115
/**
216
* Creates a promise that only executes when it's awaited or .then() is called.
317
* @see https://github.com/prisma/prisma/blob/main/packages/client/src/runtime/core/request/createPrismaPromise.ts
418
*/
5-
export function createDeferredPromise<T>(callback: () => Promise<T>): Promise<T> {
19+
export function createZenStackPromise<Schema extends SchemaDef, T>(
20+
callback: (txClient?: ClientContract<Schema>) => Promise<T>,
21+
): ZenStackPromise<Schema, T> {
622
let promise: Promise<T> | undefined;
7-
const cb = () => {
23+
const cb = (txClient?: ClientContract<Schema>) => {
824
try {
9-
return (promise ??= valueToPromise(callback()));
25+
return (promise ??= valueToPromise(callback(txClient)));
1026
} catch (err) {
1127
// deal with synchronous errors
1228
return Promise.reject<T>(err);
@@ -23,6 +39,7 @@ export function createDeferredPromise<T>(callback: () => Promise<T>): Promise<T>
2339
finally(onFinally) {
2440
return cb().finally(onFinally);
2541
},
42+
cb,
2643
[Symbol.toStringTag]: 'ZenStackPromise',
2744
};
2845
}

packages/runtime/src/utils/type-utils.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,7 @@ export type PrependParameter<Param, Func> = Func extends (...args: any[]) => inf
6868
: never;
6969

7070
export type OrUndefinedIf<T, Condition extends boolean> = Condition extends true ? T | undefined : T;
71+
72+
export type UnwrapTuplePromises<T extends readonly unknown[]> = {
73+
[K in keyof T]: Awaited<T[K]>;
74+
};

0 commit comments

Comments
 (0)