Skip to content

Commit d7fbd6d

Browse files
committed
feat(policy): support read filtering for update with "from" and delete with "using"
1 parent 50e92e0 commit d7fbd6d

File tree

3 files changed

+448
-23
lines changed

3 files changed

+448
-23
lines changed

packages/language/src/validators/expression-validator.ts

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,17 +109,19 @@ export default class ExpressionValidator implements AstValidator<Expression> {
109109
}
110110

111111
if (
112-
typeof expr.left.$resolvedType?.decl !== 'string' ||
113-
!supportedShapes.includes(expr.left.$resolvedType.decl)
112+
expr.left.$resolvedType &&
113+
(typeof expr.left.$resolvedType?.decl !== 'string' ||
114+
!supportedShapes.includes(expr.left.$resolvedType.decl))
114115
) {
115116
accept('error', `invalid operand type for "${expr.operator}" operator`, {
116117
node: expr.left,
117118
});
118119
return;
119120
}
120121
if (
121-
typeof expr.right.$resolvedType?.decl !== 'string' ||
122-
!supportedShapes.includes(expr.right.$resolvedType.decl)
122+
expr.right.$resolvedType &&
123+
(typeof expr.right.$resolvedType?.decl !== 'string' ||
124+
!supportedShapes.includes(expr.right.$resolvedType.decl))
123125
) {
124126
accept('error', `invalid operand type for "${expr.operator}" operator`, {
125127
node: expr.right,
@@ -128,13 +130,13 @@ export default class ExpressionValidator implements AstValidator<Expression> {
128130
}
129131

130132
// DateTime comparison is only allowed between two DateTime values
131-
if (expr.left.$resolvedType.decl === 'DateTime' && expr.right.$resolvedType.decl !== 'DateTime') {
133+
if (expr.left.$resolvedType?.decl === 'DateTime' && expr.right.$resolvedType?.decl !== 'DateTime') {
132134
accept('error', 'incompatible operand types', {
133135
node: expr,
134136
});
135137
} else if (
136-
expr.right.$resolvedType.decl === 'DateTime' &&
137-
expr.left.$resolvedType.decl !== 'DateTime'
138+
expr.right.$resolvedType?.decl === 'DateTime' &&
139+
expr.left.$resolvedType?.decl !== 'DateTime'
138140
) {
139141
accept('error', 'incompatible operand types', {
140142
node: expr,

packages/runtime/src/plugins/policy/policy-handler.ts

Lines changed: 72 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import {
88
FunctionNode,
99
IdentifierNode,
1010
InsertQueryNode,
11+
JoinNode,
1112
OperationNodeTransformer,
1213
OperatorNode,
1314
ParensNode,
@@ -138,18 +139,15 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
138139
// #region overrides
139140

140141
protected override transformSelectQuery(node: SelectQueryNode) {
141-
let whereNode = node.where;
142+
let whereNode = this.transformNode(node.where);
142143

143-
node.from?.froms.forEach((from) => {
144-
const extractResult = this.extractTableName(from);
145-
if (extractResult) {
146-
const { model, alias } = extractResult;
147-
const filter = this.buildPolicyFilter(model, alias, 'read');
148-
whereNode = WhereNode.create(
149-
whereNode?.where ? conjunction(this.dialect, [whereNode.where, filter]) : filter,
150-
);
151-
}
152-
});
144+
// get combined policy filter for all froms, and merge into where clause
145+
const policyFilter = this.createPolicyFilterForFrom(node.from);
146+
if (policyFilter) {
147+
whereNode = WhereNode.create(
148+
whereNode?.where ? conjunction(this.dialect, [whereNode.where, policyFilter]) : policyFilter,
149+
);
150+
}
153151

154152
const baseResult = super.transformSelectQuery({
155153
...node,
@@ -162,6 +160,27 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
162160
};
163161
}
164162

163+
protected override transformJoin(node: JoinNode) {
164+
const table = this.extractTableName(node.table);
165+
if (!table) {
166+
// unable to extract table name, can be a subquery, which will be handled when nested transformation happens
167+
return super.transformJoin(node);
168+
}
169+
170+
// build a nested query with policy filter applied
171+
const filter = this.buildPolicyFilter(table.model, undefined, 'read');
172+
const nestedSelect: SelectQueryNode = {
173+
kind: 'SelectQueryNode',
174+
from: FromNode.create([node.table]),
175+
selections: [SelectionNode.createSelectAll()],
176+
where: WhereNode.create(filter),
177+
};
178+
return {
179+
...node,
180+
table: AliasNode.create(ParensNode.create(nestedSelect), IdentifierNode.create(table.alias ?? table.model)),
181+
};
182+
}
183+
165184
protected override transformInsertQuery(node: InsertQueryNode) {
166185
// pre-insert check is done in `handle()`
167186

@@ -210,7 +229,16 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
210229
protected override transformUpdateQuery(node: UpdateQueryNode) {
211230
const result = super.transformUpdateQuery(node);
212231
const mutationModel = this.getMutationModel(node);
213-
const filter = this.buildPolicyFilter(mutationModel, undefined, 'update');
232+
let filter = this.buildPolicyFilter(mutationModel, undefined, 'update');
233+
234+
if (node.from) {
235+
// for update with from (join), we need to merge join tables' policy filters to the "where" clause
236+
const joinFilter = this.createPolicyFilterForFrom(node.from);
237+
if (joinFilter) {
238+
filter = conjunction(this.dialect, [filter, joinFilter]);
239+
}
240+
}
241+
214242
return {
215243
...result,
216244
where: WhereNode.create(result.where ? conjunction(this.dialect, [result.where.where, filter]) : filter),
@@ -220,7 +248,16 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
220248
protected override transformDeleteQuery(node: DeleteQueryNode) {
221249
const result = super.transformDeleteQuery(node);
222250
const mutationModel = this.getMutationModel(node);
223-
const filter = this.buildPolicyFilter(mutationModel, undefined, 'delete');
251+
let filter = this.buildPolicyFilter(mutationModel, undefined, 'delete');
252+
253+
if (node.using) {
254+
// for delete with using (join), we need to merge join tables' policy filters to the "where" clause
255+
const joinFilter = this.createPolicyFilterForTables(node.using.tables);
256+
if (joinFilter) {
257+
filter = conjunction(this.dialect, [filter, joinFilter]);
258+
}
259+
}
260+
224261
return {
225262
...result,
226263
where: WhereNode.create(result.where ? conjunction(this.dialect, [result.where.where, filter]) : filter),
@@ -466,11 +503,11 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
466503

467504
const allows = policies
468505
.filter((policy) => policy.kind === 'allow')
469-
.map((policy) => this.transformPolicyCondition(model, alias, operation, policy));
506+
.map((policy) => this.compilePolicyCondition(model, alias, operation, policy));
470507

471508
const denies = policies
472509
.filter((policy) => policy.kind === 'deny')
473-
.map((policy) => this.transformPolicyCondition(model, alias, operation, policy));
510+
.map((policy) => this.compilePolicyCondition(model, alias, operation, policy));
474511

475512
let combinedPolicy: OperationNode;
476513

@@ -514,7 +551,26 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
514551
}
515552
}
516553

517-
private transformPolicyCondition(
554+
private createPolicyFilterForFrom(node: FromNode | undefined) {
555+
if (!node) {
556+
return undefined;
557+
}
558+
return this.createPolicyFilterForTables(node.froms);
559+
}
560+
561+
private createPolicyFilterForTables(tables: readonly OperationNode[]) {
562+
return tables.reduce<OperationNode | undefined>((acc, table) => {
563+
const extractResult = this.extractTableName(table);
564+
if (extractResult) {
565+
const { model, alias } = extractResult;
566+
const filter = this.buildPolicyFilter(model, alias, 'read');
567+
return acc ? conjunction(this.dialect, [acc, filter]) : filter;
568+
}
569+
return acc;
570+
}, undefined);
571+
}
572+
573+
private compilePolicyCondition(
518574
model: GetModels<Schema>,
519575
alias: string | undefined,
520576
operation: CRUD,

0 commit comments

Comments
 (0)