Skip to content

Commit 2e95aa5

Browse files
authored
fix: support using aggregations inside orderBy and having of groupBy (#152)
* fix: support using aggregations inside `orderBy` and `having` of `groupBy` * update * update
1 parent 8833aa7 commit 2e95aa5

File tree

12 files changed

+445
-153
lines changed

12 files changed

+445
-153
lines changed

packages/language/src/validators/datamodel-validator.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ export default class DataModelValidator implements AstValidator<DataModel> {
107107
if (field.type.array && !isDataModel(field.type.reference?.ref)) {
108108
const provider = this.getDataSourceProvider(AstUtils.getContainerOfType(field, isModel)!);
109109
if (provider === 'sqlite') {
110-
accept('error', `Array type is not supported for "${provider}" provider.`, { node: field.type });
110+
accept('error', `List type is not supported for "${provider}" provider.`, { node: field.type });
111111
}
112112
}
113113

packages/runtime/src/client/client-impl.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,12 @@ function createModelCrudHandler<Schema extends SchemaDef, Model extends GetModel
556556
},
557557

558558
groupBy: (args: unknown) => {
559-
return createPromise('groupBy', args, new GroupByOperationHandler<Schema>(client, model, inputValidator));
559+
return createPromise(
560+
'groupBy',
561+
args,
562+
new GroupByOperationHandler<Schema>(client, model, inputValidator),
563+
true,
564+
);
560565
},
561566
} as ModelOperations<Schema, Model>;
562567
}

packages/runtime/src/client/constants.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,14 @@ export const TRANSACTION_UNSUPPORTED_METHODS = ['$transaction', '$disconnect', '
1717
* Prefix for JSON field used to store joined delegate rows.
1818
*/
1919
export const DELEGATE_JOINED_FIELD_PREFIX = '$delegate$';
20+
21+
/**
22+
* Logical combinators used in filters.
23+
*/
24+
export const LOGICAL_COMBINATORS = ['AND', 'OR', 'NOT'] as const;
25+
26+
/**
27+
* Aggregation operators.
28+
*/
29+
export const AGGREGATE_OPERATORS = ['_count', '_sum', '_avg', '_min', '_max'] as const;
30+
export type AGGREGATE_OPERATORS = (typeof AGGREGATE_OPERATORS)[number];

packages/runtime/src/client/contract.ts

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -752,17 +752,18 @@ export type ModelOperations<Schema extends SchemaDef, Model extends GetModels<Sc
752752
* _count: true
753753
* }); // result: `Array<{ country: string, city: string, _count: number }>`
754754
*
755-
* // group by with sorting, the `orderBy` fields must be in the `by` list
755+
* // group by with sorting, the `orderBy` fields must be either an aggregation
756+
* // or a field used in the `by` list
756757
* await db.profile.groupBy({
757758
* by: 'country',
758759
* orderBy: { country: 'desc' }
759760
* });
760761
*
761-
* // group by with having (post-aggregation filter), the `having` fields must
762-
* // be in the `by` list
762+
* // group by with having (post-aggregation filter), the fields used in `having` must
763+
* // be either an aggregation, or a field used in the `by` list
763764
* await db.profile.groupBy({
764765
* by: 'country',
765-
* having: { country: 'US' }
766+
* having: { country: 'US', age: { _avg: { gte: 18 } } }
766767
* });
767768
*/
768769
groupBy<T extends GroupByArgs<Schema, Model>>(

packages/runtime/src/client/crud-types.ts

Lines changed: 104 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ export type WhereInput<
209209
Schema extends SchemaDef,
210210
Model extends GetModels<Schema>,
211211
ScalarOnly extends boolean = false,
212+
WithAggregations extends boolean = false,
212213
> = {
213214
[Key in GetModelFields<Schema, Model> as ScalarOnly extends true
214215
? Key extends RelationFields<Schema, Model>
@@ -223,7 +224,12 @@ export type WhereInput<
223224
: FieldIsArray<Schema, Model, Key> extends true
224225
? ArrayFilter<GetModelFieldType<Schema, Model, Key>>
225226
: // primitive
226-
PrimitiveFilter<Schema, GetModelFieldType<Schema, Model, Key>, ModelFieldIsOptional<Schema, Model, Key>>;
227+
PrimitiveFilter<
228+
Schema,
229+
GetModelFieldType<Schema, Model, Key>,
230+
ModelFieldIsOptional<Schema, Model, Key>,
231+
WithAggregations
232+
>;
227233
} & {
228234
$expr?: (eb: ExpressionBuilder<ToKyselySchema<Schema>, Model>) => OperandExpression<SqlBool>;
229235
} & {
@@ -249,38 +255,56 @@ type ArrayFilter<T extends string> = {
249255
isEmpty?: boolean;
250256
};
251257

252-
type PrimitiveFilter<Schema extends SchemaDef, T extends string, Nullable extends boolean> = T extends 'String'
253-
? StringFilter<Schema, Nullable>
258+
type PrimitiveFilter<
259+
Schema extends SchemaDef,
260+
T extends string,
261+
Nullable extends boolean,
262+
WithAggregations extends boolean,
263+
> = T extends 'String'
264+
? StringFilter<Schema, Nullable, WithAggregations>
254265
: T extends 'Int' | 'Float' | 'Decimal' | 'BigInt'
255-
? NumberFilter<Schema, T, Nullable>
266+
? NumberFilter<Schema, T, Nullable, WithAggregations>
256267
: T extends 'Boolean'
257-
? BooleanFilter<Nullable>
268+
? BooleanFilter<Schema, Nullable, WithAggregations>
258269
: T extends 'DateTime'
259-
? DateTimeFilter<Schema, Nullable>
270+
? DateTimeFilter<Schema, Nullable, WithAggregations>
260271
: T extends 'Bytes'
261-
? BytesFilter<Nullable>
272+
? BytesFilter<Schema, Nullable, WithAggregations>
262273
: T extends 'Json'
263274
? 'Not implemented yet' // TODO: Json filter
264275
: never;
265276

266-
type CommonPrimitiveFilter<Schema extends SchemaDef, DataType, T extends BuiltinType, Nullable extends boolean> = {
277+
type CommonPrimitiveFilter<
278+
Schema extends SchemaDef,
279+
DataType,
280+
T extends BuiltinType,
281+
Nullable extends boolean,
282+
WithAggregations extends boolean,
283+
> = {
267284
equals?: NullableIf<DataType, Nullable>;
268285
in?: DataType[];
269286
notIn?: DataType[];
270287
lt?: DataType;
271288
lte?: DataType;
272289
gt?: DataType;
273290
gte?: DataType;
274-
not?: PrimitiveFilter<Schema, T, Nullable>;
291+
not?: PrimitiveFilter<Schema, T, Nullable, WithAggregations>;
275292
};
276293

277-
export type StringFilter<Schema extends SchemaDef, Nullable extends boolean> =
294+
export type StringFilter<Schema extends SchemaDef, Nullable extends boolean, WithAggregations extends boolean> =
278295
| NullableIf<string, Nullable>
279-
| (CommonPrimitiveFilter<Schema, string, 'String', Nullable> & {
296+
| (CommonPrimitiveFilter<Schema, string, 'String', Nullable, WithAggregations> & {
280297
contains?: string;
281298
startsWith?: string;
282299
endsWith?: string;
283-
} & (ProviderSupportsCaseSensitivity<Schema> extends true
300+
} & (WithAggregations extends true
301+
? {
302+
_count?: NumberFilter<Schema, 'Int', false, false>;
303+
_min?: StringFilter<Schema, false, false>;
304+
_max?: StringFilter<Schema, false, false>;
305+
}
306+
: {}) &
307+
(ProviderSupportsCaseSensitivity<Schema> extends true
284308
? {
285309
mode?: 'default' | 'insensitive';
286310
}
@@ -290,27 +314,58 @@ export type NumberFilter<
290314
Schema extends SchemaDef,
291315
T extends 'Int' | 'Float' | 'Decimal' | 'BigInt',
292316
Nullable extends boolean,
293-
> = NullableIf<number | bigint, Nullable> | CommonPrimitiveFilter<Schema, number, T, Nullable>;
317+
WithAggregations extends boolean,
318+
> =
319+
| NullableIf<number | bigint, Nullable>
320+
| (CommonPrimitiveFilter<Schema, number, T, Nullable, WithAggregations> &
321+
(WithAggregations extends true
322+
? {
323+
_count?: NumberFilter<Schema, 'Int', false, false>;
324+
_avg?: NumberFilter<Schema, T, false, false>;
325+
_sum?: NumberFilter<Schema, T, false, false>;
326+
_min?: NumberFilter<Schema, T, false, false>;
327+
_max?: NumberFilter<Schema, T, false, false>;
328+
}
329+
: {}));
294330

295-
export type DateTimeFilter<Schema extends SchemaDef, Nullable extends boolean> =
331+
export type DateTimeFilter<Schema extends SchemaDef, Nullable extends boolean, WithAggregations extends boolean> =
296332
| NullableIf<Date | string, Nullable>
297-
| CommonPrimitiveFilter<Schema, Date | string, 'DateTime', Nullable>;
333+
| (CommonPrimitiveFilter<Schema, Date | string, 'DateTime', Nullable, WithAggregations> &
334+
(WithAggregations extends true
335+
? {
336+
_count?: NumberFilter<Schema, 'Int', false, false>;
337+
_min?: DateTimeFilter<Schema, false, false>;
338+
_max?: DateTimeFilter<Schema, false, false>;
339+
}
340+
: {}));
298341

299-
export type BytesFilter<Nullable extends boolean> =
342+
export type BytesFilter<Schema extends SchemaDef, Nullable extends boolean, WithAggregations extends boolean> =
300343
| NullableIf<Uint8Array | Buffer, Nullable>
301-
| {
344+
| ({
302345
equals?: NullableIf<Uint8Array, Nullable>;
303346
in?: Uint8Array[];
304347
notIn?: Uint8Array[];
305-
not?: BytesFilter<Nullable>;
306-
};
348+
not?: BytesFilter<Schema, Nullable, WithAggregations>;
349+
} & (WithAggregations extends true
350+
? {
351+
_count?: NumberFilter<Schema, 'Int', false, false>;
352+
_min?: BytesFilter<Schema, false, false>;
353+
_max?: BytesFilter<Schema, false, false>;
354+
}
355+
: {}));
307356

308-
export type BooleanFilter<Nullable extends boolean> =
357+
export type BooleanFilter<Schema extends SchemaDef, Nullable extends boolean, WithAggregations extends boolean> =
309358
| NullableIf<boolean, Nullable>
310-
| {
359+
| ({
311360
equals?: NullableIf<boolean, Nullable>;
312-
not?: BooleanFilter<Nullable>;
313-
};
361+
not?: BooleanFilter<Schema, Nullable, WithAggregations>;
362+
} & (WithAggregations extends true
363+
? {
364+
_count?: NumberFilter<Schema, 'Int', false, false>;
365+
_min?: BooleanFilter<Schema, false, false>;
366+
_max?: BooleanFilter<Schema, false, false>;
367+
}
368+
: {}));
314369

315370
export type SortOrder = 'asc' | 'desc';
316371
export type NullsOrder = 'first' | 'last';
@@ -340,14 +395,15 @@ export type OrderBy<
340395
: {}) &
341396
(WithAggregation extends true
342397
? {
343-
_count?: OrderBy<Schema, Model, WithRelation, false>;
398+
_count?: OrderBy<Schema, Model, false, false>;
399+
_min?: MinMaxInput<Schema, Model, SortOrder>;
400+
_max?: MinMaxInput<Schema, Model, SortOrder>;
344401
} & (NumericFields<Schema, Model> extends never
345402
? {}
346403
: {
347-
_avg?: SumAvgInput<Schema, Model>;
348-
_sum?: SumAvgInput<Schema, Model>;
349-
_min?: MinMaxInput<Schema, Model>;
350-
_max?: MinMaxInput<Schema, Model>;
404+
// aggregations specific to numeric fields
405+
_avg?: SumAvgInput<Schema, Model, SortOrder>;
406+
_sum?: SumAvgInput<Schema, Model, SortOrder>;
351407
})
352408
: {});
353409

@@ -931,13 +987,13 @@ export type AggregateArgs<Schema extends SchemaDef, Model extends GetModels<Sche
931987
orderBy?: OrArray<OrderBy<Schema, Model, true, false>>;
932988
} & {
933989
_count?: true | CountAggregateInput<Schema, Model>;
990+
_min?: MinMaxInput<Schema, Model, true>;
991+
_max?: MinMaxInput<Schema, Model, true>;
934992
} & (NumericFields<Schema, Model> extends never
935993
? {}
936994
: {
937-
_avg?: SumAvgInput<Schema, Model>;
938-
_sum?: SumAvgInput<Schema, Model>;
939-
_min?: MinMaxInput<Schema, Model>;
940-
_max?: MinMaxInput<Schema, Model>;
995+
_avg?: SumAvgInput<Schema, Model, true>;
996+
_sum?: SumAvgInput<Schema, Model, true>;
941997
});
942998

943999
type NumericFields<Schema extends SchemaDef, Model extends GetModels<Schema>> = keyof {
@@ -952,16 +1008,16 @@ type NumericFields<Schema extends SchemaDef, Model extends GetModels<Schema>> =
9521008
: never]: GetModelField<Schema, Model, Key>;
9531009
};
9541010

955-
type SumAvgInput<Schema extends SchemaDef, Model extends GetModels<Schema>> = {
956-
[Key in NumericFields<Schema, Model>]?: true;
1011+
type SumAvgInput<Schema extends SchemaDef, Model extends GetModels<Schema>, ValueType> = {
1012+
[Key in NumericFields<Schema, Model>]?: ValueType;
9571013
};
9581014

959-
type MinMaxInput<Schema extends SchemaDef, Model extends GetModels<Schema>> = {
1015+
type MinMaxInput<Schema extends SchemaDef, Model extends GetModels<Schema>, ValueType> = {
9601016
[Key in GetModelFields<Schema, Model> as FieldIsArray<Schema, Model, Key> extends true
9611017
? never
9621018
: FieldIsRelation<Schema, Model, Key> extends true
9631019
? never
964-
: Key]?: true;
1020+
: Key]?: ValueType;
9651021
};
9661022

9671023
export type AggregateResult<
@@ -1006,21 +1062,28 @@ type AggCommonOutput<Input> = Input extends true
10061062

10071063
// #region GroupBy
10081064

1065+
type GroupByHaving<Schema extends SchemaDef, Model extends GetModels<Schema>> = Omit<
1066+
WhereInput<Schema, Model, true, true>,
1067+
'$expr'
1068+
>;
1069+
10091070
export type GroupByArgs<Schema extends SchemaDef, Model extends GetModels<Schema>> = {
10101071
where?: WhereInput<Schema, Model>;
10111072
orderBy?: OrArray<OrderBy<Schema, Model, false, true>>;
10121073
by: NonRelationFields<Schema, Model> | NonEmptyArray<NonRelationFields<Schema, Model>>;
1013-
having?: WhereInput<Schema, Model, true>;
1074+
having?: GroupByHaving<Schema, Model>;
10141075
take?: number;
10151076
skip?: number;
1077+
// aggregations
10161078
_count?: true | CountAggregateInput<Schema, Model>;
1079+
_min?: MinMaxInput<Schema, Model, true>;
1080+
_max?: MinMaxInput<Schema, Model, true>;
10171081
} & (NumericFields<Schema, Model> extends never
10181082
? {}
10191083
: {
1020-
_avg?: SumAvgInput<Schema, Model>;
1021-
_sum?: SumAvgInput<Schema, Model>;
1022-
_min?: MinMaxInput<Schema, Model>;
1023-
_max?: MinMaxInput<Schema, Model>;
1084+
// aggregations specific to numeric fields
1085+
_avg?: SumAvgInput<Schema, Model, true>;
1086+
_sum?: SumAvgInput<Schema, Model, true>;
10241087
});
10251088

10261089
export type GroupByResult<

0 commit comments

Comments
 (0)