Skip to content

Commit 59dfa73

Browse files
authored
refactor: move policy attributes to its own zmodel (#307)
* refactor: move policy attributes to its own zmodel * update * update * fix tests
1 parent 2298fc9 commit 59dfa73

File tree

30 files changed

+261
-149
lines changed

30 files changed

+261
-149
lines changed

.vscode/launch.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
"skipFiles": ["<node_internals>/**"],
1212
"type": "node",
1313
"args": ["generate"],
14-
"cwd": "${workspaceFolder}/samples/blog/zenstack"
14+
"cwd": "${workspaceFolder}/samples/blog"
1515
},
1616
{
1717
"name": "Debug with TSX",

packages/cli/src/actions/action-utils.ts

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1-
import { loadDocument } from '@zenstackhq/language';
2-
import { isDataSource } from '@zenstackhq/language/ast';
1+
import { createZModelServices, loadDocument, type ZModelServices } from '@zenstackhq/language';
2+
import { isDataSource, isPlugin, Model } from '@zenstackhq/language/ast';
3+
import { getLiteral } from '@zenstackhq/language/utils';
34
import { PrismaSchemaGenerator } from '@zenstackhq/sdk';
45
import colors from 'colors';
56
import fs from 'node:fs';
67
import path from 'node:path';
8+
import { fileURLToPath } from 'node:url';
79
import { CliError } from '../cli-error';
10+
import { PLUGIN_MODULE_NAME } from '../constants';
811

912
export function getSchemaFile(file?: string) {
1013
if (file) {
@@ -34,7 +37,9 @@ export function getSchemaFile(file?: string) {
3437
}
3538

3639
export async function loadSchemaDocument(schemaFile: string) {
37-
const loadResult = await loadDocument(schemaFile);
40+
const { ZModelLanguage: services } = createZModelServices();
41+
const pluginDocs = await getPluginDocuments(services, schemaFile);
42+
const loadResult = await loadDocument(schemaFile, pluginDocs);
3843
if (!loadResult.success) {
3944
loadResult.errors.forEach((err) => {
4045
console.error(colors.red(err));
@@ -47,6 +52,63 @@ export async function loadSchemaDocument(schemaFile: string) {
4752
return loadResult.model;
4853
}
4954

55+
export async function getPluginDocuments(services: ZModelServices, fileName: string): Promise<string[]> {
56+
// parse the user document (without validation)
57+
const parseResult = services.parser.LangiumParser.parse(fs.readFileSync(fileName, { encoding: 'utf-8' }));
58+
const parsed = parseResult.value as Model;
59+
60+
// balk if there are syntax errors
61+
if (parseResult.lexerErrors.length > 0 || parseResult.parserErrors.length > 0) {
62+
return [];
63+
}
64+
65+
// traverse plugins and collect "plugin.zmodel" documents
66+
const result: string[] = [];
67+
for (const decl of parsed.declarations.filter(isPlugin)) {
68+
const providerField = decl.fields.find((f) => f.name === 'provider');
69+
if (!providerField) {
70+
continue;
71+
}
72+
73+
const provider = getLiteral<string>(providerField.value);
74+
if (!provider) {
75+
continue;
76+
}
77+
78+
let pluginModelFile: string | undefined;
79+
80+
// first try to treat provider as a path
81+
let providerPath = path.resolve(path.dirname(fileName), provider);
82+
if (fs.existsSync(providerPath)) {
83+
if (fs.statSync(providerPath).isDirectory()) {
84+
providerPath = path.join(providerPath, 'index.js');
85+
}
86+
87+
// try plugin.zmodel next to the provider file
88+
pluginModelFile = path.resolve(path.dirname(providerPath), PLUGIN_MODULE_NAME);
89+
if (!fs.existsSync(pluginModelFile)) {
90+
// try to find upwards
91+
pluginModelFile = findUp([PLUGIN_MODULE_NAME], path.dirname(providerPath));
92+
}
93+
}
94+
95+
if (!pluginModelFile) {
96+
// try loading it as a ESM module
97+
try {
98+
const resolvedUrl = import.meta.resolve(`${provider}/${PLUGIN_MODULE_NAME}`);
99+
pluginModelFile = fileURLToPath(resolvedUrl);
100+
} catch {
101+
// noop
102+
}
103+
}
104+
105+
if (pluginModelFile && fs.existsSync(pluginModelFile)) {
106+
result.push(pluginModelFile);
107+
}
108+
}
109+
return result;
110+
}
111+
50112
export function handleSubProcessError(err: unknown) {
51113
if (err instanceof Error && 'status' in err && typeof err.status === 'number') {
52114
process.exit(err.status);

packages/cli/src/actions/generate.ts

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ async function runPlugins(schemaFile: string, model: Model, outputPath: string,
6464
for (const plugin of plugins) {
6565
const provider = getPluginProvider(plugin);
6666

67-
let cliPlugin: CliPlugin;
67+
let cliPlugin: CliPlugin | undefined;
6868
if (provider.startsWith('@core/')) {
6969
cliPlugin = (corePlugins as any)[provider.slice('@core/'.length)];
7070
if (!cliPlugin) {
@@ -78,12 +78,14 @@ async function runPlugins(schemaFile: string, model: Model, outputPath: string,
7878
}
7979
try {
8080
cliPlugin = (await import(moduleSpec)).default as CliPlugin;
81-
} catch (error) {
82-
throw new CliError(`Failed to load plugin ${provider}: ${error}`);
81+
} catch {
82+
// plugin may not export a generator so we simply ignore the error here
8383
}
8484
}
8585

86-
processedPlugins.push({ cliPlugin, pluginOptions: getPluginOptions(plugin) });
86+
if (cliPlugin) {
87+
processedPlugins.push({ cliPlugin, pluginOptions: getPluginOptions(plugin) });
88+
}
8789
}
8890

8991
const defaultPlugins = [corePlugins['typescript']].reverse();

packages/cli/src/constants.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
// replaced at build time
22
export const TELEMETRY_TRACKING_TOKEN = '<TELEMETRY_TRACKING_TOKEN>';
3+
4+
// plugin-contributed model file name
5+
export const PLUGIN_MODULE_NAME = 'plugin.zmodel';

packages/language/res/stdlib.zmodel

Lines changed: 0 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -174,29 +174,6 @@ function hasSome(field: Any[], search: Any[]): Boolean {
174174
function isEmpty(field: Any[]): Boolean {
175175
} @@@expressionContext([AccessPolicy, ValidationRule])
176176

177-
/**
178-
* The name of the model for which the policy rule is defined. If the rule is
179-
* inherited to a sub model, this function returns the name of the sub model.
180-
*
181-
* @param optional parameter to control the casing of the returned value. Valid
182-
* values are "original", "upper", "lower", "capitalize", "uncapitalize". Defaults
183-
* to "original".
184-
*/
185-
function currentModel(casing: String?): String {
186-
} @@@expressionContext([AccessPolicy])
187-
188-
/**
189-
* The operation for which the policy rule is defined for. Note that a rule with
190-
* "all" operation is expanded to "create", "read", "update", and "delete" rules,
191-
* and the function returns corresponding value for each expanded version.
192-
*
193-
* @param optional parameter to control the casing of the returned value. Valid
194-
* values are "original", "upper", "lower", "capitalize", "uncapitalize". Defaults
195-
* to "original".
196-
*/
197-
function currentOperation(casing: String?): String {
198-
} @@@expressionContext([AccessPolicy])
199-
200177
/**
201178
* Marks an attribute to be only applicable to certain field types.
202179
*/
@@ -658,56 +635,3 @@ attribute @meta(_ name: String, _ value: Any)
658635
* Marks an attribute as deprecated.
659636
*/
660637
attribute @@@deprecated(_ message: String)
661-
662-
/* --- Policy Plugin --- */
663-
664-
/**
665-
* Defines an access policy that allows a set of operations when the given condition is true.
666-
*
667-
* @param operation: comma-separated list of "create", "read", "update", "delete". Use "all" to denote all operations.
668-
* @param condition: a boolean expression that controls if the operation should be allowed.
669-
*/
670-
attribute @@allow(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'post-update'","'delete'", "'all'"]), _ condition: Boolean)
671-
672-
/**
673-
* Defines an access policy that allows the annotated field to be read or updated.
674-
* You can pass a third argument as `true` to make it override the model-level policies.
675-
*
676-
* @param operation: comma-separated list of "create", "read", "update", "delete". Use "all" to denote all operations.
677-
* @param condition: a boolean expression that controls if the operation should be allowed.
678-
* @param override: a boolean value that controls if the field-level policy should override the model-level policy.
679-
*/
680-
// attribute @allow(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'delete'", "'all'"]), _ condition: Boolean, _ override: Boolean?)
681-
682-
/**
683-
* Defines an access policy that denies a set of operations when the given condition is true.
684-
*
685-
* @param operation: comma-separated list of "create", "read", "update", "delete". Use "all" to denote all operations.
686-
* @param condition: a boolean expression that controls if the operation should be denied.
687-
*/
688-
attribute @@deny(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'post-update'","'delete'", "'all'"]), _ condition: Boolean)
689-
690-
/**
691-
* Defines an access policy that denies the annotated field to be read or updated.
692-
*
693-
* @param operation: comma-separated list of "create", "read", "update", "delete". Use "all" to denote all operations.
694-
* @param condition: a boolean expression that controls if the operation should be denied.
695-
*/
696-
// attribute @deny(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'delete'", "'all'"]), _ condition: Boolean)
697-
698-
/**
699-
* Checks if the current user can perform the given operation on the given field.
700-
*
701-
* @param field: The field to check access for
702-
* @param operation: The operation to check access for. Can be "read", "create", "update", or "delete". If the operation is not provided,
703-
* it defaults the operation of the containing policy rule.
704-
*/
705-
function check(field: Any, operation: String?): Boolean {
706-
} @@@expressionContext([AccessPolicy])
707-
708-
/**
709-
* Gets entity's value before an update. Only valid when used in a "post-update" policy rule.
710-
*/
711-
function before(): Any {
712-
} @@@expressionContext([AccessPolicy])
713-

packages/language/src/index.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ export async function loadDocument(
5151
);
5252

5353
// load additional model files
54-
const pluginDocs = await Promise.all(
54+
const additionalDocs = await Promise.all(
5555
additionalModelFiles.map((file) =>
5656
services.shared.workspace.LangiumDocuments.getOrCreateDocument(URI.file(path.resolve(file))),
5757
),
@@ -69,7 +69,7 @@ export async function loadDocument(
6969
}
7070

7171
// build the document together with standard library, plugin modules, and imported documents
72-
await services.shared.workspace.DocumentBuilder.build([stdLib, ...pluginDocs, document, ...importedDocuments], {
72+
await services.shared.workspace.DocumentBuilder.build([stdLib, ...additionalDocs, document, ...importedDocuments], {
7373
validation: {
7474
stopAfterLexingErrors: true,
7575
stopAfterParsingErrors: true,

packages/language/src/utils.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,8 +447,9 @@ export function getAuthDecl(decls: (DataModel | TypeDef)[]) {
447447
return authModel;
448448
}
449449

450+
// TODO: move to policy plugin
450451
export function isBeforeInvocation(node: AstNode) {
451-
return isInvocationExpr(node) && node.function.ref?.name === 'before' && isFromStdlib(node.function.ref);
452+
return isInvocationExpr(node) && node.function.ref?.name === 'before';
452453
}
453454

454455
export function isCollectionPredicate(node: AstNode): node is BinaryExpr {

packages/language/src/validators/function-invocation-validator.ts

Lines changed: 27 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ import {
2020
getLiteral,
2121
isCheckInvocation,
2222
isDataFieldReference,
23-
isFromStdlib,
2423
typeAssignable,
2524
} from '../utils';
2625
import type { AstValidator } from './common';
@@ -52,43 +51,39 @@ export default class FunctionInvocationValidator implements AstValidator<Express
5251
return;
5352
}
5453

55-
if (isFromStdlib(funcDecl)) {
56-
// validate standard library functions
57-
58-
// find the containing attribute context for the invocation
59-
let curr: AstNode | undefined = expr.$container;
60-
let containerAttribute: DataModelAttribute | DataFieldAttribute | undefined;
61-
while (curr) {
62-
if (isDataModelAttribute(curr) || isDataFieldAttribute(curr)) {
63-
containerAttribute = curr;
64-
break;
65-
}
66-
curr = curr.$container;
54+
// find the containing attribute context for the invocation
55+
let curr: AstNode | undefined = expr.$container;
56+
let containerAttribute: DataModelAttribute | DataFieldAttribute | undefined;
57+
while (curr) {
58+
if (isDataModelAttribute(curr) || isDataFieldAttribute(curr)) {
59+
containerAttribute = curr;
60+
break;
6761
}
62+
curr = curr.$container;
63+
}
6864

69-
// validate the context allowed for the function
70-
const exprContext = this.getExpressionContext(containerAttribute);
65+
// validate the context allowed for the function
66+
const exprContext = this.getExpressionContext(containerAttribute);
7167

72-
// get the context allowed for the function
73-
const funcAllowedContext = getFunctionExpressionContext(funcDecl);
68+
// get the context allowed for the function
69+
const funcAllowedContext = getFunctionExpressionContext(funcDecl);
7470

75-
if (exprContext && !funcAllowedContext.includes(exprContext)) {
76-
accept('error', `function "${funcDecl.name}" is not allowed in the current context: ${exprContext}`, {
77-
node: expr,
78-
});
79-
return;
80-
}
71+
if (exprContext && !funcAllowedContext.includes(exprContext)) {
72+
accept('error', `function "${funcDecl.name}" is not allowed in the current context: ${exprContext}`, {
73+
node: expr,
74+
});
75+
return;
76+
}
8177

82-
// TODO: express function validation rules declaratively in ZModel
78+
// TODO: express function validation rules declaratively in ZModel
8379

84-
const allCasing = ['original', 'upper', 'lower', 'capitalize', 'uncapitalize'];
85-
if (['currentModel', 'currentOperation'].includes(funcDecl.name)) {
86-
const arg = getLiteral<string>(expr.args[0]?.value);
87-
if (arg && !allCasing.includes(arg)) {
88-
accept('error', `argument must be one of: ${allCasing.map((c) => '"' + c + '"').join(', ')}`, {
89-
node: expr.args[0]!,
90-
});
91-
}
80+
const allCasing = ['original', 'upper', 'lower', 'capitalize', 'uncapitalize'];
81+
if (['currentModel', 'currentOperation'].includes(funcDecl.name)) {
82+
const arg = getLiteral<string>(expr.args[0]?.value);
83+
if (arg && !allCasing.includes(arg)) {
84+
accept('error', `argument must be one of: ${allCasing.map((c) => '"' + c + '"').join(', ')}`, {
85+
node: expr.args[0]!,
86+
});
9287
}
9388
}
9489

packages/language/test/utils.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@ import path from 'node:path';
55
import { expect } from 'vitest';
66
import { loadDocument } from '../src';
77

8+
const pluginDocs = [path.resolve(__dirname, '../../plugins/policy/plugin.zmodel')];
9+
810
export async function loadSchema(schema: string) {
911
// create a temp file
1012
const tempFile = path.join(os.tmpdir(), `zenstack-schema-${crypto.randomUUID()}.zmodel`);
1113
fs.writeFileSync(tempFile, schema);
12-
const r = await loadDocument(tempFile);
14+
const r = await loadDocument(tempFile, pluginDocs);
1315
expect(r).toSatisfy(
1416
(r) => r.success,
1517
`Failed to load schema: ${(r as any).errors?.map((e) => e.toString()).join(', ')}`,
@@ -22,7 +24,8 @@ export async function loadSchemaWithError(schema: string, error: string | RegExp
2224
// create a temp file
2325
const tempFile = path.join(os.tmpdir(), `zenstack-schema-${crypto.randomUUID()}.zmodel`);
2426
fs.writeFileSync(tempFile, schema);
25-
const r = await loadDocument(tempFile);
27+
28+
const r = await loadDocument(tempFile, pluginDocs);
2629
expect(r.success).toBe(false);
2730
invariant(!r.success);
2831
if (typeof error === 'string') {

packages/plugins/policy/package.json

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
"author": "ZenStack Team",
1414
"license": "MIT",
1515
"files": [
16-
"dist"
16+
"dist",
17+
"plugin.zmodel"
1718
],
1819
"exports": {
1920
".": {
@@ -26,14 +27,17 @@
2627
"default": "./dist/index.cjs"
2728
}
2829
},
30+
"./plugin.zmodel": {
31+
"import": "./plugin.zmodel",
32+
"require": "./plugin.zmodel"
33+
},
2934
"./package.json": {
3035
"import": "./package.json",
3136
"require": "./package.json"
3237
}
3338
},
3439
"dependencies": {
3540
"@zenstackhq/common-helpers": "workspace:*",
36-
"@zenstackhq/sdk": "workspace:*",
3741
"@zenstackhq/runtime": "workspace:*",
3842
"ts-pattern": "catalog:"
3943
},

0 commit comments

Comments
 (0)