1- import { lowerCaseFirst } from '@zenstackhq/common-helpers' ;
1+ import { invariant , lowerCaseFirst } from '@zenstackhq/common-helpers' ;
22import type { QueryExecutor , SqliteDialectConfig } from 'kysely' ;
33import {
44 CompiledQuery ,
@@ -15,7 +15,8 @@ import {
1515import { match } from 'ts-pattern' ;
1616import type { GetModels , ProcedureDef , SchemaDef } from '../schema' ;
1717import type { AuthType } from '../schema/auth' ;
18- import type { ClientConstructor , ClientContract , ModelOperations } from './contract' ;
18+ import type { UnwrapTuplePromises } from '../utils/type-utils' ;
19+ import type { ClientConstructor , ClientContract , ModelOperations , TransactionIsolationLevel } from './contract' ;
1920import { AggregateOperationHandler } from './crud/operations/aggregate' ;
2021import type { CrudOperation } from './crud/operations/base' ;
2122import { BaseOperationHandler } from './crud/operations/base' ;
@@ -33,7 +34,7 @@ import * as BuiltinFunctions from './functions';
3334import { SchemaDbPusher } from './helpers/schema-db-pusher' ;
3435import type { ClientOptions , ProceduresOptions } from './options' ;
3536import type { RuntimePlugin } from './plugin' ;
36- import { createDeferredPromise } from './promise' ;
37+ import { createZenStackPromise , type ZenStackPromise } from './promise' ;
3738import type { ToKysely } from './query-builder' ;
3839import { ResultProcessor } from './result-processor' ;
3940
@@ -123,6 +124,10 @@ export class ClientImpl<Schema extends SchemaDef> {
123124 return this . kyselyRaw ;
124125 }
125126
127+ get isTransaction ( ) {
128+ return this . kysely . isTransaction ;
129+ }
130+
126131 /**
127132 * Create a new client with a new query executor.
128133 */
@@ -145,20 +150,78 @@ export class ClientImpl<Schema extends SchemaDef> {
145150 return new SqliteDialect ( this . options . dialectConfig as SqliteDialectConfig ) ;
146151 }
147152
148- async $transaction < T > ( callback : ( tx : ClientContract < Schema > ) => Promise < T > ) : Promise < T > {
153+ // overload for interactive transaction
154+ $transaction < T > (
155+ callback : ( tx : ClientContract < Schema > ) => Promise < T > ,
156+ options ?: { isolationLevel ?: TransactionIsolationLevel } ,
157+ ) : Promise < T > ;
158+
159+ // overload for sequential transaction
160+ $transaction < P extends ZenStackPromise < Schema , any > [ ] > (
161+ arg : [ ...P ] ,
162+ options ?: { isolationLevel ?: TransactionIsolationLevel } ,
163+ ) : Promise < UnwrapTuplePromises < P > > ;
164+
165+ // implementation
166+ async $transaction ( input : any , options ?: { isolationLevel ?: TransactionIsolationLevel } ) {
167+ invariant (
168+ typeof input === 'function' || ( Array . isArray ( input ) && input . every ( ( p ) => p . then && p . cb ) ) ,
169+ 'Invalid transaction input, expected a function or an array of ZenStackPromise' ,
170+ ) ;
171+ if ( typeof input === 'function' ) {
172+ return this . interactiveTransaction ( input , options ) ;
173+ } else {
174+ return this . sequentialTransaction ( input , options ) ;
175+ }
176+ }
177+
178+ private async interactiveTransaction (
179+ callback : ( tx : ClientContract < Schema > ) => Promise < any > ,
180+ options ?: { isolationLevel ?: TransactionIsolationLevel } ,
181+ ) : Promise < any > {
149182 if ( this . kysely . isTransaction ) {
150183 // proceed directly if already in a transaction
151184 return callback ( this as unknown as ClientContract < Schema > ) ;
152185 } else {
153186 // otherwise, create a new transaction, clone the client, and execute the callback
154- return this . kysely . transaction ( ) . execute ( ( tx ) => {
155- const txClient = new ClientImpl < Schema > ( this . schema , this . $options ) ;
187+ let txBuilder = this . kysely . transaction ( ) ;
188+ if ( options ?. isolationLevel ) {
189+ txBuilder = txBuilder . setIsolationLevel ( options . isolationLevel ) ;
190+ }
191+ return txBuilder . execute ( ( tx ) => {
192+ const txClient = new ClientImpl < Schema > ( this . schema , this . $options , this ) ;
156193 txClient . kysely = tx ;
157194 return callback ( txClient as unknown as ClientContract < Schema > ) ;
158195 } ) ;
159196 }
160197 }
161198
199+ private async sequentialTransaction (
200+ arg : ZenStackPromise < Schema , any > [ ] ,
201+ options ?: { isolationLevel ?: TransactionIsolationLevel } ,
202+ ) {
203+ const execute = async ( tx : Kysely < any > ) => {
204+ const txClient = new ClientImpl < Schema > ( this . schema , this . $options , this ) ;
205+ txClient . kysely = tx ;
206+ const result : any [ ] = [ ] ;
207+ for ( const promise of arg ) {
208+ result . push ( await promise . cb ( txClient as unknown as ClientContract < Schema > ) ) ;
209+ }
210+ return result ;
211+ } ;
212+ if ( this . kysely . isTransaction ) {
213+ // proceed directly if already in a transaction
214+ return execute ( this . kysely ) ;
215+ } else {
216+ // otherwise, create a new transaction, clone the client, and execute the callback
217+ let txBuilder = this . kysely . transaction ( ) ;
218+ if ( options ?. isolationLevel ) {
219+ txBuilder = txBuilder . setIsolationLevel ( options . isolationLevel ) ;
220+ }
221+ return txBuilder . execute ( ( tx ) => execute ( tx as Kysely < any > ) ) ;
222+ }
223+ }
224+
162225 get $procedures ( ) {
163226 return Object . keys ( this . $schema . procedures ?? { } ) . reduce ( ( acc , name ) => {
164227 acc [ name ] = ( ...args : unknown [ ] ) => this . handleProc ( name , args ) ;
@@ -229,29 +292,29 @@ export class ClientImpl<Schema extends SchemaDef> {
229292 }
230293
231294 $executeRaw ( query : TemplateStringsArray , ...values : any [ ] ) {
232- return createDeferredPromise ( async ( ) => {
295+ return createZenStackPromise ( async ( ) => {
233296 const result = await sql ( query , ...values ) . execute ( this . kysely ) ;
234297 return Number ( result . numAffectedRows ?? 0 ) ;
235298 } ) ;
236299 }
237300
238301 $executeRawUnsafe ( query : string , ...values : any [ ] ) {
239- return createDeferredPromise ( async ( ) => {
302+ return createZenStackPromise ( async ( ) => {
240303 const compiledQuery = this . createRawCompiledQuery ( query , values ) ;
241304 const result = await this . kysely . executeQuery ( compiledQuery ) ;
242305 return Number ( result . numAffectedRows ?? 0 ) ;
243306 } ) ;
244307 }
245308
246309 $queryRaw < T = unknown > ( query : TemplateStringsArray , ...values : any [ ] ) {
247- return createDeferredPromise ( async ( ) => {
310+ return createZenStackPromise ( async ( ) => {
248311 const result = await sql ( query , ...values ) . execute ( this . kysely ) ;
249312 return result . rows as T ;
250313 } ) ;
251314 }
252315
253316 $queryRawUnsafe < T = unknown > ( query : string , ...values : any [ ] ) {
254- return createDeferredPromise ( async ( ) => {
317+ return createZenStackPromise ( async ( ) => {
255318 const compiledQuery = this . createRawCompiledQuery ( query , values ) ;
256319 const result = await this . kysely . executeQuery ( compiledQuery ) ;
257320 return result . rows as T ;
@@ -278,7 +341,7 @@ function createClientProxy<Schema extends SchemaDef>(client: ClientImpl<Schema>)
278341 const model = Object . keys ( client . $schema . models ) . find ( ( m ) => m . toLowerCase ( ) === prop . toLowerCase ( ) ) ;
279342 if ( model ) {
280343 return createModelCrudHandler (
281- client as ClientContract < Schema > ,
344+ client as unknown as ClientContract < Schema > ,
282345 model as GetModels < Schema > ,
283346 inputValidator ,
284347 resultProcessor ,
@@ -304,9 +367,9 @@ function createModelCrudHandler<Schema extends SchemaDef, Model extends GetModel
304367 postProcess = false ,
305368 throwIfNoResult = false ,
306369 ) => {
307- return createDeferredPromise ( async ( ) => {
308- let proceed = async ( _args ?: unknown , tx ?: ClientContract < Schema > ) => {
309- const _handler = tx ? handler . withClient ( tx ) : handler ;
370+ return createZenStackPromise ( async ( txClient ?: ClientContract < Schema > ) => {
371+ let proceed = async ( _args ?: unknown ) => {
372+ const _handler = txClient ? handler . withClient ( txClient ) : handler ;
310373 const r = await _handler . handle ( operation , _args ?? args ) ;
311374 if ( ! r && throwIfNoResult ) {
312375 throw new NotFoundError ( model ) ;
0 commit comments