Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions packages/runtime/src/client/crud/operations/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1009,10 +1009,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
throw new QueryError(`Relation update not allowed for field "${field}"`);
}
if (!thisEntity) {
thisEntity = await this.readUnique(kysely, model, {
where: combinedWhere,
select: this.makeIdSelect(model),
});
thisEntity = await this.getEntityIds(kysely, model, combinedWhere);
if (!thisEntity) {
if (throwIfNotFound) {
throw new NotFoundError(model);
Expand Down
69 changes: 53 additions & 16 deletions packages/runtime/src/plugins/policy/expression-transformer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import { getCrudDialect } from '../../client/crud/dialects';
import type { BaseCrudDialect } from '../../client/crud/dialects/base-dialect';
import { InternalError, QueryError } from '../../client/errors';
import type { ClientOptions } from '../../client/options';
import { getModel, getRelationForeignKeyFieldPairs, requireField } from '../../client/query-utils';
import { getIdFields, getModel, getRelationForeignKeyFieldPairs, requireField } from '../../client/query-utils';
import type {
BinaryExpression,
BinaryOperator,
Expand Down Expand Up @@ -111,7 +111,6 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
}

@expr('field')
// @ts-expect-error
private _field(expr: FieldExpression, context: ExpressionTransformerContext<Schema>) {
const fieldDef = requireField(this.schema, context.model, expr.field);
if (!fieldDef.relation) {
Expand Down Expand Up @@ -162,8 +161,9 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
return this.transformCollectionPredicate(expr, context);
}

const left = this.transform(expr.left, context);
const right = this.transform(expr.right, context);
const { normalizedLeft, normalizedRight } = this.normalizeBinaryOperationOperands(expr, context);
const left = this.transform(normalizedLeft, context);
const right = this.transform(normalizedRight, context);

if (op === 'in') {
if (this.isNullNode(left)) {
Expand Down Expand Up @@ -195,6 +195,22 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
return BinaryOperationNode.create(left, this.transformOperator(op), right);
}

private normalizeBinaryOperationOperands(expr: BinaryExpression, context: ExpressionTransformerContext<Schema>) {
let normalizedLeft: Expression = expr.left;
if (this.isRelationField(expr.left, context.model)) {
invariant(ExpressionUtils.isNull(expr.right));
const idFields = getIdFields(this.schema, context.model);
normalizedLeft = this.makeOrAppendMember(normalizedLeft, idFields[0]!);
}
let normalizedRight: Expression = expr.right;
if (this.isRelationField(expr.right, context.model)) {
invariant(ExpressionUtils.isNull(expr.left));
const idFields = getIdFields(this.schema, context.model);
normalizedRight = this.makeOrAppendMember(normalizedRight, idFields[0]!);
}
return { normalizedLeft, normalizedRight };
}

private transformCollectionPredicate(expr: BinaryExpression, context: ExpressionTransformerContext<Schema>) {
invariant(expr.op === '?' || expr.op === '!' || expr.op === '^', 'expected "?" or "!" or "^" operator');

Expand All @@ -211,11 +227,15 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
);

let newContextModel: string;
if (ExpressionUtils.isField(expr.left)) {
const fieldDef = requireField(this.schema, context.model, expr.left.field);
const fieldDef = this.getFieldDefFromFieldRef(expr.left, context.model);
if (fieldDef) {
invariant(fieldDef.relation, `field is not a relation: ${JSON.stringify(expr.left)}`);
newContextModel = fieldDef.type;
} else {
invariant(ExpressionUtils.isField(expr.left.receiver));
invariant(
ExpressionUtils.isMember(expr.left) && ExpressionUtils.isField(expr.left.receiver),
'left operand must be member access with field receiver',
);
const fieldDef = requireField(this.schema, context.model, expr.left.receiver.field);
newContextModel = fieldDef.type;
for (const member of expr.left.members) {
Expand Down Expand Up @@ -396,16 +416,14 @@ export class ExpressionTransformer<Schema extends SchemaDef> {

if (ExpressionUtils.isThis(expr.receiver)) {
if (expr.members.length === 1) {
// optimize for the simple this.scalar case
const fieldDef = requireField(this.schema, context.model, expr.members[0]!);
invariant(!fieldDef.relation, 'this.relation access should have been transformed into relation access');
return this.createColumnRef(expr.members[0]!, restContext);
// `this.relation` case, equivalent to field access
return this._field(ExpressionUtils.field(expr.members[0]!), context);
} else {
// transform the first segment into a relation access, then continue with the rest of the members
const firstMemberFieldDef = requireField(this.schema, context.model, expr.members[0]!);
receiver = this.transformRelationAccess(expr.members[0]!, firstMemberFieldDef.type, restContext);
members = expr.members.slice(1);
}

// transform the first segment into a relation access, then continue with the rest of the members
const firstMemberFieldDef = requireField(this.schema, context.model, expr.members[0]!);
receiver = this.transformRelationAccess(expr.members[0]!, firstMemberFieldDef.type, restContext);
members = expr.members.slice(1);
} else {
receiver = this.transform(expr.receiver, restContext);
}
Expand Down Expand Up @@ -559,4 +577,23 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
return conditions.reduce((acc, condition) => ExpressionUtils.binary(acc, '&&', condition));
}
}

private isRelationField(expr: Expression, model: GetModels<Schema>) {
const fieldDef = this.getFieldDefFromFieldRef(expr, model);
return !!fieldDef?.relation;
}

private getFieldDefFromFieldRef(expr: Expression, model: GetModels<Schema>): FieldDef | undefined {
if (ExpressionUtils.isField(expr)) {
return requireField(this.schema, model, expr.field);
} else if (
ExpressionUtils.isMember(expr) &&
expr.members.length === 1 &&
ExpressionUtils.isThis(expr.receiver)
) {
return requireField(this.schema, model, expr.members[0]!);
} else {
return undefined;
}
}
}
1 change: 1 addition & 0 deletions packages/runtime/src/plugins/policy/policy-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
}
return readBackResult;
} else {
// reading id fields bypasses policy
return result;
}

Expand Down
248 changes: 248 additions & 0 deletions packages/runtime/test/policy/crud/read.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
import { describe, expect, it } from 'vitest';
import { createPolicyTestClient } from '../utils';

describe('Read policy tests', () => {
describe('Find tests', () => {
it('works with top-level find', async () => {
const db = await createPolicyTestClient(
`
model Foo {
id Int @id
x Int
@@allow('create', true)
@@allow('read', x > 0)
}
`,
);

await db.$unuseAll().foo.create({ data: { id: 1, x: 0 } });
await expect(db.foo.findUnique({ where: { id: 1 } })).toResolveNull();

await db.$unuseAll().foo.update({ where: { id: 1 }, data: { x: 1 } });
await expect(db.foo.findUnique({ where: { id: 1 } })).toResolveTruthy();
});

it('works with mutation read-back', async () => {
const db = await createPolicyTestClient(
`
model Foo {
id Int @id
x Int
@@allow('create,update', true)
@@allow('read', x > 0)
}
`,
);

await expect(db.foo.create({ data: { id: 1, x: 0 } })).toBeRejectedByPolicy();
await expect(db.$unuseAll().foo.count()).resolves.toBe(1);
await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).resolves.toMatchObject({ x: 1 });
});

it('works with to-one relation optional owner-side read', async () => {
const db = await createPolicyTestClient(
`
model Foo {
id Int @id
bar Bar? @relation(fields: [barId], references: [id])
barId Int? @unique
@@allow('all', true)
}

model Bar {
id Int @id
y Int
foo Foo?
@@allow('create,update', true)
@@allow('read', y > 0)
}
`,
);

await db.foo.create({ data: { id: 1, bar: { create: { id: 1, y: 0 } } } });
await expect(db.foo.findFirst({ include: { bar: true } })).resolves.toMatchObject({ id: 1, bar: null });
await db.bar.update({ where: { id: 1 }, data: { y: 1 } });
await expect(db.foo.findFirst({ include: { bar: true } })).resolves.toMatchObject({
id: 1,
bar: { id: 1 },
});
});

// TODO: check if we should be consistent with v2 and filter out the parent entity
// if a non-optional child relation is included but not readable
it('works with to-one relation non-optional owner-side read', async () => {
const db = await createPolicyTestClient(
`
model Foo {
id Int @id
bar Bar @relation(fields: [barId], references: [id])
barId Int @unique
@@allow('all', true)
}

model Bar {
id Int @id
y Int
foo Foo?
@@allow('create,update', true)
@@allow('read', y > 0)
}
`,
);

await db.foo.create({ data: { id: 1, bar: { create: { id: 1, y: 0 } } } });
await expect(db.foo.findFirst({ include: { bar: true } })).resolves.toMatchObject({ id: 1, bar: null });
await db.bar.update({ where: { id: 1 }, data: { y: 1 } });
await expect(db.foo.findFirst({ include: { bar: true } })).resolves.toMatchObject({
id: 1,
bar: { id: 1 },
});
});

it('works with to-one relation non-owner-side read', async () => {
const db = await createPolicyTestClient(
`
model Foo {
id Int @id
bar Bar?
@@allow('all', true)
}

model Bar {
id Int @id
y Int
foo Foo @relation(fields: [fooId], references: [id])
fooId Int @unique
@@allow('create,update', true)
@@allow('read', y > 0)
}
`,
);

await db.foo.create({ data: { id: 1, bar: { create: { id: 1, y: 0 } } } });
await expect(db.foo.findFirst({ include: { bar: true } })).resolves.toMatchObject({ id: 1, bar: null });
await db.bar.update({ where: { id: 1 }, data: { y: 1 } });
await expect(db.foo.findFirst({ include: { bar: true } })).resolves.toMatchObject({
id: 1,
bar: { id: 1 },
});
});

it('works with to-many relation read', async () => {
const db = await createPolicyTestClient(
`
model Foo {
id Int @id
bars Bar[]
@@allow('all', true)
}

model Bar {
id Int @id
y Int
foo Foo? @relation(fields: [fooId], references: [id])
fooId Int?
@@allow('create,update', true)
@@allow('read', y > 0)
}
`,
);

await db.foo.create({
data: {
id: 1,
bars: {
create: [
{ id: 1, y: 0 },
{ id: 2, y: 1 },
],
},
},
});
await expect(db.foo.findFirst({ include: { bars: true } })).resolves.toMatchObject({
id: 1,
bars: [{ id: 2 }],
});
});

it('works with filtered by to-one relation field', async () => {
const db = await createPolicyTestClient(
`
model Foo {
id Int @id
bar Bar? @relation(fields: [barId], references: [id])
barId Int? @unique
@@allow('create', true)
@@allow('read', bar.y > 0)
}

model Bar {
id Int @id
y Int
foo Foo?
@@allow('all', true)
}
`,
);

await db.$unuseAll().foo.create({ data: { id: 1, bar: { create: { id: 1, y: 0 } } } });
await expect(db.foo.findMany()).resolves.toHaveLength(0);
await db.bar.update({ where: { id: 1 }, data: { y: 1 } });
await expect(db.foo.findMany()).resolves.toHaveLength(1);
});

it('works with filtered by to-one relation non-null', async () => {
const db = await createPolicyTestClient(
`
model Foo {
id Int @id
bar Bar? @relation(fields: [barId], references: [id])
barId Int? @unique
@@allow('create,update', true)
@@allow('read', bar != null)
@@allow('read', this.bar != null)
}

model Bar {
id Int @id
y Int
foo Foo?
@@allow('all', true)
}
`,
);

await db.$unuseAll().foo.create({ data: { id: 1 } });
await expect(db.foo.findMany()).resolves.toHaveLength(0);
await db.foo.update({ where: { id: 1 }, data: { bar: { create: { id: 1, y: 0 } } } });
await expect(db.foo.findMany()).resolves.toHaveLength(1);
});

it('works with filtered by to-many relation', async () => {
const db = await createPolicyTestClient(
`
model Foo {
id Int @id
bars Bar[]
@@allow('create,update', true)
@@allow('read', bars?[y > 0])
@@allow('read', this.bars?[y > 0])
}

model Bar {
id Int @id
y Int
foo Foo? @relation(fields: [fooId], references: [id])
fooId Int?
@@allow('all', true)
}
`,
);

await db.$unuseAll().foo.create({ data: { id: 1, bars: { create: [{ id: 1, y: 0 }] } } });
await expect(db.foo.findMany()).resolves.toHaveLength(0);
await db.foo.update({ where: { id: 1 }, data: { bars: { create: { id: 2, y: 1 } } } });
await expect(db.foo.findMany()).resolves.toHaveLength(1);
});
});
});