Skip to content

Commit 79dcba9

Browse files
committed
address PR comments, refactor m2m check
1 parent 4f3c15b commit 79dcba9

File tree

6 files changed

+201
-130
lines changed

6 files changed

+201
-130
lines changed

packages/common-helpers/src/zip.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Zipped two arrays into an array of tuples.
2+
* Zips two arrays into an array of tuples.
33
*/
44
export function zip<T, U>(arr1: T[], arr2: U[]): Array<[T, U]> {
55
const length = Math.min(arr1.length, arr2.length);

packages/runtime/src/client/crud/dialects/base-dialect.ts

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,13 +1055,13 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
10551055
if (m2m) {
10561056
// many-to-many relation, count the join table
10571057
fieldCountQuery = eb
1058-
.selectFrom(m2m.joinTable)
1059-
.select(eb.fn.countAll().as(`_count$${field}`))
1060-
.whereRef(
1061-
eb.ref(`${parentAlias}.${m2m.parentPKName}`),
1062-
'=',
1063-
eb.ref(`${m2m.joinTable}.${m2m.parentFkName}`),
1064-
);
1058+
.selectFrom(fieldModel)
1059+
.innerJoin(m2m.joinTable, (join) =>
1060+
join
1061+
.onRef(`${m2m.joinTable}.${m2m.otherFkName}`, '=', `${fieldModel}.${m2m.otherPKName}`)
1062+
.onRef(`${m2m.joinTable}.${m2m.parentFkName}`, '=', `${parentAlias}.${m2m.parentPKName}`),
1063+
)
1064+
.select(eb.fn.countAll().as(`_count$${field}`));
10651065
} else {
10661066
// build a nested query to count the number of records in the relation
10671067
fieldCountQuery = eb.selectFrom(fieldModel).select(eb.fn.countAll().as(`_count$${field}`));

packages/runtime/src/plugins/policy/expression-transformer.ts

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ import {
4444
type SchemaDef,
4545
} from '../../schema';
4646
import { ExpressionEvaluator } from './expression-evaluator';
47-
import { conjunction, disjunction, logicalNot, trueNode } from './utils';
47+
import { conjunction, disjunction, falseNode, logicalNot, trueNode } from './utils';
4848

4949
export type ExpressionTransformerContext<Schema extends SchemaDef> = {
5050
model: GetModels<Schema>;
@@ -335,7 +335,13 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
335335
}
336336

337337
private transformValue(value: unknown, type: BuiltinType) {
338-
return ValueNode.create(this.dialect.transformPrimitive(value, type, false) ?? null);
338+
if (value === true) {
339+
return trueNode(this.dialect);
340+
} else if (value === false) {
341+
return falseNode(this.dialect);
342+
} else {
343+
return ValueNode.create(this.dialect.transformPrimitive(value, type, false) ?? null);
344+
}
339345
}
340346

341347
@expr('unary')

packages/runtime/src/plugins/policy/policy-handler.ts

Lines changed: 180 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -71,74 +71,46 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
7171
}
7272

7373
if (!this.isMutationQueryNode(node)) {
74-
// transform and proceed read without transaction
74+
// transform and proceed with read directly
7575
return proceed(this.transformNode(node));
7676
}
7777

78-
let mutationRequiresTransaction = false;
7978
const { mutationModel } = this.getMutationModel(node);
8079

81-
const isManyToManyJoinTable = this.isManyToManyJoinTable(mutationModel);
80+
if (InsertQueryNode.is(node)) {
81+
// pre-create policy evaluation happens before execution of the query
82+
const isManyToManyJoinTable = this.isManyToManyJoinTable(mutationModel);
83+
let needCheckPreCreate = true;
84+
85+
// many-to-many join table is not a model so can't have policies on it
86+
if (!isManyToManyJoinTable) {
87+
// check constant policies
88+
const constCondition = this.tryGetConstantPolicy(mutationModel, 'create');
89+
if (constCondition === true) {
90+
needCheckPreCreate = false;
91+
} else if (constCondition === false) {
92+
throw new RejectedByPolicyError(mutationModel);
93+
}
94+
}
8295

83-
if (InsertQueryNode.is(node) && !isManyToManyJoinTable) {
84-
// reject create if unconditional deny
85-
const constCondition = this.tryGetConstantPolicy(mutationModel, 'create');
86-
if (constCondition === false) {
87-
throw new RejectedByPolicyError(mutationModel);
88-
} else if (constCondition === undefined) {
89-
mutationRequiresTransaction = true;
96+
if (needCheckPreCreate) {
97+
await this.enforcePreCreatePolicy(node, mutationModel, isManyToManyJoinTable, proceed);
9098
}
9199
}
92100

93-
if (!mutationRequiresTransaction && !node.returning) {
94-
// transform and proceed mutation without transaction
95-
return proceed(this.transformNode(node));
96-
}
101+
// proceed with query
97102

98-
if (InsertQueryNode.is(node)) {
99-
await this.enforcePreCreatePolicy(node, mutationModel, isManyToManyJoinTable, proceed);
100-
}
101-
const transformedNode = this.transformNode(node);
102-
const result = await proceed(transformedNode);
103+
const result = await proceed(this.transformNode(node));
103104

104-
if (!this.onlyReturningId(node)) {
105+
if (!node.returning || this.onlyReturningId(node)) {
106+
return result;
107+
} else {
105108
const readBackResult = await this.processReadBack(node, result, proceed);
106109
if (readBackResult.rows.length !== result.rows.length) {
107110
throw new RejectedByPolicyError(mutationModel, 'result is not allowed to be read back');
108111
}
109112
return readBackResult;
110-
} else {
111-
// reading id fields bypasses policy
112-
return result;
113113
}
114-
115-
// TODO: run in transaction
116-
// let readBackError = false;
117-
118-
// transform and post-process in a transaction
119-
// const result = await transaction(async (txProceed) => {
120-
// if (InsertQueryNode.is(node)) {
121-
// await this.enforcePreCreatePolicy(node, txProceed);
122-
// }
123-
// const transformedNode = this.transformNode(node);
124-
// const result = await txProceed(transformedNode);
125-
126-
// if (!this.onlyReturningId(node)) {
127-
// const readBackResult = await this.processReadBack(node, result, txProceed);
128-
// if (readBackResult.rows.length !== result.rows.length) {
129-
// readBackError = true;
130-
// }
131-
// return readBackResult;
132-
// } else {
133-
// return result;
134-
// }
135-
// });
136-
137-
// if (readBackError) {
138-
// throw new RejectedByPolicyError(mutationModel, 'result is not allowed to be read back');
139-
// }
140-
141-
// return result;
142114
}
143115

144116
// #region overrides
@@ -296,11 +268,81 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
296268
? this.unwrapCreateValueRows(node.values, mutationModel, fields, isManyToManyJoinTable)
297269
: [[]];
298270
for (const values of valueRows) {
299-
await this.enforcePreCreatePolicyForOne(
300-
mutationModel,
301-
fields,
302-
values.map((v) => v.node),
303-
proceed,
271+
if (isManyToManyJoinTable) {
272+
await this.enforcePreCreatePolicyForManyToManyJoinTable(
273+
mutationModel,
274+
fields,
275+
values.map((v) => v.node),
276+
proceed,
277+
);
278+
} else {
279+
await this.enforcePreCreatePolicyForOne(
280+
mutationModel,
281+
fields,
282+
values.map((v) => v.node),
283+
proceed,
284+
);
285+
}
286+
}
287+
}
288+
289+
private async enforcePreCreatePolicyForManyToManyJoinTable(
290+
tableName: GetModels<Schema>,
291+
fields: string[],
292+
values: OperationNode[],
293+
proceed: ProceedKyselyQueryFunction,
294+
) {
295+
const m2m = this.resolveManyToManyJoinTable(tableName);
296+
invariant(m2m);
297+
298+
// m2m create requires both sides to be updatable
299+
invariant(fields.includes('A') && fields.includes('B'), 'many-to-many join table must have A and B fk fields');
300+
301+
const aIndex = fields.indexOf('A');
302+
const aNode = values[aIndex]!;
303+
const bIndex = fields.indexOf('B');
304+
const bNode = values[bIndex]!;
305+
invariant(ValueNode.is(aNode) && ValueNode.is(bNode), 'A and B values must be ValueNode');
306+
307+
const aValue = aNode.value;
308+
const bValue = bNode.value;
309+
invariant(aValue !== null && aValue !== undefined, 'A value cannot be null or undefined');
310+
invariant(bValue !== null && bValue !== undefined, 'B value cannot be null or undefined');
311+
312+
const eb = expressionBuilder<any, any>();
313+
314+
const filterA = this.buildPolicyFilter(m2m.firstModel as GetModels<Schema>, undefined, 'update');
315+
const queryA = eb
316+
.selectFrom(m2m.firstModel)
317+
.where(eb(eb.ref(`${m2m.firstModel}.${m2m.firstIdField}`), '=', aValue))
318+
.select(() => new ExpressionWrapper(filterA).as('$t'));
319+
320+
const filterB = this.buildPolicyFilter(m2m.secondModel as GetModels<Schema>, undefined, 'update');
321+
const queryB = eb
322+
.selectFrom(m2m.secondModel)
323+
.where(eb(eb.ref(`${m2m.secondModel}.${m2m.secondIdField}`), '=', bValue))
324+
.select(() => new ExpressionWrapper(filterB).as('$t'));
325+
326+
// select both conditions in one query
327+
const queryNode: SelectQueryNode = {
328+
kind: 'SelectQueryNode',
329+
selections: [
330+
SelectionNode.create(AliasNode.create(queryA.toOperationNode(), IdentifierNode.create('$conditionA'))),
331+
SelectionNode.create(AliasNode.create(queryB.toOperationNode(), IdentifierNode.create('$conditionB'))),
332+
],
333+
};
334+
335+
const result = await proceed(queryNode);
336+
if (!result.rows[0]?.$conditionA) {
337+
throw new RejectedByPolicyError(
338+
m2m.firstModel as GetModels<Schema>,
339+
`many-to-many relation participant model "${m2m.firstModel}" not updatable`,
340+
);
341+
}
342+
if (!result.rows[0]?.$conditionB) {
343+
throw new RejectedByPolicyError(
344+
m2m.secondModel as GetModels<Schema>,
345+
`many-to-many relation participant model "${m2m.secondModel}" not updatable`,
304346
);
305347
}
306348
}
@@ -658,77 +700,100 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
658700
return result;
659701
}
660702

661-
private isManyToManyJoinTable(tableName: string) {
662-
return Object.values(this.client.$schema.models).some((modelDef) => {
663-
return Object.values(modelDef.fields).some((field) => {
664-
const m2m = getManyToManyRelation(this.client.$schema, modelDef.name, field.name);
665-
return m2m?.joinTable === tableName;
666-
});
667-
});
668-
}
669-
670-
private getModelPolicyFilterForManyToManyJoinTable(
671-
tableName: string,
672-
alias: string | undefined,
673-
operation: PolicyOperation,
674-
): OperationNode | undefined {
675-
// find the m2m relation for this join table
703+
private resolveManyToManyJoinTable(tableName: string) {
676704
for (const model of Object.values(this.client.$schema.models)) {
677705
for (const field of Object.values(model.fields)) {
678706
const m2m = getManyToManyRelation(this.client.$schema, model.name, field.name);
679-
if (m2m?.joinTable !== tableName) {
680-
continue;
681-
}
682-
683-
// determine A/B side
684-
const sortedRecords = [
685-
{
686-
model: model.name,
687-
field: field.name,
688-
},
689-
{
690-
model: m2m.otherModel,
691-
field: m2m.otherField,
692-
},
693-
].sort((a, b) =>
694-
// the implicit m2m join table's "A", "B" fk fields' order is determined
695-
// by model name's sort order, and when identical (for self-relations),
696-
// field name's sort order
697-
a.model !== b.model ? a.model.localeCompare(b.model) : a.field.localeCompare(b.field),
698-
);
699-
700-
// join table's permission:
701-
// - read: requires both sides to be readable
702-
// - mutation: requires both sides to be updatable
703-
704-
const queries: SelectQueryBuilder<any, any, any>[] = [];
705-
const eb = expressionBuilder<any, any>();
706-
707-
for (const [fk, entry] of zip(['A', 'B'], sortedRecords)) {
708-
const idFields = requireIdFields(this.client.$schema, entry.model);
707+
if (m2m?.joinTable === tableName) {
708+
const sortedRecord = [
709+
{
710+
model: model.name,
711+
field: field.name,
712+
},
713+
{
714+
model: m2m.otherModel,
715+
field: m2m.otherField,
716+
},
717+
].sort(this.manyToManySorter);
718+
719+
const firstIdFields = requireIdFields(this.client.$schema, sortedRecord[0]!.model);
720+
const secondIdFields = requireIdFields(this.client.$schema, sortedRecord[1]!.model);
709721
invariant(
710-
idFields.length === 1,
722+
firstIdFields.length === 1 && secondIdFields.length === 1,
711723
'only single-field id is supported for implicit many-to-many join table',
712724
);
713725

714-
const policyFilter = this.buildPolicyFilter(
715-
entry.model as GetModels<Schema>,
716-
undefined,
717-
operation === 'read' ? 'read' : 'update',
718-
);
719-
const query = eb
720-
.selectFrom(entry.model)
721-
.whereRef(`${entry.model}.${idFields[0]}`, '=', `${alias ?? tableName}.${fk}`)
722-
.select(new ExpressionWrapper(policyFilter).as(`$condition${fk}`));
723-
queries.push(query);
726+
return {
727+
firstModel: sortedRecord[0]!.model,
728+
firstField: sortedRecord[0]!.field,
729+
firstIdField: firstIdFields[0]!,
730+
secondModel: sortedRecord[1]!.model,
731+
secondField: sortedRecord[1]!.field,
732+
secondIdField: secondIdFields[0]!,
733+
};
724734
}
725-
726-
return eb.and(queries).toOperationNode();
727735
}
728736
}
729-
730737
return undefined;
731738
}
732739

740+
private manyToManySorter(a: { model: string; field: string }, b: { model: string; field: string }): number {
741+
// the implicit m2m join table's "A", "B" fk fields' order is determined
742+
// by model name's sort order, and when identical (for self-relations),
743+
// field name's sort order
744+
return a.model !== b.model ? a.model.localeCompare(b.model) : a.field.localeCompare(b.field);
745+
}
746+
747+
private isManyToManyJoinTable(tableName: string) {
748+
return !!this.resolveManyToManyJoinTable(tableName);
749+
}
750+
751+
private getModelPolicyFilterForManyToManyJoinTable(
752+
tableName: string,
753+
alias: string | undefined,
754+
operation: PolicyOperation,
755+
): OperationNode | undefined {
756+
const m2m = this.resolveManyToManyJoinTable(tableName);
757+
if (!m2m) {
758+
return undefined;
759+
}
760+
761+
const sortedRecords = [
762+
{
763+
model: m2m.firstModel,
764+
field: m2m.firstField,
765+
},
766+
{
767+
model: m2m.secondModel,
768+
field: m2m.secondField,
769+
},
770+
];
771+
772+
// join table's permission:
773+
// - read: requires both sides to be readable
774+
// - mutation: requires both sides to be updatable
775+
776+
const queries: SelectQueryBuilder<any, any, any>[] = [];
777+
const eb = expressionBuilder<any, any>();
778+
779+
for (const [fk, entry] of zip(['A', 'B'], sortedRecords)) {
780+
const idFields = requireIdFields(this.client.$schema, entry.model);
781+
invariant(idFields.length === 1, 'only single-field id is supported for implicit many-to-many join table');
782+
783+
const policyFilter = this.buildPolicyFilter(
784+
entry.model as GetModels<Schema>,
785+
undefined,
786+
operation === 'read' ? 'read' : 'update',
787+
);
788+
const query = eb
789+
.selectFrom(entry.model)
790+
.whereRef(`${entry.model}.${idFields[0]}`, '=', `${alias ?? tableName}.${fk}`)
791+
.select(new ExpressionWrapper(policyFilter).as(`$condition${fk}`));
792+
queries.push(query);
793+
}
794+
795+
return eb.and(queries).toOperationNode();
796+
}
797+
733798
// #endregion
734799
}

0 commit comments

Comments
 (0)