Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
- [x] Raw queries
- [ ] Transactions
- [x] Interactive transaction
- [ ] Batch transaction
- [x] Sequential transaction
- [ ] Extensions
- [x] Query builder API
- [x] Computed fields
Expand Down
1 change: 1 addition & 0 deletions packages/cli/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"data modeling"
],
"bin": {
"zen": "bin/cli",
"zenstack": "bin/cli"
},
"scripts": {
Expand Down
91 changes: 77 additions & 14 deletions packages/runtime/src/client/client-impl.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { lowerCaseFirst } from '@zenstackhq/common-helpers';
import { invariant, lowerCaseFirst } from '@zenstackhq/common-helpers';
import type { QueryExecutor, SqliteDialectConfig } from 'kysely';
import {
CompiledQuery,
Expand All @@ -15,7 +15,8 @@ import {
import { match } from 'ts-pattern';
import type { GetModels, ProcedureDef, SchemaDef } from '../schema';
import type { AuthType } from '../schema/auth';
import type { ClientConstructor, ClientContract, ModelOperations } from './contract';
import type { UnwrapTuplePromises } from '../utils/type-utils';
import type { ClientConstructor, ClientContract, ModelOperations, TransactionIsolationLevel } from './contract';
import { AggregateOperationHandler } from './crud/operations/aggregate';
import type { CrudOperation } from './crud/operations/base';
import { BaseOperationHandler } from './crud/operations/base';
Expand All @@ -33,7 +34,7 @@ import * as BuiltinFunctions from './functions';
import { SchemaDbPusher } from './helpers/schema-db-pusher';
import type { ClientOptions, ProceduresOptions } from './options';
import type { RuntimePlugin } from './plugin';
import { createDeferredPromise } from './promise';
import { createZenStackPromise, type ZenStackPromise } from './promise';
import type { ToKysely } from './query-builder';
import { ResultProcessor } from './result-processor';

Expand Down Expand Up @@ -123,6 +124,10 @@ export class ClientImpl<Schema extends SchemaDef> {
return this.kyselyRaw;
}

get isTransaction() {
return this.kysely.isTransaction;
}

/**
* Create a new client with a new query executor.
*/
Expand All @@ -145,20 +150,78 @@ export class ClientImpl<Schema extends SchemaDef> {
return new SqliteDialect(this.options.dialectConfig as SqliteDialectConfig);
}

async $transaction<T>(callback: (tx: ClientContract<Schema>) => Promise<T>): Promise<T> {
// overload for interactive transaction
$transaction<T>(
callback: (tx: ClientContract<Schema>) => Promise<T>,
options?: { isolationLevel?: TransactionIsolationLevel },
): Promise<T>;

// overload for sequential transaction
$transaction<P extends ZenStackPromise<Schema, any>[]>(
arg: [...P],
options?: { isolationLevel?: TransactionIsolationLevel },
): Promise<UnwrapTuplePromises<P>>;

// implementation
async $transaction(input: any, options?: { isolationLevel?: TransactionIsolationLevel }) {
invariant(
typeof input === 'function' || (Array.isArray(input) && input.every((p) => p.then && p.cb)),
'Invalid transaction input, expected a function or an array of ZenStackPromise',
);
if (typeof input === 'function') {
return this.interactiveTransaction(input, options);
} else {
return this.sequentialTransaction(input, options);
}
}

private async interactiveTransaction(
callback: (tx: ClientContract<Schema>) => Promise<any>,
options?: { isolationLevel?: TransactionIsolationLevel },
): Promise<any> {
if (this.kysely.isTransaction) {
// proceed directly if already in a transaction
return callback(this as unknown as ClientContract<Schema>);
} else {
// otherwise, create a new transaction, clone the client, and execute the callback
return this.kysely.transaction().execute((tx) => {
const txClient = new ClientImpl<Schema>(this.schema, this.$options);
let txBuilder = this.kysely.transaction();
if (options?.isolationLevel) {
txBuilder = txBuilder.setIsolationLevel(options.isolationLevel);
}
return txBuilder.execute((tx) => {
const txClient = new ClientImpl<Schema>(this.schema, this.$options, this);
txClient.kysely = tx;
return callback(txClient as unknown as ClientContract<Schema>);
});
}
}

private async sequentialTransaction(
arg: ZenStackPromise<Schema, any>[],
options?: { isolationLevel?: TransactionIsolationLevel },
) {
const execute = async (tx: Kysely<any>) => {
const txClient = new ClientImpl<Schema>(this.schema, this.$options, this);
txClient.kysely = tx;
const result: any[] = [];
for (const promise of arg) {
result.push(await promise.cb(txClient as unknown as ClientContract<Schema>));
}
return result;
};
if (this.kysely.isTransaction) {
// proceed directly if already in a transaction
return execute(this.kysely);
} else {
// otherwise, create a new transaction, clone the client, and execute the callback
let txBuilder = this.kysely.transaction();
if (options?.isolationLevel) {
txBuilder = txBuilder.setIsolationLevel(options.isolationLevel);
}
return txBuilder.execute((tx) => execute(tx as Kysely<any>));
}
}

get $procedures() {
return Object.keys(this.$schema.procedures ?? {}).reduce((acc, name) => {
acc[name] = (...args: unknown[]) => this.handleProc(name, args);
Expand Down Expand Up @@ -229,29 +292,29 @@ export class ClientImpl<Schema extends SchemaDef> {
}

$executeRaw(query: TemplateStringsArray, ...values: any[]) {
return createDeferredPromise(async () => {
return createZenStackPromise(async () => {
const result = await sql(query, ...values).execute(this.kysely);
return Number(result.numAffectedRows ?? 0);
});
}

$executeRawUnsafe(query: string, ...values: any[]) {
return createDeferredPromise(async () => {
return createZenStackPromise(async () => {
const compiledQuery = this.createRawCompiledQuery(query, values);
const result = await this.kysely.executeQuery(compiledQuery);
return Number(result.numAffectedRows ?? 0);
});
}

$queryRaw<T = unknown>(query: TemplateStringsArray, ...values: any[]) {
return createDeferredPromise(async () => {
return createZenStackPromise(async () => {
const result = await sql(query, ...values).execute(this.kysely);
return result.rows as T;
});
}

$queryRawUnsafe<T = unknown>(query: string, ...values: any[]) {
return createDeferredPromise(async () => {
return createZenStackPromise(async () => {
const compiledQuery = this.createRawCompiledQuery(query, values);
const result = await this.kysely.executeQuery(compiledQuery);
return result.rows as T;
Expand All @@ -278,7 +341,7 @@ function createClientProxy<Schema extends SchemaDef>(client: ClientImpl<Schema>)
const model = Object.keys(client.$schema.models).find((m) => m.toLowerCase() === prop.toLowerCase());
if (model) {
return createModelCrudHandler(
client as ClientContract<Schema>,
client as unknown as ClientContract<Schema>,
model as GetModels<Schema>,
inputValidator,
resultProcessor,
Expand All @@ -304,9 +367,9 @@ function createModelCrudHandler<Schema extends SchemaDef, Model extends GetModel
postProcess = false,
throwIfNoResult = false,
) => {
return createDeferredPromise(async () => {
let proceed = async (_args?: unknown, tx?: ClientContract<Schema>) => {
const _handler = tx ? handler.withClient(tx) : handler;
return createZenStackPromise(async (txClient?: ClientContract<Schema>) => {
let proceed = async (_args?: unknown) => {
const _handler = txClient ? handler.withClient(txClient) : handler;
const r = await _handler.handle(operation, _args ?? args);
if (!r && throwIfNoResult) {
throw new NotFoundError(model);
Expand Down
5 changes: 5 additions & 0 deletions packages/runtime/src/client/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,8 @@ export const CONTEXT_COMMENT_PREFIX = '-- $$context:';
* The types of fields that are numeric.
*/
export const NUMERIC_FIELD_TYPES = ['Int', 'Float', 'BigInt', 'Decimal'];

/**
* Client API methods that are not supported in transactions.
*/
export const TRANSACTION_UNSUPPORTED_METHODS = ['$transaction', '$disconnect', '$use'] as const;
Loading