Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions packages/language/src/validators/expression-validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -109,17 +109,19 @@ export default class ExpressionValidator implements AstValidator<Expression> {
}

if (
typeof expr.left.$resolvedType?.decl !== 'string' ||
!supportedShapes.includes(expr.left.$resolvedType.decl)
expr.left.$resolvedType &&
(typeof expr.left.$resolvedType?.decl !== 'string' ||
!supportedShapes.includes(expr.left.$resolvedType.decl))
) {
accept('error', `invalid operand type for "${expr.operator}" operator`, {
node: expr.left,
});
return;
}
if (
typeof expr.right.$resolvedType?.decl !== 'string' ||
!supportedShapes.includes(expr.right.$resolvedType.decl)
expr.right.$resolvedType &&
(typeof expr.right.$resolvedType?.decl !== 'string' ||
!supportedShapes.includes(expr.right.$resolvedType.decl))
) {
accept('error', `invalid operand type for "${expr.operator}" operator`, {
node: expr.right,
Expand All @@ -128,13 +130,13 @@ export default class ExpressionValidator implements AstValidator<Expression> {
}

// DateTime comparison is only allowed between two DateTime values
if (expr.left.$resolvedType.decl === 'DateTime' && expr.right.$resolvedType.decl !== 'DateTime') {
if (expr.left.$resolvedType?.decl === 'DateTime' && expr.right.$resolvedType?.decl !== 'DateTime') {
accept('error', 'incompatible operand types', {
node: expr,
});
} else if (
expr.right.$resolvedType.decl === 'DateTime' &&
expr.left.$resolvedType.decl !== 'DateTime'
expr.right.$resolvedType?.decl === 'DateTime' &&
expr.left.$resolvedType?.decl !== 'DateTime'
) {
accept('error', 'incompatible operand types', {
node: expr,
Expand Down
88 changes: 72 additions & 16 deletions packages/runtime/src/plugins/policy/policy-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
FunctionNode,
IdentifierNode,
InsertQueryNode,
JoinNode,
OperationNodeTransformer,
OperatorNode,
ParensNode,
Expand Down Expand Up @@ -138,18 +139,15 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
// #region overrides

protected override transformSelectQuery(node: SelectQueryNode) {
let whereNode = node.where;
let whereNode = this.transformNode(node.where);

node.from?.froms.forEach((from) => {
const extractResult = this.extractTableName(from);
if (extractResult) {
const { model, alias } = extractResult;
const filter = this.buildPolicyFilter(model, alias, 'read');
whereNode = WhereNode.create(
whereNode?.where ? conjunction(this.dialect, [whereNode.where, filter]) : filter,
);
}
});
// get combined policy filter for all froms, and merge into where clause
const policyFilter = this.createPolicyFilterForFrom(node.from);
if (policyFilter) {
whereNode = WhereNode.create(
whereNode?.where ? conjunction(this.dialect, [whereNode.where, policyFilter]) : policyFilter,
);
}

const baseResult = super.transformSelectQuery({
...node,
Expand All @@ -162,6 +160,27 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
};
}

protected override transformJoin(node: JoinNode) {
const table = this.extractTableName(node.table);
if (!table) {
// unable to extract table name, can be a subquery, which will be handled when nested transformation happens
return super.transformJoin(node);
}

// build a nested query with policy filter applied
const filter = this.buildPolicyFilter(table.model, undefined, 'read');
const nestedSelect: SelectQueryNode = {
kind: 'SelectQueryNode',
from: FromNode.create([node.table]),
selections: [SelectionNode.createSelectAll()],
where: WhereNode.create(filter),
};
return {
...node,
table: AliasNode.create(ParensNode.create(nestedSelect), IdentifierNode.create(table.alias ?? table.model)),
};
}

protected override transformInsertQuery(node: InsertQueryNode) {
// pre-insert check is done in `handle()`

Expand Down Expand Up @@ -210,7 +229,16 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
protected override transformUpdateQuery(node: UpdateQueryNode) {
const result = super.transformUpdateQuery(node);
const mutationModel = this.getMutationModel(node);
const filter = this.buildPolicyFilter(mutationModel, undefined, 'update');
let filter = this.buildPolicyFilter(mutationModel, undefined, 'update');

if (node.from) {
// for update with from (join), we need to merge join tables' policy filters to the "where" clause
const joinFilter = this.createPolicyFilterForFrom(node.from);
if (joinFilter) {
filter = conjunction(this.dialect, [filter, joinFilter]);
}
}

return {
...result,
where: WhereNode.create(result.where ? conjunction(this.dialect, [result.where.where, filter]) : filter),
Expand All @@ -220,7 +248,16 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
protected override transformDeleteQuery(node: DeleteQueryNode) {
const result = super.transformDeleteQuery(node);
const mutationModel = this.getMutationModel(node);
const filter = this.buildPolicyFilter(mutationModel, undefined, 'delete');
let filter = this.buildPolicyFilter(mutationModel, undefined, 'delete');

if (node.using) {
// for delete with using (join), we need to merge join tables' policy filters to the "where" clause
const joinFilter = this.createPolicyFilterForTables(node.using.tables);
if (joinFilter) {
filter = conjunction(this.dialect, [filter, joinFilter]);
}
}

return {
...result,
where: WhereNode.create(result.where ? conjunction(this.dialect, [result.where.where, filter]) : filter),
Expand Down Expand Up @@ -466,11 +503,11 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf

const allows = policies
.filter((policy) => policy.kind === 'allow')
.map((policy) => this.transformPolicyCondition(model, alias, operation, policy));
.map((policy) => this.compilePolicyCondition(model, alias, operation, policy));

const denies = policies
.filter((policy) => policy.kind === 'deny')
.map((policy) => this.transformPolicyCondition(model, alias, operation, policy));
.map((policy) => this.compilePolicyCondition(model, alias, operation, policy));

let combinedPolicy: OperationNode;

Expand Down Expand Up @@ -514,7 +551,26 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
}
}

private transformPolicyCondition(
private createPolicyFilterForFrom(node: FromNode | undefined) {
if (!node) {
return undefined;
}
return this.createPolicyFilterForTables(node.froms);
}

private createPolicyFilterForTables(tables: readonly OperationNode[]) {
return tables.reduce<OperationNode | undefined>((acc, table) => {
const extractResult = this.extractTableName(table);
if (extractResult) {
const { model, alias } = extractResult;
const filter = this.buildPolicyFilter(model, alias, 'read');
return acc ? conjunction(this.dialect, [acc, filter]) : filter;
}
return acc;
}, undefined);
}

private compilePolicyCondition(
model: GetModels<Schema>,
alias: string | undefined,
operation: CRUD,
Expand Down
Loading