Skip to content

Commit 514f8f9

Browse files
authored
feat(policy): many-to-many and self-relation support (#256)
* feat(policy): many-to-many and self-relation support * address PR comments, refactor m2m check * extra fixes and tests
1 parent e1bda19 commit 514f8f9

File tree

19 files changed

+3346
-106
lines changed

19 files changed

+3346
-106
lines changed

TODO.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,8 @@
9999
- [ ] Validation
100100
- [ ] Access Policy
101101
- [ ] Short-circuit pre-create check for scalar-field only policies
102-
- [ ] Inject "replace into"
103-
- [ ] Inject "on conflict do update"
104-
- [ ] Inject "insert into select from"
102+
- [x] Inject "on conflict do update"
103+
- [x] `check` function
105104
- [x] Migration
106105
- [ ] Databases
107106
- [x] SQLite

packages/common-helpers/src/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ export * from './param-case';
44
export * from './sleep';
55
export * from './tiny-invariant';
66
export * from './upper-case-first';
7+
export * from './zip';

packages/common-helpers/src/zip.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
/**
2+
* Zips two arrays into an array of tuples.
3+
*/
4+
export function zip<T, U>(arr1: T[], arr2: U[]): Array<[T, U]> {
5+
const length = Math.min(arr1.length, arr2.length);
6+
const result: Array<[T, U]> = [];
7+
for (let i = 0; i < length; i++) {
8+
result.push([arr1[i]!, arr2[i]!]);
9+
}
10+
return result;
11+
}

packages/runtime/src/client/crud-types.ts

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ export type OmitInput<Schema extends SchemaDef, Model extends GetModels<Schema>>
452452

453453
export type SelectIncludeOmit<Schema extends SchemaDef, Model extends GetModels<Schema>, AllowCount extends boolean> = {
454454
select?: SelectInput<Schema, Model, AllowCount, boolean>;
455-
include?: IncludeInput<Schema, Model>;
455+
include?: IncludeInput<Schema, Model, AllowCount>;
456456
omit?: OmitInput<Schema, Model>;
457457
};
458458

@@ -463,14 +463,7 @@ export type SelectInput<
463463
AllowRelation extends boolean = true,
464464
> = {
465465
[Key in NonRelationFields<Schema, Model>]?: boolean;
466-
} & (AllowRelation extends true ? IncludeInput<Schema, Model> : {}) & // relation fields
467-
// relation count
468-
(AllowCount extends true
469-
? // _count is only allowed if the model has to-many relations
470-
HasToManyRelations<Schema, Model> extends true
471-
? { _count?: SelectCount<Schema, Model> }
472-
: {}
473-
: {});
466+
} & (AllowRelation extends true ? IncludeInput<Schema, Model, AllowCount> : {});
474467

475468
type SelectCount<Schema extends SchemaDef, Model extends GetModels<Schema>> =
476469
| boolean
@@ -484,7 +477,11 @@ type SelectCount<Schema extends SchemaDef, Model extends GetModels<Schema>> =
484477
};
485478
};
486479

487-
export type IncludeInput<Schema extends SchemaDef, Model extends GetModels<Schema>> = {
480+
export type IncludeInput<
481+
Schema extends SchemaDef,
482+
Model extends GetModels<Schema>,
483+
AllowCount extends boolean = true,
484+
> = {
488485
[Key in RelationFields<Schema, Model>]?:
489486
| boolean
490487
| FindArgs<
@@ -498,7 +495,12 @@ export type IncludeInput<Schema extends SchemaDef, Model extends GetModels<Schem
498495
? true
499496
: false
500497
>;
501-
};
498+
} & (AllowCount extends true
499+
? // _count is only allowed if the model has to-many relations
500+
HasToManyRelations<Schema, Model> extends true
501+
? { _count?: SelectCount<Schema, Model> }
502+
: {}
503+
: {});
502504

503505
export type Subset<T, U> = {
504506
[key in keyof T]: key extends keyof U ? T[key] : never;
@@ -674,7 +676,7 @@ export type FindUniqueArgs<Schema extends SchemaDef, Model extends GetModels<Sch
674676

675677
export type CreateArgs<Schema extends SchemaDef, Model extends GetModels<Schema>> = {
676678
data: CreateInput<Schema, Model>;
677-
select?: SelectInput<Schema, Model, true>;
679+
select?: SelectInput<Schema, Model>;
678680
include?: IncludeInput<Schema, Model>;
679681
omit?: OmitInput<Schema, Model>;
680682
};
@@ -813,7 +815,7 @@ type NestedCreateManyInput<
813815
export type UpdateArgs<Schema extends SchemaDef, Model extends GetModels<Schema>> = {
814816
data: UpdateInput<Schema, Model>;
815817
where: WhereUniqueInput<Schema, Model>;
816-
select?: SelectInput<Schema, Model, true>;
818+
select?: SelectInput<Schema, Model>;
817819
include?: IncludeInput<Schema, Model>;
818820
omit?: OmitInput<Schema, Model>;
819821
};
@@ -841,7 +843,7 @@ export type UpsertArgs<Schema extends SchemaDef, Model extends GetModels<Schema>
841843
create: CreateInput<Schema, Model>;
842844
update: UpdateInput<Schema, Model>;
843845
where: WhereUniqueInput<Schema, Model>;
844-
select?: SelectInput<Schema, Model, true>;
846+
select?: SelectInput<Schema, Model>;
845847
include?: IncludeInput<Schema, Model>;
846848
omit?: OmitInput<Schema, Model>;
847849
};
@@ -958,7 +960,7 @@ type ToOneRelationUpdateInput<
958960

959961
export type DeleteArgs<Schema extends SchemaDef, Model extends GetModels<Schema>> = {
960962
where: WhereUniqueInput<Schema, Model>;
961-
select?: SelectInput<Schema, Model, true>;
963+
select?: SelectInput<Schema, Model>;
962964
include?: IncludeInput<Schema, Model>;
963965
omit?: OmitInput<Schema, Model>;
964966
};

packages/runtime/src/client/crud/dialects/base-dialect.ts

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,14 +1048,29 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
10481048
for (const [field, value] of Object.entries(selections.select)) {
10491049
const fieldDef = requireField(this.schema, model, field);
10501050
const fieldModel = fieldDef.type;
1051-
const joinPairs = buildJoinPairs(this.schema, model, parentAlias, field, fieldModel);
1052-
1053-
// build a nested query to count the number of records in the relation
1054-
let fieldCountQuery = eb.selectFrom(fieldModel).select(eb.fn.countAll().as(`_count$${field}`));
1051+
let fieldCountQuery: SelectQueryBuilder<any, any, any>;
10551052

10561053
// join conditions
1057-
for (const [left, right] of joinPairs) {
1058-
fieldCountQuery = fieldCountQuery.whereRef(left, '=', right);
1054+
const m2m = getManyToManyRelation(this.schema, model, field);
1055+
if (m2m) {
1056+
// many-to-many relation, count the join table
1057+
fieldCountQuery = eb
1058+
.selectFrom(fieldModel)
1059+
.innerJoin(m2m.joinTable, (join) =>
1060+
join
1061+
.onRef(`${m2m.joinTable}.${m2m.otherFkName}`, '=', `${fieldModel}.${m2m.otherPKName}`)
1062+
.onRef(`${m2m.joinTable}.${m2m.parentFkName}`, '=', `${parentAlias}.${m2m.parentPKName}`),
1063+
)
1064+
.select(eb.fn.countAll().as(`_count$${field}`));
1065+
} else {
1066+
// build a nested query to count the number of records in the relation
1067+
fieldCountQuery = eb.selectFrom(fieldModel).select(eb.fn.countAll().as(`_count$${field}`));
1068+
1069+
// join conditions
1070+
const joinPairs = buildJoinPairs(this.schema, model, parentAlias, field, fieldModel);
1071+
for (const [left, right] of joinPairs) {
1072+
fieldCountQuery = fieldCountQuery.whereRef(left, '=', right);
1073+
}
10591074
}
10601075

10611076
// merge _count filter

packages/runtime/src/client/crud/operations/base.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
475475
entity: rightEntity,
476476
},
477477
].sort((a, b) =>
478-
// the implement m2m join table's "A", "B" fk fields' order is determined
478+
// the implicit m2m join table's "A", "B" fk fields' order is determined
479479
// by model name's sort order, and when identical (for self-relations),
480480
// field name's sort order
481481
a.model !== b.model ? a.model.localeCompare(b.model) : a.field.localeCompare(b.field),

packages/runtime/src/client/crud/validator.ts

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,14 @@ import Decimal from 'decimal.js';
33
import stableStringify from 'json-stable-stringify';
44
import { match, P } from 'ts-pattern';
55
import { z, ZodType } from 'zod';
6-
import { type BuiltinType, type EnumDef, type FieldDef, type GetModels, type SchemaDef } from '../../schema';
6+
import {
7+
type BuiltinType,
8+
type EnumDef,
9+
type FieldDef,
10+
type GetModels,
11+
type ModelDef,
12+
type SchemaDef,
13+
} from '../../schema';
714
import { enumerate } from '../../utils/enumerate';
815
import { extractFields } from '../../utils/object-utils';
916
import { formatError } from '../../utils/zod-utils';
@@ -595,10 +602,18 @@ export class InputValidator<Schema extends SchemaDef> {
595602
}
596603
}
597604

598-
const toManyRelations = Object.values(modelDef.fields).filter((def) => def.relation && def.array);
605+
const _countSchema = this.makeCountSelectionSchema(modelDef);
606+
if (_countSchema) {
607+
fields['_count'] = _countSchema;
608+
}
609+
610+
return z.strictObject(fields);
611+
}
599612

613+
private makeCountSelectionSchema(modelDef: ModelDef) {
614+
const toManyRelations = Object.values(modelDef.fields).filter((def) => def.relation && def.array);
600615
if (toManyRelations.length > 0) {
601-
fields['_count'] = z
616+
return z
602617
.union([
603618
z.literal(true),
604619
z.strictObject({
@@ -621,9 +636,9 @@ export class InputValidator<Schema extends SchemaDef> {
621636
}),
622637
])
623638
.optional();
639+
} else {
640+
return undefined;
624641
}
625-
626-
return z.strictObject(fields);
627642
}
628643

629644
private makeRelationSelectIncludeSchema(fieldDef: FieldDef) {
@@ -677,6 +692,11 @@ export class InputValidator<Schema extends SchemaDef> {
677692
}
678693
}
679694

695+
const _countSchema = this.makeCountSelectionSchema(modelDef);
696+
if (_countSchema) {
697+
fields['_count'] = _countSchema;
698+
}
699+
680700
return z.strictObject(fields);
681701
}
682702

packages/runtime/src/client/query-utils.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import { invariant } from '@zenstackhq/common-helpers';
12
import type { Expression, ExpressionBuilder, ExpressionWrapper } from 'kysely';
23
import { match } from 'ts-pattern';
34
import { ExpressionUtils, type FieldDef, type GetModels, type ModelDef, type SchemaDef } from '../schema';
@@ -259,11 +260,18 @@ export function getManyToManyRelation(schema: SchemaDef, model: string, field: s
259260
orderedFK = sortedFieldNames[0] === field ? ['A', 'B'] : ['B', 'A'];
260261
}
261262

263+
const modelIdFields = requireIdFields(schema, model);
264+
invariant(modelIdFields.length === 1, 'Only single-field ID is supported for many-to-many relation');
265+
const otherIdFields = requireIdFields(schema, fieldDef.type);
266+
invariant(otherIdFields.length === 1, 'Only single-field ID is supported for many-to-many relation');
267+
262268
return {
263269
parentFkName: orderedFK[0],
270+
parentPKName: modelIdFields[0]!,
264271
otherModel: fieldDef.type,
265272
otherField: fieldDef.relation.opposite,
266273
otherFkName: orderedFK[1],
274+
otherPKName: otherIdFields[0]!,
267275
joinTable: fieldDef.relation.name
268276
? `_${fieldDef.relation.name}`
269277
: `_${sortedModelNames[0]}To${sortedModelNames[1]}`,

packages/runtime/src/plugins/policy/expression-transformer.ts

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,13 @@ import type { ClientContract, CRUD } from '../../client/contract';
2424
import { getCrudDialect } from '../../client/crud/dialects';
2525
import type { BaseCrudDialect } from '../../client/crud/dialects/base-dialect';
2626
import { InternalError, QueryError } from '../../client/errors';
27-
import { getModel, getRelationForeignKeyFieldPairs, requireField, requireIdFields } from '../../client/query-utils';
27+
import {
28+
getManyToManyRelation,
29+
getModel,
30+
getRelationForeignKeyFieldPairs,
31+
requireField,
32+
requireIdFields,
33+
} from '../../client/query-utils';
2834
import type {
2935
BinaryExpression,
3036
BinaryOperator,
@@ -44,7 +50,7 @@ import {
4450
type SchemaDef,
4551
} from '../../schema';
4652
import { ExpressionEvaluator } from './expression-evaluator';
47-
import { conjunction, disjunction, logicalNot, trueNode } from './utils';
53+
import { conjunction, disjunction, falseNode, logicalNot, trueNode } from './utils';
4854

4955
export type ExpressionTransformerContext<Schema extends SchemaDef> = {
5056
model: GetModels<Schema>;
@@ -335,7 +341,13 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
335341
}
336342

337343
private transformValue(value: unknown, type: BuiltinType) {
338-
return ValueNode.create(this.dialect.transformPrimitive(value, type, false) ?? null);
344+
if (value === true) {
345+
return trueNode(this.dialect);
346+
} else if (value === false) {
347+
return falseNode(this.dialect);
348+
} else {
349+
return ValueNode.create(this.dialect.transformPrimitive(value, type, false) ?? null);
350+
}
339351
}
340352

341353
@expr('unary')
@@ -537,6 +549,11 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
537549
relationModel: string,
538550
context: ExpressionTransformerContext<Schema>,
539551
): SelectQueryNode {
552+
const m2m = getManyToManyRelation(this.schema, context.model, field);
553+
if (m2m) {
554+
return this.transformManyToManyRelationAccess(m2m, context);
555+
}
556+
540557
const fromModel = context.model;
541558
const { keyPairs, ownedByModel } = getRelationForeignKeyFieldPairs(this.schema, fromModel, field);
542559

@@ -574,6 +591,28 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
574591
};
575592
}
576593

594+
private transformManyToManyRelationAccess(
595+
m2m: NonNullable<ReturnType<typeof getManyToManyRelation>>,
596+
context: ExpressionTransformerContext<Schema>,
597+
) {
598+
const eb = expressionBuilder<any, any>();
599+
const relationQuery = eb
600+
.selectFrom(m2m.otherModel)
601+
// inner join with join table and additionally filter by the parent model
602+
.innerJoin(m2m.joinTable, (join) =>
603+
join
604+
// relation model pk to join table fk
605+
.onRef(`${m2m.otherModel}.${m2m.otherPKName}`, '=', `${m2m.joinTable}.${m2m.otherFkName}`)
606+
// parent model pk to join table fk
607+
.onRef(
608+
`${m2m.joinTable}.${m2m.parentFkName}`,
609+
'=',
610+
`${context.alias ?? context.model}.${m2m.parentPKName}`,
611+
),
612+
);
613+
return relationQuery.toOperationNode();
614+
}
615+
577616
private createColumnRef(column: string, context: ExpressionTransformerContext<Schema>): ReferenceNode {
578617
return ReferenceNode.create(ColumnNode.create(column), TableNode.create(context.alias ?? context.model));
579618
}

0 commit comments

Comments
 (0)