diff --git a/packages/language/res/stdlib.zmodel b/packages/language/res/stdlib.zmodel index f1b46e84..7df81364 100644 --- a/packages/language/res/stdlib.zmodel +++ b/packages/language/res/stdlib.zmodel @@ -126,11 +126,11 @@ function dbgenerated(expr: String?): Any { function contains(field: String, search: String, caseInSensitive: Boolean?): Boolean { } @@@expressionContext([AccessPolicy, ValidationRule]) -/** - * If the field value matches the search condition with [full-text-search](https://www.prisma.io/docs/concepts/components/prisma-client/full-text-search). Need to enable "fullTextSearch" preview feature to use. - */ -function search(field: String, search: String): Boolean { -} @@@expressionContext([AccessPolicy]) +// /** +// * If the field value matches the search condition with [full-text-search](https://www.prisma.io/docs/concepts/components/prisma-client/full-text-search). Need to enable "fullTextSearch" preview feature to use. +// */ +// function search(field: String, search: String): Boolean { +// } @@@expressionContext([AccessPolicy]) /** * Checks the field value starts with the search string. By default, the search is case-sensitive, and @@ -151,25 +151,25 @@ function endsWith(field: String, search: String, caseInSensitive: Boolean?): Boo } @@@expressionContext([AccessPolicy, ValidationRule]) /** - * If the field value (a list) has the given search value + * Checks if the list field has the given search value */ function has(field: Any[], search: Any): Boolean { } @@@expressionContext([AccessPolicy, ValidationRule]) /** - * If the field value (a list) has every element of the search list + * Checks if the list field has at least one element of the search list */ -function hasEvery(field: Any[], search: Any[]): Boolean { +function hasSome(field: Any[], search: Any[]): Boolean { } @@@expressionContext([AccessPolicy, ValidationRule]) /** - * If the field value (a list) has at least one element of the search list + * Checks if the list field has every element of the search list */ -function hasSome(field: Any[], search: Any[]): Boolean { +function hasEvery(field: Any[], search: Any[]): Boolean { } @@@expressionContext([AccessPolicy, ValidationRule]) /** - * If the field value (a list) is empty + * Checks if the list field is empty */ function isEmpty(field: Any[]): Boolean { } @@@expressionContext([AccessPolicy, ValidationRule]) @@ -551,9 +551,9 @@ function length(field: Any): Int { /** - * Validates a string field value matches a regex. + * Validates a string field value matches a regex pattern. */ -function regex(field: String, regex: String): Boolean { +function regex(field: String, pattern: String): Boolean { } @@@expressionContext([ValidationRule]) /** diff --git a/packages/language/src/utils.ts b/packages/language/src/utils.ts index c361feee..894c6fc7 100644 --- a/packages/language/src/utils.ts +++ b/packages/language/src/utils.ts @@ -443,9 +443,9 @@ export function getAllDeclarationsIncludingImports(documents: LangiumDocuments, } export function getAuthDecl(decls: (DataModel | TypeDef)[]) { - let authModel = decls.find((m) => hasAttribute(m, '@@auth')); + let authModel = decls.find((d) => hasAttribute(d, '@@auth')); if (!authModel) { - authModel = decls.find((m) => m.name === 'User'); + authModel = decls.find((d) => d.name === 'User'); } return authModel; } diff --git a/packages/runtime/src/client/client-impl.ts b/packages/runtime/src/client/client-impl.ts index bcb753a3..24b2d2be 100644 --- a/packages/runtime/src/client/client-impl.ts +++ b/packages/runtime/src/client/client-impl.ts @@ -233,6 +233,12 @@ export class ClientImpl { return (procOptions[name] as Function).apply(this, [this, ...args]); } + async $connect() { + await this.kysely.connection().execute(async (conn) => { + await conn.executeQuery(sql`select 1`.compile(this.kysely)); + }); + } + async $disconnect() { await this.kysely.destroy(); } diff --git a/packages/runtime/src/client/contract.ts b/packages/runtime/src/client/contract.ts index d2c19cb1..0d90bc88 100644 --- a/packages/runtime/src/client/contract.ts +++ b/packages/runtime/src/client/contract.ts @@ -151,7 +151,12 @@ export type ClientContract = { $unuseAll(): ClientContract; /** - * Disconnects the underlying Kysely instance from the database. + * Eagerly connects to the database. + */ + $connect(): Promise; + + /** + * Explicitly disconnects from the database. */ $disconnect(): Promise; diff --git a/packages/runtime/src/client/crud/dialects/base-dialect.ts b/packages/runtime/src/client/crud/dialects/base-dialect.ts index 1b8b1e1c..a1c7501b 100644 --- a/packages/runtime/src/client/crud/dialects/base-dialect.ts +++ b/packages/runtime/src/client/crud/dialects/base-dialect.ts @@ -89,14 +89,7 @@ export abstract class BaseCrudDialect { result = this.buildSkipTake(result, skip, take); // orderBy - result = this.buildOrderBy( - result, - model, - modelAlias, - args.orderBy, - skip !== undefined || take !== undefined, - negateOrderBy, - ); + result = this.buildOrderBy(result, model, modelAlias, args.orderBy, negateOrderBy); // distinct if ('distinct' in args && (args as any).distinct) { @@ -748,15 +741,10 @@ export abstract class BaseCrudDialect { model: string, modelAlias: string, orderBy: OrArray, boolean, boolean>> | undefined, - useDefaultIfEmpty: boolean, negated: boolean, ) { if (!orderBy) { - if (useDefaultIfEmpty) { - orderBy = makeDefaultOrderBy(this.schema, model); - } else { - return query; - } + return query; } let result = query; @@ -862,7 +850,7 @@ export abstract class BaseCrudDialect { ), ); }); - result = this.buildOrderBy(result, fieldDef.type, relationModel, value, false, negated); + result = this.buildOrderBy(result, fieldDef.type, relationModel, value, negated); } } } diff --git a/packages/runtime/src/client/crud/operations/aggregate.ts b/packages/runtime/src/client/crud/operations/aggregate.ts index 6362fbe6..f92a8518 100644 --- a/packages/runtime/src/client/crud/operations/aggregate.ts +++ b/packages/runtime/src/client/crud/operations/aggregate.ts @@ -52,14 +52,7 @@ export class AggregateOperationHandler extends BaseOpe subQuery = this.dialect.buildSkipTake(subQuery, skip, take); // orderBy - subQuery = this.dialect.buildOrderBy( - subQuery, - this.model, - this.model, - parsedArgs.orderBy, - skip !== undefined || take !== undefined, - negateOrderBy, - ); + subQuery = this.dialect.buildOrderBy(subQuery, this.model, this.model, parsedArgs.orderBy, negateOrderBy); return subQuery.as('$sub'); }); diff --git a/packages/runtime/src/client/crud/operations/group-by.ts b/packages/runtime/src/client/crud/operations/group-by.ts index 4f4a083f..1f8b880a 100644 --- a/packages/runtime/src/client/crud/operations/group-by.ts +++ b/packages/runtime/src/client/crud/operations/group-by.ts @@ -11,51 +11,33 @@ export class GroupByOperationHandler extends BaseOpera // parse args const parsedArgs = this.inputValidator.validateGroupByArgs(this.model, normalizedArgs); - let query = this.kysely.selectFrom((eb) => { - // nested query for filtering and pagination - - // where - let subQuery = eb - .selectFrom(this.model) - .selectAll() - .where(() => this.dialect.buildFilter(this.model, this.model, parsedArgs?.where)); - - // skip & take - const skip = parsedArgs?.skip; - let take = parsedArgs?.take; - let negateOrderBy = false; - if (take !== undefined && take < 0) { - negateOrderBy = true; - take = -take; - } - subQuery = this.dialect.buildSkipTake(subQuery, skip, take); - - // default orderBy - subQuery = this.dialect.buildOrderBy( - subQuery, - this.model, - this.model, - undefined, - skip !== undefined || take !== undefined, - negateOrderBy, - ); - - return subQuery.as('$sub'); - }); + let query = this.kysely + .selectFrom(this.model) + .where(() => this.dialect.buildFilter(this.model, this.model, parsedArgs?.where)); - const fieldRef = (field: string) => this.dialect.fieldRef(this.model, field, '$sub'); + const fieldRef = (field: string) => this.dialect.fieldRef(this.model, field); // groupBy const bys = typeof parsedArgs.by === 'string' ? [parsedArgs.by] : (parsedArgs.by as string[]); query = query.groupBy(bys.map((by) => fieldRef(by))); + // skip & take + const skip = parsedArgs?.skip; + let take = parsedArgs?.take; + let negateOrderBy = false; + if (take !== undefined && take < 0) { + negateOrderBy = true; + take = -take; + } + query = this.dialect.buildSkipTake(query, skip, take); + // orderBy if (parsedArgs.orderBy) { - query = this.dialect.buildOrderBy(query, this.model, '$sub', parsedArgs.orderBy, false, false); + query = this.dialect.buildOrderBy(query, this.model, this.model, parsedArgs.orderBy, negateOrderBy); } if (parsedArgs.having) { - query = query.having(() => this.dialect.buildFilter(this.model, '$sub', parsedArgs.having)); + query = query.having(() => this.dialect.buildFilter(this.model, this.model, parsedArgs.having)); } // select all by fields diff --git a/packages/runtime/src/client/crud/validator/index.ts b/packages/runtime/src/client/crud/validator/index.ts index 09bfb8e5..1e32a865 100644 --- a/packages/runtime/src/client/crud/validator/index.ts +++ b/packages/runtime/src/client/crud/validator/index.ts @@ -50,11 +50,11 @@ import { addStringValidation, } from './utils'; +const schemaCache = new WeakMap>(); + type GetSchemaFunc = (model: GetModels, options: Options) => ZodType; export class InputValidator { - private schemaCache = new Map(); - constructor(private readonly client: ClientContract) {} private get schema() { @@ -192,6 +192,24 @@ export class InputValidator { ); } + private getSchemaCache(cacheKey: string) { + let thisCache = schemaCache.get(this.schema); + if (!thisCache) { + thisCache = new Map(); + schemaCache.set(this.schema, thisCache); + } + return thisCache.get(cacheKey); + } + + private setSchemaCache(cacheKey: string, schema: ZodType) { + let thisCache = schemaCache.get(this.schema); + if (!thisCache) { + thisCache = new Map(); + schemaCache.set(this.schema, thisCache); + } + return thisCache.set(cacheKey, schema); + } + private validate( model: GetModels, operation: string, @@ -200,14 +218,16 @@ export class InputValidator { args: unknown, ) { const cacheKey = stableStringify({ + type: 'model', model, operation, options, + extraValidationsEnabled: this.extraValidationsEnabled, }); - let schema = this.schemaCache.get(cacheKey!); + let schema = this.getSchemaCache(cacheKey!); if (!schema) { schema = getSchema(model, options); - this.schemaCache.set(cacheKey!, schema); + this.setSchemaCache(cacheKey!, schema); } const { error, data } = schema.safeParse(args); if (error) { @@ -293,8 +313,12 @@ export class InputValidator { } private makeTypeDefSchema(type: string): z.ZodType { - const key = `$typedef-${type}`; - let schema = this.schemaCache.get(key); + const key = stableStringify({ + type: 'typedef', + name: type, + extraValidationsEnabled: this.extraValidationsEnabled, + }); + let schema = this.getSchemaCache(key!); if (schema) { return schema; } @@ -316,7 +340,7 @@ export class InputValidator { ), ) .passthrough(); - this.schemaCache.set(key, schema); + this.setSchemaCache(key!, schema); return schema; } diff --git a/packages/runtime/src/client/executor/zenstack-query-executor.ts b/packages/runtime/src/client/executor/zenstack-query-executor.ts index f3e855fa..e1fb6298 100644 --- a/packages/runtime/src/client/executor/zenstack-query-executor.ts +++ b/packages/runtime/src/client/executor/zenstack-query-executor.ts @@ -22,7 +22,7 @@ import { type RootOperationNode, } from 'kysely'; import { match } from 'ts-pattern'; -import type { GetModels, SchemaDef } from '../../schema'; +import type { GetModels, ModelDef, SchemaDef, TypeDefDef } from '../../schema'; import { type ClientImpl } from '../client-impl'; import { TransactionIsolationLevel, type ClientContract } from '../contract'; import { InternalError, QueryError, ZenStackError } from '../errors'; @@ -42,7 +42,7 @@ type MutationInfo = { }; export class ZenStackQueryExecutor extends DefaultQueryExecutor { - private readonly nameMapper: QueryNameMapper; + private readonly nameMapper: QueryNameMapper | undefined; constructor( private client: ClientImpl, @@ -54,7 +54,21 @@ export class ZenStackQueryExecutor extends DefaultQuer private suppressMutationHooks: boolean = false, ) { super(compiler, adapter, connectionProvider, plugins); - this.nameMapper = new QueryNameMapper(client.$schema); + + if (this.schemaHasMappedNames(client.$schema)) { + this.nameMapper = new QueryNameMapper(client.$schema); + } + } + + private schemaHasMappedNames(schema: Schema) { + const hasMapAttr = (decl: ModelDef | TypeDefDef) => { + if (decl.attributes?.some((attr) => attr.name === '@@map')) { + return true; + } + return Object.values(decl.fields).some((field) => field.attributes?.some((attr) => attr.name === '@map')); + }; + + return Object.values(schema.models).some(hasMapAttr) || Object.values(schema.typeDefs ?? []).some(hasMapAttr); } private get kysely() { @@ -170,7 +184,7 @@ export class ZenStackQueryExecutor extends DefaultQuer if (this.suppressMutationHooks || !this.isMutationNode(query) || !this.hasEntityMutationPlugins) { // no need to handle mutation hooks, just proceed - const finalQuery = this.nameMapper.transformNode(query); + const finalQuery = this.processNameMapping(query); compiled = this.compileQuery(finalQuery); if (parameters) { compiled = { ...compiled, parameters }; @@ -189,7 +203,7 @@ export class ZenStackQueryExecutor extends DefaultQuer returning: ReturningNode.create([SelectionNode.createSelectAll()]), }; } - const finalQuery = this.nameMapper.transformNode(query); + const finalQuery = this.processNameMapping(query); compiled = this.compileQuery(finalQuery); if (parameters) { compiled = { ...compiled, parameters }; @@ -239,6 +253,10 @@ export class ZenStackQueryExecutor extends DefaultQuer return result; } + private processNameMapping(query: Node): Node { + return this.nameMapper?.transformNode(query) ?? query; + } + private createClientForConnection(connection: DatabaseConnection, inTx: boolean) { const innerExecutor = this.withConnectionProvider(new SingleConnectionProvider(connection)); innerExecutor.suppressMutationHooks = true; diff --git a/tests/e2e/orm/client-api/connect-disconnect.test.ts b/tests/e2e/orm/client-api/connect-disconnect.test.ts new file mode 100644 index 00000000..5154ead9 --- /dev/null +++ b/tests/e2e/orm/client-api/connect-disconnect.test.ts @@ -0,0 +1,29 @@ +import { createTestClient } from '@zenstackhq/testtools'; +import { describe, expect, it } from 'vitest'; + +describe('Client $connect and $disconnect tests', () => { + it('works with connect and disconnect', async () => { + const db = await createTestClient( + ` + model User { + id String @id @default(cuid()) + email String @unique + } + `, + ); + + // connect to the database + await db.$connect(); + + // perform a simple operation + await db.user.create({ + data: { + email: 'u1@test.com', + }, + }); + + await db.$disconnect(); + + await expect(db.user.findFirst()).rejects.toThrow(); + }); +}); diff --git a/tests/e2e/orm/client-api/find.test.ts b/tests/e2e/orm/client-api/find.test.ts index 6d188843..cdf67584 100644 --- a/tests/e2e/orm/client-api/find.test.ts +++ b/tests/e2e/orm/client-api/find.test.ts @@ -187,6 +187,7 @@ describe('Client find tests ', () => { await expect( client.user.findMany({ cursor: { id: user2.id }, + orderBy: { id: 'asc' }, }), ).resolves.toEqual([user2, user3]); @@ -195,6 +196,7 @@ describe('Client find tests ', () => { client.user.findMany({ skip: 1, cursor: { id: user1.id }, + orderBy: { id: 'asc' }, }), ).resolves.toEqual([user2, user3]); @@ -221,6 +223,7 @@ describe('Client find tests ', () => { client.user.findMany({ skip: 1, cursor: { id: user1.id, role: 'ADMIN' }, + orderBy: { id: 'asc' }, }), ).resolves.toEqual([user2, user3]); @@ -238,6 +241,7 @@ describe('Client find tests ', () => { skip: 1, take: -2, cursor: { id: user3.id }, + orderBy: { id: 'asc' }, }), ).resolves.toEqual([user1, user2]); }); @@ -343,6 +347,7 @@ describe('Client find tests ', () => { posts: { skip: 1, take: -2, + orderBy: { id: 'asc' }, }, }, }), diff --git a/tests/e2e/orm/client-api/group-by.test.ts b/tests/e2e/orm/client-api/group-by.test.ts index 08da59fe..d4eb2beb 100644 --- a/tests/e2e/orm/client-api/group-by.test.ts +++ b/tests/e2e/orm/client-api/group-by.test.ts @@ -1,7 +1,7 @@ -import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import type { ClientContract } from '@zenstackhq/runtime'; -import { schema } from '../schemas/basic'; import { createTestClient } from '@zenstackhq/testtools'; +import { afterEach, beforeEach, describe, expect, it } from 'vitest'; +import { schema } from '../schemas/basic'; import { createPosts, createUser } from './utils'; describe('Client groupBy tests', () => { @@ -57,7 +57,7 @@ describe('Client groupBy tests', () => { take: -1, orderBy: { email: 'desc' }, }), - ).resolves.toEqual([{ email: 'u1@test.com' }]); + ).resolves.toEqual([{ email: 'u3@test.com' }]); await expect( client.user.groupBy({ @@ -66,7 +66,7 @@ describe('Client groupBy tests', () => { take: -2, orderBy: { email: 'desc' }, }), - ).resolves.toEqual(expect.arrayContaining([{ email: 'u2@test.com' }, { email: 'u1@test.com' }])); + ).resolves.toEqual(expect.arrayContaining([{ email: 'u2@test.com' }, { email: 'u3@test.com' }])); await expect( client.user.groupBy({ @@ -88,10 +88,12 @@ describe('Client groupBy tests', () => { }, _count: true, }), - ).resolves.toEqual([ - { name: 'User', role: 'USER', _count: 2 }, - { name: 'Admin', role: 'ADMIN', _count: 1 }, - ]); + ).resolves.toEqual( + expect.arrayContaining([ + { name: 'User', role: 'USER', _count: 2 }, + { name: 'Admin', role: 'ADMIN', _count: 1 }, + ]), + ); await expect( client.post.groupBy({