Skip to content

Commit dc18713

Browse files
authored
feat(policy): support "insert on conflict update" (#251)
* feat(policy): support "insert on conflict update" * address pr comments
1 parent a7b9ad3 commit dc18713

File tree

3 files changed

+245
-64
lines changed

3 files changed

+245
-64
lines changed

packages/runtime/src/plugins/policy/policy-handler.ts

Lines changed: 98 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,102 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
134134
// return result;
135135
}
136136

137+
// #region overrides
138+
139+
protected override transformSelectQuery(node: SelectQueryNode) {
140+
let whereNode = node.where;
141+
142+
node.from?.froms.forEach((from) => {
143+
const extractResult = this.extractTableName(from);
144+
if (extractResult) {
145+
const { model, alias } = extractResult;
146+
const filter = this.buildPolicyFilter(model, alias, 'read');
147+
whereNode = WhereNode.create(
148+
whereNode?.where ? conjunction(this.dialect, [whereNode.where, filter]) : filter,
149+
);
150+
}
151+
});
152+
153+
const baseResult = super.transformSelectQuery({
154+
...node,
155+
where: undefined,
156+
});
157+
158+
return {
159+
...baseResult,
160+
where: whereNode,
161+
};
162+
}
163+
164+
protected override transformInsertQuery(node: InsertQueryNode) {
165+
// pre-insert check is done in `handle()`
166+
167+
let onConflict = node.onConflict;
168+
169+
if (onConflict?.updates) {
170+
// for "on conflict do update", we need to apply policy filter to the "where" clause
171+
const mutationModel = this.getMutationModel(node);
172+
const filter = this.buildPolicyFilter(mutationModel, undefined, 'update');
173+
if (onConflict.updateWhere) {
174+
onConflict = {
175+
...onConflict,
176+
updateWhere: WhereNode.create(conjunction(this.dialect, [onConflict.updateWhere.where, filter])),
177+
};
178+
} else {
179+
onConflict = {
180+
...onConflict,
181+
updateWhere: WhereNode.create(filter),
182+
};
183+
}
184+
}
185+
186+
// merge updated onConflict
187+
const processedNode = onConflict ? { ...node, onConflict } : node;
188+
189+
const result = super.transformInsertQuery(processedNode);
190+
191+
if (!node.returning) {
192+
return result;
193+
}
194+
195+
if (this.onlyReturningId(node)) {
196+
return result;
197+
} else {
198+
// only return ID fields, that's enough for reading back the inserted row
199+
const idFields = getIdFields(this.client.$schema, this.getMutationModel(node));
200+
return {
201+
...result,
202+
returning: ReturningNode.create(
203+
idFields.map((field) => SelectionNode.create(ColumnNode.create(field))),
204+
),
205+
};
206+
}
207+
}
208+
209+
protected override transformUpdateQuery(node: UpdateQueryNode) {
210+
const result = super.transformUpdateQuery(node);
211+
const mutationModel = this.getMutationModel(node);
212+
const filter = this.buildPolicyFilter(mutationModel, undefined, 'update');
213+
return {
214+
...result,
215+
where: WhereNode.create(result.where ? conjunction(this.dialect, [result.where.where, filter]) : filter),
216+
};
217+
}
218+
219+
protected override transformDeleteQuery(node: DeleteQueryNode) {
220+
const result = super.transformDeleteQuery(node);
221+
const mutationModel = this.getMutationModel(node);
222+
const filter = this.buildPolicyFilter(mutationModel, undefined, 'delete');
223+
return {
224+
...result,
225+
where: WhereNode.create(result.where ? conjunction(this.dialect, [result.where.where, filter]) : filter),
226+
};
227+
}
228+
229+
// #endregion
230+
231+
// #region helpers
232+
137233
private onlyReturningId(node: MutationQueryNode) {
138234
if (!node.returning) {
139235
return true;
@@ -397,70 +493,6 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
397493
return combinedPolicy;
398494
}
399495

400-
protected override transformSelectQuery(node: SelectQueryNode) {
401-
let whereNode = node.where;
402-
403-
node.from?.froms.forEach((from) => {
404-
const extractResult = this.extractTableName(from);
405-
if (extractResult) {
406-
const { model, alias } = extractResult;
407-
const filter = this.buildPolicyFilter(model, alias, 'read');
408-
whereNode = WhereNode.create(
409-
whereNode?.where ? conjunction(this.dialect, [whereNode.where, filter]) : filter,
410-
);
411-
}
412-
});
413-
414-
const baseResult = super.transformSelectQuery({
415-
...node,
416-
where: undefined,
417-
});
418-
419-
return {
420-
...baseResult,
421-
where: whereNode,
422-
};
423-
}
424-
425-
protected override transformInsertQuery(node: InsertQueryNode) {
426-
const result = super.transformInsertQuery(node);
427-
if (!node.returning) {
428-
return result;
429-
}
430-
if (this.onlyReturningId(node)) {
431-
return result;
432-
} else {
433-
// only return ID fields, that's enough for reading back the inserted row
434-
const idFields = getIdFields(this.client.$schema, this.getMutationModel(node));
435-
return {
436-
...result,
437-
returning: ReturningNode.create(
438-
idFields.map((field) => SelectionNode.create(ColumnNode.create(field))),
439-
),
440-
};
441-
}
442-
}
443-
444-
protected override transformUpdateQuery(node: UpdateQueryNode) {
445-
const result = super.transformUpdateQuery(node);
446-
const mutationModel = this.getMutationModel(node);
447-
const filter = this.buildPolicyFilter(mutationModel, undefined, 'update');
448-
return {
449-
...result,
450-
where: WhereNode.create(result.where ? conjunction(this.dialect, [result.where.where, filter]) : filter),
451-
};
452-
}
453-
454-
protected override transformDeleteQuery(node: DeleteQueryNode) {
455-
const result = super.transformDeleteQuery(node);
456-
const mutationModel = this.getMutationModel(node);
457-
const filter = this.buildPolicyFilter(mutationModel, undefined, 'delete');
458-
return {
459-
...result,
460-
where: WhereNode.create(result.where ? conjunction(this.dialect, [result.where.where, filter]) : filter),
461-
};
462-
}
463-
464496
private extractTableName(from: OperationNode): { model: GetModels<Schema>; alias?: string } | undefined {
465497
if (TableNode.is(from)) {
466498
return { model: from.table.identifier.name as GetModels<Schema> };
@@ -528,4 +560,6 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
528560
}
529561
return result;
530562
}
563+
564+
// #endregion
531565
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import { describe, expect, it } from 'vitest';
2+
import { createPolicyTestClient } from '../utils';
3+
4+
describe('Delete policy tests', () => {
5+
it('works with top-level delete/deleteMany', async () => {
6+
const db = await createPolicyTestClient(
7+
`
8+
model Foo {
9+
id Int @id
10+
x Int
11+
@@allow('create,read', true)
12+
@@allow('delete', x > 0)
13+
}
14+
`,
15+
);
16+
17+
await db.foo.create({ data: { id: 1, x: 0 } });
18+
await expect(db.foo.delete({ where: { id: 1 } })).toBeRejectedNotFound();
19+
20+
await db.foo.create({ data: { id: 2, x: 1 } });
21+
await expect(db.foo.delete({ where: { id: 2 } })).toResolveTruthy();
22+
await expect(db.foo.count()).resolves.toBe(1);
23+
24+
await db.foo.create({ data: { id: 3, x: 1 } });
25+
await expect(db.foo.deleteMany()).resolves.toMatchObject({ count: 1 });
26+
await expect(db.foo.count()).resolves.toBe(1);
27+
});
28+
29+
it('works with query builder delete', async () => {
30+
const db = await createPolicyTestClient(
31+
`
32+
model Foo {
33+
id Int @id
34+
x Int
35+
@@allow('create,read', true)
36+
@@allow('delete', x > 0)
37+
}
38+
`,
39+
);
40+
await db.foo.create({ data: { id: 1, x: 0 } });
41+
await db.foo.create({ data: { id: 2, x: 1 } });
42+
43+
await expect(db.$qb.deleteFrom('Foo').where('id', '=', 1).executeTakeFirst()).resolves.toMatchObject({
44+
numDeletedRows: 0n,
45+
});
46+
await expect(db.foo.count()).resolves.toBe(2);
47+
48+
await expect(db.$qb.deleteFrom('Foo').executeTakeFirst()).resolves.toMatchObject({ numDeletedRows: 1n });
49+
await expect(db.foo.count()).resolves.toBe(1);
50+
});
51+
});

packages/runtime/test/policy/crud/update.test.ts

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -953,4 +953,100 @@ model Foo {
953953
);
954954
});
955955
});
956+
957+
describe('Query builder tests', () => {
958+
it('works with simple update', async () => {
959+
const db = await createPolicyTestClient(
960+
`
961+
model Foo {
962+
id Int @id
963+
x Int
964+
@@allow('create', true)
965+
@@allow('update', x > 1)
966+
@@allow('read', true)
967+
}
968+
`,
969+
);
970+
971+
await db.foo.createMany({
972+
data: [
973+
{ id: 1, x: 1 },
974+
{ id: 2, x: 2 },
975+
{ id: 3, x: 3 },
976+
],
977+
});
978+
979+
// not updatable
980+
await expect(
981+
db.$qb.updateTable('Foo').set({ x: 5 }).where('id', '=', 1).executeTakeFirst(),
982+
).resolves.toMatchObject({ numUpdatedRows: 0n });
983+
984+
// with where
985+
await expect(
986+
db.$qb.updateTable('Foo').set({ x: 5 }).where('id', '=', 2).executeTakeFirst(),
987+
).resolves.toMatchObject({ numUpdatedRows: 1n });
988+
await expect(db.foo.findUnique({ where: { id: 2 } })).resolves.toMatchObject({ x: 5 });
989+
990+
// without where
991+
await expect(db.$qb.updateTable('Foo').set({ x: 6 }).executeTakeFirst()).resolves.toMatchObject({
992+
numUpdatedRows: 2n,
993+
});
994+
await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 1 });
995+
});
996+
997+
it('works with insert on conflict do update', async () => {
998+
const db = await createPolicyTestClient(
999+
`
1000+
model Foo {
1001+
id Int @id
1002+
x Int
1003+
@@allow('create', true)
1004+
@@allow('update', x > 1)
1005+
@@allow('read', true)
1006+
}
1007+
`,
1008+
);
1009+
1010+
await db.foo.createMany({
1011+
data: [
1012+
{ id: 1, x: 1 },
1013+
{ id: 2, x: 2 },
1014+
{ id: 3, x: 3 },
1015+
],
1016+
});
1017+
1018+
// #1 not updatable
1019+
await expect(
1020+
db.$qb
1021+
.insertInto('Foo')
1022+
.values({ id: 1, x: 5 })
1023+
.onConflict((oc: any) => oc.column('id').doUpdateSet({ x: 5 }))
1024+
.executeTakeFirst(),
1025+
).resolves.toMatchObject({ numInsertedOrUpdatedRows: 0n });
1026+
await expect(db.foo.count()).resolves.toBe(3);
1027+
await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 1 });
1028+
1029+
// with where, #1 not updatable
1030+
await expect(
1031+
db.$qb
1032+
.insertInto('Foo')
1033+
.values({ id: 1, x: 5 })
1034+
.onConflict((oc: any) => oc.column('id').doUpdateSet({ x: 5 }).where('id', '=', 1))
1035+
.executeTakeFirst(),
1036+
).resolves.toMatchObject({ numInsertedOrUpdatedRows: 0n });
1037+
await expect(db.foo.count()).resolves.toBe(3);
1038+
await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 1 });
1039+
1040+
// with where, #2 updatable
1041+
await expect(
1042+
db.$qb
1043+
.insertInto('Foo')
1044+
.values({ id: 2, x: 5 })
1045+
.onConflict((oc: any) => oc.column('id').doUpdateSet({ x: 6 }).where('id', '=', 2))
1046+
.executeTakeFirst(),
1047+
).resolves.toMatchObject({ numInsertedOrUpdatedRows: 1n });
1048+
await expect(db.foo.count()).resolves.toBe(3);
1049+
await expect(db.foo.findUnique({ where: { id: 2 } })).resolves.toMatchObject({ x: 6 });
1050+
});
1051+
});
9561052
});

0 commit comments

Comments
 (0)