Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions packages/language/src/ast.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ declare module './ast' {
$resolvedParam?: AttributeParam;
}

interface BinaryExpr {
/**
* Optional iterator binding for collection predicates
*/
binding?: string;
}

export interface DataModel {
/**
* All fields including those marked with `@ignore`
Expand Down
8 changes: 6 additions & 2 deletions packages/language/src/generated/ast.ts
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ export function isMemberAccessTarget(item: unknown): item is MemberAccessTarget
return reflection.isInstance(item, MemberAccessTarget);
}

export type ReferenceTarget = DataField | EnumField | FunctionParam;
export type ReferenceTarget = BinaryExpr | DataField | EnumField | FunctionParam;

export const ReferenceTarget = 'ReferenceTarget';

Expand Down Expand Up @@ -256,6 +256,7 @@ export function isAttributeParamType(item: unknown): item is AttributeParamType
export interface BinaryExpr extends langium.AstNode {
readonly $container: Argument | ArrayExpr | AttributeArg | BinaryExpr | FieldInitializer | FunctionDecl | MemberAccessExpr | ReferenceArg | UnaryExpr;
readonly $type: 'BinaryExpr';
binding?: RegularID;
left: Expression;
operator: '!' | '!=' | '&&' | '<' | '<=' | '==' | '>' | '>=' | '?' | '^' | 'in' | '||';
right: Expression;
Expand Down Expand Up @@ -826,7 +827,6 @@ export class ZModelAstReflection extends langium.AbstractAstReflection {
protected override computeIsSubtype(subtype: string, supertype: string): boolean {
switch (subtype) {
case ArrayExpr:
case BinaryExpr:
case MemberAccessExpr:
case NullExpr:
case ObjectExpr:
Expand All @@ -843,6 +843,9 @@ export class ZModelAstReflection extends langium.AbstractAstReflection {
case Procedure: {
return this.isSubtype(AbstractDeclaration, supertype);
}
case BinaryExpr: {
return this.isSubtype(Expression, supertype) || this.isSubtype(ReferenceTarget, supertype);
}
case BooleanLiteral:
case NumberLiteral:
case StringLiteral: {
Expand Down Expand Up @@ -973,6 +976,7 @@ export class ZModelAstReflection extends langium.AbstractAstReflection {
return {
name: BinaryExpr,
properties: [
{ name: 'binding' },
{ name: 'left' },
{ name: 'operator' },
{ name: 'right' }
Expand Down
28 changes: 28 additions & 0 deletions packages/language/src/generated/grammar.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1418,6 +1418,28 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"$type": "Keyword",
"value": "["
},
{
"$type": "Group",
"elements": [
{
"$type": "Assignment",
"feature": "binding",
"operator": "=",
"terminal": {
"$type": "RuleCall",
"rule": {
"$ref": "#/rules@51"
},
"arguments": []
}
},
{
"$type": "Keyword",
"value": ","
}
],
"cardinality": "?"
},
{
"$type": "Assignment",
"feature": "right",
Expand Down Expand Up @@ -3996,6 +4018,12 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"typeRef": {
"$ref": "#/rules@45"
}
},
{
"$type": "SimpleType",
"typeRef": {
"$ref": "#/rules@29/definition/elements@1/elements@0/inferredType"
}
}
]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import {
DataFieldAttribute,
DataModelAttribute,
InternalAttribute,
ReferenceExpr,
isArrayExpr,
isAttribute,
isConfigArrayExpr,
Expand Down Expand Up @@ -491,9 +490,16 @@ function isValidAttributeTarget(attrDecl: Attribute, targetDecl: DataField) {
return true;
}

const fieldTypes = (targetField.args[0].value as ArrayExpr).items.map(
(item) => (item as ReferenceExpr).target.ref?.name,
);
const fieldTypes = (targetField.args[0].value as ArrayExpr).items
.map((item) => {
if (!isReferenceExpr(item)) {
return undefined;
}

const ref = item.target.ref;
return ref && 'name' in ref && typeof ref.name === 'string' ? ref.name : undefined;
})
.filter((name): name is string => !!name);

let allowed = false;
for (const allowedType of fieldTypes) {
Expand Down
8 changes: 5 additions & 3 deletions packages/language/src/zmodel-code-generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -252,13 +252,15 @@ ${ast.fields.map((x) => this.indent + this.generate(x)).join('\n')}${

const { left: isLeftParenthesis, right: isRightParenthesis } = this.isParenthesesNeededForBinaryExpr(ast);

const collectionPredicate = isCollectionPredicate
? `[${ast.binding ? `${ast.binding}, ${rightExpr}` : rightExpr}]`
: rightExpr;

return `${isLeftParenthesis ? '(' : ''}${this.generate(ast.left)}${
isLeftParenthesis ? ')' : ''
}${isCollectionPredicate ? '' : this.binaryExprSpace}${operator}${
isCollectionPredicate ? '' : this.binaryExprSpace
}${isRightParenthesis ? '(' : ''}${
isCollectionPredicate ? `[${rightExpr}]` : rightExpr
}${isRightParenthesis ? ')' : ''}`;
}${isRightParenthesis ? '(' : ''}${collectionPredicate}${isRightParenthesis ? ')' : ''}`;
}

@gen(ReferenceExpr)
Expand Down
37 changes: 29 additions & 8 deletions packages/language/src/zmodel-linker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import {
DataModel,
Enum,
EnumField,
isBinaryExpr,
type ExpressionType,
FunctionDecl,
FunctionParam,
Expand Down Expand Up @@ -121,7 +122,13 @@ export class ZModelLinker extends DefaultLinker {
const target = provider(reference.$refText);
if (target) {
reference._ref = target;
reference._nodeDescription = this.descriptions.createDescription(target, target.name, document);
let targetName = reference.$refText;
if ('name' in target && typeof target.name === 'string') {
targetName = target.name;
} else if ('binding' in target && typeof (target as { binding?: unknown }).binding === 'string') {
targetName = (target as { binding: string }).binding;
}
reference._nodeDescription = this.descriptions.createDescription(target, targetName, document);

// Add the reference to the document's array of references
document.references.push(reference);
Expand Down Expand Up @@ -249,13 +256,24 @@ export class ZModelLinker extends DefaultLinker {

private resolveReference(node: ReferenceExpr, document: LangiumDocument<AstNode>, extraScopes: ScopeProvider[]) {
this.resolveDefault(node, document, extraScopes);

if (node.target.ref) {
// resolve type
if (node.target.ref.$type === EnumField) {
this.resolveToBuiltinTypeOrDecl(node, node.target.ref.$container);
} else {
this.resolveToDeclaredType(node, (node.target.ref as DataField | FunctionParam).type);
const target = node.target.ref;

if (target) {
if (isBinaryExpr(target) && ['?', '!', '^'].includes(target.operator)) {
const collectionType = target.left.$resolvedType;
if (collectionType?.decl) {
node.$resolvedType = {
decl: collectionType.decl,
array: false,
nullable: collectionType.nullable,
};
}
} else if (target.$type === EnumField) {
this.resolveToBuiltinTypeOrDecl(node, target.$container);
} else if (isDataField(target)) {
this.resolveToDeclaredType(node, target.type);
} else if (target.$type === FunctionParam && (target as FunctionParam).type) {
this.resolveToDeclaredType(node, (target as FunctionParam).type);
}
}
}
Expand Down Expand Up @@ -506,6 +524,9 @@ export class ZModelLinker extends DefaultLinker {
//#region Utils

private resolveToDeclaredType(node: AstNode, type: FunctionParamType | DataFieldType) {
if (!type) {
return;
}
let nullable = false;
if (isDataFieldType(type)) {
nullable = type.optional;
Expand Down
21 changes: 21 additions & 0 deletions packages/language/src/zmodel-scope.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {
StreamScope,
UriUtils,
interruptAndCheck,
stream,
type AstNode,
type AstNodeDescription,
type LangiumCoreServices,
Expand All @@ -18,7 +19,9 @@ import {
import { match } from 'ts-pattern';
import {
BinaryExpr,
Expression,
MemberAccessExpr,
isBinaryExpr,
isDataField,
isDataModel,
isEnumField,
Expand Down Expand Up @@ -145,6 +148,9 @@ export class ZModelScopeProvider extends DefaultScopeProvider {
.when(isReferenceExpr, (operand) => {
// operand is a reference, it can only be a model/type-def field
const ref = operand.target.ref;
if (isBinaryExpr(ref) && isCollectionPredicate(ref)) {
return this.createScopeForCollectionElement(ref.left, globalScope, allowTypeDefScope);
}
if (isDataField(ref)) {
return this.createScopeForContainer(ref.type.reference?.ref, globalScope, allowTypeDefScope);
}
Expand Down Expand Up @@ -188,6 +194,21 @@ export class ZModelScopeProvider extends DefaultScopeProvider {
// // typedef's fields are only added to the scope if the access starts with `auth().`
const allowTypeDefScope = isAuthOrAuthMemberAccess(collection);

const collectionScope = this.createScopeForCollectionElement(collection, globalScope, allowTypeDefScope);

if (collectionPredicate.binding) {
const description = this.descriptions.createDescription(
collectionPredicate,
collectionPredicate.binding,
collectionPredicate.$document!,
);
return new StreamScope(stream([description]), collectionScope);
}

return collectionScope;
}

private createScopeForCollectionElement(collection: Expression, globalScope: Scope, allowTypeDefScope: boolean) {
return match(collection)
.when(isReferenceExpr, (expr) => {
// collection is a reference - model or typedef field
Expand Down
4 changes: 2 additions & 2 deletions packages/language/src/zmodel.langium
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ ConfigArrayExpr:
ConfigExpr:
LiteralExpr | InvocationExpr | ConfigArrayExpr;

type ReferenceTarget = FunctionParam | DataField | EnumField;
type ReferenceTarget = FunctionParam | DataField | EnumField | BinaryExpr;

ThisExpr:
value='this';
Expand Down Expand Up @@ -113,7 +113,7 @@ CollectionPredicateExpr infers Expression:
MemberAccessExpr (
{infer BinaryExpr.left=current}
operator=('?'|'!'|'^')
'[' right=Expression ']'
'[' (binding=RegularID ',')? right=Expression ']'
)*;

InExpr infers Expression:
Expand Down
44 changes: 44 additions & 0 deletions packages/language/test/expression-validation.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,48 @@ describe('Expression Validation Tests', () => {
'incompatible operand types',
);
});

it('should allow collection predicate with iterator binding', async () => {
await loadSchema(`
datasource db {
provider = 'sqlite'
url = 'file:./dev.db'
}

model User {
id Int @id
memberships Membership[]
@@allow('read', memberships?[m, m.tenantId == id])
}

model Membership {
id Int @id
tenantId Int
user User @relation(fields: [userId], references: [id])
userId Int
}
`);
});

it('should keep supporting unbound collection predicate syntax', async () => {
await loadSchema(`
datasource db {
provider = 'sqlite'
url = 'file:./dev.db'
}

model User {
id Int @id
memberships Membership[]
@@allow('read', memberships?[tenantId == id])
}

model Membership {
id Int @id
tenantId Int
user User @relation(fields: [userId], references: [id])
userId Int
}
`);
});
});
31 changes: 29 additions & 2 deletions packages/plugins/policy/src/expression-evaluator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
type ExpressionEvaluatorContext = {
auth?: any;
thisValue?: any;
scope?: Record<string, any>;
};

/**
Expand Down Expand Up @@ -64,6 +65,9 @@ export class ExpressionEvaluator {
}

private evaluateField(expr: FieldExpression, context: ExpressionEvaluatorContext): any {
if (context.scope && expr.field in context.scope) {
return context.scope[expr.field];
}
return context.thisValue?.[expr.field];
}

Expand Down Expand Up @@ -113,15 +117,38 @@ export class ExpressionEvaluator {
invariant(Array.isArray(left), 'expected array');

return match(op)
.with('?', () => left.some((item: any) => this.evaluate(expr.right, { ...context, thisValue: item })))
.with('!', () => left.every((item: any) => this.evaluate(expr.right, { ...context, thisValue: item })))
.with('?', () =>
left.some((item: any) =>
this.evaluate(expr.right, {
...context,
thisValue: item,
scope: expr.binding
? { ...(context.scope ?? {}), [expr.binding]: item }
: context.scope,
}),
),
)
.with('!', () =>
left.every((item: any) =>
this.evaluate(expr.right, {
...context,
thisValue: item,
scope: expr.binding
? { ...(context.scope ?? {}), [expr.binding]: item }
: context.scope,
}),
),
)
.with(
'^',
() =>
!left.some((item: any) =>
this.evaluate(expr.right, {
...context,
thisValue: item,
scope: expr.binding
? { ...(context.scope ?? {}), [expr.binding]: item }
: context.scope,
}),
),
)
Expand Down
Loading
Loading