Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
- [x] Raw queries
- [ ] Transactions
- [x] Interactive transaction
- [ ] Batch transaction
- [x] Sequential transaction
- [ ] Extensions
- [x] Query builder API
- [x] Computed fields
Expand Down
1 change: 1 addition & 0 deletions packages/cli/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"data modeling"
],
"bin": {
"zen": "bin/cli",
"zenstack": "bin/cli"
},
"scripts": {
Expand Down
83 changes: 69 additions & 14 deletions packages/runtime/src/client/client-impl.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { lowerCaseFirst } from '@zenstackhq/common-helpers';
import { invariant, lowerCaseFirst } from '@zenstackhq/common-helpers';
import type { QueryExecutor, SqliteDialectConfig } from 'kysely';
import {
CompiledQuery,
Expand All @@ -15,7 +15,7 @@ import {
import { match } from 'ts-pattern';
import type { GetModels, ProcedureDef, SchemaDef } from '../schema';
import type { AuthType } from '../schema/auth';
import type { ClientConstructor, ClientContract, ModelOperations } from './contract';
import type { ClientConstructor, ClientContract, ModelOperations, TransactionIsolationLevel } from './contract';
import { AggregateOperationHandler } from './crud/operations/aggregate';
import type { CrudOperation } from './crud/operations/base';
import { BaseOperationHandler } from './crud/operations/base';
Expand All @@ -33,7 +33,7 @@ import * as BuiltinFunctions from './functions';
import { SchemaDbPusher } from './helpers/schema-db-pusher';
import type { ClientOptions, ProceduresOptions } from './options';
import type { RuntimePlugin } from './plugin';
import { createDeferredPromise } from './promise';
import { createZenStackPromise, type ZenStackPromise } from './promise';
import type { ToKysely } from './query-builder';
import { ResultProcessor } from './result-processor';

Expand Down Expand Up @@ -145,20 +145,75 @@ export class ClientImpl<Schema extends SchemaDef> {
return new SqliteDialect(this.options.dialectConfig as SqliteDialectConfig);
}

async $transaction<T>(callback: (tx: ClientContract<Schema>) => Promise<T>): Promise<T> {
// overload for interactive transaction
$transaction<T>(
callback: (tx: ClientContract<Schema>) => Promise<T>,
options?: { isolationLevel?: TransactionIsolationLevel },
): Promise<T>;

// overload for sequential transaction
$transaction<P extends Promise<any>[]>(arg: [...P], options?: { isolationLevel?: TransactionIsolationLevel }): P;

// implementation
async $transaction(input: any, options?: { isolationLevel?: TransactionIsolationLevel }) {
invariant(
typeof input === 'function' || (Array.isArray(input) && input.every((p) => p.then)),
'Invalid transaction input, expected a function or an array of ZenStackClient promises',
);
if (typeof input === 'function') {
return this.interactiveTransaction(input, options);
} else {
return this.sequentialTransaction(input, options);
}
}

private async interactiveTransaction(
callback: (tx: ClientContract<Schema>) => Promise<any>,
options?: { isolationLevel?: TransactionIsolationLevel },
): Promise<any> {
if (this.kysely.isTransaction) {
// proceed directly if already in a transaction
return callback(this as unknown as ClientContract<Schema>);
} else {
// otherwise, create a new transaction, clone the client, and execute the callback
return this.kysely.transaction().execute((tx) => {
const txClient = new ClientImpl<Schema>(this.schema, this.$options);
let txBuilder = this.kysely.transaction();
if (options?.isolationLevel) {
txBuilder = txBuilder.setIsolationLevel(options.isolationLevel);
}
return txBuilder.execute((tx) => {
const txClient = new ClientImpl<Schema>(this.schema, this.$options, this);
txClient.kysely = tx;
return callback(txClient as unknown as ClientContract<Schema>);
});
}
}

private async sequentialTransaction(
arg: ZenStackPromise<Schema, any>[],
options?: { isolationLevel?: TransactionIsolationLevel },
) {
const execute = async (tx: Kysely<any>) => {
const txClient = new ClientImpl<Schema>(this.schema, this.$options, this);
txClient.kysely = tx;
const result: any[] = [];
for (const promise of arg) {
result.push(await promise.cb(txClient as unknown as ClientContract<Schema>));
}
return result;
};
if (this.kysely.isTransaction) {
// proceed directly if already in a transaction
return execute(this.kysely);
} else {
// otherwise, create a new transaction, clone the client, and execute the callback
let txBuilder = this.kysely.transaction();
if (options?.isolationLevel) {
txBuilder = txBuilder.setIsolationLevel(options.isolationLevel);
}
return txBuilder.execute((tx) => execute(tx as Kysely<any>));
}
}

get $procedures() {
return Object.keys(this.$schema.procedures ?? {}).reduce((acc, name) => {
acc[name] = (...args: unknown[]) => this.handleProc(name, args);
Expand Down Expand Up @@ -229,29 +284,29 @@ export class ClientImpl<Schema extends SchemaDef> {
}

$executeRaw(query: TemplateStringsArray, ...values: any[]) {
return createDeferredPromise(async () => {
return createZenStackPromise(async () => {
const result = await sql(query, ...values).execute(this.kysely);
return Number(result.numAffectedRows ?? 0);
});
}

$executeRawUnsafe(query: string, ...values: any[]) {
return createDeferredPromise(async () => {
return createZenStackPromise(async () => {
const compiledQuery = this.createRawCompiledQuery(query, values);
const result = await this.kysely.executeQuery(compiledQuery);
return Number(result.numAffectedRows ?? 0);
});
}

$queryRaw<T = unknown>(query: TemplateStringsArray, ...values: any[]) {
return createDeferredPromise(async () => {
return createZenStackPromise(async () => {
const result = await sql(query, ...values).execute(this.kysely);
return result.rows as T;
});
}

$queryRawUnsafe<T = unknown>(query: string, ...values: any[]) {
return createDeferredPromise(async () => {
return createZenStackPromise(async () => {
const compiledQuery = this.createRawCompiledQuery(query, values);
const result = await this.kysely.executeQuery(compiledQuery);
return result.rows as T;
Expand All @@ -278,7 +333,7 @@ function createClientProxy<Schema extends SchemaDef>(client: ClientImpl<Schema>)
const model = Object.keys(client.$schema.models).find((m) => m.toLowerCase() === prop.toLowerCase());
if (model) {
return createModelCrudHandler(
client as ClientContract<Schema>,
client as unknown as ClientContract<Schema>,
model as GetModels<Schema>,
inputValidator,
resultProcessor,
Expand All @@ -304,9 +359,9 @@ function createModelCrudHandler<Schema extends SchemaDef, Model extends GetModel
postProcess = false,
throwIfNoResult = false,
) => {
return createDeferredPromise(async () => {
let proceed = async (_args?: unknown, tx?: ClientContract<Schema>) => {
const _handler = tx ? handler.withClient(tx) : handler;
return createZenStackPromise(async (txClient?: ClientContract<Schema>) => {
let proceed = async (_args?: unknown) => {
const _handler = txClient ? handler.withClient(txClient) : handler;
const r = await _handler.handle(operation, _args ?? args);
if (!r && throwIfNoResult) {
throw new NotFoundError(model);
Expand Down
5 changes: 5 additions & 0 deletions packages/runtime/src/client/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,8 @@ export const CONTEXT_COMMENT_PREFIX = '-- $$context:';
* The types of fields that are numeric.
*/
export const NUMERIC_FIELD_TYPES = ['Int', 'Float', 'BigInt', 'Decimal'];

/**
* Client API methods that are not supported in transactions.
*/
export const TRANSACTION_UNSUPPORTED_METHODS = ['$transaction', '$disconnect', '$use'] as const;
31 changes: 28 additions & 3 deletions packages/runtime/src/client/contract.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import type { Decimal } from 'decimal.js';
import { type GetModels, type ProcedureDef, type SchemaDef } from '../schema';
import type { AuthType } from '../schema/auth';
import type { OrUndefinedIf } from '../utils/type-utils';
import type { OrUndefinedIf, UnwrapTuplePromises } from '../utils/type-utils';
import type {
AggregateArgs,
AggregateResult,
Expand All @@ -28,6 +28,20 @@ import type {
import type { ClientOptions } from './options';
import type { RuntimePlugin } from './plugin';
import type { ToKysely } from './query-builder';
import type { TRANSACTION_UNSUPPORTED_METHODS } from './constants';

type TransactionUnsupportedMethods = (typeof TRANSACTION_UNSUPPORTED_METHODS)[number];

/**
* Transaction isolation levels.
*/
export enum TransactionIsolationLevel {
ReadUncommitted = 'read uncommitted',
ReadCommitted = 'read committed',
RepeatableRead = 'repeatable read',
Serializable = 'serializable',
Snapshot = 'snapshot',
}

/**
* ZenStack client interface.
Expand Down Expand Up @@ -99,9 +113,20 @@ export type ClientContract<Schema extends SchemaDef> = {
readonly $qbRaw: ToKysely<any>;

/**
* Starts a transaction.
* Starts an interactive transaction.
*/
$transaction<T>(
callback: (tx: Omit<ClientContract<Schema>, TransactionUnsupportedMethods>) => Promise<T>,
options?: { isolationLevel?: TransactionIsolationLevel },
): Promise<T>;

/**
* Starts a sequential transaction.
*/
$transaction<T>(callback: (tx: ClientContract<Schema>) => Promise<T>): Promise<T>;
$transaction<P extends Promise<any>[]>(
arg: [...P],
options?: { isolationLevel?: TransactionIsolationLevel },
): Promise<UnwrapTuplePromises<P>>;

/**
* Returns a new client with the specified plugin installed.
Expand Down
23 changes: 20 additions & 3 deletions packages/runtime/src/client/promise.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,28 @@
import type { SchemaDef } from '../schema';
import type { ClientContract } from './contract';

/**
* A promise that only executes when it's awaited or .then() is called.
*/
export type ZenStackPromise<Schema extends SchemaDef, T> = Promise<T> & {
/**
* @private
* Callable to get a plain promise.
*/
cb: (txClient?: ClientContract<Schema>) => Promise<T>;
};

/**
* Creates a promise that only executes when it's awaited or .then() is called.
* @see https://github.com/prisma/prisma/blob/main/packages/client/src/runtime/core/request/createPrismaPromise.ts
*/
export function createDeferredPromise<T>(callback: () => Promise<T>): Promise<T> {
export function createZenStackPromise<Schema extends SchemaDef, T>(
callback: (txClient?: ClientContract<Schema>) => Promise<T>,
): ZenStackPromise<Schema, T> {
let promise: Promise<T> | undefined;
const cb = () => {
const cb = (txClient?: ClientContract<Schema>) => {
try {
return (promise ??= valueToPromise(callback()));
return (promise ??= valueToPromise(callback(txClient)));
} catch (err) {
// deal with synchronous errors
return Promise.reject<T>(err);
Expand All @@ -23,6 +39,7 @@ export function createDeferredPromise<T>(callback: () => Promise<T>): Promise<T>
finally(onFinally) {
return cb().finally(onFinally);
},
cb,
[Symbol.toStringTag]: 'ZenStackPromise',
};
}
Expand Down
4 changes: 4 additions & 0 deletions packages/runtime/src/utils/type-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,7 @@ export type PrependParameter<Param, Func> = Func extends (...args: any[]) => inf
: never;

export type OrUndefinedIf<T, Condition extends boolean> = Condition extends true ? T | undefined : T;

export type UnwrapTuplePromises<T extends readonly unknown[]> = {
[K in keyof T]: Awaited<T[K]>;
};
Loading
Loading