Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 78 additions & 36 deletions packages/runtime/src/enhancements/node/default-auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ import { ACTIONS_WITH_WRITE_PAYLOAD } from '../../constants';
import {
FieldInfo,
NestedWriteVisitor,
NestedWriteVisitorContext,
PrismaWriteActionType,
clone,
enumerate,
getFields,
getModelInfo,
getTypeDefInfo,
requireField,
} from '../../cross';
Expand Down Expand Up @@ -61,7 +63,7 @@ class DefaultAuthHandler extends DefaultPrismaProxyHandler {
private async preprocessWritePayload(model: string, action: PrismaWriteActionType, args: any) {
const newArgs = clone(args);

const processCreatePayload = (model: string, data: any) => {
const processCreatePayload = (model: string, data: any, context: NestedWriteVisitorContext) => {
const fields = getFields(this.options.modelMeta, model);
for (const fieldInfo of Object.values(fields)) {
if (fieldInfo.isTypeDef) {
Expand All @@ -82,24 +84,24 @@ class DefaultAuthHandler extends DefaultPrismaProxyHandler {
const defaultValue = this.getDefaultValue(fieldInfo);
if (defaultValue !== undefined) {
// set field value extracted from `auth()`
this.setDefaultValueForModelData(fieldInfo, model, data, defaultValue);
this.setDefaultValueForModelData(fieldInfo, model, data, defaultValue, context);
}
}
};

// visit create payload and set default value to fields using `auth()` in `@default()`
const visitor = new NestedWriteVisitor(this.options.modelMeta, {
create: (model, data) => {
processCreatePayload(model, data);
create: (model, data, context) => {
processCreatePayload(model, data, context);
},

upsert: (model, data) => {
processCreatePayload(model, data.create);
upsert: (model, data, context) => {
processCreatePayload(model, data.create, context);
},

createMany: (model, args) => {
createMany: (model, args, context) => {
for (const item of enumerate(args.data)) {
processCreatePayload(model, item);
processCreatePayload(model, item, context);
}
},
});
Expand All @@ -108,42 +110,82 @@ class DefaultAuthHandler extends DefaultPrismaProxyHandler {
return newArgs;
}

private setDefaultValueForModelData(fieldInfo: FieldInfo, model: string, data: any, authDefaultValue: unknown) {
if (fieldInfo.isForeignKey && fieldInfo.relationField && fieldInfo.relationField in data) {
private setDefaultValueForModelData(
fieldInfo: FieldInfo,
model: string,
data: any,
authDefaultValue: unknown,
context: NestedWriteVisitorContext
) {
if (fieldInfo.isForeignKey) {
// if the field being inspected is a fk field, there are several cases we should not
// set the default value or should not set directly

// if the field is a fk, and the relation field is already set, we should not override it
return;
}
if (fieldInfo.relationField && fieldInfo.relationField in data) {
return;
}

if (fieldInfo.isForeignKey && !isUnsafeMutate(model, data, this.options.modelMeta)) {
// if the field is a fk, and the create payload is not unsafe, we need to translate
// the fk field setting to a `connect` of the corresponding relation field
const relFieldName = fieldInfo.relationField;
if (!relFieldName) {
throw new Error(
`Field \`${fieldInfo.name}\` is a foreign key field but no corresponding relation field is found`
if (context.field?.backLink && context.nestingPath.length > 1) {
// if the fk field is in a creation context where its implied by the parent,
// we should not set the default value, e.g.:
//
// ```
// parent.create({ data: { child: { create: {} } } })
// ```
//
// event if child's fk to parent has a default value, we should not set default
// value here

// fetch parent model from the parent context
const parentModel = getModelInfo(
this.options.modelMeta,
context.nestingPath[context.nestingPath.length - 2].model
);
}
const relationField = requireField(this.options.modelMeta, model, relFieldName);

// construct a `{ connect: { ... } }` payload
let connect = data[relationField.name]?.connect;
if (!connect) {
connect = {};
data[relationField.name] = { connect };
if (parentModel) {
// get the opposite side of the relation for the current create context
const oppositeRelationField = requireField(this.options.modelMeta, model, context.field.backLink);
if (parentModel.name === oppositeRelationField.type) {
// if the opposite side matches the parent model, it means we currently in a creation context
// that implicitly sets this fk field
return;
}
}
}

// sets the opposite fk field to value `authDefaultValue`
const oppositeFkFieldName = this.getOppositeFkFieldName(relationField, fieldInfo);
if (!oppositeFkFieldName) {
throw new Error(
`Cannot find opposite foreign key field for \`${fieldInfo.name}\` in relation field \`${relFieldName}\``
);
if (!isUnsafeMutate(model, data, this.options.modelMeta)) {
// if the field is a fk, and the create payload is not unsafe, we need to translate
// the fk field setting to a `connect` of the corresponding relation field
const relFieldName = fieldInfo.relationField;
if (!relFieldName) {
throw new Error(
`Field \`${fieldInfo.name}\` is a foreign key field but no corresponding relation field is found`
);
}
const relationField = requireField(this.options.modelMeta, model, relFieldName);

// construct a `{ connect: { ... } }` payload
let connect = data[relationField.name]?.connect;
if (!connect) {
connect = {};
data[relationField.name] = { connect };
}

// sets the opposite fk field to value `authDefaultValue`
const oppositeFkFieldName = this.getOppositeFkFieldName(relationField, fieldInfo);
if (!oppositeFkFieldName) {
throw new Error(
`Cannot find opposite foreign key field for \`${fieldInfo.name}\` in relation field \`${relFieldName}\``
);
}
connect[oppositeFkFieldName] = authDefaultValue;
return;
}
connect[oppositeFkFieldName] = authDefaultValue;
} else {
// set default value directly
data[fieldInfo.name] = authDefaultValue;
}

// set default value directly
data[fieldInfo.name] = authDefaultValue;
}

private getOppositeFkFieldName(relationField: FieldInfo, fieldInfo: FieldInfo) {
Expand Down
131 changes: 131 additions & 0 deletions tests/regression/tests/issue-1997.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import { loadSchema } from '@zenstackhq/testtools';

describe('issue 1997', () => {
it('regression', async () => {
const { prisma, enhance } = await loadSchema(
`
model Tenant {
id String @id @default(uuid())

users User[]
posts Post[]
comments Comment[]
postUserLikes PostUserLikes[]
}

model User {
id String @id @default(uuid())
tenantId String @default(auth().tenantId)
tenant Tenant @relation(fields: [tenantId], references: [id])
posts Post[]
likes PostUserLikes[]

@@allow('all', true)
}

model Post {
tenantId String @default(auth().tenantId)
tenant Tenant @relation(fields: [tenantId], references: [id])
id String @default(uuid())
author User @relation(fields: [authorId], references: [id])
authorId String @default(auth().id)

comments Comment[]
likes PostUserLikes[]

@@id([tenantId, id])

@@allow('all', true)
}

model PostUserLikes {
tenantId String @default(auth().tenantId)
tenant Tenant @relation(fields: [tenantId], references: [id])
id String @default(uuid())

userId String
user User @relation(fields: [userId], references: [id])

postId String
post Post @relation(fields: [tenantId, postId], references: [tenantId, id])

@@id([tenantId, id])
@@unique([tenantId, userId, postId])

@@allow('all', true)
}

model Comment {
tenantId String @default(auth().tenantId)
tenant Tenant @relation(fields: [tenantId], references: [id])
id String @default(uuid())
postId String
post Post @relation(fields: [tenantId, postId], references: [tenantId, id])

@@id([tenantId, id])

@@allow('all', true)
}
`,
{ logPrismaQuery: true }
);

const tenant = await prisma.tenant.create({
data: {},
});
const user = await prisma.user.create({
data: { tenantId: tenant.id },
});

const db = enhance({ id: user.id, tenantId: tenant.id });

await expect(
db.post.create({
data: {
likes: {
createMany: {
data: [
{
userId: user.id,
},
],
},
},
},
include: {
likes: true,
},
})
).resolves.toMatchObject({
authorId: user.id,
likes: [
{
tenantId: tenant.id,
userId: user.id,
},
],
});

await expect(
db.post.create({
data: {
comments: {
createMany: {
data: [{}],
},
},
},
include: {
comments: true,
},
})
).resolves.toMatchObject({
authorId: user.id,
comments: [
{
tenantId: tenant.id,
},
],
});
});
});
Loading