diff --git a/packages/runtime/src/enhancements/node/default-auth.ts b/packages/runtime/src/enhancements/node/default-auth.ts index e6162a2d2..f151d014f 100644 --- a/packages/runtime/src/enhancements/node/default-auth.ts +++ b/packages/runtime/src/enhancements/node/default-auth.ts @@ -5,10 +5,12 @@ import { ACTIONS_WITH_WRITE_PAYLOAD } from '../../constants'; import { FieldInfo, NestedWriteVisitor, + NestedWriteVisitorContext, PrismaWriteActionType, clone, enumerate, getFields, + getModelInfo, getTypeDefInfo, requireField, } from '../../cross'; @@ -61,7 +63,7 @@ class DefaultAuthHandler extends DefaultPrismaProxyHandler { private async preprocessWritePayload(model: string, action: PrismaWriteActionType, args: any) { const newArgs = clone(args); - const processCreatePayload = (model: string, data: any) => { + const processCreatePayload = (model: string, data: any, context: NestedWriteVisitorContext) => { const fields = getFields(this.options.modelMeta, model); for (const fieldInfo of Object.values(fields)) { if (fieldInfo.isTypeDef) { @@ -82,24 +84,24 @@ class DefaultAuthHandler extends DefaultPrismaProxyHandler { const defaultValue = this.getDefaultValue(fieldInfo); if (defaultValue !== undefined) { // set field value extracted from `auth()` - this.setDefaultValueForModelData(fieldInfo, model, data, defaultValue); + this.setDefaultValueForModelData(fieldInfo, model, data, defaultValue, context); } } }; // visit create payload and set default value to fields using `auth()` in `@default()` const visitor = new NestedWriteVisitor(this.options.modelMeta, { - create: (model, data) => { - processCreatePayload(model, data); + create: (model, data, context) => { + processCreatePayload(model, data, context); }, - upsert: (model, data) => { - processCreatePayload(model, data.create); + upsert: (model, data, context) => { + processCreatePayload(model, data.create, context); }, - createMany: (model, args) => { + createMany: (model, args, context) => { for (const item of enumerate(args.data)) { - processCreatePayload(model, item); + processCreatePayload(model, item, context); } }, }); @@ -108,42 +110,82 @@ class DefaultAuthHandler extends DefaultPrismaProxyHandler { return newArgs; } - private setDefaultValueForModelData(fieldInfo: FieldInfo, model: string, data: any, authDefaultValue: unknown) { - if (fieldInfo.isForeignKey && fieldInfo.relationField && fieldInfo.relationField in data) { + private setDefaultValueForModelData( + fieldInfo: FieldInfo, + model: string, + data: any, + authDefaultValue: unknown, + context: NestedWriteVisitorContext + ) { + if (fieldInfo.isForeignKey) { + // if the field being inspected is a fk field, there are several cases we should not + // set the default value or should not set directly + // if the field is a fk, and the relation field is already set, we should not override it - return; - } + if (fieldInfo.relationField && fieldInfo.relationField in data) { + return; + } - if (fieldInfo.isForeignKey && !isUnsafeMutate(model, data, this.options.modelMeta)) { - // if the field is a fk, and the create payload is not unsafe, we need to translate - // the fk field setting to a `connect` of the corresponding relation field - const relFieldName = fieldInfo.relationField; - if (!relFieldName) { - throw new Error( - `Field \`${fieldInfo.name}\` is a foreign key field but no corresponding relation field is found` + if (context.field?.backLink && context.nestingPath.length > 1) { + // if the fk field is in a creation context where its implied by the parent, + // we should not set the default value, e.g.: + // + // ``` + // parent.create({ data: { child: { create: {} } } }) + // ``` + // + // event if child's fk to parent has a default value, we should not set default + // value here + + // fetch parent model from the parent context + const parentModel = getModelInfo( + this.options.modelMeta, + context.nestingPath[context.nestingPath.length - 2].model ); - } - const relationField = requireField(this.options.modelMeta, model, relFieldName); - // construct a `{ connect: { ... } }` payload - let connect = data[relationField.name]?.connect; - if (!connect) { - connect = {}; - data[relationField.name] = { connect }; + if (parentModel) { + // get the opposite side of the relation for the current create context + const oppositeRelationField = requireField(this.options.modelMeta, model, context.field.backLink); + if (parentModel.name === oppositeRelationField.type) { + // if the opposite side matches the parent model, it means we currently in a creation context + // that implicitly sets this fk field + return; + } + } } - // sets the opposite fk field to value `authDefaultValue` - const oppositeFkFieldName = this.getOppositeFkFieldName(relationField, fieldInfo); - if (!oppositeFkFieldName) { - throw new Error( - `Cannot find opposite foreign key field for \`${fieldInfo.name}\` in relation field \`${relFieldName}\`` - ); + if (!isUnsafeMutate(model, data, this.options.modelMeta)) { + // if the field is a fk, and the create payload is not unsafe, we need to translate + // the fk field setting to a `connect` of the corresponding relation field + const relFieldName = fieldInfo.relationField; + if (!relFieldName) { + throw new Error( + `Field \`${fieldInfo.name}\` is a foreign key field but no corresponding relation field is found` + ); + } + const relationField = requireField(this.options.modelMeta, model, relFieldName); + + // construct a `{ connect: { ... } }` payload + let connect = data[relationField.name]?.connect; + if (!connect) { + connect = {}; + data[relationField.name] = { connect }; + } + + // sets the opposite fk field to value `authDefaultValue` + const oppositeFkFieldName = this.getOppositeFkFieldName(relationField, fieldInfo); + if (!oppositeFkFieldName) { + throw new Error( + `Cannot find opposite foreign key field for \`${fieldInfo.name}\` in relation field \`${relFieldName}\`` + ); + } + connect[oppositeFkFieldName] = authDefaultValue; + return; } - connect[oppositeFkFieldName] = authDefaultValue; - } else { - // set default value directly - data[fieldInfo.name] = authDefaultValue; } + + // set default value directly + data[fieldInfo.name] = authDefaultValue; } private getOppositeFkFieldName(relationField: FieldInfo, fieldInfo: FieldInfo) { diff --git a/tests/regression/tests/issue-1997.test.ts b/tests/regression/tests/issue-1997.test.ts new file mode 100644 index 000000000..3153c26c6 --- /dev/null +++ b/tests/regression/tests/issue-1997.test.ts @@ -0,0 +1,131 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('issue 1997', () => { + it('regression', async () => { + const { prisma, enhance } = await loadSchema( + ` + model Tenant { + id String @id @default(uuid()) + + users User[] + posts Post[] + comments Comment[] + postUserLikes PostUserLikes[] + } + + model User { + id String @id @default(uuid()) + tenantId String @default(auth().tenantId) + tenant Tenant @relation(fields: [tenantId], references: [id]) + posts Post[] + likes PostUserLikes[] + + @@allow('all', true) + } + + model Post { + tenantId String @default(auth().tenantId) + tenant Tenant @relation(fields: [tenantId], references: [id]) + id String @default(uuid()) + author User @relation(fields: [authorId], references: [id]) + authorId String @default(auth().id) + + comments Comment[] + likes PostUserLikes[] + + @@id([tenantId, id]) + + @@allow('all', true) + } + + model PostUserLikes { + tenantId String @default(auth().tenantId) + tenant Tenant @relation(fields: [tenantId], references: [id]) + id String @default(uuid()) + + userId String + user User @relation(fields: [userId], references: [id]) + + postId String + post Post @relation(fields: [tenantId, postId], references: [tenantId, id]) + + @@id([tenantId, id]) + @@unique([tenantId, userId, postId]) + + @@allow('all', true) + } + + model Comment { + tenantId String @default(auth().tenantId) + tenant Tenant @relation(fields: [tenantId], references: [id]) + id String @default(uuid()) + postId String + post Post @relation(fields: [tenantId, postId], references: [tenantId, id]) + + @@id([tenantId, id]) + + @@allow('all', true) + } + `, + { logPrismaQuery: true } + ); + + const tenant = await prisma.tenant.create({ + data: {}, + }); + const user = await prisma.user.create({ + data: { tenantId: tenant.id }, + }); + + const db = enhance({ id: user.id, tenantId: tenant.id }); + + await expect( + db.post.create({ + data: { + likes: { + createMany: { + data: [ + { + userId: user.id, + }, + ], + }, + }, + }, + include: { + likes: true, + }, + }) + ).resolves.toMatchObject({ + authorId: user.id, + likes: [ + { + tenantId: tenant.id, + userId: user.id, + }, + ], + }); + + await expect( + db.post.create({ + data: { + comments: { + createMany: { + data: [{}], + }, + }, + }, + include: { + comments: true, + }, + }) + ).resolves.toMatchObject({ + authorId: user.id, + comments: [ + { + tenantId: tenant.id, + }, + ], + }); + }); +});