diff --git a/TODO.md b/TODO.md index cd66cb46..c92d4fc6 100644 --- a/TODO.md +++ b/TODO.md @@ -99,9 +99,8 @@ - [ ] Validation - [ ] Access Policy - [ ] Short-circuit pre-create check for scalar-field only policies - - [ ] Inject "replace into" - - [ ] Inject "on conflict do update" - - [ ] Inject "insert into select from" + - [x] Inject "on conflict do update" + - [x] `check` function - [x] Migration - [ ] Databases - [x] SQLite diff --git a/packages/common-helpers/src/index.ts b/packages/common-helpers/src/index.ts index 7f9c421b..5b63ae85 100644 --- a/packages/common-helpers/src/index.ts +++ b/packages/common-helpers/src/index.ts @@ -4,3 +4,4 @@ export * from './param-case'; export * from './sleep'; export * from './tiny-invariant'; export * from './upper-case-first'; +export * from './zip'; diff --git a/packages/common-helpers/src/zip.ts b/packages/common-helpers/src/zip.ts new file mode 100644 index 00000000..35d4981b --- /dev/null +++ b/packages/common-helpers/src/zip.ts @@ -0,0 +1,11 @@ +/** + * Zips two arrays into an array of tuples. + */ +export function zip(arr1: T[], arr2: U[]): Array<[T, U]> { + const length = Math.min(arr1.length, arr2.length); + const result: Array<[T, U]> = []; + for (let i = 0; i < length; i++) { + result.push([arr1[i]!, arr2[i]!]); + } + return result; +} diff --git a/packages/runtime/src/client/crud-types.ts b/packages/runtime/src/client/crud-types.ts index 12c2a64e..fd9714ed 100644 --- a/packages/runtime/src/client/crud-types.ts +++ b/packages/runtime/src/client/crud-types.ts @@ -452,7 +452,7 @@ export type OmitInput> export type SelectIncludeOmit, AllowCount extends boolean> = { select?: SelectInput; - include?: IncludeInput; + include?: IncludeInput; omit?: OmitInput; }; @@ -463,14 +463,7 @@ export type SelectInput< AllowRelation extends boolean = true, > = { [Key in NonRelationFields]?: boolean; -} & (AllowRelation extends true ? IncludeInput : {}) & // relation fields - // relation count - (AllowCount extends true - ? // _count is only allowed if the model has to-many relations - HasToManyRelations extends true - ? { _count?: SelectCount } - : {} - : {}); +} & (AllowRelation extends true ? IncludeInput : {}); type SelectCount> = | boolean @@ -484,7 +477,11 @@ type SelectCount> = }; }; -export type IncludeInput> = { +export type IncludeInput< + Schema extends SchemaDef, + Model extends GetModels, + AllowCount extends boolean = true, +> = { [Key in RelationFields]?: | boolean | FindArgs< @@ -498,7 +495,12 @@ export type IncludeInput; -}; +} & (AllowCount extends true + ? // _count is only allowed if the model has to-many relations + HasToManyRelations extends true + ? { _count?: SelectCount } + : {} + : {}); export type Subset = { [key in keyof T]: key extends keyof U ? T[key] : never; @@ -674,7 +676,7 @@ export type FindUniqueArgs> = { data: CreateInput; - select?: SelectInput; + select?: SelectInput; include?: IncludeInput; omit?: OmitInput; }; @@ -813,7 +815,7 @@ type NestedCreateManyInput< export type UpdateArgs> = { data: UpdateInput; where: WhereUniqueInput; - select?: SelectInput; + select?: SelectInput; include?: IncludeInput; omit?: OmitInput; }; @@ -841,7 +843,7 @@ export type UpsertArgs create: CreateInput; update: UpdateInput; where: WhereUniqueInput; - select?: SelectInput; + select?: SelectInput; include?: IncludeInput; omit?: OmitInput; }; @@ -958,7 +960,7 @@ type ToOneRelationUpdateInput< export type DeleteArgs> = { where: WhereUniqueInput; - select?: SelectInput; + select?: SelectInput; include?: IncludeInput; omit?: OmitInput; }; diff --git a/packages/runtime/src/client/crud/dialects/base-dialect.ts b/packages/runtime/src/client/crud/dialects/base-dialect.ts index 602a3a4f..a8a1243c 100644 --- a/packages/runtime/src/client/crud/dialects/base-dialect.ts +++ b/packages/runtime/src/client/crud/dialects/base-dialect.ts @@ -1048,14 +1048,29 @@ export abstract class BaseCrudDialect { for (const [field, value] of Object.entries(selections.select)) { const fieldDef = requireField(this.schema, model, field); const fieldModel = fieldDef.type; - const joinPairs = buildJoinPairs(this.schema, model, parentAlias, field, fieldModel); - - // build a nested query to count the number of records in the relation - let fieldCountQuery = eb.selectFrom(fieldModel).select(eb.fn.countAll().as(`_count$${field}`)); + let fieldCountQuery: SelectQueryBuilder; // join conditions - for (const [left, right] of joinPairs) { - fieldCountQuery = fieldCountQuery.whereRef(left, '=', right); + const m2m = getManyToManyRelation(this.schema, model, field); + if (m2m) { + // many-to-many relation, count the join table + fieldCountQuery = eb + .selectFrom(fieldModel) + .innerJoin(m2m.joinTable, (join) => + join + .onRef(`${m2m.joinTable}.${m2m.otherFkName}`, '=', `${fieldModel}.${m2m.otherPKName}`) + .onRef(`${m2m.joinTable}.${m2m.parentFkName}`, '=', `${parentAlias}.${m2m.parentPKName}`), + ) + .select(eb.fn.countAll().as(`_count$${field}`)); + } else { + // build a nested query to count the number of records in the relation + fieldCountQuery = eb.selectFrom(fieldModel).select(eb.fn.countAll().as(`_count$${field}`)); + + // join conditions + const joinPairs = buildJoinPairs(this.schema, model, parentAlias, field, fieldModel); + for (const [left, right] of joinPairs) { + fieldCountQuery = fieldCountQuery.whereRef(left, '=', right); + } } // merge _count filter diff --git a/packages/runtime/src/client/crud/operations/base.ts b/packages/runtime/src/client/crud/operations/base.ts index a4ce1c52..eb17167d 100644 --- a/packages/runtime/src/client/crud/operations/base.ts +++ b/packages/runtime/src/client/crud/operations/base.ts @@ -475,7 +475,7 @@ export abstract class BaseOperationHandler { entity: rightEntity, }, ].sort((a, b) => - // the implement m2m join table's "A", "B" fk fields' order is determined + // the implicit m2m join table's "A", "B" fk fields' order is determined // by model name's sort order, and when identical (for self-relations), // field name's sort order a.model !== b.model ? a.model.localeCompare(b.model) : a.field.localeCompare(b.field), diff --git a/packages/runtime/src/client/crud/validator.ts b/packages/runtime/src/client/crud/validator.ts index 3a8bdf3e..7c8b3d5f 100644 --- a/packages/runtime/src/client/crud/validator.ts +++ b/packages/runtime/src/client/crud/validator.ts @@ -3,7 +3,14 @@ import Decimal from 'decimal.js'; import stableStringify from 'json-stable-stringify'; import { match, P } from 'ts-pattern'; import { z, ZodType } from 'zod'; -import { type BuiltinType, type EnumDef, type FieldDef, type GetModels, type SchemaDef } from '../../schema'; +import { + type BuiltinType, + type EnumDef, + type FieldDef, + type GetModels, + type ModelDef, + type SchemaDef, +} from '../../schema'; import { enumerate } from '../../utils/enumerate'; import { extractFields } from '../../utils/object-utils'; import { formatError } from '../../utils/zod-utils'; @@ -595,10 +602,18 @@ export class InputValidator { } } - const toManyRelations = Object.values(modelDef.fields).filter((def) => def.relation && def.array); + const _countSchema = this.makeCountSelectionSchema(modelDef); + if (_countSchema) { + fields['_count'] = _countSchema; + } + + return z.strictObject(fields); + } + private makeCountSelectionSchema(modelDef: ModelDef) { + const toManyRelations = Object.values(modelDef.fields).filter((def) => def.relation && def.array); if (toManyRelations.length > 0) { - fields['_count'] = z + return z .union([ z.literal(true), z.strictObject({ @@ -621,9 +636,9 @@ export class InputValidator { }), ]) .optional(); + } else { + return undefined; } - - return z.strictObject(fields); } private makeRelationSelectIncludeSchema(fieldDef: FieldDef) { @@ -677,6 +692,11 @@ export class InputValidator { } } + const _countSchema = this.makeCountSelectionSchema(modelDef); + if (_countSchema) { + fields['_count'] = _countSchema; + } + return z.strictObject(fields); } diff --git a/packages/runtime/src/client/query-utils.ts b/packages/runtime/src/client/query-utils.ts index 3fdd9858..1cfbdd14 100644 --- a/packages/runtime/src/client/query-utils.ts +++ b/packages/runtime/src/client/query-utils.ts @@ -1,3 +1,4 @@ +import { invariant } from '@zenstackhq/common-helpers'; import type { Expression, ExpressionBuilder, ExpressionWrapper } from 'kysely'; import { match } from 'ts-pattern'; import { ExpressionUtils, type FieldDef, type GetModels, type ModelDef, type SchemaDef } from '../schema'; @@ -259,11 +260,18 @@ export function getManyToManyRelation(schema: SchemaDef, model: string, field: s orderedFK = sortedFieldNames[0] === field ? ['A', 'B'] : ['B', 'A']; } + const modelIdFields = requireIdFields(schema, model); + invariant(modelIdFields.length === 1, 'Only single-field ID is supported for many-to-many relation'); + const otherIdFields = requireIdFields(schema, fieldDef.type); + invariant(otherIdFields.length === 1, 'Only single-field ID is supported for many-to-many relation'); + return { parentFkName: orderedFK[0], + parentPKName: modelIdFields[0]!, otherModel: fieldDef.type, otherField: fieldDef.relation.opposite, otherFkName: orderedFK[1], + otherPKName: otherIdFields[0]!, joinTable: fieldDef.relation.name ? `_${fieldDef.relation.name}` : `_${sortedModelNames[0]}To${sortedModelNames[1]}`, diff --git a/packages/runtime/src/plugins/policy/expression-transformer.ts b/packages/runtime/src/plugins/policy/expression-transformer.ts index d5b879b3..8502e33b 100644 --- a/packages/runtime/src/plugins/policy/expression-transformer.ts +++ b/packages/runtime/src/plugins/policy/expression-transformer.ts @@ -24,7 +24,13 @@ import type { ClientContract, CRUD } from '../../client/contract'; import { getCrudDialect } from '../../client/crud/dialects'; import type { BaseCrudDialect } from '../../client/crud/dialects/base-dialect'; import { InternalError, QueryError } from '../../client/errors'; -import { getModel, getRelationForeignKeyFieldPairs, requireField, requireIdFields } from '../../client/query-utils'; +import { + getManyToManyRelation, + getModel, + getRelationForeignKeyFieldPairs, + requireField, + requireIdFields, +} from '../../client/query-utils'; import type { BinaryExpression, BinaryOperator, @@ -44,7 +50,7 @@ import { type SchemaDef, } from '../../schema'; import { ExpressionEvaluator } from './expression-evaluator'; -import { conjunction, disjunction, logicalNot, trueNode } from './utils'; +import { conjunction, disjunction, falseNode, logicalNot, trueNode } from './utils'; export type ExpressionTransformerContext = { model: GetModels; @@ -335,7 +341,13 @@ export class ExpressionTransformer { } private transformValue(value: unknown, type: BuiltinType) { - return ValueNode.create(this.dialect.transformPrimitive(value, type, false) ?? null); + if (value === true) { + return trueNode(this.dialect); + } else if (value === false) { + return falseNode(this.dialect); + } else { + return ValueNode.create(this.dialect.transformPrimitive(value, type, false) ?? null); + } } @expr('unary') @@ -537,6 +549,11 @@ export class ExpressionTransformer { relationModel: string, context: ExpressionTransformerContext, ): SelectQueryNode { + const m2m = getManyToManyRelation(this.schema, context.model, field); + if (m2m) { + return this.transformManyToManyRelationAccess(m2m, context); + } + const fromModel = context.model; const { keyPairs, ownedByModel } = getRelationForeignKeyFieldPairs(this.schema, fromModel, field); @@ -574,6 +591,28 @@ export class ExpressionTransformer { }; } + private transformManyToManyRelationAccess( + m2m: NonNullable>, + context: ExpressionTransformerContext, + ) { + const eb = expressionBuilder(); + const relationQuery = eb + .selectFrom(m2m.otherModel) + // inner join with join table and additionally filter by the parent model + .innerJoin(m2m.joinTable, (join) => + join + // relation model pk to join table fk + .onRef(`${m2m.otherModel}.${m2m.otherPKName}`, '=', `${m2m.joinTable}.${m2m.otherFkName}`) + // parent model pk to join table fk + .onRef( + `${m2m.joinTable}.${m2m.parentFkName}`, + '=', + `${context.alias ?? context.model}.${m2m.parentPKName}`, + ), + ); + return relationQuery.toOperationNode(); + } + private createColumnRef(column: string, context: ExpressionTransformerContext): ReferenceNode { return ReferenceNode.create(ColumnNode.create(column), TableNode.create(context.alias ?? context.model)); } diff --git a/packages/runtime/src/plugins/policy/policy-handler.ts b/packages/runtime/src/plugins/policy/policy-handler.ts index 7ab7d2f0..7b55b334 100644 --- a/packages/runtime/src/plugins/policy/policy-handler.ts +++ b/packages/runtime/src/plugins/policy/policy-handler.ts @@ -4,6 +4,8 @@ import { BinaryOperationNode, ColumnNode, DeleteQueryNode, + expressionBuilder, + ExpressionWrapper, FromNode, FunctionNode, IdentifierNode, @@ -34,7 +36,7 @@ import { getCrudDialect } from '../../client/crud/dialects'; import type { BaseCrudDialect } from '../../client/crud/dialects/base-dialect'; import { InternalError, QueryError } from '../../client/errors'; import type { ProceedKyselyQueryFunction } from '../../client/plugin'; -import { requireField, requireIdFields, requireModel } from '../../client/query-utils'; +import { getManyToManyRelation, requireField, requireIdFields, requireModel } from '../../client/query-utils'; import { ExpressionUtils, type BuiltinType, type Expression, type GetModels, type SchemaDef } from '../../schema'; import { ColumnCollector } from './column-collector'; import { RejectedByPolicyError } from './errors'; @@ -68,72 +70,46 @@ export class PolicyHandler extends OperationNodeTransf } if (!this.isMutationQueryNode(node)) { - // transform and proceed read without transaction + // transform and proceed with read directly return proceed(this.transformNode(node)); } - let mutationRequiresTransaction = false; const { mutationModel } = this.getMutationModel(node); if (InsertQueryNode.is(node)) { - // reject create if unconditional deny - const constCondition = this.tryGetConstantPolicy(mutationModel, 'create'); - if (constCondition === false) { - throw new RejectedByPolicyError(mutationModel); - } else if (constCondition === undefined) { - mutationRequiresTransaction = true; + // pre-create policy evaluation happens before execution of the query + const isManyToManyJoinTable = this.isManyToManyJoinTable(mutationModel); + let needCheckPreCreate = true; + + // many-to-many join table is not a model so can't have policies on it + if (!isManyToManyJoinTable) { + // check constant policies + const constCondition = this.tryGetConstantPolicy(mutationModel, 'create'); + if (constCondition === true) { + needCheckPreCreate = false; + } else if (constCondition === false) { + throw new RejectedByPolicyError(mutationModel); + } } - } - if (!mutationRequiresTransaction && !node.returning) { - // transform and proceed mutation without transaction - return proceed(this.transformNode(node)); + if (needCheckPreCreate) { + await this.enforcePreCreatePolicy(node, mutationModel, isManyToManyJoinTable, proceed); + } } - if (InsertQueryNode.is(node)) { - await this.enforcePreCreatePolicy(node, proceed); - } - const transformedNode = this.transformNode(node); - const result = await proceed(transformedNode); + // proceed with query - if (!this.onlyReturningId(node)) { + const result = await proceed(this.transformNode(node)); + + if (!node.returning || this.onlyReturningId(node)) { + return result; + } else { const readBackResult = await this.processReadBack(node, result, proceed); if (readBackResult.rows.length !== result.rows.length) { throw new RejectedByPolicyError(mutationModel, 'result is not allowed to be read back'); } return readBackResult; - } else { - // reading id fields bypasses policy - return result; } - - // TODO: run in transaction - // let readBackError = false; - - // transform and post-process in a transaction - // const result = await transaction(async (txProceed) => { - // if (InsertQueryNode.is(node)) { - // await this.enforcePreCreatePolicy(node, txProceed); - // } - // const transformedNode = this.transformNode(node); - // const result = await txProceed(transformedNode); - - // if (!this.onlyReturningId(node)) { - // const readBackResult = await this.processReadBack(node, result, txProceed); - // if (readBackResult.rows.length !== result.rows.length) { - // readBackError = true; - // } - // return readBackResult; - // } else { - // return result; - // } - // }); - - // if (readBackError) { - // throw new RejectedByPolicyError(mutationModel, 'result is not allowed to be read back'); - // } - - // return result; } // #region overrides @@ -280,16 +256,92 @@ export class PolicyHandler extends OperationNodeTransf return selectedColumns.every((c) => idFields.includes(c)); } - private async enforcePreCreatePolicy(node: InsertQueryNode, proceed: ProceedKyselyQueryFunction) { - const { mutationModel } = this.getMutationModel(node); + private async enforcePreCreatePolicy( + node: InsertQueryNode, + mutationModel: GetModels, + isManyToManyJoinTable: boolean, + proceed: ProceedKyselyQueryFunction, + ) { const fields = node.columns?.map((c) => c.column.name) ?? []; - const valueRows = node.values ? this.unwrapCreateValueRows(node.values, mutationModel, fields) : [[]]; + const valueRows = node.values + ? this.unwrapCreateValueRows(node.values, mutationModel, fields, isManyToManyJoinTable) + : [[]]; for (const values of valueRows) { - await this.enforcePreCreatePolicyForOne( - mutationModel, - fields, - values.map((v) => v.node), - proceed, + if (isManyToManyJoinTable) { + await this.enforcePreCreatePolicyForManyToManyJoinTable( + mutationModel, + fields, + values.map((v) => v.node), + proceed, + ); + } else { + await this.enforcePreCreatePolicyForOne( + mutationModel, + fields, + values.map((v) => v.node), + proceed, + ); + } + } + } + + private async enforcePreCreatePolicyForManyToManyJoinTable( + tableName: GetModels, + fields: string[], + values: OperationNode[], + proceed: ProceedKyselyQueryFunction, + ) { + const m2m = this.resolveManyToManyJoinTable(tableName); + invariant(m2m); + + // m2m create requires both sides to be updatable + invariant(fields.includes('A') && fields.includes('B'), 'many-to-many join table must have A and B fk fields'); + + const aIndex = fields.indexOf('A'); + const aNode = values[aIndex]!; + const bIndex = fields.indexOf('B'); + const bNode = values[bIndex]!; + invariant(ValueNode.is(aNode) && ValueNode.is(bNode), 'A and B values must be ValueNode'); + + const aValue = aNode.value; + const bValue = bNode.value; + invariant(aValue !== null && aValue !== undefined, 'A value cannot be null or undefined'); + invariant(bValue !== null && bValue !== undefined, 'B value cannot be null or undefined'); + + const eb = expressionBuilder(); + + const filterA = this.buildPolicyFilter(m2m.firstModel as GetModels, undefined, 'update'); + const queryA = eb + .selectFrom(m2m.firstModel) + .where(eb(eb.ref(`${m2m.firstModel}.${m2m.firstIdField}`), '=', aValue)) + .select(() => new ExpressionWrapper(filterA).as('$t')); + + const filterB = this.buildPolicyFilter(m2m.secondModel as GetModels, undefined, 'update'); + const queryB = eb + .selectFrom(m2m.secondModel) + .where(eb(eb.ref(`${m2m.secondModel}.${m2m.secondIdField}`), '=', bValue)) + .select(() => new ExpressionWrapper(filterB).as('$t')); + + // select both conditions in one query + const queryNode: SelectQueryNode = { + kind: 'SelectQueryNode', + selections: [ + SelectionNode.create(AliasNode.create(queryA.toOperationNode(), IdentifierNode.create('$conditionA'))), + SelectionNode.create(AliasNode.create(queryB.toOperationNode(), IdentifierNode.create('$conditionB'))), + ], + }; + + const result = await proceed(queryNode); + if (!result.rows[0]?.$conditionA) { + throw new RejectedByPolicyError( + m2m.firstModel as GetModels, + `many-to-many relation participant model "${m2m.firstModel}" not updatable`, + ); + } + if (!result.rows[0]?.$conditionB) { + throw new RejectedByPolicyError( + m2m.secondModel as GetModels, + `many-to-many relation participant model "${m2m.secondModel}" not updatable`, ); } } @@ -355,23 +407,33 @@ export class PolicyHandler extends OperationNodeTransf } } - private unwrapCreateValueRows(node: OperationNode, model: GetModels, fields: string[]) { + private unwrapCreateValueRows( + node: OperationNode, + model: GetModels, + fields: string[], + isManyToManyJoinTable: boolean, + ) { if (ValuesNode.is(node)) { - return node.values.map((v) => this.unwrapCreateValueRow(v.values, model, fields)); + return node.values.map((v) => this.unwrapCreateValueRow(v.values, model, fields, isManyToManyJoinTable)); } else if (PrimitiveValueListNode.is(node)) { - return [this.unwrapCreateValueRow(node.values, model, fields)]; + return [this.unwrapCreateValueRow(node.values, model, fields, isManyToManyJoinTable)]; } else { throw new InternalError(`Unexpected node kind: ${node.kind} for unwrapping create values`); } } - private unwrapCreateValueRow(data: readonly unknown[], model: GetModels, fields: string[]) { + private unwrapCreateValueRow( + data: readonly unknown[], + model: GetModels, + fields: string[], + isImplicitManyToManyJoinTable: boolean, + ) { invariant(data.length === fields.length, 'data length must match fields length'); const result: { node: OperationNode; raw: unknown }[] = []; for (let i = 0; i < data.length; i++) { const item = data[i]!; - const fieldDef = requireField(this.client.$schema, model, fields[i]!); if (typeof item === 'object' && item && 'kind' in item) { + const fieldDef = requireField(this.client.$schema, model, fields[i]!); invariant(item.kind === 'ValueNode', 'expecting a ValueNode'); result.push({ node: ValueNode.create( @@ -384,7 +446,15 @@ export class PolicyHandler extends OperationNodeTransf raw: (item as ValueNode).value, }); } else { - const value = this.dialect.transformPrimitive(item, fieldDef.type as BuiltinType, !!fieldDef.array); + let value: unknown = item; + + // many-to-many join table is not a model so we don't have field definitions, + // but there's no need to transform values anyway because they're the fields + // are all foreign keys + if (!isImplicitManyToManyJoinTable) { + const fieldDef = requireField(this.client.$schema, model, fields[i]!); + value = this.dialect.transformPrimitive(item, fieldDef.type as BuiltinType, !!fieldDef.array); + } if (Array.isArray(value)) { result.push({ node: RawNode.createWithSql(this.dialect.buildArrayLiteralSQL(value)), @@ -504,6 +574,12 @@ export class PolicyHandler extends OperationNodeTransf } buildPolicyFilter(model: GetModels, alias: string | undefined, operation: CRUD) { + // first check if it's a many-to-many join table, and if so, handle specially + const m2mFilter = this.getModelPolicyFilterForManyToManyJoinTable(model, alias, operation); + if (m2mFilter) { + return m2mFilter; + } + const policies = this.getModelPolicies(model, operation); if (policies.length === 0) { return falseNode(this.dialect); @@ -592,8 +668,8 @@ export class PolicyHandler extends OperationNodeTransf }); } - private getModelPolicies(modelName: string, operation: PolicyOperation) { - const modelDef = requireModel(this.client.$schema, modelName); + private getModelPolicies(model: string, operation: PolicyOperation) { + const modelDef = requireModel(this.client.$schema, model); const result: Policy[] = []; const extractOperations = (expr: Expression) => { @@ -623,5 +699,92 @@ export class PolicyHandler extends OperationNodeTransf return result; } + private resolveManyToManyJoinTable(tableName: string) { + for (const model of Object.values(this.client.$schema.models)) { + for (const field of Object.values(model.fields)) { + const m2m = getManyToManyRelation(this.client.$schema, model.name, field.name); + if (m2m?.joinTable === tableName) { + const sortedRecord = [ + { + model: model.name, + field: field.name, + }, + { + model: m2m.otherModel, + field: m2m.otherField, + }, + ].sort(this.manyToManySorter); + + const firstIdFields = requireIdFields(this.client.$schema, sortedRecord[0]!.model); + const secondIdFields = requireIdFields(this.client.$schema, sortedRecord[1]!.model); + invariant( + firstIdFields.length === 1 && secondIdFields.length === 1, + 'only single-field id is supported for implicit many-to-many join table', + ); + + return { + firstModel: sortedRecord[0]!.model, + firstField: sortedRecord[0]!.field, + firstIdField: firstIdFields[0]!, + secondModel: sortedRecord[1]!.model, + secondField: sortedRecord[1]!.field, + secondIdField: secondIdFields[0]!, + }; + } + } + } + return undefined; + } + + private manyToManySorter(a: { model: string; field: string }, b: { model: string; field: string }): number { + // the implicit m2m join table's "A", "B" fk fields' order is determined + // by model name's sort order, and when identical (for self-relations), + // field name's sort order + return a.model !== b.model ? a.model.localeCompare(b.model) : a.field.localeCompare(b.field); + } + + private isManyToManyJoinTable(tableName: string) { + return !!this.resolveManyToManyJoinTable(tableName); + } + + private getModelPolicyFilterForManyToManyJoinTable( + tableName: string, + alias: string | undefined, + operation: PolicyOperation, + ): OperationNode | undefined { + const m2m = this.resolveManyToManyJoinTable(tableName); + if (!m2m) { + return undefined; + } + + // join table's permission: + // - read: requires both sides to be readable + // - mutation: requires both sides to be updatable + + const checkForOperation = operation === 'read' ? 'read' : 'update'; + const eb = expressionBuilder(); + const joinTable = alias ?? tableName; + + const aQuery = eb + .selectFrom(m2m.firstModel) + .whereRef(`${m2m.firstModel}.${m2m.firstIdField}`, '=', `${joinTable}.A`) + .select(() => + new ExpressionWrapper( + this.buildPolicyFilter(m2m.firstModel as GetModels, undefined, checkForOperation), + ).as('$conditionA'), + ); + + const bQuery = eb + .selectFrom(m2m.secondModel) + .whereRef(`${m2m.secondModel}.${m2m.secondIdField}`, '=', `${joinTable}.B`) + .select(() => + new ExpressionWrapper( + this.buildPolicyFilter(m2m.secondModel as GetModels, undefined, checkForOperation), + ).as('$conditionB'), + ); + + return eb.and([aQuery, bQuery]).toOperationNode(); + } + // #endregion } diff --git a/packages/runtime/test/policy/crud/create.test.ts b/packages/runtime/test/policy/crud/create.test.ts index dbd7a414..d5eb0657 100644 --- a/packages/runtime/test/policy/crud/create.test.ts +++ b/packages/runtime/test/policy/crud/create.test.ts @@ -273,4 +273,98 @@ model Profile { }, }); }); + + it('works with unnamed many-to-many relation', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + groups Group[] + private Boolean + @@allow('create,read', true) + @@allow('update', !private) +} + +model Group { + id Int @id + private Boolean + users User[] + @@allow('create,read', true) + @@allow('update', !private) +} + `, + { usePrismaPush: true }, + ); + + await expect( + db.user.create({ + data: { id: 1, private: false, groups: { create: [{ id: 1, private: false }] } }, + }), + ).toResolveTruthy(); + + await expect( + db.user.create({ + data: { id: 2, private: true, groups: { create: [{ id: 2, private: false }] } }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.user.create({ + data: { id: 2, private: false, groups: { create: [{ id: 2, private: true }] } }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.user.create({ + data: { id: 2, private: true, groups: { create: [{ id: 2, private: true }] } }, + }), + ).toBeRejectedByPolicy(); + }); + + it('works with named many-to-many relation', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + groups Group[] @relation("UserGroups") + private Boolean + @@allow('create,read', true) + @@allow('update', !private) +} + +model Group { + id Int @id + private Boolean + users User[] @relation("UserGroups") + @@allow('create,read', true) + @@allow('update', !private) +} + `, + { usePrismaPush: true }, + ); + + await expect( + db.user.create({ + data: { id: 1, private: false, groups: { create: [{ id: 1, private: false }] } }, + }), + ).toResolveTruthy(); + + await expect( + db.user.create({ + data: { id: 2, private: true, groups: { create: [{ id: 2, private: false }] } }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.user.create({ + data: { id: 2, private: false, groups: { create: [{ id: 2, private: true }] } }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.user.create({ + data: { id: 2, private: true, groups: { create: [{ id: 2, private: true }] } }, + }), + ).toBeRejectedByPolicy(); + }); }); diff --git a/packages/runtime/test/policy/crud/read.test.ts b/packages/runtime/test/policy/crud/read.test.ts index 467abea6..46f4e38b 100644 --- a/packages/runtime/test/policy/crud/read.test.ts +++ b/packages/runtime/test/policy/crud/read.test.ts @@ -165,6 +165,86 @@ model Bar { }); }); + it('works with unnamed many-to-many relation read', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + groups Group[] + @@allow('all', true) +} + +model Group { + id Int @id + private Boolean + users User[] + @@allow('read', !private) +} +`, + { usePrismaPush: true }, + ); + + await db.$unuseAll().user.create({ + data: { + id: 1, + groups: { + create: [ + { id: 1, private: true }, + { id: 2, private: false }, + ], + }, + }, + }); + await expect(db.user.findFirst({ include: { groups: true } })).resolves.toMatchObject({ + groups: [{ id: 2 }], + }); + await expect( + db.user.findFirst({ where: { id: 1 }, select: { _count: { select: { groups: true } } } }), + ).resolves.toMatchObject({ + _count: { groups: 1 }, + }); + }); + + it('works with named many-to-many relation read', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + groups Group[] @relation("UserGroups") + @@allow('all', true) +} + +model Group { + id Int @id + private Boolean + users User[] @relation("UserGroups") + @@allow('read', !private) +} +`, + { usePrismaPush: true }, + ); + + await db.$unuseAll().user.create({ + data: { + id: 1, + groups: { + create: [ + { id: 1, private: true }, + { id: 2, private: false }, + ], + }, + }, + }); + await expect(db.user.findFirst({ include: { groups: true } })).resolves.toMatchObject({ + groups: [{ id: 2 }], + }); + await expect( + db.user.findFirst({ where: { id: 1 }, select: { _count: { select: { groups: true } } } }), + ).resolves.toMatchObject({ + _count: { groups: 1 }, + }); + }); + it('works with filtered by to-one relation field', async () => { const db = await createPolicyTestClient( ` diff --git a/packages/runtime/test/policy/crud/update.test.ts b/packages/runtime/test/policy/crud/update.test.ts index f7b2b820..975000f5 100644 --- a/packages/runtime/test/policy/crud/update.test.ts +++ b/packages/runtime/test/policy/crud/update.test.ts @@ -338,6 +338,88 @@ model Post { }); await expect(db.user.update({ where: { id: 3 }, data: { name: 'UpdatedUser3' } })).toResolveTruthy(); }); + + it('works with unnamed many-to-many relation check', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + name String + groups Group[] + @@allow('create,read', true) + @@allow('update', groups?[!private]) +} + +model Group { + id Int @id + private Boolean + members User[] + @@allow('all', true) +} +`, + { usePrismaPush: true }, + ); + + await db.$unuseAll().user.create({ + data: { + id: 1, + name: 'User1', + groups: { + create: [ + { id: 1, private: true }, + { id: 2, private: false }, + ], + }, + }, + }); + + await expect(db.user.update({ where: { id: 1 }, data: { name: 'User2' } })).toResolveTruthy(); + + await db.$unuseAll().group.update({ where: { id: 2 }, data: { private: true } }); + // not satisfying update policy anymore + await expect(db.user.update({ where: { id: 1 }, data: { name: 'User3' } })).toBeRejectedNotFound(); + }); + + it('works with named many-to-many relation check', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + name String + groups Group[] @relation("UserGroups") + @@allow('create,read', true) + @@allow('update', groups?[!private]) +} + +model Group { + id Int @id + private Boolean + members User[] @relation("UserGroups") + @@allow('all', true) +} +`, + { usePrismaPush: true }, + ); + + await db.$unuseAll().user.create({ + data: { + id: 1, + name: 'User1', + groups: { + create: [ + { id: 1, private: true }, + { id: 2, private: false }, + ], + }, + }, + }); + + await expect(db.user.update({ where: { id: 1 }, data: { name: 'User2' } })).toResolveTruthy(); + + await db.$unuseAll().group.update({ where: { id: 2 }, data: { private: true } }); + // not satisfying update policy anymore + await expect(db.user.update({ where: { id: 1 }, data: { name: 'User3' } })).toBeRejectedNotFound(); + }); }); describe('Nested create tests', () => { @@ -888,6 +970,135 @@ model Profile { }), ).toResolveTruthy(); }); + + it('works with many-to-many relation manipulation', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + private Boolean + groups Group[] @relation("UserGroups") + @@allow('create,read', true) + @@allow('update,delete', !private) +} + +model Group { + id Int @id + private Boolean + members User[] @relation("UserGroups") + @@allow('create,read', true) + @@allow('update,delete', !private) +} +`, + { usePrismaPush: true }, + ); + + await db.$unuseAll().user.create({ data: { id: 1, private: true } }); + await db.$unuseAll().user.create({ data: { id: 2, private: false } }); + + // user not updatable + await expect( + db.user.update({ where: { id: 1 }, data: { groups: { create: { id: 1, private: false } } } }), + ).toBeRejectedByPolicy(); + + // group not updatable + await expect( + db.user.update({ where: { id: 2 }, data: { groups: { create: { id: 1, private: true } } } }), + ).toBeRejectedByPolicy(); + + // both updatable + await expect( + db.user.update({ + where: { id: 2 }, + data: { groups: { create: { id: 1, private: false } } }, + include: { groups: true }, + }), + ).toResolveTruthy(); + + // disconnect + await expect( + db.user.update({ where: { id: 2 }, data: { groups: { disconnect: { id: 1 } } } }), + ).toResolveTruthy(); + + // set + await expect( + db.user.update({ where: { id: 2 }, data: { groups: { set: [{ id: 1 }] } } }), + ).toResolveTruthy(); + + // delete + await expect( + db.user.update({ where: { id: 2 }, data: { groups: { delete: { id: 1 } } } }), + ).toResolveTruthy(); + + // recreate group as private + await db.$unuseAll().group.create({ data: { id: 2, private: true } }); + + // connect rejected + await expect( + db.user.update({ where: { id: 2 }, data: { groups: { connect: { id: 2 } } } }), + ).toBeRejectedByPolicy(); + + // disconnect rejected + await db.$unuseAll().user.update({ where: { id: 2 }, data: { groups: { connect: { id: 2 } } } }); + await expect( + db.user.update({ + where: { id: 2 }, + data: { groups: { disconnect: { id: 2 } } }, + include: { groups: true }, + }), + ).resolves.toMatchObject({ + groups: [{ id: 2 }], // verify not disconnected + }); + + // delete rejected + await expect( + db.user.update({ + where: { id: 2 }, + data: { groups: { delete: { id: 2 } } }, + include: { groups: true }, + }), + ).toBeRejectedNotFound(); + + // set rejected + await expect( + db.user.update({ + where: { id: 2 }, + data: { groups: { set: [] } }, + include: { groups: true }, + }), + ).resolves.toMatchObject({ + groups: [{ id: 2 }], // verify not disconnected + }); + + await db.$unuseAll().group.update({ where: { id: 2 }, data: { private: false } }); + await db.$unuseAll().group.create({ data: { id: 3, private: true } }); + + // set rejected + await expect( + db.user.update({ + where: { id: 2 }, + data: { groups: { set: [{ id: 3 }] } }, + include: { groups: true }, + }), + ).toBeRejectedByPolicy(); + + // relation unchanged + await expect(db.user.findUnique({ where: { id: 2 }, include: { groups: true } })).resolves.toMatchObject({ + groups: [{ id: 2 }], + }); + + // set success + await db.$unuseAll().group.update({ where: { id: 3 }, data: { private: false } }); + await expect( + db.user.update({ + where: { id: 2 }, + data: { groups: { set: [{ id: 3 }] } }, + include: { groups: true }, + }), + ).resolves.toMatchObject({ + groups: [{ id: 3 }], + }); + }); }); describe('Upsert tests', () => { diff --git a/packages/runtime/test/policy/migrated/relation-many-to-many-filter.test.ts b/packages/runtime/test/policy/migrated/relation-many-to-many-filter.test.ts new file mode 100644 index 00000000..916f8c50 --- /dev/null +++ b/packages/runtime/test/policy/migrated/relation-many-to-many-filter.test.ts @@ -0,0 +1,280 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from '../utils'; + +describe('Policy many-to-many relation tests', () => { + const model = ` + model M1 { + id String @id @default(uuid()) + value Int + deleted Boolean @default(false) + m2 M2[] + + @@allow('read', !deleted) + @@allow('create,update', true) + } + + model M2 { + id String @id @default(uuid()) + value Int + deleted Boolean @default(false) + m1 M1[] + + @@allow('read', !deleted) + @@allow('create,update', true) + } + `; + + it('some filter', async () => { + const db = await createPolicyTestClient(model, { usePrismaPush: true }); + + await db.m1.create({ + data: { + id: '1', + value: 1, + m2: { + create: [ + { + id: '1', + value: 1, + }, + { + id: '2', + value: 2, + deleted: true, + }, + ], + }, + }, + }); + + // m1 -> m2 lookup + const r = await db.m1.findFirst({ + where: { + id: '1', + m2: { + some: {}, + }, + }, + include: { + _count: { select: { m2: true } }, + }, + }); + expect(r._count.m2).toBe(1); + + // m2 -> m1 lookup + await expect( + db.m2.findFirst({ + where: { + id: '1', + m1: { + some: {}, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + some: { value: { gt: 1 } }, + }, + }, + }), + ).toResolveFalsy(); + + // m1 with empty m2 list + await db.m1.create({ + data: { + id: '2', + value: 1, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + some: {}, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + some: { value: { gt: 1 } }, + }, + }, + }), + ).toResolveFalsy(); + }); + + it('none filter', async () => { + const db = await createPolicyTestClient(model, { usePrismaPush: true }); + + await db.m1.create({ + data: { + id: '1', + value: 1, + m2: { + create: [ + { id: '1', value: 1 }, + { id: '2', value: 2, deleted: true }, + ], + }, + }, + }); + + // m1 -> m2 lookup + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + none: {}, + }, + }, + }), + ).toResolveFalsy(); + + // m2 -> m1 lookup + await expect( + db.m2.findFirst({ + where: { + m1: { + none: {}, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + none: { value: { gt: 1 } }, + }, + }, + }), + ).toResolveTruthy(); + + // m1 with empty m2 list + await db.m1.create({ + data: { + id: '2', + value: 2, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + none: {}, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + none: { value: { gt: 1 } }, + }, + }, + }), + ).toResolveTruthy(); + }); + + it('every filter', async () => { + const db = await createPolicyTestClient(model, { usePrismaPush: true }); + + await db.m1.create({ + data: { + id: '1', + value: 1, + m2: { + create: [ + { id: '1', value: 1 }, + { id: '2', value: 2, deleted: true }, + ], + }, + }, + }); + + // m1 -> m2 lookup + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + every: {}, + }, + }, + }), + ).toResolveTruthy(); + + // m2 -> m1 lookup + await expect( + db.m2.findFirst({ + where: { + id: '1', + m1: { + every: {}, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + every: { value: { gt: 1 } }, + }, + }, + }), + ).toResolveFalsy(); + + // m1 with empty m2 list + await db.m1.create({ + data: { + id: '2', + value: 2, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + every: {}, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + every: { value: { gt: 1 } }, + }, + }, + }), + ).toResolveTruthy(); + }); +}); diff --git a/packages/runtime/test/policy/migrated/relation-one-to-many-filter.test.ts b/packages/runtime/test/policy/migrated/relation-one-to-many-filter.test.ts new file mode 100644 index 00000000..4330c008 --- /dev/null +++ b/packages/runtime/test/policy/migrated/relation-one-to-many-filter.test.ts @@ -0,0 +1,1009 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from '../utils'; + +describe('Relation one-to-many filter', () => { + const model = ` + model M1 { + id String @id @default(uuid()) + m2 M2[] + + @@allow('all', true) + } + + model M2 { + id String @id @default(uuid()) + value Int + deleted Boolean @default(false) + m1 M1 @relation(fields: [m1Id], references:[id]) + m1Id String + m3 M3[] + + @@allow('read', !deleted) + @@allow('create', true) + } + + model M3 { + id String @id @default(uuid()) + value Int + deleted Boolean @default(false) + m2 M2 @relation(fields: [m2Id], references:[id]) + m2Id String + + @@allow('read', !deleted) + @@allow('create', true) + } + `; + + it('some filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: [ + { + value: 1, + m3: { + create: { + value: 1, + }, + }, + }, + { + value: 2, + deleted: true, + m3: { + create: { + value: 2, + deleted: true, + }, + }, + }, + ], + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + some: {}, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + some: { value: { gt: 1 } }, + }, + }, + }), + ).toResolveFalsy(); + + // include clause + + const r = await db.m1.findFirst({ + where: { id: '1' }, + include: { + m2: { + where: { + m3: { + some: {}, + }, + }, + }, + }, + }); + expect(r.m2).toHaveLength(1); + + const r1 = await db.m1.findFirst({ + where: { + id: '1', + }, + include: { + m2: { + where: { + m3: { + some: { value: { gt: 1 } }, + }, + }, + }, + }, + }); + expect(r1.m2).toHaveLength(0); + + // m1 with empty m2 list + await db.m1.create({ + data: { + id: '2', + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + some: {}, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + some: { value: { gt: 1 } }, + }, + }, + }), + ).toResolveFalsy(); + }); + + it('none filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: [ + { + value: 1, + m3: { + create: { + value: 1, + }, + }, + }, + { + value: 2, + deleted: true, + m3: { + create: { + value: 2, + deleted: true, + }, + }, + }, + ], + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + none: {}, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + none: { value: { gt: 1 } }, + }, + }, + }), + ).toResolveTruthy(); + + // include clause + + const r = await db.m1.findFirst({ + where: { id: '1' }, + include: { + m2: { + where: { + m3: { + none: {}, + }, + }, + }, + }, + }); + expect(r.m2).toHaveLength(0); + + const r1 = await db.m1.findFirst({ + where: { + id: '1', + }, + include: { + m2: { + where: { + m3: { + none: { value: { gt: 1 } }, + }, + }, + }, + }, + }); + expect(r1.m2).toHaveLength(1); + + // m1 with empty m2 list + await db.m1.create({ + data: { + id: '2', + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + none: {}, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + none: { value: { gt: 1 } }, + }, + }, + }), + ).toResolveTruthy(); + }); + + it('every filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: [ + { + value: 1, + m3: { + create: { + value: 1, + }, + }, + }, + { + value: 2, + deleted: true, + m3: { + create: { + value: 2, + deleted: true, + }, + }, + }, + ], + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + every: {}, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + every: { value: { gt: 1 } }, + }, + }, + }), + ).toResolveFalsy(); + + // include clause + + const r = await db.m1.findFirst({ + where: { id: '1' }, + include: { + m2: { + where: { + m3: { + every: {}, + }, + }, + }, + }, + }); + expect(r.m2).toHaveLength(1); + + const r1 = await db.m1.findFirst({ + where: { + id: '1', + }, + include: { + m2: { + where: { + m3: { + every: { value: { gt: 1 } }, + }, + }, + }, + }, + }); + expect(r1.m2).toHaveLength(0); + + // m1 with empty m2 list + await db.m1.create({ + data: { + id: '2', + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + every: {}, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + every: { value: { gt: 1 } }, + }, + }, + }), + ).toResolveTruthy(); + }); + + it('_count filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: [ + { + value: 1, + m3: { + create: { + value: 1, + }, + }, + }, + { + value: 2, + deleted: true, + m3: { + create: { + value: 2, + deleted: true, + }, + }, + }, + ], + }, + }, + }); + + await expect(db.m1.findFirst({ include: { _count: true } })).resolves.toMatchObject({ _count: { m2: 1 } }); + await expect(db.m1.findFirst({ include: { _count: { select: { m2: true } } } })).resolves.toMatchObject({ + _count: { m2: 1 }, + }); + await expect( + db.m1.findFirst({ include: { _count: { select: { m2: { where: { value: { gt: 0 } } } } } } }), + ).resolves.toMatchObject({ _count: { m2: 1 } }); + await expect( + db.m1.findFirst({ include: { _count: { select: { m2: { where: { value: { gt: 1 } } } } } } }), + ).resolves.toMatchObject({ _count: { m2: 0 } }); + + await expect(db.m1.findFirst({ include: { m2: { select: { _count: true } } } })).resolves.toMatchObject({ + m2: [{ _count: { m3: 1 } }], + }); + await expect( + db.m1.findFirst({ include: { m2: { select: { _count: { select: { m3: true } } } } } }), + ).resolves.toMatchObject({ m2: [{ _count: { m3: 1 } }] }); + await expect( + db.m1.findFirst({ + include: { m2: { select: { _count: { select: { m3: { where: { value: { gt: 1 } } } } } } } }, + }), + ).resolves.toMatchObject({ m2: [{ _count: { m3: 0 } }] }); + }); +}); + +// TODO: field-level policy support +describe.skip('Relation one-to-many filter with field-level rules', () => { + const model = ` + model M1 { + id String @id @default(uuid()) + m2 M2[] + + @@allow('all', true) + } + + model M2 { + id String @id @default(uuid()) + value Int @allow('read', !deleted) + deleted Boolean @default(false) + m1 M1 @relation(fields: [m1Id], references:[id]) + m1Id String + m3 M3[] + + @@allow('read', true) + @@allow('create', true) + } + + model M3 { + id String @id @default(uuid()) + value Int @deny('read', deleted) + deleted Boolean @default(false) + m2 M2 @relation(fields: [m2Id], references:[id]) + m2Id String + + @@allow('read', true) + @@allow('create', true) + } + `; + + it('some filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: [ + { + id: '2-1', + value: 1, + m3: { + create: { + id: '3-1', + value: 1, + }, + }, + }, + { + id: '2-2', + value: 2, + deleted: true, + m3: { + create: { + id: '3-2', + value: 2, + deleted: true, + }, + }, + }, + ], + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + some: {}, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + some: { value: { gt: 1 } }, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + some: { id: '2-2' }, + }, + }, + }), + ).toResolveTruthy(); + + // include clause + + const r = await db.m1.findFirst({ + where: { id: '1' }, + include: { + m2: { + where: { + m3: { + some: {}, + }, + }, + }, + }, + }); + expect(r.m2).toHaveLength(2); + + let r1 = await db.m1.findFirst({ + where: { + id: '1', + }, + include: { + m2: { + where: { + m3: { + some: { value: { gt: 1 } }, + }, + }, + }, + }, + }); + expect(r1.m2).toHaveLength(0); + + r1 = await db.m1.findFirst({ + where: { + id: '1', + }, + include: { + m2: { + where: { + m3: { + some: { id: { equals: '3-2' } }, + }, + }, + }, + }, + }); + expect(r1.m2).toHaveLength(1); + }); + + it('none filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: [ + { + id: '2-1', + value: 1, + m3: { + create: { + id: '3-1', + value: 1, + }, + }, + }, + { + id: '2-2', + value: 2, + deleted: true, + m3: { + create: { + id: '3-2', + value: 2, + deleted: true, + }, + }, + }, + ], + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + none: {}, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + none: { value: { gt: 1 } }, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + none: { id: '2-1' }, + }, + }, + }), + ).toResolveFalsy(); + + // include clause + + let r = await db.m1.findFirst({ + where: { + id: '1', + }, + include: { + m2: { + where: { + m3: { + none: { value: { gt: 1 } }, + }, + }, + }, + }, + }); + expect(r.m2).toHaveLength(2); + + r = await db.m1.findFirst({ + where: { + id: '1', + }, + include: { + m2: { + where: { + m3: { + none: { id: { equals: '3-2' } }, + }, + }, + }, + }, + }); + expect(r.m2).toHaveLength(1); + }); + + it('every filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: [ + { + id: '2-1', + value: 1, + m3: { + create: { + id: '3-1', + value: 1, + }, + }, + }, + { + id: '2-2', + value: 2, + deleted: true, + m3: { + create: { + id: '3-2', + value: 2, + deleted: true, + }, + }, + }, + ], + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + every: {}, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + every: { value: { gt: 1 } }, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + every: { id: { contains: '2' } }, + }, + }, + }), + ).toResolveTruthy(); + + // include clause + + const r = await db.m1.findFirst({ + where: { id: '1' }, + include: { + m2: { + where: { + m3: { + every: {}, + }, + }, + }, + }, + }); + expect(r.m2).toHaveLength(2); + + let r1 = await db.m1.findFirst({ + where: { + id: '1', + }, + include: { + m2: { + where: { + m3: { + every: { value: { gt: 1 } }, + }, + }, + }, + }, + }); + expect(r1.m2).toHaveLength(1); + + r1 = await db.m1.findFirst({ + where: { + id: '1', + }, + include: { + m2: { + where: { + m3: { + every: { id: { contains: '3' } }, + }, + }, + }, + }, + }); + expect(r1.m2).toHaveLength(2); + }); +}); + +// TODO: field-level policy support +describe.skip('Relation one-to-many filter with field-level override rules', () => { + const model = ` + model M1 { + id String @id @default(uuid()) + m2 M2[] + + @@allow('all', true) + } + + model M2 { + id String @id @default(uuid()) @allow('read', true, true) + value Int @allow('read', !deleted) + deleted Boolean @default(false) + m1 M1 @relation(fields: [m1Id], references:[id]) + m1Id String + + @@allow('read', !deleted) + @@allow('create', true) + } + `; + + it('some filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: [ + { + id: '2-1', + value: 1, + }, + { + id: '2-2', + value: 2, + deleted: true, + }, + ], + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + some: {}, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + some: { value: { gt: 1 } }, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + some: { id: '2-2' }, + }, + }, + }), + ).toResolveTruthy(); + }); + + it('none filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: [ + { + id: '2-1', + value: 1, + }, + { + id: '2-2', + value: 2, + deleted: true, + }, + ], + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + none: {}, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + none: { value: { gt: 1 } }, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + none: { id: '2-1' }, + }, + }, + }), + ).toResolveFalsy(); + }); + + it('every filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: [ + { + id: '2-1', + value: 1, + }, + { + id: '2-2', + value: 2, + deleted: true, + }, + ], + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + every: {}, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + every: { value: { gt: 1 } }, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + every: { id: { contains: '2' } }, + }, + }, + }), + ).toResolveTruthy(); + }); +}); diff --git a/packages/runtime/test/policy/migrated/relation-one-to-one-filter.test.ts b/packages/runtime/test/policy/migrated/relation-one-to-one-filter.test.ts new file mode 100644 index 00000000..060eea77 --- /dev/null +++ b/packages/runtime/test/policy/migrated/relation-one-to-one-filter.test.ts @@ -0,0 +1,1096 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from '../utils'; + +describe('Relation one-to-one filter', () => { + const model = ` + model M1 { + id String @id @default(uuid()) + m2 M2? + + @@allow('all', true) + } + + model M2 { + id String @id @default(uuid()) + value Int + deleted Boolean @default(false) + m1 M1 @relation(fields: [m1Id], references:[id]) + m1Id String @unique + m3 M3? + + @@allow('read', !deleted) + @@allow('create', true) + } + + model M3 { + id String @id @default(uuid()) + value Int + deleted Boolean @default(false) + m2 M2 @relation(fields: [m2Id], references:[id]) + m2Id String @unique + + @@allow('read', !deleted) + @@allow('create', true) + } + `; + + it('is filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: { + value: 1, + m3: { + create: { + value: 1, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + is: { value: 1 }, + }, + }, + }), + ).toResolveTruthy(); + + // m1 with m2 + await db.m1.create({ + data: { + id: '2', + m2: { + create: { + value: 1, + deleted: true, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + is: { value: 1 }, + }, + }, + }), + ).toResolveFalsy(); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '3', + m2: { + create: { + value: 1, + m3: { + create: { + value: 1, + deleted: true, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '3', + m2: { + is: { + m3: { value: 1 }, + }, + }, + }, + }), + ).toResolveFalsy(); + + // m1 with null m2 + await db.m1.create({ + data: { + id: '4', + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '4', + m2: { + is: { value: 1 }, + }, + }, + }), + ).toResolveFalsy(); + }); + + it('isNot filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: { + value: 1, + m3: { + create: { + value: 1, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + isNot: { value: 0 }, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + isNot: { value: 1 }, + }, + }, + }), + ).toResolveFalsy(); + + // m1 with m2 + await db.m1.create({ + data: { + id: '2', + m2: { + create: { + value: 1, + deleted: true, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + isNot: { value: 0 }, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + isNot: { value: 1 }, + }, + }, + }), + ).toResolveTruthy(); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '3', + m2: { + create: { + value: 1, + m3: { + create: { + value: 1, + deleted: true, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '3', + m2: { + isNot: { + m3: { + isNot: { value: 0 }, + }, + }, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '3', + m2: { + isNot: { + m3: { + isNot: { value: 1 }, + }, + }, + }, + }, + }), + ).toResolveFalsy(); + + // m1 with null m2 + await db.m1.create({ + data: { + id: '4', + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '4', + m2: { + isNot: { value: 1 }, + }, + }, + }), + ).toResolveTruthy(); + }); + + it('direct object filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: { + value: 1, + m3: { + create: { + value: 1, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + value: 1, + }, + }, + }), + ).toResolveTruthy(); + + // m1 with m2 + await db.m1.create({ + data: { + id: '2', + m2: { + create: { + value: 1, + deleted: true, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + value: 1, + }, + }, + }), + ).toResolveFalsy(); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '3', + m2: { + create: { + value: 1, + m3: { + create: { + value: 1, + deleted: true, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '3', + m2: { + m3: { value: 1 }, + }, + }, + }), + ).toResolveFalsy(); + + // m1 with null m2 + await db.m1.create({ + data: { + id: '4', + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '4', + m2: { + value: 1, + }, + }, + }), + ).toResolveFalsy(); + }); +}); + +// TODO: field-level policy support +describe.skip('Relation one-to-one filter with field-level rules', () => { + const model = ` + model M1 { + id String @id @default(uuid()) + m2 M2? + + @@allow('all', true) + } + + model M2 { + id String @id @default(uuid()) + value Int @allow('read', !deleted) + deleted Boolean @default(false) + m1 M1 @relation(fields: [m1Id], references:[id]) + m1Id String @unique + m3 M3? + + @@allow('read', true) + @@allow('create', true) + } + + model M3 { + id String @id @default(uuid()) + value Int @allow('read', !deleted) + deleted Boolean @default(false) + m2 M2 @relation(fields: [m2Id], references:[id]) + m2Id String @unique + + @@allow('read', true) + @@allow('create', true) + } + `; + + it('is filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: { + id: '1', + value: 1, + m3: { + create: { + id: '1', + value: 1, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + is: { value: 1 }, + }, + }, + }), + ).toResolveTruthy(); + + // m1 with m2 + await db.m1.create({ + data: { + id: '2', + m2: { + create: { + id: '2', + value: 1, + deleted: true, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + is: { value: 1 }, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + is: { id: '2' }, + }, + }, + }), + ).toResolveTruthy(); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '3', + m2: { + create: { + id: '3', + value: 1, + m3: { + create: { + id: '3', + value: 1, + deleted: true, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '3', + m2: { + is: { + m3: { value: 1 }, + }, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '3', + m2: { + is: { + m3: { id: '3' }, + }, + }, + }, + }), + ).toResolveTruthy(); + }); + + it('isNot filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: { + id: '1', + value: 1, + m3: { + create: { + id: '1', + value: 1, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + isNot: { value: 0 }, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + isNot: { value: 1 }, + }, + }, + }), + ).toResolveFalsy(); + + // m1 with m2 + await db.m1.create({ + data: { + id: '2', + m2: { + create: { + id: '2', + value: 1, + deleted: true, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + isNot: { value: 0 }, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + isNot: { value: 1 }, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + isNot: { id: '2' }, + }, + }, + }), + ).toResolveFalsy(); + }); + + it('direct object filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: { + id: '1', + value: 1, + m3: { + create: { + value: 1, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + value: 1, + }, + }, + }), + ).toResolveTruthy(); + + // m1 with m2 + await db.m1.create({ + data: { + id: '2', + m2: { + create: { + id: '2', + value: 1, + deleted: true, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + value: 1, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + id: '2', + }, + }, + }), + ).toResolveTruthy(); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '3', + m2: { + create: { + id: '3', + value: 1, + m3: { + create: { + id: '3', + value: 1, + deleted: true, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '3', + m2: { + m3: { value: 1 }, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '3', + m2: { + m3: { id: '3' }, + }, + }, + }), + ).toResolveTruthy(); + }); +}); + +// TODO: field-level policy support +describe.skip('Relation one-to-one filter with field-level override rules', () => { + const model = ` + model M1 { + id String @id @default(uuid()) + m2 M2? + + @@allow('all', true) + } + + model M2 { + id String @id @default(uuid()) @allow('read', true, true) + value Int + deleted Boolean @default(false) + m1 M1 @relation(fields: [m1Id], references:[id]) + m1Id String @unique + m3 M3? + + @@allow('read', !deleted) + @@allow('create', true) + } + + model M3 { + id String @id @default(uuid()) @allow('read', true, true) + value Int + deleted Boolean @default(false) + m2 M2 @relation(fields: [m2Id], references:[id]) + m2Id String @unique + + @@allow('read', !deleted) + @@allow('create', true) + } + `; + + it('is filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: { + id: '1', + value: 1, + m3: { + create: { + id: '1', + value: 1, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + is: { value: 1 }, + }, + }, + }), + ).toResolveTruthy(); + + // m1 with m2 + await db.m1.create({ + data: { + id: '2', + m2: { + create: { + id: '2', + value: 1, + deleted: true, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + is: { value: 1 }, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + is: { id: '2' }, + }, + }, + }), + ).toResolveTruthy(); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '3', + m2: { + create: { + id: '3', + value: 1, + m3: { + create: { + id: '3', + value: 1, + deleted: true, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '3', + m2: { + is: { + m3: { value: 1 }, + }, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '3', + m2: { + is: { + m3: { id: '3' }, + }, + }, + }, + }), + ).toResolveTruthy(); + }); + + it('isNot filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: { + id: '1', + value: 1, + m3: { + create: { + id: '1', + value: 1, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + isNot: { value: 0 }, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + isNot: { value: 1 }, + }, + }, + }), + ).toResolveFalsy(); + + // m1 with m2 + await db.m1.create({ + data: { + id: '2', + m2: { + create: { + id: '2', + value: 1, + deleted: true, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + isNot: { value: 0 }, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + isNot: { value: 1 }, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + isNot: { id: '2' }, + }, + }, + }), + ).toResolveFalsy(); + }); + + it('direct object filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: { + id: '1', + value: 1, + m3: { + create: { + value: 1, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + value: 1, + }, + }, + }), + ).toResolveTruthy(); + + // m1 with m2 + await db.m1.create({ + data: { + id: '2', + m2: { + create: { + id: '2', + value: 1, + deleted: true, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + value: 1, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + id: '2', + }, + }, + }), + ).toResolveTruthy(); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '3', + m2: { + create: { + id: '3', + value: 1, + m3: { + create: { + id: '3', + value: 1, + deleted: true, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '3', + m2: { + m3: { value: 1 }, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '3', + m2: { + m3: { id: '3' }, + }, + }, + }), + ).toResolveTruthy(); + }); +}); diff --git a/packages/runtime/test/policy/migrated/self-relation.test.ts b/packages/runtime/test/policy/migrated/self-relation.test.ts new file mode 100644 index 00000000..f06c34d2 --- /dev/null +++ b/packages/runtime/test/policy/migrated/self-relation.test.ts @@ -0,0 +1,201 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from '../utils'; + +describe('Policy self relations tests', () => { + it('one-to-one', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + value Int + successorId Int? @unique + successor User? @relation("BlogOwnerHistory", fields: [successorId], references: [id]) + predecessor User? @relation("BlogOwnerHistory") + + @@allow('create,update', value > 0) + @@allow('read', true) + } + `, + { usePrismaPush: true }, + ); + + // create denied + await expect( + db.user.create({ + data: { + value: 0, + }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.user.create({ + data: { + value: 1, + successor: { + create: { + value: 0, + }, + }, + }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.user.create({ + data: { + value: 1, + successor: { + create: { + value: 1, + }, + }, + predecessor: { + create: { + value: 0, + }, + }, + }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.user.create({ + data: { + value: 1, + successor: { + create: { + value: 1, + }, + }, + predecessor: { + create: { + value: 1, + }, + }, + }, + }), + ).toResolveTruthy(); + }); + + it('one-to-many', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + value Int + teacherId Int? + teacher User? @relation("TeacherStudents", fields: [teacherId], references: [id]) + students User[] @relation("TeacherStudents") + + @@allow('create,update', value > 0) + @@allow('read', true) + } + `, + { usePrismaPush: true }, + ); + + // create denied + await expect( + db.user.create({ + data: { + value: 0, + }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.user.create({ + data: { + value: 1, + teacher: { + create: { value: 0 }, + }, + }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.user.create({ + data: { + value: 1, + teacher: { + create: { value: 1 }, + }, + students: { + create: [{ value: 0 }, { value: 1 }], + }, + }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.user.create({ + data: { + value: 1, + teacher: { + create: { value: 1 }, + }, + students: { + create: [{ value: 1 }, { value: 2 }], + }, + }, + }), + ).toResolveTruthy(); + }); + + it('many-to-many', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + value Int + followedBy User[] @relation("UserFollows") + following User[] @relation("UserFollows") + + @@allow('create,update', value > 0) + @@allow('read', true) + } + `, + { usePrismaPush: true }, + ); + + // create denied + await expect( + db.user.create({ + data: { + value: 0, + }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.user.create({ + data: { + value: 1, + followedBy: { create: { value: 0 } }, + }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.user.create({ + data: { + value: 1, + followedBy: { create: { value: 1 } }, + following: { create: [{ value: 0 }, { value: 1 }] }, + }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.user.create({ + data: { + value: 1, + followedBy: { create: { value: 1 } }, + following: { create: [{ value: 1 }, { value: 2 }] }, + }, + }), + ).toResolveTruthy(); + }); +}); diff --git a/packages/runtime/test/utils.ts b/packages/runtime/test/utils.ts index f07a2a27..279b95d8 100644 --- a/packages/runtime/test/utils.ts +++ b/packages/runtime/test/utils.ts @@ -1,7 +1,8 @@ import { invariant } from '@zenstackhq/common-helpers'; import { loadDocument } from '@zenstackhq/language'; +import type { Model } from '@zenstackhq/language/ast'; import { PrismaSchemaGenerator } from '@zenstackhq/sdk'; -import { createTestProject, generateTsSchema } from '@zenstackhq/testtools'; +import { createTestProject, generateTsSchema, getPluginModules } from '@zenstackhq/testtools'; import SQLite from 'better-sqlite3'; import { PostgresDialect, SqliteDialect, type LogEvent } from 'kysely'; import { execSync } from 'node:child_process'; @@ -98,9 +99,12 @@ export async function createTestClient( ? `file:${dbName}` : `postgres://${TEST_PG_CONFIG.user}:${TEST_PG_CONFIG.password}@${TEST_PG_CONFIG.host}:${TEST_PG_CONFIG.port}/${dbName}`; + let model: Model | undefined; + if (typeof schema === 'string') { const generated = await generateTsSchema(schema, provider, dbUrl, options?.extraSourceFiles); workDir = generated.workDir; + model = generated.model; // replace schema's provider _schema = { ...generated.schema, @@ -143,11 +147,14 @@ export async function createTestClient( if (options?.usePrismaPush) { invariant(typeof schema === 'string' || schemaFile, 'a schema file must be provided when using prisma db push'); - const r = await loadDocument(path.resolve(workDir!, 'schema.zmodel')); - if (!r.success) { - throw new Error(r.errors.join('\n')); + if (!model) { + const r = await loadDocument(path.join(workDir, 'schema.zmodel'), getPluginModules()); + if (!r.success) { + throw new Error(r.errors.join('\n')); + } + model = r.model; } - const prismaSchema = new PrismaSchemaGenerator(r.model); + const prismaSchema = new PrismaSchemaGenerator(model); const prismaSchemaText = await prismaSchema.generate(); fs.writeFileSync(path.resolve(workDir!, 'schema.prisma'), prismaSchemaText); execSync('npx prisma db push --schema ./schema.prisma --skip-generate --force-reset', { diff --git a/packages/testtools/src/schema.ts b/packages/testtools/src/schema.ts index 788f092c..b4f5386e 100644 --- a/packages/testtools/src/schema.ts +++ b/packages/testtools/src/schema.ts @@ -41,7 +41,7 @@ export async function generateTsSchema( const noPrelude = schemaText.includes('datasource '); fs.writeFileSync(zmodelPath, `${noPrelude ? '' : makePrelude(provider, dbUrl)}\n\n${schemaText}`); - const pluginModelFiles = glob.sync(path.resolve(__dirname, '../../runtime/src/plugins/**/plugin.zmodel')); + const pluginModelFiles = getPluginModules(); const result = await loadDocument(zmodelPath, pluginModelFiles); if (!result.success) { throw new Error(`Failed to load schema from ${zmodelPath}: ${result.errors}`); @@ -59,7 +59,11 @@ export async function generateTsSchema( } // compile the generated TS schema - return compileAndLoad(workDir); + return { ...(await compileAndLoad(workDir)), model: result.model }; +} + +export function getPluginModules() { + return glob.sync(path.resolve(__dirname, '../../runtime/src/plugins/**/plugin.zmodel')); } async function compileAndLoad(workDir: string) {