Skip to content

Commit 023f192

Browse files
authored
feat(policy): support arbitrary collection traversal from auth() (#439)
* feat(policy): support arbitrary collection traversal from `auth()` * address PR comments
1 parent 7e57864 commit 023f192

File tree

11 files changed

+746
-80
lines changed

11 files changed

+746
-80
lines changed

TODO.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
- [ ] CLI
44
- [x] generate
55
- [x] migrate
6-
- [ ] db
6+
- [x] db
77
- [x] push
8-
- [ ] seed
8+
- [x] seed
99
- [x] info
1010
- [x] init
1111
- [x] validate
12-
- [ ] format
12+
- [x] format
1313
- [ ] repl
1414
- [x] plugin mechanism
1515
- [x] built-in plugins

packages/language/src/validators/function-invocation-validator.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,17 @@ export default class FunctionInvocationValidator implements AstValidator<Express
179179
return true;
180180
}
181181

182+
@func('auth')
183+
private _checkAuth(expr: InvocationExpr, accept: ValidationAcceptor) {
184+
if (!expr.$resolvedType) {
185+
accept(
186+
'error',
187+
'cannot resolve `auth()` - make sure you have a model or type with `@auth` attribute or named "User"',
188+
{ node: expr },
189+
);
190+
}
191+
}
192+
182193
@func('length')
183194
private _checkLength(expr: InvocationExpr, accept: ValidationAcceptor) {
184195
const msg = 'argument must be a string or list field';

packages/language/src/zmodel-scope.ts

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import {
3636
getAuthDecl,
3737
getRecursiveBases,
3838
isAuthInvocation,
39+
isAuthOrAuthMemberAccess,
3940
isBeforeInvocation,
4041
isCollectionPredicate,
4142
resolveImportUri,
@@ -138,8 +139,7 @@ export class ZModelScopeProvider extends DefaultScopeProvider {
138139
// typedef's fields are only added to the scope if the access starts with `auth().`
139140
// or the member access resides inside a typedef
140141
const allowTypeDefScope =
141-
// isAuthOrAuthMemberAccess(node.operand) ||
142-
!!AstUtils.getContainerOfType(node, isTypeDef);
142+
isAuthOrAuthMemberAccess(node.operand) || !!AstUtils.getContainerOfType(node, isTypeDef);
143143

144144
return match(node.operand)
145145
.when(isReferenceExpr, (operand) => {
@@ -184,10 +184,9 @@ export class ZModelScopeProvider extends DefaultScopeProvider {
184184
const globalScope = this.getGlobalScope(referenceType, context);
185185
const collection = collectionPredicate.left;
186186

187-
// TODO: generalize it
187+
// TODO: full support of typedef member access
188188
// // typedef's fields are only added to the scope if the access starts with `auth().`
189-
// const allowTypeDefScope = isAuthOrAuthMemberAccess(collection);
190-
const allowTypeDefScope = false;
189+
const allowTypeDefScope = isAuthOrAuthMemberAccess(collection);
191190

192191
return match(collection)
193192
.when(isReferenceExpr, (expr) => {

packages/orm/src/utils/schema-utils.ts

Lines changed: 55 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@ import type {
1212
UnaryExpression,
1313
} from '../schema';
1414

15+
export type VisitResult = void | { abort: true };
16+
1517
export class ExpressionVisitor {
16-
visit(expr: Expression): void {
17-
match(expr)
18+
visit(expr: Expression): VisitResult {
19+
return match(expr)
1820
.with({ kind: 'literal' }, (e) => this.visitLiteral(e))
1921
.with({ kind: 'array' }, (e) => this.visitArray(e))
2022
.with({ kind: 'field' }, (e) => this.visitField(e))
@@ -27,32 +29,68 @@ export class ExpressionVisitor {
2729
.exhaustive();
2830
}
2931

30-
protected visitLiteral(_e: LiteralExpression) {}
32+
protected visitLiteral(_e: LiteralExpression): VisitResult {}
3133

32-
protected visitArray(e: ArrayExpression) {
33-
e.items.forEach((item) => this.visit(item));
34+
protected visitArray(e: ArrayExpression): VisitResult {
35+
for (const item of e.items) {
36+
const result = this.visit(item);
37+
if (result?.abort) {
38+
return result;
39+
}
40+
}
3441
}
3542

36-
protected visitField(_e: FieldExpression) {}
43+
protected visitField(_e: FieldExpression): VisitResult {}
3744

38-
protected visitMember(e: MemberExpression) {
39-
this.visit(e.receiver);
45+
protected visitMember(e: MemberExpression): VisitResult {
46+
return this.visit(e.receiver);
4047
}
4148

42-
protected visitBinary(e: BinaryExpression) {
43-
this.visit(e.left);
44-
this.visit(e.right);
49+
protected visitBinary(e: BinaryExpression): VisitResult {
50+
const l = this.visit(e.left);
51+
if (l?.abort) {
52+
return l;
53+
} else {
54+
return this.visit(e.right);
55+
}
4556
}
4657

47-
protected visitUnary(e: UnaryExpression) {
48-
this.visit(e.operand);
58+
protected visitUnary(e: UnaryExpression): VisitResult {
59+
return this.visit(e.operand);
4960
}
5061

51-
protected visitCall(e: CallExpression) {
52-
e.args?.forEach((arg) => this.visit(arg));
62+
protected visitCall(e: CallExpression): VisitResult {
63+
for (const arg of e.args ?? []) {
64+
const r = this.visit(arg);
65+
if (r?.abort) {
66+
return r;
67+
}
68+
}
5369
}
5470

55-
protected visitThis(_e: ThisExpression) {}
71+
protected visitThis(_e: ThisExpression): VisitResult {}
72+
73+
protected visitNull(_e: NullExpression): VisitResult {}
74+
}
5675

57-
protected visitNull(_e: NullExpression) {}
76+
export class MatchingExpressionVisitor extends ExpressionVisitor {
77+
private found = false;
78+
79+
constructor(private predicate: (expr: Expression) => boolean) {
80+
super();
81+
}
82+
83+
find(expr: Expression) {
84+
this.visit(expr);
85+
return this.found;
86+
}
87+
88+
override visit(expr: Expression) {
89+
if (this.predicate(expr)) {
90+
this.found = true;
91+
return { abort: true } as const;
92+
} else {
93+
return super.visit(expr);
94+
}
95+
}
5896
}

packages/plugins/policy/src/expression-evaluator.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,11 @@ export class ExpressionEvaluator {
7979
const left = this.evaluate(expr.left, context);
8080
const right = this.evaluate(expr.right, context);
8181

82+
if (!['==', '!='].includes(expr.op) && (left === null || right === null)) {
83+
// non-equality comparison with null always yields null (follow SQL logic)
84+
return null;
85+
}
86+
8287
return match(expr.op)
8388
.with('==', () => left === right)
8489
.with('!=', () => left !== right)
@@ -102,7 +107,7 @@ export class ExpressionEvaluator {
102107

103108
const left = this.evaluate(expr.left, context);
104109
if (!left) {
105-
return false;
110+
return null;
106111
}
107112

108113
invariant(Array.isArray(left), 'expected array');

0 commit comments

Comments
 (0)