diff --git a/.vscode/launch.json b/.vscode/launch.json index 886089d1..09ccbd59 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -11,7 +11,7 @@ "skipFiles": ["/**"], "type": "node", "args": ["generate"], - "cwd": "${workspaceFolder}/samples/blog/zenstack" + "cwd": "${workspaceFolder}/samples/blog" }, { "name": "Debug with TSX", diff --git a/packages/cli/src/actions/action-utils.ts b/packages/cli/src/actions/action-utils.ts index 287c5593..49655622 100644 --- a/packages/cli/src/actions/action-utils.ts +++ b/packages/cli/src/actions/action-utils.ts @@ -1,10 +1,13 @@ -import { loadDocument } from '@zenstackhq/language'; -import { isDataSource } from '@zenstackhq/language/ast'; +import { createZModelServices, loadDocument, type ZModelServices } from '@zenstackhq/language'; +import { isDataSource, isPlugin, Model } from '@zenstackhq/language/ast'; +import { getLiteral } from '@zenstackhq/language/utils'; import { PrismaSchemaGenerator } from '@zenstackhq/sdk'; import colors from 'colors'; import fs from 'node:fs'; import path from 'node:path'; +import { fileURLToPath } from 'node:url'; import { CliError } from '../cli-error'; +import { PLUGIN_MODULE_NAME } from '../constants'; export function getSchemaFile(file?: string) { if (file) { @@ -34,7 +37,9 @@ export function getSchemaFile(file?: string) { } export async function loadSchemaDocument(schemaFile: string) { - const loadResult = await loadDocument(schemaFile); + const { ZModelLanguage: services } = createZModelServices(); + const pluginDocs = await getPluginDocuments(services, schemaFile); + const loadResult = await loadDocument(schemaFile, pluginDocs); if (!loadResult.success) { loadResult.errors.forEach((err) => { console.error(colors.red(err)); @@ -47,6 +52,63 @@ export async function loadSchemaDocument(schemaFile: string) { return loadResult.model; } +export async function getPluginDocuments(services: ZModelServices, fileName: string): Promise { + // parse the user document (without validation) + const parseResult = services.parser.LangiumParser.parse(fs.readFileSync(fileName, { encoding: 'utf-8' })); + const parsed = parseResult.value as Model; + + // balk if there are syntax errors + if (parseResult.lexerErrors.length > 0 || parseResult.parserErrors.length > 0) { + return []; + } + + // traverse plugins and collect "plugin.zmodel" documents + const result: string[] = []; + for (const decl of parsed.declarations.filter(isPlugin)) { + const providerField = decl.fields.find((f) => f.name === 'provider'); + if (!providerField) { + continue; + } + + const provider = getLiteral(providerField.value); + if (!provider) { + continue; + } + + let pluginModelFile: string | undefined; + + // first try to treat provider as a path + let providerPath = path.resolve(path.dirname(fileName), provider); + if (fs.existsSync(providerPath)) { + if (fs.statSync(providerPath).isDirectory()) { + providerPath = path.join(providerPath, 'index.js'); + } + + // try plugin.zmodel next to the provider file + pluginModelFile = path.resolve(path.dirname(providerPath), PLUGIN_MODULE_NAME); + if (!fs.existsSync(pluginModelFile)) { + // try to find upwards + pluginModelFile = findUp([PLUGIN_MODULE_NAME], path.dirname(providerPath)); + } + } + + if (!pluginModelFile) { + // try loading it as a ESM module + try { + const resolvedUrl = import.meta.resolve(`${provider}/${PLUGIN_MODULE_NAME}`); + pluginModelFile = fileURLToPath(resolvedUrl); + } catch { + // noop + } + } + + if (pluginModelFile && fs.existsSync(pluginModelFile)) { + result.push(pluginModelFile); + } + } + return result; +} + export function handleSubProcessError(err: unknown) { if (err instanceof Error && 'status' in err && typeof err.status === 'number') { process.exit(err.status); diff --git a/packages/cli/src/actions/generate.ts b/packages/cli/src/actions/generate.ts index e80c219d..2db079e1 100644 --- a/packages/cli/src/actions/generate.ts +++ b/packages/cli/src/actions/generate.ts @@ -64,7 +64,7 @@ async function runPlugins(schemaFile: string, model: Model, outputPath: string, for (const plugin of plugins) { const provider = getPluginProvider(plugin); - let cliPlugin: CliPlugin; + let cliPlugin: CliPlugin | undefined; if (provider.startsWith('@core/')) { cliPlugin = (corePlugins as any)[provider.slice('@core/'.length)]; if (!cliPlugin) { @@ -78,12 +78,14 @@ async function runPlugins(schemaFile: string, model: Model, outputPath: string, } try { cliPlugin = (await import(moduleSpec)).default as CliPlugin; - } catch (error) { - throw new CliError(`Failed to load plugin ${provider}: ${error}`); + } catch { + // plugin may not export a generator so we simply ignore the error here } } - processedPlugins.push({ cliPlugin, pluginOptions: getPluginOptions(plugin) }); + if (cliPlugin) { + processedPlugins.push({ cliPlugin, pluginOptions: getPluginOptions(plugin) }); + } } const defaultPlugins = [corePlugins['typescript']].reverse(); diff --git a/packages/cli/src/constants.ts b/packages/cli/src/constants.ts index 586537af..72463404 100644 --- a/packages/cli/src/constants.ts +++ b/packages/cli/src/constants.ts @@ -1,2 +1,5 @@ // replaced at build time export const TELEMETRY_TRACKING_TOKEN = ''; + +// plugin-contributed model file name +export const PLUGIN_MODULE_NAME = 'plugin.zmodel'; diff --git a/packages/language/res/stdlib.zmodel b/packages/language/res/stdlib.zmodel index 7ac57ba3..f1b46e84 100644 --- a/packages/language/res/stdlib.zmodel +++ b/packages/language/res/stdlib.zmodel @@ -174,29 +174,6 @@ function hasSome(field: Any[], search: Any[]): Boolean { function isEmpty(field: Any[]): Boolean { } @@@expressionContext([AccessPolicy, ValidationRule]) -/** - * The name of the model for which the policy rule is defined. If the rule is - * inherited to a sub model, this function returns the name of the sub model. - * - * @param optional parameter to control the casing of the returned value. Valid - * values are "original", "upper", "lower", "capitalize", "uncapitalize". Defaults - * to "original". - */ -function currentModel(casing: String?): String { -} @@@expressionContext([AccessPolicy]) - -/** - * The operation for which the policy rule is defined for. Note that a rule with - * "all" operation is expanded to "create", "read", "update", and "delete" rules, - * and the function returns corresponding value for each expanded version. - * - * @param optional parameter to control the casing of the returned value. Valid - * values are "original", "upper", "lower", "capitalize", "uncapitalize". Defaults - * to "original". - */ -function currentOperation(casing: String?): String { -} @@@expressionContext([AccessPolicy]) - /** * Marks an attribute to be only applicable to certain field types. */ @@ -658,56 +635,3 @@ attribute @meta(_ name: String, _ value: Any) * Marks an attribute as deprecated. */ attribute @@@deprecated(_ message: String) - -/* --- Policy Plugin --- */ - -/** - * Defines an access policy that allows a set of operations when the given condition is true. - * - * @param operation: comma-separated list of "create", "read", "update", "delete". Use "all" to denote all operations. - * @param condition: a boolean expression that controls if the operation should be allowed. - */ -attribute @@allow(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'post-update'","'delete'", "'all'"]), _ condition: Boolean) - -/** - * Defines an access policy that allows the annotated field to be read or updated. - * You can pass a third argument as `true` to make it override the model-level policies. - * - * @param operation: comma-separated list of "create", "read", "update", "delete". Use "all" to denote all operations. - * @param condition: a boolean expression that controls if the operation should be allowed. - * @param override: a boolean value that controls if the field-level policy should override the model-level policy. - */ -// attribute @allow(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'delete'", "'all'"]), _ condition: Boolean, _ override: Boolean?) - -/** - * Defines an access policy that denies a set of operations when the given condition is true. - * - * @param operation: comma-separated list of "create", "read", "update", "delete". Use "all" to denote all operations. - * @param condition: a boolean expression that controls if the operation should be denied. - */ -attribute @@deny(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'post-update'","'delete'", "'all'"]), _ condition: Boolean) - -/** - * Defines an access policy that denies the annotated field to be read or updated. - * - * @param operation: comma-separated list of "create", "read", "update", "delete". Use "all" to denote all operations. - * @param condition: a boolean expression that controls if the operation should be denied. - */ -// attribute @deny(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'delete'", "'all'"]), _ condition: Boolean) - -/** - * Checks if the current user can perform the given operation on the given field. - * - * @param field: The field to check access for - * @param operation: The operation to check access for. Can be "read", "create", "update", or "delete". If the operation is not provided, - * it defaults the operation of the containing policy rule. - */ -function check(field: Any, operation: String?): Boolean { -} @@@expressionContext([AccessPolicy]) - -/** - * Gets entity's value before an update. Only valid when used in a "post-update" policy rule. - */ -function before(): Any { -} @@@expressionContext([AccessPolicy]) - diff --git a/packages/language/src/index.ts b/packages/language/src/index.ts index ab577c7f..d6e0d72f 100644 --- a/packages/language/src/index.ts +++ b/packages/language/src/index.ts @@ -51,7 +51,7 @@ export async function loadDocument( ); // load additional model files - const pluginDocs = await Promise.all( + const additionalDocs = await Promise.all( additionalModelFiles.map((file) => services.shared.workspace.LangiumDocuments.getOrCreateDocument(URI.file(path.resolve(file))), ), @@ -69,7 +69,7 @@ export async function loadDocument( } // build the document together with standard library, plugin modules, and imported documents - await services.shared.workspace.DocumentBuilder.build([stdLib, ...pluginDocs, document, ...importedDocuments], { + await services.shared.workspace.DocumentBuilder.build([stdLib, ...additionalDocs, document, ...importedDocuments], { validation: { stopAfterLexingErrors: true, stopAfterParsingErrors: true, diff --git a/packages/language/src/utils.ts b/packages/language/src/utils.ts index 762187b3..3e24285f 100644 --- a/packages/language/src/utils.ts +++ b/packages/language/src/utils.ts @@ -447,8 +447,9 @@ export function getAuthDecl(decls: (DataModel | TypeDef)[]) { return authModel; } +// TODO: move to policy plugin export function isBeforeInvocation(node: AstNode) { - return isInvocationExpr(node) && node.function.ref?.name === 'before' && isFromStdlib(node.function.ref); + return isInvocationExpr(node) && node.function.ref?.name === 'before'; } export function isCollectionPredicate(node: AstNode): node is BinaryExpr { diff --git a/packages/language/src/validators/function-invocation-validator.ts b/packages/language/src/validators/function-invocation-validator.ts index a740b86e..9e75a18a 100644 --- a/packages/language/src/validators/function-invocation-validator.ts +++ b/packages/language/src/validators/function-invocation-validator.ts @@ -20,7 +20,6 @@ import { getLiteral, isCheckInvocation, isDataFieldReference, - isFromStdlib, typeAssignable, } from '../utils'; import type { AstValidator } from './common'; @@ -52,43 +51,39 @@ export default class FunctionInvocationValidator implements AstValidator(expr.args[0]?.value); - if (arg && !allCasing.includes(arg)) { - accept('error', `argument must be one of: ${allCasing.map((c) => '"' + c + '"').join(', ')}`, { - node: expr.args[0]!, - }); - } + const allCasing = ['original', 'upper', 'lower', 'capitalize', 'uncapitalize']; + if (['currentModel', 'currentOperation'].includes(funcDecl.name)) { + const arg = getLiteral(expr.args[0]?.value); + if (arg && !allCasing.includes(arg)) { + accept('error', `argument must be one of: ${allCasing.map((c) => '"' + c + '"').join(', ')}`, { + node: expr.args[0]!, + }); } } diff --git a/packages/language/test/utils.ts b/packages/language/test/utils.ts index 4b60ce42..690b2490 100644 --- a/packages/language/test/utils.ts +++ b/packages/language/test/utils.ts @@ -5,11 +5,13 @@ import path from 'node:path'; import { expect } from 'vitest'; import { loadDocument } from '../src'; +const pluginDocs = [path.resolve(__dirname, '../../plugins/policy/plugin.zmodel')]; + export async function loadSchema(schema: string) { // create a temp file const tempFile = path.join(os.tmpdir(), `zenstack-schema-${crypto.randomUUID()}.zmodel`); fs.writeFileSync(tempFile, schema); - const r = await loadDocument(tempFile); + const r = await loadDocument(tempFile, pluginDocs); expect(r).toSatisfy( (r) => r.success, `Failed to load schema: ${(r as any).errors?.map((e) => e.toString()).join(', ')}`, @@ -22,7 +24,8 @@ export async function loadSchemaWithError(schema: string, error: string | RegExp // create a temp file const tempFile = path.join(os.tmpdir(), `zenstack-schema-${crypto.randomUUID()}.zmodel`); fs.writeFileSync(tempFile, schema); - const r = await loadDocument(tempFile); + + const r = await loadDocument(tempFile, pluginDocs); expect(r.success).toBe(false); invariant(!r.success); if (typeof error === 'string') { diff --git a/packages/plugins/policy/package.json b/packages/plugins/policy/package.json index 4c60ab7d..4187bd9f 100644 --- a/packages/plugins/policy/package.json +++ b/packages/plugins/policy/package.json @@ -13,7 +13,8 @@ "author": "ZenStack Team", "license": "MIT", "files": [ - "dist" + "dist", + "plugin.zmodel" ], "exports": { ".": { @@ -26,6 +27,10 @@ "default": "./dist/index.cjs" } }, + "./plugin.zmodel": { + "import": "./plugin.zmodel", + "require": "./plugin.zmodel" + }, "./package.json": { "import": "./package.json", "require": "./package.json" @@ -33,7 +38,6 @@ }, "dependencies": { "@zenstackhq/common-helpers": "workspace:*", - "@zenstackhq/sdk": "workspace:*", "@zenstackhq/runtime": "workspace:*", "ts-pattern": "catalog:" }, diff --git a/packages/plugins/policy/plugin.zmodel b/packages/plugins/policy/plugin.zmodel new file mode 100644 index 00000000..ae04a41b --- /dev/null +++ b/packages/plugins/policy/plugin.zmodel @@ -0,0 +1,72 @@ +/** + * Defines an access policy that allows a set of operations when the given condition is true. + * + * @param operation: comma-separated list of "create", "read", "update", "post-update", " "delete". Use "all" to denote all operations. + * @param condition: a boolean expression that controls if the operation should be allowed. + */ +attribute @@allow(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'post-update'","'delete'", "'all'"]), _ condition: Boolean) + +/** + * Defines an access policy that allows the annotated field to be read or updated. + * You can pass a third argument as `true` to make it override the model-level policies. + * + * @param operation: comma-separated list of "create", "read", "update", "post-update", "delete". Use "all" to denote all operations. + * @param condition: a boolean expression that controls if the operation should be allowed. + * @param override: a boolean value that controls if the field-level policy should override the model-level policy. + */ +// attribute @allow(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'post-update'", "'delete'", "'all'"]), _ condition: Boolean, _ override: Boolean?) + +/** + * Defines an access policy that denies a set of operations when the given condition is true. + * + * @param operation: comma-separated list of "create", "read", "update", "post-update", "delete". Use "all" to denote all operations. + * @param condition: a boolean expression that controls if the operation should be denied. + */ +attribute @@deny(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'post-update'","'delete'", "'all'"]), _ condition: Boolean) + +/** + * Defines an access policy that denies the annotated field to be read or updated. + * + * @param operation: comma-separated list of "create", "read", "update", "post-update", "delete". Use "all" to denote all operations. + * @param condition: a boolean expression that controls if the operation should be denied. + */ +// attribute @deny(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'delete'", "'all'"]), _ condition: Boolean) + +/** + * Checks if the current user can perform the given operation on the given field. + * + * @param field: The field to check access for + * @param operation: The operation to check access for. Can be "read", "create", "update", "post-update", or "delete". If the operation is not provided, + * it defaults the operation of the containing policy rule. + */ +function check(field: Any, operation: String?): Boolean { +} @@@expressionContext([AccessPolicy]) + +/** + * Gets entity's value before an update. Only valid when used in a "post-update" policy rule. + */ +function before(): Any { +} @@@expressionContext([AccessPolicy]) + +/** + * The name of the model for which the policy rule is defined. If the rule is + * inherited to a sub model, this function returns the name of the sub model. + * + * @param optional parameter to control the casing of the returned value. Valid + * values are "original", "upper", "lower", "capitalize", "uncapitalize". Defaults + * to "original". + */ +function currentModel(casing: String?): String { +} @@@expressionContext([AccessPolicy]) + +/** + * The operation for which the policy rule is defined for. Note that a rule with + * "all" operation is expanded to "create", "read", "update", and "delete" rules, + * and the function returns corresponding value for each expanded version. + * + * @param optional parameter to control the casing of the returned value. Valid + * values are "original", "upper", "lower", "capitalize", "uncapitalize". Defaults + * to "original". + */ +function currentOperation(casing: String?): String { +} @@@expressionContext([AccessPolicy]) diff --git a/packages/plugins/policy/src/column-collector.ts b/packages/plugins/policy/src/column-collector.ts index 37d9df11..ed68c76d 100644 --- a/packages/plugins/policy/src/column-collector.ts +++ b/packages/plugins/policy/src/column-collector.ts @@ -1,10 +1,10 @@ +import { KyselyUtils } from '@zenstackhq/runtime'; import type { ColumnNode, OperationNode } from 'kysely'; -import { DefaultOperationNodeVisitor } from '@zenstackhq/sdk'; /** * Collects all column names from a query. */ -export class ColumnCollector extends DefaultOperationNodeVisitor { +export class ColumnCollector extends KyselyUtils.DefaultOperationNodeVisitor { private columns: string[] = []; collect(node: OperationNode) { diff --git a/packages/plugins/policy/src/policy-handler.ts b/packages/plugins/policy/src/policy-handler.ts index 92f5e74c..c90bb563 100644 --- a/packages/plugins/policy/src/policy-handler.ts +++ b/packages/plugins/policy/src/policy-handler.ts @@ -7,6 +7,7 @@ import { QueryUtils, RejectedByPolicyError, RejectedByPolicyReason, + SchemaUtils, type CRUD_EXT, } from '@zenstackhq/runtime'; import { @@ -17,7 +18,6 @@ import { type MemberExpression, type SchemaDef, } from '@zenstackhq/runtime/schema'; -import { ExpressionVisitor } from '@zenstackhq/sdk'; import { AliasNode, BinaryOperationNode, @@ -270,7 +270,7 @@ export class PolicyHandler extends OperationNodeTransf } const fields = new Set(); - const fieldCollector = new (class extends ExpressionVisitor { + const fieldCollector = new (class extends SchemaUtils.ExpressionVisitor { protected override visitMember(e: MemberExpression): void { if (isBeforeInvocation(e.receiver)) { invariant(e.members.length === 1, 'before() can only be followed by a scalar field access'); diff --git a/packages/runtime/src/index.ts b/packages/runtime/src/index.ts index b5018f05..cff4960f 100644 --- a/packages/runtime/src/index.ts +++ b/packages/runtime/src/index.ts @@ -1,2 +1,4 @@ export * from './client'; +export * as KyselyUtils from './utils/kysely-utils'; +export * as SchemaUtils from './utils/schema-utils'; export type { JsonArray, JsonObject, JsonValue } from './utils/type-utils'; diff --git a/packages/sdk/src/default-operation-node-visitor.ts b/packages/runtime/src/utils/kysely-utils.ts similarity index 100% rename from packages/sdk/src/default-operation-node-visitor.ts rename to packages/runtime/src/utils/kysely-utils.ts diff --git a/packages/sdk/src/expression-utils.ts b/packages/runtime/src/utils/schema-utils.ts similarity index 98% rename from packages/sdk/src/expression-utils.ts rename to packages/runtime/src/utils/schema-utils.ts index ec423767..8c0824d4 100644 --- a/packages/sdk/src/expression-utils.ts +++ b/packages/runtime/src/utils/schema-utils.ts @@ -10,7 +10,7 @@ import type { NullExpression, ThisExpression, UnaryExpression, -} from './schema'; +} from '../schema'; export class ExpressionVisitor { visit(expr: Expression): void { diff --git a/packages/sdk/src/index.ts b/packages/sdk/src/index.ts index c74b1419..649a7201 100644 --- a/packages/sdk/src/index.ts +++ b/packages/sdk/src/index.ts @@ -1,7 +1,5 @@ import * as ModelUtils from './model-utils'; export * from './cli-plugin'; -export * from './expression-utils'; -export * from './default-operation-node-visitor'; export { PrismaSchemaGenerator } from './prisma/prisma-schema-generator'; export * from './ts-schema-generator'; export * from './zmodel-code-generator'; diff --git a/packages/sdk/src/model-utils.ts b/packages/sdk/src/model-utils.ts index 7b54aa96..198e59ac 100644 --- a/packages/sdk/src/model-utils.ts +++ b/packages/sdk/src/model-utils.ts @@ -3,7 +3,6 @@ import { isLiteralExpr, isModel, isTypeDef, - Model, type AstNode, type Attribute, type AttributeParam, @@ -14,6 +13,7 @@ import { type Enum, type EnumField, type FunctionDecl, + type Model, type Reference, type TypeDef, } from '@zenstackhq/language/ast'; diff --git a/packages/testtools/src/client.ts b/packages/testtools/src/client.ts index 4a23280e..27133f83 100644 --- a/packages/testtools/src/client.ts +++ b/packages/testtools/src/client.ts @@ -1,5 +1,4 @@ import { invariant } from '@zenstackhq/common-helpers'; -import { loadDocument } from '@zenstackhq/language'; import type { Model } from '@zenstackhq/language/ast'; import { PolicyPlugin } from '@zenstackhq/plugin-policy'; import { ZenStackClient, type ClientContract, type ClientOptions } from '@zenstackhq/runtime'; @@ -15,6 +14,7 @@ import { Client as PGClient, Pool } from 'pg'; import { expect } from 'vitest'; import { createTestProject } from './project'; import { generateTsSchema } from './schema'; +import { loadDocumentWithPlugins } from './utils'; export function getTestDbProvider() { const val = process.env['TEST_DB_PROVIDER'] ?? 'sqlite'; @@ -116,7 +116,7 @@ export async function createTestClient( if (options?.usePrismaPush) { invariant(typeof schema === 'string' || schemaFile, 'a schema file must be provided when using prisma db push'); if (!model) { - const r = await loadDocument(path.join(workDir, 'schema.zmodel')); + const r = await loadDocumentWithPlugins(path.join(workDir, 'schema.zmodel')); if (!r.success) { throw new Error(r.errors.join('\n')); } diff --git a/packages/testtools/src/schema.ts b/packages/testtools/src/schema.ts index 516d445a..b78bcbc8 100644 --- a/packages/testtools/src/schema.ts +++ b/packages/testtools/src/schema.ts @@ -1,5 +1,4 @@ import { invariant } from '@zenstackhq/common-helpers'; -import { loadDocument } from '@zenstackhq/language'; import { TsSchemaGenerator } from '@zenstackhq/sdk'; import type { SchemaDef } from '@zenstackhq/sdk/schema'; import { execSync } from 'node:child_process'; @@ -10,6 +9,7 @@ import path from 'node:path'; import { match } from 'ts-pattern'; import { expect } from 'vitest'; import { createTestProject } from './project'; +import { loadDocumentWithPlugins } from './utils'; function makePrelude(provider: 'sqlite' | 'postgresql', dbUrl?: string) { return match(provider) @@ -44,7 +44,7 @@ export async function generateTsSchema( const noPrelude = schemaText.includes('datasource '); fs.writeFileSync(zmodelPath, `${noPrelude ? '' : makePrelude(provider, dbUrl)}\n\n${schemaText}`); - const result = await loadDocument(zmodelPath); + const result = await loadDocumentWithPlugins(zmodelPath); if (!result.success) { throw new Error(`Failed to load schema from ${zmodelPath}: ${result.errors}`); } @@ -82,7 +82,7 @@ export function generateTsSchemaFromFile(filePath: string) { export async function generateTsSchemaInPlace(schemaPath: string) { const workDir = path.dirname(schemaPath); - const result = await loadDocument(schemaPath); + const result = await loadDocumentWithPlugins(schemaPath); if (!result.success) { throw new Error(`Failed to load schema from ${schemaPath}: ${result.errors}`); } @@ -114,7 +114,7 @@ export async function loadSchema(schema: string, additionalSchemas?: Record r.success, `Failed to load schema: ${(r as any).errors?.map((e: any) => e.toString()).join(', ')}`, @@ -131,7 +131,7 @@ export async function loadSchemaWithError(schema: string, error: string | RegExp // create a temp file const tempFile = path.join(os.tmpdir(), `zenstack-schema-${crypto.randomUUID()}.zmodel`); fs.writeFileSync(tempFile, schema); - const r = await loadDocument(tempFile); + const r = await loadDocumentWithPlugins(tempFile); expect(r.success).toBe(false); invariant(!r.success); if (typeof error === 'string') { diff --git a/packages/testtools/src/utils.ts b/packages/testtools/src/utils.ts new file mode 100644 index 00000000..75f48ee8 --- /dev/null +++ b/packages/testtools/src/utils.ts @@ -0,0 +1,7 @@ +import { loadDocument } from '@zenstackhq/language'; +import path from 'node:path'; + +export function loadDocumentWithPlugins(filePath: string) { + const pluginModelFiles = [path.resolve(__dirname, '../node_modules/@zenstackhq/plugin-policy/plugin.zmodel')]; + return loadDocument(filePath, pluginModelFiles); +} diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index eff0e66f..aa32c4a2 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -282,9 +282,6 @@ importers: '@zenstackhq/runtime': specifier: workspace:* version: link:../../runtime - '@zenstackhq/sdk': - specifier: workspace:* - version: link:../../sdk kysely: specifier: 'catalog:' version: 0.27.6 @@ -510,6 +507,9 @@ importers: samples/blog: dependencies: + '@zenstackhq/plugin-policy': + specifier: workspace:* + version: link:../../packages/plugins/policy '@zenstackhq/runtime': specifier: workspace:* version: link:../../packages/runtime diff --git a/samples/blog/main.ts b/samples/blog/main.ts index 8bbfb5bf..d0c82fc0 100644 --- a/samples/blog/main.ts +++ b/samples/blog/main.ts @@ -1,3 +1,4 @@ +import { PolicyPlugin } from '@zenstackhq/plugin-policy'; import { ZenStackClient } from '@zenstackhq/runtime'; import SQLite from 'better-sqlite3'; import { sql, SqliteDialect } from 'kysely'; @@ -89,6 +90,17 @@ async function main() { }, }); console.log('User found with computed field:', userWithMorePosts); + + // policy-enabled read + const authDb = db.$use(new PolicyPlugin()); + const user1Db = authDb.$setAuth({ id: user1.id }); + const user2Db = authDb.$setAuth({ id: user2.id }); + + console.log('Posts readable to', user1.email); + console.table(await user1Db.post.findMany({ select: { title: true, published: true } })); + + console.log('Posts readable to', user2.email); + console.table(await user2Db.post.findMany({ select: { title: true, published: true } })); } main(); diff --git a/samples/blog/package.json b/samples/blog/package.json index 9298036c..8e56eacb 100644 --- a/samples/blog/package.json +++ b/samples/blog/package.json @@ -15,6 +15,7 @@ "license": "MIT", "dependencies": { "@zenstackhq/runtime": "workspace:*", + "@zenstackhq/plugin-policy": "workspace:*", "better-sqlite3": "^12.2.0", "kysely": "catalog:" }, diff --git a/samples/blog/zenstack/models.ts b/samples/blog/zenstack/models.ts index 86b941a1..608e9397 100644 --- a/samples/blog/zenstack/models.ts +++ b/samples/blog/zenstack/models.ts @@ -9,8 +9,6 @@ import { schema as $schema, type SchemaType as $Schema } from "./schema"; import { type ModelResult as $ModelResult, type TypeDefResult as $TypeDefResult } from "@zenstackhq/runtime"; /** * User model - * - * Represents a user of the blog. */ export type User = $ModelResult<$Schema, "User">; /** diff --git a/samples/blog/zenstack/schema.ts b/samples/blog/zenstack/schema.ts index 4ca14e3e..fa51f40f 100644 --- a/samples/blog/zenstack/schema.ts +++ b/samples/blog/zenstack/schema.ts @@ -69,6 +69,10 @@ export const schema = { relation: { opposite: "user" } } }, + attributes: [ + { name: "@@allow", args: [{ name: "operation", value: ExpressionUtils.literal("read,create") }, { name: "condition", value: ExpressionUtils.literal(true) }] }, + { name: "@@allow", args: [{ name: "operation", value: ExpressionUtils.literal("all") }, { name: "condition", value: ExpressionUtils.binary(ExpressionUtils._this(), "==", ExpressionUtils.call("auth")) }] } + ], idFields: ["id"], uniqueFields: { id: { type: "String" }, @@ -132,6 +136,9 @@ export const schema = { ] } }, + attributes: [ + { name: "@@allow", args: [{ name: "operation", value: ExpressionUtils.literal("all") }, { name: "condition", value: ExpressionUtils.call("check", [ExpressionUtils.field("user")]) }] } + ], idFields: ["id"], uniqueFields: { id: { type: "String" }, @@ -188,6 +195,10 @@ export const schema = { ] } }, + attributes: [ + { name: "@@allow", args: [{ name: "operation", value: ExpressionUtils.literal("read") }, { name: "condition", value: ExpressionUtils.field("published") }] }, + { name: "@@allow", args: [{ name: "operation", value: ExpressionUtils.literal("all") }, { name: "condition", value: ExpressionUtils.binary(ExpressionUtils.field("author"), "==", ExpressionUtils.call("auth")) }] } + ], idFields: ["id"], uniqueFields: { id: { type: "String" } diff --git a/samples/blog/zenstack/schema.zmodel b/samples/blog/zenstack/schema.zmodel index aeccb56f..708afb9e 100644 --- a/samples/blog/zenstack/schema.zmodel +++ b/samples/blog/zenstack/schema.zmodel @@ -9,6 +9,10 @@ enum Role { USER } +plugin policy { + provider = '../node_modules/@zenstackhq/plugin-policy/dist/index.js' +} + type CommonFields { id String @id @default(cuid()) createdAt DateTime @default(now()) @@ -16,8 +20,6 @@ type CommonFields { } /// User model -/// -/// Represents a user of the blog. model User with CommonFields { email String @unique name String? @@ -25,6 +27,9 @@ model User with CommonFields { role Role @default(USER) posts Post[] profile Profile? + + @@allow('read,create', true) + @@allow('all', this == auth()) } /// Profile model @@ -33,6 +38,7 @@ model Profile with CommonFields { age Int? user User? @relation(fields: [userId], references: [id]) userId String? @unique + @@allow('all', check(user)) } /// Post model @@ -42,4 +48,7 @@ model Post with CommonFields { published Boolean @default(false) author User @relation(fields: [authorId], references: [id]) authorId String + + @@allow('read', published) + @@allow('all', author == auth()) } diff --git a/tests/e2e/orm/policy/migrated/current-model.test.ts b/tests/e2e/orm/policy/migrated/current-model.test.ts index 04c9a30a..0ceb96c3 100644 --- a/tests/e2e/orm/policy/migrated/current-model.test.ts +++ b/tests/e2e/orm/policy/migrated/current-model.test.ts @@ -158,7 +158,7 @@ describe('currentModel tests', () => { createPolicyTestClient( ` model User { - id String @default(currentModel()) + id String @id @default(currentModel()) } `, ), diff --git a/tests/e2e/orm/policy/migrated/current-operation.test.ts b/tests/e2e/orm/policy/migrated/current-operation.test.ts index 2b1610c9..3cbae4ca 100644 --- a/tests/e2e/orm/policy/migrated/current-operation.test.ts +++ b/tests/e2e/orm/policy/migrated/current-operation.test.ts @@ -127,7 +127,7 @@ describe('currentOperation tests', () => { createPolicyTestClient( ` model User { - id String @default(currentOperation()) + id String @id @default(currentOperation()) } `, ), diff --git a/tests/e2e/orm/scripts/generate.ts b/tests/e2e/orm/scripts/generate.ts index b4f78958..9d59db73 100644 --- a/tests/e2e/orm/scripts/generate.ts +++ b/tests/e2e/orm/scripts/generate.ts @@ -1,4 +1,5 @@ import { loadDocument } from '@zenstackhq/language'; +import type { Model } from '@zenstackhq/language/ast'; import { TsSchemaGenerator } from '@zenstackhq/sdk'; import { glob } from 'glob'; import path from 'node:path'; @@ -17,11 +18,18 @@ async function main() { async function generate(schemaPath: string) { const generator = new TsSchemaGenerator(); const outputDir = path.dirname(schemaPath); - const result = await loadDocument(schemaPath); + + // isomorphic __dirname + const _dirname = typeof __dirname !== 'undefined' ? __dirname : path.dirname(fileURLToPath(import.meta.url)); + + // plugin models + const pluginDocs = [path.resolve(_dirname, '../../node_modules/@zenstackhq/plugin-policy/plugin.zmodel')]; + + const result = await loadDocument(schemaPath, pluginDocs); if (!result.success) { throw new Error(`Failed to load schema from ${schemaPath}: ${result.errors}`); } - await generator.generate(result.model, outputDir); + await generator.generate(result.model as Model, outputDir); } main();