diff --git a/TODO.md b/TODO.md index c92d4fc6..0d4ad692 100644 --- a/TODO.md +++ b/TODO.md @@ -101,6 +101,7 @@ - [ ] Short-circuit pre-create check for scalar-field only policies - [x] Inject "on conflict do update" - [x] `check` function + - [ ] Accessing tables not in the schema - [x] Migration - [ ] Databases - [x] SQLite diff --git a/packages/language/res/stdlib.zmodel b/packages/language/res/stdlib.zmodel index e43c389c..52d34ae4 100644 --- a/packages/language/res/stdlib.zmodel +++ b/packages/language/res/stdlib.zmodel @@ -666,7 +666,7 @@ attribute @@@deprecated(_ message: String) * @param operation: comma-separated list of "create", "read", "update", "delete". Use "all" to denote all operations. * @param condition: a boolean expression that controls if the operation should be allowed. */ -attribute @@allow(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'delete'", "'all'"]), _ condition: Boolean) +attribute @@allow(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'post-update'","'delete'", "'all'"]), _ condition: Boolean) /** * Defines an access policy that allows the annotated field to be read or updated. @@ -684,7 +684,7 @@ attribute @allow(_ operation: String @@@completionHint(["'create'", "'read'", "' * @param operation: comma-separated list of "create", "read", "update", "delete". Use "all" to denote all operations. * @param condition: a boolean expression that controls if the operation should be denied. */ -attribute @@deny(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'delete'", "'all'"]), _ condition: Boolean) +attribute @@deny(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'post-update'","'delete'", "'all'"]), _ condition: Boolean) /** * Defines an access policy that denies the annotated field to be read or updated. @@ -705,8 +705,8 @@ function check(field: Any, operation: String?): Boolean { } @@@expressionContext([AccessPolicy]) /** - * Gets entities value before an update. Only valid when used in a "update" policy rule. + * Gets entity's value before an update. Only valid when used in a "post-update" policy rule. */ -function future(): Any { +function before(): Any { } @@@expressionContext([AccessPolicy]) diff --git a/packages/language/src/utils.ts b/packages/language/src/utils.ts index 2f9b4a75..19220f07 100644 --- a/packages/language/src/utils.ts +++ b/packages/language/src/utils.ts @@ -145,10 +145,6 @@ export function isRelationshipField(field: DataField) { return isDataModel(field.type.reference?.ref); } -export function isFutureExpr(node: AstNode) { - return isInvocationExpr(node) && node.function.ref?.name === 'future' && isFromStdlib(node.function.ref); -} - export function isDelegateModel(node: AstNode) { return isDataModel(node) && hasAttribute(node, '@@delegate'); } @@ -450,8 +446,8 @@ export function getAuthDecl(decls: (DataModel | TypeDef)[]) { return authModel; } -export function isFutureInvocation(node: AstNode) { - return isInvocationExpr(node) && node.function.ref?.name === 'future' && isFromStdlib(node.function.ref); +export function isBeforeInvocation(node: AstNode) { + return isInvocationExpr(node) && node.function.ref?.name === 'before' && isFromStdlib(node.function.ref); } export function isCollectionPredicate(node: AstNode): node is BinaryExpr { diff --git a/packages/language/src/validators/attribute-application-validator.ts b/packages/language/src/validators/attribute-application-validator.ts index dc376036..b5384196 100644 --- a/packages/language/src/validators/attribute-application-validator.ts +++ b/packages/language/src/validators/attribute-application-validator.ts @@ -23,10 +23,10 @@ import { getAllAttributes, getStringLiteral, isAuthOrAuthMemberAccess, + isBeforeInvocation, isCollectionPredicate, isDataFieldReference, isDelegateModel, - isFutureExpr, isRelationshipField, mapBuiltinTypeToExpressionType, resolved, @@ -166,13 +166,20 @@ export default class AttributeApplicationValidator implements AstValidator isFutureExpr(node))) { - accept('error', `"future()" is not allowed in field-level policy rules`, { node: expr }); + if (expr && AstUtils.streamAst(expr).some((node) => isBeforeInvocation(node))) { + accept('error', `"before()" is not allowed in field-level policy rules`, { node: expr }); } // 'update' rules are not allowed for relation fields diff --git a/packages/language/src/validators/expression-validator.ts b/packages/language/src/validators/expression-validator.ts index f8dc4930..f455753f 100644 --- a/packages/language/src/validators/expression-validator.ts +++ b/packages/language/src/validators/expression-validator.ts @@ -11,6 +11,7 @@ import { isNullExpr, isReferenceExpr, isThisExpr, + MemberAccessExpr, type ExpressionType, } from '../generated/ast'; @@ -18,6 +19,7 @@ import { findUpAst, isAuthInvocation, isAuthOrAuthMemberAccess, + isBeforeInvocation, isDataFieldReference, isEnumFieldReference, typeAssignable, @@ -59,12 +61,21 @@ export default class ExpressionValidator implements AstValidator { // extra validations by expression type switch (expr.$type) { + case 'MemberAccessExpr': + this.validateMemberAccessExpr(expr, accept); + break; case 'BinaryExpr': this.validateBinaryExpr(expr, accept); break; } } + private validateMemberAccessExpr(expr: MemberAccessExpr, accept: ValidationAcceptor) { + if (isBeforeInvocation(expr.operand) && isDataModel(expr.$resolvedType?.decl)) { + accept('error', 'relation fields cannot be accessed from `before()`', { node: expr }); + } + } + private validateBinaryExpr(expr: BinaryExpr, accept: ValidationAcceptor) { switch (expr.operator) { case 'in': { diff --git a/packages/language/src/zmodel-linker.ts b/packages/language/src/zmodel-linker.ts index 9b045283..3bb45134 100644 --- a/packages/language/src/zmodel-linker.ts +++ b/packages/language/src/zmodel-linker.ts @@ -57,7 +57,7 @@ import { getAuthDecl, getContainingDataModel, isAuthInvocation, - isFutureExpr, + isBeforeInvocation, isMemberContainer, mapBuiltinTypeToExpressionType, } from './utils'; @@ -292,8 +292,8 @@ export class ZModelLinker extends DefaultLinker { if (authDecl) { node.$resolvedType = { decl: authDecl, nullable: true }; } - } else if (isFutureExpr(node)) { - // future() function is resolved to current model + } else if (isBeforeInvocation(node)) { + // before() function is resolved to current model node.$resolvedType = { decl: getContainingDataModel(node) }; } else { this.resolveToDeclaredType(node, funcDecl.returnType); diff --git a/packages/language/src/zmodel-scope.ts b/packages/language/src/zmodel-scope.ts index e2b58f02..2fd8b37a 100644 --- a/packages/language/src/zmodel-scope.ts +++ b/packages/language/src/zmodel-scope.ts @@ -37,7 +37,7 @@ import { getRecursiveBases, isAuthInvocation, isCollectionPredicate, - isFutureInvocation, + isBeforeInvocation, resolveImportUri, } from './utils'; @@ -170,8 +170,8 @@ export class ZModelScopeProvider extends DefaultScopeProvider { return this.createScopeForAuth(node, globalScope); } - if (isFutureInvocation(operand)) { - // resolve `future()` to the containing model + if (isBeforeInvocation(operand)) { + // resolve `before()` to the containing model return this.createScopeForContainingModel(node, globalScope); } return EMPTY_SCOPE; diff --git a/packages/language/test/attribute-application.test.ts b/packages/language/test/attribute-application.test.ts new file mode 100644 index 00000000..71d39323 --- /dev/null +++ b/packages/language/test/attribute-application.test.ts @@ -0,0 +1,23 @@ +import { describe, it } from 'vitest'; +import { loadSchemaWithError } from './utils'; + +describe('Attribute application validation tests', () => { + it('rejects before in non-post-update policies', async () => { + await loadSchemaWithError( + ` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + + model Foo { + id Int @id @default(autoincrement()) + x Int + @@allow('all', true) + @@deny('update', before(x) > 2) + } + `, + `"before()" is only allowed in "post-update" policy rules`, + ); + }); +}); diff --git a/packages/runtime/src/client/contract.ts b/packages/runtime/src/client/contract.ts index 0d38e34e..002f478c 100644 --- a/packages/runtime/src/client/contract.ts +++ b/packages/runtime/src/client/contract.ts @@ -213,11 +213,21 @@ export interface ClientConstructor { */ export type CRUD = 'create' | 'read' | 'update' | 'delete'; +/** + * Extended CRUD operations including 'post-update'. + */ +export type CRUD_EXT = CRUD | 'post-update'; + /** * CRUD operations. */ export const CRUD = ['create', 'read', 'update', 'delete'] as const; +/** + * Extended CRUD operations including 'post-update'. + */ +export const CRUD_EXT = [...CRUD, 'post-update'] as const; + //#region Model operations export type AllModelOperations> = { diff --git a/packages/runtime/src/client/crud/operations/base.ts b/packages/runtime/src/client/crud/operations/base.ts index 170b8d89..7a4c9a66 100644 --- a/packages/runtime/src/client/crud/operations/base.ts +++ b/packages/runtime/src/client/crud/operations/base.ts @@ -1296,8 +1296,9 @@ export abstract class BaseOperationHandler { return { count: Number(result.numAffectedRows) } as Result; } else { const idFields = requireIdFields(this.schema, model); - const result = await query.returning(idFields as any).execute(); - return result as Result; + const finalQuery = query.returning(idFields as any); + const result = await this.executeQuery(kysely, finalQuery, 'update'); + return result.rows as Result; } } diff --git a/packages/runtime/src/client/options.ts b/packages/runtime/src/client/options.ts index 7d1134a3..f09f44d6 100644 --- a/packages/runtime/src/client/options.ts +++ b/packages/runtime/src/client/options.ts @@ -1,7 +1,7 @@ import type { Dialect, Expression, ExpressionBuilder, KyselyConfig } from 'kysely'; import type { GetModel, GetModels, ProcedureDef, SchemaDef } from '../schema'; import type { PrependParameter } from '../utils/type-utils'; -import type { ClientContract, CRUD, ProcedureFunc } from './contract'; +import type { ClientContract, CRUD_EXT, ProcedureFunc } from './contract'; import type { BaseCrudDialect } from './crud/dialects/base-dialect'; import type { RuntimePlugin } from './plugin'; import type { ToKyselySchema } from './query-builder'; @@ -30,7 +30,7 @@ export type ZModelFunctionContext = { /** * The CRUD operation being performed */ - operation: CRUD; + operation: CRUD_EXT; }; export type ZModelFunction = ( diff --git a/packages/runtime/src/plugins/policy/expression-transformer.ts b/packages/runtime/src/plugins/policy/expression-transformer.ts index 414b72b4..d78c2803 100644 --- a/packages/runtime/src/plugins/policy/expression-transformer.ts +++ b/packages/runtime/src/plugins/policy/expression-transformer.ts @@ -20,7 +20,7 @@ import { type OperationNode, } from 'kysely'; import { match } from 'ts-pattern'; -import type { ClientContract, CRUD } from '../../client/contract'; +import type { ClientContract, CRUD_EXT } from '../../client/contract'; import { getCrudDialect } from '../../client/crud/dialects'; import type { BaseCrudDialect } from '../../client/crud/dialects/base-dialect'; import { InternalError, QueryError } from '../../client/errors'; @@ -50,13 +50,12 @@ import { type SchemaDef, } from '../../schema'; import { ExpressionEvaluator } from './expression-evaluator'; -import { conjunction, disjunction, falseNode, logicalNot, trueNode } from './utils'; +import { conjunction, disjunction, falseNode, isBeforeInvocation, logicalNot, trueNode } from './utils'; export type ExpressionTransformerContext = { model: GetModels; alias?: string; - operation: CRUD; - auth?: any; + operation: CRUD_EXT; memberFilter?: OperationNode; memberSelect?: SelectionNode; }; @@ -439,7 +438,7 @@ export class ExpressionTransformer { } if (this.isAuthMember(arg)) { - const valNode = this.valueMemberAccess(context.auth, arg as MemberExpression, this.authType); + const valNode = this.valueMemberAccess(this.auth, arg as MemberExpression, this.authType); return valNode ? eb.val(valNode.value) : eb.val(null); } @@ -453,11 +452,20 @@ export class ExpressionTransformer { @expr('member') // @ts-ignore private _member(expr: MemberExpression, context: ExpressionTransformerContext) { - // auth() member access + // `auth()` member access if (this.isAuthCall(expr.receiver)) { return this.valueMemberAccess(this.auth, expr, this.authType); } + // `before()` member access + if (isBeforeInvocation(expr.receiver)) { + // policy handler creates a join table named `$before` using entity value before update, + // we can directly reference the column from there + invariant(context.operation === 'post-update', 'before() can only be used in post-update policy'); + invariant(expr.members.length === 1, 'before() can only be followed by a scalar field access'); + return ReferenceNode.create(ColumnNode.create(expr.members[0]!), TableNode.create('$before')); + } + invariant( ExpressionUtils.isField(expr.receiver) || ExpressionUtils.isThis(expr.receiver), 'expect receiver to be field expression or "this"', diff --git a/packages/runtime/src/plugins/policy/plugin.ts b/packages/runtime/src/plugins/policy/plugin.ts index 6af93353..7ebd2882 100644 --- a/packages/runtime/src/plugins/policy/plugin.ts +++ b/packages/runtime/src/plugins/policy/plugin.ts @@ -22,8 +22,8 @@ export class PolicyPlugin implements RuntimePlugin) { + onKyselyQuery({ query, client, proceed }: OnKyselyQueryArgs) { const handler = new PolicyHandler(client); - return handler.handle(query, proceed /*, transaction*/); + return handler.handle(query, proceed); } } diff --git a/packages/runtime/src/plugins/policy/policy-handler.ts b/packages/runtime/src/plugins/policy/policy-handler.ts index 50cfc835..597ac249 100644 --- a/packages/runtime/src/plugins/policy/policy-handler.ts +++ b/packages/runtime/src/plugins/policy/policy-handler.ts @@ -16,7 +16,9 @@ import { ParensNode, PrimitiveValueListNode, RawNode, + ReferenceNode, ReturningNode, + SelectAllNode, SelectionNode, SelectQueryNode, sql, @@ -32,18 +34,26 @@ import { } from 'kysely'; import { match } from 'ts-pattern'; import type { ClientContract } from '../../client'; -import type { CRUD } from '../../client/contract'; +import { type CRUD_EXT } from '../../client/contract'; import { getCrudDialect } from '../../client/crud/dialects'; import type { BaseCrudDialect } from '../../client/crud/dialects/base-dialect'; import { InternalError, QueryError } from '../../client/errors'; import type { ProceedKyselyQueryFunction } from '../../client/plugin'; import { getManyToManyRelation, requireField, requireIdFields, requireModel } from '../../client/query-utils'; -import { ExpressionUtils, type BuiltinType, type Expression, type GetModels, type SchemaDef } from '../../schema'; +import { + ExpressionUtils, + type BuiltinType, + type Expression, + type GetModels, + type MemberExpression, + type SchemaDef, +} from '../../schema'; +import { ExpressionVisitor } from '../../utils/expression-utils'; import { ColumnCollector } from './column-collector'; import { RejectedByPolicyError, RejectedByPolicyReason } from './errors'; import { ExpressionTransformer } from './expression-transformer'; import type { Policy, PolicyOperation } from './types'; -import { buildIsFalse, conjunction, disjunction, falseNode, getTableName } from './utils'; +import { buildIsFalse, conjunction, disjunction, falseNode, getTableName, isBeforeInvocation, trueNode } from './utils'; export type CrudQueryNode = SelectQueryNode | InsertQueryNode | UpdateQueryNode | DeleteQueryNode; @@ -61,10 +71,7 @@ export class PolicyHandler extends OperationNodeTransf return this.client.$qb; } - async handle( - node: RootOperationNode, - proceed: ProceedKyselyQueryFunction /*, transaction: OnKyselyQueryTransaction*/, - ) { + async handle(node: RootOperationNode, proceed: ProceedKyselyQueryFunction) { if (!this.isCrudQueryNode(node)) { // non-CRUD queries are not allowed throw new RejectedByPolicyError( @@ -81,6 +88,8 @@ export class PolicyHandler extends OperationNodeTransf const { mutationModel } = this.getMutationModel(node); + // --- Pre mutation work --- + if (InsertQueryNode.is(node)) { // pre-create policy evaluation happens before execution of the query const isManyToManyJoinTable = this.isManyToManyJoinTable(mutationModel); @@ -102,12 +111,86 @@ export class PolicyHandler extends OperationNodeTransf } } + const hasPostUpdatePolicies = UpdateQueryNode.is(node) && this.hasPostUpdatePolicies(mutationModel); + + let beforeUpdateInfo: Awaited> | undefined; + if (hasPostUpdatePolicies) { + beforeUpdateInfo = await this.loadBeforeUpdateEntities(mutationModel, node.where, proceed); + } + // proceed with query const result = await proceed(this.transformNode(node)); + // --- Post mutation work --- + + if (hasPostUpdatePolicies) { + // entities updated filter + const idConditions = this.buildIdConditions(mutationModel, result.rows); + + // post-update policy filter + const postUpdateFilter = this.buildPolicyFilter(mutationModel, undefined, 'post-update'); + + // read the post-update row with filter applied + + const eb = expressionBuilder(); + + // create a `SELECT column1 as field1, column2 as field2, ... FROM (VALUES (...))` table for before-update rows + const beforeUpdateTable: SelectQueryNode | undefined = beforeUpdateInfo + ? { + kind: 'SelectQueryNode', + from: FromNode.create([ + ParensNode.create( + ValuesNode.create( + beforeUpdateInfo!.rows.map((r) => + PrimitiveValueListNode.create(beforeUpdateInfo!.fields.map((f) => r[f])), + ), + ), + ), + ]), + selections: beforeUpdateInfo.fields.map((name, index) => { + const def = requireField(this.client.$schema, mutationModel, name); + const castedColumnRef = + sql`CAST(${eb.ref(`column${index + 1}`)} as ${sql.raw(this.dialect.getFieldSqlType(def))})`.as( + name, + ); + return SelectionNode.create(castedColumnRef.toOperationNode()); + }), + } + : undefined; + + const postUpdateQuery = eb + .selectFrom(mutationModel) + .select(() => [eb(eb.fn('COUNT', [eb.lit(1)]), '=', result.rows.length).as('$condition')]) + .where(() => new ExpressionWrapper(conjunction(this.dialect, [idConditions, postUpdateFilter]))) + .$if(!!beforeUpdateInfo, (qb) => + qb.leftJoin( + () => new ExpressionWrapper(beforeUpdateTable!).as('$before'), + (join) => { + const idFields = requireIdFields(this.client.$schema, mutationModel); + return idFields.reduce( + (acc, f) => acc.onRef(`${mutationModel}.${f}`, '=', `$before.${f}`), + join, + ); + }, + ), + ); + + const postUpdateResult = await proceed(postUpdateQuery.toOperationNode()); + if (!postUpdateResult.rows[0]?.$condition) { + throw new RejectedByPolicyError( + mutationModel, + RejectedByPolicyReason.NO_ACCESS, + 'some or all updated rows failed to pass post-update policy check', + ); + } + } + + // --- Read back --- + if (!node.returning || this.onlyReturningId(node)) { - return result; + // no need to check read back + return this.postProcessMutationResult(result, node); } else { const readBackResult = await this.processReadBack(node, result, proceed); if (readBackResult.rows.length !== result.rows.length) { @@ -121,6 +204,75 @@ export class PolicyHandler extends OperationNodeTransf } } + // correction to kysely mutation result may be needed because we might have added + // returning clause to the query and caused changes to the result shape + private postProcessMutationResult(result: QueryResult, node: MutationQueryNode) { + if (node.returning) { + return result; + } else { + return { + ...result, + rows: [], + numAffectedRows: result.numAffectedRows ?? BigInt(result.rows.length), + }; + } + } + + hasPostUpdatePolicies(model: GetModels) { + const policies = this.getModelPolicies(model, 'post-update'); + return policies.length > 0; + } + + private async loadBeforeUpdateEntities( + model: GetModels, + where: WhereNode | undefined, + proceed: ProceedKyselyQueryFunction, + ) { + const beforeUpdateAccessFields = this.getFieldsAccessForBeforeUpdatePolicies(model); + if (!beforeUpdateAccessFields || beforeUpdateAccessFields.length === 0) { + return undefined; + } + const query: SelectQueryNode = { + kind: 'SelectQueryNode', + from: FromNode.create([TableNode.create(model)]), + where, + selections: [...beforeUpdateAccessFields.map((f) => SelectionNode.create(ColumnNode.create(f)))], + }; + const result = await proceed(query); + return { fields: beforeUpdateAccessFields, rows: result.rows }; + } + + private getFieldsAccessForBeforeUpdatePolicies(model: GetModels) { + const policies = this.getModelPolicies(model, 'post-update'); + if (policies.length === 0) { + return undefined; + } + + const fields = new Set(); + const fieldCollector = new (class extends ExpressionVisitor { + protected override visitMember(e: MemberExpression): void { + if (isBeforeInvocation(e.receiver)) { + invariant(e.members.length === 1, 'before() can only be followed by a scalar field access'); + fields.add(e.members[0]!); + } + super.visitMember(e); + } + })(); + + for (const policy of policies) { + fieldCollector.visit(policy.condition); + } + + if (fields.size === 0) { + return undefined; + } + + // make sure id fields are included + requireIdFields(this.client.$schema, model).forEach((f) => fields.add(f)); + + return Array.from(fields).sort(); + } + // #region overrides protected override transformSelectQuery(node: SelectQueryNode) { @@ -193,23 +345,19 @@ export class PolicyHandler extends OperationNodeTransf const result = super.transformInsertQuery(processedNode); - if (!node.returning) { - return result; - } - - if (this.onlyReturningId(node)) { - return result; - } else { - // only return ID fields, that's enough for reading back the inserted row + // if any field is to be returned, we select ID fields here which will be used + // for reading back post-insert + let returning = result.returning; + if (returning) { const { mutationModel } = this.getMutationModel(node); const idFields = requireIdFields(this.client.$schema, mutationModel); - return { - ...result, - returning: ReturningNode.create( - idFields.map((field) => SelectionNode.create(ColumnNode.create(field))), - ), - }; + returning = ReturningNode.create(idFields.map((f) => SelectionNode.create(ColumnNode.create(f)))); } + + return { + ...result, + returning, + }; } protected override transformUpdateQuery(node: UpdateQueryNode) { @@ -225,9 +373,23 @@ export class PolicyHandler extends OperationNodeTransf } } + let returning = result.returning; + + // regarding returning: + // 1. if fields are to be returned, we only select id fields here which will be used for reading back + // post-update + // 2. if there are post-update policies, we need to make sure id fields are selected for joining with + // before-update rows + + if (returning || this.hasPostUpdatePolicies(mutationModel)) { + const idFields = requireIdFields(this.client.$schema, mutationModel); + returning = ReturningNode.create(idFields.map((f) => SelectionNode.create(ColumnNode.create(f)))); + } + return { ...result, where: WhereNode.create(result.where ? conjunction(this.dialect, [result.where.where, filter]) : filter), + returning, }; } @@ -260,6 +422,19 @@ export class PolicyHandler extends OperationNodeTransf } const { mutationModel } = this.getMutationModel(node); const idFields = requireIdFields(this.client.$schema, mutationModel); + + if (node.returning.selections.some((s) => SelectAllNode.is(s.selection))) { + const modelDef = requireModel(this.client.$schema, mutationModel); + if (Object.keys(modelDef.fields).some((f) => !idFields.includes(f))) { + // there are fields other than ID fields + return false; + } else { + // select all but model only has ID fields + return true; + } + } + + // analyze selected columns const collector = new ColumnCollector(); const selectedColumns = collector.collect(node.returning); return selectedColumns.every((c) => idFields.includes(c)); @@ -543,7 +718,7 @@ export class PolicyHandler extends OperationNodeTransf this.dialect, idFields.map((field) => BinaryOperationNode.create( - ColumnNode.create(field), + ReferenceNode.create(ColumnNode.create(field), TableNode.create(table)), OperatorNode.create('='), ValueNode.create(row[field]), ), @@ -590,7 +765,7 @@ export class PolicyHandler extends OperationNodeTransf return InsertQueryNode.is(node) || UpdateQueryNode.is(node) || DeleteQueryNode.is(node); } - buildPolicyFilter(model: GetModels, alias: string | undefined, operation: CRUD) { + buildPolicyFilter(model: GetModels, alias: string | undefined, operation: CRUD_EXT): OperationNode { // first check if it's a many-to-many join table, and if so, handle specially const m2mFilter = this.getModelPolicyFilterForManyToManyJoinTable(model, alias, operation); if (m2mFilter) { @@ -598,9 +773,6 @@ export class PolicyHandler extends OperationNodeTransf } const policies = this.getModelPolicies(model, operation); - if (policies.length === 0) { - return falseNode(this.dialect); - } const allows = policies .filter((policy) => policy.kind === 'allow') @@ -610,25 +782,33 @@ export class PolicyHandler extends OperationNodeTransf .filter((policy) => policy.kind === 'deny') .map((policy) => this.compilePolicyCondition(model, alias, operation, policy)); + // 'post-update' is by default allowed, other operations are by default denied let combinedPolicy: OperationNode; if (allows.length === 0) { - // constant false - combinedPolicy = falseNode(this.dialect); + // no allow rules + if (operation === 'post-update') { + // post-update is allowed if no allow rules are defined + combinedPolicy = trueNode(this.dialect); + } else { + // other operations are denied by default + combinedPolicy = falseNode(this.dialect); + } } else { // or(...allows) combinedPolicy = disjunction(this.dialect, allows); + } - // and(...!denies) - if (denies.length !== 0) { - const combinedDenies = conjunction( - this.dialect, - denies.map((d) => buildIsFalse(d, this.dialect)), - ); - // or(...allows) && and(...!denies) - combinedPolicy = conjunction(this.dialect, [combinedPolicy, combinedDenies]); - } + // and(...!denies) + if (denies.length !== 0) { + const combinedDenies = conjunction( + this.dialect, + denies.map((d) => buildIsFalse(d, this.dialect)), + ); + // or(...allows) && and(...!denies) + combinedPolicy = conjunction(this.dialect, [combinedPolicy, combinedDenies]); } + return combinedPolicy; } @@ -674,14 +854,13 @@ export class PolicyHandler extends OperationNodeTransf private compilePolicyCondition( model: GetModels, alias: string | undefined, - operation: CRUD, + operation: CRUD_EXT, policy: Policy, ) { return new ExpressionTransformer(this.client).transform(policy.condition, { model, alias, operation, - auth: this.client.$auth, }); } @@ -710,7 +889,11 @@ export class PolicyHandler extends OperationNodeTransf condition: attr.args![1]!.value, }) as const, ) - .filter((policy) => policy.operations.includes('all') || policy.operations.includes(operation)), + .filter( + (policy) => + (operation !== 'post-update' && policy.operations.includes('all')) || + policy.operations.includes(operation), + ), ); } return result; diff --git a/packages/runtime/src/plugins/policy/types.ts b/packages/runtime/src/plugins/policy/types.ts index cfe20871..74c49d85 100644 --- a/packages/runtime/src/plugins/policy/types.ts +++ b/packages/runtime/src/plugins/policy/types.ts @@ -1,4 +1,4 @@ -import type { CRUD } from '../../client/contract'; +import type { CRUD_EXT } from '../../client/contract'; import type { Expression } from '../../schema'; /** @@ -9,7 +9,7 @@ export type PolicyKind = 'allow' | 'deny'; /** * Access policy operation. */ -export type PolicyOperation = CRUD | 'all'; +export type PolicyOperation = CRUD_EXT | 'all'; /** * Access policy definition. diff --git a/packages/runtime/src/plugins/policy/utils.ts b/packages/runtime/src/plugins/policy/utils.ts index 1113cb7e..5fc11410 100644 --- a/packages/runtime/src/plugins/policy/utils.ts +++ b/packages/runtime/src/plugins/policy/utils.ts @@ -13,7 +13,7 @@ import { ValueNode, } from 'kysely'; import type { BaseCrudDialect } from '../../client/crud/dialects/base-dialect'; -import type { SchemaDef } from '../../schema'; +import { ExpressionUtils, type Expression, type SchemaDef } from '../../schema'; /** * Creates a `true` value node. @@ -154,3 +154,7 @@ export function getTableName(node: OperationNode | undefined) { } return undefined; } + +export function isBeforeInvocation(expr: Expression) { + return ExpressionUtils.isCall(expr) && expr.function === 'before'; +} diff --git a/packages/runtime/src/utils/expression-utils.ts b/packages/runtime/src/utils/expression-utils.ts new file mode 100644 index 00000000..8c0824d4 --- /dev/null +++ b/packages/runtime/src/utils/expression-utils.ts @@ -0,0 +1,58 @@ +import { match } from 'ts-pattern'; +import type { + ArrayExpression, + BinaryExpression, + CallExpression, + Expression, + FieldExpression, + LiteralExpression, + MemberExpression, + NullExpression, + ThisExpression, + UnaryExpression, +} from '../schema'; + +export class ExpressionVisitor { + visit(expr: Expression): void { + match(expr) + .with({ kind: 'literal' }, (e) => this.visitLiteral(e)) + .with({ kind: 'array' }, (e) => this.visitArray(e)) + .with({ kind: 'field' }, (e) => this.visitField(e)) + .with({ kind: 'member' }, (e) => this.visitMember(e)) + .with({ kind: 'binary' }, (e) => this.visitBinary(e)) + .with({ kind: 'unary' }, (e) => this.visitUnary(e)) + .with({ kind: 'call' }, (e) => this.visitCall(e)) + .with({ kind: 'this' }, (e) => this.visitThis(e)) + .with({ kind: 'null' }, (e) => this.visitNull(e)) + .exhaustive(); + } + + protected visitLiteral(_e: LiteralExpression) {} + + protected visitArray(e: ArrayExpression) { + e.items.forEach((item) => this.visit(item)); + } + + protected visitField(_e: FieldExpression) {} + + protected visitMember(e: MemberExpression) { + this.visit(e.receiver); + } + + protected visitBinary(e: BinaryExpression) { + this.visit(e.left); + this.visit(e.right); + } + + protected visitUnary(e: UnaryExpression) { + this.visit(e.operand); + } + + protected visitCall(e: CallExpression) { + e.args?.forEach((arg) => this.visit(arg)); + } + + protected visitThis(_e: ThisExpression) {} + + protected visitNull(_e: NullExpression) {} +} diff --git a/packages/runtime/test/plugin/entity-mutation-hooks.test.ts b/packages/runtime/test/plugin-infra/entity-mutation-hooks.test.ts similarity index 100% rename from packages/runtime/test/plugin/entity-mutation-hooks.test.ts rename to packages/runtime/test/plugin-infra/entity-mutation-hooks.test.ts diff --git a/packages/runtime/test/plugin/on-kysely-query.test.ts b/packages/runtime/test/plugin-infra/on-kysely-query.test.ts similarity index 100% rename from packages/runtime/test/plugin/on-kysely-query.test.ts rename to packages/runtime/test/plugin-infra/on-kysely-query.test.ts diff --git a/packages/runtime/test/plugin/on-query-hooks.test.ts b/packages/runtime/test/plugin-infra/on-query-hooks.test.ts similarity index 100% rename from packages/runtime/test/plugin/on-query-hooks.test.ts rename to packages/runtime/test/plugin-infra/on-query-hooks.test.ts diff --git a/packages/runtime/test/policy/crud/post-update.test.ts b/packages/runtime/test/policy/crud/post-update.test.ts new file mode 100644 index 00000000..585ee180 --- /dev/null +++ b/packages/runtime/test/policy/crud/post-update.test.ts @@ -0,0 +1,190 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from '../utils'; + +describe('Policy post-update tests', () => { + it('allows post-update by default', async () => { + const db = await createPolicyTestClient( + ` + model Foo { + id Int @id + x Int + @@allow('read,create,update', true) + } + `, + ); + + await db.foo.create({ data: { id: 1, x: 0 } }); + await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).toResolveTruthy(); + }); + + it('works with simple post-update rules', async () => { + const db = await createPolicyTestClient( + ` + model Foo { + id Int @id + x Int + @@allow('all', true) + @@allow('post-update', x > 1) + @@deny('post-update', x > 2) + } + `, + ); + + await db.foo.create({ data: { id: 1, x: 0 } }); + + // allow: x > 1 + await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).toBeRejectedByPolicy(); + // check not updated + await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 0 }); + + // deny: x > 2 + await expect(db.foo.update({ where: { id: 1 }, data: { x: 3 } })).toBeRejectedByPolicy(); + // check not updated + await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 0 }); + + await expect(db.foo.update({ where: { id: 1 }, data: { x: 2 } })).resolves.toMatchObject({ x: 2 }); + }); + + it('respect deny rules without allow', async () => { + const db = await createPolicyTestClient( + ` + model Foo { + id Int @id + x Int + @@allow('create,read,update', true) + @@deny('post-update', x > 1) + } + `, + ); + + await db.foo.create({ data: { id: 1, x: 0 } }); + await expect(db.foo.update({ where: { id: 1 }, data: { x: 2 } })).toBeRejectedByPolicy(); + await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).toResolveTruthy(); + }); + + it('works with relation conditions', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id + age Int + profile Profile? + @@allow('all', true) + @@allow('post-update', profile == null || age == profile.age) + } + + model Profile { + id Int @id + age Int + userId Int @unique + user User @relation(fields: [userId], references: [id]) + @@allow('all', true) + } + `, + ); + + await db.user.create({ data: { id: 1, age: 20, profile: { create: { id: 1, age: 18 } } } }); + await expect(db.user.update({ where: { id: 1 }, data: { age: 22 } })).toBeRejectedByPolicy(); + await expect(db.user.update({ where: { id: 1 }, data: { age: 18 } })).toResolveTruthy(); + + await db.user.create({ data: { id: 2, age: 20, profile: { create: { id: 2, age: 18 } } } }); + await expect( + db.user.update({ where: { id: 2 }, data: { age: 22, profile: { delete: true } } }), + ).toResolveTruthy(); + }); + + it('works with before function', async () => { + const db = await createPolicyTestClient( + ` + model Foo { + id Int @id + x Int + @@allow('all', true) + @@allow('post-update', x > before().x) + } + `, + ); + + await db.foo.create({ data: { id: 1, x: 1 } }); + await db.foo.create({ data: { id: 2, x: 2 } }); + + // update one + await expect(db.foo.update({ where: { id: 1 }, data: { x: 0 } })).toBeRejectedByPolicy(); + // check not updated + await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 1 }); + + // update many + await expect(db.foo.updateMany({ data: { x: 0 } })).toBeRejectedByPolicy(); + // check not updated + await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 1 }); + await expect(db.foo.findUnique({ where: { id: 2 } })).resolves.toMatchObject({ x: 2 }); + + await expect(db.foo.update({ where: { id: 1 }, data: { x: 2 } })).toResolveTruthy(); + await expect(db.foo.updateMany({ data: { x: 3 } })).resolves.toMatchObject({ count: 2 }); + // check updated + await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 3 }); + await expect(db.foo.findUnique({ where: { id: 2 } })).resolves.toMatchObject({ x: 3 }); + }); + + // TODO: fix transaction issue + it.skip('works with query builder API', async () => { + const db = await createPolicyTestClient( + ` + model Foo { + id Int @id + x Int + @@allow('all', true) + @@allow('post-update', x > before().x) + } + `, + ); + + await db.foo.create({ data: { id: 1, x: 1 } }); + await db.foo.create({ data: { id: 2, x: 2 } }); + + // update one + await expect(db.$qb.updateTable('Foo').set({ x: 0 }).where('id', '=', 1).execute()).toBeRejectedByPolicy(); + // check not updated + await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 1 }); + + // update many + await expect(db.$qb.updateTable('Foo').set({ x: 0 }).execute()).toBeRejectedByPolicy(); + // check not updated + await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 1 }); + await expect(db.foo.findUnique({ where: { id: 2 } })).resolves.toMatchObject({ x: 2 }); + + await expect(db.$qb.updateTable('Foo').set({ x: 2 }).where('id', '=', 1).execute()).resolves.toMatchObject({ + numAffectedRows: 1n, + }); + // check updated + await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 2 }); + + await expect(db.$qb.updateTable('Foo').set({ x: 3 }).execute()).resolves.toMatchObject({ + numAffectedRows: 2n, + }); + // check updated + await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 3 }); + await expect(db.foo.findUnique({ where: { id: 2 } })).resolves.toMatchObject({ x: 3 }); + }); + + it('rejects accessing relation fields from before', async () => { + await expect( + createPolicyTestClient( + ` + model User { + id Int @id + name String + profile Profile? + } + + model Profile { + id Int @id + userId Int @unique + user User @relation(fields: [userId], references: [id]) + @@allow('post-update', before().user.name == 'a') + } + `, + ), + ).rejects.toThrow('relation fields cannot be accessed from `before()`'); + }); +}); diff --git a/packages/runtime/test/schemas/petstore/schema.ts b/packages/runtime/test/schemas/petstore/schema.ts index c6902c7e..59954144 100644 --- a/packages/runtime/test/schemas/petstore/schema.ts +++ b/packages/runtime/test/schemas/petstore/schema.ts @@ -92,7 +92,7 @@ export const schema = { }, attributes: [ { name: "@@allow", args: [{ name: "operation", value: ExpressionUtils.literal("read") }, { name: "condition", value: ExpressionUtils.binary(ExpressionUtils.binary(ExpressionUtils.field("orderId"), "==", ExpressionUtils._null()), "||", ExpressionUtils.binary(ExpressionUtils.member(ExpressionUtils.field("order"), ["user"]), "==", ExpressionUtils.call("auth"))) }] }, - { name: "@@allow", args: [{ name: "operation", value: ExpressionUtils.literal("update") }, { name: "condition", value: ExpressionUtils.binary(ExpressionUtils.binary(ExpressionUtils.binary(ExpressionUtils.field("name"), "==", ExpressionUtils.member(ExpressionUtils.call("future"), ["name"])), "&&", ExpressionUtils.binary(ExpressionUtils.field("category"), "==", ExpressionUtils.member(ExpressionUtils.call("future"), ["category"]))), "&&", ExpressionUtils.binary(ExpressionUtils.field("orderId"), "==", ExpressionUtils._null())) }] } + { name: "@@allow", args: [{ name: "operation", value: ExpressionUtils.literal("post-update") }, { name: "condition", value: ExpressionUtils.binary(ExpressionUtils.binary(ExpressionUtils.binary(ExpressionUtils.member(ExpressionUtils.call("before"), ["name"]), "==", ExpressionUtils.field("name")), "&&", ExpressionUtils.binary(ExpressionUtils.member(ExpressionUtils.call("before"), ["category"]), "==", ExpressionUtils.field("category"))), "&&", ExpressionUtils.binary(ExpressionUtils.member(ExpressionUtils.call("before"), ["orderId"]), "==", ExpressionUtils._null())) }] } ], idFields: ["id"], uniqueFields: { diff --git a/packages/runtime/test/schemas/petstore/schema.zmodel b/packages/runtime/test/schemas/petstore/schema.zmodel index 4a2442ca..a7a0f53b 100644 --- a/packages/runtime/test/schemas/petstore/schema.zmodel +++ b/packages/runtime/test/schemas/petstore/schema.zmodel @@ -36,7 +36,7 @@ model Pet { @@allow('read', orderId == null || order.user == auth()) // only allow update to 'orderId' field if it's not set yet (unsold) - @@allow('update', name == future().name && category == future().category && orderId == null ) + @@allow('post-update', before().name == name && before().category == category && before().orderId == null) } model Order {