Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions packages/language/res/stdlib.zmodel
Original file line number Diff line number Diff line change
Expand Up @@ -594,16 +594,6 @@ function datetime(field: String): Boolean {
function url(field: String): Boolean {
} @@@expressionContext([ValidationRule])

/**
* Checks if the current user can perform the given operation on the given field.
*
* @param field: The field to check access for
* @param operation: The operation to check access for. Can be "read", "create", "update", or "delete". If the operation is not provided,
* it defaults the operation of the containing policy rule.
*/
function check(field: Any, operation: String?): Boolean {
} @@@expressionContext([AccessPolicy])

//////////////////////////////////////////////
// End validation attributes and functions
//////////////////////////////////////////////
Expand Down
3 changes: 2 additions & 1 deletion packages/language/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -357,8 +357,9 @@ export function getFieldReference(expr: Expression): DataField | undefined {
}
}

// TODO: move to policy plugin
export function isCheckInvocation(node: AstNode) {
return isInvocationExpr(node) && node.function.ref?.name === 'check' && isFromStdlib(node.function.ref);
return isInvocationExpr(node) && node.function.ref?.name === 'check';
}

export function resolveTransitiveImports(documents: LangiumDocuments, model: Model) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ export default class FunctionInvocationValidator implements AstValidator<Express
return true;
}

// TODO: move this to policy plugin
@func('check')
// @ts-expect-error
private _checkCheck(expr: InvocationExpr, accept: ValidationAcceptor) {
Expand Down
5 changes: 5 additions & 0 deletions packages/runtime/src/client/contract.ts
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,11 @@ export interface ClientConstructor {
*/
export type CRUD = 'create' | 'read' | 'update' | 'delete';

/**
* CRUD operations.
*/
export const CRUD = ['create', 'read', 'update', 'delete'] as const;

//#region Model operations

export type AllModelOperations<Schema extends SchemaDef, Model extends GetModels<Schema>> = {
Expand Down
12 changes: 0 additions & 12 deletions packages/runtime/src/client/executor/kysely-utils.ts

This file was deleted.

25 changes: 5 additions & 20 deletions packages/runtime/src/client/executor/name-mapper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ import {
type OperationNode,
} from 'kysely';
import type { FieldDef, ModelDef, SchemaDef } from '../../schema';
import { extractFieldName, extractModelName, stripAlias } from '../kysely-utils';
import { getModel, requireModel } from '../query-utils';
import { stripAlias } from './kysely-utils';

type Scope = {
model?: string;
Expand Down Expand Up @@ -170,7 +170,7 @@ export class QueryNameMapper extends OperationNodeTransformer {
const scopes: Scope[] = node.from.froms.map((node) => {
const { alias, node: innerNode } = stripAlias(node);
return {
model: this.extractModelName(innerNode),
model: extractModelName(innerNode),
alias,
namesMapped: false,
};
Expand Down Expand Up @@ -219,8 +219,8 @@ export class QueryNameMapper extends OperationNodeTransformer {
selections.push(SelectionNode.create(transformed));
} else {
// otherwise use an alias to preserve the original field name
const origFieldName = this.extractFieldName(selection.selection);
const fieldName = this.extractFieldName(transformed);
const origFieldName = extractFieldName(selection.selection);
const fieldName = extractFieldName(transformed);
if (fieldName !== origFieldName) {
selections.push(
SelectionNode.create(
Expand Down Expand Up @@ -425,7 +425,7 @@ export class QueryNameMapper extends OperationNodeTransformer {
private processSelection(node: AliasNode | ColumnNode | ReferenceNode) {
let alias: string | undefined;
if (!AliasNode.is(node)) {
alias = this.extractFieldName(node);
alias = extractFieldName(node);
}
const result = super.transformNode(node);
return this.wrapAlias(result, alias ? IdentifierNode.create(alias) : undefined);
Expand All @@ -451,20 +451,5 @@ export class QueryNameMapper extends OperationNodeTransformer {
});
}

private extractModelName(node: OperationNode): string | undefined {
const { node: innerNode } = stripAlias(node);
return TableNode.is(innerNode!) ? innerNode!.table.identifier.name : undefined;
}

private extractFieldName(node: ReferenceNode | ColumnNode) {
if (ReferenceNode.is(node) && ColumnNode.is(node.column)) {
return node.column.column.name;
} else if (ColumnNode.is(node)) {
return node.column.name;
} else {
return undefined;
}
}

// #endregion
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ import type { GetModels, SchemaDef } from '../../schema';
import { type ClientImpl } from '../client-impl';
import { TransactionIsolationLevel, type ClientContract } from '../contract';
import { InternalError, QueryError } from '../errors';
import { stripAlias } from '../kysely-utils';
import type { AfterEntityMutationCallback, OnKyselyQueryCallback } from '../plugin';
import { stripAlias } from './kysely-utils';
import { QueryNameMapper } from './name-mapper';
import type { ZenStackDriver } from './zenstack-driver';

Expand Down
33 changes: 33 additions & 0 deletions packages/runtime/src/client/kysely-utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import { type OperationNode, AliasNode, ColumnNode, ReferenceNode, TableNode } from 'kysely';

/**
* Strips alias from the node if it exists.
*/
export function stripAlias(node: OperationNode) {
if (AliasNode.is(node)) {
return { alias: node.alias, node: node.node };
} else {
return { alias: undefined, node };
}
}

/**
* Extracts model name from an OperationNode.
*/
export function extractModelName(node: OperationNode) {
const { node: innerNode } = stripAlias(node);
return TableNode.is(innerNode!) ? innerNode!.table.identifier.name : undefined;
}

/**
* Extracts field name from an OperationNode.
*/
export function extractFieldName(node: OperationNode) {
if (ReferenceNode.is(node) && ColumnNode.is(node.column)) {
return node.column.column.name;
} else if (ColumnNode.is(node)) {
return node.column.name;
} else {
return undefined;
}
}
21 changes: 21 additions & 0 deletions packages/runtime/src/client/options.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,29 @@ import type { RuntimePlugin } from './plugin';
import type { ToKyselySchema } from './query-builder';

export type ZModelFunctionContext<Schema extends SchemaDef> = {
/**
* ZenStack client instance
*/
client: ClientContract<Schema>;

/**
* Database dialect
*/
dialect: BaseCrudDialect<Schema>;

/**
* The containing model name
*/
model: GetModels<Schema>;

/**
* The alias name that can be used to refer to the containing model
*/
modelAlias: string;

/**
* The CRUD operation being performed
*/
operation: CRUD;
};

Expand Down
6 changes: 6 additions & 0 deletions packages/runtime/src/client/plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import type { ClientContract } from '.';
import type { GetModels, SchemaDef } from '../schema';
import type { MaybePromise } from '../utils/type-utils';
import type { AllCrudOperation } from './crud/operations/base';
import type { ZModelFunction } from './options';

/**
* ZenStack runtime plugin.
Expand All @@ -23,6 +24,11 @@ export interface RuntimePlugin<Schema extends SchemaDef = SchemaDef> {
*/
description?: string;

/**
* Custom function implementations.
*/
functions?: Record<string, ZModelFunction<Schema>>;

/**
* Intercepts an ORM query.
*/
Expand Down
2 changes: 1 addition & 1 deletion packages/runtime/src/client/query-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ export function buildFieldRef<Schema extends SchemaDef>(
if (!computer) {
throw new QueryError(`Computed field "${field}" implementation not provided for model "${model}"`);
}
return computer(eb, { currentModel: modelAlias });
return computer(eb, { modelAlias });
}
}

Expand Down
40 changes: 32 additions & 8 deletions packages/runtime/src/plugins/policy/expression-transformer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@ import {
type OperationNode,
} from 'kysely';
import { match } from 'ts-pattern';
import type { CRUD } from '../../client/contract';
import type { ClientContract, CRUD } from '../../client/contract';
import { getCrudDialect } from '../../client/crud/dialects';
import type { BaseCrudDialect } from '../../client/crud/dialects/base-dialect';
import { InternalError, QueryError } from '../../client/errors';
import type { ClientOptions } from '../../client/options';
import { getModel, getRelationForeignKeyFieldPairs, requireField, requireIdFields } from '../../client/query-utils';
import type {
BinaryExpression,
Expand Down Expand Up @@ -72,14 +71,22 @@ function expr(kind: Expression['kind']) {
export class ExpressionTransformer<Schema extends SchemaDef> {
private readonly dialect: BaseCrudDialect<Schema>;

constructor(
private readonly schema: Schema,
private readonly clientOptions: ClientOptions<Schema>,
private readonly auth: unknown | undefined,
) {
constructor(private readonly client: ClientContract<Schema>) {
this.dialect = getCrudDialect(this.schema, this.clientOptions);
}

get schema() {
return this.client.$schema;
}

get clientOptions() {
return this.client.$options;
}

get auth() {
return this.client.$auth;
}

get authType() {
if (!this.schema.authType) {
throw new InternalError('Schema does not have an "authType" specified');
Expand Down Expand Up @@ -354,7 +361,7 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
}

private transformCall(expr: CallExpression, context: ExpressionTransformerContext<Schema>) {
const func = this.clientOptions.functions?.[expr.function];
const func = this.getFunctionImpl(expr.function);
if (!func) {
throw new QueryError(`Function not implemented: ${expr.function}`);
}
Expand All @@ -363,13 +370,30 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
eb,
(expr.args ?? []).map((arg) => this.transformCallArg(eb, arg, context)),
{
client: this.client,
dialect: this.dialect,
model: context.model,
modelAlias: context.alias ?? context.model,
operation: context.operation,
},
);
}

private getFunctionImpl(functionName: string) {
// check built-in functions
let func = this.clientOptions.functions?.[functionName];
if (!func) {
// check plugins
for (const plugin of this.clientOptions.plugins ?? []) {
if (plugin.functions?.[functionName]) {
func = plugin.functions[functionName];
break;
}
}
}
return func;
}

private transformCallArg(
eb: ExpressionBuilder<any, any>,
arg: Expression,
Expand Down
62 changes: 62 additions & 0 deletions packages/runtime/src/plugins/policy/functions.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import { invariant } from '@zenstackhq/common-helpers';
import { ExpressionWrapper, ValueNode, type Expression, type ExpressionBuilder } from 'kysely';
import { CRUD } from '../../client/contract';
import { extractFieldName } from '../../client/kysely-utils';
import type { ZModelFunction, ZModelFunctionContext } from '../../client/options';
import { buildJoinPairs, requireField } from '../../client/query-utils';
import { PolicyHandler } from './policy-handler';

/**
* Relation checker implementation.
*/
export const check: ZModelFunction<any> = (
eb: ExpressionBuilder<any, any>,
args: Expression<any>[],
{ client, model, modelAlias, operation }: ZModelFunctionContext<any>,
) => {
invariant(args.length === 1 || args.length === 2, '"check" function requires 1 or 2 arguments');

const arg1Node = args[0]!.toOperationNode();

const arg2Node = args.length === 2 ? args[1]!.toOperationNode() : undefined;
if (arg2Node) {
invariant(
ValueNode.is(arg2Node) && typeof arg2Node.value === 'string',
'"operation" parameter must be a string literal when provided',
);
invariant(
CRUD.includes(arg2Node.value as CRUD),
'"operation" parameter must be one of "create", "read", "update", "delete"',
);
}

// first argument must be a field reference
const fieldName = extractFieldName(arg1Node);
invariant(fieldName, 'Failed to extract field name from the first argument of "check" function');
const fieldDef = requireField(client.$schema, model, fieldName);
invariant(fieldDef.relation, `Field "${fieldName}" is not a relation field in model "${model}"`);
invariant(!fieldDef.array, `Field "${fieldName}" is a to-many relation, which is not supported by "check"`);
const relationModel = fieldDef.type;

const op = arg2Node ? (arg2Node.value as CRUD) : operation;

const policyHandler = new PolicyHandler(client);

// join with parent model
const joinPairs = buildJoinPairs(client.$schema, model, modelAlias, fieldName, relationModel);
const joinCondition =
joinPairs.length === 1
? eb(eb.ref(joinPairs[0]![0]), '=', eb.ref(joinPairs[0]![1]))
: eb.and(joinPairs.map(([left, right]) => eb(eb.ref(left), '=', eb.ref(right))));

// policy condition of the related model
const policyCondition = policyHandler.buildPolicyFilter(relationModel, undefined, op);

// build the final nested select that evaluates the policy condition
const result = eb
.selectFrom(relationModel)
.where(joinCondition)
.select(new ExpressionWrapper(policyCondition).as('$condition'));

return result;
};
7 changes: 7 additions & 0 deletions packages/runtime/src/plugins/policy/plugin.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { type OnKyselyQueryArgs, type RuntimePlugin } from '../../client/plugin';
import type { SchemaDef } from '../../schema';
import { check } from './functions';
import { PolicyHandler } from './policy-handler';

export class PolicyPlugin<Schema extends SchemaDef> implements RuntimePlugin<Schema> {
Expand All @@ -15,6 +16,12 @@ export class PolicyPlugin<Schema extends SchemaDef> implements RuntimePlugin<Sch
return 'Enforces access policies defined in the schema.';
}

get functions() {
return {
check,
};
}

onKyselyQuery({ query, client, proceed /*, transaction*/ }: OnKyselyQueryArgs<Schema>) {
const handler = new PolicyHandler<Schema>(client);
return handler.handle(query, proceed /*, transaction*/);
Expand Down
Loading