Skip to content

Commit 0067817

Browse files
authored
fix: orderBy issue when used with groupBy, add zod cache and $connect API (#317)
* fix: orderBy issue when used with groupBy, add zod cache and $connect API * addressing PR comments
1 parent 2be71a5 commit 0067817

File tree

12 files changed

+145
-93
lines changed

12 files changed

+145
-93
lines changed

packages/language/res/stdlib.zmodel

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,11 @@ function dbgenerated(expr: String?): Any {
126126
function contains(field: String, search: String, caseInSensitive: Boolean?): Boolean {
127127
} @@@expressionContext([AccessPolicy, ValidationRule])
128128

129-
/**
130-
* If the field value matches the search condition with [full-text-search](https://www.prisma.io/docs/concepts/components/prisma-client/full-text-search). Need to enable "fullTextSearch" preview feature to use.
131-
*/
132-
function search(field: String, search: String): Boolean {
133-
} @@@expressionContext([AccessPolicy])
129+
// /**
130+
// * If the field value matches the search condition with [full-text-search](https://www.prisma.io/docs/concepts/components/prisma-client/full-text-search). Need to enable "fullTextSearch" preview feature to use.
131+
// */
132+
// function search(field: String, search: String): Boolean {
133+
// } @@@expressionContext([AccessPolicy])
134134

135135
/**
136136
* Checks the field value starts with the search string. By default, the search is case-sensitive, and
@@ -151,25 +151,25 @@ function endsWith(field: String, search: String, caseInSensitive: Boolean?): Boo
151151
} @@@expressionContext([AccessPolicy, ValidationRule])
152152

153153
/**
154-
* If the field value (a list) has the given search value
154+
* Checks if the list field has the given search value
155155
*/
156156
function has(field: Any[], search: Any): Boolean {
157157
} @@@expressionContext([AccessPolicy, ValidationRule])
158158

159159
/**
160-
* If the field value (a list) has every element of the search list
160+
* Checks if the list field has at least one element of the search list
161161
*/
162-
function hasEvery(field: Any[], search: Any[]): Boolean {
162+
function hasSome(field: Any[], search: Any[]): Boolean {
163163
} @@@expressionContext([AccessPolicy, ValidationRule])
164164

165165
/**
166-
* If the field value (a list) has at least one element of the search list
166+
* Checks if the list field has every element of the search list
167167
*/
168-
function hasSome(field: Any[], search: Any[]): Boolean {
168+
function hasEvery(field: Any[], search: Any[]): Boolean {
169169
} @@@expressionContext([AccessPolicy, ValidationRule])
170170

171171
/**
172-
* If the field value (a list) is empty
172+
* Checks if the list field is empty
173173
*/
174174
function isEmpty(field: Any[]): Boolean {
175175
} @@@expressionContext([AccessPolicy, ValidationRule])
@@ -551,9 +551,9 @@ function length(field: Any): Int {
551551

552552

553553
/**
554-
* Validates a string field value matches a regex.
554+
* Validates a string field value matches a regex pattern.
555555
*/
556-
function regex(field: String, regex: String): Boolean {
556+
function regex(field: String, pattern: String): Boolean {
557557
} @@@expressionContext([ValidationRule])
558558

559559
/**

packages/language/src/utils.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -443,9 +443,9 @@ export function getAllDeclarationsIncludingImports(documents: LangiumDocuments,
443443
}
444444

445445
export function getAuthDecl(decls: (DataModel | TypeDef)[]) {
446-
let authModel = decls.find((m) => hasAttribute(m, '@@auth'));
446+
let authModel = decls.find((d) => hasAttribute(d, '@@auth'));
447447
if (!authModel) {
448-
authModel = decls.find((m) => m.name === 'User');
448+
authModel = decls.find((d) => d.name === 'User');
449449
}
450450
return authModel;
451451
}

packages/runtime/src/client/client-impl.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,12 @@ export class ClientImpl<Schema extends SchemaDef> {
233233
return (procOptions[name] as Function).apply(this, [this, ...args]);
234234
}
235235

236+
async $connect() {
237+
await this.kysely.connection().execute(async (conn) => {
238+
await conn.executeQuery(sql`select 1`.compile(this.kysely));
239+
});
240+
}
241+
236242
async $disconnect() {
237243
await this.kysely.destroy();
238244
}

packages/runtime/src/client/contract.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,12 @@ export type ClientContract<Schema extends SchemaDef> = {
151151
$unuseAll(): ClientContract<Schema>;
152152

153153
/**
154-
* Disconnects the underlying Kysely instance from the database.
154+
* Eagerly connects to the database.
155+
*/
156+
$connect(): Promise<void>;
157+
158+
/**
159+
* Explicitly disconnects from the database.
155160
*/
156161
$disconnect(): Promise<void>;
157162

packages/runtime/src/client/crud/dialects/base-dialect.ts

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
8989
result = this.buildSkipTake(result, skip, take);
9090

9191
// orderBy
92-
result = this.buildOrderBy(
93-
result,
94-
model,
95-
modelAlias,
96-
args.orderBy,
97-
skip !== undefined || take !== undefined,
98-
negateOrderBy,
99-
);
92+
result = this.buildOrderBy(result, model, modelAlias, args.orderBy, negateOrderBy);
10093

10194
// distinct
10295
if ('distinct' in args && (args as any).distinct) {
@@ -748,15 +741,10 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
748741
model: string,
749742
modelAlias: string,
750743
orderBy: OrArray<OrderBy<Schema, GetModels<Schema>, boolean, boolean>> | undefined,
751-
useDefaultIfEmpty: boolean,
752744
negated: boolean,
753745
) {
754746
if (!orderBy) {
755-
if (useDefaultIfEmpty) {
756-
orderBy = makeDefaultOrderBy(this.schema, model);
757-
} else {
758-
return query;
759-
}
747+
return query;
760748
}
761749

762750
let result = query;
@@ -862,7 +850,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
862850
),
863851
);
864852
});
865-
result = this.buildOrderBy(result, fieldDef.type, relationModel, value, false, negated);
853+
result = this.buildOrderBy(result, fieldDef.type, relationModel, value, negated);
866854
}
867855
}
868856
}

packages/runtime/src/client/crud/operations/aggregate.ts

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,7 @@ export class AggregateOperationHandler<Schema extends SchemaDef> extends BaseOpe
5252
subQuery = this.dialect.buildSkipTake(subQuery, skip, take);
5353

5454
// orderBy
55-
subQuery = this.dialect.buildOrderBy(
56-
subQuery,
57-
this.model,
58-
this.model,
59-
parsedArgs.orderBy,
60-
skip !== undefined || take !== undefined,
61-
negateOrderBy,
62-
);
55+
subQuery = this.dialect.buildOrderBy(subQuery, this.model, this.model, parsedArgs.orderBy, negateOrderBy);
6356

6457
return subQuery.as('$sub');
6558
});

packages/runtime/src/client/crud/operations/group-by.ts

Lines changed: 16 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -11,51 +11,33 @@ export class GroupByOperationHandler<Schema extends SchemaDef> extends BaseOpera
1111
// parse args
1212
const parsedArgs = this.inputValidator.validateGroupByArgs(this.model, normalizedArgs);
1313

14-
let query = this.kysely.selectFrom((eb) => {
15-
// nested query for filtering and pagination
16-
17-
// where
18-
let subQuery = eb
19-
.selectFrom(this.model)
20-
.selectAll()
21-
.where(() => this.dialect.buildFilter(this.model, this.model, parsedArgs?.where));
22-
23-
// skip & take
24-
const skip = parsedArgs?.skip;
25-
let take = parsedArgs?.take;
26-
let negateOrderBy = false;
27-
if (take !== undefined && take < 0) {
28-
negateOrderBy = true;
29-
take = -take;
30-
}
31-
subQuery = this.dialect.buildSkipTake(subQuery, skip, take);
32-
33-
// default orderBy
34-
subQuery = this.dialect.buildOrderBy(
35-
subQuery,
36-
this.model,
37-
this.model,
38-
undefined,
39-
skip !== undefined || take !== undefined,
40-
negateOrderBy,
41-
);
42-
43-
return subQuery.as('$sub');
44-
});
14+
let query = this.kysely
15+
.selectFrom(this.model)
16+
.where(() => this.dialect.buildFilter(this.model, this.model, parsedArgs?.where));
4517

46-
const fieldRef = (field: string) => this.dialect.fieldRef(this.model, field, '$sub');
18+
const fieldRef = (field: string) => this.dialect.fieldRef(this.model, field);
4719

4820
// groupBy
4921
const bys = typeof parsedArgs.by === 'string' ? [parsedArgs.by] : (parsedArgs.by as string[]);
5022
query = query.groupBy(bys.map((by) => fieldRef(by)));
5123

24+
// skip & take
25+
const skip = parsedArgs?.skip;
26+
let take = parsedArgs?.take;
27+
let negateOrderBy = false;
28+
if (take !== undefined && take < 0) {
29+
negateOrderBy = true;
30+
take = -take;
31+
}
32+
query = this.dialect.buildSkipTake(query, skip, take);
33+
5234
// orderBy
5335
if (parsedArgs.orderBy) {
54-
query = this.dialect.buildOrderBy(query, this.model, '$sub', parsedArgs.orderBy, false, false);
36+
query = this.dialect.buildOrderBy(query, this.model, this.model, parsedArgs.orderBy, negateOrderBy);
5537
}
5638

5739
if (parsedArgs.having) {
58-
query = query.having(() => this.dialect.buildFilter(this.model, '$sub', parsedArgs.having));
40+
query = query.having(() => this.dialect.buildFilter(this.model, this.model, parsedArgs.having));
5941
}
6042

6143
// select all by fields

packages/runtime/src/client/crud/validator/index.ts

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,11 @@ import {
5050
addStringValidation,
5151
} from './utils';
5252

53+
const schemaCache = new WeakMap<SchemaDef, Map<string, ZodType>>();
54+
5355
type GetSchemaFunc<Schema extends SchemaDef, Options> = (model: GetModels<Schema>, options: Options) => ZodType;
5456

5557
export class InputValidator<Schema extends SchemaDef> {
56-
private schemaCache = new Map<string, ZodType>();
57-
5858
constructor(private readonly client: ClientContract<Schema>) {}
5959

6060
private get schema() {
@@ -192,6 +192,24 @@ export class InputValidator<Schema extends SchemaDef> {
192192
);
193193
}
194194

195+
private getSchemaCache(cacheKey: string) {
196+
let thisCache = schemaCache.get(this.schema);
197+
if (!thisCache) {
198+
thisCache = new Map<string, ZodType>();
199+
schemaCache.set(this.schema, thisCache);
200+
}
201+
return thisCache.get(cacheKey);
202+
}
203+
204+
private setSchemaCache(cacheKey: string, schema: ZodType) {
205+
let thisCache = schemaCache.get(this.schema);
206+
if (!thisCache) {
207+
thisCache = new Map<string, ZodType>();
208+
schemaCache.set(this.schema, thisCache);
209+
}
210+
return thisCache.set(cacheKey, schema);
211+
}
212+
195213
private validate<T, Options = undefined>(
196214
model: GetModels<Schema>,
197215
operation: string,
@@ -200,14 +218,16 @@ export class InputValidator<Schema extends SchemaDef> {
200218
args: unknown,
201219
) {
202220
const cacheKey = stableStringify({
221+
type: 'model',
203222
model,
204223
operation,
205224
options,
225+
extraValidationsEnabled: this.extraValidationsEnabled,
206226
});
207-
let schema = this.schemaCache.get(cacheKey!);
227+
let schema = this.getSchemaCache(cacheKey!);
208228
if (!schema) {
209229
schema = getSchema(model, options);
210-
this.schemaCache.set(cacheKey!, schema);
230+
this.setSchemaCache(cacheKey!, schema);
211231
}
212232
const { error, data } = schema.safeParse(args);
213233
if (error) {
@@ -293,8 +313,12 @@ export class InputValidator<Schema extends SchemaDef> {
293313
}
294314

295315
private makeTypeDefSchema(type: string): z.ZodType {
296-
const key = `$typedef-${type}`;
297-
let schema = this.schemaCache.get(key);
316+
const key = stableStringify({
317+
type: 'typedef',
318+
name: type,
319+
extraValidationsEnabled: this.extraValidationsEnabled,
320+
});
321+
let schema = this.getSchemaCache(key!);
298322
if (schema) {
299323
return schema;
300324
}
@@ -316,7 +340,7 @@ export class InputValidator<Schema extends SchemaDef> {
316340
),
317341
)
318342
.passthrough();
319-
this.schemaCache.set(key, schema);
343+
this.setSchemaCache(key!, schema);
320344
return schema;
321345
}
322346

packages/runtime/src/client/executor/zenstack-query-executor.ts

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import {
2222
type RootOperationNode,
2323
} from 'kysely';
2424
import { match } from 'ts-pattern';
25-
import type { GetModels, SchemaDef } from '../../schema';
25+
import type { GetModels, ModelDef, SchemaDef, TypeDefDef } from '../../schema';
2626
import { type ClientImpl } from '../client-impl';
2727
import { TransactionIsolationLevel, type ClientContract } from '../contract';
2828
import { InternalError, QueryError, ZenStackError } from '../errors';
@@ -42,7 +42,7 @@ type MutationInfo<Schema extends SchemaDef> = {
4242
};
4343

4444
export class ZenStackQueryExecutor<Schema extends SchemaDef> extends DefaultQueryExecutor {
45-
private readonly nameMapper: QueryNameMapper;
45+
private readonly nameMapper: QueryNameMapper | undefined;
4646

4747
constructor(
4848
private client: ClientImpl<Schema>,
@@ -54,7 +54,21 @@ export class ZenStackQueryExecutor<Schema extends SchemaDef> extends DefaultQuer
5454
private suppressMutationHooks: boolean = false,
5555
) {
5656
super(compiler, adapter, connectionProvider, plugins);
57-
this.nameMapper = new QueryNameMapper(client.$schema);
57+
58+
if (this.schemaHasMappedNames(client.$schema)) {
59+
this.nameMapper = new QueryNameMapper(client.$schema);
60+
}
61+
}
62+
63+
private schemaHasMappedNames(schema: Schema) {
64+
const hasMapAttr = (decl: ModelDef | TypeDefDef) => {
65+
if (decl.attributes?.some((attr) => attr.name === '@@map')) {
66+
return true;
67+
}
68+
return Object.values(decl.fields).some((field) => field.attributes?.some((attr) => attr.name === '@map'));
69+
};
70+
71+
return Object.values(schema.models).some(hasMapAttr) || Object.values(schema.typeDefs ?? []).some(hasMapAttr);
5872
}
5973

6074
private get kysely() {
@@ -170,7 +184,7 @@ export class ZenStackQueryExecutor<Schema extends SchemaDef> extends DefaultQuer
170184

171185
if (this.suppressMutationHooks || !this.isMutationNode(query) || !this.hasEntityMutationPlugins) {
172186
// no need to handle mutation hooks, just proceed
173-
const finalQuery = this.nameMapper.transformNode(query);
187+
const finalQuery = this.processNameMapping(query);
174188
compiled = this.compileQuery(finalQuery);
175189
if (parameters) {
176190
compiled = { ...compiled, parameters };
@@ -189,7 +203,7 @@ export class ZenStackQueryExecutor<Schema extends SchemaDef> extends DefaultQuer
189203
returning: ReturningNode.create([SelectionNode.createSelectAll()]),
190204
};
191205
}
192-
const finalQuery = this.nameMapper.transformNode(query);
206+
const finalQuery = this.processNameMapping(query);
193207
compiled = this.compileQuery(finalQuery);
194208
if (parameters) {
195209
compiled = { ...compiled, parameters };
@@ -239,6 +253,10 @@ export class ZenStackQueryExecutor<Schema extends SchemaDef> extends DefaultQuer
239253
return result;
240254
}
241255

256+
private processNameMapping<Node extends RootOperationNode>(query: Node): Node {
257+
return this.nameMapper?.transformNode(query) ?? query;
258+
}
259+
242260
private createClientForConnection(connection: DatabaseConnection, inTx: boolean) {
243261
const innerExecutor = this.withConnectionProvider(new SingleConnectionProvider(connection));
244262
innerExecutor.suppressMutationHooks = true;

0 commit comments

Comments
 (0)