Skip to content

Commit 4f3c15b

Browse files
committed
feat(policy): many-to-many and self-relation support
1 parent e1bda19 commit 4f3c15b

File tree

15 files changed

+2812
-52
lines changed

15 files changed

+2812
-52
lines changed

TODO.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,8 @@
9999
- [ ] Validation
100100
- [ ] Access Policy
101101
- [ ] Short-circuit pre-create check for scalar-field only policies
102-
- [ ] Inject "replace into"
103-
- [ ] Inject "on conflict do update"
104-
- [ ] Inject "insert into select from"
102+
- [x] Inject "on conflict do update"
103+
- [x] `check` function
105104
- [x] Migration
106105
- [ ] Databases
107106
- [x] SQLite

packages/common-helpers/src/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ export * from './param-case';
44
export * from './sleep';
55
export * from './tiny-invariant';
66
export * from './upper-case-first';
7+
export * from './zip';

packages/common-helpers/src/zip.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
/**
2+
* Zipped two arrays into an array of tuples.
3+
*/
4+
export function zip<T, U>(arr1: T[], arr2: U[]): Array<[T, U]> {
5+
const length = Math.min(arr1.length, arr2.length);
6+
const result: Array<[T, U]> = [];
7+
for (let i = 0; i < length; i++) {
8+
result.push([arr1[i]!, arr2[i]!]);
9+
}
10+
return result;
11+
}

packages/runtime/src/client/crud-types.ts

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ export type OmitInput<Schema extends SchemaDef, Model extends GetModels<Schema>>
452452

453453
export type SelectIncludeOmit<Schema extends SchemaDef, Model extends GetModels<Schema>, AllowCount extends boolean> = {
454454
select?: SelectInput<Schema, Model, AllowCount, boolean>;
455-
include?: IncludeInput<Schema, Model>;
455+
include?: IncludeInput<Schema, Model, AllowCount>;
456456
omit?: OmitInput<Schema, Model>;
457457
};
458458

@@ -463,14 +463,7 @@ export type SelectInput<
463463
AllowRelation extends boolean = true,
464464
> = {
465465
[Key in NonRelationFields<Schema, Model>]?: boolean;
466-
} & (AllowRelation extends true ? IncludeInput<Schema, Model> : {}) & // relation fields
467-
// relation count
468-
(AllowCount extends true
469-
? // _count is only allowed if the model has to-many relations
470-
HasToManyRelations<Schema, Model> extends true
471-
? { _count?: SelectCount<Schema, Model> }
472-
: {}
473-
: {});
466+
} & (AllowRelation extends true ? IncludeInput<Schema, Model, AllowCount> : {});
474467

475468
type SelectCount<Schema extends SchemaDef, Model extends GetModels<Schema>> =
476469
| boolean
@@ -484,7 +477,11 @@ type SelectCount<Schema extends SchemaDef, Model extends GetModels<Schema>> =
484477
};
485478
};
486479

487-
export type IncludeInput<Schema extends SchemaDef, Model extends GetModels<Schema>> = {
480+
export type IncludeInput<
481+
Schema extends SchemaDef,
482+
Model extends GetModels<Schema>,
483+
AllowCount extends boolean = true,
484+
> = {
488485
[Key in RelationFields<Schema, Model>]?:
489486
| boolean
490487
| FindArgs<
@@ -498,7 +495,12 @@ export type IncludeInput<Schema extends SchemaDef, Model extends GetModels<Schem
498495
? true
499496
: false
500497
>;
501-
};
498+
} & (AllowCount extends true
499+
? // _count is only allowed if the model has to-many relations
500+
HasToManyRelations<Schema, Model> extends true
501+
? { _count?: SelectCount<Schema, Model> }
502+
: {}
503+
: {});
502504

503505
export type Subset<T, U> = {
504506
[key in keyof T]: key extends keyof U ? T[key] : never;
@@ -674,7 +676,7 @@ export type FindUniqueArgs<Schema extends SchemaDef, Model extends GetModels<Sch
674676

675677
export type CreateArgs<Schema extends SchemaDef, Model extends GetModels<Schema>> = {
676678
data: CreateInput<Schema, Model>;
677-
select?: SelectInput<Schema, Model, true>;
679+
select?: SelectInput<Schema, Model>;
678680
include?: IncludeInput<Schema, Model>;
679681
omit?: OmitInput<Schema, Model>;
680682
};
@@ -813,7 +815,7 @@ type NestedCreateManyInput<
813815
export type UpdateArgs<Schema extends SchemaDef, Model extends GetModels<Schema>> = {
814816
data: UpdateInput<Schema, Model>;
815817
where: WhereUniqueInput<Schema, Model>;
816-
select?: SelectInput<Schema, Model, true>;
818+
select?: SelectInput<Schema, Model>;
817819
include?: IncludeInput<Schema, Model>;
818820
omit?: OmitInput<Schema, Model>;
819821
};
@@ -841,7 +843,7 @@ export type UpsertArgs<Schema extends SchemaDef, Model extends GetModels<Schema>
841843
create: CreateInput<Schema, Model>;
842844
update: UpdateInput<Schema, Model>;
843845
where: WhereUniqueInput<Schema, Model>;
844-
select?: SelectInput<Schema, Model, true>;
846+
select?: SelectInput<Schema, Model>;
845847
include?: IncludeInput<Schema, Model>;
846848
omit?: OmitInput<Schema, Model>;
847849
};
@@ -958,7 +960,7 @@ type ToOneRelationUpdateInput<
958960

959961
export type DeleteArgs<Schema extends SchemaDef, Model extends GetModels<Schema>> = {
960962
where: WhereUniqueInput<Schema, Model>;
961-
select?: SelectInput<Schema, Model, true>;
963+
select?: SelectInput<Schema, Model>;
962964
include?: IncludeInput<Schema, Model>;
963965
omit?: OmitInput<Schema, Model>;
964966
};

packages/runtime/src/client/crud/dialects/base-dialect.ts

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,14 +1048,29 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
10481048
for (const [field, value] of Object.entries(selections.select)) {
10491049
const fieldDef = requireField(this.schema, model, field);
10501050
const fieldModel = fieldDef.type;
1051-
const joinPairs = buildJoinPairs(this.schema, model, parentAlias, field, fieldModel);
1052-
1053-
// build a nested query to count the number of records in the relation
1054-
let fieldCountQuery = eb.selectFrom(fieldModel).select(eb.fn.countAll().as(`_count$${field}`));
1051+
let fieldCountQuery: SelectQueryBuilder<any, any, any>;
10551052

10561053
// join conditions
1057-
for (const [left, right] of joinPairs) {
1058-
fieldCountQuery = fieldCountQuery.whereRef(left, '=', right);
1054+
const m2m = getManyToManyRelation(this.schema, model, field);
1055+
if (m2m) {
1056+
// many-to-many relation, count the join table
1057+
fieldCountQuery = eb
1058+
.selectFrom(m2m.joinTable)
1059+
.select(eb.fn.countAll().as(`_count$${field}`))
1060+
.whereRef(
1061+
eb.ref(`${parentAlias}.${m2m.parentPKName}`),
1062+
'=',
1063+
eb.ref(`${m2m.joinTable}.${m2m.parentFkName}`),
1064+
);
1065+
} else {
1066+
// build a nested query to count the number of records in the relation
1067+
fieldCountQuery = eb.selectFrom(fieldModel).select(eb.fn.countAll().as(`_count$${field}`));
1068+
1069+
// join conditions
1070+
const joinPairs = buildJoinPairs(this.schema, model, parentAlias, field, fieldModel);
1071+
for (const [left, right] of joinPairs) {
1072+
fieldCountQuery = fieldCountQuery.whereRef(left, '=', right);
1073+
}
10591074
}
10601075

10611076
// merge _count filter

packages/runtime/src/client/crud/operations/base.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
475475
entity: rightEntity,
476476
},
477477
].sort((a, b) =>
478-
// the implement m2m join table's "A", "B" fk fields' order is determined
478+
// the implicit m2m join table's "A", "B" fk fields' order is determined
479479
// by model name's sort order, and when identical (for self-relations),
480480
// field name's sort order
481481
a.model !== b.model ? a.model.localeCompare(b.model) : a.field.localeCompare(b.field),

packages/runtime/src/client/crud/validator.ts

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,14 @@ import Decimal from 'decimal.js';
33
import stableStringify from 'json-stable-stringify';
44
import { match, P } from 'ts-pattern';
55
import { z, ZodType } from 'zod';
6-
import { type BuiltinType, type EnumDef, type FieldDef, type GetModels, type SchemaDef } from '../../schema';
6+
import {
7+
type BuiltinType,
8+
type EnumDef,
9+
type FieldDef,
10+
type GetModels,
11+
type ModelDef,
12+
type SchemaDef,
13+
} from '../../schema';
714
import { enumerate } from '../../utils/enumerate';
815
import { extractFields } from '../../utils/object-utils';
916
import { formatError } from '../../utils/zod-utils';
@@ -595,10 +602,18 @@ export class InputValidator<Schema extends SchemaDef> {
595602
}
596603
}
597604

598-
const toManyRelations = Object.values(modelDef.fields).filter((def) => def.relation && def.array);
605+
const _countSchema = this.makeCountSelectionSchema(modelDef);
606+
if (_countSchema) {
607+
fields['_count'] = _countSchema;
608+
}
609+
610+
return z.strictObject(fields);
611+
}
599612

613+
private makeCountSelectionSchema(modelDef: ModelDef) {
614+
const toManyRelations = Object.values(modelDef.fields).filter((def) => def.relation && def.array);
600615
if (toManyRelations.length > 0) {
601-
fields['_count'] = z
616+
return z
602617
.union([
603618
z.literal(true),
604619
z.strictObject({
@@ -621,9 +636,9 @@ export class InputValidator<Schema extends SchemaDef> {
621636
}),
622637
])
623638
.optional();
639+
} else {
640+
return undefined;
624641
}
625-
626-
return z.strictObject(fields);
627642
}
628643

629644
private makeRelationSelectIncludeSchema(fieldDef: FieldDef) {
@@ -677,6 +692,11 @@ export class InputValidator<Schema extends SchemaDef> {
677692
}
678693
}
679694

695+
const _countSchema = this.makeCountSelectionSchema(modelDef);
696+
if (_countSchema) {
697+
fields['_count'] = _countSchema;
698+
}
699+
680700
return z.strictObject(fields);
681701
}
682702

packages/runtime/src/client/query-utils.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import { invariant } from '@zenstackhq/common-helpers';
12
import type { Expression, ExpressionBuilder, ExpressionWrapper } from 'kysely';
23
import { match } from 'ts-pattern';
34
import { ExpressionUtils, type FieldDef, type GetModels, type ModelDef, type SchemaDef } from '../schema';
@@ -259,11 +260,18 @@ export function getManyToManyRelation(schema: SchemaDef, model: string, field: s
259260
orderedFK = sortedFieldNames[0] === field ? ['A', 'B'] : ['B', 'A'];
260261
}
261262

263+
const modelIdFields = requireIdFields(schema, model);
264+
invariant(modelIdFields.length === 1, 'Only single-field ID is supported for many-to-many relation');
265+
const otherIdFields = requireIdFields(schema, fieldDef.type);
266+
invariant(otherIdFields.length === 1, 'Only single-field ID is supported for many-to-many relation');
267+
262268
return {
263269
parentFkName: orderedFK[0],
270+
parentPKName: modelIdFields[0]!,
264271
otherModel: fieldDef.type,
265272
otherField: fieldDef.relation.opposite,
266273
otherFkName: orderedFK[1],
274+
otherPKName: otherIdFields[0]!,
267275
joinTable: fieldDef.relation.name
268276
? `_${fieldDef.relation.name}`
269277
: `_${sortedModelNames[0]}To${sortedModelNames[1]}`,

0 commit comments

Comments
 (0)