diff --git a/.coderabbit.yaml b/.coderabbit.yaml new file mode 100644 index 00000000..c50c3b9e --- /dev/null +++ b/.coderabbit.yaml @@ -0,0 +1,8 @@ +# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json +language: 'en-US' +early_access: false +reviews: + auto_review: + enabled: true +chat: + auto_reply: true diff --git a/packages/runtime/src/client/client-impl.ts b/packages/runtime/src/client/client-impl.ts index ce836d6e..d17cd23f 100644 --- a/packages/runtime/src/client/client-impl.ts +++ b/packages/runtime/src/client/client-impl.ts @@ -1,11 +1,13 @@ import { lowerCaseFirst } from '@zenstackhq/common-helpers'; import type { SqliteDialectConfig } from 'kysely'; import { + CompiledQuery, DefaultConnectionProvider, DefaultQueryExecutor, Kysely, Log, PostgresDialect, + sql, SqliteDialect, type KyselyProps, type PostgresDialectConfig, @@ -209,6 +211,41 @@ export class ClientImpl { get $auth() { return this.auth; } + + $executeRaw(query: TemplateStringsArray, ...values: any[]) { + return createDeferredPromise(async () => { + const result = await sql(query, ...values).execute(this.kysely); + return Number(result.numAffectedRows ?? 0); + }); + } + + $executeRawUnsafe(query: string, ...values: any[]) { + return createDeferredPromise(async () => { + const compiledQuery = this.createRawCompiledQuery(query, values); + const result = await this.kysely.executeQuery(compiledQuery); + return Number(result.numAffectedRows ?? 0); + }); + } + + $queryRaw(query: TemplateStringsArray, ...values: any[]) { + return createDeferredPromise(async () => { + const result = await sql(query, ...values).execute(this.kysely); + return result.rows as T; + }); + } + + $queryRawUnsafe(query: string, ...values: any[]) { + return createDeferredPromise(async () => { + const compiledQuery = this.createRawCompiledQuery(query, values); + const result = await this.kysely.executeQuery(compiledQuery); + return result.rows as T; + }); + } + + private createRawCompiledQuery(query: string, values: any[]) { + const q = CompiledQuery.raw(query, values); + return { ...q, $raw: true } as CompiledQuery; + } } function createClientProxy(client: ClientImpl): ClientImpl { diff --git a/packages/runtime/src/client/contract.ts b/packages/runtime/src/client/contract.ts index 6a99af69..1a10a421 100644 --- a/packages/runtime/src/client/contract.ts +++ b/packages/runtime/src/client/contract.ts @@ -40,6 +40,44 @@ export type ClientContract = { */ readonly $options: ClientOptions; + /** + * Executes a prepared raw query and returns the number of affected rows. + * @example + * ``` + * const result = await client.$executeRaw`UPDATE User SET cool = ${true} WHERE email = ${'user@email.com'};` + * ``` + */ + $executeRaw(query: TemplateStringsArray, ...values: any[]): Promise; + + /** + * Executes a raw query and returns the number of affected rows. + * This method is susceptible to SQL injections. + * @example + * ``` + * const result = await client.$executeRawUnsafe('UPDATE User SET cool = $1 WHERE email = $2 ;', true, 'user@email.com') + * ``` + */ + $executeRawUnsafe(query: string, ...values: any[]): Promise; + + /** + * Performs a prepared raw query and returns the `SELECT` data. + * @example + * ``` + * const result = await client.$queryRaw`SELECT * FROM User WHERE id = ${1} OR email = ${'user@email.com'};` + * ``` + */ + $queryRaw(query: TemplateStringsArray, ...values: any[]): Promise; + + /** + * Performs a raw query and returns the `SELECT` data. + * This method is susceptible to SQL injections. + * @example + * ``` + * const result = await client.$queryRawUnsafe('SELECT * FROM User WHERE id = $1 OR email = $2;', 1, 'user@email.com') + * ``` + */ + $queryRawUnsafe(query: string, ...values: any[]): Promise; + /** * The current user identity. */ diff --git a/packages/runtime/src/client/executor/zenstack-query-executor.ts b/packages/runtime/src/client/executor/zenstack-query-executor.ts index 381ec9af..bb9a5472 100644 --- a/packages/runtime/src/client/executor/zenstack-query-executor.ts +++ b/packages/runtime/src/client/executor/zenstack-query-executor.ts @@ -81,7 +81,9 @@ export class ZenStackQueryExecutor extends DefaultQuer } // proceed with the query with kysely interceptors - const result = await this.proceedQueryWithKyselyInterceptors(queryNode, queryId); + // if the query is a raw query, we need to carry over the parameters + const queryParams = (compiledQuery as any).$raw ? compiledQuery.parameters : undefined; + const result = await this.proceedQueryWithKyselyInterceptors(queryNode, queryParams, queryId); // call after mutation hooks await this.callAfterQueryInterceptionFilters(result, queryNode, mutationInterceptionInfo); @@ -96,8 +98,12 @@ export class ZenStackQueryExecutor extends DefaultQuer return this.executeWithTransaction(task, !!mutationInterceptionInfo?.useTransactionForMutation); } - private proceedQueryWithKyselyInterceptors(queryNode: RootOperationNode, queryId: QueryId) { - let proceed = (q: RootOperationNode) => this.proceedQuery(q, queryId); + private proceedQueryWithKyselyInterceptors( + queryNode: RootOperationNode, + parameters: readonly unknown[] | undefined, + queryId: QueryId, + ) { + let proceed = (q: RootOperationNode) => this.proceedQuery(q, parameters, queryId); const makeTx = (p: typeof proceed) => (callback: OnKyselyQueryTransactionCallback) => { return this.executeWithTransaction(() => callback(p)); @@ -125,10 +131,13 @@ export class ZenStackQueryExecutor extends DefaultQuer return proceed(queryNode); } - private async proceedQuery(query: RootOperationNode, queryId: QueryId) { + private async proceedQuery(query: RootOperationNode, parameters: readonly unknown[] | undefined, queryId: QueryId) { // run built-in transformers const finalQuery = this.nameMapper.transformNode(query); - const compiled = this.compileQuery(finalQuery); + let compiled = this.compileQuery(finalQuery); + if (parameters) { + compiled = { ...compiled, parameters }; + } try { return this.driver.txConnection ? await super diff --git a/packages/runtime/test/client-api/client-specs.ts b/packages/runtime/test/client-api/client-specs.ts index f05c5fcd..6a14ab43 100644 --- a/packages/runtime/test/client-api/client-specs.ts +++ b/packages/runtime/test/client-api/client-specs.ts @@ -3,7 +3,7 @@ import { getSchema, schema } from '../test-schema'; import { makePostgresClient, makeSqliteClient } from '../utils'; import type { ClientContract } from '../../src'; -export function createClientSpecs(dbName: string, logQueries = false, providers = ['sqlite', 'postgresql'] as const) { +export function createClientSpecs(dbName: string, logQueries = false, providers: string[] = ['sqlite', 'postgresql']) { const logger = (provider: string) => (event: LogEvent) => { if (event.level === 'query') { console.log(`query(${provider}):`, event.query.sql, event.query.parameters); diff --git a/packages/runtime/test/client-api/raw-query.test.ts b/packages/runtime/test/client-api/raw-query.test.ts new file mode 100644 index 00000000..f8ad6d41 --- /dev/null +++ b/packages/runtime/test/client-api/raw-query.test.ts @@ -0,0 +1,79 @@ +import { afterEach, beforeEach, describe, expect, it } from 'vitest'; +import type { ClientContract } from '../../src/client'; +import { schema } from '../test-schema'; +import { createClientSpecs } from './client-specs'; + +const PG_DB_NAME = 'client-api-raw-query-tests'; + +describe.each(createClientSpecs(PG_DB_NAME, true))('Client raw query tests', ({ createClient, provider }) => { + let client: ClientContract; + + beforeEach(async () => { + client = await createClient(); + }); + + afterEach(async () => { + await client?.$disconnect(); + }); + + it('works with executeRaw', async () => { + await client.user.create({ + data: { + id: '1', + email: 'u1@test.com', + }, + }); + + await expect( + client.$executeRaw`UPDATE "User" SET "email" = ${'u2@test.com'} WHERE "id" = ${'1'}`, + ).resolves.toBe(1); + await expect(client.user.findFirst()).resolves.toMatchObject({ email: 'u2@test.com' }); + }); + + it('works with executeRawUnsafe', async () => { + await client.user.create({ + data: { + id: '1', + email: 'u1@test.com', + }, + }); + + const sql = + provider === 'postgresql' + ? `UPDATE "User" SET "email" = $1 WHERE "id" = $2` + : `UPDATE "User" SET "email" = ? WHERE "id" = ?`; + await expect(client.$executeRawUnsafe(sql, 'u2@test.com', '1')).resolves.toBe(1); + await expect(client.user.findFirst()).resolves.toMatchObject({ email: 'u2@test.com' }); + }); + + it('works with queryRaw', async () => { + await client.user.create({ + data: { + id: '1', + email: 'u1@test.com', + }, + }); + + const uid = '1'; + const users = await client.$queryRaw< + { id: string; email: string }[] + >`SELECT "User"."id", "User"."email" FROM "User" WHERE "User"."id" = ${uid}`; + expect(users).toEqual([{ id: '1', email: 'u1@test.com' }]); + }); + + it('works with queryRawUnsafe', async () => { + await client.user.create({ + data: { + id: '1', + email: 'u1@test.com', + }, + }); + + const sql = + provider === 'postgresql' + ? `SELECT "User"."id", "User"."email" FROM "User" WHERE "User"."id" = $1` + : `SELECT "User"."id", "User"."email" FROM "User" WHERE "User"."id" = ?`; + const users = await client.$queryRawUnsafe<{ id: string; email: string }[]>(sql, '1'); + expect(users).toEqual([{ id: '1', email: 'u1@test.com' }]); + }); +}); diff --git a/packages/testtools/src/schema.ts b/packages/testtools/src/schema.ts index 46b733ea..b5a6f7d3 100644 --- a/packages/testtools/src/schema.ts +++ b/packages/testtools/src/schema.ts @@ -35,7 +35,7 @@ export async function generateTsSchema( extraSourceFiles?: Record, ) { const workDir = createTestProject(); - console.log(`Working directory: ${workDir}`); + console.log(`Work directory: ${workDir}`); const zmodelPath = path.join(workDir, 'schema.zmodel'); const noPrelude = schemaText.includes('datasource '); diff --git a/turbo.json b/turbo.json index 72d14c56..31aad504 100644 --- a/turbo.json +++ b/turbo.json @@ -3,6 +3,7 @@ "tasks": { "build": { "dependsOn": ["^build"], + "inputs": ["src/**"], "outputs": ["dist/**"] }, "lint": {