Skip to content

Commit 1956bdb

Browse files
authored
fix(delegate): delegate model's guards are not properly including concrete models (#1932)
1 parent 2eecae5 commit 1956bdb

File tree

7 files changed

+207
-29
lines changed

7 files changed

+207
-29
lines changed

packages/schema/src/plugins/enhancer/enhance/index.ts

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import {
2424
isArrayExpr,
2525
isDataModel,
2626
isGeneratorDecl,
27-
isReferenceExpr,
2827
isTypeDef,
2928
type Model,
3029
} from '@zenstackhq/sdk/ast';
@@ -45,6 +44,7 @@ import {
4544
} from 'ts-morph';
4645
import { upperCaseFirst } from 'upper-case-first';
4746
import { name } from '..';
47+
import { getConcreteModels, getDiscriminatorField } from '../../../utils/ast-utils';
4848
import { execPackage } from '../../../utils/exec-utils';
4949
import { CorePlugins, getPluginCustomOutputFolder } from '../../plugin-utils';
5050
import { trackPrismaSchemaError } from '../../prisma';
@@ -407,9 +407,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
407407
this.model.declarations
408408
.filter((d): d is DataModel => isDelegateModel(d))
409409
.forEach((dm) => {
410-
const concreteModels = this.model.declarations.filter(
411-
(d): d is DataModel => isDataModel(d) && d.superTypes.some((s) => s.ref === dm)
412-
);
410+
const concreteModels = getConcreteModels(dm);
413411
if (concreteModels.length > 0) {
414412
delegateInfo.push([dm, concreteModels]);
415413
}
@@ -579,7 +577,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
579577
const typeName = typeAlias.getName();
580578
const payloadRecord = delegateInfo.find(([delegate]) => `$${delegate.name}Payload` === typeName);
581579
if (payloadRecord) {
582-
const discriminatorDecl = this.getDiscriminatorField(payloadRecord[0]);
580+
const discriminatorDecl = getDiscriminatorField(payloadRecord[0]);
583581
if (discriminatorDecl) {
584582
source = `${payloadRecord[1]
585583
.map(
@@ -826,15 +824,6 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
826824
.filter((n) => n.getName().startsWith(DELEGATE_AUX_RELATION_PREFIX));
827825
}
828826

829-
private getDiscriminatorField(delegate: DataModel) {
830-
const delegateAttr = getAttribute(delegate, '@@delegate');
831-
if (!delegateAttr) {
832-
return undefined;
833-
}
834-
const arg = delegateAttr.args[0]?.value;
835-
return isReferenceExpr(arg) ? (arg.target.ref as DataModelField) : undefined;
836-
}
837-
838827
private saveSourceFile(sf: SourceFile) {
839828
if (this.options.preserveTsFiles) {
840829
saveSourceFile(sf);

packages/schema/src/plugins/enhancer/policy/expression-writer.ts

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -839,16 +839,18 @@ export class ExpressionWriter {
839839
operation = this.options.operationContext;
840840
}
841841

842-
this.block(() => {
843-
if (operation === 'postUpdate') {
844-
// 'postUpdate' policies are not delegated to relations, just use constant `false` here
845-
// e.g.:
846-
// @@allow('all', check(author)) should not delegate "postUpdate" to author
847-
this.writer.write(`${fieldRef.target.$refText}: ${FALSE}`);
848-
} else {
849-
const targetGuardFunc = getQueryGuardFunctionName(targetModel, undefined, false, operation);
850-
this.writer.write(`${fieldRef.target.$refText}: ${targetGuardFunc}(context, db)`);
851-
}
852-
});
842+
this.block(() =>
843+
this.writeFieldCondition(fieldRef, () => {
844+
if (operation === 'postUpdate') {
845+
// 'postUpdate' policies are not delegated to relations, just use constant `false` here
846+
// e.g.:
847+
// @@allow('all', check(author)) should not delegate "postUpdate" to author
848+
this.writer.write(FALSE);
849+
} else {
850+
const targetGuardFunc = getQueryGuardFunctionName(targetModel, undefined, false, operation);
851+
this.writer.write(`${targetGuardFunc}(context, db)`);
852+
}
853+
})
854+
);
853855
}
854856
}

packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,8 @@ export class PolicyGenerator {
454454
writer: CodeBlockWriter,
455455
sourceFile: SourceFile
456456
) {
457+
// first handle several cases where a constant function can be used
458+
457459
if (kind === 'update' && allows.length === 0) {
458460
// no allow rule for 'update', policy is constant based on if there's
459461
// post-update counterpart

packages/schema/src/plugins/prisma/schema-generator.ts

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ import path from 'path';
5757
import semver from 'semver';
5858
import { name } from '.';
5959
import { getStringLiteral } from '../../language-server/validator/utils';
60+
import { getConcreteModels } from '../../utils/ast-utils';
6061
import { execPackage } from '../../utils/exec-utils';
6162
import { isDefaultWithAuth } from '../enhancer/enhancer-utils';
6263
import {
@@ -320,9 +321,7 @@ export class PrismaSchemaGenerator {
320321
}
321322

322323
// collect concrete models inheriting this model
323-
const concreteModels = decl.$container.declarations.filter(
324-
(d) => isDataModel(d) && d !== decl && d.superTypes.some((base) => base.ref === decl)
325-
);
324+
const concreteModels = getConcreteModels(decl);
326325

327326
// generate an optional relation field in delegate base model to each concrete model
328327
concreteModels.forEach((concrete) => {

packages/schema/src/utils/ast-utils.ts

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,21 @@ import {
22
BinaryExpr,
33
DataModel,
44
DataModelAttribute,
5+
DataModelField,
56
Expression,
67
InheritableNode,
78
isBinaryExpr,
89
isDataModel,
910
isDataModelField,
1011
isInvocationExpr,
1112
isModel,
13+
isReferenceExpr,
1214
isTypeDef,
1315
Model,
1416
ModelImport,
1517
TypeDef,
1618
} from '@zenstackhq/language/ast';
17-
import { getInheritanceChain, getRecursiveBases, isDelegateModel, isFromStdlib } from '@zenstackhq/sdk';
19+
import { getAttribute, getInheritanceChain, getRecursiveBases, isDelegateModel, isFromStdlib } from '@zenstackhq/sdk';
1820
import {
1921
AstNode,
2022
copyAstNode,
@@ -310,3 +312,27 @@ export function findUpInheritance(start: DataModel, target: DataModel): DataMode
310312
}
311313
return undefined;
312314
}
315+
316+
/**
317+
* Gets all concrete models that inherit from the given delegate model
318+
*/
319+
export function getConcreteModels(dataModel: DataModel): DataModel[] {
320+
if (!isDelegateModel(dataModel)) {
321+
return [];
322+
}
323+
return dataModel.$container.declarations.filter(
324+
(d): d is DataModel => isDataModel(d) && d !== dataModel && d.superTypes.some((base) => base.ref === dataModel)
325+
);
326+
}
327+
328+
/**
329+
* Gets the discriminator field for the given delegate model
330+
*/
331+
export function getDiscriminatorField(dataModel: DataModel) {
332+
const delegateAttr = getAttribute(dataModel, '@@delegate');
333+
if (!delegateAttr) {
334+
return undefined;
335+
}
336+
const arg = delegateAttr.args[0]?.value;
337+
return isReferenceExpr(arg) ? (arg.target.ref as DataModelField) : undefined;
338+
}

tests/integration/tests/enhancements/with-delegate/policy-interaction.test.ts

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,4 +571,84 @@ describe('Polymorphic Policy Test', () => {
571571
expect(foundPost2.foo).toBeUndefined();
572572
expect(foundPost2.bar).toBeUndefined();
573573
});
574+
575+
it('respects concrete policies when read as base optional relation', async () => {
576+
const { enhance } = await loadSchema(
577+
`
578+
model User {
579+
id Int @id @default(autoincrement())
580+
asset Asset?
581+
@@allow('all', true)
582+
}
583+
584+
model Asset {
585+
id Int @id @default(autoincrement())
586+
user User @relation(fields: [userId], references: [id])
587+
userId Int @unique
588+
type String
589+
590+
@@delegate(type)
591+
@@allow('all', true)
592+
}
593+
594+
model Post extends Asset {
595+
title String
596+
private Boolean
597+
@@allow('create', true)
598+
@@deny('read', private)
599+
}
600+
`
601+
);
602+
603+
const fullDb = enhance(undefined, { kinds: ['delegate'] });
604+
await fullDb.user.create({ data: { id: 1 } });
605+
await fullDb.post.create({ data: { title: 'Post1', private: true, user: { connect: { id: 1 } } } });
606+
await expect(fullDb.user.findUnique({ where: { id: 1 }, include: { asset: true } })).resolves.toMatchObject({
607+
asset: expect.objectContaining({ type: 'Post' }),
608+
});
609+
610+
const db = enhance();
611+
const read = await db.user.findUnique({ where: { id: 1 }, include: { asset: true } });
612+
expect(read.asset).toBeTruthy();
613+
expect(read.asset.title).toBeUndefined();
614+
});
615+
616+
it('respects concrete policies when read as base required relation', async () => {
617+
const { enhance } = await loadSchema(
618+
`
619+
model User {
620+
id Int @id @default(autoincrement())
621+
asset Asset @relation(fields: [assetId], references: [id])
622+
assetId Int @unique
623+
@@allow('all', true)
624+
}
625+
626+
model Asset {
627+
id Int @id @default(autoincrement())
628+
user User?
629+
type String
630+
631+
@@delegate(type)
632+
@@allow('all', true)
633+
}
634+
635+
model Post extends Asset {
636+
title String
637+
private Boolean
638+
@@deny('read', private)
639+
}
640+
`
641+
);
642+
643+
const fullDb = enhance(undefined, { kinds: ['delegate'] });
644+
await fullDb.post.create({ data: { id: 1, title: 'Post1', private: true, user: { create: { id: 1 } } } });
645+
await expect(fullDb.user.findUnique({ where: { id: 1 }, include: { asset: true } })).resolves.toMatchObject({
646+
asset: expect.objectContaining({ type: 'Post' }),
647+
});
648+
649+
const db = enhance();
650+
const read = await db.user.findUnique({ where: { id: 1 }, include: { asset: true } });
651+
expect(read).toBeTruthy();
652+
expect(read.asset.title).toBeUndefined();
653+
});
574654
});
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import { loadSchema } from '@zenstackhq/testtools';
2+
3+
describe('issue 1930', () => {
4+
it('regression', async () => {
5+
const { enhance } = await loadSchema(
6+
`
7+
model Organization {
8+
id String @id @default(cuid())
9+
entities Entity[]
10+
11+
@@allow('all', true)
12+
}
13+
14+
model Entity {
15+
id String @id @default(cuid())
16+
org Organization? @relation(fields: [orgId], references: [id])
17+
orgId String?
18+
contents EntityContent[]
19+
entityType String
20+
isDeleted Boolean @default(false)
21+
22+
@@delegate(entityType)
23+
24+
@@allow('all', !isDeleted)
25+
}
26+
27+
model EntityContent {
28+
id String @id @default(cuid())
29+
entity Entity @relation(fields: [entityId], references: [id])
30+
entityId String
31+
32+
entityContentType String
33+
34+
@@delegate(entityContentType)
35+
36+
@@allow('create', true)
37+
@@allow('read', check(entity))
38+
}
39+
40+
model Article extends Entity {
41+
}
42+
43+
model ArticleContent extends EntityContent {
44+
body String?
45+
}
46+
47+
model OtherContent extends EntityContent {
48+
data Int
49+
}
50+
`
51+
);
52+
53+
const fullDb = enhance(undefined, { kinds: ['delegate'] });
54+
const org = await fullDb.organization.create({ data: {} });
55+
const article = await fullDb.article.create({
56+
data: { org: { connect: { id: org.id } } },
57+
});
58+
59+
const db = enhance();
60+
61+
// normal create/read
62+
await expect(
63+
db.articleContent.create({
64+
data: { body: 'abc', entity: { connect: { id: article.id } } },
65+
})
66+
).toResolveTruthy();
67+
await expect(db.article.findFirst({ include: { contents: true } })).resolves.toMatchObject({
68+
contents: expect.arrayContaining([expect.objectContaining({ body: 'abc' })]),
69+
});
70+
71+
// deleted article's contents are not readable
72+
const deletedArticle = await fullDb.article.create({
73+
data: { org: { connect: { id: org.id } }, isDeleted: true },
74+
});
75+
const content1 = await fullDb.articleContent.create({
76+
data: { body: 'bcd', entity: { connect: { id: deletedArticle.id } } },
77+
});
78+
await expect(db.articleContent.findUnique({ where: { id: content1.id } })).toResolveNull();
79+
});
80+
});

0 commit comments

Comments
 (0)