From f13b92fd390e81d95c5964ffc0c0d0500eaa15d6 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Wed, 10 Sep 2025 15:44:51 -0700 Subject: [PATCH 1/3] chore(policy): more test cases and update --- README.md | 2 +- .../src/client/crud/operations/base.ts | 33 +- .../runtime/test/policy/crud/create.test.ts | 67 ++ .../runtime/test/policy/crud/update.test.ts | 585 ++++++++++++++++-- 4 files changed, 616 insertions(+), 71 deletions(-) diff --git a/README.md b/README.md index f134a133..7f6dfbe0 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ -> V3 is currently in alpha phase and not ready for production use. Feedback and bug reports are greatly appreciated. Please visit this dedicated [discord channel](https://discord.com/channels/1035538056146595961/1352359627525718056) for chat and support. +> V3 is currently in beta phase and not ready for production use. Feedback and bug reports are greatly appreciated. Please visit this dedicated [discord channel](https://discord.com/channels/1035538056146595961/1352359627525718056) for chat and support. # What's ZenStack diff --git a/packages/runtime/src/client/crud/operations/base.ts b/packages/runtime/src/client/crud/operations/base.ts index 7dd4626e..14b85f55 100644 --- a/packages/runtime/src/client/crud/operations/base.ts +++ b/packages/runtime/src/client/crud/operations/base.ts @@ -281,7 +281,7 @@ export abstract class BaseOperationHandler { ); Object.assign(createFields, parentFkFields); } else { - parentUpdateTask = (entity) => { + parentUpdateTask = async (entity) => { const query = kysely .updateTable(fromRelation.model) .set( @@ -300,7 +300,10 @@ export abstract class BaseOperationHandler { operation: 'update', }), ); - return this.executeQuery(kysely, query, 'update'); + const result = await this.executeQuery(kysely, query, 'update'); + if (!result.numAffectedRows) { + throw new NotFoundError(fromRelation.model); + } }; } } @@ -1551,8 +1554,11 @@ export abstract class BaseOperationHandler { fromRelation.field, ); let updateResult: QueryResult; + let updateModel: GetModels; if (ownedByModel) { + updateModel = fromRelation.model; + // set parent fk directly invariant(_data.length === 1, 'only one entity can be connected'); const target = await this.readUnique(kysely, model, { @@ -1581,6 +1587,8 @@ export abstract class BaseOperationHandler { ); updateResult = await this.executeQuery(kysely, query, 'connect'); } else { + updateModel = model; + // disconnect current if it's a one-one relation const relationFieldDef = this.requireField(fromRelation.model, fromRelation.field); @@ -1621,9 +1629,9 @@ export abstract class BaseOperationHandler { } // validate connect result - if (_data.length > updateResult.numAffectedRows!) { + if (!updateResult.numAffectedRows || _data.length > updateResult.numAffectedRows) { // some entities were not connected - throw new NotFoundError(model); + throw new NotFoundError(updateModel); } } } @@ -1735,7 +1743,10 @@ export abstract class BaseOperationHandler { operation: 'update', }), ); - await this.executeQuery(kysely, query, 'disconnect'); + const result = await this.executeQuery(kysely, query, 'disconnect'); + if (!result.numAffectedRows) { + throw new NotFoundError(fromRelation.model); + } } else { // disconnect const query = kysely @@ -1859,7 +1870,7 @@ export abstract class BaseOperationHandler { const r = await this.executeQuery(kysely, query, 'connect'); // validate result - if (_data.length > r.numAffectedRows!) { + if (!r.numAffectedRows || _data.length > r.numAffectedRows) { // some entities were not connected throw new NotFoundError(model); } @@ -1892,9 +1903,12 @@ export abstract class BaseOperationHandler { } let deleteResult: { count: number }; + let deleteFromModel: GetModels; const m2m = getManyToManyRelation(this.schema, fromRelation.model, fromRelation.field); if (m2m) { + deleteFromModel = model; + // handle many-to-many relation const fieldDef = this.requireField(fromRelation.model, fromRelation.field); invariant(fieldDef.relation?.opposite); @@ -1919,11 +1933,13 @@ export abstract class BaseOperationHandler { ); if (ownedByModel) { + deleteFromModel = fromRelation.model; + const fromEntity = await this.readUnique(kysely, fromRelation.model as GetModels, { where: fromRelation.ids, }); if (!fromEntity) { - throw new NotFoundError(model); + throw new NotFoundError(fromRelation.model); } const fieldDef = this.requireField(fromRelation.model, fromRelation.field); @@ -1938,6 +1954,7 @@ export abstract class BaseOperationHandler { ], }); } else { + deleteFromModel = model; deleteResult = await this.delete(kysely, model, { AND: [ Object.fromEntries(keyPairs.map(({ fk, pk }) => [fk, fromRelation.ids[pk]])), @@ -1952,7 +1969,7 @@ export abstract class BaseOperationHandler { // validate result if (throwForNotFound && expectedDeleteCount > deleteResult.count) { // some entities were not deleted - throw new NotFoundError(model); + throw new NotFoundError(deleteFromModel); } } diff --git a/packages/runtime/test/policy/crud/create.test.ts b/packages/runtime/test/policy/crud/create.test.ts index be8c82da..dbd7a414 100644 --- a/packages/runtime/test/policy/crud/create.test.ts +++ b/packages/runtime/test/policy/crud/create.test.ts @@ -206,4 +206,71 @@ model Profile { await expect(db.$setAuth({ id: 4 }).profile.create({ data: { id: 2, userId: 4 } })).toBeRejectedByPolicy(); }); + + it('works with nested create owner side', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + profile Profile? + @@allow('all', true) +} + +model Profile { + id Int @id + user User? @relation(fields: [userId], references: [id]) + userId Int? @unique + + @@deny('all', auth() == null) + @@allow('create', user.id == auth().id) + @@allow('read', true) +} + `, + ); + + await expect(db.user.create({ data: { id: 1, profile: { create: { id: 1 } } } })).toBeRejectedByPolicy(); + await expect( + db + .$setAuth({ id: 1 }) + .user.create({ data: { id: 1, profile: { create: { id: 1 } } }, include: { profile: true } }), + ).resolves.toMatchObject({ + id: 1, + profile: { + id: 1, + }, + }); + }); + + it('works with nested create non-owner side', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + profile Profile? + @@deny('all', auth() == null) + @@allow('create', this.id == auth().id) + @@allow('read', true) +} + +model Profile { + id Int @id + user User? @relation(fields: [userId], references: [id]) + userId Int? @unique + @@allow('all', true) +} + `, + ); + + await expect(db.profile.create({ data: { id: 1, user: { create: { id: 1 } } } })).toBeRejectedByPolicy(); + await expect( + db + .$setAuth({ id: 1 }) + .profile.create({ data: { id: 1, user: { create: { id: 1 } } }, include: { user: true } }), + ).resolves.toMatchObject({ + id: 1, + user: { + id: 1, + }, + }); + }); }); diff --git a/packages/runtime/test/policy/crud/update.test.ts b/packages/runtime/test/policy/crud/update.test.ts index eb5735a2..b7277015 100644 --- a/packages/runtime/test/policy/crud/update.test.ts +++ b/packages/runtime/test/policy/crud/update.test.ts @@ -2,9 +2,10 @@ import { describe, expect, it } from 'vitest'; import { createPolicyTestClient } from '../utils'; describe('Update policy tests', () => { - it('works with scalar field check', async () => { - const db = await createPolicyTestClient( - ` + describe('Scala condition tests', () => { + it('works with scalar field check', async () => { + const db = await createPolicyTestClient( + ` model Foo { id Int @id x Int @@ -12,24 +13,24 @@ model Foo { @@allow('create,read', true) } `, - ); - - await db.foo.create({ data: { id: 1, x: 0 } }); - await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).toBeRejectedNotFound(); - await db.foo.create({ data: { id: 2, x: 1 } }); - await expect(db.foo.update({ where: { id: 2 }, data: { x: 2 } })).resolves.toMatchObject({ x: 2 }); - - await expect( - db.$qb.updateTable('Foo').set({ x: 1 }).where('id', '=', 1).executeTakeFirst(), - ).resolves.toMatchObject({ numUpdatedRows: 0n }); - await expect( - db.$qb.updateTable('Foo').set({ x: 3 }).where('id', '=', 2).returningAll().execute(), - ).resolves.toMatchObject([{ id: 2, x: 3 }]); - }); + ); + + await db.foo.create({ data: { id: 1, x: 0 } }); + await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).toBeRejectedNotFound(); + await db.foo.create({ data: { id: 2, x: 1 } }); + await expect(db.foo.update({ where: { id: 2 }, data: { x: 2 } })).resolves.toMatchObject({ x: 2 }); + + await expect( + db.$qb.updateTable('Foo').set({ x: 1 }).where('id', '=', 1).executeTakeFirst(), + ).resolves.toMatchObject({ numUpdatedRows: 0n }); + await expect( + db.$qb.updateTable('Foo').set({ x: 3 }).where('id', '=', 2).returningAll().execute(), + ).resolves.toMatchObject([{ id: 2, x: 3 }]); + }); - it('works with this scalar member check', async () => { - const db = await createPolicyTestClient( - ` + it('works with this scalar member check', async () => { + const db = await createPolicyTestClient( + ` model Foo { id Int @id x Int @@ -37,32 +38,32 @@ model Foo { @@allow('create,read', true) } `, - ); + ); - await db.foo.create({ data: { id: 1, x: 0 } }); - await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).toBeRejectedNotFound(); - await db.foo.create({ data: { id: 2, x: 1 } }); - await expect(db.foo.update({ where: { id: 2 }, data: { x: 2 } })).resolves.toMatchObject({ x: 2 }); - }); + await db.foo.create({ data: { id: 1, x: 0 } }); + await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).toBeRejectedNotFound(); + await db.foo.create({ data: { id: 2, x: 1 } }); + await expect(db.foo.update({ where: { id: 2 }, data: { x: 2 } })).resolves.toMatchObject({ x: 2 }); + }); - it('denies by default', async () => { - const db = await createPolicyTestClient( - ` + it('denies by default', async () => { + const db = await createPolicyTestClient( + ` model Foo { id Int @id x Int @@allow('create,read', true) } `, - ); + ); - await db.foo.create({ data: { id: 1, x: 0 } }); - await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).toBeRejectedNotFound(); - }); + await db.foo.create({ data: { id: 1, x: 0 } }); + await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).toBeRejectedNotFound(); + }); - it('works with deny rule', async () => { - const db = await createPolicyTestClient( - ` + it('works with deny rule', async () => { + const db = await createPolicyTestClient( + ` model Foo { id Int @id x Int @@ -70,16 +71,16 @@ model Foo { @@allow('create,read,update', true) } `, - ); - await db.foo.create({ data: { id: 1, x: 0 } }); - await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).toBeRejectedNotFound(); - await db.foo.create({ data: { id: 2, x: 1 } }); - await expect(db.foo.update({ where: { id: 2 }, data: { x: 2 } })).resolves.toMatchObject({ x: 2 }); - }); + ); + await db.foo.create({ data: { id: 1, x: 0 } }); + await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).toBeRejectedNotFound(); + await db.foo.create({ data: { id: 2, x: 1 } }); + await expect(db.foo.update({ where: { id: 2 }, data: { x: 2 } })).resolves.toMatchObject({ x: 2 }); + }); - it('works with mixed allow and deny rules', async () => { - const db = await createPolicyTestClient( - ` + it('works with mixed allow and deny rules', async () => { + const db = await createPolicyTestClient( + ` model Foo { id Int @id x Int @@ -88,19 +89,19 @@ model Foo { @@allow('create,read', true) } `, - ); - - await db.foo.create({ data: { id: 1, x: 0 } }); - await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).toBeRejectedNotFound(); - await db.foo.create({ data: { id: 2, x: 1 } }); - await expect(db.foo.update({ where: { id: 2 }, data: { x: 2 } })).toBeRejectedNotFound(); - await db.foo.create({ data: { id: 3, x: 2 } }); - await expect(db.foo.update({ where: { id: 3 }, data: { x: 3 } })).resolves.toMatchObject({ x: 3 }); - }); + ); + + await db.foo.create({ data: { id: 1, x: 0 } }); + await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).toBeRejectedNotFound(); + await db.foo.create({ data: { id: 2, x: 1 } }); + await expect(db.foo.update({ where: { id: 2 }, data: { x: 2 } })).toBeRejectedNotFound(); + await db.foo.create({ data: { id: 3, x: 2 } }); + await expect(db.foo.update({ where: { id: 3 }, data: { x: 3 } })).resolves.toMatchObject({ x: 3 }); + }); - it('works with auth check', async () => { - const db = await createPolicyTestClient( - ` + it('works with auth check', async () => { + const db = await createPolicyTestClient( + ` type Auth { x Int @@auth @@ -113,11 +114,471 @@ model Foo { @@allow('create,read', true) } `, - ); - await db.foo.create({ data: { id: 1, x: 1 } }); - await expect(db.$setAuth({ x: 0 }).foo.update({ where: { id: 1 }, data: { x: 2 } })).toBeRejectedNotFound(); - await expect(db.$setAuth({ x: 1 }).foo.update({ where: { id: 1 }, data: { x: 2 } })).resolves.toMatchObject({ - x: 2, + ); + await db.foo.create({ data: { id: 1, x: 1 } }); + await expect(db.$setAuth({ x: 0 }).foo.update({ where: { id: 1 }, data: { x: 2 } })).toBeRejectedNotFound(); + await expect(db.$setAuth({ x: 1 }).foo.update({ where: { id: 1 }, data: { x: 2 } })).resolves.toMatchObject( + { + x: 2, + }, + ); }); }); + + describe('Relation condition tests', () => { + it('works with to-one relation check owner side', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + name String + profile Profile? + @@allow('all', true) +} + +model Profile { + id Int @id + bio String + user User @relation(fields: [userId], references: [id]) + userId Int @unique + @@allow('create,read', true) + @@allow('update', user.name == 'User2') +} +`, + ); + + await db.user.create({ data: { id: 1, name: 'User1', profile: { create: { id: 1, bio: 'Bio1' } } } }); + await expect(db.profile.update({ where: { id: 1 }, data: { bio: 'UpdatedBio1' } })).toBeRejectedNotFound(); + + await db.user.create({ data: { id: 2, name: 'User2', profile: { create: { id: 2, bio: 'Bio2' } } } }); + await expect(db.profile.update({ where: { id: 2 }, data: { bio: 'UpdatedBio2' } })).resolves.toMatchObject({ + bio: 'UpdatedBio2', + }); + }); + + it('works with to-one relation check owner side', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + name String + profile Profile @relation(fields: [profileId], references: [id]) + profileId Int @unique + @@allow('all', true) +} + +model Profile { + id Int @id + bio String + user User? + @@allow('create,read', true) + @@allow('update', user.name == 'User2') +} +`, + ); + + await db.user.create({ data: { id: 1, name: 'User1', profile: { create: { id: 1, bio: 'Bio1' } } } }); + await expect(db.profile.update({ where: { id: 1 }, data: { bio: 'UpdatedBio1' } })).toBeRejectedNotFound(); + + await db.user.create({ data: { id: 2, name: 'User2', profile: { create: { id: 2, bio: 'Bio2' } } } }); + await expect(db.profile.update({ where: { id: 2 }, data: { bio: 'UpdatedBio2' } })).resolves.toMatchObject({ + bio: 'UpdatedBio2', + }); + }); + + it('works with to-many relation check some', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + name String + posts Post[] + @@allow('create,read', true) + @@allow('update', posts?[published]) +} + +model Post { + id Int @id + title String + published Boolean + author User @relation(fields: [authorId], references: [id]) + authorId Int + @@allow('all', true) +} +`, + ); + + await db.user.create({ data: { id: 1, name: 'User1' } }); + await expect(db.user.update({ where: { id: 1 }, data: { name: 'UpdatedUser1' } })).toBeRejectedNotFound(); + + await db.user.create({ + data: { id: 2, name: 'User2', posts: { create: { id: 1, title: 'Post1', published: false } } }, + }); + await expect(db.user.update({ where: { id: 2 }, data: { name: 'UpdatedUser2' } })).toBeRejectedNotFound(); + + await db.user.create({ + data: { + id: 3, + name: 'User3', + posts: { + create: [ + { id: 2, title: 'Post2', published: false }, + { id: 3, title: 'Post3', published: true }, + ], + }, + }, + }); + await expect(db.user.update({ where: { id: 3 }, data: { name: 'UpdatedUser3' } })).toResolveTruthy(); + }); + + it('works with to-many relation check all', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + name String + posts Post[] + @@allow('create,read', true) + @@allow('update', posts![published]) +} + +model Post { + id Int @id + title String + published Boolean + author User @relation(fields: [authorId], references: [id]) + authorId Int + @@allow('all', true) +} +`, + ); + + await db.user.create({ data: { id: 1, name: 'User1' } }); + await expect(db.user.update({ where: { id: 1 }, data: { name: 'UpdatedUser1' } })).toResolveTruthy(); + + await db.user.create({ + data: { + id: 2, + name: 'User2', + posts: { + create: [ + { id: 1, title: 'Post1', published: false }, + { id: 2, title: 'Post2', published: true }, + ], + }, + }, + }); + await expect(db.user.update({ where: { id: 2 }, data: { name: 'UpdatedUser2' } })).toBeRejectedNotFound(); + + await db.user.create({ + data: { + id: 3, + name: 'User3', + posts: { + create: [ + { id: 3, title: 'Post3', published: true }, + { id: 4, title: 'Post4', published: true }, + ], + }, + }, + }); + await expect(db.user.update({ where: { id: 3 }, data: { name: 'UpdatedUser3' } })).toResolveTruthy(); + }); + + it('works with to-many relation check none', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + name String + posts Post[] + @@allow('create,read', true) + @@allow('update', posts^[published]) +} + +model Post { + id Int @id + title String + published Boolean + author User @relation(fields: [authorId], references: [id]) + authorId Int + @@allow('all', true) +} +`, + ); + + await db.user.create({ data: { id: 1, name: 'User1' } }); + await expect(db.user.update({ where: { id: 1 }, data: { name: 'UpdatedUser1' } })).toResolveTruthy(); + + await db.user.create({ + data: { + id: 2, + name: 'User2', + posts: { + create: [ + { id: 1, title: 'Post1', published: false }, + { id: 2, title: 'Post2', published: true }, + ], + }, + }, + }); + await expect(db.user.update({ where: { id: 2 }, data: { name: 'UpdatedUser2' } })).toBeRejectedNotFound(); + + await db.user.create({ + data: { + id: 3, + name: 'User3', + posts: { + create: [ + { id: 3, title: 'Post3', published: false }, + { id: 4, title: 'Post4', published: false }, + ], + }, + }, + }); + await expect(db.user.update({ where: { id: 3 }, data: { name: 'UpdatedUser3' } })).toResolveTruthy(); + }); + }); + + describe('Nested update tests', () => { + it('works with nested update owner side', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + profile Profile? + @@allow('all', true) +} + +model Profile { + id Int @id + bio String + private Boolean + user User? @relation(fields: [userId], references: [id]) + userId Int? @unique + @@allow('create,read', true) + @@allow('update', !private) +} +`, + ); + + await db.user.create({ data: { id: 1, profile: { create: { id: 1, bio: 'Bio1', private: true } } } }); + await expect( + db.user.update({ + where: { id: 1 }, + data: { profile: { update: { bio: 'UpdatedBio1' } } }, + }), + ).toBeRejectedNotFound(); + + await db.user.create({ data: { id: 2, profile: { create: { id: 2, bio: 'Bio2', private: false } } } }); + await expect( + db.user.update({ + where: { id: 2 }, + data: { profile: { update: { bio: 'UpdatedBio2' } } }, + include: { profile: true }, + }), + ).resolves.toMatchObject({ + profile: { + bio: 'UpdatedBio2', + }, + }); + }); + + it('works with nested update non-owner side', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + profile Profile @relation(fields: [profileId], references: [id]) + profileId Int @unique + @@allow('all', true) +} + +model Profile { + id Int @id + bio String + private Boolean + user User? + @@allow('create,read', true) + @@allow('update', !private) +} +`, + ); + + await db.user.create({ data: { id: 1, profile: { create: { id: 1, bio: 'Bio1', private: true } } } }); + await expect( + db.user.update({ + where: { id: 1 }, + data: { profile: { update: { bio: 'UpdatedBio1' } } }, + }), + ).toBeRejectedNotFound(); + + await db.user.create({ data: { id: 2, profile: { create: { id: 2, bio: 'Bio2', private: false } } } }); + await expect( + db.user.update({ + where: { id: 2 }, + data: { profile: { update: { bio: 'UpdatedBio2' } } }, + include: { profile: true }, + }), + ).resolves.toMatchObject({ + profile: { + bio: 'UpdatedBio2', + }, + }); + }); + }); + + describe('Relation manipulation tests', () => { + it('works with connect/disconnect/create owner side', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + profile Profile? + @@allow('all', true) +} + +model Profile { + id Int @id + private Boolean + user User? @relation(fields: [userId], references: [id]) + userId Int? @unique + @@allow('create,read', true) + @@allow('update', !private) +} +`, + ); + + await db.user.create({ data: { id: 1 } }); + + await db.profile.create({ data: { id: 1, private: true } }); + await expect( + db.user.update({ + where: { id: 1 }, + data: { profile: { connect: { id: 1 } } }, + include: { profile: true }, + }), + ).toBeRejectedNotFound(); + + await db.profile.create({ data: { id: 2, private: false } }); + await expect( + db.user.update({ + where: { id: 1 }, + data: { profile: { connect: { id: 2 } } }, + include: { profile: true }, + }), + ).resolves.toMatchObject({ + profile: { + id: 2, + }, + }); + await expect( + db.user.update({ + where: { id: 1 }, + data: { profile: { disconnect: true } }, + include: { profile: true }, + }), + ).resolves.toMatchObject({ + profile: null, + }); + // reconnect + await db.user.update({ where: { id: 1 }, data: { profile: { connect: { id: 2 } } } }); + // set private + await db.profile.update({ where: { id: 2 }, data: { private: true } }); + // disconnect should have no effect since update is not allowed + await expect( + db.user.update({ + where: { id: 1 }, + data: { profile: { disconnect: true } }, + include: { profile: true }, + }), + ).resolves.toMatchObject({ profile: { id: 2 } }); + + await db.profile.create({ data: { id: 3, private: true } }); + await expect( + db.profile.update({ + where: { id: 3 }, + data: { user: { create: { id: 2 } } }, + }), + ).toBeRejectedNotFound(); + }); + + it('works with connect/disconnect/create non-owner side', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + profile Profile? @relation(fields: [profileId], references: [id]) + profileId Int? @unique + private Boolean + @@allow('create,read', true) + @@allow('update', !private) +} + +model Profile { + id Int @id + user User? + @@allow('all', true) +} +`, + ); + + await db.user.create({ data: { id: 1, private: true } }); + await db.profile.create({ data: { id: 1 } }); + await expect( + db.user.update({ + where: { id: 1 }, + data: { profile: { connect: { id: 1 } } }, + include: { profile: true }, + }), + ).toBeRejectedNotFound(); + + await db.user.create({ data: { id: 2, private: false } }); + await db.profile.create({ data: { id: 2 } }); + await expect( + db.user.update({ + where: { id: 2 }, + data: { profile: { connect: { id: 2 } } }, + include: { profile: true }, + }), + ).resolves.toMatchObject({ + profile: { + id: 2, + }, + }); + await expect( + db.user.update({ + where: { id: 2 }, + data: { profile: { disconnect: true } }, + include: { profile: true }, + }), + ).resolves.toMatchObject({ + profile: null, + }); + // reconnect + await db.user.update({ where: { id: 2 }, data: { profile: { connect: { id: 2 } } } }); + // set private + await db.user.update({ where: { id: 2 }, data: { private: true } }); + // disconnect should be rejected since update is not allowed + await expect( + db.user.update({ + where: { id: 2 }, + data: { profile: { disconnect: true } }, + include: { profile: true }, + }), + ).toBeRejectedNotFound(); + + await db.profile.create({ data: { id: 3 } }); + await expect( + db.profile.update({ + where: { id: 3 }, + data: { user: { create: { id: 3, private: true } } }, + }), + ).toResolveTruthy(); + }); + }); + + // describe('Upsert tests', () => {}); + + // describe('Update many tests', () => {}); }); From ca061694e613b2f725da3aa940ef3c2f1a3eb084 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Wed, 10 Sep 2025 16:17:26 -0700 Subject: [PATCH 2/3] fix update regression --- packages/runtime/src/client/crud/operations/base.ts | 8 +++++++- packages/runtime/test/policy/crud/update.test.ts | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/packages/runtime/src/client/crud/operations/base.ts b/packages/runtime/src/client/crud/operations/base.ts index 14b85f55..ff5685c2 100644 --- a/packages/runtime/src/client/crud/operations/base.ts +++ b/packages/runtime/src/client/crud/operations/base.ts @@ -1745,7 +1745,13 @@ export abstract class BaseOperationHandler { ); const result = await this.executeQuery(kysely, query, 'disconnect'); if (!result.numAffectedRows) { - throw new NotFoundError(fromRelation.model); + // determine if the parent entity doesn't exist, or the relation entity to be disconnected doesn't exist + const parentExists = await this.exists(kysely, fromRelation.model, fromRelation.ids); + if (!parentExists) { + throw new NotFoundError(fromRelation.model); + } else { + // silently ignore + } } } else { // disconnect diff --git a/packages/runtime/test/policy/crud/update.test.ts b/packages/runtime/test/policy/crud/update.test.ts index b7277015..e0082a49 100644 --- a/packages/runtime/test/policy/crud/update.test.ts +++ b/packages/runtime/test/policy/crud/update.test.ts @@ -2,7 +2,7 @@ import { describe, expect, it } from 'vitest'; import { createPolicyTestClient } from '../utils'; describe('Update policy tests', () => { - describe('Scala condition tests', () => { + describe('Scalar condition tests', () => { it('works with scalar field check', async () => { const db = await createPolicyTestClient( ` From 9ef7e186be908526c1f86a14679b3f509e0943bd Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Thu, 11 Sep 2025 22:13:39 -0700 Subject: [PATCH 3/3] optimize nested relation manipulation --- .../dialects/{base.ts => base-dialect.ts} | 6 +- .../runtime/src/client/crud/dialects/index.ts | 2 +- .../src/client/crud/dialects/postgresql.ts | 2 +- .../src/client/crud/dialects/sqlite.ts | 2 +- .../src/client/crud/operations/base.ts | 187 +++++++----------- packages/runtime/src/client/options.ts | 2 +- .../plugins/policy/expression-transformer.ts | 2 +- .../src/plugins/policy/policy-handler.ts | 9 +- packages/runtime/src/plugins/policy/utils.ts | 2 +- .../runtime/test/client-api/update.test.ts | 18 +- 10 files changed, 106 insertions(+), 126 deletions(-) rename packages/runtime/src/client/crud/dialects/{base.ts => base-dialect.ts} (99%) diff --git a/packages/runtime/src/client/crud/dialects/base.ts b/packages/runtime/src/client/crud/dialects/base-dialect.ts similarity index 99% rename from packages/runtime/src/client/crud/dialects/base.ts rename to packages/runtime/src/client/crud/dialects/base-dialect.ts index 8afec156..9f314bf9 100644 --- a/packages/runtime/src/client/crud/dialects/base.ts +++ b/packages/runtime/src/client/crud/dialects/base-dialect.ts @@ -1104,7 +1104,7 @@ export abstract class BaseCrudDialect { return (node as ValueNode).value === false || (node as ValueNode).value === 0; } - protected and(eb: ExpressionBuilder, ...args: Expression[]) { + and(eb: ExpressionBuilder, ...args: Expression[]) { const nonTrueArgs = args.filter((arg) => !this.isTrue(arg)); if (nonTrueArgs.length === 0) { return this.true(eb); @@ -1115,7 +1115,7 @@ export abstract class BaseCrudDialect { } } - protected or(eb: ExpressionBuilder, ...args: Expression[]) { + or(eb: ExpressionBuilder, ...args: Expression[]) { const nonFalseArgs = args.filter((arg) => !this.isFalse(arg)); if (nonFalseArgs.length === 0) { return this.false(eb); @@ -1126,7 +1126,7 @@ export abstract class BaseCrudDialect { } } - protected not(eb: ExpressionBuilder, ...args: Expression[]) { + not(eb: ExpressionBuilder, ...args: Expression[]) { return eb.not(this.and(eb, ...args)); } diff --git a/packages/runtime/src/client/crud/dialects/index.ts b/packages/runtime/src/client/crud/dialects/index.ts index 9d67009e..ede19cdd 100644 --- a/packages/runtime/src/client/crud/dialects/index.ts +++ b/packages/runtime/src/client/crud/dialects/index.ts @@ -1,7 +1,7 @@ import { match } from 'ts-pattern'; import type { SchemaDef } from '../../../schema'; import type { ClientOptions } from '../../options'; -import type { BaseCrudDialect } from './base'; +import type { BaseCrudDialect } from './base-dialect'; import { PostgresCrudDialect } from './postgresql'; import { SqliteCrudDialect } from './sqlite'; diff --git a/packages/runtime/src/client/crud/dialects/postgresql.ts b/packages/runtime/src/client/crud/dialects/postgresql.ts index 93722037..a71e987d 100644 --- a/packages/runtime/src/client/crud/dialects/postgresql.ts +++ b/packages/runtime/src/client/crud/dialects/postgresql.ts @@ -20,7 +20,7 @@ import { requireField, requireModel, } from '../../query-utils'; -import { BaseCrudDialect } from './base'; +import { BaseCrudDialect } from './base-dialect'; export class PostgresCrudDialect extends BaseCrudDialect { override get provider() { diff --git a/packages/runtime/src/client/crud/dialects/sqlite.ts b/packages/runtime/src/client/crud/dialects/sqlite.ts index 34ece56e..69de608d 100644 --- a/packages/runtime/src/client/crud/dialects/sqlite.ts +++ b/packages/runtime/src/client/crud/dialects/sqlite.ts @@ -20,7 +20,7 @@ import { requireField, requireModel, } from '../../query-utils'; -import { BaseCrudDialect } from './base'; +import { BaseCrudDialect } from './base-dialect'; export class SqliteCrudDialect extends BaseCrudDialect { override get provider() { diff --git a/packages/runtime/src/client/crud/operations/base.ts b/packages/runtime/src/client/crud/operations/base.ts index ff5685c2..65d0d32b 100644 --- a/packages/runtime/src/client/crud/operations/base.ts +++ b/packages/runtime/src/client/crud/operations/base.ts @@ -7,7 +7,6 @@ import { UpdateResult, type Compilable, type IsolationLevel, - type QueryResult, type SelectQueryBuilder, } from 'kysely'; import { nanoid } from 'nanoid'; @@ -44,7 +43,7 @@ import { requireModel, } from '../../query-utils'; import { getCrudDialect } from '../dialects'; -import type { BaseCrudDialect } from '../dialects/base'; +import type { BaseCrudDialect } from '../dialects/base-dialect'; import { InputValidator } from '../validator'; export type CoreCrudOperation = @@ -66,10 +65,16 @@ export type CoreCrudOperation = export type AllCrudOperation = CoreCrudOperation | 'findUniqueOrThrow' | 'findFirstOrThrow'; +// context for nested relation operations export type FromRelationContext = { + // the model where the relation field is defined model: GetModels; + // the relation field name field: string; + // the parent entity's id fields and values ids: any; + // for relations owned by model, record the parent updates needed after the relation is processed + parentUpdates: Record; }; export abstract class BaseOperationHandler { @@ -258,7 +263,7 @@ export abstract class BaseOperationHandler { } let createFields: any = {}; - let parentUpdateTask: ((entity: any) => Promise) | undefined = undefined; + let updateParent: ((entity: any) => void) | undefined = undefined; let m2m: ReturnType = undefined; @@ -281,28 +286,10 @@ export abstract class BaseOperationHandler { ); Object.assign(createFields, parentFkFields); } else { - parentUpdateTask = async (entity) => { - const query = kysely - .updateTable(fromRelation.model) - .set( - keyPairs.reduce( - (acc, { fk, pk }) => ({ - ...acc, - [fk]: entity[pk], - }), - {} as any, - ), - ) - .where((eb) => eb.and(fromRelation.ids)) - .modifyEnd( - this.makeContextComment({ - model: fromRelation.model, - operation: 'update', - }), - ); - const result = await this.executeQuery(kysely, query, 'update'); - if (!result.numAffectedRows) { - throw new NotFoundError(fromRelation.model); + // record parent fk update after entity is created + updateParent = (entity) => { + for (const { fk, pk } of keyPairs) { + fromRelation.parentUpdates[fk] = entity[pk]; } }; } @@ -406,8 +393,8 @@ export abstract class BaseOperationHandler { } // finally update parent if needed - if (parentUpdateTask) { - await parentUpdateTask(createdEntity); + if (updateParent) { + updateParent(createdEntity); } return createdEntity; @@ -611,10 +598,11 @@ export abstract class BaseOperationHandler { const relationFieldDef = this.requireField(contextModel, relationFieldName); const relationModel = relationFieldDef.type as GetModels; const tasks: Promise[] = []; - const fromRelationContext = { + const fromRelationContext: FromRelationContext = { model: contextModel, field: relationFieldName, ids: parentEntity, + parentUpdates: {}, }; for (const [action, subPayload] of Object.entries(payload)) { @@ -647,13 +635,7 @@ export abstract class BaseOperationHandler { } case 'connect': { - tasks.push( - this.connectRelation(kysely, relationModel, subPayload, { - model: contextModel, - field: relationFieldName, - ids: parentEntity, - }), - ); + tasks.push(this.connectRelation(kysely, relationModel, subPayload, fromRelationContext)); break; } @@ -662,16 +644,8 @@ export abstract class BaseOperationHandler { ...enumerate(subPayload).map((item) => this.exists(kysely, relationModel, item.where).then((found) => !found - ? this.create(kysely, relationModel, item.create, { - model: contextModel, - field: relationFieldName, - ids: parentEntity, - }) - : this.connectRelation(kysely, relationModel, found, { - model: contextModel, - field: relationFieldName, - ids: parentEntity, - }), + ? this.create(kysely, relationModel, item.create, fromRelationContext) + : this.connectRelation(kysely, relationModel, found, fromRelationContext), ), ), ); @@ -1047,7 +1021,7 @@ export abstract class BaseOperationHandler { } } } - await this.processRelationUpdates( + const parentUpdates = await this.processRelationUpdates( kysely, model, field, @@ -1056,6 +1030,11 @@ export abstract class BaseOperationHandler { finalData[field], throwIfNotFound, ); + + if (Object.keys(parentUpdates).length > 0) { + // merge field updates propagated from nested relation processing + Object.assign(updateFields, parentUpdates); + } } } @@ -1375,10 +1354,11 @@ export abstract class BaseOperationHandler { ) { const tasks: Promise[] = []; const fieldModel = fieldDef.type as GetModels; - const fromRelationContext = { + const fromRelationContext: FromRelationContext = { model, field, ids: parentIds, + parentUpdates: {}, }; for (const [key, value] of Object.entries(args)) { @@ -1509,6 +1489,8 @@ export abstract class BaseOperationHandler { } await Promise.all(tasks); + + return fromRelationContext.parentUpdates; } // #region relation manipulation @@ -1553,13 +1535,9 @@ export abstract class BaseOperationHandler { fromRelation.model, fromRelation.field, ); - let updateResult: QueryResult; - let updateModel: GetModels; if (ownedByModel) { - updateModel = fromRelation.model; - - // set parent fk directly + // record parent fk update invariant(_data.length === 1, 'only one entity can be connected'); const target = await this.readUnique(kysely, model, { where: _data[0], @@ -1567,28 +1545,11 @@ export abstract class BaseOperationHandler { if (!target) { throw new NotFoundError(model); } - const query = kysely - .updateTable(fromRelation.model) - .where((eb) => eb.and(fromRelation.ids)) - .set( - keyPairs.reduce( - (acc, { fk, pk }) => ({ - ...acc, - [fk]: target[pk], - }), - {} as any, - ), - ) - .modifyEnd( - this.makeContextComment({ - model: fromRelation.model, - operation: 'update', - }), - ); - updateResult = await this.executeQuery(kysely, query, 'connect'); - } else { - updateModel = model; + for (const { fk, pk } of keyPairs) { + fromRelation.parentUpdates[fk] = target[pk]; + } + } else { // disconnect current if it's a one-one relation const relationFieldDef = this.requireField(fromRelation.model, fromRelation.field); @@ -1625,13 +1586,13 @@ export abstract class BaseOperationHandler { operation: 'update', }), ); - updateResult = await this.executeQuery(kysely, query, 'connect'); - } + const updateResult = await this.executeQuery(kysely, query, 'connect'); - // validate connect result - if (!updateResult.numAffectedRows || _data.length > updateResult.numAffectedRows) { - // some entities were not connected - throw new NotFoundError(updateModel); + // validate connect result + if (!updateResult.numAffectedRows || _data.length > updateResult.numAffectedRows) { + // some entities were not connected + throw new NotFoundError(model); + } } } } @@ -1715,42 +1676,42 @@ export abstract class BaseOperationHandler { const eb = expressionBuilder(); if (ownedByModel) { - // set parent fk directly + // record parent fk update invariant(disconnectConditions.length === 1, 'only one entity can be disconnected'); const condition = disconnectConditions[0]; - const query = kysely - .updateTable(fromRelation.model) - // id filter - .where(eb.and(fromRelation.ids)) - // merge extra disconnect conditions - .$if(condition !== true, (qb) => - qb.where( - eb( - // @ts-ignore - eb.refTuple(...keyPairs.map(({ fk }) => fk)), - 'in', - eb - .selectFrom(model) - .select(keyPairs.map(({ pk }) => pk)) - .where(this.dialect.buildFilter(eb, model, model, condition)), - ), - ), - ) - .set(keyPairs.reduce((acc, { fk }) => ({ ...acc, [fk]: null }), {} as any)) - .modifyEnd( - this.makeContextComment({ - model: fromRelation.model, - operation: 'update', - }), - ); - const result = await this.executeQuery(kysely, query, 'disconnect'); - if (!result.numAffectedRows) { - // determine if the parent entity doesn't exist, or the relation entity to be disconnected doesn't exist - const parentExists = await this.exists(kysely, fromRelation.model, fromRelation.ids); - if (!parentExists) { - throw new NotFoundError(fromRelation.model); - } else { - // silently ignore + + if (condition === true) { + // just disconnect, record parent fk update + for (const { fk } of keyPairs) { + fromRelation.parentUpdates[fk] = null; + } + } else { + // disconnect with a filter + + // read parent's fk + const fromEntity = await this.readUnique(kysely, fromRelation.model, { + where: fromRelation.ids, + select: fieldsToSelectObject(keyPairs.map(({ fk }) => fk)), + }); + if (!fromEntity || keyPairs.some(({ fk }) => fromEntity[fk] == null)) { + return; + } + + // check if the disconnect target exists under parent fk and the filter condition + const relationFilter = { + AND: [condition, Object.fromEntries(keyPairs.map(({ fk, pk }) => [pk, fromEntity[fk]]))], + }; + + // if the target exists, record parent fk update, otherwise do nothing + const targetExists = await this.read(kysely, model, { + where: relationFilter, + take: 1, + select: this.makeIdSelect(model), + } as any); + if (targetExists.length > 0) { + for (const { fk } of keyPairs) { + fromRelation.parentUpdates[fk] = null; + } } } } else { diff --git a/packages/runtime/src/client/options.ts b/packages/runtime/src/client/options.ts index 3146a402..7c90e330 100644 --- a/packages/runtime/src/client/options.ts +++ b/packages/runtime/src/client/options.ts @@ -2,7 +2,7 @@ import type { Dialect, Expression, ExpressionBuilder, KyselyConfig } from 'kysel import type { GetModel, GetModels, ProcedureDef, SchemaDef } from '../schema'; import type { PrependParameter } from '../utils/type-utils'; import type { ClientContract, CRUD, ProcedureFunc } from './contract'; -import type { BaseCrudDialect } from './crud/dialects/base'; +import type { BaseCrudDialect } from './crud/dialects/base-dialect'; import type { RuntimePlugin } from './plugin'; import type { ToKyselySchema } from './query-builder'; diff --git a/packages/runtime/src/plugins/policy/expression-transformer.ts b/packages/runtime/src/plugins/policy/expression-transformer.ts index 68af823b..9cf81ccc 100644 --- a/packages/runtime/src/plugins/policy/expression-transformer.ts +++ b/packages/runtime/src/plugins/policy/expression-transformer.ts @@ -22,7 +22,7 @@ import { import { match } from 'ts-pattern'; import type { CRUD } from '../../client/contract'; import { getCrudDialect } from '../../client/crud/dialects'; -import type { BaseCrudDialect } from '../../client/crud/dialects/base'; +import type { BaseCrudDialect } from '../../client/crud/dialects/base-dialect'; import { InternalError, QueryError } from '../../client/errors'; import type { ClientOptions } from '../../client/options'; import { getModel, getRelationForeignKeyFieldPairs, requireField } from '../../client/query-utils'; diff --git a/packages/runtime/src/plugins/policy/policy-handler.ts b/packages/runtime/src/plugins/policy/policy-handler.ts index 54e4ff7d..f26c2038 100644 --- a/packages/runtime/src/plugins/policy/policy-handler.ts +++ b/packages/runtime/src/plugins/policy/policy-handler.ts @@ -30,7 +30,7 @@ import { match } from 'ts-pattern'; import type { ClientContract } from '../../client'; import type { CRUD } from '../../client/contract'; import { getCrudDialect } from '../../client/crud/dialects'; -import type { BaseCrudDialect } from '../../client/crud/dialects/base'; +import type { BaseCrudDialect } from '../../client/crud/dialects/base-dialect'; import { InternalError } from '../../client/errors'; import type { ProceedKyselyQueryFunction } from '../../client/plugin'; import { getIdFields, requireField, requireModel } from '../../client/query-utils'; @@ -180,7 +180,12 @@ export class PolicyHandler extends OperationNodeTransf // create a `SELECT column1 as field1, column2 as field2, ... FROM (VALUES (...))` table for policy evaluation const constTable: SelectQueryNode = { kind: 'SelectQueryNode', - from: FromNode.create([ParensNode.create(ValuesNode.create([ValueListNode.create(allValues)]))]), + from: FromNode.create([ + AliasNode.create( + ParensNode.create(ValuesNode.create([ValueListNode.create(allValues)])), + IdentifierNode.create('$t'), + ), + ]), selections: allFields.map((field, index) => SelectionNode.create( AliasNode.create(ColumnNode.create(`column${index + 1}`), IdentifierNode.create(field)), diff --git a/packages/runtime/src/plugins/policy/utils.ts b/packages/runtime/src/plugins/policy/utils.ts index a86b4857..1113cb7e 100644 --- a/packages/runtime/src/plugins/policy/utils.ts +++ b/packages/runtime/src/plugins/policy/utils.ts @@ -12,7 +12,7 @@ import { UnaryOperationNode, ValueNode, } from 'kysely'; -import type { BaseCrudDialect } from '../../client/crud/dialects/base'; +import type { BaseCrudDialect } from '../../client/crud/dialects/base-dialect'; import type { SchemaDef } from '../../schema'; /** diff --git a/packages/runtime/test/client-api/update.test.ts b/packages/runtime/test/client-api/update.test.ts index 2fc75fb8..a82a87bc 100644 --- a/packages/runtime/test/client-api/update.test.ts +++ b/packages/runtime/test/client-api/update.test.ts @@ -1815,6 +1815,21 @@ describe.each(createClientSpecs(PG_DB_NAME))('Client update tests', ({ createCli user: { connect: { id: '1' } }, }, }); + // not matching filter, no-op + await expect( + client.profile.update({ + where: { id: profile.id }, + data: { + user: { + disconnect: { id: '2' }, + }, + }, + include: { user: true }, + }), + ).resolves.toMatchObject({ + user: { id: '1' }, + }); + // connected, disconnect await expect( client.profile.update({ where: { id: profile.id }, @@ -1828,8 +1843,7 @@ describe.each(createClientSpecs(PG_DB_NAME))('Client update tests', ({ createCli ).resolves.toMatchObject({ user: null, }); - - // non-existing + // not connected, no-op await expect( client.profile.update({ where: { id: profile.id },