Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@
- [x] Count
- [x] Aggregate
- [x] Group by
- [ ] Raw queries
- [x] Raw queries
- [ ] Transactions
- [x] Interactive transaction
- [ ] Batch transaction
- [ ] Extensions
- [x] Query builder API
- [x] Computed fields
Expand All @@ -69,6 +72,8 @@
- [x] Custom field name
- [ ] Strict undefined checks
- [ ] Benchmark
- [ ] Plugin
- [ ] Post-mutation hooks should be called after transaction is committed
- [ ] Polymorphism
- [ ] Validation
- [ ] Access Policy
Expand Down
50 changes: 33 additions & 17 deletions packages/runtime/src/client/client-impl.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { lowerCaseFirst } from '@zenstackhq/common-helpers';
import type { SqliteDialectConfig } from 'kysely';
import type { QueryExecutor, SqliteDialectConfig } from 'kysely';
import {
CompiledQuery,
DefaultConnectionProvider,
Expand Down Expand Up @@ -60,6 +60,7 @@ export class ClientImpl<Schema extends SchemaDef> {
private readonly schema: Schema,
private options: ClientOptions<Schema>,
baseClient?: ClientImpl<Schema>,
executor?: QueryExecutor,
) {
this.$schema = schema;
this.$options = options ?? ({} as ClientOptions<Schema>);
Expand All @@ -73,22 +74,24 @@ export class ClientImpl<Schema extends SchemaDef> {
if (baseClient) {
this.kyselyProps = {
...baseClient.kyselyProps,
executor: new ZenStackQueryExecutor(
this,
baseClient.kyselyProps.driver as ZenStackDriver,
baseClient.kyselyProps.dialect.createQueryCompiler(),
baseClient.kyselyProps.dialect.createAdapter(),
new DefaultConnectionProvider(baseClient.kyselyProps.driver),
),
executor:
executor ??
new ZenStackQueryExecutor(
this,
baseClient.kyselyProps.driver as ZenStackDriver,
baseClient.kyselyProps.dialect.createQueryCompiler(),
baseClient.kyselyProps.dialect.createAdapter(),
new DefaultConnectionProvider(baseClient.kyselyProps.driver),
),
};
this.kyselyRaw = baseClient.kyselyRaw;
this.auth = baseClient.auth;
} else {
const dialect = this.getKyselyDialect();
const driver = new ZenStackDriver(dialect.createDriver(), new Log(this.$options.log ?? []));
const compiler = dialect.createQueryCompiler();
const adapter = dialect.createAdapter();
const connectionProvider = new DefaultConnectionProvider(driver);
const executor = new ZenStackQueryExecutor(this, driver, compiler, adapter, connectionProvider);

this.kyselyProps = {
config: {
Expand All @@ -97,7 +100,7 @@ export class ClientImpl<Schema extends SchemaDef> {
},
dialect,
driver,
executor,
executor: executor ?? new ZenStackQueryExecutor(this, driver, compiler, adapter, connectionProvider),
};

// raw kysely instance with default executor
Expand All @@ -112,14 +115,21 @@ export class ClientImpl<Schema extends SchemaDef> {
return createClientProxy(this);
}

public get $qb() {
get $qb() {
return this.kysely;
}

public get $qbRaw() {
get $qbRaw() {
return this.kyselyRaw;
}

/**
* Create a new client with a new query executor.
*/
withExecutor(executor: QueryExecutor) {
return new ClientImpl(this.schema, this.$options, this, executor);
}

private getKyselyDialect() {
return match(this.schema.provider.type)
.with('sqlite', () => this.makeSqliteKyselyDialect())
Expand All @@ -136,11 +146,17 @@ export class ClientImpl<Schema extends SchemaDef> {
}

async $transaction<T>(callback: (tx: ClientContract<Schema>) => Promise<T>): Promise<T> {
return this.kysely.transaction().execute((tx) => {
const txClient = new ClientImpl<Schema>(this.schema, this.$options);
txClient.kysely = tx;
return callback(txClient as unknown as ClientContract<Schema>);
});
if (this.kysely.isTransaction) {
// proceed directly if already in a transaction
return callback(this as unknown as ClientContract<Schema>);
} else {
// otherwise, create a new transaction, clone the client, and execute the callback
return this.kysely.transaction().execute((tx) => {
const txClient = new ClientImpl<Schema>(this.schema, this.$options);
txClient.kysely = tx;
return callback(txClient as unknown as ClientContract<Schema>);
});
}
}

get $procedures() {
Expand Down
40 changes: 25 additions & 15 deletions packages/runtime/src/client/crud/operations/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ import {
ExpressionWrapper,
sql,
UpdateResult,
type IsolationLevel,
type Expression as KyselyExpression,
type SelectQueryBuilder,
} from 'kysely';
import { nanoid } from 'nanoid';
import { inspect } from 'node:util';
import { match } from 'ts-pattern';
import { ulid } from 'ulid';
import * as uuid from 'uuid';
Expand Down Expand Up @@ -203,7 +205,11 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
result = await query.execute();
} catch (err) {
const { sql, parameters } = query.compile();
throw new QueryError(`Failed to execute query: ${err}, sql: ${sql}, parameters: ${parameters}`);
let message = `Failed to execute query: ${err}, sql: ${sql}`;
if (this.options.debug) {
message += `, parameters: \n${parameters.map((p) => inspect(p)).join('\n')}`;
}
throw new QueryError(message, err);
}

if (inMemoryDistinct) {
Expand Down Expand Up @@ -1181,18 +1187,13 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {

query = query.modifyEnd(this.makeContextComment({ model, operation: 'update' }));

try {
if (!returnData) {
const result = await query.executeTakeFirstOrThrow();
return { count: Number(result.numUpdatedRows) } as Result;
} else {
const idFields = getIdFields(this.schema, model);
const result = await query.returning(idFields as any).execute();
return result as Result;
}
} catch (err) {
const { sql, parameters } = query.compile();
throw new QueryError(`Error during updateMany: ${err}, sql: ${sql}, parameters: ${parameters}`);
if (!returnData) {
const result = await query.executeTakeFirstOrThrow();
return { count: Number(result.numUpdatedRows) } as Result;
} else {
const idFields = getIdFields(this.schema, model);
const result = await query.returning(idFields as any).execute();
return result as Result;
}
}

Expand Down Expand Up @@ -1900,11 +1901,20 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
return returnRelation;
}

protected async safeTransaction<T>(callback: (tx: ToKysely<Schema>) => Promise<T>) {
protected async safeTransaction<T>(
callback: (tx: ToKysely<Schema>) => Promise<T>,
isolationLevel?: IsolationLevel,
) {
if (this.kysely.isTransaction) {
// proceed directly if already in a transaction
return callback(this.kysely);
} else {
return this.kysely.transaction().setIsolationLevel('repeatable read').execute(callback);
// otherwise, create a new transaction and execute the callback
let txBuilder = this.kysely.transaction();
if (isolationLevel) {
txBuilder = txBuilder.setIsolationLevel(isolationLevel);
}
return txBuilder.execute(callback);
}
}

Expand Down
8 changes: 4 additions & 4 deletions packages/runtime/src/client/crud/validator.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { invariant } from '@zenstackhq/common-helpers';
import Decimal from 'decimal.js';
import stableStringify from 'json-stable-stringify';
import { match, P } from 'ts-pattern';
Expand All @@ -19,9 +20,8 @@ import {
type UpdateManyArgs,
type UpsertArgs,
} from '../crud-types';
import { InternalError, QueryError } from '../errors';
import { InputValidationError, InternalError, QueryError } from '../errors';
import { fieldHasDefaultValue, getEnum, getModel, getUniqueFields, requireField, requireModel } from '../query-utils';
import { invariant } from '@zenstackhq/common-helpers';

type GetSchemaFunc<Schema extends SchemaDef, Options> = (model: GetModels<Schema>, options: Options) => ZodType;

Expand Down Expand Up @@ -179,7 +179,7 @@ export class InputValidator<Schema extends SchemaDef> {
}
const { error } = schema.safeParse(args);
if (error) {
throw new QueryError(`Invalid ${operation} args: ${error.message}`);
throw new InputValidationError(`Invalid ${operation} args: ${error.message}`, error);
}
return args as T;
}
Expand Down Expand Up @@ -233,7 +233,7 @@ export class InputValidator<Schema extends SchemaDef> {
private makeWhereSchema(model: string, unique: boolean, withoutRelationFields = false): ZodType {
const modelDef = getModel(this.schema, model);
if (!modelDef) {
throw new QueryError(`Model "${model}" not found`);
throw new QueryError(`Model "${model}" not found in schema`);
}

const fields: Record<string, any> = {};
Expand Down
22 changes: 20 additions & 2 deletions packages/runtime/src/client/errors.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,33 @@
/**
* Error thrown when input validation fails.
*/
export class InputValidationError extends Error {
constructor(message: string, cause?: unknown) {
super(message, { cause });
}
}

/**
* Error thrown when a query fails.
*/
export class QueryError extends Error {
constructor(message: string) {
super(message);
constructor(message: string, cause?: unknown) {
super(message, { cause });
}
}

/**
* Error thrown when an internal error occurs.
*/
export class InternalError extends Error {
constructor(message: string) {
super(message);
}
}

/**
* Error thrown when an entity is not found.
*/
export class NotFoundError extends Error {
constructor(model: string) {
super(`Entity not found for model "${model}"`);
Expand Down
Loading