Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions packages/runtime/src/client/client-impl.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import { lowerCaseFirst } from '@zenstackhq/common-helpers';
import type { SqliteDialectConfig } from 'kysely';
import {
CompiledQuery,
DefaultConnectionProvider,
DefaultQueryExecutor,
Kysely,
Log,
PostgresDialect,
sql,
SqliteDialect,
type KyselyProps,
type PostgresDialectConfig,
Expand Down Expand Up @@ -209,6 +211,36 @@ export class ClientImpl<Schema extends SchemaDef> {
get $auth() {
return this.auth;
}

$executeRaw(query: TemplateStringsArray, ...values: any[]) {
return createDeferredPromise(async () => {
const result = await sql(query, ...values).execute(this.kysely);
return Number(result.numAffectedRows ?? 0);
});
}

$executeRawUnsafe(query: string, ...values: any[]) {
return createDeferredPromise(async () => {
const compiledQuery = CompiledQuery.raw(query, values);
const result = await this.kysely.executeQuery(compiledQuery);
return Number(result.numAffectedRows ?? 0);
});
}

$queryRaw<T = unknown>(query: TemplateStringsArray, ...values: any[]) {
return createDeferredPromise(async () => {
const result = await sql(query, ...values).execute(this.kysely);
return result.rows as T;
});
}

$queryRawUnsafe<T = unknown>(query: string, ...values: any[]) {
return createDeferredPromise(async () => {
const compiledQuery = CompiledQuery.raw(query, values);
const result = await this.kysely.executeQuery(compiledQuery);
return result.rows as T;
});
}
}

function createClientProxy<Schema extends SchemaDef>(client: ClientImpl<Schema>): ClientImpl<Schema> {
Expand Down
38 changes: 38 additions & 0 deletions packages/runtime/src/client/contract.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,44 @@ export type ClientContract<Schema extends SchemaDef> = {
*/
readonly $options: ClientOptions<Schema>;

/**
* Executes a prepared raw query and returns the number of affected rows.
* @example
* ```
* const result = await client.$executeRaw`UPDATE User SET cool = ${true} WHERE email = ${'[email protected]'};`
* ```
*/
$executeRaw(query: TemplateStringsArray, ...values: any[]): Promise<number>;

/**
* Executes a raw query and returns the number of affected rows.
* This method is susceptible to SQL injections.
* @example
* ```
* const result = await client.$executeRawUnsafe('UPDATE User SET cool = $1 WHERE email = $2 ;', true, '[email protected]')
* ```
*/
$executeRawUnsafe(query: string, ...values: any[]): Promise<number>;

/**
* Performs a prepared raw query and returns the `SELECT` data.
* @example
* ```
* const result = await client.$queryRaw`SELECT * FROM User WHERE id = ${1} OR email = ${'[email protected]'};`
* ```
*/
$queryRaw<T = unknown>(query: TemplateStringsArray, ...values: any[]): Promise<T>;

/**
* Performs a raw query and returns the `SELECT` data.
* This method is susceptible to SQL injections.
* @example
* ```
* const result = await client.$queryRawUnsafe('SELECT * FROM User WHERE id = $1 OR email = $2;', 1, '[email protected]')
* ```
*/
$queryRawUnsafe<T = unknown>(query: string, ...values: any[]): Promise<T>;

/**
* The current user identity.
*/
Expand Down
14 changes: 9 additions & 5 deletions packages/runtime/src/client/executor/zenstack-query-executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ export class ZenStackQueryExecutor<Schema extends SchemaDef> extends DefaultQuer
}

// proceed with the query with kysely interceptors
const result = await this.proceedQueryWithKyselyInterceptors(queryNode, queryId);
const result = await this.proceedQueryWithKyselyInterceptors(queryNode, compiledQuery.parameters, queryId);

// call after mutation hooks
await this.callAfterQueryInterceptionFilters(result, queryNode, mutationInterceptionInfo);
Expand All @@ -96,8 +96,12 @@ export class ZenStackQueryExecutor<Schema extends SchemaDef> extends DefaultQuer
return this.executeWithTransaction(task, !!mutationInterceptionInfo?.useTransactionForMutation);
}

private proceedQueryWithKyselyInterceptors(queryNode: RootOperationNode, queryId: QueryId) {
let proceed = (q: RootOperationNode) => this.proceedQuery(q, queryId);
private proceedQueryWithKyselyInterceptors(
queryNode: RootOperationNode,
parameters: readonly unknown[],
queryId: QueryId,
) {
let proceed = (q: RootOperationNode) => this.proceedQuery(q, parameters, queryId);

const makeTx = (p: typeof proceed) => (callback: OnKyselyQueryTransactionCallback) => {
return this.executeWithTransaction(() => callback(p));
Expand Down Expand Up @@ -125,10 +129,10 @@ export class ZenStackQueryExecutor<Schema extends SchemaDef> extends DefaultQuer
return proceed(queryNode);
}

private async proceedQuery(query: RootOperationNode, queryId: QueryId) {
private async proceedQuery(query: RootOperationNode, parameters: readonly unknown[], queryId: QueryId) {
// run built-in transformers
const finalQuery = this.nameMapper.transformNode(query);
const compiled = this.compileQuery(finalQuery);
const compiled: CompiledQuery = { ...this.compileQuery(finalQuery), parameters };
try {
return this.driver.txConnection
? await super
Expand Down
2 changes: 1 addition & 1 deletion packages/runtime/test/client-api/client-specs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { getSchema, schema } from '../test-schema';
import { makePostgresClient, makeSqliteClient } from '../utils';
import type { ClientContract } from '../../src';

export function createClientSpecs(dbName: string, logQueries = false, providers = ['sqlite', 'postgresql'] as const) {
export function createClientSpecs(dbName: string, logQueries = false, providers: string[] = ['sqlite', 'postgresql']) {
const logger = (provider: string) => (event: LogEvent) => {
if (event.level === 'query') {
console.log(`query(${provider}):`, event.query.sql, event.query.parameters);
Expand Down
79 changes: 79 additions & 0 deletions packages/runtime/test/client-api/raw-query.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import { afterEach, beforeEach, describe, expect, it } from 'vitest';
import type { ClientContract } from '../../src/client';
import { schema } from '../test-schema';
import { createClientSpecs } from './client-specs';

const PG_DB_NAME = 'client-api-raw-query-tests';

describe.each(createClientSpecs(PG_DB_NAME, true))('Client raw query tests', ({ createClient, provider }) => {
let client: ClientContract<typeof schema>;

beforeEach(async () => {
client = await createClient();
});

afterEach(async () => {
await client?.$disconnect();
});

it('works with executeRaw', async () => {
await client.user.create({
data: {
id: '1',
email: '[email protected]',
},
});

await expect(
client.$executeRaw`UPDATE "User" SET "email" = ${'[email protected]'} WHERE "id" = ${'1'}`,
).resolves.toBe(1);
await expect(client.user.findFirst()).resolves.toMatchObject({ email: '[email protected]' });
});

it('works with executeRawUnsafe', async () => {
await client.user.create({
data: {
id: '1',
email: '[email protected]',
},
});

const sql =
provider === 'postgresql'
? `UPDATE "User" SET "email" = $1 WHERE "id" = $2`
: `UPDATE "User" SET "email" = ? WHERE "id" = ?`;
await expect(client.$executeRawUnsafe(sql, '[email protected]', '1')).resolves.toBe(1);
await expect(client.user.findFirst()).resolves.toMatchObject({ email: '[email protected]' });
});

it('works with queryRaw', async () => {
await client.user.create({
data: {
id: '1',
email: '[email protected]',
},
});

const uid = '1';
const users = await client.$queryRaw<
{ id: string; email: string }[]
>`SELECT "User"."id", "User"."email" FROM "User" WHERE "User"."id" = ${uid}`;
expect(users).toEqual([{ id: '1', email: '[email protected]' }]);
});

it('works with queryRawUnsafe', async () => {
await client.user.create({
data: {
id: '1',
email: '[email protected]',
},
});

const sql =
provider === 'postgresql'
? `SELECT "User"."id", "User"."email" FROM "User" WHERE "User"."id" = $1`
: `SELECT "User"."id", "User"."email" FROM "User" WHERE "User"."id" = ?`;
const users = await client.$queryRawUnsafe<{ id: string; email: string }[]>(sql, '1');
expect(users).toEqual([{ id: '1', email: '[email protected]' }]);
});
});
2 changes: 1 addition & 1 deletion packages/testtools/src/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ export async function generateTsSchema(
extraSourceFiles?: Record<string, string>,
) {
const workDir = createTestProject();
console.log(`Working directory: ${workDir}`);
console.log(`Work directory: ${workDir}`);

const zmodelPath = path.join(workDir, 'schema.zmodel');
const noPrelude = schemaText.includes('datasource ');
Expand Down
1 change: 1 addition & 0 deletions turbo.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"tasks": {
"build": {
"dependsOn": ["^build"],
"inputs": ["src/**"],
"outputs": ["dist/**"]
},
"lint": {
Expand Down
Loading