Skip to content

Commit 36e1b77

Browse files
authored
feat: count and aggregate for delegate models (#115)
* feat: count and aggregate for delegate models * fixes
1 parent cefe223 commit 36e1b77

File tree

5 files changed

+195
-16
lines changed

5 files changed

+195
-16
lines changed

packages/runtime/src/client/crud/dialects/base.ts

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -818,20 +818,21 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
818818
return result;
819819
}
820820

821-
buildSelectField(query: SelectQueryBuilder<any, any, any>, model: string, modelAlias: string, field: string) {
821+
buildSelectField(
822+
query: SelectQueryBuilder<any, any, any>,
823+
model: string,
824+
modelAlias: string,
825+
field: string,
826+
): SelectQueryBuilder<any, any, any> {
822827
const fieldDef = requireField(this.schema, model, field);
823-
824828
if (fieldDef.computed) {
825829
// TODO: computed field from delegate base?
826830
return query.select((eb) => buildFieldRef(this.schema, model, field, this.options, eb).as(field));
827831
} else if (!fieldDef.originModel) {
828832
// regular field
829833
return query.select(sql.ref(`${modelAlias}.${field}`).as(field));
830834
} else {
831-
// field from delegate base, build a join
832-
let result = query;
833-
result = this.buildSelectField(result, fieldDef.originModel, fieldDef.originModel, field);
834-
return result;
835+
return this.buildSelectField(query, fieldDef.originModel, fieldDef.originModel, field);
835836
}
836837
}
837838

packages/runtime/src/client/crud/operations/aggregate.ts

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import type { ExpressionBuilder } from 'kysely';
12
import { sql } from 'kysely';
23
import { match } from 'ts-pattern';
34
import type { SchemaDef } from '../../../schema';
@@ -15,12 +16,33 @@ export class AggregateOperationHandler<Schema extends SchemaDef> extends BaseOpe
1516
let query = this.kysely.selectFrom((eb) => {
1617
// nested query for filtering and pagination
1718

18-
// where
19-
let subQuery = eb
20-
.selectFrom(this.model)
21-
.selectAll(this.model as any) // TODO: check typing
19+
// table and where
20+
let subQuery = this.dialect
21+
.buildSelectModel(eb as ExpressionBuilder<any, any>, this.model)
2222
.where((eb1) => this.dialect.buildFilter(eb1, this.model, this.model, parsedArgs?.where));
2323

24+
// select fields: collect fields from aggregation body
25+
const selectedFields: string[] = [];
26+
for (const [key, value] of Object.entries(parsedArgs)) {
27+
if (key.startsWith('_') && value && typeof value === 'object') {
28+
// select fields
29+
Object.entries(value)
30+
.filter(([field]) => field !== '_all')
31+
.filter(([, val]) => val === true)
32+
.forEach(([field]) => {
33+
if (!selectedFields.includes(field)) selectedFields.push(field);
34+
});
35+
}
36+
}
37+
if (selectedFields.length > 0) {
38+
for (const field of selectedFields) {
39+
subQuery = this.dialect.buildSelectField(subQuery, this.model, this.model, field);
40+
}
41+
} else {
42+
// if no field is explicitly selected, just do a `select 1` so `_count` works
43+
subQuery = subQuery.select(() => eb.lit(1).as('_all'));
44+
}
45+
2446
// skip & take
2547
const skip = parsedArgs?.skip;
2648
let take = parsedArgs?.take;

packages/runtime/src/client/crud/operations/count.ts

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import type { ExpressionBuilder } from 'kysely';
12
import { sql } from 'kysely';
23
import type { SchemaDef } from '../../../schema';
34
import { BaseOperationHandler } from './base';
@@ -9,15 +10,29 @@ export class CountOperationHandler<Schema extends SchemaDef> extends BaseOperati
910

1011
// parse args
1112
const parsedArgs = this.inputValidator.validateCountArgs(this.model, normalizedArgs);
13+
const subQueryName = '$sub';
1214

1315
let query = this.kysely.selectFrom((eb) => {
1416
// nested query for filtering and pagination
15-
let subQuery = eb
16-
.selectFrom(this.model)
17-
.selectAll()
17+
18+
let subQuery = this.dialect
19+
.buildSelectModel(eb as ExpressionBuilder<any, any>, this.model)
1820
.where((eb1) => this.dialect.buildFilter(eb1, this.model, this.model, parsedArgs?.where));
21+
22+
if (parsedArgs?.select && typeof parsedArgs.select === 'object') {
23+
// select fields
24+
for (const [key, value] of Object.entries(parsedArgs.select)) {
25+
if (key !== '_all' && value === true) {
26+
subQuery = this.dialect.buildSelectField(subQuery, this.model, this.model, key);
27+
}
28+
}
29+
} else {
30+
// no field selection, just build a `select 1`
31+
subQuery = subQuery.select(() => eb.lit(1).as('_all'));
32+
}
33+
1934
subQuery = this.dialect.buildSkipTake(subQuery, parsedArgs?.skip, parsedArgs?.take);
20-
return subQuery.as('$sub');
35+
return subQuery.as(subQueryName);
2136
});
2237

2338
if (parsedArgs?.select && typeof parsedArgs.select === 'object') {
@@ -26,7 +41,7 @@ export class CountOperationHandler<Schema extends SchemaDef> extends BaseOperati
2641
Object.keys(parsedArgs.select!).map((key) =>
2742
key === '_all'
2843
? eb.cast(eb.fn.countAll(), 'integer').as('_all')
29-
: eb.cast(eb.fn.count(sql.ref(`$sub.${key}`)), 'integer').as(key),
44+
: eb.cast(eb.fn.count(sql.ref(`${subQueryName}.${key}`)), 'integer').as(key),
3045
),
3146
);
3247

packages/runtime/src/client/executor/name-mapper.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ export class QueryNameMapper extends OperationNodeTransformer {
263263
model = model ?? this.currentModel;
264264
const modelDef = requireModel(this.schema, model!);
265265
const scalarFields = Object.entries(modelDef.fields)
266-
.filter(([, fieldDef]) => !fieldDef.relation && !fieldDef.computed)
266+
.filter(([, fieldDef]) => !fieldDef.relation && !fieldDef.computed && !fieldDef.originModel)
267267
.map(([fieldName]) => fieldName);
268268
return scalarFields;
269269
}

packages/runtime/test/client-api/delegate.test.ts

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,5 +1070,146 @@ model Gallery {
10701070
await expect(client.asset.findMany()).toResolveWithLength(1);
10711071
});
10721072
});
1073+
1074+
describe('Delegate aggregation tests', () => {
1075+
beforeEach(async () => {
1076+
const u = await client.user.create({
1077+
data: {
1078+
id: 1,
1079+
1080+
},
1081+
});
1082+
await client.ratedVideo.create({
1083+
data: {
1084+
id: 1,
1085+
viewCount: 0,
1086+
duration: 100,
1087+
url: 'v1',
1088+
rating: 5,
1089+
owner: { connect: { id: u.id } },
1090+
user: { connect: { id: u.id } },
1091+
comments: { create: [{ content: 'c1' }, { content: 'c2' }] },
1092+
},
1093+
});
1094+
await client.ratedVideo.create({
1095+
data: {
1096+
id: 2,
1097+
viewCount: 2,
1098+
duration: 200,
1099+
url: 'v2',
1100+
rating: 3,
1101+
},
1102+
});
1103+
});
1104+
1105+
it('works with count', async () => {
1106+
await expect(
1107+
client.ratedVideo.count({
1108+
where: { rating: 5 },
1109+
}),
1110+
).resolves.toEqual(1);
1111+
await expect(
1112+
client.ratedVideo.count({
1113+
where: { duration: 100 },
1114+
}),
1115+
).resolves.toEqual(1);
1116+
await expect(
1117+
client.ratedVideo.count({
1118+
where: { viewCount: 2 },
1119+
}),
1120+
).resolves.toEqual(1);
1121+
1122+
await expect(
1123+
client.video.count({
1124+
where: { duration: 100 },
1125+
}),
1126+
).resolves.toEqual(1);
1127+
await expect(
1128+
client.asset.count({
1129+
where: { viewCount: { gt: 0 } },
1130+
}),
1131+
).resolves.toEqual(1);
1132+
1133+
// field selection
1134+
await expect(
1135+
client.ratedVideo.count({
1136+
select: { _all: true, viewCount: true, url: true, rating: true },
1137+
}),
1138+
).resolves.toMatchObject({
1139+
_all: 2,
1140+
viewCount: 2,
1141+
url: 2,
1142+
rating: 2,
1143+
});
1144+
await expect(
1145+
client.video.count({
1146+
select: { _all: true, viewCount: true, url: true },
1147+
}),
1148+
).resolves.toMatchObject({
1149+
_all: 2,
1150+
viewCount: 2,
1151+
url: 2,
1152+
});
1153+
await expect(
1154+
client.asset.count({
1155+
select: { _all: true, viewCount: true },
1156+
}),
1157+
).resolves.toMatchObject({
1158+
_all: 2,
1159+
viewCount: 2,
1160+
});
1161+
});
1162+
1163+
it('works with aggregate', async () => {
1164+
await expect(
1165+
client.ratedVideo.aggregate({
1166+
where: { viewCount: { gte: 0 }, duration: { gt: 0 }, rating: { gt: 0 } },
1167+
_avg: { viewCount: true, duration: true, rating: true },
1168+
_count: true,
1169+
}),
1170+
).resolves.toMatchObject({
1171+
_avg: {
1172+
viewCount: 1,
1173+
duration: 150,
1174+
rating: 4,
1175+
},
1176+
_count: 2,
1177+
});
1178+
await expect(
1179+
client.video.aggregate({
1180+
where: { viewCount: { gte: 0 }, duration: { gt: 0 } },
1181+
_avg: { viewCount: true, duration: true },
1182+
_count: true,
1183+
}),
1184+
).resolves.toMatchObject({
1185+
_avg: {
1186+
viewCount: 1,
1187+
duration: 150,
1188+
},
1189+
_count: 2,
1190+
});
1191+
await expect(
1192+
client.asset.aggregate({
1193+
where: { viewCount: { gte: 0 } },
1194+
_avg: { viewCount: true },
1195+
_count: true,
1196+
}),
1197+
).resolves.toMatchObject({
1198+
_avg: {
1199+
viewCount: 1,
1200+
},
1201+
_count: 2,
1202+
});
1203+
1204+
// just count
1205+
await expect(
1206+
client.ratedVideo.aggregate({
1207+
_count: true,
1208+
}),
1209+
).resolves.toMatchObject({
1210+
_count: 2,
1211+
});
1212+
});
1213+
});
10731214
},
10741215
);

0 commit comments

Comments
 (0)