diff --git a/packages/language/res/stdlib.zmodel b/packages/language/res/stdlib.zmodel index 0d8c4264..b419ab24 100644 --- a/packages/language/res/stdlib.zmodel +++ b/packages/language/res/stdlib.zmodel @@ -594,16 +594,6 @@ function datetime(field: String): Boolean { function url(field: String): Boolean { } @@@expressionContext([ValidationRule]) -/** - * Checks if the current user can perform the given operation on the given field. - * - * @param field: The field to check access for - * @param operation: The operation to check access for. Can be "read", "create", "update", or "delete". If the operation is not provided, - * it defaults the operation of the containing policy rule. - */ -function check(field: Any, operation: String?): Boolean { -} @@@expressionContext([AccessPolicy]) - ////////////////////////////////////////////// // End validation attributes and functions ////////////////////////////////////////////// diff --git a/packages/language/src/utils.ts b/packages/language/src/utils.ts index 06677192..b4b5dd30 100644 --- a/packages/language/src/utils.ts +++ b/packages/language/src/utils.ts @@ -357,8 +357,9 @@ export function getFieldReference(expr: Expression): DataField | undefined { } } +// TODO: move to policy plugin export function isCheckInvocation(node: AstNode) { - return isInvocationExpr(node) && node.function.ref?.name === 'check' && isFromStdlib(node.function.ref); + return isInvocationExpr(node) && node.function.ref?.name === 'check'; } export function resolveTransitiveImports(documents: LangiumDocuments, model: Model) { diff --git a/packages/language/src/validators/function-invocation-validator.ts b/packages/language/src/validators/function-invocation-validator.ts index b640ad1b..e75c8e3d 100644 --- a/packages/language/src/validators/function-invocation-validator.ts +++ b/packages/language/src/validators/function-invocation-validator.ts @@ -170,6 +170,7 @@ export default class FunctionInvocationValidator implements AstValidator> = { diff --git a/packages/runtime/src/client/executor/kysely-utils.ts b/packages/runtime/src/client/executor/kysely-utils.ts deleted file mode 100644 index fb9ec845..00000000 --- a/packages/runtime/src/client/executor/kysely-utils.ts +++ /dev/null @@ -1,12 +0,0 @@ -import { type OperationNode, AliasNode } from 'kysely'; - -/** - * Strips alias from the node if it exists. - */ -export function stripAlias(node: OperationNode) { - if (AliasNode.is(node)) { - return { alias: node.alias, node: node.node }; - } else { - return { alias: undefined, node }; - } -} diff --git a/packages/runtime/src/client/executor/name-mapper.ts b/packages/runtime/src/client/executor/name-mapper.ts index c839bc75..83ef8a33 100644 --- a/packages/runtime/src/client/executor/name-mapper.ts +++ b/packages/runtime/src/client/executor/name-mapper.ts @@ -17,8 +17,8 @@ import { type OperationNode, } from 'kysely'; import type { FieldDef, ModelDef, SchemaDef } from '../../schema'; +import { extractFieldName, extractModelName, stripAlias } from '../kysely-utils'; import { getModel, requireModel } from '../query-utils'; -import { stripAlias } from './kysely-utils'; type Scope = { model?: string; @@ -170,7 +170,7 @@ export class QueryNameMapper extends OperationNodeTransformer { const scopes: Scope[] = node.from.froms.map((node) => { const { alias, node: innerNode } = stripAlias(node); return { - model: this.extractModelName(innerNode), + model: extractModelName(innerNode), alias, namesMapped: false, }; @@ -219,8 +219,8 @@ export class QueryNameMapper extends OperationNodeTransformer { selections.push(SelectionNode.create(transformed)); } else { // otherwise use an alias to preserve the original field name - const origFieldName = this.extractFieldName(selection.selection); - const fieldName = this.extractFieldName(transformed); + const origFieldName = extractFieldName(selection.selection); + const fieldName = extractFieldName(transformed); if (fieldName !== origFieldName) { selections.push( SelectionNode.create( @@ -425,7 +425,7 @@ export class QueryNameMapper extends OperationNodeTransformer { private processSelection(node: AliasNode | ColumnNode | ReferenceNode) { let alias: string | undefined; if (!AliasNode.is(node)) { - alias = this.extractFieldName(node); + alias = extractFieldName(node); } const result = super.transformNode(node); return this.wrapAlias(result, alias ? IdentifierNode.create(alias) : undefined); @@ -451,20 +451,5 @@ export class QueryNameMapper extends OperationNodeTransformer { }); } - private extractModelName(node: OperationNode): string | undefined { - const { node: innerNode } = stripAlias(node); - return TableNode.is(innerNode!) ? innerNode!.table.identifier.name : undefined; - } - - private extractFieldName(node: ReferenceNode | ColumnNode) { - if (ReferenceNode.is(node) && ColumnNode.is(node.column)) { - return node.column.column.name; - } else if (ColumnNode.is(node)) { - return node.column.name; - } else { - return undefined; - } - } - // #endregion } diff --git a/packages/runtime/src/client/executor/zenstack-query-executor.ts b/packages/runtime/src/client/executor/zenstack-query-executor.ts index be317924..2d2395cb 100644 --- a/packages/runtime/src/client/executor/zenstack-query-executor.ts +++ b/packages/runtime/src/client/executor/zenstack-query-executor.ts @@ -26,8 +26,8 @@ import type { GetModels, SchemaDef } from '../../schema'; import { type ClientImpl } from '../client-impl'; import { TransactionIsolationLevel, type ClientContract } from '../contract'; import { InternalError, QueryError } from '../errors'; +import { stripAlias } from '../kysely-utils'; import type { AfterEntityMutationCallback, OnKyselyQueryCallback } from '../plugin'; -import { stripAlias } from './kysely-utils'; import { QueryNameMapper } from './name-mapper'; import type { ZenStackDriver } from './zenstack-driver'; diff --git a/packages/runtime/src/client/kysely-utils.ts b/packages/runtime/src/client/kysely-utils.ts new file mode 100644 index 00000000..a46464c3 --- /dev/null +++ b/packages/runtime/src/client/kysely-utils.ts @@ -0,0 +1,33 @@ +import { type OperationNode, AliasNode, ColumnNode, ReferenceNode, TableNode } from 'kysely'; + +/** + * Strips alias from the node if it exists. + */ +export function stripAlias(node: OperationNode) { + if (AliasNode.is(node)) { + return { alias: node.alias, node: node.node }; + } else { + return { alias: undefined, node }; + } +} + +/** + * Extracts model name from an OperationNode. + */ +export function extractModelName(node: OperationNode) { + const { node: innerNode } = stripAlias(node); + return TableNode.is(innerNode!) ? innerNode!.table.identifier.name : undefined; +} + +/** + * Extracts field name from an OperationNode. + */ +export function extractFieldName(node: OperationNode) { + if (ReferenceNode.is(node) && ColumnNode.is(node.column)) { + return node.column.column.name; + } else if (ColumnNode.is(node)) { + return node.column.name; + } else { + return undefined; + } +} diff --git a/packages/runtime/src/client/options.ts b/packages/runtime/src/client/options.ts index 7c90e330..ad7df8f0 100644 --- a/packages/runtime/src/client/options.ts +++ b/packages/runtime/src/client/options.ts @@ -7,8 +7,29 @@ import type { RuntimePlugin } from './plugin'; import type { ToKyselySchema } from './query-builder'; export type ZModelFunctionContext = { + /** + * ZenStack client instance + */ + client: ClientContract; + + /** + * Database dialect + */ dialect: BaseCrudDialect; + + /** + * The containing model name + */ model: GetModels; + + /** + * The alias name that can be used to refer to the containing model + */ + modelAlias: string; + + /** + * The CRUD operation being performed + */ operation: CRUD; }; diff --git a/packages/runtime/src/client/plugin.ts b/packages/runtime/src/client/plugin.ts index 62216a3d..eda9e4a7 100644 --- a/packages/runtime/src/client/plugin.ts +++ b/packages/runtime/src/client/plugin.ts @@ -3,6 +3,7 @@ import type { ClientContract } from '.'; import type { GetModels, SchemaDef } from '../schema'; import type { MaybePromise } from '../utils/type-utils'; import type { AllCrudOperation } from './crud/operations/base'; +import type { ZModelFunction } from './options'; /** * ZenStack runtime plugin. @@ -23,6 +24,11 @@ export interface RuntimePlugin { */ description?: string; + /** + * Custom function implementations. + */ + functions?: Record>; + /** * Intercepts an ORM query. */ diff --git a/packages/runtime/src/client/query-utils.ts b/packages/runtime/src/client/query-utils.ts index b574ed05..3fdd9858 100644 --- a/packages/runtime/src/client/query-utils.ts +++ b/packages/runtime/src/client/query-utils.ts @@ -194,7 +194,7 @@ export function buildFieldRef( if (!computer) { throw new QueryError(`Computed field "${field}" implementation not provided for model "${model}"`); } - return computer(eb, { currentModel: modelAlias }); + return computer(eb, { modelAlias }); } } diff --git a/packages/runtime/src/plugins/policy/expression-transformer.ts b/packages/runtime/src/plugins/policy/expression-transformer.ts index 04df8cca..d5b879b3 100644 --- a/packages/runtime/src/plugins/policy/expression-transformer.ts +++ b/packages/runtime/src/plugins/policy/expression-transformer.ts @@ -20,11 +20,10 @@ import { type OperationNode, } from 'kysely'; import { match } from 'ts-pattern'; -import type { CRUD } from '../../client/contract'; +import type { ClientContract, CRUD } from '../../client/contract'; import { getCrudDialect } from '../../client/crud/dialects'; import type { BaseCrudDialect } from '../../client/crud/dialects/base-dialect'; import { InternalError, QueryError } from '../../client/errors'; -import type { ClientOptions } from '../../client/options'; import { getModel, getRelationForeignKeyFieldPairs, requireField, requireIdFields } from '../../client/query-utils'; import type { BinaryExpression, @@ -72,14 +71,22 @@ function expr(kind: Expression['kind']) { export class ExpressionTransformer { private readonly dialect: BaseCrudDialect; - constructor( - private readonly schema: Schema, - private readonly clientOptions: ClientOptions, - private readonly auth: unknown | undefined, - ) { + constructor(private readonly client: ClientContract) { this.dialect = getCrudDialect(this.schema, this.clientOptions); } + get schema() { + return this.client.$schema; + } + + get clientOptions() { + return this.client.$options; + } + + get auth() { + return this.client.$auth; + } + get authType() { if (!this.schema.authType) { throw new InternalError('Schema does not have an "authType" specified'); @@ -354,7 +361,7 @@ export class ExpressionTransformer { } private transformCall(expr: CallExpression, context: ExpressionTransformerContext) { - const func = this.clientOptions.functions?.[expr.function]; + const func = this.getFunctionImpl(expr.function); if (!func) { throw new QueryError(`Function not implemented: ${expr.function}`); } @@ -363,13 +370,30 @@ export class ExpressionTransformer { eb, (expr.args ?? []).map((arg) => this.transformCallArg(eb, arg, context)), { + client: this.client, dialect: this.dialect, model: context.model, + modelAlias: context.alias ?? context.model, operation: context.operation, }, ); } + private getFunctionImpl(functionName: string) { + // check built-in functions + let func = this.clientOptions.functions?.[functionName]; + if (!func) { + // check plugins + for (const plugin of this.clientOptions.plugins ?? []) { + if (plugin.functions?.[functionName]) { + func = plugin.functions[functionName]; + break; + } + } + } + return func; + } + private transformCallArg( eb: ExpressionBuilder, arg: Expression, diff --git a/packages/runtime/src/plugins/policy/functions.ts b/packages/runtime/src/plugins/policy/functions.ts new file mode 100644 index 00000000..c7fa09d7 --- /dev/null +++ b/packages/runtime/src/plugins/policy/functions.ts @@ -0,0 +1,62 @@ +import { invariant } from '@zenstackhq/common-helpers'; +import { ExpressionWrapper, ValueNode, type Expression, type ExpressionBuilder } from 'kysely'; +import { CRUD } from '../../client/contract'; +import { extractFieldName } from '../../client/kysely-utils'; +import type { ZModelFunction, ZModelFunctionContext } from '../../client/options'; +import { buildJoinPairs, requireField } from '../../client/query-utils'; +import { PolicyHandler } from './policy-handler'; + +/** + * Relation checker implementation. + */ +export const check: ZModelFunction = ( + eb: ExpressionBuilder, + args: Expression[], + { client, model, modelAlias, operation }: ZModelFunctionContext, +) => { + invariant(args.length === 1 || args.length === 2, '"check" function requires 1 or 2 arguments'); + + const arg1Node = args[0]!.toOperationNode(); + + const arg2Node = args.length === 2 ? args[1]!.toOperationNode() : undefined; + if (arg2Node) { + invariant( + ValueNode.is(arg2Node) && typeof arg2Node.value === 'string', + '"operation" parameter must be a string literal when provided', + ); + invariant( + CRUD.includes(arg2Node.value as CRUD), + '"operation" parameter must be one of "create", "read", "update", "delete"', + ); + } + + // first argument must be a field reference + const fieldName = extractFieldName(arg1Node); + invariant(fieldName, 'Failed to extract field name from the first argument of "check" function'); + const fieldDef = requireField(client.$schema, model, fieldName); + invariant(fieldDef.relation, `Field "${fieldName}" is not a relation field in model "${model}"`); + invariant(!fieldDef.array, `Field "${fieldName}" is a to-many relation, which is not supported by "check"`); + const relationModel = fieldDef.type; + + const op = arg2Node ? (arg2Node.value as CRUD) : operation; + + const policyHandler = new PolicyHandler(client); + + // join with parent model + const joinPairs = buildJoinPairs(client.$schema, model, modelAlias, fieldName, relationModel); + const joinCondition = + joinPairs.length === 1 + ? eb(eb.ref(joinPairs[0]![0]), '=', eb.ref(joinPairs[0]![1])) + : eb.and(joinPairs.map(([left, right]) => eb(eb.ref(left), '=', eb.ref(right)))); + + // policy condition of the related model + const policyCondition = policyHandler.buildPolicyFilter(relationModel, undefined, op); + + // build the final nested select that evaluates the policy condition + const result = eb + .selectFrom(relationModel) + .where(joinCondition) + .select(new ExpressionWrapper(policyCondition).as('$condition')); + + return result; +}; diff --git a/packages/runtime/src/plugins/policy/plugin.ts b/packages/runtime/src/plugins/policy/plugin.ts index e5b914d5..6af93353 100644 --- a/packages/runtime/src/plugins/policy/plugin.ts +++ b/packages/runtime/src/plugins/policy/plugin.ts @@ -1,5 +1,6 @@ import { type OnKyselyQueryArgs, type RuntimePlugin } from '../../client/plugin'; import type { SchemaDef } from '../../schema'; +import { check } from './functions'; import { PolicyHandler } from './policy-handler'; export class PolicyPlugin implements RuntimePlugin { @@ -15,6 +16,12 @@ export class PolicyPlugin implements RuntimePlugin) { const handler = new PolicyHandler(client); return handler.handle(query, proceed /*, transaction*/); diff --git a/packages/runtime/src/plugins/policy/plugin.zmodel b/packages/runtime/src/plugins/policy/plugin.zmodel index ecb39320..659705ce 100644 --- a/packages/runtime/src/plugins/policy/plugin.zmodel +++ b/packages/runtime/src/plugins/policy/plugin.zmodel @@ -31,3 +31,13 @@ attribute @@deny(_ operation: String @@@completionHint(["'create'", "'read'", "' * @param condition: a boolean expression that controls if the operation should be denied. */ attribute @deny(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'delete'", "'all'"]), _ condition: Boolean) + +/** + * Checks if the current user can perform the given operation on the given field. + * + * @param field: The field to check access for + * @param operation: The operation to check access for. Can be "read", "create", "update", or "delete". If the operation is not provided, + * it defaults the operation of the containing policy rule. + */ +function check(field: Any, operation: String?): Boolean { +} @@@expressionContext([AccessPolicy]) diff --git a/packages/runtime/src/plugins/policy/policy-handler.ts b/packages/runtime/src/plugins/policy/policy-handler.ts index 029ec58c..7ab7d2f0 100644 --- a/packages/runtime/src/plugins/policy/policy-handler.ts +++ b/packages/runtime/src/plugins/policy/policy-handler.ts @@ -503,7 +503,7 @@ export class PolicyHandler extends OperationNodeTransf return InsertQueryNode.is(node) || UpdateQueryNode.is(node) || DeleteQueryNode.is(node); } - private buildPolicyFilter(model: GetModels, alias: string | undefined, operation: CRUD) { + buildPolicyFilter(model: GetModels, alias: string | undefined, operation: CRUD) { const policies = this.getModelPolicies(model, operation); if (policies.length === 0) { return falseNode(this.dialect); @@ -584,15 +584,12 @@ export class PolicyHandler extends OperationNodeTransf operation: CRUD, policy: Policy, ) { - return new ExpressionTransformer(this.client.$schema, this.client.$options, this.client.$auth).transform( - policy.condition, - { - model, - alias, - operation, - auth: this.client.$auth, - }, - ); + return new ExpressionTransformer(this.client).transform(policy.condition, { + model, + alias, + operation, + auth: this.client.$auth, + }); } private getModelPolicies(modelName: string, operation: PolicyOperation) { diff --git a/packages/runtime/test/client-api/computed-fields.test.ts b/packages/runtime/test/client-api/computed-fields.test.ts index 054997a3..5bf3c16a 100644 --- a/packages/runtime/test/client-api/computed-fields.test.ts +++ b/packages/runtime/test/client-api/computed-fields.test.ts @@ -237,10 +237,10 @@ model Post { dbName: TEST_DB, computedFields: { User: { - postCount: (eb: any, context: { currentModel: string }) => + postCount: (eb: any, context: { modelAlias: string }) => eb .selectFrom('Post') - .whereRef('Post.authorId', '=', sql.ref(`${context.currentModel}.id`)) + .whereRef('Post.authorId', '=', sql.ref(`${context.modelAlias}.id`)) .select(() => eb.fn.countAll().as('count')), }, }, diff --git a/packages/runtime/test/policy/migrated/relation-check.test.ts b/packages/runtime/test/policy/migrated/relation-check.test.ts new file mode 100644 index 00000000..be0947aa --- /dev/null +++ b/packages/runtime/test/policy/migrated/relation-check.test.ts @@ -0,0 +1,736 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from '../utils'; + +describe('Relation checker', () => { + it('should work for read', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + public Boolean + @@allow('read', public) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + age Int + @@allow('read', check(user, 'read')) + } + `, + ); + + await db.$unuseAll().user.create({ + data: { + id: 1, + public: true, + profile: { + create: { age: 18 }, + }, + }, + }); + + await db.$unuseAll().user.create({ + data: { + id: 2, + public: false, + profile: { + create: { age: 20 }, + }, + }, + }); + + await expect(db.profile.findMany()).resolves.toHaveLength(1); + }); + + it('should work for simple create', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + public Boolean + @@allow('create', true) + @@allow('read', public) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + age Int + @@allow('read', true) + @@allow('create', check(user, 'read')) + } + `, + ); + + await db.$unuseAll().user.create({ + data: { + id: 1, + public: true, + }, + }); + + await db.$unuseAll().user.create({ + data: { + id: 2, + public: false, + }, + }); + + await expect(db.profile.create({ data: { user: { connect: { id: 1 } }, age: 18 } })).toResolveTruthy(); + await expect(db.profile.create({ data: { user: { connect: { id: 2 } }, age: 18 } })).toBeRejectedByPolicy(); + }); + + it('should work for nested create', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + public Boolean + @@allow('create', true) + @@allow('read', public) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + age Int + @@allow('read', true) + @@allow('create', age < 30 && check(user, 'read')) + } + `, + ); + + await expect( + db.user.create({ + data: { + id: 1, + public: true, + profile: { + create: { age: 18 }, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.user.create({ + data: { + id: 2, + public: false, + profile: { + create: { age: 18 }, + }, + }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.user.create({ + data: { + id: 3, + public: true, + profile: { + create: { age: 30 }, + }, + }, + }), + ).toBeRejectedByPolicy(); + }); + + it('should work for update', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + public Boolean + @@allow('create', true) + @@allow('read', public) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + age Int + @@allow('read', true) + @@allow('update', check(user, 'read') && age < 30) + } + `, + ); + + await db.$unuseAll().user.create({ + data: { + id: 1, + public: true, + profile: { + create: { id: 1, age: 18 }, + }, + }, + }); + + await db.$unuseAll().user.create({ + data: { + id: 2, + public: false, + profile: { + create: { id: 2, age: 20 }, + }, + }, + }); + + await db.$unuseAll().user.create({ + data: { + id: 3, + public: true, + profile: { + create: { id: 3, age: 30 }, + }, + }, + }); + + await expect(db.profile.update({ where: { id: 1 }, data: { age: 21 } })).toResolveTruthy(); + await expect(db.profile.update({ where: { id: 2 }, data: { age: 21 } })).toBeRejectedNotFound(); + await expect(db.profile.update({ where: { id: 3 }, data: { age: 21 } })).toBeRejectedNotFound(); + }); + + it('should work for delete', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + public Boolean + @@allow('create', true) + @@allow('read', public) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + age Int + @@allow('read', true) + @@allow('delete', check(user, 'read') && age < 30) + } + `, + ); + + await db.$unuseAll().user.create({ + data: { + id: 1, + public: true, + profile: { + create: { id: 1, age: 18 }, + }, + }, + }); + + await db.$unuseAll().user.create({ + data: { + id: 2, + public: false, + profile: { + create: { id: 2, age: 20 }, + }, + }, + }); + + await db.$unuseAll().user.create({ + data: { + id: 3, + public: true, + profile: { + create: { id: 3, age: 30 }, + }, + }, + }); + + await expect(db.profile.delete({ where: { id: 1 } })).toResolveTruthy(); + await expect(db.profile.delete({ where: { id: 2 } })).toBeRejectedNotFound(); + await expect(db.profile.delete({ where: { id: 3 } })).toBeRejectedNotFound(); + }); + + // TODO: field-level policy support + it.skip('should work for field-level', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + public Boolean + @@allow('read', public) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + age Int @allow('read', age < 30 && check(user, 'read')) + @@allow('all', true) + } + `, + ); + + await db.$unuseAll().user.create({ + data: { + id: 1, + public: true, + profile: { + create: { age: 18 }, + }, + }, + }); + + await db.$unuseAll().user.create({ + data: { + id: 2, + public: false, + profile: { + create: { age: 20 }, + }, + }, + }); + + await db.$unuseAll().user.create({ + data: { + id: 3, + public: true, + profile: { + create: { age: 30 }, + }, + }, + }); + + const p1 = await db.profile.findUnique({ where: { id: 1 } }); + expect(p1.age).toBe(18); + const p2 = await db.profile.findUnique({ where: { id: 2 } }); + expect(p2.age).toBeUndefined(); + const p3 = await db.profile.findUnique({ where: { id: 3 } }); + expect(p3.age).toBeUndefined(); + }); + + // TODO: field-level policy support + it.skip('should work for field-level with override', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + public Boolean + @@allow('read', public) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + age Int @allow('read', age < 30 && check(user, 'read'), true) + } + `, + ); + + await db.$unuseAll().user.create({ + data: { + id: 1, + public: true, + profile: { + create: { age: 18 }, + }, + }, + }); + + await db.$unuseAll().user.create({ + data: { + id: 2, + public: false, + profile: { + create: { age: 20 }, + }, + }, + }); + + await db.$unuseAll().user.create({ + data: { + id: 3, + public: true, + profile: { + create: { age: 30 }, + }, + }, + }); + + const p1 = await db.profile.findUnique({ where: { id: 1 }, select: { age: true } }); + expect(p1.age).toBe(18); + const p2 = await db.profile.findUnique({ where: { id: 2 }, select: { age: true } }); + expect(p2).toBeNull(); + const p3 = await db.profile.findUnique({ where: { id: 3 }, select: { age: true } }); + expect(p3).toBeNull(); + }); + + it('should work for cross-model field comparison', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + age Int + @@allow('read', true) + @@allow('update', age == profile.age) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + age Int + @@allow('read', true) + @@allow('update', check(user, 'update') && age < 30) + } + `, + ); + + await db.$unuseAll().user.create({ + data: { + id: 1, + age: 18, + profile: { + create: { id: 1, age: 18 }, + }, + }, + }); + + await db.$unuseAll().user.create({ + data: { + id: 2, + age: 18, + profile: { + create: { id: 2, age: 20 }, + }, + }, + }); + + await db.$unuseAll().user.create({ + data: { + id: 3, + age: 30, + profile: { + create: { id: 3, age: 30 }, + }, + }, + }); + + await expect(db.profile.update({ where: { id: 1 }, data: { age: 21 } })).toResolveTruthy(); + await expect(db.profile.update({ where: { id: 2 }, data: { age: 21 } })).toBeRejectedNotFound(); + await expect(db.profile.update({ where: { id: 3 }, data: { age: 21 } })).toBeRejectedNotFound(); + }); + + it('should work for implicit specific operations', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + public Boolean + @@allow('read', public) + @@allow('create', true) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + age Int + @@allow('read', check(user)) + @@allow('create', check(user)) + } + `, + ); + + await db.$unuseAll().user.create({ + data: { + id: 1, + public: true, + profile: { + create: { age: 18 }, + }, + }, + }); + + await db.$unuseAll().user.create({ + data: { + id: 2, + public: false, + profile: { + create: { age: 20 }, + }, + }, + }); + + await expect(db.profile.findMany()).resolves.toHaveLength(1); + + await db.$unuseAll().user.create({ + data: { + id: 3, + public: true, + }, + }); + await expect(db.profile.create({ data: { user: { connect: { id: 3 } }, age: 18 } })).toResolveTruthy(); + + await db.$unuseAll().user.create({ + data: { + id: 4, + public: false, + }, + }); + await expect(db.profile.create({ data: { user: { connect: { id: 4 } }, age: 18 } })).toBeRejectedByPolicy(); + }); + + it('should work for implicit all operations', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + public Boolean + @@allow('all', public) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + age Int + @@allow('all', check(user)) + } + `, + ); + + await db.$unuseAll().user.create({ + data: { + id: 1, + public: true, + profile: { + create: { age: 18 }, + }, + }, + }); + + await db.$unuseAll().user.create({ + data: { + id: 2, + public: false, + profile: { + create: { age: 20 }, + }, + }, + }); + + await expect(db.profile.findMany()).resolves.toHaveLength(1); + + await db.$unuseAll().user.create({ + data: { + id: 3, + public: true, + }, + }); + await expect(db.profile.create({ data: { user: { connect: { id: 3 } }, age: 18 } })).toResolveTruthy(); + + await db.$unuseAll().user.create({ + data: { + id: 4, + public: false, + }, + }); + await expect(db.profile.create({ data: { user: { connect: { id: 4 } }, age: 18 } })).toBeRejectedByPolicy(); + }); + + it('should report error for invalid args', async () => { + await expect( + createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + public Boolean + @@allow('read', check(public)) + } + `, + ), + ).rejects.toThrow(/argument must be a relation field/); + + await expect( + createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + posts Post[] + @@allow('read', check(posts)) + } + model Post { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int + } + `, + ), + ).rejects.toThrow(/argument cannot be an array field/); + + await expect( + createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + @@allow('read', check(profile.details)) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int + details ProfileDetails? + } + + model ProfileDetails { + id Int @id @default(autoincrement()) + profile Profile @relation(fields: [profileId], references: [id]) + profileId Int + age Int + } + `, + ), + ).rejects.toThrow(/argument must be a relation field/); + + await expect( + createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + posts Post[] + @@allow('read', check(posts, 'all')) + } + model Post { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int + } + `, + ), + ).rejects.toThrow(/argument must be a "read", "create", "update", or "delete"/); + }); + + it('should report error for cyclic relation check', async () => { + await expect( + createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + profileDetails ProfileDetails? + public Boolean + @@allow('read', check(profile)) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + details ProfileDetails? + @@allow('read', check(details)) + } + + model ProfileDetails { + id Int @id @default(autoincrement()) + profile Profile @relation(fields: [profileId], references: [id]) + profileId Int @unique + user User @relation(fields: [userId], references: [id]) + userId Int @unique + age Int + @@allow('read', check(user)) + } + `, + ), + ).rejects.toThrow(/cyclic dependency/); + }); + + it('should report error for cyclic relation check indirect', async () => { + await expect( + createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + public Boolean + @@allow('read', check(profile)) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + details ProfileDetails? + @@allow('read', check(details)) + } + + model ProfileDetails { + id Int @id @default(autoincrement()) + profile Profile @relation(fields: [profileId], references: [id]) + profileId Int @unique + age Int + @@allow('read', check(profile)) + } + `, + ), + ).rejects.toThrow(/cyclic dependency/); + }); + + it('should work for query builder', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + public Boolean + @@allow('read', public) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + age Int + @@allow('read', check(user)) + } + `, + ); + + await db.$unuseAll().user.create({ + data: { + id: 1, + public: true, + profile: { + create: { age: 18 }, + }, + }, + }); + + await db.$unuseAll().user.create({ + data: { + id: 2, + public: false, + profile: { + create: { age: 20 }, + }, + }, + }); + + await expect(db.$qb.selectFrom('Profile as p').selectAll('p').execute()).resolves.toHaveLength(1); + }); +}); diff --git a/packages/runtime/test/schemas/typing/schema.ts b/packages/runtime/test/schemas/typing/schema.ts index 90a532e0..18270ceb 100644 --- a/packages/runtime/test/schemas/typing/schema.ts +++ b/packages/runtime/test/schemas/typing/schema.ts @@ -86,7 +86,7 @@ export const schema = { }, computedFields: { postCount(_context: { - currentModel: string; + modelAlias: string; }): OperandExpression { throw new Error("This is a stub for computed field"); } diff --git a/packages/sdk/src/ts-schema-generator.ts b/packages/sdk/src/ts-schema-generator.ts index 1d558300..5fc89046 100644 --- a/packages/sdk/src/ts-schema-generator.ts +++ b/packages/sdk/src/ts-schema-generator.ts @@ -376,7 +376,7 @@ export class TsSchemaGenerator { undefined, undefined, [ - // parameter: `context: { currentModel: string }` + // parameter: `context: { modelAlias: string }` ts.factory.createParameterDeclaration( undefined, undefined, @@ -385,7 +385,7 @@ export class TsSchemaGenerator { ts.factory.createTypeLiteralNode([ ts.factory.createPropertySignature( undefined, - 'currentModel', + 'modelAlias', undefined, ts.factory.createKeywordTypeNode(ts.SyntaxKind.StringKeyword), ), diff --git a/samples/blog/main.ts b/samples/blog/main.ts index dfaa6c04..8bbfb5bf 100644 --- a/samples/blog/main.ts +++ b/samples/blog/main.ts @@ -8,10 +8,10 @@ async function main() { dialect: new SqliteDialect({ database: new SQLite('./zenstack/dev.db') }), computedFields: { User: { - postCount: (eb, { currentModel }) => + postCount: (eb, { modelAlias }) => eb .selectFrom('Post') - .whereRef('Post.authorId', '=', sql.ref(`${currentModel}.id`)) + .whereRef('Post.authorId', '=', sql.ref(`${modelAlias}.id`)) .select(({ fn }) => fn.countAll().as('postCount')), }, }, diff --git a/samples/blog/zenstack/schema.ts b/samples/blog/zenstack/schema.ts index 95f2e4a8..4ca14e3e 100644 --- a/samples/blog/zenstack/schema.ts +++ b/samples/blog/zenstack/schema.ts @@ -76,7 +76,7 @@ export const schema = { }, computedFields: { postCount(_context: { - currentModel: string; + modelAlias: string; }): OperandExpression { throw new Error("This is a stub for computed field"); }