Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 97 additions & 64 deletions packages/runtime/src/plugins/policy/policy-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,101 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
// return result;
}

// #region overrides

protected override transformSelectQuery(node: SelectQueryNode) {
let whereNode = node.where;

node.from?.froms.forEach((from) => {
const extractResult = this.extractTableName(from);
if (extractResult) {
const { model, alias } = extractResult;
const filter = this.buildPolicyFilter(model, alias, 'read');
whereNode = WhereNode.create(
whereNode?.where ? conjunction(this.dialect, [whereNode.where, filter]) : filter,
);
}
});

const baseResult = super.transformSelectQuery({
...node,
where: undefined,
});

return {
...baseResult,
where: whereNode,
};
}

protected override transformInsertQuery(node: InsertQueryNode) {
// pre-insert check is done in `handle()`

let onConflict = node.onConflict;

if (onConflict?.updates) {
// for "on conflict do update", we need to apply policy filter to the "where" clause
const mutationModel = this.getMutationModel(node);
const filter = this.buildPolicyFilter(mutationModel, undefined, 'update');
if (onConflict.updateWhere) {
onConflict = {
...onConflict,
updateWhere: WhereNode.create(conjunction(this.dialect, [onConflict.updateWhere.where, filter])),
};
} else {
onConflict = {
...onConflict,
updateWhere: WhereNode.create(filter),
};
}
}

let result = super.transformInsertQuery(node);
// merge updated onConflict
result = onConflict ? { ...result, onConflict } : result;

if (!node.returning) {
return result;
}

if (this.onlyReturningId(node)) {
return result;
} else {
// only return ID fields, that's enough for reading back the inserted row
const idFields = getIdFields(this.client.$schema, this.getMutationModel(node));
return {
...result,
returning: ReturningNode.create(
idFields.map((field) => SelectionNode.create(ColumnNode.create(field))),
),
};
}
}

protected override transformUpdateQuery(node: UpdateQueryNode) {
const result = super.transformUpdateQuery(node);
const mutationModel = this.getMutationModel(node);
const filter = this.buildPolicyFilter(mutationModel, undefined, 'update');
return {
...result,
where: WhereNode.create(result.where ? conjunction(this.dialect, [result.where.where, filter]) : filter),
};
}

protected override transformDeleteQuery(node: DeleteQueryNode) {
const result = super.transformDeleteQuery(node);
const mutationModel = this.getMutationModel(node);
const filter = this.buildPolicyFilter(mutationModel, undefined, 'delete');
return {
...result,
where: WhereNode.create(result.where ? conjunction(this.dialect, [result.where.where, filter]) : filter),
};
}

// #endregion

// #region helpers

private onlyReturningId(node: MutationQueryNode) {
if (!node.returning) {
return true;
Expand Down Expand Up @@ -397,70 +492,6 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
return combinedPolicy;
}

protected override transformSelectQuery(node: SelectQueryNode) {
let whereNode = node.where;

node.from?.froms.forEach((from) => {
const extractResult = this.extractTableName(from);
if (extractResult) {
const { model, alias } = extractResult;
const filter = this.buildPolicyFilter(model, alias, 'read');
whereNode = WhereNode.create(
whereNode?.where ? conjunction(this.dialect, [whereNode.where, filter]) : filter,
);
}
});

const baseResult = super.transformSelectQuery({
...node,
where: undefined,
});

return {
...baseResult,
where: whereNode,
};
}

protected override transformInsertQuery(node: InsertQueryNode) {
const result = super.transformInsertQuery(node);
if (!node.returning) {
return result;
}
if (this.onlyReturningId(node)) {
return result;
} else {
// only return ID fields, that's enough for reading back the inserted row
const idFields = getIdFields(this.client.$schema, this.getMutationModel(node));
return {
...result,
returning: ReturningNode.create(
idFields.map((field) => SelectionNode.create(ColumnNode.create(field))),
),
};
}
}

protected override transformUpdateQuery(node: UpdateQueryNode) {
const result = super.transformUpdateQuery(node);
const mutationModel = this.getMutationModel(node);
const filter = this.buildPolicyFilter(mutationModel, undefined, 'update');
return {
...result,
where: WhereNode.create(result.where ? conjunction(this.dialect, [result.where.where, filter]) : filter),
};
}

protected override transformDeleteQuery(node: DeleteQueryNode) {
const result = super.transformDeleteQuery(node);
const mutationModel = this.getMutationModel(node);
const filter = this.buildPolicyFilter(mutationModel, undefined, 'delete');
return {
...result,
where: WhereNode.create(result.where ? conjunction(this.dialect, [result.where.where, filter]) : filter),
};
}

private extractTableName(from: OperationNode): { model: GetModels<Schema>; alias?: string } | undefined {
if (TableNode.is(from)) {
return { model: from.table.identifier.name as GetModels<Schema> };
Expand Down Expand Up @@ -528,4 +559,6 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
}
return result;
}

// #endregion
}
51 changes: 51 additions & 0 deletions packages/runtime/test/policy/crud/delete.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import { describe, expect, it } from 'vitest';
import { createPolicyTestClient } from '../utils';

describe('Delete policy tests', () => {
it('works with top-level delete/deleteMany', async () => {
const db = await createPolicyTestClient(
`
model Foo {
id Int @id
x Int
@@allow('create,read', true)
@@allow('delete', x > 0)
}
`,
);

await db.foo.create({ data: { id: 1, x: 0 } });
await expect(db.foo.delete({ where: { id: 1 } })).toBeRejectedNotFound();

await db.foo.create({ data: { id: 2, x: 1 } });
await expect(db.foo.delete({ where: { id: 2 } })).toResolveTruthy();
await expect(db.foo.count()).resolves.toBe(1);

await db.foo.create({ data: { id: 3, x: 1 } });
await expect(db.foo.deleteMany()).resolves.toMatchObject({ count: 1 });
await expect(db.foo.count()).resolves.toBe(1);
});

it('works with query builder delete', async () => {
const db = await createPolicyTestClient(
`
model Foo {
id Int @id
x Int
@@allow('create,read', true)
@@allow('delete', x > 0)
}
`,
);
await db.foo.create({ data: { id: 1, x: 0 } });
await db.foo.create({ data: { id: 2, x: 1 } });

await expect(db.$qb.deleteFrom('Foo').where('id', '=', 1).executeTakeFirst()).resolves.toMatchObject({
numDeletedRows: 0n,
});
await expect(db.foo.count()).resolves.toBe(2);

await expect(db.$qb.deleteFrom('Foo').executeTakeFirst()).resolves.toMatchObject({ numDeletedRows: 1n });
await expect(db.foo.count()).resolves.toBe(1);
});
});
96 changes: 96 additions & 0 deletions packages/runtime/test/policy/crud/update.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -953,4 +953,100 @@ model Foo {
);
});
});

describe('Query builder tests', () => {
it('works with simple update', async () => {
const db = await createPolicyTestClient(
`
model Foo {
id Int @id
x Int
@@allow('create', true)
@@allow('update', x > 1)
@@allow('read', true)
}
`,
);

await db.foo.createMany({
data: [
{ id: 1, x: 1 },
{ id: 2, x: 2 },
{ id: 3, x: 3 },
],
});

// not updatable
await expect(
db.$qb.updateTable('Foo').set({ x: 5 }).where('id', '=', 1).executeTakeFirst(),
).resolves.toMatchObject({ numUpdatedRows: 0n });

// with where
await expect(
db.$qb.updateTable('Foo').set({ x: 5 }).where('id', '=', 2).executeTakeFirst(),
).resolves.toMatchObject({ numUpdatedRows: 1n });
await expect(db.foo.findUnique({ where: { id: 2 } })).resolves.toMatchObject({ x: 5 });

// without where
await expect(db.$qb.updateTable('Foo').set({ x: 6 }).executeTakeFirst()).resolves.toMatchObject({
numUpdatedRows: 2n,
});
await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 1 });
});

it('works with insert on conflict do update', async () => {
const db = await createPolicyTestClient(
`
model Foo {
id Int @id
x Int
@@allow('create', true)
@@allow('update', x > 1)
@@allow('read', true)
}
`,
);

await db.foo.createMany({
data: [
{ id: 1, x: 1 },
{ id: 2, x: 2 },
{ id: 3, x: 3 },
],
});

// #1 not updatable
await expect(
db.$qb
.insertInto('Foo')
.values({ id: 1, x: 5 })
.onConflict((oc: any) => oc.column('id').doUpdateSet({ x: 5 }))
.executeTakeFirst(),
).resolves.toMatchObject({ numInsertedOrUpdatedRows: 0n });
await expect(db.foo.count()).resolves.toBe(3);
await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 1 });

// with where, #1 not updatable
await expect(
db.$qb
.insertInto('Foo')
.values({ id: 1, x: 5 })
.onConflict((oc: any) => oc.column('id').doUpdateSet({ x: 5 }).where('id', '=', 1))
.executeTakeFirst(),
).resolves.toMatchObject({ numInsertedOrUpdatedRows: 0n });
await expect(db.foo.count()).resolves.toBe(3);
await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 1 });

// with where, #2 updatable
await expect(
db.$qb
.insertInto('Foo')
.values({ id: 2, x: 5 })
.onConflict((oc: any) => oc.column('id').doUpdateSet({ x: 6 }).where('id', '=', 2))
.executeTakeFirst(),
).resolves.toMatchObject({ numInsertedOrUpdatedRows: 1n });
await expect(db.foo.count()).resolves.toBe(3);
await expect(db.foo.findUnique({ where: { id: 2 } })).resolves.toMatchObject({ x: 6 });
});
});
});