Skip to content

Commit e1bda19

Browse files
authored
feat(policy): implementing check function (#255)
* feat(policy): implementing `check` function * addressing PR comments
1 parent b17bf54 commit e1bda19

File tree

22 files changed

+937
-71
lines changed

22 files changed

+937
-71
lines changed

packages/language/res/stdlib.zmodel

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -594,16 +594,6 @@ function datetime(field: String): Boolean {
594594
function url(field: String): Boolean {
595595
} @@@expressionContext([ValidationRule])
596596

597-
/**
598-
* Checks if the current user can perform the given operation on the given field.
599-
*
600-
* @param field: The field to check access for
601-
* @param operation: The operation to check access for. Can be "read", "create", "update", or "delete". If the operation is not provided,
602-
* it defaults the operation of the containing policy rule.
603-
*/
604-
function check(field: Any, operation: String?): Boolean {
605-
} @@@expressionContext([AccessPolicy])
606-
607597
//////////////////////////////////////////////
608598
// End validation attributes and functions
609599
//////////////////////////////////////////////

packages/language/src/utils.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,8 +357,9 @@ export function getFieldReference(expr: Expression): DataField | undefined {
357357
}
358358
}
359359

360+
// TODO: move to policy plugin
360361
export function isCheckInvocation(node: AstNode) {
361-
return isInvocationExpr(node) && node.function.ref?.name === 'check' && isFromStdlib(node.function.ref);
362+
return isInvocationExpr(node) && node.function.ref?.name === 'check';
362363
}
363364

364365
export function resolveTransitiveImports(documents: LangiumDocuments, model: Model) {

packages/language/src/validators/function-invocation-validator.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ export default class FunctionInvocationValidator implements AstValidator<Express
170170
return true;
171171
}
172172

173+
// TODO: move this to policy plugin
173174
@func('check')
174175
// @ts-expect-error
175176
private _checkCheck(expr: InvocationExpr, accept: ValidationAcceptor) {

packages/runtime/src/client/contract.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,11 @@ export interface ClientConstructor {
213213
*/
214214
export type CRUD = 'create' | 'read' | 'update' | 'delete';
215215

216+
/**
217+
* CRUD operations.
218+
*/
219+
export const CRUD = ['create', 'read', 'update', 'delete'] as const;
220+
216221
//#region Model operations
217222

218223
export type AllModelOperations<Schema extends SchemaDef, Model extends GetModels<Schema>> = {

packages/runtime/src/client/executor/kysely-utils.ts

Lines changed: 0 additions & 12 deletions
This file was deleted.

packages/runtime/src/client/executor/name-mapper.ts

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ import {
1717
type OperationNode,
1818
} from 'kysely';
1919
import type { FieldDef, ModelDef, SchemaDef } from '../../schema';
20+
import { extractFieldName, extractModelName, stripAlias } from '../kysely-utils';
2021
import { getModel, requireModel } from '../query-utils';
21-
import { stripAlias } from './kysely-utils';
2222

2323
type Scope = {
2424
model?: string;
@@ -170,7 +170,7 @@ export class QueryNameMapper extends OperationNodeTransformer {
170170
const scopes: Scope[] = node.from.froms.map((node) => {
171171
const { alias, node: innerNode } = stripAlias(node);
172172
return {
173-
model: this.extractModelName(innerNode),
173+
model: extractModelName(innerNode),
174174
alias,
175175
namesMapped: false,
176176
};
@@ -219,8 +219,8 @@ export class QueryNameMapper extends OperationNodeTransformer {
219219
selections.push(SelectionNode.create(transformed));
220220
} else {
221221
// otherwise use an alias to preserve the original field name
222-
const origFieldName = this.extractFieldName(selection.selection);
223-
const fieldName = this.extractFieldName(transformed);
222+
const origFieldName = extractFieldName(selection.selection);
223+
const fieldName = extractFieldName(transformed);
224224
if (fieldName !== origFieldName) {
225225
selections.push(
226226
SelectionNode.create(
@@ -425,7 +425,7 @@ export class QueryNameMapper extends OperationNodeTransformer {
425425
private processSelection(node: AliasNode | ColumnNode | ReferenceNode) {
426426
let alias: string | undefined;
427427
if (!AliasNode.is(node)) {
428-
alias = this.extractFieldName(node);
428+
alias = extractFieldName(node);
429429
}
430430
const result = super.transformNode(node);
431431
return this.wrapAlias(result, alias ? IdentifierNode.create(alias) : undefined);
@@ -451,20 +451,5 @@ export class QueryNameMapper extends OperationNodeTransformer {
451451
});
452452
}
453453

454-
private extractModelName(node: OperationNode): string | undefined {
455-
const { node: innerNode } = stripAlias(node);
456-
return TableNode.is(innerNode!) ? innerNode!.table.identifier.name : undefined;
457-
}
458-
459-
private extractFieldName(node: ReferenceNode | ColumnNode) {
460-
if (ReferenceNode.is(node) && ColumnNode.is(node.column)) {
461-
return node.column.column.name;
462-
} else if (ColumnNode.is(node)) {
463-
return node.column.name;
464-
} else {
465-
return undefined;
466-
}
467-
}
468-
469454
// #endregion
470455
}

packages/runtime/src/client/executor/zenstack-query-executor.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ import type { GetModels, SchemaDef } from '../../schema';
2626
import { type ClientImpl } from '../client-impl';
2727
import { TransactionIsolationLevel, type ClientContract } from '../contract';
2828
import { InternalError, QueryError } from '../errors';
29+
import { stripAlias } from '../kysely-utils';
2930
import type { AfterEntityMutationCallback, OnKyselyQueryCallback } from '../plugin';
30-
import { stripAlias } from './kysely-utils';
3131
import { QueryNameMapper } from './name-mapper';
3232
import type { ZenStackDriver } from './zenstack-driver';
3333

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import { type OperationNode, AliasNode, ColumnNode, ReferenceNode, TableNode } from 'kysely';
2+
3+
/**
4+
* Strips alias from the node if it exists.
5+
*/
6+
export function stripAlias(node: OperationNode) {
7+
if (AliasNode.is(node)) {
8+
return { alias: node.alias, node: node.node };
9+
} else {
10+
return { alias: undefined, node };
11+
}
12+
}
13+
14+
/**
15+
* Extracts model name from an OperationNode.
16+
*/
17+
export function extractModelName(node: OperationNode) {
18+
const { node: innerNode } = stripAlias(node);
19+
return TableNode.is(innerNode!) ? innerNode!.table.identifier.name : undefined;
20+
}
21+
22+
/**
23+
* Extracts field name from an OperationNode.
24+
*/
25+
export function extractFieldName(node: OperationNode) {
26+
if (ReferenceNode.is(node) && ColumnNode.is(node.column)) {
27+
return node.column.column.name;
28+
} else if (ColumnNode.is(node)) {
29+
return node.column.name;
30+
} else {
31+
return undefined;
32+
}
33+
}

packages/runtime/src/client/options.ts

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,29 @@ import type { RuntimePlugin } from './plugin';
77
import type { ToKyselySchema } from './query-builder';
88

99
export type ZModelFunctionContext<Schema extends SchemaDef> = {
10+
/**
11+
* ZenStack client instance
12+
*/
13+
client: ClientContract<Schema>;
14+
15+
/**
16+
* Database dialect
17+
*/
1018
dialect: BaseCrudDialect<Schema>;
19+
20+
/**
21+
* The containing model name
22+
*/
1123
model: GetModels<Schema>;
24+
25+
/**
26+
* The alias name that can be used to refer to the containing model
27+
*/
28+
modelAlias: string;
29+
30+
/**
31+
* The CRUD operation being performed
32+
*/
1233
operation: CRUD;
1334
};
1435

packages/runtime/src/client/plugin.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import type { ClientContract } from '.';
33
import type { GetModels, SchemaDef } from '../schema';
44
import type { MaybePromise } from '../utils/type-utils';
55
import type { AllCrudOperation } from './crud/operations/base';
6+
import type { ZModelFunction } from './options';
67

78
/**
89
* ZenStack runtime plugin.
@@ -23,6 +24,11 @@ export interface RuntimePlugin<Schema extends SchemaDef = SchemaDef> {
2324
*/
2425
description?: string;
2526

27+
/**
28+
* Custom function implementations.
29+
*/
30+
functions?: Record<string, ZModelFunction<Schema>>;
31+
2632
/**
2733
* Intercepts an ORM query.
2834
*/

0 commit comments

Comments
 (0)