Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,10 @@
- [x] Error system
- [x] Custom table name
- [x] Custom field name
- [ ] Global omit
- [ ] DbNull vs JsonNull
- [ ] Migrate to tsdown
- [ ] @default validation
- [ ] Benchmark
- [x] Plugin
- [x] Post-mutation hooks should be called after transaction is committed
Expand Down
90 changes: 73 additions & 17 deletions packages/language/src/validators/attribute-application-validator.ts
Original file line number Diff line number Diff line change
@@ -1,27 +1,29 @@
import { AstUtils, type ValidationAcceptor } from 'langium';
import pluralize from 'pluralize';
import type { BinaryExpr, DataModel, Expression } from '../ast';
import {
ArrayExpr,
Attribute,
AttributeArg,
AttributeParam,
DataModelAttribute,
DataField,
DataFieldAttribute,
DataModelAttribute,
InternalAttribute,
ReferenceExpr,
isArrayExpr,
isAttribute,
isDataModel,
isDataField,
isDataModel,
isEnum,
isReferenceExpr,
isTypeDef,
} from '../generated/ast';
import {
getAllAttributes,
getStringLiteral,
hasAttribute,
isAuthOrAuthMemberAccess,
isCollectionPredicate,
isDataFieldReference,
isDelegateModel,
isFutureExpr,
Expand All @@ -31,7 +33,6 @@ import {
typeAssignable,
} from '../utils';
import type { AstValidator } from './common';
import type { DataModel } from '../ast';

// a registry of function handlers marked with @check
const attributeCheckers = new Map<string, PropertyDescriptor>();
Expand Down Expand Up @@ -153,6 +154,7 @@ export default class AttributeApplicationValidator implements AstValidator<Attri
}
}

// TODO: design a way to let plugin register validation
@check('@@allow')
@check('@@deny')
// @ts-expect-error
Expand All @@ -166,10 +168,75 @@ export default class AttributeApplicationValidator implements AstValidator<Attri
}
this.validatePolicyKinds(kind, ['create', 'read', 'update', 'delete', 'all'], attr, accept);

// @encrypted fields cannot be used in policy rules
this.rejectEncryptedFields(attr, accept);
if ((kind === 'create' || kind === 'all') && attr.args[1]?.value) {
// "create" rules cannot access non-owned relations because the entity does not exist yet, so
// there can't possibly be a fk that points to it
this.rejectNonOwnedRelationInExpression(attr.args[1].value, accept);
}
}

private rejectNonOwnedRelationInExpression(expr: Expression, accept: ValidationAcceptor) {
const contextModel = AstUtils.getContainerOfType(expr, isDataModel);
if (!contextModel) {
return;
}

if (
AstUtils.streamAst(expr).some((node) => {
if (!isDataFieldReference(node)) {
// not a field reference, skip
return false;
}

// referenced field is not a member of the context model, skip
if (node.target.ref?.$container !== contextModel) {
return false;
}

const field = node.target.ref as DataField;
if (!isRelationshipField(field)) {
// not a relation, skip
return false;
}

if (isAuthOrAuthMemberAccess(node)) {
// field reference is from auth() or access from auth(), not a relation query
return false;
}

// check if the the node is a reference inside a collection predicate scope by auth access,
// e.g., `auth().foo?[x > 0]`

// make sure to skip the current level if the node is already an LHS of a collection predicate,
// otherwise we're just circling back to itself when visiting the parent
const startNode =
isCollectionPredicate(node.$container) && (node.$container as BinaryExpr).left === node
? node.$container
: node;
const collectionPredicate = AstUtils.getContainerOfType(startNode.$container, isCollectionPredicate);
if (collectionPredicate && isAuthOrAuthMemberAccess(collectionPredicate.left)) {
return false;
}

const relationAttr = field.attributes.find((attr) => attr.decl.ref?.name === '@relation');
if (!relationAttr) {
// no "@relation", not owner side of the relation, match
return true;
}

if (!relationAttr.args.some((arg) => arg.name === 'fields')) {
// no "fields" argument, can't be owner side of the relation, match
return true;
}

return false;
})
) {
accept('error', `non-owned relation fields are not allowed in "create" rules`, { node: expr });
}
}

// TODO: design a way to let plugin register validation
@check('@allow')
@check('@deny')
// @ts-expect-error
Expand Down Expand Up @@ -199,9 +266,6 @@ export default class AttributeApplicationValidator implements AstValidator<Attri
);
}
}

// @encrypted fields cannot be used in policy rules
this.rejectEncryptedFields(attr, accept);
}

@check('@@validate')
Expand Down Expand Up @@ -261,14 +325,6 @@ export default class AttributeApplicationValidator implements AstValidator<Attri
}
}

private rejectEncryptedFields(attr: AttributeApplication, accept: ValidationAcceptor) {
AstUtils.streamAllContents(attr).forEach((node) => {
if (isDataFieldReference(node) && hasAttribute(node.target.ref as DataField, '@encrypted')) {
accept('error', `Encrypted fields cannot be used in policy rules`, { node });
}
});
}

private validatePolicyKinds(
kind: string,
candidates: string[],
Expand Down
21 changes: 18 additions & 3 deletions packages/runtime/src/client/crud/operations/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
createFields = baseCreateResult.remainingFields;
}

const updatedData = this.fillGeneratedValues(modelDef, createFields);
const updatedData = this.fillGeneratedAndDefaultValues(modelDef, createFields);
const idFields = getIdFields(this.schema, model);
const query = kysely
.insertInto(model)
Expand Down Expand Up @@ -722,7 +722,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
newItem[fk] = fromRelation.ids[pk];
}
}
return this.fillGeneratedValues(modelDef, newItem);
return this.fillGeneratedAndDefaultValues(modelDef, newItem);
});

if (!this.dialect.supportInsertWithDefault) {
Expand Down Expand Up @@ -841,7 +841,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
return { baseEntities, remainingFieldRows };
}

private fillGeneratedValues(modelDef: ModelDef, data: object) {
private fillGeneratedAndDefaultValues(modelDef: ModelDef, data: object) {
const fields = modelDef.fields;
const values: any = clone(data);
for (const [field, fieldDef] of Object.entries(fields)) {
Expand All @@ -858,6 +858,21 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
} else if (fields[field]?.updatedAt) {
// TODO: should this work at kysely level instead?
values[field] = this.dialect.transformPrimitive(new Date(), 'DateTime', false);
} else if (fields[field]?.default !== undefined) {
let value = fields[field].default;
if (fieldDef.type === 'Json') {
// Schema uses JSON string for default value of Json fields
if (fieldDef.array && Array.isArray(value)) {
value = value.map((v) => (typeof v === 'string' ? JSON.parse(v) : v));
} else if (typeof value === 'string') {
value = JSON.parse(value);
}
}
values[field] = this.dialect.transformPrimitive(
value,
fields[field].type as BuiltinType,
!!fields[field].array,
);
}
}
}
Expand Down
8 changes: 2 additions & 6 deletions packages/runtime/src/client/crud/validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@ import {
type UpdateManyArgs,
type UpsertArgs,
} from '../crud-types';
import { InputValidationError, InternalError, QueryError } from '../errors';
import { InputValidationError, InternalError } from '../errors';
import {
fieldHasDefaultValue,
getDiscriminatorField,
getEnum,
getModel,
getUniqueFields,
requireField,
requireModel,
Expand Down Expand Up @@ -279,10 +278,7 @@ export class InputValidator<Schema extends SchemaDef> {
withoutRelationFields = false,
withAggregations = false,
): ZodType {
const modelDef = getModel(this.schema, model);
if (!modelDef) {
throw new QueryError(`Model "${model}" not found in schema`);
}
const modelDef = requireModel(this.schema, model);

const fields: Record<string, any> = {};
for (const field of Object.keys(modelDef.fields)) {
Expand Down
6 changes: 2 additions & 4 deletions packages/runtime/src/client/executor/kysely-utils.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import { invariant } from '@zenstackhq/common-helpers';
import { type OperationNode, AliasNode, IdentifierNode } from 'kysely';
import { type OperationNode, AliasNode } from 'kysely';

/**
* Strips alias from the node if it exists.
*/
export function stripAlias(node: OperationNode) {
if (AliasNode.is(node)) {
invariant(IdentifierNode.is(node.alias), 'Expected identifier as alias');
return { alias: node.alias.name, node: node.node };
return { alias: node.alias, node: node.node };
} else {
return { alias: undefined, node };
}
Expand Down
33 changes: 21 additions & 12 deletions packages/runtime/src/client/executor/name-mapper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import { stripAlias } from './kysely-utils';

type Scope = {
model?: string;
alias?: string;
alias?: OperationNode;
namesMapped?: boolean; // true means fields referring to this scope have their names already mapped
};

Expand Down Expand Up @@ -120,7 +120,7 @@ export class QueryNameMapper extends OperationNodeTransformer {
// map table name depending on how it is resolved
let mappedTableName = node.table?.table.identifier.name;
if (mappedTableName) {
if (scope.alias === mappedTableName) {
if (scope.alias && IdentifierNode.is(scope.alias) && scope.alias.name === mappedTableName) {
// table name is resolved to an alias, no mapping needed
} else if (scope.model === mappedTableName) {
// table name is resolved to a model, map the name as needed
Expand Down Expand Up @@ -222,7 +222,14 @@ export class QueryNameMapper extends OperationNodeTransformer {
const origFieldName = this.extractFieldName(selection.selection);
const fieldName = this.extractFieldName(transformed);
if (fieldName !== origFieldName) {
selections.push(SelectionNode.create(this.wrapAlias(transformed, origFieldName)));
selections.push(
SelectionNode.create(
this.wrapAlias(
transformed,
origFieldName ? IdentifierNode.create(origFieldName) : undefined,
),
),
);
} else {
selections.push(SelectionNode.create(transformed));
}
Expand All @@ -241,7 +248,7 @@ export class QueryNameMapper extends OperationNodeTransformer {
// if the field as a qualifier, the qualifier must match the scope's
// alias if any, or model if no alias
if (scope.alias) {
if (scope.alias === qualifier) {
if (scope.alias && IdentifierNode.is(scope.alias) && scope.alias.name === qualifier) {
// scope has an alias that matches the qualifier
return scope;
} else {
Expand Down Expand Up @@ -295,8 +302,8 @@ export class QueryNameMapper extends OperationNodeTransformer {
}
}

private wrapAlias<T extends OperationNode>(node: T, alias: string | undefined) {
return alias ? AliasNode.create(node, IdentifierNode.create(alias)) : node;
private wrapAlias<T extends OperationNode>(node: T, alias: OperationNode | undefined) {
return alias ? AliasNode.create(node, alias) : node;
}

private processTableRef(node: TableNode) {
Expand Down Expand Up @@ -351,11 +358,11 @@ export class QueryNameMapper extends OperationNodeTransformer {
// inner transformations will map column names
const modelName = innerNode.table.identifier.name;
const mappedName = this.mapTableName(modelName);
const finalAlias = alias ?? (mappedName !== modelName ? modelName : undefined);
const finalAlias = alias ?? (mappedName !== modelName ? IdentifierNode.create(modelName) : undefined);
return {
node: this.wrapAlias(TableNode.create(mappedName), finalAlias),
scope: {
alias: alias ?? modelName,
alias: alias ?? IdentifierNode.create(modelName),
model: modelName,
namesMapped: !this.hasMappedColumns(modelName),
},
Expand All @@ -374,13 +381,13 @@ export class QueryNameMapper extends OperationNodeTransformer {
}
}

private createSelectAllFields(model: string, alias: string | undefined) {
private createSelectAllFields(model: string, alias: OperationNode | undefined) {
const modelDef = requireModel(this.schema, model);
return this.getModelFields(modelDef).map((fieldDef) => {
const columnName = this.mapFieldName(model, fieldDef.name);
const columnRef = ReferenceNode.create(
ColumnNode.create(columnName),
alias ? TableNode.create(alias) : undefined,
alias && IdentifierNode.is(alias) ? TableNode.create(alias.name) : undefined,
);
if (columnName !== fieldDef.name) {
const aliased = AliasNode.create(columnRef, IdentifierNode.create(fieldDef.name));
Expand Down Expand Up @@ -421,7 +428,7 @@ export class QueryNameMapper extends OperationNodeTransformer {
alias = this.extractFieldName(node);
}
const result = super.transformNode(node);
return this.wrapAlias(result, alias);
return this.wrapAlias(result, alias ? IdentifierNode.create(alias) : undefined);
}

private processSelectAll(node: SelectAllNode) {
Expand All @@ -438,7 +445,9 @@ export class QueryNameMapper extends OperationNodeTransformer {
return this.getModelFields(modelDef).map((fieldDef) => {
const columnName = this.mapFieldName(modelDef.name, fieldDef.name);
const columnRef = ReferenceNode.create(ColumnNode.create(columnName));
return columnName !== fieldDef.name ? this.wrapAlias(columnRef, fieldDef.name) : columnRef;
return columnName !== fieldDef.name
? this.wrapAlias(columnRef, IdentifierNode.create(fieldDef.name))
: columnRef;
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ export class ZenStackQueryExecutor<Schema extends SchemaDef> extends DefaultQuer
const hookResult = await hook!({
client: this.client as ClientContract<Schema>,
schema: this.client.$schema,
kysely: this.kysely,
query,
proceed: _p,
});
Expand Down
3 changes: 1 addition & 2 deletions packages/runtime/src/client/plugin.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import type { OperationNode, QueryResult, RootOperationNode, UnknownRow } from 'kysely';
import type { ClientContract, ToKysely } from '.';
import type { ClientContract } from '.';
import type { GetModels, SchemaDef } from '../schema';
import type { MaybePromise } from '../utils/type-utils';
import type { AllCrudOperation } from './crud/operations/base';
Expand Down Expand Up @@ -180,7 +180,6 @@ export type PluginAfterEntityMutationArgs<Schema extends SchemaDef> = MutationHo
// #region OnKyselyQuery hooks

export type OnKyselyQueryArgs<Schema extends SchemaDef> = {
kysely: ToKysely<Schema>;
schema: SchemaDef;
client: ClientContract<Schema>;
query: RootOperationNode;
Expand Down
Loading