Skip to content

Commit 192cc01

Browse files
authored
fix: using _count in relation selection (#143)
1 parent ceaaaf8 commit 192cc01

File tree

8 files changed

+145
-90
lines changed

8 files changed

+145
-90
lines changed

packages/runtime/src/client/crud-types.ts

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,12 @@ export type SelectInput<
393393
[Key in NonRelationFields<Schema, Model>]?: true;
394394
} & (AllowRelation extends true ? IncludeInput<Schema, Model> : {}) & // relation fields
395395
// relation count
396-
(AllowCount extends true ? { _count?: SelectCount<Schema, Model> } : {});
396+
(AllowCount extends true
397+
? // _count is only allowed if the model has to-many relations
398+
HasToManyRelations<Schema, Model> extends true
399+
? { _count?: SelectCount<Schema, Model> }
400+
: {}
401+
: {});
397402

398403
type SelectCount<Schema extends SchemaDef, Model extends GetModels<Schema>> =
399404
| true
@@ -1181,4 +1186,10 @@ type NonOwnedRelationFields<Schema extends SchemaDef, Model extends GetModels<Sc
11811186
: Key]: true;
11821187
};
11831188

1189+
type HasToManyRelations<Schema extends SchemaDef, Model extends GetModels<Schema>> = keyof {
1190+
[Key in RelationFields<Schema, Model> as FieldIsArray<Schema, Model, Key> extends true ? Key : never]: true;
1191+
} extends never
1192+
? false
1193+
: true;
1194+
11841195
// #endregion

packages/runtime/src/client/crud/dialects/base.ts

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -847,6 +847,56 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
847847
return query;
848848
}
849849

850+
buildCountJson(model: string, eb: ExpressionBuilder<any, any>, parentAlias: string, payload: any) {
851+
const modelDef = requireModel(this.schema, model);
852+
const toManyRelations = Object.entries(modelDef.fields).filter(([, field]) => field.relation && field.array);
853+
854+
const selections =
855+
payload === true
856+
? {
857+
select: toManyRelations.reduce(
858+
(acc, [field]) => {
859+
acc[field] = true;
860+
return acc;
861+
},
862+
{} as Record<string, boolean>,
863+
),
864+
}
865+
: payload;
866+
867+
const jsonObject: Record<string, Expression<any>> = {};
868+
869+
for (const [field, value] of Object.entries(selections.select)) {
870+
const fieldDef = requireField(this.schema, model, field);
871+
const fieldModel = fieldDef.type;
872+
const joinPairs = buildJoinPairs(this.schema, model, parentAlias, field, fieldModel);
873+
874+
// build a nested query to count the number of records in the relation
875+
let fieldCountQuery = eb.selectFrom(fieldModel).select(eb.fn.countAll().as(`_count$${field}`));
876+
877+
// join conditions
878+
for (const [left, right] of joinPairs) {
879+
fieldCountQuery = fieldCountQuery.whereRef(left, '=', right);
880+
}
881+
882+
// merge _count filter
883+
if (
884+
value &&
885+
typeof value === 'object' &&
886+
'where' in value &&
887+
value.where &&
888+
typeof value.where === 'object'
889+
) {
890+
const filter = this.buildFilter(eb, fieldModel, fieldModel, value.where);
891+
fieldCountQuery = fieldCountQuery.where(filter);
892+
}
893+
894+
jsonObject[field] = fieldCountQuery;
895+
}
896+
897+
return this.buildJsonObject(eb, jsonObject);
898+
}
899+
850900
// #endregion
851901

852902
// #region utils

packages/runtime/src/client/crud/dialects/postgresql.ts

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
200200
relationField: string,
201201
eb: ExpressionBuilder<any, any>,
202202
payload: true | FindArgs<Schema, GetModels<Schema>, true>,
203-
parentName: string,
203+
parentAlias: string,
204204
) {
205205
const relationModelDef = requireModel(this.schema, relationModel);
206206
const objArgs: Array<
@@ -238,14 +238,24 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
238238
objArgs.push(
239239
...Object.entries(payload.select)
240240
.filter(([, value]) => value)
241-
.map(([field]) => {
242-
const fieldDef = requireField(this.schema, relationModel, field);
243-
const fieldValue = fieldDef.relation
244-
? // reference the synthesized JSON field
245-
eb.ref(`${parentName}$${relationField}$${field}.$j`)
246-
: // reference a plain field
247-
buildFieldRef(this.schema, relationModel, field, this.options, eb);
248-
return [sql.lit(field), fieldValue];
241+
.map(([field, value]) => {
242+
if (field === '_count') {
243+
const subJson = this.buildCountJson(
244+
relationModel as GetModels<Schema>,
245+
eb,
246+
`${parentAlias}$${relationField}`,
247+
value,
248+
);
249+
return [sql.lit(field), subJson];
250+
} else {
251+
const fieldDef = requireField(this.schema, relationModel, field);
252+
const fieldValue = fieldDef.relation
253+
? // reference the synthesized JSON field
254+
eb.ref(`${parentAlias}$${relationField}$${field}.$j`)
255+
: // reference a plain field
256+
buildFieldRef(this.schema, relationModel, field, this.options, eb);
257+
return [sql.lit(field), fieldValue];
258+
}
249259
})
250260
.flatMap((v) => v),
251261
);
@@ -259,7 +269,7 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
259269
.map(([field]) => [
260270
sql.lit(field),
261271
// reference the synthesized JSON field
262-
eb.ref(`${parentName}$${relationField}$${field}.$j`),
272+
eb.ref(`${parentAlias}$${relationField}$${field}.$j`),
263273
])
264274
.flatMap((v) => v),
265275
);

packages/runtime/src/client/crud/dialects/sqlite.ts

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,14 @@ export class SqliteCrudDialect<Schema extends SchemaDef> extends BaseCrudDialect
6767
model: string,
6868
eb: ExpressionBuilder<any, any>,
6969
relationField: string,
70-
parentName: string,
70+
parentAlias: string,
7171
payload: true | FindArgs<Schema, GetModels<Schema>, true>,
7272
) {
7373
const relationFieldDef = requireField(this.schema, model, relationField);
7474
const relationModel = relationFieldDef.type as GetModels<Schema>;
7575
const relationModelDef = requireModel(this.schema, relationModel);
7676

77-
const subQueryName = `${parentName}$${relationField}`;
77+
const subQueryName = `${parentAlias}$${relationField}`;
7878

7979
let tbl = eb.selectFrom(() => {
8080
let subQuery = this.buildSelectModel(eb, relationModel);
@@ -129,18 +129,18 @@ export class SqliteCrudDialect<Schema extends SchemaDef> extends BaseCrudDialect
129129
eb
130130
.selectFrom(m2m.joinTable)
131131
.select(`${m2m.joinTable}.${m2m.otherFkName}`)
132-
.whereRef(`${parentName}.${parentIds[0]}`, '=', `${m2m.joinTable}.${m2m.parentFkName}`),
132+
.whereRef(`${parentAlias}.${parentIds[0]}`, '=', `${m2m.joinTable}.${m2m.parentFkName}`),
133133
),
134134
);
135135
} else {
136136
const { keyPairs, ownedByModel } = getRelationForeignKeyFieldPairs(this.schema, model, relationField);
137137
keyPairs.forEach(({ fk, pk }) => {
138138
if (ownedByModel) {
139139
// the parent model owns the fk
140-
subQuery = subQuery.whereRef(`${relationModel}.${pk}`, '=', `${parentName}.${fk}`);
140+
subQuery = subQuery.whereRef(`${relationModel}.${pk}`, '=', `${parentAlias}.${fk}`);
141141
} else {
142142
// the relation side owns the fk
143-
subQuery = subQuery.whereRef(`${relationModel}.${fk}`, '=', `${parentName}.${pk}`);
143+
subQuery = subQuery.whereRef(`${relationModel}.${fk}`, '=', `${parentAlias}.${pk}`);
144144
}
145145
});
146146
}
@@ -183,21 +183,31 @@ export class SqliteCrudDialect<Schema extends SchemaDef> extends BaseCrudDialect
183183
...Object.entries<any>(payload.select)
184184
.filter(([, value]) => value)
185185
.map(([field, value]) => {
186-
const fieldDef = requireField(this.schema, relationModel, field);
187-
if (fieldDef.relation) {
188-
const subJson = this.buildRelationJSON(
186+
if (field === '_count') {
187+
const subJson = this.buildCountJson(
189188
relationModel as GetModels<Schema>,
190189
eb,
191-
field,
192-
`${parentName}$${relationField}`,
190+
`${parentAlias}$${relationField}`,
193191
value,
194192
);
195-
return [sql.lit(field), subJson as ArgsType];
193+
return [sql.lit(field), subJson];
196194
} else {
197-
return [
198-
sql.lit(field),
199-
buildFieldRef(this.schema, relationModel, field, this.options, eb) as ArgsType,
200-
];
195+
const fieldDef = requireField(this.schema, relationModel, field);
196+
if (fieldDef.relation) {
197+
const subJson = this.buildRelationJSON(
198+
relationModel as GetModels<Schema>,
199+
eb,
200+
field,
201+
`${parentAlias}$${relationField}`,
202+
value,
203+
);
204+
return [sql.lit(field), subJson];
205+
} else {
206+
return [
207+
sql.lit(field),
208+
buildFieldRef(this.schema, relationModel, field, this.options, eb) as ArgsType,
209+
];
210+
}
201211
}
202212
})
203213
.flatMap((v) => v),
@@ -214,7 +224,7 @@ export class SqliteCrudDialect<Schema extends SchemaDef> extends BaseCrudDialect
214224
relationModel as GetModels<Schema>,
215225
eb,
216226
field,
217-
`${parentName}$${relationField}`,
227+
`${parentAlias}$${relationField}`,
218228
value,
219229
);
220230
return [sql.lit(field), subJson];

packages/runtime/src/client/crud/operations/base.ts

Lines changed: 1 addition & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import {
88
UpdateResult,
99
type Compilable,
1010
type IsolationLevel,
11-
type Expression as KyselyExpression,
1211
type QueryResult,
1312
type SelectQueryBuilder,
1413
} from 'kysely';
@@ -31,7 +30,6 @@ import { InternalError, NotFoundError, QueryError } from '../../errors';
3130
import type { ToKysely } from '../../query-builder';
3231
import {
3332
buildFieldRef,
34-
buildJoinPairs,
3533
ensureArray,
3634
extractIdFields,
3735
flattenCompoundUniqueFilters,
@@ -298,56 +296,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
298296
parentAlias: string,
299297
payload: any,
300298
) {
301-
const modelDef = requireModel(this.schema, model);
302-
const toManyRelations = Object.entries(modelDef.fields).filter(([, field]) => field.relation && field.array);
303-
304-
const selections =
305-
payload === true
306-
? {
307-
select: toManyRelations.reduce(
308-
(acc, [field]) => {
309-
acc[field] = true;
310-
return acc;
311-
},
312-
{} as Record<string, boolean>,
313-
),
314-
}
315-
: payload;
316-
317-
const eb = expressionBuilder<any, any>();
318-
const jsonObject: Record<string, KyselyExpression<any>> = {};
319-
320-
for (const [field, value] of Object.entries(selections.select)) {
321-
const fieldDef = requireField(this.schema, model, field);
322-
const fieldModel = fieldDef.type;
323-
const joinPairs = buildJoinPairs(this.schema, model, parentAlias, field, fieldModel);
324-
325-
// build a nested query to count the number of records in the relation
326-
let fieldCountQuery = eb.selectFrom(fieldModel).select(eb.fn.countAll().as(`_count$${field}`));
327-
328-
// join conditions
329-
for (const [left, right] of joinPairs) {
330-
fieldCountQuery = fieldCountQuery.whereRef(left, '=', right);
331-
}
332-
333-
// merge _count filter
334-
if (
335-
value &&
336-
typeof value === 'object' &&
337-
'where' in value &&
338-
value.where &&
339-
typeof value.where === 'object'
340-
) {
341-
const filter = this.dialect.buildFilter(eb, fieldModel, fieldModel, value.where);
342-
fieldCountQuery = fieldCountQuery.where(filter);
343-
}
344-
345-
jsonObject[field] = fieldCountQuery;
346-
}
347-
348-
query = query.select((eb) => this.dialect.buildJsonObject(eb, jsonObject).as('_count'));
349-
350-
return query;
299+
return query.select((eb) => this.dialect.buildCountJson(model, eb, parentAlias, payload).as('_count'));
351300
}
352301

353302
private buildCursorFilter(

packages/runtime/src/client/crud/validator.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -544,17 +544,17 @@ export class InputValidator<Schema extends SchemaDef> {
544544
}
545545
}
546546

547-
const toManyRelations = Object.entries(modelDef.fields).filter(([, value]) => value.relation && value.array);
547+
const toManyRelations = Object.values(modelDef.fields).filter((def) => def.relation && def.array);
548548

549549
if (toManyRelations.length > 0) {
550550
fields['_count'] = z
551551
.union([
552552
z.literal(true),
553553
z.object(
554554
toManyRelations.reduce(
555-
(acc, [name, fieldDef]) => ({
555+
(acc, fieldDef) => ({
556556
...acc,
557-
[name]: z
557+
[fieldDef.name]: z
558558
.union([
559559
z.boolean(),
560560
z.object({

packages/runtime/src/client/query-utils.ts

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,23 +97,23 @@ export function getRelationForeignKeyFieldPairs(schema: SchemaDef, model: string
9797
}
9898

9999
export function isScalarField(schema: SchemaDef, model: string, field: string): boolean {
100-
const fieldDef = requireField(schema, model, field);
101-
return !fieldDef.relation && !fieldDef.foreignKeyFor;
100+
const fieldDef = getField(schema, model, field);
101+
return !fieldDef?.relation && !fieldDef?.foreignKeyFor;
102102
}
103103

104104
export function isForeignKeyField(schema: SchemaDef, model: string, field: string): boolean {
105-
const fieldDef = requireField(schema, model, field);
106-
return !!fieldDef.foreignKeyFor;
105+
const fieldDef = getField(schema, model, field);
106+
return !!fieldDef?.foreignKeyFor;
107107
}
108108

109109
export function isRelationField(schema: SchemaDef, model: string, field: string): boolean {
110-
const fieldDef = requireField(schema, model, field);
111-
return !!fieldDef.relation;
110+
const fieldDef = getField(schema, model, field);
111+
return !!fieldDef?.relation;
112112
}
113113

114114
export function isInheritedField(schema: SchemaDef, model: string, field: string): boolean {
115-
const fieldDef = requireField(schema, model, field);
116-
return !!fieldDef.originModel;
115+
const fieldDef = getField(schema, model, field);
116+
return !!fieldDef?.originModel;
117117
}
118118

119119
export function getUniqueFields(schema: SchemaDef, model: string) {

packages/runtime/test/client-api/find.test.ts

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,6 +832,31 @@ describe.each(createClientSpecs(PG_DB_NAME))('Client find tests for $provider',
832832
_count: { posts: 2 },
833833
});
834834

835+
await expect(
836+
client.user.findUnique({
837+
where: { id: user1.id },
838+
select: {
839+
id: true,
840+
posts: {
841+
select: { _count: true },
842+
},
843+
},
844+
}),
845+
).resolves.toMatchObject({
846+
id: user1.id,
847+
posts: [{ _count: { comments: 0 } }, { _count: { comments: 0 } }],
848+
});
849+
850+
client.comment.findFirst({
851+
// @ts-expect-error Comment has no to-many relations to count
852+
select: { _count: true },
853+
});
854+
855+
client.post.findFirst({
856+
// @ts-expect-error Comment has no to-many relations to count
857+
select: { comments: { _count: true } },
858+
});
859+
835860
await expect(
836861
client.user.findUnique({
837862
where: { id: user1.id },

0 commit comments

Comments
 (0)