diff --git a/TODO.md b/TODO.md index c92d4fc6..cd7e8eb8 100644 --- a/TODO.md +++ b/TODO.md @@ -101,6 +101,8 @@ - [ ] Short-circuit pre-create check for scalar-field only policies - [x] Inject "on conflict do update" - [x] `check` function + - [ ] Custom functions + - [ ] Accessing tables not in the schema - [x] Migration - [ ] Databases - [x] SQLite diff --git a/package.json b/package.json index ba77ebd4..7fb1f4c9 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "zenstack-v3", - "version": "3.0.0-beta.7", + "version": "3.0.0-beta.8", "description": "ZenStack", "packageManager": "pnpm@10.12.1", "scripts": { diff --git a/packages/cli/package.json b/packages/cli/package.json index 204f87e9..1275a901 100644 --- a/packages/cli/package.json +++ b/packages/cli/package.json @@ -3,7 +3,7 @@ "publisher": "zenstack", "displayName": "ZenStack CLI", "description": "FullStack database toolkit with built-in access control and automatic API generation.", - "version": "3.0.0-beta.7", + "version": "3.0.0-beta.8", "type": "module", "author": { "name": "ZenStack Team" diff --git a/packages/common-helpers/package.json b/packages/common-helpers/package.json index 0537b870..48dfd85c 100644 --- a/packages/common-helpers/package.json +++ b/packages/common-helpers/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/common-helpers", - "version": "3.0.0-beta.7", + "version": "3.0.0-beta.8", "description": "ZenStack Common Helpers", "type": "module", "scripts": { diff --git a/packages/create-zenstack/package.json b/packages/create-zenstack/package.json index 6af9f8b4..2f252d55 100644 --- a/packages/create-zenstack/package.json +++ b/packages/create-zenstack/package.json @@ -1,6 +1,6 @@ { "name": "create-zenstack", - "version": "3.0.0-beta.7", + "version": "3.0.0-beta.8", "description": "Create a new ZenStack project", "type": "module", "scripts": { diff --git a/packages/dialects/sql.js/package.json b/packages/dialects/sql.js/package.json index 362e9ae4..b9f0a77d 100644 --- a/packages/dialects/sql.js/package.json +++ b/packages/dialects/sql.js/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/kysely-sql-js", - "version": "3.0.0-beta.7", + "version": "3.0.0-beta.8", "description": "Kysely dialect for sql.js", "type": "module", "scripts": { diff --git a/packages/eslint-config/package.json b/packages/eslint-config/package.json index 18dc914a..1401d50b 100644 --- a/packages/eslint-config/package.json +++ b/packages/eslint-config/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/eslint-config", - "version": "3.0.0-beta.7", + "version": "3.0.0-beta.8", "type": "module", "private": true, "license": "MIT" diff --git a/packages/ide/vscode/package.json b/packages/ide/vscode/package.json index 00efbd79..ac3667af 100644 --- a/packages/ide/vscode/package.json +++ b/packages/ide/vscode/package.json @@ -1,7 +1,7 @@ { "name": "zenstack-v3", "publisher": "zenstack", - "version": "3.0.8", + "version": "3.0.9", "displayName": "ZenStack V3 Language Tools", "description": "VSCode extension for ZenStack (v3) ZModel language", "private": true, diff --git a/packages/language/package.json b/packages/language/package.json index b089edce..b2b454b7 100644 --- a/packages/language/package.json +++ b/packages/language/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/language", "description": "ZenStack ZModel language specification", - "version": "3.0.0-beta.7", + "version": "3.0.0-beta.8", "license": "MIT", "author": "ZenStack Team", "files": [ diff --git a/packages/language/res/stdlib.zmodel b/packages/language/res/stdlib.zmodel index e43c389c..52d34ae4 100644 --- a/packages/language/res/stdlib.zmodel +++ b/packages/language/res/stdlib.zmodel @@ -666,7 +666,7 @@ attribute @@@deprecated(_ message: String) * @param operation: comma-separated list of "create", "read", "update", "delete". Use "all" to denote all operations. * @param condition: a boolean expression that controls if the operation should be allowed. */ -attribute @@allow(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'delete'", "'all'"]), _ condition: Boolean) +attribute @@allow(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'post-update'","'delete'", "'all'"]), _ condition: Boolean) /** * Defines an access policy that allows the annotated field to be read or updated. @@ -684,7 +684,7 @@ attribute @allow(_ operation: String @@@completionHint(["'create'", "'read'", "' * @param operation: comma-separated list of "create", "read", "update", "delete". Use "all" to denote all operations. * @param condition: a boolean expression that controls if the operation should be denied. */ -attribute @@deny(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'delete'", "'all'"]), _ condition: Boolean) +attribute @@deny(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'post-update'","'delete'", "'all'"]), _ condition: Boolean) /** * Defines an access policy that denies the annotated field to be read or updated. @@ -705,8 +705,8 @@ function check(field: Any, operation: String?): Boolean { } @@@expressionContext([AccessPolicy]) /** - * Gets entities value before an update. Only valid when used in a "update" policy rule. + * Gets entity's value before an update. Only valid when used in a "post-update" policy rule. */ -function future(): Any { +function before(): Any { } @@@expressionContext([AccessPolicy]) diff --git a/packages/language/src/index.ts b/packages/language/src/index.ts index 6b2cb56a..4b578f31 100644 --- a/packages/language/src/index.ts +++ b/packages/language/src/index.ts @@ -70,7 +70,11 @@ export async function loadDocument( // build the document together with standard library, plugin modules, and imported documents await services.shared.workspace.DocumentBuilder.build([stdLib, ...pluginDocs, document, ...importedDocuments], { - validation: true, + validation: { + stopAfterLexingErrors: true, + stopAfterParsingErrors: true, + stopAfterLinkingErrors: true, + }, }); const diagnostics = langiumDocuments.all diff --git a/packages/language/src/utils.ts b/packages/language/src/utils.ts index 2f9b4a75..19220f07 100644 --- a/packages/language/src/utils.ts +++ b/packages/language/src/utils.ts @@ -145,10 +145,6 @@ export function isRelationshipField(field: DataField) { return isDataModel(field.type.reference?.ref); } -export function isFutureExpr(node: AstNode) { - return isInvocationExpr(node) && node.function.ref?.name === 'future' && isFromStdlib(node.function.ref); -} - export function isDelegateModel(node: AstNode) { return isDataModel(node) && hasAttribute(node, '@@delegate'); } @@ -450,8 +446,8 @@ export function getAuthDecl(decls: (DataModel | TypeDef)[]) { return authModel; } -export function isFutureInvocation(node: AstNode) { - return isInvocationExpr(node) && node.function.ref?.name === 'future' && isFromStdlib(node.function.ref); +export function isBeforeInvocation(node: AstNode) { + return isInvocationExpr(node) && node.function.ref?.name === 'before' && isFromStdlib(node.function.ref); } export function isCollectionPredicate(node: AstNode): node is BinaryExpr { diff --git a/packages/language/src/validators/attribute-application-validator.ts b/packages/language/src/validators/attribute-application-validator.ts index dc376036..b5384196 100644 --- a/packages/language/src/validators/attribute-application-validator.ts +++ b/packages/language/src/validators/attribute-application-validator.ts @@ -23,10 +23,10 @@ import { getAllAttributes, getStringLiteral, isAuthOrAuthMemberAccess, + isBeforeInvocation, isCollectionPredicate, isDataFieldReference, isDelegateModel, - isFutureExpr, isRelationshipField, mapBuiltinTypeToExpressionType, resolved, @@ -166,13 +166,20 @@ export default class AttributeApplicationValidator implements AstValidator isFutureExpr(node))) { - accept('error', `"future()" is not allowed in field-level policy rules`, { node: expr }); + if (expr && AstUtils.streamAst(expr).some((node) => isBeforeInvocation(node))) { + accept('error', `"before()" is not allowed in field-level policy rules`, { node: expr }); } // 'update' rules are not allowed for relation fields diff --git a/packages/language/src/validators/expression-validator.ts b/packages/language/src/validators/expression-validator.ts index f8dc4930..f455753f 100644 --- a/packages/language/src/validators/expression-validator.ts +++ b/packages/language/src/validators/expression-validator.ts @@ -11,6 +11,7 @@ import { isNullExpr, isReferenceExpr, isThisExpr, + MemberAccessExpr, type ExpressionType, } from '../generated/ast'; @@ -18,6 +19,7 @@ import { findUpAst, isAuthInvocation, isAuthOrAuthMemberAccess, + isBeforeInvocation, isDataFieldReference, isEnumFieldReference, typeAssignable, @@ -59,12 +61,21 @@ export default class ExpressionValidator implements AstValidator { // extra validations by expression type switch (expr.$type) { + case 'MemberAccessExpr': + this.validateMemberAccessExpr(expr, accept); + break; case 'BinaryExpr': this.validateBinaryExpr(expr, accept); break; } } + private validateMemberAccessExpr(expr: MemberAccessExpr, accept: ValidationAcceptor) { + if (isBeforeInvocation(expr.operand) && isDataModel(expr.$resolvedType?.decl)) { + accept('error', 'relation fields cannot be accessed from `before()`', { node: expr }); + } + } + private validateBinaryExpr(expr: BinaryExpr, accept: ValidationAcceptor) { switch (expr.operator) { case 'in': { diff --git a/packages/language/src/zmodel-document-builder.ts b/packages/language/src/zmodel-document-builder.ts index 8e55e267..cddfc80b 100644 --- a/packages/language/src/zmodel-document-builder.ts +++ b/packages/language/src/zmodel-document-builder.ts @@ -1,22 +1,26 @@ -import { DefaultDocumentBuilder, type BuildOptions, type LangiumDocument } from 'langium'; +import { DefaultDocumentBuilder, type LangiumSharedCoreServices } from 'langium'; export class ZModelDocumentBuilder extends DefaultDocumentBuilder { - override buildDocuments(documents: LangiumDocument[], options: BuildOptions, cancelToken: any): Promise { - return super.buildDocuments( - documents, - { - ...options, - validation: - // force overriding validation options - options.validation === false || options.validation === undefined - ? options.validation - : { - stopAfterLexingErrors: true, - stopAfterParsingErrors: true, - stopAfterLinkingErrors: true, - }, - }, - cancelToken, - ); + constructor(services: LangiumSharedCoreServices) { + super(services); + + // override update build options to skip validation when there are + // errors in the previous stages + let validationOptions = this.updateBuildOptions.validation; + const stopFlags = { + stopAfterLinkingErrors: true, + stopAfterLexingErrors: true, + stopAfterParsingErrors: true, + }; + if (validationOptions === true) { + validationOptions = stopFlags; + } else if (typeof validationOptions === 'object') { + validationOptions = { ...validationOptions, ...stopFlags }; + } + + this.updateBuildOptions = { + ...this.updateBuildOptions, + validation: validationOptions, + }; } } diff --git a/packages/language/src/zmodel-linker.ts b/packages/language/src/zmodel-linker.ts index 65a2cb84..3bb45134 100644 --- a/packages/language/src/zmodel-linker.ts +++ b/packages/language/src/zmodel-linker.ts @@ -10,8 +10,8 @@ import { type LangiumDocument, type LinkingError, type Reference, + type ReferenceInfo, interruptAndCheck, - isReference, } from 'langium'; import { match } from 'ts-pattern'; import { @@ -57,7 +57,7 @@ import { getAuthDecl, getContainingDataModel, isAuthInvocation, - isFutureExpr, + isBeforeInvocation, isMemberContainer, mapBuiltinTypeToExpressionType, } from './utils'; @@ -94,18 +94,18 @@ export class ZModelLinker extends DefaultLinker { document.state = DocumentState.Linked; } - private linkReference( - container: AstNode, - property: string, - document: LangiumDocument, - extraScopes: ScopeProvider[], - ) { - if (this.resolveFromScopeProviders(container, property, document, extraScopes)) { + private linkReference(refInfo: ReferenceInfo, document: LangiumDocument, extraScopes: ScopeProvider[]) { + const defaultRef = refInfo.reference as DefaultReference; + if (defaultRef._ref) { + // already linked return; } - - const reference: DefaultReference = (container as any)[property]; - this.doLink({ reference, container, property }, document); + if (this.resolveFromScopeProviders(refInfo.reference, document, extraScopes)) { + // resolved from additional scope provider + return; + } + // default linking + this.doLink(refInfo, document); } //#endregion @@ -113,12 +113,10 @@ export class ZModelLinker extends DefaultLinker { //#region Expression type resolving private resolveFromScopeProviders( - node: AstNode, - property: string, + reference: DefaultReference, document: LangiumDocument, providers: ScopeProvider[], ) { - const reference: DefaultReference = (node as any)[property]; for (const provider of providers) { const target = provider(reference.$refText); if (target) { @@ -276,7 +274,7 @@ export class ZModelLinker extends DefaultLinker { } private resolveInvocation(node: InvocationExpr, document: LangiumDocument, extraScopes: ScopeProvider[]) { - this.linkReference(node, 'function', document, extraScopes); + this.linkReference({ reference: node.function, container: node, property: 'function' }, document, extraScopes); node.args.forEach((arg) => this.resolve(arg, document, extraScopes)); if (node.function.ref) { const funcDecl = node.function.ref as FunctionDecl; @@ -294,8 +292,8 @@ export class ZModelLinker extends DefaultLinker { if (authDecl) { node.$resolvedType = { decl: authDecl, nullable: true }; } - } else if (isFutureExpr(node)) { - // future() function is resolved to current model + } else if (isBeforeInvocation(node)) { + // before() function is resolved to current model node.$resolvedType = { decl: getContainingDataModel(node) }; } else { this.resolveToDeclaredType(node, funcDecl.returnType); @@ -401,7 +399,7 @@ export class ZModelLinker extends DefaultLinker { if (isArrayExpr(node.value)) { node.value.items.forEach((item) => { if (isReferenceExpr(item)) { - const resolved = this.resolveFromScopeProviders(item, 'target', document, [scopeProvider]); + const resolved = this.resolveFromScopeProviders(item.target, document, [scopeProvider]); if (resolved) { this.resolveToDeclaredType(item, (resolved as DataField).type); } else { @@ -414,7 +412,7 @@ export class ZModelLinker extends DefaultLinker { this.resolveToBuiltinTypeOrDecl(node.value, node.value.items[0].$resolvedType.decl, true); } } else if (isReferenceExpr(node.value)) { - const resolved = this.resolveFromScopeProviders(node.value, 'target', document, [scopeProvider]); + const resolved = this.resolveFromScopeProviders(node.value.target, document, [scopeProvider]); if (resolved) { this.resolveToDeclaredType(node.value, (resolved as DataField).type); } else { @@ -495,13 +493,9 @@ export class ZModelLinker extends DefaultLinker { } private resolveDefault(node: AstNode, document: LangiumDocument, extraScopes: ScopeProvider[]) { - for (const [property, value] of Object.entries(node)) { - if (!property.startsWith('$')) { - if (isReference(value)) { - this.linkReference(node, property, document, extraScopes); - } - } - } + AstUtils.streamReferences(node).forEach((ref) => { + this.linkReference(ref, document, extraScopes); + }); for (const child of AstUtils.streamContents(node)) { this.resolve(child, document, extraScopes); } diff --git a/packages/language/src/zmodel-scope.ts b/packages/language/src/zmodel-scope.ts index e2b58f02..2fd8b37a 100644 --- a/packages/language/src/zmodel-scope.ts +++ b/packages/language/src/zmodel-scope.ts @@ -37,7 +37,7 @@ import { getRecursiveBases, isAuthInvocation, isCollectionPredicate, - isFutureInvocation, + isBeforeInvocation, resolveImportUri, } from './utils'; @@ -170,8 +170,8 @@ export class ZModelScopeProvider extends DefaultScopeProvider { return this.createScopeForAuth(node, globalScope); } - if (isFutureInvocation(operand)) { - // resolve `future()` to the containing model + if (isBeforeInvocation(operand)) { + // resolve `before()` to the containing model return this.createScopeForContainingModel(node, globalScope); } return EMPTY_SCOPE; diff --git a/packages/language/test/attribute-application.test.ts b/packages/language/test/attribute-application.test.ts new file mode 100644 index 00000000..71d39323 --- /dev/null +++ b/packages/language/test/attribute-application.test.ts @@ -0,0 +1,23 @@ +import { describe, it } from 'vitest'; +import { loadSchemaWithError } from './utils'; + +describe('Attribute application validation tests', () => { + it('rejects before in non-post-update policies', async () => { + await loadSchemaWithError( + ` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + + model Foo { + id Int @id @default(autoincrement()) + x Int + @@allow('all', true) + @@deny('update', before(x) > 2) + } + `, + `"before()" is only allowed in "post-update" policy rules`, + ); + }); +}); diff --git a/packages/runtime/package.json b/packages/runtime/package.json index ff7712da..9e563efa 100644 --- a/packages/runtime/package.json +++ b/packages/runtime/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/runtime", - "version": "3.0.0-beta.7", + "version": "3.0.0-beta.8", "description": "ZenStack Runtime", "type": "module", "scripts": { diff --git a/packages/runtime/src/client/contract.ts b/packages/runtime/src/client/contract.ts index 0d38e34e..002f478c 100644 --- a/packages/runtime/src/client/contract.ts +++ b/packages/runtime/src/client/contract.ts @@ -213,11 +213,21 @@ export interface ClientConstructor { */ export type CRUD = 'create' | 'read' | 'update' | 'delete'; +/** + * Extended CRUD operations including 'post-update'. + */ +export type CRUD_EXT = CRUD | 'post-update'; + /** * CRUD operations. */ export const CRUD = ['create', 'read', 'update', 'delete'] as const; +/** + * Extended CRUD operations including 'post-update'. + */ +export const CRUD_EXT = [...CRUD, 'post-update'] as const; + //#region Model operations export type AllModelOperations> = { diff --git a/packages/runtime/src/client/crud/operations/base.ts b/packages/runtime/src/client/crud/operations/base.ts index 170b8d89..7a4c9a66 100644 --- a/packages/runtime/src/client/crud/operations/base.ts +++ b/packages/runtime/src/client/crud/operations/base.ts @@ -1296,8 +1296,9 @@ export abstract class BaseOperationHandler { return { count: Number(result.numAffectedRows) } as Result; } else { const idFields = requireIdFields(this.schema, model); - const result = await query.returning(idFields as any).execute(); - return result as Result; + const finalQuery = query.returning(idFields as any); + const result = await this.executeQuery(kysely, finalQuery, 'update'); + return result.rows as Result; } } diff --git a/packages/runtime/src/client/errors.ts b/packages/runtime/src/client/errors.ts index 1d6134e9..15961811 100644 --- a/packages/runtime/src/client/errors.ts +++ b/packages/runtime/src/client/errors.ts @@ -1,7 +1,12 @@ +/** + * Base for all ZenStack runtime errors. + */ +export class ZenStackError extends Error {} + /** * Error thrown when input validation fails. */ -export class InputValidationError extends Error { +export class InputValidationError extends ZenStackError { constructor(message: string, cause?: unknown) { super(message, { cause }); } @@ -10,7 +15,7 @@ export class InputValidationError extends Error { /** * Error thrown when a query fails. */ -export class QueryError extends Error { +export class QueryError extends ZenStackError { constructor(message: string, cause?: unknown) { super(message, { cause }); } @@ -19,12 +24,12 @@ export class QueryError extends Error { /** * Error thrown when an internal error occurs. */ -export class InternalError extends Error {} +export class InternalError extends ZenStackError {} /** * Error thrown when an entity is not found. */ -export class NotFoundError extends Error { +export class NotFoundError extends ZenStackError { constructor(model: string, details?: string) { super(`Entity not found for model "${model}"${details ? `: ${details}` : ''}`); } diff --git a/packages/runtime/src/client/executor/zenstack-query-executor.ts b/packages/runtime/src/client/executor/zenstack-query-executor.ts index 2d2395cb..c307bc4e 100644 --- a/packages/runtime/src/client/executor/zenstack-query-executor.ts +++ b/packages/runtime/src/client/executor/zenstack-query-executor.ts @@ -25,7 +25,7 @@ import { match } from 'ts-pattern'; import type { GetModels, SchemaDef } from '../../schema'; import { type ClientImpl } from '../client-impl'; import { TransactionIsolationLevel, type ClientContract } from '../contract'; -import { InternalError, QueryError } from '../errors'; +import { InternalError, QueryError, ZenStackError } from '../errors'; import { stripAlias } from '../kysely-utils'; import type { AfterEntityMutationCallback, OnKyselyQueryCallback } from '../plugin'; import { QueryNameMapper } from './name-mapper'; @@ -65,21 +65,53 @@ export class ZenStackQueryExecutor extends DefaultQuer return this.client.$options; } - override async executeQuery(compiledQuery: CompiledQuery, queryId: QueryId) { + override executeQuery(compiledQuery: CompiledQuery, queryId: QueryId) { // proceed with the query with kysely interceptors // if the query is a raw query, we need to carry over the parameters const queryParams = (compiledQuery as any).$raw ? compiledQuery.parameters : undefined; - const result = await this.proceedQueryWithKyselyInterceptors(compiledQuery.query, queryParams, queryId.queryId); - return result.result; + return this.provideConnection(async (connection) => { + let startedTx = false; + try { + // mutations are wrapped in tx if not already in one + if (this.isMutationNode(compiledQuery.query) && !this.driver.isTransactionConnection(connection)) { + await this.driver.beginTransaction(connection, { + isolationLevel: TransactionIsolationLevel.RepeatableRead, + }); + startedTx = true; + } + const result = await this.proceedQueryWithKyselyInterceptors( + connection, + compiledQuery.query, + queryParams, + queryId.queryId, + ); + if (startedTx) { + await this.driver.commitTransaction(connection); + } + return result; + } catch (err) { + if (startedTx) { + await this.driver.rollbackTransaction(connection); + } + if (err instanceof ZenStackError) { + throw err; + } else { + // wrap error + const message = `Failed to execute query: ${err}, sql: ${compiledQuery?.sql}`; + throw new QueryError(message, err); + } + } + }); } private async proceedQueryWithKyselyInterceptors( + connection: DatabaseConnection, queryNode: RootOperationNode, parameters: readonly unknown[] | undefined, queryId: string, ) { - let proceed = (q: RootOperationNode) => this.proceedQuery(q, parameters, queryId); + let proceed = (q: RootOperationNode) => this.proceedQuery(connection, q, parameters, queryId); const hooks: OnKyselyQueryCallback[] = []; // tsc perf @@ -92,18 +124,14 @@ export class ZenStackQueryExecutor extends DefaultQuer for (const hook of hooks) { const _proceed = proceed; proceed = async (query: RootOperationNode) => { - const _p = async (q: RootOperationNode) => { - const r = await _proceed(q); - return r.result; - }; - + const _p = (q: RootOperationNode) => _proceed(q); const hookResult = await hook!({ client: this.client as ClientContract, schema: this.client.$schema, query, proceed: _p, }); - return { result: hookResult }; + return hookResult; }; } @@ -132,161 +160,83 @@ export class ZenStackQueryExecutor extends DefaultQuer return { model, action, where }; } - private async proceedQuery(query: RootOperationNode, parameters: readonly unknown[] | undefined, queryId: string) { + private async proceedQuery( + connection: DatabaseConnection, + query: RootOperationNode, + parameters: readonly unknown[] | undefined, + queryId: string, + ) { let compiled: CompiledQuery | undefined; - try { - return await this.provideConnection(async (connection) => { - if (this.suppressMutationHooks || !this.isMutationNode(query) || !this.hasEntityMutationPlugins) { - // no need to handle mutation hooks, just proceed - const finalQuery = this.nameMapper.transformNode(query); - compiled = this.compileQuery(finalQuery); - if (parameters) { - compiled = { ...compiled, parameters }; - } - const result = await connection.executeQuery(compiled); - return { result }; - } + if (this.suppressMutationHooks || !this.isMutationNode(query) || !this.hasEntityMutationPlugins) { + // no need to handle mutation hooks, just proceed + const finalQuery = this.nameMapper.transformNode(query); + compiled = this.compileQuery(finalQuery); + if (parameters) { + compiled = { ...compiled, parameters }; + } + return connection.executeQuery(compiled); + } - if ( - (InsertQueryNode.is(query) || UpdateQueryNode.is(query)) && - this.hasEntityMutationPluginsWithAfterMutationHooks - ) { - // need to make sure the query node has "returnAll" for insert and update queries - // so that after-mutation hooks can get the mutated entities with all fields - query = { - ...query, - returning: ReturningNode.create([SelectionNode.createSelectAll()]), - }; - } - const finalQuery = this.nameMapper.transformNode(query); - compiled = this.compileQuery(finalQuery); - if (parameters) { - compiled = { ...compiled, parameters }; - } + if ( + (InsertQueryNode.is(query) || UpdateQueryNode.is(query)) && + this.hasEntityMutationPluginsWithAfterMutationHooks + ) { + // need to make sure the query node has "returnAll" for insert and update queries + // so that after-mutation hooks can get the mutated entities with all fields + query = { + ...query, + returning: ReturningNode.create([SelectionNode.createSelectAll()]), + }; + } + const finalQuery = this.nameMapper.transformNode(query); + compiled = this.compileQuery(finalQuery); + if (parameters) { + compiled = { ...compiled, parameters }; + } - // the client passed to hooks needs to be in sync with current in-transaction - // status so that it doesn't try to create a nested one - const currentlyInTx = this.driver.isTransactionConnection(connection); - - const connectionClient = this.createClientForConnection(connection, currentlyInTx); - - const mutationInfo = this.getMutationInfo(finalQuery); - - // cache already loaded before-mutation entities - let beforeMutationEntities: Record[] | undefined; - const loadBeforeMutationEntities = async () => { - if ( - beforeMutationEntities === undefined && - (UpdateQueryNode.is(query) || DeleteQueryNode.is(query)) - ) { - beforeMutationEntities = await this.loadEntities( - mutationInfo.model, - mutationInfo.where, - connection, - ); - } - return beforeMutationEntities; - }; - - // call before mutation hooks - await this.callBeforeMutationHooks( - finalQuery, - mutationInfo, - loadBeforeMutationEntities, - connectionClient, - queryId, - ); + // the client passed to hooks needs to be in sync with current in-transaction + // status so that it doesn't try to create a nested one + const currentlyInTx = this.driver.isTransactionConnection(connection); - // if mutation interceptor demands to run afterMutation hook in the transaction but we're not already - // inside one, we need to create one on the fly - const shouldCreateTx = - this.hasPluginRequestingAfterMutationWithinTransaction && - !this.driver.isTransactionConnection(connection); - - if (!shouldCreateTx) { - // if no on-the-fly tx is needed, just proceed with the query as is - const result = await connection.executeQuery(compiled); - - if (!this.driver.isTransactionConnection(connection)) { - // not in a transaction, just call all after-mutation hooks - await this.callAfterMutationHooks( - result, - finalQuery, - mutationInfo, - connectionClient, - 'all', - queryId, - ); - } else { - // run after-mutation hooks that are requested to be run inside tx - await this.callAfterMutationHooks( - result, - finalQuery, - mutationInfo, - connectionClient, - 'inTx', - queryId, - ); - - // register other after-mutation hooks to be run after the tx is committed - this.driver.registerTransactionCommitCallback(connection, () => - this.callAfterMutationHooks( - result, - finalQuery, - mutationInfo, - connectionClient, - 'outTx', - queryId, - ), - ); - } - - return { result }; - } else { - // if an on-the-fly tx is created, create one and wrap the query execution inside - await this.driver.beginTransaction(connection, { - isolationLevel: TransactionIsolationLevel.ReadCommitted, - }); - try { - // execute the query inside the on-the-fly transaction - const result = await connection.executeQuery(compiled); - - // run after-mutation hooks that are requested to be run inside tx - await this.callAfterMutationHooks( - result, - finalQuery, - mutationInfo, - connectionClient, - 'inTx', - queryId, - ); - - // commit the transaction - await this.driver.commitTransaction(connection); - - // run other after-mutation hooks after the tx is committed - await this.callAfterMutationHooks( - result, - finalQuery, - mutationInfo, - connectionClient, - 'outTx', - queryId, - ); - - return { result }; - } catch (err) { - // rollback the transaction - await this.driver.rollbackTransaction(connection); - throw err; - } - } - }); - } catch (err) { - const message = `Failed to execute query: ${err}, sql: ${compiled?.sql}`; - throw new QueryError(message, err); + const connectionClient = this.createClientForConnection(connection, currentlyInTx); + + const mutationInfo = this.getMutationInfo(finalQuery); + + // cache already loaded before-mutation entities + let beforeMutationEntities: Record[] | undefined; + const loadBeforeMutationEntities = async () => { + if (beforeMutationEntities === undefined && (UpdateQueryNode.is(query) || DeleteQueryNode.is(query))) { + beforeMutationEntities = await this.loadEntities(mutationInfo.model, mutationInfo.where, connection); + } + return beforeMutationEntities; + }; + + // call before mutation hooks + await this.callBeforeMutationHooks( + finalQuery, + mutationInfo, + loadBeforeMutationEntities, + connectionClient, + queryId, + ); + + const result = await connection.executeQuery(compiled); + + if (!this.driver.isTransactionConnection(connection)) { + // not in a transaction, just call all after-mutation hooks + await this.callAfterMutationHooks(result, finalQuery, mutationInfo, connectionClient, 'all', queryId); + } else { + // run after-mutation hooks that are requested to be run inside tx + await this.callAfterMutationHooks(result, finalQuery, mutationInfo, connectionClient, 'inTx', queryId); + + // register other after-mutation hooks to be run after the tx is committed + this.driver.registerTransactionCommitCallback(connection, () => + this.callAfterMutationHooks(result, finalQuery, mutationInfo, connectionClient, 'outTx', queryId), + ); } + + return result; } private createClientForConnection(connection: DatabaseConnection, inTx: boolean) { @@ -307,12 +257,6 @@ export class ZenStackQueryExecutor extends DefaultQuer return (this.client.$options.plugins ?? []).some((plugin) => plugin.onEntityMutation?.afterEntityMutation); } - private get hasPluginRequestingAfterMutationWithinTransaction() { - return (this.client.$options.plugins ?? []).some( - (plugin) => plugin.onEntityMutation?.runAfterMutationWithinTransaction, - ); - } - private isMutationNode(queryNode: RootOperationNode): queryNode is MutationQueryNode { return InsertQueryNode.is(queryNode) || UpdateQueryNode.is(queryNode) || DeleteQueryNode.is(queryNode); } diff --git a/packages/runtime/src/client/options.ts b/packages/runtime/src/client/options.ts index 7d1134a3..f09f44d6 100644 --- a/packages/runtime/src/client/options.ts +++ b/packages/runtime/src/client/options.ts @@ -1,7 +1,7 @@ import type { Dialect, Expression, ExpressionBuilder, KyselyConfig } from 'kysely'; import type { GetModel, GetModels, ProcedureDef, SchemaDef } from '../schema'; import type { PrependParameter } from '../utils/type-utils'; -import type { ClientContract, CRUD, ProcedureFunc } from './contract'; +import type { ClientContract, CRUD_EXT, ProcedureFunc } from './contract'; import type { BaseCrudDialect } from './crud/dialects/base-dialect'; import type { RuntimePlugin } from './plugin'; import type { ToKyselySchema } from './query-builder'; @@ -30,7 +30,7 @@ export type ZModelFunctionContext = { /** * The CRUD operation being performed */ - operation: CRUD; + operation: CRUD_EXT; }; export type ZModelFunction = ( diff --git a/packages/runtime/src/plugins/policy/errors.ts b/packages/runtime/src/plugins/policy/errors.ts index 675506d6..42d57b18 100644 --- a/packages/runtime/src/plugins/policy/errors.ts +++ b/packages/runtime/src/plugins/policy/errors.ts @@ -1,3 +1,5 @@ +import { ZenStackError } from '../../client'; + /** * Reason code for policy rejection. */ @@ -21,7 +23,7 @@ export enum RejectedByPolicyReason { /** * Error thrown when an operation is rejected by access policy. */ -export class RejectedByPolicyError extends Error { +export class RejectedByPolicyError extends ZenStackError { constructor( public readonly model: string | undefined, public readonly reason: RejectedByPolicyReason = RejectedByPolicyReason.NO_ACCESS, diff --git a/packages/runtime/src/plugins/policy/expression-transformer.ts b/packages/runtime/src/plugins/policy/expression-transformer.ts index 414b72b4..1eca04fa 100644 --- a/packages/runtime/src/plugins/policy/expression-transformer.ts +++ b/packages/runtime/src/plugins/policy/expression-transformer.ts @@ -20,7 +20,7 @@ import { type OperationNode, } from 'kysely'; import { match } from 'ts-pattern'; -import type { ClientContract, CRUD } from '../../client/contract'; +import type { ClientContract, CRUD_EXT } from '../../client/contract'; import { getCrudDialect } from '../../client/crud/dialects'; import type { BaseCrudDialect } from '../../client/crud/dialects/base-dialect'; import { InternalError, QueryError } from '../../client/errors'; @@ -50,13 +50,12 @@ import { type SchemaDef, } from '../../schema'; import { ExpressionEvaluator } from './expression-evaluator'; -import { conjunction, disjunction, falseNode, logicalNot, trueNode } from './utils'; +import { conjunction, disjunction, falseNode, isBeforeInvocation, logicalNot, trueNode } from './utils'; export type ExpressionTransformerContext = { model: GetModels; alias?: string; - operation: CRUD; - auth?: any; + operation: CRUD_EXT; memberFilter?: OperationNode; memberSelect?: SelectionNode; }; @@ -439,7 +438,7 @@ export class ExpressionTransformer { } if (this.isAuthMember(arg)) { - const valNode = this.valueMemberAccess(context.auth, arg as MemberExpression, this.authType); + const valNode = this.valueMemberAccess(this.auth, arg as MemberExpression, this.authType); return valNode ? eb.val(valNode.value) : eb.val(null); } @@ -453,11 +452,20 @@ export class ExpressionTransformer { @expr('member') // @ts-ignore private _member(expr: MemberExpression, context: ExpressionTransformerContext) { - // auth() member access + // `auth()` member access if (this.isAuthCall(expr.receiver)) { return this.valueMemberAccess(this.auth, expr, this.authType); } + // `before()` member access + if (isBeforeInvocation(expr.receiver)) { + // policy handler creates a join table named `$before` using entity value before update, + // we can directly reference the column from there + invariant(context.operation === 'post-update', 'before() can only be used in post-update policy'); + invariant(expr.members.length === 1, 'before() can only be followed by a scalar field access'); + return ReferenceNode.create(ColumnNode.create(expr.members[0]!), TableNode.create('$before')); + } + invariant( ExpressionUtils.isField(expr.receiver) || ExpressionUtils.isThis(expr.receiver), 'expect receiver to be field expression or "this"', @@ -515,7 +523,6 @@ export class ExpressionTransformer { }); if (currNode) { - invariant(SelectQueryNode.is(currNode), 'expected select query node'); currNode = { ...relation, selections: [ diff --git a/packages/runtime/src/plugins/policy/plugin.ts b/packages/runtime/src/plugins/policy/plugin.ts index 6af93353..7ebd2882 100644 --- a/packages/runtime/src/plugins/policy/plugin.ts +++ b/packages/runtime/src/plugins/policy/plugin.ts @@ -22,8 +22,8 @@ export class PolicyPlugin implements RuntimePlugin) { + onKyselyQuery({ query, client, proceed }: OnKyselyQueryArgs) { const handler = new PolicyHandler(client); - return handler.handle(query, proceed /*, transaction*/); + return handler.handle(query, proceed); } } diff --git a/packages/runtime/src/plugins/policy/policy-handler.ts b/packages/runtime/src/plugins/policy/policy-handler.ts index 50cfc835..49e5afd1 100644 --- a/packages/runtime/src/plugins/policy/policy-handler.ts +++ b/packages/runtime/src/plugins/policy/policy-handler.ts @@ -16,7 +16,9 @@ import { ParensNode, PrimitiveValueListNode, RawNode, + ReferenceNode, ReturningNode, + SelectAllNode, SelectionNode, SelectQueryNode, sql, @@ -32,18 +34,26 @@ import { } from 'kysely'; import { match } from 'ts-pattern'; import type { ClientContract } from '../../client'; -import type { CRUD } from '../../client/contract'; +import { type CRUD_EXT } from '../../client/contract'; import { getCrudDialect } from '../../client/crud/dialects'; import type { BaseCrudDialect } from '../../client/crud/dialects/base-dialect'; import { InternalError, QueryError } from '../../client/errors'; import type { ProceedKyselyQueryFunction } from '../../client/plugin'; import { getManyToManyRelation, requireField, requireIdFields, requireModel } from '../../client/query-utils'; -import { ExpressionUtils, type BuiltinType, type Expression, type GetModels, type SchemaDef } from '../../schema'; +import { + ExpressionUtils, + type BuiltinType, + type Expression, + type GetModels, + type MemberExpression, + type SchemaDef, +} from '../../schema'; +import { ExpressionVisitor } from '../../utils/expression-utils'; import { ColumnCollector } from './column-collector'; import { RejectedByPolicyError, RejectedByPolicyReason } from './errors'; import { ExpressionTransformer } from './expression-transformer'; import type { Policy, PolicyOperation } from './types'; -import { buildIsFalse, conjunction, disjunction, falseNode, getTableName } from './utils'; +import { buildIsFalse, conjunction, disjunction, falseNode, getTableName, isBeforeInvocation, trueNode } from './utils'; export type CrudQueryNode = SelectQueryNode | InsertQueryNode | UpdateQueryNode | DeleteQueryNode; @@ -61,10 +71,7 @@ export class PolicyHandler extends OperationNodeTransf return this.client.$qb; } - async handle( - node: RootOperationNode, - proceed: ProceedKyselyQueryFunction /*, transaction: OnKyselyQueryTransaction*/, - ) { + async handle(node: RootOperationNode, proceed: ProceedKyselyQueryFunction) { if (!this.isCrudQueryNode(node)) { // non-CRUD queries are not allowed throw new RejectedByPolicyError( @@ -81,6 +88,8 @@ export class PolicyHandler extends OperationNodeTransf const { mutationModel } = this.getMutationModel(node); + // --- Pre mutation work --- + if (InsertQueryNode.is(node)) { // pre-create policy evaluation happens before execution of the query const isManyToManyJoinTable = this.isManyToManyJoinTable(mutationModel); @@ -102,12 +111,86 @@ export class PolicyHandler extends OperationNodeTransf } } + const hasPostUpdatePolicies = UpdateQueryNode.is(node) && this.hasPostUpdatePolicies(mutationModel); + + let beforeUpdateInfo: Awaited> | undefined; + if (hasPostUpdatePolicies) { + beforeUpdateInfo = await this.loadBeforeUpdateEntities(mutationModel, node.where, proceed); + } + // proceed with query const result = await proceed(this.transformNode(node)); + // --- Post mutation work --- + + if (hasPostUpdatePolicies && result.rows.length > 0) { + // entities updated filter + const idConditions = this.buildIdConditions(mutationModel, result.rows); + + // post-update policy filter + const postUpdateFilter = this.buildPolicyFilter(mutationModel, undefined, 'post-update'); + + // read the post-update row with filter applied + + const eb = expressionBuilder(); + + // create a `SELECT column1 as field1, column2 as field2, ... FROM (VALUES (...))` table for before-update rows + const beforeUpdateTable: SelectQueryNode | undefined = beforeUpdateInfo + ? { + kind: 'SelectQueryNode', + from: FromNode.create([ + ParensNode.create( + ValuesNode.create( + beforeUpdateInfo!.rows.map((r) => + PrimitiveValueListNode.create(beforeUpdateInfo!.fields.map((f) => r[f])), + ), + ), + ), + ]), + selections: beforeUpdateInfo.fields.map((name, index) => { + const def = requireField(this.client.$schema, mutationModel, name); + const castedColumnRef = + sql`CAST(${eb.ref(`column${index + 1}`)} as ${sql.raw(this.dialect.getFieldSqlType(def))})`.as( + name, + ); + return SelectionNode.create(castedColumnRef.toOperationNode()); + }), + } + : undefined; + + const postUpdateQuery = eb + .selectFrom(mutationModel) + .select(() => [eb(eb.fn('COUNT', [eb.lit(1)]), '=', result.rows.length).as('$condition')]) + .where(() => new ExpressionWrapper(conjunction(this.dialect, [idConditions, postUpdateFilter]))) + .$if(!!beforeUpdateInfo, (qb) => + qb.leftJoin( + () => new ExpressionWrapper(beforeUpdateTable!).as('$before'), + (join) => { + const idFields = requireIdFields(this.client.$schema, mutationModel); + return idFields.reduce( + (acc, f) => acc.onRef(`${mutationModel}.${f}`, '=', `$before.${f}`), + join, + ); + }, + ), + ); + + const postUpdateResult = await proceed(postUpdateQuery.toOperationNode()); + if (!postUpdateResult.rows[0]?.$condition) { + throw new RejectedByPolicyError( + mutationModel, + RejectedByPolicyReason.NO_ACCESS, + 'some or all updated rows failed to pass post-update policy check', + ); + } + } + + // --- Read back --- + if (!node.returning || this.onlyReturningId(node)) { - return result; + // no need to check read back + return this.postProcessMutationResult(result, node); } else { const readBackResult = await this.processReadBack(node, result, proceed); if (readBackResult.rows.length !== result.rows.length) { @@ -121,6 +204,75 @@ export class PolicyHandler extends OperationNodeTransf } } + // correction to kysely mutation result may be needed because we might have added + // returning clause to the query and caused changes to the result shape + private postProcessMutationResult(result: QueryResult, node: MutationQueryNode) { + if (node.returning) { + return result; + } else { + return { + ...result, + rows: [], + numAffectedRows: result.numAffectedRows ?? BigInt(result.rows.length), + }; + } + } + + hasPostUpdatePolicies(model: GetModels) { + const policies = this.getModelPolicies(model, 'post-update'); + return policies.length > 0; + } + + private async loadBeforeUpdateEntities( + model: GetModels, + where: WhereNode | undefined, + proceed: ProceedKyselyQueryFunction, + ) { + const beforeUpdateAccessFields = this.getFieldsAccessForBeforeUpdatePolicies(model); + if (!beforeUpdateAccessFields || beforeUpdateAccessFields.length === 0) { + return undefined; + } + const query: SelectQueryNode = { + kind: 'SelectQueryNode', + from: FromNode.create([TableNode.create(model)]), + where, + selections: [...beforeUpdateAccessFields.map((f) => SelectionNode.create(ColumnNode.create(f)))], + }; + const result = await proceed(query); + return { fields: beforeUpdateAccessFields, rows: result.rows }; + } + + private getFieldsAccessForBeforeUpdatePolicies(model: GetModels) { + const policies = this.getModelPolicies(model, 'post-update'); + if (policies.length === 0) { + return undefined; + } + + const fields = new Set(); + const fieldCollector = new (class extends ExpressionVisitor { + protected override visitMember(e: MemberExpression): void { + if (isBeforeInvocation(e.receiver)) { + invariant(e.members.length === 1, 'before() can only be followed by a scalar field access'); + fields.add(e.members[0]!); + } + super.visitMember(e); + } + })(); + + for (const policy of policies) { + fieldCollector.visit(policy.condition); + } + + if (fields.size === 0) { + return undefined; + } + + // make sure id fields are included + requireIdFields(this.client.$schema, model).forEach((f) => fields.add(f)); + + return Array.from(fields).sort(); + } + // #region overrides protected override transformSelectQuery(node: SelectQueryNode) { @@ -193,23 +345,19 @@ export class PolicyHandler extends OperationNodeTransf const result = super.transformInsertQuery(processedNode); - if (!node.returning) { - return result; - } - - if (this.onlyReturningId(node)) { - return result; - } else { - // only return ID fields, that's enough for reading back the inserted row + // if any field is to be returned, we select ID fields here which will be used + // for reading back post-insert + let returning = result.returning; + if (returning) { const { mutationModel } = this.getMutationModel(node); const idFields = requireIdFields(this.client.$schema, mutationModel); - return { - ...result, - returning: ReturningNode.create( - idFields.map((field) => SelectionNode.create(ColumnNode.create(field))), - ), - }; + returning = ReturningNode.create(idFields.map((f) => SelectionNode.create(ColumnNode.create(f)))); } + + return { + ...result, + returning, + }; } protected override transformUpdateQuery(node: UpdateQueryNode) { @@ -225,9 +373,23 @@ export class PolicyHandler extends OperationNodeTransf } } + let returning = result.returning; + + // regarding returning: + // 1. if fields are to be returned, we only select id fields here which will be used for reading back + // post-update + // 2. if there are post-update policies, we need to make sure id fields are selected for joining with + // before-update rows + + if (returning || this.hasPostUpdatePolicies(mutationModel)) { + const idFields = requireIdFields(this.client.$schema, mutationModel); + returning = ReturningNode.create(idFields.map((f) => SelectionNode.create(ColumnNode.create(f)))); + } + return { ...result, where: WhereNode.create(result.where ? conjunction(this.dialect, [result.where.where, filter]) : filter), + returning, }; } @@ -260,6 +422,19 @@ export class PolicyHandler extends OperationNodeTransf } const { mutationModel } = this.getMutationModel(node); const idFields = requireIdFields(this.client.$schema, mutationModel); + + if (node.returning.selections.some((s) => SelectAllNode.is(s.selection))) { + const modelDef = requireModel(this.client.$schema, mutationModel); + if (Object.keys(modelDef.fields).some((f) => !idFields.includes(f))) { + // there are fields other than ID fields + return false; + } else { + // select all but model only has ID fields + return true; + } + } + + // analyze selected columns const collector = new ColumnCollector(); const selectedColumns = collector.collect(node.returning); return selectedColumns.every((c) => idFields.includes(c)); @@ -543,7 +718,7 @@ export class PolicyHandler extends OperationNodeTransf this.dialect, idFields.map((field) => BinaryOperationNode.create( - ColumnNode.create(field), + ReferenceNode.create(ColumnNode.create(field), TableNode.create(table)), OperatorNode.create('='), ValueNode.create(row[field]), ), @@ -590,7 +765,7 @@ export class PolicyHandler extends OperationNodeTransf return InsertQueryNode.is(node) || UpdateQueryNode.is(node) || DeleteQueryNode.is(node); } - buildPolicyFilter(model: GetModels, alias: string | undefined, operation: CRUD) { + buildPolicyFilter(model: GetModels, alias: string | undefined, operation: CRUD_EXT): OperationNode { // first check if it's a many-to-many join table, and if so, handle specially const m2mFilter = this.getModelPolicyFilterForManyToManyJoinTable(model, alias, operation); if (m2mFilter) { @@ -598,9 +773,6 @@ export class PolicyHandler extends OperationNodeTransf } const policies = this.getModelPolicies(model, operation); - if (policies.length === 0) { - return falseNode(this.dialect); - } const allows = policies .filter((policy) => policy.kind === 'allow') @@ -610,25 +782,33 @@ export class PolicyHandler extends OperationNodeTransf .filter((policy) => policy.kind === 'deny') .map((policy) => this.compilePolicyCondition(model, alias, operation, policy)); + // 'post-update' is by default allowed, other operations are by default denied let combinedPolicy: OperationNode; if (allows.length === 0) { - // constant false - combinedPolicy = falseNode(this.dialect); + // no allow rules + if (operation === 'post-update') { + // post-update is allowed if no allow rules are defined + combinedPolicy = trueNode(this.dialect); + } else { + // other operations are denied by default + combinedPolicy = falseNode(this.dialect); + } } else { // or(...allows) combinedPolicy = disjunction(this.dialect, allows); + } - // and(...!denies) - if (denies.length !== 0) { - const combinedDenies = conjunction( - this.dialect, - denies.map((d) => buildIsFalse(d, this.dialect)), - ); - // or(...allows) && and(...!denies) - combinedPolicy = conjunction(this.dialect, [combinedPolicy, combinedDenies]); - } + // and(...!denies) + if (denies.length !== 0) { + const combinedDenies = conjunction( + this.dialect, + denies.map((d) => buildIsFalse(d, this.dialect)), + ); + // or(...allows) && and(...!denies) + combinedPolicy = conjunction(this.dialect, [combinedPolicy, combinedDenies]); } + return combinedPolicy; } @@ -674,14 +854,13 @@ export class PolicyHandler extends OperationNodeTransf private compilePolicyCondition( model: GetModels, alias: string | undefined, - operation: CRUD, + operation: CRUD_EXT, policy: Policy, ) { return new ExpressionTransformer(this.client).transform(policy.condition, { model, alias, operation, - auth: this.client.$auth, }); } @@ -710,7 +889,11 @@ export class PolicyHandler extends OperationNodeTransf condition: attr.args![1]!.value, }) as const, ) - .filter((policy) => policy.operations.includes('all') || policy.operations.includes(operation)), + .filter( + (policy) => + (operation !== 'post-update' && policy.operations.includes('all')) || + policy.operations.includes(operation), + ), ); } return result; diff --git a/packages/runtime/src/plugins/policy/types.ts b/packages/runtime/src/plugins/policy/types.ts index cfe20871..74c49d85 100644 --- a/packages/runtime/src/plugins/policy/types.ts +++ b/packages/runtime/src/plugins/policy/types.ts @@ -1,4 +1,4 @@ -import type { CRUD } from '../../client/contract'; +import type { CRUD_EXT } from '../../client/contract'; import type { Expression } from '../../schema'; /** @@ -9,7 +9,7 @@ export type PolicyKind = 'allow' | 'deny'; /** * Access policy operation. */ -export type PolicyOperation = CRUD | 'all'; +export type PolicyOperation = CRUD_EXT | 'all'; /** * Access policy definition. diff --git a/packages/runtime/src/plugins/policy/utils.ts b/packages/runtime/src/plugins/policy/utils.ts index 1113cb7e..5fc11410 100644 --- a/packages/runtime/src/plugins/policy/utils.ts +++ b/packages/runtime/src/plugins/policy/utils.ts @@ -13,7 +13,7 @@ import { ValueNode, } from 'kysely'; import type { BaseCrudDialect } from '../../client/crud/dialects/base-dialect'; -import type { SchemaDef } from '../../schema'; +import { ExpressionUtils, type Expression, type SchemaDef } from '../../schema'; /** * Creates a `true` value node. @@ -154,3 +154,7 @@ export function getTableName(node: OperationNode | undefined) { } return undefined; } + +export function isBeforeInvocation(expr: Expression) { + return ExpressionUtils.isCall(expr) && expr.function === 'before'; +} diff --git a/packages/runtime/src/utils/expression-utils.ts b/packages/runtime/src/utils/expression-utils.ts new file mode 100644 index 00000000..8c0824d4 --- /dev/null +++ b/packages/runtime/src/utils/expression-utils.ts @@ -0,0 +1,58 @@ +import { match } from 'ts-pattern'; +import type { + ArrayExpression, + BinaryExpression, + CallExpression, + Expression, + FieldExpression, + LiteralExpression, + MemberExpression, + NullExpression, + ThisExpression, + UnaryExpression, +} from '../schema'; + +export class ExpressionVisitor { + visit(expr: Expression): void { + match(expr) + .with({ kind: 'literal' }, (e) => this.visitLiteral(e)) + .with({ kind: 'array' }, (e) => this.visitArray(e)) + .with({ kind: 'field' }, (e) => this.visitField(e)) + .with({ kind: 'member' }, (e) => this.visitMember(e)) + .with({ kind: 'binary' }, (e) => this.visitBinary(e)) + .with({ kind: 'unary' }, (e) => this.visitUnary(e)) + .with({ kind: 'call' }, (e) => this.visitCall(e)) + .with({ kind: 'this' }, (e) => this.visitThis(e)) + .with({ kind: 'null' }, (e) => this.visitNull(e)) + .exhaustive(); + } + + protected visitLiteral(_e: LiteralExpression) {} + + protected visitArray(e: ArrayExpression) { + e.items.forEach((item) => this.visit(item)); + } + + protected visitField(_e: FieldExpression) {} + + protected visitMember(e: MemberExpression) { + this.visit(e.receiver); + } + + protected visitBinary(e: BinaryExpression) { + this.visit(e.left); + this.visit(e.right); + } + + protected visitUnary(e: UnaryExpression) { + this.visit(e.operand); + } + + protected visitCall(e: CallExpression) { + e.args?.forEach((arg) => this.visit(arg)); + } + + protected visitThis(_e: ThisExpression) {} + + protected visitNull(_e: NullExpression) {} +} diff --git a/packages/runtime/test/plugin/entity-mutation-hooks.test.ts b/packages/runtime/test/plugin-infra/entity-mutation-hooks.test.ts similarity index 100% rename from packages/runtime/test/plugin/entity-mutation-hooks.test.ts rename to packages/runtime/test/plugin-infra/entity-mutation-hooks.test.ts diff --git a/packages/runtime/test/plugin/on-kysely-query.test.ts b/packages/runtime/test/plugin-infra/on-kysely-query.test.ts similarity index 100% rename from packages/runtime/test/plugin/on-kysely-query.test.ts rename to packages/runtime/test/plugin-infra/on-kysely-query.test.ts diff --git a/packages/runtime/test/plugin/on-query-hooks.test.ts b/packages/runtime/test/plugin-infra/on-query-hooks.test.ts similarity index 100% rename from packages/runtime/test/plugin/on-query-hooks.test.ts rename to packages/runtime/test/plugin-infra/on-query-hooks.test.ts diff --git a/packages/runtime/test/policy/crud/post-update.test.ts b/packages/runtime/test/policy/crud/post-update.test.ts new file mode 100644 index 00000000..98c2b5b5 --- /dev/null +++ b/packages/runtime/test/policy/crud/post-update.test.ts @@ -0,0 +1,191 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from '../utils'; + +describe('Policy post-update tests', () => { + it('allows post-update by default', async () => { + const db = await createPolicyTestClient( + ` + model Foo { + id Int @id + x Int + @@allow('read,create,update', true) + } + `, + ); + + await db.foo.create({ data: { id: 1, x: 0 } }); + await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).toResolveTruthy(); + }); + + it('works with simple post-update rules', async () => { + const db = await createPolicyTestClient( + ` + model Foo { + id Int @id + x Int + @@allow('all', true) + @@allow('post-update', x > 1) + @@deny('post-update', x > 2) + } + `, + ); + + await db.foo.create({ data: { id: 1, x: 0 } }); + + // allow: x > 1 + await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).toBeRejectedByPolicy(); + // check not updated + await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 0 }); + + // deny: x > 2 + await expect(db.foo.update({ where: { id: 1 }, data: { x: 3 } })).toBeRejectedByPolicy(); + // check not updated + await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 0 }); + + await expect(db.foo.update({ where: { id: 1 }, data: { x: 2 } })).resolves.toMatchObject({ x: 2 }); + }); + + it('respect deny rules without allow', async () => { + const db = await createPolicyTestClient( + ` + model Foo { + id Int @id + x Int + @@allow('create,read,update', true) + @@deny('post-update', x > 1) + } + `, + ); + + await db.foo.create({ data: { id: 1, x: 0 } }); + await expect(db.foo.update({ where: { id: 1 }, data: { x: 2 } })).toBeRejectedByPolicy(); + await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).toResolveTruthy(); + }); + + it('works with relation conditions', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id + age Int + profile Profile? + @@allow('all', true) + @@allow('post-update', profile == null || age == profile.age) + } + + model Profile { + id Int @id + age Int + userId Int @unique + user User @relation(fields: [userId], references: [id]) + @@allow('all', true) + } + `, + ); + + await db.user.create({ data: { id: 1, age: 20, profile: { create: { id: 1, age: 18 } } } }); + await expect(db.user.update({ where: { id: 1 }, data: { age: 22 } })).toBeRejectedByPolicy(); + await expect(db.user.update({ where: { id: 1 }, data: { age: 18 } })).toResolveTruthy(); + + await db.user.create({ data: { id: 2, age: 20, profile: { create: { id: 2, age: 18 } } } }); + await expect( + db.user.update({ where: { id: 2 }, data: { age: 22, profile: { delete: true } } }), + ).toResolveTruthy(); + }); + + it('works with before function', async () => { + const db = await createPolicyTestClient( + ` + model Foo { + id Int @id + x Int + @@allow('all', true) + @@allow('post-update', x > before().x) + } + `, + ); + + await db.foo.create({ data: { id: 1, x: 1 } }); + await db.foo.create({ data: { id: 2, x: 2 } }); + + // update one + await expect(db.foo.update({ where: { id: 1 }, data: { x: 0 } })).toBeRejectedByPolicy(); + // check not updated + await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 1 }); + + // update many + await expect(db.foo.updateMany({ data: { x: 0 } })).toBeRejectedByPolicy(); + // check not updated + await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 1 }); + await expect(db.foo.findUnique({ where: { id: 2 } })).resolves.toMatchObject({ x: 2 }); + + await expect(db.foo.update({ where: { id: 1 }, data: { x: 2 } })).toResolveTruthy(); + await expect(db.foo.updateMany({ data: { x: 3 } })).resolves.toMatchObject({ count: 2 }); + // check updated + await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 3 }); + await expect(db.foo.findUnique({ where: { id: 2 } })).resolves.toMatchObject({ x: 3 }); + }); + + it('works with query builder API', async () => { + const db = await createPolicyTestClient( + ` + model Foo { + id Int @id + x Int + @@allow('all', true) + @@allow('post-update', x > before().x) + } + `, + ); + + await db.foo.create({ data: { id: 1, x: 1 } }); + await db.foo.create({ data: { id: 2, x: 2 } }); + + // update one + await expect(db.$qb.updateTable('Foo').set({ x: 0 }).where('id', '=', 1).execute()).toBeRejectedByPolicy(); + // check not updated + await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 1 }); + + // update many + await expect(db.$qb.updateTable('Foo').set({ x: 0 }).execute()).toBeRejectedByPolicy(); + // check not updated + await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 1 }); + await expect(db.foo.findUnique({ where: { id: 2 } })).resolves.toMatchObject({ x: 2 }); + + await expect( + db.$qb.updateTable('Foo').set({ x: 2 }).where('id', '=', 1).executeTakeFirst(), + ).resolves.toMatchObject({ + numUpdatedRows: 1n, + }); + // check updated + await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 2 }); + + await expect(db.$qb.updateTable('Foo').set({ x: 3 }).executeTakeFirst()).resolves.toMatchObject({ + numUpdatedRows: 2n, + }); + // check updated + await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 3 }); + await expect(db.foo.findUnique({ where: { id: 2 } })).resolves.toMatchObject({ x: 3 }); + }); + + it('rejects accessing relation fields from before', async () => { + await expect( + createPolicyTestClient( + ` + model User { + id Int @id + name String + profile Profile? + } + + model Profile { + id Int @id + userId Int @unique + user User @relation(fields: [userId], references: [id]) + @@allow('post-update', before().user.name == 'a') + } + `, + ), + ).rejects.toThrow('relation fields cannot be accessed from `before()`'); + }); +}); diff --git a/packages/runtime/test/policy/migrated/multi-field-unique.test.ts b/packages/runtime/test/policy/migrated/multi-field-unique.test.ts index 7edbe019..fba22b09 100644 --- a/packages/runtime/test/policy/migrated/multi-field-unique.test.ts +++ b/packages/runtime/test/policy/migrated/multi-field-unique.test.ts @@ -1,19 +1,8 @@ -import path from 'path'; -import { afterEach, beforeAll, describe, expect, it } from 'vitest'; -import { createPolicyTestClient } from '../utils'; +import { describe, expect, it } from 'vitest'; import { QueryError } from '../../../src'; +import { createPolicyTestClient } from '../utils'; describe('Policy tests multi-field unique', () => { - let origDir: string; - - beforeAll(async () => { - origDir = path.resolve('.'); - }); - - afterEach(() => { - process.chdir(origDir); - }); - it('toplevel crud test unnamed constraint', async () => { const db = await createPolicyTestClient( ` diff --git a/packages/runtime/test/policy/migrated/multi-id-fields.test.ts b/packages/runtime/test/policy/migrated/multi-id-fields.test.ts index 56941f03..9444fe20 100644 --- a/packages/runtime/test/policy/migrated/multi-id-fields.test.ts +++ b/packages/runtime/test/policy/migrated/multi-id-fields.test.ts @@ -57,8 +57,7 @@ describe('Policy tests multiple id fields', () => { ).toResolveTruthy(); }); - // TODO: `future()` support - it.skip('multi-id fields id update', async () => { + it('multi-id fields id update', async () => { const db = await createPolicyTestClient( ` model A { @@ -70,7 +69,8 @@ describe('Policy tests multiple id fields', () => { @@allow('read', true) @@allow('create', value > 0) - @@allow('update', value > 0 && future().value > 1) + @@allow('update', value > 0) + @@allow('post-update', value > 1) } model B { @@ -319,8 +319,7 @@ describe('Policy tests multiple id fields', () => { expect(await db.c.findUnique({ where: { id: 1 } })).toEqual(expect.objectContaining({ v: 6 })); }); - // TODO: `future()` support - it.skip('multi-id fields nested id update', async () => { + it('multi-id fields nested id update', async () => { const db = await createPolicyTestClient( ` model A { @@ -333,7 +332,8 @@ describe('Policy tests multiple id fields', () => { @@allow('read', true) @@allow('create', value > 0) - @@allow('update', value > 0 && future().value > 1) + @@allow('update', value > 0) + @@allow('post-update', value > 1) } model B { @@ -369,7 +369,7 @@ describe('Policy tests multiple id fields', () => { upsert: { where: { x_y: { x: '2', y: 2 } }, update: { x: '3', y: 3, value: 0 }, - create: { x: '4', y: '4', value: 4 }, + create: { x: '4', y: 4, value: 4 }, }, }, }, @@ -384,7 +384,7 @@ describe('Policy tests multiple id fields', () => { upsert: { where: { x_y: { x: '2', y: 2 } }, update: { x: '3', y: 3, value: 3 }, - create: { x: '4', y: '4', value: 4 }, + create: { x: '4', y: 4, value: 4 }, }, }, }, diff --git a/packages/runtime/test/policy/migrated/nested-to-one.test.ts b/packages/runtime/test/policy/migrated/nested-to-one.test.ts index 5838cae8..432c8065 100644 --- a/packages/runtime/test/policy/migrated/nested-to-one.test.ts +++ b/packages/runtime/test/policy/migrated/nested-to-one.test.ts @@ -197,8 +197,7 @@ describe('With Policy:nested to-one', () => { ).toBeRejectedNotFound(); }); - // TODO: `future()` support - it.skip('nested update id tests', async () => { + it('nested update id tests', async () => { const db = await createPolicyTestClient( ` model M1 { @@ -216,7 +215,8 @@ describe('With Policy:nested to-one', () => { @@allow('read', true) @@allow('create', value > 0) - @@allow('update', value > 1 && future().value > 2) + @@allow('update', value > 1) + @@allow('post-update', value > 2) } `, ); diff --git a/packages/runtime/test/policy/migrated/petstore-sample.test.ts b/packages/runtime/test/policy/migrated/petstore-sample.test.ts index 99e5e8c7..2b210827 100644 --- a/packages/runtime/test/policy/migrated/petstore-sample.test.ts +++ b/packages/runtime/test/policy/migrated/petstore-sample.test.ts @@ -2,8 +2,7 @@ import { describe, expect, it } from 'vitest'; import { createPolicyTestClient } from '../utils'; import { schema } from '../../schemas/petstore/schema'; -// TODO: `future()` support -describe.skip('Pet Store Policy Tests', () => { +describe('Pet Store Policy Tests', () => { it('crud', async () => { const petData = [ { diff --git a/packages/runtime/test/policy/migrated/todo-sample.test.ts b/packages/runtime/test/policy/migrated/todo-sample.test.ts index c81ac3f7..541ca69b 100644 --- a/packages/runtime/test/policy/migrated/todo-sample.test.ts +++ b/packages/runtime/test/policy/migrated/todo-sample.test.ts @@ -370,8 +370,7 @@ describe('Todo Policy Tests', () => { expect(r1.lists).toHaveLength(1); }); - // TODO: `future()` support - it.skip('post-update checks', async () => { + it('post-update checks', async () => { await createSpaceAndUsers(db.$unuseAll()); const user1Db = db.$setAuth({ id: user1.id }); diff --git a/packages/runtime/test/policy/migrated/toplevel-operations.test.ts b/packages/runtime/test/policy/migrated/toplevel-operations.test.ts index f545148c..f427c4ad 100644 --- a/packages/runtime/test/policy/migrated/toplevel-operations.test.ts +++ b/packages/runtime/test/policy/migrated/toplevel-operations.test.ts @@ -133,8 +133,7 @@ describe('Policy toplevel operations tests', () => { ).toBeTruthy(); }); - // TODO: `future()` support - it.skip('update id tests', async () => { + it('update id tests', async () => { const db = await createPolicyTestClient( ` model Model { @@ -143,7 +142,8 @@ describe('Policy toplevel operations tests', () => { @@allow('read', value > 1) @@allow('create', value > 0) - @@allow('update', value > 1 && future().value > 2) + @@allow('update', value > 1) + @@allow('post-update', value > 2) } `, ); @@ -164,7 +164,7 @@ describe('Policy toplevel operations tests', () => { value: 1, }, }), - ).toBeRejectedNotFound(); + ).toBeRejectedByPolicy(); // update success await expect( diff --git a/packages/runtime/test/policy/todo-sample.test.ts b/packages/runtime/test/policy/todo-sample.test.ts index 83c812b5..a53c7466 100644 --- a/packages/runtime/test/policy/todo-sample.test.ts +++ b/packages/runtime/test/policy/todo-sample.test.ts @@ -383,8 +383,7 @@ describe('todo sample tests', () => { expect(r1?.lists).toHaveLength(1); }); - // TODO: `future()` support - it.skip('works with post-update checks', async () => { + it('works with post-update checks', async () => { const anonDb = await createPolicyTestClient(schema); await createSpaceAndUsers(anonDb.$unuseAll()); diff --git a/packages/runtime/test/schemas/petstore/schema.ts b/packages/runtime/test/schemas/petstore/schema.ts index c6902c7e..a2eb7d67 100644 --- a/packages/runtime/test/schemas/petstore/schema.ts +++ b/packages/runtime/test/schemas/petstore/schema.ts @@ -92,7 +92,8 @@ export const schema = { }, attributes: [ { name: "@@allow", args: [{ name: "operation", value: ExpressionUtils.literal("read") }, { name: "condition", value: ExpressionUtils.binary(ExpressionUtils.binary(ExpressionUtils.field("orderId"), "==", ExpressionUtils._null()), "||", ExpressionUtils.binary(ExpressionUtils.member(ExpressionUtils.field("order"), ["user"]), "==", ExpressionUtils.call("auth"))) }] }, - { name: "@@allow", args: [{ name: "operation", value: ExpressionUtils.literal("update") }, { name: "condition", value: ExpressionUtils.binary(ExpressionUtils.binary(ExpressionUtils.binary(ExpressionUtils.field("name"), "==", ExpressionUtils.member(ExpressionUtils.call("future"), ["name"])), "&&", ExpressionUtils.binary(ExpressionUtils.field("category"), "==", ExpressionUtils.member(ExpressionUtils.call("future"), ["category"]))), "&&", ExpressionUtils.binary(ExpressionUtils.field("orderId"), "==", ExpressionUtils._null())) }] } + { name: "@@allow", args: [{ name: "operation", value: ExpressionUtils.literal("update") }, { name: "condition", value: ExpressionUtils.binary(ExpressionUtils.call("auth"), "!=", ExpressionUtils._null()) }] }, + { name: "@@allow", args: [{ name: "operation", value: ExpressionUtils.literal("post-update") }, { name: "condition", value: ExpressionUtils.binary(ExpressionUtils.binary(ExpressionUtils.binary(ExpressionUtils.member(ExpressionUtils.call("before"), ["name"]), "==", ExpressionUtils.field("name")), "&&", ExpressionUtils.binary(ExpressionUtils.member(ExpressionUtils.call("before"), ["category"]), "==", ExpressionUtils.field("category"))), "&&", ExpressionUtils.binary(ExpressionUtils.member(ExpressionUtils.call("before"), ["orderId"]), "==", ExpressionUtils._null())) }] } ], idFields: ["id"], uniqueFields: { diff --git a/packages/runtime/test/schemas/petstore/schema.zmodel b/packages/runtime/test/schemas/petstore/schema.zmodel index 4a2442ca..8809445c 100644 --- a/packages/runtime/test/schemas/petstore/schema.zmodel +++ b/packages/runtime/test/schemas/petstore/schema.zmodel @@ -35,8 +35,10 @@ model Pet { // unsold pets are readable to all; sold ones are readable to buyers only @@allow('read', orderId == null || order.user == auth()) + @@allow('update', auth() != null) + // only allow update to 'orderId' field if it's not set yet (unsold) - @@allow('update', name == future().name && category == future().category && orderId == null ) + @@allow('post-update', before().name == name && before().category == category && before().orderId == null) } model Order { diff --git a/packages/runtime/test/schemas/todo/schema.ts b/packages/runtime/test/schemas/todo/schema.ts index 14ef60d1..f0ae9c26 100644 --- a/packages/runtime/test/schemas/todo/schema.ts +++ b/packages/runtime/test/schemas/todo/schema.ts @@ -311,6 +311,7 @@ export const schema = { { name: "@@allow", args: [{ name: "operation", value: ExpressionUtils.literal("read") }, { name: "condition", value: ExpressionUtils.binary(ExpressionUtils.binary(ExpressionUtils.field("ownerId"), "==", ExpressionUtils.member(ExpressionUtils.call("auth"), ["id"])), "||", ExpressionUtils.binary(ExpressionUtils.binary(ExpressionUtils.member(ExpressionUtils.field("space"), ["members"]), "?", ExpressionUtils.binary(ExpressionUtils.field("userId"), "==", ExpressionUtils.member(ExpressionUtils.call("auth"), ["id"]))), "&&", ExpressionUtils.unary("!", ExpressionUtils.field("private")))) }] }, { name: "@@allow", args: [{ name: "operation", value: ExpressionUtils.literal("create") }, { name: "condition", value: ExpressionUtils.binary(ExpressionUtils.binary(ExpressionUtils.field("ownerId"), "==", ExpressionUtils.member(ExpressionUtils.call("auth"), ["id"])), "&&", ExpressionUtils.binary(ExpressionUtils.member(ExpressionUtils.field("space"), ["members"]), "?", ExpressionUtils.binary(ExpressionUtils.field("userId"), "==", ExpressionUtils.member(ExpressionUtils.call("auth"), ["id"])))) }] }, { name: "@@allow", args: [{ name: "operation", value: ExpressionUtils.literal("update") }, { name: "condition", value: ExpressionUtils.binary(ExpressionUtils.binary(ExpressionUtils.field("ownerId"), "==", ExpressionUtils.member(ExpressionUtils.call("auth"), ["id"])), "&&", ExpressionUtils.binary(ExpressionUtils.member(ExpressionUtils.field("space"), ["members"]), "?", ExpressionUtils.binary(ExpressionUtils.field("userId"), "==", ExpressionUtils.member(ExpressionUtils.call("auth"), ["id"])))) }] }, + { name: "@@deny", args: [{ name: "operation", value: ExpressionUtils.literal("post-update") }, { name: "condition", value: ExpressionUtils.binary(ExpressionUtils.member(ExpressionUtils.call("before"), ["ownerId"]), "!=", ExpressionUtils.field("ownerId")) }] }, { name: "@@allow", args: [{ name: "operation", value: ExpressionUtils.literal("delete") }, { name: "condition", value: ExpressionUtils.binary(ExpressionUtils.field("ownerId"), "==", ExpressionUtils.member(ExpressionUtils.call("auth"), ["id"])) }] } ], idFields: ["id"], @@ -380,7 +381,8 @@ export const schema = { attributes: [ { name: "@@deny", args: [{ name: "operation", value: ExpressionUtils.literal("all") }, { name: "condition", value: ExpressionUtils.binary(ExpressionUtils.call("auth"), "==", ExpressionUtils._null()) }] }, { name: "@@allow", args: [{ name: "operation", value: ExpressionUtils.literal("all") }, { name: "condition", value: ExpressionUtils.binary(ExpressionUtils.member(ExpressionUtils.field("list"), ["ownerId"]), "==", ExpressionUtils.member(ExpressionUtils.call("auth"), ["id"])) }] }, - { name: "@@allow", args: [{ name: "operation", value: ExpressionUtils.literal("all") }, { name: "condition", value: ExpressionUtils.binary(ExpressionUtils.binary(ExpressionUtils.member(ExpressionUtils.field("list"), ["space", "members"]), "?", ExpressionUtils.binary(ExpressionUtils.field("userId"), "==", ExpressionUtils.member(ExpressionUtils.call("auth"), ["id"]))), "&&", ExpressionUtils.unary("!", ExpressionUtils.member(ExpressionUtils.field("list"), ["private"]))) }] } + { name: "@@allow", args: [{ name: "operation", value: ExpressionUtils.literal("all") }, { name: "condition", value: ExpressionUtils.binary(ExpressionUtils.binary(ExpressionUtils.member(ExpressionUtils.field("list"), ["space", "members"]), "?", ExpressionUtils.binary(ExpressionUtils.field("userId"), "==", ExpressionUtils.member(ExpressionUtils.call("auth"), ["id"]))), "&&", ExpressionUtils.unary("!", ExpressionUtils.member(ExpressionUtils.field("list"), ["private"]))) }] }, + { name: "@@deny", args: [{ name: "operation", value: ExpressionUtils.literal("post-update") }, { name: "condition", value: ExpressionUtils.binary(ExpressionUtils.member(ExpressionUtils.call("before"), ["ownerId"]), "!=", ExpressionUtils.field("ownerId")) }] } ], idFields: ["id"], uniqueFields: { diff --git a/packages/runtime/test/schemas/todo/todo.zmodel b/packages/runtime/test/schemas/todo/todo.zmodel index d91ed34a..faeaa660 100644 --- a/packages/runtime/test/schemas/todo/todo.zmodel +++ b/packages/runtime/test/schemas/todo/todo.zmodel @@ -117,10 +117,9 @@ model List { // when create, owner must be set to current user, and user must be in the space // update is not allowed to change owner - @@allow('update', ownerId == auth().id && space.members?[userId == auth().id] - // TODO: future() support - // && future().ownerId == ownerId - ) + @@allow('update', ownerId == auth().id && space.members?[userId == auth().id]) + + @@deny('post-update', before().ownerId != ownerId) // can be deleted by owner @@allow('delete', ownerId == auth().id) @@ -147,7 +146,6 @@ model Todo { @@allow('all', list.ownerId == auth().id) @@allow('all', list.space.members?[userId == auth().id] && !list.private) - // TODO: future() support - // // update is not allowed to change owner - // @@deny('update', future().owner != owner) + // update is not allowed to change owner + @@deny('post-update', before().ownerId != ownerId) } diff --git a/packages/sdk/package.json b/packages/sdk/package.json index 1a065577..278b610a 100644 --- a/packages/sdk/package.json +++ b/packages/sdk/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/sdk", - "version": "3.0.0-beta.7", + "version": "3.0.0-beta.8", "description": "ZenStack SDK", "type": "module", "scripts": { diff --git a/packages/sdk/src/prisma/prisma-schema-generator.ts b/packages/sdk/src/prisma/prisma-schema-generator.ts index 116ee872..a02aba03 100644 --- a/packages/sdk/src/prisma/prisma-schema-generator.ts +++ b/packages/sdk/src/prisma/prisma-schema-generator.ts @@ -20,7 +20,6 @@ import { isDataModel, isInvocationExpr, isLiteralExpr, - isModel, isNullExpr, isReferenceExpr, isStringLiteral, @@ -31,7 +30,7 @@ import { StringLiteral, type AstNode, } from '@zenstackhq/language/ast'; -import { getAllAttributes, getAllFields, isDelegateModel } from '@zenstackhq/language/utils'; +import { getAllAttributes, getAllFields, isAuthInvocation, isDelegateModel } from '@zenstackhq/language/utils'; import { AstUtils } from 'langium'; import { match } from 'ts-pattern'; import { ModelUtils, ZModelCodeGenerator } from '..'; @@ -242,8 +241,8 @@ export class PrismaSchemaGenerator { const attributes = field.attributes .filter((attr) => this.isPrismaAttribute(attr)) - // `@default` with calling functions from plugin is handled outside Prisma - .filter((attr) => !this.isDefaultWithPluginInvocation(attr)) + // `@default` using `auth()` is handled outside Prisma + .filter((attr) => !this.isDefaultWithAuthInvocation(attr)) .filter( (attr) => // when building physical schema, exclude `@default` for id fields inherited from delegate base @@ -260,7 +259,7 @@ export class PrismaSchemaGenerator { return result; } - private isDefaultWithPluginInvocation(attr: DataFieldAttribute) { + private isDefaultWithAuthInvocation(attr: DataFieldAttribute) { if (attr.decl.ref?.name !== '@default') { return false; } @@ -270,12 +269,7 @@ export class PrismaSchemaGenerator { return false; } - return AstUtils.streamAst(expr).some((node) => isInvocationExpr(node) && this.isFromPlugin(node.function.ref)); - } - - private isFromPlugin(node: AstNode | undefined) { - const model = AstUtils.getContainerOfType(node, isModel); - return !!model && !!model.$document && model.$document.uri.path.endsWith('plugin.zmodel'); + return AstUtils.streamAst(expr).some(isAuthInvocation); } private isInheritedFromDelegate(field: DataField, contextModel: DataModel) { diff --git a/packages/tanstack-query/package.json b/packages/tanstack-query/package.json index 931b48d5..f6b036ef 100644 --- a/packages/tanstack-query/package.json +++ b/packages/tanstack-query/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/tanstack-query", - "version": "3.0.0-beta.7", + "version": "3.0.0-beta.8", "description": "", "main": "index.js", "type": "module", diff --git a/packages/testtools/package.json b/packages/testtools/package.json index 84470c0f..be36abc2 100644 --- a/packages/testtools/package.json +++ b/packages/testtools/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/testtools", - "version": "3.0.0-beta.7", + "version": "3.0.0-beta.8", "description": "ZenStack Test Tools", "type": "module", "scripts": { diff --git a/packages/testtools/src/project.ts b/packages/testtools/src/project.ts index c3753cfb..0a795c6b 100644 --- a/packages/testtools/src/project.ts +++ b/packages/testtools/src/project.ts @@ -2,7 +2,7 @@ import fs from 'node:fs'; import path from 'node:path'; import tmp from 'tmp'; -export function createTestProject() { +export function createTestProject(zmodelContent?: string) { const { name: workDir } = tmp.dirSync({ unsafeCleanup: true }); fs.mkdirSync(path.join(workDir, 'node_modules')); @@ -63,5 +63,9 @@ export function createTestProject() { ), ); + if (zmodelContent) { + fs.writeFileSync(path.join(workDir, 'schema.zmodel'), zmodelContent); + } + return workDir; } diff --git a/packages/typescript-config/package.json b/packages/typescript-config/package.json index 81028a8d..aeea0618 100644 --- a/packages/typescript-config/package.json +++ b/packages/typescript-config/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/typescript-config", - "version": "3.0.0-beta.7", + "version": "3.0.0-beta.8", "private": true, "license": "MIT" } diff --git a/packages/vitest-config/package.json b/packages/vitest-config/package.json index a053036c..1783b010 100644 --- a/packages/vitest-config/package.json +++ b/packages/vitest-config/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/vitest-config", "type": "module", - "version": "3.0.0-beta.7", + "version": "3.0.0-beta.8", "private": true, "license": "MIT", "exports": { diff --git a/packages/zod/package.json b/packages/zod/package.json index 138ae529..f5f6a862 100644 --- a/packages/zod/package.json +++ b/packages/zod/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/zod", - "version": "3.0.0-beta.7", + "version": "3.0.0-beta.8", "description": "", "type": "module", "main": "index.js", diff --git a/samples/blog/package.json b/samples/blog/package.json index c36aeb21..3c6353b6 100644 --- a/samples/blog/package.json +++ b/samples/blog/package.json @@ -1,6 +1,6 @@ { "name": "sample-blog", - "version": "3.0.0-beta.7", + "version": "3.0.0-beta.8", "description": "", "main": "index.js", "scripts": { diff --git a/tests/e2e/package.json b/tests/e2e/package.json index afc5c930..f1b847a1 100644 --- a/tests/e2e/package.json +++ b/tests/e2e/package.json @@ -1,6 +1,6 @@ { "name": "e2e", - "version": "3.0.0-beta.7", + "version": "3.0.0-beta.8", "private": true, "type": "module", "scripts": { diff --git a/tests/regression/package.json b/tests/regression/package.json index 46368cad..bb98998f 100644 --- a/tests/regression/package.json +++ b/tests/regression/package.json @@ -1,6 +1,6 @@ { "name": "regression", - "version": "3.0.0-beta.7", + "version": "3.0.0-beta.8", "private": true, "type": "module", "scripts": { diff --git a/tests/regression/test/issue-204/regression.test.ts b/tests/regression/test/issue-204/regression.test.ts index 24a43e3b..d7d78948 100644 --- a/tests/regression/test/issue-204/regression.test.ts +++ b/tests/regression/test/issue-204/regression.test.ts @@ -1,11 +1,9 @@ -import { describe, it } from 'vitest'; +import { it } from 'vitest'; import { type Configuration, ShirtColor } from './models'; -describe('Issue 204 regression tests', () => { - it('tests issue 204', () => { - const config: Configuration = { teamColors: [ShirtColor.Black, ShirtColor.Blue] }; - console.log(config.teamColors?.[0]); - const config1: Configuration = {}; - console.log(config1); - }); +it('tests issue 204', () => { + const config: Configuration = { teamColors: [ShirtColor.Black, ShirtColor.Blue] }; + console.log(config.teamColors?.[0]); + const config1: Configuration = {}; + console.log(config1); }); diff --git a/tests/regression/test/issue-274/regression.test.ts b/tests/regression/test/issue-274/regression.test.ts new file mode 100644 index 00000000..ffa11a09 --- /dev/null +++ b/tests/regression/test/issue-274/regression.test.ts @@ -0,0 +1,27 @@ +import { createTestProject } from '@zenstackhq/testtools'; +import { execSync } from 'child_process'; +import { it } from 'vitest'; + +it('tests issue 274', async () => { + const dir = await createTestProject(` + +datasource db { + provider = 'sqlite' + url = "file:./test.db" +} + +model Comment { + id String @id + author User? @relation(fields: [authorId], references: [id]) + authorId String? @default(auth().id) +} + +model User { + id String @id + email String + comments Comment[] +} +`); + + execSync('node node_modules/@zenstackhq/cli/dist/index.js migrate dev --name init', { cwd: dir }); +});