diff --git a/.vscode/launch.json b/.vscode/launch.json index 886089d1..09ccbd59 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -11,7 +11,7 @@ "skipFiles": ["/**"], "type": "node", "args": ["generate"], - "cwd": "${workspaceFolder}/samples/blog/zenstack" + "cwd": "${workspaceFolder}/samples/blog" }, { "name": "Debug with TSX", diff --git a/package.json b/package.json index fc9d2920..d46cef02 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "zenstack-v3", - "version": "3.0.0-beta.9", + "version": "3.0.0-beta.10", "description": "ZenStack", "packageManager": "pnpm@10.12.1", "scripts": { diff --git a/packages/cli/package.json b/packages/cli/package.json index b21a97e4..c89bd768 100644 --- a/packages/cli/package.json +++ b/packages/cli/package.json @@ -3,7 +3,7 @@ "publisher": "zenstack", "displayName": "ZenStack CLI", "description": "FullStack database toolkit with built-in access control and automatic API generation.", - "version": "3.0.0-beta.9", + "version": "3.0.0-beta.10", "type": "module", "author": { "name": "ZenStack Team" diff --git a/packages/cli/src/actions/generate.ts b/packages/cli/src/actions/generate.ts index e80c219d..2db079e1 100644 --- a/packages/cli/src/actions/generate.ts +++ b/packages/cli/src/actions/generate.ts @@ -64,7 +64,7 @@ async function runPlugins(schemaFile: string, model: Model, outputPath: string, for (const plugin of plugins) { const provider = getPluginProvider(plugin); - let cliPlugin: CliPlugin; + let cliPlugin: CliPlugin | undefined; if (provider.startsWith('@core/')) { cliPlugin = (corePlugins as any)[provider.slice('@core/'.length)]; if (!cliPlugin) { @@ -78,12 +78,14 @@ async function runPlugins(schemaFile: string, model: Model, outputPath: string, } try { cliPlugin = (await import(moduleSpec)).default as CliPlugin; - } catch (error) { - throw new CliError(`Failed to load plugin ${provider}: ${error}`); + } catch { + // plugin may not export a generator so we simply ignore the error here } } - processedPlugins.push({ cliPlugin, pluginOptions: getPluginOptions(plugin) }); + if (cliPlugin) { + processedPlugins.push({ cliPlugin, pluginOptions: getPluginOptions(plugin) }); + } } const defaultPlugins = [corePlugins['typescript']].reverse(); diff --git a/packages/cli/src/actions/migrate.ts b/packages/cli/src/actions/migrate.ts index 896cc991..19f94ce7 100644 --- a/packages/cli/src/actions/migrate.ts +++ b/packages/cli/src/actions/migrate.ts @@ -82,9 +82,12 @@ async function runDev(prismaSchemaFile: string, options: DevOptions) { async function runReset(prismaSchemaFile: string, options: ResetOptions) { try { - const cmd = ['prisma migrate reset', ` --schema "${prismaSchemaFile}"`, options.force ? ' --force' : ''].join( - '', - ); + const cmd = [ + 'prisma migrate reset', + ` --schema "${prismaSchemaFile}"`, + ' --skip-generate', + options.force ? ' --force' : '' + ].join(''); await execPackage(cmd); } catch (err) { diff --git a/packages/cli/src/constants.ts b/packages/cli/src/constants.ts index 586537af..72463404 100644 --- a/packages/cli/src/constants.ts +++ b/packages/cli/src/constants.ts @@ -1,2 +1,5 @@ // replaced at build time export const TELEMETRY_TRACKING_TOKEN = ''; + +// plugin-contributed model file name +export const PLUGIN_MODULE_NAME = 'plugin.zmodel'; diff --git a/packages/common-helpers/package.json b/packages/common-helpers/package.json index 958daf1b..d5ebe7fc 100644 --- a/packages/common-helpers/package.json +++ b/packages/common-helpers/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/common-helpers", - "version": "3.0.0-beta.9", + "version": "3.0.0-beta.10", "description": "ZenStack Common Helpers", "type": "module", "scripts": { diff --git a/packages/config/eslint-config/package.json b/packages/config/eslint-config/package.json index c1421f8d..64b951cf 100644 --- a/packages/config/eslint-config/package.json +++ b/packages/config/eslint-config/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/eslint-config", - "version": "3.0.0-beta.9", + "version": "3.0.0-beta.10", "type": "module", "private": true, "license": "MIT" diff --git a/packages/config/typescript-config/package.json b/packages/config/typescript-config/package.json index 53638e67..ee134446 100644 --- a/packages/config/typescript-config/package.json +++ b/packages/config/typescript-config/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/typescript-config", - "version": "3.0.0-beta.9", + "version": "3.0.0-beta.10", "private": true, "license": "MIT" } diff --git a/packages/config/vitest-config/package.json b/packages/config/vitest-config/package.json index 35f3a279..e7686f38 100644 --- a/packages/config/vitest-config/package.json +++ b/packages/config/vitest-config/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/vitest-config", "type": "module", - "version": "3.0.0-beta.9", + "version": "3.0.0-beta.10", "private": true, "license": "MIT", "exports": { diff --git a/packages/create-zenstack/package.json b/packages/create-zenstack/package.json index 4288b173..924924fa 100644 --- a/packages/create-zenstack/package.json +++ b/packages/create-zenstack/package.json @@ -1,6 +1,6 @@ { "name": "create-zenstack", - "version": "3.0.0-beta.9", + "version": "3.0.0-beta.10", "description": "Create a new ZenStack project", "type": "module", "scripts": { diff --git a/packages/dialects/sql.js/package.json b/packages/dialects/sql.js/package.json index bb1ab3c1..acb53278 100644 --- a/packages/dialects/sql.js/package.json +++ b/packages/dialects/sql.js/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/kysely-sql-js", - "version": "3.0.0-beta.9", + "version": "3.0.0-beta.10", "description": "Kysely dialect for sql.js", "type": "module", "scripts": { diff --git a/packages/ide/vscode/package.json b/packages/ide/vscode/package.json index ac3667af..44f22fcc 100644 --- a/packages/ide/vscode/package.json +++ b/packages/ide/vscode/package.json @@ -11,6 +11,7 @@ }, "scripts": { "build": "tsc --noEmit && tsup", + "watch": "tsup --watch", "lint": "eslint src --ext ts", "vscode:publish": "pnpm build && vsce publish --no-dependencies --follow-symlinks", "vscode:package": "pnpm build && vsce package --no-dependencies --follow-symlinks" diff --git a/packages/ide/vscode/src/language-server/main.ts b/packages/ide/vscode/src/language-server/main.ts index b9ac6998..efa21569 100644 --- a/packages/ide/vscode/src/language-server/main.ts +++ b/packages/ide/vscode/src/language-server/main.ts @@ -7,10 +7,13 @@ import { createConnection, ProposedFeatures } from 'vscode-languageserver/node.j const connection = createConnection(ProposedFeatures.all); // Inject the shared services and language-specific services -const { shared } = createZModelLanguageServices({ - connection, - ...NodeFileSystem, -}); +const { shared } = createZModelLanguageServices( + { + connection, + ...NodeFileSystem, + }, + true, +); // Start the language server with the shared services startLanguageServer(shared); diff --git a/packages/language/package.json b/packages/language/package.json index c6b487f2..b20e9815 100644 --- a/packages/language/package.json +++ b/packages/language/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/language", "description": "ZenStack ZModel language specification", - "version": "3.0.0-beta.9", + "version": "3.0.0-beta.10", "license": "MIT", "author": "ZenStack Team", "files": [ diff --git a/packages/language/res/stdlib.zmodel b/packages/language/res/stdlib.zmodel index 7ac57ba3..f1b46e84 100644 --- a/packages/language/res/stdlib.zmodel +++ b/packages/language/res/stdlib.zmodel @@ -174,29 +174,6 @@ function hasSome(field: Any[], search: Any[]): Boolean { function isEmpty(field: Any[]): Boolean { } @@@expressionContext([AccessPolicy, ValidationRule]) -/** - * The name of the model for which the policy rule is defined. If the rule is - * inherited to a sub model, this function returns the name of the sub model. - * - * @param optional parameter to control the casing of the returned value. Valid - * values are "original", "upper", "lower", "capitalize", "uncapitalize". Defaults - * to "original". - */ -function currentModel(casing: String?): String { -} @@@expressionContext([AccessPolicy]) - -/** - * The operation for which the policy rule is defined for. Note that a rule with - * "all" operation is expanded to "create", "read", "update", and "delete" rules, - * and the function returns corresponding value for each expanded version. - * - * @param optional parameter to control the casing of the returned value. Valid - * values are "original", "upper", "lower", "capitalize", "uncapitalize". Defaults - * to "original". - */ -function currentOperation(casing: String?): String { -} @@@expressionContext([AccessPolicy]) - /** * Marks an attribute to be only applicable to certain field types. */ @@ -658,56 +635,3 @@ attribute @meta(_ name: String, _ value: Any) * Marks an attribute as deprecated. */ attribute @@@deprecated(_ message: String) - -/* --- Policy Plugin --- */ - -/** - * Defines an access policy that allows a set of operations when the given condition is true. - * - * @param operation: comma-separated list of "create", "read", "update", "delete". Use "all" to denote all operations. - * @param condition: a boolean expression that controls if the operation should be allowed. - */ -attribute @@allow(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'post-update'","'delete'", "'all'"]), _ condition: Boolean) - -/** - * Defines an access policy that allows the annotated field to be read or updated. - * You can pass a third argument as `true` to make it override the model-level policies. - * - * @param operation: comma-separated list of "create", "read", "update", "delete". Use "all" to denote all operations. - * @param condition: a boolean expression that controls if the operation should be allowed. - * @param override: a boolean value that controls if the field-level policy should override the model-level policy. - */ -// attribute @allow(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'delete'", "'all'"]), _ condition: Boolean, _ override: Boolean?) - -/** - * Defines an access policy that denies a set of operations when the given condition is true. - * - * @param operation: comma-separated list of "create", "read", "update", "delete". Use "all" to denote all operations. - * @param condition: a boolean expression that controls if the operation should be denied. - */ -attribute @@deny(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'post-update'","'delete'", "'all'"]), _ condition: Boolean) - -/** - * Defines an access policy that denies the annotated field to be read or updated. - * - * @param operation: comma-separated list of "create", "read", "update", "delete". Use "all" to denote all operations. - * @param condition: a boolean expression that controls if the operation should be denied. - */ -// attribute @deny(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'delete'", "'all'"]), _ condition: Boolean) - -/** - * Checks if the current user can perform the given operation on the given field. - * - * @param field: The field to check access for - * @param operation: The operation to check access for. Can be "read", "create", "update", or "delete". If the operation is not provided, - * it defaults the operation of the containing policy rule. - */ -function check(field: Any, operation: String?): Boolean { -} @@@expressionContext([AccessPolicy]) - -/** - * Gets entity's value before an update. Only valid when used in a "post-update" policy rule. - */ -function before(): Any { -} @@@expressionContext([AccessPolicy]) - diff --git a/packages/language/src/document.ts b/packages/language/src/document.ts new file mode 100644 index 00000000..b8405c48 --- /dev/null +++ b/packages/language/src/document.ts @@ -0,0 +1,202 @@ +import { isAstNode, URI, type AstNode, type LangiumDocument, type LangiumDocuments, type Mutable } from 'langium'; +import fs from 'node:fs'; +import path from 'node:path'; +import { fileURLToPath } from 'node:url'; +import { isDataSource, type Model } from './ast'; +import { STD_LIB_MODULE_NAME } from './constants'; +import { createZModelServices } from './module'; +import { getDataModelAndTypeDefs, getDocument, hasAttribute, resolveImport, resolveTransitiveImports } from './utils'; + +/** + * Loads ZModel document from the given file name. Include the additional document + * files if given. + */ +export async function loadDocument( + fileName: string, + additionalModelFiles: string[] = [], +): Promise< + { success: true; model: Model; warnings: string[] } | { success: false; errors: string[]; warnings: string[] } +> { + const { ZModelLanguage: services } = createZModelServices(false); + const extensions = services.LanguageMetaData.fileExtensions; + if (!extensions.includes(path.extname(fileName))) { + return { + success: false, + errors: ['invalid schema file extension'], + warnings: [], + }; + } + + if (!fs.existsSync(fileName)) { + return { + success: false, + errors: ['schema file does not exist'], + warnings: [], + }; + } + + // load standard library + + // isomorphic __dirname + const _dirname = typeof __dirname !== 'undefined' ? __dirname : path.dirname(fileURLToPath(import.meta.url)); + const stdLib = await services.shared.workspace.LangiumDocuments.getOrCreateDocument( + URI.file(path.resolve(path.join(_dirname, '../res', STD_LIB_MODULE_NAME))), + ); + + // load the document + const langiumDocuments = services.shared.workspace.LangiumDocuments; + const document = await langiumDocuments.getOrCreateDocument(URI.file(path.resolve(fileName))); + + // load imports + const importedURIs = await loadImports(document, langiumDocuments); + const importedDocuments: LangiumDocument[] = []; + for (const uri of importedURIs) { + importedDocuments.push(await langiumDocuments.getOrCreateDocument(uri)); + } + + // build the document together with standard library, additional modules, and imported documents + + // load additional model files + const additionalDocs = await Promise.all( + additionalModelFiles.map((file) => + services.shared.workspace.LangiumDocuments.getOrCreateDocument(URI.file(path.resolve(file))), + ), + ); + + await services.shared.workspace.DocumentBuilder.build([stdLib, ...additionalDocs, document, ...importedDocuments], { + validation: { + stopAfterLexingErrors: true, + stopAfterParsingErrors: true, + stopAfterLinkingErrors: true, + }, + }); + + const diagnostics = langiumDocuments.all + .flatMap((doc) => (doc.diagnostics ?? []).map((diag) => ({ doc, diag }))) + .filter(({ diag }) => diag.severity === 1 || diag.severity === 2) + .toArray(); + + const errors: string[] = []; + const warnings: string[] = []; + + if (diagnostics.length > 0) { + for (const { doc, diag } of diagnostics) { + const message = `${path.relative(process.cwd(), doc.uri.fsPath)}:${ + diag.range.start.line + 1 + }:${diag.range.start.character + 1} - ${diag.message}`; + + if (diag.severity === 1) { + errors.push(message); + } else { + warnings.push(message); + } + } + } + + if (errors.length > 0) { + return { + success: false, + errors, + warnings, + }; + } + + const model = document.parseResult.value as Model; + + // merge all declarations into the main document + const imported = mergeImportsDeclarations(langiumDocuments, model); + + // remove imported documents + imported.forEach((model) => { + langiumDocuments.deleteDocument(model.$document!.uri); + services.shared.workspace.IndexManager.remove(model.$document!.uri); + }); + + // extra validation after merging imported declarations + const additionalErrors = validationAfterImportMerge(model); + if (additionalErrors.length > 0) { + return { + success: false, + errors: additionalErrors, + warnings, + }; + } + + return { + success: true, + model: document.parseResult.value as Model, + warnings, + }; +} + +async function loadImports(document: LangiumDocument, documents: LangiumDocuments, uris: Set = new Set()) { + const uriString = document.uri.toString(); + if (!uris.has(uriString)) { + uris.add(uriString); + const model = document.parseResult.value as Model; + for (const imp of model.imports) { + const importedModel = resolveImport(documents, imp); + if (importedModel) { + const importedDoc = getDocument(importedModel); + await loadImports(importedDoc, documents, uris); + } + } + } + return Array.from(uris) + .filter((x) => uriString != x) + .map((e) => URI.parse(e)); +} + +function mergeImportsDeclarations(documents: LangiumDocuments, model: Model) { + const importedModels = resolveTransitiveImports(documents, model); + + const importedDeclarations = importedModels.flatMap((m) => m.declarations); + model.declarations.push(...importedDeclarations); + + // remove import directives + model.imports = []; + + // fix $container, $containerIndex, and $containerProperty + linkContentToContainer(model); + + return importedModels; +} + +function linkContentToContainer(node: AstNode): void { + for (const [name, value] of Object.entries(node)) { + if (!name.startsWith('$')) { + if (Array.isArray(value)) { + value.forEach((item, index) => { + if (isAstNode(item)) { + (item as Mutable).$container = node; + (item as Mutable).$containerProperty = name; + (item as Mutable).$containerIndex = index; + } + }); + } else if (isAstNode(value)) { + (value as Mutable).$container = node; + (value as Mutable).$containerProperty = name; + } + } + } +} + +function validationAfterImportMerge(model: Model) { + const errors: string[] = []; + const dataSources = model.declarations.filter((d) => isDataSource(d)); + if (dataSources.length === 0) { + errors.push('Validation error: schema must have a datasource declaration'); + } else { + if (dataSources.length > 1) { + errors.push('Validation error: multiple datasource declarations are not allowed'); + } + } + + // at most one `@@auth` model + const decls = getDataModelAndTypeDefs(model, true); + const authDecls = decls.filter((d) => hasAttribute(d, '@@auth')); + if (authDecls.length > 1) { + errors.push('Validation error: Multiple `@@auth` declarations are not allowed'); + } + return errors; +} diff --git a/packages/language/src/index.ts b/packages/language/src/index.ts index ab577c7f..6edc0494 100644 --- a/packages/language/src/index.ts +++ b/packages/language/src/index.ts @@ -1,210 +1,2 @@ -import { isAstNode, URI, type LangiumDocument, type LangiumDocuments, type Mutable } from 'langium'; -import { NodeFileSystem } from 'langium/node'; -import fs from 'node:fs'; -import path from 'node:path'; -import { fileURLToPath } from 'node:url'; -import { isDataSource, type AstNode, type Model } from './ast'; -import { STD_LIB_MODULE_NAME } from './constants'; -import { createZModelLanguageServices } from './module'; -import { getDataModelAndTypeDefs, getDocument, hasAttribute, resolveImport, resolveTransitiveImports } from './utils'; - -export function createZModelServices() { - return createZModelLanguageServices(NodeFileSystem); -} - -export class DocumentLoadError extends Error { - constructor(message: string) { - super(message); - } -} - -export async function loadDocument( - fileName: string, - additionalModelFiles: string[] = [], -): Promise< - { success: true; model: Model; warnings: string[] } | { success: false; errors: string[]; warnings: string[] } -> { - const { ZModelLanguage: services } = createZModelServices(); - const extensions = services.LanguageMetaData.fileExtensions; - if (!extensions.includes(path.extname(fileName))) { - return { - success: false, - errors: ['invalid schema file extension'], - warnings: [], - }; - } - - if (!fs.existsSync(fileName)) { - return { - success: false, - errors: ['schema file does not exist'], - warnings: [], - }; - } - - // load standard library - - // isomorphic __dirname - const _dirname = typeof __dirname !== 'undefined' ? __dirname : path.dirname(fileURLToPath(import.meta.url)); - const stdLib = await services.shared.workspace.LangiumDocuments.getOrCreateDocument( - URI.file(path.resolve(path.join(_dirname, '../res', STD_LIB_MODULE_NAME))), - ); - - // load additional model files - const pluginDocs = await Promise.all( - additionalModelFiles.map((file) => - services.shared.workspace.LangiumDocuments.getOrCreateDocument(URI.file(path.resolve(file))), - ), - ); - - // load the document - const langiumDocuments = services.shared.workspace.LangiumDocuments; - const document = await langiumDocuments.getOrCreateDocument(URI.file(path.resolve(fileName))); - - // load imports - const importedURIs = await loadImports(document, langiumDocuments); - const importedDocuments: LangiumDocument[] = []; - for (const uri of importedURIs) { - importedDocuments.push(await langiumDocuments.getOrCreateDocument(uri)); - } - - // build the document together with standard library, plugin modules, and imported documents - await services.shared.workspace.DocumentBuilder.build([stdLib, ...pluginDocs, document, ...importedDocuments], { - validation: { - stopAfterLexingErrors: true, - stopAfterParsingErrors: true, - stopAfterLinkingErrors: true, - }, - }); - - const diagnostics = langiumDocuments.all - .flatMap((doc) => (doc.diagnostics ?? []).map((diag) => ({ doc, diag }))) - .filter(({ diag }) => diag.severity === 1 || diag.severity === 2) - .toArray(); - - const errors: string[] = []; - const warnings: string[] = []; - - if (diagnostics.length > 0) { - for (const { doc, diag } of diagnostics) { - const message = `${path.relative(process.cwd(), doc.uri.fsPath)}:${ - diag.range.start.line + 1 - }:${diag.range.start.character + 1} - ${diag.message}`; - - if (diag.severity === 1) { - errors.push(message); - } else { - warnings.push(message); - } - } - } - - if (errors.length > 0) { - return { - success: false, - errors, - warnings, - }; - } - - const model = document.parseResult.value as Model; - - // merge all declarations into the main document - const imported = mergeImportsDeclarations(langiumDocuments, model); - - // remove imported documents - imported.forEach((model) => { - langiumDocuments.deleteDocument(model.$document!.uri); - services.shared.workspace.IndexManager.remove(model.$document!.uri); - }); - - // extra validation after merging imported declarations - const additionalErrors = validationAfterImportMerge(model); - if (additionalErrors.length > 0) { - return { - success: false, - errors: additionalErrors, - warnings, - }; - } - - return { - success: true, - model: document.parseResult.value as Model, - warnings, - }; -} - -async function loadImports(document: LangiumDocument, documents: LangiumDocuments, uris: Set = new Set()) { - const uriString = document.uri.toString(); - if (!uris.has(uriString)) { - uris.add(uriString); - const model = document.parseResult.value as Model; - for (const imp of model.imports) { - const importedModel = resolveImport(documents, imp); - if (importedModel) { - const importedDoc = getDocument(importedModel); - await loadImports(importedDoc, documents, uris); - } - } - } - return Array.from(uris) - .filter((x) => uriString != x) - .map((e) => URI.parse(e)); -} - -function mergeImportsDeclarations(documents: LangiumDocuments, model: Model) { - const importedModels = resolveTransitiveImports(documents, model); - - const importedDeclarations = importedModels.flatMap((m) => m.declarations); - model.declarations.push(...importedDeclarations); - - // remove import directives - model.imports = []; - - // fix $container, $containerIndex, and $containerProperty - linkContentToContainer(model); - - return importedModels; -} - -function linkContentToContainer(node: AstNode): void { - for (const [name, value] of Object.entries(node)) { - if (!name.startsWith('$')) { - if (Array.isArray(value)) { - value.forEach((item, index) => { - if (isAstNode(item)) { - (item as Mutable).$container = node; - (item as Mutable).$containerProperty = name; - (item as Mutable).$containerIndex = index; - } - }); - } else if (isAstNode(value)) { - (value as Mutable).$container = node; - (value as Mutable).$containerProperty = name; - } - } - } -} - -function validationAfterImportMerge(model: Model) { - const errors: string[] = []; - const dataSources = model.declarations.filter((d) => isDataSource(d)); - if (dataSources.length === 0) { - errors.push('Validation error: schema must have a datasource declaration'); - } else { - if (dataSources.length > 1) { - errors.push('Validation error: multiple datasource declarations are not allowed'); - } - } - - // at most one `@@auth` model - const decls = getDataModelAndTypeDefs(model, true); - const authDecls = decls.filter((d) => hasAttribute(d, '@@auth')); - if (authDecls.length > 1) { - errors.push('Validation error: Multiple `@@auth` declarations are not allowed'); - } - return errors; -} - +export { loadDocument } from './document'; export * from './module'; diff --git a/packages/language/src/module.ts b/packages/language/src/module.ts index 1b853945..cff4ac0e 100644 --- a/packages/language/src/module.ts +++ b/packages/language/src/module.ts @@ -1,4 +1,4 @@ -import { inject, type DeepPartial, type Module } from 'langium'; +import { DocumentState, inject, URI, type DeepPartial, type Module } from 'langium'; import { createDefaultModule, createDefaultSharedModule, @@ -7,8 +7,13 @@ import { type LangiumSharedServices, type PartialLangiumServices, } from 'langium/lsp'; +import { NodeFileSystem } from 'langium/node'; +import path from 'node:path'; +import { fileURLToPath } from 'node:url'; +import type { Model } from './ast'; import { ZModelGeneratedModule, ZModelGeneratedSharedModule, ZModelLanguageMetaData } from './generated/module'; -import { ZModelValidator, registerValidationChecks } from './validator'; +import { getPluginDocuments } from './utils'; +import { registerValidationChecks, ZModelValidator } from './validator'; import { ZModelDocumentBuilder } from './zmodel-document-builder'; import { ZModelLinker } from './zmodel-linker'; import { ZModelScopeComputation, ZModelScopeProvider } from './zmodel-scope'; @@ -70,7 +75,10 @@ export const ZModelSharedModule: Module { + for (const doc of documents) { + if (doc.parseResult.lexerErrors.length > 0 || doc.parseResult.parserErrors.length > 0) { + // balk if there are lexer or parser errors + continue; + } + + const schemaPath = fileURLToPath(doc.uri.toString()); + const pluginSchemas = getPluginDocuments(doc.parseResult.value as Model, schemaPath); + for (const plugin of pluginSchemas) { + // load the plugin model document + const pluginDoc = await shared.workspace.LangiumDocuments.getOrCreateDocument( + URI.file(path.resolve(plugin)), + ); + // add to indexer so the plugin model's definitions are globally visible + shared.workspace.IndexManager.updateContent(pluginDoc); + if (logToConsole) { + console.log(`Loaded plugin model: ${plugin}`); + } + } + } + }); + return { shared, ZModelLanguage }; } + +// TODO: proper logging system +export function createZModelServices(logToConsole = false) { + return createZModelLanguageServices(NodeFileSystem, logToConsole); +} diff --git a/packages/language/src/utils.ts b/packages/language/src/utils.ts index 762187b3..c361feee 100644 --- a/packages/language/src/utils.ts +++ b/packages/language/src/utils.ts @@ -1,10 +1,10 @@ import { AstUtils, URI, type AstNode, type LangiumDocument, type LangiumDocuments, type Reference } from 'langium'; import fs from 'node:fs'; -import path from 'path'; -import { STD_LIB_MODULE_NAME, type ExpressionContext } from './constants'; +import { createRequire } from 'node:module'; +import path from 'node:path'; +import { fileURLToPath, pathToFileURL } from 'node:url'; +import { PLUGIN_MODULE_NAME, STD_LIB_MODULE_NAME, type ExpressionContext } from './constants'; import { - BinaryExpr, - ConfigExpr, isArrayExpr, isBinaryExpr, isConfigArrayExpr, @@ -17,15 +17,15 @@ import { isMemberAccessExpr, isModel, isObjectExpr, + isPlugin, isReferenceExpr, isStringLiteral, isTypeDef, - Model, - ModelImport, - ReferenceExpr, type Attribute, type AttributeParam, + type BinaryExpr, type BuiltinType, + type ConfigExpr, type DataField, type DataFieldAttribute, type DataModel, @@ -35,6 +35,9 @@ import { type Expression, type ExpressionType, type FunctionDecl, + type Model, + type ModelImport, + type ReferenceExpr, type TypeDef, } from './generated/ast'; @@ -447,8 +450,9 @@ export function getAuthDecl(decls: (DataModel | TypeDef)[]) { return authModel; } +// TODO: move to policy plugin export function isBeforeInvocation(node: AstNode) { - return isInvocationExpr(node) && node.function.ref?.name === 'before' && isFromStdlib(node.function.ref); + return isInvocationExpr(node) && node.function.ref?.name === 'before'; } export function isCollectionPredicate(node: AstNode): node is BinaryExpr { @@ -572,6 +576,91 @@ export function getDocument(node: AstNode): Langium return result as LangiumDocument; } +export function getPluginDocuments(model: Model, schemaPath: string): string[] { + // traverse plugins and collect "plugin.zmodel" documents + const result: string[] = []; + for (const decl of model.declarations.filter(isPlugin)) { + const providerField = decl.fields.find((f) => f.name === 'provider'); + if (!providerField) { + continue; + } + + const provider = getLiteral(providerField.value); + if (!provider) { + continue; + } + + let pluginModelFile: string | undefined; + + // first try to treat provider as a path + let providerPath = path.resolve(path.dirname(schemaPath), provider); + if (fs.existsSync(providerPath)) { + if (fs.statSync(providerPath).isDirectory()) { + providerPath = path.join(providerPath, 'index.js'); + } + + // try plugin.zmodel next to the provider file + pluginModelFile = path.resolve(path.dirname(providerPath), PLUGIN_MODULE_NAME); + if (!fs.existsSync(pluginModelFile)) { + // try to find upwards + pluginModelFile = findUp([PLUGIN_MODULE_NAME], path.dirname(providerPath)); + } + } + + if (!pluginModelFile) { + if (typeof import.meta.resolve === 'function') { + try { + // try loading as a ESM module + const resolvedUrl = import.meta.resolve(`${provider}/${PLUGIN_MODULE_NAME}`); + pluginModelFile = fileURLToPath(resolvedUrl); + } catch { + // noop + } + } + } + + if (!pluginModelFile) { + // try loading as a CJS module + try { + const require = createRequire(pathToFileURL(schemaPath)); + pluginModelFile = require.resolve(`${provider}/${PLUGIN_MODULE_NAME}`); + } catch { + // noop + } + } + + if (pluginModelFile && fs.existsSync(pluginModelFile)) { + result.push(pluginModelFile); + } + } + return result; +} + +type FindUpResult = Multiple extends true ? string[] | undefined : string | undefined; + +function findUp( + names: string[], + cwd: string = process.cwd(), + multiple: Multiple = false as Multiple, + result: string[] = [], +): FindUpResult { + if (!names.some((name) => !!name)) { + return undefined; + } + const target = names.find((name) => fs.existsSync(path.join(cwd, name))); + if (multiple === false && target) { + return path.join(cwd, target) as FindUpResult; + } + if (target) { + result.push(path.join(cwd, target)); + } + const up = path.resolve(cwd, '..'); + if (up === cwd) { + return (multiple && result.length > 0 ? result : undefined) as FindUpResult; + } + return findUp(names, up, multiple, result); +} + /** * Returns the root node of the given AST node by following the `$container` references. */ diff --git a/packages/language/src/validators/function-invocation-validator.ts b/packages/language/src/validators/function-invocation-validator.ts index a740b86e..9e75a18a 100644 --- a/packages/language/src/validators/function-invocation-validator.ts +++ b/packages/language/src/validators/function-invocation-validator.ts @@ -20,7 +20,6 @@ import { getLiteral, isCheckInvocation, isDataFieldReference, - isFromStdlib, typeAssignable, } from '../utils'; import type { AstValidator } from './common'; @@ -52,43 +51,39 @@ export default class FunctionInvocationValidator implements AstValidator(expr.args[0]?.value); - if (arg && !allCasing.includes(arg)) { - accept('error', `argument must be one of: ${allCasing.map((c) => '"' + c + '"').join(', ')}`, { - node: expr.args[0]!, - }); - } + const allCasing = ['original', 'upper', 'lower', 'capitalize', 'uncapitalize']; + if (['currentModel', 'currentOperation'].includes(funcDecl.name)) { + const arg = getLiteral(expr.args[0]?.value); + if (arg && !allCasing.includes(arg)) { + accept('error', `argument must be one of: ${allCasing.map((c) => '"' + c + '"').join(', ')}`, { + node: expr.args[0]!, + }); } } diff --git a/packages/language/src/zmodel-workspace-manager.ts b/packages/language/src/zmodel-workspace-manager.ts index 7b21b56b..f21db797 100644 --- a/packages/language/src/zmodel-workspace-manager.ts +++ b/packages/language/src/zmodel-workspace-manager.ts @@ -1,7 +1,6 @@ import { DefaultWorkspaceManager, URI, - UriUtils, type AstNode, type LangiumDocument, type LangiumDocumentFactory, @@ -11,9 +10,7 @@ import type { LangiumSharedServices } from 'langium/lsp'; import fs from 'node:fs'; import path from 'node:path'; import { fileURLToPath } from 'node:url'; -import { isPlugin, type Model } from './ast'; -import { PLUGIN_MODULE_NAME, STD_LIB_MODULE_NAME } from './constants'; -import { getLiteral } from './utils'; +import { STD_LIB_MODULE_NAME } from './constants'; export class ZModelWorkspaceManager extends DefaultWorkspaceManager { private documentFactory: LangiumDocumentFactory; @@ -71,87 +68,5 @@ export class ZModelWorkspaceManager extends DefaultWorkspaceManager { const stdlib = await this.documentFactory.fromUri(URI.file(stdLibPath)); collector(stdlib); - - const documents = this.langiumDocuments.all; - const pluginModels = new Set(); - - // find plugin models - documents.forEach((doc) => { - const parsed = doc.parseResult.value as Model; - parsed.declarations.forEach((decl) => { - if (isPlugin(decl)) { - const providerField = decl.fields.find((f) => f.name === 'provider'); - if (providerField) { - const provider = getLiteral(providerField.value); - if (provider) { - pluginModels.add(provider); - } - } - } - }); - }); - - if (pluginModels.size > 0) { - console.log(`Used plugin modules: ${Array.from(pluginModels)}`); - - // the loaded plugin models would be removed from the set - const pendingPluginModules = new Set(pluginModels); - - await Promise.all( - folders - .map((wf) => [wf, this.getRootFolder(wf)] as [WorkspaceFolder, URI]) - .map(async (entry) => this.loadPluginModels(...entry, pendingPluginModules, collector)), - ); - } - } - - protected async loadPluginModels( - workspaceFolder: WorkspaceFolder, - folderPath: URI, - pendingPluginModels: Set, - collector: (document: LangiumDocument) => void, - ): Promise { - const content = (await this.fileSystemProvider.readDirectory(folderPath)).sort((a, b) => { - // make sure the node_modules folder is always the first one to be checked - // so we can exit early if the plugin is found - if (a.isDirectory && b.isDirectory) { - const aName = UriUtils.basename(a.uri); - if (aName === 'node_modules') { - return -1; - } else { - return 1; - } - } else { - return 0; - } - }); - - for (const entry of content) { - if (entry.isDirectory) { - const name = UriUtils.basename(entry.uri); - if (name === 'node_modules') { - for (const plugin of Array.from(pendingPluginModels)) { - const path = UriUtils.joinPath(entry.uri, plugin, PLUGIN_MODULE_NAME); - try { - await this.fileSystemProvider.readFile(path); - const document = await this.langiumDocuments.getOrCreateDocument(path); - collector(document); - console.log(`Adding plugin document from ${path.path}`); - - pendingPluginModels.delete(plugin); - // early exit if all plugins are loaded - if (pendingPluginModels.size === 0) { - return; - } - } catch { - // no-op. The module might be found in another node_modules folder - // will show the warning message eventually if not found - } - } - } else { - await this.loadPluginModels(workspaceFolder, entry.uri, pendingPluginModels, collector); - } - } - } } } diff --git a/packages/language/test/utils.ts b/packages/language/test/utils.ts index 4b60ce42..690b2490 100644 --- a/packages/language/test/utils.ts +++ b/packages/language/test/utils.ts @@ -5,11 +5,13 @@ import path from 'node:path'; import { expect } from 'vitest'; import { loadDocument } from '../src'; +const pluginDocs = [path.resolve(__dirname, '../../plugins/policy/plugin.zmodel')]; + export async function loadSchema(schema: string) { // create a temp file const tempFile = path.join(os.tmpdir(), `zenstack-schema-${crypto.randomUUID()}.zmodel`); fs.writeFileSync(tempFile, schema); - const r = await loadDocument(tempFile); + const r = await loadDocument(tempFile, pluginDocs); expect(r).toSatisfy( (r) => r.success, `Failed to load schema: ${(r as any).errors?.map((e) => e.toString()).join(', ')}`, @@ -22,7 +24,8 @@ export async function loadSchemaWithError(schema: string, error: string | RegExp // create a temp file const tempFile = path.join(os.tmpdir(), `zenstack-schema-${crypto.randomUUID()}.zmodel`); fs.writeFileSync(tempFile, schema); - const r = await loadDocument(tempFile); + + const r = await loadDocument(tempFile, pluginDocs); expect(r.success).toBe(false); invariant(!r.success); if (typeof error === 'string') { diff --git a/packages/plugins/policy/package.json b/packages/plugins/policy/package.json index 4c60ab7d..08ffed5c 100644 --- a/packages/plugins/policy/package.json +++ b/packages/plugins/policy/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/plugin-policy", - "version": "3.0.0-beta.9", + "version": "3.0.0-beta.10", "description": "ZenStack Policy Plugin", "type": "module", "scripts": { @@ -13,7 +13,8 @@ "author": "ZenStack Team", "license": "MIT", "files": [ - "dist" + "dist", + "plugin.zmodel" ], "exports": { ".": { @@ -26,6 +27,10 @@ "default": "./dist/index.cjs" } }, + "./plugin.zmodel": { + "import": "./plugin.zmodel", + "require": "./plugin.zmodel" + }, "./package.json": { "import": "./package.json", "require": "./package.json" @@ -33,7 +38,6 @@ }, "dependencies": { "@zenstackhq/common-helpers": "workspace:*", - "@zenstackhq/sdk": "workspace:*", "@zenstackhq/runtime": "workspace:*", "ts-pattern": "catalog:" }, diff --git a/packages/plugins/policy/plugin.zmodel b/packages/plugins/policy/plugin.zmodel new file mode 100644 index 00000000..ae5e11cf --- /dev/null +++ b/packages/plugins/policy/plugin.zmodel @@ -0,0 +1,72 @@ +/** + * Defines an access policy that allows a set of operations when the given condition is true. + * + * @param operation: comma-separated list of "create", "read", "update", "post-update", "delete". Use "all" to denote all operations. + * @param condition: a boolean expression that controls if the operation should be allowed. + */ +attribute @@allow(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'post-update'","'delete'", "'all'"]), _ condition: Boolean) + +/** + * Defines an access policy that allows the annotated field to be read or updated. + * You can pass a third argument as `true` to make it override the model-level policies. + * + * @param operation: comma-separated list of "create", "read", "update", "post-update", "delete". Use "all" to denote all operations. + * @param condition: a boolean expression that controls if the operation should be allowed. + * @param override: a boolean value that controls if the field-level policy should override the model-level policy. + */ +// attribute @allow(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'post-update'", "'delete'", "'all'"]), _ condition: Boolean, _ override: Boolean?) + +/** + * Defines an access policy that denies a set of operations when the given condition is true. + * + * @param operation: comma-separated list of "create", "read", "update", "post-update", "delete". Use "all" to denote all operations. + * @param condition: a boolean expression that controls if the operation should be denied. + */ +attribute @@deny(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'post-update'","'delete'", "'all'"]), _ condition: Boolean) + +/** + * Defines an access policy that denies the annotated field to be read or updated. + * + * @param operation: comma-separated list of "create", "read", "update", "post-update", "delete". Use "all" to denote all operations. + * @param condition: a boolean expression that controls if the operation should be denied. + */ +// attribute @deny(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'delete'", "'all'"]), _ condition: Boolean) + +/** + * Checks if the current user can perform the given operation on the given field. + * + * @param field: The field to check access for + * @param operation: The operation to check access for. Can be "read", "create", "update", "post-update", or "delete". If the operation is not provided, + * it defaults the operation of the containing policy rule. + */ +function check(field: Any, operation: String?): Boolean { +} @@@expressionContext([AccessPolicy]) + +/** + * Gets entity's value before an update. Only valid when used in a "post-update" policy rule. + */ +function before(): Any { +} @@@expressionContext([AccessPolicy]) + +/** + * The name of the model for which the policy rule is defined. If the rule is + * inherited to a sub model, this function returns the name of the sub model. + * + * @param optional parameter to control the casing of the returned value. Valid + * values are "original", "upper", "lower", "capitalize", "uncapitalize". Defaults + * to "original". + */ +function currentModel(casing: String?): String { +} @@@expressionContext([AccessPolicy]) + +/** + * The operation for which the policy rule is defined for. Note that a rule with + * "all" operation is expanded to "create", "read", "update", and "delete" rules, + * and the function returns corresponding value for each expanded version. + * + * @param optional parameter to control the casing of the returned value. Valid + * values are "original", "upper", "lower", "capitalize", "uncapitalize". Defaults + * to "original". + */ +function currentOperation(casing: String?): String { +} @@@expressionContext([AccessPolicy]) diff --git a/packages/plugins/policy/src/column-collector.ts b/packages/plugins/policy/src/column-collector.ts index 37d9df11..ed68c76d 100644 --- a/packages/plugins/policy/src/column-collector.ts +++ b/packages/plugins/policy/src/column-collector.ts @@ -1,10 +1,10 @@ +import { KyselyUtils } from '@zenstackhq/runtime'; import type { ColumnNode, OperationNode } from 'kysely'; -import { DefaultOperationNodeVisitor } from '@zenstackhq/sdk'; /** * Collects all column names from a query. */ -export class ColumnCollector extends DefaultOperationNodeVisitor { +export class ColumnCollector extends KyselyUtils.DefaultOperationNodeVisitor { private columns: string[] = []; collect(node: OperationNode) { diff --git a/packages/plugins/policy/src/policy-handler.ts b/packages/plugins/policy/src/policy-handler.ts index 92f5e74c..c90bb563 100644 --- a/packages/plugins/policy/src/policy-handler.ts +++ b/packages/plugins/policy/src/policy-handler.ts @@ -7,6 +7,7 @@ import { QueryUtils, RejectedByPolicyError, RejectedByPolicyReason, + SchemaUtils, type CRUD_EXT, } from '@zenstackhq/runtime'; import { @@ -17,7 +18,6 @@ import { type MemberExpression, type SchemaDef, } from '@zenstackhq/runtime/schema'; -import { ExpressionVisitor } from '@zenstackhq/sdk'; import { AliasNode, BinaryOperationNode, @@ -270,7 +270,7 @@ export class PolicyHandler extends OperationNodeTransf } const fields = new Set(); - const fieldCollector = new (class extends ExpressionVisitor { + const fieldCollector = new (class extends SchemaUtils.ExpressionVisitor { protected override visitMember(e: MemberExpression): void { if (isBeforeInvocation(e.receiver)) { invariant(e.members.length === 1, 'before() can only be followed by a scalar field access'); diff --git a/packages/runtime/package.json b/packages/runtime/package.json index 8d2cdc86..4d1db5d6 100644 --- a/packages/runtime/package.json +++ b/packages/runtime/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/runtime", - "version": "3.0.0-beta.9", + "version": "3.0.0-beta.10", "description": "ZenStack Runtime", "type": "module", "scripts": { diff --git a/packages/runtime/src/index.ts b/packages/runtime/src/index.ts index b5018f05..cff4960f 100644 --- a/packages/runtime/src/index.ts +++ b/packages/runtime/src/index.ts @@ -1,2 +1,4 @@ export * from './client'; +export * as KyselyUtils from './utils/kysely-utils'; +export * as SchemaUtils from './utils/schema-utils'; export type { JsonArray, JsonObject, JsonValue } from './utils/type-utils'; diff --git a/packages/sdk/src/default-operation-node-visitor.ts b/packages/runtime/src/utils/kysely-utils.ts similarity index 100% rename from packages/sdk/src/default-operation-node-visitor.ts rename to packages/runtime/src/utils/kysely-utils.ts diff --git a/packages/sdk/src/expression-utils.ts b/packages/runtime/src/utils/schema-utils.ts similarity index 98% rename from packages/sdk/src/expression-utils.ts rename to packages/runtime/src/utils/schema-utils.ts index ec423767..8c0824d4 100644 --- a/packages/sdk/src/expression-utils.ts +++ b/packages/runtime/src/utils/schema-utils.ts @@ -10,7 +10,7 @@ import type { NullExpression, ThisExpression, UnaryExpression, -} from './schema'; +} from '../schema'; export class ExpressionVisitor { visit(expr: Expression): void { diff --git a/packages/sdk/package.json b/packages/sdk/package.json index 895e42bc..daf9e27e 100644 --- a/packages/sdk/package.json +++ b/packages/sdk/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/sdk", - "version": "3.0.0-beta.9", + "version": "3.0.0-beta.10", "description": "ZenStack SDK", "type": "module", "scripts": { diff --git a/packages/sdk/src/index.ts b/packages/sdk/src/index.ts index c74b1419..649a7201 100644 --- a/packages/sdk/src/index.ts +++ b/packages/sdk/src/index.ts @@ -1,7 +1,5 @@ import * as ModelUtils from './model-utils'; export * from './cli-plugin'; -export * from './expression-utils'; -export * from './default-operation-node-visitor'; export { PrismaSchemaGenerator } from './prisma/prisma-schema-generator'; export * from './ts-schema-generator'; export * from './zmodel-code-generator'; diff --git a/packages/sdk/src/model-utils.ts b/packages/sdk/src/model-utils.ts index 7b54aa96..198e59ac 100644 --- a/packages/sdk/src/model-utils.ts +++ b/packages/sdk/src/model-utils.ts @@ -3,7 +3,6 @@ import { isLiteralExpr, isModel, isTypeDef, - Model, type AstNode, type Attribute, type AttributeParam, @@ -14,6 +13,7 @@ import { type Enum, type EnumField, type FunctionDecl, + type Model, type Reference, type TypeDef, } from '@zenstackhq/language/ast'; diff --git a/packages/tanstack-query/package.json b/packages/tanstack-query/package.json index f96af3b9..92cb74de 100644 --- a/packages/tanstack-query/package.json +++ b/packages/tanstack-query/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/tanstack-query", - "version": "3.0.0-beta.9", + "version": "3.0.0-beta.10", "description": "", "main": "index.js", "type": "module", diff --git a/packages/testtools/package.json b/packages/testtools/package.json index 29d574d9..4c0f2ba3 100644 --- a/packages/testtools/package.json +++ b/packages/testtools/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/testtools", - "version": "3.0.0-beta.9", + "version": "3.0.0-beta.10", "description": "ZenStack Test Tools", "type": "module", "scripts": { diff --git a/packages/testtools/src/client.ts b/packages/testtools/src/client.ts index 4a23280e..27133f83 100644 --- a/packages/testtools/src/client.ts +++ b/packages/testtools/src/client.ts @@ -1,5 +1,4 @@ import { invariant } from '@zenstackhq/common-helpers'; -import { loadDocument } from '@zenstackhq/language'; import type { Model } from '@zenstackhq/language/ast'; import { PolicyPlugin } from '@zenstackhq/plugin-policy'; import { ZenStackClient, type ClientContract, type ClientOptions } from '@zenstackhq/runtime'; @@ -15,6 +14,7 @@ import { Client as PGClient, Pool } from 'pg'; import { expect } from 'vitest'; import { createTestProject } from './project'; import { generateTsSchema } from './schema'; +import { loadDocumentWithPlugins } from './utils'; export function getTestDbProvider() { const val = process.env['TEST_DB_PROVIDER'] ?? 'sqlite'; @@ -116,7 +116,7 @@ export async function createTestClient( if (options?.usePrismaPush) { invariant(typeof schema === 'string' || schemaFile, 'a schema file must be provided when using prisma db push'); if (!model) { - const r = await loadDocument(path.join(workDir, 'schema.zmodel')); + const r = await loadDocumentWithPlugins(path.join(workDir, 'schema.zmodel')); if (!r.success) { throw new Error(r.errors.join('\n')); } diff --git a/packages/testtools/src/schema.ts b/packages/testtools/src/schema.ts index 516d445a..b78bcbc8 100644 --- a/packages/testtools/src/schema.ts +++ b/packages/testtools/src/schema.ts @@ -1,5 +1,4 @@ import { invariant } from '@zenstackhq/common-helpers'; -import { loadDocument } from '@zenstackhq/language'; import { TsSchemaGenerator } from '@zenstackhq/sdk'; import type { SchemaDef } from '@zenstackhq/sdk/schema'; import { execSync } from 'node:child_process'; @@ -10,6 +9,7 @@ import path from 'node:path'; import { match } from 'ts-pattern'; import { expect } from 'vitest'; import { createTestProject } from './project'; +import { loadDocumentWithPlugins } from './utils'; function makePrelude(provider: 'sqlite' | 'postgresql', dbUrl?: string) { return match(provider) @@ -44,7 +44,7 @@ export async function generateTsSchema( const noPrelude = schemaText.includes('datasource '); fs.writeFileSync(zmodelPath, `${noPrelude ? '' : makePrelude(provider, dbUrl)}\n\n${schemaText}`); - const result = await loadDocument(zmodelPath); + const result = await loadDocumentWithPlugins(zmodelPath); if (!result.success) { throw new Error(`Failed to load schema from ${zmodelPath}: ${result.errors}`); } @@ -82,7 +82,7 @@ export function generateTsSchemaFromFile(filePath: string) { export async function generateTsSchemaInPlace(schemaPath: string) { const workDir = path.dirname(schemaPath); - const result = await loadDocument(schemaPath); + const result = await loadDocumentWithPlugins(schemaPath); if (!result.success) { throw new Error(`Failed to load schema from ${schemaPath}: ${result.errors}`); } @@ -114,7 +114,7 @@ export async function loadSchema(schema: string, additionalSchemas?: Record r.success, `Failed to load schema: ${(r as any).errors?.map((e: any) => e.toString()).join(', ')}`, @@ -131,7 +131,7 @@ export async function loadSchemaWithError(schema: string, error: string | RegExp // create a temp file const tempFile = path.join(os.tmpdir(), `zenstack-schema-${crypto.randomUUID()}.zmodel`); fs.writeFileSync(tempFile, schema); - const r = await loadDocument(tempFile); + const r = await loadDocumentWithPlugins(tempFile); expect(r.success).toBe(false); invariant(!r.success); if (typeof error === 'string') { diff --git a/packages/testtools/src/utils.ts b/packages/testtools/src/utils.ts new file mode 100644 index 00000000..1f8119fe --- /dev/null +++ b/packages/testtools/src/utils.ts @@ -0,0 +1,6 @@ +import { loadDocument } from '@zenstackhq/language'; + +export function loadDocumentWithPlugins(filePath: string) { + const pluginModelFiles = [require.resolve('@zenstackhq/plugin-policy/plugin.zmodel')]; + return loadDocument(filePath, pluginModelFiles); +} diff --git a/packages/zod/package.json b/packages/zod/package.json index 0de8db0e..0ee27d46 100644 --- a/packages/zod/package.json +++ b/packages/zod/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/zod", - "version": "3.0.0-beta.9", + "version": "3.0.0-beta.10", "description": "", "type": "module", "main": "index.js", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index eff0e66f..aa32c4a2 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -282,9 +282,6 @@ importers: '@zenstackhq/runtime': specifier: workspace:* version: link:../../runtime - '@zenstackhq/sdk': - specifier: workspace:* - version: link:../../sdk kysely: specifier: 'catalog:' version: 0.27.6 @@ -510,6 +507,9 @@ importers: samples/blog: dependencies: + '@zenstackhq/plugin-policy': + specifier: workspace:* + version: link:../../packages/plugins/policy '@zenstackhq/runtime': specifier: workspace:* version: link:../../packages/runtime diff --git a/samples/blog/main.ts b/samples/blog/main.ts index 8bbfb5bf..d0c82fc0 100644 --- a/samples/blog/main.ts +++ b/samples/blog/main.ts @@ -1,3 +1,4 @@ +import { PolicyPlugin } from '@zenstackhq/plugin-policy'; import { ZenStackClient } from '@zenstackhq/runtime'; import SQLite from 'better-sqlite3'; import { sql, SqliteDialect } from 'kysely'; @@ -89,6 +90,17 @@ async function main() { }, }); console.log('User found with computed field:', userWithMorePosts); + + // policy-enabled read + const authDb = db.$use(new PolicyPlugin()); + const user1Db = authDb.$setAuth({ id: user1.id }); + const user2Db = authDb.$setAuth({ id: user2.id }); + + console.log('Posts readable to', user1.email); + console.table(await user1Db.post.findMany({ select: { title: true, published: true } })); + + console.log('Posts readable to', user2.email); + console.table(await user2Db.post.findMany({ select: { title: true, published: true } })); } main(); diff --git a/samples/blog/package.json b/samples/blog/package.json index 9298036c..b68731ef 100644 --- a/samples/blog/package.json +++ b/samples/blog/package.json @@ -1,6 +1,6 @@ { "name": "sample-blog", - "version": "3.0.0-beta.9", + "version": "3.0.0-beta.10", "description": "", "main": "index.js", "scripts": { @@ -15,6 +15,7 @@ "license": "MIT", "dependencies": { "@zenstackhq/runtime": "workspace:*", + "@zenstackhq/plugin-policy": "workspace:*", "better-sqlite3": "^12.2.0", "kysely": "catalog:" }, diff --git a/samples/blog/zenstack/models.ts b/samples/blog/zenstack/models.ts index 86b941a1..608e9397 100644 --- a/samples/blog/zenstack/models.ts +++ b/samples/blog/zenstack/models.ts @@ -9,8 +9,6 @@ import { schema as $schema, type SchemaType as $Schema } from "./schema"; import { type ModelResult as $ModelResult, type TypeDefResult as $TypeDefResult } from "@zenstackhq/runtime"; /** * User model - * - * Represents a user of the blog. */ export type User = $ModelResult<$Schema, "User">; /** diff --git a/samples/blog/zenstack/schema.ts b/samples/blog/zenstack/schema.ts index 4ca14e3e..fa51f40f 100644 --- a/samples/blog/zenstack/schema.ts +++ b/samples/blog/zenstack/schema.ts @@ -69,6 +69,10 @@ export const schema = { relation: { opposite: "user" } } }, + attributes: [ + { name: "@@allow", args: [{ name: "operation", value: ExpressionUtils.literal("read,create") }, { name: "condition", value: ExpressionUtils.literal(true) }] }, + { name: "@@allow", args: [{ name: "operation", value: ExpressionUtils.literal("all") }, { name: "condition", value: ExpressionUtils.binary(ExpressionUtils._this(), "==", ExpressionUtils.call("auth")) }] } + ], idFields: ["id"], uniqueFields: { id: { type: "String" }, @@ -132,6 +136,9 @@ export const schema = { ] } }, + attributes: [ + { name: "@@allow", args: [{ name: "operation", value: ExpressionUtils.literal("all") }, { name: "condition", value: ExpressionUtils.call("check", [ExpressionUtils.field("user")]) }] } + ], idFields: ["id"], uniqueFields: { id: { type: "String" }, @@ -188,6 +195,10 @@ export const schema = { ] } }, + attributes: [ + { name: "@@allow", args: [{ name: "operation", value: ExpressionUtils.literal("read") }, { name: "condition", value: ExpressionUtils.field("published") }] }, + { name: "@@allow", args: [{ name: "operation", value: ExpressionUtils.literal("all") }, { name: "condition", value: ExpressionUtils.binary(ExpressionUtils.field("author"), "==", ExpressionUtils.call("auth")) }] } + ], idFields: ["id"], uniqueFields: { id: { type: "String" } diff --git a/samples/blog/zenstack/schema.zmodel b/samples/blog/zenstack/schema.zmodel index aeccb56f..6cf112e2 100644 --- a/samples/blog/zenstack/schema.zmodel +++ b/samples/blog/zenstack/schema.zmodel @@ -9,6 +9,11 @@ enum Role { USER } +plugin policy { + // due to pnpm layout we can't directly use package name here + provider = '../node_modules/@zenstackhq/plugin-policy/dist/index.js' +} + type CommonFields { id String @id @default(cuid()) createdAt DateTime @default(now()) @@ -16,8 +21,6 @@ type CommonFields { } /// User model -/// -/// Represents a user of the blog. model User with CommonFields { email String @unique name String? @@ -25,6 +28,9 @@ model User with CommonFields { role Role @default(USER) posts Post[] profile Profile? + + @@allow('read,create', true) + @@allow('all', this == auth()) } /// Profile model @@ -33,6 +39,7 @@ model Profile with CommonFields { age Int? user User? @relation(fields: [userId], references: [id]) userId String? @unique + @@allow('all', check(user)) } /// Post model @@ -42,4 +49,7 @@ model Post with CommonFields { published Boolean @default(false) author User @relation(fields: [authorId], references: [id]) authorId String + + @@allow('read', published) + @@allow('all', author == auth()) } diff --git a/tests/e2e/orm/policy/migrated/current-model.test.ts b/tests/e2e/orm/policy/migrated/current-model.test.ts index 04c9a30a..0ceb96c3 100644 --- a/tests/e2e/orm/policy/migrated/current-model.test.ts +++ b/tests/e2e/orm/policy/migrated/current-model.test.ts @@ -158,7 +158,7 @@ describe('currentModel tests', () => { createPolicyTestClient( ` model User { - id String @default(currentModel()) + id String @id @default(currentModel()) } `, ), diff --git a/tests/e2e/orm/policy/migrated/current-operation.test.ts b/tests/e2e/orm/policy/migrated/current-operation.test.ts index 2b1610c9..3cbae4ca 100644 --- a/tests/e2e/orm/policy/migrated/current-operation.test.ts +++ b/tests/e2e/orm/policy/migrated/current-operation.test.ts @@ -127,7 +127,7 @@ describe('currentOperation tests', () => { createPolicyTestClient( ` model User { - id String @default(currentOperation()) + id String @id @default(currentOperation()) } `, ), diff --git a/tests/e2e/orm/scripts/generate.ts b/tests/e2e/orm/scripts/generate.ts index b4f78958..9d59db73 100644 --- a/tests/e2e/orm/scripts/generate.ts +++ b/tests/e2e/orm/scripts/generate.ts @@ -1,4 +1,5 @@ import { loadDocument } from '@zenstackhq/language'; +import type { Model } from '@zenstackhq/language/ast'; import { TsSchemaGenerator } from '@zenstackhq/sdk'; import { glob } from 'glob'; import path from 'node:path'; @@ -17,11 +18,18 @@ async function main() { async function generate(schemaPath: string) { const generator = new TsSchemaGenerator(); const outputDir = path.dirname(schemaPath); - const result = await loadDocument(schemaPath); + + // isomorphic __dirname + const _dirname = typeof __dirname !== 'undefined' ? __dirname : path.dirname(fileURLToPath(import.meta.url)); + + // plugin models + const pluginDocs = [path.resolve(_dirname, '../../node_modules/@zenstackhq/plugin-policy/plugin.zmodel')]; + + const result = await loadDocument(schemaPath, pluginDocs); if (!result.success) { throw new Error(`Failed to load schema from ${schemaPath}: ${result.errors}`); } - await generator.generate(result.model, outputDir); + await generator.generate(result.model as Model, outputDir); } main(); diff --git a/tests/e2e/package.json b/tests/e2e/package.json index bcf2a10f..101221c2 100644 --- a/tests/e2e/package.json +++ b/tests/e2e/package.json @@ -1,6 +1,6 @@ { "name": "e2e", - "version": "3.0.0-beta.9", + "version": "3.0.0-beta.10", "private": true, "type": "module", "scripts": { diff --git a/tests/regression/package.json b/tests/regression/package.json index 9bf71b47..773a689d 100644 --- a/tests/regression/package.json +++ b/tests/regression/package.json @@ -1,6 +1,6 @@ { "name": "regression", - "version": "3.0.0-beta.9", + "version": "3.0.0-beta.10", "private": true, "type": "module", "scripts": {