Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .coderabbit.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
language: 'en-US'
early_access: false
reviews:
auto_review:
enabled: true
chat:
auto_reply: true
32 changes: 32 additions & 0 deletions packages/runtime/src/client/client-impl.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import { lowerCaseFirst } from '@zenstackhq/common-helpers';
import type { SqliteDialectConfig } from 'kysely';
import {
CompiledQuery,
DefaultConnectionProvider,
DefaultQueryExecutor,
Kysely,
Log,
PostgresDialect,
sql,
SqliteDialect,
type KyselyProps,
type PostgresDialectConfig,
Expand Down Expand Up @@ -209,6 +211,36 @@ export class ClientImpl<Schema extends SchemaDef> {
get $auth() {
return this.auth;
}

$executeRaw(query: TemplateStringsArray, ...values: any[]) {
return createDeferredPromise(async () => {
const result = await sql(query, ...values).execute(this.kysely);
return Number(result.numAffectedRows ?? 0);
});
}

$executeRawUnsafe(query: string, ...values: any[]) {
return createDeferredPromise(async () => {
const compiledQuery = CompiledQuery.raw(query, values);
const result = await this.kysely.executeQuery(compiledQuery);
return Number(result.numAffectedRows ?? 0);
});
}

$queryRaw<T = unknown>(query: TemplateStringsArray, ...values: any[]) {
return createDeferredPromise(async () => {
const result = await sql(query, ...values).execute(this.kysely);
return result.rows as T;
});
}

$queryRawUnsafe<T = unknown>(query: string, ...values: any[]) {
return createDeferredPromise(async () => {
const compiledQuery = CompiledQuery.raw(query, values);
const result = await this.kysely.executeQuery(compiledQuery);
return result.rows as T;
});
}
}

function createClientProxy<Schema extends SchemaDef>(client: ClientImpl<Schema>): ClientImpl<Schema> {
Expand Down
38 changes: 38 additions & 0 deletions packages/runtime/src/client/contract.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,44 @@ export type ClientContract<Schema extends SchemaDef> = {
*/
readonly $options: ClientOptions<Schema>;

/**
* Executes a prepared raw query and returns the number of affected rows.
* @example
* ```
* const result = await client.$executeRaw`UPDATE User SET cool = ${true} WHERE email = ${'[email protected]'};`
* ```
*/
$executeRaw(query: TemplateStringsArray, ...values: any[]): Promise<number>;

/**
* Executes a raw query and returns the number of affected rows.
* This method is susceptible to SQL injections.
* @example
* ```
* const result = await client.$executeRawUnsafe('UPDATE User SET cool = $1 WHERE email = $2 ;', true, '[email protected]')
* ```
*/
$executeRawUnsafe(query: string, ...values: any[]): Promise<number>;

/**
* Performs a prepared raw query and returns the `SELECT` data.
* @example
* ```
* const result = await client.$queryRaw`SELECT * FROM User WHERE id = ${1} OR email = ${'[email protected]'};`
* ```
*/
$queryRaw<T = unknown>(query: TemplateStringsArray, ...values: any[]): Promise<T>;

/**
* Performs a raw query and returns the `SELECT` data.
* This method is susceptible to SQL injections.
* @example
* ```
* const result = await client.$queryRawUnsafe('SELECT * FROM User WHERE id = $1 OR email = $2;', 1, '[email protected]')
* ```
*/
$queryRawUnsafe<T = unknown>(query: string, ...values: any[]): Promise<T>;

/**
* The current user identity.
*/
Expand Down
14 changes: 9 additions & 5 deletions packages/runtime/src/client/executor/zenstack-query-executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ export class ZenStackQueryExecutor<Schema extends SchemaDef> extends DefaultQuer
}

// proceed with the query with kysely interceptors
const result = await this.proceedQueryWithKyselyInterceptors(queryNode, queryId);
const result = await this.proceedQueryWithKyselyInterceptors(queryNode, compiledQuery.parameters, queryId);

// call after mutation hooks
await this.callAfterQueryInterceptionFilters(result, queryNode, mutationInterceptionInfo);
Expand All @@ -96,8 +96,12 @@ export class ZenStackQueryExecutor<Schema extends SchemaDef> extends DefaultQuer
return this.executeWithTransaction(task, !!mutationInterceptionInfo?.useTransactionForMutation);
}

private proceedQueryWithKyselyInterceptors(queryNode: RootOperationNode, queryId: QueryId) {
let proceed = (q: RootOperationNode) => this.proceedQuery(q, queryId);
private proceedQueryWithKyselyInterceptors(
queryNode: RootOperationNode,
parameters: readonly unknown[],
queryId: QueryId,
) {
let proceed = (q: RootOperationNode) => this.proceedQuery(q, parameters, queryId);

const makeTx = (p: typeof proceed) => (callback: OnKyselyQueryTransactionCallback) => {
return this.executeWithTransaction(() => callback(p));
Expand Down Expand Up @@ -125,10 +129,10 @@ export class ZenStackQueryExecutor<Schema extends SchemaDef> extends DefaultQuer
return proceed(queryNode);
}

private async proceedQuery(query: RootOperationNode, queryId: QueryId) {
private async proceedQuery(query: RootOperationNode, parameters: readonly unknown[], queryId: QueryId) {
// run built-in transformers
const finalQuery = this.nameMapper.transformNode(query);
const compiled = this.compileQuery(finalQuery);
const compiled: CompiledQuery = { ...this.compileQuery(finalQuery), parameters };
try {
return this.driver.txConnection
? await super
Expand Down
2 changes: 1 addition & 1 deletion packages/runtime/test/client-api/client-specs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { getSchema, schema } from '../test-schema';
import { makePostgresClient, makeSqliteClient } from '../utils';
import type { ClientContract } from '../../src';

export function createClientSpecs(dbName: string, logQueries = false, providers = ['sqlite', 'postgresql'] as const) {
export function createClientSpecs(dbName: string, logQueries = false, providers: string[] = ['sqlite', 'postgresql']) {
const logger = (provider: string) => (event: LogEvent) => {
if (event.level === 'query') {
console.log(`query(${provider}):`, event.query.sql, event.query.parameters);
Expand Down
79 changes: 79 additions & 0 deletions packages/runtime/test/client-api/raw-query.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import { afterEach, beforeEach, describe, expect, it } from 'vitest';
import type { ClientContract } from '../../src/client';
import { schema } from '../test-schema';
import { createClientSpecs } from './client-specs';

const PG_DB_NAME = 'client-api-raw-query-tests';

describe.each(createClientSpecs(PG_DB_NAME, true))('Client raw query tests', ({ createClient, provider }) => {
let client: ClientContract<typeof schema>;

beforeEach(async () => {
client = await createClient();
});

afterEach(async () => {
await client?.$disconnect();
});

it('works with executeRaw', async () => {
await client.user.create({
data: {
id: '1',
email: '[email protected]',
},
});

await expect(
client.$executeRaw`UPDATE "User" SET "email" = ${'[email protected]'} WHERE "id" = ${'1'}`,
).resolves.toBe(1);
await expect(client.user.findFirst()).resolves.toMatchObject({ email: '[email protected]' });
});

it('works with executeRawUnsafe', async () => {
await client.user.create({
data: {
id: '1',
email: '[email protected]',
},
});

const sql =
provider === 'postgresql'
? `UPDATE "User" SET "email" = $1 WHERE "id" = $2`
: `UPDATE "User" SET "email" = ? WHERE "id" = ?`;
await expect(client.$executeRawUnsafe(sql, '[email protected]', '1')).resolves.toBe(1);
await expect(client.user.findFirst()).resolves.toMatchObject({ email: '[email protected]' });
});

it('works with queryRaw', async () => {
await client.user.create({
data: {
id: '1',
email: '[email protected]',
},
});

const uid = '1';
const users = await client.$queryRaw<
{ id: string; email: string }[]
>`SELECT "User"."id", "User"."email" FROM "User" WHERE "User"."id" = ${uid}`;
expect(users).toEqual([{ id: '1', email: '[email protected]' }]);
});

it('works with queryRawUnsafe', async () => {
await client.user.create({
data: {
id: '1',
email: '[email protected]',
},
});

const sql =
provider === 'postgresql'
? `SELECT "User"."id", "User"."email" FROM "User" WHERE "User"."id" = $1`
: `SELECT "User"."id", "User"."email" FROM "User" WHERE "User"."id" = ?`;
const users = await client.$queryRawUnsafe<{ id: string; email: string }[]>(sql, '1');
expect(users).toEqual([{ id: '1', email: '[email protected]' }]);
});
});
2 changes: 1 addition & 1 deletion packages/testtools/src/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ export async function generateTsSchema(
extraSourceFiles?: Record<string, string>,
) {
const workDir = createTestProject();
console.log(`Working directory: ${workDir}`);
console.log(`Work directory: ${workDir}`);

const zmodelPath = path.join(workDir, 'schema.zmodel');
const noPrelude = schemaText.includes('datasource ');
Expand Down
1 change: 1 addition & 0 deletions turbo.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"tasks": {
"build": {
"dependsOn": ["^build"],
"inputs": ["src/**"],
"outputs": ["dist/**"]
},
"lint": {
Expand Down
Loading