Skip to content

Commit 67d35d5

Browse files
authored
feat: "create" access policy implementation (#242)
* feat: "create" access policy implementation * fix test * update
1 parent 0f9764f commit 67d35d5

File tree

19 files changed

+615
-182
lines changed

19 files changed

+615
-182
lines changed

TODO.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,10 @@
8383
- [x] Error system
8484
- [x] Custom table name
8585
- [x] Custom field name
86+
- [ ] Global omit
8687
- [ ] DbNull vs JsonNull
8788
- [ ] Migrate to tsdown
89+
- [ ] @default validation
8890
- [ ] Benchmark
8991
- [x] Plugin
9092
- [x] Post-mutation hooks should be called after transaction is committed

packages/language/src/validators/attribute-application-validator.ts

Lines changed: 73 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,29 @@
11
import { AstUtils, type ValidationAcceptor } from 'langium';
22
import pluralize from 'pluralize';
3+
import type { BinaryExpr, DataModel, Expression } from '../ast';
34
import {
45
ArrayExpr,
56
Attribute,
67
AttributeArg,
78
AttributeParam,
8-
DataModelAttribute,
99
DataField,
1010
DataFieldAttribute,
11+
DataModelAttribute,
1112
InternalAttribute,
1213
ReferenceExpr,
1314
isArrayExpr,
1415
isAttribute,
15-
isDataModel,
1616
isDataField,
17+
isDataModel,
1718
isEnum,
1819
isReferenceExpr,
1920
isTypeDef,
2021
} from '../generated/ast';
2122
import {
2223
getAllAttributes,
2324
getStringLiteral,
24-
hasAttribute,
25+
isAuthOrAuthMemberAccess,
26+
isCollectionPredicate,
2527
isDataFieldReference,
2628
isDelegateModel,
2729
isFutureExpr,
@@ -31,7 +33,6 @@ import {
3133
typeAssignable,
3234
} from '../utils';
3335
import type { AstValidator } from './common';
34-
import type { DataModel } from '../ast';
3536

3637
// a registry of function handlers marked with @check
3738
const attributeCheckers = new Map<string, PropertyDescriptor>();
@@ -153,6 +154,7 @@ export default class AttributeApplicationValidator implements AstValidator<Attri
153154
}
154155
}
155156

157+
// TODO: design a way to let plugin register validation
156158
@check('@@allow')
157159
@check('@@deny')
158160
// @ts-expect-error
@@ -166,10 +168,75 @@ export default class AttributeApplicationValidator implements AstValidator<Attri
166168
}
167169
this.validatePolicyKinds(kind, ['create', 'read', 'update', 'delete', 'all'], attr, accept);
168170

169-
// @encrypted fields cannot be used in policy rules
170-
this.rejectEncryptedFields(attr, accept);
171+
if ((kind === 'create' || kind === 'all') && attr.args[1]?.value) {
172+
// "create" rules cannot access non-owned relations because the entity does not exist yet, so
173+
// there can't possibly be a fk that points to it
174+
this.rejectNonOwnedRelationInExpression(attr.args[1].value, accept);
175+
}
171176
}
172177

178+
private rejectNonOwnedRelationInExpression(expr: Expression, accept: ValidationAcceptor) {
179+
const contextModel = AstUtils.getContainerOfType(expr, isDataModel);
180+
if (!contextModel) {
181+
return;
182+
}
183+
184+
if (
185+
AstUtils.streamAst(expr).some((node) => {
186+
if (!isDataFieldReference(node)) {
187+
// not a field reference, skip
188+
return false;
189+
}
190+
191+
// referenced field is not a member of the context model, skip
192+
if (node.target.ref?.$container !== contextModel) {
193+
return false;
194+
}
195+
196+
const field = node.target.ref as DataField;
197+
if (!isRelationshipField(field)) {
198+
// not a relation, skip
199+
return false;
200+
}
201+
202+
if (isAuthOrAuthMemberAccess(node)) {
203+
// field reference is from auth() or access from auth(), not a relation query
204+
return false;
205+
}
206+
207+
// check if the the node is a reference inside a collection predicate scope by auth access,
208+
// e.g., `auth().foo?[x > 0]`
209+
210+
// make sure to skip the current level if the node is already an LHS of a collection predicate,
211+
// otherwise we're just circling back to itself when visiting the parent
212+
const startNode =
213+
isCollectionPredicate(node.$container) && (node.$container as BinaryExpr).left === node
214+
? node.$container
215+
: node;
216+
const collectionPredicate = AstUtils.getContainerOfType(startNode.$container, isCollectionPredicate);
217+
if (collectionPredicate && isAuthOrAuthMemberAccess(collectionPredicate.left)) {
218+
return false;
219+
}
220+
221+
const relationAttr = field.attributes.find((attr) => attr.decl.ref?.name === '@relation');
222+
if (!relationAttr) {
223+
// no "@relation", not owner side of the relation, match
224+
return true;
225+
}
226+
227+
if (!relationAttr.args.some((arg) => arg.name === 'fields')) {
228+
// no "fields" argument, can't be owner side of the relation, match
229+
return true;
230+
}
231+
232+
return false;
233+
})
234+
) {
235+
accept('error', `non-owned relation fields are not allowed in "create" rules`, { node: expr });
236+
}
237+
}
238+
239+
// TODO: design a way to let plugin register validation
173240
@check('@allow')
174241
@check('@deny')
175242
// @ts-expect-error
@@ -199,9 +266,6 @@ export default class AttributeApplicationValidator implements AstValidator<Attri
199266
);
200267
}
201268
}
202-
203-
// @encrypted fields cannot be used in policy rules
204-
this.rejectEncryptedFields(attr, accept);
205269
}
206270

207271
@check('@@validate')
@@ -261,14 +325,6 @@ export default class AttributeApplicationValidator implements AstValidator<Attri
261325
}
262326
}
263327

264-
private rejectEncryptedFields(attr: AttributeApplication, accept: ValidationAcceptor) {
265-
AstUtils.streamAllContents(attr).forEach((node) => {
266-
if (isDataFieldReference(node) && hasAttribute(node.target.ref as DataField, '@encrypted')) {
267-
accept('error', `Encrypted fields cannot be used in policy rules`, { node });
268-
}
269-
});
270-
}
271-
272328
private validatePolicyKinds(
273329
kind: string,
274330
candidates: string[],

packages/runtime/src/client/crud/operations/base.ts

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
353353
createFields = baseCreateResult.remainingFields;
354354
}
355355

356-
const updatedData = this.fillGeneratedValues(modelDef, createFields);
356+
const updatedData = this.fillGeneratedAndDefaultValues(modelDef, createFields);
357357
const idFields = getIdFields(this.schema, model);
358358
const query = kysely
359359
.insertInto(model)
@@ -722,7 +722,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
722722
newItem[fk] = fromRelation.ids[pk];
723723
}
724724
}
725-
return this.fillGeneratedValues(modelDef, newItem);
725+
return this.fillGeneratedAndDefaultValues(modelDef, newItem);
726726
});
727727

728728
if (!this.dialect.supportInsertWithDefault) {
@@ -841,7 +841,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
841841
return { baseEntities, remainingFieldRows };
842842
}
843843

844-
private fillGeneratedValues(modelDef: ModelDef, data: object) {
844+
private fillGeneratedAndDefaultValues(modelDef: ModelDef, data: object) {
845845
const fields = modelDef.fields;
846846
const values: any = clone(data);
847847
for (const [field, fieldDef] of Object.entries(fields)) {
@@ -858,6 +858,21 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
858858
} else if (fields[field]?.updatedAt) {
859859
// TODO: should this work at kysely level instead?
860860
values[field] = this.dialect.transformPrimitive(new Date(), 'DateTime', false);
861+
} else if (fields[field]?.default !== undefined) {
862+
let value = fields[field].default;
863+
if (fieldDef.type === 'Json') {
864+
// Schema uses JSON string for default value of Json fields
865+
if (fieldDef.array && Array.isArray(value)) {
866+
value = value.map((v) => (typeof v === 'string' ? JSON.parse(v) : v));
867+
} else if (typeof value === 'string') {
868+
value = JSON.parse(value);
869+
}
870+
}
871+
values[field] = this.dialect.transformPrimitive(
872+
value,
873+
fields[field].type as BuiltinType,
874+
!!fields[field].array,
875+
);
861876
}
862877
}
863878
}

packages/runtime/src/client/crud/validator.ts

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,11 @@ import {
2222
type UpdateManyArgs,
2323
type UpsertArgs,
2424
} from '../crud-types';
25-
import { InputValidationError, InternalError, QueryError } from '../errors';
25+
import { InputValidationError, InternalError } from '../errors';
2626
import {
2727
fieldHasDefaultValue,
2828
getDiscriminatorField,
2929
getEnum,
30-
getModel,
3130
getUniqueFields,
3231
requireField,
3332
requireModel,
@@ -279,10 +278,7 @@ export class InputValidator<Schema extends SchemaDef> {
279278
withoutRelationFields = false,
280279
withAggregations = false,
281280
): ZodType {
282-
const modelDef = getModel(this.schema, model);
283-
if (!modelDef) {
284-
throw new QueryError(`Model "${model}" not found in schema`);
285-
}
281+
const modelDef = requireModel(this.schema, model);
286282

287283
const fields: Record<string, any> = {};
288284
for (const field of Object.keys(modelDef.fields)) {

packages/runtime/src/client/executor/kysely-utils.ts

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1-
import { invariant } from '@zenstackhq/common-helpers';
2-
import { type OperationNode, AliasNode, IdentifierNode } from 'kysely';
1+
import { type OperationNode, AliasNode } from 'kysely';
32

43
/**
54
* Strips alias from the node if it exists.
65
*/
76
export function stripAlias(node: OperationNode) {
87
if (AliasNode.is(node)) {
9-
invariant(IdentifierNode.is(node.alias), 'Expected identifier as alias');
10-
return { alias: node.alias.name, node: node.node };
8+
return { alias: node.alias, node: node.node };
119
} else {
1210
return { alias: undefined, node };
1311
}

packages/runtime/src/client/executor/name-mapper.ts

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import { stripAlias } from './kysely-utils';
2222

2323
type Scope = {
2424
model?: string;
25-
alias?: string;
25+
alias?: OperationNode;
2626
namesMapped?: boolean; // true means fields referring to this scope have their names already mapped
2727
};
2828

@@ -120,7 +120,7 @@ export class QueryNameMapper extends OperationNodeTransformer {
120120
// map table name depending on how it is resolved
121121
let mappedTableName = node.table?.table.identifier.name;
122122
if (mappedTableName) {
123-
if (scope.alias === mappedTableName) {
123+
if (scope.alias && IdentifierNode.is(scope.alias) && scope.alias.name === mappedTableName) {
124124
// table name is resolved to an alias, no mapping needed
125125
} else if (scope.model === mappedTableName) {
126126
// table name is resolved to a model, map the name as needed
@@ -222,7 +222,14 @@ export class QueryNameMapper extends OperationNodeTransformer {
222222
const origFieldName = this.extractFieldName(selection.selection);
223223
const fieldName = this.extractFieldName(transformed);
224224
if (fieldName !== origFieldName) {
225-
selections.push(SelectionNode.create(this.wrapAlias(transformed, origFieldName)));
225+
selections.push(
226+
SelectionNode.create(
227+
this.wrapAlias(
228+
transformed,
229+
origFieldName ? IdentifierNode.create(origFieldName) : undefined,
230+
),
231+
),
232+
);
226233
} else {
227234
selections.push(SelectionNode.create(transformed));
228235
}
@@ -241,7 +248,7 @@ export class QueryNameMapper extends OperationNodeTransformer {
241248
// if the field as a qualifier, the qualifier must match the scope's
242249
// alias if any, or model if no alias
243250
if (scope.alias) {
244-
if (scope.alias === qualifier) {
251+
if (scope.alias && IdentifierNode.is(scope.alias) && scope.alias.name === qualifier) {
245252
// scope has an alias that matches the qualifier
246253
return scope;
247254
} else {
@@ -295,8 +302,8 @@ export class QueryNameMapper extends OperationNodeTransformer {
295302
}
296303
}
297304

298-
private wrapAlias<T extends OperationNode>(node: T, alias: string | undefined) {
299-
return alias ? AliasNode.create(node, IdentifierNode.create(alias)) : node;
305+
private wrapAlias<T extends OperationNode>(node: T, alias: OperationNode | undefined) {
306+
return alias ? AliasNode.create(node, alias) : node;
300307
}
301308

302309
private processTableRef(node: TableNode) {
@@ -351,11 +358,11 @@ export class QueryNameMapper extends OperationNodeTransformer {
351358
// inner transformations will map column names
352359
const modelName = innerNode.table.identifier.name;
353360
const mappedName = this.mapTableName(modelName);
354-
const finalAlias = alias ?? (mappedName !== modelName ? modelName : undefined);
361+
const finalAlias = alias ?? (mappedName !== modelName ? IdentifierNode.create(modelName) : undefined);
355362
return {
356363
node: this.wrapAlias(TableNode.create(mappedName), finalAlias),
357364
scope: {
358-
alias: alias ?? modelName,
365+
alias: alias ?? IdentifierNode.create(modelName),
359366
model: modelName,
360367
namesMapped: !this.hasMappedColumns(modelName),
361368
},
@@ -374,13 +381,13 @@ export class QueryNameMapper extends OperationNodeTransformer {
374381
}
375382
}
376383

377-
private createSelectAllFields(model: string, alias: string | undefined) {
384+
private createSelectAllFields(model: string, alias: OperationNode | undefined) {
378385
const modelDef = requireModel(this.schema, model);
379386
return this.getModelFields(modelDef).map((fieldDef) => {
380387
const columnName = this.mapFieldName(model, fieldDef.name);
381388
const columnRef = ReferenceNode.create(
382389
ColumnNode.create(columnName),
383-
alias ? TableNode.create(alias) : undefined,
390+
alias && IdentifierNode.is(alias) ? TableNode.create(alias.name) : undefined,
384391
);
385392
if (columnName !== fieldDef.name) {
386393
const aliased = AliasNode.create(columnRef, IdentifierNode.create(fieldDef.name));
@@ -421,7 +428,7 @@ export class QueryNameMapper extends OperationNodeTransformer {
421428
alias = this.extractFieldName(node);
422429
}
423430
const result = super.transformNode(node);
424-
return this.wrapAlias(result, alias);
431+
return this.wrapAlias(result, alias ? IdentifierNode.create(alias) : undefined);
425432
}
426433

427434
private processSelectAll(node: SelectAllNode) {
@@ -438,7 +445,9 @@ export class QueryNameMapper extends OperationNodeTransformer {
438445
return this.getModelFields(modelDef).map((fieldDef) => {
439446
const columnName = this.mapFieldName(modelDef.name, fieldDef.name);
440447
const columnRef = ReferenceNode.create(ColumnNode.create(columnName));
441-
return columnName !== fieldDef.name ? this.wrapAlias(columnRef, fieldDef.name) : columnRef;
448+
return columnName !== fieldDef.name
449+
? this.wrapAlias(columnRef, IdentifierNode.create(fieldDef.name))
450+
: columnRef;
442451
});
443452
}
444453

packages/runtime/src/client/executor/zenstack-query-executor.ts

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ export class ZenStackQueryExecutor<Schema extends SchemaDef> extends DefaultQuer
100100
const hookResult = await hook!({
101101
client: this.client as ClientContract<Schema>,
102102
schema: this.client.$schema,
103-
kysely: this.kysely,
104103
query,
105104
proceed: _p,
106105
});

packages/runtime/src/client/plugin.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import type { OperationNode, QueryResult, RootOperationNode, UnknownRow } from 'kysely';
2-
import type { ClientContract, ToKysely } from '.';
2+
import type { ClientContract } from '.';
33
import type { GetModels, SchemaDef } from '../schema';
44
import type { MaybePromise } from '../utils/type-utils';
55
import type { AllCrudOperation } from './crud/operations/base';
@@ -180,7 +180,6 @@ export type PluginAfterEntityMutationArgs<Schema extends SchemaDef> = MutationHo
180180
// #region OnKyselyQuery hooks
181181

182182
export type OnKyselyQueryArgs<Schema extends SchemaDef> = {
183-
kysely: ToKysely<Schema>;
184183
schema: SchemaDef;
185184
client: ClientContract<Schema>;
186185
query: RootOperationNode;

0 commit comments

Comments
 (0)