Skip to content

Commit 2704643

Browse files
committed
fix: using computed fields for orderBy and aggregations
1 parent 2e95aa5 commit 2704643

File tree

7 files changed

+81
-49
lines changed

7 files changed

+81
-49
lines changed

packages/runtime/src/client/crud/dialects/base.ts

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { invariant, isPlainObject } from '@zenstackhq/common-helpers';
22
import type { Expression, ExpressionBuilder, ExpressionWrapper, SqlBool, ValueNode } from 'kysely';
3-
import { sql, type SelectQueryBuilder } from 'kysely';
3+
import { expressionBuilder, sql, type SelectQueryBuilder } from 'kysely';
44
import { match, P } from 'ts-pattern';
55
import type { BuiltinType, DataSourceProviderType, FieldDef, GetModels, SchemaDef } from '../../../schema';
66
import { enumerate } from '../../../utils/enumerate';
@@ -95,11 +95,9 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
9595
result = this.and(eb, result, this.buildRelationFilter(eb, model, modelAlias, key, fieldDef, payload));
9696
} else {
9797
// if the field is from a base model, build a reference from that model
98-
const fieldRef = buildFieldRef(
99-
this.schema,
98+
const fieldRef = this.fieldRef(
10099
fieldDef.originModel ?? model,
101100
key,
102-
this.options,
103101
eb,
104102
fieldDef.originModel ?? modelAlias,
105103
);
@@ -727,7 +725,8 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
727725
for (const [k, v] of Object.entries<string>(value)) {
728726
invariant(v === 'asc' || v === 'desc', `invalid orderBy value for field "${field}"`);
729727
result = result.orderBy(
730-
(eb) => aggregate(eb, sql.ref(`${modelAlias}.${k}`), field as AGGREGATE_OPERATORS),
728+
(eb) =>
729+
aggregate(eb, this.fieldRef(model, k, eb, modelAlias), field as AGGREGATE_OPERATORS),
731730
sql.raw(this.negateSort(v, negated)),
732731
);
733732
}
@@ -740,7 +739,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
740739
for (const [k, v] of Object.entries<string>(value)) {
741740
invariant(v === 'asc' || v === 'desc', `invalid orderBy value for field "${field}"`);
742741
result = result.orderBy(
743-
(eb) => eb.fn.count(sql.ref(k)),
742+
(eb) => eb.fn.count(this.fieldRef(model, k, eb, modelAlias)),
744743
sql.raw(this.negateSort(v, negated)),
745744
);
746745
}
@@ -753,8 +752,9 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
753752
const fieldDef = requireField(this.schema, model, field);
754753

755754
if (!fieldDef.relation) {
755+
const fieldRef = this.fieldRef(model, field, expressionBuilder(), modelAlias);
756756
if (value === 'asc' || value === 'desc') {
757-
result = result.orderBy(sql.ref(`${modelAlias}.${field}`), this.negateSort(value, negated));
757+
result = result.orderBy(fieldRef, this.negateSort(value, negated));
758758
} else if (
759759
value &&
760760
typeof value === 'object' &&
@@ -764,7 +764,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
764764
(value.nulls === 'first' || value.nulls === 'last')
765765
) {
766766
result = result.orderBy(
767-
sql.ref(`${modelAlias}.${field}`),
767+
fieldRef,
768768
sql.raw(`${this.negateSort(value.sort, negated)} nulls ${value.nulls}`),
769769
);
770770
}
@@ -865,7 +865,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
865865
const fieldDef = requireField(this.schema, model, field);
866866
if (fieldDef.computed) {
867867
// TODO: computed field from delegate base?
868-
return query.select((eb) => buildFieldRef(this.schema, model, field, this.options, eb).as(field));
868+
return query.select((eb) => this.fieldRef(model, field, eb, modelAlias).as(field));
869869
} else if (!fieldDef.originModel) {
870870
// regular field
871871
return query.select(sql.ref(`${modelAlias}.${field}`).as(field));
@@ -993,6 +993,10 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
993993
return eb.not(this.and(eb, ...args));
994994
}
995995

996+
fieldRef(model: string, field: string, eb: ExpressionBuilder<any, any>, modelAlias?: string) {
997+
return buildFieldRef(this.schema, model, field, this.options, eb, modelAlias);
998+
}
999+
9961000
// #endregion
9971001

9981002
// #region abstract methods

packages/runtime/src/client/crud/dialects/postgresql.ts

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ import type { BuiltinType, FieldDef, GetModels, SchemaDef } from '../../../schem
1212
import { DELEGATE_JOINED_FIELD_PREFIX } from '../../constants';
1313
import type { FindArgs } from '../../crud-types';
1414
import {
15-
buildFieldRef,
1615
buildJoinPairs,
1716
getDelegateDescendantModels,
1817
getIdFields,
@@ -227,10 +226,7 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
227226
...Object.entries(relationModelDef.fields)
228227
.filter(([, value]) => !value.relation)
229228
.filter(([name]) => !(typeof payload === 'object' && (payload.omit as any)?.[name] === true))
230-
.map(([field]) => [
231-
sql.lit(field),
232-
buildFieldRef(this.schema, relationModel, field, this.options, eb),
233-
])
229+
.map(([field]) => [sql.lit(field), this.fieldRef(relationModel, field, eb)])
234230
.flatMap((v) => v),
235231
);
236232
} else if (payload.select) {
@@ -253,7 +249,7 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
253249
? // reference the synthesized JSON field
254250
eb.ref(`${parentAlias}$${relationField}$${field}.$j`)
255251
: // reference a plain field
256-
buildFieldRef(this.schema, relationModel, field, this.options, eb);
252+
this.fieldRef(relationModel, field, eb);
257253
return [sql.lit(field), fieldValue];
258254
}
259255
})

packages/runtime/src/client/crud/dialects/sqlite.ts

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import type { BuiltinType, GetModels, SchemaDef } from '../../../schema';
1313
import { DELEGATE_JOINED_FIELD_PREFIX } from '../../constants';
1414
import type { FindArgs } from '../../crud-types';
1515
import {
16-
buildFieldRef,
1716
getDelegateDescendantModels,
1817
getIdFields,
1918
getManyToManyRelation,
@@ -171,10 +170,7 @@ export class SqliteCrudDialect<Schema extends SchemaDef> extends BaseCrudDialect
171170
...Object.entries(relationModelDef.fields)
172171
.filter(([, value]) => !value.relation)
173172
.filter(([name]) => !(typeof payload === 'object' && (payload.omit as any)?.[name] === true))
174-
.map(([field]) => [
175-
sql.lit(field),
176-
buildFieldRef(this.schema, relationModel, field, this.options, eb),
177-
])
173+
.map(([field]) => [sql.lit(field), this.fieldRef(relationModel, field, eb)])
178174
.flatMap((v) => v),
179175
);
180176
} else if (payload.select) {
@@ -203,10 +199,7 @@ export class SqliteCrudDialect<Schema extends SchemaDef> extends BaseCrudDialect
203199
);
204200
return [sql.lit(field), subJson];
205201
} else {
206-
return [
207-
sql.lit(field),
208-
buildFieldRef(this.schema, relationModel, field, this.options, eb) as ArgsType,
209-
];
202+
return [sql.lit(field), this.fieldRef(relationModel, field, eb) as ArgsType];
210203
}
211204
}
212205
})

packages/runtime/src/client/crud/operations/base.ts

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ import type { FindArgs, SelectIncludeOmit, SortOrder, WhereInput } from '../../c
2929
import { InternalError, NotFoundError, QueryError } from '../../errors';
3030
import type { ToKysely } from '../../query-builder';
3131
import {
32-
buildFieldRef,
3332
ensureArray,
3433
extractIdFields,
3534
flattenCompoundUniqueFilters,
@@ -187,9 +186,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
187186
// make sure distinct fields are selected
188187
query = distinct.reduce(
189188
(acc, field) =>
190-
acc.select((eb) =>
191-
buildFieldRef(this.schema, model, field, this.options, eb).as(`$distinct$${field}`),
192-
),
189+
acc.select((eb) => this.dialect.fieldRef(model, field, eb).as(`$distinct$${field}`)),
193190
query,
194191
);
195192
}
@@ -1267,7 +1264,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
12671264
const key = Object.keys(payload)[0];
12681265
const value = this.dialect.transformPrimitive(payload[key!], fieldDef.type as BuiltinType, false);
12691266
const eb = expressionBuilder<any, any>();
1270-
const fieldRef = buildFieldRef(this.schema, model, field, this.options, eb);
1267+
const fieldRef = this.dialect.fieldRef(model, field, eb);
12711268

12721269
return match(key)
12731270
.with('set', () => value)
@@ -1290,7 +1287,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
12901287
const key = Object.keys(payload)[0];
12911288
const value = this.dialect.transformPrimitive(payload[key!], fieldDef.type as BuiltinType, true);
12921289
const eb = expressionBuilder<any, any>();
1293-
const fieldRef = buildFieldRef(this.schema, model, field, this.options, eb);
1290+
const fieldRef = this.dialect.fieldRef(model, field, eb);
12941291

12951292
return match(key)
12961293
.with('set', () => value)

packages/runtime/src/client/crud/operations/group-by.ts

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
import { sql } from 'kysely';
1+
import { expressionBuilder } from 'kysely';
22
import { match } from 'ts-pattern';
33
import type { SchemaDef } from '../../../schema';
4-
import { getField } from '../../query-utils';
4+
import { aggregate, getField } from '../../query-utils';
55
import { BaseOperationHandler } from './base';
66

77
export class GroupByOperationHandler<Schema extends SchemaDef> extends BaseOperationHandler<Schema> {
@@ -44,9 +44,11 @@ export class GroupByOperationHandler<Schema extends SchemaDef> extends BaseOpera
4444
return subQuery.as('$sub');
4545
});
4646

47+
const fieldRef = (field: string) => this.dialect.fieldRef(this.model, field, expressionBuilder(), '$sub');
48+
4749
// groupBy
4850
const bys = typeof parsedArgs.by === 'string' ? [parsedArgs.by] : (parsedArgs.by as string[]);
49-
query = query.groupBy(bys.map((by) => sql.ref(`$sub.${by}`)));
51+
query = query.groupBy(bys.map((by) => fieldRef(by)));
5052

5153
// orderBy
5254
if (parsedArgs.orderBy) {
@@ -59,7 +61,7 @@ export class GroupByOperationHandler<Schema extends SchemaDef> extends BaseOpera
5961

6062
// select all by fields
6163
for (const by of bys) {
62-
query = query.select(() => sql.ref(`$sub.${by}`).as(by));
64+
query = query.select(() => fieldRef(by).as(by));
6365
}
6466

6567
// aggregations
@@ -77,7 +79,7 @@ export class GroupByOperationHandler<Schema extends SchemaDef> extends BaseOpera
7779
);
7880
} else {
7981
query = query.select((eb) =>
80-
eb.cast(eb.fn.count(sql.ref(`$sub.${field}`)), 'integer').as(`${key}.${field}`),
82+
eb.cast(eb.fn.count(fieldRef(field)), 'integer').as(`${key}.${field}`),
8183
);
8284
}
8385
}
@@ -92,15 +94,7 @@ export class GroupByOperationHandler<Schema extends SchemaDef> extends BaseOpera
9294
case '_min': {
9395
Object.entries(value).forEach(([field, val]) => {
9496
if (val === true) {
95-
query = query.select((eb) => {
96-
const fn = match(key)
97-
.with('_sum', () => eb.fn.sum)
98-
.with('_avg', () => eb.fn.avg)
99-
.with('_max', () => eb.fn.max)
100-
.with('_min', () => eb.fn.min)
101-
.exhaustive();
102-
return fn(sql.ref(`$sub.${field}`)).as(`${key}.${field}`);
103-
});
97+
query = query.select((eb) => aggregate(eb, fieldRef(field), key).as(`${key}.${field}`));
10498
}
10599
});
106100
break;

packages/runtime/src/client/query-utils.ts

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -318,11 +318,7 @@ export function getDelegateDescendantModels(
318318
return [...collected];
319319
}
320320

321-
export function aggregate(
322-
eb: ExpressionBuilder<any, any>,
323-
expr: Expression<any>,
324-
op: AGGREGATE_OPERATORS,
325-
): Expression<any> {
321+
export function aggregate(eb: ExpressionBuilder<any, any>, expr: Expression<any>, op: AGGREGATE_OPERATORS) {
326322
return match(op)
327323
.with('_count', () => eb.fn.count(expr))
328324
.with('_sum', () => eb.fn.sum(expr))

packages/runtime/test/client-api/computed-fields.test.ts

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,58 @@ model User {
3636
).resolves.toMatchObject({
3737
upperName: 'ALEX',
3838
});
39+
40+
await expect(
41+
db.user.findFirst({
42+
where: { upperName: 'ALEX' },
43+
}),
44+
).resolves.toMatchObject({
45+
upperName: 'ALEX',
46+
});
47+
48+
await expect(
49+
db.user.findFirst({
50+
where: { upperName: 'Alex' },
51+
}),
52+
).toResolveNull();
53+
54+
await expect(
55+
db.user.findFirst({
56+
orderBy: { upperName: 'desc' },
57+
}),
58+
).resolves.toMatchObject({
59+
upperName: 'ALEX',
60+
});
61+
62+
await expect(
63+
db.user.findFirst({
64+
orderBy: { upperName: 'desc' },
65+
take: -1,
66+
}),
67+
).resolves.toMatchObject({
68+
upperName: 'ALEX',
69+
});
70+
71+
await expect(
72+
db.user.aggregate({
73+
_count: { upperName: true },
74+
}),
75+
).resolves.toMatchObject({
76+
_count: { upperName: 1 },
77+
});
78+
79+
await expect(
80+
db.user.groupBy({
81+
by: ['upperName'],
82+
_count: { upperName: true },
83+
_max: { upperName: true },
84+
}),
85+
).resolves.toEqual([
86+
expect.objectContaining({
87+
_count: { upperName: 1 },
88+
_max: { upperName: 'ALEX' },
89+
}),
90+
]);
3991
});
4092

4193
it('is typed correctly for non-optional fields', async () => {

0 commit comments

Comments
 (0)