Skip to content

Commit aae9b60

Browse files
authored
fix(policy): allow auth(). calls in filter functions (#1771)
1 parent 374e962 commit aae9b60

File tree

6 files changed

+154
-14
lines changed

6 files changed

+154
-14
lines changed

packages/schema/src/language-server/validator/expression-validator.ts

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import {
2424
import { ValidationAcceptor, streamAst } from 'langium';
2525
import { findUpAst, getContainingDataModel } from '../../utils/ast-utils';
2626
import { AstValidator } from '../types';
27-
import { typeAssignable } from './utils';
27+
import { isAuthOrAuthMemberAccess, typeAssignable } from './utils';
2828

2929
/**
3030
* Validates expressions.
@@ -296,13 +296,9 @@ export default class ExpressionValidator implements AstValidator<Expression> {
296296
// null
297297
isNullExpr(expr) ||
298298
// `auth()` access
299-
this.isAuthOrAuthMemberAccess(expr) ||
299+
isAuthOrAuthMemberAccess(expr) ||
300300
// array
301301
(isArrayExpr(expr) && expr.items.every((item) => this.isNotModelFieldExpr(item)))
302302
);
303303
}
304-
305-
private isAuthOrAuthMemberAccess(expr: Expression) {
306-
return isAuthInvocation(expr) || (isMemberAccessExpr(expr) && isAuthInvocation(expr.operand));
307-
}
308304
}

packages/schema/src/language-server/validator/function-invocation-validator.ts

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import { AstNode, streamAst, ValidationAcceptor } from 'langium';
2626
import { match, P } from 'ts-pattern';
2727
import { isCheckInvocation } from '../../utils/ast-utils';
2828
import { AstValidator } from '../types';
29-
import { typeAssignable } from './utils';
29+
import { isAuthOrAuthMemberAccess, typeAssignable } from './utils';
3030

3131
// a registry of function handlers marked with @func
3232
const invocationCheckers = new Map<string, PropertyDescriptor>();
@@ -109,15 +109,24 @@ export default class FunctionInvocationValidator implements AstValidator<Express
109109
!isLiteralExpr(secondArg) &&
110110
// enum field
111111
!isEnumFieldReference(secondArg) &&
112+
// `auth()...` expression
113+
!isAuthOrAuthMemberAccess(secondArg) &&
112114
// array of literal/enum
113115
!(
114116
isArrayExpr(secondArg) &&
115-
secondArg.items.every((item) => isLiteralExpr(item) || isEnumFieldReference(item))
117+
secondArg.items.every(
118+
(item) =>
119+
isLiteralExpr(item) || isEnumFieldReference(item) || isAuthOrAuthMemberAccess(item)
120+
)
116121
)
117122
) {
118-
accept('error', 'second argument must be a literal, an enum, or an array of them', {
119-
node: secondArg,
120-
});
123+
accept(
124+
'error',
125+
'second argument must be a literal, an enum, an expression starting with `auth().`, or an array of them',
126+
{
127+
node: secondArg,
128+
}
129+
);
121130
}
122131
}
123132
}

packages/schema/src/language-server/validator/utils.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@ import {
1010
isArrayExpr,
1111
isDataModelField,
1212
isEnum,
13+
isMemberAccessExpr,
1314
isReferenceExpr,
1415
isStringLiteral,
1516
} from '@zenstackhq/language/ast';
16-
import { resolved } from '@zenstackhq/sdk';
17+
import { isAuthInvocation, resolved } from '@zenstackhq/sdk';
1718
import { AstNode, ValidationAcceptor } from 'langium';
1819

1920
/**
@@ -181,3 +182,7 @@ export function assignableToAttributeParam(
181182
return (dstRef?.ref === argResolvedType.decl || dstType === 'Any') && dstIsArray === argResolvedType.array;
182183
}
183184
}
185+
186+
export function isAuthOrAuthMemberAccess(expr: Expression): boolean {
187+
return isAuthInvocation(expr) || (isMemberAccessExpr(expr) && isAuthOrAuthMemberAccess(expr.operand));
188+
}

packages/schema/tests/schema/validation/attribute-validation.test.ts

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,11 @@ describe('Attribute tests', () => {
816816
E2
817817
}
818818
819+
model User {
820+
id String @id
821+
e E
822+
}
823+
819824
model N {
820825
id String @id
821826
e E
@@ -840,6 +845,7 @@ describe('Attribute tests', () => {
840845
@@allow('all', startsWith(s, 'a'))
841846
@@allow('all', endsWith(s, 'a'))
842847
@@allow('all', has(es, E1))
848+
@@allow('all', has(es, auth().e))
843849
@@allow('all', hasSome(es, [E1]))
844850
@@allow('all', hasEvery(es, [E1]))
845851
@@allow('all', isEmpty(es))
@@ -890,7 +896,9 @@ describe('Attribute tests', () => {
890896
@@allow('all', contains(s, s1))
891897
}
892898
`)
893-
).toContain('second argument must be a literal, an enum, or an array of them');
899+
).toContain(
900+
'second argument must be a literal, an enum, an expression starting with `auth().`, or an array of them'
901+
);
894902

895903
expect(
896904
await loadModelWithError(`
@@ -1022,7 +1030,9 @@ describe('Attribute tests', () => {
10221030
@@validate(contains(s, s1))
10231031
}
10241032
`)
1025-
).toContain('second argument must be a literal, an enum, or an array of them');
1033+
).toContain(
1034+
'second argument must be a literal, an enum, an expression starting with `auth().`, or an array of them'
1035+
);
10261036

10271037
expect(
10281038
await loadModelWithError(`

tests/integration/tests/e2e/filter-function-coverage.test.ts

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,28 @@ describe('Filter Function Coverage Tests', () => {
3636
await expect(enhance({ id: 'user1', name: 'bac' }).foo.create({ data: {} })).toResolveTruthy();
3737
});
3838

39+
it('contains with auth()', async () => {
40+
const { enhance } = await loadSchema(
41+
`
42+
model User {
43+
id String @id
44+
name String
45+
}
46+
47+
model Foo {
48+
id String @id @default(cuid())
49+
string String
50+
@@allow('all', contains(string, auth().name))
51+
}
52+
`
53+
);
54+
55+
await expect(enhance().foo.create({ data: { string: 'abc' } })).toBeRejectedByPolicy();
56+
const db = enhance({ id: '1', name: 'a' });
57+
await expect(db.foo.create({ data: { string: 'bcd' } })).toBeRejectedByPolicy();
58+
await expect(db.foo.create({ data: { string: 'bac' } })).toResolveTruthy();
59+
});
60+
3961
it('startsWith field', async () => {
4062
const { enhance } = await loadSchema(
4163
`
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import { createPostgresDb, dropPostgresDb, loadSchema } from '@zenstackhq/testtools';
2+
3+
describe('issue 1745', () => {
4+
it('regression', async () => {
5+
const dbUrl = await createPostgresDb('issue-1745');
6+
7+
try {
8+
await loadSchema(
9+
`
10+
enum BuyerType {
11+
STORE
12+
RESTAURANT
13+
WHOLESALER
14+
}
15+
16+
enum ChainStore {
17+
ALL
18+
CHAINSTORE_1
19+
CHAINSTORE_2
20+
CHAINSTORE_3
21+
}
22+
23+
abstract model Id {
24+
id String @id @default(cuid())
25+
}
26+
27+
abstract model Base extends Id {
28+
createdAt DateTime @default(now())
29+
updatedAt DateTime @updatedAt
30+
}
31+
32+
model Ad extends Base {
33+
serial Int @unique @default(autoincrement())
34+
buyerTypes BuyerType[]
35+
chainStores ChainStore[]
36+
listPrice Float
37+
isSold Boolean @default(false)
38+
39+
supplier Supplier @relation(fields: [supplierId], references: [id])
40+
supplierId String @default(auth().companyId)
41+
42+
@@allow('all', auth().company.companyType == 'Buyer' && has(buyerTypes, auth().company.buyerType))
43+
@@allow('all', auth().company.companyType == 'Supplier' && auth().companyId == supplierId)
44+
@@allow('all', auth().isAdmin)
45+
}
46+
47+
model Company extends Base {
48+
name String @unique
49+
organizationNumber String @unique
50+
users User[]
51+
buyerType BuyerType
52+
53+
companyType String
54+
@@delegate(companyType)
55+
56+
@@allow('read, update', auth().companyId == id)
57+
@@allow('all', auth().isAdmin)
58+
}
59+
60+
model Buyer extends Company {
61+
storeName String
62+
type String
63+
chainStore ChainStore @default(ALL)
64+
65+
@@allow('read, update', auth().company.companyType == 'Buyer' && auth().companyId == id)
66+
@@allow('all', auth().isAdmin)
67+
}
68+
69+
model Supplier extends Company {
70+
ads Ad[]
71+
72+
@@allow('all', auth().company.companyType == 'Supplier' && auth().companyId == id)
73+
@@allow('all', auth().isAdmin)
74+
}
75+
76+
model User extends Base {
77+
firstName String
78+
lastName String
79+
email String @unique
80+
username String @unique
81+
password String @password @omit
82+
isAdmin Boolean @default(false)
83+
84+
company Company? @relation(fields: [companyId], references: [id])
85+
companyId String?
86+
87+
@@allow('read', auth().id == id)
88+
@@allow('read', auth().companyId == companyId)
89+
@@allow('all', auth().isAdmin)
90+
}
91+
`,
92+
{ provider: 'postgresql', dbUrl, pushDb: false }
93+
);
94+
} finally {
95+
dropPostgresDb('issue-1745');
96+
}
97+
});
98+
});

0 commit comments

Comments
 (0)