Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"skipFiles": ["<node_internals>/**"],
"type": "node",
"args": ["generate"],
"cwd": "${workspaceFolder}/samples/blog/zenstack"
"cwd": "${workspaceFolder}/samples/blog"
},
{
"name": "Debug with TSX",
Expand Down
68 changes: 65 additions & 3 deletions packages/cli/src/actions/action-utils.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import { loadDocument } from '@zenstackhq/language';
import { isDataSource } from '@zenstackhq/language/ast';
import { createZModelServices, loadDocument, type ZModelServices } from '@zenstackhq/language';
import { isDataSource, isPlugin, Model } from '@zenstackhq/language/ast';
import { getLiteral } from '@zenstackhq/language/utils';
import { PrismaSchemaGenerator } from '@zenstackhq/sdk';
import colors from 'colors';
import fs from 'node:fs';
import path from 'node:path';
import { fileURLToPath } from 'node:url';
import { CliError } from '../cli-error';
import { PLUGIN_MODULE_NAME } from '../constants';

export function getSchemaFile(file?: string) {
if (file) {
Expand Down Expand Up @@ -34,7 +37,9 @@ export function getSchemaFile(file?: string) {
}

export async function loadSchemaDocument(schemaFile: string) {
const loadResult = await loadDocument(schemaFile);
const { ZModelLanguage: services } = createZModelServices();
const pluginDocs = await getPluginDocuments(services, schemaFile);
const loadResult = await loadDocument(schemaFile, pluginDocs);
if (!loadResult.success) {
loadResult.errors.forEach((err) => {
console.error(colors.red(err));
Expand All @@ -47,6 +52,63 @@ export async function loadSchemaDocument(schemaFile: string) {
return loadResult.model;
}

export async function getPluginDocuments(services: ZModelServices, fileName: string): Promise<string[]> {
// parse the user document (without validation)
const parseResult = services.parser.LangiumParser.parse(fs.readFileSync(fileName, { encoding: 'utf-8' }));
const parsed = parseResult.value as Model;

// balk if there are syntax errors
if (parseResult.lexerErrors.length > 0 || parseResult.parserErrors.length > 0) {
return [];
}

// traverse plugins and collect "plugin.zmodel" documents
const result: string[] = [];
for (const decl of parsed.declarations.filter(isPlugin)) {
const providerField = decl.fields.find((f) => f.name === 'provider');
if (!providerField) {
continue;
}

const provider = getLiteral<string>(providerField.value);
if (!provider) {
continue;
}

let pluginModelFile: string | undefined;

// first try to treat provider as a path
let providerPath = path.resolve(path.dirname(fileName), provider);
if (fs.existsSync(providerPath)) {
if (fs.statSync(providerPath).isDirectory()) {
providerPath = path.join(providerPath, 'index.js');
}

// try plugin.zmodel next to the provider file
pluginModelFile = path.resolve(path.dirname(providerPath), PLUGIN_MODULE_NAME);
if (!fs.existsSync(pluginModelFile)) {
// try to find upwards
pluginModelFile = findUp([PLUGIN_MODULE_NAME], path.dirname(providerPath));
}
}

if (!pluginModelFile) {
// try loading it as a ESM module
try {
const resolvedUrl = import.meta.resolve(`${provider}/${PLUGIN_MODULE_NAME}`);
pluginModelFile = fileURLToPath(resolvedUrl);
} catch {
// noop
}
}

if (pluginModelFile && fs.existsSync(pluginModelFile)) {
result.push(pluginModelFile);
}
}
return result;
}

export function handleSubProcessError(err: unknown) {
if (err instanceof Error && 'status' in err && typeof err.status === 'number') {
process.exit(err.status);
Expand Down
10 changes: 6 additions & 4 deletions packages/cli/src/actions/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ async function runPlugins(schemaFile: string, model: Model, outputPath: string,
for (const plugin of plugins) {
const provider = getPluginProvider(plugin);

let cliPlugin: CliPlugin;
let cliPlugin: CliPlugin | undefined;
if (provider.startsWith('@core/')) {
cliPlugin = (corePlugins as any)[provider.slice('@core/'.length)];
if (!cliPlugin) {
Expand All @@ -78,12 +78,14 @@ async function runPlugins(schemaFile: string, model: Model, outputPath: string,
}
try {
cliPlugin = (await import(moduleSpec)).default as CliPlugin;
} catch (error) {
throw new CliError(`Failed to load plugin ${provider}: ${error}`);
} catch {
// plugin may not export a generator so we simply ignore the error here
}
}

processedPlugins.push({ cliPlugin, pluginOptions: getPluginOptions(plugin) });
if (cliPlugin) {
processedPlugins.push({ cliPlugin, pluginOptions: getPluginOptions(plugin) });
}
}

const defaultPlugins = [corePlugins['typescript']].reverse();
Expand Down
3 changes: 3 additions & 0 deletions packages/cli/src/constants.ts
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
// replaced at build time
export const TELEMETRY_TRACKING_TOKEN = '<TELEMETRY_TRACKING_TOKEN>';

// plugin-contributed model file name
export const PLUGIN_MODULE_NAME = 'plugin.zmodel';
76 changes: 0 additions & 76 deletions packages/language/res/stdlib.zmodel
Original file line number Diff line number Diff line change
Expand Up @@ -174,29 +174,6 @@ function hasSome(field: Any[], search: Any[]): Boolean {
function isEmpty(field: Any[]): Boolean {
} @@@expressionContext([AccessPolicy, ValidationRule])

/**
* The name of the model for which the policy rule is defined. If the rule is
* inherited to a sub model, this function returns the name of the sub model.
*
* @param optional parameter to control the casing of the returned value. Valid
* values are "original", "upper", "lower", "capitalize", "uncapitalize". Defaults
* to "original".
*/
function currentModel(casing: String?): String {
} @@@expressionContext([AccessPolicy])

/**
* The operation for which the policy rule is defined for. Note that a rule with
* "all" operation is expanded to "create", "read", "update", and "delete" rules,
* and the function returns corresponding value for each expanded version.
*
* @param optional parameter to control the casing of the returned value. Valid
* values are "original", "upper", "lower", "capitalize", "uncapitalize". Defaults
* to "original".
*/
function currentOperation(casing: String?): String {
} @@@expressionContext([AccessPolicy])

/**
* Marks an attribute to be only applicable to certain field types.
*/
Expand Down Expand Up @@ -658,56 +635,3 @@ attribute @meta(_ name: String, _ value: Any)
* Marks an attribute as deprecated.
*/
attribute @@@deprecated(_ message: String)

/* --- Policy Plugin --- */

/**
* Defines an access policy that allows a set of operations when the given condition is true.
*
* @param operation: comma-separated list of "create", "read", "update", "delete". Use "all" to denote all operations.
* @param condition: a boolean expression that controls if the operation should be allowed.
*/
attribute @@allow(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'post-update'","'delete'", "'all'"]), _ condition: Boolean)

/**
* Defines an access policy that allows the annotated field to be read or updated.
* You can pass a third argument as `true` to make it override the model-level policies.
*
* @param operation: comma-separated list of "create", "read", "update", "delete". Use "all" to denote all operations.
* @param condition: a boolean expression that controls if the operation should be allowed.
* @param override: a boolean value that controls if the field-level policy should override the model-level policy.
*/
// attribute @allow(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'delete'", "'all'"]), _ condition: Boolean, _ override: Boolean?)

/**
* Defines an access policy that denies a set of operations when the given condition is true.
*
* @param operation: comma-separated list of "create", "read", "update", "delete". Use "all" to denote all operations.
* @param condition: a boolean expression that controls if the operation should be denied.
*/
attribute @@deny(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'post-update'","'delete'", "'all'"]), _ condition: Boolean)

/**
* Defines an access policy that denies the annotated field to be read or updated.
*
* @param operation: comma-separated list of "create", "read", "update", "delete". Use "all" to denote all operations.
* @param condition: a boolean expression that controls if the operation should be denied.
*/
// attribute @deny(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'delete'", "'all'"]), _ condition: Boolean)

/**
* Checks if the current user can perform the given operation on the given field.
*
* @param field: The field to check access for
* @param operation: The operation to check access for. Can be "read", "create", "update", or "delete". If the operation is not provided,
* it defaults the operation of the containing policy rule.
*/
function check(field: Any, operation: String?): Boolean {
} @@@expressionContext([AccessPolicy])

/**
* Gets entity's value before an update. Only valid when used in a "post-update" policy rule.
*/
function before(): Any {
} @@@expressionContext([AccessPolicy])

4 changes: 2 additions & 2 deletions packages/language/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ export async function loadDocument(
);

// load additional model files
const pluginDocs = await Promise.all(
const additionalDocs = await Promise.all(
additionalModelFiles.map((file) =>
services.shared.workspace.LangiumDocuments.getOrCreateDocument(URI.file(path.resolve(file))),
),
Expand All @@ -69,7 +69,7 @@ export async function loadDocument(
}

// build the document together with standard library, plugin modules, and imported documents
await services.shared.workspace.DocumentBuilder.build([stdLib, ...pluginDocs, document, ...importedDocuments], {
await services.shared.workspace.DocumentBuilder.build([stdLib, ...additionalDocs, document, ...importedDocuments], {
validation: {
stopAfterLexingErrors: true,
stopAfterParsingErrors: true,
Expand Down
3 changes: 2 additions & 1 deletion packages/language/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -447,8 +447,9 @@ export function getAuthDecl(decls: (DataModel | TypeDef)[]) {
return authModel;
}

// TODO: move to policy plugin
export function isBeforeInvocation(node: AstNode) {
return isInvocationExpr(node) && node.function.ref?.name === 'before' && isFromStdlib(node.function.ref);
return isInvocationExpr(node) && node.function.ref?.name === 'before';
}

export function isCollectionPredicate(node: AstNode): node is BinaryExpr {
Expand Down
59 changes: 27 additions & 32 deletions packages/language/src/validators/function-invocation-validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import {
getLiteral,
isCheckInvocation,
isDataFieldReference,
isFromStdlib,
typeAssignable,
} from '../utils';
import type { AstValidator } from './common';
Expand Down Expand Up @@ -52,43 +51,39 @@ export default class FunctionInvocationValidator implements AstValidator<Express
return;
}

if (isFromStdlib(funcDecl)) {
// validate standard library functions

// find the containing attribute context for the invocation
let curr: AstNode | undefined = expr.$container;
let containerAttribute: DataModelAttribute | DataFieldAttribute | undefined;
while (curr) {
if (isDataModelAttribute(curr) || isDataFieldAttribute(curr)) {
containerAttribute = curr;
break;
}
curr = curr.$container;
// find the containing attribute context for the invocation
let curr: AstNode | undefined = expr.$container;
let containerAttribute: DataModelAttribute | DataFieldAttribute | undefined;
while (curr) {
if (isDataModelAttribute(curr) || isDataFieldAttribute(curr)) {
containerAttribute = curr;
break;
}
curr = curr.$container;
}

// validate the context allowed for the function
const exprContext = this.getExpressionContext(containerAttribute);
// validate the context allowed for the function
const exprContext = this.getExpressionContext(containerAttribute);

// get the context allowed for the function
const funcAllowedContext = getFunctionExpressionContext(funcDecl);
// get the context allowed for the function
const funcAllowedContext = getFunctionExpressionContext(funcDecl);

if (exprContext && !funcAllowedContext.includes(exprContext)) {
accept('error', `function "${funcDecl.name}" is not allowed in the current context: ${exprContext}`, {
node: expr,
});
return;
}
if (exprContext && !funcAllowedContext.includes(exprContext)) {
accept('error', `function "${funcDecl.name}" is not allowed in the current context: ${exprContext}`, {
node: expr,
});
return;
}

// TODO: express function validation rules declaratively in ZModel
// TODO: express function validation rules declaratively in ZModel

const allCasing = ['original', 'upper', 'lower', 'capitalize', 'uncapitalize'];
if (['currentModel', 'currentOperation'].includes(funcDecl.name)) {
const arg = getLiteral<string>(expr.args[0]?.value);
if (arg && !allCasing.includes(arg)) {
accept('error', `argument must be one of: ${allCasing.map((c) => '"' + c + '"').join(', ')}`, {
node: expr.args[0]!,
});
}
const allCasing = ['original', 'upper', 'lower', 'capitalize', 'uncapitalize'];
if (['currentModel', 'currentOperation'].includes(funcDecl.name)) {
const arg = getLiteral<string>(expr.args[0]?.value);
if (arg && !allCasing.includes(arg)) {
accept('error', `argument must be one of: ${allCasing.map((c) => '"' + c + '"').join(', ')}`, {
node: expr.args[0]!,
});
}
}

Expand Down
7 changes: 5 additions & 2 deletions packages/language/test/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ import path from 'node:path';
import { expect } from 'vitest';
import { loadDocument } from '../src';

const pluginDocs = [path.resolve(__dirname, '../../plugins/policy/plugin.zmodel')];

export async function loadSchema(schema: string) {
// create a temp file
const tempFile = path.join(os.tmpdir(), `zenstack-schema-${crypto.randomUUID()}.zmodel`);
fs.writeFileSync(tempFile, schema);
const r = await loadDocument(tempFile);
const r = await loadDocument(tempFile, pluginDocs);
expect(r).toSatisfy(
(r) => r.success,
`Failed to load schema: ${(r as any).errors?.map((e) => e.toString()).join(', ')}`,
Expand All @@ -22,7 +24,8 @@ export async function loadSchemaWithError(schema: string, error: string | RegExp
// create a temp file
const tempFile = path.join(os.tmpdir(), `zenstack-schema-${crypto.randomUUID()}.zmodel`);
fs.writeFileSync(tempFile, schema);
const r = await loadDocument(tempFile);

const r = await loadDocument(tempFile, pluginDocs);
expect(r.success).toBe(false);
invariant(!r.success);
if (typeof error === 'string') {
Expand Down
8 changes: 6 additions & 2 deletions packages/plugins/policy/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
"author": "ZenStack Team",
"license": "MIT",
"files": [
"dist"
"dist",
"plugin.zmodel"
],
"exports": {
".": {
Expand All @@ -26,14 +27,17 @@
"default": "./dist/index.cjs"
}
},
"./plugin.zmodel": {
"import": "./plugin.zmodel",
"require": "./plugin.zmodel"
},
"./package.json": {
"import": "./package.json",
"require": "./package.json"
}
},
"dependencies": {
"@zenstackhq/common-helpers": "workspace:*",
"@zenstackhq/sdk": "workspace:*",
"@zenstackhq/runtime": "workspace:*",
"ts-pattern": "catalog:"
},
Expand Down
Loading
Loading