Skip to content

Commit f01bcde

Browse files
authored
fix(policy): logical combination issue and more tests for update (#245)
1 parent 4ef27c7 commit f01bcde

File tree

4 files changed

+162
-17
lines changed

4 files changed

+162
-17
lines changed

packages/runtime/src/plugins/policy/expression-transformer.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
231231
});
232232

233233
if (expr.op === '!') {
234-
predicateFilter = logicalNot(predicateFilter);
234+
predicateFilter = logicalNot(this.dialect, predicateFilter);
235235
}
236236

237237
const count = FunctionNode.create('count', [ValueNode.createImmediate(1)]);
@@ -305,7 +305,7 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
305305
private _unary(expr: UnaryExpression, context: ExpressionTransformerContext<Schema>) {
306306
// only '!' operator for now
307307
invariant(expr.op === '!', 'only "!" operator is supported');
308-
return logicalNot(this.transform(expr.operand, context));
308+
return logicalNot(this.dialect, this.transform(expr.operand, context));
309309
}
310310

311311
private transformOperator(op: Exclude<BinaryOperator, '?' | '!' | '^'>) {

packages/runtime/src/plugins/policy/utils.ts

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -50,50 +50,65 @@ export function conjunction<Schema extends SchemaDef>(
5050
dialect: BaseCrudDialect<Schema>,
5151
nodes: OperationNode[],
5252
): OperationNode {
53+
if (nodes.length === 0) {
54+
return trueNode(dialect);
55+
}
56+
if (nodes.length === 1) {
57+
return nodes[0]!;
58+
}
5359
if (nodes.some(isFalseNode)) {
5460
return falseNode(dialect);
5561
}
5662
const items = nodes.filter((n) => !isTrueNode(n));
5763
if (items.length === 0) {
5864
return trueNode(dialect);
5965
}
60-
return items.reduce((acc, node) =>
61-
OrNode.is(node)
62-
? AndNode.create(acc, ParensNode.create(node)) // wraps parentheses
63-
: AndNode.create(acc, node),
64-
);
66+
return items.reduce((acc, node) => AndNode.create(wrapParensIf(acc, OrNode.is), wrapParensIf(node, OrNode.is)));
6567
}
6668

6769
export function disjunction<Schema extends SchemaDef>(
6870
dialect: BaseCrudDialect<Schema>,
6971
nodes: OperationNode[],
7072
): OperationNode {
73+
if (nodes.length === 0) {
74+
return falseNode(dialect);
75+
}
76+
if (nodes.length === 1) {
77+
return nodes[0]!;
78+
}
7179
if (nodes.some(isTrueNode)) {
7280
return trueNode(dialect);
7381
}
7482
const items = nodes.filter((n) => !isFalseNode(n));
7583
if (items.length === 0) {
7684
return falseNode(dialect);
7785
}
78-
return items.reduce((acc, node) =>
79-
AndNode.is(node)
80-
? OrNode.create(acc, ParensNode.create(node)) // wraps parentheses
81-
: OrNode.create(acc, node),
82-
);
86+
return items.reduce((acc, node) => OrNode.create(wrapParensIf(acc, AndNode.is), wrapParensIf(node, AndNode.is)));
8387
}
8488

8589
/**
8690
* Negates a logical expression.
8791
*/
88-
export function logicalNot(node: OperationNode): OperationNode {
92+
export function logicalNot<Schema extends SchemaDef>(
93+
dialect: BaseCrudDialect<Schema>,
94+
node: OperationNode,
95+
): OperationNode {
96+
if (isTrueNode(node)) {
97+
return falseNode(dialect);
98+
}
99+
if (isFalseNode(node)) {
100+
return trueNode(dialect);
101+
}
89102
return UnaryOperationNode.create(
90103
OperatorNode.create('not'),
91-
AndNode.is(node) || OrNode.is(node)
92-
? ParensNode.create(node) // wraps parentheses
93-
: node,
104+
wrapParensIf(node, (n) => AndNode.is(n) || OrNode.is(n)),
94105
);
95106
}
96107

108+
function wrapParensIf(node: OperationNode, predicate: (node: OperationNode) => boolean): OperationNode {
109+
return predicate(node) ? ParensNode.create(node) : node;
110+
}
111+
97112
/**
98113
* Builds an expression node that checks if a node is true.
99114
*/

packages/runtime/test/policy/crud/create.test.ts

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@ model Foo {
1515
);
1616
await expect(db.foo.create({ data: { x: 0 } })).toBeRejectedByPolicy();
1717
await expect(db.foo.create({ data: { x: 1 } })).resolves.toMatchObject({ x: 1 });
18+
19+
await expect(
20+
db.$qb.insertInto('Foo').values({ x: 0 }).returningAll().executeTakeFirst(),
21+
).toBeRejectedByPolicy();
22+
await expect(
23+
db.$qb.insertInto('Foo').values({ x: 1 }).returningAll().executeTakeFirst(),
24+
).resolves.toMatchObject({ x: 1 });
1825
});
1926

2027
it('works with this scalar member check', async () => {
@@ -66,7 +73,7 @@ model Foo {
6673
id Int @id @default(autoincrement())
6774
x Int
6875
@@deny('create', x <= 0)
69-
@@allow('create', x > 1)
76+
@@allow('create', x <= 0 || x > 1)
7077
@@allow('read', true)
7178
}
7279
`,
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import { describe, expect, it } from 'vitest';
2+
import { createPolicyTestClient } from '../utils';
3+
4+
describe('Update policy tests', () => {
5+
it('works with scalar field check', async () => {
6+
const db = await createPolicyTestClient(
7+
`
8+
model Foo {
9+
id Int @id
10+
x Int
11+
@@allow('update', x > 0)
12+
@@allow('create,read', true)
13+
}
14+
`,
15+
);
16+
17+
await db.foo.create({ data: { id: 1, x: 0 } });
18+
await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).toBeRejectedNotFound();
19+
await db.foo.create({ data: { id: 2, x: 1 } });
20+
await expect(db.foo.update({ where: { id: 2 }, data: { x: 2 } })).resolves.toMatchObject({ x: 2 });
21+
22+
await expect(
23+
db.$qb.updateTable('Foo').set({ x: 1 }).where('id', '=', 1).executeTakeFirst(),
24+
).resolves.toMatchObject({ numUpdatedRows: 0n });
25+
await expect(
26+
db.$qb.updateTable('Foo').set({ x: 3 }).where('id', '=', 2).returningAll().execute(),
27+
).resolves.toMatchObject([{ id: 2, x: 3 }]);
28+
});
29+
30+
it('works with this scalar member check', async () => {
31+
const db = await createPolicyTestClient(
32+
`
33+
model Foo {
34+
id Int @id
35+
x Int
36+
@@allow('update', this.x > 0)
37+
@@allow('create,read', true)
38+
}
39+
`,
40+
);
41+
42+
await db.foo.create({ data: { id: 1, x: 0 } });
43+
await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).toBeRejectedNotFound();
44+
await db.foo.create({ data: { id: 2, x: 1 } });
45+
await expect(db.foo.update({ where: { id: 2 }, data: { x: 2 } })).resolves.toMatchObject({ x: 2 });
46+
});
47+
48+
it('denies by default', async () => {
49+
const db = await createPolicyTestClient(
50+
`
51+
model Foo {
52+
id Int @id
53+
x Int
54+
@@allow('create,read', true)
55+
}
56+
`,
57+
);
58+
59+
await db.foo.create({ data: { id: 1, x: 0 } });
60+
await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).toBeRejectedNotFound();
61+
});
62+
63+
it('works with deny rule', async () => {
64+
const db = await createPolicyTestClient(
65+
`
66+
model Foo {
67+
id Int @id
68+
x Int
69+
@@deny('update', x <= 0)
70+
@@allow('create,read,update', true)
71+
}
72+
`,
73+
);
74+
await db.foo.create({ data: { id: 1, x: 0 } });
75+
await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).toBeRejectedNotFound();
76+
await db.foo.create({ data: { id: 2, x: 1 } });
77+
await expect(db.foo.update({ where: { id: 2 }, data: { x: 2 } })).resolves.toMatchObject({ x: 2 });
78+
});
79+
80+
it('works with mixed allow and deny rules', async () => {
81+
const db = await createPolicyTestClient(
82+
`
83+
model Foo {
84+
id Int @id
85+
x Int
86+
@@deny('update', x <= 0)
87+
@@allow('update', x <= 0 || x > 1)
88+
@@allow('create,read', true)
89+
}
90+
`,
91+
);
92+
93+
await db.foo.create({ data: { id: 1, x: 0 } });
94+
await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).toBeRejectedNotFound();
95+
await db.foo.create({ data: { id: 2, x: 1 } });
96+
await expect(db.foo.update({ where: { id: 2 }, data: { x: 2 } })).toBeRejectedNotFound();
97+
await db.foo.create({ data: { id: 3, x: 2 } });
98+
await expect(db.foo.update({ where: { id: 3 }, data: { x: 3 } })).resolves.toMatchObject({ x: 3 });
99+
});
100+
101+
it('works with auth check', async () => {
102+
const db = await createPolicyTestClient(
103+
`
104+
type Auth {
105+
x Int
106+
@@auth
107+
}
108+
109+
model Foo {
110+
id Int @id
111+
x Int
112+
@@allow('update', x == auth().x)
113+
@@allow('create,read', true)
114+
}
115+
`,
116+
);
117+
await db.foo.create({ data: { id: 1, x: 1 } });
118+
await expect(db.$setAuth({ x: 0 }).foo.update({ where: { id: 1 }, data: { x: 2 } })).toBeRejectedNotFound();
119+
await expect(db.$setAuth({ x: 1 }).foo.update({ where: { id: 1 }, data: { x: 2 } })).resolves.toMatchObject({
120+
x: 2,
121+
});
122+
});
123+
});

0 commit comments

Comments
 (0)