1- import { lowerCaseFirst } from '@zenstackhq/common-helpers' ;
1+ import { invariant , lowerCaseFirst } from '@zenstackhq/common-helpers' ;
22import type { QueryExecutor , SqliteDialectConfig } from 'kysely' ;
33import {
44 CompiledQuery ,
@@ -15,7 +15,7 @@ import {
1515import { match } from 'ts-pattern' ;
1616import type { GetModels , ProcedureDef , SchemaDef } from '../schema' ;
1717import type { AuthType } from '../schema/auth' ;
18- import type { ClientConstructor , ClientContract , ModelOperations } from './contract' ;
18+ import type { ClientConstructor , ClientContract , ModelOperations , TransactionIsolationLevel } from './contract' ;
1919import { AggregateOperationHandler } from './crud/operations/aggregate' ;
2020import type { CrudOperation } from './crud/operations/base' ;
2121import { BaseOperationHandler } from './crud/operations/base' ;
@@ -33,7 +33,7 @@ import * as BuiltinFunctions from './functions';
3333import { SchemaDbPusher } from './helpers/schema-db-pusher' ;
3434import type { ClientOptions , ProceduresOptions } from './options' ;
3535import type { RuntimePlugin } from './plugin' ;
36- import { createDeferredPromise } from './promise' ;
36+ import { createZenStackPromise , type ZenStackPromise } from './promise' ;
3737import type { ToKysely } from './query-builder' ;
3838import { ResultProcessor } from './result-processor' ;
3939
@@ -145,20 +145,75 @@ export class ClientImpl<Schema extends SchemaDef> {
145145 return new SqliteDialect ( this . options . dialectConfig as SqliteDialectConfig ) ;
146146 }
147147
148- async $transaction < T > ( callback : ( tx : ClientContract < Schema > ) => Promise < T > ) : Promise < T > {
148+ // overload for interactive transaction
149+ $transaction < T > (
150+ callback : ( tx : ClientContract < Schema > ) => Promise < T > ,
151+ options ?: { isolationLevel ?: TransactionIsolationLevel } ,
152+ ) : Promise < T > ;
153+
154+ // overload for sequential transaction
155+ $transaction < P extends Promise < any > [ ] > ( arg : [ ...P ] , options ?: { isolationLevel ?: TransactionIsolationLevel } ) : P ;
156+
157+ // implementation
158+ async $transaction ( input : any , options ?: { isolationLevel ?: TransactionIsolationLevel } ) {
159+ invariant (
160+ typeof input === 'function' || ( Array . isArray ( input ) && input . every ( ( p ) => p . then ) ) ,
161+ 'Invalid transaction input, expected a function or an array of ZenStackClient promises' ,
162+ ) ;
163+ if ( typeof input === 'function' ) {
164+ return this . interactiveTransaction ( input , options ) ;
165+ } else {
166+ return this . sequentialTransaction ( input , options ) ;
167+ }
168+ }
169+
170+ private async interactiveTransaction (
171+ callback : ( tx : ClientContract < Schema > ) => Promise < any > ,
172+ options ?: { isolationLevel ?: TransactionIsolationLevel } ,
173+ ) : Promise < any > {
149174 if ( this . kysely . isTransaction ) {
150175 // proceed directly if already in a transaction
151176 return callback ( this as unknown as ClientContract < Schema > ) ;
152177 } else {
153178 // otherwise, create a new transaction, clone the client, and execute the callback
154- return this . kysely . transaction ( ) . execute ( ( tx ) => {
155- const txClient = new ClientImpl < Schema > ( this . schema , this . $options ) ;
179+ let txBuilder = this . kysely . transaction ( ) ;
180+ if ( options ?. isolationLevel ) {
181+ txBuilder = txBuilder . setIsolationLevel ( options . isolationLevel ) ;
182+ }
183+ return txBuilder . execute ( ( tx ) => {
184+ const txClient = new ClientImpl < Schema > ( this . schema , this . $options , this ) ;
156185 txClient . kysely = tx ;
157186 return callback ( txClient as unknown as ClientContract < Schema > ) ;
158187 } ) ;
159188 }
160189 }
161190
191+ private async sequentialTransaction (
192+ arg : ZenStackPromise < Schema , any > [ ] ,
193+ options ?: { isolationLevel ?: TransactionIsolationLevel } ,
194+ ) {
195+ const execute = async ( tx : Kysely < any > ) => {
196+ const txClient = new ClientImpl < Schema > ( this . schema , this . $options , this ) ;
197+ txClient . kysely = tx ;
198+ const result : any [ ] = [ ] ;
199+ for ( const promise of arg ) {
200+ result . push ( await promise . cb ( txClient as unknown as ClientContract < Schema > ) ) ;
201+ }
202+ return result ;
203+ } ;
204+ if ( this . kysely . isTransaction ) {
205+ // proceed directly if already in a transaction
206+ return execute ( this . kysely ) ;
207+ } else {
208+ // otherwise, create a new transaction, clone the client, and execute the callback
209+ let txBuilder = this . kysely . transaction ( ) ;
210+ if ( options ?. isolationLevel ) {
211+ txBuilder = txBuilder . setIsolationLevel ( options . isolationLevel ) ;
212+ }
213+ return txBuilder . execute ( ( tx ) => execute ( tx as Kysely < any > ) ) ;
214+ }
215+ }
216+
162217 get $procedures ( ) {
163218 return Object . keys ( this . $schema . procedures ?? { } ) . reduce ( ( acc , name ) => {
164219 acc [ name ] = ( ...args : unknown [ ] ) => this . handleProc ( name , args ) ;
@@ -229,29 +284,29 @@ export class ClientImpl<Schema extends SchemaDef> {
229284 }
230285
231286 $executeRaw ( query : TemplateStringsArray , ...values : any [ ] ) {
232- return createDeferredPromise ( async ( ) => {
287+ return createZenStackPromise ( async ( ) => {
233288 const result = await sql ( query , ...values ) . execute ( this . kysely ) ;
234289 return Number ( result . numAffectedRows ?? 0 ) ;
235290 } ) ;
236291 }
237292
238293 $executeRawUnsafe ( query : string , ...values : any [ ] ) {
239- return createDeferredPromise ( async ( ) => {
294+ return createZenStackPromise ( async ( ) => {
240295 const compiledQuery = this . createRawCompiledQuery ( query , values ) ;
241296 const result = await this . kysely . executeQuery ( compiledQuery ) ;
242297 return Number ( result . numAffectedRows ?? 0 ) ;
243298 } ) ;
244299 }
245300
246301 $queryRaw < T = unknown > ( query : TemplateStringsArray , ...values : any [ ] ) {
247- return createDeferredPromise ( async ( ) => {
302+ return createZenStackPromise ( async ( ) => {
248303 const result = await sql ( query , ...values ) . execute ( this . kysely ) ;
249304 return result . rows as T ;
250305 } ) ;
251306 }
252307
253308 $queryRawUnsafe < T = unknown > ( query : string , ...values : any [ ] ) {
254- return createDeferredPromise ( async ( ) => {
309+ return createZenStackPromise ( async ( ) => {
255310 const compiledQuery = this . createRawCompiledQuery ( query , values ) ;
256311 const result = await this . kysely . executeQuery ( compiledQuery ) ;
257312 return result . rows as T ;
@@ -278,7 +333,7 @@ function createClientProxy<Schema extends SchemaDef>(client: ClientImpl<Schema>)
278333 const model = Object . keys ( client . $schema . models ) . find ( ( m ) => m . toLowerCase ( ) === prop . toLowerCase ( ) ) ;
279334 if ( model ) {
280335 return createModelCrudHandler (
281- client as ClientContract < Schema > ,
336+ client as unknown as ClientContract < Schema > ,
282337 model as GetModels < Schema > ,
283338 inputValidator ,
284339 resultProcessor ,
@@ -304,9 +359,9 @@ function createModelCrudHandler<Schema extends SchemaDef, Model extends GetModel
304359 postProcess = false ,
305360 throwIfNoResult = false ,
306361 ) => {
307- return createDeferredPromise ( async ( ) => {
308- let proceed = async ( _args ?: unknown , tx ?: ClientContract < Schema > ) => {
309- const _handler = tx ? handler . withClient ( tx ) : handler ;
362+ return createZenStackPromise ( async ( txClient ?: ClientContract < Schema > ) => {
363+ let proceed = async ( _args ?: unknown ) => {
364+ const _handler = txClient ? handler . withClient ( txClient ) : handler ;
310365 const r = await _handler . handle ( operation , _args ?? args ) ;
311366 if ( ! r && throwIfNoResult ) {
312367 throw new NotFoundError ( model ) ;
0 commit comments