Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions packages/language/src/validators/expression-validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -108,18 +108,21 @@ export default class ExpressionValidator implements AstValidator<Expression> {
supportedShapes = ['Boolean', 'Any'];
}

const leftResolvedDecl = expr.left.$resolvedType?.decl;
const rightResolvedDecl = expr.right.$resolvedType?.decl;

if (
typeof expr.left.$resolvedType?.decl !== 'string' ||
!supportedShapes.includes(expr.left.$resolvedType.decl)
leftResolvedDecl &&
(typeof leftResolvedDecl !== 'string' || !supportedShapes.includes(leftResolvedDecl))
) {
accept('error', `invalid operand type for "${expr.operator}" operator`, {
node: expr.left,
});
return;
}
if (
typeof expr.right.$resolvedType?.decl !== 'string' ||
!supportedShapes.includes(expr.right.$resolvedType.decl)
rightResolvedDecl &&
(typeof rightResolvedDecl !== 'string' || !supportedShapes.includes(rightResolvedDecl))
) {
accept('error', `invalid operand type for "${expr.operator}" operator`, {
node: expr.right,
Expand All @@ -128,14 +131,11 @@ export default class ExpressionValidator implements AstValidator<Expression> {
}

// DateTime comparison is only allowed between two DateTime values
if (expr.left.$resolvedType.decl === 'DateTime' && expr.right.$resolvedType.decl !== 'DateTime') {
if (leftResolvedDecl === 'DateTime' && rightResolvedDecl && rightResolvedDecl !== 'DateTime') {
accept('error', 'incompatible operand types', {
node: expr,
});
} else if (
expr.right.$resolvedType.decl === 'DateTime' &&
expr.left.$resolvedType.decl !== 'DateTime'
) {
} else if (rightResolvedDecl === 'DateTime' && leftResolvedDecl && leftResolvedDecl !== 'DateTime') {
accept('error', 'incompatible operand types', {
node: expr,
});
Expand Down
2 changes: 1 addition & 1 deletion packages/runtime/src/client/query-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ export function requireField(schema: SchemaDef, modelOrType: string, field: stri

export function getIdFields<Schema extends SchemaDef>(schema: SchemaDef, model: GetModels<Schema>) {
const modelDef = requireModel(schema, model);
return modelDef?.idFields as GetModels<Schema>[];
return modelDef?.idFields;
}

export function requireIdFields(schema: SchemaDef, model: string) {
Expand Down
154 changes: 109 additions & 45 deletions packages/runtime/src/plugins/policy/policy-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
FunctionNode,
IdentifierNode,
InsertQueryNode,
JoinNode,
OperationNodeTransformer,
OperatorNode,
ParensNode,
Expand All @@ -31,7 +32,7 @@ import type { ClientContract } from '../../client';
import type { CRUD } from '../../client/contract';
import { getCrudDialect } from '../../client/crud/dialects';
import type { BaseCrudDialect } from '../../client/crud/dialects/base-dialect';
import { InternalError } from '../../client/errors';
import { InternalError, QueryError } from '../../client/errors';
import type { ProceedKyselyQueryFunction } from '../../client/plugin';
import { getIdFields, requireField, requireModel } from '../../client/query-utils';
import { ExpressionUtils, type BuiltinType, type Expression, type GetModels, type SchemaDef } from '../../schema';
Expand Down Expand Up @@ -72,7 +73,7 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
}

let mutationRequiresTransaction = false;
const mutationModel = this.getMutationModel(node);
const { mutationModel } = this.getMutationModel(node);

if (InsertQueryNode.is(node)) {
// reject create if unconditional deny
Expand Down Expand Up @@ -138,18 +139,15 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
// #region overrides

protected override transformSelectQuery(node: SelectQueryNode) {
let whereNode = node.where;
let whereNode = this.transformNode(node.where);

node.from?.froms.forEach((from) => {
const extractResult = this.extractTableName(from);
if (extractResult) {
const { model, alias } = extractResult;
const filter = this.buildPolicyFilter(model, alias, 'read');
whereNode = WhereNode.create(
whereNode?.where ? conjunction(this.dialect, [whereNode.where, filter]) : filter,
);
}
});
// get combined policy filter for all froms, and merge into where clause
const policyFilter = this.createPolicyFilterForFrom(node.from);
if (policyFilter) {
whereNode = WhereNode.create(
whereNode?.where ? conjunction(this.dialect, [whereNode.where, policyFilter]) : policyFilter,
);
}

const baseResult = super.transformSelectQuery({
...node,
Expand All @@ -162,15 +160,36 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
};
}

protected override transformJoin(node: JoinNode) {
const table = this.extractTableName(node.table);
if (!table) {
// unable to extract table name, can be a subquery, which will be handled when nested transformation happens
return super.transformJoin(node);
}

// build a nested query with policy filter applied
const filter = this.buildPolicyFilter(table.model, table.alias, 'read');
const nestedSelect: SelectQueryNode = {
kind: 'SelectQueryNode',
from: FromNode.create([node.table]),
selections: [SelectionNode.createSelectAll()],
where: WhereNode.create(filter),
};
return {
...node,
table: AliasNode.create(ParensNode.create(nestedSelect), IdentifierNode.create(table.alias ?? table.model)),
};
}

protected override transformInsertQuery(node: InsertQueryNode) {
// pre-insert check is done in `handle()`

let onConflict = node.onConflict;

if (onConflict?.updates) {
// for "on conflict do update", we need to apply policy filter to the "where" clause
const mutationModel = this.getMutationModel(node);
const filter = this.buildPolicyFilter(mutationModel, undefined, 'update');
const { mutationModel, alias } = this.getMutationModel(node);
const filter = this.buildPolicyFilter(mutationModel, alias, 'update');
if (onConflict.updateWhere) {
onConflict = {
...onConflict,
Expand All @@ -197,7 +216,8 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
return result;
} else {
// only return ID fields, that's enough for reading back the inserted row
const idFields = getIdFields(this.client.$schema, this.getMutationModel(node));
const { mutationModel } = this.getMutationModel(node);
const idFields = getIdFields(this.client.$schema, mutationModel);
return {
...result,
returning: ReturningNode.create(
Expand All @@ -209,8 +229,17 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf

protected override transformUpdateQuery(node: UpdateQueryNode) {
const result = super.transformUpdateQuery(node);
const mutationModel = this.getMutationModel(node);
const filter = this.buildPolicyFilter(mutationModel, undefined, 'update');
const { mutationModel, alias } = this.getMutationModel(node);
let filter = this.buildPolicyFilter(mutationModel, alias, 'update');

if (node.from) {
// for update with from (join), we need to merge join tables' policy filters to the "where" clause
const joinFilter = this.createPolicyFilterForFrom(node.from);
if (joinFilter) {
filter = conjunction(this.dialect, [filter, joinFilter]);
}
}

return {
...result,
where: WhereNode.create(result.where ? conjunction(this.dialect, [result.where.where, filter]) : filter),
Expand All @@ -219,8 +248,17 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf

protected override transformDeleteQuery(node: DeleteQueryNode) {
const result = super.transformDeleteQuery(node);
const mutationModel = this.getMutationModel(node);
const filter = this.buildPolicyFilter(mutationModel, undefined, 'delete');
const { mutationModel, alias } = this.getMutationModel(node);
let filter = this.buildPolicyFilter(mutationModel, alias, 'delete');

if (node.using) {
// for delete with using (join), we need to merge join tables' policy filters to the "where" clause
const joinFilter = this.createPolicyFilterForTables(node.using.tables);
if (joinFilter) {
filter = conjunction(this.dialect, [filter, joinFilter]);
}
}

return {
...result,
where: WhereNode.create(result.where ? conjunction(this.dialect, [result.where.where, filter]) : filter),
Expand All @@ -235,19 +273,20 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
if (!node.returning) {
return true;
}
const idFields = getIdFields(this.client.$schema, this.getMutationModel(node));
const { mutationModel } = this.getMutationModel(node);
const idFields = getIdFields(this.client.$schema, mutationModel);
const collector = new ColumnCollector();
const selectedColumns = collector.collect(node.returning);
return selectedColumns.every((c) => idFields.includes(c));
}

private async enforcePreCreatePolicy(node: InsertQueryNode, proceed: ProceedKyselyQueryFunction) {
const model = this.getMutationModel(node);
const { mutationModel } = this.getMutationModel(node);
const fields = node.columns?.map((c) => c.column.name) ?? [];
const valueRows = node.values ? this.unwrapCreateValueRows(node.values, model, fields) : [[]];
const valueRows = node.values ? this.unwrapCreateValueRows(node.values, mutationModel, fields) : [[]];
for (const values of valueRows) {
await this.enforcePreCreatePolicyForOne(
model,
mutationModel,
fields,
values.map((v) => v.node),
proceed,
Expand Down Expand Up @@ -394,17 +433,13 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
}

// do a select (with policy) in place of returning
const table = this.getMutationModel(node);
if (!table) {
throw new InternalError(`Unable to get table name for query node: ${node}`);
}

const idConditions = this.buildIdConditions(table, result.rows);
const policyFilter = this.buildPolicyFilter(table, undefined, 'read');
const { mutationModel } = this.getMutationModel(node);
const idConditions = this.buildIdConditions(mutationModel, result.rows);
const policyFilter = this.buildPolicyFilter(mutationModel, undefined, 'read');

const select: SelectQueryNode = {
kind: 'SelectQueryNode',
from: FromNode.create([TableNode.create(table)]),
from: FromNode.create([TableNode.create(mutationModel)]),
where: WhereNode.create(conjunction(this.dialect, [idConditions, policyFilter])),
selections: node.returning.selections,
};
Expand Down Expand Up @@ -433,13 +468,23 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf

private getMutationModel(node: InsertQueryNode | UpdateQueryNode | DeleteQueryNode) {
const r = match(node)
.when(InsertQueryNode.is, (node) => getTableName(node.into) as GetModels<Schema>)
.when(UpdateQueryNode.is, (node) => getTableName(node.table) as GetModels<Schema>)
.when(InsertQueryNode.is, (node) => ({
mutationModel: getTableName(node.into) as GetModels<Schema>,
alias: undefined,
}))
.when(UpdateQueryNode.is, (node) => {
if (!node.table) {
throw new QueryError('Update query must have a table');
}
const r = this.extractTableName(node.table);
return r ? { mutationModel: r.model, alias: r.alias } : undefined;
})
.when(DeleteQueryNode.is, (node) => {
if (node.from.froms.length !== 1) {
throw new InternalError('Only one from table is supported for delete');
throw new QueryError('Only one from table is supported for delete');
}
return getTableName(node.from.froms[0]) as GetModels<Schema>;
const r = this.extractTableName(node.from.froms[0]!);
return r ? { mutationModel: r.model, alias: r.alias } : undefined;
})
.exhaustive();
if (!r) {
Expand All @@ -466,11 +511,11 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf

const allows = policies
.filter((policy) => policy.kind === 'allow')
.map((policy) => this.transformPolicyCondition(model, alias, operation, policy));
.map((policy) => this.compilePolicyCondition(model, alias, operation, policy));

const denies = policies
.filter((policy) => policy.kind === 'deny')
.map((policy) => this.transformPolicyCondition(model, alias, operation, policy));
.map((policy) => this.compilePolicyCondition(model, alias, operation, policy));

let combinedPolicy: OperationNode;

Expand All @@ -494,18 +539,18 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
return combinedPolicy;
}

private extractTableName(from: OperationNode): { model: GetModels<Schema>; alias?: string } | undefined {
if (TableNode.is(from)) {
return { model: from.table.identifier.name as GetModels<Schema> };
private extractTableName(node: OperationNode): { model: GetModels<Schema>; alias?: string } | undefined {
if (TableNode.is(node)) {
return { model: node.table.identifier.name as GetModels<Schema> };
}
if (AliasNode.is(from)) {
const inner = this.extractTableName(from.node);
if (AliasNode.is(node)) {
const inner = this.extractTableName(node.node);
if (!inner) {
return undefined;
}
return {
model: inner.model,
alias: IdentifierNode.is(from.alias) ? from.alias.name : undefined,
alias: IdentifierNode.is(node.alias) ? node.alias.name : undefined,
};
} else {
// this can happen for subqueries, which will be handled when nested
Expand All @@ -514,7 +559,26 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
}
}

private transformPolicyCondition(
private createPolicyFilterForFrom(node: FromNode | undefined) {
if (!node) {
return undefined;
}
return this.createPolicyFilterForTables(node.froms);
}

private createPolicyFilterForTables(tables: readonly OperationNode[]) {
return tables.reduce<OperationNode | undefined>((acc, table) => {
const extractResult = this.extractTableName(table);
if (extractResult) {
const { model, alias } = extractResult;
const filter = this.buildPolicyFilter(model, alias, 'read');
return acc ? conjunction(this.dialect, [acc, filter]) : filter;
}
return acc;
}, undefined);
}

private compilePolicyCondition(
model: GetModels<Schema>,
alias: string | undefined,
operation: CRUD,
Expand Down
Loading