diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 944d4a59..f9f1eb9a 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -34,6 +34,7 @@ jobs: strategy: matrix: node-version: [20.x] + provider: [sqlite, postgresql] steps: - name: Checkout @@ -76,4 +77,4 @@ jobs: run: pnpm run lint - name: Test - run: pnpm run test + run: TEST_DB_PROVIDER=${{ matrix.provider }} pnpm run test diff --git a/TODO.md b/TODO.md index cd66cb46..c92d4fc6 100644 --- a/TODO.md +++ b/TODO.md @@ -99,9 +99,8 @@ - [ ] Validation - [ ] Access Policy - [ ] Short-circuit pre-create check for scalar-field only policies - - [ ] Inject "replace into" - - [ ] Inject "on conflict do update" - - [ ] Inject "insert into select from" + - [x] Inject "on conflict do update" + - [x] `check` function - [x] Migration - [ ] Databases - [x] SQLite diff --git a/package.json b/package.json index 039fb22b..88f7b917 100644 --- a/package.json +++ b/package.json @@ -1,13 +1,13 @@ { "name": "zenstack-v3", - "version": "3.0.0-beta.4", + "version": "3.0.0-beta.5", "description": "ZenStack", "packageManager": "pnpm@10.12.1", "scripts": { "build": "turbo run build", "watch": "turbo run watch build", "lint": "turbo run lint", - "test": "turbo run test", + "test": "vitest run", "format": "prettier --write \"**/*.{ts,tsx,md}\"", "pr": "gh pr create --fill-first --base dev", "merge-main": "gh pr create --title \"merge dev to main\" --body \"\" --base main --head dev", diff --git a/packages/cli/package.json b/packages/cli/package.json index 63c4249f..17f38a48 100644 --- a/packages/cli/package.json +++ b/packages/cli/package.json @@ -3,7 +3,7 @@ "publisher": "zenstack", "displayName": "ZenStack CLI", "description": "FullStack database toolkit with built-in access control and automatic API generation.", - "version": "3.0.0-beta.4", + "version": "3.0.0-beta.5", "type": "module", "author": { "name": "ZenStack Team" diff --git a/packages/common-helpers/package.json b/packages/common-helpers/package.json index 965c3dd9..78f842e2 100644 --- a/packages/common-helpers/package.json +++ b/packages/common-helpers/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/common-helpers", - "version": "3.0.0-beta.4", + "version": "3.0.0-beta.5", "description": "ZenStack Common Helpers", "type": "module", "scripts": { diff --git a/packages/common-helpers/src/index.ts b/packages/common-helpers/src/index.ts index 7f9c421b..5b63ae85 100644 --- a/packages/common-helpers/src/index.ts +++ b/packages/common-helpers/src/index.ts @@ -4,3 +4,4 @@ export * from './param-case'; export * from './sleep'; export * from './tiny-invariant'; export * from './upper-case-first'; +export * from './zip'; diff --git a/packages/common-helpers/src/zip.ts b/packages/common-helpers/src/zip.ts new file mode 100644 index 00000000..35d4981b --- /dev/null +++ b/packages/common-helpers/src/zip.ts @@ -0,0 +1,11 @@ +/** + * Zips two arrays into an array of tuples. + */ +export function zip(arr1: T[], arr2: U[]): Array<[T, U]> { + const length = Math.min(arr1.length, arr2.length); + const result: Array<[T, U]> = []; + for (let i = 0; i < length; i++) { + result.push([arr1[i]!, arr2[i]!]); + } + return result; +} diff --git a/packages/create-zenstack/package.json b/packages/create-zenstack/package.json index edc194bc..adcac380 100644 --- a/packages/create-zenstack/package.json +++ b/packages/create-zenstack/package.json @@ -1,6 +1,6 @@ { "name": "create-zenstack", - "version": "3.0.0-beta.4", + "version": "3.0.0-beta.5", "description": "Create a new ZenStack project", "type": "module", "scripts": { diff --git a/packages/dialects/sql.js/package.json b/packages/dialects/sql.js/package.json index be06b085..cc43bcd2 100644 --- a/packages/dialects/sql.js/package.json +++ b/packages/dialects/sql.js/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/kysely-sql-js", - "version": "3.0.0-beta.4", + "version": "3.0.0-beta.5", "description": "Kysely dialect for sql.js", "type": "module", "scripts": { diff --git a/packages/eslint-config/package.json b/packages/eslint-config/package.json index 690197dc..5b8ec217 100644 --- a/packages/eslint-config/package.json +++ b/packages/eslint-config/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/eslint-config", - "version": "3.0.0-beta.4", + "version": "3.0.0-beta.5", "type": "module", "private": true, "license": "MIT" diff --git a/packages/language/package.json b/packages/language/package.json index 2e130d7f..0754fdd8 100644 --- a/packages/language/package.json +++ b/packages/language/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/language", "description": "ZenStack ZModel language specification", - "version": "3.0.0-beta.4", + "version": "3.0.0-beta.5", "license": "MIT", "author": "ZenStack Team", "files": [ diff --git a/packages/language/res/stdlib.zmodel b/packages/language/res/stdlib.zmodel index 0d8c4264..c248bde0 100644 --- a/packages/language/res/stdlib.zmodel +++ b/packages/language/res/stdlib.zmodel @@ -123,8 +123,10 @@ function future(): Any { } @@@expressionContext([AccessPolicy]) /** - * If the field value contains the search string. By default, the search is case-sensitive, - * but you can override the behavior with the "caseInSensitive" argument. + * Checks if the field value contains the search string. By default, the search is case-sensitive, and + * "LIKE" operator is used to match. If `caseInSensitive` is true, "ILIKE" operator is used if + * supported, otherwise it still falls back to "LIKE" and delivers whatever the database's + * behavior is. */ function contains(field: String, search: String, caseInSensitive: Boolean?): Boolean { } @@@expressionContext([AccessPolicy, ValidationRule]) @@ -136,15 +138,21 @@ function search(field: String, search: String): Boolean { } @@@expressionContext([AccessPolicy]) /** - * If the field value starts with the search string + * Checks the field value starts with the search string. By default, the search is case-sensitive, and + * "LIKE" operator is used to match. If `caseInSensitive` is true, "ILIKE" operator is used if + * supported, otherwise it still falls back to "LIKE" and delivers whatever the database's + * behavior is. */ -function startsWith(field: String, search: String): Boolean { +function startsWith(field: String, search: String, caseInSensitive: Boolean?): Boolean { } @@@expressionContext([AccessPolicy, ValidationRule]) /** - * If the field value ends with the search string + * Checks if the field value ends with the search string. By default, the search is case-sensitive, and + * "LIKE" operator is used to match. If `caseInSensitive` is true, "ILIKE" operator is used if + * supported, otherwise it still falls back to "LIKE" and delivers whatever the database's + * behavior is. */ -function endsWith(field: String, search: String): Boolean { +function endsWith(field: String, search: String, caseInSensitive: Boolean?): Boolean { } @@@expressionContext([AccessPolicy, ValidationRule]) /** @@ -594,16 +602,6 @@ function datetime(field: String): Boolean { function url(field: String): Boolean { } @@@expressionContext([ValidationRule]) -/** - * Checks if the current user can perform the given operation on the given field. - * - * @param field: The field to check access for - * @param operation: The operation to check access for. Can be "read", "create", "update", or "delete". If the operation is not provided, - * it defaults the operation of the containing policy rule. - */ -function check(field: Any, operation: String?): Boolean { -} @@@expressionContext([AccessPolicy]) - ////////////////////////////////////////////// // End validation attributes and functions ////////////////////////////////////////////// diff --git a/packages/language/src/utils.ts b/packages/language/src/utils.ts index 06677192..b4b5dd30 100644 --- a/packages/language/src/utils.ts +++ b/packages/language/src/utils.ts @@ -357,8 +357,9 @@ export function getFieldReference(expr: Expression): DataField | undefined { } } +// TODO: move to policy plugin export function isCheckInvocation(node: AstNode) { - return isInvocationExpr(node) && node.function.ref?.name === 'check' && isFromStdlib(node.function.ref); + return isInvocationExpr(node) && node.function.ref?.name === 'check'; } export function resolveTransitiveImports(documents: LangiumDocuments, model: Model) { diff --git a/packages/language/src/validators/expression-validator.ts b/packages/language/src/validators/expression-validator.ts index cf74db06..f8dc4930 100644 --- a/packages/language/src/validators/expression-validator.ts +++ b/packages/language/src/validators/expression-validator.ts @@ -108,9 +108,12 @@ export default class ExpressionValidator implements AstValidator { supportedShapes = ['Boolean', 'Any']; } + const leftResolvedDecl = expr.left.$resolvedType?.decl; + const rightResolvedDecl = expr.right.$resolvedType?.decl; + if ( - typeof expr.left.$resolvedType?.decl !== 'string' || - !supportedShapes.includes(expr.left.$resolvedType.decl) + leftResolvedDecl && + (typeof leftResolvedDecl !== 'string' || !supportedShapes.includes(leftResolvedDecl)) ) { accept('error', `invalid operand type for "${expr.operator}" operator`, { node: expr.left, @@ -118,8 +121,8 @@ export default class ExpressionValidator implements AstValidator { return; } if ( - typeof expr.right.$resolvedType?.decl !== 'string' || - !supportedShapes.includes(expr.right.$resolvedType.decl) + rightResolvedDecl && + (typeof rightResolvedDecl !== 'string' || !supportedShapes.includes(rightResolvedDecl)) ) { accept('error', `invalid operand type for "${expr.operator}" operator`, { node: expr.right, @@ -128,14 +131,11 @@ export default class ExpressionValidator implements AstValidator { } // DateTime comparison is only allowed between two DateTime values - if (expr.left.$resolvedType.decl === 'DateTime' && expr.right.$resolvedType.decl !== 'DateTime') { + if (leftResolvedDecl === 'DateTime' && rightResolvedDecl && rightResolvedDecl !== 'DateTime') { accept('error', 'incompatible operand types', { node: expr, }); - } else if ( - expr.right.$resolvedType.decl === 'DateTime' && - expr.left.$resolvedType.decl !== 'DateTime' - ) { + } else if (rightResolvedDecl === 'DateTime' && leftResolvedDecl && leftResolvedDecl !== 'DateTime') { accept('error', 'incompatible operand types', { node: expr, }); diff --git a/packages/language/src/validators/function-invocation-validator.ts b/packages/language/src/validators/function-invocation-validator.ts index b640ad1b..e75c8e3d 100644 --- a/packages/language/src/validators/function-invocation-validator.ts +++ b/packages/language/src/validators/function-invocation-validator.ts @@ -170,6 +170,7 @@ export default class FunctionInvocationValidator implements AstValidator { it('supports inheriting from delegate', async () => { const model = await loadSchema(` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + model A { id Int @id @default(autoincrement()) x String @@ -24,6 +29,11 @@ describe('Delegate Tests', () => { it('rejects inheriting from non-delegate models', async () => { await loadSchemaWithError( ` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + model A { id Int @id @default(autoincrement()) x String @@ -40,6 +50,11 @@ describe('Delegate Tests', () => { it('can detect cyclic inherits', async () => { await loadSchemaWithError( ` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + model A extends B { x String @@delegate(x) @@ -57,6 +72,11 @@ describe('Delegate Tests', () => { it('can detect duplicated fields from base model', async () => { await loadSchemaWithError( ` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + model A { id String @id x String @@ -74,6 +94,11 @@ describe('Delegate Tests', () => { it('can detect duplicated attributes from base model', async () => { await loadSchemaWithError( ` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + model A { id String @id x String diff --git a/packages/language/test/import.test.ts b/packages/language/test/import.test.ts index 48cec382..98d90d22 100644 --- a/packages/language/test/import.test.ts +++ b/packages/language/test/import.test.ts @@ -12,6 +12,11 @@ describe('Import tests', () => { fs.writeFileSync( path.join(name, 'a.zmodel'), ` +datasource db { + provider = 'sqlite' + url = 'file:./dev.db' +} + model A { id Int @id name String @@ -48,6 +53,12 @@ enum Role { path.join(name, 'b.zmodel'), ` import './a' + +datasource db { + provider = 'sqlite' + url = 'file:./dev.db' +} + model User { id Int @id role Role @@ -56,7 +67,7 @@ model User { ); const model = await expectLoaded(path.join(name, 'b.zmodel')); - expect((model.declarations[0] as DataModel).fields[1].type.reference?.ref?.name).toBe('Role'); + expect((model.declarations[1] as DataModel).fields[1].type.reference?.ref?.name).toBe('Role'); }); it('supports cyclic imports', async () => { @@ -65,6 +76,12 @@ model User { path.join(name, 'a.zmodel'), ` import './b' + +datasource db { + provider = 'sqlite' + url = 'file:./dev.db' +} + model A { id Int @id b B? @@ -86,7 +103,7 @@ model B { const modelB = await expectLoaded(path.join(name, 'b.zmodel')); expect((modelB.declarations[0] as DataModel).fields[1].type.reference?.ref?.name).toBe('A'); const modelA = await expectLoaded(path.join(name, 'a.zmodel')); - expect((modelA.declarations[0] as DataModel).fields[1].type.reference?.ref?.name).toBe('B'); + expect((modelA.declarations[1] as DataModel).fields[1].type.reference?.ref?.name).toBe('B'); }); async function expectLoaded(file: string) { diff --git a/packages/language/test/mixin.test.ts b/packages/language/test/mixin.test.ts index 8e7bcd0a..8fa9e933 100644 --- a/packages/language/test/mixin.test.ts +++ b/packages/language/test/mixin.test.ts @@ -5,6 +5,11 @@ import { DataModel, TypeDef } from '../src/ast'; describe('Mixin Tests', () => { it('supports model mixing types to Model', async () => { const model = await loadSchema(` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + type A { x String } @@ -25,6 +30,11 @@ describe('Mixin Tests', () => { it('supports model mixing types to type', async () => { const model = await loadSchema(` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + type A { x String } @@ -52,6 +62,11 @@ describe('Mixin Tests', () => { it('can detect cyclic mixins', async () => { await loadSchemaWithError( ` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + type A with B { x String } diff --git a/packages/runtime/package.json b/packages/runtime/package.json index da696705..561eb33f 100644 --- a/packages/runtime/package.json +++ b/packages/runtime/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/runtime", - "version": "3.0.0-beta.4", + "version": "3.0.0-beta.5", "description": "ZenStack Runtime", "type": "module", "scripts": { @@ -8,6 +8,8 @@ "watch": "tsup-node --watch", "lint": "eslint src --ext ts", "test": "vitest run && pnpm test:typecheck", + "test:sqlite": "TEST_DB_PROVIDER=sqlite vitest run", + "test:postgresql": "TEST_DB_PROVIDER=postgresql vitest run", "test:generate": "tsx test/scripts/generate.ts", "test:typecheck": "tsc --project tsconfig.test.json", "pack": "pnpm pack" @@ -73,7 +75,8 @@ "toposort": "^2.0.2", "ts-pattern": "catalog:", "ulid": "^3.0.0", - "uuid": "^11.0.5" + "uuid": "^11.0.5", + "zod-validation-error": "catalog:" }, "peerDependencies": { "better-sqlite3": "^12.2.0", diff --git a/packages/runtime/src/client/client-impl.ts b/packages/runtime/src/client/client-impl.ts index c762700f..c0585455 100644 --- a/packages/runtime/src/client/client-impl.ts +++ b/packages/runtime/src/client/client-impl.ts @@ -61,7 +61,7 @@ export class ClientImpl { executor?: QueryExecutor, ) { this.$schema = schema; - this.$options = options ?? ({} as ClientOptions); + this.$options = options; this.$options.functions = { ...BuiltinFunctions, @@ -326,7 +326,7 @@ export class ClientImpl { function createClientProxy(client: ClientImpl): ClientImpl { const inputValidator = new InputValidator(client.$schema); - const resultProcessor = new ResultProcessor(client.$schema); + const resultProcessor = new ResultProcessor(client.$schema, client.$options); return new Proxy(client, { get: (target, prop, receiver) => { diff --git a/packages/runtime/src/client/contract.ts b/packages/runtime/src/client/contract.ts index 94247354..0d38e34e 100644 --- a/packages/runtime/src/client/contract.ts +++ b/packages/runtime/src/client/contract.ts @@ -213,6 +213,11 @@ export interface ClientConstructor { */ export type CRUD = 'create' | 'read' | 'update' | 'delete'; +/** + * CRUD operations. + */ +export const CRUD = ['create', 'read', 'update', 'delete'] as const; + //#region Model operations export type AllModelOperations> = { diff --git a/packages/runtime/src/client/crud-types.ts b/packages/runtime/src/client/crud-types.ts index 12c2a64e..fd9714ed 100644 --- a/packages/runtime/src/client/crud-types.ts +++ b/packages/runtime/src/client/crud-types.ts @@ -452,7 +452,7 @@ export type OmitInput> export type SelectIncludeOmit, AllowCount extends boolean> = { select?: SelectInput; - include?: IncludeInput; + include?: IncludeInput; omit?: OmitInput; }; @@ -463,14 +463,7 @@ export type SelectInput< AllowRelation extends boolean = true, > = { [Key in NonRelationFields]?: boolean; -} & (AllowRelation extends true ? IncludeInput : {}) & // relation fields - // relation count - (AllowCount extends true - ? // _count is only allowed if the model has to-many relations - HasToManyRelations extends true - ? { _count?: SelectCount } - : {} - : {}); +} & (AllowRelation extends true ? IncludeInput : {}); type SelectCount> = | boolean @@ -484,7 +477,11 @@ type SelectCount> = }; }; -export type IncludeInput> = { +export type IncludeInput< + Schema extends SchemaDef, + Model extends GetModels, + AllowCount extends boolean = true, +> = { [Key in RelationFields]?: | boolean | FindArgs< @@ -498,7 +495,12 @@ export type IncludeInput; -}; +} & (AllowCount extends true + ? // _count is only allowed if the model has to-many relations + HasToManyRelations extends true + ? { _count?: SelectCount } + : {} + : {}); export type Subset = { [key in keyof T]: key extends keyof U ? T[key] : never; @@ -674,7 +676,7 @@ export type FindUniqueArgs> = { data: CreateInput; - select?: SelectInput; + select?: SelectInput; include?: IncludeInput; omit?: OmitInput; }; @@ -813,7 +815,7 @@ type NestedCreateManyInput< export type UpdateArgs> = { data: UpdateInput; where: WhereUniqueInput; - select?: SelectInput; + select?: SelectInput; include?: IncludeInput; omit?: OmitInput; }; @@ -841,7 +843,7 @@ export type UpsertArgs create: CreateInput; update: UpdateInput; where: WhereUniqueInput; - select?: SelectInput; + select?: SelectInput; include?: IncludeInput; omit?: OmitInput; }; @@ -958,7 +960,7 @@ type ToOneRelationUpdateInput< export type DeleteArgs> = { where: WhereUniqueInput; - select?: SelectInput; + select?: SelectInput; include?: IncludeInput; omit?: OmitInput; }; diff --git a/packages/runtime/src/client/crud/dialects/base-dialect.ts b/packages/runtime/src/client/crud/dialects/base-dialect.ts index 9f314bf9..7357c8f5 100644 --- a/packages/runtime/src/client/crud/dialects/base-dialect.ts +++ b/packages/runtime/src/client/crud/dialects/base-dialect.ts @@ -24,7 +24,6 @@ import { ensureArray, flattenCompoundUniqueFilters, getDelegateDescendantModels, - getIdFields, getManyToManyRelation, getRelationForeignKeyFieldPairs, isEnum, @@ -32,6 +31,7 @@ import { isRelationField, makeDefaultOrderBy, requireField, + requireIdFields, requireModel, } from '../../query-utils'; @@ -45,6 +45,10 @@ export abstract class BaseCrudDialect { return value; } + transformOutput(value: unknown, _type: BuiltinType) { + return value; + } + // #region common query builders buildSelectModel(eb: ExpressionBuilder, model: string, modelAlias: string) { @@ -366,10 +370,14 @@ export abstract class BaseCrudDialect { const m2m = getManyToManyRelation(this.schema, model, field); if (m2m) { // many-to-many relation - const modelIdField = getIdFields(this.schema, model)[0]!; - const relationIdField = getIdFields(this.schema, relationModel)[0]!; + + const modelIdFields = requireIdFields(this.schema, model); + invariant(modelIdFields.length === 1, 'many-to-many relation must have exactly one id field'); + const relationIdFields = requireIdFields(this.schema, relationModel); + invariant(relationIdFields.length === 1, 'many-to-many relation must have exactly one id field'); + return eb( - sql.ref(`${relationFilterSelectAlias}.${relationIdField}`), + sql.ref(`${relationFilterSelectAlias}.${relationIdFields[0]}`), 'in', eb .selectFrom(m2m.joinTable) @@ -377,7 +385,7 @@ export abstract class BaseCrudDialect { .whereRef( sql.ref(`${m2m.joinTable}.${m2m.parentFkName}`), '=', - sql.ref(`${modelAlias}.${modelIdField}`), + sql.ref(`${modelAlias}.${modelIdFields[0]}`), ), ); } else { @@ -1012,7 +1020,7 @@ export abstract class BaseCrudDialect { otherModelAlias: string, query: SelectQueryBuilder, ) { - const idFields = getIdFields(this.schema, thisModel); + const idFields = requireIdFields(this.schema, thisModel); query = query.leftJoin(otherModelAlias, (qb) => { for (const idField of idFields) { qb = qb.onRef(`${thisModelAlias}.${idField}`, '=', `${otherModelAlias}.${idField}`); @@ -1044,14 +1052,29 @@ export abstract class BaseCrudDialect { for (const [field, value] of Object.entries(selections.select)) { const fieldDef = requireField(this.schema, model, field); const fieldModel = fieldDef.type; - const joinPairs = buildJoinPairs(this.schema, model, parentAlias, field, fieldModel); - - // build a nested query to count the number of records in the relation - let fieldCountQuery = eb.selectFrom(fieldModel).select(eb.fn.countAll().as(`_count$${field}`)); + let fieldCountQuery: SelectQueryBuilder; // join conditions - for (const [left, right] of joinPairs) { - fieldCountQuery = fieldCountQuery.whereRef(left, '=', right); + const m2m = getManyToManyRelation(this.schema, model, field); + if (m2m) { + // many-to-many relation, count the join table + fieldCountQuery = eb + .selectFrom(fieldModel) + .innerJoin(m2m.joinTable, (join) => + join + .onRef(`${m2m.joinTable}.${m2m.otherFkName}`, '=', `${fieldModel}.${m2m.otherPKName}`) + .onRef(`${m2m.joinTable}.${m2m.parentFkName}`, '=', `${parentAlias}.${m2m.parentPKName}`), + ) + .select(eb.fn.countAll().as(`_count$${field}`)); + } else { + // build a nested query to count the number of records in the relation + fieldCountQuery = eb.selectFrom(fieldModel).select(eb.fn.countAll().as(`_count$${field}`)); + + // join conditions + const joinPairs = buildJoinPairs(this.schema, model, parentAlias, field, fieldModel); + for (const [left, right] of joinPairs) { + fieldCountQuery = fieldCountQuery.whereRef(left, '=', right); + } } // merge _count filter @@ -1236,5 +1259,15 @@ export abstract class BaseCrudDialect { */ abstract get supportInsertWithDefault(): boolean; + /** + * Gets the SQL column type for the given field definition. + */ + abstract getFieldSqlType(fieldDef: FieldDef): string; + + /* + * Gets the string casing behavior for the dialect. + */ + abstract getStringCasingBehavior(): { supportsILike: boolean; likeCaseSensitive: boolean }; + // #endregion } diff --git a/packages/runtime/src/client/crud/dialects/postgresql.ts b/packages/runtime/src/client/crud/dialects/postgresql.ts index a71e987d..b6c40661 100644 --- a/packages/runtime/src/client/crud/dialects/postgresql.ts +++ b/packages/runtime/src/client/crud/dialects/postgresql.ts @@ -1,4 +1,5 @@ import { invariant } from '@zenstackhq/common-helpers'; +import Decimal from 'decimal.js'; import { sql, type Expression, @@ -11,18 +12,24 @@ import { match } from 'ts-pattern'; import type { BuiltinType, FieldDef, GetModels, SchemaDef } from '../../../schema'; import { DELEGATE_JOINED_FIELD_PREFIX } from '../../constants'; import type { FindArgs } from '../../crud-types'; +import { QueryError } from '../../errors'; +import type { ClientOptions } from '../../options'; import { buildJoinPairs, getDelegateDescendantModels, - getIdFields, getManyToManyRelation, isRelationField, requireField, + requireIdFields, requireModel, } from '../../query-utils'; import { BaseCrudDialect } from './base-dialect'; export class PostgresCrudDialect extends BaseCrudDialect { + constructor(schema: Schema, options: ClientOptions) { + super(schema, options); + } + override get provider() { return 'postgresql' as const; } @@ -44,13 +51,69 @@ export class PostgresCrudDialect extends BaseCrudDiale } else { return match(type) .with('DateTime', () => - value instanceof Date ? value : typeof value === 'string' ? new Date(value) : value, + value instanceof Date + ? value.toISOString() + : typeof value === 'string' + ? new Date(value).toISOString() + : value, ) .with('Decimal', () => (value !== null ? value.toString() : value)) .otherwise(() => value); } } + override transformOutput(value: unknown, type: BuiltinType) { + if (value === null || value === undefined) { + return value; + } + return match(type) + .with('DateTime', () => this.transformOutputDate(value)) + .with('Bytes', () => this.transformOutputBytes(value)) + .with('BigInt', () => this.transformOutputBigInt(value)) + .with('Decimal', () => this.transformDecimal(value)) + .otherwise(() => super.transformOutput(value, type)); + } + + private transformOutputBigInt(value: unknown) { + if (typeof value === 'bigint') { + return value; + } + invariant( + typeof value === 'string' || typeof value === 'number', + `Expected string or number, got ${typeof value}`, + ); + return BigInt(value); + } + + private transformDecimal(value: unknown) { + if (value instanceof Decimal) { + return value; + } + invariant( + typeof value === 'string' || typeof value === 'number' || value instanceof Decimal, + `Expected string, number or Decimal, got ${typeof value}`, + ); + return new Decimal(value); + } + + private transformOutputDate(value: unknown) { + if (typeof value === 'string') { + return new Date(value); + } else if (value instanceof Date && this.options.fixPostgresTimezone !== false) { + // SPECIAL NOTES: + // node-pg has a terrible quirk that it returns the date value in local timezone + // as a `Date` object although for `DateTime` field the data in DB is stored in UTC + // see: https://github.com/brianc/node-postgres/issues/429 + return new Date(value.getTime() - value.getTimezoneOffset() * 60 * 1000); + } else { + return value; + } + } + + private transformOutputBytes(value: unknown) { + return Buffer.isBuffer(value) ? Uint8Array.from(value) : value; + } + override buildRelationSelection( query: SelectQueryBuilder, model: string, @@ -157,8 +220,8 @@ export class PostgresCrudDialect extends BaseCrudDiale const m2m = getManyToManyRelation(this.schema, model, relationField); if (m2m) { // many-to-many relation - const parentIds = getIdFields(this.schema, model); - const relationIds = getIdFields(this.schema, relationModel); + const parentIds = requireIdFields(this.schema, model); + const relationIds = requireIdFields(this.schema, relationModel); invariant(parentIds.length === 1, 'many-to-many relation must have exactly one id field'); invariant(relationIds.length === 1, 'many-to-many relation must have exactly one id field'); query = query.where((eb) => @@ -370,4 +433,42 @@ export class PostgresCrudDialect extends BaseCrudDiale override get supportInsertWithDefault() { return true; } + + override getFieldSqlType(fieldDef: FieldDef) { + // TODO: respect `@db.x` attributes + if (fieldDef.relation) { + throw new QueryError('Cannot get SQL type of a relation field'); + } + + let result: string; + + if (this.schema.enums?.[fieldDef.type]) { + // enums are treated as text + result = 'text'; + } else { + result = match(fieldDef.type) + .with('String', () => 'text') + .with('Boolean', () => 'boolean') + .with('Int', () => 'integer') + .with('BigInt', () => 'bigint') + .with('Float', () => 'double precision') + .with('Decimal', () => 'decimal') + .with('DateTime', () => 'timestamp') + .with('Bytes', () => 'bytea') + .with('Json', () => 'jsonb') + // fallback to text + .otherwise(() => 'text'); + } + + if (fieldDef.array) { + result += '[]'; + } + + return result; + } + + override getStringCasingBehavior() { + // Postgres `LIKE` is case-sensitive, `ILIKE` is case-insensitive + return { supportsILike: true, likeCaseSensitive: true }; + } } diff --git a/packages/runtime/src/client/crud/dialects/sqlite.ts b/packages/runtime/src/client/crud/dialects/sqlite.ts index 69de608d..5c024dfb 100644 --- a/packages/runtime/src/client/crud/dialects/sqlite.ts +++ b/packages/runtime/src/client/crud/dialects/sqlite.ts @@ -1,5 +1,5 @@ import { invariant } from '@zenstackhq/common-helpers'; -import type Decimal from 'decimal.js'; +import Decimal from 'decimal.js'; import { ExpressionWrapper, sql, @@ -9,15 +9,16 @@ import { type SelectQueryBuilder, } from 'kysely'; import { match } from 'ts-pattern'; -import type { BuiltinType, GetModels, SchemaDef } from '../../../schema'; +import type { BuiltinType, FieldDef, GetModels, SchemaDef } from '../../../schema'; import { DELEGATE_JOINED_FIELD_PREFIX } from '../../constants'; import type { FindArgs } from '../../crud-types'; +import { QueryError } from '../../errors'; import { getDelegateDescendantModels, - getIdFields, getManyToManyRelation, getRelationForeignKeyFieldPairs, requireField, + requireIdFields, requireModel, } from '../../query-utils'; import { BaseCrudDialect } from './base-dialect'; @@ -41,7 +42,13 @@ export class SqliteCrudDialect extends BaseCrudDialect } else { return match(type) .with('Boolean', () => (value ? 1 : 0)) - .with('DateTime', () => (value instanceof Date ? value.toISOString() : value)) + .with('DateTime', () => + value instanceof Date + ? value.toISOString() + : typeof value === 'string' + ? new Date(value).toISOString() + : value, + ) .with('Decimal', () => (value as Decimal).toString()) .with('Bytes', () => Buffer.from(value as Uint8Array)) .with('Json', () => JSON.stringify(value)) @@ -50,6 +57,76 @@ export class SqliteCrudDialect extends BaseCrudDialect } } + override transformOutput(value: unknown, type: BuiltinType) { + if (value === null || value === undefined) { + return value; + } else if (this.schema.typeDefs && type in this.schema.typeDefs) { + // typed JSON field + return this.transformOutputJson(value); + } else { + return match(type) + .with('Boolean', () => this.transformOutputBoolean(value)) + .with('DateTime', () => this.transformOutputDate(value)) + .with('Bytes', () => this.transformOutputBytes(value)) + .with('Decimal', () => this.transformOutputDecimal(value)) + .with('BigInt', () => this.transformOutputBigInt(value)) + .with('Json', () => this.transformOutputJson(value)) + .otherwise(() => super.transformOutput(value, type)); + } + } + + private transformOutputDecimal(value: unknown) { + if (value instanceof Decimal) { + return value; + } + invariant( + typeof value === 'string' || typeof value === 'number' || value instanceof Decimal, + `Expected string, number or Decimal, got ${typeof value}`, + ); + return new Decimal(value); + } + + private transformOutputBigInt(value: unknown) { + if (typeof value === 'bigint') { + return value; + } + invariant( + typeof value === 'string' || typeof value === 'number', + `Expected string or number, got ${typeof value}`, + ); + return BigInt(value); + } + + private transformOutputBoolean(value: unknown) { + return !!value; + } + + private transformOutputDate(value: unknown) { + if (typeof value === 'number') { + return new Date(value); + } else if (typeof value === 'string') { + return new Date(value); + } else { + return value; + } + } + + private transformOutputBytes(value: unknown) { + return Buffer.isBuffer(value) ? Uint8Array.from(value) : value; + } + + private transformOutputJson(value: unknown) { + // better-sqlite3 typically returns JSON as string; be tolerant + if (typeof value === 'string') { + try { + return JSON.parse(value); + } catch (e) { + throw new QueryError('Invalid JSON returned', e); + } + } + return value; + } + override buildRelationSelection( query: SelectQueryBuilder, model: string, @@ -213,8 +290,8 @@ export class SqliteCrudDialect extends BaseCrudDialect const m2m = getManyToManyRelation(this.schema, model, relationField); if (m2m) { // many-to-many relation - const parentIds = getIdFields(this.schema, model); - const relationIds = getIdFields(this.schema, relationModel); + const parentIds = requireIdFields(this.schema, model); + const relationIds = requireIdFields(this.schema, relationModel); invariant(parentIds.length === 1, 'many-to-many relation must have exactly one id field'); invariant(relationIds.length === 1, 'many-to-many relation must have exactly one id field'); selectModelQuery = selectModelQuery.where((eb) => @@ -301,4 +378,39 @@ export class SqliteCrudDialect extends BaseCrudDialect override get supportInsertWithDefault() { return false; } + + override getFieldSqlType(fieldDef: FieldDef) { + // TODO: respect `@db.x` attributes + if (fieldDef.relation) { + throw new QueryError('Cannot get SQL type of a relation field'); + } + if (fieldDef.array) { + throw new QueryError('SQLite does not support scalar list type'); + } + + if (this.schema.enums?.[fieldDef.type]) { + // enums are stored as text + return 'text'; + } + + return ( + match(fieldDef.type) + .with('String', () => 'text') + .with('Boolean', () => 'integer') + .with('Int', () => 'integer') + .with('BigInt', () => 'integer') + .with('Float', () => 'real') + .with('Decimal', () => 'decimal') + .with('DateTime', () => 'numeric') + .with('Bytes', () => 'blob') + .with('Json', () => 'jsonb') + // fallback to text + .otherwise(() => 'text') + ); + } + + override getStringCasingBehavior() { + // SQLite `LIKE` is case-insensitive, and there is no `ILIKE` + return { supportsILike: false, likeCaseSensitive: false }; + } } diff --git a/packages/runtime/src/client/crud/operations/base.ts b/packages/runtime/src/client/crud/operations/base.ts index 65d0d32b..170b8d89 100644 --- a/packages/runtime/src/client/crud/operations/base.ts +++ b/packages/runtime/src/client/crud/operations/base.ts @@ -31,7 +31,6 @@ import { flattenCompoundUniqueFilters, getDiscriminatorField, getField, - getIdFields, getIdValues, getManyToManyRelation, getModel, @@ -40,6 +39,7 @@ import { isRelationField, isScalarField, requireField, + requireIdFields, requireModel, } from '../../query-utils'; import { getCrudDialect } from '../dialects'; @@ -132,7 +132,7 @@ export abstract class BaseOperationHandler { model: GetModels, filter: any, ): Promise { - const idFields = getIdFields(this.schema, model); + const idFields = requireIdFields(this.schema, model); const _filter = flattenCompoundUniqueFilters(this.schema, model, filter); const query = kysely .selectFrom(model) @@ -279,7 +279,8 @@ export abstract class BaseOperationHandler { if (!ownedByModel) { // assign fks from parent - const parentFkFields = this.buildFkAssignments( + const parentFkFields = await this.buildFkAssignments( + kysely, fromRelation.model, fromRelation.field, fromRelation.ids, @@ -344,7 +345,7 @@ export abstract class BaseOperationHandler { } const updatedData = this.fillGeneratedAndDefaultValues(modelDef, createFields); - const idFields = getIdFields(this.schema, model); + const idFields = requireIdFields(this.schema, model); const query = kysely .insertInto(model) .$if(Object.keys(updatedData).length === 0, (qb) => qb.defaultValues()) @@ -359,22 +360,11 @@ export abstract class BaseOperationHandler { const createdEntity = await this.executeQueryTakeFirst(kysely, query, 'create'); - // let createdEntity: any; - // try { - // createdEntity = await this.executeQueryTakeFirst(kysely, query, 'create'); - // } catch (err) { - // const { sql, parameters } = query.compile(); - // throw new QueryError(`Error during create: ${err}, sql: ${sql}, parameters: ${parameters}`); - // } - if (Object.keys(postCreateRelations).length > 0) { // process nested creates that need to happen after the current entity is created - const relationPromises = Object.entries(postCreateRelations).map(([field, subPayload]) => { - return this.processNoneOwnedRelationForCreate(kysely, model, field, subPayload, createdEntity); - }); - - // await relation creation - await Promise.all(relationPromises); + for (const [field, subPayload] of Object.entries(postCreateRelations)) { + await this.processNoneOwnedRelationForCreate(kysely, model, field, subPayload, createdEntity); + } } if (fromRelation && m2m) { @@ -433,7 +423,12 @@ export abstract class BaseOperationHandler { return { baseEntity, remainingFields }; } - private buildFkAssignments(model: string, relationField: string, entity: any) { + private async buildFkAssignments( + kysely: ToKysely, + model: GetModels, + relationField: string, + entity: any, + ) { const parentFkFields: any = {}; invariant(relationField, 'parentField must be defined if parentModel is defined'); @@ -443,7 +438,18 @@ export abstract class BaseOperationHandler { for (const pair of keyPairs) { if (!(pair.pk in entity)) { - throw new QueryError(`Field "${pair.pk}" not found in parent created data`); + // the relation may be using a non-id field as fk, so we read in-place + // to fetch that field + const extraRead = await this.readUnique(kysely, model, { + where: entity, + select: { [pair.pk]: true }, + } as any); + if (!extraRead) { + throw new QueryError(`Field "${pair.pk}" not found in parent created data`); + } else { + // update the parent entity + Object.assign(entity, extraRead); + } } Object.assign(parentFkFields, { [pair.fk]: (entity as any)[pair.pk], @@ -475,14 +481,14 @@ export abstract class BaseOperationHandler { entity: rightEntity, }, ].sort((a, b) => - // the implement m2m join table's "A", "B" fk fields' order is determined + // the implicit m2m join table's "A", "B" fk fields' order is determined // by model name's sort order, and when identical (for self-relations), // field name's sort order a.model !== b.model ? a.model.localeCompare(b.model) : a.field.localeCompare(b.field), ); - const firstIds = getIdFields(this.schema, sortedRecords[0]!.model); - const secondIds = getIdFields(this.schema, sortedRecords[1]!.model); + const firstIds = requireIdFields(this.schema, sortedRecords[0]!.model); + const secondIds = requireIdFields(this.schema, sortedRecords[1]!.model); invariant(firstIds.length === 1, 'many-to-many relation must have exactly one id field'); invariant(secondIds.length === 1, 'many-to-many relation must have exactly one id field'); @@ -588,7 +594,7 @@ export abstract class BaseOperationHandler { return result; } - private processNoneOwnedRelationForCreate( + private async processNoneOwnedRelationForCreate( kysely: ToKysely, contextModel: GetModels, relationFieldName: string, @@ -597,7 +603,6 @@ export abstract class BaseOperationHandler { ) { const relationFieldDef = this.requireField(contextModel, relationFieldName); const relationModel = relationFieldDef.type as GetModels; - const tasks: Promise[] = []; const fromRelationContext: FromRelationContext = { model: contextModel, field: relationFieldName, @@ -612,43 +617,38 @@ export abstract class BaseOperationHandler { switch (action) { case 'create': { // create with a parent entity - tasks.push( - ...enumerate(subPayload).map((item) => - this.create(kysely, relationModel, item, fromRelationContext), - ), - ); + for (const item of enumerate(subPayload)) { + await this.create(kysely, relationModel, item, fromRelationContext); + } break; } case 'createMany': { invariant(relationFieldDef.array, 'relation must be an array for createMany'); - tasks.push( - this.createMany( - kysely, - relationModel, - subPayload as { data: any; skipDuplicates: boolean }, - false, - fromRelationContext, - ), + await this.createMany( + kysely, + relationModel, + subPayload as { data: any; skipDuplicates: boolean }, + false, + fromRelationContext, ); break; } case 'connect': { - tasks.push(this.connectRelation(kysely, relationModel, subPayload, fromRelationContext)); + await this.connectRelation(kysely, relationModel, subPayload, fromRelationContext); break; } case 'connectOrCreate': { - tasks.push( - ...enumerate(subPayload).map((item) => - this.exists(kysely, relationModel, item.where).then((found) => - !found - ? this.create(kysely, relationModel, item.create, fromRelationContext) - : this.connectRelation(kysely, relationModel, found, fromRelationContext), - ), - ), - ); + for (const item of enumerate(subPayload)) { + const found = await this.exists(kysely, relationModel, item.where); + if (!found) { + await this.create(kysely, relationModel, item.create, fromRelationContext); + } else { + await this.connectRelation(kysely, relationModel, found, fromRelationContext); + } + } break; } @@ -656,8 +656,6 @@ export abstract class BaseOperationHandler { throw new QueryError(`Invalid relation action: ${action}`); } } - - return Promise.all(tasks); } protected async createMany< @@ -771,7 +769,7 @@ export abstract class BaseOperationHandler { const result = await this.executeQuery(kysely, query, 'createMany'); return { count: Number(result.numAffectedRows) } as Result; } else { - const idFields = getIdFields(this.schema, model); + const idFields = requireIdFields(this.schema, model); const result = await query.returning(idFields as any).execute(); return result as Result; } @@ -1009,10 +1007,7 @@ export abstract class BaseOperationHandler { throw new QueryError(`Relation update not allowed for field "${field}"`); } if (!thisEntity) { - thisEntity = await this.readUnique(kysely, model, { - where: combinedWhere, - select: this.makeIdSelect(model), - }); + thisEntity = await this.getEntityIds(kysely, model, combinedWhere); if (!thisEntity) { if (throwIfNotFound) { throw new NotFoundError(model); @@ -1042,7 +1037,7 @@ export abstract class BaseOperationHandler { // nothing to update, return the filter so that the caller can identify the entity return combinedWhere; } else { - const idFields = getIdFields(this.schema, model); + const idFields = requireIdFields(this.schema, model); const query = kysely .updateTable(model) .where((eb) => this.dialect.buildFilter(eb, model, model, combinedWhere)) @@ -1107,7 +1102,7 @@ export abstract class BaseOperationHandler { if (!filter || typeof filter !== 'object') { return false; } - const idFields = getIdFields(this.schema, model); + const idFields = requireIdFields(this.schema, model); return idFields.length === Object.keys(filter).length && idFields.every((field) => field in filter); } @@ -1300,7 +1295,7 @@ export abstract class BaseOperationHandler { const result = await this.executeQuery(kysely, query, 'update'); return { count: Number(result.numAffectedRows) } as Result; } else { - const idFields = getIdFields(this.schema, model); + const idFields = requireIdFields(this.schema, model); const result = await query.returning(idFields as any).execute(); return result as Result; } @@ -1339,7 +1334,7 @@ export abstract class BaseOperationHandler { } private buildIdFieldRefs(kysely: ToKysely, model: GetModels) { - const idFields = getIdFields(this.schema, model); + const idFields = requireIdFields(this.schema, model); return idFields.map((f) => kysely.dynamic.ref(`${model}.${f}`)); } @@ -1352,7 +1347,6 @@ export abstract class BaseOperationHandler { args: any, throwIfNotFound: boolean, ) { - const tasks: Promise[] = []; const fieldModel = fieldDef.type as GetModels; const fromRelationContext: FromRelationContext = { model, @@ -1368,117 +1362,101 @@ export abstract class BaseOperationHandler { !Array.isArray(value) || fieldDef.array, 'relation must be an array if create is an array', ); - tasks.push( - ...enumerate(value).map((item) => this.create(kysely, fieldModel, item, fromRelationContext)), - ); + for (const item of enumerate(value)) { + await this.create(kysely, fieldModel, item, fromRelationContext); + } break; } case 'createMany': { invariant(fieldDef.array, 'relation must be an array for createMany'); - tasks.push( - this.createMany( - kysely, - fieldModel, - value as { data: any; skipDuplicates: boolean }, - false, - fromRelationContext, - ), + await this.createMany( + kysely, + fieldModel, + value as { data: any; skipDuplicates: boolean }, + false, + fromRelationContext, ); break; } case 'connect': { - tasks.push(this.connectRelation(kysely, fieldModel, value, fromRelationContext)); + await this.connectRelation(kysely, fieldModel, value, fromRelationContext); break; } case 'connectOrCreate': { - tasks.push(this.connectOrCreateRelation(kysely, fieldModel, value, fromRelationContext)); + await this.connectOrCreateRelation(kysely, fieldModel, value, fromRelationContext); break; } case 'disconnect': { - tasks.push(this.disconnectRelation(kysely, fieldModel, value, fromRelationContext)); + await this.disconnectRelation(kysely, fieldModel, value, fromRelationContext); break; } case 'set': { invariant(fieldDef.array, 'relation must be an array'); - tasks.push(this.setRelation(kysely, fieldModel, value, fromRelationContext)); + await this.setRelation(kysely, fieldModel, value, fromRelationContext); break; } case 'update': { - tasks.push( - ...(enumerate(value) as { where: any; data: any }[]).map((item) => { - let where; - let data; - if ('where' in item) { - where = item.where; - data = item.data; - } else { - where = undefined; - data = item; - } - return this.update( - kysely, - fieldModel, - where, - data, - fromRelationContext, - true, - throwIfNotFound, - ); - }), - ); + for (const _item of enumerate(value)) { + const item = _item as { where: any; data: any }; + let where; + let data; + if ('data' in item && typeof item.data === 'object') { + where = item.where; + data = item.data; + } else { + where = undefined; + data = item; + } + await this.update(kysely, fieldModel, where, data, fromRelationContext, true, throwIfNotFound); + } break; } case 'upsert': { - tasks.push( - ...( - enumerate(value) as { - where: any; - create: any; - update: any; - }[] - ).map(async (item) => { - const updated = await this.update( - kysely, - fieldModel, - item.where, - item.update, - fromRelationContext, - true, - false, - ); - if (updated) { - return updated; - } else { - return this.create(kysely, fieldModel, item.create, fromRelationContext); - } - }), - ); + for (const _item of enumerate(value)) { + const item = _item as { + where: any; + create: any; + update: any; + }; + + const updated = await this.update( + kysely, + fieldModel, + item.where, + item.update, + fromRelationContext, + true, + false, + ); + if (!updated) { + await this.create(kysely, fieldModel, item.create, fromRelationContext); + } + } break; } case 'updateMany': { - tasks.push( - ...(enumerate(value) as { where: any; data: any }[]).map((item) => - this.update(kysely, fieldModel, item.where, item.data, fromRelationContext, false, false), - ), - ); + for (const _item of enumerate(value)) { + const item = _item as { where: any; data: any }; + await this.update(kysely, fieldModel, item.where, item.data, fromRelationContext, false, false); + } break; } case 'delete': { - tasks.push(this.deleteRelation(kysely, fieldModel, value, fromRelationContext, true)); + await this.deleteRelation(kysely, fieldModel, value, fromRelationContext, true); break; } case 'deleteMany': { - tasks.push(this.deleteRelation(kysely, fieldModel, value, fromRelationContext, false)); + await this.deleteRelation(kysely, fieldModel, value, fromRelationContext, false); break; } @@ -1488,8 +1466,6 @@ export abstract class BaseOperationHandler { } } - await Promise.all(tasks); - return fromRelationContext.parentUpdates; } @@ -1509,9 +1485,13 @@ export abstract class BaseOperationHandler { const m2m = getManyToManyRelation(this.schema, fromRelation.model, fromRelation.field); if (m2m) { // handle many-to-many relation - const actions = _data.map(async (d) => { + const results: (unknown | undefined)[] = []; + for (const d of _data) { const ids = await this.getEntityIds(kysely, model, d); - return this.handleManyToManyRelation( + if (!ids) { + throw new NotFoundError(model); + } + const r = await this.handleManyToManyRelation( kysely, 'connect', fromRelation.model, @@ -1522,8 +1502,8 @@ export abstract class BaseOperationHandler { ids, m2m.joinTable, ); - }); - const results = await Promise.all(actions); + results.push(r); + } // validate connect result if (_data.length > results.filter((r) => !!r).length) { @@ -1608,16 +1588,14 @@ export abstract class BaseOperationHandler { return; } - return Promise.all( - _data.map(async ({ where, create }) => { - const existing = await this.exists(kysely, model, where); - if (existing) { - return this.connectRelation(kysely, model, [where], fromRelation); - } else { - return this.create(kysely, model, create, fromRelation); - } - }), - ); + for (const { where, create } of _data) { + const existing = await this.exists(kysely, model, where); + if (existing) { + await this.connectRelation(kysely, model, [where], fromRelation); + } else { + await this.create(kysely, model, create, fromRelation); + } + } } protected async disconnectRelation( @@ -1648,13 +1626,13 @@ export abstract class BaseOperationHandler { const m2m = getManyToManyRelation(this.schema, fromRelation.model, fromRelation.field); if (m2m) { // handle many-to-many relation - const actions = disconnectConditions.map(async (d) => { + for (const d of disconnectConditions) { const ids = await this.getEntityIds(kysely, model, d); if (!ids) { // not found return; } - return this.handleManyToManyRelation( + await this.handleManyToManyRelation( kysely, 'disconnect', fromRelation.model, @@ -1665,8 +1643,7 @@ export abstract class BaseOperationHandler { ids, m2m.joinTable, ); - }); - await Promise.all(actions); + } } else { const { ownedByModel, keyPairs } = getRelationForeignKeyFieldPairs( this.schema, @@ -1755,21 +1732,26 @@ export abstract class BaseOperationHandler { await this.resetManyToManyRelation(kysely, fromRelation.model, fromRelation.field, fromRelation.ids); // connect new entities - const actions = _data.map(async (d) => { + const results: (unknown | undefined)[] = []; + for (const d of _data) { const ids = await this.getEntityIds(kysely, model, d); - return this.handleManyToManyRelation( - kysely, - 'connect', - fromRelation.model, - fromRelation.field, - fromRelation.ids, - m2m.otherModel, - m2m.otherField, - ids, - m2m.joinTable, + if (!ids) { + throw new NotFoundError(model); + } + results.push( + await this.handleManyToManyRelation( + kysely, + 'connect', + fromRelation.model, + fromRelation.field, + fromRelation.ids, + m2m.otherModel, + m2m.otherField, + ids, + m2m.joinTable, + ), ); - }); - const results = await Promise.all(actions); + } // validate connect result if (_data.length > results.filter((r) => !!r).length) { @@ -2100,7 +2082,7 @@ export abstract class BaseOperationHandler { // reused the filter if it's a complete id filter (without extra fields) // otherwise, read the entity by the filter private getEntityIds(kysely: ToKysely, model: GetModels, uniqueFilter: any) { - const idFields: string[] = getIdFields(this.schema, model); + const idFields: string[] = requireIdFields(this.schema, model); if ( // all id fields are provided idFields.every((f) => f in uniqueFilter && uniqueFilter[f] !== undefined) && diff --git a/packages/runtime/src/client/crud/operations/create.ts b/packages/runtime/src/client/crud/operations/create.ts index bc15bb36..26206d99 100644 --- a/packages/runtime/src/client/crud/operations/create.ts +++ b/packages/runtime/src/client/crud/operations/create.ts @@ -1,5 +1,5 @@ import { match } from 'ts-pattern'; -import { RejectedByPolicyError } from '../../../plugins/policy/errors'; +import { RejectedByPolicyError, RejectedByPolicyReason } from '../../../plugins/policy/errors'; import type { GetModels, SchemaDef } from '../../../schema'; import type { CreateArgs, CreateManyAndReturnArgs, CreateManyArgs, WhereInput } from '../../crud-types'; import { getIdValues } from '../../query-utils'; @@ -40,7 +40,11 @@ export class CreateOperationHandler extends BaseOperat }); if (!result && this.hasPolicyEnabled) { - throw new RejectedByPolicyError(this.model, `result is not allowed to be read back`); + throw new RejectedByPolicyError( + this.model, + RejectedByPolicyReason.CANNOT_READ_BACK, + `result is not allowed to be read back`, + ); } return result; diff --git a/packages/runtime/src/client/crud/operations/delete.ts b/packages/runtime/src/client/crud/operations/delete.ts index 3ed17ce0..6eb1eca3 100644 --- a/packages/runtime/src/client/crud/operations/delete.ts +++ b/packages/runtime/src/client/crud/operations/delete.ts @@ -3,6 +3,7 @@ import type { SchemaDef } from '../../../schema'; import type { DeleteArgs, DeleteManyArgs } from '../../crud-types'; import { NotFoundError } from '../../errors'; import { BaseOperationHandler } from './base'; +import { RejectedByPolicyError, RejectedByPolicyReason } from '../../../plugins/policy'; export class DeleteOperationHandler extends BaseOperationHandler { async handle(operation: 'delete' | 'deleteMany', args: unknown | undefined) { @@ -24,9 +25,6 @@ export class DeleteOperationHandler extends BaseOperat omit: args.omit, where: args.where, }); - if (!existing) { - throw new NotFoundError(this.model); - } // TODO: avoid using transaction for simple delete await this.safeTransaction(async (tx) => { @@ -36,6 +34,14 @@ export class DeleteOperationHandler extends BaseOperat } }); + if (!existing && this.hasPolicyEnabled) { + throw new RejectedByPolicyError( + this.model, + RejectedByPolicyReason.CANNOT_READ_BACK, + 'result is not allowed to be read back', + ); + } + return existing; } diff --git a/packages/runtime/src/client/crud/operations/update.ts b/packages/runtime/src/client/crud/operations/update.ts index ea22c773..ad2fc613 100644 --- a/packages/runtime/src/client/crud/operations/update.ts +++ b/packages/runtime/src/client/crud/operations/update.ts @@ -1,5 +1,5 @@ import { match } from 'ts-pattern'; -import { RejectedByPolicyError } from '../../../plugins/policy/errors'; +import { RejectedByPolicyError, RejectedByPolicyReason } from '../../../plugins/policy/errors'; import type { GetModels, SchemaDef } from '../../../schema'; import type { UpdateArgs, UpdateManyAndReturnArgs, UpdateManyArgs, UpsertArgs, WhereInput } from '../../crud-types'; import { getIdValues } from '../../query-utils'; @@ -48,7 +48,11 @@ export class UpdateOperationHandler extends BaseOperat // update succeeded but result cannot be read back if (this.hasPolicyEnabled) { // if access policy is enabled, we assume it's due to read violation (not guaranteed though) - throw new RejectedByPolicyError(this.model, 'result is not allowed to be read back'); + throw new RejectedByPolicyError( + this.model, + RejectedByPolicyReason.CANNOT_READ_BACK, + 'result is not allowed to be read back', + ); } else { // this can happen if the entity is cascade deleted during the update, return null to // be consistent with Prisma even though it doesn't comply with the method signature @@ -71,16 +75,29 @@ export class UpdateOperationHandler extends BaseOperat return []; } - return this.safeTransaction(async (tx) => { + const { readBackResult, updateResult } = await this.safeTransaction(async (tx) => { const updateResult = await this.updateMany(tx, this.model, args.where, args.data, args.limit, true); - return this.read(tx, this.model, { + const readBackResult = await this.read(tx, this.model, { select: args.select, omit: args.omit, where: { OR: updateResult.map((item) => getIdValues(this.schema, this.model, item) as any), } as any, // TODO: fix type }); + + return { readBackResult, updateResult }; }); + + if (readBackResult.length < updateResult.length && this.hasPolicyEnabled) { + // some of the updated entities cannot be read back + throw new RejectedByPolicyError( + this.model, + RejectedByPolicyReason.CANNOT_READ_BACK, + 'result is not allowed to be read back', + ); + } + + return readBackResult; } private async runUpsert(args: UpsertArgs>) { @@ -113,7 +130,11 @@ export class UpdateOperationHandler extends BaseOperat }); if (!result && this.hasPolicyEnabled) { - throw new RejectedByPolicyError(this.model, 'result is not allowed to be read back'); + throw new RejectedByPolicyError( + this.model, + RejectedByPolicyReason.CANNOT_READ_BACK, + 'result is not allowed to be read back', + ); } return result; diff --git a/packages/runtime/src/client/crud/validator.ts b/packages/runtime/src/client/crud/validator.ts index 372129ff..beb31faf 100644 --- a/packages/runtime/src/client/crud/validator.ts +++ b/packages/runtime/src/client/crud/validator.ts @@ -2,10 +2,18 @@ import { invariant } from '@zenstackhq/common-helpers'; import Decimal from 'decimal.js'; import stableStringify from 'json-stable-stringify'; import { match, P } from 'ts-pattern'; -import { z, ZodType } from 'zod'; -import { type BuiltinType, type EnumDef, type FieldDef, type GetModels, type SchemaDef } from '../../schema'; +import { z, ZodSchema, ZodType } from 'zod'; +import { + type BuiltinType, + type EnumDef, + type FieldDef, + type GetModels, + type ModelDef, + type SchemaDef, +} from '../../schema'; import { enumerate } from '../../utils/enumerate'; import { extractFields } from '../../utils/object-utils'; +import { formatError } from '../../utils/zod-utils'; import { AGGREGATE_OPERATORS, LOGICAL_COMBINATORS, NUMERIC_FIELD_TYPES } from '../constants'; import { type AggregateArgs, @@ -185,7 +193,7 @@ export class InputValidator { } const { error } = schema.safeParse(args); if (error) { - throw new InputValidationError(`Invalid ${operation} args: ${error.message}`, error); + throw new InputValidationError(`Invalid ${operation} args: ${formatError(error)}`, error); } return args as T; } @@ -594,10 +602,18 @@ export class InputValidator { } } - const toManyRelations = Object.values(modelDef.fields).filter((def) => def.relation && def.array); + const _countSchema = this.makeCountSelectionSchema(modelDef); + if (_countSchema) { + fields['_count'] = _countSchema; + } + return z.strictObject(fields); + } + + private makeCountSelectionSchema(modelDef: ModelDef) { + const toManyRelations = Object.values(modelDef.fields).filter((def) => def.relation && def.array); if (toManyRelations.length > 0) { - fields['_count'] = z + return z .union([ z.literal(true), z.strictObject({ @@ -620,9 +636,9 @@ export class InputValidator { }), ]) .optional(); + } else { + return undefined; } - - return z.strictObject(fields); } private makeRelationSelectIncludeSchema(fieldDef: FieldDef) { @@ -676,6 +692,11 @@ export class InputValidator { } } + const _countSchema = this.makeCountSelectionSchema(modelDef); + if (_countSchema) { + fields['_count'] = _countSchema; + } + return z.strictObject(fields); } @@ -743,13 +764,15 @@ export class InputValidator { private makeCreateSchema(model: string) { const dataSchema = this.makeCreateDataSchema(model, false); - const schema = z.strictObject({ + let schema: ZodSchema = z.strictObject({ data: dataSchema, select: this.makeSelectSchema(model).optional(), include: this.makeIncludeSchema(model).optional(), omit: this.makeOmitSchema(model).optional(), }); - return this.refineForSelectIncludeMutuallyExclusive(schema); + schema = this.refineForSelectIncludeMutuallyExclusive(schema); + schema = this.refineForSelectOmitMutuallyExclusive(schema); + return schema; } private makeCreateManySchema(model: string) { @@ -913,7 +936,7 @@ export class InputValidator { fields['update'] = array ? this.orArray( z.strictObject({ - where: this.makeWhereSchema(fieldType, true), + where: this.makeWhereSchema(fieldType, true).optional(), data: this.makeUpdateDataSchema(fieldType, withoutFields), }), true, @@ -921,7 +944,7 @@ export class InputValidator { : z .union([ z.strictObject({ - where: this.makeWhereSchema(fieldType, true), + where: this.makeWhereSchema(fieldType, true).optional(), data: this.makeUpdateDataSchema(fieldType, withoutFields), }), this.makeUpdateDataSchema(fieldType, withoutFields), @@ -1005,14 +1028,16 @@ export class InputValidator { // #region Update private makeUpdateSchema(model: string) { - const schema = z.strictObject({ + let schema: ZodSchema = z.strictObject({ where: this.makeWhereSchema(model, true), data: this.makeUpdateDataSchema(model), select: this.makeSelectSchema(model).optional(), include: this.makeIncludeSchema(model).optional(), omit: this.makeOmitSchema(model).optional(), }); - return this.refineForSelectIncludeMutuallyExclusive(schema); + schema = this.refineForSelectIncludeMutuallyExclusive(schema); + schema = this.refineForSelectOmitMutuallyExclusive(schema); + return schema; } private makeUpdateManySchema(model: string) { @@ -1025,15 +1050,16 @@ export class InputValidator { private makeUpdateManyAndReturnSchema(model: string) { const base = this.makeUpdateManySchema(model); - const result = base.extend({ + let schema: ZodSchema = base.extend({ select: this.makeSelectSchema(model).optional(), omit: this.makeOmitSchema(model).optional(), }); - return this.refineForSelectOmitMutuallyExclusive(result); + schema = this.refineForSelectOmitMutuallyExclusive(schema); + return schema; } private makeUpsertSchema(model: string) { - const schema = z.strictObject({ + let schema: ZodSchema = z.strictObject({ where: this.makeWhereSchema(model, true), create: this.makeCreateDataSchema(model, false), update: this.makeUpdateDataSchema(model), @@ -1041,7 +1067,9 @@ export class InputValidator { include: this.makeIncludeSchema(model).optional(), omit: this.makeOmitSchema(model).optional(), }); - return this.refineForSelectIncludeMutuallyExclusive(schema); + schema = this.refineForSelectIncludeMutuallyExclusive(schema); + schema = this.refineForSelectOmitMutuallyExclusive(schema); + return schema; } private makeUpdateDataSchema(model: string, withoutFields: string[] = [], withoutRelationFields = false) { @@ -1145,12 +1173,14 @@ export class InputValidator { // #region Delete private makeDeleteSchema(model: GetModels) { - const schema = z.strictObject({ + let schema: ZodSchema = z.strictObject({ where: this.makeWhereSchema(model, true), select: this.makeSelectSchema(model).optional(), include: this.makeIncludeSchema(model).optional(), }); - return this.refineForSelectIncludeMutuallyExclusive(schema); + schema = this.refineForSelectIncludeMutuallyExclusive(schema); + schema = this.refineForSelectOmitMutuallyExclusive(schema); + return schema; } private makeDeleteManySchema(model: GetModels) { diff --git a/packages/runtime/src/client/executor/kysely-utils.ts b/packages/runtime/src/client/executor/kysely-utils.ts deleted file mode 100644 index fb9ec845..00000000 --- a/packages/runtime/src/client/executor/kysely-utils.ts +++ /dev/null @@ -1,12 +0,0 @@ -import { type OperationNode, AliasNode } from 'kysely'; - -/** - * Strips alias from the node if it exists. - */ -export function stripAlias(node: OperationNode) { - if (AliasNode.is(node)) { - return { alias: node.alias, node: node.node }; - } else { - return { alias: undefined, node }; - } -} diff --git a/packages/runtime/src/client/executor/name-mapper.ts b/packages/runtime/src/client/executor/name-mapper.ts index c839bc75..83ef8a33 100644 --- a/packages/runtime/src/client/executor/name-mapper.ts +++ b/packages/runtime/src/client/executor/name-mapper.ts @@ -17,8 +17,8 @@ import { type OperationNode, } from 'kysely'; import type { FieldDef, ModelDef, SchemaDef } from '../../schema'; +import { extractFieldName, extractModelName, stripAlias } from '../kysely-utils'; import { getModel, requireModel } from '../query-utils'; -import { stripAlias } from './kysely-utils'; type Scope = { model?: string; @@ -170,7 +170,7 @@ export class QueryNameMapper extends OperationNodeTransformer { const scopes: Scope[] = node.from.froms.map((node) => { const { alias, node: innerNode } = stripAlias(node); return { - model: this.extractModelName(innerNode), + model: extractModelName(innerNode), alias, namesMapped: false, }; @@ -219,8 +219,8 @@ export class QueryNameMapper extends OperationNodeTransformer { selections.push(SelectionNode.create(transformed)); } else { // otherwise use an alias to preserve the original field name - const origFieldName = this.extractFieldName(selection.selection); - const fieldName = this.extractFieldName(transformed); + const origFieldName = extractFieldName(selection.selection); + const fieldName = extractFieldName(transformed); if (fieldName !== origFieldName) { selections.push( SelectionNode.create( @@ -425,7 +425,7 @@ export class QueryNameMapper extends OperationNodeTransformer { private processSelection(node: AliasNode | ColumnNode | ReferenceNode) { let alias: string | undefined; if (!AliasNode.is(node)) { - alias = this.extractFieldName(node); + alias = extractFieldName(node); } const result = super.transformNode(node); return this.wrapAlias(result, alias ? IdentifierNode.create(alias) : undefined); @@ -451,20 +451,5 @@ export class QueryNameMapper extends OperationNodeTransformer { }); } - private extractModelName(node: OperationNode): string | undefined { - const { node: innerNode } = stripAlias(node); - return TableNode.is(innerNode!) ? innerNode!.table.identifier.name : undefined; - } - - private extractFieldName(node: ReferenceNode | ColumnNode) { - if (ReferenceNode.is(node) && ColumnNode.is(node.column)) { - return node.column.column.name; - } else if (ColumnNode.is(node)) { - return node.column.name; - } else { - return undefined; - } - } - // #endregion } diff --git a/packages/runtime/src/client/executor/zenstack-query-executor.ts b/packages/runtime/src/client/executor/zenstack-query-executor.ts index be317924..2d2395cb 100644 --- a/packages/runtime/src/client/executor/zenstack-query-executor.ts +++ b/packages/runtime/src/client/executor/zenstack-query-executor.ts @@ -26,8 +26,8 @@ import type { GetModels, SchemaDef } from '../../schema'; import { type ClientImpl } from '../client-impl'; import { TransactionIsolationLevel, type ClientContract } from '../contract'; import { InternalError, QueryError } from '../errors'; +import { stripAlias } from '../kysely-utils'; import type { AfterEntityMutationCallback, OnKyselyQueryCallback } from '../plugin'; -import { stripAlias } from './kysely-utils'; import { QueryNameMapper } from './name-mapper'; import type { ZenStackDriver } from './zenstack-driver'; diff --git a/packages/runtime/src/client/functions.ts b/packages/runtime/src/client/functions.ts index 6b548e8d..35390916 100644 --- a/packages/runtime/src/client/functions.ts +++ b/packages/runtime/src/client/functions.ts @@ -1,49 +1,68 @@ import { invariant, lowerCaseFirst, upperCaseFirst } from '@zenstackhq/common-helpers'; -import { sql, ValueNode, type Expression, type ExpressionBuilder } from 'kysely'; +import { sql, ValueNode, type BinaryOperator, type Expression, type ExpressionBuilder } from 'kysely'; import { match } from 'ts-pattern'; import type { ZModelFunction, ZModelFunctionContext } from './options'; // TODO: migrate default value generation functions to here too -export const contains: ZModelFunction = (eb: ExpressionBuilder, args: Expression[]) => { - const [field, search, caseInsensitive = false] = args; - if (!field) { - throw new Error('"field" parameter is required'); - } - if (!search) { - throw new Error('"search" parameter is required'); - } - const searchExpr = eb.fn('CONCAT', [sql.lit('%'), search, sql.lit('%')]); - return eb(field, caseInsensitive ? 'ilike' : 'like', searchExpr); -}; +export const contains: ZModelFunction = (eb, args, context) => textMatch(eb, args, context, 'contains'); export const search: ZModelFunction = (_eb: ExpressionBuilder, _args: Expression[]) => { throw new Error(`"search" function is not implemented yet`); }; -export const startsWith: ZModelFunction = (eb: ExpressionBuilder, args: Expression[]) => { - const [field, search] = args; +export const startsWith: ZModelFunction = (eb, args, context) => textMatch(eb, args, context, 'startsWith'); + +export const endsWith: ZModelFunction = (eb, args, context) => textMatch(eb, args, context, 'endsWith'); + +const textMatch = ( + eb: ExpressionBuilder, + args: Expression[], + { dialect }: ZModelFunctionContext, + method: 'contains' | 'startsWith' | 'endsWith', +) => { + const [field, search, caseInsensitive = undefined] = args; if (!field) { throw new Error('"field" parameter is required'); } if (!search) { throw new Error('"search" parameter is required'); } - return eb(field, 'like', eb.fn('CONCAT', [search, sql.lit('%')])); -}; -export const endsWith: ZModelFunction = (eb: ExpressionBuilder, args: Expression[]) => { - const [field, search] = args; - if (!field) { - throw new Error('"field" parameter is required'); - } - if (!search) { - throw new Error('"search" parameter is required'); + const casingBehavior = dialect.getStringCasingBehavior(); + const caseInsensitiveValue = readBoolean(caseInsensitive, false); + let op: BinaryOperator; + let fieldExpr = field; + let searchExpr = search; + + if (caseInsensitiveValue) { + // case-insensitive search + if (casingBehavior.supportsILike) { + // use ILIKE if supported + op = 'ilike'; + } else { + // otherwise change both sides to lower case + op = 'like'; + if (casingBehavior.likeCaseSensitive === true) { + fieldExpr = eb.fn('LOWER', [fieldExpr]); + searchExpr = eb.fn('LOWER', [searchExpr]); + } + } + } else { + // case-sensitive search, just use LIKE and deliver whatever the database's behavior is + op = 'like'; } - return eb(field, 'like', eb.fn('CONCAT', [sql.lit('%'), search])); + + searchExpr = match(method) + .with('contains', () => eb.fn('CONCAT', [sql.lit('%'), sql`CAST(${searchExpr} as text)`, sql.lit('%')])) + .with('startsWith', () => eb.fn('CONCAT', [sql`CAST(${searchExpr} as text)`, sql.lit('%')])) + .with('endsWith', () => eb.fn('CONCAT', [sql.lit('%'), sql`CAST(${searchExpr} as text)`])) + .exhaustive(); + + return eb(fieldExpr, op, searchExpr); }; -export const has: ZModelFunction = (eb: ExpressionBuilder, args: Expression[]) => { +export const has: ZModelFunction = (eb, args) => { const [field, search] = args; if (!field) { throw new Error('"field" parameter is required'); @@ -65,7 +84,7 @@ export const hasEvery: ZModelFunction = (eb: ExpressionBuilder, a return eb(field, '@>', search); }; -export const hasSome: ZModelFunction = (eb: ExpressionBuilder, args: Expression[]) => { +export const hasSome: ZModelFunction = (eb, args) => { const [field, search] = args; if (!field) { throw new Error('"field" parameter is required'); @@ -76,11 +95,7 @@ export const hasSome: ZModelFunction = (eb: ExpressionBuilder, ar return eb(field, '&&', search); }; -export const isEmpty: ZModelFunction = ( - eb: ExpressionBuilder, - args: Expression[], - { dialect }: ZModelFunctionContext, -) => { +export const isEmpty: ZModelFunction = (eb, args, { dialect }: ZModelFunctionContext) => { const [field] = args; if (!field) { throw new Error('"field" parameter is required'); @@ -88,22 +103,9 @@ export const isEmpty: ZModelFunction = ( return eb(dialect.buildArrayLength(eb, field), '=', sql.lit(0)); }; -export const now: ZModelFunction = ( - eb: ExpressionBuilder, - _args: Expression[], - { dialect }: ZModelFunctionContext, -) => { - return match(dialect.provider) - .with('postgresql', () => eb.fn('now')) - .with('sqlite', () => sql.raw('CURRENT_TIMESTAMP')) - .exhaustive(); -}; +export const now: ZModelFunction = () => sql.raw('CURRENT_TIMESTAMP'); -export const currentModel: ZModelFunction = ( - _eb: ExpressionBuilder, - args: Expression[], - { model }: ZModelFunctionContext, -) => { +export const currentModel: ZModelFunction = (_eb, args, { model }: ZModelFunctionContext) => { let result = model; const [casing] = args; if (casing) { @@ -112,11 +114,7 @@ export const currentModel: ZModelFunction = ( return sql.lit(result); }; -export const currentOperation: ZModelFunction = ( - _eb: ExpressionBuilder, - args: Expression[], - { operation }: ZModelFunctionContext, -) => { +export const currentOperation: ZModelFunction = (_eb, args, { operation }: ZModelFunctionContext) => { let result: string = operation; const [casing] = args; if (casing) { @@ -141,3 +139,12 @@ function processCasing(casing: Expression, result: string, model: string) { }); return result; } + +function readBoolean(expr: Expression | undefined, defaultValue: boolean) { + if (expr === undefined) { + return defaultValue; + } + const opNode = expr.toOperationNode(); + invariant(ValueNode.is(opNode), 'expression must be a literal value'); + return !!opNode.value; +} diff --git a/packages/runtime/src/client/helpers/schema-db-pusher.ts b/packages/runtime/src/client/helpers/schema-db-pusher.ts index 781c131d..9e855398 100644 --- a/packages/runtime/src/client/helpers/schema-db-pusher.ts +++ b/packages/runtime/src/client/helpers/schema-db-pusher.ts @@ -29,7 +29,8 @@ export class SchemaDbPusher { } // sort models so that target of fk constraints are created first - const sortedModels = this.sortModels(this.schema.models); + const models = Object.values(this.schema.models).filter((m) => !m.isView); + const sortedModels = this.sortModels(models); for (const modelDef of sortedModels) { const createTable = this.createModelTable(tx, modelDef); await createTable.execute(); @@ -37,10 +38,10 @@ export class SchemaDbPusher { }); } - private sortModels(models: Record): ModelDef[] { + private sortModels(models: ModelDef[]): ModelDef[] { const graph: [ModelDef, ModelDef | undefined][] = []; - for (const model of Object.values(models)) { + for (const model of models) { let added = false; if (model.baseModel) { diff --git a/packages/runtime/src/client/kysely-utils.ts b/packages/runtime/src/client/kysely-utils.ts new file mode 100644 index 00000000..a46464c3 --- /dev/null +++ b/packages/runtime/src/client/kysely-utils.ts @@ -0,0 +1,33 @@ +import { type OperationNode, AliasNode, ColumnNode, ReferenceNode, TableNode } from 'kysely'; + +/** + * Strips alias from the node if it exists. + */ +export function stripAlias(node: OperationNode) { + if (AliasNode.is(node)) { + return { alias: node.alias, node: node.node }; + } else { + return { alias: undefined, node }; + } +} + +/** + * Extracts model name from an OperationNode. + */ +export function extractModelName(node: OperationNode) { + const { node: innerNode } = stripAlias(node); + return TableNode.is(innerNode!) ? innerNode!.table.identifier.name : undefined; +} + +/** + * Extracts field name from an OperationNode. + */ +export function extractFieldName(node: OperationNode) { + if (ReferenceNode.is(node) && ColumnNode.is(node.column)) { + return node.column.column.name; + } else if (ColumnNode.is(node)) { + return node.column.name; + } else { + return undefined; + } +} diff --git a/packages/runtime/src/client/options.ts b/packages/runtime/src/client/options.ts index 7c90e330..7d1134a3 100644 --- a/packages/runtime/src/client/options.ts +++ b/packages/runtime/src/client/options.ts @@ -7,8 +7,29 @@ import type { RuntimePlugin } from './plugin'; import type { ToKyselySchema } from './query-builder'; export type ZModelFunctionContext = { + /** + * ZenStack client instance + */ + client: ClientContract; + + /** + * Database dialect + */ dialect: BaseCrudDialect; + + /** + * The containing model name + */ model: GetModels; + + /** + * The alias name that can be used to refer to the containing model + */ + modelAlias: string; + + /** + * The CRUD operation being performed + */ operation: CRUD; }; @@ -41,6 +62,16 @@ export type ClientOptions = { * Logging configuration. */ log?: KyselyConfig['log']; + + /** + * Whether to automatically fix timezone for `DateTime` fields returned by node-pg. Defaults + * to `true`. + * + * Node-pg has a terrible quirk that it interprets the date value as local timezone (as a + * `Date` object) although for `DateTime` field the data in DB is stored in UTC. + * @see https://github.com/brianc/node-postgres/issues/429 + */ + fixPostgresTimezone?: boolean; } & (HasComputedFields extends true ? { /** diff --git a/packages/runtime/src/client/plugin.ts b/packages/runtime/src/client/plugin.ts index 62216a3d..eda9e4a7 100644 --- a/packages/runtime/src/client/plugin.ts +++ b/packages/runtime/src/client/plugin.ts @@ -3,6 +3,7 @@ import type { ClientContract } from '.'; import type { GetModels, SchemaDef } from '../schema'; import type { MaybePromise } from '../utils/type-utils'; import type { AllCrudOperation } from './crud/operations/base'; +import type { ZModelFunction } from './options'; /** * ZenStack runtime plugin. @@ -23,6 +24,11 @@ export interface RuntimePlugin { */ description?: string; + /** + * Custom function implementations. + */ + functions?: Record>; + /** * Intercepts an ORM query. */ diff --git a/packages/runtime/src/client/query-utils.ts b/packages/runtime/src/client/query-utils.ts index fdce2aaf..1cfbdd14 100644 --- a/packages/runtime/src/client/query-utils.ts +++ b/packages/runtime/src/client/query-utils.ts @@ -1,3 +1,4 @@ +import { invariant } from '@zenstackhq/common-helpers'; import type { Expression, ExpressionBuilder, ExpressionWrapper } from 'kysely'; import { match } from 'ts-pattern'; import { ExpressionUtils, type FieldDef, type GetModels, type ModelDef, type SchemaDef } from '../schema'; @@ -55,8 +56,8 @@ export function requireField(schema: SchemaDef, modelOrType: string, field: stri } export function getIdFields(schema: SchemaDef, model: GetModels) { - const modelDef = requireModel(schema, model); - return modelDef?.idFields as GetModels[]; + const modelDef = getModel(schema, model); + return modelDef?.idFields; } export function requireIdFields(schema: SchemaDef, model: string) { @@ -194,7 +195,7 @@ export function buildFieldRef( if (!computer) { throw new QueryError(`Computed field "${field}" implementation not provided for model "${model}"`); } - return computer(eb, { currentModel: modelAlias }); + return computer(eb, { modelAlias }); } } @@ -231,7 +232,7 @@ export function buildJoinPairs( } export function makeDefaultOrderBy(schema: SchemaDef, model: string) { - const idFields = getIdFields(schema, model); + const idFields = requireIdFields(schema, model); return idFields.map((f) => ({ [f]: 'asc' }) as OrderBy, true, false>); } @@ -259,11 +260,18 @@ export function getManyToManyRelation(schema: SchemaDef, model: string, field: s orderedFK = sortedFieldNames[0] === field ? ['A', 'B'] : ['B', 'A']; } + const modelIdFields = requireIdFields(schema, model); + invariant(modelIdFields.length === 1, 'Only single-field ID is supported for many-to-many relation'); + const otherIdFields = requireIdFields(schema, fieldDef.type); + invariant(otherIdFields.length === 1, 'Only single-field ID is supported for many-to-many relation'); + return { parentFkName: orderedFK[0], + parentPKName: modelIdFields[0]!, otherModel: fieldDef.type, otherField: fieldDef.relation.opposite, otherFkName: orderedFK[1], + otherPKName: otherIdFields[0]!, joinTable: fieldDef.relation.name ? `_${fieldDef.relation.name}` : `_${sortedModelNames[0]}To${sortedModelNames[1]}`, @@ -318,7 +326,7 @@ export function safeJSONStringify(value: unknown) { } export function extractIdFields(entity: any, schema: SchemaDef, model: string) { - const idFields = getIdFields(schema, model); + const idFields = requireIdFields(schema, model); return extractFields(entity, idFields); } diff --git a/packages/runtime/src/client/result-processor.ts b/packages/runtime/src/client/result-processor.ts index 96b3de64..a7870bab 100644 --- a/packages/runtime/src/client/result-processor.ts +++ b/packages/runtime/src/client/result-processor.ts @@ -1,12 +1,18 @@ -import { invariant } from '@zenstackhq/common-helpers'; -import Decimal from 'decimal.js'; -import { match } from 'ts-pattern'; import type { BuiltinType, FieldDef, GetModels, SchemaDef } from '../schema'; import { DELEGATE_JOINED_FIELD_PREFIX } from './constants'; +import { getCrudDialect } from './crud/dialects'; +import type { BaseCrudDialect } from './crud/dialects/base-dialect'; +import type { ClientOptions } from './options'; import { ensureArray, getField, getIdValues } from './query-utils'; export class ResultProcessor { - constructor(private readonly schema: Schema) {} + private dialect: BaseCrudDialect; + constructor( + private readonly schema: Schema, + options: ClientOptions, + ) { + this.dialect = getCrudDialect(schema, options); + } processResult(data: any, model: GetModels, args?: any) { const result = this.doProcessResult(data, model); @@ -43,7 +49,7 @@ export class ResultProcessor { // merge delegate descendant fields if (value) { // descendant fields are packed as JSON - const subRow = this.transformJson(value); + const subRow = this.dialect.transformOutput(value, 'Json'); // process the sub-row const subModel = key.slice(DELEGATE_JOINED_FIELD_PREFIX.length) as GetModels; @@ -87,10 +93,10 @@ export class ResultProcessor { private processFieldValue(value: unknown, fieldDef: FieldDef) { const type = fieldDef.type as BuiltinType; if (Array.isArray(value)) { - value.forEach((v, i) => (value[i] = this.transformScalar(v, type))); + value.forEach((v, i) => (value[i] = this.dialect.transformOutput(v, type))); return value; } else { - return this.transformScalar(value, type); + return this.dialect.transformOutput(value, type); } } @@ -107,62 +113,6 @@ export class ResultProcessor { return this.doProcessResult(relationData, fieldDef.type as GetModels); } - private transformScalar(value: unknown, type: BuiltinType) { - if (this.schema.typeDefs && type in this.schema.typeDefs) { - // typed JSON field - return this.transformJson(value); - } else { - return match(type) - .with('Boolean', () => this.transformBoolean(value)) - .with('DateTime', () => this.transformDate(value)) - .with('Bytes', () => this.transformBytes(value)) - .with('Decimal', () => this.transformDecimal(value)) - .with('BigInt', () => this.transformBigInt(value)) - .with('Json', () => this.transformJson(value)) - .otherwise(() => value); - } - } - - private transformDecimal(value: unknown) { - if (value instanceof Decimal) { - return value; - } - invariant( - typeof value === 'string' || typeof value === 'number' || value instanceof Decimal, - `Expected string, number or Decimal, got ${typeof value}`, - ); - return new Decimal(value); - } - - private transformBigInt(value: unknown) { - if (typeof value === 'bigint') { - return value; - } - invariant( - typeof value === 'string' || typeof value === 'number', - `Expected string or number, got ${typeof value}`, - ); - return BigInt(value); - } - - private transformBoolean(value: unknown) { - return !!value; - } - - private transformDate(value: unknown) { - if (typeof value === 'number') { - return new Date(value); - } else if (typeof value === 'string') { - return new Date(Date.parse(value)); - } else { - return value; - } - } - - private transformBytes(value: unknown) { - return Buffer.isBuffer(value) ? Uint8Array.from(value) : value; - } - private fixReversedResult(data: any, model: GetModels, args: any) { if (!data) { return; @@ -190,14 +140,4 @@ export class ResultProcessor { } } } - - private transformJson(value: unknown) { - return match(this.schema.provider.type) - .with('sqlite', () => { - // better-sqlite3 returns JSON as string - invariant(typeof value === 'string', 'Expected string, got ' + typeof value); - return JSON.parse(value as string); - }) - .otherwise(() => value); - } } diff --git a/packages/runtime/src/plugins/policy/errors.ts b/packages/runtime/src/plugins/policy/errors.ts index df1feab6..675506d6 100644 --- a/packages/runtime/src/plugins/policy/errors.ts +++ b/packages/runtime/src/plugins/policy/errors.ts @@ -1,11 +1,32 @@ +/** + * Reason code for policy rejection. + */ +export enum RejectedByPolicyReason { + /** + * Rejected because the operation is not allowed by policy. + */ + NO_ACCESS = 'no-access', + + /** + * Rejected because the result cannot be read back after mutation due to policy. + */ + CANNOT_READ_BACK = 'cannot-read-back', + + /** + * Other reasons. + */ + OTHER = 'other', +} + /** * Error thrown when an operation is rejected by access policy. */ export class RejectedByPolicyError extends Error { constructor( public readonly model: string | undefined, - public readonly reason?: string, + public readonly reason: RejectedByPolicyReason = RejectedByPolicyReason.NO_ACCESS, + message?: string, ) { - super(reason ?? `Operation rejected by policy${model ? ': ' + model : ''}`); + super(message ?? `Operation rejected by policy${model ? ': ' + model : ''}`); } } diff --git a/packages/runtime/src/plugins/policy/expression-transformer.ts b/packages/runtime/src/plugins/policy/expression-transformer.ts index 9cf81ccc..414b72b4 100644 --- a/packages/runtime/src/plugins/policy/expression-transformer.ts +++ b/packages/runtime/src/plugins/policy/expression-transformer.ts @@ -20,12 +20,17 @@ import { type OperationNode, } from 'kysely'; import { match } from 'ts-pattern'; -import type { CRUD } from '../../client/contract'; +import type { ClientContract, CRUD } from '../../client/contract'; import { getCrudDialect } from '../../client/crud/dialects'; import type { BaseCrudDialect } from '../../client/crud/dialects/base-dialect'; import { InternalError, QueryError } from '../../client/errors'; -import type { ClientOptions } from '../../client/options'; -import { getModel, getRelationForeignKeyFieldPairs, requireField } from '../../client/query-utils'; +import { + getManyToManyRelation, + getModel, + getRelationForeignKeyFieldPairs, + requireField, + requireIdFields, +} from '../../client/query-utils'; import type { BinaryExpression, BinaryOperator, @@ -45,7 +50,7 @@ import { type SchemaDef, } from '../../schema'; import { ExpressionEvaluator } from './expression-evaluator'; -import { conjunction, disjunction, logicalNot, trueNode } from './utils'; +import { conjunction, disjunction, falseNode, logicalNot, trueNode } from './utils'; export type ExpressionTransformerContext = { model: GetModels; @@ -72,14 +77,22 @@ function expr(kind: Expression['kind']) { export class ExpressionTransformer { private readonly dialect: BaseCrudDialect; - constructor( - private readonly schema: Schema, - private readonly clientOptions: ClientOptions, - private readonly auth: unknown | undefined, - ) { + constructor(private readonly client: ClientContract) { this.dialect = getCrudDialect(this.schema, this.clientOptions); } + get schema() { + return this.client.$schema; + } + + get clientOptions() { + return this.client.$options; + } + + get auth() { + return this.client.$auth; + } + get authType() { if (!this.schema.authType) { throw new InternalError('Schema does not have an "authType" specified'); @@ -111,7 +124,6 @@ export class ExpressionTransformer { } @expr('field') - // @ts-expect-error private _field(expr: FieldExpression, context: ExpressionTransformerContext) { const fieldDef = requireField(this.schema, context.model, expr.field); if (!fieldDef.relation) { @@ -162,8 +174,9 @@ export class ExpressionTransformer { return this.transformCollectionPredicate(expr, context); } - const left = this.transform(expr.left, context); - const right = this.transform(expr.right, context); + const { normalizedLeft, normalizedRight } = this.normalizeBinaryOperationOperands(expr, context); + const left = this.transform(normalizedLeft, context); + const right = this.transform(normalizedRight, context); if (op === 'in') { if (this.isNullNode(left)) { @@ -183,16 +196,49 @@ export class ExpressionTransformer { } if (this.isNullNode(right)) { - return expr.op === '==' - ? BinaryOperationNode.create(left, OperatorNode.create('is'), right) - : BinaryOperationNode.create(left, OperatorNode.create('is not'), right); + return this.transformNullCheck(left, expr.op); } else if (this.isNullNode(left)) { - return expr.op === '==' - ? BinaryOperationNode.create(right, OperatorNode.create('is'), ValueNode.createImmediate(null)) - : BinaryOperationNode.create(right, OperatorNode.create('is not'), ValueNode.createImmediate(null)); + return this.transformNullCheck(right, expr.op); + } else { + return BinaryOperationNode.create(left, this.transformOperator(op), right); } + } - return BinaryOperationNode.create(left, this.transformOperator(op), right); + private transformNullCheck(expr: OperationNode, operator: BinaryOperator) { + invariant(operator === '==' || operator === '!=', 'operator must be "==" or "!=" for null comparison'); + if (ValueNode.is(expr)) { + if (expr.value === null) { + return operator === '==' ? trueNode(this.dialect) : falseNode(this.dialect); + } else { + return operator === '==' ? falseNode(this.dialect) : trueNode(this.dialect); + } + } else { + return operator === '==' + ? BinaryOperationNode.create(expr, OperatorNode.create('is'), ValueNode.createImmediate(null)) + : BinaryOperationNode.create(expr, OperatorNode.create('is not'), ValueNode.createImmediate(null)); + } + } + + private normalizeBinaryOperationOperands(expr: BinaryExpression, context: ExpressionTransformerContext) { + // if relation fields are used directly in comparison, it can only be compared with null, + // so we normalize the args with the id field (use the first id field if multiple) + let normalizedLeft: Expression = expr.left; + if (this.isRelationField(expr.left, context.model)) { + invariant(ExpressionUtils.isNull(expr.right), 'only null comparison is supported for relation field'); + const leftRelDef = this.getFieldDefFromFieldRef(expr.left, context.model); + invariant(leftRelDef, 'failed to get relation field definition'); + const idFields = requireIdFields(this.schema, leftRelDef.type); + normalizedLeft = this.makeOrAppendMember(normalizedLeft, idFields[0]!); + } + let normalizedRight: Expression = expr.right; + if (this.isRelationField(expr.right, context.model)) { + invariant(ExpressionUtils.isNull(expr.left), 'only null comparison is supported for relation field'); + const rightRelDef = this.getFieldDefFromFieldRef(expr.right, context.model); + invariant(rightRelDef, 'failed to get relation field definition'); + const idFields = requireIdFields(this.schema, rightRelDef.type); + normalizedRight = this.makeOrAppendMember(normalizedRight, idFields[0]!); + } + return { normalizedLeft, normalizedRight }; } private transformCollectionPredicate(expr: BinaryExpression, context: ExpressionTransformerContext) { @@ -211,11 +257,15 @@ export class ExpressionTransformer { ); let newContextModel: string; - if (ExpressionUtils.isField(expr.left)) { - const fieldDef = requireField(this.schema, context.model, expr.left.field); + const fieldDef = this.getFieldDefFromFieldRef(expr.left, context.model); + if (fieldDef) { + invariant(fieldDef.relation, `field is not a relation: ${JSON.stringify(expr.left)}`); newContextModel = fieldDef.type; } else { - invariant(ExpressionUtils.isField(expr.left.receiver)); + invariant( + ExpressionUtils.isMember(expr.left) && ExpressionUtils.isField(expr.left.receiver), + 'left operand must be member access with field receiver', + ); const fieldDef = requireField(this.schema, context.model, expr.left.receiver.field); newContextModel = fieldDef.type; for (const member of expr.left.members) { @@ -281,11 +331,12 @@ export class ExpressionTransformer { .map((f) => f.name); invariant(idFields.length > 0, 'auth type model must have at least one id field'); + // convert `auth() == other` into `auth().id == other.id` const conditions = idFields.map((fieldName) => ExpressionUtils.binary( ExpressionUtils.member(authExpr, [fieldName]), '==', - ExpressionUtils.member(other, [fieldName]), + this.makeOrAppendMember(other, fieldName), ), ); let result = this.buildAnd(conditions); @@ -296,8 +347,22 @@ export class ExpressionTransformer { } } + private makeOrAppendMember(other: Expression, fieldName: string): Expression { + if (ExpressionUtils.isMember(other)) { + return ExpressionUtils.member(other.receiver, [...other.members, fieldName]); + } else { + return ExpressionUtils.member(other, [fieldName]); + } + } + private transformValue(value: unknown, type: BuiltinType) { - return ValueNode.create(this.dialect.transformPrimitive(value, type, false) ?? null); + if (value === true) { + return trueNode(this.dialect); + } else if (value === false) { + return falseNode(this.dialect); + } else { + return ValueNode.create(this.dialect.transformPrimitive(value, type, false) ?? null); + } } @expr('unary') @@ -323,7 +388,7 @@ export class ExpressionTransformer { } private transformCall(expr: CallExpression, context: ExpressionTransformerContext) { - const func = this.clientOptions.functions?.[expr.function]; + const func = this.getFunctionImpl(expr.function); if (!func) { throw new QueryError(`Function not implemented: ${expr.function}`); } @@ -332,13 +397,30 @@ export class ExpressionTransformer { eb, (expr.args ?? []).map((arg) => this.transformCallArg(eb, arg, context)), { + client: this.client, dialect: this.dialect, model: context.model, + modelAlias: context.alias ?? context.model, operation: context.operation, }, ); } + private getFunctionImpl(functionName: string) { + // check built-in functions + let func = this.clientOptions.functions?.[functionName]; + if (!func) { + // check plugins + for (const plugin of this.clientOptions.plugins ?? []) { + if (plugin.functions?.[functionName]) { + func = plugin.functions[functionName]; + break; + } + } + } + return func; + } + private transformCallArg( eb: ExpressionBuilder, arg: Expression, @@ -387,16 +469,14 @@ export class ExpressionTransformer { if (ExpressionUtils.isThis(expr.receiver)) { if (expr.members.length === 1) { - // optimize for the simple this.scalar case - const fieldDef = requireField(this.schema, context.model, expr.members[0]!); - invariant(!fieldDef.relation, 'this.relation access should have been transformed into relation access'); - return this.createColumnRef(expr.members[0]!, restContext); + // `this.relation` case, equivalent to field access + return this._field(ExpressionUtils.field(expr.members[0]!), context); + } else { + // transform the first segment into a relation access, then continue with the rest of the members + const firstMemberFieldDef = requireField(this.schema, context.model, expr.members[0]!); + receiver = this.transformRelationAccess(expr.members[0]!, firstMemberFieldDef.type, restContext); + members = expr.members.slice(1); } - - // transform the first segment into a relation access, then continue with the rest of the members - const firstMemberFieldDef = requireField(this.schema, context.model, expr.members[0]!); - receiver = this.transformRelationAccess(expr.members[0]!, firstMemberFieldDef.type, restContext); - members = expr.members.slice(1); } else { receiver = this.transform(expr.receiver, restContext); } @@ -484,6 +564,11 @@ export class ExpressionTransformer { relationModel: string, context: ExpressionTransformerContext, ): SelectQueryNode { + const m2m = getManyToManyRelation(this.schema, context.model, field); + if (m2m) { + return this.transformManyToManyRelationAccess(m2m, context); + } + const fromModel = context.model; const { keyPairs, ownedByModel } = getRelationForeignKeyFieldPairs(this.schema, fromModel, field); @@ -521,6 +606,28 @@ export class ExpressionTransformer { }; } + private transformManyToManyRelationAccess( + m2m: NonNullable>, + context: ExpressionTransformerContext, + ) { + const eb = expressionBuilder(); + const relationQuery = eb + .selectFrom(m2m.otherModel) + // inner join with join table and additionally filter by the parent model + .innerJoin(m2m.joinTable, (join) => + join + // relation model pk to join table fk + .onRef(`${m2m.otherModel}.${m2m.otherPKName}`, '=', `${m2m.joinTable}.${m2m.otherFkName}`) + // parent model pk to join table fk + .onRef( + `${m2m.joinTable}.${m2m.parentFkName}`, + '=', + `${context.alias ?? context.model}.${m2m.parentPKName}`, + ), + ); + return relationQuery.toOperationNode(); + } + private createColumnRef(column: string, context: ExpressionTransformerContext): ReferenceNode { return ReferenceNode.create(ColumnNode.create(column), TableNode.create(context.alias ?? context.model)); } @@ -550,4 +657,23 @@ export class ExpressionTransformer { return conditions.reduce((acc, condition) => ExpressionUtils.binary(acc, '&&', condition)); } } + + private isRelationField(expr: Expression, model: GetModels) { + const fieldDef = this.getFieldDefFromFieldRef(expr, model); + return !!fieldDef?.relation; + } + + private getFieldDefFromFieldRef(expr: Expression, model: GetModels): FieldDef | undefined { + if (ExpressionUtils.isField(expr)) { + return requireField(this.schema, model, expr.field); + } else if ( + ExpressionUtils.isMember(expr) && + expr.members.length === 1 && + ExpressionUtils.isThis(expr.receiver) + ) { + return requireField(this.schema, model, expr.members[0]!); + } else { + return undefined; + } + } } diff --git a/packages/runtime/src/plugins/policy/functions.ts b/packages/runtime/src/plugins/policy/functions.ts new file mode 100644 index 00000000..c7fa09d7 --- /dev/null +++ b/packages/runtime/src/plugins/policy/functions.ts @@ -0,0 +1,62 @@ +import { invariant } from '@zenstackhq/common-helpers'; +import { ExpressionWrapper, ValueNode, type Expression, type ExpressionBuilder } from 'kysely'; +import { CRUD } from '../../client/contract'; +import { extractFieldName } from '../../client/kysely-utils'; +import type { ZModelFunction, ZModelFunctionContext } from '../../client/options'; +import { buildJoinPairs, requireField } from '../../client/query-utils'; +import { PolicyHandler } from './policy-handler'; + +/** + * Relation checker implementation. + */ +export const check: ZModelFunction = ( + eb: ExpressionBuilder, + args: Expression[], + { client, model, modelAlias, operation }: ZModelFunctionContext, +) => { + invariant(args.length === 1 || args.length === 2, '"check" function requires 1 or 2 arguments'); + + const arg1Node = args[0]!.toOperationNode(); + + const arg2Node = args.length === 2 ? args[1]!.toOperationNode() : undefined; + if (arg2Node) { + invariant( + ValueNode.is(arg2Node) && typeof arg2Node.value === 'string', + '"operation" parameter must be a string literal when provided', + ); + invariant( + CRUD.includes(arg2Node.value as CRUD), + '"operation" parameter must be one of "create", "read", "update", "delete"', + ); + } + + // first argument must be a field reference + const fieldName = extractFieldName(arg1Node); + invariant(fieldName, 'Failed to extract field name from the first argument of "check" function'); + const fieldDef = requireField(client.$schema, model, fieldName); + invariant(fieldDef.relation, `Field "${fieldName}" is not a relation field in model "${model}"`); + invariant(!fieldDef.array, `Field "${fieldName}" is a to-many relation, which is not supported by "check"`); + const relationModel = fieldDef.type; + + const op = arg2Node ? (arg2Node.value as CRUD) : operation; + + const policyHandler = new PolicyHandler(client); + + // join with parent model + const joinPairs = buildJoinPairs(client.$schema, model, modelAlias, fieldName, relationModel); + const joinCondition = + joinPairs.length === 1 + ? eb(eb.ref(joinPairs[0]![0]), '=', eb.ref(joinPairs[0]![1])) + : eb.and(joinPairs.map(([left, right]) => eb(eb.ref(left), '=', eb.ref(right)))); + + // policy condition of the related model + const policyCondition = policyHandler.buildPolicyFilter(relationModel, undefined, op); + + // build the final nested select that evaluates the policy condition + const result = eb + .selectFrom(relationModel) + .where(joinCondition) + .select(new ExpressionWrapper(policyCondition).as('$condition')); + + return result; +}; diff --git a/packages/runtime/src/plugins/policy/plugin.ts b/packages/runtime/src/plugins/policy/plugin.ts index e5b914d5..6af93353 100644 --- a/packages/runtime/src/plugins/policy/plugin.ts +++ b/packages/runtime/src/plugins/policy/plugin.ts @@ -1,5 +1,6 @@ import { type OnKyselyQueryArgs, type RuntimePlugin } from '../../client/plugin'; import type { SchemaDef } from '../../schema'; +import { check } from './functions'; import { PolicyHandler } from './policy-handler'; export class PolicyPlugin implements RuntimePlugin { @@ -15,6 +16,12 @@ export class PolicyPlugin implements RuntimePlugin) { const handler = new PolicyHandler(client); return handler.handle(query, proceed /*, transaction*/); diff --git a/packages/runtime/src/plugins/policy/plugin.zmodel b/packages/runtime/src/plugins/policy/plugin.zmodel index ecb39320..659705ce 100644 --- a/packages/runtime/src/plugins/policy/plugin.zmodel +++ b/packages/runtime/src/plugins/policy/plugin.zmodel @@ -31,3 +31,13 @@ attribute @@deny(_ operation: String @@@completionHint(["'create'", "'read'", "' * @param condition: a boolean expression that controls if the operation should be denied. */ attribute @deny(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'delete'", "'all'"]), _ condition: Boolean) + +/** + * Checks if the current user can perform the given operation on the given field. + * + * @param field: The field to check access for + * @param operation: The operation to check access for. Can be "read", "create", "update", or "delete". If the operation is not provided, + * it defaults the operation of the containing policy rule. + */ +function check(field: Any, operation: String?): Boolean { +} @@@expressionContext([AccessPolicy]) diff --git a/packages/runtime/src/plugins/policy/policy-handler.ts b/packages/runtime/src/plugins/policy/policy-handler.ts index f26c2038..50cfc835 100644 --- a/packages/runtime/src/plugins/policy/policy-handler.ts +++ b/packages/runtime/src/plugins/policy/policy-handler.ts @@ -4,10 +4,13 @@ import { BinaryOperationNode, ColumnNode, DeleteQueryNode, + expressionBuilder, + ExpressionWrapper, FromNode, FunctionNode, IdentifierNode, InsertQueryNode, + JoinNode, OperationNodeTransformer, OperatorNode, ParensNode, @@ -16,6 +19,7 @@ import { ReturningNode, SelectionNode, SelectQueryNode, + sql, TableNode, UpdateQueryNode, ValueListNode, @@ -31,12 +35,12 @@ import type { ClientContract } from '../../client'; import type { CRUD } from '../../client/contract'; import { getCrudDialect } from '../../client/crud/dialects'; import type { BaseCrudDialect } from '../../client/crud/dialects/base-dialect'; -import { InternalError } from '../../client/errors'; +import { InternalError, QueryError } from '../../client/errors'; import type { ProceedKyselyQueryFunction } from '../../client/plugin'; -import { getIdFields, requireField, requireModel } from '../../client/query-utils'; +import { getManyToManyRelation, requireField, requireIdFields, requireModel } from '../../client/query-utils'; import { ExpressionUtils, type BuiltinType, type Expression, type GetModels, type SchemaDef } from '../../schema'; import { ColumnCollector } from './column-collector'; -import { RejectedByPolicyError } from './errors'; +import { RejectedByPolicyError, RejectedByPolicyReason } from './errors'; import { ExpressionTransformer } from './expression-transformer'; import type { Policy, PolicyOperation } from './types'; import { buildIsFalse, conjunction, disjunction, falseNode, getTableName } from './utils'; @@ -63,97 +67,292 @@ export class PolicyHandler extends OperationNodeTransf ) { if (!this.isCrudQueryNode(node)) { // non-CRUD queries are not allowed - throw new RejectedByPolicyError(undefined, 'non-CRUD queries are not allowed'); + throw new RejectedByPolicyError( + undefined, + RejectedByPolicyReason.OTHER, + 'non-CRUD queries are not allowed', + ); } if (!this.isMutationQueryNode(node)) { - // transform and proceed read without transaction + // transform and proceed with read directly return proceed(this.transformNode(node)); } - let mutationRequiresTransaction = false; - const mutationModel = this.getMutationModel(node); + const { mutationModel } = this.getMutationModel(node); if (InsertQueryNode.is(node)) { - // reject create if unconditional deny - const constCondition = this.tryGetConstantPolicy(mutationModel, 'create'); - if (constCondition === false) { - throw new RejectedByPolicyError(mutationModel); - } else if (constCondition === undefined) { - mutationRequiresTransaction = true; + // pre-create policy evaluation happens before execution of the query + const isManyToManyJoinTable = this.isManyToManyJoinTable(mutationModel); + let needCheckPreCreate = true; + + // many-to-many join table is not a model so can't have policies on it + if (!isManyToManyJoinTable) { + // check constant policies + const constCondition = this.tryGetConstantPolicy(mutationModel, 'create'); + if (constCondition === true) { + needCheckPreCreate = false; + } else if (constCondition === false) { + throw new RejectedByPolicyError(mutationModel); + } } - } - if (!mutationRequiresTransaction && !node.returning) { - // transform and proceed mutation without transaction - return proceed(this.transformNode(node)); + if (needCheckPreCreate) { + await this.enforcePreCreatePolicy(node, mutationModel, isManyToManyJoinTable, proceed); + } } - if (InsertQueryNode.is(node)) { - await this.enforcePreCreatePolicy(node, proceed); - } - const transformedNode = this.transformNode(node); - const result = await proceed(transformedNode); + // proceed with query - if (!this.onlyReturningId(node)) { + const result = await proceed(this.transformNode(node)); + + if (!node.returning || this.onlyReturningId(node)) { + return result; + } else { const readBackResult = await this.processReadBack(node, result, proceed); if (readBackResult.rows.length !== result.rows.length) { - throw new RejectedByPolicyError(mutationModel, 'result is not allowed to be read back'); + throw new RejectedByPolicyError( + mutationModel, + RejectedByPolicyReason.CANNOT_READ_BACK, + 'result is not allowed to be read back', + ); } return readBackResult; - } else { + } + } + + // #region overrides + + protected override transformSelectQuery(node: SelectQueryNode) { + let whereNode = this.transformNode(node.where); + + // get combined policy filter for all froms, and merge into where clause + const policyFilter = this.createPolicyFilterForFrom(node.from); + if (policyFilter) { + whereNode = WhereNode.create( + whereNode?.where ? conjunction(this.dialect, [whereNode.where, policyFilter]) : policyFilter, + ); + } + + const baseResult = super.transformSelectQuery({ + ...node, + where: undefined, + }); + + return { + ...baseResult, + where: whereNode, + }; + } + + protected override transformJoin(node: JoinNode) { + const table = this.extractTableName(node.table); + if (!table) { + // unable to extract table name, can be a subquery, which will be handled when nested transformation happens + return super.transformJoin(node); + } + + // build a nested query with policy filter applied + const filter = this.buildPolicyFilter(table.model, table.alias, 'read'); + const nestedSelect: SelectQueryNode = { + kind: 'SelectQueryNode', + from: FromNode.create([node.table]), + selections: [SelectionNode.createSelectAll()], + where: WhereNode.create(filter), + }; + return { + ...node, + table: AliasNode.create(ParensNode.create(nestedSelect), IdentifierNode.create(table.alias ?? table.model)), + }; + } + + protected override transformInsertQuery(node: InsertQueryNode) { + // pre-insert check is done in `handle()` + + let onConflict = node.onConflict; + + if (onConflict?.updates) { + // for "on conflict do update", we need to apply policy filter to the "where" clause + const { mutationModel, alias } = this.getMutationModel(node); + const filter = this.buildPolicyFilter(mutationModel, alias, 'update'); + if (onConflict.updateWhere) { + onConflict = { + ...onConflict, + updateWhere: WhereNode.create(conjunction(this.dialect, [onConflict.updateWhere.where, filter])), + }; + } else { + onConflict = { + ...onConflict, + updateWhere: WhereNode.create(filter), + }; + } + } + + // merge updated onConflict + const processedNode = onConflict ? { ...node, onConflict } : node; + + const result = super.transformInsertQuery(processedNode); + + if (!node.returning) { + return result; + } + + if (this.onlyReturningId(node)) { return result; + } else { + // only return ID fields, that's enough for reading back the inserted row + const { mutationModel } = this.getMutationModel(node); + const idFields = requireIdFields(this.client.$schema, mutationModel); + return { + ...result, + returning: ReturningNode.create( + idFields.map((field) => SelectionNode.create(ColumnNode.create(field))), + ), + }; } + } - // TODO: run in transaction - // let readBackError = false; - - // transform and post-process in a transaction - // const result = await transaction(async (txProceed) => { - // if (InsertQueryNode.is(node)) { - // await this.enforcePreCreatePolicy(node, txProceed); - // } - // const transformedNode = this.transformNode(node); - // const result = await txProceed(transformedNode); - - // if (!this.onlyReturningId(node)) { - // const readBackResult = await this.processReadBack(node, result, txProceed); - // if (readBackResult.rows.length !== result.rows.length) { - // readBackError = true; - // } - // return readBackResult; - // } else { - // return result; - // } - // }); - - // if (readBackError) { - // throw new RejectedByPolicyError(mutationModel, 'result is not allowed to be read back'); - // } - - // return result; + protected override transformUpdateQuery(node: UpdateQueryNode) { + const result = super.transformUpdateQuery(node); + const { mutationModel, alias } = this.getMutationModel(node); + let filter = this.buildPolicyFilter(mutationModel, alias, 'update'); + + if (node.from) { + // for update with from (join), we need to merge join tables' policy filters to the "where" clause + const joinFilter = this.createPolicyFilterForFrom(node.from); + if (joinFilter) { + filter = conjunction(this.dialect, [filter, joinFilter]); + } + } + + return { + ...result, + where: WhereNode.create(result.where ? conjunction(this.dialect, [result.where.where, filter]) : filter), + }; } + protected override transformDeleteQuery(node: DeleteQueryNode) { + const result = super.transformDeleteQuery(node); + const { mutationModel, alias } = this.getMutationModel(node); + let filter = this.buildPolicyFilter(mutationModel, alias, 'delete'); + + if (node.using) { + // for delete with using (join), we need to merge join tables' policy filters to the "where" clause + const joinFilter = this.createPolicyFilterForTables(node.using.tables); + if (joinFilter) { + filter = conjunction(this.dialect, [filter, joinFilter]); + } + } + + return { + ...result, + where: WhereNode.create(result.where ? conjunction(this.dialect, [result.where.where, filter]) : filter), + }; + } + + // #endregion + + // #region helpers + private onlyReturningId(node: MutationQueryNode) { if (!node.returning) { return true; } - const idFields = getIdFields(this.client.$schema, this.getMutationModel(node)); + const { mutationModel } = this.getMutationModel(node); + const idFields = requireIdFields(this.client.$schema, mutationModel); const collector = new ColumnCollector(); const selectedColumns = collector.collect(node.returning); return selectedColumns.every((c) => idFields.includes(c)); } - private async enforcePreCreatePolicy(node: InsertQueryNode, proceed: ProceedKyselyQueryFunction) { - const model = this.getMutationModel(node); + private async enforcePreCreatePolicy( + node: InsertQueryNode, + mutationModel: GetModels, + isManyToManyJoinTable: boolean, + proceed: ProceedKyselyQueryFunction, + ) { const fields = node.columns?.map((c) => c.column.name) ?? []; - const valueRows = node.values ? this.unwrapCreateValueRows(node.values, model, fields) : [[]]; + const valueRows = node.values + ? this.unwrapCreateValueRows(node.values, mutationModel, fields, isManyToManyJoinTable) + : [[]]; for (const values of valueRows) { - await this.enforcePreCreatePolicyForOne( - model, - fields, - values.map((v) => v.node), - proceed, + if (isManyToManyJoinTable) { + await this.enforcePreCreatePolicyForManyToManyJoinTable( + mutationModel, + fields, + values.map((v) => v.node), + proceed, + ); + } else { + await this.enforcePreCreatePolicyForOne( + mutationModel, + fields, + values.map((v) => v.node), + proceed, + ); + } + } + } + + private async enforcePreCreatePolicyForManyToManyJoinTable( + tableName: GetModels, + fields: string[], + values: OperationNode[], + proceed: ProceedKyselyQueryFunction, + ) { + const m2m = this.resolveManyToManyJoinTable(tableName); + invariant(m2m); + + // m2m create requires both sides to be updatable + invariant(fields.includes('A') && fields.includes('B'), 'many-to-many join table must have A and B fk fields'); + + const aIndex = fields.indexOf('A'); + const aNode = values[aIndex]!; + const bIndex = fields.indexOf('B'); + const bNode = values[bIndex]!; + invariant(ValueNode.is(aNode) && ValueNode.is(bNode), 'A and B values must be ValueNode'); + + const aValue = aNode.value; + const bValue = bNode.value; + invariant(aValue !== null && aValue !== undefined, 'A value cannot be null or undefined'); + invariant(bValue !== null && bValue !== undefined, 'B value cannot be null or undefined'); + + const eb = expressionBuilder(); + + const filterA = this.buildPolicyFilter(m2m.firstModel as GetModels, undefined, 'update'); + const queryA = eb + .selectFrom(m2m.firstModel) + .where(eb(eb.ref(`${m2m.firstModel}.${m2m.firstIdField}`), '=', aValue)) + .select(() => new ExpressionWrapper(filterA).as('$t')); + + const filterB = this.buildPolicyFilter(m2m.secondModel as GetModels, undefined, 'update'); + const queryB = eb + .selectFrom(m2m.secondModel) + .where(eb(eb.ref(`${m2m.secondModel}.${m2m.secondIdField}`), '=', bValue)) + .select(() => new ExpressionWrapper(filterB).as('$t')); + + // select both conditions in one query + const queryNode: SelectQueryNode = { + kind: 'SelectQueryNode', + selections: [ + SelectionNode.create(AliasNode.create(queryA.toOperationNode(), IdentifierNode.create('$conditionA'))), + SelectionNode.create(AliasNode.create(queryB.toOperationNode(), IdentifierNode.create('$conditionB'))), + ], + }; + + const result = await proceed(queryNode); + if (!result.rows[0]?.$conditionA) { + throw new RejectedByPolicyError( + m2m.firstModel as GetModels, + RejectedByPolicyReason.CANNOT_READ_BACK, + `many-to-many relation participant model "${m2m.firstModel}" not updatable`, + ); + } + if (!result.rows[0]?.$conditionB) { + throw new RejectedByPolicyError( + m2m.secondModel as GetModels, + RejectedByPolicyReason.NO_ACCESS, + `many-to-many relation participant model "${m2m.secondModel}" not updatable`, ); } } @@ -164,11 +363,13 @@ export class PolicyHandler extends OperationNodeTransf values: OperationNode[], proceed: ProceedKyselyQueryFunction, ) { - const allFields = Object.keys(requireModel(this.client.$schema, model).fields); + const allFields = Object.entries(requireModel(this.client.$schema, model).fields).filter( + ([, def]) => !def.relation, + ); const allValues: OperationNode[] = []; - for (const fieldName of allFields) { - const index = fields.indexOf(fieldName); + for (const [name, _def] of allFields) { + const index = fields.indexOf(name); if (index >= 0) { allValues.push(values[index]!); } else { @@ -178,6 +379,8 @@ export class PolicyHandler extends OperationNodeTransf } // create a `SELECT column1 as field1, column2 as field2, ... FROM (VALUES (...))` table for policy evaluation + const eb = expressionBuilder(); + const constTable: SelectQueryNode = { kind: 'SelectQueryNode', from: FromNode.create([ @@ -186,11 +389,13 @@ export class PolicyHandler extends OperationNodeTransf IdentifierNode.create('$t'), ), ]), - selections: allFields.map((field, index) => - SelectionNode.create( - AliasNode.create(ColumnNode.create(`column${index + 1}`), IdentifierNode.create(field)), - ), - ), + selections: allFields.map(([name, def], index) => { + const castedColumnRef = + sql`CAST(${eb.ref(`column${index + 1}`)} as ${sql.raw(this.dialect.getFieldSqlType(def))})`.as( + name, + ); + return SelectionNode.create(castedColumnRef.toOperationNode()); + }), }; const filter = this.buildPolicyFilter(model, undefined, 'create'); @@ -219,23 +424,33 @@ export class PolicyHandler extends OperationNodeTransf } } - private unwrapCreateValueRows(node: OperationNode, model: GetModels, fields: string[]) { + private unwrapCreateValueRows( + node: OperationNode, + model: GetModels, + fields: string[], + isManyToManyJoinTable: boolean, + ) { if (ValuesNode.is(node)) { - return node.values.map((v) => this.unwrapCreateValueRow(v.values, model, fields)); + return node.values.map((v) => this.unwrapCreateValueRow(v.values, model, fields, isManyToManyJoinTable)); } else if (PrimitiveValueListNode.is(node)) { - return [this.unwrapCreateValueRow(node.values, model, fields)]; + return [this.unwrapCreateValueRow(node.values, model, fields, isManyToManyJoinTable)]; } else { throw new InternalError(`Unexpected node kind: ${node.kind} for unwrapping create values`); } } - private unwrapCreateValueRow(data: readonly unknown[], model: GetModels, fields: string[]) { + private unwrapCreateValueRow( + data: readonly unknown[], + model: GetModels, + fields: string[], + isImplicitManyToManyJoinTable: boolean, + ) { invariant(data.length === fields.length, 'data length must match fields length'); const result: { node: OperationNode; raw: unknown }[] = []; for (let i = 0; i < data.length; i++) { const item = data[i]!; - const fieldDef = requireField(this.client.$schema, model, fields[i]!); if (typeof item === 'object' && item && 'kind' in item) { + const fieldDef = requireField(this.client.$schema, model, fields[i]!); invariant(item.kind === 'ValueNode', 'expecting a ValueNode'); result.push({ node: ValueNode.create( @@ -248,7 +463,15 @@ export class PolicyHandler extends OperationNodeTransf raw: (item as ValueNode).value, }); } else { - const value = this.dialect.transformPrimitive(item, fieldDef.type as BuiltinType, !!fieldDef.array); + let value: unknown = item; + + // many-to-many join table is not a model so we don't have field definitions, + // but there's no need to transform values anyway because they're the fields + // are all foreign keys + if (!isImplicitManyToManyJoinTable) { + const fieldDef = requireField(this.client.$schema, model, fields[i]!); + value = this.dialect.transformPrimitive(item, fieldDef.type as BuiltinType, !!fieldDef.array); + } if (Array.isArray(value)) { result.push({ node: RawNode.createWithSql(this.dialect.buildArrayLiteralSQL(value)), @@ -297,17 +520,13 @@ export class PolicyHandler extends OperationNodeTransf } // do a select (with policy) in place of returning - const table = this.getMutationModel(node); - if (!table) { - throw new InternalError(`Unable to get table name for query node: ${node}`); - } - - const idConditions = this.buildIdConditions(table, result.rows); - const policyFilter = this.buildPolicyFilter(table, undefined, 'read'); + const { mutationModel } = this.getMutationModel(node); + const idConditions = this.buildIdConditions(mutationModel, result.rows); + const policyFilter = this.buildPolicyFilter(mutationModel, undefined, 'read'); const select: SelectQueryNode = { kind: 'SelectQueryNode', - from: FromNode.create([TableNode.create(table)]), + from: FromNode.create([TableNode.create(mutationModel)]), where: WhereNode.create(conjunction(this.dialect, [idConditions, policyFilter])), selections: node.returning.selections, }; @@ -316,7 +535,7 @@ export class PolicyHandler extends OperationNodeTransf } private buildIdConditions(table: string, rows: any[]): OperationNode { - const idFields = getIdFields(this.client.$schema, table); + const idFields = requireIdFields(this.client.$schema, table); return disjunction( this.dialect, rows.map((row) => @@ -336,13 +555,23 @@ export class PolicyHandler extends OperationNodeTransf private getMutationModel(node: InsertQueryNode | UpdateQueryNode | DeleteQueryNode) { const r = match(node) - .when(InsertQueryNode.is, (node) => getTableName(node.into) as GetModels) - .when(UpdateQueryNode.is, (node) => getTableName(node.table) as GetModels) + .when(InsertQueryNode.is, (node) => ({ + mutationModel: getTableName(node.into) as GetModels, + alias: undefined, + })) + .when(UpdateQueryNode.is, (node) => { + if (!node.table) { + throw new QueryError('Update query must have a table'); + } + const r = this.extractTableName(node.table); + return r ? { mutationModel: r.model, alias: r.alias } : undefined; + }) .when(DeleteQueryNode.is, (node) => { if (node.from.froms.length !== 1) { - throw new InternalError('Only one from table is supported for delete'); + throw new QueryError('Only one from table is supported for delete'); } - return getTableName(node.from.froms[0]) as GetModels; + const r = this.extractTableName(node.from.froms[0]!); + return r ? { mutationModel: r.model, alias: r.alias } : undefined; }) .exhaustive(); if (!r) { @@ -361,7 +590,13 @@ export class PolicyHandler extends OperationNodeTransf return InsertQueryNode.is(node) || UpdateQueryNode.is(node) || DeleteQueryNode.is(node); } - private buildPolicyFilter(model: GetModels, alias: string | undefined, operation: CRUD) { + buildPolicyFilter(model: GetModels, alias: string | undefined, operation: CRUD) { + // first check if it's a many-to-many join table, and if so, handle specially + const m2mFilter = this.getModelPolicyFilterForManyToManyJoinTable(model, alias, operation); + if (m2mFilter) { + return m2mFilter; + } + const policies = this.getModelPolicies(model, operation); if (policies.length === 0) { return falseNode(this.dialect); @@ -369,11 +604,11 @@ export class PolicyHandler extends OperationNodeTransf const allows = policies .filter((policy) => policy.kind === 'allow') - .map((policy) => this.transformPolicyCondition(model, alias, operation, policy)); + .map((policy) => this.compilePolicyCondition(model, alias, operation, policy)); const denies = policies .filter((policy) => policy.kind === 'deny') - .map((policy) => this.transformPolicyCondition(model, alias, operation, policy)); + .map((policy) => this.compilePolicyCondition(model, alias, operation, policy)); let combinedPolicy: OperationNode; @@ -397,82 +632,18 @@ export class PolicyHandler extends OperationNodeTransf return combinedPolicy; } - protected override transformSelectQuery(node: SelectQueryNode) { - let whereNode = node.where; - - node.from?.froms.forEach((from) => { - const extractResult = this.extractTableName(from); - if (extractResult) { - const { model, alias } = extractResult; - const filter = this.buildPolicyFilter(model, alias, 'read'); - whereNode = WhereNode.create( - whereNode?.where ? conjunction(this.dialect, [whereNode.where, filter]) : filter, - ); - } - }); - - const baseResult = super.transformSelectQuery({ - ...node, - where: undefined, - }); - - return { - ...baseResult, - where: whereNode, - }; - } - - protected override transformInsertQuery(node: InsertQueryNode) { - const result = super.transformInsertQuery(node); - if (!node.returning) { - return result; + private extractTableName(node: OperationNode): { model: GetModels; alias?: string } | undefined { + if (TableNode.is(node)) { + return { model: node.table.identifier.name as GetModels }; } - if (this.onlyReturningId(node)) { - return result; - } else { - // only return ID fields, that's enough for reading back the inserted row - const idFields = getIdFields(this.client.$schema, this.getMutationModel(node)); - return { - ...result, - returning: ReturningNode.create( - idFields.map((field) => SelectionNode.create(ColumnNode.create(field))), - ), - }; - } - } - - protected override transformUpdateQuery(node: UpdateQueryNode) { - const result = super.transformUpdateQuery(node); - const mutationModel = this.getMutationModel(node); - const filter = this.buildPolicyFilter(mutationModel, undefined, 'update'); - return { - ...result, - where: WhereNode.create(result.where ? conjunction(this.dialect, [result.where.where, filter]) : filter), - }; - } - - protected override transformDeleteQuery(node: DeleteQueryNode) { - const result = super.transformDeleteQuery(node); - const mutationModel = this.getMutationModel(node); - const filter = this.buildPolicyFilter(mutationModel, undefined, 'delete'); - return { - ...result, - where: WhereNode.create(result.where ? conjunction(this.dialect, [result.where.where, filter]) : filter), - }; - } - - private extractTableName(from: OperationNode): { model: GetModels; alias?: string } | undefined { - if (TableNode.is(from)) { - return { model: from.table.identifier.name as GetModels }; - } - if (AliasNode.is(from)) { - const inner = this.extractTableName(from.node); + if (AliasNode.is(node)) { + const inner = this.extractTableName(node.node); if (!inner) { return undefined; } return { model: inner.model, - alias: IdentifierNode.is(from.alias) ? from.alias.name : undefined, + alias: IdentifierNode.is(node.alias) ? node.alias.name : undefined, }; } else { // this can happen for subqueries, which will be handled when nested @@ -481,25 +652,41 @@ export class PolicyHandler extends OperationNodeTransf } } - private transformPolicyCondition( + private createPolicyFilterForFrom(node: FromNode | undefined) { + if (!node) { + return undefined; + } + return this.createPolicyFilterForTables(node.froms); + } + + private createPolicyFilterForTables(tables: readonly OperationNode[]) { + return tables.reduce((acc, table) => { + const extractResult = this.extractTableName(table); + if (extractResult) { + const { model, alias } = extractResult; + const filter = this.buildPolicyFilter(model, alias, 'read'); + return acc ? conjunction(this.dialect, [acc, filter]) : filter; + } + return acc; + }, undefined); + } + + private compilePolicyCondition( model: GetModels, alias: string | undefined, operation: CRUD, policy: Policy, ) { - return new ExpressionTransformer(this.client.$schema, this.client.$options, this.client.$auth).transform( - policy.condition, - { - model, - alias, - operation, - auth: this.client.$auth, - }, - ); + return new ExpressionTransformer(this.client).transform(policy.condition, { + model, + alias, + operation, + auth: this.client.$auth, + }); } - private getModelPolicies(modelName: string, operation: PolicyOperation) { - const modelDef = requireModel(this.client.$schema, modelName); + private getModelPolicies(model: string, operation: PolicyOperation) { + const modelDef = requireModel(this.client.$schema, model); const result: Policy[] = []; const extractOperations = (expr: Expression) => { @@ -528,4 +715,93 @@ export class PolicyHandler extends OperationNodeTransf } return result; } + + private resolveManyToManyJoinTable(tableName: string) { + for (const model of Object.values(this.client.$schema.models)) { + for (const field of Object.values(model.fields)) { + const m2m = getManyToManyRelation(this.client.$schema, model.name, field.name); + if (m2m?.joinTable === tableName) { + const sortedRecord = [ + { + model: model.name, + field: field.name, + }, + { + model: m2m.otherModel, + field: m2m.otherField, + }, + ].sort(this.manyToManySorter); + + const firstIdFields = requireIdFields(this.client.$schema, sortedRecord[0]!.model); + const secondIdFields = requireIdFields(this.client.$schema, sortedRecord[1]!.model); + invariant( + firstIdFields.length === 1 && secondIdFields.length === 1, + 'only single-field id is supported for implicit many-to-many join table', + ); + + return { + firstModel: sortedRecord[0]!.model, + firstField: sortedRecord[0]!.field, + firstIdField: firstIdFields[0]!, + secondModel: sortedRecord[1]!.model, + secondField: sortedRecord[1]!.field, + secondIdField: secondIdFields[0]!, + }; + } + } + } + return undefined; + } + + private manyToManySorter(a: { model: string; field: string }, b: { model: string; field: string }): number { + // the implicit m2m join table's "A", "B" fk fields' order is determined + // by model name's sort order, and when identical (for self-relations), + // field name's sort order + return a.model !== b.model ? a.model.localeCompare(b.model) : a.field.localeCompare(b.field); + } + + private isManyToManyJoinTable(tableName: string) { + return !!this.resolveManyToManyJoinTable(tableName); + } + + private getModelPolicyFilterForManyToManyJoinTable( + tableName: string, + alias: string | undefined, + operation: PolicyOperation, + ): OperationNode | undefined { + const m2m = this.resolveManyToManyJoinTable(tableName); + if (!m2m) { + return undefined; + } + + // join table's permission: + // - read: requires both sides to be readable + // - mutation: requires both sides to be updatable + + const checkForOperation = operation === 'read' ? 'read' : 'update'; + const eb = expressionBuilder(); + const joinTable = alias ?? tableName; + + const aQuery = eb + .selectFrom(m2m.firstModel) + .whereRef(`${m2m.firstModel}.${m2m.firstIdField}`, '=', `${joinTable}.A`) + .select(() => + new ExpressionWrapper( + this.buildPolicyFilter(m2m.firstModel as GetModels, undefined, checkForOperation), + ).as('$conditionA'), + ); + + const bQuery = eb + .selectFrom(m2m.secondModel) + .whereRef(`${m2m.secondModel}.${m2m.secondIdField}`, '=', `${joinTable}.B`) + .select(() => + new ExpressionWrapper( + this.buildPolicyFilter(m2m.secondModel as GetModels, undefined, checkForOperation), + ).as('$conditionB'), + ); + + return eb.and([aQuery, bQuery]).toOperationNode(); + } + + // #endregion } diff --git a/packages/runtime/src/utils/zod-utils.ts b/packages/runtime/src/utils/zod-utils.ts new file mode 100644 index 00000000..2ca23ca8 --- /dev/null +++ b/packages/runtime/src/utils/zod-utils.ts @@ -0,0 +1,14 @@ +import { ZodError } from 'zod'; +import { fromError as fromError3 } from 'zod-validation-error/v3'; +import { fromError as fromError4 } from 'zod-validation-error/v4'; + +/** + * Format ZodError into a readable string + */ +export function formatError(error: ZodError): string { + if ('_zod' in error) { + return fromError4(error).toString(); + } else { + return fromError3(error).toString(); + } +} diff --git a/packages/runtime/test/client-api/aggregate.test.ts b/packages/runtime/test/client-api/aggregate.test.ts index 0c8ffd27..6b7edd64 100644 --- a/packages/runtime/test/client-api/aggregate.test.ts +++ b/packages/runtime/test/client-api/aggregate.test.ts @@ -1,16 +1,14 @@ import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import type { ClientContract } from '../../src/client'; import { schema } from '../schemas/basic'; -import { createClientSpecs } from './client-specs'; import { createUser } from './utils'; +import { createTestClient } from '../utils'; -const PG_DB_NAME = 'client-api-aggregate-tests'; - -describe.each(createClientSpecs(PG_DB_NAME))('Client aggregate tests', ({ createClient }) => { +describe('Client aggregate tests', () => { let client: ClientContract; beforeEach(async () => { - client = await createClient(); + client = await createTestClient(schema); }); afterEach(async () => { diff --git a/packages/runtime/test/client-api/client-specs.ts b/packages/runtime/test/client-api/client-specs.ts deleted file mode 100644 index 59e50d5d..00000000 --- a/packages/runtime/test/client-api/client-specs.ts +++ /dev/null @@ -1,42 +0,0 @@ -import type { LogEvent } from 'kysely'; -import { getSchema, schema } from '../schemas/basic'; -import { makePostgresClient, makeSqliteClient } from '../utils'; -import type { ClientContract } from '../../src'; - -export function createClientSpecs(dbName: string, logQueries = false, providers: string[] = ['sqlite', 'postgresql']) { - const logger = (provider: string) => (event: LogEvent) => { - if (event.level === 'query') { - console.log(`query(${provider}):`, event.query.sql, event.query.parameters); - } - }; - return [ - ...(providers.includes('sqlite') - ? [ - { - provider: 'sqlite' as const, - schema: getSchema('sqlite'), - createClient: async (): Promise> => { - // tsc perf - return makeSqliteClient(getSchema('sqlite'), { - log: logQueries ? logger('sqlite') : undefined, - }) as unknown as ClientContract; - }, - }, - ] - : []), - ...(providers.includes('postgresql') - ? [ - { - provider: 'postgresql' as const, - schema: getSchema('postgresql'), - createClient: async (): Promise> => { - // tsc perf - return makePostgresClient(getSchema('postgresql'), dbName, { - log: logQueries ? logger('postgresql') : undefined, - }) as unknown as ClientContract; - }, - }, - ] - : []), - ] as const; -} diff --git a/packages/runtime/test/client-api/computed-fields.test.ts b/packages/runtime/test/client-api/computed-fields.test.ts index 054997a3..0ece9ddf 100644 --- a/packages/runtime/test/client-api/computed-fields.test.ts +++ b/packages/runtime/test/client-api/computed-fields.test.ts @@ -2,121 +2,113 @@ import { sql } from 'kysely'; import { afterEach, describe, expect, it } from 'vitest'; import { createTestClient } from '../utils'; -const TEST_DB = 'client-api-computed-fields'; +describe('Computed fields tests', () => { + let db: any; -describe.each([{ provider: 'sqlite' as const }, { provider: 'postgresql' as const }])( - 'Computed fields tests', - ({ provider }) => { - let db: any; - - afterEach(async () => { - await db?.$disconnect(); - }); + afterEach(async () => { + await db?.$disconnect(); + }); - it('works with non-optional fields', async () => { - db = await createTestClient( - ` + it('works with non-optional fields', async () => { + db = await createTestClient( + ` model User { id Int @id @default(autoincrement()) name String upperName String @computed } `, - { - provider, - dbName: TEST_DB, - computedFields: { - User: { - upperName: (eb: any) => eb.fn('upper', ['name']), - }, + { + computedFields: { + User: { + upperName: (eb: any) => eb.fn('upper', ['name']), }, - } as any, - ); - - await expect( - db.user.create({ - data: { id: 1, name: 'Alex' }, - }), - ).resolves.toMatchObject({ - upperName: 'ALEX', - }); + }, + } as any, + ); + + await expect( + db.user.create({ + data: { id: 1, name: 'Alex' }, + }), + ).resolves.toMatchObject({ + upperName: 'ALEX', + }); - await expect( - db.user.findUnique({ - where: { id: 1 }, - select: { upperName: true }, - }), - ).resolves.toMatchObject({ - upperName: 'ALEX', - }); + await expect( + db.user.findUnique({ + where: { id: 1 }, + select: { upperName: true }, + }), + ).resolves.toMatchObject({ + upperName: 'ALEX', + }); - await expect( - db.user.findFirst({ - where: { upperName: 'ALEX' }, - }), - ).resolves.toMatchObject({ - upperName: 'ALEX', - }); + await expect( + db.user.findFirst({ + where: { upperName: 'ALEX' }, + }), + ).resolves.toMatchObject({ + upperName: 'ALEX', + }); - await expect( - db.user.findFirst({ - where: { upperName: 'Alex' }, - }), - ).toResolveNull(); + await expect( + db.user.findFirst({ + where: { upperName: 'Alex' }, + }), + ).toResolveNull(); + + await expect( + db.user.findFirst({ + orderBy: { upperName: 'desc' }, + }), + ).resolves.toMatchObject({ + upperName: 'ALEX', + }); - await expect( - db.user.findFirst({ - orderBy: { upperName: 'desc' }, - }), - ).resolves.toMatchObject({ - upperName: 'ALEX', - }); + await expect( + db.user.findFirst({ + orderBy: { upperName: 'desc' }, + take: 1, + }), + ).resolves.toMatchObject({ + upperName: 'ALEX', + }); - await expect( - db.user.findFirst({ - orderBy: { upperName: 'desc' }, - take: 1, - }), - ).resolves.toMatchObject({ - upperName: 'ALEX', - }); + await expect( + db.user.aggregate({ + _count: { upperName: true }, + }), + ).resolves.toMatchObject({ + _count: { upperName: 1 }, + }); - await expect( - db.user.aggregate({ - _count: { upperName: true }, - }), - ).resolves.toMatchObject({ + await expect( + db.user.groupBy({ + by: ['upperName'], + _count: { upperName: true }, + _max: { upperName: true }, + }), + ).resolves.toEqual([ + expect.objectContaining({ _count: { upperName: 1 }, - }); - - await expect( - db.user.groupBy({ - by: ['upperName'], - _count: { upperName: true }, - _max: { upperName: true }, - }), - ).resolves.toEqual([ - expect.objectContaining({ - _count: { upperName: 1 }, - _max: { upperName: 'ALEX' }, - }), - ]); - }); + _max: { upperName: 'ALEX' }, + }), + ]); + }); - it('is typed correctly for non-optional fields', async () => { - db = await createTestClient( - ` + it('is typed correctly for non-optional fields', async () => { + db = await createTestClient( + ` model User { id Int @id @default(autoincrement()) name String upperName String @computed } `, - { - provider, - dbName: TEST_DB, - extraSourceFiles: { - main: ` + { + extraSourceFiles: { + main: ` import { ZenStackClient } from '@zenstackhq/runtime'; import { schema } from './schema'; @@ -140,54 +132,50 @@ async function main() { main(); `, - }, }, - ); - }); + }, + ); + }); - it('works with optional fields', async () => { - db = await createTestClient( - ` + it('works with optional fields', async () => { + db = await createTestClient( + ` model User { id Int @id @default(autoincrement()) name String upperName String? @computed } `, - { - provider, - dbName: TEST_DB, - computedFields: { - User: { - upperName: (eb: any) => eb.lit(null), - }, + { + computedFields: { + User: { + upperName: (eb: any) => eb.lit(null), }, - } as any, - ); - - await expect( - db.user.create({ - data: { id: 1, name: 'Alex' }, - }), - ).resolves.toMatchObject({ - upperName: null, - }); + }, + } as any, + ); + + await expect( + db.user.create({ + data: { id: 1, name: 'Alex' }, + }), + ).resolves.toMatchObject({ + upperName: null, }); + }); - it('is typed correctly for optional fields', async () => { - db = await createTestClient( - ` + it('is typed correctly for optional fields', async () => { + db = await createTestClient( + ` model User { id Int @id @default(autoincrement()) name String upperName String? @computed } `, - { - provider, - dbName: TEST_DB, - extraSourceFiles: { - main: ` + { + extraSourceFiles: { + main: ` import { ZenStackClient } from '@zenstackhq/runtime'; import { schema } from './schema'; @@ -210,14 +198,14 @@ async function main() { main(); `, - }, }, - ); - }); + }, + ); + }); - it('works with read from a relation', async () => { - db = await createTestClient( - ` + it('works with read from a relation', async () => { + db = await createTestClient( + ` model User { id Int @id @default(autoincrement()) name String @@ -232,28 +220,25 @@ model Post { authorId Int } `, - { - provider, - dbName: TEST_DB, - computedFields: { - User: { - postCount: (eb: any, context: { currentModel: string }) => - eb - .selectFrom('Post') - .whereRef('Post.authorId', '=', sql.ref(`${context.currentModel}.id`)) - .select(() => eb.fn.countAll().as('count')), - }, + { + computedFields: { + User: { + postCount: (eb: any, context: { modelAlias: string }) => + eb + .selectFrom('Post') + .whereRef('Post.authorId', '=', sql.ref(`${context.modelAlias}.id`)) + .select(() => eb.fn.countAll().as('count')), }, - } as any, - ); + }, + } as any, + ); - await db.user.create({ - data: { id: 1, name: 'Alex', posts: { create: { title: 'Post1' } } }, - }); + await db.user.create({ + data: { id: 1, name: 'Alex', posts: { create: { title: 'Post1' } } }, + }); - await expect(db.post.findFirst({ select: { id: true, author: true } })).resolves.toMatchObject({ - author: expect.objectContaining({ postCount: 1 }), - }); + await expect(db.post.findFirst({ select: { id: true, author: true } })).resolves.toMatchObject({ + author: expect.objectContaining({ postCount: 1 }), }); - }, -); + }); +}); diff --git a/packages/runtime/test/client-api/count.test.ts b/packages/runtime/test/client-api/count.test.ts index 743b4169..22a89ddc 100644 --- a/packages/runtime/test/client-api/count.test.ts +++ b/packages/runtime/test/client-api/count.test.ts @@ -1,15 +1,13 @@ import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import type { ClientContract } from '../../src/client'; import { schema } from '../schemas/basic'; -import { createClientSpecs } from './client-specs'; +import { createTestClient } from '../utils'; -const PG_DB_NAME = 'client-api-count-tests'; - -describe.each(createClientSpecs(PG_DB_NAME))('Client count tests', ({ createClient }) => { +describe('Client count tests', () => { let client: ClientContract; beforeEach(async () => { - client = await createClient(); + client = await createTestClient(schema); }); afterEach(async () => { diff --git a/packages/runtime/test/client-api/create-many-and-return.test.ts b/packages/runtime/test/client-api/create-many-and-return.test.ts index 29d5887e..be2a46e8 100644 --- a/packages/runtime/test/client-api/create-many-and-return.test.ts +++ b/packages/runtime/test/client-api/create-many-and-return.test.ts @@ -1,15 +1,13 @@ import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import type { ClientContract } from '../../src/client'; import { schema } from '../schemas/basic'; -import { createClientSpecs } from './client-specs'; +import { createTestClient } from '../utils'; -const PG_DB_NAME = 'client-api-create-many-and-return-tests'; - -describe.each(createClientSpecs(PG_DB_NAME))('Client createManyAndReturn tests', ({ createClient }) => { +describe('Client createManyAndReturn tests', () => { let client: ClientContract; beforeEach(async () => { - client = await createClient(); + client = await createTestClient(schema); }); afterEach(async () => { diff --git a/packages/runtime/test/client-api/create-many.test.ts b/packages/runtime/test/client-api/create-many.test.ts index d25d8587..3ccbbe73 100644 --- a/packages/runtime/test/client-api/create-many.test.ts +++ b/packages/runtime/test/client-api/create-many.test.ts @@ -1,15 +1,13 @@ import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import type { ClientContract } from '../../src/client'; import { schema } from '../schemas/basic'; -import { createClientSpecs } from './client-specs'; +import { createTestClient } from '../utils'; -const PG_DB_NAME = 'client-api-create-many-tests'; - -describe.each(createClientSpecs(PG_DB_NAME))('Client createMany tests', ({ createClient }) => { +describe('Client createMany tests', () => { let client: ClientContract; beforeEach(async () => { - client = await createClient(); + client = await createTestClient(schema); }); afterEach(async () => { diff --git a/packages/runtime/test/client-api/create.test.ts b/packages/runtime/test/client-api/create.test.ts index 8cd692fd..41ab341c 100644 --- a/packages/runtime/test/client-api/create.test.ts +++ b/packages/runtime/test/client-api/create.test.ts @@ -1,15 +1,13 @@ import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import type { ClientContract } from '../../src/client'; import { schema } from '../schemas/basic'; -import { createClientSpecs } from './client-specs'; +import { createTestClient } from '../utils'; -const PG_DB_NAME = 'client-api-create-tests'; - -describe.each(createClientSpecs(PG_DB_NAME))('Client create tests', ({ createClient }) => { +describe('Client create tests', () => { let client: ClientContract; beforeEach(async () => { - client = await createClient(); + client = await createTestClient(schema); }); afterEach(async () => { diff --git a/packages/runtime/test/client-api/delegate.test.ts b/packages/runtime/test/client-api/delegate.test.ts index a9ca705c..d9efff6f 100644 --- a/packages/runtime/test/client-api/delegate.test.ts +++ b/packages/runtime/test/client-api/delegate.test.ts @@ -4,1169 +4,1162 @@ import type { ClientContract } from '../../src'; import { schema, type SchemaType } from '../schemas/delegate/schema'; import { createTestClient } from '../utils'; -const DB_NAME = `client-api-delegate-tests`; +describe('Delegate model tests ', () => { + let client: ClientContract; + + beforeEach(async () => { + client = await createTestClient( + schema, + { + usePrismaPush: true, + }, + path.join(__dirname, '../schemas/delegate/schema.zmodel'), + ); + }); + + afterEach(async () => { + await client.$disconnect(); + }); + + describe('Delegate create tests', () => { + it('works with create', async () => { + // delegate model cannot be created directly + await expect( + // @ts-expect-error + client.video.create({ + data: { + duration: 100, + url: 'abc', + videoType: 'MyVideo', + }, + }), + ).rejects.toThrow('is a delegate'); + await expect( + client.user.create({ + data: { + assets: { + // @ts-expect-error + create: { assetType: 'Video' }, + }, + }, + }), + ).rejects.toThrow('is a delegate'); -describe.each([{ provider: 'sqlite' as const }, { provider: 'postgresql' as const }])( - 'Delegate model tests for $provider', - ({ provider }) => { - let client: ClientContract; + // create entity with two levels of delegation + await expect( + client.ratedVideo.create({ + data: { + duration: 100, + url: 'abc', + rating: 5, + }, + }), + ).resolves.toMatchObject({ + id: expect.any(Number), + duration: 100, + url: 'abc', + rating: 5, + assetType: 'Video', + videoType: 'RatedVideo', + }); - beforeEach(async () => { - client = await createTestClient( - schema, - { - usePrismaPush: true, - provider, - dbName: provider === 'postgresql' ? DB_NAME : undefined, + // create entity with relation + await expect( + client.ratedVideo.create({ + data: { + duration: 50, + url: 'bcd', + rating: 5, + user: { create: { email: 'u1@example.com' } }, + }, + include: { user: true }, + }), + ).resolves.toMatchObject({ + userId: expect.any(Number), + user: { + email: 'u1@example.com', }, - path.join(__dirname, '../schemas/delegate/schema.zmodel'), - ); - }); - - afterEach(async () => { - await client.$disconnect(); - }); + }); - describe('Delegate create tests', () => { - it('works with create', async () => { - // delegate model cannot be created directly - await expect( - // @ts-expect-error - client.video.create({ - data: { - duration: 100, - url: 'abc', - videoType: 'MyVideo', - }, - }), - ).rejects.toThrow('is a delegate'); - await expect( - client.user.create({ - data: { - assets: { - // @ts-expect-error - create: { assetType: 'Video' }, - }, - }, - }), - ).rejects.toThrow('is a delegate'); - - // create entity with two levels of delegation - await expect( - client.ratedVideo.create({ - data: { - duration: 100, - url: 'abc', - rating: 5, - }, - }), - ).resolves.toMatchObject({ - id: expect.any(Number), - duration: 100, - url: 'abc', - rating: 5, - assetType: 'Video', - videoType: 'RatedVideo', - }); - - // create entity with relation - await expect( - client.ratedVideo.create({ - data: { - duration: 50, - url: 'bcd', - rating: 5, - user: { create: { email: 'u1@example.com' } }, + // create entity with one level of delegation + await expect( + client.image.create({ + data: { + format: 'png', + gallery: { + create: {}, }, - include: { user: true }, - }), - ).resolves.toMatchObject({ - userId: expect.any(Number), - user: { - email: 'u1@example.com', }, - }); - - // create entity with one level of delegation - await expect( - client.image.create({ - data: { - format: 'png', - gallery: { - create: {}, - }, - }, - }), - ).resolves.toMatchObject({ - id: expect.any(Number), - format: 'png', - galleryId: expect.any(Number), - assetType: 'Image', - }); + }), + ).resolves.toMatchObject({ + id: expect.any(Number), + format: 'png', + galleryId: expect.any(Number), + assetType: 'Image', }); + }); + + it('works with createMany', async () => { + await expect( + client.ratedVideo.createMany({ + data: [ + { viewCount: 1, duration: 100, url: 'abc', rating: 5 }, + { viewCount: 2, duration: 200, url: 'def', rating: 4 }, + ], + }), + ).resolves.toEqual({ count: 2 }); - it('works with createMany', async () => { - await expect( - client.ratedVideo.createMany({ - data: [ - { viewCount: 1, duration: 100, url: 'abc', rating: 5 }, - { viewCount: 2, duration: 200, url: 'def', rating: 4 }, - ], + await expect(client.ratedVideo.findMany()).resolves.toEqual( + expect.arrayContaining([ + expect.objectContaining({ + viewCount: 1, + duration: 100, + url: 'abc', + rating: 5, }), - ).resolves.toEqual({ count: 2 }); - - await expect(client.ratedVideo.findMany()).resolves.toEqual( - expect.arrayContaining([ - expect.objectContaining({ - viewCount: 1, - duration: 100, - url: 'abc', - rating: 5, - }), - expect.objectContaining({ - viewCount: 2, - duration: 200, - url: 'def', - rating: 4, - }), - ]), - ); - - await expect( - client.ratedVideo.createMany({ - data: [ - { viewCount: 1, duration: 100, url: 'abc', rating: 5 }, - { viewCount: 2, duration: 200, url: 'def', rating: 4 }, - ], - skipDuplicates: true, + expect.objectContaining({ + viewCount: 2, + duration: 200, + url: 'def', + rating: 4, }), - ).rejects.toThrow('not supported'); - }); + ]), + ); - it('works with createManyAndReturn', async () => { - await expect( - client.ratedVideo.createManyAndReturn({ - data: [ - { viewCount: 1, duration: 100, url: 'abc', rating: 5 }, - { viewCount: 2, duration: 200, url: 'def', rating: 4 }, - ], - }), - ).resolves.toEqual( - expect.arrayContaining([ - expect.objectContaining({ - viewCount: 1, - duration: 100, - url: 'abc', - rating: 5, - }), - expect.objectContaining({ - viewCount: 2, - duration: 200, - url: 'def', - rating: 4, - }), - ]), - ); - }); + await expect( + client.ratedVideo.createMany({ + data: [ + { viewCount: 1, duration: 100, url: 'abc', rating: 5 }, + { viewCount: 2, duration: 200, url: 'def', rating: 4 }, + ], + skipDuplicates: true, + }), + ).rejects.toThrow('not supported'); + }); - it('ensures create is atomic', async () => { - // create with a relation that fails - await expect( - client.ratedVideo.create({ - data: { - duration: 100, - url: 'abc', - rating: 5, - }, + it('works with createManyAndReturn', async () => { + await expect( + client.ratedVideo.createManyAndReturn({ + data: [ + { viewCount: 1, duration: 100, url: 'abc', rating: 5 }, + { viewCount: 2, duration: 200, url: 'def', rating: 4 }, + ], + }), + ).resolves.toEqual( + expect.arrayContaining([ + expect.objectContaining({ + viewCount: 1, + duration: 100, + url: 'abc', + rating: 5, }), - ).toResolveTruthy(); - await expect( - client.ratedVideo.create({ - data: { - duration: 200, - url: 'abc', - rating: 3, - }, + expect.objectContaining({ + viewCount: 2, + duration: 200, + url: 'def', + rating: 4, }), - ).rejects.toThrow('constraint'); + ]), + ); + }); - await expect(client.ratedVideo.findMany()).toResolveWithLength(1); - await expect(client.video.findMany()).toResolveWithLength(1); - await expect(client.asset.findMany()).toResolveWithLength(1); - }); + it('ensures create is atomic', async () => { + // create with a relation that fails + await expect( + client.ratedVideo.create({ + data: { + duration: 100, + url: 'abc', + rating: 5, + }, + }), + ).toResolveTruthy(); + await expect( + client.ratedVideo.create({ + data: { + duration: 200, + url: 'abc', + rating: 3, + }, + }), + ).rejects.toThrow('constraint'); + + await expect(client.ratedVideo.findMany()).toResolveWithLength(1); + await expect(client.video.findMany()).toResolveWithLength(1); + await expect(client.asset.findMany()).toResolveWithLength(1); + }); + }); + + it('works with find', async () => { + const u = await client.user.create({ + data: { + email: 'u1@example.com', + }, + }); + const v = await client.ratedVideo.create({ + data: { + duration: 100, + url: 'abc', + rating: 5, + owner: { connect: { id: u.id } }, + user: { connect: { id: u.id } }, + }, + }); + + const ratedVideoContent = { + id: v.id, + createdAt: expect.any(Date), + duration: 100, + rating: 5, + assetType: 'Video', + videoType: 'RatedVideo', + }; + + // include all base fields + await expect( + client.ratedVideo.findUnique({ + where: { id: v.id }, + include: { user: true, owner: true }, + }), + ).resolves.toMatchObject({ ...ratedVideoContent, user: expect.any(Object), owner: expect.any(Object) }); + + // select fields + await expect( + client.ratedVideo.findUnique({ + where: { id: v.id }, + select: { + id: true, + viewCount: true, + url: true, + rating: true, + }, + }), + ).resolves.toEqual({ + id: v.id, + viewCount: 0, + url: 'abc', + rating: 5, + }); + + // omit fields + const r: any = await client.ratedVideo.findUnique({ + where: { id: v.id }, + omit: { + viewCount: true, + url: true, + rating: true, + }, + }); + expect(r.viewCount).toBeUndefined(); + expect(r.url).toBeUndefined(); + expect(r.rating).toBeUndefined(); + expect(r.duration).toEqual(expect.any(Number)); + + // include all sub fields + await expect( + client.video.findUnique({ + where: { id: v.id }, + }), + ).resolves.toMatchObject(ratedVideoContent); + + // include all sub fields + await expect( + client.asset.findUnique({ + where: { id: v.id }, + }), + ).resolves.toMatchObject(ratedVideoContent); + + // find as a relation + await expect( + client.user.findUnique({ + where: { id: u.id }, + include: { assets: true, ratedVideos: true }, + }), + ).resolves.toMatchObject({ + assets: [ratedVideoContent], + ratedVideos: [ratedVideoContent], + }); + + // find as a relation with selection + await expect( + client.user.findUnique({ + where: { id: u.id }, + include: { + assets: { + select: { id: true, assetType: true }, + }, + ratedVideos: { + select: { + url: true, + rating: true, + }, + }, + }, + }), + ).resolves.toMatchObject({ + assets: [{ id: v.id, assetType: 'Video' }], + ratedVideos: [{ url: 'abc', rating: 5 }], }); + }); - it('works with find', async () => { + describe('Delegate filter tests', () => { + beforeEach(async () => { const u = await client.user.create({ data: { email: 'u1@example.com', }, }); - const v = await client.ratedVideo.create({ + await client.ratedVideo.create({ data: { + viewCount: 0, duration: 100, - url: 'abc', + url: 'v1', rating: 5, owner: { connect: { id: u.id } }, user: { connect: { id: u.id } }, + comments: { create: { content: 'c1' } }, }, }); + await client.ratedVideo.create({ + data: { + viewCount: 1, + duration: 200, + url: 'v2', + rating: 4, + owner: { connect: { id: u.id } }, + user: { connect: { id: u.id } }, + comments: { create: { content: 'c2' } }, + }, + }); + }); - const ratedVideoContent = { - id: v.id, - createdAt: expect.any(Date), - duration: 100, - rating: 5, - assetType: 'Video', - videoType: 'RatedVideo', - }; - - // include all base fields + it('works with toplevel filters', async () => { await expect( - client.ratedVideo.findUnique({ - where: { id: v.id }, - include: { user: true, owner: true }, + client.asset.findMany({ + where: { viewCount: { gt: 0 } }, }), - ).resolves.toMatchObject({ ...ratedVideoContent, user: expect.any(Object), owner: expect.any(Object) }); + ).toResolveWithLength(1); - // select fields await expect( - client.ratedVideo.findUnique({ - where: { id: v.id }, - select: { - id: true, - viewCount: true, - url: true, - rating: true, - }, + client.video.findMany({ + where: { viewCount: { gt: 0 }, url: 'v1' }, }), - ).resolves.toEqual({ - id: v.id, - viewCount: 0, - url: 'abc', - rating: 5, - }); - - // omit fields - const r: any = await client.ratedVideo.findUnique({ - where: { id: v.id }, - omit: { - viewCount: true, - url: true, - rating: true, - }, - }); - expect(r.viewCount).toBeUndefined(); - expect(r.url).toBeUndefined(); - expect(r.rating).toBeUndefined(); - expect(r.duration).toEqual(expect.any(Number)); + ).toResolveWithLength(0); - // include all sub fields await expect( - client.video.findUnique({ - where: { id: v.id }, + client.video.findMany({ + where: { viewCount: { gt: 0 }, url: 'v2' }, }), - ).resolves.toMatchObject(ratedVideoContent); + ).toResolveWithLength(1); - // include all sub fields await expect( - client.asset.findUnique({ - where: { id: v.id }, + client.ratedVideo.findMany({ + where: { viewCount: { gt: 0 }, rating: 5 }, }), - ).resolves.toMatchObject(ratedVideoContent); + ).toResolveWithLength(0); - // find as a relation await expect( - client.user.findUnique({ - where: { id: u.id }, - include: { assets: true, ratedVideos: true }, + client.ratedVideo.findMany({ + where: { viewCount: { gt: 0 }, rating: 4 }, }), - ).resolves.toMatchObject({ - assets: [ratedVideoContent], - ratedVideos: [ratedVideoContent], - }); + ).toResolveWithLength(1); + }); - // find as a relation with selection + it('works with filtering relations', async () => { await expect( - client.user.findUnique({ - where: { id: u.id }, + client.user.findFirst({ include: { assets: { - select: { id: true, assetType: true }, + where: { viewCount: { gt: 0 } }, }, + }, + }), + ).resolves.toSatisfy((user) => user.assets.length === 1); + + await expect( + client.user.findFirst({ + include: { ratedVideos: { - select: { - url: true, - rating: true, - }, + where: { viewCount: { gt: 0 }, url: 'v1' }, }, }, }), - ).resolves.toMatchObject({ - assets: [{ id: v.id, assetType: 'Video' }], - ratedVideos: [{ url: 'abc', rating: 5 }], - }); - }); + ).resolves.toSatisfy((user) => user.ratedVideos.length === 0); - describe('Delegate filter tests', async () => { - beforeEach(async () => { - const u = await client.user.create({ - data: { - email: 'u1@example.com', - }, - }); - await client.ratedVideo.create({ - data: { - viewCount: 0, - duration: 100, - url: 'v1', - rating: 5, - owner: { connect: { id: u.id } }, - user: { connect: { id: u.id } }, - comments: { create: { content: 'c1' } }, - }, - }); - await client.ratedVideo.create({ - data: { - viewCount: 1, - duration: 200, - url: 'v2', - rating: 4, - owner: { connect: { id: u.id } }, - user: { connect: { id: u.id } }, - comments: { create: { content: 'c2' } }, + await expect( + client.user.findFirst({ + include: { + ratedVideos: { + where: { viewCount: { gt: 0 }, url: 'v2' }, + }, }, - }); - }); - - it('works with toplevel filters', async () => { - await expect( - client.asset.findMany({ - where: { viewCount: { gt: 0 } }, - }), - ).toResolveWithLength(1); - - await expect( - client.video.findMany({ - where: { viewCount: { gt: 0 }, url: 'v1' }, - }), - ).toResolveWithLength(0); - - await expect( - client.video.findMany({ - where: { viewCount: { gt: 0 }, url: 'v2' }, - }), - ).toResolveWithLength(1); - - await expect( - client.ratedVideo.findMany({ - where: { viewCount: { gt: 0 }, rating: 5 }, - }), - ).toResolveWithLength(0); - - await expect( - client.ratedVideo.findMany({ - where: { viewCount: { gt: 0 }, rating: 4 }, - }), - ).toResolveWithLength(1); - }); + }), + ).resolves.toSatisfy((user) => user.ratedVideos.length === 1); - it('works with filtering relations', async () => { - await expect( - client.user.findFirst({ - include: { - assets: { - where: { viewCount: { gt: 0 } }, - }, + await expect( + client.user.findFirst({ + include: { + ratedVideos: { + where: { viewCount: { gt: 0 }, rating: 5 }, }, - }), - ).resolves.toSatisfy((user) => user.assets.length === 1); + }, + }), + ).resolves.toSatisfy((user) => user.ratedVideos.length === 0); - await expect( - client.user.findFirst({ - include: { - ratedVideos: { - where: { viewCount: { gt: 0 }, url: 'v1' }, - }, + await expect( + client.user.findFirst({ + include: { + ratedVideos: { + where: { viewCount: { gt: 0 }, rating: 4 }, }, - }), - ).resolves.toSatisfy((user) => user.ratedVideos.length === 0); + }, + }), + ).resolves.toSatisfy((user) => user.ratedVideos.length === 1); + }); - await expect( - client.user.findFirst({ - include: { - ratedVideos: { - where: { viewCount: { gt: 0 }, url: 'v2' }, - }, + it('works with filtering parents', async () => { + await expect( + client.user.findFirst({ + where: { + assets: { + some: { viewCount: { gt: 0 } }, }, - }), - ).resolves.toSatisfy((user) => user.ratedVideos.length === 1); + }, + }), + ).toResolveTruthy(); - await expect( - client.user.findFirst({ - include: { - ratedVideos: { - where: { viewCount: { gt: 0 }, rating: 5 }, - }, + await expect( + client.user.findFirst({ + where: { + assets: { + some: { viewCount: { gt: 1 } }, }, - }), - ).resolves.toSatisfy((user) => user.ratedVideos.length === 0); + }, + }), + ).toResolveFalsy(); - await expect( - client.user.findFirst({ - include: { - ratedVideos: { - where: { viewCount: { gt: 0 }, rating: 4 }, - }, + await expect( + client.user.findFirst({ + where: { + ratedVideos: { + some: { viewCount: { gt: 0 }, url: 'v1' }, }, - }), - ).resolves.toSatisfy((user) => user.ratedVideos.length === 1); - }); + }, + }), + ).toResolveFalsy(); - it('works with filtering parents', async () => { - await expect( - client.user.findFirst({ - where: { - assets: { - some: { viewCount: { gt: 0 } }, - }, + await expect( + client.user.findFirst({ + where: { + ratedVideos: { + some: { viewCount: { gt: 0 }, url: 'v2' }, }, - }), - ).toResolveTruthy(); + }, + }), + ).toResolveTruthy(); + }); - await expect( - client.user.findFirst({ - where: { - assets: { - some: { viewCount: { gt: 1 } }, - }, + it('works with filtering with relations from base', async () => { + await expect( + client.video.findFirst({ + where: { + owner: { + email: 'u1@example.com', }, - }), - ).toResolveFalsy(); + }, + }), + ).toResolveTruthy(); - await expect( - client.user.findFirst({ - where: { - ratedVideos: { - some: { viewCount: { gt: 0 }, url: 'v1' }, - }, - }, - }), - ).toResolveFalsy(); - - await expect( - client.user.findFirst({ - where: { - ratedVideos: { - some: { viewCount: { gt: 0 }, url: 'v2' }, - }, - }, - }), - ).toResolveTruthy(); - }); - - it('works with filtering with relations from base', async () => { - await expect( - client.video.findFirst({ - where: { - owner: { - email: 'u1@example.com', - }, + await expect( + client.video.findFirst({ + where: { + owner: { + email: 'u2@example.com', }, - }), - ).toResolveTruthy(); + }, + }), + ).toResolveFalsy(); - await expect( - client.video.findFirst({ - where: { - owner: { - email: 'u2@example.com', - }, - }, - }), - ).toResolveFalsy(); + await expect( + client.video.findFirst({ + where: { + owner: null, + }, + }), + ).toResolveFalsy(); - await expect( - client.video.findFirst({ - where: { - owner: null, - }, - }), - ).toResolveFalsy(); + await expect( + client.video.findFirst({ + where: { + owner: { is: null }, + }, + }), + ).toResolveFalsy(); - await expect( - client.video.findFirst({ - where: { - owner: { is: null }, - }, - }), - ).toResolveFalsy(); + await expect( + client.video.findFirst({ + where: { + owner: { isNot: null }, + }, + }), + ).toResolveTruthy(); - await expect( - client.video.findFirst({ - where: { - owner: { isNot: null }, + await expect( + client.video.findFirst({ + where: { + comments: { + some: { content: 'c1' }, }, - }), - ).toResolveTruthy(); + }, + }), + ).toResolveTruthy(); - await expect( - client.video.findFirst({ - where: { - comments: { - some: { content: 'c1' }, - }, + await expect( + client.video.findFirst({ + where: { + comments: { + every: { content: 'c2' }, }, - }), - ).toResolveTruthy(); + }, + }), + ).toResolveTruthy(); - await expect( - client.video.findFirst({ - where: { - comments: { - every: { content: 'c2' }, - }, + await expect( + client.video.findFirst({ + where: { + comments: { + none: { content: 'c1' }, }, - }), - ).toResolveTruthy(); + }, + }), + ).toResolveTruthy(); - await expect( - client.video.findFirst({ - where: { - comments: { - none: { content: 'c1' }, - }, + await expect( + client.video.findFirst({ + where: { + comments: { + none: { content: { startsWith: 'c' } }, }, - }), - ).toResolveTruthy(); + }, + }), + ).toResolveFalsy(); + }); + }); - await expect( - client.video.findFirst({ - where: { - comments: { - none: { content: { startsWith: 'c' } }, - }, - }, - }), - ).toResolveFalsy(); + describe('Delegate update tests', () => { + beforeEach(async () => { + const u = await client.user.create({ + data: { + id: 1, + email: 'u1@example.com', + }, + }); + await client.ratedVideo.create({ + data: { + id: 1, + viewCount: 0, + duration: 100, + url: 'v1', + rating: 5, + owner: { connect: { id: u.id } }, + user: { connect: { id: u.id } }, + }, }); }); - describe('Delegate update tests', async () => { - beforeEach(async () => { - const u = await client.user.create({ - data: { - id: 1, - email: 'u1@example.com', - }, - }); - await client.ratedVideo.create({ - data: { - id: 1, - viewCount: 0, - duration: 100, - url: 'v1', - rating: 5, - owner: { connect: { id: u.id } }, - user: { connect: { id: u.id } }, - }, - }); + it('works with toplevel update', async () => { + // id filter + await expect( + client.ratedVideo.update({ + where: { id: 1 }, + data: { viewCount: { increment: 1 }, duration: 200, rating: { set: 4 } }, + }), + ).resolves.toMatchObject({ + viewCount: 1, + duration: 200, + rating: 4, + }); + await expect( + client.video.update({ + where: { id: 1 }, + data: { viewCount: { decrement: 1 }, duration: 100 }, + }), + ).resolves.toMatchObject({ + viewCount: 0, + duration: 100, + }); + await expect( + client.asset.update({ + where: { id: 1 }, + data: { viewCount: { increment: 1 } }, + }), + ).resolves.toMatchObject({ + viewCount: 1, }); - it('works with toplevel update', async () => { - // id filter - await expect( - client.ratedVideo.update({ - where: { id: 1 }, - data: { viewCount: { increment: 1 }, duration: 200, rating: { set: 4 } }, - }), - ).resolves.toMatchObject({ - viewCount: 1, - duration: 200, - rating: 4, - }); - await expect( - client.video.update({ - where: { id: 1 }, - data: { viewCount: { decrement: 1 }, duration: 100 }, - }), - ).resolves.toMatchObject({ - viewCount: 0, - duration: 100, - }); - await expect( - client.asset.update({ - where: { id: 1 }, - data: { viewCount: { increment: 1 } }, - }), - ).resolves.toMatchObject({ - viewCount: 1, - }); + // unique field filter + await expect( + client.ratedVideo.update({ + where: { url: 'v1' }, + data: { viewCount: 2, duration: 300, rating: 3 }, + }), + ).resolves.toMatchObject({ + viewCount: 2, + duration: 300, + rating: 3, + }); + await expect( + client.video.update({ + where: { url: 'v1' }, + data: { viewCount: 3 }, + }), + ).resolves.toMatchObject({ + viewCount: 3, + }); - // unique field filter - await expect( - client.ratedVideo.update({ - where: { url: 'v1' }, - data: { viewCount: 2, duration: 300, rating: 3 }, - }), - ).resolves.toMatchObject({ - viewCount: 2, - duration: 300, - rating: 3, - }); - await expect( - client.video.update({ - where: { url: 'v1' }, - data: { viewCount: 3 }, - }), - ).resolves.toMatchObject({ - viewCount: 3, - }); - - // not found - await expect( - client.ratedVideo.update({ - where: { url: 'v2' }, - data: { viewCount: 4 }, - }), - ).toBeRejectedNotFound(); + // not found + await expect( + client.ratedVideo.update({ + where: { url: 'v2' }, + data: { viewCount: 4 }, + }), + ).toBeRejectedNotFound(); - // update id - await expect( - client.ratedVideo.update({ - where: { id: 1 }, - data: { id: 2 }, - }), - ).resolves.toMatchObject({ - id: 2, - viewCount: 3, - }); + // update id + await expect( + client.ratedVideo.update({ + where: { id: 1 }, + data: { id: 2 }, + }), + ).resolves.toMatchObject({ + id: 2, + viewCount: 3, }); + }); - it('works with nested update', async () => { - await expect( - client.user.update({ - where: { id: 1 }, - data: { - assets: { - update: { - where: { id: 1 }, - data: { viewCount: { increment: 1 } }, - }, - }, - }, - include: { assets: true }, - }), - ).resolves.toMatchObject({ - assets: [{ viewCount: 1 }], - }); - - await expect( - client.user.update({ - where: { id: 1 }, - data: { - ratedVideos: { - update: { - where: { id: 1 }, - data: { viewCount: 2, rating: 4, duration: 200 }, - }, + it('works with nested update', async () => { + await expect( + client.user.update({ + where: { id: 1 }, + data: { + assets: { + update: { + where: { id: 1 }, + data: { viewCount: { increment: 1 } }, }, }, - include: { ratedVideos: true }, - }), - ).resolves.toMatchObject({ - ratedVideos: [{ viewCount: 2, rating: 4, duration: 200 }], - }); - - // unique filter - await expect( - client.user.update({ - where: { id: 1 }, - data: { - ratedVideos: { - update: { - where: { url: 'v1' }, - data: { viewCount: 3 }, - }, + }, + include: { assets: true }, + }), + ).resolves.toMatchObject({ + assets: [{ viewCount: 1 }], + }); + + await expect( + client.user.update({ + where: { id: 1 }, + data: { + ratedVideos: { + update: { + where: { id: 1 }, + data: { viewCount: 2, rating: 4, duration: 200 }, }, }, - include: { ratedVideos: true }, - }), - ).resolves.toMatchObject({ - ratedVideos: [{ viewCount: 3 }], - }); - - // deep nested - await expect( - client.user.update({ - where: { id: 1 }, - data: { - assets: { - update: { - where: { id: 1 }, - data: { comments: { create: { content: 'c1' } } }, - }, + }, + include: { ratedVideos: true }, + }), + ).resolves.toMatchObject({ + ratedVideos: [{ viewCount: 2, rating: 4, duration: 200 }], + }); + + // unique filter + await expect( + client.user.update({ + where: { id: 1 }, + data: { + ratedVideos: { + update: { + where: { url: 'v1' }, + data: { viewCount: 3 }, }, }, - include: { assets: { include: { comments: true } } }, - }), - ).resolves.toMatchObject({ - assets: [{ comments: [{ content: 'c1' }] }], - }); + }, + include: { ratedVideos: true }, + }), + ).resolves.toMatchObject({ + ratedVideos: [{ viewCount: 3 }], }); - it('works with updating a base relation', async () => { - await expect( - client.video.update({ - where: { id: 1 }, - data: { - owner: { update: { level: { increment: 1 } } }, + // deep nested + await expect( + client.user.update({ + where: { id: 1 }, + data: { + assets: { + update: { + where: { id: 1 }, + data: { comments: { create: { content: 'c1' } } }, + }, }, - include: { owner: true }, - }), - ).resolves.toMatchObject({ - owner: { level: 1 }, - }); + }, + include: { assets: { include: { comments: true } } }, + }), + ).resolves.toMatchObject({ + assets: [{ comments: [{ content: 'c1' }] }], }); + }); - it('works with updateMany', async () => { - await client.ratedVideo.create({ - data: { id: 2, viewCount: 1, duration: 200, url: 'abc', rating: 5 }, - }); - - // update from sub model - await expect( - client.ratedVideo.updateMany({ - where: { duration: { gt: 100 } }, - data: { viewCount: { increment: 1 }, duration: { increment: 1 }, rating: { set: 3 } }, - }), - ).resolves.toEqual({ count: 1 }); - - await expect(client.ratedVideo.findMany()).resolves.toEqual( - expect.arrayContaining([ - expect.objectContaining({ - viewCount: 2, - duration: 201, - rating: 3, - }), - ]), - ); - - await expect( - client.ratedVideo.updateMany({ - where: { viewCount: { gt: 1 } }, - data: { viewCount: { increment: 1 } }, - }), - ).resolves.toEqual({ count: 1 }); - - await expect( - client.ratedVideo.updateMany({ - where: { rating: 3 }, - data: { viewCount: { increment: 1 } }, - }), - ).resolves.toEqual({ count: 1 }); + it('works with updating a base relation', async () => { + await expect( + client.video.update({ + where: { id: 1 }, + data: { + owner: { update: { level: { increment: 1 } } }, + }, + include: { owner: true }, + }), + ).resolves.toMatchObject({ + owner: { level: 1 }, + }); + }); - // update from delegate model - await expect( - client.asset.updateMany({ - where: { viewCount: { gt: 0 } }, - data: { viewCount: 100 }, - }), - ).resolves.toEqual({ count: 1 }); - await expect( - client.video.updateMany({ - where: { duration: { gt: 200 } }, - data: { viewCount: 200, duration: 300 }, - }), - ).resolves.toEqual({ count: 1 }); - await expect(client.ratedVideo.findMany()).resolves.toEqual( - expect.arrayContaining([ - expect.objectContaining({ - viewCount: 200, - duration: 300, - }), - ]), - ); - - // updateMany with limit unsupported - await expect( - client.ratedVideo.updateMany({ - where: { duration: { gt: 200 } }, - data: { viewCount: 200, duration: 300 }, - limit: 1, - }), - ).rejects.toThrow('Updating with a limit is not supported for polymorphic models'); + it('works with updateMany', async () => { + await client.ratedVideo.create({ + data: { id: 2, viewCount: 1, duration: 200, url: 'abc', rating: 5 }, }); - it('works with updateManyAndReturn', async () => { - await client.ratedVideo.create({ - data: { id: 2, viewCount: 1, duration: 200, url: 'abc', rating: 5 }, - }); + // update from sub model + await expect( + client.ratedVideo.updateMany({ + where: { duration: { gt: 100 } }, + data: { viewCount: { increment: 1 }, duration: { increment: 1 }, rating: { set: 3 } }, + }), + ).resolves.toEqual({ count: 1 }); - // update from sub model - await expect( - client.ratedVideo.updateManyAndReturn({ - where: { duration: { gt: 100 } }, - data: { viewCount: { increment: 1 }, duration: { increment: 1 }, rating: { set: 3 } }, - }), - ).resolves.toEqual([ + await expect(client.ratedVideo.findMany()).resolves.toEqual( + expect.arrayContaining([ expect.objectContaining({ viewCount: 2, duration: 201, rating: 3, }), - ]); + ]), + ); - // update from delegate model - await expect( - client.asset.updateManyAndReturn({ - where: { viewCount: { gt: 0 } }, - data: { viewCount: 100 }, - }), - ).resolves.toEqual([ + await expect( + client.ratedVideo.updateMany({ + where: { viewCount: { gt: 1 } }, + data: { viewCount: { increment: 1 } }, + }), + ).resolves.toEqual({ count: 1 }); + + await expect( + client.ratedVideo.updateMany({ + where: { rating: 3 }, + data: { viewCount: { increment: 1 } }, + }), + ).resolves.toEqual({ count: 1 }); + + // update from delegate model + await expect( + client.asset.updateMany({ + where: { viewCount: { gt: 0 } }, + data: { viewCount: 100 }, + }), + ).resolves.toEqual({ count: 1 }); + await expect( + client.video.updateMany({ + where: { duration: { gt: 200 } }, + data: { viewCount: 200, duration: 300 }, + }), + ).resolves.toEqual({ count: 1 }); + await expect(client.ratedVideo.findMany()).resolves.toEqual( + expect.arrayContaining([ expect.objectContaining({ - viewCount: 100, - duration: 201, - rating: 3, + viewCount: 200, + duration: 300, }), - ]); + ]), + ); + + // updateMany with limit unsupported + await expect( + client.ratedVideo.updateMany({ + where: { duration: { gt: 200 } }, + data: { viewCount: 200, duration: 300 }, + limit: 1, + }), + ).rejects.toThrow('Updating with a limit is not supported for polymorphic models'); + }); + + it('works with updateManyAndReturn', async () => { + await client.ratedVideo.create({ + data: { id: 2, viewCount: 1, duration: 200, url: 'abc', rating: 5 }, }); - it('works with upsert', async () => { - await expect( - // @ts-expect-error - client.asset.upsert({ - where: { id: 2 }, - create: { - viewCount: 10, - assetType: 'Video', - }, - update: { - viewCount: { increment: 1 }, - }, - }), - ).rejects.toThrow('is a delegate'); - - // create case - await expect( - client.ratedVideo.upsert({ - where: { id: 2 }, - create: { - id: 2, - viewCount: 2, - duration: 200, - url: 'v2', - rating: 3, - }, - update: { - viewCount: { increment: 1 }, - }, - }), - ).resolves.toMatchObject({ - id: 2, + // update from sub model + await expect( + client.ratedVideo.updateManyAndReturn({ + where: { duration: { gt: 100 } }, + data: { viewCount: { increment: 1 }, duration: { increment: 1 }, rating: { set: 3 } }, + }), + ).resolves.toEqual([ + expect.objectContaining({ viewCount: 2, - }); - - // update case - await expect( - client.ratedVideo.upsert({ - where: { id: 2 }, - create: { - id: 2, - viewCount: 2, - duration: 200, - url: 'v2', - rating: 3, - }, - update: { - viewCount: 3, - duration: 300, - rating: 2, - }, - }), - ).resolves.toMatchObject({ - id: 2, - viewCount: 3, - duration: 300, - rating: 2, - }); - }); + duration: 201, + rating: 3, + }), + ]); + + // update from delegate model + await expect( + client.asset.updateManyAndReturn({ + where: { viewCount: { gt: 0 } }, + data: { viewCount: 100 }, + }), + ).resolves.toEqual([ + expect.objectContaining({ + viewCount: 100, + duration: 201, + rating: 3, + }), + ]); }); - describe('Delegate delete tests', () => { - it('works with delete', async () => { - // delete from sub model - await client.ratedVideo.create({ - data: { - id: 1, - duration: 100, - url: 'abc', - rating: 5, + it('works with upsert', async () => { + await expect( + // @ts-expect-error + client.asset.upsert({ + where: { id: 2 }, + create: { + viewCount: 10, + assetType: 'Video', }, - }); - await expect( - client.ratedVideo.delete({ - where: { url: 'abc' }, - }), - ).resolves.toMatchObject({ + update: { + viewCount: { increment: 1 }, + }, + }), + ).rejects.toThrow('is a delegate'); + + // create case + await expect( + client.ratedVideo.upsert({ + where: { id: 2 }, + create: { + id: 2, + viewCount: 2, + duration: 200, + url: 'v2', + rating: 3, + }, + update: { + viewCount: { increment: 1 }, + }, + }), + ).resolves.toMatchObject({ + id: 2, + viewCount: 2, + }); + + // update case + await expect( + client.ratedVideo.upsert({ + where: { id: 2 }, + create: { + id: 2, + viewCount: 2, + duration: 200, + url: 'v2', + rating: 3, + }, + update: { + viewCount: 3, + duration: 300, + rating: 2, + }, + }), + ).resolves.toMatchObject({ + id: 2, + viewCount: 3, + duration: 300, + rating: 2, + }); + }); + }); + + describe('Delegate delete tests', () => { + it('works with delete', async () => { + // delete from sub model + await client.ratedVideo.create({ + data: { id: 1, duration: 100, url: 'abc', rating: 5, - }); - await expect(client.ratedVideo.findMany()).toResolveWithLength(0); - await expect(client.video.findMany()).toResolveWithLength(0); - await expect(client.asset.findMany()).toResolveWithLength(0); + }, + }); + await expect( + client.ratedVideo.delete({ + where: { url: 'abc' }, + }), + ).resolves.toMatchObject({ + id: 1, + duration: 100, + url: 'abc', + rating: 5, + }); + await expect(client.ratedVideo.findMany()).toResolveWithLength(0); + await expect(client.video.findMany()).toResolveWithLength(0); + await expect(client.asset.findMany()).toResolveWithLength(0); - // delete from base model - await client.ratedVideo.create({ - data: { - id: 1, - duration: 100, - url: 'abc', - rating: 5, - }, - }); - await expect( - client.asset.delete({ - where: { id: 1 }, - }), - ).resolves.toMatchObject({ + // delete from base model + await client.ratedVideo.create({ + data: { id: 1, duration: 100, url: 'abc', rating: 5, - }); - await expect(client.ratedVideo.findMany()).toResolveWithLength(0); - await expect(client.video.findMany()).toResolveWithLength(0); - await expect(client.asset.findMany()).toResolveWithLength(0); + }, + }); + await expect( + client.asset.delete({ + where: { id: 1 }, + }), + ).resolves.toMatchObject({ + id: 1, + duration: 100, + url: 'abc', + rating: 5, + }); + await expect(client.ratedVideo.findMany()).toResolveWithLength(0); + await expect(client.video.findMany()).toResolveWithLength(0); + await expect(client.asset.findMany()).toResolveWithLength(0); - // nested delete - await client.user.create({ - data: { - id: 1, - email: 'abc', - }, - }); - await client.ratedVideo.create({ + // nested delete + await client.user.create({ + data: { + id: 1, + email: 'abc', + }, + }); + await client.ratedVideo.create({ + data: { + id: 1, + duration: 100, + url: 'abc', + rating: 5, + owner: { connect: { id: 1 } }, + }, + }); + await expect( + client.user.update({ + where: { id: 1 }, data: { - id: 1, - duration: 100, - url: 'abc', - rating: 5, - owner: { connect: { id: 1 } }, - }, - }); - await expect( - client.user.update({ - where: { id: 1 }, - data: { - assets: { - delete: { id: 1 }, - }, + assets: { + delete: { id: 1 }, }, - include: { assets: true }, - }), - ).resolves.toMatchObject({ assets: [] }); - await expect(client.ratedVideo.findMany()).toResolveWithLength(0); - await expect(client.video.findMany()).toResolveWithLength(0); - await expect(client.asset.findMany()).toResolveWithLength(0); - - // delete user should cascade to ratedVideo and in turn delete its bases - await client.ratedVideo.create({ - data: { - id: 1, - duration: 100, - url: 'abc', - rating: 5, - user: { connect: { id: 1 } }, }, - }); - await expect( - client.user.delete({ - where: { id: 1 }, - }), - ).toResolveTruthy(); - await expect(client.ratedVideo.findMany()).toResolveWithLength(0); - await expect(client.video.findMany()).toResolveWithLength(0); - await expect(client.asset.findMany()).toResolveWithLength(0); - }); - - it('works with deleteMany', async () => { - await client.ratedVideo.createMany({ - data: [ - { - id: 1, - viewCount: 1, - duration: 100, - url: 'abc', - rating: 5, - }, - { - id: 2, - viewCount: 2, - duration: 200, - url: 'def', - rating: 4, - }, - ], - }); + include: { assets: true }, + }), + ).resolves.toMatchObject({ assets: [] }); + await expect(client.ratedVideo.findMany()).toResolveWithLength(0); + await expect(client.video.findMany()).toResolveWithLength(0); + await expect(client.asset.findMany()).toResolveWithLength(0); - await expect( - client.video.deleteMany({ - where: { duration: { gt: 150 }, viewCount: 1 }, - }), - ).resolves.toMatchObject({ count: 0 }); - await expect( - client.video.deleteMany({ - where: { duration: { gt: 150 }, viewCount: 2 }, - }), - ).resolves.toMatchObject({ count: 1 }); - await expect(client.ratedVideo.findMany()).toResolveWithLength(1); - await expect(client.video.findMany()).toResolveWithLength(1); - await expect(client.asset.findMany()).toResolveWithLength(1); + // delete user should cascade to ratedVideo and in turn delete its bases + await client.ratedVideo.create({ + data: { + id: 1, + duration: 100, + url: 'abc', + rating: 5, + user: { connect: { id: 1 } }, + }, }); + await expect( + client.user.delete({ + where: { id: 1 }, + }), + ).toResolveTruthy(); + await expect(client.ratedVideo.findMany()).toResolveWithLength(0); + await expect(client.video.findMany()).toResolveWithLength(0); + await expect(client.asset.findMany()).toResolveWithLength(0); }); - describe('Delegate aggregation tests', () => { - beforeEach(async () => { - const u = await client.user.create({ - data: { - id: 1, - email: 'u1@example.com', - }, - }); - await client.ratedVideo.create({ - data: { + it('works with deleteMany', async () => { + await client.ratedVideo.createMany({ + data: [ + { id: 1, - viewCount: 0, + viewCount: 1, duration: 100, - url: 'v1', + url: 'abc', rating: 5, - owner: { connect: { id: u.id } }, - user: { connect: { id: u.id } }, - comments: { create: [{ content: 'c1' }, { content: 'c2' }] }, }, - }); - await client.ratedVideo.create({ - data: { + { id: 2, viewCount: 2, duration: 200, - url: 'v2', - rating: 3, + url: 'def', + rating: 4, }, - }); + ], }); - it('works with count', async () => { - await expect( - client.ratedVideo.count({ - where: { rating: 5 }, - }), - ).resolves.toEqual(1); - await expect( - client.ratedVideo.count({ - where: { duration: 100 }, - }), - ).resolves.toEqual(1); - await expect( - client.ratedVideo.count({ - where: { viewCount: 2 }, - }), - ).resolves.toEqual(1); - - await expect( - client.video.count({ - where: { duration: 100 }, - }), - ).resolves.toEqual(1); - await expect( - client.asset.count({ - where: { viewCount: { gt: 0 } }, - }), - ).resolves.toEqual(1); + await expect( + client.video.deleteMany({ + where: { duration: { gt: 150 }, viewCount: 1 }, + }), + ).resolves.toMatchObject({ count: 0 }); + await expect( + client.video.deleteMany({ + where: { duration: { gt: 150 }, viewCount: 2 }, + }), + ).resolves.toMatchObject({ count: 1 }); + await expect(client.ratedVideo.findMany()).toResolveWithLength(1); + await expect(client.video.findMany()).toResolveWithLength(1); + await expect(client.asset.findMany()).toResolveWithLength(1); + }); + }); - // field selection - await expect( - client.ratedVideo.count({ - select: { _all: true, viewCount: true, url: true, rating: true }, - }), - ).resolves.toMatchObject({ - _all: 2, - viewCount: 2, - url: 2, - rating: 2, - }); - await expect( - client.video.count({ - select: { _all: true, viewCount: true, url: true }, - }), - ).resolves.toMatchObject({ - _all: 2, - viewCount: 2, - url: 2, - }); - await expect( - client.asset.count({ - select: { _all: true, viewCount: true }, - }), - ).resolves.toMatchObject({ - _all: 2, + describe('Delegate aggregation tests', () => { + beforeEach(async () => { + const u = await client.user.create({ + data: { + id: 1, + email: 'u1@example.com', + }, + }); + await client.ratedVideo.create({ + data: { + id: 1, + viewCount: 0, + duration: 100, + url: 'v1', + rating: 5, + owner: { connect: { id: u.id } }, + user: { connect: { id: u.id } }, + comments: { create: [{ content: 'c1' }, { content: 'c2' }] }, + }, + }); + await client.ratedVideo.create({ + data: { + id: 2, viewCount: 2, - }); + duration: 200, + url: 'v2', + rating: 3, + }, }); + }); - it('works with aggregate', async () => { - await expect( - client.ratedVideo.aggregate({ - where: { viewCount: { gte: 0 }, duration: { gt: 0 }, rating: { gt: 0 } }, - _avg: { viewCount: true, duration: true, rating: true }, - _count: true, - }), - ).resolves.toMatchObject({ - _avg: { - viewCount: 1, - duration: 150, - rating: 4, - }, - _count: 2, - }); - await expect( - client.video.aggregate({ - where: { viewCount: { gte: 0 }, duration: { gt: 0 } }, - _avg: { viewCount: true, duration: true }, - _count: true, - }), - ).resolves.toMatchObject({ - _avg: { - viewCount: 1, - duration: 150, - }, - _count: 2, - }); - await expect( - client.asset.aggregate({ - where: { viewCount: { gte: 0 } }, - _avg: { viewCount: true }, - _count: true, - }), - ).resolves.toMatchObject({ - _avg: { - viewCount: 1, - }, - _count: 2, - }); + it('works with count', async () => { + await expect( + client.ratedVideo.count({ + where: { rating: 5 }, + }), + ).resolves.toEqual(1); + await expect( + client.ratedVideo.count({ + where: { duration: 100 }, + }), + ).resolves.toEqual(1); + await expect( + client.ratedVideo.count({ + where: { viewCount: 2 }, + }), + ).resolves.toEqual(1); - // just count - await expect( - client.ratedVideo.aggregate({ - _count: true, - }), - ).resolves.toMatchObject({ - _count: 2, - }); + await expect( + client.video.count({ + where: { duration: 100 }, + }), + ).resolves.toEqual(1); + await expect( + client.asset.count({ + where: { viewCount: { gt: 0 } }, + }), + ).resolves.toEqual(1); + + // field selection + await expect( + client.ratedVideo.count({ + select: { _all: true, viewCount: true, url: true, rating: true }, + }), + ).resolves.toMatchObject({ + _all: 2, + viewCount: 2, + url: 2, + rating: 2, + }); + await expect( + client.video.count({ + select: { _all: true, viewCount: true, url: true }, + }), + ).resolves.toMatchObject({ + _all: 2, + viewCount: 2, + url: 2, + }); + await expect( + client.asset.count({ + select: { _all: true, viewCount: true }, + }), + ).resolves.toMatchObject({ + _all: 2, + viewCount: 2, + }); + }); + + it('works with aggregate', async () => { + await expect( + client.ratedVideo.aggregate({ + where: { viewCount: { gte: 0 }, duration: { gt: 0 }, rating: { gt: 0 } }, + _avg: { viewCount: true, duration: true, rating: true }, + _count: true, + }), + ).resolves.toMatchObject({ + _avg: { + viewCount: 1, + duration: 150, + rating: 4, + }, + _count: 2, + }); + await expect( + client.video.aggregate({ + where: { viewCount: { gte: 0 }, duration: { gt: 0 } }, + _avg: { viewCount: true, duration: true }, + _count: true, + }), + ).resolves.toMatchObject({ + _avg: { + viewCount: 1, + duration: 150, + }, + _count: 2, + }); + await expect( + client.asset.aggregate({ + where: { viewCount: { gte: 0 } }, + _avg: { viewCount: true }, + _count: true, + }), + ).resolves.toMatchObject({ + _avg: { + viewCount: 1, + }, + _count: 2, + }); + + // just count + await expect( + client.ratedVideo.aggregate({ + _count: true, + }), + ).resolves.toMatchObject({ + _count: 2, }); }); - }, -); + }); +}); diff --git a/packages/runtime/test/client-api/delete-many.test.ts b/packages/runtime/test/client-api/delete-many.test.ts index df8f3ac4..b31896f0 100644 --- a/packages/runtime/test/client-api/delete-many.test.ts +++ b/packages/runtime/test/client-api/delete-many.test.ts @@ -1,15 +1,13 @@ import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import type { ClientContract } from '../../src/client'; import { schema } from '../schemas/basic'; -import { createClientSpecs } from './client-specs'; +import { createTestClient } from '../utils'; -const PG_DB_NAME = 'client-api-delete-many-tests'; - -describe.each(createClientSpecs(PG_DB_NAME))('Client deleteMany tests', ({ createClient }) => { +describe('Client deleteMany tests', () => { let client: ClientContract; beforeEach(async () => { - client = await createClient(); + client = await createTestClient(schema); }); afterEach(async () => { diff --git a/packages/runtime/test/client-api/delete.test.ts b/packages/runtime/test/client-api/delete.test.ts index b67216ae..4e518c07 100644 --- a/packages/runtime/test/client-api/delete.test.ts +++ b/packages/runtime/test/client-api/delete.test.ts @@ -1,15 +1,13 @@ import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import type { ClientContract } from '../../src/client'; import { schema } from '../schemas/basic'; -import { createClientSpecs } from './client-specs'; +import { createTestClient } from '../utils'; -const PG_DB_NAME = 'client-api-delete-tests'; - -describe.each(createClientSpecs(PG_DB_NAME))('Client delete tests', ({ createClient }) => { +describe('Client delete tests', () => { let client: ClientContract; beforeEach(async () => { - client = await createClient(); + client = await createTestClient(schema); }); afterEach(async () => { diff --git a/packages/runtime/test/client-api/filter.test.ts b/packages/runtime/test/client-api/filter.test.ts index b7ec82af..26af9dd7 100644 --- a/packages/runtime/test/client-api/filter.test.ts +++ b/packages/runtime/test/client-api/filter.test.ts @@ -1,15 +1,13 @@ import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import type { ClientContract } from '../../src/client'; import { schema } from '../schemas/basic'; -import { createClientSpecs } from './client-specs'; +import { createTestClient } from '../utils'; -const PG_DB_NAME = 'client-api-filter-tests'; - -describe.each(createClientSpecs(PG_DB_NAME))('Client filter tests for $provider', ({ createClient, provider }) => { +describe('Client filter tests ', () => { let client: ClientContract; beforeEach(async () => { - client = await createClient(); + client = await createTestClient(schema); }); afterEach(async () => { @@ -76,7 +74,7 @@ describe.each(createClientSpecs(PG_DB_NAME))('Client filter tests for $provider' }), ).toResolveTruthy(); - if (provider === 'sqlite') { + if (client.$schema.provider.type === 'sqlite') { // sqlite: equalities are case-sensitive, match is case-insensitive await expect( client.user.findFirst({ @@ -126,7 +124,7 @@ describe.each(createClientSpecs(PG_DB_NAME))('Client filter tests for $provider' }, }), ).toResolveTruthy(); - } else if (provider === 'postgresql') { + } else if (client.$schema.provider.type === 'postgresql') { // postgresql: default is case-sensitive, but can be toggled with "mode" await expect( diff --git a/packages/runtime/test/client-api/find.test.ts b/packages/runtime/test/client-api/find.test.ts index 36d20eba..1f8219ec 100644 --- a/packages/runtime/test/client-api/find.test.ts +++ b/packages/runtime/test/client-api/find.test.ts @@ -2,16 +2,14 @@ import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import type { ClientContract } from '../../src/client'; import { InputValidationError, NotFoundError } from '../../src/client/errors'; import { schema } from '../schemas/basic'; -import { createClientSpecs } from './client-specs'; +import { createTestClient } from '../utils'; import { createPosts, createUser } from './utils'; -const PG_DB_NAME = 'client-api-find-tests'; - -describe.each(createClientSpecs(PG_DB_NAME))('Client find tests for $provider', ({ createClient, provider }) => { +describe('Client find tests ', () => { let client: ClientContract; beforeEach(async () => { - client = await createClient(); + client = await createTestClient(schema); }); afterEach(async () => { @@ -265,7 +263,7 @@ describe.each(createClientSpecs(PG_DB_NAME))('Client find tests for $provider', role: 'USER', }); - if (provider === 'sqlite') { + if (client.$schema.provider.type === 'sqlite') { await expect(client.user.findMany({ distinct: ['role'] } as any)).rejects.toThrow('not supported'); return; } @@ -675,7 +673,8 @@ describe.each(createClientSpecs(PG_DB_NAME))('Client find tests for $provider', posts: [expect.objectContaining({ title: 'Post1' })], }); - if (provider === 'postgresql') { + // @ts-ignore + if (client.$schema.provider.type === 'postgresql') { await expect( client.user.findUnique({ where: { id: user.id }, @@ -901,7 +900,12 @@ describe.each(createClientSpecs(PG_DB_NAME))('Client find tests for $provider', }, }, }); - expect(u.posts[0]).toMatchObject(post2); + expect(u.posts[0]).toMatchObject({ + title: post2.title, + published: post2.published, + createdAt: expect.any(Date), + updatedAt: expect.any(Date), + }); u = await client.user.findUniqueOrThrow({ where: { id: user.id }, include: { @@ -912,7 +916,12 @@ describe.each(createClientSpecs(PG_DB_NAME))('Client find tests for $provider', }, }, }); - expect(u.posts[0]).toMatchObject(post1); + expect(u.posts[0]).toMatchObject({ + title: post1.title, + published: post1.published, + createdAt: expect.any(Date), + updatedAt: expect.any(Date), + }); // cursor u = await client.user.findUniqueOrThrow({ diff --git a/packages/runtime/test/client-api/group-by.test.ts b/packages/runtime/test/client-api/group-by.test.ts index a9c0d56a..b0909e34 100644 --- a/packages/runtime/test/client-api/group-by.test.ts +++ b/packages/runtime/test/client-api/group-by.test.ts @@ -1,16 +1,14 @@ import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import type { ClientContract } from '../../src/client'; import { schema } from '../schemas/basic'; -import { createClientSpecs } from './client-specs'; +import { createTestClient } from '../utils'; import { createPosts, createUser } from './utils'; -const PG_DB_NAME = 'client-api-group-by-tests'; - -describe.each(createClientSpecs(PG_DB_NAME))('Client groupBy tests', ({ createClient }) => { +describe('Client groupBy tests', () => { let client: ClientContract; beforeEach(async () => { - client = await createClient(); + client = await createTestClient(schema); }); afterEach(async () => { @@ -269,7 +267,7 @@ describe.each(createClientSpecs(PG_DB_NAME))('Client groupBy tests', ({ createCl age: 10, }, }), - ).rejects.toThrow(/must be in \\"by\\"/); + ).rejects.toThrow(/must be in "by"/); }); it('complains about fields in orderBy that are not in by', async () => { @@ -280,6 +278,6 @@ describe.each(createClientSpecs(PG_DB_NAME))('Client groupBy tests', ({ createCl age: 'asc', }, }), - ).rejects.toThrow(/must be in \\"by\\"/); + ).rejects.toThrow(/must be in "by"/); }); }); diff --git a/packages/runtime/test/client-api/mixin.test.ts b/packages/runtime/test/client-api/mixin.test.ts index 9f85d30a..e7b9dcac 100644 --- a/packages/runtime/test/client-api/mixin.test.ts +++ b/packages/runtime/test/client-api/mixin.test.ts @@ -75,7 +75,7 @@ model Bar with CommonFields { description: 'Bar', }, }), - ).rejects.toThrow('constraint failed'); + ).rejects.toThrow('constraint'); }); it('supports multiple-level mixins', async () => { diff --git a/packages/runtime/test/client-api/name-mapping.test.ts b/packages/runtime/test/client-api/name-mapping.test.ts index 41341f7c..cfb9bec2 100644 --- a/packages/runtime/test/client-api/name-mapping.test.ts +++ b/packages/runtime/test/client-api/name-mapping.test.ts @@ -4,84 +4,24 @@ import type { ClientContract } from '../../src'; import { schema, type SchemaType } from '../schemas/name-mapping/schema'; import { createTestClient } from '../utils'; -const TEST_DB = 'client-api-name-mapper-test'; - -describe.each([{ provider: 'sqlite' as const }, { provider: 'postgresql' as const }])( - 'Name mapping tests', - ({ provider }) => { - let db: ClientContract; - - beforeEach(async () => { - db = await createTestClient( - schema, - { usePrismaPush: true, provider, dbName: TEST_DB }, - path.join(__dirname, '../schemas/name-mapping/schema.zmodel'), - ); - }); - - afterEach(async () => { - await db.$disconnect(); - }); - - it('works with create', async () => { - await expect( - db.user.create({ - data: { - email: 'u1@test.com', - posts: { - create: { - title: 'Post1', - }, - }, - }, - }), - ).resolves.toMatchObject({ - id: expect.any(Number), - email: 'u1@test.com', - }); - - await expect( - db.$qb - .insertInto('User') - .values({ - email: 'u2@test.com', - }) - .returning(['id', 'email']) - .executeTakeFirst(), - ).resolves.toMatchObject({ - id: expect.any(Number), - email: 'u2@test.com', - }); - - await expect( - db.$qb - .insertInto('User') - .values({ - email: 'u3@test.com', - }) - .returning(['User.id', 'User.email']) - .executeTakeFirst(), - ).resolves.toMatchObject({ - id: expect.any(Number), - email: 'u3@test.com', - }); - - await expect( - db.$qb - .insertInto('User') - .values({ - email: 'u4@test.com', - }) - .returningAll() - .executeTakeFirst(), - ).resolves.toMatchObject({ - id: expect.any(Number), - email: 'u4@test.com', - }); - }); - - it('works with find', async () => { - const user = await db.user.create({ +describe('Name mapping tests', () => { + let db: ClientContract; + + beforeEach(async () => { + db = await createTestClient( + schema, + { usePrismaPush: true }, + path.join(__dirname, '../schemas/name-mapping/schema.zmodel'), + ); + }); + + afterEach(async () => { + await db.$disconnect(); + }); + + it('works with create', async () => { + await expect( + db.user.create({ data: { email: 'u1@test.com', posts: { @@ -90,346 +30,401 @@ describe.each([{ provider: 'sqlite' as const }, { provider: 'postgresql' as cons }, }, }, - }); - - await expect( - db.user.findFirst({ - where: { email: 'u1@test.com' }, - select: { - id: true, - email: true, - posts: { where: { title: { contains: 'Post1' } }, select: { title: true } }, - }, - }), - ).resolves.toMatchObject({ - id: expect.any(Number), - email: 'u1@test.com', - posts: [{ title: 'Post1' }], - }); + }), + ).resolves.toMatchObject({ + id: expect.any(Number), + email: 'u1@test.com', + }); - await expect( - db.$qb.selectFrom('User').selectAll().where('email', '=', 'u1@test.com').executeTakeFirst(), - ).resolves.toMatchObject({ - id: expect.any(Number), - email: 'u1@test.com', - }); + await expect( + db.$qb + .insertInto('User') + .values({ + email: 'u2@test.com', + }) + .returning(['id', 'email']) + .executeTakeFirst(), + ).resolves.toMatchObject({ + id: expect.any(Number), + email: 'u2@test.com', + }); - await expect( - db.$qb.selectFrom('User').select(['User.email']).where('email', '=', 'u1@test.com').executeTakeFirst(), - ).resolves.toMatchObject({ - email: 'u1@test.com', - }); - - await expect( - db.$qb - .selectFrom('User') - .select(['email']) - .whereRef('email', '=', 'email') - .orderBy(['email']) - .executeTakeFirst(), - ).resolves.toMatchObject({ - email: 'u1@test.com', - }); - - await expect( - db.$qb - .selectFrom('Post') - .innerJoin('User', 'User.id', 'Post.authorId') - .select(['User.email', 'Post.authorId', 'Post.title']) - .whereRef('Post.authorId', '=', 'User.id') - .executeTakeFirst(), - ).resolves.toMatchObject({ - authorId: user.id, - title: 'Post1', - }); - - await expect( - db.$qb - .selectFrom('Post') - .select(['id', 'title']) - .select((eb) => - eb.selectFrom('User').select(['email']).whereRef('User.id', '=', 'Post.authorId').as('email'), - ) - .executeTakeFirst(), - ).resolves.toMatchObject({ - id: user.id, - title: 'Post1', - email: 'u1@test.com', - }); + await expect( + db.$qb + .insertInto('User') + .values({ + email: 'u3@test.com', + }) + .returning(['User.id', 'User.email']) + .executeTakeFirst(), + ).resolves.toMatchObject({ + id: expect.any(Number), + email: 'u3@test.com', }); - it('works with update', async () => { - const user = await db.user.create({ - data: { - email: 'u1@test.com', - posts: { - create: { - id: 1, - title: 'Post1', - }, + await expect( + db.$qb + .insertInto('User') + .values({ + email: 'u4@test.com', + }) + .returningAll() + .executeTakeFirst(), + ).resolves.toMatchObject({ + id: expect.any(Number), + email: 'u4@test.com', + }); + }); + + it('works with find', async () => { + const user = await db.user.create({ + data: { + email: 'u1@test.com', + posts: { + create: { + title: 'Post1', }, }, - }); - - await expect( - db.user.update({ - where: { id: user.id }, - data: { - email: 'u2@test.com', - posts: { - update: { - where: { id: 1 }, - data: { title: 'Post2' }, - }, - }, - }, - include: { posts: true }, - }), - ).resolves.toMatchObject({ - id: user.id, - email: 'u2@test.com', - posts: [expect.objectContaining({ title: 'Post2' })], - }); - - await expect( - db.$qb - .updateTable('User') - .set({ email: (eb) => eb.fn('upper', [eb.ref('email')]) }) - .where('email', '=', 'u2@test.com') - .returning(['email']) - .executeTakeFirst(), - ).resolves.toMatchObject({ email: 'U2@TEST.COM' }); - - await expect( - db.$qb.updateTable('User as u').set({ email: 'u3@test.com' }).returningAll().executeTakeFirst(), - ).resolves.toMatchObject({ id: expect.any(Number), email: 'u3@test.com' }); + }, }); - it('works with delete', async () => { - const user = await db.user.create({ - data: { - email: 'u1@test.com', - posts: { - create: { - id: 1, - title: 'Post1', - }, - }, + await expect( + db.user.findFirst({ + where: { email: 'u1@test.com' }, + select: { + id: true, + email: true, + posts: { where: { title: { contains: 'Post1' } }, select: { title: true } }, }, - }); - - await expect( - db.$qb.deleteFrom('Post').where('title', '=', 'Post1').returning(['id', 'title']).executeTakeFirst(), - ).resolves.toMatchObject({ - id: user.id, - title: 'Post1', - }); - - await expect( - db.user.delete({ - where: { email: 'u1@test.com' }, - include: { posts: true }, - }), - ).resolves.toMatchObject({ - email: 'u1@test.com', - posts: [], - }); + }), + ).resolves.toMatchObject({ + id: expect.any(Number), + email: 'u1@test.com', + posts: [{ title: 'Post1' }], }); - it('works with count', async () => { - await db.user.create({ - data: { - email: 'u1@test.com', - posts: { - create: [{ title: 'Post1' }, { title: 'Post2' }], + await expect( + db.$qb.selectFrom('User').selectAll().where('email', '=', 'u1@test.com').executeTakeFirst(), + ).resolves.toMatchObject({ + id: expect.any(Number), + email: 'u1@test.com', + }); + + await expect( + db.$qb.selectFrom('User').select(['User.email']).where('email', '=', 'u1@test.com').executeTakeFirst(), + ).resolves.toMatchObject({ + email: 'u1@test.com', + }); + + await expect( + db.$qb + .selectFrom('User') + .select(['email']) + .whereRef('email', '=', 'email') + .orderBy(['email']) + .executeTakeFirst(), + ).resolves.toMatchObject({ + email: 'u1@test.com', + }); + + await expect( + db.$qb + .selectFrom('Post') + .innerJoin('User', 'User.id', 'Post.authorId') + .select(['User.email', 'Post.authorId', 'Post.title']) + .whereRef('Post.authorId', '=', 'User.id') + .executeTakeFirst(), + ).resolves.toMatchObject({ + authorId: user.id, + title: 'Post1', + }); + + await expect( + db.$qb + .selectFrom('Post') + .select(['id', 'title']) + .select((eb) => + eb.selectFrom('User').select(['email']).whereRef('User.id', '=', 'Post.authorId').as('email'), + ) + .executeTakeFirst(), + ).resolves.toMatchObject({ + id: user.id, + title: 'Post1', + email: 'u1@test.com', + }); + }); + + it('works with update', async () => { + const user = await db.user.create({ + data: { + email: 'u1@test.com', + posts: { + create: { + id: 1, + title: 'Post1', }, }, - }); + }, + }); - await db.user.create({ + await expect( + db.user.update({ + where: { id: user.id }, data: { email: 'u2@test.com', posts: { - create: [{ title: 'Post3' }], + update: { + where: { id: 1 }, + data: { title: 'Post2' }, + }, }, }, - }); + include: { posts: true }, + }), + ).resolves.toMatchObject({ + id: user.id, + email: 'u2@test.com', + posts: [expect.objectContaining({ title: 'Post2' })], + }); - // Test ORM count operations - await expect(db.user.count()).resolves.toBe(2); - await expect(db.post.count()).resolves.toBe(3); - await expect(db.user.count({ select: { email: true } })).resolves.toMatchObject({ - email: 2, - }); + await expect( + db.$qb + .updateTable('User') + .set({ email: (eb) => eb.fn('upper', [eb.ref('email')]) }) + .where('email', '=', 'u2@test.com') + .returning(['email']) + .executeTakeFirst(), + ).resolves.toMatchObject({ email: 'U2@TEST.COM' }); + + await expect( + db.$qb.updateTable('User as u').set({ email: 'u3@test.com' }).returningAll().executeTakeFirst(), + ).resolves.toMatchObject({ id: expect.any(Number), email: 'u3@test.com' }); + }); + + it('works with delete', async () => { + const user = await db.user.create({ + data: { + email: 'u1@test.com', + posts: { + create: { + id: 1, + title: 'Post1', + }, + }, + }, + }); - await expect(db.user.count({ where: { email: 'u1@test.com' } })).resolves.toBe(1); - await expect(db.post.count({ where: { title: { contains: 'Post1' } } })).resolves.toBe(1); + await expect( + db.$qb.deleteFrom('Post').where('title', '=', 'Post1').returning(['id', 'title']).executeTakeFirst(), + ).resolves.toMatchObject({ + id: user.id, + title: 'Post1', + }); - await expect(db.post.count({ where: { author: { email: 'u1@test.com' } } })).resolves.toBe(2); + await expect( + db.user.delete({ + where: { email: 'u1@test.com' }, + include: { posts: true }, + }), + ).resolves.toMatchObject({ + email: 'u1@test.com', + posts: [], + }); + }); - // Test Kysely count operations - const r = await db.$qb - .selectFrom('User') - .select((eb) => eb.fn.count('email').as('count')) - .executeTakeFirst(); - await expect(Number(r?.count)).toBe(2); + it('works with count', async () => { + await db.user.create({ + data: { + email: 'u1@test.com', + posts: { + create: [{ title: 'Post1' }, { title: 'Post2' }], + }, + }, }); - it('works with aggregate', async () => { - await db.user.create({ - data: { - id: 1, - email: 'u1@test.com', - posts: { - create: [ - { id: 1, title: 'Post1' }, - { id: 2, title: 'Post2' }, - ], - }, + await db.user.create({ + data: { + email: 'u2@test.com', + posts: { + create: [{ title: 'Post3' }], }, - }); + }, + }); - await db.user.create({ - data: { - id: 2, - email: 'u2@test.com', - posts: { - create: [{ id: 3, title: 'Post3' }], - }, + // Test ORM count operations + await expect(db.user.count()).resolves.toBe(2); + await expect(db.post.count()).resolves.toBe(3); + await expect(db.user.count({ select: { email: true } })).resolves.toMatchObject({ + email: 2, + }); + + await expect(db.user.count({ where: { email: 'u1@test.com' } })).resolves.toBe(1); + await expect(db.post.count({ where: { title: { contains: 'Post1' } } })).resolves.toBe(1); + + await expect(db.post.count({ where: { author: { email: 'u1@test.com' } } })).resolves.toBe(2); + + // Test Kysely count operations + const r = await db.$qb + .selectFrom('User') + .select((eb) => eb.fn.count('email').as('count')) + .executeTakeFirst(); + await expect(Number(r?.count)).toBe(2); + }); + + it('works with aggregate', async () => { + await db.user.create({ + data: { + id: 1, + email: 'u1@test.com', + posts: { + create: [ + { id: 1, title: 'Post1' }, + { id: 2, title: 'Post2' }, + ], }, - }); - - // Test ORM aggregate operations - await expect(db.user.aggregate({ _count: { id: true, email: true } })).resolves.toMatchObject({ - _count: { id: 2, email: 2 }, - }); - - await expect( - db.post.aggregate({ _count: { authorId: true }, _min: { authorId: true }, _max: { authorId: true } }), - ).resolves.toMatchObject({ - _count: { authorId: 3 }, - _min: { authorId: 1 }, - _max: { authorId: 2 }, - }); - - await expect( - db.post.aggregate({ - where: { author: { email: 'u1@test.com' } }, - _count: { authorId: true }, - _min: { authorId: true }, - _max: { authorId: true }, - }), - ).resolves.toMatchObject({ - _count: { authorId: 2 }, - _min: { authorId: 1 }, - _max: { authorId: 1 }, - }); - - // Test Kysely aggregate operations - const countResult = await db.$qb - .selectFrom('User') - .select((eb) => eb.fn.count('email').as('emailCount')) - .executeTakeFirst(); - expect(Number(countResult?.emailCount)).toBe(2); + }, + }); - const postAggResult = await db.$qb - .selectFrom('Post') - .select((eb) => [eb.fn.min('authorId').as('minAuthorId'), eb.fn.max('authorId').as('maxAuthorId')]) - .executeTakeFirst(); - expect(Number(postAggResult?.minAuthorId)).toBe(1); - expect(Number(postAggResult?.maxAuthorId)).toBe(2); + await db.user.create({ + data: { + id: 2, + email: 'u2@test.com', + posts: { + create: [{ id: 3, title: 'Post3' }], + }, + }, }); - it('works with groupBy', async () => { - // Create test data with multiple posts per user - await db.user.create({ - data: { - id: 1, - email: 'u1@test.com', - posts: { - create: [ - { id: 1, title: 'Post1' }, - { id: 2, title: 'Post2' }, - { id: 3, title: 'Post3' }, - ], - }, + // Test ORM aggregate operations + await expect(db.user.aggregate({ _count: { id: true, email: true } })).resolves.toMatchObject({ + _count: { id: 2, email: 2 }, + }); + + await expect( + db.post.aggregate({ _count: { authorId: true }, _min: { authorId: true }, _max: { authorId: true } }), + ).resolves.toMatchObject({ + _count: { authorId: 3 }, + _min: { authorId: 1 }, + _max: { authorId: 2 }, + }); + + await expect( + db.post.aggregate({ + where: { author: { email: 'u1@test.com' } }, + _count: { authorId: true }, + _min: { authorId: true }, + _max: { authorId: true }, + }), + ).resolves.toMatchObject({ + _count: { authorId: 2 }, + _min: { authorId: 1 }, + _max: { authorId: 1 }, + }); + + // Test Kysely aggregate operations + const countResult = await db.$qb + .selectFrom('User') + .select((eb) => eb.fn.count('email').as('emailCount')) + .executeTakeFirst(); + expect(Number(countResult?.emailCount)).toBe(2); + + const postAggResult = await db.$qb + .selectFrom('Post') + .select((eb) => [eb.fn.min('authorId').as('minAuthorId'), eb.fn.max('authorId').as('maxAuthorId')]) + .executeTakeFirst(); + expect(Number(postAggResult?.minAuthorId)).toBe(1); + expect(Number(postAggResult?.maxAuthorId)).toBe(2); + }); + + it('works with groupBy', async () => { + // Create test data with multiple posts per user + await db.user.create({ + data: { + id: 1, + email: 'u1@test.com', + posts: { + create: [ + { id: 1, title: 'Post1' }, + { id: 2, title: 'Post2' }, + { id: 3, title: 'Post3' }, + ], }, - }); + }, + }); - await db.user.create({ - data: { - id: 2, - email: 'u2@test.com', - posts: { - create: [ - { id: 4, title: 'Post4' }, - { id: 5, title: 'Post5' }, - ], - }, + await db.user.create({ + data: { + id: 2, + email: 'u2@test.com', + posts: { + create: [ + { id: 4, title: 'Post4' }, + { id: 5, title: 'Post5' }, + ], }, - }); + }, + }); - await db.user.create({ - data: { - id: 3, - email: 'u3@test.com', - posts: { - create: [{ id: 6, title: 'Post6' }], - }, + await db.user.create({ + data: { + id: 3, + email: 'u3@test.com', + posts: { + create: [{ id: 6, title: 'Post6' }], }, - }); - - // Test ORM groupBy operations - const userGroupBy = await db.user.groupBy({ - by: ['email'], - _count: { id: true }, - }); - expect(userGroupBy).toHaveLength(3); - expect(userGroupBy).toEqual( - expect.arrayContaining([ - { email: 'u1@test.com', _count: { id: 1 } }, - { email: 'u2@test.com', _count: { id: 1 } }, - { email: 'u3@test.com', _count: { id: 1 } }, - ]), - ); - - const postGroupBy = await db.post.groupBy({ - by: ['authorId'], - _count: { id: true }, - _min: { id: true }, - _max: { id: true }, - }); - expect(postGroupBy).toHaveLength(3); - expect(postGroupBy).toEqual( - expect.arrayContaining([ - { authorId: 1, _count: { id: 3 }, _min: { id: 1 }, _max: { id: 3 } }, - { authorId: 2, _count: { id: 2 }, _min: { id: 4 }, _max: { id: 5 } }, - { authorId: 3, _count: { id: 1 }, _min: { id: 6 }, _max: { id: 6 } }, - ]), - ); - - const filteredGroupBy = await db.post.groupBy({ - by: ['authorId'], - where: { title: { contains: 'Post' } }, - _count: { title: true }, - having: { title: { _count: { gte: 2 } } }, - }); - expect(filteredGroupBy).toHaveLength(2); - expect(filteredGroupBy).toEqual( - expect.arrayContaining([ - { authorId: 1, _count: { title: 3 } }, - { authorId: 2, _count: { title: 2 } }, - ]), - ); - - // Test Kysely groupBy operations - const kyselyUserGroupBy = await db.$qb - .selectFrom('User') - .select(['email', (eb) => eb.fn.count('email').as('count')]) - .groupBy('email') - .having((eb) => eb.fn.count('email'), '>=', 1) - .execute(); - expect(kyselyUserGroupBy).toHaveLength(3); + }, + }); + + // Test ORM groupBy operations + const userGroupBy = await db.user.groupBy({ + by: ['email'], + _count: { id: true }, + }); + expect(userGroupBy).toHaveLength(3); + expect(userGroupBy).toEqual( + expect.arrayContaining([ + { email: 'u1@test.com', _count: { id: 1 } }, + { email: 'u2@test.com', _count: { id: 1 } }, + { email: 'u3@test.com', _count: { id: 1 } }, + ]), + ); + + const postGroupBy = await db.post.groupBy({ + by: ['authorId'], + _count: { id: true }, + _min: { id: true }, + _max: { id: true }, + }); + expect(postGroupBy).toHaveLength(3); + expect(postGroupBy).toEqual( + expect.arrayContaining([ + { authorId: 1, _count: { id: 3 }, _min: { id: 1 }, _max: { id: 3 } }, + { authorId: 2, _count: { id: 2 }, _min: { id: 4 }, _max: { id: 5 } }, + { authorId: 3, _count: { id: 1 }, _min: { id: 6 }, _max: { id: 6 } }, + ]), + ); + + const filteredGroupBy = await db.post.groupBy({ + by: ['authorId'], + where: { title: { contains: 'Post' } }, + _count: { title: true }, + having: { title: { _count: { gte: 2 } } }, }); - }, -); + expect(filteredGroupBy).toHaveLength(2); + expect(filteredGroupBy).toEqual( + expect.arrayContaining([ + { authorId: 1, _count: { title: 3 } }, + { authorId: 2, _count: { title: 2 } }, + ]), + ); + + // Test Kysely groupBy operations + const kyselyUserGroupBy = await db.$qb + .selectFrom('User') + .select(['email', (eb) => eb.fn.count('email').as('count')]) + .groupBy('email') + .having((eb) => eb.fn.count('email'), '>=', 1) + .execute(); + expect(kyselyUserGroupBy).toHaveLength(3); + }); +}); diff --git a/packages/runtime/test/client-api/raw-query.test.ts b/packages/runtime/test/client-api/raw-query.test.ts index b1754b67..f7838641 100644 --- a/packages/runtime/test/client-api/raw-query.test.ts +++ b/packages/runtime/test/client-api/raw-query.test.ts @@ -1,15 +1,13 @@ import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import type { ClientContract } from '../../src/client'; import { schema } from '../schemas/basic'; -import { createClientSpecs } from './client-specs'; +import { createTestClient } from '../utils'; -const PG_DB_NAME = 'client-api-raw-query-tests'; - -describe.each(createClientSpecs(PG_DB_NAME))('Client raw query tests', ({ createClient, provider }) => { +describe('Client raw query tests', () => { let client: ClientContract; beforeEach(async () => { - client = await createClient(); + client = await createTestClient(schema); }); afterEach(async () => { @@ -39,7 +37,8 @@ describe.each(createClientSpecs(PG_DB_NAME))('Client raw query tests', ({ create }); const sql = - provider === 'postgresql' + // @ts-ignore + client.$schema.provider.type === 'postgresql' ? `UPDATE "User" SET "email" = $1 WHERE "id" = $2` : `UPDATE "User" SET "email" = ? WHERE "id" = ?`; await expect(client.$executeRawUnsafe(sql, 'u2@test.com', '1')).resolves.toBe(1); @@ -70,7 +69,8 @@ describe.each(createClientSpecs(PG_DB_NAME))('Client raw query tests', ({ create }); const sql = - provider === 'postgresql' + // @ts-ignore + client.$schema.provider.type === 'postgresql' ? `SELECT "User"."id", "User"."email" FROM "User" WHERE "User"."id" = $1` : `SELECT "User"."id", "User"."email" FROM "User" WHERE "User"."id" = ?`; const users = await client.$queryRawUnsafe<{ id: string; email: string }[]>(sql, '1'); diff --git a/packages/runtime/test/client-api/relation/many-to-many.test.ts b/packages/runtime/test/client-api/relation/many-to-many.test.ts index c951387e..dd1eacf0 100644 --- a/packages/runtime/test/client-api/relation/many-to-many.test.ts +++ b/packages/runtime/test/client-api/relation/many-to-many.test.ts @@ -1,20 +1,16 @@ import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import { createTestClient } from '../../utils'; -const TEST_DB = 'client-api-relation-test-many-to-many'; +describe('Many-to-many relation tests', () => { + let client: any; -describe.each([{ provider: 'sqlite' as const }, { provider: 'postgresql' as const }])( - 'Many-to-many relation tests for $provider', - ({ provider }) => { - let client: any; + afterEach(async () => { + await client?.$disconnect(); + }); - afterEach(async () => { - await client?.$disconnect(); - }); - - it('works with explicit many-to-many relation', async () => { - client = await createTestClient( - ` + it('works with explicit many-to-many relation', async () => { + client = await createTestClient( + ` model User { id Int @id @default(autoincrement()) name String @@ -36,54 +32,50 @@ describe.each([{ provider: 'sqlite' as const }, { provider: 'postgresql' as cons @@unique([userId, tagId]) } `, - { - provider, - dbName: TEST_DB, - }, - ); - - await client.user.create({ data: { id: 1, name: 'User1' } }); - await client.user.create({ data: { id: 2, name: 'User2' } }); - await client.tag.create({ data: { id: 1, name: 'Tag1' } }); - await client.tag.create({ data: { id: 2, name: 'Tag2' } }); - - await client.userTag.create({ data: { userId: 1, tagId: 1 } }); - await client.userTag.create({ data: { userId: 1, tagId: 2 } }); - await client.userTag.create({ data: { userId: 2, tagId: 1 } }); - - await expect( - client.user.findMany({ - include: { tags: { include: { tag: true } } }, - }), - ).resolves.toMatchObject([ - expect.objectContaining({ - name: 'User1', - tags: [ - expect.objectContaining({ - tag: expect.objectContaining({ name: 'Tag1' }), - }), - expect.objectContaining({ - tag: expect.objectContaining({ name: 'Tag2' }), - }), - ], - }), - expect.objectContaining({ - name: 'User2', - tags: [ - expect.objectContaining({ - tag: expect.objectContaining({ name: 'Tag1' }), - }), - ], - }), - ]); - }); - - describe.each([{ relationName: undefined }, { relationName: 'myM2M' }])( - 'Implicit many-to-many relation (relation: $relationName)', - ({ relationName }) => { - beforeEach(async () => { - client = await createTestClient( - ` + ); + + await client.user.create({ data: { id: 1, name: 'User1' } }); + await client.user.create({ data: { id: 2, name: 'User2' } }); + await client.tag.create({ data: { id: 1, name: 'Tag1' } }); + await client.tag.create({ data: { id: 2, name: 'Tag2' } }); + + await client.userTag.create({ data: { userId: 1, tagId: 1 } }); + await client.userTag.create({ data: { userId: 1, tagId: 2 } }); + await client.userTag.create({ data: { userId: 2, tagId: 1 } }); + + await expect( + client.user.findMany({ + include: { tags: { include: { tag: true } } }, + }), + ).resolves.toMatchObject([ + expect.objectContaining({ + name: 'User1', + tags: [ + expect.objectContaining({ + tag: expect.objectContaining({ name: 'Tag1' }), + }), + expect.objectContaining({ + tag: expect.objectContaining({ name: 'Tag2' }), + }), + ], + }), + expect.objectContaining({ + name: 'User2', + tags: [ + expect.objectContaining({ + tag: expect.objectContaining({ name: 'Tag1' }), + }), + ], + }), + ]); + }); + + describe.each([{ relationName: undefined }, { relationName: 'myM2M' }])( + 'Implicit many-to-many relation (relation: $relationName)', + ({ relationName }) => { + beforeEach(async () => { + client = await createTestClient( + ` model User { id Int @id @default(autoincrement()) name String @@ -104,500 +96,494 @@ describe.each([{ provider: 'sqlite' as const }, { provider: 'postgresql' as cons userId Int @unique } `, - { - provider, - dbName: provider === 'postgresql' ? TEST_DB : undefined, - usePrismaPush: true, + { + usePrismaPush: true, + }, + ); + }); + + it('works with find', async () => { + await client.user.create({ + data: { + id: 1, + name: 'User1', + tags: { + create: [ + { id: 1, name: 'Tag1' }, + { id: 2, name: 'Tag2' }, + ], }, - ); + profile: { + create: { + id: 1, + age: 20, + }, + }, + }, + }); + + await client.user.create({ + data: { + id: 2, + name: 'User2', + }, + }); + + // include without filter + await expect( + client.user.findFirst({ + include: { tags: true }, + }), + ).resolves.toMatchObject({ + tags: [expect.objectContaining({ name: 'Tag1' }), expect.objectContaining({ name: 'Tag2' })], + }); + + await expect( + client.profile.findFirst({ + include: { + user: { + include: { tags: true }, + }, + }, + }), + ).resolves.toMatchObject({ + user: expect.objectContaining({ + tags: [expect.objectContaining({ name: 'Tag1' }), expect.objectContaining({ name: 'Tag2' })], + }), }); - it('works with find', async () => { - await client.user.create({ + await expect( + client.user.findUnique({ + where: { id: 2 }, + include: { tags: true }, + }), + ).resolves.toMatchObject({ + tags: [], + }); + + // include with filter + await expect( + client.user.findFirst({ + where: { id: 1 }, + include: { tags: { where: { name: 'Tag1' } } }, + }), + ).resolves.toMatchObject({ + tags: [expect.objectContaining({ name: 'Tag1' })], + }); + + // filter with m2m + await expect( + client.user.findMany({ + where: { tags: { some: { name: 'Tag1' } } }, + }), + ).resolves.toEqual([ + expect.objectContaining({ + name: 'User1', + }), + ]); + await expect( + client.user.findMany({ + where: { tags: { none: { name: 'Tag1' } } }, + }), + ).resolves.toEqual([ + expect.objectContaining({ + name: 'User2', + }), + ]); + }); + + it('works with create', async () => { + // create + await expect( + client.user.create({ data: { id: 1, name: 'User1', tags: { create: [ - { id: 1, name: 'Tag1' }, - { id: 2, name: 'Tag2' }, + { + id: 1, + name: 'Tag1', + }, + { + id: 2, + name: 'Tag2', + }, ], }, - profile: { - create: { - id: 1, - age: 20, - }, - }, }, - }); + include: { tags: true }, + }), + ).resolves.toMatchObject({ + tags: [expect.objectContaining({ name: 'Tag1' }), expect.objectContaining({ name: 'Tag2' })], + }); - await client.user.create({ + // connect + await expect( + client.user.create({ data: { id: 2, name: 'User2', + tags: { connect: { id: 1 } }, }, - }); - - // include without filter - await expect( - client.user.findFirst({ - include: { tags: true }, - }), - ).resolves.toMatchObject({ - tags: [expect.objectContaining({ name: 'Tag1' }), expect.objectContaining({ name: 'Tag2' })], - }); + include: { tags: true }, + }), + ).resolves.toMatchObject({ + tags: [expect.objectContaining({ name: 'Tag1' })], + }); - await expect( - client.profile.findFirst({ - include: { - user: { - include: { tags: true }, + // connectOrCreate + await expect( + client.user.create({ + data: { + id: 3, + name: 'User3', + tags: { + connectOrCreate: { + where: { id: 1 }, + create: { id: 1, name: 'Tag1' }, }, }, - }), - ).resolves.toMatchObject({ - user: expect.objectContaining({ - tags: [ - expect.objectContaining({ name: 'Tag1' }), - expect.objectContaining({ name: 'Tag2' }), - ], - }), - }); - - await expect( - client.user.findUnique({ - where: { id: 2 }, - include: { tags: true }, - }), - ).resolves.toMatchObject({ - tags: [], - }); - - // include with filter - await expect( - client.user.findFirst({ - where: { id: 1 }, - include: { tags: { where: { name: 'Tag1' } } }, - }), - ).resolves.toMatchObject({ - tags: [expect.objectContaining({ name: 'Tag1' })], - }); - - // filter with m2m - await expect( - client.user.findMany({ - where: { tags: { some: { name: 'Tag1' } } }, - }), - ).resolves.toEqual([ - expect.objectContaining({ - name: 'User1', - }), - ]); - await expect( - client.user.findMany({ - where: { tags: { none: { name: 'Tag1' } } }, - }), - ).resolves.toEqual([ - expect.objectContaining({ - name: 'User2', - }), - ]); + }, + include: { tags: true }, + }), + ).resolves.toMatchObject({ + tags: [expect.objectContaining({ id: 1, name: 'Tag1' })], }); - it('works with create', async () => { - // create - await expect( - client.user.create({ - data: { - id: 1, - name: 'User1', - tags: { - create: [ - { - id: 1, - name: 'Tag1', - }, - { - id: 2, - name: 'Tag2', - }, - ], - }, - }, - include: { tags: true }, - }), - ).resolves.toMatchObject({ - tags: [expect.objectContaining({ name: 'Tag1' }), expect.objectContaining({ name: 'Tag2' })], - }); - - // connect - await expect( - client.user.create({ - data: { - id: 2, - name: 'User2', - tags: { connect: { id: 1 } }, - }, - include: { tags: true }, - }), - ).resolves.toMatchObject({ - tags: [expect.objectContaining({ name: 'Tag1' })], - }); - - // connectOrCreate - await expect( - client.user.create({ - data: { - id: 3, - name: 'User3', - tags: { - connectOrCreate: { - where: { id: 1 }, - create: { id: 1, name: 'Tag1' }, - }, + await expect( + client.user.create({ + data: { + id: 4, + name: 'User4', + tags: { + connectOrCreate: { + where: { id: 3 }, + create: { id: 3, name: 'Tag3' }, }, }, - include: { tags: true }, - }), - ).resolves.toMatchObject({ - tags: [expect.objectContaining({ id: 1, name: 'Tag1' })], - }); - - await expect( - client.user.create({ - data: { - id: 4, - name: 'User4', - tags: { - connectOrCreate: { - where: { id: 3 }, - create: { id: 3, name: 'Tag3' }, - }, + }, + include: { tags: true }, + }), + ).resolves.toMatchObject({ + tags: [expect.objectContaining({ id: 3, name: 'Tag3' })], + }); + }); + + it('works with update', async () => { + // create + await client.user.create({ + data: { + id: 1, + name: 'User1', + tags: { + create: [ + { + id: 1, + name: 'Tag1', }, - }, - include: { tags: true }, - }), - ).resolves.toMatchObject({ - tags: [expect.objectContaining({ id: 3, name: 'Tag3' })], - }); + ], + }, + }, + include: { tags: true }, }); - it('works with update', async () => { - // create - await client.user.create({ + // create + await expect( + client.user.update({ + where: { id: 1 }, data: { - id: 1, - name: 'User1', tags: { create: [ { - id: 1, - name: 'Tag1', + id: 2, + name: 'Tag2', }, ], }, }, include: { tags: true }, - }); - - // create - await expect( - client.user.update({ - where: { id: 1 }, - data: { - tags: { - create: [ - { - id: 2, - name: 'Tag2', - }, - ], - }, - }, - include: { tags: true }, - }), - ).resolves.toMatchObject({ - tags: [expect.objectContaining({ id: 1 }), expect.objectContaining({ id: 2 })], - }); + }), + ).resolves.toMatchObject({ + tags: [expect.objectContaining({ id: 1 }), expect.objectContaining({ id: 2 })], + }); - await client.tag.create({ + await client.tag.create({ + data: { + id: 3, + name: 'Tag3', + }, + }); + + // connect + await expect( + client.user.update({ + where: { id: 1 }, + data: { tags: { connect: { id: 3 } } }, + include: { tags: true }, + }), + ).resolves.toMatchObject({ + tags: [ + expect.objectContaining({ id: 1 }), + expect.objectContaining({ id: 2 }), + expect.objectContaining({ id: 3 }), + ], + }); + // connecting a connected entity is no-op + await expect( + client.user.update({ + where: { id: 1 }, + data: { tags: { connect: { id: 3 } } }, + include: { tags: true }, + }), + ).resolves.toMatchObject({ + tags: [ + expect.objectContaining({ id: 1 }), + expect.objectContaining({ id: 2 }), + expect.objectContaining({ id: 3 }), + ], + }); + + // disconnect - not found + await expect( + client.user.update({ + where: { id: 1 }, data: { - id: 3, - name: 'Tag3', + tags: { disconnect: { id: 3, name: 'not found' } }, }, - }); - - // connect - await expect( - client.user.update({ - where: { id: 1 }, - data: { tags: { connect: { id: 3 } } }, - include: { tags: true }, - }), - ).resolves.toMatchObject({ - tags: [ - expect.objectContaining({ id: 1 }), - expect.objectContaining({ id: 2 }), - expect.objectContaining({ id: 3 }), - ], - }); - // connecting a connected entity is no-op - await expect( - client.user.update({ - where: { id: 1 }, - data: { tags: { connect: { id: 3 } } }, - include: { tags: true }, - }), - ).resolves.toMatchObject({ - tags: [ - expect.objectContaining({ id: 1 }), - expect.objectContaining({ id: 2 }), - expect.objectContaining({ id: 3 }), - ], - }); - - // disconnect - not found - await expect( - client.user.update({ - where: { id: 1 }, - data: { - tags: { disconnect: { id: 3, name: 'not found' } }, - }, - include: { tags: true }, - }), - ).resolves.toMatchObject({ - tags: [ - expect.objectContaining({ id: 1 }), - expect.objectContaining({ id: 2 }), - expect.objectContaining({ id: 3 }), - ], - }); - - // disconnect - found - await expect( - client.user.update({ - where: { id: 1 }, - data: { tags: { disconnect: { id: 3 } } }, - include: { tags: true }, - }), - ).resolves.toMatchObject({ - tags: [expect.objectContaining({ id: 1 }), expect.objectContaining({ id: 2 })], - }); - - await expect( - client.$qbRaw - .selectFrom(relationName ? `_${relationName}` : '_TagToUser') - .selectAll() - .where('B', '=', 1) // user id - .where('A', '=', 3) // tag id - .execute(), - ).resolves.toHaveLength(0); - - await expect( - client.user.update({ - where: { id: 1 }, - data: { tags: { set: [{ id: 2 }, { id: 3 }] } }, - include: { tags: true }, - }), - ).resolves.toMatchObject({ - tags: [expect.objectContaining({ id: 2 }), expect.objectContaining({ id: 3 })], - }); - - // update - not found - await expect( - client.user.update({ - where: { id: 1 }, - data: { - tags: { - update: { - where: { id: 1 }, - data: { name: 'Tag1-updated' }, - }, + include: { tags: true }, + }), + ).resolves.toMatchObject({ + tags: [ + expect.objectContaining({ id: 1 }), + expect.objectContaining({ id: 2 }), + expect.objectContaining({ id: 3 }), + ], + }); + + // disconnect - found + await expect( + client.user.update({ + where: { id: 1 }, + data: { tags: { disconnect: { id: 3 } } }, + include: { tags: true }, + }), + ).resolves.toMatchObject({ + tags: [expect.objectContaining({ id: 1 }), expect.objectContaining({ id: 2 })], + }); + + await expect( + client.$qbRaw + .selectFrom(relationName ? `_${relationName}` : '_TagToUser') + .selectAll() + .where('B', '=', 1) // user id + .where('A', '=', 3) // tag id + .execute(), + ).resolves.toHaveLength(0); + + await expect( + client.user.update({ + where: { id: 1 }, + data: { tags: { set: [{ id: 2 }, { id: 3 }] } }, + include: { tags: true }, + }), + ).resolves.toMatchObject({ + tags: [expect.objectContaining({ id: 2 }), expect.objectContaining({ id: 3 })], + }); + + // update - not found + await expect( + client.user.update({ + where: { id: 1 }, + data: { + tags: { + update: { + where: { id: 1 }, + data: { name: 'Tag1-updated' }, }, }, - }), - ).toBeRejectedNotFound(); - - // update - found - await expect( - client.user.update({ - where: { id: 1 }, - data: { - tags: { - update: { - where: { id: 2 }, - data: { name: 'Tag2-updated' }, - }, + }, + }), + ).toBeRejectedNotFound(); + + // update - found + await expect( + client.user.update({ + where: { id: 1 }, + data: { + tags: { + update: { + where: { id: 2 }, + data: { name: 'Tag2-updated' }, }, }, - include: { tags: true }, + }, + include: { tags: true }, + }), + ).resolves.toMatchObject({ + tags: expect.arrayContaining([ + expect.objectContaining({ + id: 2, + name: 'Tag2-updated', }), - ).resolves.toMatchObject({ - tags: expect.arrayContaining([ - expect.objectContaining({ - id: 2, - name: 'Tag2-updated', - }), - expect.objectContaining({ id: 3, name: 'Tag3' }), - ]), - }); - - // updateMany - await expect( - client.user.update({ - where: { id: 1 }, - data: { - tags: { - updateMany: { - where: { id: { not: 2 } }, - data: { name: 'Tag3-updated' }, - }, + expect.objectContaining({ id: 3, name: 'Tag3' }), + ]), + }); + + // updateMany + await expect( + client.user.update({ + where: { id: 1 }, + data: { + tags: { + updateMany: { + where: { id: { not: 2 } }, + data: { name: 'Tag3-updated' }, }, }, - include: { tags: true }, + }, + include: { tags: true }, + }), + ).resolves.toMatchObject({ + tags: [ + expect.objectContaining({ + id: 2, + name: 'Tag2-updated', }), - ).resolves.toMatchObject({ - tags: [ - expect.objectContaining({ - id: 2, - name: 'Tag2-updated', - }), - expect.objectContaining({ - id: 3, - name: 'Tag3-updated', - }), - ], - }); - - await expect(client.tag.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ - name: 'Tag1', - }); - - // upsert - update - await expect( - client.user.update({ - where: { id: 1 }, - data: { - tags: { - upsert: { - where: { id: 3 }, - create: { id: 3, name: 'Tag4' }, - update: { name: 'Tag3-updated-1' }, - }, - }, - }, - include: { tags: true }, + expect.objectContaining({ + id: 3, + name: 'Tag3-updated', }), - ).resolves.toMatchObject({ - tags: [ - expect.objectContaining({ - id: 2, - name: 'Tag2-updated', - }), - expect.objectContaining({ - id: 3, - name: 'Tag3-updated-1', - }), - ], - }); - - // upsert - create - await expect( - client.user.update({ - where: { id: 1 }, - data: { - tags: { - upsert: { - where: { id: 4 }, - create: { id: 4, name: 'Tag4' }, - update: { name: 'Tag4' }, - }, + ], + }); + + await expect(client.tag.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ + name: 'Tag1', + }); + + // upsert - update + await expect( + client.user.update({ + where: { id: 1 }, + data: { + tags: { + upsert: { + where: { id: 3 }, + create: { id: 3, name: 'Tag4' }, + update: { name: 'Tag3-updated-1' }, }, }, - include: { tags: true }, - }), - ).resolves.toMatchObject({ - tags: expect.arrayContaining([expect.objectContaining({ id: 4, name: 'Tag4' })]), - }); - - // delete - not found - await expect( - client.user.update({ - where: { id: 1 }, - data: { tags: { delete: { id: 1 } } }, - }), - ).toBeRejectedNotFound(); - - // delete - found - await expect( - client.user.update({ - where: { id: 1 }, - data: { tags: { delete: { id: 2 } } }, - include: { tags: true }, + }, + include: { tags: true }, + }), + ).resolves.toMatchObject({ + tags: [ + expect.objectContaining({ + id: 2, + name: 'Tag2-updated', }), - ).resolves.toMatchObject({ - tags: [expect.objectContaining({ id: 3 }), expect.objectContaining({ id: 4 })], - }); - await expect(client.tag.findUnique({ where: { id: 2 } })).toResolveNull(); - - // deleteMany - await expect( - client.user.update({ - where: { id: 1 }, - data: { - tags: { deleteMany: { id: { in: [1, 2, 3] } } }, - }, - include: { tags: true }, + expect.objectContaining({ + id: 3, + name: 'Tag3-updated-1', }), - ).resolves.toMatchObject({ - tags: [expect.objectContaining({ id: 4 })], - }); - await expect(client.tag.findUnique({ where: { id: 3 } })).toResolveNull(); - await expect(client.tag.findUnique({ where: { id: 1 } })).toResolveTruthy(); + ], }); - it('works with delete', async () => { - await client.user.create({ + // upsert - create + await expect( + client.user.update({ + where: { id: 1 }, data: { - id: 1, - name: 'User1', tags: { - create: [ - { id: 1, name: 'Tag1' }, - { id: 2, name: 'Tag2' }, - ], + upsert: { + where: { id: 4 }, + create: { id: 4, name: 'Tag4' }, + update: { name: 'Tag4' }, + }, }, }, - }); + include: { tags: true }, + }), + ).resolves.toMatchObject({ + tags: expect.arrayContaining([expect.objectContaining({ id: 4, name: 'Tag4' })]), + }); - // cascade from tag - await client.tag.delete({ + // delete - not found + await expect( + client.user.update({ where: { id: 1 }, - }); - await expect( - client.user.findUnique({ - where: { id: 1 }, - include: { tags: true }, - }), - ).resolves.toMatchObject({ - tags: [expect.objectContaining({ id: 2 })], - }); + data: { tags: { delete: { id: 1 } } }, + }), + ).toBeRejectedNotFound(); - // cascade from user - await client.user.delete({ + // delete - found + await expect( + client.user.update({ where: { id: 1 }, - }); - await expect( - client.tag.findUnique({ - where: { id: 2 }, - include: { users: true }, - }), - ).resolves.toMatchObject({ - users: [], - }); + data: { tags: { delete: { id: 2 } } }, + include: { tags: true }, + }), + ).resolves.toMatchObject({ + tags: [expect.objectContaining({ id: 3 }), expect.objectContaining({ id: 4 })], }); - }, - ); - }, -); + await expect(client.tag.findUnique({ where: { id: 2 } })).toResolveNull(); + + // deleteMany + await expect( + client.user.update({ + where: { id: 1 }, + data: { + tags: { deleteMany: { id: { in: [1, 2, 3] } } }, + }, + include: { tags: true }, + }), + ).resolves.toMatchObject({ + tags: [expect.objectContaining({ id: 4 })], + }); + await expect(client.tag.findUnique({ where: { id: 3 } })).toResolveNull(); + await expect(client.tag.findUnique({ where: { id: 1 } })).toResolveTruthy(); + }); + + it('works with delete', async () => { + await client.user.create({ + data: { + id: 1, + name: 'User1', + tags: { + create: [ + { id: 1, name: 'Tag1' }, + { id: 2, name: 'Tag2' }, + ], + }, + }, + }); + + // cascade from tag + await client.tag.delete({ + where: { id: 1 }, + }); + await expect( + client.user.findUnique({ + where: { id: 1 }, + include: { tags: true }, + }), + ).resolves.toMatchObject({ + tags: [expect.objectContaining({ id: 2 })], + }); + + // cascade from user + await client.user.delete({ + where: { id: 1 }, + }); + await expect( + client.tag.findUnique({ + where: { id: 2 }, + include: { users: true }, + }), + ).resolves.toMatchObject({ + users: [], + }); + }); + }, + ); +}); diff --git a/packages/runtime/test/client-api/relation/one-to-many.test.ts b/packages/runtime/test/client-api/relation/one-to-many.test.ts index 656c5daa..be847d5e 100644 --- a/packages/runtime/test/client-api/relation/one-to-many.test.ts +++ b/packages/runtime/test/client-api/relation/one-to-many.test.ts @@ -1,20 +1,16 @@ import { afterEach, describe, expect, it } from 'vitest'; import { createTestClient } from '../../utils'; -const TEST_DB = 'client-api-relation-test-one-to-many'; +describe('One-to-many relation tests ', () => { + let client: any; -describe.each([{ provider: 'sqlite' as const }, { provider: 'postgresql' as const }])( - 'One-to-many relation tests for $provider', - ({ provider }) => { - let client: any; + afterEach(async () => { + await client?.$disconnect(); + }); - afterEach(async () => { - await client?.$disconnect(); - }); - - it('works with unnamed one-to-many relation', async () => { - client = await createTestClient( - ` + it('works with unnamed one-to-many relation', async () => { + client = await createTestClient( + ` model User { id Int @id @default(autoincrement()) name String @@ -28,31 +24,27 @@ describe.each([{ provider: 'sqlite' as const }, { provider: 'postgresql' as cons userId Int } `, - { - provider, - dbName: TEST_DB, - }, - ); + ); - await expect( - client.user.create({ - data: { - name: 'User', - posts: { - create: [{ title: 'Post 1' }, { title: 'Post 2' }], - }, + await expect( + client.user.create({ + data: { + name: 'User', + posts: { + create: [{ title: 'Post 1' }, { title: 'Post 2' }], }, - include: { posts: true }, - }), - ).resolves.toMatchObject({ - name: 'User', - posts: [expect.objectContaining({ title: 'Post 1' }), expect.objectContaining({ title: 'Post 2' })], - }); + }, + include: { posts: true }, + }), + ).resolves.toMatchObject({ + name: 'User', + posts: [expect.objectContaining({ title: 'Post 1' }), expect.objectContaining({ title: 'Post 2' })], }); + }); - it('works with named one-to-many relation', async () => { - client = await createTestClient( - ` + it('works with named one-to-many relation', async () => { + client = await createTestClient( + ` model User { id Int @id @default(autoincrement()) name String @@ -69,30 +61,25 @@ describe.each([{ provider: 'sqlite' as const }, { provider: 'postgresql' as cons userId2 Int? } `, - { - provider, - dbName: TEST_DB, - }, - ); + ); - await expect( - client.user.create({ - data: { - name: 'User', - posts1: { - create: [{ title: 'Post 1' }, { title: 'Post 2' }], - }, - posts2: { - create: [{ title: 'Post 3' }, { title: 'Post 4' }], - }, + await expect( + client.user.create({ + data: { + name: 'User', + posts1: { + create: [{ title: 'Post 1' }, { title: 'Post 2' }], + }, + posts2: { + create: [{ title: 'Post 3' }, { title: 'Post 4' }], }, - include: { posts1: true, posts2: true }, - }), - ).resolves.toMatchObject({ - name: 'User', - posts1: [expect.objectContaining({ title: 'Post 1' }), expect.objectContaining({ title: 'Post 2' })], - posts2: [expect.objectContaining({ title: 'Post 3' }), expect.objectContaining({ title: 'Post 4' })], - }); + }, + include: { posts1: true, posts2: true }, + }), + ).resolves.toMatchObject({ + name: 'User', + posts1: [expect.objectContaining({ title: 'Post 1' }), expect.objectContaining({ title: 'Post 2' })], + posts2: [expect.objectContaining({ title: 'Post 3' }), expect.objectContaining({ title: 'Post 4' })], }); - }, -); + }); +}); diff --git a/packages/runtime/test/client-api/relation/one-to-one.test.ts b/packages/runtime/test/client-api/relation/one-to-one.test.ts index b1e80562..e41a0cf9 100644 --- a/packages/runtime/test/client-api/relation/one-to-one.test.ts +++ b/packages/runtime/test/client-api/relation/one-to-one.test.ts @@ -4,7 +4,7 @@ import { createTestClient } from '../../utils'; const TEST_DB = 'client-api-relation-test-one-to-one'; describe.each([{ provider: 'sqlite' as const }, { provider: 'postgresql' as const }])( - 'One-to-one relation tests for $provider', + 'One-to-one relation tests', ({ provider }) => { let client: any; diff --git a/packages/runtime/test/client-api/relation/self-relation.test.ts b/packages/runtime/test/client-api/relation/self-relation.test.ts index f85c20c4..65380b30 100644 --- a/packages/runtime/test/client-api/relation/self-relation.test.ts +++ b/packages/runtime/test/client-api/relation/self-relation.test.ts @@ -1,20 +1,16 @@ import { afterEach, describe, expect, it } from 'vitest'; import { createTestClient } from '../../utils'; -const TEST_DB = 'client-api-relation-test-self-relation'; +describe('Self relation tests', () => { + let client: any; -describe.each([{ provider: 'sqlite' as const }, { provider: 'postgresql' as const }])( - 'Self relation tests for $provider', - ({ provider }) => { - let client: any; + afterEach(async () => { + await client?.$disconnect(); + }); - afterEach(async () => { - await client?.$disconnect(); - }); - - it('works with one-to-one self relation', async () => { - client = await createTestClient( - ` + it('works with one-to-one self relation', async () => { + client = await createTestClient( + ` model User { id Int @id @default(autoincrement()) name String @@ -23,105 +19,103 @@ describe.each([{ provider: 'sqlite' as const }, { provider: 'postgresql' as cons spouseId Int? @unique } `, - { - provider, - dbName: TEST_DB, - usePrismaPush: true, + { + usePrismaPush: true, + }, + ); + + // Create first user + const alice = await client.user.create({ + data: { name: 'Alice' }, + }); + + // Create second user and establish marriage relationship + await expect( + client.user.create({ + data: { + name: 'Bob', + spouse: { connect: { id: alice.id } }, }, - ); - - // Create first user - const alice = await client.user.create({ - data: { name: 'Alice' }, - }); - - // Create second user and establish marriage relationship - await expect( - client.user.create({ - data: { - name: 'Bob', - spouse: { connect: { id: alice.id } }, - }, - include: { spouse: true }, - }), - ).resolves.toMatchObject({ - name: 'Bob', - spouse: { name: 'Alice' }, - }); - - // Verify the reverse relationship - await expect( - client.user.findUnique({ - where: { id: alice.id }, - include: { marriedTo: true }, - }), - ).resolves.toMatchObject({ - name: 'Alice', - marriedTo: { name: 'Bob' }, - }); - - // Test creating with nested create - await expect( - client.user.create({ - data: { - name: 'Charlie', - spouse: { - create: { name: 'Diana' }, - }, - }, - include: { spouse: true }, - }), - ).resolves.toMatchObject({ - name: 'Charlie', - spouse: { name: 'Diana' }, - }); - - // Verify Diana is married to Charlie - await expect( - client.user.findFirst({ - where: { name: 'Diana' }, - include: { marriedTo: true }, - }), - ).resolves.toMatchObject({ - name: 'Diana', - marriedTo: { name: 'Charlie' }, - }); - - // Test disconnecting relationship - const bob = await client.user.findFirst({ - where: { name: 'Bob' }, - }); - - await expect( - client.user.update({ - where: { id: bob!.id }, - data: { - spouse: { disconnect: true }, + include: { spouse: true }, + }), + ).resolves.toMatchObject({ + name: 'Bob', + spouse: { name: 'Alice' }, + }); + + // Verify the reverse relationship + await expect( + client.user.findUnique({ + where: { id: alice.id }, + include: { marriedTo: true }, + }), + ).resolves.toMatchObject({ + name: 'Alice', + marriedTo: { name: 'Bob' }, + }); + + // Test creating with nested create + await expect( + client.user.create({ + data: { + name: 'Charlie', + spouse: { + create: { name: 'Diana' }, }, - include: { spouse: true, marriedTo: true }, - }), - ).resolves.toMatchObject({ - name: 'Bob', - spouse: null, - marriedTo: null, - }); - - // Verify Alice is also disconnected - await expect( - client.user.findUnique({ - where: { id: alice.id }, - include: { spouse: true, marriedTo: true }, - }), - ).resolves.toMatchObject({ - name: 'Alice', - spouse: null, - marriedTo: null, - }); + }, + include: { spouse: true }, + }), + ).resolves.toMatchObject({ + name: 'Charlie', + spouse: { name: 'Diana' }, + }); + + // Verify Diana is married to Charlie + await expect( + client.user.findFirst({ + where: { name: 'Diana' }, + include: { marriedTo: true }, + }), + ).resolves.toMatchObject({ + name: 'Diana', + marriedTo: { name: 'Charlie' }, + }); + + // Test disconnecting relationship + const bob = await client.user.findFirst({ + where: { name: 'Bob' }, + }); + + await expect( + client.user.update({ + where: { id: bob!.id }, + data: { + spouse: { disconnect: true }, + }, + include: { spouse: true, marriedTo: true }, + }), + ).resolves.toMatchObject({ + name: 'Bob', + spouse: null, + marriedTo: null, }); - it('works with one-to-many self relation', async () => { - client = await createTestClient( - ` + // Verify Alice is also disconnected + await expect( + client.user.findUnique({ + where: { id: alice.id }, + include: { spouse: true, marriedTo: true }, + }), + ).resolves.toMatchObject({ + name: 'Alice', + spouse: null, + marriedTo: null, + }); + }); + + it('works with one-to-many self relation', async () => { + client = await createTestClient( + ` model Category { id Int @id @default(autoincrement()) name String @@ -130,181 +124,179 @@ describe.each([{ provider: 'sqlite' as const }, { provider: 'postgresql' as cons parentId Int? } `, - { - provider, - dbName: TEST_DB, - usePrismaPush: true, - }, - ); + { + usePrismaPush: true, + }, + ); + + // Create parent category + const parent = await client.category.create({ + data: { + name: 'Electronics', + }, + }); - // Create parent category - const parent = await client.category.create({ + // Create children with parent + await expect( + client.category.create({ data: { - name: 'Electronics', + name: 'Smartphones', + parent: { connect: { id: parent.id } }, }, - }); - - // Create children with parent - await expect( - client.category.create({ - data: { - name: 'Smartphones', - parent: { connect: { id: parent.id } }, - }, - include: { parent: true }, - }), - ).resolves.toMatchObject({ - name: 'Smartphones', - parent: { name: 'Electronics' }, - }); - - // Create child using nested create - await expect( - client.category.create({ - data: { - name: 'Gaming', - children: { - create: [{ name: 'Console Games' }, { name: 'PC Games' }], - }, - }, - include: { children: true }, - }), - ).resolves.toMatchObject({ - name: 'Gaming', - children: [ - expect.objectContaining({ name: 'Console Games' }), - expect.objectContaining({ name: 'PC Games' }), - ], - }); - - // Query with full hierarchy - await expect( - client.category.findFirst({ - where: { name: 'Electronics' }, - include: { - children: { - include: { parent: true }, - }, - }, - }), - ).resolves.toMatchObject({ - name: 'Electronics', - children: [ - expect.objectContaining({ - name: 'Smartphones', - parent: expect.objectContaining({ name: 'Electronics' }), - }), - ], - }); - - // Test relation manipulation with update - move child to different parent - const gaming = await client.category.findFirst({ where: { name: 'Gaming' } }); - const smartphone = await client.category.findFirst({ where: { name: 'Smartphones' } }); - - await expect( - client.category.update({ - where: { id: smartphone.id }, - data: { - parent: { connect: { id: gaming.id } }, - }, - include: { parent: true }, - }), - ).resolves.toMatchObject({ - name: 'Smartphones', - parent: { name: 'Gaming' }, - }); - - // Test update to disconnect parent (make orphan) - await expect( - client.category.update({ - where: { id: smartphone.id }, - data: { - parent: { disconnect: true }, + include: { parent: true }, + }), + ).resolves.toMatchObject({ + name: 'Smartphones', + parent: { name: 'Electronics' }, + }); + + // Create child using nested create + await expect( + client.category.create({ + data: { + name: 'Gaming', + children: { + create: [{ name: 'Console Games' }, { name: 'PC Games' }], }, - include: { parent: true }, - }), - ).resolves.toMatchObject({ - name: 'Smartphones', - parent: null, - }); - - // Test update to add new children to existing parent - const newChild = await client.category.create({ data: { name: 'Accessories' } }); - - await expect( - client.category.update({ - where: { id: parent.id }, - data: { - children: { connect: { id: newChild.id } }, + }, + include: { children: true }, + }), + ).resolves.toMatchObject({ + name: 'Gaming', + children: [ + expect.objectContaining({ name: 'Console Games' }), + expect.objectContaining({ name: 'PC Games' }), + ], + }); + + // Query with full hierarchy + await expect( + client.category.findFirst({ + where: { name: 'Electronics' }, + include: { + children: { + include: { parent: true }, }, - include: { children: true }, - }), - ).resolves.toMatchObject({ - name: 'Electronics', - children: expect.arrayContaining([expect.objectContaining({ name: 'Accessories' })]), - }); - - // Test nested relation delete - delete specific children via update - const consoleGames = await client.category.findFirst({ where: { name: 'Console Games' } }); - - await expect( - client.category.update({ - where: { id: gaming.id }, - data: { - children: { - delete: { id: consoleGames.id }, - }, + }, + }), + ).resolves.toMatchObject({ + name: 'Electronics', + children: [ + expect.objectContaining({ + name: 'Smartphones', + parent: expect.objectContaining({ name: 'Electronics' }), + }), + ], + }); + + // Test relation manipulation with update - move child to different parent + const gaming = await client.category.findFirst({ where: { name: 'Gaming' } }); + const smartphone = await client.category.findFirst({ where: { name: 'Smartphones' } }); + + await expect( + client.category.update({ + where: { id: smartphone.id }, + data: { + parent: { connect: { id: gaming.id } }, + }, + include: { parent: true }, + }), + ).resolves.toMatchObject({ + name: 'Smartphones', + parent: { name: 'Gaming' }, + }); + + // Test update to disconnect parent (make orphan) + await expect( + client.category.update({ + where: { id: smartphone.id }, + data: { + parent: { disconnect: true }, + }, + include: { parent: true }, + }), + ).resolves.toMatchObject({ + name: 'Smartphones', + parent: null, + }); + + // Test update to add new children to existing parent + const newChild = await client.category.create({ data: { name: 'Accessories' } }); + + await expect( + client.category.update({ + where: { id: parent.id }, + data: { + children: { connect: { id: newChild.id } }, + }, + include: { children: true }, + }), + ).resolves.toMatchObject({ + name: 'Electronics', + children: expect.arrayContaining([expect.objectContaining({ name: 'Accessories' })]), + }); + + // Test nested relation delete - delete specific children via update + const consoleGames = await client.category.findFirst({ where: { name: 'Console Games' } }); + + await expect( + client.category.update({ + where: { id: gaming.id }, + data: { + children: { + delete: { id: consoleGames.id }, }, - include: { children: true }, - }), - ).resolves.toMatchObject({ - name: 'Gaming', - children: [expect.objectContaining({ name: 'PC Games' })], - }); - - // Verify the deleted child no longer exists - await expect(client.category.findFirst({ where: { id: consoleGames.id } })).resolves.toBeNull(); - - // Test nested delete with multiple children - await expect( - client.category.update({ - where: { id: gaming.id }, - data: { - children: { - deleteMany: { - name: { startsWith: 'PC' }, - }, + }, + include: { children: true }, + }), + ).resolves.toMatchObject({ + name: 'Gaming', + children: [expect.objectContaining({ name: 'PC Games' })], + }); + + // Verify the deleted child no longer exists + await expect(client.category.findFirst({ where: { id: consoleGames.id } })).resolves.toBeNull(); + + // Test nested delete with multiple children + await expect( + client.category.update({ + where: { id: gaming.id }, + data: { + children: { + deleteMany: { + name: { startsWith: 'PC' }, }, }, - include: { children: true }, - }), - ).resolves.toMatchObject({ - name: 'Gaming', - children: [], - }); - - // Test update with nested delete using where condition - await expect( - client.category.update({ - where: { id: parent.id }, - data: { - children: { - deleteMany: { - name: 'Accessories', - }, + }, + include: { children: true }, + }), + ).resolves.toMatchObject({ + name: 'Gaming', + children: [], + }); + + // Test update with nested delete using where condition + await expect( + client.category.update({ + where: { id: parent.id }, + data: { + children: { + deleteMany: { + name: 'Accessories', }, }, - include: { children: true }, - }), - ).resolves.toMatchObject({ - name: 'Electronics', - children: [], - }); + }, + include: { children: true }, + }), + ).resolves.toMatchObject({ + name: 'Electronics', + children: [], }); + }); - it('works with many-to-many self relation', async () => { - client = await createTestClient( - ` + it('works with many-to-many self relation', async () => { + client = await createTestClient( + ` model User { id Int @id @default(autoincrement()) name String @@ -312,222 +304,217 @@ describe.each([{ provider: 'sqlite' as const }, { provider: 'postgresql' as cons followers User[] @relation("UserFollows") } `, - { - provider, - dbName: provider === 'postgresql' ? TEST_DB : undefined, - usePrismaPush: true, - }, - ); - - // Create users - const user1 = await client.user.create({ data: { name: 'Alice' } }); - const user2 = await client.user.create({ data: { name: 'Bob' } }); - const user3 = await client.user.create({ data: { name: 'Charlie' } }); - - // Alice follows Bob and Charlie - await expect( - client.user.update({ - where: { id: user1.id }, - data: { - following: { - connect: [{ id: user2.id }, { id: user3.id }], - }, + { + usePrismaPush: true, + }, + ); + + // Create users + const user1 = await client.user.create({ data: { name: 'Alice' } }); + const user2 = await client.user.create({ data: { name: 'Bob' } }); + const user3 = await client.user.create({ data: { name: 'Charlie' } }); + + // Alice follows Bob and Charlie + await expect( + client.user.update({ + where: { id: user1.id }, + data: { + following: { + connect: [{ id: user2.id }, { id: user3.id }], }, - include: { following: true }, - }), - ).resolves.toMatchObject({ - name: 'Alice', - following: [expect.objectContaining({ name: 'Bob' }), expect.objectContaining({ name: 'Charlie' })], - }); + }, + include: { following: true }, + }), + ).resolves.toMatchObject({ + name: 'Alice', + following: [expect.objectContaining({ name: 'Bob' }), expect.objectContaining({ name: 'Charlie' })], + }); + + // Bob follows Charlie + await client.user.update({ + where: { id: user2.id }, + data: { + following: { connect: { id: user3.id } }, + }, + }); - // Bob follows Charlie - await client.user.update({ + // Check Bob's followers (should include Alice) + await expect( + client.user.findUnique({ where: { id: user2.id }, - data: { - following: { connect: { id: user3.id } }, - }, - }); + include: { followers: true }, + }), + ).resolves.toMatchObject({ + name: 'Bob', + followers: [expect.objectContaining({ name: 'Alice' })], + }); - // Check Bob's followers (should include Alice) - await expect( - client.user.findUnique({ - where: { id: user2.id }, - include: { followers: true }, - }), - ).resolves.toMatchObject({ - name: 'Bob', - followers: [expect.objectContaining({ name: 'Alice' })], - }); - - // Check Charlie's followers (should include Alice and Bob) - await expect( - client.user.findUnique({ - where: { id: user3.id }, - include: { followers: true }, - }), - ).resolves.toMatchObject({ - name: 'Charlie', - followers: [expect.objectContaining({ name: 'Alice' }), expect.objectContaining({ name: 'Bob' })], - }); - - // Test filtering with self relation - await expect( - client.user.findMany({ - where: { - followers: { - some: { name: 'Alice' }, - }, + // Check Charlie's followers (should include Alice and Bob) + await expect( + client.user.findUnique({ + where: { id: user3.id }, + include: { followers: true }, + }), + ).resolves.toMatchObject({ + name: 'Charlie', + followers: [expect.objectContaining({ name: 'Alice' }), expect.objectContaining({ name: 'Bob' })], + }); + + // Test filtering with self relation + await expect( + client.user.findMany({ + where: { + followers: { + some: { name: 'Alice' }, }, - }), - ).resolves.toEqual([ - expect.objectContaining({ name: 'Bob' }), - expect.objectContaining({ name: 'Charlie' }), - ]); - - // Test disconnect operation - await expect( - client.user.update({ - where: { id: user1.id }, - data: { - following: { - disconnect: { id: user2.id }, - }, + }, + }), + ).resolves.toEqual([expect.objectContaining({ name: 'Bob' }), expect.objectContaining({ name: 'Charlie' })]); + + // Test disconnect operation + await expect( + client.user.update({ + where: { id: user1.id }, + data: { + following: { + disconnect: { id: user2.id }, }, - include: { following: true }, - }), - ).resolves.toMatchObject({ - name: 'Alice', - following: [expect.objectContaining({ name: 'Charlie' })], - }); - - // Verify Bob no longer has Alice as follower - await expect( - client.user.findUnique({ - where: { id: user2.id }, - include: { followers: true }, - }), - ).resolves.toMatchObject({ - name: 'Bob', - followers: [], - }); - - // Test set operation (replace all following) - await expect( - client.user.update({ - where: { id: user1.id }, - data: { - following: { - set: [{ id: user2.id }], - }, + }, + include: { following: true }, + }), + ).resolves.toMatchObject({ + name: 'Alice', + following: [expect.objectContaining({ name: 'Charlie' })], + }); + + // Verify Bob no longer has Alice as follower + await expect( + client.user.findUnique({ + where: { id: user2.id }, + include: { followers: true }, + }), + ).resolves.toMatchObject({ + name: 'Bob', + followers: [], + }); + + // Test set operation (replace all following) + await expect( + client.user.update({ + where: { id: user1.id }, + data: { + following: { + set: [{ id: user2.id }], }, - include: { following: true }, - }), - ).resolves.toMatchObject({ - name: 'Alice', - following: [expect.objectContaining({ name: 'Bob' })], - }); - - // Verify Charlie no longer has Alice as follower after set - await expect( - client.user.findUnique({ - where: { id: user3.id }, - include: { followers: true }, - }), - ).resolves.toMatchObject({ - name: 'Charlie', - followers: [expect.objectContaining({ name: 'Bob' })], - }); - - // Test connectOrCreate with existing user - await expect( - client.user.update({ - where: { id: user1.id }, - data: { - following: { - connectOrCreate: { - where: { id: user3.id }, - create: { name: 'Charlie' }, - }, + }, + include: { following: true }, + }), + ).resolves.toMatchObject({ + name: 'Alice', + following: [expect.objectContaining({ name: 'Bob' })], + }); + + // Verify Charlie no longer has Alice as follower after set + await expect( + client.user.findUnique({ + where: { id: user3.id }, + include: { followers: true }, + }), + ).resolves.toMatchObject({ + name: 'Charlie', + followers: [expect.objectContaining({ name: 'Bob' })], + }); + + // Test connectOrCreate with existing user + await expect( + client.user.update({ + where: { id: user1.id }, + data: { + following: { + connectOrCreate: { + where: { id: user3.id }, + create: { name: 'Charlie' }, }, }, - include: { following: true }, - }), - ).resolves.toMatchObject({ - name: 'Alice', - following: [expect.objectContaining({ name: 'Bob' }), expect.objectContaining({ name: 'Charlie' })], - }); - - // Test connectOrCreate with new user - await expect( - client.user.update({ - where: { id: user1.id }, - data: { - following: { - connectOrCreate: { - where: { id: 999 }, - create: { name: 'David' }, - }, + }, + include: { following: true }, + }), + ).resolves.toMatchObject({ + name: 'Alice', + following: [expect.objectContaining({ name: 'Bob' }), expect.objectContaining({ name: 'Charlie' })], + }); + + // Test connectOrCreate with new user + await expect( + client.user.update({ + where: { id: user1.id }, + data: { + following: { + connectOrCreate: { + where: { id: 999 }, + create: { name: 'David' }, }, }, - include: { following: true }, - }), - ).resolves.toMatchObject({ - name: 'Alice', - following: expect.arrayContaining([ - expect.objectContaining({ name: 'Bob' }), - expect.objectContaining({ name: 'Charlie' }), - expect.objectContaining({ name: 'David' }), - ]), - }); - - // Test create operation within update - await expect( - client.user.update({ - where: { id: user2.id }, - data: { - following: { - create: { name: 'Eve' }, - }, + }, + include: { following: true }, + }), + ).resolves.toMatchObject({ + name: 'Alice', + following: expect.arrayContaining([ + expect.objectContaining({ name: 'Bob' }), + expect.objectContaining({ name: 'Charlie' }), + expect.objectContaining({ name: 'David' }), + ]), + }); + + // Test create operation within update + await expect( + client.user.update({ + where: { id: user2.id }, + data: { + following: { + create: { name: 'Eve' }, }, - include: { following: true }, - }), - ).resolves.toMatchObject({ - name: 'Bob', - following: expect.arrayContaining([ - expect.objectContaining({ name: 'Charlie' }), - expect.objectContaining({ name: 'Eve' }), - ]), - }); - - // Test deleteMany operation (disconnect and delete) - const davidUser = await client.user.findFirst({ where: { name: 'David' } }); - const eveUser = await client.user.findFirst({ where: { name: 'Eve' } }); - - await expect( - client.user.update({ - where: { id: user1.id }, - data: { - following: { - deleteMany: { - name: { in: ['David', 'Eve'] }, - }, + }, + include: { following: true }, + }), + ).resolves.toMatchObject({ + name: 'Bob', + following: expect.arrayContaining([ + expect.objectContaining({ name: 'Charlie' }), + expect.objectContaining({ name: 'Eve' }), + ]), + }); + + // Test deleteMany operation (disconnect and delete) + const davidUser = await client.user.findFirst({ where: { name: 'David' } }); + const eveUser = await client.user.findFirst({ where: { name: 'Eve' } }); + + await expect( + client.user.update({ + where: { id: user1.id }, + data: { + following: { + deleteMany: { + name: { in: ['David', 'Eve'] }, }, }, - include: { following: true }, - }), - ).resolves.toMatchObject({ - name: 'Alice', - following: [expect.objectContaining({ name: 'Bob' }), expect.objectContaining({ name: 'Charlie' })], - }); - - // Verify David was deleted from database - await expect(client.user.findUnique({ where: { id: davidUser!.id } })).toResolveNull(); - await expect(client.user.findUnique({ where: { id: eveUser!.id } })).toResolveTruthy(); + }, + include: { following: true }, + }), + ).resolves.toMatchObject({ + name: 'Alice', + following: [expect.objectContaining({ name: 'Bob' }), expect.objectContaining({ name: 'Charlie' })], }); - it('works with explicit self-referencing many-to-many', async () => { - client = await createTestClient( - ` + // Verify David was deleted from database + await expect(client.user.findUnique({ where: { id: davidUser!.id } })).toResolveNull(); + await expect(client.user.findUnique({ where: { id: eveUser!.id } })).toResolveTruthy(); + }); + + it('works with explicit self-referencing many-to-many', async () => { + client = await createTestClient( + ` model User { id Int @id @default(autoincrement()) name String @@ -545,65 +532,61 @@ describe.each([{ provider: 'sqlite' as const }, { provider: 'postgresql' as cons @@unique([followerId, followingId]) } `, - { - provider, - dbName: TEST_DB, - }, - ); + ); - const user1 = await client.user.create({ data: { name: 'Alice' } }); - const user2 = await client.user.create({ data: { name: 'Bob' } }); + const user1 = await client.user.create({ data: { name: 'Alice' } }); + const user2 = await client.user.create({ data: { name: 'Bob' } }); - // Create follow relationship - await client.userFollow.create({ - data: { - followerId: user1.id, - followingId: user2.id, - }, - }); + // Create follow relationship + await client.userFollow.create({ + data: { + followerId: user1.id, + followingId: user2.id, + }, + }); - // Query following relationships - await expect( - client.user.findUnique({ - where: { id: user1.id }, - include: { - followingRelations: { - include: { following: true }, - }, + // Query following relationships + await expect( + client.user.findUnique({ + where: { id: user1.id }, + include: { + followingRelations: { + include: { following: true }, }, - }), - ).resolves.toMatchObject({ - name: 'Alice', - followingRelations: [ - expect.objectContaining({ - following: expect.objectContaining({ name: 'Bob' }), - }), - ], - }); - - // Query follower relationships - await expect( - client.user.findUnique({ - where: { id: user2.id }, - include: { - followerRelations: { - include: { follower: true }, - }, + }, + }), + ).resolves.toMatchObject({ + name: 'Alice', + followingRelations: [ + expect.objectContaining({ + following: expect.objectContaining({ name: 'Bob' }), + }), + ], + }); + + // Query follower relationships + await expect( + client.user.findUnique({ + where: { id: user2.id }, + include: { + followerRelations: { + include: { follower: true }, }, - }), - ).resolves.toMatchObject({ - name: 'Bob', - followerRelations: [ - expect.objectContaining({ - follower: expect.objectContaining({ name: 'Alice' }), - }), - ], - }); - }); - - it('works with multiple self relations on same model', async () => { - client = await createTestClient( - ` + }, + }), + ).resolves.toMatchObject({ + name: 'Bob', + followerRelations: [ + expect.objectContaining({ + follower: expect.objectContaining({ name: 'Alice' }), + }), + ], + }); + }); + + it('works with multiple self relations on same model', async () => { + client = await createTestClient( + ` model Person { id Int @id @default(autoincrement()) name String @@ -616,64 +599,60 @@ describe.each([{ provider: 'sqlite' as const }, { provider: 'postgresql' as cons mentorId Int? } `, - { - provider, - usePrismaPush: true, - dbName: TEST_DB, - }, - ); + { usePrismaPush: true }, + ); - // Create CEO - const ceo = await client.person.create({ - data: { name: 'CEO' }, - }); + // Create CEO + const ceo = await client.person.create({ + data: { name: 'CEO' }, + }); - // Create manager who reports to CEO and is also a mentor - const manager = await client.person.create({ + // Create manager who reports to CEO and is also a mentor + const manager = await client.person.create({ + data: { + name: 'Manager', + manager: { connect: { id: ceo.id } }, + }, + }); + + // Create employee who reports to manager and is mentored by CEO + await expect( + client.person.create({ data: { - name: 'Manager', - manager: { connect: { id: ceo.id } }, + name: 'Employee', + manager: { connect: { id: manager.id } }, + mentor: { connect: { id: ceo.id } }, }, - }); - - // Create employee who reports to manager and is mentored by CEO - await expect( - client.person.create({ - data: { - name: 'Employee', - manager: { connect: { id: manager.id } }, - mentor: { connect: { id: ceo.id } }, - }, - include: { - manager: true, - mentor: true, - }, - }), - ).resolves.toMatchObject({ - name: 'Employee', - manager: { name: 'Manager' }, - mentor: { name: 'CEO' }, - }); - - // Check CEO's reports and mentees - await expect( - client.person.findUnique({ - where: { id: ceo.id }, - include: { - reports: true, - mentees: true, - }, - }), - ).resolves.toMatchObject({ - name: 'CEO', - reports: [expect.objectContaining({ name: 'Manager' })], - mentees: [expect.objectContaining({ name: 'Employee' })], - }); + include: { + manager: true, + mentor: true, + }, + }), + ).resolves.toMatchObject({ + name: 'Employee', + manager: { name: 'Manager' }, + mentor: { name: 'CEO' }, }); - it('works with deep self relation queries', async () => { - client = await createTestClient( - ` + // Check CEO's reports and mentees + await expect( + client.person.findUnique({ + where: { id: ceo.id }, + include: { + reports: true, + mentees: true, + }, + }), + ).resolves.toMatchObject({ + name: 'CEO', + reports: [expect.objectContaining({ name: 'Manager' })], + mentees: [expect.objectContaining({ name: 'Employee' })], + }); + }); + + it('works with deep self relation queries', async () => { + client = await createTestClient( + ` model Comment { id Int @id @default(autoincrement()) content String @@ -682,76 +661,71 @@ describe.each([{ provider: 'sqlite' as const }, { provider: 'postgresql' as cons parentId Int? } `, - { - provider, - usePrismaPush: true, - dbName: TEST_DB, - }, - ); + { usePrismaPush: true }, + ); - // Create nested comment thread - const topComment = await client.comment.create({ - data: { - content: 'Top level comment', - replies: { - create: [ - { - content: 'First reply', - replies: { - create: [{ content: 'Nested reply 1' }, { content: 'Nested reply 2' }], - }, + // Create nested comment thread + const topComment = await client.comment.create({ + data: { + content: 'Top level comment', + replies: { + create: [ + { + content: 'First reply', + replies: { + create: [{ content: 'Nested reply 1' }, { content: 'Nested reply 2' }], }, - { content: 'Second reply' }, - ], - }, - }, - include: { - replies: { - include: { - replies: true, }, + { content: 'Second reply' }, + ], + }, + }, + include: { + replies: { + include: { + replies: true, }, }, - }); + }, + }); - expect(topComment).toMatchObject({ - content: 'Top level comment', - replies: [ - expect.objectContaining({ - content: 'First reply', - replies: [ - expect.objectContaining({ content: 'Nested reply 1' }), - expect.objectContaining({ content: 'Nested reply 2' }), - ], - }), - expect.objectContaining({ - content: 'Second reply', - replies: [], - }), - ], - }); - - // Query from nested comment up the chain - const nestedReply = await client.comment.findFirst({ - where: { content: 'Nested reply 1' }, - include: { - parent: { - include: { - parent: true, - }, + expect(topComment).toMatchObject({ + content: 'Top level comment', + replies: [ + expect.objectContaining({ + content: 'First reply', + replies: [ + expect.objectContaining({ content: 'Nested reply 1' }), + expect.objectContaining({ content: 'Nested reply 2' }), + ], + }), + expect.objectContaining({ + content: 'Second reply', + replies: [], + }), + ], + }); + + // Query from nested comment up the chain + const nestedReply = await client.comment.findFirst({ + where: { content: 'Nested reply 1' }, + include: { + parent: { + include: { + parent: true, }, }, - }); + }, + }); - expect(nestedReply).toMatchObject({ - content: 'Nested reply 1', + expect(nestedReply).toMatchObject({ + content: 'Nested reply 1', + parent: expect.objectContaining({ + content: 'First reply', parent: expect.objectContaining({ - content: 'First reply', - parent: expect.objectContaining({ - content: 'Top level comment', - }), + content: 'Top level comment', }), - }); + }), }); - }, -); + }); +}); diff --git a/packages/runtime/test/client-api/scalar-list.test.ts b/packages/runtime/test/client-api/scalar-list.test.ts index b10744e5..c9bfb0fc 100644 --- a/packages/runtime/test/client-api/scalar-list.test.ts +++ b/packages/runtime/test/client-api/scalar-list.test.ts @@ -1,8 +1,6 @@ import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import { createTestClient } from '../utils'; -const PG_DB_NAME = 'client-api-scalar-list-tests'; - describe('Scalar list tests', () => { const schema = ` model User { @@ -18,7 +16,6 @@ describe('Scalar list tests', () => { beforeEach(async () => { client = await createTestClient(schema, { provider: 'postgresql', - dbName: PG_DB_NAME, }); }); diff --git a/packages/runtime/test/client-api/transaction.test.ts b/packages/runtime/test/client-api/transaction.test.ts index 1daac7c5..c235420a 100644 --- a/packages/runtime/test/client-api/transaction.test.ts +++ b/packages/runtime/test/client-api/transaction.test.ts @@ -1,15 +1,13 @@ import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import type { ClientContract } from '../../src/client'; import { schema } from '../schemas/basic'; -import { createClientSpecs } from './client-specs'; +import { createTestClient } from '../utils'; -const PG_DB_NAME = 'client-api-transaction-tests'; - -describe.each(createClientSpecs(PG_DB_NAME))('Client raw query tests', ({ createClient }) => { +describe('Client raw query tests', () => { let client: ClientContract; beforeEach(async () => { - client = await createClient(); + client = await createTestClient(schema); }); afterEach(async () => { diff --git a/packages/runtime/test/client-api/type-coverage.test.ts b/packages/runtime/test/client-api/type-coverage.test.ts index 50f3e2bb..1055c712 100644 --- a/packages/runtime/test/client-api/type-coverage.test.ts +++ b/packages/runtime/test/client-api/type-coverage.test.ts @@ -1,10 +1,8 @@ import Decimal from 'decimal.js'; import { describe, expect, it } from 'vitest'; -import { createTestClient } from '../utils'; +import { createTestClient, getTestDbProvider } from '../utils'; -const PG_DB_NAME = 'client-api-type-coverage-tests'; - -describe.each(['sqlite', 'postgresql'] as const)('zmodel type coverage tests', (provider) => { +describe('Zmodel type coverage tests', () => { it('supports all types - plain', async () => { const date = new Date(); const data = { @@ -37,7 +35,6 @@ describe.each(['sqlite', 'postgresql'] as const)('zmodel type coverage tests', ( Json Json } `, - { provider, dbName: PG_DB_NAME }, ); await db.foo.create({ data }); @@ -64,7 +61,6 @@ describe.each(['sqlite', 'postgresql'] as const)('zmodel type coverage tests', ( Json Json @default("{\\"foo\\":\\"bar\\"}") } `, - { provider, dbName: PG_DB_NAME }, ); await db.foo.create({ data: { id: '1' } }); @@ -84,7 +80,7 @@ describe.each(['sqlite', 'postgresql'] as const)('zmodel type coverage tests', ( }); it('supports all types - array', async () => { - if (provider === 'sqlite') { + if (getTestDbProvider() === 'sqlite') { return; } @@ -120,7 +116,6 @@ describe.each(['sqlite', 'postgresql'] as const)('zmodel type coverage tests', ( Json Json[] } `, - { provider, dbName: PG_DB_NAME }, ); await db.foo.create({ data }); @@ -131,7 +126,7 @@ describe.each(['sqlite', 'postgresql'] as const)('zmodel type coverage tests', ( }); it('supports all types - array for plain json field', async () => { - if (provider === 'sqlite') { + if (getTestDbProvider() === 'sqlite') { return; } @@ -149,7 +144,6 @@ describe.each(['sqlite', 'postgresql'] as const)('zmodel type coverage tests', ( Json Json } `, - { provider, dbName: PG_DB_NAME }, ); await db.foo.create({ data }); diff --git a/packages/runtime/test/client-api/typed-json-fields.test.ts b/packages/runtime/test/client-api/typed-json-fields.test.ts index 4ea57c7e..65757169 100644 --- a/packages/runtime/test/client-api/typed-json-fields.test.ts +++ b/packages/runtime/test/client-api/typed-json-fields.test.ts @@ -1,12 +1,8 @@ import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import { createTestClient } from '../utils'; -const PG_DB_NAME = 'client-api-typed-json-fields-tests'; - -describe.each([{ provider: 'sqlite' as const }, { provider: 'postgresql' as const }])( - 'Typed JSON fields', - ({ provider }) => { - const schema = ` +describe('Typed JSON fields', () => { + const schema = ` type Identity { providers IdentityProvider[] } @@ -22,200 +18,197 @@ model User { } `; - let client: any; + let client: any; - beforeEach(async () => { - client = await createTestClient(schema, { - usePrismaPush: true, - provider, - dbName: provider === 'postgresql' ? PG_DB_NAME : undefined, - }); + beforeEach(async () => { + client = await createTestClient(schema, { + usePrismaPush: true, }); - - afterEach(async () => { - await client?.$disconnect(); + }); + + afterEach(async () => { + await client?.$disconnect(); + }); + + it('works with create', async () => { + await expect( + client.user.create({ + data: {}, + }), + ).resolves.toMatchObject({ + identity: null, }); - it('works with create', async () => { - await expect( - client.user.create({ - data: {}, - }), - ).resolves.toMatchObject({ - identity: null, - }); - - await expect( - client.user.create({ - data: { - identity: { - providers: [ - { - id: '123', - name: 'Google', - }, - ], - }, + await expect( + client.user.create({ + data: { + identity: { + providers: [ + { + id: '123', + name: 'Google', + }, + ], }, - }), - ).resolves.toMatchObject({ - identity: { - providers: [ - { - id: '123', - name: 'Google', - }, - ], }, - }); - - await expect( - client.user.create({ - data: { - identity: { - providers: [ - { - id: '123', - }, - ], - }, + }), + ).resolves.toMatchObject({ + identity: { + providers: [ + { + id: '123', + name: 'Google', + }, + ], + }, + }); + + await expect( + client.user.create({ + data: { + identity: { + providers: [ + { + id: '123', + }, + ], }, - }), - ).resolves.toMatchObject({ - identity: { - providers: [ - { - id: '123', - }, - ], }, - }); - - await expect( - client.user.create({ - data: { - identity: { - providers: [ - { - id: '123', - foo: 1, - }, - ], - }, + }), + ).resolves.toMatchObject({ + identity: { + providers: [ + { + id: '123', + }, + ], + }, + }); + + await expect( + client.user.create({ + data: { + identity: { + providers: [ + { + id: '123', + foo: 1, + }, + ], }, - }), - ).resolves.toMatchObject({ - identity: { - providers: [ - { - id: '123', - foo: 1, - }, - ], }, - }); - - await expect( - client.user.create({ - data: { - identity: { - providers: [ - { - name: 'Google', - }, - ], - }, + }), + ).resolves.toMatchObject({ + identity: { + providers: [ + { + id: '123', + foo: 1, }, - }), - ).rejects.toThrow(/invalid/i); + ], + }, }); - it('works with find', async () => { - await expect( - client.user.create({ - data: { id: 1 }, - }), - ).toResolveTruthy(); - await expect(client.user.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ - identity: null, - }); - - await expect( - client.user.create({ - data: { - id: 2, - identity: { - providers: [ - { - id: '123', - name: 'Google', - }, - ], - }, + await expect( + client.user.create({ + data: { + identity: { + providers: [ + { + name: 'Google', + }, + ], }, - }), - ).toResolveTruthy(); - - await expect(client.user.findUnique({ where: { id: 2 } })).resolves.toMatchObject({ - identity: { - providers: [ - { - id: '123', - name: 'Google', - }, - ], }, - }); + }), + ).rejects.toThrow(/invalid/i); + }); + + it('works with find', async () => { + await expect( + client.user.create({ + data: { id: 1 }, + }), + ).toResolveTruthy(); + await expect(client.user.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ + identity: null, }); - it('works with update', async () => { - await expect( - client.user.create({ - data: { id: 1 }, - }), - ).toResolveTruthy(); - - await expect( - client.user.update({ - where: { id: 1 }, - data: { - identity: { - providers: [ - { - id: '123', - name: 'Google', - foo: 1, - }, - ], - }, + await expect( + client.user.create({ + data: { + id: 2, + identity: { + providers: [ + { + id: '123', + name: 'Google', + }, + ], }, - }), - ).resolves.toMatchObject({ - identity: { - providers: [ - { - id: '123', - name: 'Google', - foo: 1, - }, - ], }, - }); - - await expect( - client.user.update({ - where: { id: 1 }, - data: { - identity: { - providers: [ - { - name: 'GitHub', - }, - ], - }, + }), + ).toResolveTruthy(); + + await expect(client.user.findUnique({ where: { id: 2 } })).resolves.toMatchObject({ + identity: { + providers: [ + { + id: '123', + name: 'Google', }, - }), - ).rejects.toThrow(/invalid/i); + ], + }, }); - }, -); + }); + + it('works with update', async () => { + await expect( + client.user.create({ + data: { id: 1 }, + }), + ).toResolveTruthy(); + + await expect( + client.user.update({ + where: { id: 1 }, + data: { + identity: { + providers: [ + { + id: '123', + name: 'Google', + foo: 1, + }, + ], + }, + }, + }), + ).resolves.toMatchObject({ + identity: { + providers: [ + { + id: '123', + name: 'Google', + foo: 1, + }, + ], + }, + }); + + await expect( + client.user.update({ + where: { id: 1 }, + data: { + identity: { + providers: [ + { + name: 'GitHub', + }, + ], + }, + }, + }), + ).rejects.toThrow(/invalid/i); + }); +}); diff --git a/packages/runtime/test/client-api/undefined-values.test.ts b/packages/runtime/test/client-api/undefined-values.test.ts index 07037be7..74d2851b 100644 --- a/packages/runtime/test/client-api/undefined-values.test.ts +++ b/packages/runtime/test/client-api/undefined-values.test.ts @@ -1,16 +1,14 @@ import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import type { ClientContract } from '../../src/client'; import { schema } from '../schemas/basic'; -import { createClientSpecs } from './client-specs'; +import { createTestClient } from '../utils'; import { createUser } from './utils'; -const PG_DB_NAME = 'client-api-undefined-values-tests'; - -describe.each(createClientSpecs(PG_DB_NAME))('Client undefined values tests for $provider', ({ createClient }) => { +describe('Client undefined values tests ', () => { let client: ClientContract; beforeEach(async () => { - client = await createClient(); + client = await createTestClient(schema); }); afterEach(async () => { diff --git a/packages/runtime/test/client-api/update-many.test.ts b/packages/runtime/test/client-api/update-many.test.ts index eaef7e63..934154fe 100644 --- a/packages/runtime/test/client-api/update-many.test.ts +++ b/packages/runtime/test/client-api/update-many.test.ts @@ -1,15 +1,13 @@ import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import type { ClientContract } from '../../src/client'; import { schema } from '../schemas/basic'; -import { createClientSpecs } from './client-specs'; +import { createTestClient } from '../utils'; -const PG_DB_NAME = 'client-api-update-many-tests'; - -describe.each(createClientSpecs(PG_DB_NAME))('Client updateMany tests', ({ createClient }) => { +describe('Client updateMany tests', () => { let client: ClientContract; beforeEach(async () => { - client = await createClient(); + client = await createTestClient(schema); }); afterEach(async () => { diff --git a/packages/runtime/test/client-api/update.test.ts b/packages/runtime/test/client-api/update.test.ts index a82a87bc..82ec4a5a 100644 --- a/packages/runtime/test/client-api/update.test.ts +++ b/packages/runtime/test/client-api/update.test.ts @@ -1,16 +1,14 @@ import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import type { ClientContract } from '../../src/client'; import { schema } from '../schemas/basic'; -import { createClientSpecs } from './client-specs'; +import { createTestClient } from '../utils'; import { createUser } from './utils'; -const PG_DB_NAME = 'client-api-update-tests'; - -describe.each(createClientSpecs(PG_DB_NAME))('Client update tests', ({ createClient }) => { +describe('Client update tests', () => { let client: ClientContract; beforeEach(async () => { - client = await createClient(); + client = await createTestClient(schema); }); afterEach(async () => { diff --git a/packages/runtime/test/client-api/upsert.test.ts b/packages/runtime/test/client-api/upsert.test.ts index 02ede41c..cbb16d65 100644 --- a/packages/runtime/test/client-api/upsert.test.ts +++ b/packages/runtime/test/client-api/upsert.test.ts @@ -1,15 +1,13 @@ import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import type { ClientContract } from '../../src/client'; import { schema } from '../schemas/basic'; -import { createClientSpecs } from './client-specs'; +import { createTestClient } from '../utils'; -const PG_DB_NAME = 'client-api-upsert-tests'; - -describe.each(createClientSpecs(PG_DB_NAME))('Client upsert tests', ({ createClient }) => { +describe('Client upsert tests', () => { let client: ClientContract; beforeEach(async () => { - client = await createClient(); + client = await createTestClient(schema); }); afterEach(async () => { diff --git a/packages/runtime/test/plugin/entity-mutation-hooks.test.ts b/packages/runtime/test/plugin/entity-mutation-hooks.test.ts index 96961c7c..9172d0f5 100644 --- a/packages/runtime/test/plugin/entity-mutation-hooks.test.ts +++ b/packages/runtime/test/plugin/entity-mutation-hooks.test.ts @@ -4,689 +4,681 @@ import { type ClientContract } from '../../src'; import { schema } from '../schemas/basic'; import { createTestClient } from '../utils'; -const TEST_DB = 'client-api-entity-mutation-hooks-test'; - -describe.each([{ provider: 'sqlite' as const }, { provider: 'postgresql' as const }])( - 'Entity mutation hooks tests for $provider', - ({ provider }) => { - let _client: ClientContract; - - beforeEach(async () => { - _client = await createTestClient(schema, { - provider, - dbName: TEST_DB, - }); +describe('Entity mutation hooks tests', () => { + let _client: ClientContract; + + beforeEach(async () => { + _client = await createTestClient(schema, {}); + }); + + afterEach(async () => { + await _client?.$disconnect(); + }); + + it('can intercept all mutations', async () => { + const beforeCalled = { create: false, update: false, delete: false }; + const afterCalled = { create: false, update: false, delete: false }; + + const client = _client.$use({ + id: 'test', + onEntityMutation: { + beforeEntityMutation(args) { + beforeCalled[args.action] = true; + if (args.action === 'create') { + expect(InsertQueryNode.is(args.queryNode)).toBe(true); + } + if (args.action === 'update') { + expect(UpdateQueryNode.is(args.queryNode)).toBe(true); + } + if (args.action === 'delete') { + expect(DeleteQueryNode.is(args.queryNode)).toBe(true); + } + }, + afterEntityMutation(args) { + afterCalled[args.action] = true; + }, + }, }); - afterEach(async () => { - await _client?.$disconnect(); + const user = await client.user.create({ + data: { email: 'u1@test.com' }, }); + await client.user.update({ + where: { id: user.id }, + data: { email: 'u2@test.com' }, + }); + await client.user.delete({ where: { id: user.id } }); - it('can intercept all mutations', async () => { - const beforeCalled = { create: false, update: false, delete: false }; - const afterCalled = { create: false, update: false, delete: false }; + expect(beforeCalled).toEqual({ + create: true, + update: true, + delete: true, + }); + expect(afterCalled).toEqual({ + create: true, + update: true, + delete: true, + }); + }); + + it('can intercept with loading before mutation entities', async () => { + const queryIds = { + update: { before: '', after: '' }, + delete: { before: '', after: '' }, + }; + + const client = _client.$use({ + id: 'test', + onEntityMutation: { + async beforeEntityMutation(args) { + if (args.action === 'update' || args.action === 'delete') { + await expect(args.loadBeforeMutationEntities()).resolves.toEqual([ + expect.objectContaining({ + email: args.action === 'update' ? 'u1@test.com' : 'u3@test.com', + }), + ]); + queryIds[args.action].before = args.queryId; + } + }, + async afterEntityMutation(args) { + if (args.action === 'update' || args.action === 'delete') { + queryIds[args.action].after = args.queryId; + } + }, + }, + }); - const client = _client.$use({ - id: 'test', - onEntityMutation: { - beforeEntityMutation(args) { - beforeCalled[args.action] = true; + const user = await client.user.create({ + data: { email: 'u1@test.com' }, + }); + await client.user.create({ + data: { email: 'u2@test.com' }, + }); + await client.user.update({ + where: { id: user.id }, + data: { email: 'u3@test.com' }, + }); + await client.user.delete({ where: { id: user.id } }); + + expect(queryIds.update.before).toBeTruthy(); + expect(queryIds.delete.before).toBeTruthy(); + expect(queryIds.update.before).toBe(queryIds.update.after); + expect(queryIds.delete.before).toBe(queryIds.delete.after); + }); + + it('can intercept with loading after mutation entities', async () => { + let userCreateIntercepted = false; + let userUpdateIntercepted = false; + const client = _client.$use({ + id: 'test', + onEntityMutation: { + async afterEntityMutation(args) { + if (args.action === 'create' || args.action === 'update') { if (args.action === 'create') { - expect(InsertQueryNode.is(args.queryNode)).toBe(true); + userCreateIntercepted = true; } if (args.action === 'update') { - expect(UpdateQueryNode.is(args.queryNode)).toBe(true); + userUpdateIntercepted = true; } - if (args.action === 'delete') { - expect(DeleteQueryNode.is(args.queryNode)).toBe(true); + await expect(args.loadAfterMutationEntities()).resolves.toEqual( + expect.arrayContaining([ + expect.objectContaining({ + email: args.action === 'create' ? 'u1@test.com' : 'u2@test.com', + }), + ]), + ); + } + }, + }, + }); + + const user = await client.user.create({ + data: { email: 'u1@test.com' }, + }); + await client.user.update({ + where: { id: user.id }, + data: { email: 'u2@test.com' }, + }); + + expect(userCreateIntercepted).toBe(true); + expect(userUpdateIntercepted).toBe(true); + }); + + it('can intercept multi-entity mutations', async () => { + let userCreateIntercepted = false; + let userUpdateIntercepted = false; + let userDeleteIntercepted = false; + + const client = _client.$use({ + id: 'test', + onEntityMutation: { + async afterEntityMutation(args) { + if (args.action === 'create') { + userCreateIntercepted = true; + await expect(args.loadAfterMutationEntities()).resolves.toEqual( + expect.arrayContaining([ + expect.objectContaining({ email: 'u1@test.com' }), + expect.objectContaining({ email: 'u2@test.com' }), + ]), + ); + } else if (args.action === 'update') { + userUpdateIntercepted = true; + await expect(args.loadAfterMutationEntities()).resolves.toEqual( + expect.arrayContaining([ + expect.objectContaining({ + email: 'u1@test.com', + name: 'A user', + }), + expect.objectContaining({ + email: 'u2@test.com', + name: 'A user', + }), + ]), + ); + } else if (args.action === 'delete') { + userDeleteIntercepted = true; + await expect(args.loadAfterMutationEntities()).resolves.toEqual( + expect.arrayContaining([ + expect.objectContaining({ email: 'u1@test.com' }), + expect.objectContaining({ email: 'u2@test.com' }), + ]), + ); + } + }, + }, + }); + + await client.user.createMany({ + data: [{ email: 'u1@test.com' }, { email: 'u2@test.com' }], + }); + await client.user.updateMany({ + data: { name: 'A user' }, + }); + + expect(userCreateIntercepted).toBe(true); + expect(userUpdateIntercepted).toBe(true); + expect(userDeleteIntercepted).toBe(false); + }); + + it('can intercept nested mutations', async () => { + let post1Intercepted = false; + let post2Intercepted = false; + const client = _client.$use({ + id: 'test', + onEntityMutation: { + async afterEntityMutation(args) { + if (args.action === 'create') { + if (args.model === 'Post') { + const afterEntities = await args.loadAfterMutationEntities(); + if ((afterEntities![0] as any).title === 'Post1') { + post1Intercepted = true; + } + if ((afterEntities![0] as any).title === 'Post2') { + post2Intercepted = true; + } } - }, - afterEntityMutation(args) { - afterCalled[args.action] = true; - }, + } }, - }); + }, + }); - const user = await client.user.create({ + const user = await client.user.create({ + data: { + email: 'u1@test.com', + posts: { create: { title: 'Post1' } }, + }, + }); + await client.user.update({ + where: { id: user.id }, + data: { + email: 'u2@test.com', + posts: { create: { title: 'Post2' } }, + }, + }); + + expect(post1Intercepted).toBe(true); + expect(post2Intercepted).toBe(true); + }); + + it('triggers multiple afterEntityMutation hooks for multiple mutations', async () => { + const triggered: any[] = []; + + const client = _client.$use({ + id: 'test', + onEntityMutation: { + async afterEntityMutation(args) { + triggered.push({ + action: args.action, + model: args.model, + afterMutationEntities: await args.loadAfterMutationEntities(), + }); + }, + }, + }); + + await client.$transaction(async (tx) => { + await tx.user.create({ data: { email: 'u1@test.com' }, }); - await client.user.update({ - where: { id: user.id }, + await tx.user.update({ + where: { email: 'u1@test.com' }, data: { email: 'u2@test.com' }, }); - await client.user.delete({ where: { id: user.id } }); - - expect(beforeCalled).toEqual({ - create: true, - update: true, - delete: true, - }); - expect(afterCalled).toEqual({ - create: true, - update: true, - delete: true, - }); + await tx.user.delete({ where: { email: 'u2@test.com' } }); }); - it('can intercept with loading before mutation entities', async () => { - const queryIds = { - update: { before: '', after: '' }, - delete: { before: '', after: '' }, - }; + expect(triggered).toEqual([ + expect.objectContaining({ + action: 'create', + model: 'User', + afterMutationEntities: [expect.objectContaining({ email: 'u1@test.com' })], + }), + expect.objectContaining({ + action: 'update', + model: 'User', + afterMutationEntities: [expect.objectContaining({ email: 'u2@test.com' })], + }), + expect.objectContaining({ + action: 'delete', + model: 'User', + afterMutationEntities: undefined, + }), + ]); + }); + + describe('Without outer transaction', () => { + it('persists hooks db side effects when run out of tx', async () => { + let intercepted = false; const client = _client.$use({ id: 'test', onEntityMutation: { - async beforeEntityMutation(args) { - if (args.action === 'update' || args.action === 'delete') { - await expect(args.loadBeforeMutationEntities()).resolves.toEqual([ - expect.objectContaining({ - email: args.action === 'update' ? 'u1@test.com' : 'u3@test.com', - }), - ]); - queryIds[args.action].before = args.queryId; - } + async beforeEntityMutation(ctx) { + await ctx.client.profile.create({ + data: { bio: 'Bio1' }, + }); }, - async afterEntityMutation(args) { - if (args.action === 'update' || args.action === 'delete') { - queryIds[args.action].after = args.queryId; - } + async afterEntityMutation(ctx) { + intercepted = true; + await ctx.client.user.update({ + where: { email: 'u1@test.com' }, + data: { email: 'u2@test.com' }, + }); }, }, }); - const user = await client.user.create({ - data: { email: 'u1@test.com' }, - }); await client.user.create({ - data: { email: 'u2@test.com' }, - }); - await client.user.update({ - where: { id: user.id }, - data: { email: 'u3@test.com' }, + data: { email: 'u1@test.com' }, }); - await client.user.delete({ where: { id: user.id } }); - - expect(queryIds.update.before).toBeTruthy(); - expect(queryIds.delete.before).toBeTruthy(); - expect(queryIds.update.before).toBe(queryIds.update.after); - expect(queryIds.delete.before).toBe(queryIds.delete.after); + expect(intercepted).toBe(true); + // both the mutation and hook's side effect are persisted + await expect(client.profile.findMany()).toResolveWithLength(1); + await expect(client.user.findFirst()).resolves.toMatchObject({ email: 'u2@test.com' }); }); - it('can intercept with loading after mutation entities', async () => { - let userCreateIntercepted = false; - let userUpdateIntercepted = false; + it('persists hooks db side effects when run within tx', async () => { + let intercepted = false; + const client = _client.$use({ id: 'test', onEntityMutation: { - async afterEntityMutation(args) { - if (args.action === 'create' || args.action === 'update') { - if (args.action === 'create') { - userCreateIntercepted = true; - } - if (args.action === 'update') { - userUpdateIntercepted = true; - } - await expect(args.loadAfterMutationEntities()).resolves.toEqual( - expect.arrayContaining([ - expect.objectContaining({ - email: args.action === 'create' ? 'u1@test.com' : 'u2@test.com', - }), - ]), - ); - } + runAfterMutationWithinTransaction: true, + async beforeEntityMutation(ctx) { + await ctx.client.profile.create({ + data: { bio: 'Bio1' }, + }); + }, + async afterEntityMutation(ctx) { + intercepted = true; + await ctx.client.user.update({ + where: { email: 'u1@test.com' }, + data: { email: 'u2@test.com' }, + }); }, }, }); - const user = await client.user.create({ + await client.user.create({ data: { email: 'u1@test.com' }, }); - await client.user.update({ - where: { id: user.id }, - data: { email: 'u2@test.com' }, - }); - - expect(userCreateIntercepted).toBe(true); - expect(userUpdateIntercepted).toBe(true); + expect(intercepted).toBe(true); + // both the mutation and hook's side effect are persisted + await expect(client.profile.findMany()).toResolveWithLength(1); + await expect(client.user.findFirst()).resolves.toMatchObject({ email: 'u2@test.com' }); }); - it('can intercept multi-entity mutations', async () => { - let userCreateIntercepted = false; - let userUpdateIntercepted = false; - let userDeleteIntercepted = false; - + it('fails the mutation if before mutation hook throws', async () => { const client = _client.$use({ id: 'test', onEntityMutation: { - async afterEntityMutation(args) { - if (args.action === 'create') { - userCreateIntercepted = true; - await expect(args.loadAfterMutationEntities()).resolves.toEqual( - expect.arrayContaining([ - expect.objectContaining({ email: 'u1@test.com' }), - expect.objectContaining({ email: 'u2@test.com' }), - ]), - ); - } else if (args.action === 'update') { - userUpdateIntercepted = true; - await expect(args.loadAfterMutationEntities()).resolves.toEqual( - expect.arrayContaining([ - expect.objectContaining({ - email: 'u1@test.com', - name: 'A user', - }), - expect.objectContaining({ - email: 'u2@test.com', - name: 'A user', - }), - ]), - ); - } else if (args.action === 'delete') { - userDeleteIntercepted = true; - await expect(args.loadAfterMutationEntities()).resolves.toEqual( - expect.arrayContaining([ - expect.objectContaining({ email: 'u1@test.com' }), - expect.objectContaining({ email: 'u2@test.com' }), - ]), - ); - } + async beforeEntityMutation() { + throw new Error('trigger failure'); }, }, }); - await client.user.createMany({ - data: [{ email: 'u1@test.com' }, { email: 'u2@test.com' }], - }); - await client.user.updateMany({ - data: { name: 'A user' }, - }); + await expect( + client.user.create({ + data: { email: 'u1@test.com' }, + }), + ).rejects.toThrow(); - expect(userCreateIntercepted).toBe(true); - expect(userUpdateIntercepted).toBe(true); - expect(userDeleteIntercepted).toBe(false); + // mutation is persisted + await expect(client.user.findMany()).toResolveWithLength(0); }); - it('can intercept nested mutations', async () => { - let post1Intercepted = false; - let post2Intercepted = false; + it('does not affect the database operation if after mutation hook throws', async () => { + let intercepted = false; + const client = _client.$use({ id: 'test', onEntityMutation: { - async afterEntityMutation(args) { - if (args.action === 'create') { - if (args.model === 'Post') { - const afterEntities = await args.loadAfterMutationEntities(); - if ((afterEntities![0] as any).title === 'Post1') { - post1Intercepted = true; - } - if ((afterEntities![0] as any).title === 'Post2') { - post2Intercepted = true; - } - } - } + async afterEntityMutation() { + intercepted = true; + throw new Error('trigger rollback'); }, }, }); - const user = await client.user.create({ - data: { - email: 'u1@test.com', - posts: { create: { title: 'Post1' } }, - }, - }); - await client.user.update({ - where: { id: user.id }, - data: { - email: 'u2@test.com', - posts: { create: { title: 'Post2' } }, - }, + await client.user.create({ + data: { email: 'u1@test.com' }, }); - expect(post1Intercepted).toBe(true); - expect(post2Intercepted).toBe(true); + expect(intercepted).toBe(true); + // mutation is persisted + await expect(client.user.findMany()).toResolveWithLength(1); }); - it('triggers multiple afterEntityMutation hooks for multiple mutations', async () => { - const triggered: any[] = []; + it('fails the entire transaction if specified to run inside the tx', async () => { + let intercepted = false; const client = _client.$use({ id: 'test', onEntityMutation: { - async afterEntityMutation(args) { - triggered.push({ - action: args.action, - model: args.model, - afterMutationEntities: await args.loadAfterMutationEntities(), - }); + runAfterMutationWithinTransaction: true, + async afterEntityMutation(ctx) { + intercepted = true; + await ctx.client.user.create({ data: { email: 'u2@test.com' } }); + throw new Error('trigger rollback'); }, }, }); - await client.$transaction(async (tx) => { - await tx.user.create({ + await expect( + client.user.create({ data: { email: 'u1@test.com' }, - }); - await tx.user.update({ - where: { email: 'u1@test.com' }, - data: { email: 'u2@test.com' }, - }); - await tx.user.delete({ where: { email: 'u2@test.com' } }); - }); - - expect(triggered).toEqual([ - expect.objectContaining({ - action: 'create', - model: 'User', - afterMutationEntities: [expect.objectContaining({ email: 'u1@test.com' })], }), - expect.objectContaining({ - action: 'update', - model: 'User', - afterMutationEntities: [expect.objectContaining({ email: 'u2@test.com' })], - }), - expect.objectContaining({ - action: 'delete', - model: 'User', - afterMutationEntities: undefined, - }), - ]); - }); - - describe('Without outer transaction', () => { - it('persists hooks db side effects when run out of tx', async () => { - let intercepted = false; + ).rejects.toThrow(); - const client = _client.$use({ - id: 'test', - onEntityMutation: { - async beforeEntityMutation(ctx) { - await ctx.client.profile.create({ - data: { bio: 'Bio1' }, - }); - }, - async afterEntityMutation(ctx) { - intercepted = true; - await ctx.client.user.update({ - where: { email: 'u1@test.com' }, - data: { email: 'u2@test.com' }, - }); - }, - }, - }); - - await client.user.create({ - data: { email: 'u1@test.com' }, - }); - expect(intercepted).toBe(true); - // both the mutation and hook's side effect are persisted - await expect(client.profile.findMany()).toResolveWithLength(1); - await expect(client.user.findFirst()).resolves.toMatchObject({ email: 'u2@test.com' }); - }); + expect(intercepted).toBe(true); + // mutation is not persisted + await expect(client.user.findMany()).toResolveWithLength(0); + }); - it('persists hooks db side effects when run within tx', async () => { - let intercepted = false; + it('does not trigger afterEntityMutation hook if a transaction is rolled back', async () => { + let intercepted = false; - const client = _client.$use({ - id: 'test', - onEntityMutation: { - runAfterMutationWithinTransaction: true, - async beforeEntityMutation(ctx) { - await ctx.client.profile.create({ - data: { bio: 'Bio1' }, - }); - }, - async afterEntityMutation(ctx) { - intercepted = true; - await ctx.client.user.update({ - where: { email: 'u1@test.com' }, - data: { email: 'u2@test.com' }, - }); - }, + const client = _client.$use({ + id: 'test', + onEntityMutation: { + async afterEntityMutation(ctx) { + intercepted = true; + await ctx.client.user.create({ data: { email: 'u2@test.com' } }); }, - }); - - await client.user.create({ - data: { email: 'u1@test.com' }, - }); - expect(intercepted).toBe(true); - // both the mutation and hook's side effect are persisted - await expect(client.profile.findMany()).toResolveWithLength(1); - await expect(client.user.findFirst()).resolves.toMatchObject({ email: 'u2@test.com' }); + }, }); - it('fails the mutation if before mutation hook throws', async () => { - const client = _client.$use({ - id: 'test', - onEntityMutation: { - async beforeEntityMutation() { - throw new Error('trigger failure'); - }, - }, - }); - - await expect( - client.user.create({ + try { + await client.$transaction(async (tx) => { + await tx.user.create({ data: { email: 'u1@test.com' }, - }), - ).rejects.toThrow(); + }); + throw new Error('trigger rollback'); + }); + } catch { + // noop + } - // mutation is persisted - await expect(client.user.findMany()).toResolveWithLength(0); - }); + expect(intercepted).toBe(false); + // neither the mutation nor the hook's side effect are persisted + await expect(client.user.findMany()).toResolveWithLength(0); + }); - it('does not affect the database operation if after mutation hook throws', async () => { - let intercepted = false; + it('triggers afterEntityMutation hook if a transaction is rolled back but hook runs within tx', async () => { + let intercepted = false; - const client = _client.$use({ - id: 'test', - onEntityMutation: { - async afterEntityMutation() { - intercepted = true; - throw new Error('trigger rollback'); - }, + const client = _client.$use({ + id: 'test', + onEntityMutation: { + runAfterMutationWithinTransaction: true, + async afterEntityMutation(ctx) { + intercepted = true; + await ctx.client.user.create({ data: { email: 'u2@test.com' } }); }, - }); + }, + }); - await client.user.create({ - data: { email: 'u1@test.com' }, + try { + await client.$transaction(async (tx) => { + await tx.user.create({ + data: { email: 'u1@test.com' }, + }); + throw new Error('trigger rollback'); }); + } catch { + // noop + } - expect(intercepted).toBe(true); - // mutation is persisted - await expect(client.user.findMany()).toResolveWithLength(1); - }); - - it('fails the entire transaction if specified to run inside the tx', async () => { - let intercepted = false; + expect(intercepted).toBe(true); + // neither the mutation nor the hook's side effect are persisted + await expect(client.user.findMany()).toResolveWithLength(0); + }); + }); - const client = _client.$use({ - id: 'test', - onEntityMutation: { - runAfterMutationWithinTransaction: true, - async afterEntityMutation(ctx) { + describe('With outer transaction', () => { + it('sees changes in the transaction prior to reading before mutation entities', async () => { + let intercepted = false; + const client = _client.$use({ + id: 'test', + onEntityMutation: { + async beforeEntityMutation(args) { + if (args.action === 'update') { intercepted = true; - await ctx.client.user.create({ data: { email: 'u2@test.com' } }); - throw new Error('trigger rollback'); - }, + await expect(args.loadBeforeMutationEntities()).resolves.toEqual([ + expect.objectContaining({ email: 'u1@test.com' }), + ]); + } }, - }); - - await expect( - client.user.create({ - data: { email: 'u1@test.com' }, - }), - ).rejects.toThrow(); - - expect(intercepted).toBe(true); - // mutation is not persisted - await expect(client.user.findMany()).toResolveWithLength(0); + }, }); - it('does not trigger afterEntityMutation hook if a transaction is rolled back', async () => { - let intercepted = false; - - const client = _client.$use({ - id: 'test', - onEntityMutation: { - async afterEntityMutation(ctx) { - intercepted = true; - await ctx.client.user.create({ data: { email: 'u2@test.com' } }); - }, - }, + await client.$transaction(async (tx) => { + await tx.user.create({ data: { email: 'u1@test.com' } }); + await tx.user.update({ + where: { email: 'u1@test.com' }, + data: { email: 'u2@test.com' }, }); - - try { - await client.$transaction(async (tx) => { - await tx.user.create({ - data: { email: 'u1@test.com' }, - }); - throw new Error('trigger rollback'); - }); - } catch { - // noop - } - - expect(intercepted).toBe(false); - // neither the mutation nor the hook's side effect are persisted - await expect(client.user.findMany()).toResolveWithLength(0); }); - it('triggers afterEntityMutation hook if a transaction is rolled back but hook runs within tx', async () => { - let intercepted = false; + expect(intercepted).toBe(true); + }); - const client = _client.$use({ - id: 'test', - onEntityMutation: { - runAfterMutationWithinTransaction: true, - async afterEntityMutation(ctx) { - intercepted = true; - await ctx.client.user.create({ data: { email: 'u2@test.com' } }); - }, + it('runs before mutation hook within the transaction', async () => { + let intercepted = false; + const client = _client.$use({ + id: 'test', + onEntityMutation: { + async beforeEntityMutation(ctx) { + intercepted = true; + await ctx.client.profile.create({ + data: { bio: 'Bio1' }, + }); }, - }); + }, + }); - try { - await client.$transaction(async (tx) => { - await tx.user.create({ - data: { email: 'u1@test.com' }, - }); - throw new Error('trigger rollback'); + await expect( + client.$transaction(async (tx) => { + await tx.user.create({ + data: { email: 'u1@test.com' }, }); - } catch { - // noop - } + throw new Error('trigger rollback'); + }), + ).rejects.toThrow(); - expect(intercepted).toBe(true); - // neither the mutation nor the hook's side effect are persisted - await expect(client.user.findMany()).toResolveWithLength(0); - }); + expect(intercepted).toBe(true); + await expect(client.user.findMany()).toResolveWithLength(0); + await expect(client.profile.findMany()).toResolveWithLength(0); }); - describe('With outer transaction', () => { - it('sees changes in the transaction prior to reading before mutation entities', async () => { - let intercepted = false; - const client = _client.$use({ - id: 'test', - onEntityMutation: { - async beforeEntityMutation(args) { - if (args.action === 'update') { - intercepted = true; - await expect(args.loadBeforeMutationEntities()).resolves.toEqual([ - expect.objectContaining({ email: 'u1@test.com' }), - ]); - } - }, - }, - }); + it('persists hooks db side effects when run out of tx', async () => { + let intercepted = false; + let txVisible = false; - await client.$transaction(async (tx) => { - await tx.user.create({ data: { email: 'u1@test.com' } }); - await tx.user.update({ - where: { email: 'u1@test.com' }, - data: { email: 'u2@test.com' }, - }); - }); - - expect(intercepted).toBe(true); - }); - - it('runs before mutation hook within the transaction', async () => { - let intercepted = false; - const client = _client.$use({ - id: 'test', - onEntityMutation: { - async beforeEntityMutation(ctx) { - intercepted = true; + const client = _client.$use({ + id: 'test', + onEntityMutation: { + async beforeEntityMutation(ctx) { + const r = await ctx.client.user.findUnique({ where: { email: 'u1@test.com' } }); + if (r) { + // second create + txVisible = true; + } else { + // first create await ctx.client.profile.create({ data: { bio: 'Bio1' }, }); - }, + } }, - }); - - await expect( - client.$transaction(async (tx) => { - await tx.user.create({ - data: { email: 'u1@test.com' }, + async afterEntityMutation(ctx) { + if (intercepted) { + return; + } + intercepted = true; + await ctx.client.user.update({ + where: { email: 'u1@test.com' }, + data: { email: 'u3@test.com' }, }); - throw new Error('trigger rollback'); - }), - ).rejects.toThrow(); - - expect(intercepted).toBe(true); - await expect(client.user.findMany()).toResolveWithLength(0); - await expect(client.profile.findMany()).toResolveWithLength(0); + }, + }, }); - it('persists hooks db side effects when run out of tx', async () => { - let intercepted = false; - let txVisible = false; - - const client = _client.$use({ - id: 'test', - onEntityMutation: { - async beforeEntityMutation(ctx) { - const r = await ctx.client.user.findUnique({ where: { email: 'u1@test.com' } }); - if (r) { - // second create - txVisible = true; - } else { - // first create - await ctx.client.profile.create({ - data: { bio: 'Bio1' }, - }); - } - }, - async afterEntityMutation(ctx) { - if (intercepted) { - return; - } - intercepted = true; - await ctx.client.user.update({ - where: { email: 'u1@test.com' }, - data: { email: 'u3@test.com' }, - }); - }, - }, + await client.$transaction(async (tx) => { + await tx.user.create({ + data: { email: 'u1@test.com' }, }); - - await client.$transaction(async (tx) => { - await tx.user.create({ - data: { email: 'u1@test.com' }, - }); - await tx.user.create({ - data: { email: 'u2@test.com' }, - }); + await tx.user.create({ + data: { email: 'u2@test.com' }, }); - - expect(intercepted).toBe(true); - expect(txVisible).toBe(true); - - // both the mutation and hook's side effect are persisted - await expect(client.profile.findMany()).toResolveWithLength(1); - await expect(client.user.findMany()).resolves.toEqual( - expect.arrayContaining([ - expect.objectContaining({ email: 'u2@test.com' }), - expect.objectContaining({ email: 'u3@test.com' }), - ]), - ); }); - it('persists hooks db side effects when run within tx', async () => { - let intercepted = false; + expect(intercepted).toBe(true); + expect(txVisible).toBe(true); - const client = _client.$use({ - id: 'test', - onEntityMutation: { - runAfterMutationWithinTransaction: true, - async afterEntityMutation(ctx) { - if (intercepted) { - return; - } - intercepted = true; - await ctx.client.user.update({ - where: { email: 'u1@test.com' }, - data: { email: 'u3@test.com' }, - }); - }, + // both the mutation and hook's side effect are persisted + await expect(client.profile.findMany()).toResolveWithLength(1); + await expect(client.user.findMany()).resolves.toEqual( + expect.arrayContaining([ + expect.objectContaining({ email: 'u2@test.com' }), + expect.objectContaining({ email: 'u3@test.com' }), + ]), + ); + }); + + it('persists hooks db side effects when run within tx', async () => { + let intercepted = false; + + const client = _client.$use({ + id: 'test', + onEntityMutation: { + runAfterMutationWithinTransaction: true, + async afterEntityMutation(ctx) { + if (intercepted) { + return; + } + intercepted = true; + await ctx.client.user.update({ + where: { email: 'u1@test.com' }, + data: { email: 'u3@test.com' }, + }); }, - }); + }, + }); - await client.$transaction(async (tx) => { - await tx.user.create({ - data: { email: 'u1@test.com' }, - }); - await tx.user.create({ - data: { email: 'u2@test.com' }, - }); + await client.$transaction(async (tx) => { + await tx.user.create({ + data: { email: 'u1@test.com' }, }); + await tx.user.create({ + data: { email: 'u2@test.com' }, + }); + }); - expect(intercepted).toBe(true); + expect(intercepted).toBe(true); - // both the mutation and hook's side effect are persisted - await expect(client.user.findMany()).resolves.toEqual( - expect.arrayContaining([ - expect.objectContaining({ email: 'u2@test.com' }), - expect.objectContaining({ email: 'u3@test.com' }), - ]), - ); - }); + // both the mutation and hook's side effect are persisted + await expect(client.user.findMany()).resolves.toEqual( + expect.arrayContaining([ + expect.objectContaining({ email: 'u2@test.com' }), + expect.objectContaining({ email: 'u3@test.com' }), + ]), + ); + }); - it('persists mutation when run out of tx and throws', async () => { - let intercepted = false; + it('persists mutation when run out of tx and throws', async () => { + let intercepted = false; - const client = _client.$use({ - id: 'test', - onEntityMutation: { - async afterEntityMutation(ctx) { - intercepted = true; - await ctx.client.user.create({ data: { email: 'u2@test.com' } }); - throw new Error('trigger error'); - }, + const client = _client.$use({ + id: 'test', + onEntityMutation: { + async afterEntityMutation(ctx) { + intercepted = true; + await ctx.client.user.create({ data: { email: 'u2@test.com' } }); + throw new Error('trigger error'); }, - }); + }, + }); - await client.$transaction(async (tx) => { - await tx.user.create({ - data: { email: 'u1@test.com' }, - }); + await client.$transaction(async (tx) => { + await tx.user.create({ + data: { email: 'u1@test.com' }, }); + }); - expect(intercepted).toBe(true); + expect(intercepted).toBe(true); - // both the mutation and hook's side effect are persisted - await expect(client.user.findMany()).toResolveWithLength(2); - }); + // both the mutation and hook's side effect are persisted + await expect(client.user.findMany()).toResolveWithLength(2); + }); - it('rolls back mutation when run within tx and throws', async () => { - let intercepted = false; + it('rolls back mutation when run within tx and throws', async () => { + let intercepted = false; - const client = _client.$use({ - id: 'test', - onEntityMutation: { - runAfterMutationWithinTransaction: true, - async afterEntityMutation(ctx) { - intercepted = true; - await ctx.client.user.create({ data: { email: 'u2@test.com' } }); - throw new Error('trigger error'); - }, + const client = _client.$use({ + id: 'test', + onEntityMutation: { + runAfterMutationWithinTransaction: true, + async afterEntityMutation(ctx) { + intercepted = true; + await ctx.client.user.create({ data: { email: 'u2@test.com' } }); + throw new Error('trigger error'); }, - }); + }, + }); - await expect( - client.$transaction(async (tx) => { - await tx.user.create({ - data: { email: 'u1@test.com' }, - }); - }), - ).rejects.toThrow(); + await expect( + client.$transaction(async (tx) => { + await tx.user.create({ + data: { email: 'u1@test.com' }, + }); + }), + ).rejects.toThrow(); - expect(intercepted).toBe(true); + expect(intercepted).toBe(true); - // both the mutation and hook's side effect are rolled back - await expect(client.user.findMany()).toResolveWithLength(0); - }); + // both the mutation and hook's side effect are rolled back + await expect(client.user.findMany()).toResolveWithLength(0); }); - }, -); + }); +}); diff --git a/packages/runtime/test/plugin/on-kysely-query.test.ts b/packages/runtime/test/plugin/on-kysely-query.test.ts index 75105927..4fe5855d 100644 --- a/packages/runtime/test/plugin/on-kysely-query.test.ts +++ b/packages/runtime/test/plugin/on-kysely-query.test.ts @@ -1,17 +1,18 @@ -import SQLite from 'better-sqlite3'; -import { InsertQueryNode, Kysely, PrimitiveValueListNode, SqliteDialect, ValuesNode, type QueryResult } from 'kysely'; -import { beforeEach, describe, expect, it } from 'vitest'; -import { ZenStackClient, type ClientContract } from '../../src/client'; +import { InsertQueryNode, Kysely, PrimitiveValueListNode, ValuesNode, type QueryResult } from 'kysely'; +import { afterEach, beforeEach, describe, expect, it } from 'vitest'; +import { type ClientContract } from '../../src/client'; import { schema } from '../schemas/basic'; +import { createTestClient } from '../utils'; describe('On kysely query tests', () => { let _client: ClientContract; beforeEach(async () => { - _client = new ZenStackClient(schema, { - dialect: new SqliteDialect({ database: new SQLite(':memory:') }), - }); - await _client.$pushSchema(); + _client = await createTestClient(schema); + }); + + afterEach(async () => { + await _client.$disconnect(); }); it('intercepts queries', async () => { diff --git a/packages/runtime/test/plugin/on-query-hooks.test.ts b/packages/runtime/test/plugin/on-query-hooks.test.ts index 3e4c478b..3a6df8ca 100644 --- a/packages/runtime/test/plugin/on-query-hooks.test.ts +++ b/packages/runtime/test/plugin/on-query-hooks.test.ts @@ -1,17 +1,13 @@ -import SQLite from 'better-sqlite3'; -import { SqliteDialect } from 'kysely'; import { afterEach, beforeEach, describe, expect, it } from 'vitest'; -import { definePlugin, ZenStackClient, type ClientContract } from '../../src/client'; +import { definePlugin, type ClientContract } from '../../src/client'; import { schema } from '../schemas/basic'; +import { createTestClient } from '../utils'; describe('On query hooks tests', () => { let _client: ClientContract; beforeEach(async () => { - _client = new ZenStackClient(schema, { - dialect: new SqliteDialect({ database: new SQLite(':memory:') }), - }); - await _client.$pushSchema(); + _client = await createTestClient(schema); }); afterEach(async () => { diff --git a/packages/runtime/test/policy/read.test.ts b/packages/runtime/test/policy/basic-schema-read.test.ts similarity index 79% rename from packages/runtime/test/policy/read.test.ts rename to packages/runtime/test/policy/basic-schema-read.test.ts index eb1ccb41..c8b8c87e 100644 --- a/packages/runtime/test/policy/read.test.ts +++ b/packages/runtime/test/policy/basic-schema-read.test.ts @@ -1,16 +1,14 @@ import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import { type ClientContract } from '../../src/client'; import { PolicyPlugin } from '../../src/plugins/policy/plugin'; -import { createClientSpecs } from '../client-api/client-specs'; import { schema } from '../schemas/basic'; +import { createTestClient } from '../utils'; -const PG_DB_NAME = 'policy-read-tests'; - -describe.each(createClientSpecs(PG_DB_NAME))('Read policy tests', ({ createClient }) => { +describe('Read policy tests', () => { let client: ClientContract; beforeEach(async () => { - client = await createClient(); + client = await createTestClient(schema); }); afterEach(async () => { @@ -74,12 +72,6 @@ describe.each(createClientSpecs(PG_DB_NAME))('Read policy tests', ({ createClien await expect(anonClient.$qb.selectFrom('User').selectAll().executeTakeFirst()).toResolveFalsy(); const authClient = anonClient.$setAuth({ id: user.id }); - const foundUser = await authClient.$qb.selectFrom('User').selectAll().executeTakeFirstOrThrow(); - - if (typeof foundUser.createdAt === 'string') { - expect(Date.parse(foundUser.createdAt)).toEqual(user.createdAt.getTime()); - } else { - expect(foundUser.createdAt).toEqual(user.createdAt); - } + await expect(authClient.$qb.selectFrom('User').selectAll().executeTakeFirstOrThrow()).toResolveTruthy(); }); }); diff --git a/packages/runtime/test/policy/crud/create.test.ts b/packages/runtime/test/policy/crud/create.test.ts index dbd7a414..d5eb0657 100644 --- a/packages/runtime/test/policy/crud/create.test.ts +++ b/packages/runtime/test/policy/crud/create.test.ts @@ -273,4 +273,98 @@ model Profile { }, }); }); + + it('works with unnamed many-to-many relation', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + groups Group[] + private Boolean + @@allow('create,read', true) + @@allow('update', !private) +} + +model Group { + id Int @id + private Boolean + users User[] + @@allow('create,read', true) + @@allow('update', !private) +} + `, + { usePrismaPush: true }, + ); + + await expect( + db.user.create({ + data: { id: 1, private: false, groups: { create: [{ id: 1, private: false }] } }, + }), + ).toResolveTruthy(); + + await expect( + db.user.create({ + data: { id: 2, private: true, groups: { create: [{ id: 2, private: false }] } }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.user.create({ + data: { id: 2, private: false, groups: { create: [{ id: 2, private: true }] } }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.user.create({ + data: { id: 2, private: true, groups: { create: [{ id: 2, private: true }] } }, + }), + ).toBeRejectedByPolicy(); + }); + + it('works with named many-to-many relation', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + groups Group[] @relation("UserGroups") + private Boolean + @@allow('create,read', true) + @@allow('update', !private) +} + +model Group { + id Int @id + private Boolean + users User[] @relation("UserGroups") + @@allow('create,read', true) + @@allow('update', !private) +} + `, + { usePrismaPush: true }, + ); + + await expect( + db.user.create({ + data: { id: 1, private: false, groups: { create: [{ id: 1, private: false }] } }, + }), + ).toResolveTruthy(); + + await expect( + db.user.create({ + data: { id: 2, private: true, groups: { create: [{ id: 2, private: false }] } }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.user.create({ + data: { id: 2, private: false, groups: { create: [{ id: 2, private: true }] } }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.user.create({ + data: { id: 2, private: true, groups: { create: [{ id: 2, private: true }] } }, + }), + ).toBeRejectedByPolicy(); + }); }); diff --git a/packages/runtime/test/policy/crud/delete.test.ts b/packages/runtime/test/policy/crud/delete.test.ts new file mode 100644 index 00000000..f515f0dc --- /dev/null +++ b/packages/runtime/test/policy/crud/delete.test.ts @@ -0,0 +1,51 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from '../utils'; + +describe('Delete policy tests', () => { + it('works with top-level delete/deleteMany', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + @@allow('create,read', true) + @@allow('delete', x > 0) +} +`, + ); + + await db.foo.create({ data: { id: 1, x: 0 } }); + await expect(db.foo.delete({ where: { id: 1 } })).toBeRejectedNotFound(); + + await db.foo.create({ data: { id: 2, x: 1 } }); + await expect(db.foo.delete({ where: { id: 2 } })).toResolveTruthy(); + await expect(db.foo.count()).resolves.toBe(1); + + await db.foo.create({ data: { id: 3, x: 1 } }); + await expect(db.foo.deleteMany()).resolves.toMatchObject({ count: 1 }); + await expect(db.foo.count()).resolves.toBe(1); + }); + + it('works with query builder delete', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + @@allow('create,read', true) + @@allow('delete', x > 0) +} +`, + ); + await db.foo.create({ data: { id: 1, x: 0 } }); + await db.foo.create({ data: { id: 2, x: 1 } }); + + await expect(db.$qb.deleteFrom('Foo').where('id', '=', 1).executeTakeFirst()).resolves.toMatchObject({ + numDeletedRows: 0n, + }); + await expect(db.foo.count()).resolves.toBe(2); + + await expect(db.$qb.deleteFrom('Foo').executeTakeFirst()).resolves.toMatchObject({ numDeletedRows: 1n }); + await expect(db.foo.count()).resolves.toBe(1); + }); +}); diff --git a/packages/runtime/test/policy/crud/read.test.ts b/packages/runtime/test/policy/crud/read.test.ts new file mode 100644 index 00000000..46f4e38b --- /dev/null +++ b/packages/runtime/test/policy/crud/read.test.ts @@ -0,0 +1,695 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from '../utils'; + +describe('Read policy tests', () => { + describe('Find tests', () => { + it('works with top-level find', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + @@allow('create', true) + @@allow('read', x > 0) +} +`, + ); + + await db.$unuseAll().foo.create({ data: { id: 1, x: 0 } }); + await expect(db.foo.findUnique({ where: { id: 1 } })).toResolveNull(); + + await db.$unuseAll().foo.update({ where: { id: 1 }, data: { x: 1 } }); + await expect(db.foo.findUnique({ where: { id: 1 } })).toResolveTruthy(); + }); + + it('works with mutation read-back', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + @@allow('create,update', true) + @@allow('read', x > 0) +} +`, + ); + + await expect(db.foo.create({ data: { id: 1, x: 0 } })).toBeRejectedByPolicy(); + await expect(db.$unuseAll().foo.count()).resolves.toBe(1); + await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).resolves.toMatchObject({ x: 1 }); + }); + + it('works with to-one relation optional owner-side read', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + bar Bar? @relation(fields: [barId], references: [id]) + barId Int? @unique + @@allow('all', true) +} + +model Bar { + id Int @id + y Int + foo Foo? + @@allow('create,update', true) + @@allow('read', y > 0) +} +`, + ); + + await db.foo.create({ data: { id: 1, bar: { create: { id: 1, y: 0 } } } }); + await expect(db.foo.findFirst({ include: { bar: true } })).resolves.toMatchObject({ id: 1, bar: null }); + await db.bar.update({ where: { id: 1 }, data: { y: 1 } }); + await expect(db.foo.findFirst({ include: { bar: true } })).resolves.toMatchObject({ + id: 1, + bar: { id: 1 }, + }); + }); + + // TODO: check if we should be consistent with v2 and filter out the parent entity + // if a non-optional child relation is included but not readable + it('works with to-one relation non-optional owner-side read', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + bar Bar @relation(fields: [barId], references: [id]) + barId Int @unique + @@allow('all', true) +} + +model Bar { + id Int @id + y Int + foo Foo? + @@allow('create,update', true) + @@allow('read', y > 0) +} +`, + ); + + await db.foo.create({ data: { id: 1, bar: { create: { id: 1, y: 0 } } } }); + await expect(db.foo.findFirst({ include: { bar: true } })).resolves.toMatchObject({ id: 1, bar: null }); + await db.bar.update({ where: { id: 1 }, data: { y: 1 } }); + await expect(db.foo.findFirst({ include: { bar: true } })).resolves.toMatchObject({ + id: 1, + bar: { id: 1 }, + }); + }); + + it('works with to-one relation non-owner-side read', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + bar Bar? + @@allow('all', true) +} + +model Bar { + id Int @id + y Int + foo Foo @relation(fields: [fooId], references: [id]) + fooId Int @unique + @@allow('create,update', true) + @@allow('read', y > 0) +} +`, + ); + + await db.foo.create({ data: { id: 1, bar: { create: { id: 1, y: 0 } } } }); + await expect(db.foo.findFirst({ include: { bar: true } })).resolves.toMatchObject({ id: 1, bar: null }); + await db.bar.update({ where: { id: 1 }, data: { y: 1 } }); + await expect(db.foo.findFirst({ include: { bar: true } })).resolves.toMatchObject({ + id: 1, + bar: { id: 1 }, + }); + }); + + it('works with to-many relation read', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + bars Bar[] + @@allow('all', true) +} + +model Bar { + id Int @id + y Int + foo Foo? @relation(fields: [fooId], references: [id]) + fooId Int? + @@allow('create,update', true) + @@allow('read', y > 0) +} +`, + ); + + await db.foo.create({ + data: { + id: 1, + bars: { + create: [ + { id: 1, y: 0 }, + { id: 2, y: 1 }, + ], + }, + }, + }); + await expect(db.foo.findFirst({ include: { bars: true } })).resolves.toMatchObject({ + id: 1, + bars: [{ id: 2 }], + }); + }); + + it('works with unnamed many-to-many relation read', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + groups Group[] + @@allow('all', true) +} + +model Group { + id Int @id + private Boolean + users User[] + @@allow('read', !private) +} +`, + { usePrismaPush: true }, + ); + + await db.$unuseAll().user.create({ + data: { + id: 1, + groups: { + create: [ + { id: 1, private: true }, + { id: 2, private: false }, + ], + }, + }, + }); + await expect(db.user.findFirst({ include: { groups: true } })).resolves.toMatchObject({ + groups: [{ id: 2 }], + }); + await expect( + db.user.findFirst({ where: { id: 1 }, select: { _count: { select: { groups: true } } } }), + ).resolves.toMatchObject({ + _count: { groups: 1 }, + }); + }); + + it('works with named many-to-many relation read', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + groups Group[] @relation("UserGroups") + @@allow('all', true) +} + +model Group { + id Int @id + private Boolean + users User[] @relation("UserGroups") + @@allow('read', !private) +} +`, + { usePrismaPush: true }, + ); + + await db.$unuseAll().user.create({ + data: { + id: 1, + groups: { + create: [ + { id: 1, private: true }, + { id: 2, private: false }, + ], + }, + }, + }); + await expect(db.user.findFirst({ include: { groups: true } })).resolves.toMatchObject({ + groups: [{ id: 2 }], + }); + await expect( + db.user.findFirst({ where: { id: 1 }, select: { _count: { select: { groups: true } } } }), + ).resolves.toMatchObject({ + _count: { groups: 1 }, + }); + }); + + it('works with filtered by to-one relation field', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + bar Bar? @relation(fields: [barId], references: [id]) + barId Int? @unique + @@allow('create', true) + @@allow('read', bar.y > 0) +} + +model Bar { + id Int @id + y Int + foo Foo? + @@allow('all', true) +} +`, + ); + + await db.$unuseAll().foo.create({ data: { id: 1, bar: { create: { id: 1, y: 0 } } } }); + await expect(db.foo.findMany()).resolves.toHaveLength(0); + await db.bar.update({ where: { id: 1 }, data: { y: 1 } }); + await expect(db.foo.findMany()).resolves.toHaveLength(1); + }); + + it('works with filtered by to-one relation non-null', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + bar Bar? @relation(fields: [barId], references: [id]) + barId Int? @unique + @@allow('create,update', true) + @@allow('read', bar != null) + @@allow('read', this.bar != null) +} + +model Bar { + id Int @id + y Int + foo Foo? + @@allow('all', true) +} +`, + ); + + await db.$unuseAll().foo.create({ data: { id: 1 } }); + await expect(db.foo.findMany()).resolves.toHaveLength(0); + await db.foo.update({ where: { id: 1 }, data: { bar: { create: { id: 1, y: 0 } } } }); + await expect(db.foo.findMany()).resolves.toHaveLength(1); + }); + + it('works with filtered by to-many relation', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + bars Bar[] + @@allow('create,update', true) + @@allow('read', bars?[y > 0]) + @@allow('read', this.bars?[y > 0]) +} + +model Bar { + id Int @id + y Int + foo Foo? @relation(fields: [fooId], references: [id]) + fooId Int? + @@allow('all', true) +} +`, + ); + + await db.$unuseAll().foo.create({ data: { id: 1, bars: { create: [{ id: 1, y: 0 }] } } }); + await expect(db.foo.findMany()).resolves.toHaveLength(0); + await db.foo.update({ where: { id: 1 }, data: { bars: { create: { id: 2, y: 1 } } } }); + await expect(db.foo.findMany()).resolves.toHaveLength(1); + }); + + it('works with counting relations', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + bars Bar[] + @@allow('all', true) +} + +model Bar { + id Int @id + y Int + foo Foo? @relation(fields: [fooId], references: [id]) + fooId Int? + @@allow('create,update', true) + @@allow('read', y > 0) +} +`, + ); + + await db.$unuseAll().foo.create({ + data: { + id: 1, + bars: { + create: [ + { id: 1, y: 0 }, + { id: 2, y: 1 }, + ], + }, + }, + }); + await expect( + db.foo.findFirst({ where: { id: 1 }, select: { _count: { select: { bars: true } } } }), + ).resolves.toMatchObject({ _count: { bars: 1 } }); + }); + }); + + describe('Count tests', () => { + it('works with top-level count', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + name String + @@allow('create', true) + @@allow('read', x > 0) +} +`, + ); + + await db.$unuseAll().foo.create({ data: { id: 1, x: 0, name: 'Foo1' } }); + await db.$unuseAll().foo.create({ data: { id: 2, x: 0, name: 'Foo2' } }); + await expect(db.foo.count()).resolves.toBe(0); + await expect(db.foo.count({ select: { _all: true, name: true } })).resolves.toEqual({ _all: 0, name: 0 }); + + await db.$unuseAll().foo.update({ where: { id: 1 }, data: { x: 1 } }); + await expect(db.foo.count()).resolves.toBe(1); + await expect(db.foo.count({ select: { _all: true, name: true } })).resolves.toEqual({ _all: 1, name: 1 }); + }); + }); + + describe('Aggregate tests', () => { + it('respects read policies', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + @@allow('create', true) + @@allow('read', x > 0) +} +`, + ); + + await db.$unuseAll().foo.create({ data: { id: 1, x: 0 } }); + await db.$unuseAll().foo.create({ data: { id: 2, x: 1 } }); + await db.$unuseAll().foo.create({ data: { id: 3, x: 3 } }); + + await expect( + db.foo.aggregate({ + _count: true, + _sum: { x: true }, + _avg: { x: true }, + _min: { x: true }, + _max: { x: true }, + }), + ).resolves.toEqual({ + _count: 2, + _sum: { x: 4 }, + _avg: { x: 2 }, + _min: { x: 1 }, + _max: { x: 3 }, + }); + }); + }); + + describe('GroupBy tests', () => { + it('respects read policies', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + y Int + @@allow('create', true) + @@allow('read', x > 0) +} +`, + ); + + await db.$unuseAll().foo.create({ data: { id: 1, x: 0, y: 1 } }); + await db.$unuseAll().foo.create({ data: { id: 2, x: 1, y: 1 } }); + await db.$unuseAll().foo.create({ data: { id: 3, x: 3, y: 2 } }); + await db.$unuseAll().foo.create({ data: { id: 4, x: 5, y: 2 } }); + + await expect( + db.foo.groupBy({ + by: ['y'], + _count: { _all: true }, + _sum: { x: true }, + _avg: { x: true }, + _min: { x: true }, + _max: { x: true }, + orderBy: { y: 'asc' }, + }), + ).resolves.toEqual([ + { + y: 1, + _count: { _all: 1 }, + _sum: { x: 1 }, + _avg: { x: 1 }, + _min: { x: 1 }, + _max: { x: 1 }, + }, + { + y: 2, + _count: { _all: 2 }, + _sum: { x: 8 }, + _avg: { x: 4 }, + _min: { x: 3 }, + _max: { x: 5 }, + }, + ]); + }); + }); + + describe('Query builder tests', () => { + it('works with simple selects', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + @@allow('create', true) + @@allow('read', x > 0) +} +`, + ); + + await db.$unuseAll().foo.create({ data: { id: 1, x: 0 } }); + await db.$unuseAll().foo.create({ data: { id: 2, x: 1 } }); + + await expect(db.$qb.selectFrom('Foo').selectAll().execute()).resolves.toHaveLength(1); + await expect(db.$qb.selectFrom('Foo as f').selectAll().execute()).resolves.toHaveLength(1); + await expect(db.$qb.selectFrom('Foo').selectAll().execute()).resolves.toHaveLength(1); + await expect(db.$qb.selectFrom('Foo').where('id', '=', 1).selectAll().execute()).resolves.toHaveLength(0); + + // nested query + await expect( + db.$qb + .selectFrom((eb: any) => eb.selectFrom('Foo').selectAll().as('f')) + .selectAll() + .execute(), + ).resolves.toHaveLength(1); + await expect( + db.$qb + .selectFrom((eb: any) => eb.selectFrom('Foo').selectAll().as('f')) + .selectAll() + .where('f.id', '=', 1) + .execute(), + ).resolves.toHaveLength(0); + }); + + it('works with joins', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + bars Bar[] + @@allow('create', true) + @@allow('read', x > 0) +} + +model Bar { + id Int @id + y Int + foo Foo? @relation(fields: [fooId], references: [id]) + fooId Int? + @@allow('create', true) + @@allow('read', y > 0) +} +`, + ); + + await db.$unuseAll().foo.create({ + data: { + id: 1, + x: 1, + bars: { + create: [ + { id: 1, y: 0 }, + { id: 2, y: 1 }, + ], + }, + }, + }); + await db.$unuseAll().foo.create({ + data: { + id: 2, + x: 0, + bars: { + create: { id: 3, y: 1 }, + }, + }, + }); + + // direct join + await expect( + db.$qb.selectFrom('Foo').innerJoin('Bar', 'Bar.fooId', 'Foo.id').select(['Foo.id', 'x', 'y']).execute(), + ).resolves.toEqual([expect.objectContaining({ id: 1, x: 1, y: 1 })]); + + // through alias + await expect( + db.$qb + .selectFrom('Foo as f') + .innerJoin( + (eb: any) => eb.selectFrom('Bar').selectAll().as('b'), + (join: any) => join.onRef('b.fooId', '=', 'f.id'), + ) + .select(['f.id', 'x', 'y']) + .execute(), + ).resolves.toEqual([expect.objectContaining({ id: 1, x: 1, y: 1 })]); + }); + + it('works with implicit cross join', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + @@allow('create', true) + @@allow('read', x > 0) +} + +model Bar { + id Int @id + y Int + @@allow('create', true) + @@allow('read', y > 0) +} +`, + { provider: 'postgresql', dbName: 'policy-test-implicit-cross-join' }, + ); + + await db.$unuseAll().foo.createMany({ + data: [ + { id: 1, x: 1 }, + { id: 2, x: 0 }, + ], + }); + await db.$unuseAll().bar.createMany({ + data: [ + { id: 1, y: 1 }, + { id: 2, y: 0 }, + ], + }); + + await expect( + db.$qb.selectFrom(['Foo', 'Bar']).select(['Foo.id as fooId', 'Bar.id as barId', 'x', 'y']).execute(), + ).resolves.toEqual([ + { + fooId: 1, + barId: 1, + x: 1, + y: 1, + }, + ]); + }); + + it('works with update from', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + @@allow('all', true) +} + +model Bar { + id Int @id + y Int + @@allow('read', y > 0) +} +`, + ); + + await db.$unuseAll().foo.create({ data: { id: 1, x: 1 } }); + await db.$unuseAll().bar.create({ data: { id: 1, y: 0 } }); + + // update with from, only one row is visible + await expect( + db.$qb + .updateTable('Foo') + .from('Bar as bar') + .whereRef('Foo.id', '=', 'bar.id') + .set((eb: any) => ({ x: eb.ref('bar.y') })) + .executeTakeFirst(), + ).resolves.toMatchObject({ numUpdatedRows: 0n }); + await expect(db.foo.findFirst()).resolves.toMatchObject({ x: 1 }); + + await db.$unuseAll().bar.update({ where: { id: 1 }, data: { y: 2 } }); + await expect( + db.$qb + .updateTable('Foo') + .from('Bar as bar') + .whereRef('Foo.id', '=', 'bar.id') + .set((eb: any) => ({ x: eb.ref('bar.y') })) + .executeTakeFirst(), + ).resolves.toMatchObject({ numUpdatedRows: 1n }); + await expect(db.foo.findFirst()).resolves.toMatchObject({ x: 2 }); + }); + + it('works with delete using', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + @@allow('all', true) +} + +model Bar { + id Int @id + y Int + @@allow('read', y > 0) +} +`, + { provider: 'postgresql', dbName: 'policy-test-delete-using' }, + ); + + await db.$unuseAll().foo.create({ data: { id: 1, x: 1 } }); + await db.$unuseAll().bar.create({ data: { id: 1, y: 0 } }); + + await expect( + db.$qb.deleteFrom('Foo').using('Bar as bar').whereRef('Foo.id', '=', 'bar.id').executeTakeFirst(), + ).resolves.toMatchObject({ numDeletedRows: 0n }); + await expect(db.foo.findFirst()).resolves.toBeTruthy(); + + await db.$unuseAll().bar.update({ where: { id: 1 }, data: { y: 2 } }); + await expect( + db.$qb.deleteFrom('Foo').using('Bar as bar').whereRef('Foo.id', '=', 'bar.id').executeTakeFirst(), + ).resolves.toMatchObject({ numDeletedRows: 1n }); + await expect(db.foo.findFirst()).resolves.toBeNull(); + }); + }); +}); diff --git a/packages/runtime/test/policy/crud/update.test.ts b/packages/runtime/test/policy/crud/update.test.ts index e0082a49..c092682b 100644 --- a/packages/runtime/test/policy/crud/update.test.ts +++ b/packages/runtime/test/policy/crud/update.test.ts @@ -156,7 +156,7 @@ model Profile { }); }); - it('works with to-one relation check owner side', async () => { + it('works with to-one relation check non-owner side', async () => { const db = await createPolicyTestClient( ` model User { @@ -338,10 +338,222 @@ model Post { }); await expect(db.user.update({ where: { id: 3 }, data: { name: 'UpdatedUser3' } })).toResolveTruthy(); }); + + it('works with unnamed many-to-many relation check', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + name String + groups Group[] + @@allow('create,read', true) + @@allow('update', groups?[!private]) +} + +model Group { + id Int @id + private Boolean + members User[] + @@allow('all', true) +} +`, + { usePrismaPush: true }, + ); + + await db.$unuseAll().user.create({ + data: { + id: 1, + name: 'User1', + groups: { + create: [ + { id: 1, private: true }, + { id: 2, private: false }, + ], + }, + }, + }); + + await expect(db.user.update({ where: { id: 1 }, data: { name: 'User2' } })).toResolveTruthy(); + + await db.$unuseAll().group.update({ where: { id: 2 }, data: { private: true } }); + // not satisfying update policy anymore + await expect(db.user.update({ where: { id: 1 }, data: { name: 'User3' } })).toBeRejectedNotFound(); + }); + + it('works with named many-to-many relation check', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + name String + groups Group[] @relation("UserGroups") + @@allow('create,read', true) + @@allow('update', groups?[!private]) +} + +model Group { + id Int @id + private Boolean + members User[] @relation("UserGroups") + @@allow('all', true) +} +`, + { usePrismaPush: true }, + ); + + await db.$unuseAll().user.create({ + data: { + id: 1, + name: 'User1', + groups: { + create: [ + { id: 1, private: true }, + { id: 2, private: false }, + ], + }, + }, + }); + + await expect(db.user.update({ where: { id: 1 }, data: { name: 'User2' } })).toResolveTruthy(); + + await db.$unuseAll().group.update({ where: { id: 2 }, data: { private: true } }); + // not satisfying update policy anymore + await expect(db.user.update({ where: { id: 1 }, data: { name: 'User3' } })).toBeRejectedNotFound(); + }); + }); + + describe('Nested create tests', () => { + it('works with nested create non-owner side', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + profile Profile? + @@allow('all', true) +} + +model Profile { + id Int @id + user User? @relation(fields: [userId], references: [id]) + userId Int? @unique + @@allow('create', user.id == auth().id) + @@allow('read', true) +} + `, + ); + + await db.user.create({ data: { id: 1 } }); + await expect( + db.user.update({ where: { id: 1 }, data: { profile: { create: { id: 1 } } } }), + ).toBeRejectedByPolicy(); + await expect( + db.$setAuth({ id: 1 }).user.update({ + where: { id: 1 }, + data: { profile: { create: { id: 1 } } }, + include: { profile: true }, + }), + ).resolves.toMatchObject({ + profile: { + id: 1, + }, + }); + }); + + it('works with nested create owner side', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + profile Profile? @relation(fields: [profileId], references: [id]) + profileId Int? @unique + @@allow('create,read', true) + @@allow('update', auth() == this) +} + +model Profile { + id Int @id + user User? + @@allow('all', true) +} +`, + ); + + await db.user.create({ data: { id: 1 } }); + await expect( + db.user.update({ where: { id: 1 }, data: { profile: { create: { id: 1 } } } }), + ).toBeRejectedNotFound(); + await expect( + db.$setAuth({ id: 1 }).user.update({ + where: { id: 1 }, + data: { profile: { create: { id: 1 } } }, + include: { profile: true }, + }), + ).resolves.toMatchObject({ + profile: { + id: 1, + }, + }); + }); + + it('works with nested create many', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + posts Post[] + @@allow('all', true) +} + +model Post { + id Int @id + title String + user User @relation(fields: [userId], references: [id]) + userId Int + @@allow('read', true) + @@allow('create', auth() == this.user) +} +`, + ); + + await db.user.create({ data: { id: 1 } }); + await expect( + db.user.update({ + where: { id: 1 }, + data: { + posts: { + createMany: { + data: [ + { id: 1, title: 'Post1' }, + { id: 2, title: 'Post2' }, + ], + }, + }, + }, + }), + ).toBeRejectedByPolicy(); + await expect( + db.$setAuth({ id: 1 }).user.update({ + where: { id: 1 }, + data: { + posts: { + createMany: { + data: [ + { id: 1, title: 'Post1' }, + { id: 2, title: 'Post2' }, + ], + }, + }, + }, + include: { posts: true }, + }), + ).resolves.toMatchObject({ + posts: [{ id: 1 }, { id: 2 }], + }); + }); }); describe('Nested update tests', () => { - it('works with nested update owner side', async () => { + it('works with nested update non-owner side', async () => { const db = await createPolicyTestClient( ` model User { @@ -384,7 +596,7 @@ model Profile { }); }); - it('works with nested update non-owner side', async () => { + it('works with nested update owner side', async () => { const db = await createPolicyTestClient( ` model User { @@ -426,6 +638,188 @@ model Profile { }, }); }); + + it('works with nested update many', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + posts Post[] + @@allow('all', true) +} + +model Post { + id Int @id + title String + private Boolean + user User @relation(fields: [userId], references: [id]) + userId Int + @@allow('create,read', true) + @@allow('update', !private) +} +`, + ); + + await db.user.create({ + data: { + id: 1, + posts: { + create: [ + { id: 1, title: 'Post 1', private: true }, + { id: 2, title: 'Post 2', private: false }, + ], + }, + }, + }); + await expect( + db.user.update({ + where: { id: 1 }, + data: { + posts: { + updateMany: { + where: { title: { contains: 'Post' } }, + data: { title: 'Updated Title' }, + }, + }, + }, + include: { posts: true }, + }), + ).resolves.toMatchObject({ + posts: [{ title: 'Post 1' }, { title: 'Updated Title' }], + }); + }); + + it('works with nested upsert', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + posts Post[] + @@allow('all', true) +} + +model Post { + id Int @id + title String + user User @relation(fields: [userId], references: [id]) + userId Int + @@allow('read', true) + @@allow('create', contains(title, 'Foo')) + @@allow('update', contains(title, 'Bar')) +} +`, + ); + + await db.user.create({ data: { id: 1 } }); + // can't create + await expect( + db.user.update({ + where: { id: 1 }, + data: { + posts: { + upsert: { + where: { id: 1 }, + create: { id: 1, title: 'Post1' }, + update: { title: 'Post1' }, + }, + }, + }, + }), + ).toBeRejectedByPolicy(); + // can create + await expect( + db.user.update({ + where: { id: 1 }, + data: { + posts: { + upsert: { + where: { id: 1 }, + create: { id: 1, title: 'Foo Post' }, + update: { title: 'Post1' }, + }, + }, + }, + include: { posts: true }, + }), + ).resolves.toMatchObject({ + posts: [{ id: 1, title: 'Foo Post' }], + }); + // can't update + await expect( + db.user.update({ + where: { id: 1 }, + data: { + posts: { + upsert: { + where: { id: 1 }, + create: { id: 1, title: 'Foo Post' }, + update: { title: 'Post1' }, + }, + }, + }, + }), + ).rejects.toThrow('constraint'); + await db.$unuseAll().post.update({ where: { id: 1 }, data: { title: 'Bar Post' } }); + // can update + await expect( + db.user.update({ + where: { id: 1 }, + data: { + posts: { + upsert: { + where: { id: 1 }, + create: { id: 1, title: 'Foo Post' }, + update: { title: 'Bar Updated' }, + }, + }, + }, + include: { posts: true }, + }), + ).resolves.toMatchObject({ + posts: [{ id: 1, title: 'Bar Updated' }], + }); + }); + }); + + describe('Nested delete tests', () => { + it('works with nested delete non-owner side', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + profile Profile? + @@allow('all', true) +} + +model Profile { + id Int @id + private Boolean + user User? @relation(fields: [userId], references: [id]) + userId Int? @unique + @@allow('create,read', true) + @@allow('delete', !private) +} +`, + ); + + await db.user.create({ data: { id: 1, profile: { create: { id: 1, private: true } } } }); + await expect( + db.user.update({ + where: { id: 1 }, + data: { profile: { delete: true } }, + }), + ).toBeRejectedNotFound(); + + await db.user.create({ data: { id: 2, profile: { create: { id: 2, private: false } } } }); + await expect( + db.user.update({ + where: { id: 2 }, + data: { profile: { delete: true } }, + include: { profile: true }, + }), + ).resolves.toMatchObject({ profile: null }); + await expect(db.profile.findUnique({ where: { id: 2 } })).resolves.toBeNull(); + }); }); describe('Relation manipulation tests', () => { @@ -576,9 +970,294 @@ model Profile { }), ).toResolveTruthy(); }); + + it('works with many-to-many relation manipulation', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + private Boolean + groups Group[] @relation("UserGroups") + @@allow('create,read', true) + @@allow('update,delete', !private) +} + +model Group { + id Int @id + private Boolean + members User[] @relation("UserGroups") + @@allow('create,read', true) + @@allow('update,delete', !private) +} +`, + { usePrismaPush: true }, + ); + + await db.$unuseAll().user.create({ data: { id: 1, private: true } }); + await db.$unuseAll().user.create({ data: { id: 2, private: false } }); + + // user not updatable + await expect( + db.user.update({ where: { id: 1 }, data: { groups: { create: { id: 1, private: false } } } }), + ).toBeRejectedByPolicy(); + + // group not updatable + await expect( + db.user.update({ where: { id: 2 }, data: { groups: { create: { id: 1, private: true } } } }), + ).toBeRejectedByPolicy(); + + // both updatable + await expect( + db.user.update({ + where: { id: 2 }, + data: { groups: { create: { id: 1, private: false } } }, + include: { groups: true }, + }), + ).toResolveTruthy(); + + // disconnect + await expect( + db.user.update({ where: { id: 2 }, data: { groups: { disconnect: { id: 1 } } } }), + ).toResolveTruthy(); + + // set + await expect( + db.user.update({ where: { id: 2 }, data: { groups: { set: [{ id: 1 }] } } }), + ).toResolveTruthy(); + + // delete + await expect( + db.user.update({ where: { id: 2 }, data: { groups: { delete: { id: 1 } } } }), + ).toResolveTruthy(); + + // recreate group as private + await db.$unuseAll().group.create({ data: { id: 2, private: true } }); + + // connect rejected + await expect( + db.user.update({ where: { id: 2 }, data: { groups: { connect: { id: 2 } } } }), + ).toBeRejectedByPolicy(); + + // disconnect rejected + await db.$unuseAll().user.update({ where: { id: 2 }, data: { groups: { connect: { id: 2 } } } }); + await expect( + db.user.update({ + where: { id: 2 }, + data: { groups: { disconnect: { id: 2 } } }, + include: { groups: true }, + }), + ).resolves.toMatchObject({ + groups: [{ id: 2 }], // verify not disconnected + }); + + // delete rejected + await expect( + db.user.update({ + where: { id: 2 }, + data: { groups: { delete: { id: 2 } } }, + include: { groups: true }, + }), + ).toBeRejectedNotFound(); + + // set rejected + await expect( + db.user.update({ + where: { id: 2 }, + data: { groups: { set: [] } }, + include: { groups: true }, + }), + ).resolves.toMatchObject({ + groups: [{ id: 2 }], // verify not disconnected + }); + + await db.$unuseAll().group.update({ where: { id: 2 }, data: { private: false } }); + await db.$unuseAll().group.create({ data: { id: 3, private: true } }); + + // set rejected + await expect( + db.user.update({ + where: { id: 2 }, + data: { groups: { set: [{ id: 3 }] } }, + include: { groups: true }, + }), + ).toBeRejectedByPolicy(); + + // relation unchanged + await expect(db.user.findUnique({ where: { id: 2 }, include: { groups: true } })).resolves.toMatchObject({ + groups: [{ id: 2 }], + }); + + // set success + await db.$unuseAll().group.update({ where: { id: 3 }, data: { private: false } }); + await expect( + db.user.update({ + where: { id: 2 }, + data: { groups: { set: [{ id: 3 }] } }, + include: { groups: true }, + }), + ).resolves.toMatchObject({ + groups: [{ id: 3 }], + }); + }); + }); + + describe('Upsert tests', () => { + it('works with upsert', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + @@allow('create', x > 0) + @@allow('update', x > 1) + @@allow('read', true) +} +`, + ); + // can't create + await expect( + db.foo.upsert({ where: { id: 1 }, create: { id: 1, x: 0 }, update: { x: 2 } }), + ).toBeRejectedByPolicy(); + await expect( + db.foo.upsert({ where: { id: 1 }, create: { id: 1, x: 1 }, update: { x: 2 } }), + ).resolves.toMatchObject({ x: 1 }); + // can't update, but create violates unique constraint + await expect( + db.foo.upsert({ where: { id: 1 }, create: { id: 1, x: 1 }, update: { x: 1 } }), + ).rejects.toThrow('constraint'); + await db.$unuseAll().foo.update({ where: { id: 1 }, data: { x: 2 } }); + // can update now + await expect( + db.foo.upsert({ where: { id: 1 }, create: { id: 1, x: 1 }, update: { x: 3 } }), + ).resolves.toMatchObject({ x: 3 }); + }); + }); + + describe('Update many tests', () => { + it('works with update many', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + @@allow('create', true) + @@allow('update', x > 1) + @@allow('read', true) +} +`, + ); + + await db.foo.createMany({ + data: [ + { id: 1, x: 1 }, + { id: 2, x: 2 }, + { id: 3, x: 3 }, + ], + }); + await expect(db.foo.updateMany({ data: { x: 5 } })).resolves.toMatchObject({ count: 2 }); + await expect(db.foo.findMany()).resolves.toEqual( + expect.arrayContaining([ + { id: 1, x: 1 }, + { id: 2, x: 5 }, + { id: 3, x: 5 }, + ]), + ); + }); }); - // describe('Upsert tests', () => {}); + describe('Query builder tests', () => { + it('works with simple update', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + @@allow('create', true) + @@allow('update', x > 1) + @@allow('read', true) +} +`, + ); + + await db.foo.createMany({ + data: [ + { id: 1, x: 1 }, + { id: 2, x: 2 }, + { id: 3, x: 3 }, + ], + }); - // describe('Update many tests', () => {}); + // not updatable + await expect( + db.$qb.updateTable('Foo').set({ x: 5 }).where('id', '=', 1).executeTakeFirst(), + ).resolves.toMatchObject({ numUpdatedRows: 0n }); + + // with where + await expect( + db.$qb.updateTable('Foo').set({ x: 5 }).where('id', '=', 2).executeTakeFirst(), + ).resolves.toMatchObject({ numUpdatedRows: 1n }); + await expect(db.foo.findUnique({ where: { id: 2 } })).resolves.toMatchObject({ x: 5 }); + + // without where + await expect(db.$qb.updateTable('Foo').set({ x: 6 }).executeTakeFirst()).resolves.toMatchObject({ + numUpdatedRows: 2n, + }); + await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 1 }); + }); + + it('works with insert on conflict do update', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + @@allow('create', true) + @@allow('update', x > 1) + @@allow('read', true) +} +`, + ); + + await db.foo.createMany({ + data: [ + { id: 1, x: 1 }, + { id: 2, x: 2 }, + { id: 3, x: 3 }, + ], + }); + + // #1 not updatable + await expect( + db.$qb + .insertInto('Foo') + .values({ id: 1, x: 5 }) + .onConflict((oc: any) => oc.column('id').doUpdateSet({ x: 5 })) + .executeTakeFirst(), + ).resolves.toMatchObject({ numInsertedOrUpdatedRows: 0n }); + await expect(db.foo.count()).resolves.toBe(3); + await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 1 }); + + // with where, #1 not updatable + await expect( + db.$qb + .insertInto('Foo') + .values({ id: 1, x: 5 }) + .onConflict((oc: any) => oc.column('id').doUpdateSet({ x: 5 }).where('Foo.id', '=', 1)) + .executeTakeFirst(), + ).resolves.toMatchObject({ numInsertedOrUpdatedRows: 0n }); + await expect(db.foo.count()).resolves.toBe(3); + await expect(db.foo.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ x: 1 }); + + // with where, #2 updatable + await expect( + db.$qb + .insertInto('Foo') + .values({ id: 2, x: 5 }) + .onConflict((oc: any) => oc.column('id').doUpdateSet({ x: 6 }).where('Foo.id', '=', 2)) + .executeTakeFirst(), + ).resolves.toMatchObject({ numInsertedOrUpdatedRows: 1n }); + await expect(db.foo.count()).resolves.toBe(3); + await expect(db.foo.findUnique({ where: { id: 2 } })).resolves.toMatchObject({ x: 6 }); + }); + }); }); diff --git a/packages/runtime/test/policy/auth.test.ts b/packages/runtime/test/policy/migrated/auth.test.ts similarity index 99% rename from packages/runtime/test/policy/auth.test.ts rename to packages/runtime/test/policy/migrated/auth.test.ts index f00c79e7..d075e3d7 100644 --- a/packages/runtime/test/policy/auth.test.ts +++ b/packages/runtime/test/policy/migrated/auth.test.ts @@ -1,5 +1,5 @@ import { describe, expect, it } from 'vitest'; -import { createPolicyTestClient } from './utils'; +import { createPolicyTestClient } from '../utils'; describe('auth() tests', () => { it('works with string id non-null test', async () => { @@ -536,7 +536,7 @@ model Post { ); await expect(db.user.create({ data: { id: 'userId-1' } })).toResolveTruthy(); - await expect(db.post.create({ data: { title: 'title' } })).rejects.toThrow('constraint failed'); + await expect(db.post.create({ data: { title: 'title' } })).rejects.toThrow('constraint'); await expect(db.post.findMany({})).toResolveTruthy(); }); diff --git a/packages/runtime/test/policy/client-extensions.test.ts b/packages/runtime/test/policy/migrated/client-extensions.test.ts similarity index 97% rename from packages/runtime/test/policy/client-extensions.test.ts rename to packages/runtime/test/policy/migrated/client-extensions.test.ts index 1f725172..16692543 100644 --- a/packages/runtime/test/policy/client-extensions.test.ts +++ b/packages/runtime/test/policy/migrated/client-extensions.test.ts @@ -1,6 +1,6 @@ import { describe, expect, it } from 'vitest'; -import { definePlugin } from '../../src/client'; -import { createPolicyTestClient } from './utils'; +import { definePlugin } from '../../../src/client'; +import { createPolicyTestClient } from '../utils'; describe('client extensions tests for policies', () => { it('query override one model', async () => { diff --git a/packages/runtime/test/policy/connect-disconnect.test.ts b/packages/runtime/test/policy/migrated/connect-disconnect.test.ts similarity index 99% rename from packages/runtime/test/policy/connect-disconnect.test.ts rename to packages/runtime/test/policy/migrated/connect-disconnect.test.ts index d6e30128..02d8e04e 100644 --- a/packages/runtime/test/policy/connect-disconnect.test.ts +++ b/packages/runtime/test/policy/migrated/connect-disconnect.test.ts @@ -1,5 +1,5 @@ import { describe, expect, it } from 'vitest'; -import { createPolicyTestClient } from './utils'; +import { createPolicyTestClient } from '../utils'; describe('connect and disconnect tests', () => { const modelToMany = ` diff --git a/packages/runtime/test/policy/create-many-and-return.test.ts b/packages/runtime/test/policy/migrated/create-many-and-return.test.ts similarity index 98% rename from packages/runtime/test/policy/create-many-and-return.test.ts rename to packages/runtime/test/policy/migrated/create-many-and-return.test.ts index 97829ce4..1df0e5b6 100644 --- a/packages/runtime/test/policy/create-many-and-return.test.ts +++ b/packages/runtime/test/policy/migrated/create-many-and-return.test.ts @@ -1,5 +1,5 @@ import { describe, expect, it } from 'vitest'; -import { createPolicyTestClient } from './utils'; +import { createPolicyTestClient } from '../utils'; describe('createManyAndReturn tests', () => { it('works with model-level policies', async () => { diff --git a/packages/runtime/test/policy/cross-model-field-comparison.test.ts b/packages/runtime/test/policy/migrated/cross-model-field-comparison.test.ts similarity index 99% rename from packages/runtime/test/policy/cross-model-field-comparison.test.ts rename to packages/runtime/test/policy/migrated/cross-model-field-comparison.test.ts index 146992a2..f0a35f79 100644 --- a/packages/runtime/test/policy/cross-model-field-comparison.test.ts +++ b/packages/runtime/test/policy/migrated/cross-model-field-comparison.test.ts @@ -1,5 +1,5 @@ import { describe, expect, it } from 'vitest'; -import { createPolicyTestClient } from './utils'; +import { createPolicyTestClient } from '../utils'; describe('cross-model field comparison tests', () => { it('works with to-one relation', async () => { diff --git a/packages/runtime/test/policy/current-model.test.ts b/packages/runtime/test/policy/migrated/current-model.test.ts similarity index 99% rename from packages/runtime/test/policy/current-model.test.ts rename to packages/runtime/test/policy/migrated/current-model.test.ts index 024e658f..61ea1d24 100644 --- a/packages/runtime/test/policy/current-model.test.ts +++ b/packages/runtime/test/policy/migrated/current-model.test.ts @@ -1,5 +1,5 @@ import { describe, it, expect } from 'vitest'; -import { createPolicyTestClient } from './utils'; +import { createPolicyTestClient } from '../utils'; describe('currentModel tests', () => { it('works in models', async () => { diff --git a/packages/runtime/test/policy/current-operation.test.ts b/packages/runtime/test/policy/migrated/current-operation.test.ts similarity index 98% rename from packages/runtime/test/policy/current-operation.test.ts rename to packages/runtime/test/policy/migrated/current-operation.test.ts index 957f8779..42d67939 100644 --- a/packages/runtime/test/policy/current-operation.test.ts +++ b/packages/runtime/test/policy/migrated/current-operation.test.ts @@ -1,5 +1,5 @@ import { describe, it, expect } from 'vitest'; -import { createPolicyTestClient } from './utils'; +import { createPolicyTestClient } from '../utils'; describe('currentOperation tests', () => { it('works with specific rules', async () => { diff --git a/packages/runtime/test/policy/deep-nested.test.ts b/packages/runtime/test/policy/migrated/deep-nested.test.ts similarity index 99% rename from packages/runtime/test/policy/deep-nested.test.ts rename to packages/runtime/test/policy/migrated/deep-nested.test.ts index a35e34b8..a88134ce 100644 --- a/packages/runtime/test/policy/deep-nested.test.ts +++ b/packages/runtime/test/policy/migrated/deep-nested.test.ts @@ -1,5 +1,5 @@ import { beforeEach, describe, expect, it } from 'vitest'; -import { createPolicyTestClient } from './utils'; +import { createPolicyTestClient } from '../utils'; describe('deep nested operations tests', () => { const model = ` @@ -482,7 +482,7 @@ describe('deep nested operations tests', () => { }, }, }), - ).rejects.toThrow('constraint failed'); + ).rejects.toThrow('constraint'); // createMany skip duplicate await db.m1.update({ diff --git a/packages/runtime/test/policy/empty-policy.test.ts b/packages/runtime/test/policy/migrated/empty-policy.test.ts similarity index 98% rename from packages/runtime/test/policy/empty-policy.test.ts rename to packages/runtime/test/policy/migrated/empty-policy.test.ts index 432454c1..452845b3 100644 --- a/packages/runtime/test/policy/empty-policy.test.ts +++ b/packages/runtime/test/policy/migrated/empty-policy.test.ts @@ -1,5 +1,5 @@ import { describe, expect, it } from 'vitest'; -import { createPolicyTestClient } from './utils'; +import { createPolicyTestClient } from '../utils'; describe('empty policy tests', () => { it('works with simple operations', async () => { diff --git a/packages/runtime/test/policy/field-comparison.test.ts b/packages/runtime/test/policy/migrated/field-comparison.test.ts similarity index 98% rename from packages/runtime/test/policy/field-comparison.test.ts rename to packages/runtime/test/policy/migrated/field-comparison.test.ts index 1d8e3cdf..1bf33c37 100644 --- a/packages/runtime/test/policy/field-comparison.test.ts +++ b/packages/runtime/test/policy/migrated/field-comparison.test.ts @@ -1,5 +1,5 @@ import { describe, expect, it } from 'vitest'; -import { createPolicyTestClient } from './utils'; +import { createPolicyTestClient } from '../utils'; describe('field comparison tests', () => { it('works with policies involving field comparison', async () => { diff --git a/packages/runtime/test/policy/multi-field-unique.test.ts b/packages/runtime/test/policy/migrated/multi-field-unique.test.ts similarity index 97% rename from packages/runtime/test/policy/multi-field-unique.test.ts rename to packages/runtime/test/policy/migrated/multi-field-unique.test.ts index 029bdaeb..7edbe019 100644 --- a/packages/runtime/test/policy/multi-field-unique.test.ts +++ b/packages/runtime/test/policy/migrated/multi-field-unique.test.ts @@ -1,9 +1,9 @@ import path from 'path'; import { afterEach, beforeAll, describe, expect, it } from 'vitest'; -import { createPolicyTestClient } from './utils'; -import { QueryError } from '../../src'; +import { createPolicyTestClient } from '../utils'; +import { QueryError } from '../../../src'; -describe('With Policy: multi-field unique', () => { +describe('Policy tests multi-field unique', () => { let origDir: string; beforeAll(async () => { diff --git a/packages/runtime/test/policy/migrated/multi-id-fields.test.ts b/packages/runtime/test/policy/migrated/multi-id-fields.test.ts new file mode 100644 index 00000000..56941f03 --- /dev/null +++ b/packages/runtime/test/policy/migrated/multi-id-fields.test.ts @@ -0,0 +1,395 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from '../utils'; + +describe('Policy tests multiple id fields', () => { + it('multi-id fields crud', async () => { + const db = await createPolicyTestClient( + ` + model A { + x String + y Int + value Int + b B? + @@id([x, y]) + + @@allow('read', true) + @@allow('create', value > 0) + } + + model B { + b1 String + b2 String + value Int + a A @relation(fields: [ax, ay], references: [x, y]) + ax String + ay Int + + @@allow('read', value > 2) + @@allow('create', value > 1) + + @@unique([ax, ay]) + @@id([b1, b2]) + } + `, + ); + + await expect(db.a.create({ data: { x: '1', y: 1, value: 0 } })).toBeRejectedByPolicy(); + await expect(db.a.create({ data: { x: '1', y: 2, value: 1 } })).toResolveTruthy(); + + await expect( + db.a.create({ data: { x: '2', y: 1, value: 1, b: { create: { b1: '1', b2: '2', value: 1 } } } }), + ).toBeRejectedByPolicy(); + + const r = await db.a.create({ + include: { b: true }, + data: { x: '2', y: 1, value: 1, b: { create: { b1: '1', b2: '2', value: 2 } } }, + }); + expect(r.b).toBeNull(); + + const r1 = await db.$unuseAll().b.findUnique({ where: { b1_b2: { b1: '1', b2: '2' } } }); + expect(r1.value).toBe(2); + + await expect( + db.a.create({ + include: { b: true }, + data: { x: '3', y: 1, value: 1, b: { create: { b1: '2', b2: '2', value: 3 } } }, + }), + ).toResolveTruthy(); + }); + + // TODO: `future()` support + it.skip('multi-id fields id update', async () => { + const db = await createPolicyTestClient( + ` + model A { + x String + y Int + value Int + b B? + @@id([x, y]) + + @@allow('read', true) + @@allow('create', value > 0) + @@allow('update', value > 0 && future().value > 1) + } + + model B { + b1 String + b2 String + value Int + a A @relation(fields: [ax, ay], references: [x, y]) + ax String + ay Int + + @@allow('read', value > 2) + @@allow('create', value > 1) + + @@unique([ax, ay]) + @@id([b1, b2]) + } + `, + ); + + await db.a.create({ data: { x: '1', y: 2, value: 1 } }); + + await expect( + db.a.update({ where: { x_y: { x: '1', y: 2 } }, data: { x: '2', y: 3, value: 0 } }), + ).toBeRejectedByPolicy(); + + await expect( + db.a.update({ where: { x_y: { x: '1', y: 2 } }, data: { x: '2', y: 3, value: 2 } }), + ).resolves.toMatchObject({ + x: '2', + y: 3, + value: 2, + }); + + await expect( + db.a.upsert({ + where: { x_y: { x: '2', y: 3 } }, + update: { x: '3', y: 4, value: 0 }, + create: { x: '4', y: 5, value: 5 }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.a.upsert({ + where: { x_y: { x: '2', y: 3 } }, + update: { x: '3', y: 4, value: 3 }, + create: { x: '4', y: 5, value: 5 }, + }), + ).resolves.toMatchObject({ + x: '3', + y: 4, + value: 3, + }); + }); + + it('multi-id auth', async () => { + const db = await createPolicyTestClient( + ` + model User { + x String + y String + m M? + n N? + p P? + q Q? + @@id([x, y]) + @@allow('all', true) + } + + model M { + id String @id @default(cuid()) + owner User @relation(fields: [ownerX, ownerY], references: [x, y]) + ownerX String + ownerY String + @@unique([ownerX, ownerY]) + @@allow('all', auth() == owner) + } + + model N { + id String @id @default(cuid()) + owner User @relation(fields: [ownerX, ownerY], references: [x, y]) + ownerX String + ownerY String + @@unique([ownerX, ownerY]) + @@allow('all', auth().x == owner.x && auth().y == owner.y) + } + + model P { + id String @id @default(cuid()) + owner User @relation(fields: [ownerX, ownerY], references: [x, y]) + ownerX String + ownerY String + @@unique([ownerX, ownerY]) + @@allow('all', auth() != owner) + } + + model Q { + id String @id @default(cuid()) + owner User @relation(fields: [ownerX, ownerY], references: [x, y]) + ownerX String + ownerY String + @@unique([ownerX, ownerY]) + @@allow('all', auth() != null) + } + `, + ); + + await db.$unuseAll().user.create({ data: { x: '1', y: '1' } }); + await db.$unuseAll().user.create({ data: { x: '1', y: '2' } }); + + await expect(db.m.create({ data: { owner: { connect: { x_y: { x: '1', y: '2' } } } } })).toBeRejectedByPolicy(); + await expect(db.m.create({ data: { owner: { connect: { x_y: { x: '1', y: '1' } } } } })).toBeRejectedByPolicy(); + await expect(db.n.create({ data: { owner: { connect: { x_y: { x: '1', y: '2' } } } } })).toBeRejectedByPolicy(); + await expect(db.n.create({ data: { owner: { connect: { x_y: { x: '1', y: '1' } } } } })).toBeRejectedByPolicy(); + + const dbAuth = db.$setAuth({ x: '1', y: '1' }); + + await expect( + dbAuth.m.create({ data: { owner: { connect: { x_y: { x: '1', y: '2' } } } } }), + ).toBeRejectedByPolicy(); + await expect(dbAuth.m.create({ data: { owner: { connect: { x_y: { x: '1', y: '1' } } } } })).toResolveTruthy(); + await expect( + dbAuth.n.create({ data: { owner: { connect: { x_y: { x: '1', y: '2' } } } } }), + ).toBeRejectedByPolicy(); + await expect(dbAuth.n.create({ data: { owner: { connect: { x_y: { x: '1', y: '1' } } } } })).toResolveTruthy(); + await expect( + dbAuth.p.create({ data: { owner: { connect: { x_y: { x: '1', y: '1' } } } } }), + ).toBeRejectedByPolicy(); + await expect(dbAuth.p.create({ data: { owner: { connect: { x_y: { x: '1', y: '2' } } } } })).toResolveTruthy(); + + await expect(db.q.create({ data: { owner: { connect: { x_y: { x: '1', y: '1' } } } } })).toBeRejectedByPolicy(); + await expect(dbAuth.q.create({ data: { owner: { connect: { x_y: { x: '1', y: '2' } } } } })).toResolveTruthy(); + }); + + it('multi-id to-one nested write', async () => { + const db = await createPolicyTestClient( + ` + model A { + x Int + y Int + v Int + b B @relation(fields: [bId], references: [id]) + bId Int @unique + + @@id([x, y]) + @@allow('all', v > 0) + } + + model B { + id Int @id + v Int + a A? + + @@allow('all', v > 0) + } + `, + ); + await expect( + db.b.create({ + data: { + id: 1, + v: 1, + a: { + create: { + x: 1, + y: 2, + v: 3, + }, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.a.update({ + where: { x_y: { x: 1, y: 2 } }, + data: { b: { update: { v: 5 } } }, + }), + ).toResolveTruthy(); + + expect(await db.b.findUnique({ where: { id: 1 } })).toEqual(expect.objectContaining({ v: 5 })); + }); + + it('multi-id to-many nested write', async () => { + const db = await createPolicyTestClient( + ` + model A { + x Int + y Int + v Int + b B @relation(fields: [bId], references: [id]) + bId Int @unique + + @@id([x, y]) + @@allow('all', v > 0) + } + + model B { + id Int @id + v Int + a A[] + c C? + + @@allow('all', v > 0) + } + + model C { + id Int @id + v Int + b B @relation(fields: [bId], references: [id]) + bId Int @unique + + @@allow('all', v > 0) + } + `, + ); + await expect( + db.b.create({ + data: { + id: 1, + v: 1, + a: { + create: { + x: 1, + y: 2, + v: 2, + }, + }, + c: { + create: { + id: 1, + v: 3, + }, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.a.update({ + where: { x_y: { x: 1, y: 2 } }, + data: { b: { update: { v: 5, c: { update: { v: 6 } } } } }, + }), + ).toResolveTruthy(); + + expect(await db.b.findUnique({ where: { id: 1 } })).toEqual(expect.objectContaining({ v: 5 })); + expect(await db.c.findUnique({ where: { id: 1 } })).toEqual(expect.objectContaining({ v: 6 })); + }); + + // TODO: `future()` support + it.skip('multi-id fields nested id update', async () => { + const db = await createPolicyTestClient( + ` + model A { + x String + y Int + value Int + b B @relation(fields: [bId], references: [id]) + bId Int + @@id([x, y]) + + @@allow('read', true) + @@allow('create', value > 0) + @@allow('update', value > 0 && future().value > 1) + } + + model B { + id Int @id @default(autoincrement()) + a A[] + @@allow('all', true) + } + `, + ); + + await db.b.create({ data: { id: 1, a: { create: { x: '1', y: 1, value: 1 } } } }); + + await expect( + db.b.update({ + where: { id: 1 }, + data: { a: { update: { where: { x_y: { x: '1', y: 1 } }, data: { x: '2', y: 2, value: 0 } } } }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.b.update({ + where: { id: 1 }, + data: { a: { update: { where: { x_y: { x: '1', y: 1 } }, data: { x: '2', y: 2, value: 2 } } } }, + include: { a: true }, + }), + ).resolves.toMatchObject({ a: expect.arrayContaining([expect.objectContaining({ x: '2', y: 2, value: 2 })]) }); + + await expect( + db.b.update({ + where: { id: 1 }, + data: { + a: { + upsert: { + where: { x_y: { x: '2', y: 2 } }, + update: { x: '3', y: 3, value: 0 }, + create: { x: '4', y: '4', value: 4 }, + }, + }, + }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.b.update({ + where: { id: 1 }, + data: { + a: { + upsert: { + where: { x_y: { x: '2', y: 2 } }, + update: { x: '3', y: 3, value: 3 }, + create: { x: '4', y: '4', value: 4 }, + }, + }, + }, + include: { a: true }, + }), + ).resolves.toMatchObject({ a: expect.arrayContaining([expect.objectContaining({ x: '3', y: 3, value: 3 })]) }); + }); +}); diff --git a/packages/runtime/test/policy/migrated/nested-to-many.test.ts b/packages/runtime/test/policy/migrated/nested-to-many.test.ts new file mode 100644 index 00000000..03415119 --- /dev/null +++ b/packages/runtime/test/policy/migrated/nested-to-many.test.ts @@ -0,0 +1,720 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from '../utils'; + +describe('Policy tests to-many', () => { + it('read filtering', async () => { + const db = await createPolicyTestClient( + ` + model M1 { + id String @id @default(uuid()) + m2 M2[] + + @@allow('all', true) + } + + model M2 { + id String @id @default(uuid()) + value Int + m1 M1 @relation(fields: [m1Id], references:[id]) + m1Id String + + @@allow('create', true) + @@allow('read', value > 0) + } + `, + ); + + let read = await db.m1.create({ + include: { m2: true }, + data: { + id: '1', + m2: { + create: [{ value: 0 }], + }, + }, + }); + expect(read.m2).toHaveLength(0); + read = await db.m1.findFirst({ where: { id: '1' }, include: { m2: true } }); + expect(read.m2).toHaveLength(0); + + await db.m1.create({ + data: { + id: '2', + m2: { + create: [{ value: 0 }, { value: 1 }, { value: 2 }], + }, + }, + }); + read = await db.m1.findFirst({ where: { id: '2' }, include: { m2: true } }); + expect(read.m2).toHaveLength(2); + }); + + // TODO: do we need to keep the v2 semantic? + it.skip('read condition hoisting', async () => { + const db = await createPolicyTestClient( + ` + model M1 { + id String @id @default(uuid()) + m2 M2[] + + @@allow('all', true) + } + + model M2 { + id String @id @default(uuid()) + value Int + + m1 M1 @relation(fields: [m1Id], references:[id]) + m1Id String + + m3 M3 @relation(fields: [m3Id], references:[id]) + m3Id String @unique + + m4 M4 @relation(fields: [m4Id], references:[id]) + m4Id String + + @@allow('create', true) + @@allow('read', value > 0) + } + + model M3 { + id String @id @default(uuid()) + value Int + m2 M2? + + @@allow('create', true) + @@allow('read', value > 1) + } + + model M4 { + id String @id @default(uuid()) + value Int + m2 M2[] + + @@allow('create', true) + @@allow('read', value > 1) + } + `, + ); + + await db.m1.create({ + include: { m2: true }, + data: { + id: '1', + m2: { + create: [ + { id: 'm2-1', value: 1, m3: { create: { value: 1 } }, m4: { create: { value: 1 } } }, + { id: 'm2-2', value: 1, m3: { create: { value: 2 } }, m4: { create: { value: 2 } } }, + ], + }, + }, + }); + + let read = await db.m1.findFirst({ include: { m2: true } }); + expect(read.m2).toHaveLength(2); + read = await db.m1.findFirst({ select: { m2: { select: { id: true } } } }); + expect(read.m2).toHaveLength(2); + + // check m2-m3 filtering + // including m3 causes m2 to be filtered since m3 is not nullable + read = await db.m1.findFirst({ include: { m2: { include: { m3: true } } } }); + expect(read.m2).toHaveLength(1); + read = await db.m1.findFirst({ select: { m2: { select: { m3: true } } } }); + expect(read.m2).toHaveLength(1); + + // check m2-m4 filtering + // including m3 causes m2 to be filtered since m4 is not nullable + read = await db.m1.findFirst({ include: { m2: { include: { m4: true } } } }); + expect(read.m2).toHaveLength(1); + read = await db.m1.findFirst({ select: { m2: { select: { m4: true } } } }); + expect(read.m2).toHaveLength(1); + }); + + it('create simple', async () => { + const db = await createPolicyTestClient( + ` + model M1 { + id String @id @default(uuid()) + m2 M2[] + + @@allow('all', true) + } + + model M2 { + id String @id @default(uuid()) + value Int + m1 M1 @relation(fields: [m1Id], references:[id]) + m1Id String + + @@allow('read', true) + @@allow('create', value > 0) + } + `, + ); + + // single create denied + await expect( + db.m1.create({ + data: { + m2: { + create: { value: 0 }, + }, + }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.m1.create({ + data: { + m2: { + create: { value: 1 }, + }, + }, + }), + ).toResolveTruthy(); + + // multi create denied + await expect( + db.m1.create({ + data: { + m2: { + create: [{ value: 0 }, { value: 1 }], + }, + }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.m1.create({ + data: { + m2: { + create: [{ value: 1 }, { value: 2 }], + }, + }, + }), + ).toResolveTruthy(); + }); + + it('update simple', async () => { + const db = await createPolicyTestClient( + ` + model M1 { + id String @id @default(uuid()) + m2 M2[] + + @@allow('all', true) + } + + model M2 { + id String @id @default(uuid()) + value Int + m1 M1 @relation(fields: [m1Id], references:[id]) + m1Id String + + @@allow('read', true) + @@allow('create', true) + @@allow('update', value > 1) + } + `, + ); + + await db.m1.create({ + data: { + id: '1', + m2: { + create: [{ id: '1', value: 1 }], + }, + }, + }); + + // update denied + await expect( + db.m1.update({ + where: { id: '1' }, + data: { + m2: { + update: { + where: { id: '1' }, + data: { value: 2 }, + }, + }, + }, + }), + ).toBeRejectedNotFound(); + + await db.m1.create({ + data: { + id: '2', + m2: { + create: { id: '2', value: 2 }, + }, + }, + }); + + // update success + const r = await db.m1.update({ + where: { id: '2' }, + include: { m2: true }, + data: { + m2: { + update: { + where: { id: '2' }, + data: { value: 3 }, + }, + }, + }, + }); + expect(r.m2).toEqual(expect.arrayContaining([expect.objectContaining({ id: '2', value: 3 })])); + }); + + it('update id field', async () => { + const db = await createPolicyTestClient( + ` + model M1 { + id String @id @default(uuid()) + m2 M2[] + + @@allow('all', true) + } + + model M2 { + id String @id @default(uuid()) + value Int + m1 M1 @relation(fields: [m1Id], references:[id]) + m1Id String + + @@allow('read', true) + @@allow('create', true) + @@allow('update', value > 1) + } + `, + ); + + await db.m1.create({ + data: { + id: '1', + m2: { + create: { id: '1', value: 2 }, + }, + }, + }); + + let r = await db.m1.update({ + where: { id: '1' }, + include: { m2: true }, + data: { + m2: { + update: { + where: { id: '1' }, + data: { id: '2', value: 3 }, + }, + }, + }, + }); + expect(r.m2).toEqual(expect.arrayContaining([expect.objectContaining({ id: '2', value: 3 })])); + + r = await db.m1.update({ + where: { id: '1' }, + include: { m2: true }, + data: { + m2: { + upsert: { + where: { id: '2' }, + create: { id: '4', value: 4 }, + update: { id: '3', value: 4 }, + }, + }, + }, + }); + expect(r.m2).toEqual(expect.arrayContaining([expect.objectContaining({ id: '3', value: 4 })])); + }); + + it('update with create from one to many', async () => { + const db = await createPolicyTestClient( + ` + model M1 { + id String @id @default(uuid()) + m2 M2[] + + @@allow('all', true) + } + + model M2 { + id String @id @default(uuid()) + value Int + m1 M1 @relation(fields: [m1Id], references:[id]) + m1Id String + + @@allow('read', true) + @@allow('create', value > 0) + @@allow('update', value > 1) + } + `, + ); + + await db.m1.create({ + data: { + id: '1', + m2: { + create: { value: 1 }, + }, + }, + }); + + await expect( + db.m1.update({ + where: { id: '1' }, + data: { + m2: { + create: [{ value: 5 }, { value: 0 }], + }, + }, + }), + ).toBeRejectedByPolicy(); + + const r = await db.m1.update({ + where: { id: '1' }, + include: { m2: true }, + data: { + m2: { + create: [{ value: 1 }, { value: 2 }], + }, + }, + }); + expect(r.m2).toHaveLength(3); + }); + + it('update with create from many to one', async () => { + const db = await createPolicyTestClient( + ` + model M1 { + id String @id @default(uuid()) + value Int + m2 M2[] + + @@allow('read', true) + @@allow('create', value > 0) + @@allow('update', value > 1) + } + + model M2 { + id String @id @default(uuid()) + m1 M1? @relation(fields: [m1Id], references:[id]) + m1Id String? + + @@allow('all', true) + } + `, + ); + + await db.m2.create({ data: { id: '1' } }); + + await expect( + db.m2.update({ + where: { id: '1' }, + data: { + m1: { + create: { value: 0 }, + }, + }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.m2.update({ + where: { id: '1' }, + data: { + m1: { + create: { value: 1 }, + }, + }, + }), + ).toResolveTruthy(); + }); + + it('update with delete', async () => { + const db = await createPolicyTestClient( + ` + model M1 { + id String @id @default(uuid()) + m2 M2[] + + @@allow('all', true) + } + + model M2 { + id String @id @default(uuid()) + value Int + m1 M1 @relation(fields: [m1Id], references:[id]) + m1Id String + + @@allow('read', true) + @@allow('create', value > 0) + @@allow('update', value > 1) + @@allow('delete', value > 2) + } + `, + ); + + await db.m1.create({ + data: { + id: '1', + m2: { + create: [ + { id: '1', value: 1 }, + { id: '2', value: 2 }, + { id: '3', value: 3 }, + { id: '4', value: 4 }, + { id: '5', value: 5 }, + ], + }, + }, + }); + + await expect( + db.m1.update({ + where: { id: '1' }, + data: { + m2: { + delete: { id: '1' }, + }, + }, + }), + ).toBeRejectedNotFound(); + expect(await db.$unuseAll().m2.findMany()).toHaveLength(5); + + await expect( + db.m1.update({ + where: { id: '1' }, + data: { + m2: { + delete: [{ id: '1' }, { id: '2' }], + }, + }, + }), + ).toBeRejectedNotFound(); + expect(await db.$unuseAll().m2.findMany()).toHaveLength(5); + + await expect( + db.m1.update({ + where: { id: '1' }, + data: { + m2: { + deleteMany: { OR: [{ id: '2' }, { id: '3' }] }, + }, + }, + }), + ).toResolveTruthy(); + // only m2#3 should be deleted, m2#2 should remain because of policy + await expect(db.m2.findUnique({ where: { id: '3' } })).toResolveNull(); + await expect(db.m2.findUnique({ where: { id: '2' } })).toResolveTruthy(); + + await expect( + db.m1.update({ + where: { id: '1' }, + data: { + m2: { + delete: { id: '3' }, + }, + }, + }), + ).toBeRejectedNotFound(); + + await expect( + db.m1.update({ + where: { id: '1' }, + data: { + m2: { + deleteMany: { value: { gte: 4 } }, + }, + }, + }), + ).toResolveTruthy(); + + await expect(db.m2.findMany({ where: { id: { in: ['4', '5'] } } })).resolves.toHaveLength(0); + }); + + it('create with nested read', async () => { + const db = await createPolicyTestClient( + ` + model M1 { + id String @id @default(uuid()) + value Int + m2 M2[] + m3 M3? + + @@allow('read', value > 1) + @@allow('create', true) + } + + model M2 { + id String @id @default(uuid()) + value Int + m1 M1 @relation(fields: [m1Id], references:[id]) + m1Id String + + @@allow('create', true) + @@allow('read', value > 0) + } + + model M3 { + id String @id @default(uuid()) + value Int + m1 M1 @relation(fields: [m1Id], references:[id]) + m1Id String @unique + + @@allow('create', true) + @@allow('read', value > 0) + } + `, + ); + + await expect( + db.m1.create({ + data: { + id: '1', + value: 1, + }, + }), + ).toBeRejectedByPolicy(); + + // included 'm1' can't be read + await expect( + db.m2.create({ + include: { m1: true }, + data: { + id: '1', + value: 1, + m1: { connect: { id: '1' } }, + }, + }), + ).resolves.toMatchObject({ m1: null }); + await expect(db.m2.findUnique({ where: { id: '1' } })).toResolveTruthy(); + + // included 'm1' can't be read + await expect( + db.m3.create({ + include: { m1: true }, + data: { + id: '1', + value: 1, + m1: { connect: { id: '1' } }, + }, + }), + ).resolves.toMatchObject({ m1: null }); + await expect(db.m3.findUnique({ where: { id: '1' } })).toResolveTruthy(); + + // nested to-many got filtered on read + const r = await db.m1.create({ + include: { m2: true }, + data: { + value: 2, + m2: { create: [{ value: 0 }, { value: 1 }] }, + }, + }); + expect(r.m2).toHaveLength(1); + + // read-back for to-one relation rejected + const r1 = await db.m1.create({ + include: { m3: true }, + data: { + value: 2, + m3: { create: { value: 0 } }, + }, + }); + expect(r1.m3).toBeNull(); + }); + + it('update with nested read', async () => { + const db = await createPolicyTestClient( + ` + model M1 { + id String @id @default(uuid()) + m2 M2[] + m3 M3? + + @@allow('all', true) + } + + model M2 { + id String @id @default(uuid()) + value Int + m1 M1 @relation(fields: [m1Id], references:[id]) + m1Id String + + @@allow('read', value > 1) + @@allow('create,update', true) + } + + model M3 { + id String @id @default(uuid()) + value Int + m1 M1 @relation(fields: [m1Id], references:[id]) + m1Id String @unique + + @@allow('read', value > 1) + @@allow('create,update', true) + } + `, + ); + + await db.m1.create({ + data: { + id: '1', + m2: { + create: [ + { id: '1', value: 0 }, + { id: '2', value: 0 }, + ], + }, + m3: { + create: { value: 0 }, + }, + }, + }); + + const r = await db.m1.update({ + where: { id: '1' }, + include: { m3: true }, + data: { + m3: { + update: { + value: 1, + }, + }, + }, + }); + expect(r.m3).toBeNull(); + + const r1 = await db.m1.update({ + where: { id: '1' }, + include: { m3: true, m2: true }, + data: { + m3: { + update: { + value: 2, + }, + }, + }, + }); + // m3 is ok now + expect(r1.m3.value).toBe(2); + // m2 got filtered + expect(r1.m2).toHaveLength(0); + + const r2 = await db.m1.update({ + where: { id: '1' }, + select: { m2: true }, + data: { + m2: { + update: { + where: { id: '1' }, + data: { value: 2 }, + }, + }, + }, + }); + // one of m2 matches policy now + expect(r2.m2).toHaveLength(1); + }); +}); diff --git a/packages/runtime/test/policy/migrated/nested-to-one.test.ts b/packages/runtime/test/policy/migrated/nested-to-one.test.ts new file mode 100644 index 00000000..5838cae8 --- /dev/null +++ b/packages/runtime/test/policy/migrated/nested-to-one.test.ts @@ -0,0 +1,445 @@ +import { describe, it, expect } from 'vitest'; +import { createPolicyTestClient } from '../utils'; + +describe('With Policy:nested to-one', () => { + it('read filtering for optional relation', async () => { + const db = await createPolicyTestClient( + ` + model M1 { + id String @id @default(uuid()) + m2 M2? + + @@allow('all', true) + } + + model M2 { + id String @id @default(uuid()) + m1 M1 @relation(fields: [m1Id], references:[id]) + m1Id String @unique + value Int + + @@allow('create', true) + @@allow('read', value > 0) + } + `, + ); + + let read = await db.m1.create({ + include: { m2: true }, + data: { + id: '1', + m2: { + create: { id: '1', value: 0 }, + }, + }, + }); + expect(read.m2).toBeNull(); + + await expect(db.m1.findUnique({ where: { id: '1' }, include: { m2: true } })).resolves.toEqual( + expect.objectContaining({ m2: null }), + ); + await expect(db.m1.findMany({ include: { m2: true } })).resolves.toEqual( + expect.arrayContaining([expect.objectContaining({ m2: null })]), + ); + + await db.$unuseAll().m2.update({ where: { id: '1' }, data: { value: 1 } }); + read = await db.m1.findUnique({ where: { id: '1' }, include: { m2: true } }); + expect(read.m2).toEqual(expect.objectContaining({ id: '1', value: 1 })); + }); + + it('read rejection for non-optional relation', async () => { + const db = await createPolicyTestClient( + ` + model M1 { + id String @id @default(uuid()) + m2 M2? + value Int + + @@allow('create', true) + @@allow('read', value > 0) + } + + model M2 { + id String @id @default(uuid()) + m1 M1 @relation(fields: [m1Id], references:[id]) + m1Id String @unique + + @@allow('all', true) + } + `, + ); + + await db.$unuseAll().m1.create({ + data: { + id: '1', + value: 0, + m2: { + create: { id: '1' }, + }, + }, + }); + + await expect(db.m2.findUnique({ where: { id: '1' }, include: { m1: true } })).resolves.toMatchObject({ + m1: null, + }); + + await db.$unuseAll().m1.update({ where: { id: '1' }, data: { value: 1 } }); + await expect(db.m2.findMany({ include: { m1: true } })).toResolveTruthy(); + }); + + // TODO: should we keep v2 semantic? + it.skip('read condition hoisting', async () => { + const db = await createPolicyTestClient( + ` + model M1 { + id String @id @default(uuid()) + m2 M2 @relation(fields: [m2Id], references:[id]) + m2Id String @unique + + @@allow('all', true) + } + + model M2 { + id String @id @default(uuid()) + value Int + + m1 M1? + + m3 M3 @relation(fields: [m3Id], references:[id]) + m3Id String @unique + + @@allow('create', true) + @@allow('read', value > 0) + } + + model M3 { + id String @id @default(uuid()) + value Int + m2 M2? + + @@allow('create', true) + @@allow('read', value > 1) + } + `, + ); + + await db.m1.create({ + include: { m2: true }, + data: { + id: '1', + m2: { + create: { id: 'm2-1', value: 1, m3: { create: { value: 1 } } }, + }, + }, + }); + + // check m2-m3 filtering + // including m3 causes m1 to be filtered due to hosting + await expect(db.m1.findFirst({ include: { m2: { include: { m3: true } } } })).toResolveNull(); + await expect(db.m1.findFirst({ select: { m2: { select: { m3: true } } } })).toResolveNull(); + }); + + it('create and update tests', async () => { + const db = await createPolicyTestClient( + ` + model M1 { + id String @id @default(uuid()) + m2 M2? + + @@allow('all', true) + } + + model M2 { + id String @id @default(uuid()) + value Int + m1 M1 @relation(fields: [m1Id], references:[id]) + m1Id String @unique + + @@allow('read', true) + @@allow('create', value > 0) + @@allow('update', value > 1) + } + `, + ); + + // create denied + await expect( + db.m1.create({ + data: { + m2: { + create: { value: 0 }, + }, + }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.m1.create({ + data: { + id: '1', + m2: { + create: { id: '1', value: 1 }, + }, + }, + }), + ).toResolveTruthy(); + + // nested update denied + await expect( + db.m1.update({ + where: { id: '1' }, + data: { + m2: { + update: { value: 2 }, + }, + }, + }), + ).toBeRejectedNotFound(); + }); + + // TODO: `future()` support + it.skip('nested update id tests', async () => { + const db = await createPolicyTestClient( + ` + model M1 { + id String @id @default(uuid()) + m2 M2? + + @@allow('all', true) + } + + model M2 { + id String @id @default(uuid()) + value Int + m1 M1 @relation(fields: [m1Id], references:[id]) + m1Id String @unique + + @@allow('read', true) + @@allow('create', value > 0) + @@allow('update', value > 1 && future().value > 2) + } + `, + ); + + await db.m1.create({ + data: { + id: '1', + m2: { + create: { id: '1', value: 2 }, + }, + }, + }); + + await expect( + db.m1.update({ + where: { id: '1' }, + data: { + m2: { + update: { id: '2', value: 1 }, + }, + }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.m1.update({ + where: { id: '1' }, + data: { + m2: { + update: { id: '2', value: 3 }, + }, + }, + include: { m2: true }, + }), + ).resolves.toMatchObject({ m2: expect.objectContaining({ id: '2', value: 3 }) }); + }); + + it('nested create', async () => { + const db = await createPolicyTestClient( + ` + model M1 { + id String @id @default(uuid()) + m2 M2? + + @@allow('all', true) + } + + model M2 { + id String @id @default(uuid()) + value Int + m1 M1 @relation(fields: [m1Id], references:[id]) + m1Id String @unique + + @@allow('read', true) + @@allow('create', value > 0) + @@allow('update', value > 1) + } + `, + ); + + await db.m1.create({ + data: { + id: '1', + }, + }); + + // nested create denied + await expect( + db.m1.update({ + where: { id: '1' }, + data: { + m2: { + create: { value: 0 }, + }, + }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.m1.update({ + where: { id: '1' }, + data: { + m2: { + create: { value: 1 }, + }, + }, + }), + ).toResolveTruthy(); + }); + + it('nested delete', async () => { + const db = await createPolicyTestClient( + ` + model M1 { + id String @id @default(uuid()) + m2 M2? + + @@allow('all', true) + } + + model M2 { + id String @id @default(uuid()) + value Int + m1 M1 @relation(fields: [m1Id], references:[id]) + m1Id String @unique + + @@allow('read', true) + @@allow('create', true) + @@allow('update', true) + @@allow('delete', value > 1) + } + `, + ); + + await db.m1.create({ + data: { + id: '1', + m2: { + create: { id: '1', value: 1 }, + }, + }, + }); + + // nested delete denied + await expect( + db.m1.update({ + where: { id: '1' }, + data: { + m2: { delete: true }, + }, + }), + ).toBeRejectedNotFound(); + expect(await db.m2.findUnique({ where: { id: '1' } })).toBeTruthy(); + + // update m2 so it can be deleted + await db.m1.update({ + where: { id: '1' }, + data: { + m2: { update: { value: 3 } }, + }, + }); + + expect( + await db.m1.update({ + where: { id: '1' }, + data: { + m2: { delete: true }, + }, + }), + ).toBeTruthy(); + // check deleted + expect(await db.m2.findUnique({ where: { id: '1' } })).toBeNull(); + }); + + it('nested relation delete', async () => { + const db = await createPolicyTestClient( + ` + model User { + id String @id @default(uuid()) + m1 M1? + + @@allow('all', true) + } + + model M1 { + id String @id @default(uuid()) + value Int + user User? @relation(fields: [userId], references: [id]) + userId String? @unique + + @@allow('read,create,update', true) + @@allow('delete', auth().id == 'user1' && value > 0) + } + `, + ); + + await db.$setAuth({ id: 'user1' }).m1.create({ + data: { + id: 'm1', + value: 1, + }, + }); + + await expect( + db.$setAuth({ id: 'user2' }).user.create({ + data: { + id: 'user2', + m1: { + connect: { id: 'm1' }, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.$setAuth({ id: 'user2' }).user.update({ + where: { id: 'user2' }, + data: { + m1: { delete: true }, + }, + }), + ).toBeRejectedNotFound(); + + await expect( + db.$setAuth({ id: 'user1' }).user.create({ + data: { + id: 'user1', + m1: { + connect: { id: 'm1' }, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.$setAuth({ id: 'user1' }).user.update({ + where: { id: 'user1' }, + data: { + m1: { delete: true }, + }, + }), + ).toResolveTruthy(); + + expect(await db.$unuseAll().m1.findMany()).toHaveLength(0); + }); +}); diff --git a/packages/runtime/test/policy/migrated/omit.test.ts b/packages/runtime/test/policy/migrated/omit.test.ts new file mode 100644 index 00000000..57fdd014 --- /dev/null +++ b/packages/runtime/test/policy/migrated/omit.test.ts @@ -0,0 +1,56 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from '../utils'; + +describe('prisma omit', () => { + it('per query', async () => { + const db = await createPolicyTestClient( + ` + model User { + id String @id @default(cuid()) + name String + profile Profile? + age Int + value Int @allow('read', age > 20) + @@allow('all', age > 18) + } + + model Profile { + id String @id @default(cuid()) + user User @relation(fields: [userId], references: [id]) + userId String @unique + level Int + @@allow('all', level > 1) + } + `, + ); + + await db.$unuseAll().user.create({ + data: { + name: 'John', + age: 25, + value: 10, + profile: { + create: { level: 2 }, + }, + }, + }); + + let found = await db.user.findFirst({ + include: { profile: { omit: { level: true } } }, + omit: { + age: true, + }, + }); + expect(found.age).toBeUndefined(); + expect(found.value).toEqual(10); + expect(found.profile.level).toBeUndefined(); + + found = await db.user.findFirst({ + select: { value: true, profile: { omit: { level: true } } }, + }); + console.log(found); + expect(found.age).toBeUndefined(); + expect(found.value).toEqual(10); + expect(found.profile.level).toBeUndefined(); + }); +}); diff --git a/packages/runtime/test/policy/migrated/petstore-sample.test.ts b/packages/runtime/test/policy/migrated/petstore-sample.test.ts new file mode 100644 index 00000000..99e5e8c7 --- /dev/null +++ b/packages/runtime/test/policy/migrated/petstore-sample.test.ts @@ -0,0 +1,45 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from '../utils'; +import { schema } from '../../schemas/petstore/schema'; + +// TODO: `future()` support +describe.skip('Pet Store Policy Tests', () => { + it('crud', async () => { + const petData = [ + { + id: 'luna', + name: 'Luna', + category: 'kitten', + }, + { + id: 'max', + name: 'Max', + category: 'doggie', + }, + { + id: 'cooper', + name: 'Cooper', + category: 'reptile', + }, + ]; + + const db = await createPolicyTestClient(schema); + + for (const pet of petData) { + await db.$unuseAll().pet.create({ data: pet }); + } + + await db.$unuseAll().user.create({ data: { id: 'user1', email: 'user1@abc.com' } }); + + const r = await db.$setAuth({ id: 'user1' }).order.create({ + include: { user: true, pets: true }, + data: { + user: { connect: { id: 'user1' } }, + pets: { connect: [{ id: 'luna' }, { id: 'max' }] }, + }, + }); + + expect(r.user.id).toBe('user1'); + expect(r.pets).toHaveLength(2); + }); +}); diff --git a/packages/runtime/test/policy/migrated/query-reduction.test.ts b/packages/runtime/test/policy/migrated/query-reduction.test.ts new file mode 100644 index 00000000..61b11d06 --- /dev/null +++ b/packages/runtime/test/policy/migrated/query-reduction.test.ts @@ -0,0 +1,137 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from '../utils'; + +describe('With Policy: query reduction', () => { + it('test query reduction', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + role String @default("User") + posts Post[] + private Boolean @default(false) + age Int + + @@allow('all', auth() == this) + @@allow('read', !private) + } + + model Post { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int + title String + published Boolean @default(false) + viewCount Int @default(0) + + @@allow('all', auth() == user) + @@allow('read', published) + } + `, + ); + + await db.$unuseAll().user.create({ + data: { + id: 1, + role: 'User', + age: 18, + posts: { + create: [ + { id: 1, title: 'Post 1' }, + { id: 2, title: 'Post 2', published: true }, + ], + }, + }, + }); + await db.$unuseAll().user.create({ + data: { + id: 2, + role: 'Admin', + age: 28, + private: true, + posts: { + create: [{ id: 3, title: 'Post 3', viewCount: 100 }], + }, + }, + }); + + const dbUser1 = db.$setAuth({ id: 1 }); + const dbUser2 = db.$setAuth({ id: 2 }); + + await expect( + dbUser1.user.findMany({ + where: { id: 2, AND: { age: { gt: 20 } } }, + }), + ).resolves.toHaveLength(0); + + await expect( + dbUser2.user.findMany({ + where: { id: 2, AND: { age: { gt: 20 } } }, + }), + ).resolves.toHaveLength(1); + + await expect( + dbUser1.user.findMany({ + where: { + AND: { age: { gt: 10 } }, + OR: [{ age: { gt: 25 } }, { age: { lt: 20 } }], + NOT: { private: true }, + }, + }), + ).resolves.toHaveLength(1); + + await expect( + dbUser2.user.findMany({ + where: { + AND: { age: { gt: 10 } }, + OR: [{ age: { gt: 25 } }, { age: { lt: 20 } }], + NOT: { private: true }, + }, + }), + ).resolves.toHaveLength(1); + + // to-many relation query + await expect( + dbUser1.user.findMany({ + where: { posts: { some: { published: true } } }, + }), + ).resolves.toHaveLength(1); + await expect( + dbUser1.user.findMany({ + where: { posts: { some: { AND: [{ published: true }, { viewCount: { gt: 0 } }] } } }, + }), + ).resolves.toHaveLength(0); + await expect( + dbUser2.user.findMany({ + where: { posts: { some: { AND: [{ published: false }, { viewCount: { gt: 0 } }] } } }, + }), + ).resolves.toHaveLength(1); + await expect( + dbUser1.user.findMany({ + where: { posts: { every: { published: true } } }, + }), + ).resolves.toHaveLength(0); + await expect( + dbUser1.user.findMany({ + where: { posts: { none: { published: true } } }, + }), + ).resolves.toHaveLength(0); + + // to-one relation query + await expect( + dbUser1.post.findMany({ + where: { user: { role: 'Admin' } }, + }), + ).resolves.toHaveLength(0); + await expect( + dbUser1.post.findMany({ + where: { user: { is: { role: 'Admin' } } }, + }), + ).resolves.toHaveLength(0); + await expect( + dbUser1.post.findMany({ + where: { user: { isNot: { role: 'User' } } }, + }), + ).resolves.toHaveLength(0); + }); +}); diff --git a/packages/runtime/test/policy/migrated/relation-check.test.ts b/packages/runtime/test/policy/migrated/relation-check.test.ts new file mode 100644 index 00000000..be0947aa --- /dev/null +++ b/packages/runtime/test/policy/migrated/relation-check.test.ts @@ -0,0 +1,736 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from '../utils'; + +describe('Relation checker', () => { + it('should work for read', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + public Boolean + @@allow('read', public) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + age Int + @@allow('read', check(user, 'read')) + } + `, + ); + + await db.$unuseAll().user.create({ + data: { + id: 1, + public: true, + profile: { + create: { age: 18 }, + }, + }, + }); + + await db.$unuseAll().user.create({ + data: { + id: 2, + public: false, + profile: { + create: { age: 20 }, + }, + }, + }); + + await expect(db.profile.findMany()).resolves.toHaveLength(1); + }); + + it('should work for simple create', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + public Boolean + @@allow('create', true) + @@allow('read', public) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + age Int + @@allow('read', true) + @@allow('create', check(user, 'read')) + } + `, + ); + + await db.$unuseAll().user.create({ + data: { + id: 1, + public: true, + }, + }); + + await db.$unuseAll().user.create({ + data: { + id: 2, + public: false, + }, + }); + + await expect(db.profile.create({ data: { user: { connect: { id: 1 } }, age: 18 } })).toResolveTruthy(); + await expect(db.profile.create({ data: { user: { connect: { id: 2 } }, age: 18 } })).toBeRejectedByPolicy(); + }); + + it('should work for nested create', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + public Boolean + @@allow('create', true) + @@allow('read', public) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + age Int + @@allow('read', true) + @@allow('create', age < 30 && check(user, 'read')) + } + `, + ); + + await expect( + db.user.create({ + data: { + id: 1, + public: true, + profile: { + create: { age: 18 }, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.user.create({ + data: { + id: 2, + public: false, + profile: { + create: { age: 18 }, + }, + }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.user.create({ + data: { + id: 3, + public: true, + profile: { + create: { age: 30 }, + }, + }, + }), + ).toBeRejectedByPolicy(); + }); + + it('should work for update', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + public Boolean + @@allow('create', true) + @@allow('read', public) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + age Int + @@allow('read', true) + @@allow('update', check(user, 'read') && age < 30) + } + `, + ); + + await db.$unuseAll().user.create({ + data: { + id: 1, + public: true, + profile: { + create: { id: 1, age: 18 }, + }, + }, + }); + + await db.$unuseAll().user.create({ + data: { + id: 2, + public: false, + profile: { + create: { id: 2, age: 20 }, + }, + }, + }); + + await db.$unuseAll().user.create({ + data: { + id: 3, + public: true, + profile: { + create: { id: 3, age: 30 }, + }, + }, + }); + + await expect(db.profile.update({ where: { id: 1 }, data: { age: 21 } })).toResolveTruthy(); + await expect(db.profile.update({ where: { id: 2 }, data: { age: 21 } })).toBeRejectedNotFound(); + await expect(db.profile.update({ where: { id: 3 }, data: { age: 21 } })).toBeRejectedNotFound(); + }); + + it('should work for delete', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + public Boolean + @@allow('create', true) + @@allow('read', public) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + age Int + @@allow('read', true) + @@allow('delete', check(user, 'read') && age < 30) + } + `, + ); + + await db.$unuseAll().user.create({ + data: { + id: 1, + public: true, + profile: { + create: { id: 1, age: 18 }, + }, + }, + }); + + await db.$unuseAll().user.create({ + data: { + id: 2, + public: false, + profile: { + create: { id: 2, age: 20 }, + }, + }, + }); + + await db.$unuseAll().user.create({ + data: { + id: 3, + public: true, + profile: { + create: { id: 3, age: 30 }, + }, + }, + }); + + await expect(db.profile.delete({ where: { id: 1 } })).toResolveTruthy(); + await expect(db.profile.delete({ where: { id: 2 } })).toBeRejectedNotFound(); + await expect(db.profile.delete({ where: { id: 3 } })).toBeRejectedNotFound(); + }); + + // TODO: field-level policy support + it.skip('should work for field-level', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + public Boolean + @@allow('read', public) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + age Int @allow('read', age < 30 && check(user, 'read')) + @@allow('all', true) + } + `, + ); + + await db.$unuseAll().user.create({ + data: { + id: 1, + public: true, + profile: { + create: { age: 18 }, + }, + }, + }); + + await db.$unuseAll().user.create({ + data: { + id: 2, + public: false, + profile: { + create: { age: 20 }, + }, + }, + }); + + await db.$unuseAll().user.create({ + data: { + id: 3, + public: true, + profile: { + create: { age: 30 }, + }, + }, + }); + + const p1 = await db.profile.findUnique({ where: { id: 1 } }); + expect(p1.age).toBe(18); + const p2 = await db.profile.findUnique({ where: { id: 2 } }); + expect(p2.age).toBeUndefined(); + const p3 = await db.profile.findUnique({ where: { id: 3 } }); + expect(p3.age).toBeUndefined(); + }); + + // TODO: field-level policy support + it.skip('should work for field-level with override', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + public Boolean + @@allow('read', public) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + age Int @allow('read', age < 30 && check(user, 'read'), true) + } + `, + ); + + await db.$unuseAll().user.create({ + data: { + id: 1, + public: true, + profile: { + create: { age: 18 }, + }, + }, + }); + + await db.$unuseAll().user.create({ + data: { + id: 2, + public: false, + profile: { + create: { age: 20 }, + }, + }, + }); + + await db.$unuseAll().user.create({ + data: { + id: 3, + public: true, + profile: { + create: { age: 30 }, + }, + }, + }); + + const p1 = await db.profile.findUnique({ where: { id: 1 }, select: { age: true } }); + expect(p1.age).toBe(18); + const p2 = await db.profile.findUnique({ where: { id: 2 }, select: { age: true } }); + expect(p2).toBeNull(); + const p3 = await db.profile.findUnique({ where: { id: 3 }, select: { age: true } }); + expect(p3).toBeNull(); + }); + + it('should work for cross-model field comparison', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + age Int + @@allow('read', true) + @@allow('update', age == profile.age) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + age Int + @@allow('read', true) + @@allow('update', check(user, 'update') && age < 30) + } + `, + ); + + await db.$unuseAll().user.create({ + data: { + id: 1, + age: 18, + profile: { + create: { id: 1, age: 18 }, + }, + }, + }); + + await db.$unuseAll().user.create({ + data: { + id: 2, + age: 18, + profile: { + create: { id: 2, age: 20 }, + }, + }, + }); + + await db.$unuseAll().user.create({ + data: { + id: 3, + age: 30, + profile: { + create: { id: 3, age: 30 }, + }, + }, + }); + + await expect(db.profile.update({ where: { id: 1 }, data: { age: 21 } })).toResolveTruthy(); + await expect(db.profile.update({ where: { id: 2 }, data: { age: 21 } })).toBeRejectedNotFound(); + await expect(db.profile.update({ where: { id: 3 }, data: { age: 21 } })).toBeRejectedNotFound(); + }); + + it('should work for implicit specific operations', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + public Boolean + @@allow('read', public) + @@allow('create', true) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + age Int + @@allow('read', check(user)) + @@allow('create', check(user)) + } + `, + ); + + await db.$unuseAll().user.create({ + data: { + id: 1, + public: true, + profile: { + create: { age: 18 }, + }, + }, + }); + + await db.$unuseAll().user.create({ + data: { + id: 2, + public: false, + profile: { + create: { age: 20 }, + }, + }, + }); + + await expect(db.profile.findMany()).resolves.toHaveLength(1); + + await db.$unuseAll().user.create({ + data: { + id: 3, + public: true, + }, + }); + await expect(db.profile.create({ data: { user: { connect: { id: 3 } }, age: 18 } })).toResolveTruthy(); + + await db.$unuseAll().user.create({ + data: { + id: 4, + public: false, + }, + }); + await expect(db.profile.create({ data: { user: { connect: { id: 4 } }, age: 18 } })).toBeRejectedByPolicy(); + }); + + it('should work for implicit all operations', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + public Boolean + @@allow('all', public) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + age Int + @@allow('all', check(user)) + } + `, + ); + + await db.$unuseAll().user.create({ + data: { + id: 1, + public: true, + profile: { + create: { age: 18 }, + }, + }, + }); + + await db.$unuseAll().user.create({ + data: { + id: 2, + public: false, + profile: { + create: { age: 20 }, + }, + }, + }); + + await expect(db.profile.findMany()).resolves.toHaveLength(1); + + await db.$unuseAll().user.create({ + data: { + id: 3, + public: true, + }, + }); + await expect(db.profile.create({ data: { user: { connect: { id: 3 } }, age: 18 } })).toResolveTruthy(); + + await db.$unuseAll().user.create({ + data: { + id: 4, + public: false, + }, + }); + await expect(db.profile.create({ data: { user: { connect: { id: 4 } }, age: 18 } })).toBeRejectedByPolicy(); + }); + + it('should report error for invalid args', async () => { + await expect( + createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + public Boolean + @@allow('read', check(public)) + } + `, + ), + ).rejects.toThrow(/argument must be a relation field/); + + await expect( + createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + posts Post[] + @@allow('read', check(posts)) + } + model Post { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int + } + `, + ), + ).rejects.toThrow(/argument cannot be an array field/); + + await expect( + createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + @@allow('read', check(profile.details)) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int + details ProfileDetails? + } + + model ProfileDetails { + id Int @id @default(autoincrement()) + profile Profile @relation(fields: [profileId], references: [id]) + profileId Int + age Int + } + `, + ), + ).rejects.toThrow(/argument must be a relation field/); + + await expect( + createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + posts Post[] + @@allow('read', check(posts, 'all')) + } + model Post { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int + } + `, + ), + ).rejects.toThrow(/argument must be a "read", "create", "update", or "delete"/); + }); + + it('should report error for cyclic relation check', async () => { + await expect( + createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + profileDetails ProfileDetails? + public Boolean + @@allow('read', check(profile)) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + details ProfileDetails? + @@allow('read', check(details)) + } + + model ProfileDetails { + id Int @id @default(autoincrement()) + profile Profile @relation(fields: [profileId], references: [id]) + profileId Int @unique + user User @relation(fields: [userId], references: [id]) + userId Int @unique + age Int + @@allow('read', check(user)) + } + `, + ), + ).rejects.toThrow(/cyclic dependency/); + }); + + it('should report error for cyclic relation check indirect', async () => { + await expect( + createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + public Boolean + @@allow('read', check(profile)) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + details ProfileDetails? + @@allow('read', check(details)) + } + + model ProfileDetails { + id Int @id @default(autoincrement()) + profile Profile @relation(fields: [profileId], references: [id]) + profileId Int @unique + age Int + @@allow('read', check(profile)) + } + `, + ), + ).rejects.toThrow(/cyclic dependency/); + }); + + it('should work for query builder', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + public Boolean + @@allow('read', public) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + age Int + @@allow('read', check(user)) + } + `, + ); + + await db.$unuseAll().user.create({ + data: { + id: 1, + public: true, + profile: { + create: { age: 18 }, + }, + }, + }); + + await db.$unuseAll().user.create({ + data: { + id: 2, + public: false, + profile: { + create: { age: 20 }, + }, + }, + }); + + await expect(db.$qb.selectFrom('Profile as p').selectAll('p').execute()).resolves.toHaveLength(1); + }); +}); diff --git a/packages/runtime/test/policy/migrated/relation-many-to-many-filter.test.ts b/packages/runtime/test/policy/migrated/relation-many-to-many-filter.test.ts new file mode 100644 index 00000000..916f8c50 --- /dev/null +++ b/packages/runtime/test/policy/migrated/relation-many-to-many-filter.test.ts @@ -0,0 +1,280 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from '../utils'; + +describe('Policy many-to-many relation tests', () => { + const model = ` + model M1 { + id String @id @default(uuid()) + value Int + deleted Boolean @default(false) + m2 M2[] + + @@allow('read', !deleted) + @@allow('create,update', true) + } + + model M2 { + id String @id @default(uuid()) + value Int + deleted Boolean @default(false) + m1 M1[] + + @@allow('read', !deleted) + @@allow('create,update', true) + } + `; + + it('some filter', async () => { + const db = await createPolicyTestClient(model, { usePrismaPush: true }); + + await db.m1.create({ + data: { + id: '1', + value: 1, + m2: { + create: [ + { + id: '1', + value: 1, + }, + { + id: '2', + value: 2, + deleted: true, + }, + ], + }, + }, + }); + + // m1 -> m2 lookup + const r = await db.m1.findFirst({ + where: { + id: '1', + m2: { + some: {}, + }, + }, + include: { + _count: { select: { m2: true } }, + }, + }); + expect(r._count.m2).toBe(1); + + // m2 -> m1 lookup + await expect( + db.m2.findFirst({ + where: { + id: '1', + m1: { + some: {}, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + some: { value: { gt: 1 } }, + }, + }, + }), + ).toResolveFalsy(); + + // m1 with empty m2 list + await db.m1.create({ + data: { + id: '2', + value: 1, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + some: {}, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + some: { value: { gt: 1 } }, + }, + }, + }), + ).toResolveFalsy(); + }); + + it('none filter', async () => { + const db = await createPolicyTestClient(model, { usePrismaPush: true }); + + await db.m1.create({ + data: { + id: '1', + value: 1, + m2: { + create: [ + { id: '1', value: 1 }, + { id: '2', value: 2, deleted: true }, + ], + }, + }, + }); + + // m1 -> m2 lookup + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + none: {}, + }, + }, + }), + ).toResolveFalsy(); + + // m2 -> m1 lookup + await expect( + db.m2.findFirst({ + where: { + m1: { + none: {}, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + none: { value: { gt: 1 } }, + }, + }, + }), + ).toResolveTruthy(); + + // m1 with empty m2 list + await db.m1.create({ + data: { + id: '2', + value: 2, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + none: {}, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + none: { value: { gt: 1 } }, + }, + }, + }), + ).toResolveTruthy(); + }); + + it('every filter', async () => { + const db = await createPolicyTestClient(model, { usePrismaPush: true }); + + await db.m1.create({ + data: { + id: '1', + value: 1, + m2: { + create: [ + { id: '1', value: 1 }, + { id: '2', value: 2, deleted: true }, + ], + }, + }, + }); + + // m1 -> m2 lookup + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + every: {}, + }, + }, + }), + ).toResolveTruthy(); + + // m2 -> m1 lookup + await expect( + db.m2.findFirst({ + where: { + id: '1', + m1: { + every: {}, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + every: { value: { gt: 1 } }, + }, + }, + }), + ).toResolveFalsy(); + + // m1 with empty m2 list + await db.m1.create({ + data: { + id: '2', + value: 2, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + every: {}, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + every: { value: { gt: 1 } }, + }, + }, + }), + ).toResolveTruthy(); + }); +}); diff --git a/packages/runtime/test/policy/migrated/relation-one-to-many-filter.test.ts b/packages/runtime/test/policy/migrated/relation-one-to-many-filter.test.ts new file mode 100644 index 00000000..4330c008 --- /dev/null +++ b/packages/runtime/test/policy/migrated/relation-one-to-many-filter.test.ts @@ -0,0 +1,1009 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from '../utils'; + +describe('Relation one-to-many filter', () => { + const model = ` + model M1 { + id String @id @default(uuid()) + m2 M2[] + + @@allow('all', true) + } + + model M2 { + id String @id @default(uuid()) + value Int + deleted Boolean @default(false) + m1 M1 @relation(fields: [m1Id], references:[id]) + m1Id String + m3 M3[] + + @@allow('read', !deleted) + @@allow('create', true) + } + + model M3 { + id String @id @default(uuid()) + value Int + deleted Boolean @default(false) + m2 M2 @relation(fields: [m2Id], references:[id]) + m2Id String + + @@allow('read', !deleted) + @@allow('create', true) + } + `; + + it('some filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: [ + { + value: 1, + m3: { + create: { + value: 1, + }, + }, + }, + { + value: 2, + deleted: true, + m3: { + create: { + value: 2, + deleted: true, + }, + }, + }, + ], + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + some: {}, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + some: { value: { gt: 1 } }, + }, + }, + }), + ).toResolveFalsy(); + + // include clause + + const r = await db.m1.findFirst({ + where: { id: '1' }, + include: { + m2: { + where: { + m3: { + some: {}, + }, + }, + }, + }, + }); + expect(r.m2).toHaveLength(1); + + const r1 = await db.m1.findFirst({ + where: { + id: '1', + }, + include: { + m2: { + where: { + m3: { + some: { value: { gt: 1 } }, + }, + }, + }, + }, + }); + expect(r1.m2).toHaveLength(0); + + // m1 with empty m2 list + await db.m1.create({ + data: { + id: '2', + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + some: {}, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + some: { value: { gt: 1 } }, + }, + }, + }), + ).toResolveFalsy(); + }); + + it('none filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: [ + { + value: 1, + m3: { + create: { + value: 1, + }, + }, + }, + { + value: 2, + deleted: true, + m3: { + create: { + value: 2, + deleted: true, + }, + }, + }, + ], + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + none: {}, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + none: { value: { gt: 1 } }, + }, + }, + }), + ).toResolveTruthy(); + + // include clause + + const r = await db.m1.findFirst({ + where: { id: '1' }, + include: { + m2: { + where: { + m3: { + none: {}, + }, + }, + }, + }, + }); + expect(r.m2).toHaveLength(0); + + const r1 = await db.m1.findFirst({ + where: { + id: '1', + }, + include: { + m2: { + where: { + m3: { + none: { value: { gt: 1 } }, + }, + }, + }, + }, + }); + expect(r1.m2).toHaveLength(1); + + // m1 with empty m2 list + await db.m1.create({ + data: { + id: '2', + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + none: {}, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + none: { value: { gt: 1 } }, + }, + }, + }), + ).toResolveTruthy(); + }); + + it('every filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: [ + { + value: 1, + m3: { + create: { + value: 1, + }, + }, + }, + { + value: 2, + deleted: true, + m3: { + create: { + value: 2, + deleted: true, + }, + }, + }, + ], + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + every: {}, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + every: { value: { gt: 1 } }, + }, + }, + }), + ).toResolveFalsy(); + + // include clause + + const r = await db.m1.findFirst({ + where: { id: '1' }, + include: { + m2: { + where: { + m3: { + every: {}, + }, + }, + }, + }, + }); + expect(r.m2).toHaveLength(1); + + const r1 = await db.m1.findFirst({ + where: { + id: '1', + }, + include: { + m2: { + where: { + m3: { + every: { value: { gt: 1 } }, + }, + }, + }, + }, + }); + expect(r1.m2).toHaveLength(0); + + // m1 with empty m2 list + await db.m1.create({ + data: { + id: '2', + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + every: {}, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + every: { value: { gt: 1 } }, + }, + }, + }), + ).toResolveTruthy(); + }); + + it('_count filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: [ + { + value: 1, + m3: { + create: { + value: 1, + }, + }, + }, + { + value: 2, + deleted: true, + m3: { + create: { + value: 2, + deleted: true, + }, + }, + }, + ], + }, + }, + }); + + await expect(db.m1.findFirst({ include: { _count: true } })).resolves.toMatchObject({ _count: { m2: 1 } }); + await expect(db.m1.findFirst({ include: { _count: { select: { m2: true } } } })).resolves.toMatchObject({ + _count: { m2: 1 }, + }); + await expect( + db.m1.findFirst({ include: { _count: { select: { m2: { where: { value: { gt: 0 } } } } } } }), + ).resolves.toMatchObject({ _count: { m2: 1 } }); + await expect( + db.m1.findFirst({ include: { _count: { select: { m2: { where: { value: { gt: 1 } } } } } } }), + ).resolves.toMatchObject({ _count: { m2: 0 } }); + + await expect(db.m1.findFirst({ include: { m2: { select: { _count: true } } } })).resolves.toMatchObject({ + m2: [{ _count: { m3: 1 } }], + }); + await expect( + db.m1.findFirst({ include: { m2: { select: { _count: { select: { m3: true } } } } } }), + ).resolves.toMatchObject({ m2: [{ _count: { m3: 1 } }] }); + await expect( + db.m1.findFirst({ + include: { m2: { select: { _count: { select: { m3: { where: { value: { gt: 1 } } } } } } } }, + }), + ).resolves.toMatchObject({ m2: [{ _count: { m3: 0 } }] }); + }); +}); + +// TODO: field-level policy support +describe.skip('Relation one-to-many filter with field-level rules', () => { + const model = ` + model M1 { + id String @id @default(uuid()) + m2 M2[] + + @@allow('all', true) + } + + model M2 { + id String @id @default(uuid()) + value Int @allow('read', !deleted) + deleted Boolean @default(false) + m1 M1 @relation(fields: [m1Id], references:[id]) + m1Id String + m3 M3[] + + @@allow('read', true) + @@allow('create', true) + } + + model M3 { + id String @id @default(uuid()) + value Int @deny('read', deleted) + deleted Boolean @default(false) + m2 M2 @relation(fields: [m2Id], references:[id]) + m2Id String + + @@allow('read', true) + @@allow('create', true) + } + `; + + it('some filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: [ + { + id: '2-1', + value: 1, + m3: { + create: { + id: '3-1', + value: 1, + }, + }, + }, + { + id: '2-2', + value: 2, + deleted: true, + m3: { + create: { + id: '3-2', + value: 2, + deleted: true, + }, + }, + }, + ], + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + some: {}, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + some: { value: { gt: 1 } }, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + some: { id: '2-2' }, + }, + }, + }), + ).toResolveTruthy(); + + // include clause + + const r = await db.m1.findFirst({ + where: { id: '1' }, + include: { + m2: { + where: { + m3: { + some: {}, + }, + }, + }, + }, + }); + expect(r.m2).toHaveLength(2); + + let r1 = await db.m1.findFirst({ + where: { + id: '1', + }, + include: { + m2: { + where: { + m3: { + some: { value: { gt: 1 } }, + }, + }, + }, + }, + }); + expect(r1.m2).toHaveLength(0); + + r1 = await db.m1.findFirst({ + where: { + id: '1', + }, + include: { + m2: { + where: { + m3: { + some: { id: { equals: '3-2' } }, + }, + }, + }, + }, + }); + expect(r1.m2).toHaveLength(1); + }); + + it('none filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: [ + { + id: '2-1', + value: 1, + m3: { + create: { + id: '3-1', + value: 1, + }, + }, + }, + { + id: '2-2', + value: 2, + deleted: true, + m3: { + create: { + id: '3-2', + value: 2, + deleted: true, + }, + }, + }, + ], + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + none: {}, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + none: { value: { gt: 1 } }, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + none: { id: '2-1' }, + }, + }, + }), + ).toResolveFalsy(); + + // include clause + + let r = await db.m1.findFirst({ + where: { + id: '1', + }, + include: { + m2: { + where: { + m3: { + none: { value: { gt: 1 } }, + }, + }, + }, + }, + }); + expect(r.m2).toHaveLength(2); + + r = await db.m1.findFirst({ + where: { + id: '1', + }, + include: { + m2: { + where: { + m3: { + none: { id: { equals: '3-2' } }, + }, + }, + }, + }, + }); + expect(r.m2).toHaveLength(1); + }); + + it('every filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: [ + { + id: '2-1', + value: 1, + m3: { + create: { + id: '3-1', + value: 1, + }, + }, + }, + { + id: '2-2', + value: 2, + deleted: true, + m3: { + create: { + id: '3-2', + value: 2, + deleted: true, + }, + }, + }, + ], + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + every: {}, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + every: { value: { gt: 1 } }, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + every: { id: { contains: '2' } }, + }, + }, + }), + ).toResolveTruthy(); + + // include clause + + const r = await db.m1.findFirst({ + where: { id: '1' }, + include: { + m2: { + where: { + m3: { + every: {}, + }, + }, + }, + }, + }); + expect(r.m2).toHaveLength(2); + + let r1 = await db.m1.findFirst({ + where: { + id: '1', + }, + include: { + m2: { + where: { + m3: { + every: { value: { gt: 1 } }, + }, + }, + }, + }, + }); + expect(r1.m2).toHaveLength(1); + + r1 = await db.m1.findFirst({ + where: { + id: '1', + }, + include: { + m2: { + where: { + m3: { + every: { id: { contains: '3' } }, + }, + }, + }, + }, + }); + expect(r1.m2).toHaveLength(2); + }); +}); + +// TODO: field-level policy support +describe.skip('Relation one-to-many filter with field-level override rules', () => { + const model = ` + model M1 { + id String @id @default(uuid()) + m2 M2[] + + @@allow('all', true) + } + + model M2 { + id String @id @default(uuid()) @allow('read', true, true) + value Int @allow('read', !deleted) + deleted Boolean @default(false) + m1 M1 @relation(fields: [m1Id], references:[id]) + m1Id String + + @@allow('read', !deleted) + @@allow('create', true) + } + `; + + it('some filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: [ + { + id: '2-1', + value: 1, + }, + { + id: '2-2', + value: 2, + deleted: true, + }, + ], + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + some: {}, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + some: { value: { gt: 1 } }, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + some: { id: '2-2' }, + }, + }, + }), + ).toResolveTruthy(); + }); + + it('none filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: [ + { + id: '2-1', + value: 1, + }, + { + id: '2-2', + value: 2, + deleted: true, + }, + ], + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + none: {}, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + none: { value: { gt: 1 } }, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + none: { id: '2-1' }, + }, + }, + }), + ).toResolveFalsy(); + }); + + it('every filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: [ + { + id: '2-1', + value: 1, + }, + { + id: '2-2', + value: 2, + deleted: true, + }, + ], + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + every: {}, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + every: { value: { gt: 1 } }, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + every: { id: { contains: '2' } }, + }, + }, + }), + ).toResolveTruthy(); + }); +}); diff --git a/packages/runtime/test/policy/migrated/relation-one-to-one-filter.test.ts b/packages/runtime/test/policy/migrated/relation-one-to-one-filter.test.ts new file mode 100644 index 00000000..060eea77 --- /dev/null +++ b/packages/runtime/test/policy/migrated/relation-one-to-one-filter.test.ts @@ -0,0 +1,1096 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from '../utils'; + +describe('Relation one-to-one filter', () => { + const model = ` + model M1 { + id String @id @default(uuid()) + m2 M2? + + @@allow('all', true) + } + + model M2 { + id String @id @default(uuid()) + value Int + deleted Boolean @default(false) + m1 M1 @relation(fields: [m1Id], references:[id]) + m1Id String @unique + m3 M3? + + @@allow('read', !deleted) + @@allow('create', true) + } + + model M3 { + id String @id @default(uuid()) + value Int + deleted Boolean @default(false) + m2 M2 @relation(fields: [m2Id], references:[id]) + m2Id String @unique + + @@allow('read', !deleted) + @@allow('create', true) + } + `; + + it('is filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: { + value: 1, + m3: { + create: { + value: 1, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + is: { value: 1 }, + }, + }, + }), + ).toResolveTruthy(); + + // m1 with m2 + await db.m1.create({ + data: { + id: '2', + m2: { + create: { + value: 1, + deleted: true, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + is: { value: 1 }, + }, + }, + }), + ).toResolveFalsy(); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '3', + m2: { + create: { + value: 1, + m3: { + create: { + value: 1, + deleted: true, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '3', + m2: { + is: { + m3: { value: 1 }, + }, + }, + }, + }), + ).toResolveFalsy(); + + // m1 with null m2 + await db.m1.create({ + data: { + id: '4', + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '4', + m2: { + is: { value: 1 }, + }, + }, + }), + ).toResolveFalsy(); + }); + + it('isNot filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: { + value: 1, + m3: { + create: { + value: 1, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + isNot: { value: 0 }, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + isNot: { value: 1 }, + }, + }, + }), + ).toResolveFalsy(); + + // m1 with m2 + await db.m1.create({ + data: { + id: '2', + m2: { + create: { + value: 1, + deleted: true, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + isNot: { value: 0 }, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + isNot: { value: 1 }, + }, + }, + }), + ).toResolveTruthy(); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '3', + m2: { + create: { + value: 1, + m3: { + create: { + value: 1, + deleted: true, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '3', + m2: { + isNot: { + m3: { + isNot: { value: 0 }, + }, + }, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '3', + m2: { + isNot: { + m3: { + isNot: { value: 1 }, + }, + }, + }, + }, + }), + ).toResolveFalsy(); + + // m1 with null m2 + await db.m1.create({ + data: { + id: '4', + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '4', + m2: { + isNot: { value: 1 }, + }, + }, + }), + ).toResolveTruthy(); + }); + + it('direct object filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: { + value: 1, + m3: { + create: { + value: 1, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + value: 1, + }, + }, + }), + ).toResolveTruthy(); + + // m1 with m2 + await db.m1.create({ + data: { + id: '2', + m2: { + create: { + value: 1, + deleted: true, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + value: 1, + }, + }, + }), + ).toResolveFalsy(); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '3', + m2: { + create: { + value: 1, + m3: { + create: { + value: 1, + deleted: true, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '3', + m2: { + m3: { value: 1 }, + }, + }, + }), + ).toResolveFalsy(); + + // m1 with null m2 + await db.m1.create({ + data: { + id: '4', + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '4', + m2: { + value: 1, + }, + }, + }), + ).toResolveFalsy(); + }); +}); + +// TODO: field-level policy support +describe.skip('Relation one-to-one filter with field-level rules', () => { + const model = ` + model M1 { + id String @id @default(uuid()) + m2 M2? + + @@allow('all', true) + } + + model M2 { + id String @id @default(uuid()) + value Int @allow('read', !deleted) + deleted Boolean @default(false) + m1 M1 @relation(fields: [m1Id], references:[id]) + m1Id String @unique + m3 M3? + + @@allow('read', true) + @@allow('create', true) + } + + model M3 { + id String @id @default(uuid()) + value Int @allow('read', !deleted) + deleted Boolean @default(false) + m2 M2 @relation(fields: [m2Id], references:[id]) + m2Id String @unique + + @@allow('read', true) + @@allow('create', true) + } + `; + + it('is filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: { + id: '1', + value: 1, + m3: { + create: { + id: '1', + value: 1, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + is: { value: 1 }, + }, + }, + }), + ).toResolveTruthy(); + + // m1 with m2 + await db.m1.create({ + data: { + id: '2', + m2: { + create: { + id: '2', + value: 1, + deleted: true, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + is: { value: 1 }, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + is: { id: '2' }, + }, + }, + }), + ).toResolveTruthy(); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '3', + m2: { + create: { + id: '3', + value: 1, + m3: { + create: { + id: '3', + value: 1, + deleted: true, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '3', + m2: { + is: { + m3: { value: 1 }, + }, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '3', + m2: { + is: { + m3: { id: '3' }, + }, + }, + }, + }), + ).toResolveTruthy(); + }); + + it('isNot filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: { + id: '1', + value: 1, + m3: { + create: { + id: '1', + value: 1, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + isNot: { value: 0 }, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + isNot: { value: 1 }, + }, + }, + }), + ).toResolveFalsy(); + + // m1 with m2 + await db.m1.create({ + data: { + id: '2', + m2: { + create: { + id: '2', + value: 1, + deleted: true, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + isNot: { value: 0 }, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + isNot: { value: 1 }, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + isNot: { id: '2' }, + }, + }, + }), + ).toResolveFalsy(); + }); + + it('direct object filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: { + id: '1', + value: 1, + m3: { + create: { + value: 1, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + value: 1, + }, + }, + }), + ).toResolveTruthy(); + + // m1 with m2 + await db.m1.create({ + data: { + id: '2', + m2: { + create: { + id: '2', + value: 1, + deleted: true, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + value: 1, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + id: '2', + }, + }, + }), + ).toResolveTruthy(); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '3', + m2: { + create: { + id: '3', + value: 1, + m3: { + create: { + id: '3', + value: 1, + deleted: true, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '3', + m2: { + m3: { value: 1 }, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '3', + m2: { + m3: { id: '3' }, + }, + }, + }), + ).toResolveTruthy(); + }); +}); + +// TODO: field-level policy support +describe.skip('Relation one-to-one filter with field-level override rules', () => { + const model = ` + model M1 { + id String @id @default(uuid()) + m2 M2? + + @@allow('all', true) + } + + model M2 { + id String @id @default(uuid()) @allow('read', true, true) + value Int + deleted Boolean @default(false) + m1 M1 @relation(fields: [m1Id], references:[id]) + m1Id String @unique + m3 M3? + + @@allow('read', !deleted) + @@allow('create', true) + } + + model M3 { + id String @id @default(uuid()) @allow('read', true, true) + value Int + deleted Boolean @default(false) + m2 M2 @relation(fields: [m2Id], references:[id]) + m2Id String @unique + + @@allow('read', !deleted) + @@allow('create', true) + } + `; + + it('is filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: { + id: '1', + value: 1, + m3: { + create: { + id: '1', + value: 1, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + is: { value: 1 }, + }, + }, + }), + ).toResolveTruthy(); + + // m1 with m2 + await db.m1.create({ + data: { + id: '2', + m2: { + create: { + id: '2', + value: 1, + deleted: true, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + is: { value: 1 }, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + is: { id: '2' }, + }, + }, + }), + ).toResolveTruthy(); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '3', + m2: { + create: { + id: '3', + value: 1, + m3: { + create: { + id: '3', + value: 1, + deleted: true, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '3', + m2: { + is: { + m3: { value: 1 }, + }, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '3', + m2: { + is: { + m3: { id: '3' }, + }, + }, + }, + }), + ).toResolveTruthy(); + }); + + it('isNot filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: { + id: '1', + value: 1, + m3: { + create: { + id: '1', + value: 1, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + isNot: { value: 0 }, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + isNot: { value: 1 }, + }, + }, + }), + ).toResolveFalsy(); + + // m1 with m2 + await db.m1.create({ + data: { + id: '2', + m2: { + create: { + id: '2', + value: 1, + deleted: true, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + isNot: { value: 0 }, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + isNot: { value: 1 }, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + isNot: { id: '2' }, + }, + }, + }), + ).toResolveFalsy(); + }); + + it('direct object filter', async () => { + const db = await createPolicyTestClient(model); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: { + id: '1', + value: 1, + m3: { + create: { + value: 1, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + value: 1, + }, + }, + }), + ).toResolveTruthy(); + + // m1 with m2 + await db.m1.create({ + data: { + id: '2', + m2: { + create: { + id: '2', + value: 1, + deleted: true, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + value: 1, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + id: '2', + }, + }, + }), + ).toResolveTruthy(); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '3', + m2: { + create: { + id: '3', + value: 1, + m3: { + create: { + id: '3', + value: 1, + deleted: true, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '3', + m2: { + m3: { value: 1 }, + }, + }, + }), + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '3', + m2: { + m3: { id: '3' }, + }, + }, + }), + ).toResolveTruthy(); + }); +}); diff --git a/packages/runtime/test/policy/migrated/self-relation.test.ts b/packages/runtime/test/policy/migrated/self-relation.test.ts new file mode 100644 index 00000000..f06c34d2 --- /dev/null +++ b/packages/runtime/test/policy/migrated/self-relation.test.ts @@ -0,0 +1,201 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from '../utils'; + +describe('Policy self relations tests', () => { + it('one-to-one', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + value Int + successorId Int? @unique + successor User? @relation("BlogOwnerHistory", fields: [successorId], references: [id]) + predecessor User? @relation("BlogOwnerHistory") + + @@allow('create,update', value > 0) + @@allow('read', true) + } + `, + { usePrismaPush: true }, + ); + + // create denied + await expect( + db.user.create({ + data: { + value: 0, + }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.user.create({ + data: { + value: 1, + successor: { + create: { + value: 0, + }, + }, + }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.user.create({ + data: { + value: 1, + successor: { + create: { + value: 1, + }, + }, + predecessor: { + create: { + value: 0, + }, + }, + }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.user.create({ + data: { + value: 1, + successor: { + create: { + value: 1, + }, + }, + predecessor: { + create: { + value: 1, + }, + }, + }, + }), + ).toResolveTruthy(); + }); + + it('one-to-many', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + value Int + teacherId Int? + teacher User? @relation("TeacherStudents", fields: [teacherId], references: [id]) + students User[] @relation("TeacherStudents") + + @@allow('create,update', value > 0) + @@allow('read', true) + } + `, + { usePrismaPush: true }, + ); + + // create denied + await expect( + db.user.create({ + data: { + value: 0, + }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.user.create({ + data: { + value: 1, + teacher: { + create: { value: 0 }, + }, + }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.user.create({ + data: { + value: 1, + teacher: { + create: { value: 1 }, + }, + students: { + create: [{ value: 0 }, { value: 1 }], + }, + }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.user.create({ + data: { + value: 1, + teacher: { + create: { value: 1 }, + }, + students: { + create: [{ value: 1 }, { value: 2 }], + }, + }, + }), + ).toResolveTruthy(); + }); + + it('many-to-many', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + value Int + followedBy User[] @relation("UserFollows") + following User[] @relation("UserFollows") + + @@allow('create,update', value > 0) + @@allow('read', true) + } + `, + { usePrismaPush: true }, + ); + + // create denied + await expect( + db.user.create({ + data: { + value: 0, + }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.user.create({ + data: { + value: 1, + followedBy: { create: { value: 0 } }, + }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.user.create({ + data: { + value: 1, + followedBy: { create: { value: 1 } }, + following: { create: [{ value: 0 }, { value: 1 }] }, + }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.user.create({ + data: { + value: 1, + followedBy: { create: { value: 1 } }, + following: { create: [{ value: 1 }, { value: 2 }] }, + }, + }), + ).toResolveTruthy(); + }); +}); diff --git a/packages/runtime/test/policy/migrated/todo-sample.test.ts b/packages/runtime/test/policy/migrated/todo-sample.test.ts new file mode 100644 index 00000000..c81ac3f7 --- /dev/null +++ b/packages/runtime/test/policy/migrated/todo-sample.test.ts @@ -0,0 +1,503 @@ +import { beforeEach, describe, expect, it } from 'vitest'; +import type { ClientContract } from '../../../src'; +import { schema, type SchemaType } from '../../schemas/todo/schema'; +import { createPolicyTestClient } from '../utils'; + +describe('Todo Policy Tests', () => { + let db: ClientContract; + + beforeEach(async () => { + db = await createPolicyTestClient(schema); + }); + + it('user', async () => { + const user1 = { + id: 'user1', + email: 'user1@zenstack.dev', + name: 'User 1', + }; + const user2 = { + id: 'user2', + email: 'user2@zenstack.dev', + name: 'User 2', + }; + + const anonDb = db; + const user1Db = db.$setAuth({ id: user1.id }); + const user2Db = db.$setAuth({ id: user2.id }); + + // create user1 + // create should succeed but result can be read back anonymously + await expect(anonDb.user.create({ data: user1 })).toBeRejectedByPolicy([ + 'result is not allowed to be read back', + ]); + await expect(user1Db.user.findUnique({ where: { id: user1.id } })).toResolveTruthy(); + await expect(user2Db.user.findUnique({ where: { id: user1.id } })).toResolveNull(); + + // create user2 + await expect(anonDb.user.create({ data: user2 })).toBeRejectedByPolicy(); + + // find with user1 should only get user1 + const r = await user1Db.user.findMany(); + expect(r).toHaveLength(1); + expect(r[0]).toEqual(expect.objectContaining(user1)); + + // get user2 as user1 + await expect(user1Db.user.findUnique({ where: { id: user2.id } })).toResolveNull(); + + // add both users into the same space + await expect( + user1Db.space.create({ + data: { + name: 'Space 1', + slug: 'space1', + owner: { connect: { id: user1.id } }, + members: { + create: [ + { + user: { connect: { id: user1.id } }, + role: 'ADMIN', + }, + { + user: { connect: { id: user2.id } }, + role: 'USER', + }, + ], + }, + }, + }), + ).toResolveTruthy(); + + // now both user1 and user2 should be visible + await expect(user1Db.user.findMany()).resolves.toHaveLength(2); + await expect(user2Db.user.findMany()).resolves.toHaveLength(2); + + // update user2 as user1 + await expect( + user2Db.user.update({ + where: { id: user1.id }, + data: { name: 'hello' }, + }), + ).toBeRejectedNotFound(); + + // update user1 as user1 + await expect( + user1Db.user.update({ + where: { id: user1.id }, + data: { name: 'hello' }, + }), + ).toResolveTruthy(); + + // delete user2 as user1 + await expect(user1Db.user.delete({ where: { id: user2.id } })).toBeRejectedNotFound(); + + // delete user1 as user1 + await expect(user1Db.user.delete({ where: { id: user1.id } })).toResolveTruthy(); + await expect(user1Db.user.findUnique({ where: { id: user1.id } })).toResolveNull(); + }); + + it('todo list', async () => { + await createSpaceAndUsers(db.$unuseAll()); + + const anonDb = db; + const emptyUIDDb = db.$setAuth({ id: '' }); + const user1Db = db.$setAuth({ id: user1.id }); + const user2Db = db.$setAuth({ id: user2.id }); + const user3Db = db.$setAuth({ id: user3.id }); + + await expect( + anonDb.list.create({ + data: { + id: 'list1', + title: 'List 1', + owner: { connect: { id: user1.id } }, + space: { connect: { id: space1.id } }, + }, + }), + ).toBeRejectedByPolicy(); + + await expect( + user1Db.list.create({ + data: { + id: 'list1', + title: 'List 1', + owner: { connect: { id: user1.id } }, + space: { connect: { id: space1.id } }, + }, + }), + ).toResolveTruthy(); + + await expect(user1Db.list.findMany()).resolves.toHaveLength(1); + await expect(anonDb.list.findMany()).resolves.toHaveLength(0); + await expect(emptyUIDDb.list.findMany()).resolves.toHaveLength(0); + await expect(anonDb.list.findUnique({ where: { id: 'list1' } })).toResolveNull(); + + // accessible to owner + await expect(user1Db.list.findUnique({ where: { id: 'list1' } })).resolves.toEqual( + expect.objectContaining({ id: 'list1', title: 'List 1' }), + ); + + // accessible to user in the space + await expect(user2Db.list.findUnique({ where: { id: 'list1' } })).toResolveTruthy(); + + // inaccessible to user not in the space + await expect(user3Db.list.findUnique({ where: { id: 'list1' } })).toResolveNull(); + + // make a private list + await user1Db.list.create({ + data: { + id: 'list2', + title: 'List 2', + private: true, + owner: { connect: { id: user1.id } }, + space: { connect: { id: space1.id } }, + }, + }); + + // accessible to owner + await expect(user1Db.list.findUnique({ where: { id: 'list2' } })).toResolveTruthy(); + + // inaccessible to other user in the space + await expect(user2Db.list.findUnique({ where: { id: 'list2' } })).toResolveNull(); + + // create a list which doesn't match credential should fail + await expect( + user1Db.list.create({ + data: { + id: 'list3', + title: 'List 3', + owner: { connect: { id: user2.id } }, + space: { connect: { id: space1.id } }, + }, + }), + ).toBeRejectedByPolicy(); + + // create a list which doesn't match credential's space should fail + await expect( + user1Db.list.create({ + data: { + id: 'list3', + title: 'List 3', + owner: { connect: { id: user1.id } }, + space: { connect: { id: space2.id } }, + }, + }), + ).toBeRejectedByPolicy(); + + // update list + await expect( + user1Db.list.update({ + where: { id: 'list1' }, + data: { + title: 'List 1 updated', + }, + }), + ).resolves.toEqual(expect.objectContaining({ title: 'List 1 updated' })); + + await expect( + user2Db.list.update({ + where: { id: 'list1' }, + data: { + title: 'List 1 updated', + }, + }), + ).toBeRejectedNotFound(); + + // delete list + await expect(user2Db.list.delete({ where: { id: 'list1' } })).toBeRejectedNotFound(); + await expect(user1Db.list.delete({ where: { id: 'list1' } })).toResolveTruthy(); + await expect(user1Db.list.findUnique({ where: { id: 'list1' } })).toResolveNull(); + }); + + it('todo', async () => { + await createSpaceAndUsers(db.$unuseAll()); + + const user1Db = db.$setAuth({ id: user1.id }); + const user2Db = db.$setAuth({ id: user2.id }); + + // create a public list + await user1Db.list.create({ + data: { + id: 'list1', + title: 'List 1', + owner: { connect: { id: user1.id } }, + space: { connect: { id: space1.id } }, + }, + }); + + // create + await expect( + user1Db.todo.create({ + data: { + id: 'todo1', + title: 'Todo 1', + owner: { connect: { id: user1.id } }, + list: { + connect: { id: 'list1' }, + }, + }, + }), + ).toResolveTruthy(); + + await expect( + user2Db.todo.create({ + data: { + id: 'todo2', + title: 'Todo 2', + owner: { connect: { id: user2.id } }, + list: { + connect: { id: 'list1' }, + }, + }, + }), + ).toResolveTruthy(); + + // read + await expect(user1Db.todo.findMany()).resolves.toHaveLength(2); + await expect(user2Db.todo.findMany()).resolves.toHaveLength(2); + + // update, user in the same space can freely update + await expect( + user1Db.todo.update({ + where: { id: 'todo1' }, + data: { + title: 'Todo 1 updated', + }, + }), + ).toResolveTruthy(); + await expect( + user1Db.todo.update({ + where: { id: 'todo2' }, + data: { + title: 'Todo 2 updated', + }, + }), + ).toResolveTruthy(); + + // create a private list + await user1Db.list.create({ + data: { + id: 'list2', + private: true, + title: 'List 2', + owner: { connect: { id: user1.id } }, + space: { connect: { id: space1.id } }, + }, + }); + + // create + await expect( + user1Db.todo.create({ + data: { + id: 'todo3', + title: 'Todo 3', + owner: { connect: { id: user1.id } }, + list: { + connect: { id: 'list2' }, + }, + }, + }), + ).toResolveTruthy(); + + // reject because list2 is private + await expect( + user2Db.todo.create({ + data: { + id: 'todo4', + title: 'Todo 4', + owner: { connect: { id: user2.id } }, + list: { + connect: { id: 'list2' }, + }, + }, + }), + ).toBeRejectedByPolicy(); + + // update, only owner can update todo in a private list + await expect( + user1Db.todo.update({ + where: { id: 'todo3' }, + data: { + title: 'Todo 3 updated', + }, + }), + ).toResolveTruthy(); + await expect( + user2Db.todo.update({ + where: { id: 'todo3' }, + data: { + title: 'Todo 3 updated', + }, + }), + ).toBeRejectedNotFound(); + }); + + it('relation query', async () => { + await createSpaceAndUsers(db.$unuseAll()); + + const user1Db = db.$setAuth({ id: user1.id }); + const user2Db = db.$setAuth({ id: user2.id }); + + await user1Db.list.create({ + data: { + id: 'list1', + title: 'List 1', + owner: { connect: { id: user1.id } }, + space: { connect: { id: space1.id } }, + }, + }); + + await user1Db.list.create({ + data: { + id: 'list2', + title: 'List 2', + private: true, + owner: { connect: { id: user1.id } }, + space: { connect: { id: space1.id } }, + }, + }); + + const r = await user1Db.space.findFirstOrThrow({ + where: { id: 'space1' }, + include: { lists: true }, + }); + expect(r.lists).toHaveLength(2); + + const r1 = await user2Db.space.findFirstOrThrow({ + where: { id: 'space1' }, + include: { lists: true }, + }); + expect(r1.lists).toHaveLength(1); + }); + + // TODO: `future()` support + it.skip('post-update checks', async () => { + await createSpaceAndUsers(db.$unuseAll()); + + const user1Db = db.$setAuth({ id: user1.id }); + + await user1Db.list.create({ + data: { + id: 'list1', + title: 'List 1', + owner: { connect: { id: user1.id } }, + space: { connect: { id: space1.id } }, + todos: { + create: { + id: 'todo1', + title: 'Todo 1', + owner: { connect: { id: user1.id } }, + }, + }, + }, + }); + + // change list's owner + await expect( + user1Db.list.update({ + where: { id: 'list1' }, + data: { + owner: { connect: { id: user2.id } }, + }, + }), + ).toBeRejectedByPolicy(); + + // change todo's owner + await expect( + user1Db.todo.update({ + where: { id: 'todo1' }, + data: { + owner: { connect: { id: user2.id } }, + }, + }), + ).toBeRejectedByPolicy(); + + // nested change todo's owner + await expect( + user1Db.list.update({ + where: { id: 'list1' }, + data: { + todos: { + update: { + where: { id: 'todo1' }, + data: { + owner: { connect: { id: user2.id } }, + }, + }, + }, + }, + }), + ).toBeRejectedByPolicy(); + }); +}); + +const user1 = { + id: 'user1', + email: 'user1@zenstack.dev', + name: 'User 1', +}; + +const user2 = { + id: 'user2', + email: 'user2@zenstack.dev', + name: 'User 2', +}; + +const user3 = { + id: 'user3', + email: 'user3@zenstack.dev', + name: 'User 3', +}; + +const space1 = { + id: 'space1', + name: 'Space 1', + slug: 'space1', +}; + +const space2 = { + id: 'space2', + name: 'Space 2', + slug: 'space2', +}; + +async function createSpaceAndUsers(db: ClientContract) { + // create users + await db.user.create({ data: user1 }); + await db.user.create({ data: user2 }); + await db.user.create({ data: user3 }); + + // add user1 and user2 into space1 + await db.space.create({ + data: { + ...space1, + members: { + create: [ + { + user: { connect: { id: user1.id } }, + role: 'ADMIN', + }, + { + user: { connect: { id: user2.id } }, + role: 'USER', + }, + ], + }, + }, + }); + + // add user3 to space2 + await db.space.create({ + data: { + ...space2, + members: { + create: [ + { + user: { connect: { id: user3.id } }, + role: 'ADMIN', + }, + ], + }, + }, + }); +} diff --git a/packages/runtime/test/policy/migrated/toplevel-operations.test.ts b/packages/runtime/test/policy/migrated/toplevel-operations.test.ts new file mode 100644 index 00000000..f545148c --- /dev/null +++ b/packages/runtime/test/policy/migrated/toplevel-operations.test.ts @@ -0,0 +1,258 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from '../utils'; + +describe('Policy toplevel operations tests', () => { + it('read tests', async () => { + const db = await createPolicyTestClient( + ` + model Model { + id String @id @default(uuid()) + value Int + + @@allow('create', true) + @@allow('read', value > 1) + } + `, + ); + + await expect( + db.model.create({ + data: { + id: '1', + value: 1, + }, + }), + ).toBeRejectedByPolicy(); + const fromPrisma = await db.$unuseAll().model.findUnique({ + where: { id: '1' }, + }); + expect(fromPrisma).toBeTruthy(); + + expect(await db.model.findMany()).toHaveLength(0); + expect(await db.model.findUnique({ where: { id: '1' } })).toBeNull(); + expect(await db.model.findFirst({ where: { id: '1' } })).toBeNull(); + await expect(db.model.findUniqueOrThrow({ where: { id: '1' } })).toBeRejectedNotFound(); + await expect(db.model.findFirstOrThrow({ where: { id: '1' } })).toBeRejectedNotFound(); + + const item2 = { + id: '2', + value: 2, + }; + const r1 = await db.model.create({ + data: item2, + }); + expect(r1).toBeTruthy(); + expect(await db.model.findMany()).toHaveLength(1); + expect(await db.model.findUnique({ where: { id: '2' } })).toEqual(expect.objectContaining(item2)); + expect(await db.model.findFirst({ where: { id: '2' } })).toEqual(expect.objectContaining(item2)); + expect(await db.model.findUniqueOrThrow({ where: { id: '2' } })).toEqual(expect.objectContaining(item2)); + expect(await db.model.findFirstOrThrow({ where: { id: '2' } })).toEqual(expect.objectContaining(item2)); + }); + + it('write tests', async () => { + const db = await createPolicyTestClient( + ` + model Model { + id String @id @default(uuid()) + value Int + + @@allow('read', value > 1) + @@allow('create', value > 0) + @@allow('update', value > 1) + } + `, + ); + + // create denied + await expect( + db.model.create({ + data: { + value: 0, + }, + }), + ).toBeRejectedByPolicy(); + + // can't read back + await expect( + db.model.create({ + data: { + id: '1', + value: 1, + }, + }), + ).toBeRejectedByPolicy(); + + // success + expect( + await db.model.create({ + data: { + id: '2', + value: 2, + }, + }), + ).toBeTruthy(); + + // update not found + await expect(db.model.update({ where: { id: '3' }, data: { value: 5 } })).toBeRejectedNotFound(); + + // update-many empty + expect( + await db.model.updateMany({ + where: { id: '3' }, + data: { value: 5 }, + }), + ).toEqual(expect.objectContaining({ count: 0 })); + + // upsert + expect( + await db.model.upsert({ + where: { id: '3' }, + create: { id: '3', value: 5 }, + update: { value: 6 }, + }), + ).toEqual(expect.objectContaining({ value: 5 })); + + // update denied + await expect( + db.model.update({ + where: { id: '1' }, + data: { + value: 3, + }, + }), + ).toBeRejectedNotFound(); + + // update success + expect( + await db.model.update({ + where: { id: '2' }, + data: { + value: 3, + }, + }), + ).toBeTruthy(); + }); + + // TODO: `future()` support + it.skip('update id tests', async () => { + const db = await createPolicyTestClient( + ` + model Model { + id String @id @default(uuid()) + value Int + + @@allow('read', value > 1) + @@allow('create', value > 0) + @@allow('update', value > 1 && future().value > 2) + } + `, + ); + + await db.model.create({ + data: { + id: '1', + value: 2, + }, + }); + + // update denied + await expect( + db.model.update({ + where: { id: '1' }, + data: { + id: '2', + value: 1, + }, + }), + ).toBeRejectedNotFound(); + + // update success + await expect( + db.model.update({ + where: { id: '1' }, + data: { + id: '2', + value: 3, + }, + }), + ).resolves.toMatchObject({ id: '2', value: 3 }); + + // upsert denied + await expect( + db.model.upsert({ + where: { id: '2' }, + update: { + id: '3', + value: 1, + }, + create: { + id: '4', + value: 5, + }, + }), + ).toBeRejectedByPolicy(); + + // upsert success + await expect( + db.model.upsert({ + where: { id: '2' }, + update: { + id: '3', + value: 4, + }, + create: { + id: '4', + value: 5, + }, + }), + ).resolves.toMatchObject({ id: '3', value: 4 }); + }); + + it('delete tests', async () => { + const db = await createPolicyTestClient( + ` + model Model { + id String @id @default(uuid()) + value Int + + @@allow('create', true) + @@allow('read', value > 2) + @@allow('delete', value > 1) + } + `, + ); + + await expect(db.model.delete({ where: { id: '1' } })).toBeRejectedNotFound(); + + await expect( + db.model.create({ + data: { id: '1', value: 1 }, + }), + ).toBeRejectedByPolicy(); + + await expect(db.model.delete({ where: { id: '1' } })).toBeRejectedNotFound(); + await expect(db.$unuseAll().model.findUnique({ where: { id: '1' } })).toResolveTruthy(); + + await expect( + db.model.create({ + data: { id: '2', value: 2 }, + }), + ).toBeRejectedByPolicy(); + await expect(db.$unuseAll().model.findUnique({ where: { id: '2' } })).toBeTruthy(); + // deleted but unable to read back + await expect(db.model.delete({ where: { id: '2' } })).toBeRejectedByPolicy(); + await expect(db.$unuseAll().model.findUnique({ where: { id: '2' } })).toResolveNull(); + + await expect( + db.model.create({ + data: { id: '2', value: 2 }, + }), + ).toBeRejectedByPolicy(); + // only '2' is deleted, '1' is rejected by policy + expect(await db.model.deleteMany()).toEqual(expect.objectContaining({ count: 1 })); + expect(await db.$unuseAll().model.findUnique({ where: { id: '2' } })).toBeNull(); + expect(await db.$unuseAll().model.findUnique({ where: { id: '1' } })).toBeTruthy(); + + expect(await db.model.deleteMany()).toEqual(expect.objectContaining({ count: 0 })); + }); +}); diff --git a/packages/runtime/test/policy/migrated/unique-as-id.test.ts b/packages/runtime/test/policy/migrated/unique-as-id.test.ts new file mode 100644 index 00000000..6b3e8588 --- /dev/null +++ b/packages/runtime/test/policy/migrated/unique-as-id.test.ts @@ -0,0 +1,276 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from '../utils'; + +describe('Policy unique as id tests', () => { + it('unique fields', async () => { + const db = await createPolicyTestClient( + ` + model A { + x String @unique + y Int @unique + value Int + b B? + + @@allow('read', true) + @@allow('create', value > 0) + } + + model B { + b1 String @unique + b2 String @unique + value Int + a A @relation(fields: [ax], references: [x]) + ax String @unique + + @@allow('read', value > 2) + @@allow('create', value > 1) + } + `, + ); + + await expect(db.a.create({ data: { x: '1', y: 1, value: 0 } })).toBeRejectedByPolicy(); + await expect(db.a.create({ data: { x: '1', y: 2, value: 1 } })).toResolveTruthy(); + + await expect( + db.a.create({ data: { x: '2', y: 3, value: 1, b: { create: { b1: '1', b2: '2', value: 1 } } } }), + ).toBeRejectedByPolicy(); + + const r = await db.a.create({ + include: { b: true }, + data: { x: '2', y: 3, value: 1, b: { create: { b1: '1', b2: '2', value: 2 } } }, + }); + expect(r.b).toBeNull(); + const r1 = await db.$unuseAll().b.findUnique({ where: { b1: '1' } }); + expect(r1.value).toBe(2); + + await expect( + db.a.create({ + include: { b: true }, + data: { x: '3', y: 4, value: 1, b: { create: { b1: '2', b2: '3', value: 3 } } }, + }), + ).toResolveTruthy(); + }); + + it('unique fields mixed with id', async () => { + const db = await createPolicyTestClient( + ` + model A { + id Int @id @default(autoincrement()) + x String @unique + y Int @unique + value Int + b B? + + @@allow('read', true) + @@allow('create', value > 0) + } + + model B { + id Int @id @default(autoincrement()) + b1 String @unique + b2 String @unique + value Int + a A @relation(fields: [ax], references: [x]) + ax String @unique + + @@allow('read', value > 2) + @@allow('create', value > 1) + } + `, + ); + + await expect(db.a.create({ data: { x: '1', y: 1, value: 0 } })).toBeRejectedByPolicy(); + await expect(db.a.create({ data: { x: '1', y: 2, value: 1 } })).toResolveTruthy(); + + await expect( + db.a.create({ data: { x: '2', y: 3, value: 1, b: { create: { b1: '1', b2: '2', value: 1 } } } }), + ).toBeRejectedByPolicy(); + + const r = await db.a.create({ + include: { b: true }, + data: { x: '2', y: 3, value: 1, b: { create: { b1: '1', b2: '2', value: 2 } } }, + }); + expect(r.b).toBeNull(); + const r1 = await db.$unuseAll().b.findUnique({ where: { b1: '1' } }); + expect(r1.value).toBe(2); + + await expect( + db.a.create({ + include: { b: true }, + data: { x: '3', y: 4, value: 1, b: { create: { b1: '2', b2: '3', value: 3 } } }, + }), + ).toResolveTruthy(); + }); + + it('model-level unique fields', async () => { + const db = await createPolicyTestClient( + ` + model A { + x String + y Int + value Int + b B? + @@unique([x, y]) + + @@allow('read', true) + @@allow('create', value > 0) + } + + model B { + b1 String + b2 String + value Int + a A @relation(fields: [ax, ay], references: [x, y]) + ax String + ay Int + + @@allow('read', value > 2) + @@allow('create', value > 1) + + @@unique([ax, ay]) + @@unique([b1, b2]) + } + `, + ); + + await expect(db.a.create({ data: { x: '1', y: 1, value: 0 } })).toBeRejectedByPolicy(); + await expect(db.a.create({ data: { x: '1', y: 2, value: 1 } })).toResolveTruthy(); + + await expect( + db.a.create({ data: { x: '2', y: 1, value: 1, b: { create: { b1: '1', b2: '2', value: 1 } } } }), + ).toBeRejectedByPolicy(); + + const r = await db.a.create({ + include: { b: true }, + data: { x: '2', y: 1, value: 1, b: { create: { b1: '1', b2: '2', value: 2 } } }, + }); + expect(r.b).toBeNull(); + const r1 = await db.$unuseAll().b.findUnique({ where: { b1_b2: { b1: '1', b2: '2' } } }); + expect(r1.value).toBe(2); + + await expect( + db.a.create({ + include: { b: true }, + data: { x: '3', y: 1, value: 1, b: { create: { b1: '2', b2: '2', value: 3 } } }, + }), + ).toResolveTruthy(); + }); + + it('unique fields with to-many nested update', async () => { + const db = await createPolicyTestClient( + ` + model A { + id Int @id @default(autoincrement()) + x Int + y Int + value Int + bs B[] + @@unique([x, y]) + + @@allow('read,create', true) + @@allow('update,delete', value > 0) + } + + model B { + id Int @id @default(autoincrement()) + value Int + a A @relation(fields: [aId], references: [id]) + aId Int + + @@allow('all', value > 0) + } + `, + ); + + await db.a.create({ + data: { x: 1, y: 1, value: 1, bs: { create: [{ id: 1, value: 1 }] } }, + }); + + await db.a.create({ + data: { x: 2, y: 2, value: 2, bs: { create: [{ id: 2, value: 2 }] } }, + }); + + await db.a.update({ + where: { x_y: { x: 1, y: 1 } }, + data: { bs: { updateMany: { where: {}, data: { value: 3 } } } }, + }); + + // check b#1 is updated + await expect(db.b.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ value: 3 }); + + // check b#2 is not affected + await expect(db.b.findUnique({ where: { id: 2 } })).resolves.toMatchObject({ value: 2 }); + + await db.a.update({ + where: { x_y: { x: 1, y: 1 } }, + data: { bs: { deleteMany: {} } }, + }); + + // check b#1 is deleted + await expect(db.b.findUnique({ where: { id: 1 } })).resolves.toBeNull(); + + // check b#2 is not affected + await expect(db.b.findUnique({ where: { id: 2 } })).resolves.toMatchObject({ value: 2 }); + }); + + it('unique fields with to-one nested update', async () => { + const db = await createPolicyTestClient( + ` + model A { + id Int @id @default(autoincrement()) + x Int + y Int + value Int + b B? + @@unique([x, y]) + + @@allow('read,create', true) + @@allow('update,delete', value > 0) + } + + model B { + id Int @id @default(autoincrement()) + value Int + a A @relation(fields: [aId], references: [id]) + aId Int @unique + + @@allow('all', value > 0) + } + `, + ); + + await db.a.create({ + data: { x: 1, y: 1, value: 1, b: { create: { id: 1, value: 1 } } }, + }); + + await db.a.create({ + data: { x: 2, y: 2, value: 2, b: { create: { id: 2, value: 2 } } }, + }); + + await db.a.update({ + where: { x_y: { x: 1, y: 1 } }, + data: { b: { update: { data: { value: 3 } } } }, + }); + + // check b#1 is updated + await expect(db.b.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ value: 3 }); + + // check b#2 is not affected + await expect(db.b.findUnique({ where: { id: 2 } })).resolves.toMatchObject({ value: 2 }); + + await db.a.update({ + where: { x_y: { x: 1, y: 1 } }, + data: { b: { delete: true } }, + }); + + // check b#1 is deleted + await expect(db.b.findUnique({ where: { id: 1 } })).resolves.toBeNull(); + await expect(db.a.findUnique({ where: { x_y: { x: 1, y: 1 } }, include: { b: true } })).resolves.toMatchObject({ + b: null, + }); + + // check b#2 is not affected + await expect(db.b.findUnique({ where: { id: 2 } })).resolves.toMatchObject({ value: 2 }); + await expect(db.a.findUnique({ where: { x_y: { x: 2, y: 2 } }, include: { b: true } })).resolves.toBeTruthy(); + }); +}); diff --git a/packages/runtime/test/policy/migrated/update-many-and-return.test.ts b/packages/runtime/test/policy/migrated/update-many-and-return.test.ts new file mode 100644 index 00000000..ba83335b --- /dev/null +++ b/packages/runtime/test/policy/migrated/update-many-and-return.test.ts @@ -0,0 +1,120 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from '../utils'; + +describe('Policy updateManyAndReturn tests', () => { + it('model-level policies', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + posts Post[] + level Int + + @@allow('read', level > 0) + } + + model Post { + id Int @id @default(autoincrement()) + title String + published Boolean @default(false) + userId Int + user User @relation(fields: [userId], references: [id]) + + @@allow('read', published) + @@allow('update', contains(title, 'hello')) + } + `, + ); + + const rawDb = db.$unuseAll(); + + await rawDb.user.createMany({ + data: [{ id: 1, level: 1 }], + }); + await rawDb.user.createMany({ + data: [{ id: 2, level: 0 }], + }); + + await rawDb.post.createMany({ + data: [ + { id: 1, title: 'hello1', userId: 1, published: true }, + { id: 2, title: 'world1', userId: 1, published: false }, + ], + }); + + // only post#1 is updated + const r = await db.post.updateManyAndReturn({ + data: { title: 'foo' }, + }); + expect(r).toHaveLength(1); + expect(r[0].id).toBe(1); + + // post#2 is excluded from update + await expect( + db.post.updateManyAndReturn({ + where: { id: 2 }, + data: { title: 'foo' }, + }), + ).resolves.toHaveLength(0); + + // reset + await rawDb.post.update({ where: { id: 1 }, data: { title: 'hello1' } }); + + // post#1 is updated + await expect( + db.post.updateManyAndReturn({ + where: { id: 1 }, + data: { title: 'foo' }, + }), + ).resolves.toHaveLength(1); + + // reset + await rawDb.post.update({ where: { id: 1 }, data: { title: 'hello1' } }); + + // read-back check + // post#1 updated but can't be read back + await expect( + db.post.updateManyAndReturn({ + data: { published: false }, + }), + ).toBeRejectedByPolicy(['result is not allowed to be read back']); + // but the update should have been applied + await expect(db.$unuseAll().post.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ published: false }); + }); + + // TODO: field-level policy support + it.skip('field-level policies', async () => { + const db = await createPolicyTestClient( + ` + model Post { + id Int @id @default(autoincrement()) + title String @allow('read', published) + published Boolean @default(false) + + @@allow('all', true) + } + `, + ); + + const rawDb = db.$unuseAll(); + + // update should succeed but one result's title field can't be read back + await rawDb.post.createMany({ + data: [ + { id: 1, title: 'post1', published: true }, + { id: 2, title: 'post2', published: false }, + ], + }); + + const r = await db.post.updateManyAndReturn({ + data: { title: 'foo' }, + }); + + expect(r.length).toBe(2); + expect(r[0].title).toBeTruthy(); + expect(r[1].title).toBeUndefined(); + + // check posts are updated + await expect(rawDb.post.findMany({ where: { title: 'foo' } })).resolves.toHaveLength(2); + }); +}); diff --git a/packages/runtime/test/policy/migrated/view.test.ts b/packages/runtime/test/policy/migrated/view.test.ts new file mode 100644 index 00000000..010a7f1f --- /dev/null +++ b/packages/runtime/test/policy/migrated/view.test.ts @@ -0,0 +1,84 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from '../utils'; + +describe('View Policy Test', () => { + it('view policy', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + email String @unique + name String? + posts Post[] + userInfo UserInfo? + } + + model Post { + id Int @id @default(autoincrement()) + title String + content String? + published Boolean @default(false) + author User? @relation(fields: [authorId], references: [id]) + authorId Int? + } + + view UserInfo { + id Int @unique + name String + email String + postCount Int + user User @relation(fields: [id], references: [id]) + + @@allow('read', postCount > 1) + } + `, + ); + + const rawDb = db.$unuseAll(); + + await rawDb.$executeRaw`CREATE VIEW "UserInfo" as select "User"."id", "User"."name", "User"."email", "User"."id" as "userId", count("Post"."id") as "postCount" from "User" left join "Post" on "User"."id" = "Post"."authorId" group by "User"."id";`; + + await rawDb.user.create({ + data: { + email: 'alice@prisma.io', + name: 'Alice', + posts: { + create: { + title: 'Check out Prisma with Next.js', + content: 'https://www.prisma.io/nextjs', + published: true, + }, + }, + }, + }); + await rawDb.user.create({ + data: { + email: 'bob@prisma.io', + name: 'Bob', + posts: { + create: [ + { + title: 'Follow Prisma on Twitter', + content: 'https://twitter.com/prisma', + published: true, + }, + { + title: 'Follow Nexus on Twitter', + content: 'https://twitter.com/nexusgql', + published: false, + }, + ], + }, + }, + }); + + await expect(rawDb.userInfo.findMany()).resolves.toHaveLength(2); + await expect(db.userInfo.findMany()).resolves.toHaveLength(1); + + const r1 = await rawDb.userInfo.findFirst({ include: { user: true } }); + expect(r1.user).toBeTruthy(); + + // user not readable + await expect(db.userInfo.findFirst({ include: { user: true } })).resolves.toMatchObject({ user: null }); + }); +}); diff --git a/packages/runtime/test/policy/policy-functions.test.ts b/packages/runtime/test/policy/policy-functions.test.ts index d37eff4b..2ac094b0 100644 --- a/packages/runtime/test/policy/policy-functions.test.ts +++ b/packages/runtime/test/policy/policy-functions.test.ts @@ -2,7 +2,7 @@ import { describe, expect, it } from 'vitest'; import { createPolicyTestClient } from './utils'; describe('policy functions tests', () => { - it('supports contains with case-sensitive field', async () => { + it('supports contains case-sensitive', async () => { const db = await createPolicyTestClient( ` model Foo { @@ -14,9 +14,51 @@ describe('policy functions tests', () => { ); await expect(db.foo.create({ data: { string: 'bcd' } })).toBeRejectedByPolicy(); + if (db.$schema.provider.type === 'sqlite') { + // sqlite is always case-insensitive + await expect(db.foo.create({ data: { string: 'Acd' } })).toResolveTruthy(); + } else { + await expect(db.foo.create({ data: { string: 'Acd' } })).toBeRejectedByPolicy(); + } await expect(db.foo.create({ data: { string: 'bac' } })).toResolveTruthy(); }); + it('supports contains explicit case-sensitive', async () => { + const db = await createPolicyTestClient( + ` + model Foo { + id String @id @default(cuid()) + string String + @@allow('all', contains(string, 'a', false)) + } + `, + ); + + await expect(db.foo.create({ data: { string: 'bcd' } })).toBeRejectedByPolicy(); + if (db.$schema.provider.type === 'sqlite') { + // sqlite is always case-insensitive + await expect(db.foo.create({ data: { string: 'Acd' } })).toResolveTruthy(); + } else { + await expect(db.foo.create({ data: { string: 'Acd' } })).toBeRejectedByPolicy(); + } + await expect(db.foo.create({ data: { string: 'bac' } })).toResolveTruthy(); + }); + + it('supports contains case-insensitive', async () => { + const db = await createPolicyTestClient( + ` + model Foo { + id String @id @default(cuid()) + string String + @@allow('all', contains(string, 'a', true)) + } + `, + ); + + await expect(db.foo.create({ data: { string: 'bcd' } })).toBeRejectedByPolicy(); + await expect(db.foo.create({ data: { string: 'Abc' } })).toResolveTruthy(); + }); + it('supports contains with case-sensitive non-field', async () => { const db = await createPolicyTestClient( ` @@ -35,6 +77,12 @@ describe('policy functions tests', () => { await expect(db.foo.create({ data: {} })).toBeRejectedByPolicy(); await expect(db.$setAuth({ id: 'user1', name: 'bcd' }).foo.create({ data: {} })).toBeRejectedByPolicy(); await expect(db.$setAuth({ id: 'user1', name: 'bac' }).foo.create({ data: {} })).toResolveTruthy(); + if (db.$schema.provider.type === 'sqlite') { + // sqlite is always case-insensitive + await expect(db.$setAuth({ id: 'user1', name: 'Abc' }).foo.create({ data: {} })).toResolveTruthy(); + } else { + await expect(db.$setAuth({ id: 'user1', name: 'Abc' }).foo.create({ data: {} })).toBeRejectedByPolicy(); + } }); it('supports contains with auth()', async () => { diff --git a/packages/runtime/test/query-builder/query-builder.test.ts b/packages/runtime/test/query-builder/query-builder.test.ts index 8eed03d5..32890468 100644 --- a/packages/runtime/test/query-builder/query-builder.test.ts +++ b/packages/runtime/test/query-builder/query-builder.test.ts @@ -1,18 +1,13 @@ import { createId } from '@paralleldrive/cuid2'; -import SQLite from 'better-sqlite3'; -import { SqliteDialect } from 'kysely'; import { describe, expect, it } from 'vitest'; -import { ZenStackClient } from '../../src'; import { getSchema } from '../schemas/basic'; +import { createTestClient } from '../utils'; describe('Client API tests', () => { const schema = getSchema('sqlite'); it('works with queries', async () => { - const client = new ZenStackClient(schema, { - dialect: new SqliteDialect({ database: new SQLite(':memory:') }), - }); - await client.$pushSchema(); + const client = await createTestClient(schema); const kysely = client.$qb; diff --git a/packages/runtime/test/schemas/petstore/input.ts b/packages/runtime/test/schemas/petstore/input.ts new file mode 100644 index 00000000..6aece67e --- /dev/null +++ b/packages/runtime/test/schemas/petstore/input.ts @@ -0,0 +1,70 @@ +////////////////////////////////////////////////////////////////////////////////////////////// +// DO NOT MODIFY THIS FILE // +// This file is automatically generated by ZenStack CLI and should not be manually updated. // +////////////////////////////////////////////////////////////////////////////////////////////// + +/* eslint-disable */ + +import { type SchemaType as $Schema } from "./schema"; +import type { FindManyArgs as $FindManyArgs, FindUniqueArgs as $FindUniqueArgs, FindFirstArgs as $FindFirstArgs, CreateArgs as $CreateArgs, CreateManyArgs as $CreateManyArgs, CreateManyAndReturnArgs as $CreateManyAndReturnArgs, UpdateArgs as $UpdateArgs, UpdateManyArgs as $UpdateManyArgs, UpdateManyAndReturnArgs as $UpdateManyAndReturnArgs, UpsertArgs as $UpsertArgs, DeleteArgs as $DeleteArgs, DeleteManyArgs as $DeleteManyArgs, CountArgs as $CountArgs, AggregateArgs as $AggregateArgs, GroupByArgs as $GroupByArgs, WhereInput as $WhereInput, SelectInput as $SelectInput, IncludeInput as $IncludeInput, OmitInput as $OmitInput } from "@zenstackhq/runtime"; +import type { SimplifiedModelResult as $SimplifiedModelResult, SelectIncludeOmit as $SelectIncludeOmit } from "@zenstackhq/runtime"; +export type UserFindManyArgs = $FindManyArgs<$Schema, "User">; +export type UserFindUniqueArgs = $FindUniqueArgs<$Schema, "User">; +export type UserFindFirstArgs = $FindFirstArgs<$Schema, "User">; +export type UserCreateArgs = $CreateArgs<$Schema, "User">; +export type UserCreateManyArgs = $CreateManyArgs<$Schema, "User">; +export type UserCreateManyAndReturnArgs = $CreateManyAndReturnArgs<$Schema, "User">; +export type UserUpdateArgs = $UpdateArgs<$Schema, "User">; +export type UserUpdateManyArgs = $UpdateManyArgs<$Schema, "User">; +export type UserUpdateManyAndReturnArgs = $UpdateManyAndReturnArgs<$Schema, "User">; +export type UserUpsertArgs = $UpsertArgs<$Schema, "User">; +export type UserDeleteArgs = $DeleteArgs<$Schema, "User">; +export type UserDeleteManyArgs = $DeleteManyArgs<$Schema, "User">; +export type UserCountArgs = $CountArgs<$Schema, "User">; +export type UserAggregateArgs = $AggregateArgs<$Schema, "User">; +export type UserGroupByArgs = $GroupByArgs<$Schema, "User">; +export type UserWhereInput = $WhereInput<$Schema, "User">; +export type UserSelect = $SelectInput<$Schema, "User">; +export type UserInclude = $IncludeInput<$Schema, "User">; +export type UserOmit = $OmitInput<$Schema, "User">; +export type UserGetPayload> = $SimplifiedModelResult<$Schema, "User", Args>; +export type PetFindManyArgs = $FindManyArgs<$Schema, "Pet">; +export type PetFindUniqueArgs = $FindUniqueArgs<$Schema, "Pet">; +export type PetFindFirstArgs = $FindFirstArgs<$Schema, "Pet">; +export type PetCreateArgs = $CreateArgs<$Schema, "Pet">; +export type PetCreateManyArgs = $CreateManyArgs<$Schema, "Pet">; +export type PetCreateManyAndReturnArgs = $CreateManyAndReturnArgs<$Schema, "Pet">; +export type PetUpdateArgs = $UpdateArgs<$Schema, "Pet">; +export type PetUpdateManyArgs = $UpdateManyArgs<$Schema, "Pet">; +export type PetUpdateManyAndReturnArgs = $UpdateManyAndReturnArgs<$Schema, "Pet">; +export type PetUpsertArgs = $UpsertArgs<$Schema, "Pet">; +export type PetDeleteArgs = $DeleteArgs<$Schema, "Pet">; +export type PetDeleteManyArgs = $DeleteManyArgs<$Schema, "Pet">; +export type PetCountArgs = $CountArgs<$Schema, "Pet">; +export type PetAggregateArgs = $AggregateArgs<$Schema, "Pet">; +export type PetGroupByArgs = $GroupByArgs<$Schema, "Pet">; +export type PetWhereInput = $WhereInput<$Schema, "Pet">; +export type PetSelect = $SelectInput<$Schema, "Pet">; +export type PetInclude = $IncludeInput<$Schema, "Pet">; +export type PetOmit = $OmitInput<$Schema, "Pet">; +export type PetGetPayload> = $SimplifiedModelResult<$Schema, "Pet", Args>; +export type OrderFindManyArgs = $FindManyArgs<$Schema, "Order">; +export type OrderFindUniqueArgs = $FindUniqueArgs<$Schema, "Order">; +export type OrderFindFirstArgs = $FindFirstArgs<$Schema, "Order">; +export type OrderCreateArgs = $CreateArgs<$Schema, "Order">; +export type OrderCreateManyArgs = $CreateManyArgs<$Schema, "Order">; +export type OrderCreateManyAndReturnArgs = $CreateManyAndReturnArgs<$Schema, "Order">; +export type OrderUpdateArgs = $UpdateArgs<$Schema, "Order">; +export type OrderUpdateManyArgs = $UpdateManyArgs<$Schema, "Order">; +export type OrderUpdateManyAndReturnArgs = $UpdateManyAndReturnArgs<$Schema, "Order">; +export type OrderUpsertArgs = $UpsertArgs<$Schema, "Order">; +export type OrderDeleteArgs = $DeleteArgs<$Schema, "Order">; +export type OrderDeleteManyArgs = $DeleteManyArgs<$Schema, "Order">; +export type OrderCountArgs = $CountArgs<$Schema, "Order">; +export type OrderAggregateArgs = $AggregateArgs<$Schema, "Order">; +export type OrderGroupByArgs = $GroupByArgs<$Schema, "Order">; +export type OrderWhereInput = $WhereInput<$Schema, "Order">; +export type OrderSelect = $SelectInput<$Schema, "Order">; +export type OrderInclude = $IncludeInput<$Schema, "Order">; +export type OrderOmit = $OmitInput<$Schema, "Order">; +export type OrderGetPayload> = $SimplifiedModelResult<$Schema, "Order", Args>; diff --git a/packages/runtime/test/schemas/petstore/models.ts b/packages/runtime/test/schemas/petstore/models.ts new file mode 100644 index 00000000..6526b66c --- /dev/null +++ b/packages/runtime/test/schemas/petstore/models.ts @@ -0,0 +1,12 @@ +////////////////////////////////////////////////////////////////////////////////////////////// +// DO NOT MODIFY THIS FILE // +// This file is automatically generated by ZenStack CLI and should not be manually updated. // +////////////////////////////////////////////////////////////////////////////////////////////// + +/* eslint-disable */ + +import { type SchemaType as $Schema } from "./schema"; +import { type ModelResult as $ModelResult } from "@zenstackhq/runtime"; +export type User = $ModelResult<$Schema, "User">; +export type Pet = $ModelResult<$Schema, "Pet">; +export type Order = $ModelResult<$Schema, "Order">; diff --git a/packages/runtime/test/schemas/petstore/schema.ts b/packages/runtime/test/schemas/petstore/schema.ts new file mode 100644 index 00000000..c6902c7e --- /dev/null +++ b/packages/runtime/test/schemas/petstore/schema.ts @@ -0,0 +1,156 @@ +////////////////////////////////////////////////////////////////////////////////////////////// +// DO NOT MODIFY THIS FILE // +// This file is automatically generated by ZenStack CLI and should not be manually updated. // +////////////////////////////////////////////////////////////////////////////////////////////// + +/* eslint-disable */ + +import { type SchemaDef, ExpressionUtils } from "../../../dist/schema"; +export const schema = { + provider: { + type: "sqlite" + }, + models: { + User: { + name: "User", + fields: { + id: { + name: "id", + type: "String", + id: true, + attributes: [{ name: "@id" }, { name: "@default", args: [{ name: "value", value: ExpressionUtils.call("cuid") }] }], + default: ExpressionUtils.call("cuid") + }, + email: { + name: "email", + type: "String", + unique: true, + attributes: [{ name: "@unique" }] + }, + orders: { + name: "orders", + type: "Order", + array: true, + relation: { opposite: "user" } + } + }, + attributes: [ + { name: "@@allow", args: [{ name: "operation", value: ExpressionUtils.literal("create") }, { name: "condition", value: ExpressionUtils.literal(true) }] }, + { name: "@@allow", args: [{ name: "operation", value: ExpressionUtils.literal("read") }, { name: "condition", value: ExpressionUtils.literal(true) }] } + ], + idFields: ["id"], + uniqueFields: { + id: { type: "String" }, + email: { type: "String" } + } + }, + Pet: { + name: "Pet", + fields: { + id: { + name: "id", + type: "String", + id: true, + attributes: [{ name: "@id" }, { name: "@default", args: [{ name: "value", value: ExpressionUtils.call("cuid") }] }], + default: ExpressionUtils.call("cuid") + }, + createdAt: { + name: "createdAt", + type: "DateTime", + attributes: [{ name: "@default", args: [{ name: "value", value: ExpressionUtils.call("now") }] }], + default: ExpressionUtils.call("now") + }, + updatedAt: { + name: "updatedAt", + type: "DateTime", + updatedAt: true, + attributes: [{ name: "@updatedAt" }] + }, + name: { + name: "name", + type: "String" + }, + category: { + name: "category", + type: "String" + }, + order: { + name: "order", + type: "Order", + optional: true, + attributes: [{ name: "@relation", args: [{ name: "fields", value: ExpressionUtils.array([ExpressionUtils.field("orderId")]) }, { name: "references", value: ExpressionUtils.array([ExpressionUtils.field("id")]) }] }], + relation: { opposite: "pets", fields: ["orderId"], references: ["id"] } + }, + orderId: { + name: "orderId", + type: "String", + optional: true, + foreignKeyFor: [ + "order" + ] + } + }, + attributes: [ + { name: "@@allow", args: [{ name: "operation", value: ExpressionUtils.literal("read") }, { name: "condition", value: ExpressionUtils.binary(ExpressionUtils.binary(ExpressionUtils.field("orderId"), "==", ExpressionUtils._null()), "||", ExpressionUtils.binary(ExpressionUtils.member(ExpressionUtils.field("order"), ["user"]), "==", ExpressionUtils.call("auth"))) }] }, + { name: "@@allow", args: [{ name: "operation", value: ExpressionUtils.literal("update") }, { name: "condition", value: ExpressionUtils.binary(ExpressionUtils.binary(ExpressionUtils.binary(ExpressionUtils.field("name"), "==", ExpressionUtils.member(ExpressionUtils.call("future"), ["name"])), "&&", ExpressionUtils.binary(ExpressionUtils.field("category"), "==", ExpressionUtils.member(ExpressionUtils.call("future"), ["category"]))), "&&", ExpressionUtils.binary(ExpressionUtils.field("orderId"), "==", ExpressionUtils._null())) }] } + ], + idFields: ["id"], + uniqueFields: { + id: { type: "String" } + } + }, + Order: { + name: "Order", + fields: { + id: { + name: "id", + type: "String", + id: true, + attributes: [{ name: "@id" }, { name: "@default", args: [{ name: "value", value: ExpressionUtils.call("cuid") }] }], + default: ExpressionUtils.call("cuid") + }, + createdAt: { + name: "createdAt", + type: "DateTime", + attributes: [{ name: "@default", args: [{ name: "value", value: ExpressionUtils.call("now") }] }], + default: ExpressionUtils.call("now") + }, + updatedAt: { + name: "updatedAt", + type: "DateTime", + updatedAt: true, + attributes: [{ name: "@updatedAt" }] + }, + pets: { + name: "pets", + type: "Pet", + array: true, + relation: { opposite: "order" } + }, + user: { + name: "user", + type: "User", + attributes: [{ name: "@relation", args: [{ name: "fields", value: ExpressionUtils.array([ExpressionUtils.field("userId")]) }, { name: "references", value: ExpressionUtils.array([ExpressionUtils.field("id")]) }] }], + relation: { opposite: "orders", fields: ["userId"], references: ["id"] } + }, + userId: { + name: "userId", + type: "String", + foreignKeyFor: [ + "user" + ] + } + }, + attributes: [ + { name: "@@allow", args: [{ name: "operation", value: ExpressionUtils.literal("read,create") }, { name: "condition", value: ExpressionUtils.binary(ExpressionUtils.call("auth"), "==", ExpressionUtils.field("user")) }] } + ], + idFields: ["id"], + uniqueFields: { + id: { type: "String" } + } + } + }, + authType: "User", + plugins: {} +} as const satisfies SchemaDef; +export type SchemaType = typeof schema; diff --git a/packages/runtime/test/schemas/petstore/schema.zmodel b/packages/runtime/test/schemas/petstore/schema.zmodel new file mode 100644 index 00000000..4a2442ca --- /dev/null +++ b/packages/runtime/test/schemas/petstore/schema.zmodel @@ -0,0 +1,52 @@ +datasource db { + provider = 'sqlite' + url = 'file:./petstore.db' +} + +generator js { + provider = 'prisma-client-js' +} + +plugin zod { + provider = '@core/zod' +} + +model User { + id String @id @default(cuid()) + email String @unique + orders Order[] + + // everybody can signup + @@allow('create', true) + + // user profile is publicly readable + @@allow('read', true) +} + +model Pet { + id String @id @default(cuid()) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + name String + category String + order Order? @relation(fields: [orderId], references: [id]) + orderId String? + + // unsold pets are readable to all; sold ones are readable to buyers only + @@allow('read', orderId == null || order.user == auth()) + + // only allow update to 'orderId' field if it's not set yet (unsold) + @@allow('update', name == future().name && category == future().category && orderId == null ) +} + +model Order { + id String @id @default(cuid()) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + pets Pet[] + user User @relation(fields: [userId], references: [id]) + userId String + + // users can read their orders + @@allow('read,create', auth() == user) +} diff --git a/packages/runtime/test/schemas/typing/schema.ts b/packages/runtime/test/schemas/typing/schema.ts index 90a532e0..18270ceb 100644 --- a/packages/runtime/test/schemas/typing/schema.ts +++ b/packages/runtime/test/schemas/typing/schema.ts @@ -86,7 +86,7 @@ export const schema = { }, computedFields: { postCount(_context: { - currentModel: string; + modelAlias: string; }): OperandExpression { throw new Error("This is a stub for computed field"); } diff --git a/packages/runtime/test/utils.ts b/packages/runtime/test/utils.ts index 64484593..b7245062 100644 --- a/packages/runtime/test/utils.ts +++ b/packages/runtime/test/utils.ts @@ -1,30 +1,26 @@ import { invariant } from '@zenstackhq/common-helpers'; import { loadDocument } from '@zenstackhq/language'; +import type { Model } from '@zenstackhq/language/ast'; import { PrismaSchemaGenerator } from '@zenstackhq/sdk'; -import { createTestProject, generateTsSchema } from '@zenstackhq/testtools'; +import { createTestProject, generateTsSchema, getPluginModules } from '@zenstackhq/testtools'; import SQLite from 'better-sqlite3'; import { PostgresDialect, SqliteDialect, type LogEvent } from 'kysely'; import { execSync } from 'node:child_process'; +import { createHash } from 'node:crypto'; import fs from 'node:fs'; import path from 'node:path'; import { Client as PGClient, Pool } from 'pg'; +import { expect } from 'vitest'; import type { ClientContract, ClientOptions } from '../src/client'; import { ZenStackClient } from '../src/client'; import type { SchemaDef } from '../src/schema'; -type SqliteSchema = SchemaDef & { provider: { type: 'sqlite' } }; -type PostgresSchema = SchemaDef & { provider: { type: 'postgresql' } }; - -export async function makeSqliteClient( - schema: Schema, - extraOptions?: Partial>, -): Promise> { - const client = new ZenStackClient(schema, { - ...extraOptions, - dialect: new SqliteDialect({ database: new SQLite(':memory:') }), - } as unknown as ClientOptions); - await client.$pushSchema(); - return client; +export function getTestDbProvider() { + const val = process.env['TEST_DB_PROVIDER'] ?? 'sqlite'; + if (!['sqlite', 'postgresql'].includes(val!)) { + throw new Error(`Invalid TEST_DB_PROVIDER value: ${val}`); + } + return val as 'sqlite' | 'postgresql'; } const TEST_PG_CONFIG = { @@ -34,30 +30,6 @@ const TEST_PG_CONFIG = { password: process.env['TEST_PG_PASSWORD'] ?? 'postgres', }; -export async function makePostgresClient( - schema: Schema, - dbName: string, - extraOptions?: Partial>, -): Promise> { - invariant(dbName, 'dbName is required'); - const pgClient = new PGClient(TEST_PG_CONFIG); - await pgClient.connect(); - await pgClient.query(`DROP DATABASE IF EXISTS "${dbName}"`); - await pgClient.query(`CREATE DATABASE "${dbName}"`); - - const client = new ZenStackClient(schema, { - ...extraOptions, - dialect: new PostgresDialect({ - pool: new Pool({ - ...TEST_PG_CONFIG, - database: dbName, - }), - }), - } as unknown as ClientOptions); - await client.$pushSchema(); - return client; -} - export type CreateTestClientOptions = Omit, 'dialect'> & { provider?: 'sqlite' | 'postgresql'; dbName?: string; @@ -82,25 +54,22 @@ export async function createTestClient( ): Promise { let workDir = options?.workDir; let _schema: Schema; - const provider = options?.provider ?? 'sqlite'; - - let dbName = options?.dbName; - if (!dbName) { - if (provider === 'sqlite') { - dbName = './test.db'; - } else { - throw new Error(`dbName is required for ${provider} provider`); - } - } + const provider = options?.provider ?? getTestDbProvider() ?? 'sqlite'; + + const dbName = options?.dbName ?? getTestDbName(provider); + console.log(`Using provider: ${provider}, db: ${dbName}`); const dbUrl = provider === 'sqlite' ? `file:${dbName}` : `postgres://${TEST_PG_CONFIG.user}:${TEST_PG_CONFIG.password}@${TEST_PG_CONFIG.host}:${TEST_PG_CONFIG.port}/${dbName}`; + let model: Model | undefined; + if (typeof schema === 'string') { const generated = await generateTsSchema(schema, provider, dbUrl, options?.extraSourceFiles); workDir = generated.workDir; + model = generated.model; // replace schema's provider _schema = { ...generated.schema, @@ -143,16 +112,19 @@ export async function createTestClient( if (options?.usePrismaPush) { invariant(typeof schema === 'string' || schemaFile, 'a schema file must be provided when using prisma db push'); - const r = await loadDocument(path.resolve(workDir!, 'schema.zmodel')); - if (!r.success) { - throw new Error(r.errors.join('\n')); + if (!model) { + const r = await loadDocument(path.join(workDir, 'schema.zmodel'), getPluginModules()); + if (!r.success) { + throw new Error(r.errors.join('\n')); + } + model = r.model; } - const prismaSchema = new PrismaSchemaGenerator(r.model); + const prismaSchema = new PrismaSchemaGenerator(model); const prismaSchemaText = await prismaSchema.generate(); fs.writeFileSync(path.resolve(workDir!, 'schema.prisma'), prismaSchemaText); execSync('npx prisma db push --schema ./schema.prisma --skip-generate --force-reset', { cwd: workDir, - stdio: 'inherit', + stdio: 'ignore', }); } else { if (provider === 'postgresql') { @@ -196,3 +168,26 @@ export async function createTestClient( export function testLogger(e: LogEvent) { console.log(e.query.sql, e.query.parameters); } + +function getTestDbName(provider: string) { + if (provider === 'sqlite') { + return './test.db'; + } + const testName = expect.getState().currentTestName; + const testPath = expect.getState().testPath ?? ''; + invariant(testName); + // digest test name + const digest = createHash('md5') + .update(testName + testPath) + .digest('hex'); + // compute a database name based on test name + return ( + 'test_' + + testName + .toLowerCase() + .replace(/[^a-z0-9_]/g, '_') + .replace(/_+/g, '_') + .substring(0, 30) + + digest.slice(0, 6) + ); +} diff --git a/packages/sdk/package.json b/packages/sdk/package.json index 01e8af3b..e72757cc 100644 --- a/packages/sdk/package.json +++ b/packages/sdk/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/sdk", - "version": "3.0.0-beta.4", + "version": "3.0.0-beta.5", "description": "ZenStack SDK", "type": "module", "scripts": { diff --git a/packages/sdk/src/schema/schema.ts b/packages/sdk/src/schema/schema.ts index c6ea4d9b..e8beefc9 100644 --- a/packages/sdk/src/schema/schema.ts +++ b/packages/sdk/src/schema/schema.ts @@ -33,6 +33,7 @@ export type ModelDef = { computedFields?: Record; isDelegate?: boolean; subModels?: string[]; + isView?: boolean; }; export type AttributeApplication = { diff --git a/packages/sdk/src/ts-schema-generator.ts b/packages/sdk/src/ts-schema-generator.ts index 1d558300..75c0f44a 100644 --- a/packages/sdk/src/ts-schema-generator.ts +++ b/packages/sdk/src/ts-schema-generator.ts @@ -310,6 +310,8 @@ export class TsSchemaGenerator { ), ] : []), + + ...(dm.isView ? [ts.factory.createPropertyAssignment('isView', ts.factory.createTrue())] : []), ]; const computedFields = dm.fields.filter((f) => hasAttribute(f, '@computed')); @@ -376,7 +378,7 @@ export class TsSchemaGenerator { undefined, undefined, [ - // parameter: `context: { currentModel: string }` + // parameter: `context: { modelAlias: string }` ts.factory.createParameterDeclaration( undefined, undefined, @@ -385,7 +387,7 @@ export class TsSchemaGenerator { ts.factory.createTypeLiteralNode([ ts.factory.createPropertySignature( undefined, - 'currentModel', + 'modelAlias', undefined, ts.factory.createKeywordTypeNode(ts.SyntaxKind.StringKeyword), ), diff --git a/packages/tanstack-query/package.json b/packages/tanstack-query/package.json index 86a9f6ae..4dd4379e 100644 --- a/packages/tanstack-query/package.json +++ b/packages/tanstack-query/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/tanstack-query", - "version": "3.0.0-beta.4", + "version": "3.0.0-beta.5", "description": "", "main": "index.js", "type": "module", diff --git a/packages/testtools/package.json b/packages/testtools/package.json index 81e91b7c..683b6652 100644 --- a/packages/testtools/package.json +++ b/packages/testtools/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/testtools", - "version": "3.0.0-beta.4", + "version": "3.0.0-beta.5", "description": "ZenStack Test Tools", "type": "module", "scripts": { diff --git a/packages/testtools/src/schema.ts b/packages/testtools/src/schema.ts index 788f092c..b4f5386e 100644 --- a/packages/testtools/src/schema.ts +++ b/packages/testtools/src/schema.ts @@ -41,7 +41,7 @@ export async function generateTsSchema( const noPrelude = schemaText.includes('datasource '); fs.writeFileSync(zmodelPath, `${noPrelude ? '' : makePrelude(provider, dbUrl)}\n\n${schemaText}`); - const pluginModelFiles = glob.sync(path.resolve(__dirname, '../../runtime/src/plugins/**/plugin.zmodel')); + const pluginModelFiles = getPluginModules(); const result = await loadDocument(zmodelPath, pluginModelFiles); if (!result.success) { throw new Error(`Failed to load schema from ${zmodelPath}: ${result.errors}`); @@ -59,7 +59,11 @@ export async function generateTsSchema( } // compile the generated TS schema - return compileAndLoad(workDir); + return { ...(await compileAndLoad(workDir)), model: result.model }; +} + +export function getPluginModules() { + return glob.sync(path.resolve(__dirname, '../../runtime/src/plugins/**/plugin.zmodel')); } async function compileAndLoad(workDir: string) { diff --git a/packages/typescript-config/package.json b/packages/typescript-config/package.json index 75109c2e..833b3d0d 100644 --- a/packages/typescript-config/package.json +++ b/packages/typescript-config/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/typescript-config", - "version": "3.0.0-beta.4", + "version": "3.0.0-beta.5", "private": true, "license": "MIT" } diff --git a/packages/vitest-config/package.json b/packages/vitest-config/package.json index 878e8fbd..a7a2d8c5 100644 --- a/packages/vitest-config/package.json +++ b/packages/vitest-config/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/vitest-config", "type": "module", - "version": "3.0.0-beta.4", + "version": "3.0.0-beta.5", "private": true, "license": "MIT", "exports": { diff --git a/packages/zod/package.json b/packages/zod/package.json index 7bc82864..350be5d2 100644 --- a/packages/zod/package.json +++ b/packages/zod/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/zod", - "version": "3.0.0-beta.4", + "version": "3.0.0-beta.5", "description": "", "type": "module", "main": "index.js", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index c339a56b..a6638dc1 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -33,6 +33,9 @@ catalogs: typescript: specifier: ^5.8.0 version: 5.8.3 + zod-validation-error: + specifier: ^4.0.1 + version: 4.0.1 importers: @@ -299,6 +302,9 @@ importers: uuid: specifier: ^11.0.5 version: 11.0.5 + zod-validation-error: + specifier: 'catalog:' + version: 4.0.1(zod@3.25.76) devDependencies: '@types/better-sqlite3': specifier: ^7.6.13 @@ -2711,6 +2717,12 @@ packages: resolution: {integrity: sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q==} engines: {node: '>=10'} + zod-validation-error@4.0.1: + resolution: {integrity: sha512-F3rdaCOHs5ViJ5YTz5zzRtfkQdMdIeKudJAoxy7yB/2ZMEHw73lmCAcQw11r7++20MyGl4WV59EVh7A9rNAyog==} + engines: {node: '>=18.0.0'} + peerDependencies: + zod: ^3.25.0 || ^4.0.0 + zod@3.25.76: resolution: {integrity: sha512-gzUt/qt81nXsFGKIFcC3YnfEAx5NkunCfnDlvuBSSFS02bcXu4Lmea0AFIUwbLWxWPx3d9p8S5QoaujKcNQxcQ==} @@ -4791,4 +4803,8 @@ snapshots: yocto-queue@0.1.0: {} + zod-validation-error@4.0.1(zod@3.25.76): + dependencies: + zod: 3.25.76 + zod@3.25.76: {} diff --git a/pnpm-workspace.yaml b/pnpm-workspace.yaml index 11ea3a73..d54970c6 100644 --- a/pnpm-workspace.yaml +++ b/pnpm-workspace.yaml @@ -13,3 +13,4 @@ catalog: '@types/node': ^20.17.24 tmp: ^0.2.3 '@types/tmp': ^0.2.6 + 'zod-validation-error': ^4.0.1 diff --git a/samples/blog/main.ts b/samples/blog/main.ts index dfaa6c04..8bbfb5bf 100644 --- a/samples/blog/main.ts +++ b/samples/blog/main.ts @@ -8,10 +8,10 @@ async function main() { dialect: new SqliteDialect({ database: new SQLite('./zenstack/dev.db') }), computedFields: { User: { - postCount: (eb, { currentModel }) => + postCount: (eb, { modelAlias }) => eb .selectFrom('Post') - .whereRef('Post.authorId', '=', sql.ref(`${currentModel}.id`)) + .whereRef('Post.authorId', '=', sql.ref(`${modelAlias}.id`)) .select(({ fn }) => fn.countAll().as('postCount')), }, }, diff --git a/samples/blog/package.json b/samples/blog/package.json index a30dd08e..8f60eb7f 100644 --- a/samples/blog/package.json +++ b/samples/blog/package.json @@ -1,6 +1,6 @@ { "name": "sample-blog", - "version": "3.0.0-beta.4", + "version": "3.0.0-beta.5", "description": "", "main": "index.js", "scripts": { diff --git a/samples/blog/zenstack/schema.ts b/samples/blog/zenstack/schema.ts index 95f2e4a8..4ca14e3e 100644 --- a/samples/blog/zenstack/schema.ts +++ b/samples/blog/zenstack/schema.ts @@ -76,7 +76,7 @@ export const schema = { }, computedFields: { postCount(_context: { - currentModel: string; + modelAlias: string; }): OperandExpression { throw new Error("This is a stub for computed field"); } diff --git a/tests/e2e/package.json b/tests/e2e/package.json index e3129a69..a7184260 100644 --- a/tests/e2e/package.json +++ b/tests/e2e/package.json @@ -1,6 +1,6 @@ { "name": "e2e", - "version": "3.0.0-beta.4", + "version": "3.0.0-beta.5", "private": true, "type": "module", "scripts": { diff --git a/tests/regression/package.json b/tests/regression/package.json index 1d54ca4f..c64fd64b 100644 --- a/tests/regression/package.json +++ b/tests/regression/package.json @@ -1,6 +1,6 @@ { "name": "regression", - "version": "3.0.0-beta.3", + "version": "3.0.0-beta.5", "private": true, "type": "module", "scripts": {