diff --git a/packages/runtime/src/constants.ts b/packages/runtime/src/constants.ts index 495e1853d..b4fd9204c 100644 --- a/packages/runtime/src/constants.ts +++ b/packages/runtime/src/constants.ts @@ -77,5 +77,6 @@ export const ACTIONS_WITH_WRITE_PAYLOAD = [ 'createManyAndReturn', 'update', 'updateMany', + 'updateManyAndReturn', 'upsert', ]; diff --git a/packages/runtime/src/cross/nested-write-visitor.ts b/packages/runtime/src/cross/nested-write-visitor.ts index ba4b232a6..cff7f8143 100644 --- a/packages/runtime/src/cross/nested-write-visitor.ts +++ b/packages/runtime/src/cross/nested-write-visitor.ts @@ -247,6 +247,7 @@ export class NestedWriteVisitor { break; case 'updateMany': + case 'updateManyAndReturn': for (const item of this.enumerateReverse(data)) { const newContext = pushNewContext(field, model, item.where); let callbackResult: any; diff --git a/packages/runtime/src/cross/types.ts b/packages/runtime/src/cross/types.ts index 0466df447..50d4f1e02 100644 --- a/packages/runtime/src/cross/types.ts +++ b/packages/runtime/src/cross/types.ts @@ -8,6 +8,7 @@ export const PrismaWriteActions = [ 'connectOrCreate', 'update', 'updateMany', + 'updateManyAndReturn', 'upsert', 'connect', 'disconnect', diff --git a/packages/runtime/src/enhancements/node/policy/handler.ts b/packages/runtime/src/enhancements/node/policy/handler.ts index 098a87a67..673665dd5 100644 --- a/packages/runtime/src/enhancements/node/policy/handler.ts +++ b/packages/runtime/src/enhancements/node/policy/handler.ts @@ -511,7 +511,7 @@ export class PolicyProxyHandler implements Pr }); } - // throw read-back error if any of create result read-back fails + // throw read-back error if any of the create result read-back fails const error = result.find((r) => !!r.error)?.error; if (error) { throw error; @@ -1268,6 +1268,14 @@ export class PolicyProxyHandler implements Pr } updateMany(args: any) { + return this.doUpdateMany(args, 'updateMany'); + } + + updateManyAndReturn(args: any): Promise { + return this.doUpdateMany(args, 'updateManyAndReturn'); + } + + private doUpdateMany(args: any, action: 'updateMany' | 'updateManyAndReturn'): Promise { if (!args) { throw prismaClientValidationError(this.prisma, this.prismaModule, 'query argument is required'); } @@ -1279,9 +1287,10 @@ export class PolicyProxyHandler implements Pr ); } - return createDeferredPromise(() => { + return createDeferredPromise(async () => { this.policyUtils.tryReject(this.prisma, this.model, 'update'); + const origArgs = args; args = this.policyUtils.safeClone(args); this.policyUtils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'update'); @@ -1302,13 +1311,37 @@ export class PolicyProxyHandler implements Pr if (this.shouldLogQuery) { this.logger.info(`[policy] \`updateMany\` ${this.model}: ${formatObject(args)}`); } - return this.modelClient.updateMany(args); + if (action === 'updateMany') { + return this.modelClient.updateMany(args); + } else { + // make sure only id fields are returned so we can directly use the result + // for read-back check + const updatedArg = { + ...args, + select: this.policyUtils.makeIdSelection(this.model), + include: undefined, + }; + const updated = await this.modelClient.updateManyAndReturn(updatedArg); + // process read-back + const result = await Promise.all( + updated.map((item) => + this.policyUtils.readBack(this.prisma, this.model, 'update', origArgs, item) + ) + ); + // throw read-back error if any of create result read-back fails + const error = result.find((r) => !!r.error)?.error; + if (error) { + throw error; + } else { + return result.map((r) => r.result); + } + } } // collect post-update checks const postWriteChecks: PostWriteCheckRecord[] = []; - return this.queryUtils.transaction(this.prisma, async (tx) => { + const result = await this.queryUtils.transaction(this.prisma, async (tx) => { // collect pre-update values let select = this.policyUtils.makeIdSelection(this.model); const preValueSelect = this.policyUtils.getPreValueSelect(this.model); @@ -1352,13 +1385,45 @@ export class PolicyProxyHandler implements Pr if (this.shouldLogQuery) { this.logger.info(`[policy] \`updateMany\` in tx for ${this.model}: ${formatObject(args)}`); } - const result = await tx[this.model].updateMany(args); - // run post-write checks - await this.runPostWriteChecks(postWriteChecks, tx); + if (action === 'updateMany') { + const result = await tx[this.model].updateMany(args); + // run post-write checks + await this.runPostWriteChecks(postWriteChecks, tx); + return result; + } else { + // make sure only id fields are returned so we can directly use the result + // for read-back check + const updatedArg = { + ...args, + select: this.policyUtils.makeIdSelection(this.model), + include: undefined, + }; + const result = await tx[this.model].updateManyAndReturn(updatedArg); + // run post-write checks + await this.runPostWriteChecks(postWriteChecks, tx); + return result; + } + }); + if (action === 'updateMany') { + // no further processing needed return result; - }); + } else { + // process read-back + const readBackResult = await Promise.all( + (result as unknown[]).map((item) => + this.policyUtils.readBack(this.prisma, this.model, 'update', origArgs, item) + ) + ); + // throw read-back error if any of the update result read-back fails + const error = readBackResult.find((r) => !!r.error)?.error; + if (error) { + throw error; + } else { + return readBackResult.map((r) => r.result); + } + } }); } diff --git a/packages/runtime/src/enhancements/node/proxy.ts b/packages/runtime/src/enhancements/node/proxy.ts index cfbc0eb7c..e063f002b 100644 --- a/packages/runtime/src/enhancements/node/proxy.ts +++ b/packages/runtime/src/enhancements/node/proxy.ts @@ -35,6 +35,8 @@ export interface PrismaProxyHandler { updateMany(args: any): Promise; + updateManyAndReturn(args: any): Promise; + upsert(args: any): Promise; delete(args: any): Promise; @@ -132,6 +134,10 @@ export class DefaultPrismaProxyHandler implements PrismaProxyHandler { return this.deferred<{ count: number }>('updateMany', args, false); } + updateManyAndReturn(args: any) { + return this.deferred('updateManyAndReturn', args); + } + upsert(args: any) { return this.deferred('upsert', args); } diff --git a/packages/runtime/src/types.ts b/packages/runtime/src/types.ts index fe31a5058..5027fb5c6 100644 --- a/packages/runtime/src/types.ts +++ b/packages/runtime/src/types.ts @@ -19,6 +19,7 @@ export interface DbOperations { createManyAndReturn(args: unknown): Promise; update(args: unknown): Promise; updateMany(args: unknown): Promise<{ count: number }>; + updateManyAndReturn(args: unknown): Promise; upsert(args: unknown): Promise; delete(args: unknown): Promise; deleteMany(args?: unknown): Promise<{ count: number }>; diff --git a/tests/integration/tests/enhancements/with-policy/update-many-and-return.test.ts b/tests/integration/tests/enhancements/with-policy/update-many-and-return.test.ts new file mode 100644 index 000000000..4797cf670 --- /dev/null +++ b/tests/integration/tests/enhancements/with-policy/update-many-and-return.test.ts @@ -0,0 +1,140 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('Test API updateManyAndReturn', () => { + it('model-level policies', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + posts Post[] + level Int + + @@allow('read', level > 0) + } + + model Post { + id Int @id @default(autoincrement()) + title String + published Boolean @default(false) + userId Int + user User @relation(fields: [userId], references: [id]) + + @@allow('read', published) + @@allow('update', contains(title, 'hello')) + } + ` + ); + + await prisma.user.createMany({ + data: [{ id: 1, level: 1 }], + }); + await prisma.user.createMany({ + data: [{ id: 2, level: 0 }], + }); + + await prisma.post.createMany({ + data: [ + { id: 1, title: 'hello1', userId: 1, published: true }, + { id: 2, title: 'world1', userId: 1, published: false }, + ], + }); + + const db = enhance(); + + // only post#1 is updated + let r = await db.post.updateManyAndReturn({ + data: { title: 'foo' }, + }); + expect(r).toHaveLength(1); + expect(r[0].id).toBe(1); + + // post#2 is excluded from update + await expect( + db.post.updateManyAndReturn({ + where: { id: 2 }, + data: { title: 'foo' }, + }) + ).resolves.toHaveLength(0); + + // reset + await prisma.post.update({ where: { id: 1 }, data: { title: 'hello1' } }); + + // post#1 is updated + await expect( + db.post.updateManyAndReturn({ + where: { id: 1 }, + data: { title: 'foo' }, + }) + ).resolves.toHaveLength(1); + + // reset + await prisma.post.update({ where: { id: 1 }, data: { title: 'hello1' } }); + + // read-back check + // post#1 updated but can't be read back + await expect( + db.post.updateManyAndReturn({ + data: { published: false }, + }) + ).toBeRejectedByPolicy(['result is not allowed to be read back']); + // but the update should have been applied + await expect(prisma.post.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ published: false }); + + // reset + await prisma.post.update({ where: { id: 1 }, data: { published: true } }); + + // return relation + r = await db.post.updateManyAndReturn({ + include: { user: true }, + data: { title: 'hello2' }, + }); + expect(r[0]).toMatchObject({ user: { id: 1 } }); + + // relation filtered + await prisma.post.create({ data: { id: 3, title: 'hello3', userId: 2, published: true } }); + await expect( + db.post.updateManyAndReturn({ + where: { id: 3 }, + include: { user: true }, + data: { title: 'hello4' }, + }) + ).toBeRejectedByPolicy(['result is not allowed to be read back']); + // update is applied + await expect(prisma.post.findUnique({ where: { id: 3 } })).resolves.toMatchObject({ title: 'hello4' }); + }); + + it('field-level policies', async () => { + const { prisma, enhance } = await loadSchema( + ` + model Post { + id Int @id @default(autoincrement()) + title String @allow('read', published) + published Boolean @default(false) + + @@allow('all', true) + } + ` + ); + + const db = enhance(); + + // update should succeed but one result's title field can't be read back + await prisma.post.createMany({ + data: [ + { id: 1, title: 'post1', published: true }, + { id: 2, title: 'post2', published: false }, + ], + }); + + const r = await db.post.updateManyAndReturn({ + data: { title: 'foo' }, + }); + + expect(r.length).toBe(2); + expect(r[0].title).toBeTruthy(); + expect(r[1].title).toBeUndefined(); + + // check posts are updated + await expect(prisma.post.findMany({ where: { title: 'foo' } })).resolves.toHaveLength(2); + }); +}); diff --git a/tests/regression/tests/issue-1955.test.ts b/tests/regression/tests/issue-1955.test.ts index 3b9d116f9..703dd0f44 100644 --- a/tests/regression/tests/issue-1955.test.ts +++ b/tests/regression/tests/issue-1955.test.ts @@ -21,6 +21,7 @@ describe('issue 1955', () => { _prisma = prisma; const db = enhance(); + await expect( db.post.createManyAndReturn({ data: [ @@ -38,6 +39,17 @@ describe('issue 1955', () => { expect.objectContaining({ name: 'blu' }), ]) ); + + await expect( + db.post.updateManyAndReturn({ + data: { name: 'foo' }, + }) + ).resolves.toEqual( + expect.arrayContaining([ + expect.objectContaining({ name: 'foo' }), + expect.objectContaining({ name: 'foo' }), + ]) + ); } finally { await _prisma.$disconnect(); await dropPostgresDb('issue-1955-1'); @@ -72,6 +84,7 @@ describe('issue 1955', () => { _prisma = prisma; const db = enhance(); + await expect( db.post.createManyAndReturn({ data: [ @@ -89,6 +102,17 @@ describe('issue 1955', () => { expect.objectContaining({ name: 'blu' }), ]) ); + + await expect( + db.post.updateManyAndReturn({ + data: { name: 'foo' }, + }) + ).resolves.toEqual( + expect.arrayContaining([ + expect.objectContaining({ name: 'foo' }), + expect.objectContaining({ name: 'foo' }), + ]) + ); } finally { await _prisma.$disconnect(); await dropPostgresDb('issue-1955-2');