Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
- [x] Error system
- [x] Custom table name
- [x] Custom field name
- [ ] Global omit
- [ ] DbNull vs JsonNull
- [ ] Migrate to tsdown
- [ ] Benchmark
Expand Down
108 changes: 91 additions & 17 deletions packages/language/src/validators/attribute-application-validator.ts
Original file line number Diff line number Diff line change
@@ -1,27 +1,29 @@
import { AstUtils, type ValidationAcceptor } from 'langium';
import pluralize from 'pluralize';
import type { BinaryExpr, DataModel, Expression } from '../ast';
import {
ArrayExpr,
Attribute,
AttributeArg,
AttributeParam,
DataModelAttribute,
DataField,
DataFieldAttribute,
DataModelAttribute,
InternalAttribute,
ReferenceExpr,
isArrayExpr,
isAttribute,
isDataModel,
isDataField,
isDataModel,
isEnum,
isReferenceExpr,
isTypeDef,
} from '../generated/ast';
import {
getAllAttributes,
getStringLiteral,
hasAttribute,
isAuthOrAuthMemberAccess,
isCollectionPredicate,
isDataFieldReference,
isDelegateModel,
isFutureExpr,
Expand All @@ -31,7 +33,6 @@ import {
typeAssignable,
} from '../utils';
import type { AstValidator } from './common';
import type { DataModel } from '../ast';

// a registry of function handlers marked with @check
const attributeCheckers = new Map<string, PropertyDescriptor>();
Expand Down Expand Up @@ -153,6 +154,25 @@ export default class AttributeApplicationValidator implements AstValidator<Attri
}
}

@check('@default')
// @ts-expect-error
private _checkDefault(attr: AttributeApplication, accept: ValidationAcceptor) {
if (attr.$container && isDataField(attr.$container) && attr.$container.type.type === 'Json') {
// Json field default value must be a valid JSON string
const value = getStringLiteral(attr.args[0]?.value);
if (!value) {
accept('error', 'value must be a valid JSON string', { node: attr });
return;
}
try {
JSON.parse(value);
} catch {
accept('error', 'value is not a valid JSON string', { node: attr });
}
}
}

// TODO: design a way to let plugin register validation
@check('@@allow')
@check('@@deny')
// @ts-expect-error
Expand All @@ -166,10 +186,75 @@ export default class AttributeApplicationValidator implements AstValidator<Attri
}
this.validatePolicyKinds(kind, ['create', 'read', 'update', 'delete', 'all'], attr, accept);

// @encrypted fields cannot be used in policy rules
this.rejectEncryptedFields(attr, accept);
if ((kind === 'create' || kind === 'all') && attr.args[1]?.value) {
// "create" rules cannot access non-owned relations because the entity does not exist yet, so
// there can't possibly be a fk that points to it
this.rejectNonOwnedRelationInExpression(attr.args[1].value, accept);
}
}

private rejectNonOwnedRelationInExpression(expr: Expression, accept: ValidationAcceptor) {
const contextModel = AstUtils.getContainerOfType(expr, isDataModel);
if (!contextModel) {
return;
}

if (
AstUtils.streamAst(expr).some((node) => {
if (!isDataFieldReference(node)) {
// not a field reference, skip
return false;
}

// referenced field is not a member of the context model, skip
if (node.target.ref?.$container !== contextModel) {
return false;
}

const field = node.target.ref as DataField;
if (!isRelationshipField(field)) {
// not a relation, skip
return false;
}

if (isAuthOrAuthMemberAccess(node)) {
// field reference is from auth() or access from auth(), not a relation query
return false;
}

// check if the the node is a reference inside a collection predicate scope by auth access,
// e.g., `auth().foo?[x > 0]`

// make sure to skip the current level if the node is already an LHS of a collection predicate,
// otherwise we're just circling back to itself when visiting the parent
const startNode =
isCollectionPredicate(node.$container) && (node.$container as BinaryExpr).left === node
? node.$container
: node;
const collectionPredicate = AstUtils.getContainerOfType(startNode.$container, isCollectionPredicate);
if (collectionPredicate && isAuthOrAuthMemberAccess(collectionPredicate.left)) {
return false;
}

const relationAttr = field.attributes.find((attr) => attr.decl.ref?.name === '@relation');
if (!relationAttr) {
// no "@relation", not owner side of the relation, match
return true;
}

if (!relationAttr.args.some((arg) => arg.name === 'fields')) {
// no "fields" argument, can't be owner side of the relation, match
return true;
}

return false;
})
) {
accept('error', `non-owned relation fields are not allowed in "create" rules`, { node: expr });
}
}

// TODO: design a way to let plugin register validation
@check('@allow')
@check('@deny')
// @ts-expect-error
Expand Down Expand Up @@ -199,9 +284,6 @@ export default class AttributeApplicationValidator implements AstValidator<Attri
);
}
}

// @encrypted fields cannot be used in policy rules
this.rejectEncryptedFields(attr, accept);
}

@check('@@validate')
Expand Down Expand Up @@ -261,14 +343,6 @@ export default class AttributeApplicationValidator implements AstValidator<Attri
}
}

private rejectEncryptedFields(attr: AttributeApplication, accept: ValidationAcceptor) {
AstUtils.streamAllContents(attr).forEach((node) => {
if (isDataFieldReference(node) && hasAttribute(node.target.ref as DataField, '@encrypted')) {
accept('error', `Encrypted fields cannot be used in policy rules`, { node });
}
});
}

private validatePolicyKinds(
kind: string,
candidates: string[],
Expand Down
17 changes: 14 additions & 3 deletions packages/runtime/src/client/crud/operations/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
createFields = baseCreateResult.remainingFields;
}

const updatedData = this.fillGeneratedValues(modelDef, createFields);
const updatedData = this.fillGeneratedAndDefaultValues(modelDef, createFields);
const idFields = getIdFields(this.schema, model);
const query = kysely
.insertInto(model)
Expand Down Expand Up @@ -722,7 +722,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
newItem[fk] = fromRelation.ids[pk];
}
}
return this.fillGeneratedValues(modelDef, newItem);
return this.fillGeneratedAndDefaultValues(modelDef, newItem);
});

if (!this.dialect.supportInsertWithDefault) {
Expand Down Expand Up @@ -841,7 +841,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
return { baseEntities, remainingFieldRows };
}

private fillGeneratedValues(modelDef: ModelDef, data: object) {
private fillGeneratedAndDefaultValues(modelDef: ModelDef, data: object) {
const fields = modelDef.fields;
const values: any = clone(data);
for (const [field, fieldDef] of Object.entries(fields)) {
Expand All @@ -858,6 +858,17 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
} else if (fields[field]?.updatedAt) {
// TODO: should this work at kysely level instead?
values[field] = this.dialect.transformPrimitive(new Date(), 'DateTime', false);
} else if (fields[field]?.default !== undefined) {
let value = fields[field].default;
if (fieldDef.type === 'Json' && typeof value === 'string') {
// Schema uses JSON string for default value of Json fields
value = JSON.parse(value);
}
values[field] = this.dialect.transformPrimitive(
value,
fields[field].type as BuiltinType,
!!fields[field].array,
);
}
}
}
Expand Down
8 changes: 2 additions & 6 deletions packages/runtime/src/client/crud/validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@ import {
type UpdateManyArgs,
type UpsertArgs,
} from '../crud-types';
import { InputValidationError, InternalError, QueryError } from '../errors';
import { InputValidationError, InternalError } from '../errors';
import {
fieldHasDefaultValue,
getDiscriminatorField,
getEnum,
getModel,
getUniqueFields,
requireField,
requireModel,
Expand Down Expand Up @@ -279,10 +278,7 @@ export class InputValidator<Schema extends SchemaDef> {
withoutRelationFields = false,
withAggregations = false,
): ZodType {
const modelDef = getModel(this.schema, model);
if (!modelDef) {
throw new QueryError(`Model "${model}" not found in schema`);
}
const modelDef = requireModel(this.schema, model);

const fields: Record<string, any> = {};
for (const field of Object.keys(modelDef.fields)) {
Expand Down
6 changes: 2 additions & 4 deletions packages/runtime/src/client/executor/kysely-utils.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import { invariant } from '@zenstackhq/common-helpers';
import { type OperationNode, AliasNode, IdentifierNode } from 'kysely';
import { type OperationNode, AliasNode } from 'kysely';

/**
* Strips alias from the node if it exists.
*/
export function stripAlias(node: OperationNode) {
if (AliasNode.is(node)) {
invariant(IdentifierNode.is(node.alias), 'Expected identifier as alias');
return { alias: node.alias.name, node: node.node };
return { alias: node.alias, node: node.node };
} else {
return { alias: undefined, node };
}
Expand Down
33 changes: 21 additions & 12 deletions packages/runtime/src/client/executor/name-mapper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import { stripAlias } from './kysely-utils';

type Scope = {
model?: string;
alias?: string;
alias?: OperationNode;
namesMapped?: boolean; // true means fields referring to this scope have their names already mapped
};

Expand Down Expand Up @@ -120,7 +120,7 @@ export class QueryNameMapper extends OperationNodeTransformer {
// map table name depending on how it is resolved
let mappedTableName = node.table?.table.identifier.name;
if (mappedTableName) {
if (scope.alias === mappedTableName) {
if (scope.alias && IdentifierNode.is(scope.alias) && scope.alias.name === mappedTableName) {
// table name is resolved to an alias, no mapping needed
} else if (scope.model === mappedTableName) {
// table name is resolved to a model, map the name as needed
Expand Down Expand Up @@ -222,7 +222,14 @@ export class QueryNameMapper extends OperationNodeTransformer {
const origFieldName = this.extractFieldName(selection.selection);
const fieldName = this.extractFieldName(transformed);
if (fieldName !== origFieldName) {
selections.push(SelectionNode.create(this.wrapAlias(transformed, origFieldName)));
selections.push(
SelectionNode.create(
this.wrapAlias(
transformed,
origFieldName ? IdentifierNode.create(origFieldName) : undefined,
),
),
);
} else {
selections.push(SelectionNode.create(transformed));
}
Expand All @@ -241,7 +248,7 @@ export class QueryNameMapper extends OperationNodeTransformer {
// if the field as a qualifier, the qualifier must match the scope's
// alias if any, or model if no alias
if (scope.alias) {
if (scope.alias === qualifier) {
if (scope.alias && IdentifierNode.is(scope.alias) && scope.alias.name === qualifier) {
// scope has an alias that matches the qualifier
return scope;
} else {
Expand Down Expand Up @@ -295,8 +302,8 @@ export class QueryNameMapper extends OperationNodeTransformer {
}
}

private wrapAlias<T extends OperationNode>(node: T, alias: string | undefined) {
return alias ? AliasNode.create(node, IdentifierNode.create(alias)) : node;
private wrapAlias<T extends OperationNode>(node: T, alias: OperationNode | undefined) {
return alias ? AliasNode.create(node, alias) : node;
}

private processTableRef(node: TableNode) {
Expand Down Expand Up @@ -351,11 +358,11 @@ export class QueryNameMapper extends OperationNodeTransformer {
// inner transformations will map column names
const modelName = innerNode.table.identifier.name;
const mappedName = this.mapTableName(modelName);
const finalAlias = alias ?? (mappedName !== modelName ? modelName : undefined);
const finalAlias = alias ?? (mappedName !== modelName ? IdentifierNode.create(modelName) : undefined);
return {
node: this.wrapAlias(TableNode.create(mappedName), finalAlias),
scope: {
alias: alias ?? modelName,
alias: alias ?? IdentifierNode.create(modelName),
model: modelName,
namesMapped: !this.hasMappedColumns(modelName),
},
Expand All @@ -374,13 +381,13 @@ export class QueryNameMapper extends OperationNodeTransformer {
}
}

private createSelectAllFields(model: string, alias: string | undefined) {
private createSelectAllFields(model: string, alias: OperationNode | undefined) {
const modelDef = requireModel(this.schema, model);
return this.getModelFields(modelDef).map((fieldDef) => {
const columnName = this.mapFieldName(model, fieldDef.name);
const columnRef = ReferenceNode.create(
ColumnNode.create(columnName),
alias ? TableNode.create(alias) : undefined,
alias && IdentifierNode.is(alias) ? TableNode.create(alias.name) : undefined,
);
if (columnName !== fieldDef.name) {
const aliased = AliasNode.create(columnRef, IdentifierNode.create(fieldDef.name));
Expand Down Expand Up @@ -421,7 +428,7 @@ export class QueryNameMapper extends OperationNodeTransformer {
alias = this.extractFieldName(node);
}
const result = super.transformNode(node);
return this.wrapAlias(result, alias);
return this.wrapAlias(result, alias ? IdentifierNode.create(alias) : undefined);
}

private processSelectAll(node: SelectAllNode) {
Expand All @@ -438,7 +445,9 @@ export class QueryNameMapper extends OperationNodeTransformer {
return this.getModelFields(modelDef).map((fieldDef) => {
const columnName = this.mapFieldName(modelDef.name, fieldDef.name);
const columnRef = ReferenceNode.create(ColumnNode.create(columnName));
return columnName !== fieldDef.name ? this.wrapAlias(columnRef, fieldDef.name) : columnRef;
return columnName !== fieldDef.name
? this.wrapAlias(columnRef, IdentifierNode.create(fieldDef.name))
: columnRef;
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ export class ZenStackQueryExecutor<Schema extends SchemaDef> extends DefaultQuer
const hookResult = await hook!({
client: this.client as ClientContract<Schema>,
schema: this.client.$schema,
kysely: this.kysely,
query,
proceed: _p,
});
Expand Down
3 changes: 1 addition & 2 deletions packages/runtime/src/client/plugin.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import type { OperationNode, QueryResult, RootOperationNode, UnknownRow } from 'kysely';
import type { ClientContract, ToKysely } from '.';
import type { ClientContract } from '.';
import type { GetModels, SchemaDef } from '../schema';
import type { MaybePromise } from '../utils/type-utils';
import type { AllCrudOperation } from './crud/operations/base';
Expand Down Expand Up @@ -180,7 +180,6 @@ export type PluginAfterEntityMutationArgs<Schema extends SchemaDef> = MutationHo
// #region OnKyselyQuery hooks

export type OnKyselyQueryArgs<Schema extends SchemaDef> = {
kysely: ToKysely<Schema>;
schema: SchemaDef;
client: ClientContract<Schema>;
query: RootOperationNode;
Expand Down
Loading
Loading