Skip to content

Commit 6439fd6

Browse files
authored
fix(cli): generated TS typing for auth() access is too strong (#1589)
1 parent 3140d9b commit 6439fd6

File tree

2 files changed

+45
-42
lines changed

2 files changed

+45
-42
lines changed

packages/schema/src/plugins/enhancer/enhance/auth-type-generator.ts

Lines changed: 19 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { getIdFields, hasAttribute, isAuthInvocation, isDataModelFieldReference } from '@zenstackhq/sdk';
1+
import { getIdFields, isAuthInvocation, isDataModelFieldReference } from '@zenstackhq/sdk';
22
import {
33
DataModel,
44
DataModelField,
@@ -18,41 +18,27 @@ export function generateAuthType(model: Model, authModel: DataModel) {
1818
const types = new Map<
1919
string,
2020
{
21-
// scalar fields to directly pick from Prisma-generated type
22-
pickFields: string[];
23-
24-
// relation fields to include
25-
addFields: { name: string; type: string }[];
21+
// relation fields to require
22+
requiredRelations: { name: string; type: string }[];
2623
}
2724
>();
2825

29-
types.set(authModel.name, { pickFields: getIdFields(authModel).map((f) => f.name), addFields: [] });
26+
types.set(authModel.name, { requiredRelations: [] });
3027

3128
const ensureType = (model: string) => {
3229
if (!types.has(model)) {
33-
types.set(model, { pickFields: [], addFields: [] });
34-
}
35-
};
36-
37-
const addPickField = (model: string, field: string) => {
38-
let fields = types.get(model);
39-
if (!fields) {
40-
fields = { pickFields: [], addFields: [] };
41-
types.set(model, fields);
42-
}
43-
if (!fields.pickFields.includes(field)) {
44-
fields.pickFields.push(field);
30+
types.set(model, { requiredRelations: [] });
4531
}
4632
};
4733

4834
const addAddField = (model: string, name: string, type: string, array: boolean) => {
4935
let fields = types.get(model);
5036
if (!fields) {
51-
fields = { pickFields: [], addFields: [] };
37+
fields = { requiredRelations: [] };
5238
types.set(model, fields);
5339
}
54-
if (!fields.addFields.find((f) => f.name === name)) {
55-
fields.addFields.push({ name, type: array ? `${type}[]` : type });
40+
if (!fields.requiredRelations.find((f) => f.name === name)) {
41+
fields.requiredRelations.push({ name, type: array ? `${type}[]` : type });
5642
}
5743
};
5844

@@ -71,11 +57,6 @@ export function generateAuthType(model: Model, authModel: DataModel) {
7157
const fieldType = memberDecl.type.reference.ref.name;
7258
ensureType(fieldType);
7359
addAddField(exprType.name, memberDecl.name, fieldType, memberDecl.type.array);
74-
} else {
75-
// member is a scalar
76-
if (!isIgnoredField(node.member.ref)) {
77-
addPickField(exprType.name, node.member.$refText);
78-
}
7960
}
8061
}
8162
}
@@ -88,11 +69,6 @@ export function generateAuthType(model: Model, authModel: DataModel) {
8869
// field is a relation
8970
ensureType(fieldType.name);
9071
addAddField(fieldDecl.$container.name, node.target.$refText, fieldType.name, fieldDecl.type.array);
91-
} else {
92-
if (!isIgnoredField(fieldDecl)) {
93-
// field is a scalar
94-
addPickField(fieldDecl.$container.name, node.target.$refText);
95-
}
9672
}
9773
}
9874
});
@@ -112,16 +88,21 @@ ${Array.from(types.entries())
11288
.map(([model, fields]) => {
11389
let result = `Partial<_P.${model}>`;
11490
115-
if (fields.pickFields.length > 0) {
116-
result = `WithRequired<${result}, ${fields.pickFields
117-
.map((f) => `'${f}'`)
118-
.join('|')}> & Record<string, unknown>`;
91+
if (model === authModel.name) {
92+
// auth model's id fields are always required
93+
const idFields = getIdFields(authModel).map((f) => f.name);
94+
if (idFields.length > 0) {
95+
result = `WithRequired<${result}, ${idFields.map((f) => `'${f}'`).join('|')}>`;
96+
}
11997
}
12098
121-
if (fields.addFields.length > 0) {
122-
result = `${result} & { ${fields.addFields.map(({ name, type }) => `${name}: ${type}`).join('; ')} }`;
99+
if (fields.requiredRelations.length > 0) {
100+
// merge required relation fields
101+
result = `${result} & { ${fields.requiredRelations.map((f) => `${f.name}: ${f.type}`).join('; ')} }`;
123102
}
124103
104+
result = `${result} & Record<string, unknown>`;
105+
125106
return ` export type ${model} = ${result};`;
126107
})
127108
.join('\n')}
@@ -145,7 +126,3 @@ function isAuthAccess(node: AstNode): node is Expression {
145126

146127
return false;
147128
}
148-
149-
function isIgnoredField(field: DataModelField | undefined) {
150-
return !!(field && hasAttribute(field, '@ignore'));
151-
}

tests/integration/tests/enhancements/with-policy/auth.test.ts

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -809,4 +809,30 @@ describe('auth() compile-time test', () => {
809809
}
810810
);
811811
});
812+
813+
it('optional field stays optional', async () => {
814+
await loadSchema(
815+
`
816+
model User {
817+
id Int @id
818+
age Int?
819+
820+
@@allow('all', auth().age > 0)
821+
}
822+
`,
823+
{
824+
compile: true,
825+
extraSourceFiles: [
826+
{
827+
name: 'main.ts',
828+
content: `
829+
import { enhance } from ".zenstack/enhance";
830+
import { PrismaClient } from '@prisma/client';
831+
enhance(new PrismaClient(), { user: { id: 1 } });
832+
`,
833+
},
834+
],
835+
}
836+
);
837+
});
812838
});

0 commit comments

Comments
 (0)