Skip to content

Commit be82307

Browse files
authored
feat: implement relation check() function in ZModel (#1556)
1 parent 4cc0326 commit be82307

File tree

14 files changed

+1115
-47
lines changed

14 files changed

+1115
-47
lines changed

packages/schema/package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@
9797
"change-case": "^4.1.2",
9898
"colors": "1.4.0",
9999
"commander": "^8.3.0",
100+
"deepmerge": "^4.3.1",
100101
"get-latest-version": "^5.0.1",
101102
"langium": "1.3.1",
102103
"lower-case-first": "^2.0.2",

packages/schema/src/language-server/validator/attribute-application-validator.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import pluralize from 'pluralize';
2121
import { AstValidator } from '../types';
2222
import { getStringLiteral, mapBuiltinTypeToExpressionType, typeAssignable } from './utils';
2323

24-
// a registry of function handlers marked with @func
24+
// a registry of function handlers marked with @check
2525
const attributeCheckers = new Map<string, PropertyDescriptor>();
2626

2727
// function handler decorator

packages/schema/src/language-server/validator/function-invocation-validator.ts

Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import {
22
Argument,
3+
DataModel,
34
DataModelAttribute,
45
DataModelFieldAttribute,
56
Expression,
67
FunctionDecl,
78
FunctionParam,
89
InvocationExpr,
910
isArrayExpr,
11+
isDataModel,
1012
isDataModelAttribute,
1113
isDataModelFieldAttribute,
1214
isLiteralExpr,
@@ -15,14 +17,29 @@ import {
1517
ExpressionContext,
1618
getDataModelFieldReference,
1719
getFunctionExpressionContext,
20+
getLiteral,
21+
isDataModelFieldReference,
1822
isEnumFieldReference,
1923
isFromStdlib,
2024
} from '@zenstackhq/sdk';
21-
import { AstNode, ValidationAcceptor } from 'langium';
22-
import { P, match } from 'ts-pattern';
25+
import { AstNode, streamAst, ValidationAcceptor } from 'langium';
26+
import { match, P } from 'ts-pattern';
27+
import { isCheckInvocation } from '../../utils/ast-utils';
2328
import { AstValidator } from '../types';
2429
import { typeAssignable } from './utils';
2530

31+
// a registry of function handlers marked with @func
32+
const invocationCheckers = new Map<string, PropertyDescriptor>();
33+
34+
// function handler decorator
35+
function func(name: string) {
36+
return function (_target: unknown, _propertyKey: string, descriptor: PropertyDescriptor) {
37+
if (!invocationCheckers.get(name)) {
38+
invocationCheckers.set(name, descriptor);
39+
}
40+
return descriptor;
41+
};
42+
}
2643
/**
2744
* InvocationExpr validation
2845
*/
@@ -104,6 +121,12 @@ export default class FunctionInvocationValidator implements AstValidator<Express
104121
}
105122
}
106123
}
124+
125+
// run checkers for specific functions
126+
const checker = invocationCheckers.get(expr.function.$refText);
127+
if (checker) {
128+
checker.value.call(this, expr, accept);
129+
}
107130
}
108131

109132
private validateArgs(funcDecl: FunctionDecl, args: Argument[], accept: ValidationAcceptor) {
@@ -167,4 +190,76 @@ export default class FunctionInvocationValidator implements AstValidator<Express
167190

168191
return true;
169192
}
193+
194+
@func('check')
195+
private _checkCheck(expr: InvocationExpr, accept: ValidationAcceptor) {
196+
let valid = true;
197+
198+
const fieldArg = expr.args[0].value;
199+
if (!isDataModelFieldReference(fieldArg) || !isDataModel(fieldArg.$resolvedType?.decl)) {
200+
accept('error', 'argument must be a relation field', { node: expr.args[0] });
201+
valid = false;
202+
}
203+
204+
if (fieldArg.$resolvedType?.array) {
205+
accept('error', 'argument cannot be an array field', { node: expr.args[0] });
206+
valid = false;
207+
}
208+
209+
const opArg = expr.args[1]?.value;
210+
if (opArg) {
211+
const operation = getLiteral<string>(opArg);
212+
if (!operation || !['read', 'create', 'update', 'delete'].includes(operation)) {
213+
accept('error', 'argument must be a "read", "create", "update", or "delete"', { node: expr.args[1] });
214+
valid = false;
215+
}
216+
}
217+
218+
if (!valid) {
219+
return;
220+
}
221+
222+
// check for cyclic relation checking
223+
const start = fieldArg.$resolvedType?.decl as DataModel;
224+
const tasks = [expr];
225+
const seen = new Set<DataModel>();
226+
227+
while (tasks.length > 0) {
228+
const currExpr = tasks.pop()!;
229+
const arg = currExpr.args[0]?.value;
230+
231+
if (!isDataModel(arg?.$resolvedType?.decl)) {
232+
continue;
233+
}
234+
235+
const currModel = arg.$resolvedType.decl;
236+
237+
if (seen.has(currModel)) {
238+
if (currModel === start) {
239+
accept('error', 'cyclic dependency detected when following the `check()` call', { node: expr });
240+
} else {
241+
// a cycle is detected but it doesn't start from the invocation expression we're checking,
242+
// just break here and the cycle will be reported when we validate the start of it
243+
}
244+
break;
245+
} else {
246+
seen.add(currModel);
247+
}
248+
249+
const policyAttrs = currModel.attributes.filter(
250+
(attr) => attr.decl.$refText === '@@allow' || attr.decl.$refText === '@@deny'
251+
);
252+
for (const attr of policyAttrs) {
253+
const rule = attr.args[1];
254+
if (!rule) {
255+
continue;
256+
}
257+
streamAst(rule).forEach((node) => {
258+
if (isCheckInvocation(node)) {
259+
tasks.push(node as InvocationExpr);
260+
}
261+
});
262+
}
263+
}
264+
}
170265
}

packages/schema/src/plugins/enhancer/policy/expression-writer.ts

Lines changed: 63 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,17 @@ import {
1919
StringLiteral,
2020
UnaryExpr,
2121
} from '@zenstackhq/language/ast';
22-
import { DELEGATE_AUX_RELATION_PREFIX } from '@zenstackhq/runtime';
22+
import { DELEGATE_AUX_RELATION_PREFIX, PolicyOperationKind } from '@zenstackhq/runtime';
2323
import {
2424
ExpressionContext,
2525
getFunctionExpressionContext,
2626
getIdFields,
2727
getLiteral,
28+
getQueryGuardFunctionName,
2829
isAuthInvocation,
2930
isDataModelFieldReference,
3031
isDelegateModel,
32+
isFromStdlib,
3133
isFutureExpr,
3234
PluginError,
3335
TypeScriptExpressionTransformer,
@@ -37,6 +39,7 @@ import { lowerCaseFirst } from 'lower-case-first';
3739
import invariant from 'tiny-invariant';
3840
import { CodeBlockWriter } from 'ts-morph';
3941
import { name } from '..';
42+
import { isCheckInvocation } from '../../../utils/ast-utils';
4043

4144
type ComparisonOperator = '==' | '!=' | '>' | '>=' | '<' | '<=';
4245

@@ -60,6 +63,11 @@ type FilterOperators =
6063
export const TRUE = '{ AND: [] }';
6164
export const FALSE = '{ OR: [] }';
6265

66+
export type ExpressionWriterOptions = {
67+
isPostGuard?: boolean;
68+
operationContext: PolicyOperationKind;
69+
};
70+
6371
/**
6472
* Utility for writing ZModel expression as Prisma query argument objects into a ts-morph writer
6573
*/
@@ -68,15 +76,14 @@ export class ExpressionWriter {
6876

6977
/**
7078
* Constructs a new ExpressionWriter
71-
*
72-
* @param isPostGuard indicates if we're writing for post-update conditions
7379
*/
74-
constructor(private readonly writer: CodeBlockWriter, private readonly isPostGuard = false) {
80+
constructor(private readonly writer: CodeBlockWriter, private readonly options: ExpressionWriterOptions) {
7581
this.plainExprBuilder = new TypeScriptExpressionTransformer({
7682
context: ExpressionContext.AccessPolicy,
77-
isPostGuard: this.isPostGuard,
83+
isPostGuard: this.options.isPostGuard,
7884
// in post-guard context, `this` references pre-update value
79-
thisExprContext: this.isPostGuard ? 'context.preValue' : undefined,
85+
thisExprContext: this.options.isPostGuard ? 'context.preValue' : undefined,
86+
operationContext: this.options.operationContext,
8087
});
8188
}
8289

@@ -269,17 +276,20 @@ export class ExpressionWriter {
269276
// expression rooted to `auth()` is always compiled to plain expression
270277
!this.isAuthOrAuthMemberAccess(expr.left) &&
271278
// `future()` in post-update context
272-
((this.isPostGuard && this.isFutureMemberAccess(expr.left)) ||
279+
((this.options.isPostGuard && this.isFutureMemberAccess(expr.left)) ||
273280
// non-`future()` in pre-update context
274-
(!this.isPostGuard && !this.isFutureMemberAccess(expr.left)));
281+
(!this.options.isPostGuard && !this.isFutureMemberAccess(expr.left)));
275282

276283
if (compileToRelationQuery) {
277284
this.block(() => {
278285
this.writeFieldCondition(
279286
expr.left,
280287
() => {
281288
// inner scope of collection expression is always compiled as non-post-guard
282-
const innerWriter = new ExpressionWriter(this.writer, false);
289+
const innerWriter = new ExpressionWriter(this.writer, {
290+
isPostGuard: false,
291+
operationContext: this.options.operationContext,
292+
});
283293
innerWriter.write(expr.right);
284294
},
285295
operator === '?' ? 'some' : operator === '!' ? 'every' : 'none'
@@ -297,14 +307,14 @@ export class ExpressionWriter {
297307
}
298308

299309
if (isMemberAccessExpr(expr)) {
300-
if (isFutureExpr(expr.operand) && this.isPostGuard) {
310+
if (isFutureExpr(expr.operand) && this.options.isPostGuard) {
301311
// when writing for post-update, future().field.x is a field access
302312
return true;
303313
} else {
304314
return this.isFieldAccess(expr.operand);
305315
}
306316
}
307-
if (isDataModelFieldReference(expr) && !this.isPostGuard) {
317+
if (isDataModelFieldReference(expr) && !this.options.isPostGuard) {
308318
return true;
309319
}
310320
return false;
@@ -437,7 +447,7 @@ export class ExpressionWriter {
437447
this.writer.write(operator === '!=' ? TRUE : FALSE);
438448
} else {
439449
this.writeOperator(operator, fieldAccess, () => {
440-
if (isDataModelFieldReference(operand) && !this.isPostGuard) {
450+
if (isDataModelFieldReference(operand) && !this.options.isPostGuard) {
441451
// if operand is a field reference and we're not generating for post-update guard,
442452
// we should generate a field reference (comparing fields in the same model)
443453
this.writeFieldReference(operand);
@@ -735,6 +745,11 @@ export class ExpressionWriter {
735745
functionAllowedContext.includes(ExpressionContext.AccessPolicy) ||
736746
functionAllowedContext.includes(ExpressionContext.ValidationRule)
737747
) {
748+
if (isCheckInvocation(expr)) {
749+
this.writeRelationCheck(expr);
750+
return;
751+
}
752+
738753
if (!expr.args.some((arg) => this.isFieldAccess(arg.value))) {
739754
// filter functions without referencing fields
740755
this.guard(() => this.plain(expr));
@@ -744,13 +759,13 @@ export class ExpressionWriter {
744759
let valueArg = expr.args[1]?.value;
745760

746761
// isEmpty function is zero arity, it's mapped to a boolean literal
747-
if (funcDecl.name === 'isEmpty') {
762+
if (isFromStdlib(funcDecl) && funcDecl.name === 'isEmpty') {
748763
valueArg = { $type: BooleanLiteral, value: true } as LiteralExpr;
749764
}
750765

751766
// contains function has a 3rd argument that indicates whether the comparison should be case-insensitive
752767
let extraArgs: Record<string, Expression> | undefined = undefined;
753-
if (funcDecl.name === 'contains') {
768+
if (isFromStdlib(funcDecl) && funcDecl.name === 'contains') {
754769
if (getLiteral<boolean>(expr.args[2]?.value) === true) {
755770
extraArgs = { mode: { $type: StringLiteral, value: 'insensitive' } as LiteralExpr };
756771
}
@@ -770,4 +785,38 @@ export class ExpressionWriter {
770785
throw new PluginError(name, `Unsupported function ${funcDecl.name}`);
771786
}
772787
}
788+
789+
private writeRelationCheck(expr: InvocationExpr) {
790+
if (!isDataModelFieldReference(expr.args[0].value)) {
791+
throw new PluginError(name, `First argument of check() must be a field`);
792+
}
793+
if (!isDataModel(expr.args[0].value.$resolvedType?.decl)) {
794+
throw new PluginError(name, `First argument of check() must be a relation field`);
795+
}
796+
797+
const fieldRef = expr.args[0].value;
798+
const targetModel = fieldRef.$resolvedType?.decl as DataModel;
799+
800+
let operation: string;
801+
if (expr.args[1]) {
802+
const literal = getLiteral<string>(expr.args[1].value);
803+
if (!literal) {
804+
throw new TypeScriptExpressionTransformerError(`Second argument of check() must be a string literal`);
805+
}
806+
if (!['read', 'create', 'update', 'delete'].includes(literal)) {
807+
throw new TypeScriptExpressionTransformerError(`Invalid check() operation "${literal}"`);
808+
}
809+
operation = literal;
810+
} else {
811+
if (!this.options.operationContext) {
812+
throw new TypeScriptExpressionTransformerError('Unable to determine CRUD operation from context');
813+
}
814+
operation = this.options.operationContext;
815+
}
816+
817+
this.block(() => {
818+
const targetGuardFunc = getQueryGuardFunctionName(targetModel, undefined, false, operation);
819+
this.writer.write(`${fieldRef.target.$refText}: ${targetGuardFunc}(context, db)`);
820+
});
821+
}
773822
}

0 commit comments

Comments
 (0)