diff --git a/TODO.md b/TODO.md index 8ccc6729..cd66cb46 100644 --- a/TODO.md +++ b/TODO.md @@ -83,8 +83,10 @@ - [x] Error system - [x] Custom table name - [x] Custom field name + - [ ] Global omit - [ ] DbNull vs JsonNull - [ ] Migrate to tsdown + - [ ] @default validation - [ ] Benchmark - [x] Plugin - [x] Post-mutation hooks should be called after transaction is committed diff --git a/packages/language/src/validators/attribute-application-validator.ts b/packages/language/src/validators/attribute-application-validator.ts index c04666df..dc376036 100644 --- a/packages/language/src/validators/attribute-application-validator.ts +++ b/packages/language/src/validators/attribute-application-validator.ts @@ -1,19 +1,20 @@ import { AstUtils, type ValidationAcceptor } from 'langium'; import pluralize from 'pluralize'; +import type { BinaryExpr, DataModel, Expression } from '../ast'; import { ArrayExpr, Attribute, AttributeArg, AttributeParam, - DataModelAttribute, DataField, DataFieldAttribute, + DataModelAttribute, InternalAttribute, ReferenceExpr, isArrayExpr, isAttribute, - isDataModel, isDataField, + isDataModel, isEnum, isReferenceExpr, isTypeDef, @@ -21,7 +22,8 @@ import { import { getAllAttributes, getStringLiteral, - hasAttribute, + isAuthOrAuthMemberAccess, + isCollectionPredicate, isDataFieldReference, isDelegateModel, isFutureExpr, @@ -31,7 +33,6 @@ import { typeAssignable, } from '../utils'; import type { AstValidator } from './common'; -import type { DataModel } from '../ast'; // a registry of function handlers marked with @check const attributeCheckers = new Map(); @@ -153,6 +154,7 @@ export default class AttributeApplicationValidator implements AstValidator { + if (!isDataFieldReference(node)) { + // not a field reference, skip + return false; + } + + // referenced field is not a member of the context model, skip + if (node.target.ref?.$container !== contextModel) { + return false; + } + + const field = node.target.ref as DataField; + if (!isRelationshipField(field)) { + // not a relation, skip + return false; + } + + if (isAuthOrAuthMemberAccess(node)) { + // field reference is from auth() or access from auth(), not a relation query + return false; + } + + // check if the the node is a reference inside a collection predicate scope by auth access, + // e.g., `auth().foo?[x > 0]` + + // make sure to skip the current level if the node is already an LHS of a collection predicate, + // otherwise we're just circling back to itself when visiting the parent + const startNode = + isCollectionPredicate(node.$container) && (node.$container as BinaryExpr).left === node + ? node.$container + : node; + const collectionPredicate = AstUtils.getContainerOfType(startNode.$container, isCollectionPredicate); + if (collectionPredicate && isAuthOrAuthMemberAccess(collectionPredicate.left)) { + return false; + } + + const relationAttr = field.attributes.find((attr) => attr.decl.ref?.name === '@relation'); + if (!relationAttr) { + // no "@relation", not owner side of the relation, match + return true; + } + + if (!relationAttr.args.some((arg) => arg.name === 'fields')) { + // no "fields" argument, can't be owner side of the relation, match + return true; + } + + return false; + }) + ) { + accept('error', `non-owned relation fields are not allowed in "create" rules`, { node: expr }); + } + } + + // TODO: design a way to let plugin register validation @check('@allow') @check('@deny') // @ts-expect-error @@ -199,9 +266,6 @@ export default class AttributeApplicationValidator implements AstValidator { - if (isDataFieldReference(node) && hasAttribute(node.target.ref as DataField, '@encrypted')) { - accept('error', `Encrypted fields cannot be used in policy rules`, { node }); - } - }); - } - private validatePolicyKinds( kind: string, candidates: string[], diff --git a/packages/runtime/src/client/crud/operations/base.ts b/packages/runtime/src/client/crud/operations/base.ts index c3bab79d..781e6468 100644 --- a/packages/runtime/src/client/crud/operations/base.ts +++ b/packages/runtime/src/client/crud/operations/base.ts @@ -353,7 +353,7 @@ export abstract class BaseOperationHandler { createFields = baseCreateResult.remainingFields; } - const updatedData = this.fillGeneratedValues(modelDef, createFields); + const updatedData = this.fillGeneratedAndDefaultValues(modelDef, createFields); const idFields = getIdFields(this.schema, model); const query = kysely .insertInto(model) @@ -722,7 +722,7 @@ export abstract class BaseOperationHandler { newItem[fk] = fromRelation.ids[pk]; } } - return this.fillGeneratedValues(modelDef, newItem); + return this.fillGeneratedAndDefaultValues(modelDef, newItem); }); if (!this.dialect.supportInsertWithDefault) { @@ -841,7 +841,7 @@ export abstract class BaseOperationHandler { return { baseEntities, remainingFieldRows }; } - private fillGeneratedValues(modelDef: ModelDef, data: object) { + private fillGeneratedAndDefaultValues(modelDef: ModelDef, data: object) { const fields = modelDef.fields; const values: any = clone(data); for (const [field, fieldDef] of Object.entries(fields)) { @@ -858,6 +858,21 @@ export abstract class BaseOperationHandler { } else if (fields[field]?.updatedAt) { // TODO: should this work at kysely level instead? values[field] = this.dialect.transformPrimitive(new Date(), 'DateTime', false); + } else if (fields[field]?.default !== undefined) { + let value = fields[field].default; + if (fieldDef.type === 'Json') { + // Schema uses JSON string for default value of Json fields + if (fieldDef.array && Array.isArray(value)) { + value = value.map((v) => (typeof v === 'string' ? JSON.parse(v) : v)); + } else if (typeof value === 'string') { + value = JSON.parse(value); + } + } + values[field] = this.dialect.transformPrimitive( + value, + fields[field].type as BuiltinType, + !!fields[field].array, + ); } } } diff --git a/packages/runtime/src/client/crud/validator.ts b/packages/runtime/src/client/crud/validator.ts index b4097dea..372129ff 100644 --- a/packages/runtime/src/client/crud/validator.ts +++ b/packages/runtime/src/client/crud/validator.ts @@ -22,12 +22,11 @@ import { type UpdateManyArgs, type UpsertArgs, } from '../crud-types'; -import { InputValidationError, InternalError, QueryError } from '../errors'; +import { InputValidationError, InternalError } from '../errors'; import { fieldHasDefaultValue, getDiscriminatorField, getEnum, - getModel, getUniqueFields, requireField, requireModel, @@ -279,10 +278,7 @@ export class InputValidator { withoutRelationFields = false, withAggregations = false, ): ZodType { - const modelDef = getModel(this.schema, model); - if (!modelDef) { - throw new QueryError(`Model "${model}" not found in schema`); - } + const modelDef = requireModel(this.schema, model); const fields: Record = {}; for (const field of Object.keys(modelDef.fields)) { diff --git a/packages/runtime/src/client/executor/kysely-utils.ts b/packages/runtime/src/client/executor/kysely-utils.ts index 5ae92d39..fb9ec845 100644 --- a/packages/runtime/src/client/executor/kysely-utils.ts +++ b/packages/runtime/src/client/executor/kysely-utils.ts @@ -1,13 +1,11 @@ -import { invariant } from '@zenstackhq/common-helpers'; -import { type OperationNode, AliasNode, IdentifierNode } from 'kysely'; +import { type OperationNode, AliasNode } from 'kysely'; /** * Strips alias from the node if it exists. */ export function stripAlias(node: OperationNode) { if (AliasNode.is(node)) { - invariant(IdentifierNode.is(node.alias), 'Expected identifier as alias'); - return { alias: node.alias.name, node: node.node }; + return { alias: node.alias, node: node.node }; } else { return { alias: undefined, node }; } diff --git a/packages/runtime/src/client/executor/name-mapper.ts b/packages/runtime/src/client/executor/name-mapper.ts index cc8163c1..c839bc75 100644 --- a/packages/runtime/src/client/executor/name-mapper.ts +++ b/packages/runtime/src/client/executor/name-mapper.ts @@ -22,7 +22,7 @@ import { stripAlias } from './kysely-utils'; type Scope = { model?: string; - alias?: string; + alias?: OperationNode; namesMapped?: boolean; // true means fields referring to this scope have their names already mapped }; @@ -120,7 +120,7 @@ export class QueryNameMapper extends OperationNodeTransformer { // map table name depending on how it is resolved let mappedTableName = node.table?.table.identifier.name; if (mappedTableName) { - if (scope.alias === mappedTableName) { + if (scope.alias && IdentifierNode.is(scope.alias) && scope.alias.name === mappedTableName) { // table name is resolved to an alias, no mapping needed } else if (scope.model === mappedTableName) { // table name is resolved to a model, map the name as needed @@ -222,7 +222,14 @@ export class QueryNameMapper extends OperationNodeTransformer { const origFieldName = this.extractFieldName(selection.selection); const fieldName = this.extractFieldName(transformed); if (fieldName !== origFieldName) { - selections.push(SelectionNode.create(this.wrapAlias(transformed, origFieldName))); + selections.push( + SelectionNode.create( + this.wrapAlias( + transformed, + origFieldName ? IdentifierNode.create(origFieldName) : undefined, + ), + ), + ); } else { selections.push(SelectionNode.create(transformed)); } @@ -241,7 +248,7 @@ export class QueryNameMapper extends OperationNodeTransformer { // if the field as a qualifier, the qualifier must match the scope's // alias if any, or model if no alias if (scope.alias) { - if (scope.alias === qualifier) { + if (scope.alias && IdentifierNode.is(scope.alias) && scope.alias.name === qualifier) { // scope has an alias that matches the qualifier return scope; } else { @@ -295,8 +302,8 @@ export class QueryNameMapper extends OperationNodeTransformer { } } - private wrapAlias(node: T, alias: string | undefined) { - return alias ? AliasNode.create(node, IdentifierNode.create(alias)) : node; + private wrapAlias(node: T, alias: OperationNode | undefined) { + return alias ? AliasNode.create(node, alias) : node; } private processTableRef(node: TableNode) { @@ -351,11 +358,11 @@ export class QueryNameMapper extends OperationNodeTransformer { // inner transformations will map column names const modelName = innerNode.table.identifier.name; const mappedName = this.mapTableName(modelName); - const finalAlias = alias ?? (mappedName !== modelName ? modelName : undefined); + const finalAlias = alias ?? (mappedName !== modelName ? IdentifierNode.create(modelName) : undefined); return { node: this.wrapAlias(TableNode.create(mappedName), finalAlias), scope: { - alias: alias ?? modelName, + alias: alias ?? IdentifierNode.create(modelName), model: modelName, namesMapped: !this.hasMappedColumns(modelName), }, @@ -374,13 +381,13 @@ export class QueryNameMapper extends OperationNodeTransformer { } } - private createSelectAllFields(model: string, alias: string | undefined) { + private createSelectAllFields(model: string, alias: OperationNode | undefined) { const modelDef = requireModel(this.schema, model); return this.getModelFields(modelDef).map((fieldDef) => { const columnName = this.mapFieldName(model, fieldDef.name); const columnRef = ReferenceNode.create( ColumnNode.create(columnName), - alias ? TableNode.create(alias) : undefined, + alias && IdentifierNode.is(alias) ? TableNode.create(alias.name) : undefined, ); if (columnName !== fieldDef.name) { const aliased = AliasNode.create(columnRef, IdentifierNode.create(fieldDef.name)); @@ -421,7 +428,7 @@ export class QueryNameMapper extends OperationNodeTransformer { alias = this.extractFieldName(node); } const result = super.transformNode(node); - return this.wrapAlias(result, alias); + return this.wrapAlias(result, alias ? IdentifierNode.create(alias) : undefined); } private processSelectAll(node: SelectAllNode) { @@ -438,7 +445,9 @@ export class QueryNameMapper extends OperationNodeTransformer { return this.getModelFields(modelDef).map((fieldDef) => { const columnName = this.mapFieldName(modelDef.name, fieldDef.name); const columnRef = ReferenceNode.create(ColumnNode.create(columnName)); - return columnName !== fieldDef.name ? this.wrapAlias(columnRef, fieldDef.name) : columnRef; + return columnName !== fieldDef.name + ? this.wrapAlias(columnRef, IdentifierNode.create(fieldDef.name)) + : columnRef; }); } diff --git a/packages/runtime/src/client/executor/zenstack-query-executor.ts b/packages/runtime/src/client/executor/zenstack-query-executor.ts index 768f65ae..be317924 100644 --- a/packages/runtime/src/client/executor/zenstack-query-executor.ts +++ b/packages/runtime/src/client/executor/zenstack-query-executor.ts @@ -100,7 +100,6 @@ export class ZenStackQueryExecutor extends DefaultQuer const hookResult = await hook!({ client: this.client as ClientContract, schema: this.client.$schema, - kysely: this.kysely, query, proceed: _p, }); diff --git a/packages/runtime/src/client/plugin.ts b/packages/runtime/src/client/plugin.ts index 0a4c4a7f..62216a3d 100644 --- a/packages/runtime/src/client/plugin.ts +++ b/packages/runtime/src/client/plugin.ts @@ -1,5 +1,5 @@ import type { OperationNode, QueryResult, RootOperationNode, UnknownRow } from 'kysely'; -import type { ClientContract, ToKysely } from '.'; +import type { ClientContract } from '.'; import type { GetModels, SchemaDef } from '../schema'; import type { MaybePromise } from '../utils/type-utils'; import type { AllCrudOperation } from './crud/operations/base'; @@ -180,7 +180,6 @@ export type PluginAfterEntityMutationArgs = MutationHo // #region OnKyselyQuery hooks export type OnKyselyQueryArgs = { - kysely: ToKysely; schema: SchemaDef; client: ClientContract; query: RootOperationNode; diff --git a/packages/runtime/src/client/query-utils.ts b/packages/runtime/src/client/query-utils.ts index 6f961029..fdce2aaf 100644 --- a/packages/runtime/src/client/query-utils.ts +++ b/packages/runtime/src/client/query-utils.ts @@ -14,15 +14,19 @@ export function hasModel(schema: SchemaDef, model: string) { } export function getModel(schema: SchemaDef, model: string) { - return schema.models[model]; + return Object.values(schema.models).find((m) => m.name.toLowerCase() === model.toLowerCase()); +} + +export function getTypeDef(schema: SchemaDef, type: string) { + return schema.typeDefs?.[type]; } export function requireModel(schema: SchemaDef, model: string) { - const matchedName = Object.keys(schema.models).find((k) => k.toLowerCase() === model.toLowerCase()); - if (!matchedName) { + const modelDef = getModel(schema, model); + if (!modelDef) { throw new QueryError(`Model "${model}" not found in schema`); } - return schema.models[matchedName]!; + return modelDef; } export function getField(schema: SchemaDef, model: string, field: string) { @@ -30,12 +34,24 @@ export function getField(schema: SchemaDef, model: string, field: string) { return modelDef?.fields[field]; } -export function requireField(schema: SchemaDef, model: string, field: string) { - const modelDef = requireModel(schema, model); - if (!modelDef.fields[field]) { - throw new QueryError(`Field "${field}" not found in model "${model}"`); +export function requireField(schema: SchemaDef, modelOrType: string, field: string) { + const modelDef = getModel(schema, modelOrType); + if (modelDef) { + if (!modelDef.fields[field]) { + throw new QueryError(`Field "${field}" not found in model "${modelOrType}"`); + } else { + return modelDef.fields[field]; + } + } + const typeDef = getTypeDef(schema, modelOrType); + if (typeDef) { + if (!typeDef.fields[field]) { + throw new QueryError(`Field "${field}" not found in type "${modelOrType}"`); + } else { + return typeDef.fields[field]; + } } - return modelDef.fields[field]; + throw new QueryError(`Model or type "${modelOrType}" not found in schema`); } export function getIdFields(schema: SchemaDef, model: GetModels) { diff --git a/packages/runtime/src/plugins/policy/expression-transformer.ts b/packages/runtime/src/plugins/policy/expression-transformer.ts index bbc98881..c43a6fb7 100644 --- a/packages/runtime/src/plugins/policy/expression-transformer.ts +++ b/packages/runtime/src/plugins/policy/expression-transformer.ts @@ -25,7 +25,7 @@ import { getCrudDialect } from '../../client/crud/dialects'; import type { BaseCrudDialect } from '../../client/crud/dialects/base'; import { InternalError, QueryError } from '../../client/errors'; import type { ClientOptions } from '../../client/options'; -import { getRelationForeignKeyFieldPairs, requireField } from '../../client/query-utils'; +import { getModel, getRelationForeignKeyFieldPairs, requireField } from '../../client/query-utils'; import type { BinaryExpression, BinaryOperator, @@ -51,8 +51,6 @@ export type ExpressionTransformerContext = { model: GetModels; alias?: string; operation: CRUD; - thisEntity?: Record; - thisEntityRaw?: Record; auth?: any; memberFilter?: OperationNode; memberSelect?: SelectionNode; @@ -86,7 +84,7 @@ export class ExpressionTransformer { if (!this.schema.authType) { throw new InternalError('Schema does not have an "authType" specified'); } - return this.schema.authType; + return this.schema.authType!; } transform(expression: Expression, context: ExpressionTransformerContext): OperationNode { @@ -117,11 +115,7 @@ export class ExpressionTransformer { private _field(expr: FieldExpression, context: ExpressionTransformerContext) { const fieldDef = requireField(this.schema, context.model, expr.field); if (!fieldDef.relation) { - if (context.thisEntity) { - return context.thisEntity[expr.field]; - } else { - return this.createColumnRef(expr.field, context); - } + return this.createColumnRef(expr.field, context); } else { const { memberFilter, memberSelect, ...restContext } = context; const relation = this.transformRelationAccess(expr.field, fieldDef.type, restContext); @@ -159,7 +153,7 @@ export class ExpressionTransformer { } if (this.isAuthCall(expr.left) || this.isAuthCall(expr.right)) { - return this.transformAuthBinary(expr); + return this.transformAuthBinary(expr, context); } const op = expr.op; @@ -234,7 +228,6 @@ export class ExpressionTransformer { ...context, model: newContextModel as GetModels, alias: undefined, - thisEntity: undefined, }); if (expr.op === '!') { @@ -256,21 +249,50 @@ export class ExpressionTransformer { }); } - private transformAuthBinary(expr: BinaryExpression) { + private transformAuthBinary(expr: BinaryExpression, context: ExpressionTransformerContext) { if (expr.op !== '==' && expr.op !== '!=') { - throw new Error(`Unsupported operator for auth call: ${expr.op}`); + throw new QueryError( + `Unsupported operator for \`auth()\` in policy of model "${context.model}": ${expr.op}`, + ); } + + let authExpr: Expression; let other: Expression; if (this.isAuthCall(expr.left)) { + authExpr = expr.left; other = expr.right; } else { + authExpr = expr.right; other = expr.left; } if (ExpressionUtils.isNull(other)) { return this.transformValue(expr.op === '==' ? !this.auth : !!this.auth, 'Boolean'); } else { - throw new Error('Unsupported binary expression with `auth()`'); + const authModel = getModel(this.schema, this.authType); + if (!authModel) { + throw new QueryError( + `Unsupported use of \`auth()\` in policy of model "${context.model}", comparing with \`auth()\` is only possible when auth type is a model`, + ); + } + + const idFields = Object.values(authModel.fields) + .filter((f) => f.id) + .map((f) => f.name); + invariant(idFields.length > 0, 'auth type model must have at least one id field'); + + const conditions = idFields.map((fieldName) => + ExpressionUtils.binary( + ExpressionUtils.member(authExpr, [fieldName]), + '==', + ExpressionUtils.member(other, [fieldName]), + ), + ); + let result = this.buildAnd(conditions); + if (expr.op === '!=') { + result = this.buildLogicalNot(result); + } + return this.transform(result, context); } } @@ -331,7 +353,7 @@ export class ExpressionTransformer { } if (ExpressionUtils.isField(arg)) { - return context.thisEntityRaw ? eb.val(context.thisEntityRaw[arg.field]) : eb.ref(arg.field); + return eb.ref(arg.field); } if (ExpressionUtils.isCall(arg)) { @@ -358,20 +380,46 @@ export class ExpressionTransformer { return this.valueMemberAccess(this.auth, expr, this.authType); } - invariant(ExpressionUtils.isField(expr.receiver), 'expect receiver to be field expression'); + invariant( + ExpressionUtils.isField(expr.receiver) || ExpressionUtils.isThis(expr.receiver), + 'expect receiver to be field expression or "this"', + ); + let members = expr.members; + let receiver: OperationNode; const { memberFilter, memberSelect, ...restContext } = context; - const receiver = this.transform(expr.receiver, restContext); + if (ExpressionUtils.isThis(expr.receiver)) { + if (expr.members.length === 1) { + // optimize for the simple this.scalar case + const fieldDef = requireField(this.schema, context.model, expr.members[0]!); + invariant(!fieldDef.relation, 'this.relation access should have been transformed into relation access'); + return this.createColumnRef(expr.members[0]!, restContext); + } + + // transform the first segment into a relation access, then continue with the rest of the members + const firstMemberFieldDef = requireField(this.schema, context.model, expr.members[0]!); + receiver = this.transformRelationAccess(expr.members[0]!, firstMemberFieldDef.type, restContext); + members = expr.members.slice(1); + } else { + receiver = this.transform(expr.receiver, restContext); + } + invariant(SelectQueryNode.is(receiver), 'expected receiver to be select query'); - // relation member access - const receiverField = requireField(this.schema, context.model, expr.receiver.field); + let startType: string; + if (ExpressionUtils.isField(expr.receiver)) { + const receiverField = requireField(this.schema, context.model, expr.receiver.field); + startType = receiverField.type; + } else { + // "this." case, start type is the model of the context + startType = context.model; + } // traverse forward to collect member types const memberFields: { fromModel: string; fieldDef: FieldDef }[] = []; - let currType = receiverField.type; - for (const member of expr.members) { + let currType = startType; + for (const member of members) { const fieldDef = requireField(this.schema, currType, member); memberFields.push({ fieldDef, fromModel: currType }); currType = fieldDef.type; @@ -379,8 +427,8 @@ export class ExpressionTransformer { let currNode: SelectQueryNode | ColumnNode | ReferenceNode | undefined = undefined; - for (let i = expr.members.length - 1; i >= 0; i--) { - const member = expr.members[i]!; + for (let i = members.length - 1; i >= 0; i--) { + const member = members[i]!; const { fieldDef, fromModel } = memberFields[i]!; if (fieldDef.relation) { @@ -388,7 +436,6 @@ export class ExpressionTransformer { ...restContext, model: fromModel as GetModels, alias: undefined, - thisEntity: undefined, }); if (currNode) { @@ -396,9 +443,7 @@ export class ExpressionTransformer { currNode = { ...relation, selections: [ - SelectionNode.create( - AliasNode.create(currNode, IdentifierNode.create(expr.members[i + 1]!)), - ), + SelectionNode.create(AliasNode.create(currNode, IdentifierNode.create(members[i + 1]!))), ], }; } else { @@ -410,7 +455,7 @@ export class ExpressionTransformer { }; } } else { - invariant(i === expr.members.length - 1, 'plain field access must be the last segment'); + invariant(i === members.length - 1, 'plain field access must be the last segment'); invariant(!currNode, 'plain field access must be the last segment'); currNode = ColumnNode.create(member); @@ -446,71 +491,38 @@ export class ExpressionTransformer { const fromModel = context.model; const { keyPairs, ownedByModel } = getRelationForeignKeyFieldPairs(this.schema, fromModel, field); - if (context.thisEntity) { - let condition: OperationNode; - if (ownedByModel) { - condition = conjunction( - this.dialect, - keyPairs.map(({ fk, pk }) => - BinaryOperationNode.create( - ReferenceNode.create(ColumnNode.create(pk), TableNode.create(relationModel)), - OperatorNode.create('='), - context.thisEntity![fk]!, - ), - ), - ); - } else { - condition = conjunction( - this.dialect, - keyPairs.map(({ fk, pk }) => - BinaryOperationNode.create( - ReferenceNode.create(ColumnNode.create(fk), TableNode.create(relationModel)), - OperatorNode.create('='), - context.thisEntity![pk]!, - ), + let condition: OperationNode; + if (ownedByModel) { + // `fromModel` owns the fk + condition = conjunction( + this.dialect, + keyPairs.map(({ fk, pk }) => + BinaryOperationNode.create( + ReferenceNode.create(ColumnNode.create(fk), TableNode.create(context.alias ?? fromModel)), + OperatorNode.create('='), + ReferenceNode.create(ColumnNode.create(pk), TableNode.create(relationModel)), ), - ); - } - - return { - kind: 'SelectQueryNode', - from: FromNode.create([TableNode.create(relationModel)]), - where: WhereNode.create(condition), - }; + ), + ); } else { - let condition: OperationNode; - if (ownedByModel) { - // `fromModel` owns the fk - condition = conjunction( - this.dialect, - keyPairs.map(({ fk, pk }) => - BinaryOperationNode.create( - ReferenceNode.create(ColumnNode.create(fk), TableNode.create(context.alias ?? fromModel)), - OperatorNode.create('='), - ReferenceNode.create(ColumnNode.create(pk), TableNode.create(relationModel)), - ), - ), - ); - } else { - // `relationModel` owns the fk - condition = conjunction( - this.dialect, - keyPairs.map(({ fk, pk }) => - BinaryOperationNode.create( - ReferenceNode.create(ColumnNode.create(pk), TableNode.create(context.alias ?? fromModel)), - OperatorNode.create('='), - ReferenceNode.create(ColumnNode.create(fk), TableNode.create(relationModel)), - ), + // `relationModel` owns the fk + condition = conjunction( + this.dialect, + keyPairs.map(({ fk, pk }) => + BinaryOperationNode.create( + ReferenceNode.create(ColumnNode.create(pk), TableNode.create(context.alias ?? fromModel)), + OperatorNode.create('='), + ReferenceNode.create(ColumnNode.create(fk), TableNode.create(relationModel)), ), - ); - } - - return { - kind: 'SelectQueryNode', - from: FromNode.create([TableNode.create(relationModel)]), - where: WhereNode.create(condition), - }; + ), + ); } + + return { + kind: 'SelectQueryNode', + from: FromNode.create([TableNode.create(relationModel)]), + where: WhereNode.create(condition), + }; } private createColumnRef(column: string, context: ExpressionTransformerContext): ReferenceNode { @@ -528,4 +540,18 @@ export class ExpressionTransformer { private isNullNode(node: OperationNode) { return ValueNode.is(node) && node.value === null; } + + private buildLogicalNot(result: Expression): Expression { + return ExpressionUtils.unary('!', result); + } + + private buildAnd(conditions: BinaryExpression[]): Expression { + if (conditions.length === 0) { + return ExpressionUtils.literal(true); + } else if (conditions.length === 1) { + return conditions[0]!; + } else { + return conditions.reduce((acc, condition) => ExpressionUtils.binary(acc, '&&', condition)); + } + } } diff --git a/packages/runtime/src/plugins/policy/policy-handler.ts b/packages/runtime/src/plugins/policy/policy-handler.ts index 7cb672c2..54e4ff7d 100644 --- a/packages/runtime/src/plugins/policy/policy-handler.ts +++ b/packages/runtime/src/plugins/policy/policy-handler.ts @@ -5,10 +5,12 @@ import { ColumnNode, DeleteQueryNode, FromNode, + FunctionNode, IdentifierNode, InsertQueryNode, OperationNodeTransformer, OperatorNode, + ParensNode, PrimitiveValueListNode, RawNode, ReturningNode, @@ -16,6 +18,7 @@ import { SelectQueryNode, TableNode, UpdateQueryNode, + ValueListNode, ValueNode, ValuesNode, WhereNode, @@ -103,7 +106,7 @@ export class PolicyHandler extends OperationNodeTransf } // TODO: run in transaction - //let readBackError = false; + // let readBackError = false; // transform and post-process in a transaction // const result = await transaction(async (txProceed) => { @@ -142,19 +145,14 @@ export class PolicyHandler extends OperationNodeTransf } private async enforcePreCreatePolicy(node: InsertQueryNode, proceed: ProceedKyselyQueryFunction) { - if (!node.columns || !node.values) { - return; - } - const model = this.getMutationModel(node); - const fields = node.columns.map((c) => c.column.name); - const valueRows = this.unwrapCreateValueRows(node.values, model, fields); + const fields = node.columns?.map((c) => c.column.name) ?? []; + const valueRows = node.values ? this.unwrapCreateValueRows(node.values, model, fields) : [[]]; for (const values of valueRows) { await this.enforcePreCreatePolicyForOne( model, fields, values.map((v) => v.node), - values.map((v) => v.raw), proceed, ); } @@ -164,23 +162,54 @@ export class PolicyHandler extends OperationNodeTransf model: GetModels, fields: string[], values: OperationNode[], - valuesRaw: unknown[], proceed: ProceedKyselyQueryFunction, ) { - const thisEntity: Record = {}; - const thisEntityRaw: Record = {}; - for (let i = 0; i < fields.length; i++) { - thisEntity[fields[i]!] = values[i]!; - thisEntityRaw[fields[i]!] = valuesRaw[i]!; + const allFields = Object.keys(requireModel(this.client.$schema, model).fields); + const allValues: OperationNode[] = []; + + for (const fieldName of allFields) { + const index = fields.indexOf(fieldName); + if (index >= 0) { + allValues.push(values[index]!); + } else { + // set non-provided fields to null + allValues.push(ValueNode.createImmediate(null)); + } } - const filter = this.buildPolicyFilter(model, undefined, 'create', thisEntity, thisEntityRaw); + // create a `SELECT column1 as field1, column2 as field2, ... FROM (VALUES (...))` table for policy evaluation + const constTable: SelectQueryNode = { + kind: 'SelectQueryNode', + from: FromNode.create([ParensNode.create(ValuesNode.create([ValueListNode.create(allValues)]))]), + selections: allFields.map((field, index) => + SelectionNode.create( + AliasNode.create(ColumnNode.create(`column${index + 1}`), IdentifierNode.create(field)), + ), + ), + }; + + const filter = this.buildPolicyFilter(model, undefined, 'create'); + const preCreateCheck: SelectQueryNode = { kind: 'SelectQueryNode', - selections: [SelectionNode.create(AliasNode.create(filter, IdentifierNode.create('$condition')))], + from: FromNode.create([AliasNode.create(constTable, IdentifierNode.create(model))]), + selections: [ + SelectionNode.create( + AliasNode.create( + BinaryOperationNode.create( + FunctionNode.create('COUNT', [ValueNode.createImmediate(1)]), + OperatorNode.create('>'), + ValueNode.createImmediate(0), + ), + IdentifierNode.create('$condition'), + ), + ), + ], + where: WhereNode.create(filter), }; + const result = await proceed(preCreateCheck); - if (!(result.rows[0] as any)?.$condition) { + if (!result.rows[0]?.$condition) { throw new RejectedByPolicyError(model); } } @@ -327,13 +356,7 @@ export class PolicyHandler extends OperationNodeTransf return InsertQueryNode.is(node) || UpdateQueryNode.is(node) || DeleteQueryNode.is(node); } - private buildPolicyFilter( - model: GetModels, - alias: string | undefined, - operation: CRUD, - thisEntity?: Record, - thisEntityRaw?: Record, - ) { + private buildPolicyFilter(model: GetModels, alias: string | undefined, operation: CRUD) { const policies = this.getModelPolicies(model, operation); if (policies.length === 0) { return falseNode(this.dialect); @@ -341,11 +364,11 @@ export class PolicyHandler extends OperationNodeTransf const allows = policies .filter((policy) => policy.kind === 'allow') - .map((policy) => this.transformPolicyCondition(model, alias, operation, policy, thisEntity, thisEntityRaw)); + .map((policy) => this.transformPolicyCondition(model, alias, operation, policy)); const denies = policies .filter((policy) => policy.kind === 'deny') - .map((policy) => this.transformPolicyCondition(model, alias, operation, policy, thisEntity, thisEntityRaw)); + .map((policy) => this.transformPolicyCondition(model, alias, operation, policy)); let combinedPolicy: OperationNode; @@ -458,8 +481,6 @@ export class PolicyHandler extends OperationNodeTransf alias: string | undefined, operation: CRUD, policy: Policy, - thisEntity?: Record, - thisEntityRaw?: Record, ) { return new ExpressionTransformer(this.client.$schema, this.client.$options, this.client.$auth).transform( policy.condition, @@ -467,8 +488,6 @@ export class PolicyHandler extends OperationNodeTransf model, alias, operation, - thisEntity, - thisEntityRaw, auth: this.client.$auth, }, ); diff --git a/packages/runtime/src/schema/expression.ts b/packages/runtime/src/schema/expression.ts index a650391a..2e2337fa 100644 --- a/packages/runtime/src/schema/expression.ts +++ b/packages/runtime/src/schema/expression.ts @@ -88,6 +88,10 @@ export const ExpressionUtils = { return expressions.reduce((acc, exp) => ExpressionUtils.binary(acc, '||', exp), expr); }, + not: (expr: Expression) => { + return ExpressionUtils.unary('!', expr); + }, + is: (value: unknown, kind: Expression['kind']): value is Expression => { return !!value && typeof value === 'object' && 'kind' in value && value.kind === kind; }, diff --git a/packages/runtime/test/plugin/on-kysely-query.test.ts b/packages/runtime/test/plugin/on-kysely-query.test.ts index 7e0ac024..75105927 100644 --- a/packages/runtime/test/plugin/on-kysely-query.test.ts +++ b/packages/runtime/test/plugin/on-kysely-query.test.ts @@ -84,7 +84,7 @@ describe('On kysely query tests', () => { it('supports spawning multiple queries', async () => { const client = _client.$use({ id: 'test-plugin', - async onKyselyQuery({ kysely, proceed, query }) { + async onKyselyQuery({ client, proceed, query }) { if (query.kind !== 'InsertQueryNode') { return proceed(query); } @@ -92,7 +92,7 @@ describe('On kysely query tests', () => { const result = await proceed(query); // create a post for the user - await proceed(createPost(kysely, result)); + await proceed(createPost(client.$qb, result)); return result; }, diff --git a/packages/runtime/test/policy/crud/create.test.ts b/packages/runtime/test/policy/crud/create.test.ts new file mode 100644 index 00000000..a9bacb01 --- /dev/null +++ b/packages/runtime/test/policy/crud/create.test.ts @@ -0,0 +1,202 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from '../utils'; + +describe('Policy create tests', () => { + it('works with scalar field check', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id @default(autoincrement()) + x Int + @@allow('create', x > 0) + @@allow('read', true) +} +`, + ); + await expect(db.foo.create({ data: { x: 0 } })).toBeRejectedByPolicy(); + await expect(db.foo.create({ data: { x: 1 } })).resolves.toMatchObject({ x: 1 }); + }); + + it('works with this scalar member check', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id @default(autoincrement()) + x Int + @@allow('create', this.x > 0) + @@allow('read', true) +} +`, + ); + await expect(db.foo.create({ data: { x: 0 } })).toBeRejectedByPolicy(); + await expect(db.foo.create({ data: { x: 1 } })).resolves.toMatchObject({ x: 1 }); + }); + + it('denies by default', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id @default(autoincrement()) + x Int +} +`, + ); + await expect(db.foo.create({ data: { x: 0 } })).toBeRejectedByPolicy(); + }); + + it('works with deny rule', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id @default(autoincrement()) + x Int + @@deny('create', x <= 0) + @@allow('create,read', true) +} +`, + ); + await expect(db.foo.create({ data: { x: 0 } })).toBeRejectedByPolicy(); + await expect(db.foo.create({ data: { x: 1 } })).resolves.toMatchObject({ x: 1 }); + }); + + it('works with mixed allow and deny rules', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id @default(autoincrement()) + x Int + @@deny('create', x <= 0) + @@allow('create', x > 1) + @@allow('read', true) +} +`, + ); + await expect(db.foo.create({ data: { x: 0 } })).toBeRejectedByPolicy(); + await expect(db.foo.create({ data: { x: 1 } })).toBeRejectedByPolicy(); + await expect(db.foo.create({ data: { x: 2 } })).resolves.toMatchObject({ x: 2 }); + }); + + it('works with non-provided fields', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id @default(autoincrement()) + x Int @default(0) + @@allow('create', x > 0) + @@allow('read', true) +} +`, + ); + await expect(db.foo.create({ data: {} })).toBeRejectedByPolicy(); + await expect(db.foo.create({ data: { x: 1 } })).toResolveTruthy(); + }); + + it('works with db-generated fields', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id @default(autoincrement()) + @@allow('create', id > 0) + @@allow('read', true) +} +`, + ); + await expect(db.foo.create({ data: {} })).toBeRejectedByPolicy(); + await expect(db.foo.create({ data: { id: 1 } })).toResolveTruthy(); + }); + + it('rejects non-owned relation reference', async () => { + await expect( + createPolicyTestClient( + ` +model User { + id Int @id + profile Profile? + @@allow('create', profile == null) + @@allow('read', true) +} + +model Profile { + id Int @id + name String + user User @relation(fields: [userId], references: [id]) + userId Int @unique +} + `, + ), + ).rejects.toThrow('non-owned relation fields are not allowed in "create" rules'); + }); + + it('works with auth check', async () => { + const db = await createPolicyTestClient( + ` +type Auth { + x Int + @@auth +} + +model Foo { + id Int @id @default(autoincrement()) + x Int + @@allow('create', x == auth().x) + @@allow('read', true) +} +`, + ); + await expect(db.foo.create({ data: { x: 0 } })).toBeRejectedByPolicy(); + await expect(db.$setAuth({ x: 0 }).foo.create({ data: { x: 1 } })).toBeRejectedByPolicy(); + await expect(db.$setAuth({ x: 1 }).foo.create({ data: { x: 1 } })).resolves.toMatchObject({ x: 1 }); + }); + + it('works with owned to-one relation reference', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + profile Profile? + @@allow('all', true) +} + +model Profile { + id Int @id + user User? @relation(fields: [userId], references: [id]) + userId Int? @unique + + @@deny('all', auth() == null) + @@allow('create', user.id == auth().id) + @@allow('read', true) +} + `, + ); + + await db.user.create({ data: { id: 1 } }); + await expect(db.profile.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + await expect(db.$setAuth({ id: 0 }).profile.create({ data: { id: 1, userId: 1 } })).toBeRejectedByPolicy(); + await expect(db.$setAuth({ id: 1 }).profile.create({ data: { id: 1, userId: 1 } })).resolves.toMatchObject({ + id: 1, + }); + + await expect(db.profile.create({ data: { id: 2, user: { create: { id: 2 } } } })).toBeRejectedByPolicy(); + await expect(db.user.findUnique({ where: { id: 2 } })).toResolveNull(); + await expect( + db + .$setAuth({ id: 2 }) + .profile.create({ data: { id: 2, user: { create: { id: 2 } } }, include: { user: true } }), + ).resolves.toMatchObject({ + id: 2, + user: { + id: 2, + }, + }); + + await db.user.create({ data: { id: 3 } }); + await expect( + db.$setAuth({ id: 2 }).profile.create({ data: { id: 3, user: { connect: { id: 3 } } } }), + ).toBeRejectedByPolicy(); + await expect( + db.$setAuth({ id: 3 }).profile.create({ data: { id: 3, user: { connect: { id: 3 } } } }), + ).toResolveTruthy(); + + await expect(db.$setAuth({ id: 4 }).profile.create({ data: { id: 2, userId: 4 } })).toBeRejectedByPolicy(); + }); +}); diff --git a/packages/runtime/test/policy/crud/dumb-rules.test.ts b/packages/runtime/test/policy/crud/dumb-rules.test.ts new file mode 100644 index 00000000..b169e3a0 --- /dev/null +++ b/packages/runtime/test/policy/crud/dumb-rules.test.ts @@ -0,0 +1,42 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from '../utils'; + +describe('Policy dumb rules tests', () => { + it('works with create dumb rules', async () => { + const db = await createPolicyTestClient( + ` +model A { + id Int @id @default(autoincrement()) + x Int + @@allow('create', 1 > 0) + @@allow('read', true) +} + +model B { + id Int @id @default(autoincrement()) + x Int + @@allow('create', 0 > 1) + @@allow('read', true) +} + +model C { + id Int @id @default(autoincrement()) + x Int + @@allow('create', true) + @@allow('read', true) +} + +model D { + id Int @id @default(autoincrement()) + x Int + @@allow('create', false) + @@allow('read', true) +} +`, + ); + await expect(db.a.create({ data: { x: 0 } })).resolves.toMatchObject({ x: 0 }); + await expect(db.b.create({ data: { x: 0 } })).toBeRejectedByPolicy(); + await expect(db.c.create({ data: { x: 0 } })).resolves.toMatchObject({ x: 0 }); + await expect(db.d.create({ data: { x: 0 } })).toBeRejectedByPolicy(); + }); +}); diff --git a/packages/runtime/test/policy/deep-nested.test.ts b/packages/runtime/test/policy/deep-nested.test.ts index 0be59e24..a35e34b8 100644 --- a/packages/runtime/test/policy/deep-nested.test.ts +++ b/packages/runtime/test/policy/deep-nested.test.ts @@ -7,7 +7,8 @@ describe('deep nested operations tests', () => { // -* M4 model M1 { myId String @id @default(cuid()) - m2 M2? + m2 M2? @relation(fields: [m2Id], references: [id], onDelete: Cascade) + m2Id Int? @unique value Int @default(0) @@allow('all', true) @@ -19,8 +20,7 @@ describe('deep nested operations tests', () => { model M2 { id Int @id @default(autoincrement()) value Int - m1 M1 @relation(fields: [m1Id], references: [myId], onDelete: Cascade) - m1Id String @unique + m1 M1? m3 M3? m4 M4[] @@ -616,7 +616,8 @@ describe('deep nested operations tests', () => { myId: '1', m2: { create: { - value: 1, + id: 1, + value: 3, m4: { create: [{ value: 200 }, { value: 22 }], }, @@ -628,10 +629,14 @@ describe('deep nested operations tests', () => { // delete read-back filtered: M4 @@deny('read', value == 200) const r = await db.m1.delete({ where: { myId: '1' }, - include: { m2: { select: { m4: true } } }, + include: { m2: { select: { id: true, m4: true } } }, }); expect(r.m2.m4).toHaveLength(1); + await expect(db.m2.findMany()).resolves.toHaveLength(1); + await expect(db.m4.findMany()).resolves.toHaveLength(1); + + await db.m2.delete({ where: { id: 1 } }); await expect(db.m4.findMany()).resolves.toHaveLength(0); await db.m1.create({ diff --git a/packages/runtime/test/policy/ref-equality.test.ts b/packages/runtime/test/policy/ref-equality.test.ts new file mode 100644 index 00000000..3196f52b --- /dev/null +++ b/packages/runtime/test/policy/ref-equality.test.ts @@ -0,0 +1,40 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from './utils'; + +describe('Reference Equality Tests', () => { + it('works with auth equality', async () => { + const db = await createPolicyTestClient( + ` +model User { + id1 Int + id2 Int + posts Post[] + @@id([id1, id2]) + @@allow('all', auth() == this) +} + +model Post { + id Int @id @default(autoincrement()) + title String + authorId1 Int + authorId2 Int + author User @relation(fields: [authorId1, authorId2], references: [id1, id2]) + @@allow('all', auth() == author) +} + `, + { log: ['query'] }, + ); + + await expect( + db.user.create({ + data: { id1: 1, id2: 2 }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.$setAuth({ id1: 1, id2: 2 }).user.create({ + data: { id1: 1, id2: 2 }, + }), + ).resolves.toMatchObject({ id1: 1, id2: 2 }); + }); +}); diff --git a/packages/runtime/test/utils.ts b/packages/runtime/test/utils.ts index 4654fccc..64484593 100644 --- a/packages/runtime/test/utils.ts +++ b/packages/runtime/test/utils.ts @@ -3,7 +3,7 @@ import { loadDocument } from '@zenstackhq/language'; import { PrismaSchemaGenerator } from '@zenstackhq/sdk'; import { createTestProject, generateTsSchema } from '@zenstackhq/testtools'; import SQLite from 'better-sqlite3'; -import { PostgresDialect, SqliteDialect } from 'kysely'; +import { PostgresDialect, SqliteDialect, type LogEvent } from 'kysely'; import { execSync } from 'node:child_process'; import fs from 'node:fs'; import path from 'node:path'; @@ -192,3 +192,7 @@ export async function createTestClient( return client; } + +export function testLogger(e: LogEvent) { + console.log(e.query.sql, e.query.parameters); +} diff --git a/packages/sdk/src/model-utils.ts b/packages/sdk/src/model-utils.ts index 3ab4a01e..7b54aa96 100644 --- a/packages/sdk/src/model-utils.ts +++ b/packages/sdk/src/model-utils.ts @@ -2,6 +2,7 @@ import { isDataModel, isLiteralExpr, isModel, + isTypeDef, Model, type AstNode, type Attribute, @@ -102,7 +103,7 @@ export function resolved(ref: Reference): T { export function getAuthDecl(model: Model) { let found = model.declarations.find( - (d) => isDataModel(d) && d.attributes.some((attr) => attr.decl.$refText === '@@auth'), + (d) => (isDataModel(d) || isTypeDef(d)) && d.attributes.some((attr) => attr.decl.$refText === '@@auth'), ); if (!found) { found = model.declarations.find((d) => isDataModel(d) && d.name === 'User');