diff --git a/packages/language/src/ast.ts b/packages/language/src/ast.ts index 71d31d4d..c00cf647 100644 --- a/packages/language/src/ast.ts +++ b/packages/language/src/ast.ts @@ -46,6 +46,13 @@ declare module './ast' { $resolvedParam?: AttributeParam; } + interface BinaryExpr { + /** + * Optional iterator binding for collection predicates + */ + binding?: string; + } + export interface DataModel { /** * All fields including those marked with `@ignore` diff --git a/packages/language/src/generated/ast.ts b/packages/language/src/generated/ast.ts index e759aa1f..54a859ad 100644 --- a/packages/language/src/generated/ast.ts +++ b/packages/language/src/generated/ast.ts @@ -142,7 +142,7 @@ export function isMemberAccessTarget(item: unknown): item is MemberAccessTarget return reflection.isInstance(item, MemberAccessTarget); } -export type ReferenceTarget = DataField | EnumField | FunctionParam; +export type ReferenceTarget = BinaryExpr | DataField | EnumField | FunctionParam; export const ReferenceTarget = 'ReferenceTarget'; @@ -256,6 +256,7 @@ export function isAttributeParamType(item: unknown): item is AttributeParamType export interface BinaryExpr extends langium.AstNode { readonly $container: Argument | ArrayExpr | AttributeArg | BinaryExpr | FieldInitializer | FunctionDecl | MemberAccessExpr | ReferenceArg | UnaryExpr; readonly $type: 'BinaryExpr'; + binding?: RegularID; left: Expression; operator: '!' | '!=' | '&&' | '<' | '<=' | '==' | '>' | '>=' | '?' | '^' | 'in' | '||'; right: Expression; @@ -826,7 +827,6 @@ export class ZModelAstReflection extends langium.AbstractAstReflection { protected override computeIsSubtype(subtype: string, supertype: string): boolean { switch (subtype) { case ArrayExpr: - case BinaryExpr: case MemberAccessExpr: case NullExpr: case ObjectExpr: @@ -843,6 +843,9 @@ export class ZModelAstReflection extends langium.AbstractAstReflection { case Procedure: { return this.isSubtype(AbstractDeclaration, supertype); } + case BinaryExpr: { + return this.isSubtype(Expression, supertype) || this.isSubtype(ReferenceTarget, supertype); + } case BooleanLiteral: case NumberLiteral: case StringLiteral: { @@ -973,6 +976,7 @@ export class ZModelAstReflection extends langium.AbstractAstReflection { return { name: BinaryExpr, properties: [ + { name: 'binding' }, { name: 'left' }, { name: 'operator' }, { name: 'right' } diff --git a/packages/language/src/generated/grammar.ts b/packages/language/src/generated/grammar.ts index 02260ccd..6be9b88d 100644 --- a/packages/language/src/generated/grammar.ts +++ b/packages/language/src/generated/grammar.ts @@ -1418,6 +1418,28 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "$type": "Keyword", "value": "[" }, + { + "$type": "Group", + "elements": [ + { + "$type": "Assignment", + "feature": "binding", + "operator": "=", + "terminal": { + "$type": "RuleCall", + "rule": { + "$ref": "#/rules@51" + }, + "arguments": [] + } + }, + { + "$type": "Keyword", + "value": "," + } + ], + "cardinality": "?" + }, { "$type": "Assignment", "feature": "right", @@ -3996,6 +4018,12 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "typeRef": { "$ref": "#/rules@45" } + }, + { + "$type": "SimpleType", + "typeRef": { + "$ref": "#/rules@29/definition/elements@1/elements@0/inferredType" + } } ] } diff --git a/packages/language/src/validators/attribute-application-validator.ts b/packages/language/src/validators/attribute-application-validator.ts index 62df3a23..09124178 100644 --- a/packages/language/src/validators/attribute-application-validator.ts +++ b/packages/language/src/validators/attribute-application-validator.ts @@ -11,7 +11,6 @@ import { DataFieldAttribute, DataModelAttribute, InternalAttribute, - ReferenceExpr, isArrayExpr, isAttribute, isConfigArrayExpr, @@ -491,9 +490,16 @@ function isValidAttributeTarget(attrDecl: Attribute, targetDecl: DataField) { return true; } - const fieldTypes = (targetField.args[0].value as ArrayExpr).items.map( - (item) => (item as ReferenceExpr).target.ref?.name, - ); + const fieldTypes = (targetField.args[0].value as ArrayExpr).items + .map((item) => { + if (!isReferenceExpr(item)) { + return undefined; + } + + const ref = item.target.ref; + return ref && 'name' in ref && typeof ref.name === 'string' ? ref.name : undefined; + }) + .filter((name): name is string => !!name); let allowed = false; for (const allowedType of fieldTypes) { diff --git a/packages/language/src/zmodel-code-generator.ts b/packages/language/src/zmodel-code-generator.ts index 55efb5fc..1e0366ed 100644 --- a/packages/language/src/zmodel-code-generator.ts +++ b/packages/language/src/zmodel-code-generator.ts @@ -252,13 +252,15 @@ ${ast.fields.map((x) => this.indent + this.generate(x)).join('\n')}${ const { left: isLeftParenthesis, right: isRightParenthesis } = this.isParenthesesNeededForBinaryExpr(ast); + const collectionPredicate = isCollectionPredicate + ? `[${ast.binding ? `${ast.binding}, ${rightExpr}` : rightExpr}]` + : rightExpr; + return `${isLeftParenthesis ? '(' : ''}${this.generate(ast.left)}${ isLeftParenthesis ? ')' : '' }${isCollectionPredicate ? '' : this.binaryExprSpace}${operator}${ isCollectionPredicate ? '' : this.binaryExprSpace - }${isRightParenthesis ? '(' : ''}${ - isCollectionPredicate ? `[${rightExpr}]` : rightExpr - }${isRightParenthesis ? ')' : ''}`; + }${isRightParenthesis ? '(' : ''}${collectionPredicate}${isRightParenthesis ? ')' : ''}`; } @gen(ReferenceExpr) diff --git a/packages/language/src/zmodel-linker.ts b/packages/language/src/zmodel-linker.ts index 3bb45134..ba8d9bf5 100644 --- a/packages/language/src/zmodel-linker.ts +++ b/packages/language/src/zmodel-linker.ts @@ -25,6 +25,7 @@ import { DataModel, Enum, EnumField, + isBinaryExpr, type ExpressionType, FunctionDecl, FunctionParam, @@ -121,7 +122,13 @@ export class ZModelLinker extends DefaultLinker { const target = provider(reference.$refText); if (target) { reference._ref = target; - reference._nodeDescription = this.descriptions.createDescription(target, target.name, document); + let targetName = reference.$refText; + if ('name' in target && typeof target.name === 'string') { + targetName = target.name; + } else if ('binding' in target && typeof (target as { binding?: unknown }).binding === 'string') { + targetName = (target as { binding: string }).binding; + } + reference._nodeDescription = this.descriptions.createDescription(target, targetName, document); // Add the reference to the document's array of references document.references.push(reference); @@ -249,13 +256,24 @@ export class ZModelLinker extends DefaultLinker { private resolveReference(node: ReferenceExpr, document: LangiumDocument, extraScopes: ScopeProvider[]) { this.resolveDefault(node, document, extraScopes); - - if (node.target.ref) { - // resolve type - if (node.target.ref.$type === EnumField) { - this.resolveToBuiltinTypeOrDecl(node, node.target.ref.$container); - } else { - this.resolveToDeclaredType(node, (node.target.ref as DataField | FunctionParam).type); + const target = node.target.ref; + + if (target) { + if (isBinaryExpr(target) && ['?', '!', '^'].includes(target.operator)) { + const collectionType = target.left.$resolvedType; + if (collectionType?.decl) { + node.$resolvedType = { + decl: collectionType.decl, + array: false, + nullable: collectionType.nullable, + }; + } + } else if (target.$type === EnumField) { + this.resolveToBuiltinTypeOrDecl(node, target.$container); + } else if (isDataField(target)) { + this.resolveToDeclaredType(node, target.type); + } else if (target.$type === FunctionParam && (target as FunctionParam).type) { + this.resolveToDeclaredType(node, (target as FunctionParam).type); } } } @@ -506,6 +524,9 @@ export class ZModelLinker extends DefaultLinker { //#region Utils private resolveToDeclaredType(node: AstNode, type: FunctionParamType | DataFieldType) { + if (!type) { + return; + } let nullable = false; if (isDataFieldType(type)) { nullable = type.optional; diff --git a/packages/language/src/zmodel-scope.ts b/packages/language/src/zmodel-scope.ts index 6fd866f0..4bd4c830 100644 --- a/packages/language/src/zmodel-scope.ts +++ b/packages/language/src/zmodel-scope.ts @@ -7,6 +7,7 @@ import { StreamScope, UriUtils, interruptAndCheck, + stream, type AstNode, type AstNodeDescription, type LangiumCoreServices, @@ -18,7 +19,9 @@ import { import { match } from 'ts-pattern'; import { BinaryExpr, + Expression, MemberAccessExpr, + isBinaryExpr, isDataField, isDataModel, isEnumField, @@ -145,6 +148,9 @@ export class ZModelScopeProvider extends DefaultScopeProvider { .when(isReferenceExpr, (operand) => { // operand is a reference, it can only be a model/type-def field const ref = operand.target.ref; + if (isBinaryExpr(ref) && isCollectionPredicate(ref)) { + return this.createScopeForCollectionElement(ref.left, globalScope, allowTypeDefScope); + } if (isDataField(ref)) { return this.createScopeForContainer(ref.type.reference?.ref, globalScope, allowTypeDefScope); } @@ -188,6 +194,21 @@ export class ZModelScopeProvider extends DefaultScopeProvider { // // typedef's fields are only added to the scope if the access starts with `auth().` const allowTypeDefScope = isAuthOrAuthMemberAccess(collection); + const collectionScope = this.createScopeForCollectionElement(collection, globalScope, allowTypeDefScope); + + if (collectionPredicate.binding) { + const description = this.descriptions.createDescription( + collectionPredicate, + collectionPredicate.binding, + collectionPredicate.$document!, + ); + return new StreamScope(stream([description]), collectionScope); + } + + return collectionScope; + } + + private createScopeForCollectionElement(collection: Expression, globalScope: Scope, allowTypeDefScope: boolean) { return match(collection) .when(isReferenceExpr, (expr) => { // collection is a reference - model or typedef field diff --git a/packages/language/src/zmodel.langium b/packages/language/src/zmodel.langium index 8d279787..a80c0a1f 100644 --- a/packages/language/src/zmodel.langium +++ b/packages/language/src/zmodel.langium @@ -66,7 +66,7 @@ ConfigArrayExpr: ConfigExpr: LiteralExpr | InvocationExpr | ConfigArrayExpr; -type ReferenceTarget = FunctionParam | DataField | EnumField; +type ReferenceTarget = FunctionParam | DataField | EnumField | BinaryExpr; ThisExpr: value='this'; @@ -113,7 +113,7 @@ CollectionPredicateExpr infers Expression: MemberAccessExpr ( {infer BinaryExpr.left=current} operator=('?'|'!'|'^') - '[' right=Expression ']' + '[' (binding=RegularID ',')? right=Expression ']' )*; InExpr infers Expression: diff --git a/packages/language/test/expression-validation.test.ts b/packages/language/test/expression-validation.test.ts index 100f02b2..7976c9e9 100644 --- a/packages/language/test/expression-validation.test.ts +++ b/packages/language/test/expression-validation.test.ts @@ -98,4 +98,48 @@ describe('Expression Validation Tests', () => { 'incompatible operand types', ); }); + + it('should allow collection predicate with iterator binding', async () => { + await loadSchema(` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + + model User { + id Int @id + memberships Membership[] + @@allow('read', memberships?[m, m.tenantId == id]) + } + + model Membership { + id Int @id + tenantId Int + user User @relation(fields: [userId], references: [id]) + userId Int + } + `); + }); + + it('should keep supporting unbound collection predicate syntax', async () => { + await loadSchema(` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + + model User { + id Int @id + memberships Membership[] + @@allow('read', memberships?[tenantId == id]) + } + + model Membership { + id Int @id + tenantId Int + user User @relation(fields: [userId], references: [id]) + userId Int + } + `); + }); }); diff --git a/packages/plugins/policy/src/expression-evaluator.ts b/packages/plugins/policy/src/expression-evaluator.ts index 45c7b855..d1e1ebe0 100644 --- a/packages/plugins/policy/src/expression-evaluator.ts +++ b/packages/plugins/policy/src/expression-evaluator.ts @@ -15,6 +15,7 @@ import { type ExpressionEvaluatorContext = { auth?: any; thisValue?: any; + scope?: Record; }; /** @@ -64,6 +65,9 @@ export class ExpressionEvaluator { } private evaluateField(expr: FieldExpression, context: ExpressionEvaluatorContext): any { + if (context.scope && expr.field in context.scope) { + return context.scope[expr.field]; + } return context.thisValue?.[expr.field]; } @@ -113,8 +117,28 @@ export class ExpressionEvaluator { invariant(Array.isArray(left), 'expected array'); return match(op) - .with('?', () => left.some((item: any) => this.evaluate(expr.right, { ...context, thisValue: item }))) - .with('!', () => left.every((item: any) => this.evaluate(expr.right, { ...context, thisValue: item }))) + .with('?', () => + left.some((item: any) => + this.evaluate(expr.right, { + ...context, + thisValue: item, + scope: expr.binding + ? { ...(context.scope ?? {}), [expr.binding]: item } + : context.scope, + }), + ), + ) + .with('!', () => + left.every((item: any) => + this.evaluate(expr.right, { + ...context, + thisValue: item, + scope: expr.binding + ? { ...(context.scope ?? {}), [expr.binding]: item } + : context.scope, + }), + ), + ) .with( '^', () => @@ -122,6 +146,9 @@ export class ExpressionEvaluator { this.evaluate(expr.right, { ...context, thisValue: item, + scope: expr.binding + ? { ...(context.scope ?? {}), [expr.binding]: item } + : context.scope, }), ), ) diff --git a/packages/plugins/policy/src/expression-transformer.ts b/packages/plugins/policy/src/expression-transformer.ts index 0ea84a97..ce3f6d0e 100644 --- a/packages/plugins/policy/src/expression-transformer.ts +++ b/packages/plugins/policy/src/expression-transformer.ts @@ -58,6 +58,8 @@ import { trueNode, } from './utils'; +type BindingScope = Record; + /** * Context for transforming a policy expression */ @@ -92,6 +94,11 @@ export type ExpressionTransformerContext = { */ contextValue?: Record; + /** + * Additional named bindings available during transformation + */ + scope?: BindingScope; + /** * The model or type name that `this` keyword refers to */ @@ -310,7 +317,11 @@ export class ExpressionTransformer { // LHS of the expression is evaluated as a value const evaluator = new ExpressionEvaluator(); - const receiver = evaluator.evaluate(expr.left, { thisValue: context.contextValue, auth: this.auth }); + const receiver = evaluator.evaluate(expr.left, { + thisValue: context.contextValue, + auth: this.auth, + scope: this.getEvaluationScope(context.scope), + }); // get LHS's type const baseType = this.isAuthMember(expr.left) ? this.authType : context.modelOrType; @@ -345,10 +356,18 @@ export class ExpressionTransformer { } } + const bindingScope = expr.binding + ? { + ...(context.scope ?? {}), + [expr.binding]: { type: newContextModel, alias: context.alias ?? newContextModel }, + } + : context.scope; + let predicateFilter = this.transform(expr.right, { ...context, modelOrType: newContextModel, alias: undefined, + scope: bindingScope, }); if (expr.op === '!') { @@ -391,6 +410,7 @@ export class ExpressionTransformer { const value = new ExpressionEvaluator().evaluate(expr, { auth: this.auth, thisValue: context.contextValue, + scope: this.getEvaluationScope(context.scope), }); return this.transformValue(value, 'Boolean'); } else { @@ -402,15 +422,20 @@ export class ExpressionTransformer { // e.g.: `auth().profiles[age == this.age]`, each `auth().profiles` element (which is a value) // is used to build an expression for the RHS `age == this.age` // the transformation happens recursively for nested collection predicates - const components = receiver.map((item) => - this.transform(expr.right, { + const components = receiver.map((item) => { + const bindingScope = expr.binding + ? { ...(context.scope ?? {}), [expr.binding]: { type: context.modelOrType, value: item } } + : context.scope; + + return this.transform(expr.right, { operation: context.operation, thisType: context.thisType, thisAlias: context.thisAlias, modelOrType: context.modelOrType, contextValue: item, - }), - ); + scope: bindingScope, + }); + }); // compose the components based on the operator return ( @@ -600,6 +625,25 @@ export class ExpressionTransformer { @expr('member') // @ts-ignore private _member(expr: MemberExpression, context: ExpressionTransformerContext) { + const bindingReceiver = + ExpressionUtils.isField(expr.receiver) && context.scope ? context.scope[expr.receiver.field] : undefined; + + if (bindingReceiver) { + if (bindingReceiver.value !== undefined) { + return this.valueMemberAccess(bindingReceiver.value, expr, bindingReceiver.type); + } + + const rewritten = ExpressionUtils.member(ExpressionUtils._this(), expr.members); + return this._member(rewritten, { + ...context, + modelOrType: bindingReceiver.type, + alias: bindingReceiver.alias ?? bindingReceiver.type, + thisType: bindingReceiver.type, + thisAlias: bindingReceiver.alias ?? bindingReceiver.type, + contextValue: bindingReceiver.value, + }); + } + // `auth()` member access if (this.isAuthCall(expr.receiver)) { return this.valueMemberAccess(this.auth, expr, this.authType); @@ -833,6 +877,21 @@ export class ExpressionTransformer { return this.buildDelegateBaseFieldSelect(context.modelOrType, tableName, column, fieldDef.originModel); } + private getEvaluationScope(scope?: BindingScope) { + if (!scope) { + return undefined; + } + + const result: Record = {}; + for (const [key, value] of Object.entries(scope)) { + if (value.value !== undefined) { + result[key] = value.value; + } + } + + return Object.keys(result).length > 0 ? result : undefined; + } + private buildDelegateBaseFieldSelect(model: string, modelAlias: string, field: string, baseModel: string) { const idFields = QueryUtils.requireIdFields(this.client.$schema, model); return { diff --git a/packages/schema/src/expression-utils.ts b/packages/schema/src/expression-utils.ts index ee48aecc..f7bd526d 100644 --- a/packages/schema/src/expression-utils.ts +++ b/packages/schema/src/expression-utils.ts @@ -39,12 +39,13 @@ export const ExpressionUtils = { }; }, - binary: (left: Expression, op: BinaryOperator, right: Expression): BinaryExpression => { + binary: (left: Expression, op: BinaryOperator, right: Expression, binding?: string): BinaryExpression => { return { kind: 'binary', op, left, right, + binding, }; }, diff --git a/packages/schema/src/expression.ts b/packages/schema/src/expression.ts index 3ce3c2d1..b3bb9c40 100644 --- a/packages/schema/src/expression.ts +++ b/packages/schema/src/expression.ts @@ -41,6 +41,7 @@ export type BinaryExpression = { op: BinaryOperator; left: Expression; right: Expression; + binding?: string; }; export type CallExpression = { diff --git a/packages/sdk/src/prisma/prisma-schema-generator.ts b/packages/sdk/src/prisma/prisma-schema-generator.ts index 553658ad..78a132c8 100644 --- a/packages/sdk/src/prisma/prisma-schema-generator.ts +++ b/packages/sdk/src/prisma/prisma-schema-generator.ts @@ -15,6 +15,7 @@ import { Enum, EnumField, Expression, + isBinaryExpr, GeneratorDecl, InvocationExpr, isArrayExpr, @@ -352,10 +353,20 @@ export class PrismaSchemaGenerator { new Array(...node.items.map((item) => this.makeAttributeArgValue(item))), ); } else if (isReferenceExpr(node)) { + const ref = node.target.ref!; + const refName = + ('name' in ref && typeof (ref as { name?: unknown }).name === 'string') + ? (ref as { name: string }).name + : isBinaryExpr(ref) && typeof ref.binding === 'string' + ? ref.binding + : undefined; + if (!refName) { + throw Error(`Unsupported reference expression target: ${ref.$type}`); + } return new PrismaAttributeArgValue( 'FieldReference', new PrismaFieldReference( - node.target.ref!.name, + refName, node.args.map((arg) => new PrismaFieldReferenceArg(arg.name, this.exprToText(arg.value))), ), ); diff --git a/packages/sdk/src/ts-schema-generator.ts b/packages/sdk/src/ts-schema-generator.ts index f68bb0bc..325926ac 100644 --- a/packages/sdk/src/ts-schema-generator.ts +++ b/packages/sdk/src/ts-schema-generator.ts @@ -1271,11 +1271,17 @@ export class TsSchemaGenerator { } private createBinaryExpression(expr: BinaryExpr) { - return this.createExpressionUtilsCall('binary', [ + const args = [ this.createExpression(expr.left), this.createLiteralNode(expr.operator), this.createExpression(expr.right), - ]); + ]; + + if (expr.binding) { + args.push(this.createLiteralNode(expr.binding)); + } + + return this.createExpressionUtilsCall('binary', args); } private createUnaryExpression(expr: UnaryExpr) { @@ -1292,13 +1298,28 @@ export class TsSchemaGenerator { } private createRefExpression(expr: ReferenceExpr): any { - if (isDataField(expr.target.ref)) { + const target = expr.target.ref; + if (isDataField(target)) { return this.createExpressionUtilsCall('field', [this.createLiteralNode(expr.target.$refText)]); - } else if (isEnumField(expr.target.ref)) { + } + + if (isEnumField(target)) { return this.createLiteralExpression('StringLiteral', expr.target.$refText); - } else { - throw new Error(`Unsupported reference type: ${expr.target.$refText}`); } + + const refName = + target && 'name' in target && typeof (target as { name?: unknown }).name === 'string' + ? (target as { name: string }).name + : isBinaryExpr(target) && typeof target.binding === 'string' + ? target.binding + : undefined; + + if (refName) { + return this.createExpressionUtilsCall('field', [this.createLiteralNode(refName)]); + } + + // Fallback: treat unknown reference targets (e.g. unresolved iterator bindings) as named fields + return this.createExpressionUtilsCall('field', [this.createLiteralNode(expr.target.$refText)]); } private createCallExpression(expr: InvocationExpr) { diff --git a/tests/e2e/orm/policy/auth-access.test.ts b/tests/e2e/orm/policy/auth-access.test.ts index b994324f..76e0c9f3 100644 --- a/tests/e2e/orm/policy/auth-access.test.ts +++ b/tests/e2e/orm/policy/auth-access.test.ts @@ -130,6 +130,50 @@ model Foo { await expect(db.$setAuth({ profiles: [{ age: 15 }, { age: 20 }] }).foo.findFirst()).toResolveTruthy(); }); + it('uses iterator binding inside collection predicate for auth model', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + tenantId Int + memberships Membership[] @relation("UserMemberships") +} + +model Membership { + id Int @id + tenantId Int + userId Int + user User @relation("UserMemberships", fields: [userId], references: [id]) +} + +model Foo { + id Int @id + tenantId Int + @@allow('read', auth().memberships?[m, m.tenantId == this.tenantId]) +} +`, + ); + + await db.$unuseAll().foo.createMany({ + data: [ + { id: 1, tenantId: 1 }, + { id: 2, tenantId: 2 }, + ], + }); + + // allowed because iterator binding matches tenantId = 1 + await expect( + db.$setAuth({ tenantId: 1, memberships: [{ id: 10, tenantId: 1 }] }).foo.findMany(), + ).resolves.toEqual([ + { id: 1, tenantId: 1 }, + ]); + + // denied because membership tenantId doesn't match + await expect( + db.$setAuth({ tenantId: 1, memberships: [{ id: 20, tenantId: 3 }] }).foo.findMany(), + ).resolves.toEqual([]); + }); + it('works with shallow auth model collection predicates involving fields - some', async () => { const db = await createPolicyTestClient( `