Skip to content

Commit 50e92e0

Browse files
authored
fix(policy): relation access via this, more test cases (#252)
* fix(policy): relation access via `this`, more test cases * minor fixes
1 parent dc18713 commit 50e92e0

File tree

4 files changed

+303
-20
lines changed

4 files changed

+303
-20
lines changed

packages/runtime/src/client/crud/operations/base.ts

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,10 +1009,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
10091009
throw new QueryError(`Relation update not allowed for field "${field}"`);
10101010
}
10111011
if (!thisEntity) {
1012-
thisEntity = await this.readUnique(kysely, model, {
1013-
where: combinedWhere,
1014-
select: this.makeIdSelect(model),
1015-
});
1012+
thisEntity = await this.getEntityIds(kysely, model, combinedWhere);
10161013
if (!thisEntity) {
10171014
if (throwIfNotFound) {
10181015
throw new NotFoundError(model);

packages/runtime/src/plugins/policy/expression-transformer.ts

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import { getCrudDialect } from '../../client/crud/dialects';
2525
import type { BaseCrudDialect } from '../../client/crud/dialects/base-dialect';
2626
import { InternalError, QueryError } from '../../client/errors';
2727
import type { ClientOptions } from '../../client/options';
28-
import { getModel, getRelationForeignKeyFieldPairs, requireField } from '../../client/query-utils';
28+
import { getIdFields, getModel, getRelationForeignKeyFieldPairs, requireField } from '../../client/query-utils';
2929
import type {
3030
BinaryExpression,
3131
BinaryOperator,
@@ -111,7 +111,6 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
111111
}
112112

113113
@expr('field')
114-
// @ts-expect-error
115114
private _field(expr: FieldExpression, context: ExpressionTransformerContext<Schema>) {
116115
const fieldDef = requireField(this.schema, context.model, expr.field);
117116
if (!fieldDef.relation) {
@@ -162,8 +161,9 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
162161
return this.transformCollectionPredicate(expr, context);
163162
}
164163

165-
const left = this.transform(expr.left, context);
166-
const right = this.transform(expr.right, context);
164+
const { normalizedLeft, normalizedRight } = this.normalizeBinaryOperationOperands(expr, context);
165+
const left = this.transform(normalizedLeft, context);
166+
const right = this.transform(normalizedRight, context);
167167

168168
if (op === 'in') {
169169
if (this.isNullNode(left)) {
@@ -195,6 +195,22 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
195195
return BinaryOperationNode.create(left, this.transformOperator(op), right);
196196
}
197197

198+
private normalizeBinaryOperationOperands(expr: BinaryExpression, context: ExpressionTransformerContext<Schema>) {
199+
let normalizedLeft: Expression = expr.left;
200+
if (this.isRelationField(expr.left, context.model)) {
201+
invariant(ExpressionUtils.isNull(expr.right), 'only null comparison is supported for relation field');
202+
const idFields = getIdFields(this.schema, context.model);
203+
normalizedLeft = this.makeOrAppendMember(normalizedLeft, idFields[0]!);
204+
}
205+
let normalizedRight: Expression = expr.right;
206+
if (this.isRelationField(expr.right, context.model)) {
207+
invariant(ExpressionUtils.isNull(expr.left), 'only null comparison is supported for relation field');
208+
const idFields = getIdFields(this.schema, context.model);
209+
normalizedRight = this.makeOrAppendMember(normalizedRight, idFields[0]!);
210+
}
211+
return { normalizedLeft, normalizedRight };
212+
}
213+
198214
private transformCollectionPredicate(expr: BinaryExpression, context: ExpressionTransformerContext<Schema>) {
199215
invariant(expr.op === '?' || expr.op === '!' || expr.op === '^', 'expected "?" or "!" or "^" operator');
200216

@@ -211,11 +227,15 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
211227
);
212228

213229
let newContextModel: string;
214-
if (ExpressionUtils.isField(expr.left)) {
215-
const fieldDef = requireField(this.schema, context.model, expr.left.field);
230+
const fieldDef = this.getFieldDefFromFieldRef(expr.left, context.model);
231+
if (fieldDef) {
232+
invariant(fieldDef.relation, `field is not a relation: ${JSON.stringify(expr.left)}`);
216233
newContextModel = fieldDef.type;
217234
} else {
218-
invariant(ExpressionUtils.isField(expr.left.receiver));
235+
invariant(
236+
ExpressionUtils.isMember(expr.left) && ExpressionUtils.isField(expr.left.receiver),
237+
'left operand must be member access with field receiver',
238+
);
219239
const fieldDef = requireField(this.schema, context.model, expr.left.receiver.field);
220240
newContextModel = fieldDef.type;
221241
for (const member of expr.left.members) {
@@ -396,16 +416,14 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
396416

397417
if (ExpressionUtils.isThis(expr.receiver)) {
398418
if (expr.members.length === 1) {
399-
// optimize for the simple this.scalar case
400-
const fieldDef = requireField(this.schema, context.model, expr.members[0]!);
401-
invariant(!fieldDef.relation, 'this.relation access should have been transformed into relation access');
402-
return this.createColumnRef(expr.members[0]!, restContext);
419+
// `this.relation` case, equivalent to field access
420+
return this._field(ExpressionUtils.field(expr.members[0]!), context);
421+
} else {
422+
// transform the first segment into a relation access, then continue with the rest of the members
423+
const firstMemberFieldDef = requireField(this.schema, context.model, expr.members[0]!);
424+
receiver = this.transformRelationAccess(expr.members[0]!, firstMemberFieldDef.type, restContext);
425+
members = expr.members.slice(1);
403426
}
404-
405-
// transform the first segment into a relation access, then continue with the rest of the members
406-
const firstMemberFieldDef = requireField(this.schema, context.model, expr.members[0]!);
407-
receiver = this.transformRelationAccess(expr.members[0]!, firstMemberFieldDef.type, restContext);
408-
members = expr.members.slice(1);
409427
} else {
410428
receiver = this.transform(expr.receiver, restContext);
411429
}
@@ -559,4 +577,23 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
559577
return conditions.reduce((acc, condition) => ExpressionUtils.binary(acc, '&&', condition));
560578
}
561579
}
580+
581+
private isRelationField(expr: Expression, model: GetModels<Schema>) {
582+
const fieldDef = this.getFieldDefFromFieldRef(expr, model);
583+
return !!fieldDef?.relation;
584+
}
585+
586+
private getFieldDefFromFieldRef(expr: Expression, model: GetModels<Schema>): FieldDef | undefined {
587+
if (ExpressionUtils.isField(expr)) {
588+
return requireField(this.schema, model, expr.field);
589+
} else if (
590+
ExpressionUtils.isMember(expr) &&
591+
expr.members.length === 1 &&
592+
ExpressionUtils.isThis(expr.receiver)
593+
) {
594+
return requireField(this.schema, model, expr.members[0]!);
595+
} else {
596+
return undefined;
597+
}
598+
}
562599
}

packages/runtime/src/plugins/policy/policy-handler.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
102102
}
103103
return readBackResult;
104104
} else {
105+
// reading id fields bypasses policy
105106
return result;
106107
}
107108

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
import { describe, expect, it } from 'vitest';
2+
import { createPolicyTestClient } from '../utils';
3+
4+
describe('Read policy tests', () => {
5+
describe('Find tests', () => {
6+
it('works with top-level find', async () => {
7+
const db = await createPolicyTestClient(
8+
`
9+
model Foo {
10+
id Int @id
11+
x Int
12+
@@allow('create', true)
13+
@@allow('read', x > 0)
14+
}
15+
`,
16+
);
17+
18+
await db.$unuseAll().foo.create({ data: { id: 1, x: 0 } });
19+
await expect(db.foo.findUnique({ where: { id: 1 } })).toResolveNull();
20+
21+
await db.$unuseAll().foo.update({ where: { id: 1 }, data: { x: 1 } });
22+
await expect(db.foo.findUnique({ where: { id: 1 } })).toResolveTruthy();
23+
});
24+
25+
it('works with mutation read-back', async () => {
26+
const db = await createPolicyTestClient(
27+
`
28+
model Foo {
29+
id Int @id
30+
x Int
31+
@@allow('create,update', true)
32+
@@allow('read', x > 0)
33+
}
34+
`,
35+
);
36+
37+
await expect(db.foo.create({ data: { id: 1, x: 0 } })).toBeRejectedByPolicy();
38+
await expect(db.$unuseAll().foo.count()).resolves.toBe(1);
39+
await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).resolves.toMatchObject({ x: 1 });
40+
});
41+
42+
it('works with to-one relation optional owner-side read', async () => {
43+
const db = await createPolicyTestClient(
44+
`
45+
model Foo {
46+
id Int @id
47+
bar Bar? @relation(fields: [barId], references: [id])
48+
barId Int? @unique
49+
@@allow('all', true)
50+
}
51+
52+
model Bar {
53+
id Int @id
54+
y Int
55+
foo Foo?
56+
@@allow('create,update', true)
57+
@@allow('read', y > 0)
58+
}
59+
`,
60+
);
61+
62+
await db.foo.create({ data: { id: 1, bar: { create: { id: 1, y: 0 } } } });
63+
await expect(db.foo.findFirst({ include: { bar: true } })).resolves.toMatchObject({ id: 1, bar: null });
64+
await db.bar.update({ where: { id: 1 }, data: { y: 1 } });
65+
await expect(db.foo.findFirst({ include: { bar: true } })).resolves.toMatchObject({
66+
id: 1,
67+
bar: { id: 1 },
68+
});
69+
});
70+
71+
// TODO: check if we should be consistent with v2 and filter out the parent entity
72+
// if a non-optional child relation is included but not readable
73+
it('works with to-one relation non-optional owner-side read', async () => {
74+
const db = await createPolicyTestClient(
75+
`
76+
model Foo {
77+
id Int @id
78+
bar Bar @relation(fields: [barId], references: [id])
79+
barId Int @unique
80+
@@allow('all', true)
81+
}
82+
83+
model Bar {
84+
id Int @id
85+
y Int
86+
foo Foo?
87+
@@allow('create,update', true)
88+
@@allow('read', y > 0)
89+
}
90+
`,
91+
);
92+
93+
await db.foo.create({ data: { id: 1, bar: { create: { id: 1, y: 0 } } } });
94+
await expect(db.foo.findFirst({ include: { bar: true } })).resolves.toMatchObject({ id: 1, bar: null });
95+
await db.bar.update({ where: { id: 1 }, data: { y: 1 } });
96+
await expect(db.foo.findFirst({ include: { bar: true } })).resolves.toMatchObject({
97+
id: 1,
98+
bar: { id: 1 },
99+
});
100+
});
101+
102+
it('works with to-one relation non-owner-side read', async () => {
103+
const db = await createPolicyTestClient(
104+
`
105+
model Foo {
106+
id Int @id
107+
bar Bar?
108+
@@allow('all', true)
109+
}
110+
111+
model Bar {
112+
id Int @id
113+
y Int
114+
foo Foo @relation(fields: [fooId], references: [id])
115+
fooId Int @unique
116+
@@allow('create,update', true)
117+
@@allow('read', y > 0)
118+
}
119+
`,
120+
);
121+
122+
await db.foo.create({ data: { id: 1, bar: { create: { id: 1, y: 0 } } } });
123+
await expect(db.foo.findFirst({ include: { bar: true } })).resolves.toMatchObject({ id: 1, bar: null });
124+
await db.bar.update({ where: { id: 1 }, data: { y: 1 } });
125+
await expect(db.foo.findFirst({ include: { bar: true } })).resolves.toMatchObject({
126+
id: 1,
127+
bar: { id: 1 },
128+
});
129+
});
130+
131+
it('works with to-many relation read', async () => {
132+
const db = await createPolicyTestClient(
133+
`
134+
model Foo {
135+
id Int @id
136+
bars Bar[]
137+
@@allow('all', true)
138+
}
139+
140+
model Bar {
141+
id Int @id
142+
y Int
143+
foo Foo? @relation(fields: [fooId], references: [id])
144+
fooId Int?
145+
@@allow('create,update', true)
146+
@@allow('read', y > 0)
147+
}
148+
`,
149+
);
150+
151+
await db.foo.create({
152+
data: {
153+
id: 1,
154+
bars: {
155+
create: [
156+
{ id: 1, y: 0 },
157+
{ id: 2, y: 1 },
158+
],
159+
},
160+
},
161+
});
162+
await expect(db.foo.findFirst({ include: { bars: true } })).resolves.toMatchObject({
163+
id: 1,
164+
bars: [{ id: 2 }],
165+
});
166+
});
167+
168+
it('works with filtered by to-one relation field', async () => {
169+
const db = await createPolicyTestClient(
170+
`
171+
model Foo {
172+
id Int @id
173+
bar Bar? @relation(fields: [barId], references: [id])
174+
barId Int? @unique
175+
@@allow('create', true)
176+
@@allow('read', bar.y > 0)
177+
}
178+
179+
model Bar {
180+
id Int @id
181+
y Int
182+
foo Foo?
183+
@@allow('all', true)
184+
}
185+
`,
186+
);
187+
188+
await db.$unuseAll().foo.create({ data: { id: 1, bar: { create: { id: 1, y: 0 } } } });
189+
await expect(db.foo.findMany()).resolves.toHaveLength(0);
190+
await db.bar.update({ where: { id: 1 }, data: { y: 1 } });
191+
await expect(db.foo.findMany()).resolves.toHaveLength(1);
192+
});
193+
194+
it('works with filtered by to-one relation non-null', async () => {
195+
const db = await createPolicyTestClient(
196+
`
197+
model Foo {
198+
id Int @id
199+
bar Bar? @relation(fields: [barId], references: [id])
200+
barId Int? @unique
201+
@@allow('create,update', true)
202+
@@allow('read', bar != null)
203+
@@allow('read', this.bar != null)
204+
}
205+
206+
model Bar {
207+
id Int @id
208+
y Int
209+
foo Foo?
210+
@@allow('all', true)
211+
}
212+
`,
213+
);
214+
215+
await db.$unuseAll().foo.create({ data: { id: 1 } });
216+
await expect(db.foo.findMany()).resolves.toHaveLength(0);
217+
await db.foo.update({ where: { id: 1 }, data: { bar: { create: { id: 1, y: 0 } } } });
218+
await expect(db.foo.findMany()).resolves.toHaveLength(1);
219+
});
220+
221+
it('works with filtered by to-many relation', async () => {
222+
const db = await createPolicyTestClient(
223+
`
224+
model Foo {
225+
id Int @id
226+
bars Bar[]
227+
@@allow('create,update', true)
228+
@@allow('read', bars?[y > 0])
229+
@@allow('read', this.bars?[y > 0])
230+
}
231+
232+
model Bar {
233+
id Int @id
234+
y Int
235+
foo Foo? @relation(fields: [fooId], references: [id])
236+
fooId Int?
237+
@@allow('all', true)
238+
}
239+
`,
240+
);
241+
242+
await db.$unuseAll().foo.create({ data: { id: 1, bars: { create: [{ id: 1, y: 0 }] } } });
243+
await expect(db.foo.findMany()).resolves.toHaveLength(0);
244+
await db.foo.update({ where: { id: 1 }, data: { bars: { create: { id: 2, y: 1 } } } });
245+
await expect(db.foo.findMany()).resolves.toHaveLength(1);
246+
});
247+
});
248+
});

0 commit comments

Comments
 (0)