diff --git a/.eslintrc.json b/.eslintrc.json index e04b04831..707715244 100644 --- a/.eslintrc.json +++ b/.eslintrc.json @@ -13,7 +13,7 @@ "plugin:jest/recommended" ], "rules": { - "jest/expect-expect": "off", - "@typescript-eslint/no-unused-vars": ["error", { "varsIgnorePattern": "^_", "argsIgnorePattern": "^_" }] + "@typescript-eslint/no-unused-vars": ["error", { "varsIgnorePattern": "^_", "argsIgnorePattern": "^_" }], + "jest/expect-expect": "off" } } diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index f92512bc8..cbdbce76c 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -9,7 +9,7 @@ env: on: pull_request: - branches: ['dev', 'main'] + branches: ['dev', 'main', 'v2'] jobs: build-test: @@ -32,18 +32,11 @@ jobs: strategy: matrix: node-version: [20.x] - prisma-version: [v4, v5] steps: - name: Checkout uses: actions/checkout@v3 - - name: Set Prisma Version - if: ${{ matrix.prisma-version == 'v5' }} - shell: bash - run: | - bash ./script/test-prisma-v5.sh - - name: Install pnpm uses: pnpm/action-setup@v2 with: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f1f733983..1eed5723d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -63,11 +63,11 @@ The ZModel language's definition, including its syntax definition and parser/lin ### `schema` -The `zenstack` CLI and ZModel VSCode extension implementation. The package also contains several built-in plugins: `@core/prisma`, `@core/model-meta`, `@core/access-policy`, and `core/zod`. +The `zenstack` CLI and ZModel VSCode extension implementation. The package also contains several built-in plugins: `@core/prisma`, `@core/enhancer`, and `core/zod`. ### `runtime` -Runtime enhancements to PrismaClient, including infrastructure for creating transparent proxies and concrete implementations for the `withPolicy`, `withPassword`, and `withOmit` proxies. +Runtime enhancements to PrismaClient, including infrastructure for creating transparent proxies and concrete implementations of various proxies. ### `server` diff --git a/package.json b/package.json index 16b18ea0a..fd0c2763d 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "zenstack-monorepo", - "version": "1.8.2", + "version": "2.0.0-alpha.1", "description": "", "scripts": { "build": "pnpm -r build", @@ -9,7 +9,10 @@ "test-ci": "ZENSTACK_TEST=1 pnpm -r run test --silent --forceExit", "publish-all": "pnpm --filter \"./packages/**\" -r publish --access public", "publish-preview": "pnpm --filter \"./packages/**\" -r publish --force --registry https://preview.registry.zenstack.dev/", - "unpublish-preview": "pnpm --recursive --shell-mode exec -- npm unpublish -f --registry https://preview.registry.zenstack.dev/ \"\\$PNPM_PACKAGE_NAME\"" + "unpublish-preview": "pnpm --recursive --shell-mode exec -- npm unpublish -f --registry https://preview.registry.zenstack.dev/ \"\\$PNPM_PACKAGE_NAME\"", + "publish-next": "pnpm --filter \"./packages/**\" -r publish --access public --tag next", + "publish-preview-next": "pnpm --filter \"./packages/**\" -r publish --force --registry https://preview.registry.zenstack.dev/ --tag next", + "unpublish-preview-next": "pnpm --recursive --shell-mode exec -- npm unpublish -f --registry https://preview.registry.zenstack.dev/ --tag next \"\\$PNPM_PACKAGE_NAME\"" }, "keywords": [], "author": "", diff --git a/packages/README.md b/packages/README.md index 2104ff0eb..d4f076584 100644 --- a/packages/README.md +++ b/packages/README.md @@ -51,12 +51,12 @@ At runtime, transparent proxies are created around Prisma clients for intercepti // Next.js example: pages/api/model/[...path].ts import { requestHandler } from '@zenstackhq/next'; -import { withPolicy } from '@zenstackhq/runtime'; +import { enhance } from '@zenstackhq/runtime'; import { getSessionUser } from '@lib/auth'; import { prisma } from '@lib/db'; export default requestHandler({ - getPrisma: (req, res) => withPolicy(prisma, { user: getSessionUser(req, res) }), + getPrisma: (req, res) => enhance(prisma, { user: getSessionUser(req, res) }), }); ``` diff --git a/packages/ide/jetbrains/build.gradle.kts b/packages/ide/jetbrains/build.gradle.kts index 2643f4e2a..4c8825117 100644 --- a/packages/ide/jetbrains/build.gradle.kts +++ b/packages/ide/jetbrains/build.gradle.kts @@ -9,7 +9,7 @@ plugins { } group = "dev.zenstack" -version = "1.8.2" +version = "2.0.0-alpha.1" repositories { mavenCentral() diff --git a/packages/ide/jetbrains/package.json b/packages/ide/jetbrains/package.json index ca7e7c57e..4e7fc26df 100644 --- a/packages/ide/jetbrains/package.json +++ b/packages/ide/jetbrains/package.json @@ -1,12 +1,12 @@ { "name": "jetbrains", - "version": "1.8.2", + "version": "2.0.0-alpha.1", "displayName": "ZenStack JetBrains IDE Plugin", "description": "ZenStack JetBrains IDE plugin", "homepage": "https://zenstack.dev", "private": true, "scripts": { - "build": "./gradlew buildPlugin" + "build": "./gradlew buildPlugin" }, "author": "ZenStack Team", "license": "MIT", diff --git a/packages/language/package.json b/packages/language/package.json index ff2cfd5ec..a222a6d95 100644 --- a/packages/language/package.json +++ b/packages/language/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/language", - "version": "1.8.2", + "version": "2.0.0-alpha.1", "displayName": "ZenStack modeling language compiler", "description": "ZenStack modeling language compiler", "homepage": "https://zenstack.dev", @@ -19,11 +19,11 @@ "author": "ZenStack Team", "license": "MIT", "devDependencies": { - "langium-cli": "1.2.0", + "langium-cli": "1.3.1", "plist2": "^1.1.3" }, "dependencies": { - "langium": "1.2.0" + "langium": "1.3.1" }, "contributes": { "languages": [ diff --git a/packages/language/src/ast.ts b/packages/language/src/ast.ts index c8637115a..3da706a75 100644 --- a/packages/language/src/ast.ts +++ b/packages/language/src/ast.ts @@ -1,7 +1,8 @@ -import { AbstractDeclaration, ExpressionType, BinaryExpr } from './generated/ast'; +import { AstNode } from 'langium'; +import { AbstractDeclaration, BinaryExpr, DataModel, ExpressionType } from './generated/ast'; -export * from './generated/ast'; export { AstNode, Reference } from 'langium'; +export * from './generated/ast'; /** * Shape of type resolution result: an expression type or reference to a declaration @@ -44,16 +45,28 @@ declare module './generated/ast' { $resolvedParam?: AttributeParam; } - interface DataModel { + interface DataModelField { + $inheritedFrom?: DataModel; + } + + interface DataModelAttribute { + $inheritedFrom?: DataModel; + } + + export interface DataModel { /** - * Resolved fields, include inherited fields + * Indicates whether the model is already merged with the base types */ - $resolvedFields: Array; + $baseMerged?: boolean; } +} - interface DataModelField { - $isInherited?: boolean; - } +export interface InheritableNode extends AstNode { + $inheritedFrom?: DataModel; +} + +export interface InheritableNode extends AstNode { + $inheritedFrom?: DataModel; } declare module 'langium' { diff --git a/packages/language/src/generated/ast.ts b/packages/language/src/generated/ast.ts index 7463fb9da..a95a748d9 100644 --- a/packages/language/src/generated/ast.ts +++ b/packages/language/src/generated/ast.ts @@ -1,10 +1,24 @@ /****************************************************************************** - * This file was generated by langium-cli 1.2.0. + * This file was generated by langium-cli 1.3.1. * DO NOT EDIT MANUALLY! ******************************************************************************/ /* eslint-disable */ -import { AstNode, AbstractAstReflection, Reference, ReferenceInfo, TypeMetaData } from 'langium'; +import type { AstNode, Reference, ReferenceInfo, TypeMetaData } from 'langium'; +import { AbstractAstReflection } from 'langium'; + +export const ZModelTerminals = { + WS: /\s+/, + INTERNAL_ATTRIBUTE_NAME: /@@@([_a-zA-Z][\w_]*\.)*[_a-zA-Z][\w_]*/, + MODEL_ATTRIBUTE_NAME: /@@([_a-zA-Z][\w_]*\.)*[_a-zA-Z][\w_]*/, + FIELD_ATTRIBUTE_NAME: /@([_a-zA-Z][\w_]*\.)*[_a-zA-Z][\w_]*/, + ID: /[_a-zA-Z][\w_]*/, + STRING: /"(\\.|[^"\\])*"|'(\\.|[^'\\])*'/, + NUMBER: /[+-]?[0-9]+(\.[0-9]+)?/, + TRIPLE_SLASH_COMMENT: /\/\/\/[^\n\r]*/, + ML_COMMENT: /\/\*[\s\S]*?\*\//, + SL_COMMENT: /\/\/[^\n\r]*/, +}; export type AbstractDeclaration = Attribute | DataModel | DataSource | Enum | FunctionDecl | GeneratorDecl | Plugin; @@ -64,10 +78,10 @@ export function isReferenceTarget(item: unknown): item is ReferenceTarget { return reflection.isInstance(item, ReferenceTarget); } -export type RegularID = 'abstract' | 'attribute' | 'datasource' | 'enum' | 'import' | 'in' | 'model' | 'plugin' | 'sort' | 'view' | string; +export type RegularID = 'abstract' | 'attribute' | 'datasource' | 'enum' | 'import' | 'in' | 'model' | 'plugin' | 'view' | string; export function isRegularID(item: unknown): item is RegularID { - return item === 'model' || item === 'enum' || item === 'attribute' || item === 'datasource' || item === 'plugin' || item === 'abstract' || item === 'in' || item === 'sort' || item === 'view' || item === 'import' || (typeof item === 'string' && (/[_a-zA-Z][\w_]*/.test(item))); + return item === 'model' || item === 'enum' || item === 'attribute' || item === 'datasource' || item === 'plugin' || item === 'abstract' || item === 'in' || item === 'view' || item === 'import' || (typeof item === 'string' && (/[_a-zA-Z][\w_]*/.test(item))); } export type TypeDeclaration = DataModel | Enum; @@ -81,7 +95,6 @@ export function isTypeDeclaration(item: unknown): item is TypeDeclaration { export interface Argument extends AstNode { readonly $container: InvocationExpr; readonly $type: 'Argument'; - name?: RegularID value: Expression } @@ -92,7 +105,7 @@ export function isArgument(item: unknown): item is Argument { } export interface ArrayExpr extends AstNode { - readonly $container: Argument | ArrayExpr | AttributeArg | BinaryExpr | ConfigArrayExpr | ConfigField | ConfigInvocationArg | FieldInitializer | FunctionDecl | MemberAccessExpr | PluginField | UnaryExpr | UnsupportedFieldType; + readonly $container: Argument | ArrayExpr | AttributeArg | BinaryExpr | ConfigArrayExpr | ConfigField | ConfigInvocationArg | FieldInitializer | FunctionDecl | MemberAccessExpr | PluginField | ReferenceArg | UnaryExpr | UnsupportedFieldType; readonly $type: 'ArrayExpr'; items: Array } @@ -163,7 +176,7 @@ export function isAttributeParamType(item: unknown): item is AttributeParamType } export interface BinaryExpr extends AstNode { - readonly $container: Argument | ArrayExpr | AttributeArg | BinaryExpr | ConfigArrayExpr | ConfigField | ConfigInvocationArg | FieldInitializer | FunctionDecl | MemberAccessExpr | PluginField | UnaryExpr | UnsupportedFieldType; + readonly $container: Argument | ArrayExpr | AttributeArg | BinaryExpr | ConfigArrayExpr | ConfigField | ConfigInvocationArg | FieldInitializer | FunctionDecl | MemberAccessExpr | PluginField | ReferenceArg | UnaryExpr | UnsupportedFieldType; readonly $type: 'BinaryExpr'; left: Expression operator: '!' | '!=' | '&&' | '<' | '<=' | '==' | '>' | '>=' | '?' | '^' | 'in' | '||' @@ -177,7 +190,7 @@ export function isBinaryExpr(item: unknown): item is BinaryExpr { } export interface BooleanLiteral extends AstNode { - readonly $container: Argument | ArrayExpr | AttributeArg | BinaryExpr | ConfigArrayExpr | ConfigField | ConfigInvocationArg | FieldInitializer | FunctionDecl | MemberAccessExpr | PluginField | UnaryExpr | UnsupportedFieldType; + readonly $container: Argument | ArrayExpr | AttributeArg | BinaryExpr | ConfigArrayExpr | ConfigField | ConfigInvocationArg | FieldInitializer | FunctionDecl | MemberAccessExpr | PluginField | ReferenceArg | UnaryExpr | UnsupportedFieldType; readonly $type: 'BooleanLiteral'; value: Boolean } @@ -189,7 +202,7 @@ export function isBooleanLiteral(item: unknown): item is BooleanLiteral { } export interface ConfigArrayExpr extends AstNode { - readonly $container: Argument | ArrayExpr | AttributeArg | BinaryExpr | ConfigArrayExpr | ConfigField | ConfigInvocationArg | FieldInitializer | FunctionDecl | MemberAccessExpr | PluginField | UnaryExpr | UnsupportedFieldType; + readonly $container: Argument | ArrayExpr | AttributeArg | BinaryExpr | ConfigArrayExpr | ConfigField | ConfigInvocationArg | FieldInitializer | FunctionDecl | MemberAccessExpr | PluginField | ReferenceArg | UnaryExpr | UnsupportedFieldType; readonly $type: 'ConfigArrayExpr'; items: Array } @@ -440,7 +453,7 @@ export function isInternalAttribute(item: unknown): item is InternalAttribute { } export interface InvocationExpr extends AstNode { - readonly $container: Argument | ArrayExpr | AttributeArg | BinaryExpr | ConfigArrayExpr | ConfigField | ConfigInvocationArg | FieldInitializer | FunctionDecl | MemberAccessExpr | PluginField | UnaryExpr | UnsupportedFieldType; + readonly $container: Argument | ArrayExpr | AttributeArg | BinaryExpr | ConfigArrayExpr | ConfigField | ConfigInvocationArg | FieldInitializer | FunctionDecl | MemberAccessExpr | PluginField | ReferenceArg | UnaryExpr | UnsupportedFieldType; readonly $type: 'InvocationExpr'; args: Array function: Reference @@ -453,7 +466,7 @@ export function isInvocationExpr(item: unknown): item is InvocationExpr { } export interface MemberAccessExpr extends AstNode { - readonly $container: Argument | ArrayExpr | AttributeArg | BinaryExpr | ConfigArrayExpr | ConfigField | ConfigInvocationArg | FieldInitializer | FunctionDecl | MemberAccessExpr | PluginField | UnaryExpr | UnsupportedFieldType; + readonly $container: Argument | ArrayExpr | AttributeArg | BinaryExpr | ConfigArrayExpr | ConfigField | ConfigInvocationArg | FieldInitializer | FunctionDecl | MemberAccessExpr | PluginField | ReferenceArg | UnaryExpr | UnsupportedFieldType; readonly $type: 'MemberAccessExpr'; member: Reference operand: Expression @@ -490,7 +503,7 @@ export function isModelImport(item: unknown): item is ModelImport { } export interface NullExpr extends AstNode { - readonly $container: Argument | ArrayExpr | AttributeArg | BinaryExpr | ConfigArrayExpr | ConfigField | ConfigInvocationArg | FieldInitializer | FunctionDecl | MemberAccessExpr | PluginField | UnaryExpr | UnsupportedFieldType; + readonly $container: Argument | ArrayExpr | AttributeArg | BinaryExpr | ConfigArrayExpr | ConfigField | ConfigInvocationArg | FieldInitializer | FunctionDecl | MemberAccessExpr | PluginField | ReferenceArg | UnaryExpr | UnsupportedFieldType; readonly $type: 'NullExpr'; value: 'null' } @@ -502,7 +515,7 @@ export function isNullExpr(item: unknown): item is NullExpr { } export interface NumberLiteral extends AstNode { - readonly $container: Argument | ArrayExpr | AttributeArg | BinaryExpr | ConfigArrayExpr | ConfigField | ConfigInvocationArg | FieldInitializer | FunctionDecl | MemberAccessExpr | PluginField | UnaryExpr | UnsupportedFieldType; + readonly $container: Argument | ArrayExpr | AttributeArg | BinaryExpr | ConfigArrayExpr | ConfigField | ConfigInvocationArg | FieldInitializer | FunctionDecl | MemberAccessExpr | PluginField | ReferenceArg | UnaryExpr | UnsupportedFieldType; readonly $type: 'NumberLiteral'; value: string } @@ -514,7 +527,7 @@ export function isNumberLiteral(item: unknown): item is NumberLiteral { } export interface ObjectExpr extends AstNode { - readonly $container: Argument | ArrayExpr | AttributeArg | BinaryExpr | ConfigArrayExpr | ConfigField | ConfigInvocationArg | FieldInitializer | FunctionDecl | MemberAccessExpr | PluginField | UnaryExpr | UnsupportedFieldType; + readonly $container: Argument | ArrayExpr | AttributeArg | BinaryExpr | ConfigArrayExpr | ConfigField | ConfigInvocationArg | FieldInitializer | FunctionDecl | MemberAccessExpr | PluginField | ReferenceArg | UnaryExpr | UnsupportedFieldType; readonly $type: 'ObjectExpr'; fields: Array } @@ -554,8 +567,8 @@ export function isPluginField(item: unknown): item is PluginField { export interface ReferenceArg extends AstNode { readonly $container: ReferenceExpr; readonly $type: 'ReferenceArg'; - name: 'sort' - value: 'Asc' | 'Desc' + name: string + value: Expression } export const ReferenceArg = 'ReferenceArg'; @@ -565,7 +578,7 @@ export function isReferenceArg(item: unknown): item is ReferenceArg { } export interface ReferenceExpr extends AstNode { - readonly $container: Argument | ArrayExpr | AttributeArg | BinaryExpr | ConfigArrayExpr | ConfigField | ConfigInvocationArg | FieldInitializer | FunctionDecl | MemberAccessExpr | PluginField | UnaryExpr | UnsupportedFieldType; + readonly $container: Argument | ArrayExpr | AttributeArg | BinaryExpr | ConfigArrayExpr | ConfigField | ConfigInvocationArg | FieldInitializer | FunctionDecl | MemberAccessExpr | PluginField | ReferenceArg | UnaryExpr | UnsupportedFieldType; readonly $type: 'ReferenceExpr'; args: Array target: Reference @@ -578,7 +591,7 @@ export function isReferenceExpr(item: unknown): item is ReferenceExpr { } export interface StringLiteral extends AstNode { - readonly $container: Argument | ArrayExpr | AttributeArg | BinaryExpr | ConfigArrayExpr | ConfigField | ConfigInvocationArg | FieldInitializer | FunctionDecl | MemberAccessExpr | PluginField | UnaryExpr | UnsupportedFieldType; + readonly $container: Argument | ArrayExpr | AttributeArg | BinaryExpr | ConfigArrayExpr | ConfigField | ConfigInvocationArg | FieldInitializer | FunctionDecl | MemberAccessExpr | PluginField | ReferenceArg | UnaryExpr | UnsupportedFieldType; readonly $type: 'StringLiteral'; value: string } @@ -590,7 +603,7 @@ export function isStringLiteral(item: unknown): item is StringLiteral { } export interface ThisExpr extends AstNode { - readonly $container: Argument | ArrayExpr | AttributeArg | BinaryExpr | ConfigArrayExpr | ConfigField | ConfigInvocationArg | FieldInitializer | FunctionDecl | MemberAccessExpr | PluginField | UnaryExpr | UnsupportedFieldType; + readonly $container: Argument | ArrayExpr | AttributeArg | BinaryExpr | ConfigArrayExpr | ConfigField | ConfigInvocationArg | FieldInitializer | FunctionDecl | MemberAccessExpr | PluginField | ReferenceArg | UnaryExpr | UnsupportedFieldType; readonly $type: 'ThisExpr'; value: 'this' } @@ -602,7 +615,7 @@ export function isThisExpr(item: unknown): item is ThisExpr { } export interface UnaryExpr extends AstNode { - readonly $container: Argument | ArrayExpr | AttributeArg | BinaryExpr | ConfigArrayExpr | ConfigField | ConfigInvocationArg | FieldInitializer | FunctionDecl | MemberAccessExpr | PluginField | UnaryExpr | UnsupportedFieldType; + readonly $container: Argument | ArrayExpr | AttributeArg | BinaryExpr | ConfigArrayExpr | ConfigField | ConfigInvocationArg | FieldInitializer | FunctionDecl | MemberAccessExpr | PluginField | ReferenceArg | UnaryExpr | UnsupportedFieldType; readonly $type: 'UnaryExpr'; operand: Expression operator: '!' diff --git a/packages/language/src/generated/grammar.ts b/packages/language/src/generated/grammar.ts index 5dbe02014..45aa3ff97 100644 --- a/packages/language/src/generated/grammar.ts +++ b/packages/language/src/generated/grammar.ts @@ -1,9 +1,10 @@ /****************************************************************************** - * This file was generated by langium-cli 1.2.0. + * This file was generated by langium-cli 1.3.1. * DO NOT EDIT MANUALLY! ******************************************************************************/ -import { loadGrammarFromJson, Grammar } from 'langium'; +import type { Grammar } from 'langium'; +import { loadGrammarFromJson } from 'langium'; let loadedZModelGrammar: Grammar | undefined; export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModelGrammar = loadGrammarFromJson(`{ @@ -1052,8 +1053,11 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "feature": "name", "operator": "=", "terminal": { - "$type": "Keyword", - "value": "sort" + "$type": "RuleCall", + "rule": { + "$ref": "#/rules@62" + }, + "arguments": [] } }, { @@ -1065,17 +1069,11 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "feature": "value", "operator": "=", "terminal": { - "$type": "Alternatives", - "elements": [ - { - "$type": "Keyword", - "value": "Asc" - }, - { - "$type": "Keyword", - "value": "Desc" - } - ] + "$type": "RuleCall", + "rule": { + "$ref": "#/rules@8" + }, + "arguments": [] } } ] @@ -1865,43 +1863,16 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "$type": "ParserRule", "name": "Argument", "definition": { - "$type": "Group", - "elements": [ - { - "$type": "Group", - "elements": [ - { - "$type": "Assignment", - "feature": "name", - "operator": "=", - "terminal": { - "$type": "RuleCall", - "rule": { - "$ref": "#/rules@46" - }, - "arguments": [] - } - }, - { - "$type": "Keyword", - "value": ":" - } - ], - "cardinality": "?" + "$type": "Assignment", + "feature": "value", + "operator": "=", + "terminal": { + "$type": "RuleCall", + "rule": { + "$ref": "#/rules@8" }, - { - "$type": "Assignment", - "feature": "value", - "operator": "=", - "terminal": { - "$type": "RuleCall", - "rule": { - "$ref": "#/rules@8" - }, - "arguments": [] - } - } - ] + "arguments": [] + } }, "definesHiddenTokens": false, "entry": false, @@ -2723,10 +2694,6 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "$type": "Keyword", "value": "in" }, - { - "$type": "Keyword", - "value": "sort" - }, { "$type": "Keyword", "value": "view" @@ -3452,7 +3419,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "name": "WS", "definition": { "$type": "RegexToken", - "regex": "\\\\s+" + "regex": "/\\\\s+/" }, "fragment": false }, @@ -3461,7 +3428,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "name": "INTERNAL_ATTRIBUTE_NAME", "definition": { "$type": "RegexToken", - "regex": "@@@([_a-zA-Z][\\\\w_]*\\\\.)*[_a-zA-Z][\\\\w_]*" + "regex": "/@@@([_a-zA-Z][\\\\w_]*\\\\.)*[_a-zA-Z][\\\\w_]*/" }, "fragment": false, "hidden": false @@ -3471,7 +3438,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "name": "MODEL_ATTRIBUTE_NAME", "definition": { "$type": "RegexToken", - "regex": "@@([_a-zA-Z][\\\\w_]*\\\\.)*[_a-zA-Z][\\\\w_]*" + "regex": "/@@([_a-zA-Z][\\\\w_]*\\\\.)*[_a-zA-Z][\\\\w_]*/" }, "fragment": false, "hidden": false @@ -3481,7 +3448,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "name": "FIELD_ATTRIBUTE_NAME", "definition": { "$type": "RegexToken", - "regex": "@([_a-zA-Z][\\\\w_]*\\\\.)*[_a-zA-Z][\\\\w_]*" + "regex": "/@([_a-zA-Z][\\\\w_]*\\\\.)*[_a-zA-Z][\\\\w_]*/" }, "fragment": false, "hidden": false @@ -3491,7 +3458,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "name": "ID", "definition": { "$type": "RegexToken", - "regex": "[_a-zA-Z][\\\\w_]*" + "regex": "/[_a-zA-Z][\\\\w_]*/" }, "fragment": false, "hidden": false @@ -3501,7 +3468,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "name": "STRING", "definition": { "$type": "RegexToken", - "regex": "\\"(\\\\\\\\.|[^\\"\\\\\\\\])*\\"|'(\\\\\\\\.|[^'\\\\\\\\])*'" + "regex": "/\\"(\\\\\\\\.|[^\\"\\\\\\\\])*\\"|'(\\\\\\\\.|[^'\\\\\\\\])*'/" }, "fragment": false, "hidden": false @@ -3511,7 +3478,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "name": "NUMBER", "definition": { "$type": "RegexToken", - "regex": "[+-]?[0-9]+(\\\\.[0-9]+)?" + "regex": "/[+-]?[0-9]+(\\\\.[0-9]+)?/" }, "fragment": false, "hidden": false @@ -3521,7 +3488,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "name": "TRIPLE_SLASH_COMMENT", "definition": { "$type": "RegexToken", - "regex": "\\\\/\\\\/\\\\/[^\\\\n\\\\r]*" + "regex": "/\\\\/\\\\/\\\\/[^\\\\n\\\\r]*/" }, "fragment": false, "hidden": false @@ -3532,7 +3499,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "name": "ML_COMMENT", "definition": { "$type": "RegexToken", - "regex": "\\\\/\\\\*[\\\\s\\\\S]*?\\\\*\\\\/" + "regex": "/\\\\/\\\\*[\\\\s\\\\S]*?\\\\*\\\\//" }, "fragment": false }, @@ -3542,7 +3509,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "name": "SL_COMMENT", "definition": { "$type": "RegexToken", - "regex": "\\\\/\\\\/[^\\\\n\\\\r]*" + "regex": "/\\\\/\\\\/[^\\\\n\\\\r]*/" }, "fragment": false } diff --git a/packages/language/src/generated/module.ts b/packages/language/src/generated/module.ts index ac0995108..b96dd1dee 100644 --- a/packages/language/src/generated/module.ts +++ b/packages/language/src/generated/module.ts @@ -1,17 +1,17 @@ /****************************************************************************** - * This file was generated by langium-cli 1.2.0. + * This file was generated by langium-cli 1.3.1. * DO NOT EDIT MANUALLY! ******************************************************************************/ -import { LangiumGeneratedServices, LangiumGeneratedSharedServices, LangiumSharedServices, LangiumServices, LanguageMetaData, Module } from 'langium'; +import type { LangiumGeneratedServices, LangiumGeneratedSharedServices, LangiumSharedServices, LangiumServices, LanguageMetaData, Module } from 'langium'; import { ZModelAstReflection } from './ast'; import { ZModelGrammar } from './grammar'; -export const ZModelLanguageMetaData: LanguageMetaData = { +export const ZModelLanguageMetaData = { languageId: 'zmodel', fileExtensions: ['.zmodel'], caseInsensitive: false -}; +} as const satisfies LanguageMetaData; export const ZModelGeneratedSharedModule: Module = { AstReflection: () => new ZModelAstReflection() diff --git a/packages/language/src/zmodel.langium b/packages/language/src/zmodel.langium index da445c792..8fcc72c34 100644 --- a/packages/language/src/zmodel.langium +++ b/packages/language/src/zmodel.langium @@ -80,7 +80,7 @@ fragment ReferenceArgList: args+=ReferenceArg (',' args+=ReferenceArg)*; ReferenceArg: - name=('sort') ':' value=('Asc' | 'Desc'); + name=ID ':' value=Expression; ObjectExpr: @@ -172,7 +172,7 @@ fragment ArgumentList: args+=Argument (',' args+=Argument)*; Argument: - (name=RegularID ':')? value=Expression; + value=Expression; // model DataModel: @@ -224,7 +224,7 @@ FunctionParamType: // https://github.com/langium/langium/discussions/1012 RegularID returns string: // include keywords that we'd like to work as ID in most places - ID | 'model' | 'enum' | 'attribute' | 'datasource' | 'plugin' | 'abstract' | 'in' | 'sort' | 'view' | 'import'; + ID | 'model' | 'enum' | 'attribute' | 'datasource' | 'plugin' | 'abstract' | 'in' | 'view' | 'import'; // attribute Attribute: diff --git a/packages/language/syntaxes/zmodel.tmLanguage b/packages/language/syntaxes/zmodel.tmLanguage index cf70fb761..6102b919d 100644 --- a/packages/language/syntaxes/zmodel.tmLanguage +++ b/packages/language/syntaxes/zmodel.tmLanguage @@ -20,7 +20,7 @@ name keyword.control.zmodel match - \b(Any|Asc|BigInt|Boolean|Bytes|ContextType|DateTime|Decimal|Desc|FieldReference|Float|Int|Json|Null|Object|String|TransitiveFieldReference|Unsupported|abstract|attribute|datasource|enum|extends|false|function|generator|import|in|model|null|plugin|sort|this|true|view)\b + \b(Any|BigInt|Boolean|Bytes|ContextType|DateTime|Decimal|FieldReference|Float|Int|Json|Null|Object|String|TransitiveFieldReference|Unsupported|abstract|attribute|datasource|enum|extends|false|function|generator|import|in|model|null|plugin|this|true|view)\b name diff --git a/packages/language/syntaxes/zmodel.tmLanguage.json b/packages/language/syntaxes/zmodel.tmLanguage.json index 00c737c97..aad6a38c7 100644 --- a/packages/language/syntaxes/zmodel.tmLanguage.json +++ b/packages/language/syntaxes/zmodel.tmLanguage.json @@ -10,7 +10,7 @@ }, { "name": "keyword.control.zmodel", - "match": "\\b(Any|Asc|BigInt|Boolean|Bytes|ContextType|DateTime|Decimal|Desc|FieldReference|Float|Int|Json|Null|Object|String|TransitiveFieldReference|Unsupported|abstract|attribute|datasource|enum|extends|false|function|generator|import|in|model|null|plugin|sort|this|true|view)\\b" + "match": "\\b(Any|BigInt|Boolean|Bytes|ContextType|DateTime|Decimal|FieldReference|Float|Int|Json|Null|Object|String|TransitiveFieldReference|Unsupported|abstract|attribute|datasource|enum|extends|false|function|generator|import|in|model|null|plugin|this|true|view)\\b" }, { "name": "string.quoted.double.zmodel", diff --git a/packages/misc/redwood/package.json b/packages/misc/redwood/package.json index 47cddae1b..a1195e7d8 100644 --- a/packages/misc/redwood/package.json +++ b/packages/misc/redwood/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/redwood", "displayName": "ZenStack RedwoodJS Integration", - "version": "1.8.2", + "version": "2.0.0-alpha.1", "description": "CLI and runtime for integrating ZenStack with RedwoodJS projects.", "repository": { "type": "git", diff --git a/packages/plugins/openapi/package.json b/packages/plugins/openapi/package.json index a910ab0e8..2faa34d87 100644 --- a/packages/plugins/openapi/package.json +++ b/packages/plugins/openapi/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/openapi", "displayName": "ZenStack Plugin and Runtime for OpenAPI", - "version": "1.8.2", + "version": "2.0.0-alpha.1", "description": "ZenStack plugin and runtime supporting OpenAPI", "main": "index.js", "repository": { diff --git a/packages/plugins/openapi/src/generator-base.ts b/packages/plugins/openapi/src/generator-base.ts index d00c081fc..1a46fa528 100644 --- a/packages/plugins/openapi/src/generator-base.ts +++ b/packages/plugins/openapi/src/generator-base.ts @@ -2,9 +2,10 @@ import type { DMMF } from '@prisma/generator-helper'; import { PluginError, PluginOptions, getDataModels, hasAttribute } from '@zenstackhq/sdk'; import { Model } from '@zenstackhq/sdk/ast'; import type { OpenAPIV3_1 as OAPI } from 'openapi-types'; +import semver from 'semver'; import { fromZodError } from 'zod-validation-error'; +import { name } from '.'; import { SecuritySchemesSchema } from './schema'; -import semver from 'semver'; export abstract class OpenAPIGeneratorBase { protected readonly DEFAULT_SPEC_VERSION = '3.1.0'; @@ -91,10 +92,7 @@ export abstract class OpenAPIGeneratorBase { if (securitySchemes) { const parsed = SecuritySchemesSchema.safeParse(securitySchemes); if (!parsed.success) { - throw new PluginError( - this.options.name, - `"securitySchemes" option is invalid: ${fromZodError(parsed.error)}` - ); + throw new PluginError(name, `"securitySchemes" option is invalid: ${fromZodError(parsed.error)}`); } return parsed.data; } diff --git a/packages/plugins/openapi/src/rpc-generator.ts b/packages/plugins/openapi/src/rpc-generator.ts index 13bb91272..c551a8aef 100644 --- a/packages/plugins/openapi/src/rpc-generator.ts +++ b/packages/plugins/openapi/src/rpc-generator.ts @@ -721,7 +721,7 @@ export class RPCOpenAPIGenerator extends OpenAPIGeneratorBase { return this.wrapArray(this.wrapNullable(this.ref(def.type, false), !def.isRequired), def.isList); default: - throw new PluginError(this.options.name, `Unsupported field kind: ${def.kind}`); + throw new PluginError(name, `Unsupported field kind: ${def.kind}`); } } diff --git a/packages/plugins/swr/package.json b/packages/plugins/swr/package.json index 223433bec..bedc7fd13 100644 --- a/packages/plugins/swr/package.json +++ b/packages/plugins/swr/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/swr", "displayName": "ZenStack plugin for generating SWR hooks", - "version": "1.8.2", + "version": "2.0.0-alpha.1", "description": "ZenStack plugin for generating SWR hooks", "main": "index.js", "repository": { diff --git a/packages/plugins/swr/src/generator.ts b/packages/plugins/swr/src/generator.ts index e074b603c..3a47a1c87 100644 --- a/packages/plugins/swr/src/generator.ts +++ b/packages/plugins/swr/src/generator.ts @@ -38,8 +38,6 @@ export async function generate(model: Model, options: PluginOptions, dmmf: DMMF. await generateModelMeta(project, models, { output: path.join(outDir, '__model_meta.ts'), - compile: false, - preserveTsFiles: true, generateAttributes: false, }); diff --git a/packages/plugins/swr/tests/swr.test.ts b/packages/plugins/swr/tests/swr.test.ts index 9d198269b..9759aee2d 100644 --- a/packages/plugins/swr/tests/swr.test.ts +++ b/packages/plugins/swr/tests/swr.test.ts @@ -60,7 +60,7 @@ ${sharedModel} provider: 'postgresql', pushDb: false, extraDependencies: [ - `${path.join(__dirname, '../dist')}`, + path.resolve(__dirname, '../dist'), 'react@18.2.0', '@types/react@18.2.0', 'swr@^2', diff --git a/packages/plugins/swr/tests/test-model-meta.ts b/packages/plugins/swr/tests/test-model-meta.ts index 41731ad18..71a657bad 100644 --- a/packages/plugins/swr/tests/test-model-meta.ts +++ b/packages/plugins/swr/tests/test-model-meta.ts @@ -11,39 +11,46 @@ const fieldDefaults = { }; export const modelMeta: ModelMeta = { - fields: { + models: { user: { - id: { - ...fieldDefaults, - type: 'String', - isId: true, - name: 'id', - isOptional: false, - }, - name: { ...fieldDefaults, type: 'String', name: 'name' }, - email: { ...fieldDefaults, type: 'String', name: 'name', isOptional: false }, - posts: { - ...fieldDefaults, - type: 'Post', - isDataModel: true, - isArray: true, - name: 'posts', + name: 'user', + fields: { + id: { + ...fieldDefaults, + type: 'String', + isId: true, + name: 'id', + isOptional: false, + }, + name: { ...fieldDefaults, type: 'String', name: 'name' }, + email: { ...fieldDefaults, type: 'String', name: 'name', isOptional: false }, + posts: { + ...fieldDefaults, + type: 'Post', + isDataModel: true, + isArray: true, + name: 'posts', + }, }, + uniqueConstraints: {}, }, post: { - id: { - ...fieldDefaults, - type: 'String', - isId: true, - name: 'id', - isOptional: false, + name: 'post', + fields: { + id: { + ...fieldDefaults, + type: 'String', + isId: true, + name: 'id', + isOptional: false, + }, + title: { ...fieldDefaults, type: 'String', name: 'title' }, + owner: { ...fieldDefaults, type: 'User', name: 'owner', isDataModel: true, isRelationOwner: true }, + ownerId: { ...fieldDefaults, type: 'User', name: 'owner', isForeignKey: true }, }, - title: { ...fieldDefaults, type: 'String', name: 'title' }, - owner: { ...fieldDefaults, type: 'User', name: 'owner', isDataModel: true, isRelationOwner: true }, - ownerId: { ...fieldDefaults, type: 'User', name: 'owner', isForeignKey: true }, + uniqueConstraints: {}, }, }, - uniqueConstraints: {}, deleteCascade: { user: ['Post'], }, diff --git a/packages/plugins/tanstack-query/package.json b/packages/plugins/tanstack-query/package.json index 3d5a6d94b..5e625ca17 100644 --- a/packages/plugins/tanstack-query/package.json +++ b/packages/plugins/tanstack-query/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/tanstack-query", "displayName": "ZenStack plugin for generating tanstack-query hooks", - "version": "1.8.2", + "version": "2.0.0-alpha.1", "description": "ZenStack plugin for generating tanstack-query hooks", "main": "index.js", "exports": { diff --git a/packages/plugins/tanstack-query/src/generator.ts b/packages/plugins/tanstack-query/src/generator.ts index bf0c88e0a..58836091a 100644 --- a/packages/plugins/tanstack-query/src/generator.ts +++ b/packages/plugins/tanstack-query/src/generator.ts @@ -34,21 +34,16 @@ export async function generate(model: Model, options: PluginOptions, dmmf: DMMF. const target = requireOption(options, 'target', name); if (!supportedTargets.includes(target)) { - throw new PluginError( - options.name, - `Unsupported target "${target}", supported values: ${supportedTargets.join(', ')}` - ); + throw new PluginError(name, `Unsupported target "${target}", supported values: ${supportedTargets.join(', ')}`); } const version = typeof options.version === 'string' ? options.version : 'v4'; if (version !== 'v4' && version !== 'v5') { - throw new PluginError(options.name, `Unsupported version "${version}": use "v4" or "v5"`); + throw new PluginError(name, `Unsupported version "${version}": use "v4" or "v5"`); } await generateModelMeta(project, models, { output: path.join(outDir, '__model_meta.ts'), - compile: false, - preserveTsFiles: true, generateAttributes: false, }); @@ -78,68 +73,88 @@ function generateQueryHook( overrideReturnType?: string, overrideInputType?: string, overrideTypeParameters?: string[], - infinite = false, - optimisticUpdate = false + supportInfinite = false, + supportOptimistic = false ) { - const capOperation = upperCaseFirst(operation); - - const argsType = overrideInputType ?? `Prisma.${model}${capOperation}Args`; - const inputType = `Prisma.SelectSubset`; - - let defaultReturnType = `Prisma.${model}GetPayload`; - if (optimisticUpdate) { - defaultReturnType += '& { $optimistic?: boolean }'; + const generateModes: ('' | 'Infinite' | 'Suspense' | 'SuspenseInfinite')[] = ['']; + if (supportInfinite) { + generateModes.push('Infinite'); } - if (returnArray) { - defaultReturnType = `Array<${defaultReturnType}>`; + + if (target === 'react' && version === 'v5') { + // react-query v5 supports suspense query + generateModes.push('Suspense'); + if (supportInfinite) { + generateModes.push('SuspenseInfinite'); + } } - const returnType = overrideReturnType ?? defaultReturnType; - const optionsType = makeQueryOptions(target, 'TQueryFnData', 'TData', infinite, version); + for (const generateMode of generateModes) { + const capOperation = upperCaseFirst(operation); - const func = sf.addFunction({ - name: `use${infinite ? 'Infinite' : ''}${capOperation}${model}`, - typeParameters: overrideTypeParameters ?? [ - `TArgs extends ${argsType}`, - `TQueryFnData = ${returnType} `, - 'TData = TQueryFnData', - 'TError = DefaultError', - ], - parameters: [ - { - name: optionalInput ? 'args?' : 'args', - type: inputType, - }, - { - name: 'options?', - type: optionsType, - }, - ...(optimisticUpdate - ? [ - { - name: 'optimisticUpdate', - type: 'boolean', - initializer: 'true', - }, - ] - : []), - ], - isExported: true, - }); + const argsType = overrideInputType ?? `Prisma.${model}${capOperation}Args`; + const inputType = `Prisma.SelectSubset`; - if (version === 'v5' && infinite && ['react', 'svelte'].includes(target)) { - // initialPageParam and getNextPageParam options are required in v5 - func.addStatements([`options = options ?? { initialPageParam: undefined, getNextPageParam: () => null };`]); - } + const infinite = generateMode.includes('Infinite'); + const suspense = generateMode.includes('Suspense'); + const optimistic = + supportOptimistic && + // infinite queries are not subject to optimistic updates + !infinite; - func.addStatements([ - makeGetContext(target), - `return ${ - infinite ? 'useInfiniteModelQuery' : 'useModelQuery' - }('${model}', \`\${endpoint}/${lowerCaseFirst( - model - )}/${operation}\`, args, options, fetch${optimisticUpdate ? ', optimisticUpdate' : ''});`, - ]); + let defaultReturnType = `Prisma.${model}GetPayload`; + if (optimistic) { + defaultReturnType += '& { $optimistic?: boolean }'; + } + if (returnArray) { + defaultReturnType = `Array<${defaultReturnType}>`; + } + + const returnType = overrideReturnType ?? defaultReturnType; + const optionsType = makeQueryOptions(target, 'TQueryFnData', 'TData', infinite, suspense, version); + + const func = sf.addFunction({ + name: `use${generateMode}${capOperation}${model}`, + typeParameters: overrideTypeParameters ?? [ + `TArgs extends ${argsType}`, + `TQueryFnData = ${returnType} `, + 'TData = TQueryFnData', + 'TError = DefaultError', + ], + parameters: [ + { + name: optionalInput ? 'args?' : 'args', + type: inputType, + }, + { + name: 'options?', + type: optionsType, + }, + ...(optimistic + ? [ + { + name: 'optimisticUpdate', + type: 'boolean', + initializer: 'true', + }, + ] + : []), + ], + isExported: true, + }); + + if (version === 'v5' && infinite && ['react', 'svelte'].includes(target)) { + // initialPageParam and getNextPageParam options are required in v5 + func.addStatements([`options = options ?? { initialPageParam: undefined, getNextPageParam: () => null };`]); + } + + func.addStatements([ + makeGetContext(target), + `return use${generateMode}ModelQuery('${model}', \`\${endpoint}/${lowerCaseFirst( + model + )}/${operation}\`, args, options, fetch${optimistic ? ', optimisticUpdate' : ''});`, + ]); + } } function generateMutationHook( @@ -313,23 +328,8 @@ function generateModelHooks( undefined, undefined, undefined, - false, - true - ); - // infinite findMany - generateQueryHook( - target, - version, - sf, - model.name, - 'findMany', - true, true, - undefined, - undefined, - undefined, - true, - false + true ); } @@ -565,19 +565,29 @@ function makeBaseImports(target: TargetFramework, version: TanStackVersion) { `type DefaultError = Error;`, ]; switch (target) { - case 'react': + case 'react': { + const suspense = + version === 'v5' + ? [ + `import { useSuspenseModelQuery, useSuspenseInfiniteModelQuery } from '${runtimeImportBase}/${target}';`, + `import type { UseSuspenseQueryOptions, UseSuspenseInfiniteQueryOptions } from '@tanstack/react-query';`, + ] + : []; return [ `import type { UseMutationOptions, UseQueryOptions, UseInfiniteQueryOptions, InfiniteData } from '@tanstack/react-query';`, `import { getHooksContext } from '${runtimeImportBase}/${target}';`, ...shared, + ...suspense, ]; - case 'vue': + } + case 'vue': { return [ `import type { UseMutationOptions, UseQueryOptions, UseInfiniteQueryOptions, InfiniteData } from '@tanstack/vue-query';`, `import { getHooksContext } from '${runtimeImportBase}/${target}';`, ...shared, ]; - case 'svelte': + } + case 'svelte': { return [ `import { derived } from 'svelte/store';`, `import type { MutationOptions, CreateQueryOptions, CreateInfiniteQueryOptions } from '@tanstack/svelte-query';`, @@ -587,6 +597,7 @@ function makeBaseImports(target: TargetFramework, version: TanStackVersion) { `import { getHooksContext } from '${runtimeImportBase}/${target}';`, ...shared, ]; + } default: throw new PluginError(name, `Unsupported target: ${target}`); } @@ -597,6 +608,7 @@ function makeQueryOptions( returnType: string, dataType: string, infinite: boolean, + suspense: boolean, version: TanStackVersion ) { switch (target) { @@ -604,8 +616,10 @@ function makeQueryOptions( return infinite ? version === 'v4' ? `Omit, 'queryKey'>` - : `Omit>, 'queryKey'>` - : `Omit, 'queryKey'>`; + : `Omit>, 'queryKey'>` + : `Omit, 'queryKey'>`; case 'vue': return `Omit, 'queryKey'>`; case 'svelte': diff --git a/packages/plugins/tanstack-query/src/runtime-v5/react.ts b/packages/plugins/tanstack-query/src/runtime-v5/react.ts index 4871e8229..375cb2676 100644 --- a/packages/plugins/tanstack-query/src/runtime-v5/react.ts +++ b/packages/plugins/tanstack-query/src/runtime-v5/react.ts @@ -4,10 +4,14 @@ import { useMutation, useQuery, useQueryClient, + useSuspenseInfiniteQuery, + useSuspenseQuery, type InfiniteData, type UseInfiniteQueryOptions, type UseMutationOptions, type UseQueryOptions, + UseSuspenseInfiniteQueryOptions, + UseSuspenseQueryOptions, } from '@tanstack/react-query-v5'; import type { ModelMeta } from '@zenstackhq/runtime/cross'; import { createContext, useContext } from 'react'; @@ -71,6 +75,33 @@ export function useModelQuery( }); } +/** + * Creates a react-query suspense query. + * + * @param model The name of the model under query. + * @param url The request URL. + * @param args The request args object, URL-encoded and appended as "?q=" parameter + * @param options The react-query options object + * @param fetch The fetch function to use for sending the HTTP request + * @param optimisticUpdate Whether to enable automatic optimistic update + * @returns useSuspenseQuery hook + */ +export function useSuspenseModelQuery( + model: string, + url: string, + args?: unknown, + options?: Omit, 'queryKey'>, + fetch?: FetchFn, + optimisticUpdate = false +) { + const reqUrl = makeUrl(url, args); + return useSuspenseQuery({ + queryKey: getQueryKey(model, url, args, false, optimisticUpdate), + queryFn: () => fetcher(reqUrl, undefined, fetch, false), + ...options, + }); +} + /** * Creates a react-query infinite query. * @@ -97,6 +128,32 @@ export function useInfiniteModelQuery( }); } +/** + * Creates a react-query infinite suspense query. + * + * @param model The name of the model under query. + * @param url The request URL. + * @param args The initial request args object, URL-encoded and appended as "?q=" parameter + * @param options The react-query infinite query options object + * @param fetch The fetch function to use for sending the HTTP request + * @returns useSuspenseInfiniteQuery hook + */ +export function useSuspenseInfiniteModelQuery( + model: string, + url: string, + args: unknown, + options: Omit>, 'queryKey'>, + fetch?: FetchFn +) { + return useSuspenseInfiniteQuery({ + queryKey: getQueryKey(model, url, args, true), + queryFn: ({ pageParam }) => { + return fetcher(makeUrl(url, pageParam ?? args), undefined, fetch, false); + }, + ...options, + }); +} + /** * Creates a react-query mutation * diff --git a/packages/plugins/tanstack-query/src/runtime/vue.ts b/packages/plugins/tanstack-query/src/runtime/vue.ts index a0f1055e8..b0a35f5f3 100644 --- a/packages/plugins/tanstack-query/src/runtime/vue.ts +++ b/packages/plugins/tanstack-query/src/runtime/vue.ts @@ -61,7 +61,7 @@ export function useModelQuery( model: string, url: string, args?: unknown, - options?: UseQueryOptions, + options?: Omit, 'queryKey'>, fetch?: FetchFn, optimisticUpdate = false ) { @@ -87,7 +87,7 @@ export function useInfiniteModelQuery( model: string, url: string, args?: unknown, - options?: UseInfiniteQueryOptions, + options?: Omit, 'queryKey'>, fetch?: FetchFn ) { return useInfiniteQuery({ diff --git a/packages/plugins/tanstack-query/tests/plugin.test.ts b/packages/plugins/tanstack-query/tests/plugin.test.ts index 38370d38a..c87e2a38f 100644 --- a/packages/plugins/tanstack-query/tests/plugin.test.ts +++ b/packages/plugins/tanstack-query/tests/plugin.test.ts @@ -61,7 +61,7 @@ ${sharedModel} provider: 'postgresql', pushDb: false, extraDependencies: ['react@18.2.0', '@types/react@18.2.0', '@tanstack/react-query@4.29.7'], - copyDependencies: [`${path.join(__dirname, '..')}/dist`], + copyDependencies: [path.resolve(__dirname, '../dist')], compile: true, } ); @@ -83,7 +83,7 @@ ${sharedModel} provider: 'postgresql', pushDb: false, extraDependencies: ['react@18.2.0', '@types/react@18.2.0', '@tanstack/react-query@^5.0.0'], - copyDependencies: [`${path.join(__dirname, '..')}/dist`], + copyDependencies: [path.resolve(__dirname, '../dist')], compile: true, } ); @@ -104,7 +104,7 @@ ${sharedModel} provider: 'postgresql', pushDb: false, extraDependencies: ['vue@^3.3.4', '@tanstack/vue-query@4.37.0'], - copyDependencies: [`${path.join(__dirname, '..')}/dist`], + copyDependencies: [path.resolve(__dirname, '../dist')], compile: true, } ); @@ -126,7 +126,7 @@ ${sharedModel} provider: 'postgresql', pushDb: false, extraDependencies: ['vue@^3.3.4', '@tanstack/vue-query@latest'], - copyDependencies: [`${path.join(__dirname, '..')}/dist`], + copyDependencies: [path.resolve(__dirname, '../dist')], compile: true, } ); @@ -147,7 +147,7 @@ ${sharedModel} provider: 'postgresql', pushDb: false, extraDependencies: ['svelte@^3.0.0', '@tanstack/svelte-query@4.29.7'], - copyDependencies: [`${path.join(__dirname, '..')}/dist`], + copyDependencies: [path.resolve(__dirname, '../dist')], compile: true, } ); @@ -169,7 +169,7 @@ ${sharedModel} provider: 'postgresql', pushDb: false, extraDependencies: ['svelte@^3.0.0', '@tanstack/svelte-query@^5.0.0'], - copyDependencies: [`${path.join(__dirname, '..')}/dist`], + copyDependencies: [path.resolve(__dirname, '../dist')], compile: true, } ); diff --git a/packages/plugins/tanstack-query/tests/test-model-meta.ts b/packages/plugins/tanstack-query/tests/test-model-meta.ts index 41731ad18..71a657bad 100644 --- a/packages/plugins/tanstack-query/tests/test-model-meta.ts +++ b/packages/plugins/tanstack-query/tests/test-model-meta.ts @@ -11,39 +11,46 @@ const fieldDefaults = { }; export const modelMeta: ModelMeta = { - fields: { + models: { user: { - id: { - ...fieldDefaults, - type: 'String', - isId: true, - name: 'id', - isOptional: false, - }, - name: { ...fieldDefaults, type: 'String', name: 'name' }, - email: { ...fieldDefaults, type: 'String', name: 'name', isOptional: false }, - posts: { - ...fieldDefaults, - type: 'Post', - isDataModel: true, - isArray: true, - name: 'posts', + name: 'user', + fields: { + id: { + ...fieldDefaults, + type: 'String', + isId: true, + name: 'id', + isOptional: false, + }, + name: { ...fieldDefaults, type: 'String', name: 'name' }, + email: { ...fieldDefaults, type: 'String', name: 'name', isOptional: false }, + posts: { + ...fieldDefaults, + type: 'Post', + isDataModel: true, + isArray: true, + name: 'posts', + }, }, + uniqueConstraints: {}, }, post: { - id: { - ...fieldDefaults, - type: 'String', - isId: true, - name: 'id', - isOptional: false, + name: 'post', + fields: { + id: { + ...fieldDefaults, + type: 'String', + isId: true, + name: 'id', + isOptional: false, + }, + title: { ...fieldDefaults, type: 'String', name: 'title' }, + owner: { ...fieldDefaults, type: 'User', name: 'owner', isDataModel: true, isRelationOwner: true }, + ownerId: { ...fieldDefaults, type: 'User', name: 'owner', isForeignKey: true }, }, - title: { ...fieldDefaults, type: 'String', name: 'title' }, - owner: { ...fieldDefaults, type: 'User', name: 'owner', isDataModel: true, isRelationOwner: true }, - ownerId: { ...fieldDefaults, type: 'User', name: 'owner', isForeignKey: true }, + uniqueConstraints: {}, }, }, - uniqueConstraints: {}, deleteCascade: { user: ['Post'], }, diff --git a/packages/plugins/tanstack-query/tsconfig.json b/packages/plugins/tanstack-query/tsconfig.json index 9e4f772c5..c51ec9bae 100644 --- a/packages/plugins/tanstack-query/tsconfig.json +++ b/packages/plugins/tanstack-query/tsconfig.json @@ -6,5 +6,5 @@ "jsx": "react" }, "include": ["src/**/*.ts"], - "exclude": ["src/runtime"] + "exclude": ["src/runtime", "src/runtime-v5"] } diff --git a/packages/plugins/trpc/package.json b/packages/plugins/trpc/package.json index 6620e365e..2f24da31f 100644 --- a/packages/plugins/trpc/package.json +++ b/packages/plugins/trpc/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/trpc", "displayName": "ZenStack plugin for tRPC", - "version": "1.8.2", + "version": "2.0.0-alpha.1", "description": "ZenStack plugin for tRPC", "main": "index.js", "repository": { diff --git a/packages/plugins/trpc/tests/trpc.test.ts b/packages/plugins/trpc/tests/trpc.test.ts index ca4a9c14d..757e7e182 100644 --- a/packages/plugins/trpc/tests/trpc.test.ts +++ b/packages/plugins/trpc/tests/trpc.test.ts @@ -56,7 +56,7 @@ model Foo { { provider: 'postgresql', pushDb: false, - extraDependencies: [`${path.join(__dirname, '../dist')}`, '@trpc/client', '@trpc/server'], + extraDependencies: [path.resolve(__dirname, '../dist'), '@trpc/client', '@trpc/server'], compile: true, fullZod: true, } @@ -98,7 +98,7 @@ model Foo { `, { pushDb: false, - extraDependencies: [`${path.join(__dirname, '../dist')}`, '@trpc/client', '@trpc/server'], + extraDependencies: [path.resolve(__dirname, '../dist'), '@trpc/client', '@trpc/server'], compile: true, fullZod: true, } @@ -128,7 +128,7 @@ model Post { `, { pushDb: false, - extraDependencies: [`${path.join(__dirname, '../dist')}`, '@trpc/client', '@trpc/server'], + extraDependencies: [path.resolve(__dirname, '../dist'), '@trpc/client', '@trpc/server'], compile: true, fullZod: true, customSchemaFilePath: 'zenstack/schema.zmodel', @@ -153,7 +153,7 @@ model Post { `, { pushDb: false, - extraDependencies: [`${path.join(__dirname, '../dist')}`, '@trpc/client', '@trpc/server'], + extraDependencies: [path.resolve(__dirname, '../dist'), '@trpc/client', '@trpc/server'], compile: true, fullZod: true, customSchemaFilePath: 'zenstack/schema.zmodel', @@ -183,7 +183,7 @@ model Post { `, { pushDb: false, - extraDependencies: [`${path.join(__dirname, '../dist')}`, '@trpc/client', '@trpc/server'], + extraDependencies: [path.resolve(__dirname, '../dist'), '@trpc/client', '@trpc/server'], compile: true, fullZod: true, customSchemaFilePath: 'zenstack/schema.zmodel', @@ -230,7 +230,7 @@ model Post { { pushDb: false, extraDependencies: [ - `${path.join(__dirname, '../dist')}`, + path.resolve(__dirname, '../dist'), '@trpc/client', '@trpc/server', '@trpc/react-query', @@ -254,7 +254,7 @@ model Post { `, { pushDb: false, - extraDependencies: [`${path.join(__dirname, '../dist')}`, '@trpc/client', '@trpc/server', '@trpc/next'], + extraDependencies: [path.resolve(__dirname, '../dist'), '@trpc/client', '@trpc/server', '@trpc/next'], compile: true, fullZod: true, } @@ -284,7 +284,7 @@ model post_item { `, { pushDb: false, - extraDependencies: [`${path.join(__dirname, '../dist')}`, '@trpc/client', '@trpc/server'], + extraDependencies: [path.resolve(__dirname, '../dist'), '@trpc/client', '@trpc/server'], compile: true, fullZod: true, } @@ -331,7 +331,7 @@ model Foo { { addPrelude: false, pushDb: false, - extraDependencies: [`${path.join(__dirname, '../dist')}`, '@trpc/client', '@trpc/server'], + extraDependencies: [path.resolve(__dirname, '../dist'), '@trpc/client', '@trpc/server'], compile: true, } ); @@ -402,7 +402,7 @@ model Foo { { addPrelude: false, pushDb: false, - extraDependencies: [`${path.join(__dirname, '../dist')}`, '@trpc/client', '@trpc/server'], + extraDependencies: [path.resolve(__dirname, '../dist'), '@trpc/client', '@trpc/server'], compile: true, } ); diff --git a/packages/runtime/package.json b/packages/runtime/package.json index 4c0125474..33e3abd6b 100644 --- a/packages/runtime/package.json +++ b/packages/runtime/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/runtime", "displayName": "ZenStack Runtime Library", - "version": "1.8.2", + "version": "2.0.0-alpha.1", "description": "Runtime of ZenStack for both client-side and server-side environments.", "repository": { "type": "git", @@ -53,13 +53,13 @@ "linkDirectory": true }, "dependencies": { - "@types/bcryptjs": "^2.4.2", "bcryptjs": "^2.4.3", "buffer": "^6.0.3", "change-case": "^4.1.2", "colors": "1.4.0", "decimal.js": "^10.4.2", "deepcopy": "^2.1.0", + "deepmerge": "^4.3.1", "lower-case-first": "^2.0.2", "pluralize": "^8.0.0", "semver": "^7.5.2", @@ -69,7 +69,8 @@ "upper-case-first": "^2.0.2", "uuid": "^9.0.0", "zod": "^3.22.4", - "zod-validation-error": "^1.5.0" + "zod-validation-error": "^1.5.0", + "z3-solver": "^4.12.5" }, "author": { "name": "ZenStack Team" diff --git a/packages/runtime/res/enhance.d.ts b/packages/runtime/res/enhance.d.ts new file mode 100644 index 000000000..4ae717bc4 --- /dev/null +++ b/packages/runtime/res/enhance.d.ts @@ -0,0 +1 @@ +export { enhance } from '.zenstack/enhance'; diff --git a/packages/runtime/res/enhance.js b/packages/runtime/res/enhance.js new file mode 100644 index 000000000..aa19af865 --- /dev/null +++ b/packages/runtime/res/enhance.js @@ -0,0 +1,10 @@ +'use strict'; +Object.defineProperty(exports, '__esModule', { value: true }); + +try { + exports.enhance = require('.zenstack/enhance').enhance; +} catch { + exports.enhance = function () { + throw new Error('Generated "enhance" function not found. Please run `zenstack generate` first.'); + }; +} diff --git a/packages/runtime/src/constants.ts b/packages/runtime/src/constants.ts index 36143621f..a85392887 100644 --- a/packages/runtime/src/constants.ts +++ b/packages/runtime/src/constants.ts @@ -61,7 +61,7 @@ export const PRISMA_PROXY_ENHANCER = '$__zenstack_enhancer'; /** * Minimum Prisma version supported */ -export const PRISMA_MINIMUM_VERSION = '4.8.0'; +export const PRISMA_MINIMUM_VERSION = '5.0.0'; /** * Selector function name for fetching pre-update entity values. @@ -97,3 +97,8 @@ export const FIELD_LEVEL_OVERRIDE_UPDATE_GUARD_PREFIX = 'updateFieldGuardOverrid * Flag that indicates if the model has field-level access control */ export const HAS_FIELD_LEVEL_POLICY_FLAG = 'hasFieldLevelPolicy'; + +/** + * Prefix for auxiliary relation field generated for delegated models + */ +export const DELEGATE_AUX_RELATION_PREFIX = 'delegate_aux'; diff --git a/packages/runtime/src/cross/model-meta.ts b/packages/runtime/src/cross/model-meta.ts index a38f7986d..9f767af0e 100644 --- a/packages/runtime/src/cross/model-meta.ts +++ b/packages/runtime/src/cross/model-meta.ts @@ -4,10 +4,22 @@ import { lowerCaseFirst } from 'lower-case-first'; * Runtime information of a data model or field attribute */ export type RuntimeAttribute = { + /** + * Attribute name + */ name: string; + + /** + * Attribute arguments + */ args: Array<{ name?: string; value: unknown }>; }; +/** + * Function for computing default value for a field + */ +export type FieldDefaultValueProvider = (userContext: unknown) => unknown; + /** * Runtime information of a data model field */ @@ -67,6 +79,16 @@ export type FieldInfo = { */ foreignKeyMapping?: Record; + /** + * Model from which the field is inherited + */ + inheritedFrom?: string; + + /** + * A function that provides a default value for the field + */ + defaultValueProvider?: FieldDefaultValueProvider; + /** * If the field is an auto-increment field */ @@ -80,23 +102,53 @@ export type FieldInfo = { export type UniqueConstraint = { name: string; fields: string[] }; /** - * ZModel data model metadata + * Metadata for a data model */ -export type ModelMeta = { +export type ModelInfo = { + /** + * Model name + */ + name: string; + + /** + * Base types + */ + baseTypes?: string[]; + /** - * Model fields + * Fields */ - fields: Record>; + fields: Record; /** - * Model unique constraints + * Unique constraints */ - uniqueConstraints: Record>; + uniqueConstraints?: Record; /** - * Information for cascading delete + * Attributes on the model */ - deleteCascade: Record; + attributes?: RuntimeAttribute[]; + + /** + * Discriminator field name + */ + discriminator?: string; +}; + +/** + * ZModel data model metadata + */ +export type ModelMeta = { + /** + * Data models + */ + models: Record; + + /** + * Mapping from model name to models that will be deleted because of it due to cascade delete + */ + deleteCascade?: Record; /** * Name of model that backs the `auth()` function @@ -107,8 +159,8 @@ export type ModelMeta = { /** * Resolves a model field to its metadata. Returns undefined if not found. */ -export function resolveField(modelMeta: ModelMeta, model: string, field: string) { - return modelMeta.fields[lowerCaseFirst(model)]?.[field]; +export function resolveField(modelMeta: ModelMeta, model: string, field: string): FieldInfo | undefined { + return modelMeta.models[lowerCaseFirst(model)]?.fields?.[field]; } /** @@ -126,5 +178,12 @@ export function requireField(modelMeta: ModelMeta, model: string, field: string) * Gets all fields of a model. */ export function getFields(modelMeta: ModelMeta, model: string) { - return modelMeta.fields[lowerCaseFirst(model)]; + return modelMeta.models[lowerCaseFirst(model)]?.fields; +} + +/** + * Gets unique constraints of a model. + */ +export function getUniqueConstraints(modelMeta: ModelMeta, model: string) { + return modelMeta.models[lowerCaseFirst(model)]?.uniqueConstraints; } diff --git a/packages/runtime/src/cross/nested-write-visitor.ts b/packages/runtime/src/cross/nested-write-visitor.ts index 7d67f6d9b..db2455d7e 100644 --- a/packages/runtime/src/cross/nested-write-visitor.ts +++ b/packages/runtime/src/cross/nested-write-visitor.ts @@ -34,7 +34,7 @@ export type NestedWriteVisitorContext = { * to let the visitor traverse it instead of its original children. */ export type NestedWriterVisitorCallback = { - create?: (model: string, args: any[], context: NestedWriteVisitorContext) => MaybePromise; + create?: (model: string, data: any, context: NestedWriteVisitorContext) => MaybePromise; createMany?: ( model: string, @@ -219,8 +219,10 @@ export class NestedWriteVisitor { case 'set': if (this.callback.set) { - const newContext = pushNewContext(field, model, {}); - await this.callback.set(model, data, newContext); + for (const item of enumerate(data)) { + const newContext = pushNewContext(field, model, item, true); + await this.callback.set(model, item, newContext); + } } break; diff --git a/packages/runtime/src/cross/query-analyzer.ts b/packages/runtime/src/cross/query-analyzer.ts index 5af410e82..bf501f020 100644 --- a/packages/runtime/src/cross/query-analyzer.ts +++ b/packages/runtime/src/cross/query-analyzer.ts @@ -81,7 +81,7 @@ function collectDeleteCascades(model: string, modelMeta: ModelMeta, result: Set< } visited.add(model); - const cascades = modelMeta.deleteCascade[lowerCaseFirst(model)]; + const cascades = modelMeta.deleteCascade?.[lowerCaseFirst(model)]; if (!cascades) { return; diff --git a/packages/runtime/src/cross/utils.ts b/packages/runtime/src/cross/utils.ts index e4237dbc7..1982513b3 100644 --- a/packages/runtime/src/cross/utils.ts +++ b/packages/runtime/src/cross/utils.ts @@ -1,5 +1,5 @@ import { lowerCaseFirst } from 'lower-case-first'; -import { ModelMeta } from '.'; +import { ModelInfo, ModelMeta } from '.'; /** * Gets field names in a data model entity, filtering out internal fields. @@ -47,7 +47,7 @@ export function zip(x: Enumerable, y: Enumerable): Array<[T1, T2 } export function getIdFields(modelMeta: ModelMeta, model: string, throwIfNotFound = false) { - let fields = modelMeta.fields[lowerCaseFirst(model)]; + let fields = modelMeta.models[lowerCaseFirst(model)]?.fields; if (!fields) { if (throwIfNotFound) { throw new Error(`Unable to load fields for ${model}`); @@ -61,3 +61,19 @@ export function getIdFields(modelMeta: ModelMeta, model: string, throwIfNotFound } return result; } + +export function getModelInfo( + modelMeta: ModelMeta, + model: string, + throwIfNotFound: Throw = false as Throw +): Throw extends true ? ModelInfo : ModelInfo | undefined { + const info = modelMeta.models[lowerCaseFirst(model)]; + if (!info && throwIfNotFound) { + throw new Error(`Unable to load info for ${model}`); + } + return info; +} + +export function isDelegateModel(modelMeta: ModelMeta, model: string) { + return !!getModelInfo(modelMeta, model)?.attributes?.some((attr) => attr.name === '@@delegate'); +} diff --git a/packages/runtime/src/enhance.d.ts b/packages/runtime/src/enhance.d.ts new file mode 100644 index 000000000..48e877878 --- /dev/null +++ b/packages/runtime/src/enhance.d.ts @@ -0,0 +1,2 @@ +// @ts-expect-error stub for re-exporting generated code +export { enhance } from '.zenstack/enhance'; diff --git a/packages/runtime/src/enhancements/create-enhancement.ts b/packages/runtime/src/enhancements/create-enhancement.ts new file mode 100644 index 000000000..1b9796970 --- /dev/null +++ b/packages/runtime/src/enhancements/create-enhancement.ts @@ -0,0 +1,188 @@ +import colors from 'colors'; +import semver from 'semver'; +import { PRISMA_MINIMUM_VERSION } from '../constants'; +import { isDelegateModel, type ModelMeta } from '../cross'; +import type { AuthUser } from '../types'; +import { withDefaultAuth } from './default-auth'; +import { withDelegate } from './delegate'; +import { Logger } from './logger'; +import { withOmit } from './omit'; +import { withPassword } from './password'; +import { withPolicy } from './policy'; +import type { ErrorTransformer } from './proxy'; +import type { PolicyDef, ZodSchemas } from './types'; + +/** + * Kinds of enhancements to `PrismaClient` + */ +export type EnhancementKind = 'password' | 'omit' | 'policy' | 'delegate'; + +/** + * All enhancement kinds + */ +const ALL_ENHANCEMENTS = ['password', 'omit', 'policy', 'delegate']; + +/** + * Transaction isolation levels: https://www.prisma.io/docs/orm/prisma-client/queries/transactions#transaction-isolation-level + */ +export type TransactionIsolationLevel = + | 'ReadUncommitted' + | 'ReadCommitted' + | 'RepeatableRead' + | 'Snapshot' + | 'Serializable'; + +export type EnhancementOptions = { + /** + * The kinds of enhancements to apply. By default all enhancements are applied. + */ + kinds?: EnhancementKind[]; + + /** + * Whether to log Prisma query + */ + logPrismaQuery?: boolean; + + /** + * Hook for transforming errors before they are thrown to the caller. + */ + errorTransformer?: ErrorTransformer; + + /** + * The `maxWait` option passed to `prisma.$transaction()` call for transactions initiated by ZenStack. + */ + transactionMaxWait?: number; + + /** + * The `timeout` option passed to `prisma.$transaction()` call for transactions initiated by ZenStack. + */ + transactionTimeout?: number; + + /** + * The `isolationLevel` option passed to `prisma.$transaction()` call for transactions initiated by ZenStack. + */ + transactionIsolationLevel?: TransactionIsolationLevel; +}; + +/** + * Options for {@link createEnhancement} + * + * @private + */ +export type InternalEnhancementOptions = EnhancementOptions & { + /** + * Policy definition + */ + policy: PolicyDef; + + /** + * Model metadata + */ + modelMeta: ModelMeta; + + /** + * Zod schemas for validation + */ + zodSchemas?: ZodSchemas; + + /** + * The Node module that contains PrismaClient + */ + // eslint-disable-next-line @typescript-eslint/no-explicit-any + prismaModule: any; +}; + +/** + * Context for creating enhanced `PrismaClient` + */ +export type EnhancementContext = { + user?: AuthUser; +}; + +let hasPassword: boolean | undefined = undefined; +let hasOmit: boolean | undefined = undefined; +let hasDefaultAuth: boolean | undefined = undefined; + +/** + * Gets a Prisma client enhanced with all enhancement behaviors, including access + * policy, field validation, field omission and password hashing. + * + * @private + * + * @param prisma The Prisma client to enhance. + * @param context Context. + * @param options Options. + */ +export function createEnhancement( + prisma: DbClient, + options: InternalEnhancementOptions, + context?: EnhancementContext +) { + if (!prisma) { + throw new Error('Invalid prisma instance'); + } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const prismaVer = (prisma as any)._clientVersion; + if (prismaVer && semver.lt(prismaVer, PRISMA_MINIMUM_VERSION)) { + console.warn( + `ZenStack requires Prisma version "${PRISMA_MINIMUM_VERSION}" or higher. Detected version is "${prismaVer}".` + ); + } + + const logger = new Logger(prisma); + logger.info(`Enabled ZenStack enhancements: ${options.kinds?.join(', ')}`); + + let result = prisma; + + if ( + process.env.ZENSTACK_TEST === '1' || // avoid caching in tests + hasPassword === undefined || + hasOmit === undefined || + hasDefaultAuth === undefined + ) { + const allFields = Object.values(options.modelMeta.models).flatMap((modelInfo) => + Object.values(modelInfo.fields) + ); + hasPassword = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@password')); + hasOmit = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@omit')); + hasDefaultAuth = allFields.some((field) => field.defaultValueProvider); + } + + const kinds = options.kinds ?? ALL_ENHANCEMENTS; + + // delegate proxy needs to be wrapped inside policy proxy, since it may translate `deleteMany` + // and `updateMany` to plain `delete` and `update` + if (Object.values(options.modelMeta.models).some((model) => isDelegateModel(options.modelMeta, model.name))) { + if (!kinds.includes('delegate')) { + console.warn( + colors.yellow( + 'Your ZModel contains delegate models but "delegate" enhancement kind is not enabled. This may result in unexpected behavior.' + ) + ); + } else { + result = withDelegate(result, options); + } + } + + // policy proxy + if (kinds.includes('policy')) { + result = withPolicy(result, options, context); + if (hasDefaultAuth) { + // @default(auth()) proxy + result = withDefaultAuth(result, options, context); + } + } + + if (hasPassword && kinds.includes('password')) { + // @password proxy + result = withPassword(result, options); + } + + if (hasOmit && kinds.includes('omit')) { + // @omit proxy + result = withOmit(result, options); + } + + return result; +} diff --git a/packages/runtime/src/enhancements/default-auth.ts b/packages/runtime/src/enhancements/default-auth.ts new file mode 100644 index 000000000..cce9af782 --- /dev/null +++ b/packages/runtime/src/enhancements/default-auth.ts @@ -0,0 +1,100 @@ +/* eslint-disable @typescript-eslint/no-unused-vars */ +/* eslint-disable @typescript-eslint/no-explicit-any */ + +import deepcopy from 'deepcopy'; +import { FieldInfo, NestedWriteVisitor, PrismaWriteActionType, enumerate, getFields } from '../cross'; +import { DbClientContract } from '../types'; +import { EnhancementContext, InternalEnhancementOptions } from './create-enhancement'; +import { DefaultPrismaProxyHandler, PrismaProxyActions, makeProxy } from './proxy'; + +/** + * Gets an enhanced Prisma client that supports `@default(auth())` attribute. + * + * @private + */ +export function withDefaultAuth( + prisma: DbClient, + options: InternalEnhancementOptions, + context?: EnhancementContext +): DbClient { + return makeProxy( + prisma, + options.modelMeta, + (_prisma, model) => new DefaultAuthHandler(_prisma as DbClientContract, model, options, context), + 'defaultAuth' + ); +} + +class DefaultAuthHandler extends DefaultPrismaProxyHandler { + private readonly userContext: any; + + constructor( + prisma: DbClientContract, + model: string, + options: InternalEnhancementOptions, + private readonly context?: EnhancementContext + ) { + super(prisma, model, options); + + if (!this.context?.user) { + throw new Error(`Using \`auth()\` in \`@default\` requires a user context`); + } + + this.userContext = this.context.user; + } + + // base override + protected async preprocessArgs(action: PrismaProxyActions, args: any) { + const actionsOfInterest: PrismaProxyActions[] = ['create', 'createMany', 'update', 'updateMany', 'upsert']; + if (actionsOfInterest.includes(action)) { + const newArgs = await this.preprocessWritePayload(this.model, action as PrismaWriteActionType, args); + return newArgs; + } + return args; + } + + private async preprocessWritePayload(model: string, action: PrismaWriteActionType, args: any) { + const newArgs = deepcopy(args); + + const processCreatePayload = (model: string, data: any) => { + const fields = getFields(this.options.modelMeta, model); + for (const fieldInfo of Object.values(fields)) { + if (fieldInfo.name in data) { + // create payload already sets field value + continue; + } + + if (!fieldInfo.defaultValueProvider) { + // field doesn't have a runtime default value provider + continue; + } + + const authDefaultValue = this.getDefaultValueFromAuth(fieldInfo); + if (authDefaultValue !== undefined) { + // set field value extracted from `auth()` + data[fieldInfo.name] = authDefaultValue; + } + } + }; + + // visit create payload and set default value to fields using `auth()` in `@default()` + const visitor = new NestedWriteVisitor(this.options.modelMeta, { + create: (model, data) => { + processCreatePayload(model, data); + }, + + createMany: (model, args) => { + for (const item of enumerate(args.data)) { + processCreatePayload(model, item); + } + }, + }); + + await visitor.visit(model, action, newArgs); + return newArgs; + } + + private getDefaultValueFromAuth(fieldInfo: FieldInfo) { + return fieldInfo.defaultValueProvider?.(this.userContext); + } +} diff --git a/packages/runtime/src/enhancements/delegate.ts b/packages/runtime/src/enhancements/delegate.ts new file mode 100644 index 000000000..7032a965a --- /dev/null +++ b/packages/runtime/src/enhancements/delegate.ts @@ -0,0 +1,1133 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ +import deepcopy from 'deepcopy'; +import deepmerge from 'deepmerge'; +import { lowerCaseFirst } from 'lower-case-first'; +import { DELEGATE_AUX_RELATION_PREFIX } from '../constants'; +import { + FieldInfo, + ModelInfo, + NestedWriteVisitor, + enumerate, + getIdFields, + getModelInfo, + isDelegateModel, + requireField, + resolveField, +} from '../cross'; +import type { CrudContract, DbClientContract } from '../types'; +import type { InternalEnhancementOptions } from './create-enhancement'; +import { Logger } from './logger'; +import { DefaultPrismaProxyHandler, makeProxy } from './proxy'; +import { QueryUtils } from './query-utils'; +import { formatObject, prismaClientValidationError } from './utils'; + +export function withDelegate(prisma: DbClient, options: InternalEnhancementOptions): DbClient { + return makeProxy( + prisma, + options.modelMeta, + (_prisma, model) => new DelegateProxyHandler(_prisma as DbClientContract, model, options), + 'delegate' + ); +} + +export class DelegateProxyHandler extends DefaultPrismaProxyHandler { + private readonly logger: Logger; + private readonly queryUtils: QueryUtils; + + constructor(prisma: DbClientContract, model: string, options: InternalEnhancementOptions) { + super(prisma, model, options); + this.logger = new Logger(prisma); + this.queryUtils = new QueryUtils(prisma, this.options); + } + + // #region find + + override findFirst(args: any): Promise { + return this.doFind(this.prisma, this.model, 'findFirst', args); + } + + override findFirstOrThrow(args: any): Promise { + return this.doFind(this.prisma, this.model, 'findFirstOrThrow', args); + } + + override findUnique(args: any): Promise { + return this.doFind(this.prisma, this.model, 'findUnique', args); + } + + override findUniqueOrThrow(args: any): Promise { + return this.doFind(this.prisma, this.model, 'findUniqueOrThrow', args); + } + + override async findMany(args: any): Promise { + return this.doFind(this.prisma, this.model, 'findMany', args); + } + + private async doFind( + db: CrudContract, + model: string, + method: 'findFirst' | 'findFirstOrThrow' | 'findUnique' | 'findUniqueOrThrow' | 'findMany', + args: any + ) { + if (!this.involvesDelegateModel(model)) { + return super[method](args); + } + + args = args ? deepcopy(args) : {}; + + this.injectWhereHierarchy(model, args?.where); + this.injectSelectIncludeHierarchy(model, args); + + if (this.options.logPrismaQuery) { + this.logger.info(`[delegate] \`${method}\` ${this.getModelName(model)}: ${formatObject(args)}`); + } + const entity = await db[model][method](args); + + if (Array.isArray(entity)) { + return entity.map((item) => this.assembleHierarchy(model, item)); + } else { + return this.assembleHierarchy(model, entity); + } + } + + private injectWhereHierarchy(model: string, where: any) { + if (!where || typeof where !== 'object') { + return; + } + + Object.entries(where).forEach(([field, value]) => { + const fieldInfo = resolveField(this.options.modelMeta, model, field); + if (!fieldInfo?.inheritedFrom) { + return; + } + + let base = this.getBaseModel(model); + let target = where; + + while (base) { + const baseRelationName = this.makeAuxRelationName(base); + + // prepare base layer where + let thisLayer: any; + if (target[baseRelationName]) { + thisLayer = target[baseRelationName]; + } else { + thisLayer = target[baseRelationName] = {}; + } + + if (base.name === fieldInfo.inheritedFrom) { + thisLayer[field] = value; + delete where[field]; + break; + } else { + target = thisLayer; + base = this.getBaseModel(base.name); + } + } + }); + } + + private buildWhereHierarchy(where: any) { + if (!where) { + return undefined; + } + + where = deepcopy(where); + Object.entries(where).forEach(([field, value]) => { + const fieldInfo = resolveField(this.options.modelMeta, this.model, field); + if (!fieldInfo?.inheritedFrom) { + return; + } + + let base = this.getBaseModel(this.model); + let target = where; + + while (base) { + const baseRelationName = this.makeAuxRelationName(base); + + // prepare base layer where + let thisLayer: any; + if (target[baseRelationName]) { + thisLayer = target[baseRelationName]; + } else { + thisLayer = target[baseRelationName] = {}; + } + + if (base.name === fieldInfo.inheritedFrom) { + thisLayer[field] = value; + delete where[field]; + break; + } else { + target = thisLayer; + base = this.getBaseModel(base.name); + } + } + }); + + return where; + } + + private injectSelectIncludeHierarchy(model: string, args: any) { + if (!args || typeof args !== 'object') { + return; + } + + for (const kind of ['select', 'include'] as const) { + if (args[kind] && typeof args[kind] === 'object') { + for (const [field, value] of Object.entries(args[kind])) { + if (value !== undefined) { + if (this.injectBaseFieldSelect(model, field, value, args, kind)) { + delete args[kind][field]; + } else { + const fieldInfo = resolveField(this.options.modelMeta, model, field); + if (fieldInfo && this.isDelegateOrDescendantOfDelegate(fieldInfo.type)) { + let nextValue = value; + if (nextValue === true) { + // make sure the payload is an object + args[kind][field] = nextValue = {}; + } + this.injectSelectIncludeHierarchy(fieldInfo.type, nextValue); + } + } + } + } + } + } + + if (!args.select) { + this.injectBaseIncludeRecursively(model, args); + } + } + + private buildSelectIncludeHierarchy(model: string, args: any) { + args = deepcopy(args); + const selectInclude: any = this.extractSelectInclude(args) || {}; + + if (selectInclude.select && typeof selectInclude.select === 'object') { + Object.entries(selectInclude.select).forEach(([field, value]) => { + if (value) { + if (this.injectBaseFieldSelect(model, field, value, selectInclude, 'select')) { + delete selectInclude.select[field]; + } + } + }); + } else if (selectInclude.include && typeof selectInclude.include === 'object') { + Object.entries(selectInclude.include).forEach(([field, value]) => { + if (value) { + if (this.injectBaseFieldSelect(model, field, value, selectInclude, 'include')) { + delete selectInclude.include[field]; + } + } + }); + } + + if (!selectInclude.select) { + this.injectBaseIncludeRecursively(model, selectInclude); + } + return selectInclude; + } + + private injectBaseFieldSelect( + model: string, + field: string, + value: any, + selectInclude: any, + context: 'select' | 'include' + ) { + const fieldInfo = resolveField(this.options.modelMeta, model, field); + if (!fieldInfo?.inheritedFrom) { + return false; + } + + let base = this.getBaseModel(model); + let target = selectInclude; + + while (base) { + const baseRelationName = this.makeAuxRelationName(base); + + // prepare base layer select/include + // let selectOrInclude = 'select'; + let thisLayer: any; + if (target.include) { + // selectOrInclude = 'include'; + thisLayer = target.include; + } else if (target.select) { + // selectOrInclude = 'select'; + thisLayer = target.select; + } else { + // selectInclude = 'include'; + thisLayer = target.select = {}; + } + + if (base.name === fieldInfo.inheritedFrom) { + if (!thisLayer[baseRelationName]) { + thisLayer[baseRelationName] = { [context]: {} }; + } + thisLayer[baseRelationName][context][field] = value; + break; + } else { + if (!thisLayer[baseRelationName]) { + thisLayer[baseRelationName] = { select: {} }; + } + target = thisLayer[baseRelationName]; + base = this.getBaseModel(base.name); + } + } + + return true; + } + + private injectBaseIncludeRecursively(model: string, selectInclude: any) { + const base = this.getBaseModel(model); + if (!base) { + return; + } + const baseRelationName = this.makeAuxRelationName(base); + + if (selectInclude.select) { + selectInclude.include = { [baseRelationName]: {}, ...selectInclude.select }; + delete selectInclude.select; + } else { + selectInclude.include = { [baseRelationName]: {}, ...selectInclude.include }; + } + this.injectBaseIncludeRecursively(base.name, selectInclude.include[baseRelationName]); + } + + // #endregion + + // #region create + + override async create(args: any) { + if (!args) { + throw prismaClientValidationError(this.prisma, this.options.prismaModule, 'query argument is required'); + } + if (!args.data) { + throw prismaClientValidationError( + this.prisma, + this.options.prismaModule, + 'data field is required in query argument' + ); + } + + if (isDelegateModel(this.options.modelMeta, this.model)) { + throw prismaClientValidationError( + this.prisma, + this.options.prismaModule, + `Model "${this.model}" is a delegate and cannot be created directly` + ); + } + + if (!this.involvesDelegateModel(this.model)) { + return super.create(args); + } + + return this.doCreate(this.prisma, this.model, args); + } + + override createMany(args: { data: any; skipDuplicates?: boolean }): Promise<{ count: number }> { + if (!args) { + throw prismaClientValidationError(this.prisma, this.options.prismaModule, 'query argument is required'); + } + if (!args.data) { + throw prismaClientValidationError( + this.prisma, + this.options.prismaModule, + 'data field is required in query argument' + ); + } + + if (!this.involvesDelegateModel(this.model)) { + return super.createMany(args); + } + + if (this.isDelegateOrDescendantOfDelegate(this.model) && args.skipDuplicates) { + throw prismaClientValidationError( + this.prisma, + this.options.prismaModule, + '`createMany` with `skipDuplicates` set to true is not supported for delegated models' + ); + } + + // note that we can't call `createMany` directly because it doesn't support + // nested created, which is needed for creating base entities + return this.queryUtils.transaction(this.prisma, async (tx) => { + const r = await Promise.all( + enumerate(args.data).map(async (item) => { + return this.doCreate(tx, this.model, item); + }) + ); + + // filter out undefined value (due to skipping duplicates) + return { count: r.filter((item) => !!item).length }; + }); + } + + private async doCreate(db: CrudContract, model: string, args: any) { + args = deepcopy(args); + + await this.injectCreateHierarchy(model, args); + this.injectSelectIncludeHierarchy(model, args); + + if (this.options.logPrismaQuery) { + this.logger.info(`[delegate] \`create\` ${this.getModelName(model)}: ${formatObject(args)}`); + } + const result = await db[model].create(args); + return this.assembleHierarchy(model, result); + } + + private async injectCreateHierarchy(model: string, args: any) { + const visitor = new NestedWriteVisitor(this.options.modelMeta, { + create: (model, args, _context) => { + this.doProcessCreatePayload(model, args); + }, + + createMany: (model, args, _context) => { + if (args.skipDuplicates) { + throw prismaClientValidationError( + this.prisma, + this.options.prismaModule, + '`createMany` with `skipDuplicates` set to true is not supported for delegated models' + ); + } + + for (const item of enumerate(args?.data)) { + this.doProcessCreatePayload(model, item); + } + }, + }); + + await visitor.visit(model, 'create', args); + } + + private doProcessCreatePayload(model: string, args: any) { + if (!args) { + return; + } + + this.ensureBaseCreateHierarchy(model, args); + + for (const [field, value] of Object.entries(args)) { + const fieldInfo = resolveField(this.options.modelMeta, model, field); + if (fieldInfo?.inheritedFrom) { + this.injectBaseFieldData(model, fieldInfo, value, args, 'create'); + delete args[field]; + } + } + } + + // ensure the full nested "create" structure is created for base types + private ensureBaseCreateHierarchy(model: string, result: any) { + let curr = result; + let base = this.getBaseModel(model); + let sub = this.getModelInfo(model); + + while (base) { + const baseRelationName = this.makeAuxRelationName(base); + + if (!curr[baseRelationName]) { + curr[baseRelationName] = {}; + } + if (!curr[baseRelationName].create) { + curr[baseRelationName].create = {}; + if (base.discriminator) { + // set discriminator field + curr[baseRelationName].create[base.discriminator] = sub.name; + } + } + curr = curr[baseRelationName].create; + sub = base; + base = this.getBaseModel(base.name); + } + } + + // inject field data that belongs to base type into proper nesting structure + private injectBaseFieldData( + model: string, + fieldInfo: FieldInfo, + value: unknown, + args: any, + mode: 'create' | 'update' + ) { + let base = this.getBaseModel(model); + let curr = args; + + while (base) { + if (base.discriminator === fieldInfo.name) { + throw prismaClientValidationError( + this.prisma, + this.options.prismaModule, + `fields "${fieldInfo.name}" is a discriminator and cannot be set directly` + ); + } + + const baseRelationName = this.makeAuxRelationName(base); + + if (!curr[baseRelationName]) { + curr[baseRelationName] = {}; + } + if (!curr[baseRelationName][mode]) { + curr[baseRelationName][mode] = {}; + } + curr = curr[baseRelationName][mode]; + + if (fieldInfo.inheritedFrom === base.name) { + curr[fieldInfo.name] = value; + break; + } + + base = this.getBaseModel(base.name); + } + } + + // #endregion + + // #region update + + override update(args: any): Promise { + if (!args) { + throw prismaClientValidationError(this.prisma, this.options.prismaModule, 'query argument is required'); + } + if (!args.data) { + throw prismaClientValidationError( + this.prisma, + this.options.prismaModule, + 'data field is required in query argument' + ); + } + + if (!this.involvesDelegateModel(this.model)) { + return super.update(args); + } + + return this.queryUtils.transaction(this.prisma, (tx) => this.doUpdate(tx, this.model, args)); + } + + override async updateMany(args: any): Promise<{ count: number }> { + if (!args) { + throw prismaClientValidationError(this.prisma, this.options.prismaModule, 'query argument is required'); + } + if (!args.data) { + throw prismaClientValidationError( + this.prisma, + this.options.prismaModule, + 'data field is required in query argument' + ); + } + + if (!this.involvesDelegateModel(this.model)) { + return super.updateMany(args); + } + + const simpleUpdateMany = Object.keys(args.data).every((key) => { + // check if the `data` clause involves base fields + const fieldInfo = resolveField(this.options.modelMeta, this.model, key); + return !fieldInfo?.inheritedFrom; + }); + + return this.queryUtils.transaction(this.prisma, (tx) => + this.doUpdateMany(tx, this.model, args, simpleUpdateMany) + ); + } + + override async upsert(args: any): Promise { + if (!args) { + throw prismaClientValidationError(this.prisma, this.options.prismaModule, 'query argument is required'); + } + if (!args.where) { + throw prismaClientValidationError( + this.prisma, + this.options.prismaModule, + 'where field is required in query argument' + ); + } + + if (isDelegateModel(this.options.modelMeta, this.model)) { + throw prismaClientValidationError( + this.prisma, + this.options.prismaModule, + `Model "${this.model}" is a delegate and doesn't support upsert` + ); + } + + if (!this.involvesDelegateModel(this.model)) { + return super.upsert(args); + } + + args = deepcopy(args); + this.injectWhereHierarchy(this.model, (args as any)?.where); + this.injectSelectIncludeHierarchy(this.model, args); + if (args.create) { + this.doProcessCreatePayload(this.model, args.create); + } + if (args.update) { + this.doProcessUpdatePayload(this.model, args.update); + } + + if (this.options.logPrismaQuery) { + this.logger.info(`[delegate] \`upsert\` ${this.getModelName(this.model)}: ${formatObject(args)}`); + } + const result = await this.prisma[this.model].upsert(args); + return this.assembleHierarchy(this.model, result); + } + + private async doUpdate(db: CrudContract, model: string, args: any): Promise { + args = deepcopy(args); + + await this.injectUpdateHierarchy(db, model, args); + this.injectSelectIncludeHierarchy(model, args); + + if (this.options.logPrismaQuery) { + this.logger.info(`[delegate] \`update\` ${this.getModelName(model)}: ${formatObject(args)}`); + } + const result = await db[model].update(args); + return this.assembleHierarchy(model, result); + } + + private async doUpdateMany( + db: CrudContract, + model: string, + args: any, + simpleUpdateMany: boolean + ): Promise<{ count: number }> { + if (simpleUpdateMany) { + // do a direct `updateMany` + args = deepcopy(args); + await this.injectUpdateHierarchy(db, model, args); + + if (this.options.logPrismaQuery) { + this.logger.info(`[delegate] \`updateMany\` ${this.getModelName(model)}: ${formatObject(args)}`); + } + return db[model].updateMany(args); + } else { + // translate to plain `update` for nested write into base fields + const findArgs = { + where: deepcopy(args.where), + select: this.queryUtils.makeIdSelection(model), + }; + await this.injectUpdateHierarchy(db, model, findArgs); + if (this.options.logPrismaQuery) { + this.logger.info( + `[delegate] \`updateMany\` find candidates: ${this.getModelName(model)}: ${formatObject(findArgs)}` + ); + } + const entities = await db[model].findMany(findArgs); + + const updatePayload = { data: deepcopy(args.data), select: this.queryUtils.makeIdSelection(model) }; + await this.injectUpdateHierarchy(db, model, updatePayload); + const result = await Promise.all( + entities.map((entity) => { + const updateArgs = { + where: entity, + ...updatePayload, + }; + this.logger.info( + `[delegate] \`updateMany\` update: ${this.getModelName(model)}: ${formatObject(updateArgs)}` + ); + return db[model].update(updateArgs); + }) + ); + return { count: result.length }; + } + } + + private async injectUpdateHierarchy(db: CrudContract, model: string, args: any) { + const visitor = new NestedWriteVisitor(this.options.modelMeta, { + update: (model, args, _context) => { + this.injectWhereHierarchy(model, (args as any)?.where); + this.doProcessUpdatePayload(model, (args as any)?.data); + }, + + updateMany: async (model, args, context) => { + let simpleUpdateMany = Object.keys(args.data).every((key) => { + // check if the `data` clause involves base fields + const fieldInfo = resolveField(this.options.modelMeta, model, key); + return !fieldInfo?.inheritedFrom; + }); + + if (simpleUpdateMany) { + // check if the `where` clause involves base fields + simpleUpdateMany = Object.keys(args.where || {}).every((key) => { + const fieldInfo = resolveField(this.options.modelMeta, model, key); + return !fieldInfo?.inheritedFrom; + }); + } + + if (simpleUpdateMany) { + this.injectWhereHierarchy(model, (args as any)?.where); + this.doProcessUpdatePayload(model, (args as any)?.data); + } else { + const where = this.queryUtils.buildReversedQuery(context, false, false); + await this.queryUtils.transaction(db, async (tx) => { + await this.doUpdateMany(tx, model, { ...args, where }, simpleUpdateMany); + }); + delete context.parent['updateMany']; + } + }, + + upsert: (model, args, _context) => { + this.injectWhereHierarchy(model, (args as any)?.where); + if (args.create) { + this.doProcessCreatePayload(model, (args as any)?.create); + } + if (args.update) { + this.doProcessUpdatePayload(model, (args as any)?.update); + } + }, + + create: (model, args, _context) => { + if (isDelegateModel(this.options.modelMeta, model)) { + throw prismaClientValidationError( + this.prisma, + this.options.prismaModule, + `Model "${model}" is a delegate and cannot be created directly` + ); + } + this.doProcessCreatePayload(model, args); + }, + + createMany: (model, args, _context) => { + if (args.skipDuplicates) { + throw prismaClientValidationError( + this.prisma, + this.options.prismaModule, + '`createMany` with `skipDuplicates` set to true is not supported for delegated models' + ); + } + + for (const item of enumerate(args?.data)) { + this.doProcessCreatePayload(model, item); + } + }, + + connect: (model, args, _context) => { + this.injectWhereHierarchy(model, args); + }, + + connectOrCreate: (model, args, _context) => { + this.injectWhereHierarchy(model, args.where); + if (args.create) { + this.doProcessCreatePayload(model, args.create); + } + }, + + disconnect: (model, args, _context) => { + this.injectWhereHierarchy(model, args); + }, + + set: (model, args, _context) => { + this.injectWhereHierarchy(model, args); + }, + + delete: async (model, _args, context) => { + const where = this.queryUtils.buildReversedQuery(context, false, false); + await this.queryUtils.transaction(db, async (tx) => { + await this.doDelete(tx, model, { where }); + }); + delete context.parent['delete']; + }, + + deleteMany: async (model, _args, context) => { + const where = this.queryUtils.buildReversedQuery(context, false, false); + await this.queryUtils.transaction(db, async (tx) => { + await this.doDeleteMany(tx, model, where); + }); + delete context.parent['deleteMany']; + }, + }); + + await visitor.visit(model, 'update', args); + } + + private doProcessUpdatePayload(model: string, data: any) { + if (!data) { + return; + } + + for (const [field, value] of Object.entries(data)) { + const fieldInfo = resolveField(this.options.modelMeta, model, field); + if (fieldInfo?.inheritedFrom) { + this.injectBaseFieldData(model, fieldInfo, value, data, 'update'); + delete data[field]; + } + } + } + + // #endregion + + // #region delete + + override delete(args: any): Promise { + if (!args) { + throw prismaClientValidationError(this.prisma, this.options.prismaModule, 'query argument is required'); + } + + if (!this.involvesDelegateModel(this.model)) { + return super.delete(args); + } + + return this.queryUtils.transaction(this.prisma, async (tx) => { + const selectInclude = this.buildSelectIncludeHierarchy(this.model, args); + + // make sure id fields are selected + const idFields = this.getIdFields(this.model); + for (const idField of idFields) { + if (selectInclude?.select && !(idField.name in selectInclude.select)) { + selectInclude.select[idField.name] = true; + } + } + + const deleteArgs = { ...deepcopy(args), ...selectInclude }; + return this.doDelete(tx, this.model, deleteArgs); + }); + } + + override deleteMany(args: any): Promise<{ count: number }> { + if (!this.involvesDelegateModel(this.model)) { + return super.deleteMany(args); + } + + return this.queryUtils.transaction(this.prisma, (tx) => this.doDeleteMany(tx, this.model, args?.where)); + } + + private async doDeleteMany(db: CrudContract, model: string, where: any): Promise<{ count: number }> { + // query existing entities with id + const idSelection = this.queryUtils.makeIdSelection(model); + const findArgs = { where: deepcopy(where), select: idSelection }; + this.injectWhereHierarchy(model, findArgs.where); + + if (this.options.logPrismaQuery) { + this.logger.info( + `[delegate] \`deleteMany\` find candidates: ${this.getModelName(model)}: ${formatObject(findArgs)}` + ); + } + const entities = await db[model].findMany(findArgs); + + // recursively delete base entities (they all have the same id values) + await Promise.all(entities.map((entity) => this.doDelete(db, model, { where: entity }))); + + return { count: entities.length }; + } + + private async deleteBaseRecursively(db: CrudContract, model: string, idValues: any) { + let base = this.getBaseModel(model); + while (base) { + await db[base.name].delete({ where: idValues }); + base = this.getBaseModel(base.name); + } + } + + private async doDelete(db: CrudContract, model: string, args: any): Promise { + this.injectWhereHierarchy(model, args.where); + + if (this.options.logPrismaQuery) { + this.logger.info(`[delegate] \`delete\` ${this.getModelName(model)}: ${formatObject(args)}`); + } + const result = await db[model].delete(args); + const idValues = this.queryUtils.getEntityIds(model, result); + + // recursively delete base entities (they all have the same id values) + await this.deleteBaseRecursively(db, model, idValues); + return this.assembleHierarchy(model, result); + } + + // #endregion + + // #region aggregation + + override aggregate(args: any): Promise { + if (!args) { + throw prismaClientValidationError(this.prisma, this.options.prismaModule, 'query argument is required'); + } + if (!this.involvesDelegateModel(this.model)) { + return super.aggregate(args); + } + + // check if any aggregation operator is using fields from base + this.checkAggregationArgs('aggregate', args); + + args = deepcopy(args); + + if (args.cursor) { + args.cursor = this.buildWhereHierarchy(args.cursor); + } + + if (args.orderBy) { + args.orderBy = this.buildWhereHierarchy(args.orderBy); + } + + if (args.where) { + args.where = this.buildWhereHierarchy(args.where); + } + + if (this.options.logPrismaQuery) { + this.logger.info(`[delegate] \`aggregate\` ${this.getModelName(this.model)}: ${formatObject(args)}`); + } + return super.aggregate(args); + } + + override count(args: any): Promise { + if (!this.involvesDelegateModel(this.model)) { + return super.count(args); + } + + // check if count select is using fields from base + this.checkAggregationArgs('count', args); + + args = deepcopy(args); + + if (args?.cursor) { + args.cursor = this.buildWhereHierarchy(args.cursor); + } + + if (args?.where) { + args.where = this.buildWhereHierarchy(args.where); + } + + if (this.options.logPrismaQuery) { + this.logger.info(`[delegate] \`count\` ${this.getModelName(this.model)}: ${formatObject(args)}`); + } + return super.count(args); + } + + override groupBy(args: any): Promise { + if (!args) { + throw prismaClientValidationError(this.prisma, this.options.prismaModule, 'query argument is required'); + } + if (!this.involvesDelegateModel(this.model)) { + return super.groupBy(args); + } + + // check if count select is using fields from base + this.checkAggregationArgs('groupBy', args); + + if (args.by) { + for (const by of enumerate(args.by)) { + const fieldInfo = resolveField(this.options.modelMeta, this.model, by); + if (fieldInfo && fieldInfo.inheritedFrom) { + throw prismaClientValidationError( + this.prisma, + this.options.prismaModule, + `groupBy with fields from base type is not supported yet: "${by}"` + ); + } + } + } + + args = deepcopy(args); + + if (args.where) { + args.where = this.buildWhereHierarchy(args.where); + } + + if (this.options.logPrismaQuery) { + this.logger.info(`[delegate] \`groupBy\` ${this.getModelName(this.model)}: ${formatObject(args)}`); + } + return super.groupBy(args); + } + + private checkAggregationArgs(operation: 'aggregate' | 'count' | 'groupBy', args: any) { + if (!args) { + return; + } + + for (const op of ['_count', '_sum', '_avg', '_min', '_max', 'select', 'having']) { + if (args[op] && typeof args[op] === 'object') { + for (const field of Object.keys(args[op])) { + const fieldInfo = resolveField(this.options.modelMeta, this.model, field); + if (fieldInfo?.inheritedFrom) { + throw prismaClientValidationError( + this.prisma, + this.options.prismaModule, + `${operation} with fields from base type is not supported yet: "${field}"` + ); + } + } + } + } + } + + // #endregion + + // #region utils + + private extractSelectInclude(args: any) { + if (!args) { + return undefined; + } + args = deepcopy(args); + return 'select' in args + ? { select: args['select'] } + : 'include' in args + ? { include: args['include'] } + : undefined; + } + + private makeAuxRelationName(model: ModelInfo) { + return `${DELEGATE_AUX_RELATION_PREFIX}_${lowerCaseFirst(model.name)}`; + } + + private getModelName(model: string) { + const info = getModelInfo(this.options.modelMeta, model, true); + return info.name; + } + + private getIdFields(model: string): FieldInfo[] { + const idFields = getIdFields(this.options.modelMeta, model); + if (idFields && idFields.length > 0) { + return idFields; + } + const base = this.getBaseModel(model); + return base ? this.getIdFields(base.name) : []; + } + + private getModelInfo(model: string) { + return getModelInfo(this.options.modelMeta, model, true); + } + + private getBaseModel(model: string) { + const baseNames = getModelInfo(this.options.modelMeta, model, true).baseTypes; + if (!baseNames) { + return undefined; + } + if (baseNames.length > 1) { + throw new Error('Multi-inheritance is not supported'); + } + return this.options.modelMeta.models[lowerCaseFirst(baseNames[0])]; + } + + private involvesDelegateModel(model: string, visited?: Set): boolean { + if (this.isDelegateOrDescendantOfDelegate(model)) { + return true; + } + + visited = visited ?? new Set(); + if (visited.has(model)) { + return false; + } + visited.add(model); + + const modelInfo = getModelInfo(this.options.modelMeta, model, true); + return Object.values(modelInfo.fields).some( + (field) => field.isDataModel && this.involvesDelegateModel(field.type, visited) + ); + } + + private isDelegateOrDescendantOfDelegate(model: string): boolean { + if (isDelegateModel(this.options.modelMeta, model)) { + return true; + } + const baseTypes = getModelInfo(this.options.modelMeta, model)?.baseTypes; + return !!( + baseTypes && + baseTypes.length > 0 && + baseTypes.some((base) => this.isDelegateOrDescendantOfDelegate(base)) + ); + } + + private assembleHierarchy(model: string, entity: any) { + if (!entity || typeof entity !== 'object') { + return entity; + } + + const result: any = {}; + const base = this.getBaseModel(model); + + if (base) { + // merge base fields + const baseRelationName = this.makeAuxRelationName(base); + const baseData = entity[baseRelationName]; + if (baseData && typeof baseData === 'object') { + const baseAssembled = this.assembleHierarchy(base.name, baseData); + Object.assign(result, baseAssembled); + } + } + + const modelInfo = getModelInfo(this.options.modelMeta, model, true); + + for (const field of Object.values(modelInfo.fields)) { + if (field.inheritedFrom) { + // already merged from base + continue; + } + + if (field.name in entity) { + const fieldValue = entity[field.name]; + if (field.isDataModel) { + if (Array.isArray(fieldValue)) { + result[field.name] = fieldValue.map((item) => this.assembleHierarchy(field.type, item)); + } else { + result[field.name] = this.assembleHierarchy(field.type, fieldValue); + } + } else { + result[field.name] = fieldValue; + } + } + } + + return result; + } + + // #endregion + + // #region backup + + private transformWhereHierarchy(where: any, contextModel: ModelInfo, forModel: ModelInfo) { + if (!where || typeof where !== 'object') { + return where; + } + + let curr: ModelInfo | undefined = contextModel; + const inheritStack: ModelInfo[] = []; + while (curr) { + inheritStack.unshift(curr); + curr = this.getBaseModel(curr.name); + } + + let result: any = {}; + for (const [key, value] of Object.entries(where)) { + const fieldInfo = requireField(this.options.modelMeta, contextModel.name, key); + const fieldHierarchy = this.transformFieldHierarchy(fieldInfo, value, contextModel, forModel, inheritStack); + result = deepmerge(result, fieldHierarchy); + } + + return result; + } + + private transformFieldHierarchy( + fieldInfo: FieldInfo, + value: unknown, + contextModel: ModelInfo, + forModel: ModelInfo, + inheritStack: ModelInfo[] + ): any { + const fieldModel = fieldInfo.inheritedFrom ? this.getModelInfo(fieldInfo.inheritedFrom) : contextModel; + if (fieldModel === forModel) { + return { [fieldInfo.name]: value }; + } + + const fieldModelPos = inheritStack.findIndex((m) => m === fieldModel); + const forModelPos = inheritStack.findIndex((m) => m === forModel); + const result: any = {}; + let curr = result; + + if (fieldModelPos > forModelPos) { + // walk down hierarchy + for (let i = forModelPos + 1; i <= fieldModelPos; i++) { + const rel = this.makeAuxRelationName(inheritStack[i]); + curr[rel] = {}; + curr = curr[rel]; + } + } else { + // walk up hierarchy + for (let i = forModelPos - 1; i >= fieldModelPos; i--) { + const rel = this.makeAuxRelationName(inheritStack[i]); + curr[rel] = {}; + curr = curr[rel]; + } + } + + curr[fieldInfo.name] = value; + return result; + } + + // #endregion +} diff --git a/packages/runtime/src/enhancements/enhance.ts b/packages/runtime/src/enhancements/enhance.ts deleted file mode 100644 index 42a504bdf..000000000 --- a/packages/runtime/src/enhancements/enhance.ts +++ /dev/null @@ -1,52 +0,0 @@ -import { getDefaultModelMeta } from '../loader'; -import { withOmit, WithOmitOptions } from './omit'; -import { withPassword, WithPasswordOptions } from './password'; -import { withPolicy, WithPolicyContext, WithPolicyOptions } from './policy'; - -/** - * Options @see enhance - */ -export type EnhancementOptions = WithPolicyOptions & WithPasswordOptions & WithOmitOptions; - -let hasPassword: boolean | undefined = undefined; -let hasOmit: boolean | undefined = undefined; - -/** - * Gets a Prisma client enhanced with all essential behaviors, including access - * policy, field validation, field omission and password hashing. - * - * It's a shortcut for calling withOmit(withPassword(withPolicy(prisma, options))). - * - * @param prisma The Prisma client to enhance. - * @param context The context to for evaluating access policies. - * @param options Options. - */ -export function enhance( - prisma: DbClient, - context?: WithPolicyContext, - options?: EnhancementOptions -) { - let result = prisma; - - if (hasPassword === undefined || hasOmit === undefined) { - const modelMeta = options?.modelMeta ?? getDefaultModelMeta(options?.loadPath); - const allFields = Object.values(modelMeta.fields).flatMap((modelInfo) => Object.values(modelInfo)); - hasPassword = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@password')); - hasOmit = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@omit')); - } - - if (hasPassword) { - // @password proxy - result = withPassword(result, options); - } - - if (hasOmit) { - // @omit proxy - result = withOmit(result, options); - } - - // policy proxy - result = withPolicy(result, context, options); - - return result; -} diff --git a/packages/runtime/src/enhancements/index.ts b/packages/runtime/src/enhancements/index.ts index 25b150a71..3ddeddac0 100644 --- a/packages/runtime/src/enhancements/index.ts +++ b/packages/runtime/src/enhancements/index.ts @@ -1,9 +1,4 @@ export * from '../cross'; -export * from './enhance'; -export * from './omit'; -export * from './password'; -export * from './policy'; -export * from './preset'; +export * from './create-enhancement'; export * from './types'; export * from './utils'; -export * from './where-visitor'; diff --git a/packages/runtime/src/enhancements/policy/logger.ts b/packages/runtime/src/enhancements/logger.ts similarity index 100% rename from packages/runtime/src/enhancements/policy/logger.ts rename to packages/runtime/src/enhancements/logger.ts diff --git a/packages/runtime/src/enhancements/omit.ts b/packages/runtime/src/enhancements/omit.ts index 8b2937845..fa834166d 100644 --- a/packages/runtime/src/enhancements/omit.ts +++ b/packages/runtime/src/enhancements/omit.ts @@ -1,40 +1,28 @@ /* eslint-disable @typescript-eslint/no-unused-vars */ /* eslint-disable @typescript-eslint/no-explicit-any */ -import { enumerate, getModelFields, resolveField, type ModelMeta } from '../cross'; -import { getDefaultModelMeta } from '../loader'; +import { enumerate, getModelFields, resolveField } from '../cross'; import { DbClientContract } from '../types'; +import { InternalEnhancementOptions } from './create-enhancement'; import { DefaultPrismaProxyHandler, makeProxy } from './proxy'; -import { CommonEnhancementOptions } from './types'; /** - * Options for @see withOmit - */ -export interface WithOmitOptions extends CommonEnhancementOptions { - /** - * Model metadata - */ - modelMeta?: ModelMeta; -} - -/** - * Gets an enhanced Prisma client that supports "@omit" attribute. + * Gets an enhanced Prisma client that supports `@omit` attribute. * - * @deprecated Use {@link enhance} instead + * @private */ -export function withOmit(prisma: DbClient, options?: WithOmitOptions): DbClient { - const _modelMeta = options?.modelMeta ?? getDefaultModelMeta(options?.loadPath); +export function withOmit(prisma: DbClient, options: InternalEnhancementOptions): DbClient { return makeProxy( prisma, - _modelMeta, - (_prisma, model) => new OmitHandler(_prisma as DbClientContract, model, _modelMeta), + options.modelMeta, + (_prisma, model) => new OmitHandler(_prisma as DbClientContract, model, options), 'omit' ); } class OmitHandler extends DefaultPrismaProxyHandler { - constructor(prisma: DbClientContract, model: string, private readonly modelMeta: ModelMeta) { - super(prisma, model); + constructor(prisma: DbClientContract, model: string, options: InternalEnhancementOptions) { + super(prisma, model, options); } // base override @@ -49,16 +37,23 @@ class OmitHandler extends DefaultPrismaProxyHandler { private async doPostProcess(entityData: any, model: string) { for (const field of getModelFields(entityData)) { - const fieldInfo = await resolveField(this.modelMeta, model, field); + const fieldInfo = await resolveField(this.options.modelMeta, model, field); if (!fieldInfo) { continue; } - if (fieldInfo.attributes?.find((attr) => attr.name === '@omit')) { + const shouldOmit = fieldInfo.attributes?.find((attr) => attr.name === '@omit'); + if (shouldOmit) { delete entityData[field]; - } else if (fieldInfo.isDataModel) { - // recurse - await this.doPostProcess(entityData[field], fieldInfo.type); + } + + if (fieldInfo.isDataModel) { + const items = + fieldInfo.isArray && Array.isArray(entityData[field]) ? entityData[field] : [entityData[field]]; + for (const item of items) { + // recurse + await this.doPostProcess(item, fieldInfo.type); + } } } } diff --git a/packages/runtime/src/enhancements/password.ts b/packages/runtime/src/enhancements/password.ts index c31846298..f83939792 100644 --- a/packages/runtime/src/enhancements/password.ts +++ b/packages/runtime/src/enhancements/password.ts @@ -3,40 +3,31 @@ import { hash } from 'bcryptjs'; import { DEFAULT_PASSWORD_SALT_LENGTH } from '../constants'; -import { NestedWriteVisitor, type ModelMeta, type PrismaWriteActionType } from '../cross'; -import { getDefaultModelMeta } from '../loader'; +import { NestedWriteVisitor, type PrismaWriteActionType } from '../cross'; import { DbClientContract } from '../types'; +import { InternalEnhancementOptions } from './create-enhancement'; import { DefaultPrismaProxyHandler, PrismaProxyActions, makeProxy } from './proxy'; -import { CommonEnhancementOptions } from './types'; /** - * Options for @see withPassword - */ -export interface WithPasswordOptions extends CommonEnhancementOptions { - /** - * Model metadata - */ - modelMeta?: ModelMeta; -} - -/** - * Gets an enhanced Prisma client that supports @password attribute. + * Gets an enhanced Prisma client that supports `@password` attribute. * - * @deprecated Use {@link enhance} instead + * @private */ -export function withPassword(prisma: DbClient, options?: WithPasswordOptions): DbClient { - const _modelMeta = options?.modelMeta ?? getDefaultModelMeta(options?.loadPath); +export function withPassword( + prisma: DbClient, + options: InternalEnhancementOptions +): DbClient { return makeProxy( prisma, - _modelMeta, - (_prisma, model) => new PasswordHandler(_prisma as DbClientContract, model, _modelMeta), + options.modelMeta, + (_prisma, model) => new PasswordHandler(_prisma as DbClientContract, model, options), 'password' ); } class PasswordHandler extends DefaultPrismaProxyHandler { - constructor(prisma: DbClientContract, model: string, readonly modelMeta: ModelMeta) { - super(prisma, model); + constructor(prisma: DbClientContract, model: string, options: InternalEnhancementOptions) { + super(prisma, model, options); } // base override @@ -49,7 +40,7 @@ class PasswordHandler extends DefaultPrismaProxyHandler { } private async preprocessWritePayload(model: string, action: PrismaWriteActionType, args: any) { - const visitor = new NestedWriteVisitor(this.modelMeta, { + const visitor = new NestedWriteVisitor(this.options.modelMeta, { field: async (field, _action, data, context) => { const pwdAttr = field.attributes?.find((attr) => attr.name === '@password'); if (pwdAttr && field.type === 'String') { diff --git a/packages/runtime/src/enhancements/policy/handler.ts b/packages/runtime/src/enhancements/policy/handler.ts index e9f4daae0..2a25845d6 100644 --- a/packages/runtime/src/enhancements/policy/handler.ts +++ b/packages/runtime/src/enhancements/policy/handler.ts @@ -16,14 +16,14 @@ import { type FieldInfo, type ModelMeta, } from '../../cross'; -import { AuthUser, DbClientContract, DbOperations, PolicyOperationKind } from '../../types'; +import { PolicyOperationKind, type CrudContract, type DbClientContract, CRUDOperationKind } from '../../types'; +import type { EnhancementContext, InternalEnhancementOptions } from '../create-enhancement'; +import { Logger } from '../logger'; import { PrismaProxyHandler } from '../proxy'; -import type { PolicyDef, ZodSchemas } from '../types'; +import { QueryUtils } from '../query-utils'; import { formatObject, prismaClientValidationError } from '../utils'; -import { Logger } from './logger'; import { PolicyUtil } from './policy-utils'; import { createDeferredPromise } from './promise'; -import { WithPolicyOptions } from '.'; // a record for post-write policy check type PostWriteCheckRecord = { @@ -40,31 +40,25 @@ type FindOperations = 'findUnique' | 'findUniqueOrThrow' | 'findFirst' | 'findFi */ export class PolicyProxyHandler implements PrismaProxyHandler { private readonly logger: Logger; - private readonly utils: PolicyUtil; + private readonly policyUtils: PolicyUtil; private readonly model: string; - - private readonly DEFAULT_TX_MAXWAIT = 100000; - private readonly DEFAULT_TX_TIMEOUT = 100000; + private readonly modelMeta: ModelMeta; + private readonly prismaModule: any; + private readonly queryUtils: QueryUtils; constructor( private readonly prisma: DbClient, - private readonly policy: PolicyDef, - private readonly modelMeta: ModelMeta, - private readonly zodSchemas: ZodSchemas | undefined, model: string, - private readonly user: AuthUser | undefined, - private readonly options: WithPolicyOptions | undefined + private readonly options: InternalEnhancementOptions, + private readonly context?: EnhancementContext ) { this.logger = new Logger(prisma); - this.utils = new PolicyUtil( - this.prisma, - this.modelMeta, - this.policy, - this.zodSchemas, - this.user, - this.shouldLogQuery - ); this.model = lowerCaseFirst(model); + + ({ modelMeta: this.modelMeta, prismaModule: this.prismaModule } = options); + + this.policyUtils = new PolicyUtil(prisma, options, context, this.shouldLogQuery); + this.queryUtils = new QueryUtils(prisma, options); } private get modelClient() { @@ -77,23 +71,31 @@ export class PolicyProxyHandler implements Pr findUnique(args: any) { if (!args) { - throw prismaClientValidationError(this.prisma, 'query argument is required'); + throw prismaClientValidationError(this.prisma, this.prismaModule, 'query argument is required'); } if (!args.where) { - throw prismaClientValidationError(this.prisma, 'where field is required in query argument'); + throw prismaClientValidationError( + this.prisma, + this.prismaModule, + 'where field is required in query argument' + ); } return this.findWithFluentCallStubs(args, 'findUnique', false, () => null); } findUniqueOrThrow(args: any) { if (!args) { - throw prismaClientValidationError(this.prisma, 'query argument is required'); + throw prismaClientValidationError(this.prisma, this.prismaModule, 'query argument is required'); } if (!args.where) { - throw prismaClientValidationError(this.prisma, 'where field is required in query argument'); + throw prismaClientValidationError( + this.prisma, + this.prismaModule, + 'where field is required in query argument' + ); } return this.findWithFluentCallStubs(args, 'findUniqueOrThrow', true, () => { - throw this.utils.notFound(this.model); + throw this.policyUtils.notFound(this.model); }); } @@ -103,7 +105,7 @@ export class PolicyProxyHandler implements Pr findFirstOrThrow(args: any) { return this.findWithFluentCallStubs(args, 'findFirstOrThrow', true, () => { - throw this.utils.notFound(this.model); + throw this.policyUtils.notFound(this.model); }); } @@ -126,12 +128,15 @@ export class PolicyProxyHandler implements Pr private doFind(args: any, actionName: FindOperations, handleRejection: () => any) { const origArgs = args; - const _args = this.utils.clone(args); - if (!this.utils.injectForRead(this.prisma, this.model, _args)) { + const _args = this.policyUtils.clone(args); + if (!this.policyUtils.injectForRead(this.prisma, this.model, _args)) { + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`${actionName}\` ${this.model}: unconditionally denied`); + } return handleRejection(); } - this.utils.injectReadCheckSelect(this.model, _args); + this.policyUtils.injectReadCheckSelect(this.model, _args); if (this.shouldLogQuery) { this.logger.info(`[policy] \`${actionName}\` ${this.model}:\n${formatObject(_args)}`); @@ -140,7 +145,7 @@ export class PolicyProxyHandler implements Pr return new Promise((resolve, reject) => { this.modelClient[actionName](_args).then( (value: any) => { - this.utils.postProcessForRead(value, this.model, origArgs); + this.policyUtils.postProcessForRead(value, this.model, origArgs); resolve(value); }, (err: any) => reject(err) @@ -151,14 +156,14 @@ export class PolicyProxyHandler implements Pr // returns a fluent API call function private fluentCall(filter: any, fieldInfo: FieldInfo, rootPromise?: Promise) { return (args: any) => { - args = this.utils.clone(args); + args = this.policyUtils.clone(args); // combine the parent filter with the current one const backLinkField = this.requireBackLink(fieldInfo); const condition = backLinkField.isArray ? { [backLinkField.name]: { some: filter } } : { [backLinkField.name]: { is: filter } }; - args.where = this.utils.and(args.where, condition); + args.where = this.policyUtils.and(args.where, condition); const promise = createDeferredPromise(() => { // Promise for fetching @@ -204,7 +209,7 @@ export class PolicyProxyHandler implements Pr // add fluent API functions to the given promise private addFluentFunctions(promise: any, model: string, filter: any, rootPromise?: Promise) { - const fields = this.utils.getModelFields(model); + const fields = this.policyUtils.getModelFields(model); if (fields) { for (const [field, fieldInfo] of Object.entries(fields)) { if (fieldInfo.isDataModel) { @@ -220,26 +225,35 @@ export class PolicyProxyHandler implements Pr async create(args: any) { if (!args) { - throw prismaClientValidationError(this.prisma, 'query argument is required'); + throw prismaClientValidationError(this.prisma, this.prismaModule, 'query argument is required'); } if (!args.data) { - throw prismaClientValidationError(this.prisma, 'data field is required in query argument'); + throw prismaClientValidationError( + this.prisma, + this.prismaModule, + 'data field is required in query argument' + ); } - this.utils.tryReject(this.prisma, this.model, 'create'); + this.policyUtils.tryReject(this.prisma, this.model, 'create'); const origArgs = args; - args = this.utils.clone(args); + args = this.policyUtils.clone(args); // static input policy check for top-level create data - const inputCheck = this.utils.checkInputGuard(this.model, args.data, 'create'); + const inputCheck = this.policyUtils.checkInputGuard(this.model, args.data, 'create'); if (inputCheck === false) { - throw this.utils.deniedByPolicy(this.model, 'create', undefined, CrudFailureReason.ACCESS_POLICY_VIOLATION); + throw this.policyUtils.deniedByPolicy( + this.model, + 'create', + undefined, + CrudFailureReason.ACCESS_POLICY_VIOLATION + ); } const hasNestedCreateOrConnect = await this.hasNestedCreateOrConnect(args); - const { result, error } = await this.transaction(async (tx) => { + const { result, error } = await this.queryUtils.transaction(this.prisma, async (tx) => { if ( // MUST check true here since inputCheck can be undefined (meaning static input check not possible) inputCheck === true && @@ -249,10 +263,10 @@ export class PolicyProxyHandler implements Pr // there's no nested write and we've passed input check, proceed with the create directly // validate zod schema if any - this.validateCreateInputSchema(this.model, args.data); + args.data = this.validateCreateInputSchema(this.model, args.data); // make a create args only containing data and ID selection - const createArgs: any = { data: args.data, select: this.utils.makeIdSelection(this.model) }; + const createArgs: any = { data: args.data, select: this.policyUtils.makeIdSelection(this.model) }; if (this.shouldLogQuery) { this.logger.info(`[policy] \`create\` ${this.model}: ${formatObject(createArgs)}`); @@ -260,7 +274,7 @@ export class PolicyProxyHandler implements Pr const result = await tx[this.model].create(createArgs); // filter the read-back data - return this.utils.readBack(tx, this.model, 'create', args, result); + return this.policyUtils.readBack(tx, this.model, 'create', args, result); } else { // proceed with a complex create and collect post-write checks const { result, postWriteChecks } = await this.doCreate(this.model, args, tx); @@ -269,7 +283,7 @@ export class PolicyProxyHandler implements Pr await this.runPostWriteChecks(postWriteChecks, tx); // filter the read-back data - return this.utils.readBack(tx, this.model, 'create', origArgs, result); + return this.policyUtils.readBack(tx, this.model, 'create', origArgs, result); } }); @@ -281,7 +295,7 @@ export class PolicyProxyHandler implements Pr } // create with nested write - private async doCreate(model: string, args: any, db: Record) { + private async doCreate(model: string, args: any, db: CrudContract) { // record id fields involved in the nesting context const idSelections: Array<{ path: FieldInfo[]; ids: string[] }> = []; const pushIdFields = (model: string, context: NestedWriteVisitorContext) => { @@ -305,23 +319,33 @@ export class PolicyProxyHandler implements Pr // visit the create payload const visitor = new NestedWriteVisitor(this.modelMeta, { create: async (model, args, context) => { - this.validateCreateInputSchema(model, args); + const validateResult = this.validateCreateInputSchema(model, args); + if (validateResult !== args) { + this.policyUtils.replace(args, validateResult); + } pushIdFields(model, context); }, createMany: async (model, args, context) => { - enumerate(args.data).forEach((item) => this.validateCreateInputSchema(model, item)); + enumerate(args.data).forEach((item) => { + const r = this.validateCreateInputSchema(model, item); + if (r !== item) { + this.policyUtils.replace(item, r); + } + }); pushIdFields(model, context); }, connectOrCreate: async (model, args, context) => { if (!args.where) { - throw this.utils.validationError(`'where' field is required for connectOrCreate`); + throw this.policyUtils.validationError(`'where' field is required for connectOrCreate`); } - this.validateCreateInputSchema(model, args.create); + if (args.create) { + args.create = this.validateCreateInputSchema(model, args.create); + } - const existing = await this.utils.checkExistence(db, model, args.where); + const existing = await this.policyUtils.checkExistence(db, model, args.where); if (existing) { // connect case if (context.field?.backLink) { @@ -329,7 +353,7 @@ export class PolicyProxyHandler implements Pr if (backLinkField?.isRelationOwner) { // the target side of relation owns the relation, // check if it's updatable - await this.utils.checkPolicyForUnique(model, args.where, 'update', db, args); + await this.policyUtils.checkPolicyForUnique(model, args.where, 'update', db, args); } } @@ -363,18 +387,18 @@ export class PolicyProxyHandler implements Pr connect: async (model, args, context) => { if (!args || typeof args !== 'object' || Object.keys(args).length === 0) { - throw this.utils.validationError(`'connect' field must be an non-empty object`); + throw this.policyUtils.validationError(`'connect' field must be an non-empty object`); } if (context.field?.backLink) { const backLinkField = resolveField(this.modelMeta, model, context.field.backLink); if (backLinkField?.isRelationOwner) { // check existence - await this.utils.checkExistence(db, model, args, true); + await this.policyUtils.checkExistence(db, model, args, true); // the target side of relation owns the relation, // check if it's updatable - await this.utils.checkPolicyForUnique(model, args, 'update', db, args); + await this.policyUtils.checkPolicyForUnique(model, args, 'update', db, args); } } }, @@ -419,7 +443,7 @@ export class PolicyProxyHandler implements Pr }); // return only the ids of the top-level entity - const ids = this.utils.getEntityIds(this.model, result); + const ids = this.policyUtils.getEntityIds(this.model, result); return { result: ids, postWriteChecks: [...postCreateChecks.values()] }; } @@ -456,11 +480,11 @@ export class PolicyProxyHandler implements Pr // Validates the given create payload against Zod schema if any private validateCreateInputSchema(model: string, data: any) { - const schema = this.utils.getZodSchema(model, 'create'); + const schema = this.policyUtils.getZodSchema(model, 'create'); if (schema) { const parseResult = schema.safeParse(data); if (!parseResult.success) { - throw this.utils.deniedByPolicy( + throw this.policyUtils.deniedByPolicy( model, 'create', `input failed validation: ${fromZodError(parseResult.error)}`, @@ -468,34 +492,44 @@ export class PolicyProxyHandler implements Pr parseResult.error ); } + return parseResult.data; + } else { + return data; } } async createMany(args: { data: any; skipDuplicates?: boolean }) { if (!args) { - throw prismaClientValidationError(this.prisma, 'query argument is required'); + throw prismaClientValidationError(this.prisma, this.prismaModule, 'query argument is required'); } if (!args.data) { - throw prismaClientValidationError(this.prisma, 'data field is required in query argument'); + throw prismaClientValidationError( + this.prisma, + this.prismaModule, + 'data field is required in query argument' + ); } - this.utils.tryReject(this.prisma, this.model, 'create'); + this.policyUtils.tryReject(this.prisma, this.model, 'create'); - args = this.utils.clone(args); + args = this.policyUtils.clone(args); // do static input validation and check if post-create checks are needed let needPostCreateCheck = false; for (const item of enumerate(args.data)) { - const inputCheck = this.utils.checkInputGuard(this.model, item, 'create'); + const inputCheck = this.policyUtils.checkInputGuard(this.model, item, 'create'); if (inputCheck === false) { - throw this.utils.deniedByPolicy( + throw this.policyUtils.deniedByPolicy( this.model, 'create', undefined, CrudFailureReason.ACCESS_POLICY_VIOLATION ); } else if (inputCheck === true) { - this.validateCreateInputSchema(this.model, item); + const r = this.validateCreateInputSchema(this.model, item); + if (r !== item) { + this.policyUtils.replace(item, r); + } } else if (inputCheck === undefined) { // static policy check is not possible, need to do post-create check needPostCreateCheck = true; @@ -507,7 +541,7 @@ export class PolicyProxyHandler implements Pr return this.modelClient.createMany(args); } else { // create entities in a transaction with post-create checks - return this.transaction(async (tx) => { + return this.queryUtils.transaction(this.prisma, async (tx) => { const { result, postWriteChecks } = await this.doCreateMany(this.model, args, tx); // post-create check await this.runPostWriteChecks(postWriteChecks, tx); @@ -516,11 +550,7 @@ export class PolicyProxyHandler implements Pr } } - private async doCreateMany( - model: string, - args: { data: any; skipDuplicates?: boolean }, - db: Record - ) { + private async doCreateMany(model: string, args: { data: any; skipDuplicates?: boolean }, db: CrudContract) { // We can't call the native "createMany" because we can't get back what was created // for post-create checks. Instead, do a "create" for each item and collect the results. @@ -538,7 +568,7 @@ export class PolicyProxyHandler implements Pr if (this.shouldLogQuery) { this.logger.info(`[policy] \`create\` for \`createMany\` ${model}: ${formatObject(item)}`); } - return await db[model].create({ select: this.utils.makeIdSelection(model), data: item }); + return await db[model].create({ select: this.policyUtils.makeIdSelection(model), data: item }); }) ); @@ -555,18 +585,18 @@ export class PolicyProxyHandler implements Pr }; } - private async hasDuplicatedUniqueConstraint(model: string, createData: any, db: Record) { + private async hasDuplicatedUniqueConstraint(model: string, createData: any, db: CrudContract) { // check unique constraint conflicts // we can't rely on try/catch/ignore constraint violation error: https://github.com/prisma/prisma/issues/20496 // TODO: for simple cases we should be able to translate it to an `upsert` with empty `update` payload // for each unique constraint, check if the input item has all fields set, and if so, check if // an entity already exists, and ignore accordingly - const uniqueConstraints = this.utils.getUniqueConstraints(model); + const uniqueConstraints = this.policyUtils.getUniqueConstraints(model); for (const constraint of Object.values(uniqueConstraints)) { if (constraint.fields.every((f) => createData[f] !== undefined)) { const uniqueFilter = constraint.fields.reduce((acc, f) => ({ ...acc, [f]: createData[f] }), {}); - const existing = await this.utils.checkExistence(db, model, uniqueFilter); + const existing = await this.policyUtils.checkExistence(db, model, uniqueFilter); if (existing) { return true; } @@ -587,18 +617,26 @@ export class PolicyProxyHandler implements Pr async update(args: any) { if (!args) { - throw prismaClientValidationError(this.prisma, 'query argument is required'); + throw prismaClientValidationError(this.prisma, this.prismaModule, 'query argument is required'); } if (!args.where) { - throw prismaClientValidationError(this.prisma, 'where field is required in query argument'); + throw prismaClientValidationError( + this.prisma, + this.prismaModule, + 'where field is required in query argument' + ); } if (!args.data) { - throw prismaClientValidationError(this.prisma, 'data field is required in query argument'); + throw prismaClientValidationError( + this.prisma, + this.prismaModule, + 'data field is required in query argument' + ); } - args = this.utils.clone(args); + args = this.policyUtils.clone(args); - const { result, error } = await this.transaction(async (tx) => { + const { result, error } = await this.queryUtils.transaction(this.prisma, async (tx) => { // proceed with nested writes and collect post-write checks const { result, postWriteChecks } = await this.doUpdate(args, tx); @@ -606,7 +644,7 @@ export class PolicyProxyHandler implements Pr await this.runPostWriteChecks(postWriteChecks, tx); // filter the read-back data - return this.utils.readBack(tx, this.model, 'update', args, result); + return this.policyUtils.readBack(tx, this.model, 'update', args, result); }); if (error) { @@ -616,17 +654,17 @@ export class PolicyProxyHandler implements Pr } } - private async doUpdate(args: any, db: Record) { + private async doUpdate(args: any, db: CrudContract) { // collected post-update checks const postWriteChecks: PostWriteCheckRecord[] = []; // registers a post-update check task const _registerPostUpdateCheck = async (model: string, uniqueFilter: any) => { // both "post-update" rules and Zod schemas require a post-update check - if (this.utils.hasAuthGuard(model, 'postUpdate') || this.utils.getZodSchema(model)) { + if (this.policyUtils.hasAuthGuard(model, 'postUpdate') || this.policyUtils.getZodSchema(model)) { // select pre-update field values let preValue: any; - const preValueSelect = this.utils.getPreValueSelect(model); + const preValueSelect = this.policyUtils.getPreValueSelect(model); if (preValueSelect && Object.keys(preValueSelect).length > 0) { preValue = await db[model].findFirst({ where: uniqueFilter, select: preValueSelect }); } @@ -653,7 +691,7 @@ export class PolicyProxyHandler implements Pr const unsafe = this.isUnsafeMutate(model, args); // handles the connection to upstream entity - const reversedQuery = this.utils.buildReversedQuery(context, true, unsafe); + const reversedQuery = this.policyUtils.buildReversedQuery(context, true, unsafe); if ((!unsafe || context.field.isRelationOwner) && reversedQuery[context.field.backLink]) { // if mutation is safe, or current field owns the relation (so the other side has no fk), // and the reverse query contains the back link, then we can build a "connect" with it @@ -688,7 +726,7 @@ export class PolicyProxyHandler implements Pr // for example when it's nested inside a one-to-one update const upstreamQuery = { where: reversedQuery[backLinkField.name], - select: this.utils.makeIdSelection(backLinkField.type), + select: this.policyUtils.makeIdSelection(backLinkField.type), }; // fetch the upstream entity @@ -738,8 +776,8 @@ export class PolicyProxyHandler implements Pr const _connectDisconnect = async (model: string, args: any, context: NestedWriteVisitorContext) => { if (context.field?.backLink) { - const backLinkField = this.utils.getModelField(model, context.field.backLink); - if (backLinkField.isRelationOwner) { + const backLinkField = this.policyUtils.getModelField(model, context.field.backLink); + if (backLinkField?.isRelationOwner) { // update happens on the related model, require updatable, // translate args to foreign keys so field-level policies can be checked const checkArgs: any = {}; @@ -751,7 +789,7 @@ export class PolicyProxyHandler implements Pr } } } - await this.utils.checkPolicyForUnique(model, args, 'update', db, checkArgs); + await this.policyUtils.checkPolicyForUnique(model, args, 'update', db, checkArgs); // register post-update check await _registerPostUpdateCheck(model, args); @@ -763,10 +801,10 @@ export class PolicyProxyHandler implements Pr const visitor = new NestedWriteVisitor(this.modelMeta, { update: async (model, args, context) => { // build a unique query including upstream conditions - const uniqueFilter = this.utils.buildReversedQuery(context); + const uniqueFilter = this.policyUtils.buildReversedQuery(context); // handle not-found - const existing = await this.utils.checkExistence(db, model, uniqueFilter, true); + const existing = await this.policyUtils.checkExistence(db, model, uniqueFilter, true); // check if the update actually writes to this model let thisModelUpdate = false; @@ -789,13 +827,13 @@ export class PolicyProxyHandler implements Pr } if (thisModelUpdate) { - this.utils.tryReject(db, this.model, 'update'); + this.policyUtils.tryReject(db, this.model, 'update'); // check pre-update guard - await this.utils.checkPolicyForUnique(model, uniqueFilter, 'update', db, args); + await this.policyUtils.checkPolicyForUnique(model, uniqueFilter, 'update', db, args); // handles the case where id fields are updated - const ids = this.utils.clone(existing); + const ids = this.policyUtils.clone(existing); for (const key of Object.keys(existing)) { const updateValue = (args as any).data ? (args as any).data[key] : (args as any)[key]; if ( @@ -814,15 +852,15 @@ export class PolicyProxyHandler implements Pr updateMany: async (model, args, context) => { // prepare for post-update check - if (this.utils.hasAuthGuard(model, 'postUpdate') || this.utils.getZodSchema(model)) { - let select = this.utils.makeIdSelection(model); - const preValueSelect = this.utils.getPreValueSelect(model); + if (this.policyUtils.hasAuthGuard(model, 'postUpdate') || this.policyUtils.getZodSchema(model)) { + let select = this.policyUtils.makeIdSelection(model); + const preValueSelect = this.policyUtils.getPreValueSelect(model); if (preValueSelect) { select = { ...select, ...preValueSelect }; } - const reversedQuery = this.utils.buildReversedQuery(context); + const reversedQuery = this.policyUtils.buildReversedQuery(context); const currentSetQuery = { select, where: reversedQuery }; - this.utils.injectAuthGuardAsWhere(db, currentSetQuery, model, 'read'); + this.policyUtils.injectAuthGuardAsWhere(db, currentSetQuery, model, 'read'); if (this.shouldLogQuery) { this.logger.info( @@ -841,15 +879,15 @@ export class PolicyProxyHandler implements Pr ); } - const updateGuard = this.utils.getAuthGuard(db, model, 'update'); - if (this.utils.isTrue(updateGuard) || this.utils.isFalse(updateGuard)) { + const updateGuard = this.policyUtils.getAuthGuard(db, model, 'update'); + if (this.policyUtils.isTrue(updateGuard) || this.policyUtils.isFalse(updateGuard)) { // injects simple auth guard into where clause - this.utils.injectAuthGuardAsWhere(db, args, model, 'update'); + this.policyUtils.injectAuthGuardAsWhere(db, args, model, 'update'); } else { // we have to process `updateMany` separately because the guard may contain // filters using relation fields which are not allowed in nested `updateMany` - const reversedQuery = this.utils.buildReversedQuery(context); - const updateWhere = this.utils.and(reversedQuery, updateGuard); + const reversedQuery = this.policyUtils.buildReversedQuery(context); + const updateWhere = this.policyUtils.and(reversedQuery, updateGuard); if (this.shouldLogQuery) { this.logger.info( `[policy] \`updateMany\` ${model}:\n${formatObject({ @@ -887,15 +925,15 @@ export class PolicyProxyHandler implements Pr upsert: async (model, args, context) => { // build a unique query including upstream conditions - const uniqueFilter = this.utils.buildReversedQuery(context); + const uniqueFilter = this.policyUtils.buildReversedQuery(context); // branch based on if the update target exists - const existing = await this.utils.checkExistence(db, model, uniqueFilter); + const existing = await this.policyUtils.checkExistence(db, model, uniqueFilter); if (existing) { // update case // check pre-update guard - await this.utils.checkPolicyForUnique(model, uniqueFilter, 'update', db, args); + await this.policyUtils.checkPolicyForUnique(model, uniqueFilter, 'update', db, args); // register post-update check await _registerPostUpdateCheck(model, uniqueFilter); @@ -924,7 +962,7 @@ export class PolicyProxyHandler implements Pr connectOrCreate: async (model, args, context) => { // the where condition is already unique, so we can use it to check if the target exists - const existing = await this.utils.checkExistence(db, model, args.where); + const existing = await this.policyUtils.checkExistence(db, model, args.where); if (existing) { // connect await _connectDisconnect(model, args.where, context); @@ -938,9 +976,9 @@ export class PolicyProxyHandler implements Pr set: async (model, args, context) => { // find the set of items to be replaced - const reversedQuery = this.utils.buildReversedQuery(context); + const reversedQuery = this.policyUtils.buildReversedQuery(context); const findCurrSetArgs = { - select: this.utils.makeIdSelection(model), + select: this.policyUtils.makeIdSelection(model), where: reversedQuery, }; if (this.shouldLogQuery) { @@ -957,25 +995,25 @@ export class PolicyProxyHandler implements Pr delete: async (model, args, context) => { // build a unique query including upstream conditions - const uniqueFilter = this.utils.buildReversedQuery(context); + const uniqueFilter = this.policyUtils.buildReversedQuery(context); // handle not-found - await this.utils.checkExistence(db, model, uniqueFilter, true); + await this.policyUtils.checkExistence(db, model, uniqueFilter, true); // check delete guard - await this.utils.checkPolicyForUnique(model, uniqueFilter, 'delete', db, args); + await this.policyUtils.checkPolicyForUnique(model, uniqueFilter, 'delete', db, args); }, deleteMany: async (model, args, context) => { - const guard = await this.utils.getAuthGuard(db, model, 'delete'); - if (this.utils.isTrue(guard) || this.utils.isFalse(guard)) { + const guard = await this.policyUtils.getAuthGuard(db, model, 'delete'); + if (this.policyUtils.isTrue(guard) || this.policyUtils.isFalse(guard)) { // inject simple auth guard - context.parent.deleteMany = this.utils.and(args, guard); + context.parent.deleteMany = this.policyUtils.and(args, guard); } else { // we have to process `deleteMany` separately because the guard may contain // filters using relation fields which are not allowed in nested `deleteMany` - const reversedQuery = this.utils.buildReversedQuery(context); - const deleteWhere = this.utils.and(reversedQuery, guard); + const reversedQuery = this.policyUtils.buildReversedQuery(context); + const deleteWhere = this.policyUtils.and(reversedQuery, guard); if (this.shouldLogQuery) { this.logger.info(`[policy] \`deleteMany\` ${model}:\n${formatObject({ where: deleteWhere })}`); } @@ -994,7 +1032,7 @@ export class PolicyProxyHandler implements Pr const result = await db[this.model].update({ where: args.where, data: args.data, - select: this.utils.makeIdSelection(this.model), + select: this.policyUtils.makeIdSelection(this.model), }); return { result, postWriteChecks }; @@ -1006,7 +1044,7 @@ export class PolicyProxyHandler implements Pr } for (const k of Object.keys(args)) { const field = resolveField(this.modelMeta, model, k); - if (this.isAutoIncrementIdField(field) || field?.isForeignKey) { + if (field && (this.isAutoIncrementIdField(field) || field.isForeignKey)) { return true; } } @@ -1019,29 +1057,33 @@ export class PolicyProxyHandler implements Pr async updateMany(args: any) { if (!args) { - throw prismaClientValidationError(this.prisma, 'query argument is required'); + throw prismaClientValidationError(this.prisma, this.prismaModule, 'query argument is required'); } if (!args.data) { - throw prismaClientValidationError(this.prisma, 'data field is required in query argument'); + throw prismaClientValidationError( + this.prisma, + this.prismaModule, + 'data field is required in query argument' + ); } - this.utils.tryReject(this.prisma, this.model, 'update'); + this.policyUtils.tryReject(this.prisma, this.model, 'update'); - args = this.utils.clone(args); - this.utils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'update'); + args = this.policyUtils.clone(args); + this.policyUtils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'update'); - if (this.utils.hasAuthGuard(this.model, 'postUpdate') || this.utils.getZodSchema(this.model)) { + if (this.policyUtils.hasAuthGuard(this.model, 'postUpdate') || this.policyUtils.getZodSchema(this.model)) { // use a transaction to do post-update checks const postWriteChecks: PostWriteCheckRecord[] = []; - return this.transaction(async (tx) => { + return this.queryUtils.transaction(this.prisma, async (tx) => { // collect pre-update values - let select = this.utils.makeIdSelection(this.model); - const preValueSelect = this.utils.getPreValueSelect(this.model); + let select = this.policyUtils.makeIdSelection(this.model); + const preValueSelect = this.policyUtils.getPreValueSelect(this.model); if (preValueSelect) { select = { ...select, ...preValueSelect }; } const currentSetQuery = { select, where: args.where }; - this.utils.injectAuthGuardAsWhere(tx, currentSetQuery, this.model, 'read'); + this.policyUtils.injectAuthGuardAsWhere(tx, currentSetQuery, this.model, 'read'); if (this.shouldLogQuery) { this.logger.info(`[policy] \`findMany\` ${this.model}: ${formatObject(currentSetQuery)}`); @@ -1052,7 +1094,7 @@ export class PolicyProxyHandler implements Pr ...currentSet.map((preValue) => ({ model: this.model, operation: 'postUpdate' as PolicyOperationKind, - uniqueFilter: this.utils.getEntityIds(this.model, preValue), + uniqueFilter: this.policyUtils.getEntityIds(this.model, preValue), preValue: preValueSelect ? preValue : undefined, })) ); @@ -1076,40 +1118,52 @@ export class PolicyProxyHandler implements Pr async upsert(args: any) { if (!args) { - throw prismaClientValidationError(this.prisma, 'query argument is required'); + throw prismaClientValidationError(this.prisma, this.prismaModule, 'query argument is required'); } if (!args.where) { - throw prismaClientValidationError(this.prisma, 'where field is required in query argument'); + throw prismaClientValidationError( + this.prisma, + this.prismaModule, + 'where field is required in query argument' + ); } if (!args.create) { - throw prismaClientValidationError(this.prisma, 'create field is required in query argument'); + throw prismaClientValidationError( + this.prisma, + this.prismaModule, + 'create field is required in query argument' + ); } if (!args.update) { - throw prismaClientValidationError(this.prisma, 'update field is required in query argument'); + throw prismaClientValidationError( + this.prisma, + this.prismaModule, + 'update field is required in query argument' + ); } - this.utils.tryReject(this.prisma, this.model, 'create'); - this.utils.tryReject(this.prisma, this.model, 'update'); + this.policyUtils.tryReject(this.prisma, this.model, 'create'); + this.policyUtils.tryReject(this.prisma, this.model, 'update'); - args = this.utils.clone(args); + args = this.policyUtils.clone(args); // We can call the native "upsert" because we can't tell if an entity was created or updated // for doing post-write check accordingly. Instead, decompose it into create or update. - const { result, error } = await this.transaction(async (tx) => { + const { result, error } = await this.queryUtils.transaction(this.prisma, async (tx) => { const { where, create, update, ...rest } = args; - const existing = await this.utils.checkExistence(tx, this.model, args.where); + const existing = await this.policyUtils.checkExistence(tx, this.model, args.where); if (existing) { // update case const { result, postWriteChecks } = await this.doUpdate({ where, data: update, ...rest }, tx); await this.runPostWriteChecks(postWriteChecks, tx); - return this.utils.readBack(tx, this.model, 'update', args, result); + return this.policyUtils.readBack(tx, this.model, 'update', args, result); } else { // create case const { result, postWriteChecks } = await this.doCreate(this.model, { data: create, ...rest }, tx); await this.runPostWriteChecks(postWriteChecks, tx); - return this.utils.readBack(tx, this.model, 'create', args, result); + return this.policyUtils.readBack(tx, this.model, 'create', args, result); } }); @@ -1129,25 +1183,29 @@ export class PolicyProxyHandler implements Pr async delete(args: any) { if (!args) { - throw prismaClientValidationError(this.prisma, 'query argument is required'); + throw prismaClientValidationError(this.prisma, this.prismaModule, 'query argument is required'); } if (!args.where) { - throw prismaClientValidationError(this.prisma, 'where field is required in query argument'); + throw prismaClientValidationError( + this.prisma, + this.prismaModule, + 'where field is required in query argument' + ); } - this.utils.tryReject(this.prisma, this.model, 'delete'); + this.policyUtils.tryReject(this.prisma, this.model, 'delete'); - const { result, error } = await this.transaction(async (tx) => { + const { result, error } = await this.queryUtils.transaction(this.prisma, async (tx) => { // do a read-back before delete - const r = await this.utils.readBack(tx, this.model, 'delete', args, args.where); + const r = await this.policyUtils.readBack(tx, this.model, 'delete', args, args.where); const error = r.error; const read = r.result; // check existence - await this.utils.checkExistence(tx, this.model, args.where, true); + await this.policyUtils.checkExistence(tx, this.model, args.where, true); // inject delete guard - await this.utils.checkPolicyForUnique(this.model, args.where, 'delete', tx, args); + await this.policyUtils.checkPolicyForUnique(this.model, args.where, 'delete', tx, args); // proceed with the deletion if (this.shouldLogQuery) { @@ -1166,11 +1224,11 @@ export class PolicyProxyHandler implements Pr } async deleteMany(args: any) { - this.utils.tryReject(this.prisma, this.model, 'delete'); + this.policyUtils.tryReject(this.prisma, this.model, 'delete'); // inject policy conditions args = args ?? {}; - this.utils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'delete'); + this.policyUtils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'delete'); // conduct the deletion if (this.shouldLogQuery) { @@ -1185,13 +1243,13 @@ export class PolicyProxyHandler implements Pr async aggregate(args: any) { if (!args) { - throw prismaClientValidationError(this.prisma, 'query argument is required'); + throw prismaClientValidationError(this.prisma, this.prismaModule, 'query argument is required'); } - args = this.utils.clone(args); + args = this.policyUtils.clone(args); // inject policy conditions - this.utils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'read'); + this.policyUtils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'read'); if (this.shouldLogQuery) { this.logger.info(`[policy] \`aggregate\` ${this.model}:\n${formatObject(args)}`); @@ -1201,13 +1259,13 @@ export class PolicyProxyHandler implements Pr async groupBy(args: any) { if (!args) { - throw prismaClientValidationError(this.prisma, 'query argument is required'); + throw prismaClientValidationError(this.prisma, this.prismaModule, 'query argument is required'); } - args = this.utils.clone(args); + args = this.policyUtils.clone(args); // inject policy conditions - this.utils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'read'); + this.policyUtils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'read'); if (this.shouldLogQuery) { this.logger.info(`[policy] \`groupBy\` ${this.model}:\n${formatObject(args)}`); @@ -1217,8 +1275,8 @@ export class PolicyProxyHandler implements Pr async count(args: any) { // inject policy conditions - args = args ? this.utils.clone(args) : {}; - this.utils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'read'); + args = args ? this.policyUtils.clone(args) : {}; + this.policyUtils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'read'); if (this.shouldLogQuery) { this.logger.info(`[policy] \`count\` ${this.model}:\n${formatObject(args)}`); @@ -1231,8 +1289,8 @@ export class PolicyProxyHandler implements Pr //#region Subscribe (Prisma Pulse) async subscribe(args: any) { - const readGuard = this.utils.getAuthGuard(this.prisma, this.model, 'read'); - if (this.utils.isTrue(readGuard)) { + const readGuard = this.policyUtils.getAuthGuard(this.prisma, this.model, 'read'); + if (this.policyUtils.isTrue(readGuard)) { // no need to inject if (this.shouldLogQuery) { this.logger.info(`[policy] \`subscribe\` ${this.model}:\n${formatObject(args)}`); @@ -1245,28 +1303,28 @@ export class PolicyProxyHandler implements Pr args = { create: {}, update: {}, delete: {} }; } else { if (typeof args !== 'object') { - throw prismaClientValidationError(this.prisma, 'argument must be an object'); + throw prismaClientValidationError(this.prisma, this.prismaModule, 'argument must be an object'); } if (Object.keys(args).length === 0) { // include all args = { create: {}, update: {}, delete: {} }; } else { - args = this.utils.clone(args); + args = this.policyUtils.clone(args); } } // inject into subscribe conditions if (args.create) { - args.create.after = this.utils.and(args.create.after, readGuard); + args.create.after = this.policyUtils.and(args.create.after, readGuard); } if (args.update) { - args.update.after = this.utils.and(args.update.after, readGuard); + args.update.after = this.policyUtils.and(args.update.after, readGuard); } if (args.delete) { - args.delete.before = this.utils.and(args.delete.before, readGuard); + args.delete.before = this.policyUtils.and(args.delete.before, readGuard); } if (this.shouldLogQuery) { @@ -1277,49 +1335,36 @@ export class PolicyProxyHandler implements Pr //#endregion + //#region Check (added method for permissions check) + + async check(operation: CRUDOperationKind, args: any): Promise { + args = args ? this.policyUtils.clone(args) : {}; + + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`check\` ${this.model}\noperation:${operation}\nargs:${formatObject(args)}`); + } + + return this.policyUtils.checkPermissions(this.model, operation, args, this.policyUtils.user); + } + + //#endregion + //#region Utils private get shouldLogQuery() { return !!this.options?.logPrismaQuery && this.logger.enabled('info'); } - private transaction(action: (tx: Record) => Promise) { - if (this.prisma['$transaction']) { - const txOptions: any = { maxWait: this.DEFAULT_TX_MAXWAIT, timeout: this.DEFAULT_TX_TIMEOUT }; - if (this.options?.transactionMaxWait !== undefined) { - txOptions.maxWait = this.options.transactionMaxWait; - } - if (this.options?.transactionTimeout !== undefined) { - txOptions.timeout = this.options.transactionTimeout; - } - if (this.options?.transactionIsolationLevel !== undefined) { - txOptions.isolationLevel = this.options.transactionIsolationLevel; - } - return this.prisma.$transaction((tx) => action(tx), txOptions); - } else { - // already in transaction, don't nest - return action(this.prisma); - } - } - - private async runPostWriteChecks(postWriteChecks: PostWriteCheckRecord[], db: Record) { + private async runPostWriteChecks(postWriteChecks: PostWriteCheckRecord[], db: CrudContract) { await Promise.all( postWriteChecks.map(async ({ model, operation, uniqueFilter, preValue }) => - this.utils.checkPolicyForUnique(model, uniqueFilter, operation, db, undefined, preValue) + this.policyUtils.checkPolicyForUnique(model, uniqueFilter, operation, db, undefined, preValue) ) ); } private makeHandler(model: string) { - return new PolicyProxyHandler( - this.prisma, - this.policy, - this.modelMeta, - this.zodSchemas, - model, - this.user, - this.options - ); + return new PolicyProxyHandler(this.prisma, model, this.options, this.context); } private requireBackLink(fieldInfo: FieldInfo) { diff --git a/packages/runtime/src/enhancements/policy/index.ts b/packages/runtime/src/enhancements/policy/index.ts index d4380d72b..e197e18c1 100644 --- a/packages/runtime/src/enhancements/policy/index.ts +++ b/packages/runtime/src/enhancements/policy/index.ts @@ -1,78 +1,12 @@ -/* eslint-disable @typescript-eslint/no-var-requires */ /* eslint-disable @typescript-eslint/no-explicit-any */ -import semver from 'semver'; -import { PRISMA_MINIMUM_VERSION } from '../../constants'; -import { getIdFields, type ModelMeta } from '../../cross'; -import { getDefaultModelMeta, getDefaultPolicy, getDefaultZodSchemas } from '../../loader'; -import { AuthUser, DbClientContract } from '../../types'; +import { getIdFields } from '../../cross'; +import { DbClientContract } from '../../types'; import { hasAllFields } from '../../validation'; -import { ErrorTransformer, makeProxy } from '../proxy'; -import type { CommonEnhancementOptions, PolicyDef, ZodSchemas } from '../types'; +import type { EnhancementContext, InternalEnhancementOptions } from '../create-enhancement'; +import { makeProxy } from '../proxy'; import { PolicyProxyHandler } from './handler'; -/** - * Context for evaluating access policies - */ -export type WithPolicyContext = { - user?: AuthUser; -}; - -/** - * Transaction isolation levels: https://www.prisma.io/docs/orm/prisma-client/queries/transactions#transaction-isolation-level - */ -export type TransactionIsolationLevel = - | 'ReadUncommitted' - | 'ReadCommitted' - | 'RepeatableRead' - | 'Snapshot' - | 'Serializable'; - -/** - * Options for @see withPolicy - */ -export interface WithPolicyOptions extends CommonEnhancementOptions { - /** - * Policy definition - */ - policy?: PolicyDef; - - /** - * Model metadata - */ - modelMeta?: ModelMeta; - - /** - * Zod schemas for validation - */ - zodSchemas?: ZodSchemas; - - /** - * Whether to log Prisma query - */ - logPrismaQuery?: boolean; - - /** - * Hook for transforming errors before they are thrown to the caller. - */ - errorTransformer?: ErrorTransformer; - - /** - * The `maxWait` option passed to `prisma.$transaction()` call for transactions initiated by ZenStack. - */ - transactionMaxWait?: number; - - /** - * The `timeout` option passed to `prisma.$transaction()` call for transactions initiated by ZenStack. - */ - transactionTimeout?: number; - - /** - * The `isolationLevel` option passed to `prisma.$transaction()` call for transactions initiated by ZenStack. - */ - transactionIsolationLevel?: TransactionIsolationLevel; -} - /** * Gets an enhanced Prisma client with access policy check. * @@ -81,32 +15,19 @@ export interface WithPolicyOptions extends CommonEnhancementOptions { * @param policy The policy definition, will be loaded from default location if not provided * @param modelMeta The model metadata, will be loaded from default location if not provided * - * @deprecated Use {@link enhance} instead + * @private */ export function withPolicy( prisma: DbClient, - context?: WithPolicyContext, - options?: WithPolicyOptions + options: InternalEnhancementOptions, + context?: EnhancementContext ): DbClient { - if (!prisma) { - throw new Error('Invalid prisma instance'); - } - - const prismaVer = (prisma as any)._clientVersion; - if (prismaVer && semver.lt(prismaVer, PRISMA_MINIMUM_VERSION)) { - console.warn( - `ZenStack requires Prisma version "${PRISMA_MINIMUM_VERSION}" or higher. Detected version is "${prismaVer}".` - ); - } - - const _policy = options?.policy ?? getDefaultPolicy(options?.loadPath); - const _modelMeta = options?.modelMeta ?? getDefaultModelMeta(options?.loadPath); - const _zodSchemas = options?.zodSchemas ?? getDefaultZodSchemas(options?.loadPath); + const { modelMeta, policy } = options; // validate user context const userContext = context?.user; - if (userContext && _modelMeta.authModel) { - const idFields = getIdFields(_modelMeta, _modelMeta.authModel); + if (userContext && modelMeta.authModel) { + const idFields = getIdFields(modelMeta, modelMeta.authModel); if ( !hasAllFields( context.user, @@ -119,7 +40,7 @@ export function withPolicy( } // validate user context for fields used in policy expressions - const authSelector = _policy.authSelector; + const authSelector = policy.authSelector; if (authSelector) { Object.keys(authSelector).forEach((f) => { if (!(f in userContext)) { @@ -131,17 +52,8 @@ export function withPolicy( return makeProxy( prisma, - _modelMeta, - (_prisma, model) => - new PolicyProxyHandler( - _prisma as DbClientContract, - _policy, - _modelMeta, - _zodSchemas, - model, - context?.user, - options - ), + modelMeta, + (_prisma, model) => new PolicyProxyHandler(_prisma as DbClientContract, model, options, context), 'policy', options?.errorTransformer ); diff --git a/packages/runtime/src/enhancements/policy/policy-utils.ts b/packages/runtime/src/enhancements/policy/policy-utils.ts index 388f9cd90..fd1616c2c 100644 --- a/packages/runtime/src/enhancements/policy/policy-utils.ts +++ b/packages/runtime/src/enhancements/policy/policy-utils.ts @@ -16,45 +16,44 @@ import { PRE_UPDATE_VALUE_SELECTOR, PrismaErrorCode, } from '../../constants'; -import { - enumerate, - getFields, - getIdFields, - getModelFields, - resolveField, - zip, - type FieldInfo, - type ModelMeta, - type NestedWriteVisitorContext, -} from '../../cross'; -import { AuthUser, DbClientContract, DbOperations, PolicyOperationKind } from '../../types'; +import { enumerate, getFields, getModelFields, resolveField, zip, type FieldInfo, type ModelMeta } from '../../cross'; +import { AuthUser, CRUDOperationKind, CrudContract, DbClientContract, PolicyOperationKind } from '../../types'; import { getVersion } from '../../version'; +import type { EnhancementContext, InternalEnhancementOptions } from '../create-enhancement'; +import { Logger } from '../logger'; +import { QueryUtils } from '../query-utils'; import type { InputCheckFunc, PolicyDef, ReadFieldCheckFunc, ZodSchemas } from '../types'; -import { - formatObject, - prismaClientKnownRequestError, - prismaClientUnknownRequestError, - prismaClientValidationError, -} from '../utils'; -import { Logger } from './logger'; +import { formatObject, prismaClientKnownRequestError } from '../utils'; +import { init } from 'z3-solver'; /** * Access policy enforcement utilities */ -export class PolicyUtil { - // eslint-disable-next-line @typescript-eslint/ban-ts-comment - // @ts-ignore +export class PolicyUtil extends QueryUtils { private readonly logger: Logger; + private readonly modelMeta: ModelMeta; + private readonly policy: PolicyDef; + private readonly zodSchemas?: ZodSchemas; + private readonly prismaModule: any; + public readonly user?: AuthUser; constructor( private readonly db: DbClientContract, - private readonly modelMeta: ModelMeta, - private readonly policy: PolicyDef, - private readonly zodSchemas: ZodSchemas | undefined, - private readonly user?: AuthUser, + options: InternalEnhancementOptions, + context?: EnhancementContext, private readonly shouldLogQuery = false ) { + super(db, options); + this.logger = new Logger(db); + this.user = context?.user; + + ({ + modelMeta: this.modelMeta, + policy: this.policy, + zodSchemas: this.zodSchemas, + prismaModule: this.prismaModule, + } = options); } //#region Logical operators @@ -238,7 +237,7 @@ export class PolicyUtil { * @returns true if operation is unconditionally allowed, false if unconditionally denied, * otherwise returns a guard object */ - getAuthGuard(db: Record, model: string, operation: PolicyOperationKind, preValue?: any) { + getAuthGuard(db: CrudContract, model: string, operation: PolicyOperationKind, preValue?: any) { const guard = this.policy.guard[lowerCaseFirst(model)]; if (!guard) { throw this.unknownError(`unable to load policy guard for ${model}`); @@ -259,7 +258,7 @@ export class PolicyUtil { /** * Get field-level read auth guard that overrides the model-level */ - getFieldOverrideReadAuthGuard(db: Record, model: string, field: string) { + getFieldOverrideReadAuthGuard(db: CrudContract, model: string, field: string) { const guard = this.requireGuard(model); const provider = guard[`${FIELD_LEVEL_OVERRIDE_READ_GUARD_PREFIX}${field}`]; @@ -279,7 +278,7 @@ export class PolicyUtil { /** * Get field-level update auth guard */ - getFieldUpdateAuthGuard(db: Record, model: string, field: string) { + getFieldUpdateAuthGuard(db: CrudContract, model: string, field: string) { const guard = this.requireGuard(model); const provider = guard[`${FIELD_LEVEL_UPDATE_GUARD_PREFIX}${field}`]; @@ -299,7 +298,7 @@ export class PolicyUtil { /** * Get field-level update auth guard that overrides the model-level */ - getFieldOverrideUpdateAuthGuard(db: Record, model: string, field: string) { + getFieldOverrideUpdateAuthGuard(db: CrudContract, model: string, field: string) { const guard = this.requireGuard(model); const provider = guard[`${FIELD_LEVEL_OVERRIDE_UPDATE_GUARD_PREFIX}${field}`]; @@ -355,7 +354,7 @@ export class PolicyUtil { /** * Injects model auth guard as where clause. */ - injectAuthGuardAsWhere(db: Record, args: any, model: string, operation: PolicyOperationKind) { + injectAuthGuardAsWhere(db: CrudContract, args: any, model: string, operation: PolicyOperationKind) { let guard = this.getAuthGuard(db, model, operation); if (operation === 'update' && args) { @@ -403,7 +402,7 @@ export class PolicyUtil { } private injectGuardForRelationFields( - db: Record, + db: CrudContract, model: string, payload: any, operation: PolicyOperationKind @@ -427,7 +426,7 @@ export class PolicyUtil { } private injectGuardForToManyField( - db: Record, + db: CrudContract, fieldInfo: FieldInfo, payload: { some?: any; every?: any; none?: any }, operation: PolicyOperationKind @@ -461,7 +460,7 @@ export class PolicyUtil { } private injectGuardForToOneField( - db: Record, + db: CrudContract, fieldInfo: FieldInfo, payload: { is?: any; isNot?: any } & Record, operation: PolicyOperationKind @@ -491,7 +490,7 @@ export class PolicyUtil { /** * Injects auth guard for read operations. */ - injectForRead(db: Record, model: string, args: any) { + injectForRead(db: CrudContract, model: string, args: any) { // make select and include visible to the injection const injected: any = { select: args.select, include: args.include }; if (!this.injectAuthGuardAsWhere(db, injected, model, 'read')) { @@ -529,111 +528,14 @@ export class PolicyUtil { return true; } - // flatten unique constraint filters - private flattenGeneratedUniqueField(model: string, args: any) { - // e.g.: { a_b: { a: '1', b: '1' } } => { a: '1', b: '1' } - const uniqueConstraints = this.modelMeta.uniqueConstraints?.[lowerCaseFirst(model)]; - if (uniqueConstraints && Object.keys(uniqueConstraints).length > 0) { - for (const [field, value] of Object.entries(args)) { - if ( - uniqueConstraints[field] && - uniqueConstraints[field].fields.length > 1 && - typeof value === 'object' - ) { - // multi-field unique constraint, flatten it - delete args[field]; - if (value) { - for (const [f, v] of Object.entries(value)) { - args[f] = v; - } - } - } - } - } - } - /** * Gets unique constraints for the given model. */ getUniqueConstraints(model: string) { - return this.modelMeta.uniqueConstraints?.[lowerCaseFirst(model)] ?? {}; - } - - /** - * Builds a reversed query for the given nested path. - */ - buildReversedQuery(context: NestedWriteVisitorContext, forMutationPayload = false, unsafeOperation = false) { - let result, currQuery: any; - let currField: FieldInfo | undefined; - - for (let i = context.nestingPath.length - 1; i >= 0; i--) { - const { field, model, where } = context.nestingPath[i]; - - // never modify the original where because it's shared in the structure - const visitWhere = { ...where }; - if (model && where) { - // make sure composite unique condition is flattened - this.flattenGeneratedUniqueField(model, visitWhere); - } - - if (!result) { - // first segment (bottom), just use its where clause - result = currQuery = { ...visitWhere }; - currField = field; - } else { - if (!currField) { - throw this.unknownError(`missing field in nested path`); - } - if (!currField.backLink) { - throw this.unknownError(`field ${currField.type}.${currField.name} doesn't have a backLink`); - } - - const backLinkField = this.getModelField(currField.type, currField.backLink); - if (!backLinkField) { - throw this.unknownError(`missing backLink field ${currField.backLink} in ${currField.type}`); - } - - if (backLinkField.isArray && !forMutationPayload) { - // many-side of relationship, wrap with "some" query - currQuery[currField.backLink] = { some: { ...visitWhere } }; - currQuery = currQuery[currField.backLink].some; - } else { - const fkMapping = where && backLinkField.isRelationOwner && backLinkField.foreignKeyMapping; - - // calculate if we should preserve the relation condition (e.g., { user: { id: 1 } }) - const shouldPreserveRelationCondition = - // doing a mutation - forMutationPayload && - // and it's a safe mutate - !unsafeOperation && - // and the current segment is the direct parent (the last one is the mutate itself), - // the relation condition should be preserved and will be converted to a "connect" later - i === context.nestingPath.length - 2; - - if (fkMapping && !shouldPreserveRelationCondition) { - // turn relation condition into foreign key condition, e.g.: - // { user: { id: 1 } } => { userId: 1 } - for (const [r, fk] of Object.entries(fkMapping)) { - currQuery[fk] = visitWhere[r]; - } - - if (i > 0) { - // prepare for the next segment - currQuery[currField.backLink] = {}; - } - } else { - // preserve the original structure - currQuery[currField.backLink] = { ...visitWhere }; - } - currQuery = currQuery[currField.backLink]; - } - currField = field; - } - } - return result; + return this.modelMeta.models[lowerCaseFirst(model)]?.uniqueConstraints ?? {}; } - private injectNestedReadConditions(db: Record, model: string, args: any): any[] { + private injectNestedReadConditions(db: CrudContract, model: string, args: any): any[] { const injectTarget = args.select ?? args.include; if (!injectTarget) { return []; @@ -726,7 +628,7 @@ export class PolicyUtil { model: string, uniqueFilter: any, operation: PolicyOperationKind, - db: Record, + db: CrudContract, args: any, preValue?: any ) { @@ -820,7 +722,7 @@ export class PolicyUtil { } } - private getFieldReadGuards(db: Record, model: string, args: { select?: any; include?: any }) { + private getFieldReadGuards(db: CrudContract, model: string, args: { select?: any; include?: any }) { const allFields = Object.values(getFields(this.modelMeta, model)); // all scalar fields by default @@ -843,7 +745,7 @@ export class PolicyUtil { return this.and(...allFieldGuards); } - private getFieldUpdateGuards(db: Record, model: string, args: any) { + private getFieldUpdateGuards(db: CrudContract, model: string, args: any) { const allFieldGuards = []; const allOverrideFieldGuards = []; @@ -902,7 +804,7 @@ export class PolicyUtil { /** * Tries rejecting a request based on static "false" policy. */ - tryReject(db: Record, model: string, operation: PolicyOperationKind) { + tryReject(db: CrudContract, model: string, operation: PolicyOperationKind) { const guard = this.getAuthGuard(db, model, operation); if (this.isFalse(guard)) { throw this.deniedByPolicy(model, operation, undefined, CrudFailureReason.ACCESS_POLICY_VIOLATION); @@ -912,12 +814,7 @@ export class PolicyUtil { /** * Checks if a model exists given a unique filter. */ - async checkExistence( - db: Record, - model: string, - uniqueFilter: any, - throwIfNotFound = false - ): Promise { + async checkExistence(db: CrudContract, model: string, uniqueFilter: any, throwIfNotFound = false): Promise { uniqueFilter = this.clone(uniqueFilter); this.flattenGeneratedUniqueField(model, uniqueFilter); @@ -938,7 +835,7 @@ export class PolicyUtil { * Returns an entity given a unique filter with read policy checked. Reject if not readable. */ async readBack( - db: Record, + db: CrudContract, model: string, operation: PolicyOperationKind, selectInclude: { select?: any; include?: any }, @@ -1049,7 +946,7 @@ export class PolicyUtil { } private makeAllScalarFieldSelect(model: string): any { - const fields = this.modelMeta.fields[lowerCaseFirst(model)]; + const fields = this.getModelFields(model); const result: any = {}; if (fields) { Object.entries(fields).forEach(([k, v]) => { @@ -1083,28 +980,19 @@ export class PolicyUtil { return prismaClientKnownRequestError( this.db, + this.prismaModule, `denied by policy: ${model} entities failed '${operation}' check${extra ? ', ' + extra : ''}`, args ); } notFound(model: string) { - return prismaClientKnownRequestError(this.db, `entity not found for model ${model}`, { + return prismaClientKnownRequestError(this.db, this.prismaModule, `entity not found for model ${model}`, { clientVersion: getVersion(), code: 'P2025', }); } - validationError(message: string) { - return prismaClientValidationError(this.db, message); - } - - unknownError(message: string) { - return prismaClientUnknownRequestError(this.db, message, { - clientVersion: getVersion(), - }); - } - //#endregion //#region Misc @@ -1254,26 +1142,31 @@ export class PolicyUtil { } /** - * Gets information for all fields of a model. + * Clones an object and makes sure it's not empty. */ - getModelFields(model: string) { - model = lowerCaseFirst(model); - return this.modelMeta.fields[model]; + clone(value: unknown): any { + return value ? deepcopy(value) : {}; } /** - * Gets information for a specific model field. + * Replace content of `target` object with `withObject` in-place. */ - getModelField(model: string, field: string) { - model = lowerCaseFirst(model); - return this.modelMeta.fields[model]?.[field]; - } + replace(target: any, withObject: any) { + if (!target || typeof target !== 'object' || !withObject || typeof withObject !== 'object') { + return; + } - /** - * Clones an object and makes sure it's not empty. - */ - clone(value: unknown): any { - return value ? deepcopy(value) : {}; + // remove missing keys + for (const key of Object.keys(target)) { + if (!(key in withObject)) { + delete target[key]; + } + } + + // overwrite keys + for (const [key, value] of Object.entries(withObject)) { + target[key] = value; + } } /** @@ -1289,33 +1182,6 @@ export class PolicyUtil { }, {} as any); } - /** - * Gets "id" fields for a given model. - */ - getIdFields(model: string) { - return getIdFields(this.modelMeta, model, true); - } - - /** - * Gets id field values from an entity. - */ - getEntityIds(model: string, entityData: any) { - const idFields = this.getIdFields(model); - const result: Record = {}; - for (const idField of idFields) { - result[idField.name] = entityData[idField.name]; - } - return result; - } - - /** - * Creates a selection object for id fields for the given model. - */ - makeIdSelection(model: string) { - const idFields = this.getIdFields(model); - return Object.assign({}, ...idFields.map((f) => ({ [f.name]: true }))); - } - private mergeWhereClause(where: any, extra: any) { if (!where) { throw new Error('invalid where clause'); @@ -1351,4 +1217,74 @@ export class PolicyUtil { } //#endregion + + //#region Permissions + + /** + * Checks permissions for the given operation + */ + async checkPermissions( + model: string, + operation: CRUDOperationKind, + args: any, + user: AuthUser | undefined + ): Promise { + const checkPermission = this.policy.permission?.[model][operation]; + if (!checkPermission) { + throw this.unknownError(`unable to load permission checker for model ${model} and operation ${operation}`); + } + const { Context, em } = await init(); + const z3 = Context('main'); + const result = await checkPermission(z3, args, user); + await this.killThreads(em); + return result; + } + + private delay(ms: number): Promise & { cancel(): void }; + private delay(ms: number, result: Error): Promise & { cancel(): void }; + private delay(ms: number, result: T): Promise & { cancel(): void }; + private delay(ms: number, result?: T | Error): Promise & { cancel(): void } { + let handle: any; + const promise = new Promise( + (resolve, reject) => + (handle = setTimeout(() => { + if (result instanceof Error) { + reject(result); + } else if (result !== undefined) { + resolve(result); + } + resolve(); + }, ms)) + ); + return { ...promise, cancel: () => clearTimeout(handle) }; + } + + private waitWhile(premise: () => boolean, pollMs = 100): Promise & { cancel(): void } { + let handle: any; + const promise = new Promise((resolve) => { + handle = setInterval(() => { + if (premise()) { + clearTimeout(handle); + resolve(); + } + }, pollMs); + }); + return { ...promise, cancel: () => clearInterval(handle) }; + } + + // exit process: https://github.com/Z3Prover/z3/issues/7070#issuecomment-1871017371 + private killThreads(em: any): Promise { + em.PThread.terminateAllThreads(); + + // Create a polling lock to wait for threads to return + const lockPromise = this.waitWhile(() => !em.PThread.unusedWorkers.length && !em.PThread.runningWorkers.length); + const delayPromise = this.delay(5000, new Error('Waiting for threads to be killed timed out')); + + return Promise.race([lockPromise, delayPromise]).then(() => { + lockPromise.cancel(); + delayPromise.cancel(); + }); + } + + //#endregion } diff --git a/packages/runtime/src/enhancements/preset.ts b/packages/runtime/src/enhancements/preset.ts deleted file mode 100644 index 0123dbe64..000000000 --- a/packages/runtime/src/enhancements/preset.ts +++ /dev/null @@ -1,22 +0,0 @@ -import { EnhancementOptions, enhance } from './enhance'; -import { WithPolicyContext } from './policy'; - -/** - * Gets a Prisma client enhanced with all essential behaviors, including access - * policy, field validation, field omission and password hashing. - * - * It's a shortcut for calling withOmit(withPassword(withPolicy(prisma, options))). - * - * @param prisma The Prisma client to enhance. - * @param context The context to for evaluating access policies. - * @param options Options. - * - * @deprecated Use {@link enhance} instead - */ -export function withPresets( - prisma: DbClient, - context?: WithPolicyContext, - options?: EnhancementOptions -) { - return enhance(prisma, context, options); -} diff --git a/packages/runtime/src/enhancements/proxy.ts b/packages/runtime/src/enhancements/proxy.ts index c735d595a..6cc2bd591 100644 --- a/packages/runtime/src/enhancements/proxy.ts +++ b/packages/runtime/src/enhancements/proxy.ts @@ -2,7 +2,8 @@ import { PRISMA_PROXY_ENHANCER } from '../constants'; import type { ModelMeta } from '../cross'; -import type { DbClientContract } from '../types'; +import type { DbClientContract, PolicyOperationKind } from '../types'; +import { InternalEnhancementOptions } from './create-enhancement'; import { createDeferredPromise } from './policy/promise'; /** @@ -31,7 +32,7 @@ export interface PrismaProxyHandler { create(args: any): Promise; - createMany(args: any, skipDuplicates?: boolean): Promise; + createMany(args: { data: any; skipDuplicates?: boolean }): Promise; update(args: any): Promise; @@ -50,6 +51,8 @@ export interface PrismaProxyHandler { count(args: any): Promise; subscribe(args: any): Promise; + + check(operation: PolicyOperationKind, args: any): Promise; } /** @@ -63,7 +66,11 @@ export type PrismaProxyActions = keyof PrismaProxyHandler; * methods to allow more easily inject custom logic. */ export class DefaultPrismaProxyHandler implements PrismaProxyHandler { - constructor(protected readonly prisma: DbClientContract, protected readonly model: string) {} + constructor( + protected readonly prisma: DbClientContract, + protected readonly model: string, + protected readonly options: InternalEnhancementOptions + ) {} async findUnique(args: any): Promise { args = await this.preprocessArgs('findUnique', args); @@ -101,9 +108,9 @@ export class DefaultPrismaProxyHandler implements PrismaProxyHandler { return this.processResultEntity(r); } - async createMany(args: any, skipDuplicates?: boolean | undefined): Promise<{ count: number }> { + async createMany(args: { data: any; skipDuplicates?: boolean }): Promise<{ count: number }> { args = await this.preprocessArgs('createMany', args); - return this.prisma[this.model].createMany(args, skipDuplicates); + return this.prisma[this.model].createMany(args); } async update(args: any): Promise { @@ -154,6 +161,17 @@ export class DefaultPrismaProxyHandler implements PrismaProxyHandler { return this.prisma[this.model].subscribe(args); } + async check(operation: PolicyOperationKind, args: any): Promise { + args = await this.preprocessArgs('check', args); + try { + return this.prisma[this.model].check(operation, args); + } catch (e) { + // FIXME: cannot catch the error `db.model.check is not a function` if policy enhancer is not enabled + // I don't understand why it doesn't work with other enhancers extending the default proxy handler (tested with omit enhancer) + throw new Error('Policy enhancer must be enabled to use `check` method'); + } + } + /** * Processes result entities before they're returned */ @@ -182,7 +200,7 @@ export function makeProxy( name = 'unnamed_enhancer', errorTransformer?: ErrorTransformer ) { - const models = Object.keys(modelMeta.fields).map((k) => k.toLowerCase()); + const models = Object.keys(modelMeta.models).map((k) => k.toLowerCase()); const proxy = new Proxy(prisma, { get: (target: any, prop: string | symbol, receiver: any) => { @@ -236,7 +254,7 @@ export function makeProxy( return propVal; } - return createHandlerProxy(makeHandler(target, prop), propVal, errorTransformer); + return createHandlerProxy(makeHandler(target, prop), propVal, prop, errorTransformer); }, }); @@ -247,6 +265,7 @@ export function makeProxy( function createHandlerProxy( handler: T, origTarget: any, + model: string, errorTransformer?: ErrorTransformer ): T { return new Proxy(handler, { @@ -277,7 +296,7 @@ function createHandlerProxy( if (capture.stack && err instanceof Error) { // save the original stack and replace it with a clean one (err as any).internalStack = err.stack; - err.stack = cleanCallStack(capture.stack, propKey.toString(), err.message); + err.stack = cleanCallStack(capture.stack, model, propKey.toString(), err.message); } if (errorTransformer) { @@ -303,9 +322,9 @@ function createHandlerProxy( } // Filter out @zenstackhq/runtime stack (generated by proxy) from stack trace -function cleanCallStack(stack: string, method: string, message: string) { +function cleanCallStack(stack: string, model: string, method: string, message: string) { // message line - let resultStack = `Error calling enhanced Prisma method \`${method}\`: ${message}`; + let resultStack = `Error calling enhanced Prisma method \`${model}.${method}\`: ${message}`; const lines = stack.split('\n'); let foundMarker = false; diff --git a/packages/runtime/src/enhancements/query-utils.ts b/packages/runtime/src/enhancements/query-utils.ts new file mode 100644 index 000000000..6959b922f --- /dev/null +++ b/packages/runtime/src/enhancements/query-utils.ts @@ -0,0 +1,172 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ +import { + FieldInfo, + NestedWriteVisitorContext, + getIdFields, + getModelInfo, + getUniqueConstraints, + resolveField, +} from '../cross'; +import { CrudContract, DbClientContract } from '../types'; +import { getVersion } from '../version'; +import { InternalEnhancementOptions } from './create-enhancement'; +import { prismaClientUnknownRequestError, prismaClientValidationError } from './utils'; + +export class QueryUtils { + constructor(private readonly prisma: DbClientContract, private readonly options: InternalEnhancementOptions) {} + + getIdFields(model: string) { + return getIdFields(this.options.modelMeta, model, true); + } + + makeIdSelection(model: string) { + const idFields = this.getIdFields(model); + return Object.assign({}, ...idFields.map((f) => ({ [f.name]: true }))); + } + + getEntityIds(model: string, entityData: any) { + const idFields = this.getIdFields(model); + const result: Record = {}; + for (const idField of idFields) { + result[idField.name] = entityData[idField.name]; + } + return result; + } + + /** + * Initiates a transaction. + */ + transaction(db: CrudContract, action: (tx: CrudContract) => Promise) { + const fullDb = db as DbClientContract; + if (fullDb['$transaction']) { + return fullDb.$transaction( + (tx) => { + (tx as any)[Symbol.for('nodejs.util.inspect.custom')] = 'PrismaClient$tx'; + return action(tx); + }, + { + maxWait: this.options.transactionMaxWait, + timeout: this.options.transactionTimeout, + isolationLevel: this.options.transactionIsolationLevel, + } + ); + } else { + // already in transaction, don't nest + return action(db); + } + } + + buildReversedQuery(context: NestedWriteVisitorContext, mutating = false, unsafeOperation = false) { + let result, currQuery: any; + let currField: FieldInfo | undefined; + + for (let i = context.nestingPath.length - 1; i >= 0; i--) { + const { field, model, where } = context.nestingPath[i]; + + // never modify the original where because it's shared in the structure + const visitWhere = { ...where }; + if (model && where) { + // make sure composite unique condition is flattened + this.flattenGeneratedUniqueField(model, visitWhere); + } + + if (!result) { + // first segment (bottom), just use its where clause + result = currQuery = { ...visitWhere }; + currField = field; + } else { + if (!currField) { + throw this.unknownError(`missing field in nested path`); + } + if (!currField.backLink) { + throw this.unknownError(`field ${currField.type}.${currField.name} doesn't have a backLink`); + } + + const backLinkField = this.getModelField(currField.type, currField.backLink); + if (!backLinkField) { + throw this.unknownError(`missing backLink field ${currField.backLink} in ${currField.type}`); + } + + if (backLinkField.isArray && !mutating) { + // many-side of relationship, wrap with "some" query + currQuery[currField.backLink] = { some: { ...visitWhere } }; + currQuery = currQuery[currField.backLink].some; + } else { + const fkMapping = where && backLinkField.isRelationOwner && backLinkField.foreignKeyMapping; + + // calculate if we should preserve the relation condition (e.g., { user: { id: 1 } }) + const shouldPreserveRelationCondition = + // doing a mutation + mutating && + // and it's a safe mutate + !unsafeOperation && + // and the current segment is the direct parent (the last one is the mutate itself), + // the relation condition should be preserved and will be converted to a "connect" later + i === context.nestingPath.length - 2; + + if (fkMapping && !shouldPreserveRelationCondition) { + // turn relation condition into foreign key condition, e.g.: + // { user: { id: 1 } } => { userId: 1 } + for (const [r, fk] of Object.entries(fkMapping)) { + currQuery[fk] = visitWhere[r]; + } + + if (i > 0) { + // prepare for the next segment + currQuery[currField.backLink] = {}; + } + } else { + // preserve the original structure + currQuery[currField.backLink] = { ...visitWhere }; + } + currQuery = currQuery[currField.backLink]; + } + currField = field; + } + } + return result; + } + + flattenGeneratedUniqueField(model: string, args: any) { + // e.g.: { a_b: { a: '1', b: '1' } } => { a: '1', b: '1' } + const uniqueConstraints = getUniqueConstraints(this.options.modelMeta, model); + if (uniqueConstraints && Object.keys(uniqueConstraints).length > 0) { + for (const [field, value] of Object.entries(args)) { + if ( + uniqueConstraints[field] && + uniqueConstraints[field].fields.length > 1 && + typeof value === 'object' + ) { + // multi-field unique constraint, flatten it + delete args[field]; + if (value) { + for (const [f, v] of Object.entries(value)) { + args[f] = v; + } + } + } + } + } + } + + validationError(message: string) { + return prismaClientValidationError(this.prisma, this.options.prismaModule, message); + } + + unknownError(message: string) { + return prismaClientUnknownRequestError(this.prisma, this.options.prismaModule, message, { + clientVersion: getVersion(), + }); + } + + getModelFields(model: string) { + return getModelInfo(this.options.modelMeta, model)?.fields; + } + + /** + * Gets information for a specific model field. + */ + getModelField(model: string, field: string) { + return resolveField(this.options.modelMeta, model, field); + } +} diff --git a/packages/runtime/src/enhancements/types.ts b/packages/runtime/src/enhancements/types.ts index 9c8080096..639d50500 100644 --- a/packages/runtime/src/enhancements/types.ts +++ b/packages/runtime/src/enhancements/types.ts @@ -9,7 +9,7 @@ import { HAS_FIELD_LEVEL_POLICY_FLAG, PRE_UPDATE_VALUE_SELECTOR, } from '../constants'; -import type { DbOperations, PolicyOperationKind, QueryContext } from '../types'; +import type { CRUDOperationKind, CrudContract, PolicyOperationKind, QueryContext } from '../types'; /** * Common options for PrismaClient enhancements @@ -24,7 +24,7 @@ export interface CommonEnhancementOptions { /** * Function for getting policy guard with a given context */ -export type PolicyFunc = (context: QueryContext, db: Record) => object; +export type PolicyFunc = (context: QueryContext, db: CrudContract) => object; /** * Function for getting policy guard with a given context @@ -69,6 +69,13 @@ export type PolicyDef = { // a { select: ... } object for fetching `auth()` fields needed for policy evaluation authSelector?: object; + + // permissions checker + permission?: { + [model: string]: { + [operation in CRUDOperationKind]: PermissionsChecker; + }; + }; }; /** @@ -76,5 +83,12 @@ export type PolicyDef = { */ export type ZodSchemas = { models: Record; - input: Record>; + input?: Record>; }; + +export type Condition = { field: any; operator: string; value: any } | AndConditions | OrConditions; +type AndConditions = { AND: Condition[] }; +type OrConditions = { OR: Condition[] }; +export type Conditions = boolean | AndConditions | OrConditions; + +export type PermissionsChecker = (z3: any, args: any, user?: any) => Promise; diff --git a/packages/runtime/src/enhancements/utils.ts b/packages/runtime/src/enhancements/utils.ts index 73b4d42a0..ba2f9a2d8 100644 --- a/packages/runtime/src/enhancements/utils.ts +++ b/packages/runtime/src/enhancements/utils.ts @@ -1,6 +1,3 @@ -/* eslint-disable @typescript-eslint/no-var-requires */ - -import path from 'path'; import * as util from 'util'; import type { DbClientContract } from '../types'; @@ -11,68 +8,17 @@ export function formatObject(value: unknown) { return util.formatWithOptions({ depth: 20 }, value); } -let _PrismaClientValidationError: new (...args: unknown[]) => Error; -let _PrismaClientKnownRequestError: new (...args: unknown[]) => Error; -let _PrismaClientUnknownRequestError: new (...args: unknown[]) => Error; - -/* eslint-disable @typescript-eslint/no-explicit-any */ -function loadPrismaModule(prisma: any) { - // https://github.com/prisma/prisma/discussions/17832 - if (prisma._engineConfig?.datamodelPath) { - // try engine path first - const loadPath = path.dirname(prisma._engineConfig.datamodelPath); - try { - const _prisma = require(loadPath).Prisma; - if (typeof _prisma !== 'undefined') { - return _prisma; - } - } catch { - // noop - } - } - - try { - // Prisma v4 - return require('@prisma/client/runtime'); - } catch { - try { - // Prisma v5 - return require('@prisma/client'); - } catch (err) { - if (process.env.ZENSTACK_TEST === '1') { - // running in test, try cwd - try { - return require(path.join(process.cwd(), 'node_modules/@prisma/client/runtime')); - } catch { - return require(path.join(process.cwd(), 'node_modules/@prisma/client')); - } - } else { - throw err; - } - } - } -} - -export function prismaClientValidationError(prisma: DbClientContract, message: string) { - if (!_PrismaClientValidationError) { - const _prisma = loadPrismaModule(prisma); - _PrismaClientValidationError = _prisma.PrismaClientValidationError; - } - throw new _PrismaClientValidationError(message, { clientVersion: prisma._clientVersion }); +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export function prismaClientValidationError(prisma: DbClientContract, prismaModule: any, message: string): Error { + throw new prismaModule.PrismaClientValidationError(message, { clientVersion: prisma._clientVersion }); } -export function prismaClientKnownRequestError(prisma: DbClientContract, ...args: unknown[]) { - if (!_PrismaClientKnownRequestError) { - const _prisma = loadPrismaModule(prisma); - _PrismaClientKnownRequestError = _prisma.PrismaClientKnownRequestError; - } - return new _PrismaClientKnownRequestError(...args); +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export function prismaClientKnownRequestError(prisma: DbClientContract, prismaModule: any, ...args: unknown[]): Error { + return new prismaModule.PrismaClientKnownRequestError(...args); } -export function prismaClientUnknownRequestError(prisma: DbClientContract, ...args: unknown[]) { - if (!_PrismaClientUnknownRequestError) { - const _prisma = loadPrismaModule(prisma); - _PrismaClientUnknownRequestError = _prisma.PrismaClientUnknownRequestError; - } - throw new _PrismaClientUnknownRequestError(...args); +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export function prismaClientUnknownRequestError(prismaModule: any, ...args: unknown[]): Error { + throw new prismaModule.PrismaClientUnknownRequestError(...args); } diff --git a/packages/runtime/src/index.ts b/packages/runtime/src/index.ts index 57df37ee4..6a2609156 100644 --- a/packages/runtime/src/index.ts +++ b/packages/runtime/src/index.ts @@ -1,7 +1,7 @@ export * from './constants'; export * from './enhancements'; export * from './error'; -export * from './loader'; export * from './types'; export * from './validation'; export * from './version'; +export * from './enhance'; diff --git a/packages/runtime/src/loader.ts b/packages/runtime/src/loader.ts deleted file mode 100644 index 1c2eef7bd..000000000 --- a/packages/runtime/src/loader.ts +++ /dev/null @@ -1,88 +0,0 @@ -/* eslint-disable @typescript-eslint/no-var-requires */ -import path from 'path'; -import { ModelMeta, PolicyDef, ZodSchemas } from './enhancements'; - -/** - * Load model metadata. - * - * @param loadPath The path to load model metadata from. If not provided, - * will use default load path. - */ -export function getDefaultModelMeta(loadPath: string | undefined): ModelMeta { - try { - if (loadPath) { - const toLoad = path.resolve(loadPath, 'model-meta'); - return require(toLoad).default; - } else { - return require('.zenstack/model-meta').default; - } - } catch { - if (process.env.ZENSTACK_TEST === '1' && !loadPath) { - try { - // special handling for running as tests, try resolving relative to CWD - return require(path.join(process.cwd(), 'node_modules', '.zenstack', 'model-meta')).default; - } catch { - throw new Error('Model meta cannot be loaded. Please make sure "zenstack generate" has been run.'); - } - } - throw new Error('Model meta cannot be loaded. Please make sure "zenstack generate" has been run.'); - } -} - -/** - * Load access policies. - * - * @param loadPath The path to load access policies from. If not provided, - * will use default load path. - */ -export function getDefaultPolicy(loadPath: string | undefined): PolicyDef { - try { - if (loadPath) { - const toLoad = path.resolve(loadPath, 'policy'); - return require(toLoad).default; - } else { - return require('.zenstack/policy').default; - } - } catch { - if (process.env.ZENSTACK_TEST === '1' && !loadPath) { - try { - // special handling for running as tests, try resolving relative to CWD - return require(path.join(process.cwd(), 'node_modules', '.zenstack', 'policy')).default; - } catch { - throw new Error( - 'Policy definition cannot be loaded from default location. Please make sure "zenstack generate" has been run.' - ); - } - } - throw new Error( - 'Policy definition cannot be loaded from default location. Please make sure "zenstack generate" has been run.' - ); - } -} - -/** - * Load zod schemas. - * - * @param loadPath The path to load zod schemas from. If not provided, - * will use default load path. - */ -export function getDefaultZodSchemas(loadPath: string | undefined): ZodSchemas | undefined { - try { - if (loadPath) { - const toLoad = path.resolve(loadPath, 'zod'); - return require(toLoad); - } else { - return require('.zenstack/zod'); - } - } catch { - if (process.env.ZENSTACK_TEST === '1' && !loadPath) { - try { - // special handling for running as tests, try resolving relative to CWD - return require(path.join(process.cwd(), 'node_modules', '.zenstack', 'zod')); - } catch { - return undefined; - } - } - return undefined; - } -} diff --git a/packages/runtime/src/package.json b/packages/runtime/src/package.json new file mode 120000 index 000000000..4e26811d4 --- /dev/null +++ b/packages/runtime/src/package.json @@ -0,0 +1 @@ +../package.json \ No newline at end of file diff --git a/packages/runtime/src/types.ts b/packages/runtime/src/types.ts index e143cacfa..cc41b545f 100644 --- a/packages/runtime/src/types.ts +++ b/packages/runtime/src/types.ts @@ -22,6 +22,7 @@ export interface DbOperations { groupBy(args: unknown): Promise; count(args?: unknown): Promise; subscribe(args?: unknown): Promise; + check(operation: PolicyOperationKind, args?: unknown): Promise; fields: Record; } @@ -35,6 +36,11 @@ export type PolicyKind = 'allow' | 'deny'; */ export type PolicyOperationKind = 'create' | 'update' | 'postUpdate' | 'read' | 'delete'; +/** + * Kinds of operations controlled by access policies + */ +export type CRUDOperationKind = 'create' | 'update' | 'read' | 'delete'; + /** * Current login user info */ @@ -56,6 +62,14 @@ export type QueryContext = { preValue?: any; }; -export type DbClientContract = Record & { - $transaction: (action: (tx: Record) => Promise, options?: unknown) => Promise; +/** + * Prisma contract for CRUD operations. + */ +export type CrudContract = Record; + +/** + * Prisma contract for database client. + */ +export type DbClientContract = CrudContract & { + $transaction: (action: (tx: CrudContract) => Promise, options?: unknown) => Promise; }; diff --git a/packages/runtime/src/version.ts b/packages/runtime/src/version.ts index 567ef7a71..b8e941547 100644 --- a/packages/runtime/src/version.ts +++ b/packages/runtime/src/version.ts @@ -1,42 +1,9 @@ -import path from 'path'; - -/* eslint-disable @typescript-eslint/no-var-requires */ -export function getVersion() { - try { - return require('./package.json').version; - } catch { - try { - // dev environment - return require('../package.json').version; - } catch { - return 'unknown'; - } - } -} +import * as pkgJson from './package.json'; /** - * Gets installed Prisma version by first checking "@prisma/client" and if not available, - * "prisma". + * Gets this package's version. + * @returns */ -export function getPrismaVersion(): string | undefined { - if (process.env.ZENSTACK_TEST === '1') { - // test environment - try { - return require(path.resolve('./node_modules/@prisma/client/package.json')).version; - } catch { - return undefined; - } - } - - try { - // eslint-disable-next-line @typescript-eslint/no-var-requires - return require('@prisma/client/package.json').version; - } catch { - try { - // eslint-disable-next-line @typescript-eslint/no-var-requires - return require('prisma/package.json').version; - } catch { - return undefined; - } - } +export function getVersion() { + return pkgJson.version; } diff --git a/packages/schema/package.json b/packages/schema/package.json index d7726b3be..09693e0e7 100644 --- a/packages/schema/package.json +++ b/packages/schema/package.json @@ -3,7 +3,7 @@ "publisher": "zenstack", "displayName": "ZenStack Language Tools", "description": "Build scalable web apps with minimum code by defining authorization and validation rules inside the data schema that closer to the database", - "version": "1.8.2", + "version": "2.0.0-alpha.1", "author": { "name": "ZenStack Team" }, @@ -87,7 +87,7 @@ "colors": "1.4.0", "commander": "^8.3.0", "get-latest-version": "^5.0.1", - "langium": "1.2.0", + "langium": "1.3.1", "lower-case-first": "^2.0.2", "mixpanel": "^0.17.0", "ora": "^5.4.1", @@ -111,7 +111,7 @@ "zod-validation-error": "^1.5.0" }, "devDependencies": { - "@prisma/client": "^4.8.0", + "@prisma/client": "^5.7.1", "@types/async-exit-hook": "^2.0.0", "@types/pluralize": "^0.0.29", "@types/semver": "^7.3.13", @@ -123,7 +123,7 @@ "@zenstackhq/runtime": "workspace:*", "dotenv": "^16.0.3", "esbuild": "^0.15.12", - "prisma": "^4.8.0", + "prisma": "^5.7.1", "renamer": "^4.0.0", "tmp": "^0.2.1", "tsc-alias": "^1.7.0", diff --git a/packages/schema/src/cli/cli-util.ts b/packages/schema/src/cli/cli-util.ts index 85c38e82a..3a92d393c 100644 --- a/packages/schema/src/cli/cli-util.ts +++ b/packages/schema/src/cli/cli-util.ts @@ -85,11 +85,18 @@ export async function loadDocument(fileName: string): Promise { const model = document.parseResult.value as Model; - mergeImportsDeclarations(langiumDocuments, model); + const imported = mergeImportsDeclarations(langiumDocuments, model); + // remove imported documents + await services.shared.workspace.DocumentBuilder.update( + [], + imported.map((m) => m.$document!.uri) + ); validationAfterMerge(model); - mergeBaseModel(model); + mergeBaseModel(model, services.references.Linker); + + await relinkAll(model, services); return model; } @@ -151,6 +158,8 @@ export function mergeImportsDeclarations(documents: LangiumDocuments, model: Mod }); model.declarations.push(...importedDeclarations); + + return importedModels; } export async function getPluginDocuments(services: ZModelServices, fileName: string): Promise { @@ -295,3 +304,20 @@ export function getDefaultSchemaLocation() { return path.resolve('schema.zmodel'); } + +async function relinkAll(model: Model, services: ZModelServices) { + const doc = model.$document!; + + // unlink the document + services.references.Linker.unlink(doc); + + // remove current document + await services.shared.workspace.DocumentBuilder.update([], [doc.uri]); + + // recreate the document + const newDoc = services.shared.workspace.LangiumDocumentFactory.fromModel(model, doc.uri); + (model as Mutable).$document = newDoc; + + // rebuild the document + await services.shared.workspace.DocumentBuilder.build([newDoc], { validationChecks: 'all' }); +} diff --git a/packages/schema/src/cli/index.ts b/packages/schema/src/cli/index.ts index d96ed1121..a3ea38238 100644 --- a/packages/schema/src/cli/index.ts +++ b/packages/schema/src/cli/index.ts @@ -81,7 +81,6 @@ export function createProgram() { `schema file (with extension ${schemaExtensions}). Defaults to "schema.zmodel" unless specified in package.json.` ); - const configOption = new Option('-c, --config [file]', 'config file').hideHelp(); const pmOption = new Option('-p, --package-manager ', 'package manager to use').choices([ 'npm', 'yarn', @@ -99,7 +98,6 @@ export function createProgram() { program .command('init') .description('Initialize an existing project for ZenStack.') - .addOption(configOption) .addOption(pmOption) .addOption(new Option('--prisma ', 'location of Prisma schema file to bootstrap from')) .addOption(new Option('--tag ', 'the NPM package tag to use when installing dependencies')) @@ -112,7 +110,6 @@ export function createProgram() { .description('Run code generation.') .addOption(schemaOption) .addOption(new Option('-o, --output ', 'default output directory for built-in plugins')) - .addOption(configOption) .addOption(new Option('--no-default-plugins', 'do not run default plugins')) .addOption(new Option('--no-compile', 'do not compile the output of built-in plugins')) .addOption(noVersionCheckOption) diff --git a/packages/schema/src/cli/plugin-runner.ts b/packages/schema/src/cli/plugin-runner.ts index 0609fd4fb..3e73932a1 100644 --- a/packages/schema/src/cli/plugin-runner.ts +++ b/packages/schema/src/cli/plugin-runner.ts @@ -8,24 +8,26 @@ import { getLiteral, getLiteralArray, hasValidationAttributes, + OptionValue, + PluginDeclaredOptions, PluginError, PluginFunction, - PluginOptions, resolvePath, } from '@zenstackhq/sdk'; import colors from 'colors'; import fs from 'fs'; import ora from 'ora'; import path from 'path'; -import { ensureDefaultOutputFolder } from '../plugins/plugin-utils'; +import { CorePlugins, ensureDefaultOutputFolder } from '../plugins/plugin-utils'; import { getDefaultPrismaOutputFile } from '../plugins/prisma/schema-generator'; import telemetry from '../telemetry'; import { getVersion } from '../utils/version-utils'; type PluginInfo = { name: string; + description?: string; provider: string; - options: PluginOptions; + options: PluginDeclaredOptions; run: PluginFunction; dependencies: string[]; module: any; @@ -46,16 +48,16 @@ export class PluginRunner { /** * Runs a series of nested generators */ - async run(options: PluginRunnerOptions): Promise { + async run(runnerOptions: PluginRunnerOptions): Promise { const version = getVersion(); console.log(colors.bold(`⌛️ ZenStack CLI v${version}, running plugins`)); - ensureDefaultOutputFolder(options); + ensureDefaultOutputFolder(runnerOptions); const plugins: PluginInfo[] = []; - const pluginDecls = options.schema.declarations.filter((d): d is Plugin => isPlugin(d)); + const pluginDecls = runnerOptions.schema.declarations.filter((d): d is Plugin => isPlugin(d)); - let prismaOutput = getDefaultPrismaOutputFile(options.schemaPath); + let prismaOutput = getDefaultPrismaOutputFile(runnerOptions.schemaPath); for (const pluginDecl of pluginDecls) { const pluginProvider = this.getPluginProvider(pluginDecl); @@ -68,7 +70,7 @@ export class PluginRunner { let pluginModule: any; try { - pluginModule = this.loadPluginModule(pluginProvider, options); + pluginModule = this.loadPluginModule(pluginProvider, runnerOptions.schemaPath); } catch (err) { console.error(`Unable to load plugin module ${pluginProvider}: ${err}`); throw new PluginError('', `Unable to load plugin module ${pluginProvider}`); @@ -80,19 +82,21 @@ export class PluginRunner { } const dependencies = this.getPluginDependencies(pluginModule); - const pluginName = this.getPluginName(pluginModule, pluginProvider); - const pluginOptions: PluginOptions = { schemaPath: options.schemaPath, name: pluginName }; + const pluginOptions: PluginDeclaredOptions = { + provider: pluginProvider, + }; pluginDecl.fields.forEach((f) => { const value = getLiteral(f.value) ?? getLiteralArray(f.value); if (value === undefined) { - throw new PluginError(pluginName, `Invalid option value for ${f.name}`); + throw new PluginError(pluginDecl.name, `Invalid option value for ${f.name}`); } pluginOptions[f.name] = value; }); plugins.push({ - name: pluginName, + name: pluginDecl.name, + description: this.getPluginDescription(pluginModule), provider: pluginProvider, dependencies, options: pluginOptions, @@ -102,40 +106,17 @@ export class PluginRunner { if (pluginProvider === '@core/prisma' && typeof pluginOptions.output === 'string') { // record custom prisma output path - prismaOutput = resolvePath(pluginOptions.output, pluginOptions); + prismaOutput = resolvePath(pluginOptions.output, { schemaPath: runnerOptions.schemaPath }); } } - // get core plugins that need to be enabled - const corePlugins = this.calculateCorePlugins(options, plugins); - - // shift/insert core plugins to the front - for (const corePlugin of corePlugins.reverse()) { - const existingIdx = plugins.findIndex((p) => p.provider === corePlugin.provider); - if (existingIdx >= 0) { - // shift the plugin to the front - const existing = plugins[existingIdx]; - plugins.splice(existingIdx, 1); - plugins.unshift(existing); - } else { - // synthesize a plugin and insert front - const pluginModule = require(this.getPluginModulePath(corePlugin.provider, options)); - const pluginName = this.getPluginName(pluginModule, corePlugin.provider); - plugins.unshift({ - name: pluginName, - provider: corePlugin.provider, - dependencies: [], - options: { schemaPath: options.schemaPath, name: pluginName, ...corePlugin.options }, - run: pluginModule.default, - module: pluginModule, - }); - } - } + // calculate all plugins (including core plugins implicitly enabled) + const allPlugins = this.calculateAllPlugins(runnerOptions, plugins); // check dependencies - for (const plugin of plugins) { + for (const plugin of allPlugins) { for (const dep of plugin.dependencies) { - if (!plugins.find((p) => p.provider === dep)) { + if (!allPlugins.find((p) => p.provider === dep)) { console.error(`Plugin ${plugin.provider} depends on "${dep}" but it's not declared`); throw new PluginError( plugin.name, @@ -145,7 +126,7 @@ export class PluginRunner { } } - if (plugins.length === 0) { + if (allPlugins.length === 0) { console.log(colors.yellow('No plugins configured.')); return; } @@ -153,9 +134,9 @@ export class PluginRunner { const warnings: string[] = []; let dmmf: DMMF.Document | undefined = undefined; - for (const { name, provider, run, options: pluginOptions } of plugins) { + for (const { name, description, provider, run, options: pluginOptions } of allPlugins) { // const start = Date.now(); - await this.runPlugin(name, run, options, pluginOptions, dmmf, warnings); + await this.runPlugin(name, description, run, runnerOptions, pluginOptions, dmmf, warnings); // console.log(`✅ Plugin ${colors.bold(name)} (${provider}) completed in ${Date.now() - start}ms`); if (provider === '@core/prisma') { // load prisma DMMF @@ -171,37 +152,56 @@ export class PluginRunner { console.log(`Don't forget to restart your dev server to let the changes take effect.`); } - private calculateCorePlugins(options: PluginRunnerOptions, plugins: PluginInfo[]) { - const corePlugins: Array<{ provider: string; options?: Record }> = []; + private calculateAllPlugins(options: PluginRunnerOptions, plugins: PluginInfo[]) { + const corePlugins: PluginInfo[] = []; + let zodImplicitlyAdded = false; - if (options.defaultPlugins) { - corePlugins.push( - { provider: '@core/prisma' }, - { provider: '@core/model-meta' }, - { provider: '@core/access-policy' } - ); - } else if (plugins.length > 0) { - // "@core/prisma" plugin is always enabled if any plugin is configured - corePlugins.push({ provider: '@core/prisma' }); + // 1. @core/prisma + const existingPrisma = plugins.find((p) => p.provider === CorePlugins.Prisma); + if (existingPrisma) { + corePlugins.push(existingPrisma); + plugins.splice(plugins.indexOf(existingPrisma), 1); + } else if (options.defaultPlugins || plugins.some((p) => p.provider !== CorePlugins.Prisma)) { + // "@core/prisma" is enabled as default or if any other plugin is configured + corePlugins.push(this.makeCorePlugin(CorePlugins.Prisma, options.schemaPath, {})); } - // "@core/access-policy" has implicit requirements - let zodImplicitlyAdded = false; - if ([...plugins, ...corePlugins].find((p) => p.provider === '@core/access-policy')) { - // make sure "@core/model-meta" is enabled - if (!corePlugins.find((p) => p.provider === '@core/model-meta')) { - corePlugins.push({ provider: '@core/model-meta' }); - } + const hasValidation = this.hasValidation(options.schema); + + // 2. @core/zod + const existingZod = plugins.find((p) => p.provider === CorePlugins.Zod); + if (existingZod && !existingZod.options.output) { + // we can reuse the user-provided zod plugin if it didn't specify a custom output path + plugins.splice(plugins.indexOf(existingZod), 1); + corePlugins.push(existingZod); + } - // '@core/zod' plugin is auto-enabled by "@core/access-policy" - // if there're validation rules - if (!corePlugins.find((p) => p.provider === '@core/zod') && this.hasValidation(options.schema)) { - zodImplicitlyAdded = true; - corePlugins.push({ provider: '@core/zod', options: { modelOnly: true } }); + if ( + !corePlugins.some((p) => p.provider === CorePlugins.Zod) && + (options.defaultPlugins || plugins.some((p) => p.provider === CorePlugins.Enhancer)) && + hasValidation + ) { + // ensure "@core/zod" is enabled if "@core/enhancer" is enabled and there're validation rules + zodImplicitlyAdded = true; + corePlugins.push(this.makeCorePlugin(CorePlugins.Zod, options.schemaPath, { modelOnly: true })); + } + + // 3. @core/enhancer + const existingEnhancer = plugins.find((p) => p.provider === CorePlugins.Enhancer); + if (existingEnhancer) { + corePlugins.push(existingEnhancer); + plugins.splice(plugins.indexOf(existingEnhancer), 1); + } else { + if (options.defaultPlugins) { + corePlugins.push( + this.makeCorePlugin(CorePlugins.Enhancer, options.schemaPath, { + withZodSchemas: hasValidation, + }) + ); } } - // core plugins introduced by dependencies + // collect core plugins introduced by dependencies plugins.forEach((plugin) => { // TODO: generalize this const isTrpcPlugin = @@ -217,7 +217,9 @@ export class PluginRunner { if (existing.provider === '@core/zod') { // Zod plugin can be automatically enabled in `modelOnly` mode, however // other plugin (tRPC) for now requires it to run in full mode - existing.options = {}; + if (existing.options.modelOnly) { + delete existing.options.modelOnly; + } if ( isTrpcPlugin && @@ -229,21 +231,39 @@ export class PluginRunner { } } else { // add core dependency - const toAdd = { provider: dep, options: {} as Record }; + const depOptions: Record = {}; // TODO: generalize this if (dep === '@core/zod' && isTrpcPlugin) { // pass trpc plugin's `generateModels` option down to zod plugin - toAdd.options.generateModels = plugin.options.generateModels; + depOptions.generateModels = plugin.options.generateModels; } - corePlugins.push(toAdd); + corePlugins.push(this.makeCorePlugin(dep, options.schemaPath, depOptions)); } } } }); - return corePlugins; + return [...corePlugins, ...plugins]; + } + + private makeCorePlugin( + provider: string, + schemaPath: string, + options: Record + ): PluginInfo { + const pluginModule = require(this.getPluginModulePath(provider, schemaPath)); + const pluginName = this.getPluginName(pluginModule, provider); + return { + name: pluginName, + description: this.getPluginDescription(pluginModule), + provider: provider, + dependencies: [], + options: { ...options, provider }, + run: pluginModule.default, + module: pluginModule, + }; } private hasValidation(schema: Model) { @@ -251,10 +271,15 @@ export class PluginRunner { } // eslint-disable-next-line @typescript-eslint/no-explicit-any - private getPluginName(pluginModule: any, pluginProvider: string): string { + private getPluginName(pluginModule: any, pluginProvider: string) { return typeof pluginModule.name === 'string' ? (pluginModule.name as string) : pluginProvider; } + // eslint-disable-next-line @typescript-eslint/no-explicit-any + private getPluginDescription(pluginModule: any) { + return typeof pluginModule.description === 'string' ? (pluginModule.description as string) : undefined; + } + private getPluginDependencies(pluginModule: any) { return Array.isArray(pluginModule.dependencies) ? (pluginModule.dependencies as string[]) : []; } @@ -266,13 +291,15 @@ export class PluginRunner { private async runPlugin( name: string, + description: string | undefined, run: PluginFunction, runnerOptions: PluginRunnerOptions, - options: PluginOptions, + options: PluginDeclaredOptions, dmmf: DMMF.Document | undefined, warnings: string[] ) { - const spinner = ora(`Running plugin ${colors.cyan(name)}`).start(); + const title = description ?? `Running plugin ${colors.cyan(name)}`; + const spinner = ora(title).start(); try { await telemetry.trackSpan( 'cli:plugin:start', @@ -283,7 +310,7 @@ export class PluginRunner { options, }, async () => { - let result = run(runnerOptions.schema, options, dmmf, { + let result = run(runnerOptions.schema, { ...options, schemaPath: runnerOptions.schemaPath }, dmmf, { output: runnerOptions.output, compile: runnerOptions.compile, }); @@ -302,7 +329,7 @@ export class PluginRunner { } } - private getPluginModulePath(provider: string, options: Pick) { + private getPluginModulePath(provider: string, schemaPath: string) { let pluginModulePath = provider; if (provider.startsWith('@core/')) { pluginModulePath = provider.replace(/^@core/, path.join(__dirname, '../plugins')); @@ -312,14 +339,14 @@ export class PluginRunner { require.resolve(pluginModulePath); } catch { // relative - pluginModulePath = resolvePath(provider, options); + pluginModulePath = resolvePath(provider, { schemaPath }); } } return pluginModulePath; } - private loadPluginModule(provider: string, options: Pick) { - const pluginModulePath = this.getPluginModulePath(provider, options); + private loadPluginModule(provider: string, schemaPath: string) { + const pluginModulePath = this.getPluginModulePath(provider, schemaPath); return require(pluginModulePath); } } diff --git a/packages/schema/src/extension.ts b/packages/schema/src/extension.ts index d28f7dd87..a3e19d7f8 100644 --- a/packages/schema/src/extension.ts +++ b/packages/schema/src/extension.ts @@ -56,6 +56,6 @@ function startLanguageClient(context: vscode.ExtensionContext): LanguageClient { const client = new LanguageClient('zmodel', 'ZenStack Model', serverOptions, clientOptions); // Start the client. This will also launch the server - client.start(); + void client.start(); return client; } diff --git a/packages/schema/src/language-server/validator/datamodel-validator.ts b/packages/schema/src/language-server/validator/datamodel-validator.ts index ce1886f5e..ac727b917 100644 --- a/packages/schema/src/language-server/validator/datamodel-validator.ts +++ b/packages/schema/src/language-server/validator/datamodel-validator.ts @@ -6,8 +6,9 @@ import { isStringLiteral, ReferenceExpr, } from '@zenstackhq/language/ast'; -import { analyzePolicies, getLiteral, getModelIdFields, getModelUniqueFields } from '@zenstackhq/sdk'; +import { getLiteral, getModelIdFields, getModelUniqueFields, isDelegateModel } from '@zenstackhq/sdk'; import { AstNode, DiagnosticInfo, getDocument, ValidationAcceptor } from 'langium'; +import { getModelFieldsWithBases } from '../../utils/ast-utils'; import { IssueCodes, SCALAR_TYPES } from '../constants'; import { AstValidator } from '../types'; import { getUniqueFields } from '../utils'; @@ -20,37 +21,32 @@ import { validateDuplicatedDeclarations } from './utils'; export default class DataModelValidator implements AstValidator { validate(dm: DataModel, accept: ValidationAcceptor): void { this.validateBaseAbstractModel(dm, accept); - validateDuplicatedDeclarations(dm.$resolvedFields, accept); + validateDuplicatedDeclarations(dm, getModelFieldsWithBases(dm), accept); this.validateAttributes(dm, accept); this.validateFields(dm, accept); } private validateFields(dm: DataModel, accept: ValidationAcceptor) { - const idFields = dm.$resolvedFields.filter((f) => f.attributes.find((attr) => attr.decl.ref?.name === '@id')); - const uniqueFields = dm.$resolvedFields.filter((f) => - f.attributes.find((attr) => attr.decl.ref?.name === '@unique') - ); + const allFields = getModelFieldsWithBases(dm); + const idFields = allFields.filter((f) => f.attributes.find((attr) => attr.decl.ref?.name === '@id')); + const uniqueFields = allFields.filter((f) => f.attributes.find((attr) => attr.decl.ref?.name === '@unique')); const modelLevelIds = getModelIdFields(dm); const modelUniqueFields = getModelUniqueFields(dm); if ( + !dm.isAbstract && idFields.length === 0 && modelLevelIds.length === 0 && uniqueFields.length === 0 && modelUniqueFields.length === 0 ) { - const { allows, denies, hasFieldValidation } = analyzePolicies(dm); - if (allows.length > 0 || denies.length > 0 || hasFieldValidation) { - // TODO: relax this requirement to require only @unique fields - // when access policies or field valdaition is used, require an @id field - accept( - 'error', - 'Model must include a field with @id or @unique attribute, or a model-level @@id or @@unique attribute to use access policies', - { - node: dm, - } - ); - } + accept( + 'error', + 'Model must have at least one unique criteria. Either mark a single field with `@id`, `@unique` or add a multi field criterion with `@@id([])` or `@@unique([])` to the model.', + { + node: dm, + } + ); } else if (idFields.length > 0 && modelLevelIds.length > 0) { accept('error', 'Model cannot have both field-level @id and model-level @@id attributes', { node: dm, @@ -74,10 +70,10 @@ export default class DataModelValidator implements AstValidator { dm.fields.forEach((field) => this.validateField(field, accept)); if (!dm.isAbstract) { - dm.$resolvedFields + allFields .filter((x) => isDataModel(x.type.reference?.ref)) .forEach((y) => { - this.validateRelationField(y, accept); + this.validateRelationField(dm, y, accept); }); } } @@ -194,7 +190,7 @@ export default class DataModelValidator implements AstValidator { // points back const oppositeModel = field.type.reference?.ref as DataModel; if (oppositeModel) { - const oppositeModelFields = oppositeModel.$resolvedFields as DataModelField[]; + const oppositeModelFields = getModelFieldsWithBases(oppositeModel); for (const oppositeField of oppositeModelFields) { // find the opposite relation with the matching name const relAttr = oppositeField.attributes.find((a) => a.decl.ref?.name === '@relation'); @@ -213,18 +209,23 @@ export default class DataModelValidator implements AstValidator { return false; } - private validateRelationField(field: DataModelField, accept: ValidationAcceptor) { + private validateRelationField(contextModel: DataModel, field: DataModelField, accept: ValidationAcceptor) { const thisRelation = this.parseRelation(field, accept); if (!thisRelation.valid) { return; } + if (field.$container !== contextModel && isDelegateModel(field.$container as DataModel)) { + // relation fields inherited from delegate model don't need opposite relation + return; + } + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion const oppositeModel = field.type.reference!.ref! as DataModel; // Use name because the current document might be updated - let oppositeFields = oppositeModel.$resolvedFields.filter( - (f) => f.type.reference?.ref?.name === field.$container.name + let oppositeFields = getModelFieldsWithBases(oppositeModel).filter( + (f) => f.type.reference?.ref?.name === contextModel.name ); oppositeFields = oppositeFields.filter((f) => { const fieldRel = this.parseRelation(f); @@ -232,13 +233,13 @@ export default class DataModelValidator implements AstValidator { }); if (oppositeFields.length === 0) { - const node = field.$isInherited ? field.$container : field; - const info: DiagnosticInfo = { node, code: IssueCodes.MissingOppositeRelation }; + const info: DiagnosticInfo = { + node: field, + code: IssueCodes.MissingOppositeRelation, + }; info.property = 'name'; - // use cstNode because the field might be inherited from parent model - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - const container = field.$cstNode!.element.$container as DataModel; + const container = field.$container; const relationFieldDocUri = getDocument(container).textDocument.uri; const relationDataModelName = container.name; @@ -247,20 +248,20 @@ export default class DataModelValidator implements AstValidator { relationFieldName: field.name, relationDataModelName, relationFieldDocUri, - dataModelName: field.$container.name, + dataModelName: contextModel.name, }; info.data = data; accept( 'error', - `The relation field "${field.name}" on model "${field.$container.name}" is missing an opposite relation field on model "${oppositeModel.name}"`, + `The relation field "${field.name}" on model "${contextModel.name}" is missing an opposite relation field on model "${oppositeModel.name}"`, info ); return; } else if (oppositeFields.length > 1) { oppositeFields - .filter((x) => !x.$isInherited) + .filter((f) => f.$container !== contextModel) .forEach((f) => { if (this.isSelfRelation(f)) { // self relations are partial @@ -363,12 +364,19 @@ export default class DataModelValidator implements AstValidator { private validateBaseAbstractModel(model: DataModel, accept: ValidationAcceptor) { model.superTypes.forEach((superType, index) => { - if (!superType.ref?.isAbstract) - accept('error', `Model ${superType.$refText} cannot be extended because it's not abstract`, { - node: model, - property: 'superTypes', - index, - }); + if ( + !superType.ref?.isAbstract && + !superType.ref?.attributes.some((attr) => attr.decl.ref?.name === '@@delegate') + ) + accept( + 'error', + `Model ${superType.$refText} cannot be extended because it's neither abstract nor marked as "@@delegate"`, + { + node: model, + property: 'superTypes', + index, + } + ); }); } } diff --git a/packages/schema/src/language-server/validator/datasource-validator.ts b/packages/schema/src/language-server/validator/datasource-validator.ts index f24fed08b..d102e409f 100644 --- a/packages/schema/src/language-server/validator/datasource-validator.ts +++ b/packages/schema/src/language-server/validator/datasource-validator.ts @@ -9,7 +9,7 @@ import { SUPPORTED_PROVIDERS } from '../constants'; */ export default class DataSourceValidator implements AstValidator { validate(ds: DataSource, accept: ValidationAcceptor): void { - validateDuplicatedDeclarations(ds.fields, accept); + validateDuplicatedDeclarations(ds, ds.fields, accept); this.validateProvider(ds, accept); this.validateUrl(ds, accept); this.validateRelationMode(ds, accept); diff --git a/packages/schema/src/language-server/validator/enum-validator.ts b/packages/schema/src/language-server/validator/enum-validator.ts index 4223d8a2b..5780d91fb 100644 --- a/packages/schema/src/language-server/validator/enum-validator.ts +++ b/packages/schema/src/language-server/validator/enum-validator.ts @@ -10,7 +10,7 @@ import { validateDuplicatedDeclarations } from './utils'; export default class EnumValidator implements AstValidator { // eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types validate(_enum: Enum, accept: ValidationAcceptor) { - validateDuplicatedDeclarations(_enum.fields, accept); + validateDuplicatedDeclarations(_enum, _enum.fields, accept); this.validateAttributes(_enum, accept); _enum.fields.forEach((field) => { this.validateField(field, accept); diff --git a/packages/schema/src/language-server/validator/expression-validator.ts b/packages/schema/src/language-server/validator/expression-validator.ts index 7644521b8..cfc8a39af 100644 --- a/packages/schema/src/language-server/validator/expression-validator.ts +++ b/packages/schema/src/language-server/validator/expression-validator.ts @@ -3,16 +3,16 @@ import { Expression, ExpressionType, isDataModel, + isDataModelField, isEnum, + isLiteralExpr, isMemberAccessExpr, isNullExpr, isThisExpr, - isDataModelField, - isLiteralExpr, } from '@zenstackhq/language/ast'; -import { isDataModelFieldReference, isEnumFieldReference } from '@zenstackhq/sdk'; +import { isAuthInvocation, isDataModelFieldReference, isEnumFieldReference } from '@zenstackhq/sdk'; import { ValidationAcceptor } from 'langium'; -import { getContainingDataModel, isAuthInvocation, isCollectionPredicate } from '../../utils/ast-utils'; +import { getContainingDataModel, isCollectionPredicate } from '../../utils/ast-utils'; import { AstValidator } from '../types'; import { typeAssignable } from './utils'; @@ -132,18 +132,24 @@ export default class ExpressionValidator implements AstValidator { // - foo.user.id == userId // except: // - future().userId == userId - if(isMemberAccessExpr(expr.left) && isDataModelField(expr.left.member.ref) && expr.left.member.ref.$container != getContainingDataModel(expr) - || isMemberAccessExpr(expr.right) && isDataModelField(expr.right.member.ref) && expr.right.member.ref.$container != getContainingDataModel(expr)) - { + if ( + (isMemberAccessExpr(expr.left) && + isDataModelField(expr.left.member.ref) && + expr.left.member.ref.$container != getContainingDataModel(expr)) || + (isMemberAccessExpr(expr.right) && + isDataModelField(expr.right.member.ref) && + expr.right.member.ref.$container != getContainingDataModel(expr)) + ) { // foo.user.id == auth().id // foo.user.id == "123" // foo.user.id == null // foo.user.id == EnumValue - if(!(this.isNotModelFieldExpr(expr.left) || this.isNotModelFieldExpr(expr.right))) - { - accept('error', 'comparison between fields of different models are not supported', { node: expr }); - break; - } + if (!(this.isNotModelFieldExpr(expr.left) || this.isNotModelFieldExpr(expr.right))) { + accept('error', 'comparison between fields of different models are not supported', { + node: expr, + }); + break; + } } if ( @@ -205,14 +211,13 @@ export default class ExpressionValidator implements AstValidator { } } - private isNotModelFieldExpr(expr: Expression) { - return isLiteralExpr(expr) || isEnumFieldReference(expr) || isNullExpr(expr) || this.isAuthOrAuthMemberAccess(expr) + return ( + isLiteralExpr(expr) || isEnumFieldReference(expr) || isNullExpr(expr) || this.isAuthOrAuthMemberAccess(expr) + ); } private isAuthOrAuthMemberAccess(expr: Expression) { return isAuthInvocation(expr) || (isMemberAccessExpr(expr) && isAuthInvocation(expr.operand)); } - } - diff --git a/packages/schema/src/language-server/validator/function-invocation-validator.ts b/packages/schema/src/language-server/validator/function-invocation-validator.ts index 3bc364bd2..a6af730f2 100644 --- a/packages/schema/src/language-server/validator/function-invocation-validator.ts +++ b/packages/schema/src/language-server/validator/function-invocation-validator.ts @@ -11,10 +11,15 @@ import { isDataModelFieldAttribute, isLiteralExpr, } from '@zenstackhq/language/ast'; -import { ExpressionContext, getFunctionExpressionContext, isEnumFieldReference, isFromStdlib } from '@zenstackhq/sdk'; +import { + ExpressionContext, + getDataModelFieldReference, + getFunctionExpressionContext, + isEnumFieldReference, + isFromStdlib, +} from '@zenstackhq/sdk'; import { AstNode, ValidationAcceptor } from 'langium'; import { P, match } from 'ts-pattern'; -import { getDataModelFieldReference } from '../../utils/ast-utils'; import { AstValidator } from '../types'; import { typeAssignable } from './utils'; @@ -52,6 +57,7 @@ export default class FunctionInvocationValidator implements AstValidator ExpressionContext.DefaultValue) .with(P.union('@@allow', '@@deny', '@allow', '@deny'), () => ExpressionContext.AccessPolicy) .with('@@validate', () => ExpressionContext.ValidationRule) + .with('@@index', () => ExpressionContext.Index) .otherwise(() => undefined); // get the context allowed for the function diff --git a/packages/schema/src/language-server/validator/schema-validator.ts b/packages/schema/src/language-server/validator/schema-validator.ts index b80bf890d..d3722638e 100644 --- a/packages/schema/src/language-server/validator/schema-validator.ts +++ b/packages/schema/src/language-server/validator/schema-validator.ts @@ -13,7 +13,7 @@ export default class SchemaValidator implements AstValidator { constructor(protected readonly documents: LangiumDocuments) {} validate(model: Model, accept: ValidationAcceptor): void { this.validateImports(model, accept); - validateDuplicatedDeclarations(model.declarations, accept); + validateDuplicatedDeclarations(model, model.declarations, accept); const importedModels = resolveTransitiveImports(this.documents, model); diff --git a/packages/schema/src/language-server/validator/utils.ts b/packages/schema/src/language-server/validator/utils.ts index 50e2263d7..6a1a44336 100644 --- a/packages/schema/src/language-server/validator/utils.ts +++ b/packages/schema/src/language-server/validator/utils.ts @@ -3,7 +3,6 @@ import { AttributeParam, BuiltinType, DataModelAttribute, - DataModelField, DataModelFieldAttribute, Expression, ExpressionType, @@ -21,6 +20,7 @@ import { AstNode, ValidationAcceptor } from 'langium'; * Checks if the given declarations have duplicated names */ export function validateDuplicatedDeclarations( + container: AstNode, decls: Array, accept: ValidationAcceptor ): void { @@ -33,8 +33,8 @@ export function validateDuplicatedDeclarations( for (const [name, decls] of Object.entries(groupByName)) { if (decls.length > 1) { let errorField = decls[1]; - if (decls[0].$type === 'DataModelField') { - const nonInheritedFields = decls.filter((x) => !(x as DataModelField).$isInherited); + if (isDataModelField(decls[0])) { + const nonInheritedFields = decls.filter((x) => !(isDataModelField(x) && x.$container !== container)); if (nonInheritedFields.length > 0) { errorField = nonInheritedFields.slice(-1)[0]; } diff --git a/packages/schema/src/language-server/zmodel-code-action.ts b/packages/schema/src/language-server/zmodel-code-action.ts index aace4d0fe..5b6a6c95a 100644 --- a/packages/schema/src/language-server/zmodel-code-action.ts +++ b/packages/schema/src/language-server/zmodel-code-action.ts @@ -2,18 +2,19 @@ import { DataModel, DataModelField, Model, isDataModel } from '@zenstackhq/langu import { AstReflection, CodeActionProvider, - getDocument, IndexManager, LangiumDocument, LangiumDocuments, LangiumServices, MaybePromise, + getDocument, } from 'langium'; import { CodeAction, CodeActionKind, CodeActionParams, Command, Diagnostic } from 'vscode-languageserver'; +import { getModelFieldsWithBases } from '../utils/ast-utils'; import { IssueCodes } from './constants'; -import { ZModelFormatter } from './zmodel-formatter'; import { MissingOppositeRelationData } from './validator/datamodel-validator'; +import { ZModelFormatter } from './zmodel-formatter'; export class ZModelCodeActionProvider implements CodeActionProvider { protected readonly reflection: AstReflection; @@ -92,8 +93,8 @@ export class ZModelCodeActionProvider implements CodeActionProvider { let newText = ''; if (fieldAstNode.type.array) { - //post Post[] - const idField = container.$resolvedFields.find((f) => + // post Post[] + const idField = getModelFieldsWithBases(container).find((f) => f.attributes.find((attr) => attr.decl.ref?.name === '@id') ) as DataModelField; @@ -111,7 +112,7 @@ export class ZModelCodeActionProvider implements CodeActionProvider { const idFieldName = idField.name; const referenceIdFieldName = fieldName + this.upperCaseFirstLetter(idFieldName); - if (!oppositeModel.$resolvedFields.find((f) => f.name === referenceIdFieldName)) { + if (!getModelFieldsWithBases(oppositeModel).find((f) => f.name === referenceIdFieldName)) { referenceField = '\n' + indent + `${referenceIdFieldName} ${idField.type.type}`; } diff --git a/packages/schema/src/language-server/zmodel-completion-provider.ts b/packages/schema/src/language-server/zmodel-completion-provider.ts index 742f7087f..cd6dae0ca 100644 --- a/packages/schema/src/language-server/zmodel-completion-provider.ts +++ b/packages/schema/src/language-server/zmodel-completion-provider.ts @@ -61,7 +61,7 @@ export class ZModelCompletionProvider extends DefaultCompletionProvider { if (isDataModelAttribute(context.node) || isDataModelFieldAttribute(context.node)) { const completions = this.getCompletionFromHint(context.node); if (completions) { - completions.forEach(acceptor); + completions.forEach((c) => acceptor(context, c)); return; } } @@ -131,7 +131,7 @@ export class ZModelCompletionProvider extends DefaultCompletionProvider { return; } - const customAcceptor = (item: CompletionValueItem) => { + const customAcceptor = (context: CompletionContext, item: CompletionValueItem) => { // attributes starting with @@@ are for internal use only if (item.insertText?.startsWith('@@@') || item.label?.startsWith('@@@')) { return; @@ -156,10 +156,10 @@ export class ZModelCompletionProvider extends DefaultCompletionProvider { return; } } - acceptor(item); + acceptor(context, item); }; - super.completionForCrossReference(context, crossRef, customAcceptor); + return super.completionForCrossReference(context, crossRef, customAcceptor); } override completionForKeyword( @@ -168,13 +168,13 @@ export class ZModelCompletionProvider extends DefaultCompletionProvider { keyword: any, acceptor: CompletionAcceptor ): MaybePromise { - const customAcceptor = (item: CompletionValueItem) => { + const customAcceptor = (context: CompletionContext, item: CompletionValueItem) => { if (!this.filterKeywordForContext(context, keyword.value)) { return; } - acceptor(item); + acceptor(context, item); }; - super.completionForKeyword(context, keyword, customAcceptor); + return super.completionForKeyword(context, keyword, customAcceptor); } private filterKeywordForContext(context: CompletionContext, keyword: string) { diff --git a/packages/schema/src/language-server/zmodel-linker.ts b/packages/schema/src/language-server/zmodel-linker.ts index ef97cf4b6..13de8b968 100644 --- a/packages/schema/src/language-server/zmodel-linker.ts +++ b/packages/schema/src/language-server/zmodel-linker.ts @@ -35,7 +35,13 @@ import { isReferenceExpr, isStringLiteral, } from '@zenstackhq/language/ast'; -import { getContainingModel, hasAttribute, isFromStdlib } from '@zenstackhq/sdk'; +import { + getContainingModel, + getModelFieldsWithBases, + hasAttribute, + isAuthInvocation, + isFutureExpr, +} from '@zenstackhq/sdk'; import { AstNode, AstNodeDescription, @@ -52,12 +58,7 @@ import { } from 'langium'; import { match } from 'ts-pattern'; import { CancellationToken } from 'vscode-jsonrpc'; -import { - getAllDeclarationsFromImports, - getContainingDataModel, - isAuthInvocation, - isCollectionPredicate, -} from '../utils/ast-utils'; +import { getAllDeclarationsFromImports, getContainingDataModel } from '../utils/ast-utils'; import { mapBuiltinTypeToExpressionType } from './validator/utils'; interface DefaultReference extends Reference { @@ -261,26 +262,9 @@ export class ZModelLinker extends DefaultLinker { } private resolveReference(node: ReferenceExpr, document: LangiumDocument, extraScopes: ScopeProvider[]) { - this.linkReference(node, 'target', document, extraScopes); - node.args.forEach((arg) => this.resolve(arg, document, extraScopes)); + this.resolveDefault(node, document, extraScopes); if (node.target.ref) { - // if the reference is inside the RHS of a collection predicate, it cannot be resolve to a field - // not belonging to the collection's model type - - const collectionPredicateContext = this.getCollectionPredicateContextDataModel(node); - if ( - // inside a collection predicate RHS - collectionPredicateContext && - // current ref expr is resolved to a field - isDataModelField(node.target.ref) && - // the resolved field doesn't belong to the collection predicate's operand's type - node.target.ref.$container !== collectionPredicateContext - ) { - this.unresolvableRefExpr(node); - return; - } - // resolve type if (node.target.ref.$type === EnumField) { this.resolveToBuiltinTypeOrDecl(node, node.target.ref.$container); @@ -290,26 +274,6 @@ export class ZModelLinker extends DefaultLinker { } } - private getCollectionPredicateContextDataModel(node: ReferenceExpr) { - let curr: AstNode | undefined = node; - while (curr) { - if ( - curr.$container && - // parent is a collection predicate - isCollectionPredicate(curr.$container) && - // the collection predicate's LHS is resolved to a DataModel - isDataModel(curr.$container.left.$resolvedType?.decl) && - // current node is the RHS - curr.$containerProperty === 'right' - ) { - // return the resolved type of LHS - return curr.$container.left.$resolvedType?.decl; - } - curr = curr.$container; - } - return undefined; - } - private resolveArray(node: ArrayExpr, document: LangiumDocument, extraScopes: ScopeProvider[]) { node.items.forEach((item) => this.resolve(item, document, extraScopes)); @@ -329,7 +293,7 @@ export class ZModelLinker extends DefaultLinker { if (node.function.ref) { // eslint-disable-next-line @typescript-eslint/ban-types const funcDecl = node.function.ref as FunctionDecl; - if (funcDecl.name === 'auth' && isFromStdlib(funcDecl)) { + if (isAuthInvocation(node)) { // auth() function is resolved to User model in the current document const model = getContainingModel(node); @@ -346,7 +310,7 @@ export class ZModelLinker extends DefaultLinker { node.$resolvedType = { decl: authModel, nullable: true }; } } - } else if (funcDecl.name === 'future' && isFromStdlib(funcDecl)) { + } else if (isFutureExpr(node)) { // future() function is resolved to current model node.$resolvedType = { decl: getContainingDataModel(node) }; } else { @@ -372,14 +336,11 @@ export class ZModelLinker extends DefaultLinker { document: LangiumDocument, extraScopes: ScopeProvider[] ) { - this.resolve(node.operand, document, extraScopes); + this.resolveDefault(node, document, extraScopes); const operandResolved = node.operand.$resolvedType; if (operandResolved && !operandResolved.array && isDataModel(operandResolved.decl)) { - const modelDecl = operandResolved.decl as DataModel; - const provider = (name: string) => modelDecl.$resolvedFields.find((f) => f.name === name); // member access is resolved only in the context of the operand type - this.linkReference(node, 'member', document, [provider], true); if (node.member.ref) { this.resolveToDeclaredType(node, node.member.ref.type); @@ -393,20 +354,10 @@ export class ZModelLinker extends DefaultLinker { } private resolveCollectionPredicate(node: BinaryExpr, document: LangiumDocument, extraScopes: ScopeProvider[]) { - this.resolve(node.left, document, extraScopes); + this.resolveDefault(node, document, extraScopes); const resolvedType = node.left.$resolvedType; if (resolvedType && isDataModel(resolvedType.decl) && resolvedType.array) { - const dataModelDecl = resolvedType.decl; - const provider = (name: string) => { - if (name === 'this') { - return dataModelDecl; - } else { - return dataModelDecl.$resolvedFields.find((f) => f.name === name); - } - }; - extraScopes = [provider, ...extraScopes]; - this.resolve(node.right, document, extraScopes); this.resolveToBuiltinTypeOrDecl(node, 'Boolean'); } else { // error is reported in validation pass @@ -460,10 +411,11 @@ export class ZModelLinker extends DefaultLinker { // // In model B, the attribute argument "myId" is resolved to the field "myId" in model A - const transtiveDataModel = attrAppliedOn.type.reference?.ref as DataModel; - if (transtiveDataModel) { + const transitiveDataModel = attrAppliedOn.type.reference?.ref as DataModel; + if (transitiveDataModel) { // resolve references in the context of the transitive data model - const scopeProvider = (name: string) => transtiveDataModel.$resolvedFields.find((f) => f.name === name); + const scopeProvider = (name: string) => + getModelFieldsWithBases(transitiveDataModel).find((f) => f.name === name); if (isArrayExpr(node.value)) { node.value.items.forEach((item) => { if (isReferenceExpr(item)) { @@ -518,13 +470,6 @@ export class ZModelLinker extends DefaultLinker { } private resolveDataModel(node: DataModel, document: LangiumDocument, extraScopes: ScopeProvider[]) { - if (node.superTypes.length > 0) { - const providers = node.superTypes.map( - (superType) => (name: string) => superType.ref?.fields.find((f) => f.name === name) - ); - extraScopes = [...providers, ...extraScopes]; - } - return this.resolveDefault(node, document, extraScopes); } diff --git a/packages/schema/src/language-server/zmodel-module.ts b/packages/schema/src/language-server/zmodel-module.ts index 07dc223e0..c0c66ce43 100644 --- a/packages/schema/src/language-server/zmodel-module.ts +++ b/packages/schema/src/language-server/zmodel-module.ts @@ -2,12 +2,15 @@ import { ZModelGeneratedModule, ZModelGeneratedSharedModule } from '@zenstackhq/ import { DefaultConfigurationProvider, DefaultDocumentBuilder, + DefaultFuzzyMatcher, DefaultIndexManager, DefaultLangiumDocumentFactory, DefaultLangiumDocuments, DefaultLanguageServer, + DefaultNodeKindProvider, DefaultServiceRegistry, DefaultSharedModuleContext, + DefaultWorkspaceSymbolProvider, LangiumDefaultSharedServices, LangiumServices, LangiumSharedServices, @@ -77,6 +80,7 @@ export const ZModelModule: Module { @@ -85,6 +89,9 @@ export function createSharedModule( lsp: { Connection: () => context.connection, LanguageServer: (services) => new DefaultLanguageServer(services), + WorkspaceSymbolProvider: (services) => new DefaultWorkspaceSymbolProvider(services), + NodeKindProvider: () => new DefaultNodeKindProvider(), + FuzzyMatcher: () => new DefaultFuzzyMatcher(), }, workspace: { LangiumDocuments: (services) => new DefaultLangiumDocuments(services), diff --git a/packages/schema/src/language-server/zmodel-scope.ts b/packages/schema/src/language-server/zmodel-scope.ts index 8eda869e8..9d685db27 100644 --- a/packages/schema/src/language-server/zmodel-scope.ts +++ b/packages/schema/src/language-server/zmodel-scope.ts @@ -1,7 +1,6 @@ import { - DataModel, + BinaryExpr, MemberAccessExpr, - Model, isDataModel, isDataModelField, isEnumField, @@ -9,8 +8,15 @@ import { isMemberAccessExpr, isModel, isReferenceExpr, + isThisExpr, } from '@zenstackhq/language/ast'; -import { getAuthModel, getDataModels } from '@zenstackhq/sdk'; +import { + getAuthModel, + getDataModels, + getModelFieldsWithBases, + getRecursiveBases, + isAuthInvocation, +} from '@zenstackhq/sdk'; import { AstNode, AstNodeDescription, @@ -19,7 +25,6 @@ import { EMPTY_SCOPE, LangiumDocument, LangiumServices, - Mutable, PrecomputedScopes, ReferenceInfo, Scope, @@ -30,8 +35,9 @@ import { stream, streamAllContents, } from 'langium'; +import { match } from 'ts-pattern'; import { CancellationToken } from 'vscode-jsonrpc'; -import { resolveImportUri } from '../utils/ast-utils'; +import { isCollectionPredicate, isFutureInvocation, resolveImportUri } from '../utils/ast-utils'; import { PLUGIN_MODULE_NAME, STD_LIB_MODULE_NAME } from './constants'; /** @@ -66,49 +72,18 @@ export class ZModelScopeComputation extends DefaultScopeComputation { return result; } - override computeLocalScopes( - document: LangiumDocument, - cancelToken?: CancellationToken | undefined - ): Promise { - const result = super.computeLocalScopes(document, cancelToken); - - //the $resolvedFields would be used in Linking stage for all the documents - //so we need to set it at the end of the scope computation - this.resolveBaseModels(document); - return result; - } - - private resolveBaseModels(document: LangiumDocument) { - const model = document.parseResult.value as Model; - - model.declarations.forEach((decl) => { - if (decl.$type === 'DataModel') { - const dataModel = decl as DataModel; - dataModel.$resolvedFields = [...dataModel.fields]; - this.getRecursiveSuperTypes(dataModel).forEach((superType) => { - superType.fields.forEach((field) => { - const cloneField = Object.assign({}, field); - cloneField.$isInherited = true; - const mutable = cloneField as Mutable; - // update container - mutable.$container = dataModel; - dataModel.$resolvedFields.push(cloneField); - }); - }); - } - }); - } + override processNode(node: AstNode, document: LangiumDocument, scopes: PrecomputedScopes) { + super.processNode(node, document, scopes); - private getRecursiveSuperTypes(dataModel: DataModel): DataModel[] { - const result: DataModel[] = []; - dataModel.superTypes.forEach((superType) => { - const superTypeDecl = superType.ref; - if (superTypeDecl) { - result.push(superTypeDecl); - result.push(...this.getRecursiveSuperTypes(superTypeDecl)); + if (isDataModel(node) && !node.$baseMerged) { + // add base fields to the scope recursively + const bases = getRecursiveBases(node); + for (const base of bases) { + for (const field of base.fields) { + scopes.add(node, this.descriptions.createDescription(field, this.nameProvider.getName(field))); + } } - }); - return result; + } } } @@ -140,50 +115,129 @@ export class ZModelScopeProvider extends DefaultScopeProvider { override getScope(context: ReferenceInfo): Scope { if (isMemberAccessExpr(context.container) && context.container.operand && context.property === 'member') { - return this.getMemberAccessScope(context.container); + return this.getMemberAccessScope(context); + } + + if (isReferenceExpr(context.container) && context.property === 'target') { + // when reference expression is resolved inside a collection predicate, the scope is the collection + const containerCollectionPredicate = getCollectionPredicateContext(context.container); + if (containerCollectionPredicate) { + return this.getCollectionPredicateScope(context, containerCollectionPredicate); + } } + return super.getScope(context); } - private getMemberAccessScope(node: MemberAccessExpr) { - if (isReferenceExpr(node.operand)) { - // scope to target model's fields - const ref = node.operand.target.ref; - if (isDataModelField(ref)) { - const targetModel = ref.type.reference?.ref; - if (isDataModel(targetModel)) { - return this.createScopeForNodes(targetModel.fields); + private getMemberAccessScope(context: ReferenceInfo) { + const referenceType = this.reflection.getReferenceType(context); + const globalScope = this.getGlobalScope(referenceType, context); + const node = context.container as MemberAccessExpr; + + return match(node.operand) + .when(isReferenceExpr, (operand) => { + // operand is a reference, it can only be a model field + const ref = operand.target.ref; + if (isDataModelField(ref)) { + const targetModel = ref.type.reference?.ref; + return this.createScopeForModel(targetModel, globalScope); } - } - } else if (isMemberAccessExpr(node.operand)) { - // scope to target model's fields - const ref = node.operand.member.ref; - if (isDataModelField(ref)) { - const targetModel = ref.type.reference?.ref; - if (isDataModel(targetModel)) { - return this.createScopeForNodes(targetModel.fields); + return EMPTY_SCOPE; + }) + .when(isMemberAccessExpr, (operand) => { + // operand is a member access, it must be resolved to a + const ref = operand.member.ref; + if (isDataModelField(ref)) { + const targetModel = ref.type.reference?.ref; + return this.createScopeForModel(targetModel, globalScope); } - } - } else if (isInvocationExpr(node.operand)) { - // deal with member access from `auth()` and `future() - const funcName = node.operand.function.$refText; - if (funcName === 'auth') { - // resolve to `User` or `@@auth` model - const model = getContainerOfType(node, isModel); - if (model) { - const authModel = getAuthModel(getDataModels(model)); - if (authModel) { - return this.createScopeForNodes(authModel.fields); - } + return EMPTY_SCOPE; + }) + .when(isThisExpr, () => { + // operand is `this`, resolve to the containing model + return this.createScopeForContainingModel(node, globalScope); + }) + .when(isInvocationExpr, (operand) => { + // deal with member access from `auth()` and `future() + if (isAuthInvocation(operand)) { + // resolve to `User` or `@@auth` model + return this.createScopeForAuthModel(node, globalScope); } - } - if (funcName === 'future') { - const thisModel = getContainerOfType(node, isDataModel); - if (thisModel) { - return this.createScopeForNodes(thisModel.fields); + if (isFutureInvocation(operand)) { + // resolve `future()` to the containing model + return this.createScopeForContainingModel(node, globalScope); } + return EMPTY_SCOPE; + }) + .otherwise(() => EMPTY_SCOPE); + } + + private getCollectionPredicateScope(context: ReferenceInfo, collectionPredicate: BinaryExpr) { + const referenceType = this.reflection.getReferenceType(context); + const globalScope = this.getGlobalScope(referenceType, context); + const collection = collectionPredicate.left; + + return match(collection) + .when(isReferenceExpr, (expr) => { + // collection is a reference, it can only be a model field + const ref = expr.target.ref; + if (isDataModelField(ref)) { + const targetModel = ref.type.reference?.ref; + return this.createScopeForModel(targetModel, globalScope); + } + return EMPTY_SCOPE; + }) + .when(isMemberAccessExpr, (expr) => { + // collection is a member access, it can only be resolved to a model field + const ref = expr.member.ref; + if (isDataModelField(ref)) { + const targetModel = ref.type.reference?.ref; + return this.createScopeForModel(targetModel, globalScope); + } + return EMPTY_SCOPE; + }) + .when(isAuthInvocation, (expr) => { + return this.createScopeForAuthModel(expr, globalScope); + }) + .otherwise(() => EMPTY_SCOPE); + } + + private createScopeForContainingModel(node: AstNode, globalScope: Scope) { + const model = getContainerOfType(node, isDataModel); + if (model) { + return this.createScopeForNodes(model.fields, globalScope); + } else { + return EMPTY_SCOPE; + } + } + + private createScopeForModel(node: AstNode | undefined, globalScope: Scope) { + if (isDataModel(node)) { + return this.createScopeForNodes(getModelFieldsWithBases(node), globalScope); + } else { + return EMPTY_SCOPE; + } + } + + private createScopeForAuthModel(node: AstNode, globalScope: Scope) { + const model = getContainerOfType(node, isModel); + if (model) { + const authModel = getAuthModel(getDataModels(model, true)); + if (authModel) { + return this.createScopeForNodes(authModel.fields, globalScope); } } return EMPTY_SCOPE; } } + +function getCollectionPredicateContext(node: AstNode) { + let curr: AstNode | undefined = node; + while (curr) { + if (curr.$container && isCollectionPredicate(curr.$container) && curr.$containerProperty === 'right') { + return curr.$container; + } + curr = curr.$container; + } + return undefined; +} diff --git a/packages/schema/src/plugins/access-policy/index.ts b/packages/schema/src/plugins/access-policy/index.ts deleted file mode 100644 index cbdcbd64f..000000000 --- a/packages/schema/src/plugins/access-policy/index.ts +++ /dev/null @@ -1,10 +0,0 @@ -import { PluginFunction } from '@zenstackhq/sdk'; -import PolicyGenerator from './policy-guard-generator'; - -export const name = 'Access Policy'; - -const run: PluginFunction = async (model, options, _dmmf, globalOptions) => { - return new PolicyGenerator().generate(model, options, globalOptions); -}; - -export default run; diff --git a/packages/schema/src/plugins/enhancer/delegate/index.ts b/packages/schema/src/plugins/enhancer/delegate/index.ts new file mode 100644 index 000000000..5e4cffdfa --- /dev/null +++ b/packages/schema/src/plugins/enhancer/delegate/index.ts @@ -0,0 +1,16 @@ +import { type PluginOptions } from '@zenstackhq/sdk'; +import type { Model } from '@zenstackhq/sdk/ast'; +import type { Project } from 'ts-morph'; +import { PrismaSchemaGenerator } from '../../prisma/schema-generator'; +import path from 'path'; + +export async function generate(model: Model, options: PluginOptions, project: Project, outDir: string) { + const prismaGenerator = new PrismaSchemaGenerator(); + await prismaGenerator.generate(model, { + provider: '@internal', + schemaPath: options.schemaPath, + output: path.join(outDir, 'delegate.prisma'), + overrideClientGenerationPath: path.join(outDir, '.delegate'), + mode: 'logical', + }); +} diff --git a/packages/schema/src/plugins/enhancer/enhance/index.ts b/packages/schema/src/plugins/enhancer/enhance/index.ts new file mode 100644 index 000000000..1d42b5912 --- /dev/null +++ b/packages/schema/src/plugins/enhancer/enhance/index.ts @@ -0,0 +1,248 @@ +import { DELEGATE_AUX_RELATION_PREFIX } from '@zenstackhq/runtime'; +import { + getAttribute, + getDataModels, + getPrismaClientImportSpec, + isDelegateModel, + type PluginOptions, +} from '@zenstackhq/sdk'; +import { DataModelField, isDataModel, isReferenceExpr, type DataModel, type Model } from '@zenstackhq/sdk/ast'; +import path from 'path'; +import { + ForEachDescendantTraversalControl, + MethodSignature, + Node, + Project, + PropertySignature, + SyntaxKind, + TypeAliasDeclaration, +} from 'ts-morph'; +import { PrismaSchemaGenerator } from '../../prisma/schema-generator'; + +export async function generate(model: Model, options: PluginOptions, project: Project, outDir: string) { + const outFile = path.join(outDir, 'enhance.ts'); + let logicalPrismaClientDir: string | undefined; + + if (hasDelegateModel(model)) { + logicalPrismaClientDir = await generateLogicalPrisma(model, options, outDir); + } + + project.createSourceFile( + outFile, + `import { createEnhancement, type EnhancementContext, type EnhancementOptions, type ZodSchemas } from '@zenstackhq/runtime'; +import modelMeta from './model-meta'; +import policy from './policy'; +${options.withZodSchemas ? "import * as zodSchemas from './zod';" : 'const zodSchemas = undefined;'} +import { Prisma } from '${getPrismaClientImportSpec(model, outDir)}'; +${logicalPrismaClientDir ? `import { PrismaClient as EnhancedPrismaClient } from '${logicalPrismaClientDir}';` : ''} + +export function enhance(prisma: DbClient, context?: EnhancementContext, options?: EnhancementOptions) { + return createEnhancement(prisma, { + modelMeta, + policy, + zodSchemas: zodSchemas as unknown as (ZodSchemas | undefined), + prismaModule: Prisma, + ...options + }, context)${logicalPrismaClientDir ? ' as EnhancedPrismaClient' : ''}; +} +`, + { overwrite: true } + ); +} + +function hasDelegateModel(model: Model) { + const dataModels = getDataModels(model); + return dataModels.some( + (dm) => isDelegateModel(dm) && dataModels.some((sub) => sub.superTypes.some((base) => base.ref === dm)) + ); +} + +async function generateLogicalPrisma(model: Model, options: PluginOptions, outDir: string) { + const prismaGenerator = new PrismaSchemaGenerator(); + const prismaClientOutDir = './.delegate'; + await prismaGenerator.generate(model, { + provider: '@internal', + schemaPath: options.schemaPath, + output: path.join(outDir, 'delegate.prisma'), + overrideClientGenerationPath: prismaClientOutDir, + mode: 'logical', + }); + + await processClientTypes(model, path.join(outDir, prismaClientOutDir)); + return prismaClientOutDir; +} + +async function processClientTypes(model: Model, prismaClientDir: string) { + const project = new Project(); + const sf = project.addSourceFileAtPath(path.join(prismaClientDir, 'index.d.ts')); + + const delegateModels: [DataModel, DataModel[]][] = []; + model.declarations + .filter((d): d is DataModel => isDelegateModel(d)) + .forEach((dm) => { + delegateModels.push([ + dm, + model.declarations.filter( + (d): d is DataModel => isDataModel(d) && d.superTypes.some((s) => s.ref === dm) + ), + ]); + }); + + const toRemove: (PropertySignature | MethodSignature)[] = []; + const toReplaceText: [TypeAliasDeclaration, string][] = []; + + sf.forEachDescendant((desc, traversal) => { + removeAuxRelationFields(desc, toRemove, traversal); + fixDelegateUnionType(desc, delegateModels, toReplaceText, traversal); + removeCreateFromDelegateInputTypes(desc, delegateModels, toRemove, traversal); + removeDelegateToplevelCreates(desc, delegateModels, toRemove, traversal); + removeDiscriminatorFromConcreteInputTypes(desc, delegateModels, toRemove); + }); + + toRemove.forEach((n) => n.remove()); + toReplaceText.forEach(([node, text]) => node.replaceWithText(text)); + + await project.save(); +} + +function removeAuxRelationFields( + desc: Node, + toRemove: (PropertySignature | MethodSignature)[], + traversal: ForEachDescendantTraversalControl +) { + if (desc.isKind(SyntaxKind.PropertySignature) || desc.isKind(SyntaxKind.MethodSignature)) { + // remove aux fields + const name = desc.getName(); + + if (name.startsWith(DELEGATE_AUX_RELATION_PREFIX)) { + toRemove.push(desc); + traversal.skip(); + } + } +} + +function fixDelegateUnionType( + desc: Node, + delegateModels: [DataModel, DataModel[]][], + toReplaceText: [TypeAliasDeclaration, string][], + traversal: ForEachDescendantTraversalControl +) { + if (!desc.isKind(SyntaxKind.TypeAliasDeclaration)) { + return; + } + + const name = desc.getName(); + delegateModels.forEach(([delegate, concreteModels]) => { + if (name === `$${delegate.name}Payload`) { + const discriminator = getDiscriminatorField(delegate); + if (discriminator) { + toReplaceText.push([ + desc, + `export type ${name} = + ${concreteModels + .map((m) => `($${m.name}Payload & { scalars: { ${discriminator.name}: '${m.name}' } })`) + .join(' | ')};`, + ]); + traversal.skip(); + } + } + }); +} + +function removeCreateFromDelegateInputTypes( + desc: Node, + delegateModels: [DataModel, DataModel[]][], + toRemove: (PropertySignature | MethodSignature)[], + traversal: ForEachDescendantTraversalControl +) { + if (!desc.isKind(SyntaxKind.TypeAliasDeclaration)) { + return; + } + + const name = desc.getName(); + delegateModels.forEach(([delegate]) => { + // remove create related sub-payload from delegate's input types since they cannot be created directly + const regex = new RegExp(`\\${delegate.name}(Unchecked)?(Create|Update).*Input`); + if (regex.test(name)) { + desc.forEachDescendant((d, innerTraversal) => { + if ( + d.isKind(SyntaxKind.PropertySignature) && + ['create', 'upsert', 'connectOrCreate'].includes(d.getName()) + ) { + toRemove.push(d); + innerTraversal.skip(); + } + }); + traversal.skip(); + } + }); +} + +function removeDiscriminatorFromConcreteInputTypes( + desc: Node, + delegateModels: [DataModel, DataModel[]][], + toRemove: (PropertySignature | MethodSignature)[] +) { + if (!desc.isKind(SyntaxKind.TypeAliasDeclaration)) { + return; + } + + const name = desc.getName(); + delegateModels.forEach(([delegate, concretes]) => { + const discriminator = getDiscriminatorField(delegate); + if (!discriminator) { + return; + } + + concretes.forEach((concrete) => { + // remove discriminator field from the create/update input of concrete models + const regex = new RegExp(`\\${concrete.name}(Unchecked)?(Create|Update).*Input`); + if (regex.test(name)) { + desc.forEachDescendant((d, innerTraversal) => { + if (d.isKind(SyntaxKind.PropertySignature)) { + if (d.getName() === discriminator.name) { + toRemove.push(d); + } + innerTraversal.skip(); + } + }); + } + }); + }); +} + +function removeDelegateToplevelCreates( + desc: Node, + delegateModels: [DataModel, DataModel[]][], + toRemove: (PropertySignature | MethodSignature)[], + traversal: ForEachDescendantTraversalControl +) { + if (desc.isKind(SyntaxKind.InterfaceDeclaration)) { + // remove create and upsert methods from delegate interfaces since they cannot be created directly + const name = desc.getName(); + if (delegateModels.map(([dm]) => `${dm.name}Delegate`).includes(name)) { + const createMethod = desc.getMethod('create'); + if (createMethod) { + toRemove.push(createMethod); + } + const createManyMethod = desc.getMethod('createMany'); + if (createManyMethod) { + toRemove.push(createManyMethod); + } + const upsertMethod = desc.getMethod('upsert'); + if (upsertMethod) { + toRemove.push(upsertMethod); + } + traversal.skip(); + } + } +} + +function getDiscriminatorField(delegate: DataModel) { + const delegateAttr = getAttribute(delegate, '@@delegate'); + if (!delegateAttr) { + return undefined; + } + const arg = delegateAttr.args[0]?.value; + return isReferenceExpr(arg) ? (arg.target.ref as DataModelField) : undefined; +} diff --git a/packages/schema/src/plugins/enhancer/index.ts b/packages/schema/src/plugins/enhancer/index.ts new file mode 100644 index 000000000..86e3ecf39 --- /dev/null +++ b/packages/schema/src/plugins/enhancer/index.ts @@ -0,0 +1,48 @@ +import { + PluginError, + createProject, + emitProject, + resolvePath, + saveProject, + type PluginFunction, +} from '@zenstackhq/sdk'; +import { getDefaultOutputFolder } from '../plugin-utils'; +import { generate as generateEnhancer } from './enhance'; +import { generate as generateModelMeta } from './model-meta'; +import { generate as generatePolicy } from './policy'; + +export const name = 'Prisma Enhancer'; +export const description = 'Generating PrismaClient enhancer'; + +const run: PluginFunction = async (model, options, _dmmf, globalOptions) => { + let ourDir = options.output ? (options.output as string) : getDefaultOutputFolder(globalOptions); + if (!ourDir) { + throw new PluginError(name, `Unable to determine output path, not running plugin`); + } + ourDir = resolvePath(ourDir, options); + + const project = createProject(); + + await generateModelMeta(model, options, project, ourDir); + await generatePolicy(model, options, project, ourDir); + await generateEnhancer(model, options, project, ourDir); + + let shouldCompile = true; + if (typeof options.compile === 'boolean') { + // explicit override + shouldCompile = options.compile; + } else if (globalOptions) { + // from CLI or config file + shouldCompile = globalOptions.compile; + } + + if (!shouldCompile || options.preserveTsFiles === true) { + await saveProject(project); + } + + if (shouldCompile) { + await emitProject(project); + } +}; + +export default run; diff --git a/packages/schema/src/plugins/enhancer/model-meta/index.ts b/packages/schema/src/plugins/enhancer/model-meta/index.ts new file mode 100644 index 000000000..541106e24 --- /dev/null +++ b/packages/schema/src/plugins/enhancer/model-meta/index.ts @@ -0,0 +1,13 @@ +import { generateModelMeta, getDataModels, type PluginOptions } from '@zenstackhq/sdk'; +import type { Model } from '@zenstackhq/sdk/ast'; +import path from 'path'; +import type { Project } from 'ts-morph'; + +export async function generate(model: Model, options: PluginOptions, project: Project, outDir: string) { + const outFile = path.join(outDir, 'model-meta.ts'); + const dataModels = getDataModels(model); + await generateModelMeta(project, dataModels, { + output: outFile, + generateAttributes: true, + }); +} diff --git a/packages/schema/src/plugins/access-policy/expression-writer.ts b/packages/schema/src/plugins/enhancer/policy/expression-writer.ts similarity index 87% rename from packages/schema/src/plugins/access-policy/expression-writer.ts rename to packages/schema/src/plugins/enhancer/policy/expression-writer.ts index 2ab3e2bdd..9333634fa 100644 --- a/packages/schema/src/plugins/access-policy/expression-writer.ts +++ b/packages/schema/src/plugins/enhancer/policy/expression-writer.ts @@ -2,9 +2,11 @@ import { BinaryExpr, BooleanLiteral, DataModel, + DataModelField, Expression, InvocationExpr, isDataModel, + isDataModelField, isEnumField, isMemberAccessExpr, isReferenceExpr, @@ -13,25 +15,28 @@ import { MemberAccessExpr, NumberLiteral, ReferenceExpr, + ReferenceTarget, StringLiteral, UnaryExpr, } from '@zenstackhq/language/ast'; +import { DELEGATE_AUX_RELATION_PREFIX } from '@zenstackhq/runtime'; import { ExpressionContext, getFunctionExpressionContext, + getIdFields, getLiteral, + isAuthInvocation, isDataModelFieldReference, + isDelegateModel, isFutureExpr, PluginError, + TypeScriptExpressionTransformer, + TypeScriptExpressionTransformerError, } from '@zenstackhq/sdk'; import { lowerCaseFirst } from 'lower-case-first'; +import invariant from 'tiny-invariant'; import { CodeBlockWriter } from 'ts-morph'; -import { name } from '.'; -import { getIdFields, isAuthInvocation } from '../../utils/ast-utils'; -import { - TypeScriptExpressionTransformer, - TypeScriptExpressionTransformerError, -} from '../../utils/typescript-expression-transformer'; +import { name } from '..'; type ComparisonOperator = '==' | '!=' | '>' | '>=' | '<' | '<='; @@ -114,11 +119,44 @@ export class ExpressionWriter { throw new Error('We should never get here'); } else { this.block(() => { - this.writer.write(`${expr.target.ref?.name}: true`); + const ref = expr.target.ref; + invariant(ref); + if (this.isFieldReferenceToDelegateModel(ref)) { + const thisModel = ref.$container as DataModel; + const targetBase = ref.$inheritedFrom; + this.writeBaseHierarchy(thisModel, targetBase, () => this.writer.write(`${ref.name}: true`)); + } else { + this.writer.write(`${ref.name}: true`); + } }); } } + private writeBaseHierarchy(thisModel: DataModel, targetBase: DataModel | undefined, conditionWriter: () => void) { + if (!targetBase || thisModel === targetBase) { + conditionWriter(); + return; + } + + const base = this.getDelegateBase(thisModel); + if (!base) { + throw new PluginError(name, `Failed to resolve delegate base model for "${thisModel.name}"`); + } + + this.writer.write(`${`${DELEGATE_AUX_RELATION_PREFIX}_${lowerCaseFirst(base.name)}`}: `); + this.writer.block(() => { + this.writeBaseHierarchy(base, targetBase, conditionWriter); + }); + } + + private getDelegateBase(model: DataModel) { + return model.superTypes.map((t) => t.ref).filter((t) => t && isDelegateModel(t))?.[0]; + } + + private isFieldReferenceToDelegateModel(ref: ReferenceTarget): ref is DataModelField { + return isDataModelField(ref) && !!ref.$inheritedFrom && isDelegateModel(ref.$inheritedFrom); + } + private writeMemberAccess(expr: MemberAccessExpr) { if (this.isAuthOrAuthMemberAccess(expr)) { // member access of `auth()`, generate plain expression @@ -497,48 +535,67 @@ export class ExpressionWriter { filterOp?: FilterOperators, extraArgs?: Record ) { - let selector: string | undefined; + // let selector: string | undefined; let operand: Expression | undefined; + let fieldWriter: ((conditionWriter: () => void) => void) | undefined; if (isThisExpr(fieldAccess)) { // pass on writeCondition(); return; } else if (isReferenceExpr(fieldAccess)) { - selector = fieldAccess.target.ref?.name; + const ref = fieldAccess.target.ref; + invariant(ref); + if (this.isFieldReferenceToDelegateModel(ref)) { + const thisModel = ref.$container as DataModel; + const targetBase = ref.$inheritedFrom; + fieldWriter = (conditionWriter: () => void) => + this.writeBaseHierarchy(thisModel, targetBase, () => { + this.writer.write(`${ref.name}: `); + conditionWriter(); + }); + } else { + fieldWriter = (conditionWriter: () => void) => { + this.writer.write(`${ref.name}: `); + conditionWriter(); + }; + } } else if (isMemberAccessExpr(fieldAccess)) { - if (isFutureExpr(fieldAccess.operand)) { + if (!isFutureExpr(fieldAccess.operand)) { // future().field should be treated as the "field" - selector = fieldAccess.member.ref?.name; - } else { - selector = fieldAccess.member.ref?.name; operand = fieldAccess.operand; } + fieldWriter = (conditionWriter: () => void) => { + this.writer.write(`${fieldAccess.member.ref?.name}: `); + conditionWriter(); + }; } else { throw new PluginError(name, `Unsupported expression type: ${fieldAccess.$type}`); } - if (!selector) { + if (!fieldWriter) { throw new PluginError(name, `Failed to write FieldAccess expression`); } const writerFilterOutput = () => { - this.writer.write(selector + ': '); - if (filterOp) { - this.block(() => { - this.writer.write(`${filterOp}: `); - writeCondition(); + // this.writer.write(selector + ': '); + fieldWriter!(() => { + if (filterOp) { + this.block(() => { + this.writer.write(`${filterOp}: `); + writeCondition(); - if (extraArgs) { - for (const [k, v] of Object.entries(extraArgs)) { - this.writer.write(`,\n${k}: `); - this.plain(v); + if (extraArgs) { + for (const [k, v] of Object.entries(extraArgs)) { + this.writer.write(`,\n${k}: `); + this.plain(v); + } } - } - }); - } else { - writeCondition(); - } + }); + } else { + writeCondition(); + } + }); }; if (operand) { diff --git a/packages/schema/src/plugins/enhancer/policy/index.ts b/packages/schema/src/plugins/enhancer/policy/index.ts new file mode 100644 index 000000000..8eaf1d00b --- /dev/null +++ b/packages/schema/src/plugins/enhancer/policy/index.ts @@ -0,0 +1,8 @@ +import { type PluginOptions } from '@zenstackhq/sdk'; +import type { Model } from '@zenstackhq/sdk/ast'; +import type { Project } from 'ts-morph'; +import { PolicyGenerator } from './policy-guard-generator'; + +export async function generate(model: Model, options: PluginOptions, project: Project, outDir: string) { + return new PolicyGenerator().generate(project, model, options, outDir); +} diff --git a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts similarity index 67% rename from packages/schema/src/plugins/access-policy/policy-guard-generator.ts rename to packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts index 2025c3d5c..34441e5c1 100644 --- a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts @@ -1,12 +1,11 @@ import { DataModel, - DataModelAttribute, DataModelField, - DataModelFieldAttribute, Enum, Expression, Model, isBinaryExpr, + isUnaryExpr, isDataModel, isDataModelField, isEnum, @@ -15,7 +14,6 @@ import { isMemberAccessExpr, isReferenceExpr, isThisExpr, - isUnaryExpr, } from '@zenstackhq/language/ast'; import { FIELD_LEVEL_OVERRIDE_READ_GUARD_PREFIX, @@ -31,59 +29,48 @@ import { import { ExpressionContext, PluginError, - PluginGlobalOptions, PluginOptions, RUNTIME_PACKAGE, + TypeScriptExpressionTransformer, + TypeScriptExpressionTransformerError, + Z3ExpressionTransformer, analyzePolicies, - createProject, - emitProject, getAttributeArg, getAuthModel, getDataModels, + getIdFields, getLiteral, getPrismaClientImportSpec, hasAttribute, hasValidationAttributes, + isAuthInvocation, isEnumFieldReference, isForeignKeyField, isFromStdlib, isFutureExpr, - resolvePath, resolved, - saveProject, } from '@zenstackhq/sdk'; import { streamAllContents, streamAst, streamContents } from 'langium'; import { lowerCaseFirst } from 'lower-case-first'; import path from 'path'; -import { FunctionDeclaration, SourceFile, VariableDeclarationKind, WriterFunction } from 'ts-morph'; -import { name } from '.'; -import { getIdFields, isAuthInvocation, isCollectionPredicate } from '../../utils/ast-utils'; -import { - TypeScriptExpressionTransformer, - TypeScriptExpressionTransformerError, -} from '../../utils/typescript-expression-transformer'; -import { ALL_OPERATION_KINDS, getDefaultOutputFolder } from '../plugin-utils'; +import { FunctionDeclaration, Project, SourceFile, VariableDeclarationKind, WriterFunction } from 'ts-morph'; +import { name } from '..'; +import { isCollectionPredicate } from '../../../utils/ast-utils'; +import { ALL_OPERATION_KINDS, CRUD_OPERATION_KINDS } from '../../plugin-utils'; import { ExpressionWriter, FALSE, TRUE } from './expression-writer'; /** * Generates source file that contains Prisma query guard objects used for injecting database queries */ -export default class PolicyGenerator { - async generate(model: Model, options: PluginOptions, globalOptions?: PluginGlobalOptions) { - let output = options.output ? (options.output as string) : getDefaultOutputFolder(globalOptions); - if (!output) { - throw new PluginError(options.name, `Unable to determine output path, not running plugin`); - } - output = resolvePath(output, options); - - const project = createProject(); +export class PolicyGenerator { + async generate(project: Project, model: Model, _options: PluginOptions, output: string) { const sf = project.createSourceFile(path.join(output, 'policy.ts'), undefined, { overwrite: true }); sf.addStatements('/* eslint-disable */'); sf.addImportDeclaration({ namedImports: [ { name: 'type QueryContext' }, - { name: 'type DbOperations' }, + { name: 'type CrudContract' }, { name: 'allFieldsEqual' }, { name: 'type PolicyDef' }, ], @@ -99,11 +86,184 @@ export default class PolicyGenerator { }); } + sf.addStatements(` + const processCondition = ( + variable: any, + condition: any, // string conditions are processed as assertions + z3: any, + ): any[] => { + const assertions: any[] = []; + if (typeof condition === 'undefined' || typeof condition === 'string') { + // noop + // user properties are not pre-processed so we have to filter them out if string + } else if (typeof condition === 'number') { + assertions.push(variable.eq(condition)); + } else if (typeof condition === 'boolean') { + assertions.push(variable.eq(condition)); + } else if ('OR' in condition) { + const orCondition = condition; + const tempAssertions: any[] = []; + for (const condition of orCondition.OR) { + if (typeof condition === 'string') { + // string are pre-processed and transformed as Assertion + throw 'Invalid OR condition'; + } + tempAssertions.push(...processCondition(variable, condition, z3)); + } + const orAssertion = z3.Or(...tempAssertions); + assertions.push(orAssertion); + } else if (z3.isBool(variable)) { + assertions.push(variable); + } else { + const tempAssertions: any[] = []; + for (const operator of Object.keys(condition)) { + const value = condition[operator]; + switch (operator) { + case 'eq': + tempAssertions.push(variable.eq(value)); + break; + case 'ne': + tempAssertions.push(variable.neq(value)); + break; + case 'lt': + tempAssertions.push(variable.lt(value)); + break; + case 'le': + tempAssertions.push(variable.le(value)); + break; + case 'gt': + tempAssertions.push(variable.gt(value)); + break; + case 'ge': + tempAssertions.push(variable.ge(value)); + break; + default: + throw new Error('Invalid operator'); + } + } + if (tempAssertions.length > 1) { + const andAssertion = z3.And(...tempAssertions); + assertions.push(andAssertion); + } else if (tempAssertions.length === 1) { + assertions.push(...tempAssertions); + } + } + return assertions; + }; + `); + + // TODO: handle string and array functions in fieldStringValueMap (in, startsWith, includes, etc.) + sf.addFunction({ + name: 'checkStringCondition', + parameters: [ + { + name: 'args', + type: 'any', + }, + { + name: 'fieldStringValueMap', + type: 'Record = {}', + }, + ], + returnType: 'boolean', + statements: (writer) => { + writer.write(` + const key = Object.keys(fieldStringValueMap)[0]; + const condition = args[key]; + if (typeof condition === 'string') { + return args[key] === fieldStringValueMap[key]; + } + if (typeof condition === 'object' && 'in' in condition) { + return condition.in.some(condition === fieldStringValueMap[key]); + } + if (typeof condition === 'object' && 'startsWith' in condition) { + return fieldStringValueMap[key].startsWith(condition.startsWith); + }; + return true; + `); + }, + }); + + sf.addFunction({ + name: 'buildAssertion', + parameters: [ + { + name: 'z3', + type: 'any', + }, + { + name: 'variables', + type: 'Record', + }, + { + name: 'args', + type: 'Record = {}', + }, + { + name: 'user?', + type: 'any', + }, + { + name: 'fieldStringValueMap', + type: 'Record = {}', + }, + ], + returnType: 'any', + statements: (writer) => { + writer.write(` + const processedVariables = Object.keys(variables).reduce((acc, key) => { + const newKey = key.replace(/^_/, ''); + acc[newKey] = variables[key]; + return acc; + }, {} as Record); + const assertions: any[] = []; + if ('OR' in args) { + const tempAssertions: any[] = []; + for (const arg of args.OR) { + tempAssertions.push(buildAssertion(z3, processedVariables, arg, user, fieldStringValueMap)); + } + const orAssertion = z3.Or(...tempAssertions); + assertions.push(orAssertion); + } + + // handle string conditions + // TODO: handle string conditions for user properties + const condition = checkStringCondition(args, fieldStringValueMap); + if (condition === false) { + return z3.Bool.val(false); + } + + const tempAssertions: any[] = []; + + for (const property of Object.keys(args)) { + const condition = args[property]; + // TODO: handle nested properties + const variable = processedVariables[property]; + if (variable) { + tempAssertions.push(...processCondition(variable, condition, z3)); + } + } + + // avoid empty assertions in case of unique value or boolean + if (tempAssertions.length > 1) { + const andAssertion = z3.And(...tempAssertions); + assertions.push(andAssertion); + } else if (tempAssertions.length === 1) { + assertions.push(...tempAssertions); + } + + return z3.And(...assertions); + `); + }, + }); + const models = getDataModels(model); const policyMap: Record> = {}; + const permissionMap: Record> = {}; for (const model of models) { policyMap[model.name] = await this.generateQueryGuardForModel(model, sf); + permissionMap[model.name] = await this.generatePermissionCheckerForModel(model, sf); } const authSelector = this.generateAuthSelector(models); @@ -144,6 +304,24 @@ export default class PolicyGenerator { writer.writeLine(','); } }); + writer.writeLine(','); + + writer.write('permission:'); + writer.inlineBlock(() => { + for (const [model, map] of Object.entries(permissionMap)) { + writer.write(`${lowerCaseFirst(model)}:`); + writer.inlineBlock(() => { + for (const [op, func] of Object.entries(map)) { + if (typeof func === 'object') { + writer.write(`${op}: ${JSON.stringify(func)},`); + } else { + writer.write(`${op}: ${func},`); + } + } + }); + writer.write(','); + } + }); if (authSelector) { writer.writeLine(','); @@ -155,23 +333,7 @@ export default class PolicyGenerator { ], }); - sf.addStatements('export default policy'); - - let shouldCompile = true; - if (typeof options.compile === 'boolean') { - // explicit override - shouldCompile = options.compile; - } else if (globalOptions) { - shouldCompile = globalOptions.compile; - } - - if (!shouldCompile || options.preserveTsFiles === true) { - // save ts files - await saveProject(project); - } - if (shouldCompile) { - await emitProject(project); - } + sf.addStatements('export default policy;'); } // Generates a { select: ... } object to select `auth()` fields used in policy rules @@ -231,7 +393,7 @@ export default class PolicyGenerator { operation: PolicyOperationKind, override = false ) { - const attributes = target.attributes as (DataModelAttribute | DataModelFieldAttribute)[]; + const attributes = target.attributes; const attrName = isDataModel(target) ? `@@${kind}` : `@${kind}`; const attrs = attributes.filter((attr) => { if (attr.decl.ref?.name !== attrName) { @@ -264,7 +426,6 @@ export default class PolicyGenerator { } else if (operation === 'postUpdate') { result = this.processUpdatePolicies(result, true); } - return result; } @@ -281,30 +442,6 @@ export default class PolicyGenerator { } } - private visitPolicyExpression(expr: Expression, postUpdate: boolean): Expression | undefined { - if (isBinaryExpr(expr) && (expr.operator === '&&' || expr.operator === '||')) { - const left = this.visitPolicyExpression(expr.left, postUpdate); - const right = this.visitPolicyExpression(expr.right, postUpdate); - if (!left) return right; - if (!right) return left; - return { ...expr, left, right }; - } - - if (isUnaryExpr(expr) && expr.operator === '!') { - const operand = this.visitPolicyExpression(expr.operand, postUpdate); - if (!operand) return undefined; - return { ...expr, operand }; - } - - if (postUpdate && !this.hasFutureReference(expr)) { - return undefined; - } else if (!postUpdate && this.hasFutureReference(expr)) { - return undefined; - } - - return expr; - } - private hasFutureReference(expr: Expression) { for (const node of streamAst(expr)) { if (isInvocationExpr(node) && node.function.ref?.name === 'future' && isFromStdlib(node.function.ref)) { @@ -380,6 +517,21 @@ export default class PolicyGenerator { return result; } + private async generatePermissionCheckerForModel(model: DataModel, sourceFile: SourceFile) { + const result: Record = {}; + + for (const kind of CRUD_OPERATION_KINDS) { + const denies = this.getPolicyExpressions(model, 'deny', kind); + const allows = this.getPolicyExpressions(model, 'allow', kind); + + const checkFunc = this.generatePermissionCheckerFunction(sourceFile, model, kind, allows, denies); + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + result[kind] = checkFunc.getName()!; + } + + return result; + } + private generateReadFieldsCheckers( model: DataModel, sourceFile: SourceFile, @@ -794,7 +946,7 @@ export default class PolicyGenerator { { // for generating field references used by field comparison in the same model name: 'db', - type: 'Record', + type: 'CrudContract', }, ], statements, @@ -803,6 +955,150 @@ export default class PolicyGenerator { return func; } + private generatePermissionCheckerFunction( + sourceFile: SourceFile, + model: DataModel, + kind: PolicyOperationKind, + allows: Expression[], + denies: Expression[] + ) { + const statements: (string | WriterFunction)[] = []; + + statements.push((writer) => { + const transformer = new Z3ExpressionTransformer({ + context: ExpressionContext.AccessPolicy, + }); + try { + writer.writeLine('const solver = new z3.Solver();'); + + const variables: Record = this.generateVariables([...denies, ...allows]); + Object.keys(variables).forEach((key) => { + writer.writeLine(`const ${key} = ${variables[key]};`); + }); + writer.writeLine(`const _withAuth = !!user?.id;`); + writer.writeLine( + `const variables = { ${Object.keys(variables) + .map((v) => v) + .join(', ')} };` + ); + + const denyStmt = + denies.length > 1 + ? 'z3.Not(z3.Or(' + + denies + .map((deny) => { + return transformer.transform(deny); + }) + .join(', ') + + '))' + : denies.length === 1 + ? `z3.Not(${transformer.transform(denies[0])})` + : undefined; + const allowStmt = + allows.length > 1 + ? 'z3.Or(' + + allows + .map((allow) => { + return transformer.transform(allow); + }) + .join(', ') + + ')' + : allows.length === 1 + ? transformer.transform(allows[0]) + : undefined; + let assertion; + if (denyStmt && allowStmt) { + assertion = `z3.And(${denyStmt}, ${allowStmt})`; + } else if (denyStmt) { + assertion = denyStmt; + } else if (allowStmt) { + assertion = allowStmt; + } else { + assertion = `z3.Bool.val(false)`; + } + writer.writeLine(`const assertion = ${assertion};`); + writer.writeLine(`const assertionFromArgs = buildAssertion(z3, variables, args, user);`); + writer.writeLine(`solver.add(z3.And(assertion, assertionFromArgs));`); + writer.write(`return (await solver.check()) === "sat";`); + } catch (err) { + if (err instanceof TypeScriptExpressionTransformerError) { + throw new PluginError(name, err.message); + } else { + throw err; + } + } + }); + + const func = sourceFile.addFunction({ + isAsync: true, + name: `check_${model.name}_${kind}`, + returnType: 'Promise', + parameters: [ + { + name: 'z3', + type: 'any', + }, + { + name: 'args', + type: 'Record', + }, + { + name: 'user?', + type: 'any', + }, + ], + statements, + }); + + return func; + } + generateVariables(expressions: Expression[]): Record { + const result: Record = {}; + expressions.forEach((expr) => { + const variables = this.collectVariablesTypes(expr); + Object.keys(variables).forEach((key) => { + switch (variables[key]) { + case 'NumberLiteral': + result[`_${key}`] = `z3.Int.const("${key}")`; + break; + case 'BooleanLiteral': + result[`_${key}`] = `z3.Bool.const("${key}")`; + break; + default: + break; + } + }); + }); + return result; + } + collectVariablesTypes(expr: Expression): Record { + const result: Record = {}; + const visit = (node: Expression) => { + if (isReferenceExpr(node)) { + const variableName = node.target.ref?.name ?? 'unknown'; + result[variableName] = 'BooleanLiteral'; + } else if (isBinaryExpr(node) && typeof (node.right.$type !== 'StringLiteral')) { + if (isReferenceExpr(node.left)) { + // const variableName = `${lowerCaseFirst( + // node.left.target.ref?.$container.name ?? '' + // )}${upperCaseFirst(node.left.target?.ref?.name ?? '')}`; + const variableName = `${node.left.target?.ref?.name}`; + result[variableName] = node.right.$type; + // visit(node.right); + // } else if (isUnaryExpr(node) && node.operator === '!') { + // visit(node.operand); + } else { + visit(node.left); + visit(node.right); + } + } else if (isMemberAccessExpr(node) || isUnaryExpr(node)) { + visit(node.operand); + } + }; + visit(expr); + return result; + } + private generateInputCheckFunction( sourceFile: SourceFile, model: DataModel, diff --git a/packages/schema/src/plugins/model-meta/index.ts b/packages/schema/src/plugins/model-meta/index.ts deleted file mode 100644 index 8d7454674..000000000 --- a/packages/schema/src/plugins/model-meta/index.ts +++ /dev/null @@ -1,42 +0,0 @@ -import { - createProject, - generateModelMeta, - getDataModels, - PluginError, - PluginFunction, - resolvePath, -} from '@zenstackhq/sdk'; -import path from 'path'; -import { getDefaultOutputFolder } from '../plugin-utils'; - -export const name = 'Model Metadata'; - -const run: PluginFunction = async (model, options, _dmmf, globalOptions) => { - let output = options.output ? (options.output as string) : getDefaultOutputFolder(globalOptions); - if (!output) { - throw new PluginError(options.name, `Unable to determine output path, not running plugin`); - } - - output = resolvePath(output, options); - const outFile = path.join(output, 'model-meta.ts'); - const dataModels = getDataModels(model); - const project = createProject(); - - let shouldCompile = true; - if (typeof options.compile === 'boolean') { - // explicit override - shouldCompile = options.compile; - } else if (globalOptions) { - // from CLI or config file - shouldCompile = globalOptions.compile; - } - - await generateModelMeta(project, dataModels, { - output: outFile, - compile: shouldCompile, - preserveTsFiles: options.preserveTsFiles === true, - generateAttributes: true, - }); -}; - -export default run; diff --git a/packages/schema/src/plugins/plugin-utils.ts b/packages/schema/src/plugins/plugin-utils.ts index e095de898..d59741cc9 100644 --- a/packages/schema/src/plugins/plugin-utils.ts +++ b/packages/schema/src/plugins/plugin-utils.ts @@ -1,10 +1,11 @@ -import { DEFAULT_RUNTIME_LOAD_PATH, type PolicyOperationKind } from '@zenstackhq/runtime'; +import { DEFAULT_RUNTIME_LOAD_PATH, type PolicyOperationKind, type CRUDOperationKind } from '@zenstackhq/runtime'; import { PluginGlobalOptions } from '@zenstackhq/sdk'; import fs from 'fs'; import path from 'path'; import { PluginRunnerOptions } from '../cli/plugin-runner'; export const ALL_OPERATION_KINDS: PolicyOperationKind[] = ['create', 'update', 'postUpdate', 'read', 'delete']; +export const CRUD_OPERATION_KINDS: CRUDOperationKind[] = ['create', 'update', 'read', 'delete']; /** * Gets the nearest "node_modules" folder by walking up from start path. @@ -35,13 +36,9 @@ export function ensureDefaultOutputFolder(options: PluginRunnerOptions) { name: '.zenstack', version: '1.0.0', exports: { - './model-meta': { - types: './model-meta.ts', - default: './model-meta.js', - }, - './policy': { - types: './policy.d.ts', - default: './policy.js', + './enhance': { + types: './enhance.d.ts', + default: './enhance.js', }, './zod': { types: './zod/index.d.ts', @@ -81,7 +78,7 @@ export function getDefaultOutputFolder(globalOptions?: PluginGlobalOptions) { let runtimeModulePath = require.resolve('@zenstackhq/runtime'); if (process.env.ZENSTACK_TEST === '1') { - // handling the case when running as tests, resolve relative to CWD + // handle the case when running as tests, resolve relative to CWD runtimeModulePath = path.resolve(path.join(process.cwd(), 'node_modules', '@zenstackhq', 'runtime')); } @@ -95,3 +92,12 @@ export function getDefaultOutputFolder(globalOptions?: PluginGlobalOptions) { const modulesFolder = getNodeModulesFolder(runtimeModulePath); return modulesFolder ? path.join(modulesFolder, DEFAULT_RUNTIME_LOAD_PATH) : undefined; } + +/** + * Core plugin providers + */ +export enum CorePlugins { + Prisma = '@core/prisma', + Zod = '@core/zod', + Enhancer = '@core/enhancer', +} diff --git a/packages/schema/src/plugins/prisma/index.ts b/packages/schema/src/plugins/prisma/index.ts index 3a96cf40f..b27624cd7 100644 --- a/packages/schema/src/plugins/prisma/index.ts +++ b/packages/schema/src/plugins/prisma/index.ts @@ -1,7 +1,8 @@ import { PluginFunction } from '@zenstackhq/sdk'; -import PrismaSchemaGenerator from './schema-generator'; +import { PrismaSchemaGenerator } from './schema-generator'; export const name = 'Prisma'; +export const description = 'Generating Prisma schema'; const run: PluginFunction = async (model, options, _dmmf, _globalOptions) => { return new PrismaSchemaGenerator().generate(model, options); diff --git a/packages/schema/src/plugins/prisma/prisma-builder.ts b/packages/schema/src/plugins/prisma/prisma-builder.ts index 64777b62e..594913f8c 100644 --- a/packages/schema/src/plugins/prisma/prisma-builder.ts +++ b/packages/schema/src/plugins/prisma/prisma-builder.ts @@ -110,10 +110,15 @@ export class Model extends ContainerDeclaration { name: string, type: ModelFieldType | string, attributes: (FieldAttribute | PassThroughAttribute)[] = [], - documentations: string[] = [] + documentations: string[] = [], + addToFront = false ): ModelField { const field = new ModelField(name, type, attributes, documentations); - this.fields.push(field); + if (addToFront) { + this.fields.unshift(field); + } else { + this.fields.push(field); + } return field; } @@ -288,7 +293,7 @@ export class FieldReference { } export class FieldReferenceArg { - constructor(public name: 'sort', public value: 'Asc' | 'Desc') {} + constructor(public name: string, public value: string) {} toString(): string { return `${this.name}: ${this.value}`; @@ -304,10 +309,10 @@ export class FunctionCall { } export class FunctionCallArg { - constructor(public name: string | undefined, public value: string) {} + constructor(public value: string) {} toString(): string { - return this.name ? `${this.name}: ${this.value}` : this.value; + return this.value; } } diff --git a/packages/schema/src/plugins/prisma/schema-generator.ts b/packages/schema/src/plugins/prisma/schema-generator.ts index 98dfa717e..72d2a02e6 100644 --- a/packages/schema/src/plugins/prisma/schema-generator.ts +++ b/packages/schema/src/plugins/prisma/schema-generator.ts @@ -16,6 +16,7 @@ import { GeneratorDecl, InvocationExpr, isArrayExpr, + isDataModel, isInvocationExpr, isLiteralExpr, isNullExpr, @@ -27,12 +28,17 @@ import { StringLiteral, } from '@zenstackhq/language/ast'; import { match } from 'ts-pattern'; +import { getIdFields } from '../../utils/ast-utils'; -import { PRISMA_MINIMUM_VERSION } from '@zenstackhq/runtime'; +import { DELEGATE_AUX_RELATION_PREFIX, PRISMA_MINIMUM_VERSION } from '@zenstackhq/runtime'; import { + getAttribute, getDMMF, getLiteral, getPrismaVersion, + isAuthInvocation, + isDelegateModel, + isIdField, PluginError, PluginOptions, resolved, @@ -41,15 +47,19 @@ import { } from '@zenstackhq/sdk'; import fs from 'fs'; import { writeFile } from 'fs/promises'; +import { streamAst } from 'langium'; +import { lowerCaseFirst } from 'lower-case-first'; import path from 'path'; import semver from 'semver'; import stripColor from 'strip-color'; +import { upperCaseFirst } from 'upper-case-first'; import { name } from '.'; import { getStringLiteral } from '../../language-server/validator/utils'; import telemetry from '../../telemetry'; import { execSync } from '../../utils/exec-utils'; import { findPackageJson } from '../../utils/pkg-utils'; import { + AttributeArgValue, ModelFieldType, AttributeArg as PrismaAttributeArg, AttributeArgValue as PrismaAttributeArgValue, @@ -73,7 +83,7 @@ const FIELD_PASSTHROUGH_ATTR = '@prisma.passthrough'; /** * Generates Prisma schema file */ -export default class PrismaSchemaGenerator { +export class PrismaSchemaGenerator { private zModelGenerator: ZModelCodeGenerator = new ZModelCodeGenerator(); private readonly PRELUDE = `////////////////////////////////////////////////////////////////////////////////////////////// @@ -83,8 +93,13 @@ export default class PrismaSchemaGenerator { `; + private mode: 'logical' | 'physical' = 'physical'; + async generate(model: Model, options: PluginOptions) { const warnings: string[] = []; + if (options.mode) { + this.mode = options.mode as 'logical' | 'physical'; + } const prismaVersion = getPrismaVersion(); if (prismaVersion && semver.lt(prismaVersion, PRISMA_MINIMUM_VERSION)) { @@ -110,7 +125,7 @@ export default class PrismaSchemaGenerator { break; case GeneratorDecl: - this.generateGenerator(prisma, decl as GeneratorDecl); + this.generateGenerator(prisma, decl as GeneratorDecl, options); break; } } @@ -127,7 +142,7 @@ export default class PrismaSchemaGenerator { if (options.format === true) { try { // run 'prisma format' - await execSync(`npx prisma format --schema ${outFile}`); + await execSync(`npx prisma format --schema ${outFile}`, { stdio: 'ignore' }); } catch { warnings.push(`Failed to format Prisma schema file`); } @@ -142,7 +157,7 @@ export default class PrismaSchemaGenerator { } try { // run 'prisma generate' - await execSync(generateCmd, 'ignore'); + await execSync(generateCmd, { stdio: 'ignore' }); } catch { await this.trackPrismaSchemaError(outFile); try { @@ -217,7 +232,11 @@ export default class PrismaSchemaGenerator { return JSON.stringify(expr.value); } - private generateGenerator(prisma: PrismaModel, decl: GeneratorDecl) { + private exprToText(expr: Expression) { + return new ZModelCodeGenerator({ quote: 'double' }).generate(expr); + } + + private generateGenerator(prisma: PrismaModel, decl: GeneratorDecl, options: PluginOptions) { const generator = prisma.addGenerator( decl.name, decl.fields.map((f) => ({ name: f.name, text: this.configExprToText(f.value) })) @@ -259,13 +278,38 @@ export default class PrismaSchemaGenerator { } } } + + if (typeof options.overrideClientGenerationPath === 'string') { + const output = generator.fields.find((f) => f.name === 'output'); + if (output) { + output.text = JSON.stringify(options.overrideClientGenerationPath); + } else { + generator.fields.push({ + name: 'output', + text: JSON.stringify(options.overrideClientGenerationPath), + }); + } + } } } private generateModel(prisma: PrismaModel, decl: DataModel) { const model = decl.isView ? prisma.addView(decl.name) : prisma.addModel(decl.name); for (const field of decl.fields) { - this.generateModelField(model, field); + if (field.$inheritedFrom) { + if ( + // abstract inheritance is always kept + field.$inheritedFrom.isAbstract || + // logical schema keeps all inherited fields + this.mode === 'logical' || + // id fields are always kept + isIdField(field) + ) { + this.generateModelField(model, field); + } + } else { + this.generateModelField(model, field); + } } for (const attr of decl.attributes.filter((attr) => this.isPrismaAttribute(attr))) { @@ -278,6 +322,148 @@ export default class PrismaSchemaGenerator { // user defined comments pass-through decl.comments.forEach((c) => model.addComment(c)); + + // generate relation fields on base models linking to concrete models + this.generateDelegateRelationForBase(model, decl); + + // generate reverse relation fields on concrete models + this.generateDelegateRelationForConcrete(model, decl); + + // expand relations on other models that reference delegated models to concrete models + this.expandPolymorphicRelations(model, decl); + } + + private generateDelegateRelationForBase(model: PrismaDataModel, decl: DataModel) { + if (this.mode !== 'physical') { + return; + } + + if (!isDelegateModel(decl)) { + return; + } + + // collect concrete models inheriting this model + const concreteModels = decl.$container.declarations.filter( + (d) => isDataModel(d) && d !== decl && d.superTypes.some((base) => base.ref === decl) + ); + + // generate an optional relation field in delegate base model to each concrete model + concreteModels.forEach((concrete) => { + const auxName = `${DELEGATE_AUX_RELATION_PREFIX}_${lowerCaseFirst(concrete.name)}`; + model.addField(auxName, new ModelFieldType(concrete.name, false, true)); + }); + } + + private generateDelegateRelationForConcrete(model: PrismaDataModel, concreteDecl: DataModel) { + if (this.mode !== 'physical') { + return; + } + + // generate a relation field for each delegated base model + + const baseModels = concreteDecl.superTypes + .map((t) => t.ref) + .filter((t): t is DataModel => !!t) + .filter((t) => isDelegateModel(t)); + + baseModels.forEach((base) => { + const idFields = getIdFields(base); + + // add relation fields + const relationField = `${DELEGATE_AUX_RELATION_PREFIX}_${lowerCaseFirst(base.name)}`; + model.addField(relationField, base.name, [ + new PrismaFieldAttribute('@relation', [ + new PrismaAttributeArg( + 'fields', + new AttributeArgValue( + 'Array', + idFields.map( + (idField) => + new AttributeArgValue('FieldReference', new PrismaFieldReference(idField.name)) + ) + ) + ), + new PrismaAttributeArg( + 'references', + new AttributeArgValue( + 'Array', + idFields.map( + (idField) => + new AttributeArgValue('FieldReference', new PrismaFieldReference(idField.name)) + ) + ) + ), + new PrismaAttributeArg( + 'onDelete', + new AttributeArgValue('FieldReference', new PrismaFieldReference('Cascade')) + ), + new PrismaAttributeArg( + 'onUpdate', + new AttributeArgValue('FieldReference', new PrismaFieldReference('Cascade')) + ), + ]), + ]); + }); + } + + private expandPolymorphicRelations(model: PrismaDataModel, decl: DataModel) { + if (this.mode !== 'logical') { + return; + } + + // the logical schema needs to expand relations to the delegate models to concrete ones + + // for the given model, find all concrete models that have relation to it, + // and generate an auxiliary opposite relation field + decl.fields.forEach((f) => { + const fieldType = f.type.reference?.ref; + if (!isDataModel(fieldType)) { + return; + } + + // find concrete models that inherit from this field's model type + const concreteModels = decl.$container.declarations.filter( + (d) => isDataModel(d) && isDescendantOf(d, fieldType) + ); + + concreteModels.forEach((concrete) => { + const relationField = model.addField( + `${DELEGATE_AUX_RELATION_PREFIX}_${lowerCaseFirst(concrete.name)}`, + new ModelFieldType(concrete.name, f.type.array, f.type.optional) + ); + const relAttr = getAttribute(f, '@relation'); + if (relAttr) { + const fieldsArg = relAttr.args.find((arg) => arg.name === 'fields'); + if (fieldsArg) { + const idFields = getIdFields(fieldType); + idFields.forEach((idField) => { + model.addField( + `${DELEGATE_AUX_RELATION_PREFIX}_${lowerCaseFirst(concrete.name)}${upperCaseFirst( + idField.name + )}`, + idField.type.type! + ); + }); + + const args = new AttributeArgValue( + 'Array', + idFields.map( + (idField) => + new AttributeArgValue('FieldReference', new PrismaFieldReference(idField.name)) + ) + ); + relationField.attributes.push( + new PrismaFieldAttribute('@relation', [ + new PrismaAttributeArg('fields', args), + new PrismaAttributeArg('references', args), + ]) + ); + } else { + relationField.attributes.push(this.makeFieldAttribute(relAttr as DataModelFieldAttribute)); + } + } + }); + }); } private isPrismaAttribute(attr: DataModelAttribute | DataModelFieldAttribute) { @@ -306,7 +492,7 @@ export default class PrismaSchemaGenerator { } } - private generateModelField(model: PrismaDataModel, field: DataModelField) { + private generateModelField(model: PrismaDataModel, field: DataModelField, addToFront = false) { const fieldType = field.type.type || field.type.reference?.ref?.name || this.getUnsupportedFieldType(field.type); if (!fieldType) { @@ -317,18 +503,48 @@ export default class PrismaSchemaGenerator { const attributes = field.attributes .filter((attr) => this.isPrismaAttribute(attr)) + // `@default` with `auth()` is handled outside Prisma + .filter((attr) => !this.isDefaultWithAuth(attr)) + .filter( + (attr) => + // when building physical schema, exclude `@default` for id fields inherited from delegate base + !( + this.mode === 'physical' && + isIdField(field) && + this.isInheritedFromDelegate(field) && + attr.decl.$refText === '@default' + ) + ) .map((attr) => this.makeFieldAttribute(attr)); const nonPrismaAttributes = field.attributes.filter((attr) => attr.decl.ref && !this.isPrismaAttribute(attr)); const documentations = nonPrismaAttributes.map((attr) => '/// ' + this.zModelGenerator.generate(attr)); - const result = model.addField(field.name, type, attributes, documentations); + const result = model.addField(field.name, type, attributes, documentations, addToFront); // user defined comments pass-through field.comments.forEach((c) => result.addComment(c)); } + private isInheritedFromDelegate(field: DataModelField) { + return field.$inheritedFrom && isDelegateModel(field.$inheritedFrom); + } + + private isDefaultWithAuth(attr: DataModelFieldAttribute) { + if (attr.decl.ref?.name !== '@default') { + return false; + } + + const expr = attr.args[0]?.value; + if (!expr) { + return false; + } + + // find `auth()` in default value expression + return streamAst(expr).some(isAuthInvocation); + } + private makeFieldAttribute(attr: DataModelFieldAttribute) { const attrName = resolved(attr.decl).name; if (attrName === FIELD_PASSTHROUGH_ATTR) { @@ -368,7 +584,7 @@ export default class PrismaSchemaGenerator { 'FieldReference', new PrismaFieldReference( resolved(node.target).name, - node.args.map((arg) => new PrismaFieldReferenceArg(arg.name, arg.value)) + node.args.map((arg) => new PrismaFieldReferenceArg(arg.name, this.exprToText(arg.value))) ) ); } else if (isInvocationExpr(node)) { @@ -391,7 +607,7 @@ export default class PrismaSchemaGenerator { throw new PluginError(name, 'Function call argument must be literal or null'); }); - return new PrismaFunctionCallArg(arg.name, val); + return new PrismaFunctionCallArg(val); }) ); } @@ -444,6 +660,10 @@ export default class PrismaSchemaGenerator { } } +function isDescendantOf(model: DataModel, superModel: DataModel): boolean { + return model.superTypes.some((s) => s.ref === superModel || isDescendantOf(s.ref!, superModel)); +} + export function getDefaultPrismaOutputFile(schemaPath: string) { // handle override from package.json const pkgJsonPath = findPackageJson(path.dirname(schemaPath)); diff --git a/packages/schema/src/plugins/zod/generator.ts b/packages/schema/src/plugins/zod/generator.ts index 2727a781f..a09c4ad73 100644 --- a/packages/schema/src/plugins/zod/generator.ts +++ b/packages/schema/src/plugins/zod/generator.ts @@ -26,7 +26,7 @@ import { name } from '.'; import { getDefaultOutputFolder } from '../plugin-utils'; import Transformer from './transformer'; import removeDir from './utils/removeDir'; -import { makeFieldSchema, makeValidationRefinements, getFieldSchemaDefault } from './utils/schema-gen'; +import { getFieldSchemaDefault, makeFieldSchema, makeValidationRefinements } from './utils/schema-gen'; export async function generate( model: Model, @@ -395,7 +395,7 @@ async function generateModelSchema(model: DataModel, project: Project, output: s //////////////////////////////////////////////// // schema for validating prisma create input (all fields optional) - let prismaCreateSchema = makePartial('baseSchema'); + let prismaCreateSchema = makePassthrough(makePartial('baseSchema')); if (refineFuncName) { prismaCreateSchema = `${refineFuncName}(${prismaCreateSchema})`; } @@ -501,3 +501,7 @@ function makeOmit(schema: string, fields: string[]) { function makeMerge(schema1: string, schema2: string): string { return `${schema1}.merge(${schema2})`; } + +function makePassthrough(schema: string) { + return `${schema}.passthrough()`; +} diff --git a/packages/schema/src/plugins/zod/index.ts b/packages/schema/src/plugins/zod/index.ts index b2b43cb40..53a30b4e3 100644 --- a/packages/schema/src/plugins/zod/index.ts +++ b/packages/schema/src/plugins/zod/index.ts @@ -3,6 +3,7 @@ import invariant from 'tiny-invariant'; import { generate } from './generator'; export const name = 'Zod'; +export const description = 'Generating Zod schemas'; const run: PluginFunction = async (model, options, dmmf, globalOptions) => { invariant(dmmf); diff --git a/packages/schema/src/plugins/zod/utils/schema-gen.ts b/packages/schema/src/plugins/zod/utils/schema-gen.ts index 39e7d2bb2..74d3c18b7 100644 --- a/packages/schema/src/plugins/zod/utils/schema-gen.ts +++ b/packages/schema/src/plugins/zod/utils/schema-gen.ts @@ -1,6 +1,8 @@ import { ExpressionContext, PluginError, + TypeScriptExpressionTransformer, + TypeScriptExpressionTransformerError, getAttributeArg, getAttributeArgLiteral, getLiteral, @@ -18,10 +20,6 @@ import { } from '@zenstackhq/sdk/ast'; import { upperCaseFirst } from 'upper-case-first'; import { name } from '..'; -import { - TypeScriptExpressionTransformer, - TypeScriptExpressionTransformerError, -} from '../../../utils/typescript-expression-transformer'; export function makeFieldSchema(field: DataModelField, respectDefault = false) { if (isDataModel(field.type.reference?.ref)) { diff --git a/packages/schema/src/res/stdlib.zmodel b/packages/schema/src/res/stdlib.zmodel index be241fe2c..fd470efb8 100644 --- a/packages/schema/src/res/stdlib.zmodel +++ b/packages/schema/src/res/stdlib.zmodel @@ -61,6 +61,9 @@ enum ExpressionContext { // used in @@validate ValidationRule + + // used in @@index + Index } /** @@ -73,7 +76,7 @@ function env(name: String): String { * Gets the current login user. */ function auth(): Any { -} @@@expressionContext([AccessPolicy]) +} @@@expressionContext([DefaultValue, AccessPolicy]) /** * Gets current date-time (as DateTime type). @@ -200,11 +203,11 @@ attribute @@@completionHint(_ values: String[]) * @param sort: Allows you to specify in what order the entries of the ID are stored in the database. The available options are Asc and Desc. * @param clustered: Defines whether the ID is clustered or non-clustered. Defaults to true. */ -attribute @id(map: String?, length: Int?, sort: String?, clustered: Boolean?) @@@prisma +attribute @id(map: String?, length: Int?, sort: SortOrder?, clustered: Boolean?) @@@prisma /** * Defines a default value for a field. - * @param value: An expression (e.g. 5, true, now()). + * @param value: An expression (e.g. 5, true, now(), auth()). */ attribute @default(_ value: ContextType, map: String?) @@@prisma @@ -215,7 +218,7 @@ attribute @default(_ value: ContextType, map: String?) @@@prisma * @param sort: Allows you to specify in what order the entries of the constraint are stored in the database. The available options are Asc and Desc. * @param clustered: Boolean Defines whether the constraint is clustered or non-clustered. Defaults to false. */ -attribute @unique(map: String?, length: Int?, sort: String?, clustered: Boolean?) @@@prisma +attribute @unique(map: String?, length: Int?, sort: SortOrder?, clustered: Boolean?) @@@prisma /** * Defines a multi-field ID (composite ID) on the model. @@ -227,7 +230,7 @@ attribute @unique(map: String?, length: Int?, sort: String?, clustered: Boolean? * @param sort: Allows you to specify in what order the entries of the ID are stored in the database. The available options are Asc and Desc. * @param clustered: Defines whether the ID is clustered or non-clustered. Defaults to true. */ -attribute @@id(_ fields: FieldReference[], name: String?, map: String?, length: Int?, sort: String?, clustered: Boolean?) @@@prisma +attribute @@id(_ fields: FieldReference[], name: String?, map: String?, length: Int?, sort: SortOrder?, clustered: Boolean?) @@@prisma /** * Defines a compound unique constraint for the specified fields. @@ -238,7 +241,7 @@ attribute @@id(_ fields: FieldReference[], name: String?, map: String?, length: * @param sort: Allows you to specify in what order the entries of the constraint are stored in the database. The available options are Asc and Desc. * @param clustered: Boolean Defines whether the constraint is clustered or non-clustered. Defaults to false. */ -attribute @@unique(_ fields: FieldReference[], name: String?, map: String?, length: Int?, sort: String?, clustered: Boolean?) @@@prisma +attribute @@unique(_ fields: FieldReference[], name: String?, map: String?, length: Int?, sort: SortOrder?, clustered: Boolean?) @@@prisma /** * Index types @@ -252,6 +255,84 @@ enum IndexType { Brin } +/** + * Operator class for index + */ +enum IndexOperatorClass { + // GIN + ArrayOps + JsonbOps + JsonbPathOps + + // Gist + InetOps + + // SpGist + TextOps + + // BRIN + BitMinMaxOps + VarBitMinMaxOps + BpcharBloomOps + BpcharMinMaxOps + ByteaBloomOps + ByteaMinMaxOps + DateBloomOps + DateMinMaxOps + DateMinMaxMultiOps + Float4BloomOps + Float4MinMaxOps + Float4MinMaxMultiOps + Float8BloomOps + Float8MinMaxOps + Float8MinMaxMultiOps + InetInclusionOps + InetBloomOps + InetMinMaxOps + InetMinMaxMultiOps + Int2BloomOps + Int2MinMaxOps + Int2MinMaxMultiOps + Int4BloomOps + Int4MinMaxOps + Int4MinMaxMultiOps + Int8BloomOps + Int8MinMaxOps + Int8MinMaxMultiOps + NumericBloomOps + NumericMinMaxOps + NumericMinMaxMultiOps + OidBloomOps + OidMinMaxOps + OidMinMaxMultiOps + TextBloomOps + TextMinMaxOps + TextMinMaxMultiOps + TimestampBloomOps + TimestampMinMaxOps + TimestampMinMaxMultiOps + TimestampTzBloomOps + TimestampTzMinMaxOps + TimestampTzMinMaxMultiOps + TimeBloomOps + TimeMinMaxOps + TimeMinMaxMultiOps + TimeTzBloomOps + TimeTzMinMaxOps + TimeTzMinMaxMultiOps + UuidBloomOps + UuidMinMaxOps + UuidMinMaxMultiOps +} + +/** + * Index sort order + */ +enum SortOrder { + Asc + Desc +} + /** * Defines an index in the database. * @@ -263,7 +344,7 @@ enum IndexType { * @params clustered: Defines whether the index is clustered or non-clustered. Defaults to false. * @params type: Allows you to specify an index access method. Defaults to BTree. */ -attribute @@index(_ fields: FieldReference[], name: String?, map: String?, length: Int?, sort: String?, clustered: Boolean?, type: IndexType?) @@@prisma +attribute @@index(_ fields: FieldReference[], name: String?, map: String?, length: Int?, sort: SortOrder?, clustered: Boolean?, type: IndexType?) @@@prisma /** * Defines meta information about the relation. @@ -598,3 +679,14 @@ attribute @prisma.passthrough(_ text: String) * A utility attribute to allow passthrough of arbitrary attribute text to the generated Prisma schema. */ attribute @@prisma.passthrough(_ text: String) + +/** + * Marks a model to be a delegate. Used for implementing polymorphism. + */ +attribute @@delegate(_ discriminator: FieldReference) + +/** + * Used for specifying operator classes for GIN index. + */ +function raw(value: String): Any { +} @@@expressionContext([Index]) diff --git a/packages/schema/src/utils/ast-utils.ts b/packages/schema/src/utils/ast-utils.ts index 661f14b26..2688987a2 100644 --- a/packages/schema/src/utils/ast-utils.ts +++ b/packages/schema/src/utils/ast-utils.ts @@ -3,6 +3,7 @@ import { DataModel, DataModelField, Expression, + InheritableNode, isArrayExpr, isBinaryExpr, isDataModel, @@ -16,7 +17,17 @@ import { ReferenceExpr, } from '@zenstackhq/language/ast'; import { isFromStdlib } from '@zenstackhq/sdk'; -import { AstNode, getDocument, LangiumDocuments, Mutable } from 'langium'; +import { + AstNode, + copyAstNode, + CstNode, + getContainerOfType, + getDocument, + LangiumDocuments, + Linker, + Mutable, + Reference, +} from 'langium'; import { URI, Utils } from 'vscode-uri'; export function extractDataModelsWithAllowRules(model: Model): DataModel[] { @@ -25,39 +36,63 @@ export function extractDataModelsWithAllowRules(model: Model): DataModel[] { ) as DataModel[]; } -export function mergeBaseModel(model: Model) { - model.declarations - .filter((x) => x.$type === 'DataModel') - .forEach((decl) => { - const dataModel = decl as DataModel; +type BuildReference = ( + node: AstNode, + property: string, + refNode: CstNode | undefined, + refText: string +) => Reference; + +export function mergeBaseModel(model: Model, linker: Linker) { + const buildReference = linker.buildReference.bind(linker); + + model.declarations.filter(isDataModel).forEach((decl) => { + const dataModel = decl as DataModel; - dataModel.fields = dataModel.superTypes + const bases = getRecursiveBases(dataModel).reverse(); + if (bases.length > 0) { + dataModel.fields = bases // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - .flatMap((superType) => updateContainer(superType.ref!.fields, dataModel)) + .flatMap((base) => base.fields) + // don't inherit skip-level fields + .filter((f) => !f.$inheritedFrom) + .map((f) => cloneAst(f, dataModel, buildReference)) .concat(dataModel.fields); - dataModel.attributes = dataModel.superTypes + dataModel.attributes = bases // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - .flatMap((superType) => updateContainer(superType.ref!.attributes, dataModel)) + .flatMap((base) => base.attributes) + // don't inherit skip-level attributes + .filter((attr) => !attr.$inheritedFrom) + // don't inherit `@@delegate` attribute + .filter((attr) => attr.decl.$refText !== '@@delegate') + .map((attr) => cloneAst(attr, dataModel, buildReference)) .concat(dataModel.attributes); - }); + } + + dataModel.$baseMerged = true; + }); // remove abstract models - model.declarations = model.declarations.filter((x) => !(x.$type == 'DataModel' && x.isAbstract)); + model.declarations = model.declarations.filter((x) => !(isDataModel(x) && x.isAbstract)); } -function updateContainer(nodes: T[], container: AstNode): Mutable[] { - return nodes.map((node) => { - const cloneField = Object.assign({}, node); - const mutable = cloneField as Mutable; - // update container - mutable.$container = container; - return mutable; - }); +// deep clone an AST, relink references, and set its container +function cloneAst( + node: T, + newContainer: AstNode, + buildReference: BuildReference +): Mutable { + const clone = copyAstNode(node, buildReference) as Mutable; + clone.$container = newContainer; + clone.$containerProperty = node.$containerProperty; + clone.$containerIndex = node.$containerIndex; + clone.$inheritedFrom = node.$inheritedFrom ?? getContainerOfType(node, isDataModel); + return clone; } export function getIdFields(dataModel: DataModel) { - const fieldLevelId = dataModel.$resolvedFields.find((f) => + const fieldLevelId = getModelFieldsWithBases(dataModel).find((f) => f.attributes.some((attr) => attr.decl.$refText === '@id') ); if (fieldLevelId) { @@ -67,7 +102,7 @@ export function getIdFields(dataModel: DataModel) { const modelIdAttr = dataModel.attributes.find((attr) => attr.decl?.ref?.name === '@@id'); if (modelIdAttr) { // get fields referenced in the attribute: @@id([field1, field2]]) - if (!isArrayExpr(modelIdAttr.args[0].value)) { + if (!isArrayExpr(modelIdAttr.args[0]?.value)) { return []; } const argValue = modelIdAttr.args[0].value; @@ -83,6 +118,10 @@ export function isAuthInvocation(node: AstNode) { return isInvocationExpr(node) && node.function.ref?.name === 'auth' && isFromStdlib(node.function.ref); } +export function isFutureInvocation(node: AstNode) { + return isInvocationExpr(node) && node.function.ref?.name === 'future' && isFromStdlib(node.function.ref); +} + export function getDataModelFieldReference(expr: Expression): DataModelField | undefined { if (isReferenceExpr(expr) && isDataModelField(expr.target.ref)) { return expr.target.ref; @@ -157,7 +196,6 @@ export function isCollectionPredicate(node: AstNode): node is BinaryExpr { return isBinaryExpr(node) && ['?', '!', '^'].includes(node.operator); } - export function getContainingDataModel(node: Expression): DataModel | undefined { let curr: AstNode | undefined = node.$container; while (curr) { @@ -167,4 +205,24 @@ export function getContainingDataModel(node: Expression): DataModel | undefined curr = curr.$container; } return undefined; -} \ No newline at end of file +} + +export function getModelFieldsWithBases(model: DataModel) { + if (model.$baseMerged) { + return model.fields; + } else { + return [...model.fields, ...getRecursiveBases(model).flatMap((base) => base.fields)]; + } +} + +export function getRecursiveBases(dataModel: DataModel): DataModel[] { + const result: DataModel[] = []; + dataModel.superTypes.forEach((superType) => { + const baseDecl = superType.ref; + if (baseDecl) { + result.push(baseDecl); + result.push(...getRecursiveBases(baseDecl)); + } + }); + return result; +} diff --git a/packages/schema/src/utils/exec-utils.ts b/packages/schema/src/utils/exec-utils.ts index f355ae2b4..d88e42b3d 100644 --- a/packages/schema/src/utils/exec-utils.ts +++ b/packages/schema/src/utils/exec-utils.ts @@ -1,9 +1,10 @@ -import { execSync as _exec, StdioOptions } from 'child_process'; +import { execSync as _exec, ExecSyncOptions } from 'child_process'; /** * Utility for executing command synchronously and prints outputs on current console */ -export function execSync(cmd: string, stdio: StdioOptions = 'inherit', env?: Record): void { - const mergedEnv = { ...process.env, ...env }; - _exec(cmd, { encoding: 'utf-8', stdio, env: mergedEnv }); +export function execSync(cmd: string, options?: Omit & { env?: Record }): void { + const { env, ...restOptions } = options ?? {}; + const mergedEnv = env ? { ...process.env, ...env } : undefined; + _exec(cmd, { encoding: 'utf-8', stdio: options?.stdio ?? 'inherit', env: mergedEnv, ...restOptions }); } diff --git a/packages/schema/tests/generator/expression-writer.test.ts b/packages/schema/tests/generator/expression-writer.test.ts index f9baa0de9..a4cc6ae5f 100644 --- a/packages/schema/tests/generator/expression-writer.test.ts +++ b/packages/schema/tests/generator/expression-writer.test.ts @@ -3,7 +3,7 @@ import { DataModel, Enum, Expression, isDataModel, isEnum } from '@zenstackhq/language/ast'; import * as tmp from 'tmp'; import { Project, VariableDeclarationKind } from 'ts-morph'; -import { ExpressionWriter } from '../../src/plugins/access-policy/expression-writer'; +import { ExpressionWriter } from '../../src/plugins/enhancer/policy/expression-writer'; import { loadModel } from '../utils'; describe('Expression Writer Tests', () => { diff --git a/packages/schema/tests/generator/prisma-builder.test.ts b/packages/schema/tests/generator/prisma-builder.test.ts index 48e465362..a3944401c 100644 --- a/packages/schema/tests/generator/prisma-builder.test.ts +++ b/packages/schema/tests/generator/prisma-builder.test.ts @@ -102,7 +102,7 @@ describe('Prisma Builder Tests', () => { undefined, new AttributeArgValue( 'FunctionCall', - new FunctionCall('dbgenerated', [new FunctionCallArg(undefined, '"timestamp_id()"')]) + new FunctionCall('dbgenerated', [new FunctionCallArg('"timestamp_id()"')]) ) ), ]), diff --git a/packages/schema/tests/generator/prisma-generator.test.ts b/packages/schema/tests/generator/prisma-generator.test.ts index 30a477026..67ba27f99 100644 --- a/packages/schema/tests/generator/prisma-generator.test.ts +++ b/packages/schema/tests/generator/prisma-generator.test.ts @@ -5,10 +5,27 @@ import fs from 'fs'; import path from 'path'; import tmp from 'tmp'; import { loadDocument } from '../../src/cli/cli-util'; -import PrismaSchemaGenerator from '../../src/plugins/prisma/schema-generator'; +import { PrismaSchemaGenerator } from '../../src/plugins/prisma/schema-generator'; +import { execSync } from '../../src/utils/exec-utils'; import { loadModel } from '../utils'; describe('Prisma generator test', () => { + let origDir: string; + + beforeEach(() => { + origDir = process.cwd(); + const r = tmp.dirSync({ unsafeCleanup: true }); + console.log(`Project dir: ${r.name}`); + process.chdir(r.name); + + execSync('npm init -y', { stdio: 'ignore' }); + execSync('npm install prisma'); + }); + + afterEach(() => { + process.chdir(origDir); + }); + it('datasource coverage', async () => { const model = await loadModel(` datasource db { @@ -32,15 +49,14 @@ describe('Prisma generator test', () => { } `); - const { name } = tmp.fileSync({ postfix: '.prisma' }); await new PrismaSchemaGenerator().generate(model, { name: 'Prisma', provider: '@core/prisma', schemaPath: 'schema.zmodel', - output: name, + output: 'schema.prisma', }); - const content = fs.readFileSync(name, 'utf-8'); + const content = fs.readFileSync('schema.prisma', 'utf-8'); expect(content).toContain('provider = "postgresql"'); expect(content).toContain('url = env("DATABASE_URL")'); expect(content).toContain('directUrl = env("DATABASE_URL")'); @@ -107,6 +123,7 @@ describe('Prisma generator test', () => { id String @id @default(nanoid(6)) x String @default(nanoid()) y String @default(dbgenerated("gen_random_uuid()")) + z String @default(auth().id) } `); @@ -126,6 +143,7 @@ describe('Prisma generator test', () => { expect(content).toContain('@default(nanoid(6))'); expect(content).toContain('@default(nanoid())'); expect(content).toContain('@default(dbgenerated("gen_random_uuid()"))'); + expect(content).not.toContain('@default(auth().id)'); }); it('triple slash comments', async () => { @@ -346,6 +364,7 @@ describe('Prisma generator test', () => { output: name, generateClient: false, }); + console.log('Generated:', name); const content = fs.readFileSync(name, 'utf-8'); const dmmf = await getDMMF({ datamodel: content }); @@ -354,9 +373,7 @@ describe('Prisma generator test', () => { const post = dmmf.datamodel.models[0]; expect(post.name).toBe('Post'); expect(post.fields.length).toBe(5); - expect(post.fields[0].name).toBe('id'); - expect(post.fields[3].name).toBe('title'); - expect(post.fields[4].name).toBe('published'); + expect(post.fields.map((f) => f.name)).toEqual(expect.arrayContaining(['id', 'title', 'published'])); }); it('abstract multi files', async () => { diff --git a/packages/schema/tests/schema/all-features.zmodel b/packages/schema/tests/schema/all-features.zmodel index c47a7cf79..b567093fe 100644 --- a/packages/schema/tests/schema/all-features.zmodel +++ b/packages/schema/tests/schema/all-features.zmodel @@ -40,7 +40,7 @@ model Space extends Base { createdAt DateTime @default(now()) updatedAt DateTime @updatedAt name String @length(4, 50) - slug String @unique @length(4, 16) + slug String @length(4, 16) owner User? @relation(fields: [ownerId], references: [id]) ownerId String? members SpaceUser[] @@ -58,6 +58,8 @@ model Space extends Base { // space admin can update and delete @@allow('update,delete', members?[user == auth() && role == ADMIN]) + + @@index([slug(ops: raw("gin_trgm_ops"))], type: Gin) } /* diff --git a/packages/schema/tests/schema/cal-com.zmodel b/packages/schema/tests/schema/cal-com.zmodel index c6e874304..a32bd45a6 100644 --- a/packages/schema/tests/schema/cal-com.zmodel +++ b/packages/schema/tests/schema/cal-com.zmodel @@ -11,13 +11,8 @@ generator client { previewFeatures = [] } -plugin meta { - provider = '@core/model-meta' - output = '.zenstack' -} - -plugin policy { - provider = '@core/access-policy' +plugin enhancer { + provider = '@core/enhancer' output = '.zenstack' } diff --git a/packages/schema/tests/schema/parser.test.ts b/packages/schema/tests/schema/parser.test.ts index 9b4150cd5..25ada5ceb 100644 --- a/packages/schema/tests/schema/parser.test.ts +++ b/packages/schema/tests/schema/parser.test.ts @@ -224,7 +224,6 @@ describe('Parsing Tests', () => { expect(((model.attributes[1].args[0].value as ArrayExpr).items[0] as ReferenceExpr).args[0]).toEqual( expect.objectContaining({ name: 'sort', - value: 'Asc', }) ); @@ -232,7 +231,6 @@ describe('Parsing Tests', () => { expect((model.attributes[2].args[0].value as ReferenceExpr).args[0]).toEqual( expect.objectContaining({ name: 'sort', - value: 'Desc', }) ); }); diff --git a/packages/schema/tests/schema/validation/attribute-validation.test.ts b/packages/schema/tests/schema/validation/attribute-validation.test.ts index 8b7886334..ac87665b1 100644 --- a/packages/schema/tests/schema/validation/attribute-validation.test.ts +++ b/packages/schema/tests/schema/validation/attribute-validation.test.ts @@ -161,11 +161,11 @@ describe('Attribute tests', () => { model A { x Int y String - @@id([x, y], name: 'x_y', map: '_x_y', length: 10, sort: 'Asc', clustered: true) + @@id([x, y], name: 'x_y', map: '_x_y', length: 10, sort: Asc, clustered: true) } model B { - id String @id(map: '_id', length: 10, sort: 'Asc', clustered: true) + id String @id(map: '_id', length: 10, sort: Asc, clustered: true) } `); @@ -175,7 +175,7 @@ describe('Attribute tests', () => { id String @id x Int y String - @@unique([x, y], name: 'x_y', map: '_x_y', length: 10, sort: 'Asc', clustered: true) + @@unique([x, y], name: 'x_y', map: '_x_y', length: 10, sort: Asc, clustered: true) } `); @@ -193,7 +193,7 @@ describe('Attribute tests', () => { ${prelude} model A { id String @id - x Int @unique(map: '_x', length: 10, sort: 'Asc', clustered: true) + x Int @unique(map: '_x', length: 10, sort: Asc, clustered: true) } `); @@ -222,7 +222,7 @@ describe('Attribute tests', () => { id String @id x Int y String - @@index([x(sort: Asc), y(sort: Desc)], name: 'myindex', map: '_myindex', length: 10, sort: 'asc', clustered: true, type: BTree) + @@index([x(sort: Asc), y(sort: Desc)], name: 'myindex', map: '_myindex', length: 10, sort: Asc, clustered: true, type: BTree) } `); @@ -251,6 +251,7 @@ describe('Attribute tests', () => { ${prelude} model _String { + id String @id _string String @db.String _string1 String @db.String(1) _text String @db.Text @@ -275,6 +276,7 @@ describe('Attribute tests', () => { } model _Boolean { + id String @id _boolean Boolean @db.Boolean _bit Boolean @db.Bit _bit1 Boolean @db.Bit(1) @@ -283,6 +285,7 @@ describe('Attribute tests', () => { } model _Int { + id String @id _int Int @db.Int _integer Int @db.Integer _smallInt Int @db.SmallInt @@ -298,12 +301,14 @@ describe('Attribute tests', () => { } model _BigInt { + id String @id _bigInt BigInt @db.BigInt _unsignedBigInt BigInt @db.UnsignedBigInt _int8 BigInt @db.Int8 } model _FloatDecimal { + id String @id _float Float @db.Float _decimal Decimal @db.Decimal _decimal1 Decimal @db.Decimal(10, 2) @@ -318,6 +323,7 @@ describe('Attribute tests', () => { } model _DateTime { + id String @id _dateTime DateTime @db.DateTime _dateTime2 DateTime @db.DateTime2 _smallDateTime DateTime @db.SmallDateTime @@ -334,11 +340,13 @@ describe('Attribute tests', () => { } model _Json { + id String @id _json Json @db.Json _jsonb Json @db.JsonB } model _Bytes { + id String @id _bytes Bytes @db.Bytes _byteA Bytes @db.ByteA _longBlob Bytes @db.LongBlob @@ -1009,6 +1017,35 @@ describe('Attribute tests', () => { }); it('auth function check', async () => { + await loadModel(` + ${prelude} + + model User { + id String @id + name String + } + model B { + id String @id + userId String @default(auth().id) + userName String @default(auth().name) + } + `); + + // expect( + // await loadModelWithError(` + // ${prelude} + + // model User { + // id String @id + // name String + // } + // model B { + // id String @id + // userData String @default(auth()) + // } + // `) + // ).toContain("Value is not assignable to parameter"); + expect( await loadModelWithError(` ${prelude} @@ -1059,11 +1096,14 @@ describe('Attribute tests', () => { model A { id String @id x Int + b B @relation(references: [id], fields: [bId]) + bId String @unique } model B { id String @id - a A + a A? + aId String @unique @@allow('all', a?[x > 0]) } `) @@ -1118,20 +1158,21 @@ describe('Attribute tests', () => { } model M { + id String @id e E @default(E1) } `); }); it('incorrect function expression context', async () => { - expect( - await loadModelWithError(` - ${prelude} - model M { - id String @id @default(auth()) - } - `) - ).toContain('function "auth" is not allowed in the current context: DefaultValue'); + // expect( + // await loadModelWithError(` + // ${prelude} + // model M { + // id String @id @default(auth()) + // } + // `) + // ).toContain('function "auth" is not allowed in the current context: DefaultValue'); expect( await loadModelWithError(` diff --git a/packages/schema/tests/schema/validation/datamodel-validation.test.ts b/packages/schema/tests/schema/validation/datamodel-validation.test.ts index e1f06d268..19535d5dd 100644 --- a/packages/schema/tests/schema/validation/datamodel-validation.test.ts +++ b/packages/schema/tests/schema/validation/datamodel-validation.test.ts @@ -120,16 +120,8 @@ describe('Data Model Validation Tests', () => { }); it('id field', async () => { - // no need for '@id' field when there's no access policy or field validation - await loadModel(` - ${prelude} - model M { - x Int - } - `); - const err = - 'Model must include a field with @id or @unique attribute, or a model-level @@id or @@unique attribute to use access policies'; + 'Model must have at least one unique criteria. Either mark a single field with `@id`, `@unique` or add a multi field criterion with `@@id([])` or `@@unique([])` to the model.'; expect( await loadModelWithError(` @@ -630,9 +622,10 @@ describe('Data Model Validation Tests', () => { b String } `); - expect(errors.length).toBe(1); - expect(errors[0]).toEqual(`Model A cannot be extended because it's not abstract`); + expect(errors[0]).toEqual( + 'Model must have at least one unique criteria. Either mark a single field with `@id`, `@unique` or add a multi field criterion with `@@id([])` or `@@unique([])` to the model.' + ); // relation incomplete from multiple level inheritance expect( diff --git a/packages/schema/tests/utils.ts b/packages/schema/tests/utils.ts index f88aae6e2..4dcd45170 100644 --- a/packages/schema/tests/utils.ts +++ b/packages/schema/tests/utils.ts @@ -16,7 +16,7 @@ export class SchemaLoadingError extends Error { export async function loadModel(content: string, validate = true, verbose = true, mergeBase = true) { const { name: docPath } = tmp.fileSync({ postfix: '.zmodel' }); fs.writeFileSync(docPath, content); - const { shared } = createZModelServices(NodeFileSystem); + const { shared, ZModel } = createZModelServices(NodeFileSystem); const stdLib = shared.workspace.LangiumDocuments.getOrCreateDocument( URI.file(path.resolve(__dirname, '../../schema/src/res/stdlib.zmodel')) ); @@ -52,7 +52,7 @@ export async function loadModel(content: string, validate = true, verbose = true const model = (await doc.parseResult.value) as Model; if (mergeBase) { - mergeBaseModel(model); + mergeBaseModel(model, ZModel.references.Linker); } return model; diff --git a/packages/sdk/package.json b/packages/sdk/package.json index 8fa8cf619..8d81caac9 100644 --- a/packages/sdk/package.json +++ b/packages/sdk/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/sdk", - "version": "1.8.2", + "version": "2.0.0-alpha.1", "description": "ZenStack plugin development SDK", "main": "index.js", "scripts": { @@ -23,10 +23,12 @@ "@prisma/internals-v5": "npm:@prisma/internals@^5.0.0", "@zenstackhq/language": "workspace:*", "@zenstackhq/runtime": "workspace:*", + "langium": "1.3.1", "lower-case-first": "^2.0.2", "prettier": "^2.8.3 || 3.x", "semver": "^7.5.2", "ts-morph": "^16.0.0", + "ts-pattern": "^4.3.0", "upper-case-first": "^2.0.2" }, "devDependencies": { diff --git a/packages/sdk/src/constants.ts b/packages/sdk/src/constants.ts index e038c6958..1e0d22d67 100644 --- a/packages/sdk/src/constants.ts +++ b/packages/sdk/src/constants.ts @@ -12,6 +12,7 @@ export enum ExpressionContext { DefaultValue = 'DefaultValue', AccessPolicy = 'AccessPolicy', ValidationRule = 'ValidationRule', + Index = 'Index', } export const STD_LIB_MODULE_NAME = 'stdlib.zmodel'; diff --git a/packages/sdk/src/index.ts b/packages/sdk/src/index.ts index 64060390e..660a79d0c 100644 --- a/packages/sdk/src/index.ts +++ b/packages/sdk/src/index.ts @@ -4,6 +4,8 @@ export { generate as generateModelMeta } from './model-meta-generator'; export * from './policy'; export * from './prisma'; export * from './types'; +export * from './typescript-expression-transformer'; export * from './utils'; export * from './validation'; export * from './zmodel-code-generator'; +export * from './z3-expression-transformer'; diff --git a/packages/sdk/src/model-meta-generator.ts b/packages/sdk/src/model-meta-generator.ts index 99029e610..cd516f5ec 100644 --- a/packages/sdk/src/model-meta-generator.ts +++ b/packages/sdk/src/model-meta-generator.ts @@ -12,28 +12,30 @@ import { ReferenceExpr, } from '@zenstackhq/language/ast'; import type { RuntimeAttribute } from '@zenstackhq/runtime'; +import { streamAst } from 'langium'; import { lowerCaseFirst } from 'lower-case-first'; -import { CodeBlockWriter, Project, VariableDeclarationKind } from 'ts-morph'; +import { CodeBlockWriter, Project, SourceFile, VariableDeclarationKind } from 'ts-morph'; import { - emitProject, + ExpressionContext, getAttribute, getAttributeArg, + getAttributeArgLiteral, getAttributeArgs, getAuthModel, getDataModels, getLiteral, hasAttribute, + isDelegateModel, + isAuthInvocation, isEnumFieldReference, isForeignKeyField, isIdField, resolved, - saveProject, + TypeScriptExpressionTransformer, } from '.'; export type ModelMetaGeneratorOptions = { output: string; - compile: boolean; - preserveTsFiles: boolean; generateAttributes: boolean; }; @@ -42,142 +44,217 @@ export async function generate(project: Project, models: DataModel[], options: M sf.addStatements('/* eslint-disable */'); sf.addVariableStatement({ declarationKind: VariableDeclarationKind.Const, - declarations: [{ name: 'metadata', initializer: (writer) => generateModelMetadata(models, writer, options) }], + declarations: [ + { name: 'metadata', initializer: (writer) => generateModelMetadata(models, sf, writer, options) }, + ], }); sf.addStatements('export default metadata;'); + return sf; +} + +function generateModelMetadata( + dataModels: DataModel[], + sourceFile: SourceFile, + writer: CodeBlockWriter, + options: ModelMetaGeneratorOptions +) { + writer.block(() => { + writeModels(sourceFile, writer, dataModels, options); + writeDeleteCascade(writer, dataModels); + writeAuthModel(writer, dataModels); + }); +} + +function writeModels( + sourceFile: SourceFile, + writer: CodeBlockWriter, + dataModels: DataModel[], + options: ModelMetaGeneratorOptions +) { + writer.write('models:'); + writer.block(() => { + for (const model of dataModels) { + writer.write(`${lowerCaseFirst(model.name)}:`); + writer.block(() => { + writer.write(`name: '${model.name}',`); + writeBaseTypes(writer, model); + writeFields(sourceFile, writer, model, options); + writeUniqueConstraints(writer, model); + if (options.generateAttributes) { + writeModelAttributes(writer, model); + } + writeDiscriminator(writer, model); + }); + writer.writeLine(','); + } + }); + writer.writeLine(','); +} - if (!options.compile || options.preserveTsFiles) { - // save ts files - await saveProject(project); +function writeBaseTypes(writer: CodeBlockWriter, model: DataModel) { + if (model.superTypes.length > 0) { + writer.write('baseTypes: ['); + writer.write(model.superTypes.map((t) => `'${t.ref?.name}'`).join(', ')); + writer.write('],'); } - if (options.compile) { - await emitProject(project); +} + +function writeAuthModel(writer: CodeBlockWriter, dataModels: DataModel[]) { + const authModel = getAuthModel(dataModels); + if (authModel) { + writer.writeLine(`authModel: '${authModel.name}'`); } } -function generateModelMetadata(dataModels: DataModel[], writer: CodeBlockWriter, options: ModelMetaGeneratorOptions) { +function writeDeleteCascade(writer: CodeBlockWriter, dataModels: DataModel[]) { + writer.write('deleteCascade:'); writer.block(() => { - writer.write('fields:'); - writer.block(() => { - for (const model of dataModels) { - writer.write(`${lowerCaseFirst(model.name)}:`); - writer.block(() => { - for (const f of model.fields) { - const backlink = getBackLink(f); - const fkMapping = generateForeignKeyMapping(f); - writer.write(`${f.name}: { - name: "${f.name}", - type: "${ - f.type.reference - ? f.type.reference.$refText - : // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - f.type.type! - }",`); - - if (isIdField(f)) { - writer.write(` - isId: true,`); - } - - if (isDataModel(f.type.reference?.ref)) { - writer.write(` - isDataModel: true,`); - } - - if (f.type.array) { - writer.write(` - isArray: true,`); - } - - if (f.type.optional) { - writer.write(` - isOptional: true,`); - } - - if (options.generateAttributes) { - const attrs = getFieldAttributes(f); - if (attrs.length > 0) { - writer.write(` - attributes: ${JSON.stringify(attrs)},`); - } - } else { - // only include essential attributes - const attrs = getFieldAttributes(f).filter((attr) => - ['@default', '@updatedAt'].includes(attr.name) - ); - if (attrs.length > 0) { - writer.write(` - attributes: ${JSON.stringify(attrs)},`); - } - } - - if (backlink) { - writer.write(` - backLink: '${backlink.name}',`); - } - - if (isRelationOwner(f, backlink)) { - writer.write(` - isRelationOwner: true,`); - } - - if (isForeignKeyField(f)) { - writer.write(` - isForeignKey: true,`); - } - - if (fkMapping && Object.keys(fkMapping).length > 0) { - writer.write(` - foreignKeyMapping: ${JSON.stringify(fkMapping)},`); - } - - if (isAutoIncrement(f)) { - writer.write(` - isAutoIncrement: true,`); - } - - writer.write(` - },`); - } - }); - writer.write(','); + for (const model of dataModels) { + const cascades = getDeleteCascades(model); + if (cascades.length > 0) { + writer.writeLine(`${lowerCaseFirst(model.name)}: [${cascades.map((n) => `'${n}'`).join(', ')}],`); } - }); - writer.write(','); + } + }); + writer.writeLine(','); +} +function writeUniqueConstraints(writer: CodeBlockWriter, model: DataModel) { + const constraints = getUniqueConstraints(model); + if (constraints.length > 0) { writer.write('uniqueConstraints:'); writer.block(() => { - for (const model of dataModels) { - writer.write(`${lowerCaseFirst(model.name)}:`); - writer.block(() => { - for (const constraint of getUniqueConstraints(model)) { - writer.write(`${constraint.name}: { - name: "${constraint.name}", - fields: ${JSON.stringify(constraint.fields)} - },`); - } - }); - writer.write(','); + for (const constraint of constraints) { + writer.write(`${constraint.name}: { + name: "${constraint.name}", + fields: ${JSON.stringify(constraint.fields)} + },`); } }); writer.write(','); + } +} - writer.write('deleteCascade:'); - writer.block(() => { - for (const model of dataModels) { - const cascades = getDeleteCascades(model); - if (cascades.length > 0) { - writer.writeLine(`${lowerCaseFirst(model.name)}: [${cascades.map((n) => `'${n}'`).join(', ')}],`); +function writeModelAttributes(writer: CodeBlockWriter, model: DataModel) { + const attrs = getAttributes(model); + if (attrs.length > 0) { + writer.write(` +attributes: ${JSON.stringify(attrs)},`); + } +} + +function writeDiscriminator(writer: CodeBlockWriter, model: DataModel) { + const delegateAttr = getAttribute(model, '@@delegate'); + if (!delegateAttr) { + return; + } + const discriminator = getAttributeArg(delegateAttr, 'discriminator') as ReferenceExpr; + if (!discriminator) { + return; + } + if (discriminator) { + writer.write(`discriminator: ${JSON.stringify(discriminator.target.$refText)},`); + } +} + +function writeFields( + sourceFile: SourceFile, + writer: CodeBlockWriter, + model: DataModel, + options: ModelMetaGeneratorOptions +) { + writer.write('fields:'); + writer.block(() => { + for (const f of model.fields) { + const backlink = getBackLink(f); + const fkMapping = generateForeignKeyMapping(f); + writer.write(`${f.name}: {`); + + writer.write(` + name: "${f.name}", + type: "${ + f.type.reference + ? f.type.reference.$refText + : // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + f.type.type! + }",`); + + if (isIdField(f)) { + writer.write(` + isId: true,`); + } + + if (isDataModel(f.type.reference?.ref)) { + writer.write(` + isDataModel: true,`); + } + + if (f.type.array) { + writer.write(` + isArray: true,`); + } + + if (f.type.optional) { + writer.write(` + isOptional: true,`); + } + + if (options.generateAttributes) { + const attrs = getAttributes(f); + if (attrs.length > 0) { + writer.write(` + attributes: ${JSON.stringify(attrs)},`); + } + } else { + // only include essential attributes + const attrs = getAttributes(f).filter((attr) => ['@default', '@updatedAt'].includes(attr.name)); + if (attrs.length > 0) { + writer.write(` + attributes: ${JSON.stringify(attrs)},`); } } - }); - writer.write(','); - const authModel = getAuthModel(dataModels); - if (authModel) { - writer.writeLine(`authModel: '${authModel.name}'`); + if (backlink) { + writer.write(` + backLink: '${backlink.name}',`); + } + + if (isRelationOwner(f, backlink)) { + writer.write(` + isRelationOwner: true,`); + } + + if (isForeignKeyField(f)) { + writer.write(` + isForeignKey: true,`); + } + + if (fkMapping && Object.keys(fkMapping).length > 0) { + writer.write(` + foreignKeyMapping: ${JSON.stringify(fkMapping)},`); + } + + const defaultValueProvider = generateDefaultValueProvider(f, sourceFile); + if (defaultValueProvider) { + writer.write(` + defaultValueProvider: ${defaultValueProvider},`); + } + + if (f.$inheritedFrom && isDelegateModel(f.$inheritedFrom) && !isIdField(f)) { + writer.write(` + inheritedFrom: ${JSON.stringify(f.$inheritedFrom.name)},`); + } + + if (isAutoIncrement(f)) { + writer.write(` + isAutoIncrement: true,`); + } + + writer.write(` + },`); } }); + writer.write(','); } function getBackLink(field: DataModelField) { @@ -206,13 +283,15 @@ function getBackLink(field: DataModelField) { } function getRelationName(field: DataModelField) { - const relAttr = field.attributes.find((attr) => attr.decl.ref?.name === 'relation'); - const relName = relAttr && relAttr.args?.[0] && getLiteral(relAttr.args?.[0].value); - return relName; + const relAttr = getAttribute(field, '@relation'); + if (!relAttr) { + return undefined; + } + return getAttributeArgLiteral(relAttr, 'name'); } -function getFieldAttributes(field: DataModelField): RuntimeAttribute[] { - return field.attributes +function getAttributes(target: DataModelField | DataModel): RuntimeAttribute[] { + return target.attributes .map((attr) => { const args: Array<{ name?: string; value: unknown }> = []; for (const arg of attr.args) { @@ -345,6 +424,39 @@ function getDeleteCascades(model: DataModel): string[] { .map((m) => m.name); } +function generateDefaultValueProvider(field: DataModelField, sourceFile: SourceFile) { + const defaultAttr = getAttribute(field, '@default'); + if (!defaultAttr) { + return undefined; + } + + const expr = defaultAttr.args[0]?.value; + if (!expr) { + return undefined; + } + + // find `auth()` in default value expression + const hasAuth = streamAst(expr).some(isAuthInvocation); + if (!hasAuth) { + return undefined; + } + + // generates a provider function like: + // function $default$Model$field(user: any) { ... } + const func = sourceFile.addFunction({ + name: `$default$${field.$container.name}$${field.name}`, + parameters: [{ name: 'user', type: 'any' }], + returnType: 'unknown', + statements: (writer) => { + const tsWriter = new TypeScriptExpressionTransformer({ context: ExpressionContext.DefaultValue }); + const code = tsWriter.transform(expr, false); + writer.write(`return ${code};`); + }, + }); + + return func.getName(); +} + function isAutoIncrement(field: DataModelField) { const defaultAttr = getAttribute(field, '@default'); if (!defaultAttr) { diff --git a/packages/sdk/src/prisma.ts b/packages/sdk/src/prisma.ts index 970ce58ba..77db556b4 100644 --- a/packages/sdk/src/prisma.ts +++ b/packages/sdk/src/prisma.ts @@ -1,15 +1,11 @@ /* eslint-disable @typescript-eslint/no-var-requires */ import type { DMMF } from '@prisma/generator-helper'; -import { getPrismaVersion } from '@zenstackhq/runtime'; import path from 'path'; import * as semver from 'semver'; import { GeneratorDecl, Model, Plugin, isGeneratorDecl, isPlugin } from './ast'; import { getLiteral } from './utils'; -// reexport -export { getPrismaVersion } from '@zenstackhq/runtime'; - /** * Given a ZModel and an import context directory, compute the import spec for the Prisma Client. */ @@ -91,3 +87,27 @@ export function getDMMF(options: GetDMMFOptions, defaultPrismaVersion?: string): return _getDMMF(options); } } + +/** + * Gets the installed Prisma's version + */ +export function getPrismaVersion(): string | undefined { + if (process.env.ZENSTACK_TEST === '1') { + // test environment + try { + return require(path.resolve('./node_modules/@prisma/client/package.json')).version; + } catch { + return undefined; + } + } + + try { + return require('@prisma/client/package.json').version; + } catch { + try { + return require('prisma/package.json').version; + } catch { + return undefined; + } + } +} diff --git a/packages/sdk/src/types.ts b/packages/sdk/src/types.ts index c19fdfc42..9fbbd5553 100644 --- a/packages/sdk/src/types.ts +++ b/packages/sdk/src/types.ts @@ -9,23 +9,18 @@ export type OptionValue = string | number | boolean; /** * Plugin configuration options */ -export type PluginOptions = { +export type PluginDeclaredOptions = { /*** * The provider package */ - provider?: string; - - /** - * The path of the ZModel schema - */ - schemaPath: string; - - /** - * The name of the plugin - */ - name: string; + provider: string; } & Record; +/** + * Plugin configuration options for execution + */ +export type PluginOptions = { schemaPath: string } & PluginDeclaredOptions; + /** * Global options that apply to all plugins */ diff --git a/packages/schema/src/utils/typescript-expression-transformer.ts b/packages/sdk/src/typescript-expression-transformer.ts similarity index 98% rename from packages/schema/src/utils/typescript-expression-transformer.ts rename to packages/sdk/src/typescript-expression-transformer.ts index cd868d76c..20585118c 100644 --- a/packages/schema/src/utils/typescript-expression-transformer.ts +++ b/packages/sdk/src/typescript-expression-transformer.ts @@ -17,9 +17,9 @@ import { ThisExpr, UnaryExpr, } from '@zenstackhq/language/ast'; -import { ExpressionContext, getLiteral, isFromStdlib, isFutureExpr } from '@zenstackhq/sdk'; import { match, P } from 'ts-pattern'; -import { getIdFields } from './ast-utils'; +import { ExpressionContext } from './constants'; +import { getIdFields, getLiteral, isFromStdlib, isFutureExpr } from './utils'; export class TypeScriptExpressionTransformerError extends Error { constructor(message: string) { diff --git a/packages/sdk/src/utils.ts b/packages/sdk/src/utils.ts index d32962f11..01d5d274d 100644 --- a/packages/sdk/src/utils.ts +++ b/packages/sdk/src/utils.ts @@ -22,6 +22,7 @@ import { isGeneratorDecl, isInvocationExpr, isLiteralExpr, + isMemberAccessExpr, isModel, isObjectExpr, isReferenceExpr, @@ -31,7 +32,7 @@ import { } from '@zenstackhq/language/ast'; import path from 'path'; import { ExpressionContext, STD_LIB_MODULE_NAME } from './constants'; -import { PluginError, PluginOptions } from './types'; +import { PluginError, type PluginDeclaredOptions, type PluginOptions } from './types'; /** * Gets data models that are not ignored @@ -177,7 +178,7 @@ export function isDataModelFieldReference(node: AstNode): node is ReferenceExpr * Gets `@@id` fields declared at the data model level */ export function getModelIdFields(model: DataModel) { - const idAttr = model.attributes.find((attr) => attr.decl.ref?.name === '@@id'); + const idAttr = model.attributes.find((attr) => attr.decl.$refText === '@@id'); if (!idAttr) { return []; } @@ -195,7 +196,7 @@ export function getModelIdFields(model: DataModel) { * Gets `@@unique` fields declared at the data model level */ export function getModelUniqueFields(model: DataModel) { - const uniqueAttr = model.attributes.find((attr) => attr.decl.ref?.name === '@@unique'); + const uniqueAttr = model.attributes.find((attr) => attr.decl.$refText === '@@unique'); if (!uniqueAttr) { return []; } @@ -288,7 +289,7 @@ export function resolvePath(_path: string, options: Pick(options: PluginOptions, name: string, pluginName: string): T { +export function requireOption(options: PluginDeclaredOptions, name: string, pluginName: string): T { const value = options[name]; if (value === undefined) { throw new PluginError(pluginName, `Plugin "${options.name}" is missing required option: ${name}`); @@ -296,8 +297,8 @@ export function requireOption(options: PluginOptions, name: string, pluginNam return value as T; } -export function parseOptionAsStrings(options: PluginOptions, optionaName: string, pluginName: string) { - const value = options[optionaName]; +export function parseOptionAsStrings(options: PluginDeclaredOptions, optionName: string, pluginName: string) { + const value = options[optionName]; if (value === undefined) { return undefined; } else if (typeof value === 'string') { @@ -312,7 +313,7 @@ export function parseOptionAsStrings(options: PluginOptions, optionaName: string } else { throw new PluginError( pluginName, - `Invalid "${optionaName}" option: must be a comma-separated string or an array of strings` + `Invalid "${optionName}" option: must be a comma-separated string or an array of strings` ); } } @@ -334,7 +335,11 @@ export function getFunctionExpressionContext(funcDecl: FunctionDecl) { } export function isFutureExpr(node: AstNode) { - return !!(isInvocationExpr(node) && node.function.ref?.name === 'future' && isFromStdlib(node.function.ref)); + return isInvocationExpr(node) && node.function.ref?.name === 'future' && isFromStdlib(node.function.ref); +} + +export function isAuthInvocation(node: AstNode) { + return isInvocationExpr(node) && node.function.ref?.name === 'auth' && isFromStdlib(node.function.ref); } export function isFromStdlib(node: AstNode) { @@ -373,3 +378,66 @@ export function getAuthModel(dataModels: DataModel[]) { } return authModel; } + +export function isDelegateModel(node: AstNode) { + return isDataModel(node) && hasAttribute(node, '@@delegate'); +} + +export function isDiscriminatorField(field: DataModelField) { + const model = field.$inheritedFrom ?? field.$container; + const delegateAttr = getAttribute(model, '@@delegate'); + if (!delegateAttr) { + return false; + } + const arg = delegateAttr.args[0]?.value; + return isDataModelFieldReference(arg) && arg.target.$refText === field.name; +} + +export function getIdFields(dataModel: DataModel) { + const fieldLevelId = getModelFieldsWithBases(dataModel).find((f) => + f.attributes.some((attr) => attr.decl.$refText === '@id') + ); + if (fieldLevelId) { + return [fieldLevelId]; + } else { + // get model level @@id attribute + const modelIdAttr = dataModel.attributes.find((attr) => attr.decl?.ref?.name === '@@id'); + if (modelIdAttr) { + // get fields referenced in the attribute: @@id([field1, field2]]) + if (!isArrayExpr(modelIdAttr.args[0].value)) { + return []; + } + const argValue = modelIdAttr.args[0].value; + return argValue.items + .filter((expr): expr is ReferenceExpr => isReferenceExpr(expr) && !!getDataModelFieldReference(expr)) + .map((expr) => expr.target.ref as DataModelField); + } + } + return []; +} + +export function getDataModelFieldReference(expr: Expression): DataModelField | undefined { + if (isReferenceExpr(expr) && isDataModelField(expr.target.ref)) { + return expr.target.ref; + } else if (isMemberAccessExpr(expr) && isDataModelField(expr.member.ref)) { + return expr.member.ref; + } else { + return undefined; + } +} + +export function getModelFieldsWithBases(model: DataModel) { + return [...model.fields, ...getRecursiveBases(model).flatMap((base) => base.fields)]; +} + +export function getRecursiveBases(dataModel: DataModel): DataModel[] { + const result: DataModel[] = []; + dataModel.superTypes.forEach((superType) => { + const baseDecl = superType.ref; + if (baseDecl) { + result.push(baseDecl); + result.push(...getRecursiveBases(baseDecl)); + } + }); + return result; +} diff --git a/packages/sdk/src/z3-expression-transformer.ts b/packages/sdk/src/z3-expression-transformer.ts new file mode 100644 index 000000000..a68f866bf --- /dev/null +++ b/packages/sdk/src/z3-expression-transformer.ts @@ -0,0 +1,509 @@ +import { + ArrayExpr, + BinaryExpr, + BooleanLiteral, + // DataModel, + Expression, + InvocationExpr, + isBooleanLiteral, + isDataModel, + isEnumField, + isMemberAccessExpr, + isNullExpr, + isThisExpr, + LiteralExpr, + MemberAccessExpr, + NullExpr, + NumberLiteral, + ReferenceExpr, + StringLiteral, + ThisExpr, + UnaryExpr, +} from '@zenstackhq/language/ast'; +import { match, P } from 'ts-pattern'; +import { ExpressionContext } from './constants'; +import { getLiteral, isAuthInvocation, isFromStdlib, isFutureExpr } from './utils'; + +export class Z3ExpressionTransformerError extends Error { + constructor(message: string) { + super(message); + } +} + +type Options = { + isPostGuard?: boolean; + fieldReferenceContext?: string; + thisExprContext?: string; + context: ExpressionContext; +}; + +// a registry of function handlers marked with @func +const functionHandlers = new Map(); + +// function handler decorator +function func(name: string) { + return function (target: unknown, propertyKey: string, descriptor: PropertyDescriptor) { + if (!functionHandlers.get(name)) { + functionHandlers.set(name, descriptor); + } + return descriptor; + }; +} + +/** + * Transforms ZModel expression to Z3 assertion. + */ +export class Z3ExpressionTransformer { + /** + * Constructs a new Z3ExpressionTransformer. + * + * @param isPostGuard indicates if we're writing for post-update conditions + */ + constructor(private readonly options: Options) {} + + /** + * Transforms the given expression to a TypeScript expression. + * @returns + */ + transform(expr: Expression): string { + switch (expr.$type) { + case StringLiteral: + case NumberLiteral: + return this.literal(expr as LiteralExpr); + + case BooleanLiteral: + return this.boolean(expr as BooleanLiteral); + + case ArrayExpr: + return this.array(expr as ArrayExpr); + + case NullExpr: + return this.null(); + + case ThisExpr: + return this.this(expr as ThisExpr); + + case ReferenceExpr: + return this.reference(expr as ReferenceExpr); + + case InvocationExpr: + return this.invocation(expr as InvocationExpr); + + case MemberAccessExpr: + return this.memberAccess(expr as MemberAccessExpr); + + case UnaryExpr: + return this.unary(expr as UnaryExpr); + + case BinaryExpr: + // eslint-disable-next-line no-case-declarations + const assertion = this.binary(expr as BinaryExpr); + if (['&&', '||'].includes(expr.operator)) return assertion; + // eslint-disable-next-line no-case-declarations + const checkString = + expr.left.$type === 'ReferenceExpr' && expr.right.$type === 'StringLiteral' + ? { [expr.left.target.ref?.name ?? '']: expr.right.value } + : {}; + if (Object.keys(checkString).length > 0) { + return `z3.And(${assertion}, buildAssertion(z3, variables, args, user, ${JSON.stringify( + checkString + )}))`; + } + return assertion; + + default: + throw new Z3ExpressionTransformerError(`Unsupported expression type: ${expr.$type}`); + } + } + + private this(_expr: ThisExpr) { + // "this" is mapped to the input argument + return this.options.thisExprContext ?? 'input'; + } + + private memberAccess(expr: MemberAccessExpr) { + if (!expr.member.ref) { + throw new Z3ExpressionTransformerError(`Unresolved MemberAccessExpr`); + } + + if (isThisExpr(expr.operand)) { + return expr.member.ref.name; + } else if (isFutureExpr(expr.operand)) { + if (this.options?.isPostGuard !== true) { + throw new Z3ExpressionTransformerError(`future() is only supported in postUpdate rules`); + } + return expr.member.ref.name; + } else { + return `${this.transform(expr.operand)}?.${expr.member.ref.name}`; + } + } + + private invocation(expr: InvocationExpr) { + if (!expr.function.ref) { + throw new Z3ExpressionTransformerError(`Unresolved InvocationExpr`); + } + + const funcName = expr.function.ref.name; + const isStdFunc = isFromStdlib(expr.function.ref); + + if (!isStdFunc) { + throw new Z3ExpressionTransformerError('User-defined functions are not supported yet'); + } + + const handler = functionHandlers.get(funcName); + if (!handler) { + throw new Z3ExpressionTransformerError(`Unsupported function: ${funcName}`); + } + + const args = expr.args.map((arg) => arg.value); + return handler.value.call(this, args); + } + + // #region function invocation handlers + + // arguments have been type-checked + + @func('auth') + private _auth() { + return 'user'; + } + + @func('now') + private _now() { + return `(new Date())`; + } + + @func('length') + private _length(args: Expression[]) { + const field = this.transform(args[0]); + const min = getLiteral(args[1]); + const max = getLiteral(args[2]); + let result: string; + if (min === undefined) { + result = `(${field}?.length > 0)`; + } else if (max === undefined) { + result = `(${field}?.length >= ${min})`; + } else { + result = `(${field}?.length >= ${min} && ${field}?.length <= ${max})`; + } + return this.ensureBoolean(result); + } + + @func('contains') + private _contains(args: Expression[]) { + const field = this.transform(args[0]); + const caseInsensitive = getLiteral(args[2]) === true; + let result: string; + if (caseInsensitive) { + result = `${field}?.toLowerCase().includes(${this.transform(args[1])}?.toLowerCase())`; + } else { + result = `${field}?.includes(${this.transform(args[1])})`; + } + return this.ensureBoolean(result); + } + + @func('startsWith') + private _startsWith(args: Expression[]) { + const field = this.transform(args[0]); + const result = `${field}?.startsWith(${this.transform(args[1])})`; + return this.ensureBoolean(result); + } + + @func('endsWith') + private _endsWith(args: Expression[]) { + const field = this.transform(args[0]); + const result = `${field}?.endsWith(${this.transform(args[1])})`; + return this.ensureBoolean(result); + } + + @func('regex') + private _regex(args: Expression[]) { + const field = this.transform(args[0]); + const pattern = getLiteral(args[1]); + return `new RegExp(${JSON.stringify(pattern)}).test(${field})`; + } + + @func('email') + private _email(args: Expression[]) { + const field = this.transform(args[0]); + return `z.string().email().safeParse(${field}).success`; + } + + @func('datetime') + private _datetime(args: Expression[]) { + const field = this.transform(args[0]); + return `z.string().datetime({ offset: true }).safeParse(${field}).success`; + } + + @func('url') + private _url(args: Expression[]) { + const field = this.transform(args[0]); + return `z.string().url().safeParse(${field}).success`; + } + + @func('has') + private _has(args: Expression[]) { + const field = this.transform(args[0]); + const result = `${field}?.includes(${this.transform(args[1])})`; + return this.ensureBoolean(result); + } + + @func('hasEvery') + private _hasEvery(args: Expression[]) { + const field = this.transform(args[0]); + const result = `${this.transform(args[1])}?.every((item) => ${field}?.includes(item))`; + return this.ensureBoolean(result); + } + + @func('hasSome') + private _hasSome(args: Expression[]) { + const field = this.transform(args[0]); + const result = `${this.transform(args[1])}?.some((item) => ${field}?.includes(item))`; + return this.ensureBoolean(result); + } + + @func('isEmpty') + private _isEmpty(args: Expression[]) { + const field = this.transform(args[0]); + const result = `(!${field} || ${field}?.length === 0)`; + return this.ensureBoolean(result); + } + + private ensureBoolean(expr: string) { + return `(${expr} ?? false)`; + } + + // #endregion + + private reference(expr: ReferenceExpr) { + if (!expr.target.ref) { + throw new Z3ExpressionTransformerError(`Unresolved ReferenceExpr`); + } + + if (isEnumField(expr.target.ref)) { + return `${expr.target.ref.$container.name}.${expr.target.ref.name}`; + } else { + // const formattedName = `${lowerCaseFirst(expr.target.ref.$container.name)}${upperCaseFirst( + // expr.target.ref.name + // )}`; + return `_${expr.target.ref.name}`; + } + } + + private null() { + return 'undefined'; + } + + private array(expr: ArrayExpr) { + return `[${expr.items.map((item) => this.transform(item)).join(', ')}]`; + } + + private literal(expr: LiteralExpr) { + if (expr.$type === StringLiteral) { + return `'${expr.value}'`; + } else { + return expr.value.toString(); + } + } + + private boolean(expr: BooleanLiteral) { + return `z3.Bool.val(${expr.value})`; + } + + private unary(expr: UnaryExpr): string { + if (expr.operator !== '!') { + throw new Z3ExpressionTransformerError(`Unsupported unary operator: ${expr.operator}`); + } + return `z3.Not(${this.transform(expr.operand)})`; + } + + private isModelType(expr: Expression) { + return isDataModel(expr.$resolvedType?.decl); + } + + private binary(expr: BinaryExpr): string { + if (/* expr.left.$type === 'ReferenceExpr' && */ expr.right.$type === 'StringLiteral') return 'true'; + + let left = this.transform(expr.left); + let right = isBooleanLiteral(expr.right) ? `${expr.right.value}` : this.transform(expr.right); + + // TODO: improve handling of null expressions + if (isNullExpr(expr.right)) { + return `${this.withArgs(left)} ${expr.operator} ${right}`; + } + + // if (isMemberAccessExpr(expr.left) && !isAuthInvocation(expr.left)) { + // left = `args.${left}`; + // } + // if (isMemberAccessExpr(expr.right) && !isAuthInvocation(expr.right)) { + // right = `args.${right}`; + // } + // if (this.isModelType(expr.left)) { + // left = `(${left}.id)`; + // } + // if (this.isModelType(expr.right)) { + // right = `(${right}.id)`; + // } + if (this.isModelType(expr.left) && this.isModelType(expr.right)) { + // comparison between model type values, map to id comparison + left = isAuthInvocation(expr.left) + ? `(${left}?.id)` + : `((${this.withArgs(left)}?.id || ${this.withArgs(left)}Id))`; + right = isAuthInvocation(expr.right) + ? `(${right}?.id)` + : `((${this.withArgs(right)}?.id || ${this.withArgs(right)}Id))`; + let assertion = `${left} ${expr.operator} ${right}`; + + // only args values need implies + if (isAuthInvocation(expr.left) && (isMemberAccessExpr(expr.right) || this.isModelType(expr.right))) { + assertion = `z3.Implies(!!${right}, ${assertion})`; + } + if (isAuthInvocation(expr.right) && (isMemberAccessExpr(expr.left) || this.isModelType(expr.left))) { + assertion = `z3.Implies(!!${left}, ${assertion})`; + } + // TODO: handle strict equality comparison (===, !==, etc.) + return this.withAuth(expr, assertion); + } + + if (isAuthInvocation(expr.left) || isAuthInvocation(expr.right)) { + left = isAuthInvocation(expr.left) ? `(${left}?.id)` : left; + right = isAuthInvocation(expr.right) ? `(${right}.id)` : right; + const assertion = `${left} ${expr.operator} ${right}`; + if (this.needAuthCheck(expr)) { + return this.withAuth(expr, assertion); + } + return assertion; + } + + // auth().string compared to string argument + // TODO: for other type we could want to add a constraint to the auth model => we have to create a variable for it + if (this.isAuthComparison(left, right)) { + left = + isMemberAccessExpr(expr.left) && !this.isAuthMemberAccessExpr(expr.left, left) + ? `${this.withArgs(left)}` + : left; + right = + isMemberAccessExpr(expr.right) && !this.isAuthMemberAccessExpr(expr.right, right) + ? `${this.withArgs(right)}` + : right; + let assertion = `${left} ${expr.operator} ${right}`; + if (this.isAuthMemberAccessExpr(expr.left, left)) { + assertion = `z3.Implies(!!${right}, ${assertion})`; + } else if (this.isAuthMemberAccessExpr(expr.right, right)) { + assertion = `z3.Implies(!!${left}, ${assertion})`; + } + return this.withAuth(expr, assertion, true); + } + + const _default = `(${left} ${expr.operator} ${right})`; + + // if (expr.left.$type === 'ReferenceExpr') { + // left = `${lowerCaseFirst(expr.left.target.ref?.$container.name ?? '')}${upperCaseFirst( + // expr.left.target?.ref?.name ?? '' + // )}`; + // } + + return ( + match(expr.operator) + .with('||', () => `z3.Or(${left}, ${right})`) + .with('&&', () => `z3.And(${left}, ${right})`) + .with('==', () => `${left}.eq(${right})`) + .with('!=', () => `${left}.neq(${right})`) + .with('<', () => `${left}.lt(${right})`) + .with('<=', () => `${left}.le(${right})`) + .with('>', () => `${left}.gt(${right})`) + .with('>=', () => `${left}.ge(${right})`) + .with('in', () => `(${this.transform(expr.right)}?.includes(${this.transform(expr.left)}) ?? false)`) + // .with(P.union('==', '!='), () => { + // if (isThisExpr(expr.left) || isThisExpr(expr.right)) { + // // map equality comparison with `this` to id comparison + // const _this = isThisExpr(expr.left) ? expr.left : expr.right; + // const model = _this.$resolvedType?.decl as DataModel; + // const idFields = getIdFields(model); + // if (!idFields || idFields.length === 0) { + // throw new Z3ExpressionTransformerError(`model "${model.name}" does not have an id field`); + // } + // let result = `allFieldsEqual(${this.transform(expr.left, false)}, + // ${this.transform(expr.right, false)}, [${idFields.map((f) => "'" + f.name + "'").join(', ')}])`; + // if (expr.operator === '!=') { + // result = `!${result}`; + // } + // return result; + // } else { + // return _default; + // } + // }) + .with(P.union('?', '!', '^'), (op) => this.collectionPredicate(expr, op)) + .otherwise(() => _default) + ); + } + + private collectionPredicate(expr: BinaryExpr, operator: '?' | '!' | '^') { + const operand = this.transform(expr.left); + const innerTransformer = new Z3ExpressionTransformer({ + ...this.options, + isPostGuard: false, + fieldReferenceContext: '_item', + thisExprContext: '_item', + }); + const predicate = innerTransformer.transform(expr.right); + + return match(operator) + .with('?', () => `!!((${operand})?.some((_item: any) => ${predicate}))`) + .with('!', () => `!!((${operand})?.every((_item: any) => ${predicate}))`) + .with('^', () => `!((${operand})?.some((_item: any) => ${predicate}))`) + .exhaustive(); + } + + private needAuthCheck(expr: BinaryExpr) { + return ( + (isAuthInvocation(expr.left) && !(isNullExpr(expr.right) && expr.operator === '==')) || + isAuthInvocation(expr.right) + ); + } + + private withAuth(expr: BinaryExpr, assertion: string, forceAuth = false) { + if (this.needAuthCheck(expr) || forceAuth) { + return `z3.And(${assertion}, _withAuth)`; + } + return assertion; + } + + private withArgs(str: string) { + if (str.startsWith('_')) { + str = str.substring(1); + } + return `args.${str}`; + } + + // private isAuthProperty(expr: Expression) { + // return isMemberAccessExpr(expr) && expr.member.ref?.$container.name === 'User'; // TODO: how to get auth model name? + // } + + private isAuthMemberAccessExpr(expr: Expression, transformedExpr: string) { + return isMemberAccessExpr(expr) && transformedExpr.startsWith('user?.'); + } + + private isAuthComparison(left: string, right: string) { + return left.startsWith('user?.') || right.startsWith('user?.'); + } + + // private getModelFromMemberAccess(expr: MemberAccessExpr) { + // return expr.member.; + // } +} + +// false : +// const age = Z3.Int.const('age'); +// const assertion = Z3.And( +// Z3.Not(age.gt(18)), +// Z3.Not(age.lt(60)) +// ); +// Z3.solve(assertion); + +// true : +// const age = Z3.Int.const('age'); +// const assertion = Z3.Not(Z3.And( +// age.gt(18), age.lt(60) +// )); +// Z3.solve(assertion); diff --git a/packages/sdk/src/zmodel-code-generator.ts b/packages/sdk/src/zmodel-code-generator.ts index 1b1f001e1..96aaa87d9 100644 --- a/packages/sdk/src/zmodel-code-generator.ts +++ b/packages/sdk/src/zmodel-code-generator.ts @@ -49,6 +49,7 @@ export interface ZModelCodeOptions { binaryExprNumberOfSpaces: number; unaryExprNumberOfSpaces: number; indent: number; + quote: 'single' | 'double'; } // a registry of generation handlers marked with @gen @@ -75,6 +76,7 @@ export class ZModelCodeGenerator { binaryExprNumberOfSpaces: options?.binaryExprNumberOfSpaces ?? 1, unaryExprNumberOfSpaces: options?.unaryExprNumberOfSpaces ?? 0, indent: options?.indent ?? 4, + quote: options?.quote ?? 'single', }; } @@ -224,7 +226,7 @@ ${ast.fields.map((x) => this.indent + this.generate(x)).join('\n')}${ @gen(StringLiteral) private _generateLiteralExpr(ast: LiteralExpr) { - return `'${ast.value}'`; + return this.options.quote === 'single' ? `'${ast.value}'` : `"${ast.value}"`; } @gen(NumberLiteral) @@ -265,7 +267,7 @@ ${ast.fields.map((x) => this.indent + this.generate(x)).join('\n')}${ @gen(ReferenceArg) private _generateReferenceArg(ast: ReferenceArg) { - return `${ast.name}:${ast.value}`; + return `${ast.name}:${this.generate(ast.value)}`; } @gen(MemberAccessExpr) @@ -321,7 +323,7 @@ ${ast.fields.map((x) => this.indent + this.generate(x)).join('\n')}${ } private argument(ast: Argument) { - return `${ast.name ? ast.name + ': ' : ''}${this.generate(ast.value)}`; + return this.generate(ast.value); } private get binaryExprSpace() { diff --git a/packages/server/package.json b/packages/server/package.json index f4b4b68ab..f05366228 100644 --- a/packages/server/package.json +++ b/packages/server/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/server", - "version": "1.8.2", + "version": "2.0.0-alpha.1", "displayName": "ZenStack Server-side Adapters", "description": "ZenStack server-side adapters", "homepage": "https://zenstack.dev", diff --git a/packages/server/src/api/base.ts b/packages/server/src/api/base.ts index ba385f31c..96c547204 100644 --- a/packages/server/src/api/base.ts +++ b/packages/server/src/api/base.ts @@ -1,5 +1,6 @@ -import { DbClientContract, ModelMeta, ZodSchemas, getDefaultModelMeta } from '@zenstackhq/runtime'; -import { LoggerConfig } from '../types'; +import type { DbClientContract, ModelMeta, ZodSchemas } from '@zenstackhq/runtime'; +import { getDefaultModelMeta } from '../shared'; +import type { LoggerConfig } from '../types'; /** * API request context diff --git a/packages/server/src/api/rest/index.ts b/packages/server/src/api/rest/index.ts index 88b463c80..52d700c63 100644 --- a/packages/server/src/api/rest/index.ts +++ b/packages/server/src/api/rest/index.ts @@ -956,7 +956,7 @@ class RequestHandler extends APIHandlerBase { private buildTypeMap(logger: LoggerConfig | undefined, modelMeta: ModelMeta): void { this.typeMap = {}; - for (const [model, fields] of Object.entries(modelMeta.fields)) { + for (const [model, { fields }] of Object.entries(modelMeta.models)) { const idFields = getIdFields(modelMeta, model); if (idFields.length === 0) { logWarning(logger, `Not including model ${model} in the API because it has no ID field`); @@ -1013,7 +1013,7 @@ class RequestHandler extends APIHandlerBase { this.serializers = new Map(); const linkers: Record> = {}; - for (const model of Object.keys(modelMeta.fields)) { + for (const model of Object.keys(modelMeta.models)) { const ids = getIdFields(modelMeta, model); if (ids.length !== 1) { continue; @@ -1027,7 +1027,7 @@ class RequestHandler extends APIHandlerBase { linkers[model] = linker; let projection: Record | null = {}; - for (const [field, fieldMeta] of Object.entries(modelMeta.fields[model])) { + for (const [field, fieldMeta] of Object.entries(modelMeta.models[model].fields)) { if (fieldMeta.isDataModel) { projection[field] = 0; } @@ -1049,14 +1049,14 @@ class RequestHandler extends APIHandlerBase { } // set relators - for (const model of Object.keys(modelMeta.fields)) { + for (const model of Object.keys(modelMeta.models)) { const serializer = this.serializers.get(model); if (!serializer) { continue; } const relators: Record> = {}; - for (const [field, fieldMeta] of Object.entries(modelMeta.fields[model])) { + for (const [field, fieldMeta] of Object.entries(modelMeta.models[model].fields)) { if (!fieldMeta.isDataModel) { continue; } diff --git a/packages/server/src/shared.ts b/packages/server/src/shared.ts index 6001fbbaa..1a9c62119 100644 --- a/packages/server/src/shared.ts +++ b/packages/server/src/shared.ts @@ -1,4 +1,6 @@ -import { ZodSchemas, getDefaultModelMeta, getDefaultZodSchemas } from '@zenstackhq/runtime'; +/* eslint-disable @typescript-eslint/no-var-requires */ +import type { ModelMeta, PolicyDef, ZodSchemas } from '@zenstackhq/runtime'; +import path from 'path'; import { AdapterBaseOptions } from './types'; export function loadAssets(options: AdapterBaseOptions) { @@ -18,3 +20,88 @@ export function loadAssets(options: AdapterBaseOptions) { return { modelMeta, zodSchemas }; } + +/** + * Load model metadata. + * + * @param loadPath The path to load model metadata from. If not provided, + * will use default load path. + */ +export function getDefaultModelMeta(loadPath: string | undefined): ModelMeta { + try { + if (loadPath) { + const toLoad = path.resolve(loadPath, 'model-meta'); + return require(toLoad).default; + } else { + return require('.zenstack/model-meta').default; + } + } catch { + if (process.env.ZENSTACK_TEST === '1' && !loadPath) { + try { + // special handling for running as tests, try resolving relative to CWD + return require(path.join(process.cwd(), 'node_modules', '.zenstack', 'model-meta')).default; + } catch { + throw new Error('Model meta cannot be loaded. Please make sure "zenstack generate" has been run.'); + } + } + throw new Error('Model meta cannot be loaded. Please make sure "zenstack generate" has been run.'); + } +} + +/** + * Load access policies. + * + * @param loadPath The path to load access policies from. If not provided, + * will use default load path. + */ +export function getDefaultPolicy(loadPath: string | undefined): PolicyDef { + try { + if (loadPath) { + const toLoad = path.resolve(loadPath, 'policy'); + return require(toLoad).default; + } else { + return require('.zenstack/policy').default; + } + } catch { + if (process.env.ZENSTACK_TEST === '1' && !loadPath) { + try { + // special handling for running as tests, try resolving relative to CWD + return require(path.join(process.cwd(), 'node_modules', '.zenstack', 'policy')).default; + } catch { + throw new Error( + 'Policy definition cannot be loaded from default location. Please make sure "zenstack generate" has been run.' + ); + } + } + throw new Error( + 'Policy definition cannot be loaded from default location. Please make sure "zenstack generate" has been run.' + ); + } +} + +/** + * Load zod schemas. + * + * @param loadPath The path to load zod schemas from. If not provided, + * will use default load path. + */ +export function getDefaultZodSchemas(loadPath: string | undefined): ZodSchemas | undefined { + try { + if (loadPath) { + const toLoad = path.resolve(loadPath, 'zod'); + return require(toLoad); + } else { + return require('.zenstack/zod'); + } + } catch { + if (process.env.ZENSTACK_TEST === '1' && !loadPath) { + try { + // special handling for running as tests, try resolving relative to CWD + return require(path.join(process.cwd(), 'node_modules', '.zenstack', 'zod')); + } catch { + return undefined; + } + } + return undefined; + } +} diff --git a/packages/server/tests/api/rest.test.ts b/packages/server/tests/api/rest.test.ts index 7b084ef8a..770d05017 100644 --- a/packages/server/tests/api/rest.test.ts +++ b/packages/server/tests/api/rest.test.ts @@ -1,7 +1,7 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ /// -import { CrudFailureReason, ModelMeta, withPolicy } from '@zenstackhq/runtime'; +import { CrudFailureReason, type ModelMeta } from '@zenstackhq/runtime'; import { loadSchema, run } from '@zenstackhq/testtools'; import { Decimal } from 'decimal.js'; import SuperJSON from 'superjson'; @@ -1882,7 +1882,7 @@ describe('REST server tests', () => { beforeAll(async () => { const params = await loadSchema(schema); - prisma = withPolicy(params.prisma, undefined, params); + prisma = params.enhanceRaw(params.prisma, params); zodSchemas = params.zodSchemas; modelMeta = params.modelMeta; @@ -1995,7 +1995,7 @@ describe('REST server tests', () => { beforeAll(async () => { const params = await loadSchema(schema); - prisma = withPolicy(params.prisma, undefined, params); + prisma = params.enhanceRaw(params.prisma, params); zodSchemas = params.zodSchemas; modelMeta = params.modelMeta; diff --git a/packages/testtools/package.json b/packages/testtools/package.json index 107b4d659..3472ddca0 100644 --- a/packages/testtools/package.json +++ b/packages/testtools/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/testtools", - "version": "1.8.2", + "version": "2.0.0-alpha.1", "description": "ZenStack Test Tools", "main": "index.js", "private": true, @@ -24,7 +24,7 @@ "@zenstackhq/runtime": "workspace:*", "@zenstackhq/sdk": "workspace:*", "json5": "^2.2.3", - "langium": "1.2.0", + "langium": "1.3.1", "pg": "^8.11.1", "tmp": "^0.2.1", "vscode-uri": "^3.0.6", diff --git a/packages/testtools/src/model.ts b/packages/testtools/src/model.ts index 4be8a1613..29b15467d 100644 --- a/packages/testtools/src/model.ts +++ b/packages/testtools/src/model.ts @@ -16,7 +16,7 @@ export class SchemaLoadingError extends Error { export async function loadModel(content: string, validate = true, verbose = true) { const { name: docPath } = tmp.fileSync({ postfix: '.zmodel' }); fs.writeFileSync(docPath, content); - const { shared } = createZModelServices(NodeFileSystem); + const { shared, ZModel } = createZModelServices(NodeFileSystem); const stdLib = shared.workspace.LangiumDocuments.getOrCreateDocument( URI.file(path.resolve(__dirname, '../../schema/src/res/stdlib.zmodel')) ); @@ -51,7 +51,7 @@ export async function loadModel(content: string, validate = true, verbose = true const model = (await doc.parseResult.value) as Model; - mergeBaseModel(model); + mergeBaseModel(model, ZModel.references.Linker); return model; } diff --git a/packages/testtools/src/package.template.json b/packages/testtools/src/package.template.json index 8ea542361..ec738a4c5 100644 --- a/packages/testtools/src/package.template.json +++ b/packages/testtools/src/package.template.json @@ -7,15 +7,16 @@ "author": "", "license": "ISC", "dependencies": { - "@prisma/client": "^4.8.0", + "@prisma/client": "^5.7.1", "@zenstackhq/runtime": "file:/packages/runtime/dist", "@zenstackhq/swr": "file:/packages/plugins/swr/dist", "@zenstackhq/trpc": "file:/packages/plugins/trpc/dist", "@zenstackhq/openapi": "file:/packages/plugins/openapi/dist", - "prisma": "^4.8.0", + "prisma": "^5.7.1", "typescript": "^4.9.3", "zenstack": "file:/packages/schema/dist", "zod": "^3.22.4", - "decimal.js": "^10.4.2" + "decimal.js": "^10.4.2", + "z3-solver": "^4.12.5" } } diff --git a/packages/testtools/src/schema.ts b/packages/testtools/src/schema.ts index f69a845cc..bd64d6461 100644 --- a/packages/testtools/src/schema.ts +++ b/packages/testtools/src/schema.ts @@ -2,7 +2,7 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import type { DMMF } from '@prisma/generator-helper'; import type { Model } from '@zenstackhq/language/ast'; -import { enhance, withOmit, withPassword, withPolicy, type AuthUser, type DbOperations } from '@zenstackhq/runtime'; +import type { AuthUser, CrudContract, EnhancementKind, EnhancementOptions } from '@zenstackhq/runtime'; import { getDMMF } from '@zenstackhq/sdk'; import { execSync } from 'child_process'; import * as fs from 'fs'; @@ -24,7 +24,7 @@ import prismaPlugin from 'zenstack/plugins/prisma'; */ export const FILE_SPLITTER = '#FILE_SPLITTER#'; -export type FullDbClientContract = Record & { +export type FullDbClientContract = CrudContract & { $on(eventType: any, callback: (event: any) => void): void; $use(cb: any): void; $disconnect: () => Promise; @@ -35,14 +35,14 @@ export type FullDbClientContract = Record & { }; export function run(cmd: string, env?: Record, cwd?: string) { - const start = Date.now(); + // const start = Date.now(); execSync(cmd, { stdio: 'pipe', encoding: 'utf-8', env: { ...process.env, DO_NOT_TRACK: '1', ...env }, cwd, }); - console.log('Execution took', Date.now() - start, 'ms', '-', cmd); + // console.log('Execution took', Date.now() - start, 'ms', '-', cmd); } function normalizePath(p: string) { @@ -81,16 +81,10 @@ datasource db { generator js { provider = 'prisma-client-js' - previewFeatures = ['clientExtensions'] } -plugin meta { - provider = '@core/model-meta' - preserveTsFiles = true -} - -plugin policy { - provider = '@core/access-policy' +plugin enhancer { + provider = '@core/enhancer' preserveTsFiles = true } @@ -116,6 +110,8 @@ export type SchemaLoadOptions = { dbUrl?: string; pulseApiKey?: string; getPrismaOnly?: boolean; + enhancements?: EnhancementKind[]; + enhanceOptions?: Partial; }; const defaultOptions: SchemaLoadOptions = { @@ -224,7 +220,7 @@ export async function loadSchema(schema: string, options?: SchemaLoadOptions) { // https://github.com/prisma/prisma/issues/18292 prisma[Symbol.for('nodejs.util.inspect.custom')] = 'PrismaClient'; - const Prisma = require(path.join(projectRoot, 'node_modules/@prisma/client')).Prisma; + const prismaModule = require(path.join(projectRoot, 'node_modules/@prisma/client')).Prisma; if (opt.pulseApiKey) { const withPulse = require(path.join(projectRoot, 'node_modules/@prisma/extension-pulse/dist/cjs')).withPulse; @@ -248,58 +244,56 @@ export async function loadSchema(schema: string, options?: SchemaLoadOptions) { if (options?.getPrismaOnly) { return { prisma, - Prisma, + prismaModule, projectDir: projectRoot, - withPolicy: undefined as any, - withOmit: undefined as any, - withPassword: undefined as any, enhance: undefined as any, + enhanceRaw: undefined as any, + policy: undefined as any, + modelMeta: undefined as any, + zodSchemas: undefined as any, }; } - let policy: any; - let modelMeta: any; - let zodSchemas: any; + const outputPath = opt.output + ? path.isAbsolute(opt.output) + ? opt.output + : path.join(projectRoot, opt.output) + : path.join(projectRoot, 'node_modules', '.zenstack'); - const outputPath = path.join(projectRoot, 'node_modules'); + const policy = require(path.join(outputPath, 'policy')).default; + const modelMeta = require(path.join(outputPath, 'model-meta')).default; + let zodSchemas: any; try { - policy = require(path.join(outputPath, '.zenstack/policy')).default; - } catch { - /* noop */ - } - try { - modelMeta = require(path.join(outputPath, '.zenstack/model-meta')).default; - } catch { - /* noop */ - } - try { - zodSchemas = require(path.join(outputPath, '.zenstack/zod')); + zodSchemas = require(path.join(outputPath, 'zod')); } catch { /* noop */ } + const enhance = require(path.join(outputPath, 'enhance')).enhance; + return { projectDir: projectRoot, prisma, - Prisma, - withPolicy: (user?: AuthUser) => - withPolicy( + enhance: (user?: AuthUser, options?: EnhancementOptions): FullDbClientContract => + enhance( prisma, { user }, - { policy, modelMeta, zodSchemas, logPrismaQuery: opt.logPrismaQuery } - ), - withOmit: () => withOmit(prisma, { modelMeta }), - withPassword: () => withPassword(prisma, { modelMeta }), - enhance: (user?: AuthUser) => - enhance( - prisma, - { user }, - { policy, modelMeta, zodSchemas, logPrismaQuery: opt.logPrismaQuery } + { + policy, + modelMeta, + zodSchemas, + logPrismaQuery: opt.logPrismaQuery, + transactionTimeout: 1000000, + kinds: opt.enhancements, + ...(options ?? opt.enhanceOptions), + } ), + enhanceRaw: enhance, policy, modelMeta, zodSchemas, + prismaModule, }; } @@ -324,7 +318,12 @@ export async function loadZModelAndDmmf( const model = await loadDocument(modelFile); const { name: prismaFile } = tmp.fileSync({ postfix: '.prisma' }); - await prismaPlugin(model, { schemaPath: modelFile, name: 'Prisma', output: prismaFile, generateClient: false }); + await prismaPlugin(model, { + provider: '@core/plugin', + schemaPath: modelFile, + output: prismaFile, + generateClient: false, + }); const prismaContent = fs.readFileSync(prismaFile, { encoding: 'utf-8' }); diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index d501a52ca..53fc4cf7b 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -1,4 +1,4 @@ -lockfileVersion: '6.0' +lockfileVersion: '6.1' settings: autoInstallPeers: true @@ -69,12 +69,12 @@ importers: packages/language: dependencies: langium: - specifier: 1.2.0 - version: 1.2.0 + specifier: 1.3.1 + version: 1.3.1 devDependencies: langium-cli: - specifier: 1.2.0 - version: 1.2.0 + specifier: 1.3.1 + version: 1.3.1 plist2: specifier: ^1.1.3 version: 1.1.3 @@ -391,9 +391,6 @@ importers: packages/runtime: dependencies: - '@types/bcryptjs': - specifier: ^2.4.2 - version: 2.4.2 bcryptjs: specifier: ^2.4.3 version: 2.4.3 @@ -412,6 +409,9 @@ importers: deepcopy: specifier: ^2.1.0 version: 2.1.0 + deepmerge: + specifier: ^4.3.1 + version: 4.3.1 lower-case-first: specifier: ^2.0.2 version: 2.0.2 @@ -436,6 +436,9 @@ importers: uuid: specifier: ^9.0.0 version: 9.0.0 + z3-solver: + specifier: ^4.12.5 + version: 4.12.5 zod: specifier: ^3.22.4 version: 3.22.4 @@ -443,6 +446,9 @@ importers: specifier: ^1.5.0 version: 1.5.0(zod@3.22.4) devDependencies: + '@types/bcryptjs': + specifier: ^2.4.2 + version: 2.4.2 '@types/pluralize': specifier: ^0.0.29 version: 0.0.29 @@ -484,8 +490,8 @@ importers: specifier: ^5.0.1 version: 5.0.1 langium: - specifier: 1.2.0 - version: 1.2.0 + specifier: 1.3.1 + version: 1.3.1 lower-case-first: specifier: ^2.0.2 version: 2.0.2 @@ -551,8 +557,8 @@ importers: version: 1.5.0(zod@3.22.4) devDependencies: '@prisma/client': - specifier: ^4.8.0 - version: 4.16.2(prisma@4.16.2) + specifier: ^5.7.1 + version: 5.7.1(prisma@5.7.1) '@types/async-exit-hook': specifier: ^2.0.0 version: 2.0.0 @@ -587,8 +593,8 @@ importers: specifier: ^0.15.12 version: 0.15.12 prisma: - specifier: ^4.8.0 - version: 4.16.2 + specifier: ^5.7.1 + version: 5.7.1 renamer: specifier: ^4.0.0 version: 4.0.0 @@ -620,6 +626,9 @@ importers: '@zenstackhq/runtime': specifier: workspace:* version: link:../runtime/dist + langium: + specifier: 1.3.1 + version: 1.3.1 lower-case-first: specifier: ^2.0.2 version: 2.0.2 @@ -632,6 +641,9 @@ importers: ts-morph: specifier: ^16.0.0 version: 16.0.0 + ts-pattern: + specifier: ^4.3.0 + version: 4.3.0 upper-case-first: specifier: ^2.0.2 version: 2.0.2 @@ -739,8 +751,8 @@ importers: specifier: ^2.2.3 version: 2.2.3 langium: - specifier: 1.2.0 - version: 1.2.0 + specifier: 1.3.1 + version: 1.3.1 pg: specifier: ^8.11.1 version: 8.11.1 @@ -877,7 +889,7 @@ packages: resolution: {integrity: sha512-Xmwn266vad+6DAqEB2A6V/CcZVp62BbwVmcOJc2RPuwih1kw02TjQvWVWlcKGbBPd+8/0V5DEkOcizRGYsspYQ==} engines: {node: '>=6.9.0'} dependencies: - '@babel/highlight': 7.22.5 + '@babel/highlight': 7.22.20 dev: true /@babel/compat-data@7.22.9: @@ -1086,15 +1098,6 @@ packages: chalk: 2.4.2 js-tokens: 4.0.0 - /@babel/highlight@7.22.5: - resolution: {integrity: sha512-BSKlD1hgnedS5XRnGOljZawtag7H1yPfQp0tdNJCHoH6AZ+Pcm9VvkrK59/Yy593Ypg0zMxH2BxD1VPYUQ7UIw==} - engines: {node: '>=6.9.0'} - dependencies: - '@babel/helper-validator-identifier': 7.22.20 - chalk: 2.4.2 - js-tokens: 4.0.0 - dev: true - /@babel/parser@7.23.0: resolution: {integrity: sha512-vvPKKdMemU85V9WE/l5wZEmImpCtLqbnTvqDS2U1fJ96KrxoW7KrXhNsNCblQlg8Ck4b85yxdTyelsMUgFUXiw==} engines: {node: '>=6.0.0'} @@ -3486,22 +3489,19 @@ packages: resolution: {integrity: sha512-a5Sab1C4/icpTZVzZc5Ghpz88yQtGOyNqYXcZgOssB2uuAr+wF/MvN6bgtW32q7HHrvBki+BsZ0OuNv6EV3K9g==} dev: true - /@prisma/client@4.16.2(prisma@4.16.2): - resolution: {integrity: sha512-qCoEyxv1ZrQ4bKy39GnylE8Zq31IRmm8bNhNbZx7bF2cU5aiCCnSa93J2imF88MBjn7J9eUQneNxUQVJdl/rPQ==} - engines: {node: '>=14.17'} + /@prisma/client@5.7.0: + resolution: {integrity: sha512-cZmglCrfNbYpzUtz7HscVHl38e9CrUs31nrVoGUK1nIPXGgt8hT4jj2s657UXcNdQ/jBUxDgGmHyu2Nyrq1txg==} + engines: {node: '>=16.13'} requiresBuild: true peerDependencies: prisma: '*' peerDependenciesMeta: prisma: optional: true - dependencies: - '@prisma/engines-version': 4.16.1-1.4bc8b6e1b66cb932731fb1bdbbc550d1e010de81 - prisma: 4.16.2 dev: true - /@prisma/client@5.7.0: - resolution: {integrity: sha512-cZmglCrfNbYpzUtz7HscVHl38e9CrUs31nrVoGUK1nIPXGgt8hT4jj2s657UXcNdQ/jBUxDgGmHyu2Nyrq1txg==} + /@prisma/client@5.7.1(prisma@5.7.1): + resolution: {integrity: sha512-TUSa4nUcC4nf/e7X3jyO1pEd6XcI/TLRCA0KjkA46RDIpxUaRsBYEOqITwXRW2c0bMFyKcCRXrH4f7h4q9oOlg==} engines: {node: '>=16.13'} requiresBuild: true peerDependencies: @@ -3509,6 +3509,8 @@ packages: peerDependenciesMeta: prisma: optional: true + dependencies: + prisma: 5.7.1 dev: true /@prisma/debug@4.16.2: @@ -3535,17 +3537,22 @@ packages: resolution: {integrity: sha512-tZ+MOjWlVvz1kOEhNYMa4QUGURY+kgOUBqLHYIV8jmCsMuvA1tWcn7qtIMLzYWCbDcQT4ZS8xDgK0R2gl6/0wA==} dev: false - /@prisma/engines-version@4.16.1-1.4bc8b6e1b66cb932731fb1bdbbc550d1e010de81: - resolution: {integrity: sha512-q617EUWfRIDTriWADZ4YiWRZXCa/WuhNgLTVd+HqWLffjMSPzyM5uOWoauX91wvQClSKZU4pzI4JJLQ9Kl62Qg==} + /@prisma/debug@5.7.1: + resolution: {integrity: sha512-yrVSO/YZOxdeIxcBtZ5BaNqUfPrZkNsAKQIQg36cJKMxj/VYK3Vk5jMKkI+gQLl0KReo1YvX8GWKfV788SELjw==} dev: true /@prisma/engines-version@5.7.0-41.79fb5193cf0a8fdbef536e4b4a159cad677ab1b9: resolution: {integrity: sha512-V6tgRVi62jRwTm0Hglky3Scwjr/AKFBFtS+MdbsBr7UOuiu1TKLPc6xfPiyEN1+bYqjEtjxwGsHgahcJsd1rNg==} dev: false + /@prisma/engines-version@5.7.1-1.0ca5ccbcfa6bdc81c003cf549abe4269f59c41e5: + resolution: {integrity: sha512-dIR5IQK/ZxEoWRBDOHF87r1Jy+m2ih3Joi4vzJRP+FOj5yxCwS2pS5SBR3TWoVnEK1zxtLI/3N7BjHyGF84fgw==} + dev: true + /@prisma/engines@4.16.2: resolution: {integrity: sha512-vx1nxVvN4QeT/cepQce68deh/Turxy5Mr+4L4zClFuK1GlxN3+ivxfuv+ej/gvidWn1cE1uAhW7ALLNlYbRUAw==} requiresBuild: true + dev: false /@prisma/engines@5.0.0: resolution: {integrity: sha512-kyT/8fd0OpWmhAU5YnY7eP31brW1q1YrTGoblWrhQJDiN/1K+Z8S1kylcmtjqx5wsUGcP1HBWutayA/jtyt+sg==} @@ -3562,6 +3569,16 @@ packages: '@prisma/get-platform': 5.7.0 dev: false + /@prisma/engines@5.7.1: + resolution: {integrity: sha512-R+Pqbra8tpLP2cvyiUpx+SIKglav3nTCpA+rn6826CThviQ8yvbNG0s8jNpo51vS9FuZO3pOkARqG062vKX7uA==} + requiresBuild: true + dependencies: + '@prisma/debug': 5.7.1 + '@prisma/engines-version': 5.7.1-1.0ca5ccbcfa6bdc81c003cf549abe4269f59c41e5 + '@prisma/fetch-engine': 5.7.1 + '@prisma/get-platform': 5.7.1 + dev: true + /@prisma/fetch-engine@4.16.2: resolution: {integrity: sha512-lnCnHcOaNn0kw8qTJbVcNhyfIf5Lus2GFXbj3qpkdKEIB9xLgqkkuTP+35q1xFaqwQ0vy4HFpdRUpFP7njE15g==} dependencies: @@ -3620,6 +3637,14 @@ packages: '@prisma/get-platform': 5.7.0 dev: false + /@prisma/fetch-engine@5.7.1: + resolution: {integrity: sha512-9ELauIEBkIaEUpMIYPRlh5QELfoC6pyHolHVQgbNxglaINikZ9w9X7r1TIePAcm05pCNp2XPY1ObQIJW5nYfBQ==} + dependencies: + '@prisma/debug': 5.7.1 + '@prisma/engines-version': 5.7.1-1.0ca5ccbcfa6bdc81c003cf549abe4269f59c41e5 + '@prisma/get-platform': 5.7.1 + dev: true + /@prisma/generator-helper@4.16.2: resolution: {integrity: sha512-bMOH7y73Ui7gpQrioFeavMQA+Tf8ksaVf8Nhs9rQNzuSg8SSV6E9baczob0L5KGZTSgYoqnrRxuo03kVJYrnIg==} dependencies: @@ -3688,6 +3713,12 @@ packages: '@prisma/debug': 5.7.0 dev: false + /@prisma/get-platform@5.7.1: + resolution: {integrity: sha512-eDlswr3a1m5z9D/55Iyt/nZqS5UpD+DZ9MooBB3hvrcPhDQrcf9m4Tl7buy4mvAtrubQ626ECtb8c6L/f7rGSQ==} + dependencies: + '@prisma/debug': 5.7.1 + dev: true + /@prisma/internals@4.16.2: resolution: {integrity: sha512-/3OiSADA3RRgsaeEE+MDsBgL6oAMwddSheXn6wtYGUnjERAV/BmF5bMMLnTykesQqwZ1s8HrISrJ0Vf6cjOxMg==} dependencies: @@ -4530,7 +4561,7 @@ packages: /@ts-morph/common@0.17.0: resolution: {integrity: sha512-RMSSvSfs9kb0VzkvQ2NWobwnj7TxCA9vI/IjR9bDHqgAyVbu2T0DN4wiKVqomyDWqO7dPr/tErSfq7urQ1Q37g==} dependencies: - fast-glob: 3.3.1 + fast-glob: 3.3.2 minimatch: 5.1.6 mkdirp: 1.0.4 path-browserify: 1.0.1 @@ -4591,6 +4622,7 @@ packages: /@types/bcryptjs@2.4.2: resolution: {integrity: sha512-LiMQ6EOPob/4yUL66SZzu6Yh77cbzJFYll+ZfaPiPPFswtIlA/Fs1MzdKYA7JApHU49zQTbJGX3PDmCpIdDBRQ==} + dev: true /@types/body-parser@1.19.2: resolution: {integrity: sha512-ALYone6pm6QmwZoAgeyNksccT9Q4AWZQ6PvfwR37GT6r6FWUPguq6sUmNGSMV2Wr761oQoBxwGGa6DR5o1DC9g==} @@ -5789,6 +5821,12 @@ packages: engines: {node: '>=0.12.0'} dev: false + /async-mutex@0.3.2: + resolution: {integrity: sha512-HuTK7E7MT7jZEh1P9GtRW9+aTWiDWWi9InbZ5hjxrnRa39KS4BW04+xLBhYNS2aXhHUIKZSw3gj4Pn1pj+qGAA==} + dependencies: + tslib: 2.6.0 + dev: false + /async-sema@3.1.1: resolution: {integrity: sha512-tLRNUXati5MFePdAk8dw7Qt7DpxPB60ofAgn8WRhW6a2rcimZnYBP9oxHiv0OHy+Wz7kPMG+t4LGdt31+4EmGg==} dev: true @@ -8137,17 +8175,6 @@ packages: resolution: {integrity: sha512-/d9sfos4yxzpwkDkuN7k2SqFKtYNmCTzgfEpz82x34IM9/zc8KGxQoXg1liNC/izpRM/MBdt44Nmx41ZWqk+FQ==} dev: true - /fast-glob@3.3.1: - resolution: {integrity: sha512-kNFPyjhh5cKjrUltxs+wFx+ZkbRaxxmZ+X0ZU31SOsxCEtP9VPgtq2teZw1DebupL5GmDaNQ6yKMMVcM41iqDg==} - engines: {node: '>=8.6.0'} - dependencies: - '@nodelib/fs.stat': 2.0.5 - '@nodelib/fs.walk': 1.2.8 - glob-parent: 5.1.2 - merge2: 1.4.1 - micromatch: 4.0.5 - dev: false - /fast-glob@3.3.2: resolution: {integrity: sha512-oX2ruAFQwf/Orj8m737Y5adxDQO0LAB7/S5MnxCdTNDd4p6BsyIVsv9JQsATbTSq8KHRpLwIHbVlUNatxd+1Ow==} engines: {node: '>=8.6.0'} @@ -10232,8 +10259,8 @@ packages: resolution: {integrity: sha512-dWl0Dbjm6Xm+kDxhPQJsCBTxrJzuGl0aP9rhr+TG8D3l+GL90N8O8lYUi7dTSAN2uuDqCtNgb6aEuQH5wsiV8Q==} dev: true - /langium-cli@1.2.0: - resolution: {integrity: sha512-DPyJUd4Hj8+OBNEcAQyJtW6e38+UPd758gTI7Ep0r/sDogrwJ/GJHx5nGA+r0ygpNcDPG+mS9Hw8Y05uCNNcoQ==} + /langium-cli@1.3.1: + resolution: {integrity: sha512-9faKpioKCjBD0Z4y165+wQlDFiDHOXYBlhPVgbV+neSnSB70belZLNfykAVa564360h7Br/5PogR5jW2n/tOKw==} engines: {node: '>=14.0.0'} hasBin: true dependencies: @@ -10241,12 +10268,20 @@ packages: commander: 10.0.1 fs-extra: 11.1.1 jsonschema: 1.4.1 - langium: 1.2.0 + langium: 1.3.1 + langium-railroad: 1.3.0 lodash: 4.17.21 dev: true - /langium@1.2.0: - resolution: {integrity: sha512-jFSptpFljYo9ZTHrq/GZflMUXiKo5KBNtsaIJtnIzDm9zC2FxsxejEFAtNL09262RVQt+zFeF/2iLAShFTGitw==} + /langium-railroad@1.3.0: + resolution: {integrity: sha512-I3gx79iF+Qpn2UjzfHLf2GENAD9mPdSZHL3juAZLBsxznw4se7MBrJX32oPr/35DTjU9q99wFCQoCXu7mcf+Bg==} + dependencies: + langium: 1.3.1 + railroad-diagrams: 1.0.0 + dev: true + + /langium@1.3.1: + resolution: {integrity: sha512-xC+DnAunl6cZIgYjRpgm3s1kYAB5/Wycsj24iYaXG9uai7SgvMaFZSrRvdA5rUK/lSta/CRvgF+ZFoEKEOFJ5w==} engines: {node: '>=14.0.0'} dependencies: chevrotain: 10.4.2 @@ -12468,13 +12503,13 @@ packages: hasBin: true dev: true - /prisma@4.16.2: - resolution: {integrity: sha512-SYCsBvDf0/7XSJyf2cHTLjLeTLVXYfqp7pG5eEVafFLeT0u/hLFz/9W196nDRGUOo1JfPatAEb+uEnTQImQC1g==} - engines: {node: '>=14.17'} + /prisma@5.7.1: + resolution: {integrity: sha512-ekho7ziH0WEJvC4AxuJz+ewRTMskrebPcrKuBwcNzVDniYxx+dXOGcorNeIb9VEMO5vrKzwNYvhD271Ui2jnNw==} + engines: {node: '>=16.13'} hasBin: true requiresBuild: true dependencies: - '@prisma/engines': 4.16.2 + '@prisma/engines': 5.7.1 dev: true /process-nextick-args@2.0.1: @@ -12619,6 +12654,10 @@ packages: resolution: {integrity: sha512-pNsHDxbGORSvuSScqNJ+3Km6QAVqk8CfsCBIEoDgpqLrkD2f3QM4I7d1ozJJ172OmIcoUcerZaNWqtLkRXTV3A==} dev: true + /railroad-diagrams@1.0.0: + resolution: {integrity: sha512-cz93DjNeLY0idrCNOH6PviZGRN9GJhsdm9hpn1YCS879fj4W+x5IFJhhkRZcwVgMmFF7R82UA/7Oh+R8lLZg6A==} + dev: true + /randombytes@2.1.0: resolution: {integrity: sha512-vYl3iOX+4CKUWuxGi9Ukhie6fsqXqS9FE2Zaic4tNFD2N2QQaXOMFbuKK4QmDHC0JO6B1Zp41J0LpT0oR68amQ==} dependencies: @@ -15456,6 +15495,13 @@ packages: engines: {node: '>=12.20'} dev: true + /z3-solver@4.12.5: + resolution: {integrity: sha512-uh6zoe+ErxG2qLjWyOcZE41eb6CHUZ3IT7VYTh1SDswPaoe7I1mvC7ujA36TROdBdPm59UUFToDmRbfiyjdA1Q==} + engines: {node: '>=16'} + dependencies: + async-mutex: 0.3.2 + dev: false + /zhead@2.1.3: resolution: {integrity: sha512-T6kZx8TYdLhuy2vURjPUj9EK9Dobnctu12CYw9ibu6Xj/UAqh2q2bQaA3vFrL4Rna5+CXYHYN3uJrUu6VulYzw==} dev: true diff --git a/script/test-prisma-v5.sh b/script/test-prisma-v5.sh deleted file mode 100755 index 51fc8e3cb..000000000 --- a/script/test-prisma-v5.sh +++ /dev/null @@ -1,3 +0,0 @@ -echo Setting Prisma Versions to V5 -npx replace-in-file '/"prisma":\s*"\^4.\d.\d"/g' '"prisma": "^5.0.0"' 'packages/testtools/**/package*.json' 'tests/integration/**/package*.json' --isRegex -npx replace-in-file '/"@prisma/client":\s*"\^4.\d.\d"/g' '"@prisma/client": "^5.0.0"' 'packages/testtools/**/package*.json' 'tests/integration/**/package*.json' --isRegex \ No newline at end of file diff --git a/tests/integration/package.json b/tests/integration/package.json index 40627f354..cace90307 100644 --- a/tests/integration/package.json +++ b/tests/integration/package.json @@ -5,7 +5,7 @@ "main": "index.js", "scripts": { "lint": "eslint . --ext .ts", - "test": "ZENSTACK_TEST=1 jest" + "test": "ZENSTACK_TEST=1 jest --runInBand" }, "keywords": [], "author": "", diff --git a/tests/integration/test-run/package.json b/tests/integration/test-run/package.json index d4e05bd29..4748eba78 100644 --- a/tests/integration/test-run/package.json +++ b/tests/integration/test-run/package.json @@ -10,13 +10,14 @@ "author": "", "license": "ISC", "dependencies": { - "@prisma/client": "^4.8.0", + "@prisma/client": "^5.0.0", "@zenstackhq/runtime": "file:../../../packages/runtime/dist", - "prisma": "^4.8.0", + "prisma": "^5.0.0", "react": "^18.2.0", "swr": "^1.3.0", "typescript": "^4.9.3", "zenstack": "file:../../../packages/schema/dist", - "zod": "^3.22.4" + "zod": "^3.22.4", + "z3-solver": "^4.12.5" } } diff --git a/tests/integration/tests/cli/config.test.ts b/tests/integration/tests/cli/config.test.ts deleted file mode 100644 index f047889fd..000000000 --- a/tests/integration/tests/cli/config.test.ts +++ /dev/null @@ -1,65 +0,0 @@ -/* eslint-disable @typescript-eslint/no-var-requires */ -/// - -import * as fs from 'fs'; -import * as tmp from 'tmp'; -import { createProgram } from '../../../../packages/schema/src/cli'; - -describe('CLI Config Tests', () => { - let origDir: string; - - beforeEach(() => { - origDir = process.cwd(); - const r = tmp.dirSync({ unsafeCleanup: true }); - console.log(`Project dir: ${r.name}`); - process.chdir(r.name); - - fs.writeFileSync('package.json', JSON.stringify({ name: 'my app', version: '1.0.0' })); - }); - - afterEach(() => { - process.chdir(origDir); - }); - - // for ensuring backward compatibility only - it('valid default config empty', async () => { - fs.writeFileSync('zenstack.config.json', JSON.stringify({})); - const program = createProgram(); - await program.parseAsync(['init', '--tag', 'latest'], { from: 'user' }); - }); - - // for ensuring backward compatibility only - it('valid default config non-empty', async () => { - fs.writeFileSync( - 'zenstack.config.json', - JSON.stringify({ guardFieldName: 'myGuardField', transactionFieldName: 'myTransactionField' }) - ); - - const program = createProgram(); - await program.parseAsync(['init', '--tag', 'latest'], { from: 'user' }); - }); - - it('custom config file does not exist', async () => { - const program = createProgram(); - const configFile = `my.config.json`; - await expect( - program.parseAsync(['init', '--tag', 'latest', '--config', configFile], { from: 'user' }) - ).rejects.toThrow(/Config file could not be found/i); - }); - - it('custom config file is not json', async () => { - const program = createProgram(); - const configFile = `my.config.json`; - fs.writeFileSync(configFile, ` 😬 😬 😬`); - await expect( - program.parseAsync(['init', '--tag', 'latest', '--config', configFile], { from: 'user' }) - ).rejects.toThrow(/Config is not a valid JSON file/i); - }); - - // for ensuring backward compatibility only - it('valid custom config file', async () => { - fs.writeFileSync('my.config.json', JSON.stringify({ guardFieldName: 'myGuardField' })); - const program = createProgram(); - await program.parseAsync(['init', '--tag', 'latest', '--config', 'my.config.json'], { from: 'user' }); - }); -}); diff --git a/tests/integration/tests/cli/generate.test.ts b/tests/integration/tests/cli/generate.test.ts index 0367033bd..544ae501a 100644 --- a/tests/integration/tests/cli/generate.test.ts +++ b/tests/integration/tests/cli/generate.test.ts @@ -86,26 +86,6 @@ model Post { expect(fs.existsSync('./out/zod')).toBeTruthy(); }); - it('generate custom output override', async () => { - fs.appendFileSync( - 'schema.zmodel', - ` - plugin policy { - provider = '@core/access-policy' - output = 'policy-out' - } - ` - ); - - const program = createProgram(); - await program.parseAsync(['generate', '--no-dependency-check', '-o', 'out'], { from: 'user' }); - expect(fs.existsSync('./node_modules/.zenstack')).toBeFalsy(); - expect(fs.existsSync('./out/model-meta.js')).toBeTruthy(); - expect(fs.existsSync('./out/zod')).toBeTruthy(); - expect(fs.existsSync('./out/policy.js')).toBeFalsy(); - expect(fs.existsSync('./policy-out/policy.js')).toBeTruthy(); - }); - it('generate no default plugins run nothing', async () => { const program = createProgram(); await program.parseAsync(['generate', '--no-dependency-check', '--no-default-plugins'], { from: 'user' }); @@ -136,8 +116,8 @@ model Post { fs.appendFileSync( 'schema.zmodel', ` - plugin policy { - provider = '@core/access-policy' + plugin enhancer { + provider = '@core/enhancer' } ` ); @@ -153,8 +133,8 @@ model Post { fs.appendFileSync( 'schema.zmodel', ` - plugin policy { - provider = '@core/access-policy' + plugin enhancer { + provider = '@core/enhancer' } ` ); diff --git a/tests/integration/tests/cli/init.test.ts b/tests/integration/tests/cli/init.test.ts index 96492b286..2dc9bdbf6 100644 --- a/tests/integration/tests/cli/init.test.ts +++ b/tests/integration/tests/cli/init.test.ts @@ -9,7 +9,9 @@ import { createProgram } from '../../../../packages/schema/src/cli'; import { execSync } from '../../../../packages/schema/src/utils/exec-utils'; import { createNpmrc } from './share'; -describe('CLI init command tests', () => { +// Skipping these tests as they seem to cause hangs intermittently when running with other tests +// eslint-disable-next-line jest/no-disabled-tests +describe.skip('CLI init command tests', () => { let origDir: string; beforeEach(() => { @@ -23,10 +25,14 @@ describe('CLI init command tests', () => { process.chdir(origDir); }); + // eslint-disable-next-line jest/no-disabled-tests it('init project t3 npm std', async () => { - execSync('npx --yes create-t3-app@latest --prisma --CI --noGit .', 'inherit', { - npm_config_user_agent: 'npm', - npm_config_cache: getWorkspaceNpmCacheFolder(__dirname), + execSync('npx --yes create-t3-app@latest --prisma --CI --noGit .', { + stdio: 'inherit', + env: { + npm_config_user_agent: 'npm', + npm_config_cache: getWorkspaceNpmCacheFolder(__dirname), + }, }); createNpmrc(); @@ -39,12 +45,13 @@ describe('CLI init command tests', () => { checkDependency('@zenstackhq/runtime', false, true); }); - // Disabled because it blows up memory on MAC, not sure why ... - // eslint-disable-next-line jest/no-disabled-tests - it.skip('init project t3 yarn std', async () => { - execSync('npx --yes create-t3-app@latest --prisma --CI --noGit .', 'inherit', { - npm_config_user_agent: 'yarn', - npm_config_cache: getWorkspaceNpmCacheFolder(__dirname), + it('init project t3 yarn std', async () => { + execSync('npx --yes create-t3-app@latest --prisma --CI --noGit .', { + stdio: 'inherit', + env: { + npm_config_user_agent: 'yarn', + npm_config_cache: getWorkspaceNpmCacheFolder(__dirname), + }, }); createNpmrc(); @@ -58,9 +65,12 @@ describe('CLI init command tests', () => { }); it('init project t3 pnpm std', async () => { - execSync('npx --yes create-t3-app@latest --prisma --CI --noGit .', 'inherit', { - npm_config_user_agent: 'pnpm', - npm_config_cache: getWorkspaceNpmCacheFolder(__dirname), + execSync('npx --yes create-t3-app@latest --prisma --CI --noGit .', { + stdio: 'inherit', + env: { + npm_config_user_agent: 'pnpm', + npm_config_cache: getWorkspaceNpmCacheFolder(__dirname), + }, }); createNpmrc(); @@ -74,9 +84,12 @@ describe('CLI init command tests', () => { }); it('init project t3 non-std prisma schema', async () => { - execSync('npx --yes create-t3-app@latest --prisma --CI --noGit .', 'inherit', { - npm_config_user_agent: 'npm', - npm_config_cache: getWorkspaceNpmCacheFolder(__dirname), + execSync('npx --yes create-t3-app@latest --prisma --CI --noGit .', { + stdio: 'inherit', + env: { + npm_config_user_agent: 'npm', + npm_config_cache: getWorkspaceNpmCacheFolder(__dirname), + }, }); createNpmrc(); fs.renameSync('prisma/schema.prisma', 'prisma/my.prisma'); diff --git a/tests/integration/tests/cli/plugins.test.ts b/tests/integration/tests/cli/plugins.test.ts index 005a0f69b..19dfb4dce 100644 --- a/tests/integration/tests/cli/plugins.test.ts +++ b/tests/integration/tests/cli/plugins.test.ts @@ -129,14 +129,8 @@ describe('CLI Plugins Tests', () => { output = 'prisma/my.prisma' generateClient = true }`, - `plugin meta { - provider = '@core/model-meta' - output = 'model-meta' - } - `, - `plugin policy { - provider = '@core/access-policy' - output = 'policy' + `plugin enhancer { + provider = '@core/enhancer' }`, `plugin tanstack { provider = '@zenstackhq/tanstack-query' diff --git a/tests/integration/tests/enhancements/with-delegate/policy.test.ts b/tests/integration/tests/enhancements/with-delegate/policy.test.ts new file mode 100644 index 000000000..d0316595d --- /dev/null +++ b/tests/integration/tests/enhancements/with-delegate/policy.test.ts @@ -0,0 +1,217 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('Polymorphic Policy Test', () => { + it('simple boolean', async () => { + const booleanCondition = ` + model User { + id Int @id @default(autoincrement()) + level Int @default(0) + assets Asset[] + banned Boolean @default(false) + + @@allow('all', true) + } + + model Asset { + id Int @id @default(autoincrement()) + createdAt DateTime @default(now()) + published Boolean @default(false) + owner User @relation(fields: [ownerId], references: [id]) + ownerId Int + assetType String + viewCount Int @default(0) + + @@delegate(assetType) + @@allow('create', viewCount >= 0) + @@deny('read', !published) + @@allow('read', true) + @@deny('all', owner.banned) + } + + model Video extends Asset { + watched Boolean @default(false) + videoType String + + @@delegate(videoType) + @@deny('read', !watched) + @@allow('read', true) + } + + model RatedVideo extends Video { + rated Boolean @default(false) + @@deny('read', !rated) + @@allow('read', true) + } + `; + + const booleanExpression = ` + model User { + id Int @id @default(autoincrement()) + level Int @default(0) + assets Asset[] + banned Boolean @default(false) + + @@allow('all', true) + } + + model Asset { + id Int @id @default(autoincrement()) + createdAt DateTime @default(now()) + published Boolean @default(false) + owner User @relation(fields: [ownerId], references: [id]) + ownerId Int + assetType String + viewCount Int @default(0) + + @@delegate(assetType) + @@allow('create', viewCount >= 0) + @@deny('read', published == false) + @@allow('read', true) + @@deny('all', owner.banned == true) + } + + model Video extends Asset { + watched Boolean @default(false) + videoType String + + @@delegate(videoType) + @@deny('read', watched == false) + @@allow('read', true) + } + + model RatedVideo extends Video { + rated Boolean @default(false) + @@deny('read', rated == false) + @@allow('read', true) + } + `; + + for (const schema of [booleanCondition, booleanExpression]) { + const { enhanceRaw: enhance, prisma } = await loadSchema(schema); + + const fullDb = enhance(prisma, undefined, { kinds: ['delegate'], logPrismaQuery: true }); + + const user = await fullDb.user.create({ data: { id: 1 } }); + const userDb = enhance( + prisma, + { user: { id: user.id } }, + { kinds: ['delegate', 'policy'], logPrismaQuery: true } + ); + + // violating Asset create + await expect( + userDb.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, viewCount: -1 }, + }) + ).toBeRejectedByPolicy(); + + let video = await fullDb.ratedVideo.create({ + data: { owner: { connect: { id: user.id } } }, + }); + // violating all three layer read + await expect(userDb.asset.findUnique({ where: { id: video.id } })).toResolveNull(); + await expect(userDb.video.findUnique({ where: { id: video.id } })).toResolveNull(); + await expect(userDb.ratedVideo.findUnique({ where: { id: video.id } })).toResolveNull(); + + video = await fullDb.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, published: true }, + }); + // violating Video && RatedVideo read + await expect(userDb.asset.findUnique({ where: { id: video.id } })).toResolveTruthy(); + await expect(userDb.video.findUnique({ where: { id: video.id } })).toResolveNull(); + await expect(userDb.ratedVideo.findUnique({ where: { id: video.id } })).toResolveNull(); + + video = await fullDb.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, published: true, watched: true }, + }); + // violating RatedVideo read + await expect(userDb.asset.findUnique({ where: { id: video.id } })).toResolveTruthy(); + await expect(userDb.video.findUnique({ where: { id: video.id } })).toResolveTruthy(); + await expect(userDb.ratedVideo.findUnique({ where: { id: video.id } })).toResolveNull(); + + video = await fullDb.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, rated: true, watched: true, published: true }, + }); + // meeting all read conditions + await expect(userDb.asset.findUnique({ where: { id: video.id } })).toResolveTruthy(); + await expect(userDb.video.findUnique({ where: { id: video.id } })).toResolveTruthy(); + await expect(userDb.ratedVideo.findUnique({ where: { id: video.id } })).toResolveTruthy(); + + // ban the user + await prisma.user.update({ where: { id: user.id }, data: { banned: true } }); + + // banned user can't read + await expect(userDb.asset.findUnique({ where: { id: video.id } })).toResolveNull(); + await expect(userDb.video.findUnique({ where: { id: video.id } })).toResolveNull(); + await expect(userDb.ratedVideo.findUnique({ where: { id: video.id } })).toResolveNull(); + + // banned user can't create + await expect( + userDb.ratedVideo.create({ + data: { owner: { connect: { id: user.id } } }, + }) + ).toBeRejectedByPolicy(); + } + }); + + it('interaction with updateMany/deleteMany', async () => { + const schema = ` + model User { + id Int @id @default(autoincrement()) + level Int @default(0) + assets Asset[] + banned Boolean @default(false) + + @@allow('all', true) + } + + model Asset { + id Int @id @default(autoincrement()) + createdAt DateTime @default(now()) + published Boolean @default(false) + owner User @relation(fields: [ownerId], references: [id]) + ownerId Int + assetType String + viewCount Int @default(0) + version Int @default(0) + + @@delegate(assetType) + @@deny('update', viewCount > 0) + @@deny('delete', viewCount > 0) + @@allow('all', true) + } + + model Video extends Asset { + watched Boolean @default(false) + + @@deny('update', watched) + @@deny('delete', watched) + } + `; + + const { enhance } = await loadSchema(schema, { + logPrismaQuery: true, + }); + const db = enhance(); + + const user = await db.user.create({ data: { id: 1 } }); + const vid1 = await db.video.create({ + data: { watched: false, viewCount: 0, owner: { connect: { id: user.id } } }, + }); + const vid2 = await db.video.create({ + data: { watched: true, viewCount: 1, owner: { connect: { id: user.id } } }, + }); + + await expect(db.asset.updateMany({ data: { version: { increment: 1 } } })).resolves.toMatchObject({ + count: 1, + }); + await expect(db.asset.findUnique({ where: { id: vid1.id } })).resolves.toMatchObject({ version: 1 }); + await expect(db.asset.findUnique({ where: { id: vid2.id } })).resolves.toMatchObject({ version: 0 }); + + await expect(db.asset.deleteMany()).resolves.toMatchObject({ + count: 1, + }); + await expect(db.asset.findUnique({ where: { id: vid1.id } })).toResolveNull(); + await expect(db.asset.findUnique({ where: { id: vid2.id } })).toResolveTruthy(); + }); +}); diff --git a/tests/integration/tests/enhancements/with-delegate/polymorphism.test.ts b/tests/integration/tests/enhancements/with-delegate/polymorphism.test.ts new file mode 100644 index 000000000..0d0b24ca2 --- /dev/null +++ b/tests/integration/tests/enhancements/with-delegate/polymorphism.test.ts @@ -0,0 +1,1015 @@ +import { loadSchema } from '@zenstackhq/testtools'; +import { PrismaErrorCode } from '@zenstackhq/runtime'; + +describe('Polymorphism Test', () => { + const schema = ` +model User { + id Int @id @default(autoincrement()) + level Int @default(0) + assets Asset[] + ratedVideos RatedVideo[] @relation('direct') + + @@allow('all', true) +} + +model Asset { + id Int @id @default(autoincrement()) + createdAt DateTime @default(now()) + viewCount Int @default(0) + owner User? @relation(fields: [ownerId], references: [id]) + ownerId Int? + assetType String + + @@delegate(assetType) + @@allow('all', true) +} + +model Video extends Asset { + duration Int + url String + videoType String + + @@delegate(videoType) +} + +model RatedVideo extends Video { + rating Int + user User? @relation(name: 'direct', fields: [userId], references: [id]) + userId Int? +} + +model Image extends Asset { + format String + gallery Gallery? @relation(fields: [galleryId], references: [id]) + galleryId Int? +} + +model Gallery { + id Int @id @default(autoincrement()) + images Image[] +} +`; + + async function setup() { + const { enhance } = await loadSchema(schema, { logPrismaQuery: true, enhancements: ['delegate'] }); + const db = enhance(); + + const user = await db.user.create({ data: { id: 1 } }); + + const video = await db.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, viewCount: 1, duration: 100, url: 'xyz', rating: 100 }, + }); + + const videoWithOwner = await db.ratedVideo.findUnique({ where: { id: video.id }, include: { owner: true } }); + + return { db, video, user, videoWithOwner }; + } + + it('create hierarchy', async () => { + const { enhance } = await loadSchema(schema, { logPrismaQuery: true, enhancements: ['delegate'] }); + const db = enhance(); + + const user = await db.user.create({ data: { id: 1 } }); + + const video = await db.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, viewCount: 1, duration: 100, url: 'xyz', rating: 100 }, + include: { owner: true }, + }); + + expect(video).toMatchObject({ + viewCount: 1, + duration: 100, + url: 'xyz', + rating: 100, + assetType: 'Video', + videoType: 'RatedVideo', + owner: user, + }); + + await expect(db.asset.create({ data: { type: 'Video' } })).rejects.toThrow('is a delegate'); + await expect(db.video.create({ data: { type: 'RatedVideo' } })).rejects.toThrow('is a delegate'); + + const image = await db.image.create({ + data: { owner: { connect: { id: user.id } }, viewCount: 1, format: 'png' }, + include: { owner: true }, + }); + expect(image).toMatchObject({ + viewCount: 1, + format: 'png', + assetType: 'Image', + owner: user, + }); + + // create in a nested payload + const gallery = await db.gallery.create({ + data: { + images: { + create: [ + { owner: { connect: { id: user.id } }, format: 'png', viewCount: 1 }, + { owner: { connect: { id: user.id } }, format: 'jpg', viewCount: 2 }, + ], + }, + }, + include: { images: { include: { owner: true } } }, + }); + expect(gallery.images).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + format: 'png', + assetType: 'Image', + viewCount: 1, + owner: user, + }), + expect.objectContaining({ + format: 'jpg', + assetType: 'Image', + viewCount: 2, + owner: user, + }), + ]) + ); + }); + + it('create with base all defaults', async () => { + const { enhance } = await loadSchema( + ` + model Base { + id Int @id @default(autoincrement()) + createdAt DateTime @default(now()) + type String + + @@delegate(type) + } + + model Foo extends Base { + name String + } + `, + { logPrismaQuery: true, enhancements: ['delegate'] } + ); + + const db = enhance(); + const r = await db.foo.create({ data: { name: 'foo' } }); + expect(r).toMatchObject({ name: 'foo', type: 'Foo', id: expect.any(Number), createdAt: expect.any(Date) }); + }); + + it('create with nesting', async () => { + const { enhance } = await loadSchema(schema, { logPrismaQuery: true, enhancements: ['delegate'] }); + const db = enhance(); + + // nested create a relation from base + await expect( + db.ratedVideo.create({ + data: { owner: { create: { id: 2 } }, url: 'xyz', rating: 200, duration: 200 }, + include: { owner: true }, + }) + ).resolves.toMatchObject({ owner: { id: 2 } }); + }); + + it('read with concrete', async () => { + const { db, user, video } = await setup(); + + // find with include + let found = await db.ratedVideo.findFirst({ include: { owner: true } }); + expect(found).toMatchObject(video); + expect(found.owner).toMatchObject(user); + + // find with select + found = await db.ratedVideo.findFirst({ select: { id: true, createdAt: true, url: true, rating: true } }); + expect(found).toMatchObject({ id: video.id, createdAt: video.createdAt, url: video.url, rating: video.rating }); + + // findFirstOrThrow + found = await db.ratedVideo.findFirstOrThrow(); + expect(found).toMatchObject(video); + await expect( + db.ratedVideo.findFirstOrThrow({ + where: { id: video.id + 1 }, + }) + ).rejects.toThrow(); + + // findUnique + found = await db.ratedVideo.findUnique({ + where: { id: video.id }, + }); + expect(found).toMatchObject(video); + + // findUniqueOrThrow + found = await db.ratedVideo.findUniqueOrThrow({ + where: { id: video.id }, + }); + expect(found).toMatchObject(video); + await expect( + db.ratedVideo.findUniqueOrThrow({ + where: { id: video.id + 1 }, + }) + ).rejects.toThrow(); + + // findMany + let items = await db.ratedVideo.findMany(); + expect(items).toHaveLength(1); + expect(items[0]).toMatchObject(video); + + // findMany not found + items = await db.ratedVideo.findMany({ where: { id: video.id + 1 } }); + expect(items).toHaveLength(0); + + // findMany with select + items = await db.ratedVideo.findMany({ select: { id: true, createdAt: true, url: true, rating: true } }); + expect(items).toHaveLength(1); + expect(items[0]).toMatchObject({ + id: video.id, + createdAt: video.createdAt, + url: video.url, + rating: video.rating, + }); + + // find with base filter + found = await db.ratedVideo.findFirst({ where: { viewCount: video.viewCount } }); + expect(found).toMatchObject(video); + found = await db.ratedVideo.findFirst({ where: { url: video.url, owner: { id: user.id } } }); + expect(found).toMatchObject(video); + + // image: single inheritance + const image = await db.image.create({ + data: { owner: { connect: { id: 1 } }, viewCount: 1, format: 'png' }, + include: { owner: true }, + }); + const readImage = await db.image.findFirst({ include: { owner: true } }); + expect(readImage).toMatchObject(image); + expect(readImage.owner).toMatchObject(user); + }); + + it('read with base', async () => { + const { db, user, video: r } = await setup(); + + let video = await db.video.findFirst({ where: { duration: r.duration }, include: { owner: true } }); + expect(video).toMatchObject({ + id: video.id, + createdAt: r.createdAt, + viewCount: r.viewCount, + url: r.url, + duration: r.duration, + assetType: 'Video', + videoType: 'RatedVideo', + }); + expect(video.rating).toBeUndefined(); + expect(video.owner).toMatchObject(user); + + const asset = await db.asset.findFirst({ where: { viewCount: r.viewCount }, include: { owner: true } }); + expect(asset).toMatchObject({ id: r.id, createdAt: r.createdAt, assetType: 'Video', viewCount: r.viewCount }); + expect(asset.url).toBeUndefined(); + expect(asset.duration).toBeUndefined(); + expect(asset.rating).toBeUndefined(); + expect(asset.videoType).toBeUndefined(); + expect(asset.owner).toMatchObject(user); + + const image = await db.image.create({ + data: { owner: { connect: { id: 1 } }, viewCount: 1, format: 'png' }, + include: { owner: true }, + }); + const imgAsset = await db.asset.findFirst({ where: { assetType: 'Image' }, include: { owner: true } }); + expect(imgAsset).toMatchObject({ + id: image.id, + createdAt: image.createdAt, + assetType: 'Image', + viewCount: image.viewCount, + }); + expect(imgAsset.format).toBeUndefined(); + expect(imgAsset.owner).toMatchObject(user); + }); + + it('update simple', async () => { + const { db, videoWithOwner: video } = await setup(); + + // update with concrete + let updated = await db.ratedVideo.update({ + where: { id: video.id }, + data: { rating: 200 }, + include: { owner: true }, + }); + expect(updated.rating).toBe(200); + expect(updated.owner).toBeTruthy(); + + // update with base + updated = await db.video.update({ + where: { id: video.id }, + data: { duration: 200 }, + select: { duration: true, createdAt: true }, + }); + expect(updated.duration).toBe(200); + expect(updated.createdAt).toBeTruthy(); + + // update with base + updated = await db.asset.update({ + where: { id: video.id }, + data: { viewCount: 200 }, + }); + expect(updated.viewCount).toBe(200); + + // set discriminator + await expect(db.ratedVideo.update({ where: { id: video.id }, data: { assetType: 'Image' } })).rejects.toThrow( + 'is a discriminator' + ); + await expect( + db.ratedVideo.update({ where: { id: video.id }, data: { videoType: 'RatedVideo' } }) + ).rejects.toThrow('is a discriminator'); + }); + + it('update nested create', async () => { + const { db, videoWithOwner: video, user } = await setup(); + + // create delegate not allowed + await expect( + db.user.update({ + where: { id: user.id }, + data: { + assets: { + create: { viewCount: 1 }, + }, + }, + include: { assets: true }, + }) + ).rejects.toThrow('is a delegate'); + + // create concrete + await expect( + db.user.update({ + where: { id: user.id }, + data: { + ratedVideos: { + create: { + viewCount: 1, + duration: 100, + url: 'xyz', + rating: 100, + owner: { connect: { id: user.id } }, + }, + }, + }, + include: { ratedVideos: true }, + }) + ).resolves.toMatchObject({ + ratedVideos: expect.arrayContaining([ + expect.objectContaining({ viewCount: 1, duration: 100, url: 'xyz', rating: 100 }), + ]), + }); + + // nested create a relation from base + const newVideo = await db.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, viewCount: 1, duration: 100, url: 'xyz', rating: 100 }, + }); + await expect( + db.ratedVideo.update({ + where: { id: newVideo.id }, + data: { owner: { create: { id: 2 } }, url: 'xyz', duration: 200, rating: 200 }, + include: { owner: true }, + }) + ).resolves.toMatchObject({ owner: { id: 2 } }); + }); + + it('update nested updateOne', async () => { + const { db, videoWithOwner: video, user } = await setup(); + + // update + let updated = await db.asset.update({ + where: { id: video.id }, + data: { owner: { update: { level: 1 } } }, + include: { owner: true }, + }); + expect(updated.owner.level).toBe(1); + + updated = await db.video.update({ + where: { id: video.id }, + data: { duration: 300, owner: { update: { level: 2 } } }, + include: { owner: true }, + }); + expect(updated.duration).toBe(300); + expect(updated.owner.level).toBe(2); + + updated = await db.ratedVideo.update({ + where: { id: video.id }, + data: { rating: 300, owner: { update: { level: 3 } } }, + include: { owner: true }, + }); + expect(updated.rating).toBe(300); + expect(updated.owner.level).toBe(3); + }); + + it('update nested updateMany', async () => { + const { db, videoWithOwner: video, user } = await setup(); + + // updateMany + await db.user.update({ + where: { id: user.id }, + data: { + ratedVideos: { + create: { url: 'xyz', duration: 111, rating: 222, owner: { connect: { id: user.id } } }, + }, + }, + }); + await expect( + db.user.update({ + where: { id: user.id }, + data: { ratedVideos: { updateMany: { where: { duration: 111 }, data: { rating: 333 } } } }, + include: { ratedVideos: true }, + }) + ).resolves.toMatchObject({ ratedVideos: expect.arrayContaining([expect.objectContaining({ rating: 333 })]) }); + }); + + it('update nested deleteOne', async () => { + const { db, videoWithOwner: video, user } = await setup(); + + // delete with base + await db.user.update({ + where: { id: user.id }, + data: { assets: { delete: { id: video.id } } }, + }); + await expect(db.asset.findUnique({ where: { id: video.id } })).resolves.toBeNull(); + await expect(db.video.findUnique({ where: { id: video.id } })).resolves.toBeNull(); + await expect(db.ratedVideo.findUnique({ where: { id: video.id } })).resolves.toBeNull(); + + // delete with concrete + let vid = await db.ratedVideo.create({ + data: { + user: { connect: { id: user.id } }, + owner: { connect: { id: user.id } }, + url: 'xyz', + duration: 111, + rating: 222, + }, + }); + await db.user.update({ + where: { id: user.id }, + data: { ratedVideos: { delete: { id: vid.id } } }, + }); + await expect(db.asset.findUnique({ where: { id: vid.id } })).resolves.toBeNull(); + await expect(db.video.findUnique({ where: { id: vid.id } })).resolves.toBeNull(); + await expect(db.ratedVideo.findUnique({ where: { id: vid.id } })).resolves.toBeNull(); + + // delete with mixed filter + vid = await db.ratedVideo.create({ + data: { + user: { connect: { id: user.id } }, + owner: { connect: { id: user.id } }, + url: 'xyz', + duration: 111, + rating: 222, + }, + }); + await db.user.update({ + where: { id: user.id }, + data: { ratedVideos: { delete: { id: vid.id, duration: 111 } } }, + }); + await expect(db.asset.findUnique({ where: { id: vid.id } })).resolves.toBeNull(); + await expect(db.video.findUnique({ where: { id: vid.id } })).resolves.toBeNull(); + await expect(db.ratedVideo.findUnique({ where: { id: vid.id } })).resolves.toBeNull(); + + // delete not found + await expect( + db.user.update({ + where: { id: user.id }, + data: { ratedVideos: { delete: { id: vid.id } } }, + }) + ).toBeNotFound(); + }); + + it('update nested deleteMany', async () => { + const { db, videoWithOwner: video, user } = await setup(); + + // delete with base no filter + await db.user.update({ + where: { id: user.id }, + data: { assets: { deleteMany: {} } }, + }); + await expect(db.asset.findUnique({ where: { id: video.id } })).resolves.toBeNull(); + await expect(db.video.findUnique({ where: { id: video.id } })).resolves.toBeNull(); + await expect(db.ratedVideo.findUnique({ where: { id: video.id } })).resolves.toBeNull(); + + // delete with concrete + let vid1 = await db.ratedVideo.create({ + data: { + user: { connect: { id: user.id } }, + owner: { connect: { id: user.id } }, + url: 'abc', + duration: 111, + rating: 111, + }, + }); + let vid2 = await db.ratedVideo.create({ + data: { + user: { connect: { id: user.id } }, + owner: { connect: { id: user.id } }, + url: 'xyz', + duration: 222, + rating: 222, + }, + }); + await db.user.update({ + where: { id: user.id }, + data: { ratedVideos: { deleteMany: { rating: 111 } } }, + }); + await expect(db.asset.findUnique({ where: { id: vid1.id } })).resolves.toBeNull(); + await expect(db.video.findUnique({ where: { id: vid1.id } })).resolves.toBeNull(); + await expect(db.ratedVideo.findUnique({ where: { id: vid1.id } })).resolves.toBeNull(); + await expect(db.asset.findUnique({ where: { id: vid2.id } })).toResolveTruthy(); + await db.asset.deleteMany(); + + // delete with mixed args + vid1 = await db.ratedVideo.create({ + data: { + user: { connect: { id: user.id } }, + owner: { connect: { id: user.id } }, + url: 'abc', + duration: 111, + rating: 111, + viewCount: 111, + }, + }); + vid2 = await db.ratedVideo.create({ + data: { + user: { connect: { id: user.id } }, + owner: { connect: { id: user.id } }, + url: 'xyz', + duration: 222, + rating: 222, + viewCount: 222, + }, + }); + await db.user.update({ + where: { id: user.id }, + data: { ratedVideos: { deleteMany: { url: 'abc', rating: 111, viewCount: 111 } } }, + }); + await expect(db.asset.findUnique({ where: { id: vid1.id } })).resolves.toBeNull(); + await expect(db.video.findUnique({ where: { id: vid1.id } })).resolves.toBeNull(); + await expect(db.ratedVideo.findUnique({ where: { id: vid1.id } })).resolves.toBeNull(); + await expect(db.asset.findUnique({ where: { id: vid2.id } })).toResolveTruthy(); + await db.asset.deleteMany(); + + // delete not found + vid1 = await db.ratedVideo.create({ + data: { + user: { connect: { id: user.id } }, + owner: { connect: { id: user.id } }, + url: 'abc', + duration: 111, + rating: 111, + }, + }); + vid2 = await db.ratedVideo.create({ + data: { + user: { connect: { id: user.id } }, + owner: { connect: { id: user.id } }, + url: 'xyz', + duration: 222, + rating: 222, + }, + }); + await db.user.update({ + where: { id: user.id }, + data: { ratedVideos: { deleteMany: { url: 'abc', rating: 222 } } }, + }); + await expect(db.asset.count()).resolves.toBe(2); + }); + + it('update nested relation manipulation', async () => { + const { db, videoWithOwner: video, user } = await setup(); + + // connect, disconnect with base + await expect( + db.user.update({ + where: { id: user.id }, + data: { assets: { disconnect: { id: video.id } } }, + include: { assets: true }, + }) + ).resolves.toMatchObject({ + assets: expect.arrayContaining([]), + }); + await expect( + db.user.update({ + where: { id: user.id }, + data: { assets: { connect: { id: video.id } } }, + include: { assets: true }, + }) + ).resolves.toMatchObject({ + assets: expect.arrayContaining([expect.objectContaining({ id: video.id })]), + }); + + /// connect, disconnect with concrete + + let vid1 = await db.ratedVideo.create({ + data: { + url: 'abc', + duration: 111, + rating: 111, + }, + }); + let vid2 = await db.ratedVideo.create({ + data: { + url: 'xyz', + duration: 222, + rating: 222, + }, + }); + + // connect not found + await expect( + db.user.update({ + where: { id: user.id }, + data: { ratedVideos: { connect: [{ id: vid2.id + 1 }] } }, + include: { ratedVideos: true }, + }) + ).toBeRejectedWithCode(PrismaErrorCode.REQUIRED_CONNECTED_RECORD_NOT_FOUND); + + // connect found + await expect( + db.user.update({ + where: { id: user.id }, + data: { ratedVideos: { connect: [{ id: vid1.id, duration: vid1.duration, rating: vid1.rating }] } }, + include: { ratedVideos: true }, + }) + ).resolves.toMatchObject({ + ratedVideos: expect.arrayContaining([expect.objectContaining({ id: vid1.id })]), + }); + + // connectOrCreate + await expect( + db.user.update({ + where: { id: user.id }, + data: { + ratedVideos: { + connectOrCreate: [ + { + where: { id: vid2.id, duration: 333 }, + create: { + url: 'xyz', + duration: 333, + rating: 333, + }, + }, + ], + }, + }, + include: { ratedVideos: true }, + }) + ).resolves.toMatchObject({ + ratedVideos: expect.arrayContaining([expect.objectContaining({ duration: 333 })]), + }); + + // disconnect not found + await expect( + db.user.update({ + where: { id: user.id }, + data: { ratedVideos: { disconnect: [{ id: vid2.id }] } }, + include: { ratedVideos: true }, + }) + ).resolves.toMatchObject({ + ratedVideos: expect.arrayContaining([expect.objectContaining({ id: vid1.id })]), + }); + + // disconnect found + await expect( + db.user.update({ + where: { id: user.id }, + data: { ratedVideos: { disconnect: [{ id: vid1.id, duration: vid1.duration, rating: vid1.rating }] } }, + include: { ratedVideos: true }, + }) + ).resolves.toMatchObject({ + ratedVideos: expect.arrayContaining([]), + }); + + // set + await expect( + db.user.update({ + where: { id: user.id }, + data: { + ratedVideos: { + set: [ + { id: vid1.id, viewCount: vid1.viewCount }, + { id: vid2.id, viewCount: vid2.viewCount }, + ], + }, + }, + include: { ratedVideos: true }, + }) + ).resolves.toMatchObject({ + ratedVideos: expect.arrayContaining([ + expect.objectContaining({ id: vid1.id }), + expect.objectContaining({ id: vid2.id }), + ]), + }); + await expect( + db.user.update({ + where: { id: user.id }, + data: { ratedVideos: { set: [] } }, + include: { ratedVideos: true }, + }) + ).resolves.toMatchObject({ + ratedVideos: expect.arrayContaining([]), + }); + await expect( + db.user.update({ + where: { id: user.id }, + data: { + ratedVideos: { + set: { id: vid1.id, viewCount: vid1.viewCount }, + }, + }, + include: { ratedVideos: true }, + }) + ).resolves.toMatchObject({ + ratedVideos: expect.arrayContaining([expect.objectContaining({ id: vid1.id })]), + }); + }); + + it('updateMany', async () => { + const { db, videoWithOwner: video, user } = await setup(); + const otherVideo = await db.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, viewCount: 10000, duration: 10000, url: 'xyz', rating: 10000 }, + }); + + // update only the current level + await expect( + db.ratedVideo.updateMany({ + where: { rating: video.rating, viewCount: video.viewCount }, + data: { rating: 100 }, + }) + ).resolves.toMatchObject({ count: 1 }); + let read = await db.ratedVideo.findUnique({ where: { id: video.id } }); + expect(read).toMatchObject({ rating: 100 }); + + // update with concrete + await expect( + db.ratedVideo.updateMany({ + where: { id: video.id }, + data: { viewCount: 1, duration: 11, rating: 101 }, + }) + ).resolves.toMatchObject({ count: 1 }); + read = await db.ratedVideo.findUnique({ where: { id: video.id } }); + expect(read).toMatchObject({ viewCount: 1, duration: 11, rating: 101 }); + + // update with base + await db.video.updateMany({ + where: { viewCount: 1, duration: 11 }, + data: { viewCount: 2, duration: 12 }, + }); + read = await db.ratedVideo.findUnique({ where: { id: video.id } }); + expect(read).toMatchObject({ viewCount: 2, duration: 12 }); + + // update with base + await db.asset.updateMany({ + where: { viewCount: 2 }, + data: { viewCount: 3 }, + }); + read = await db.ratedVideo.findUnique({ where: { id: video.id } }); + expect(read.viewCount).toBe(3); + + // the other video is unchanged + await expect(await db.ratedVideo.findUnique({ where: { id: otherVideo.id } })).toMatchObject(otherVideo); + + // update with concrete no where + await expect( + db.ratedVideo.updateMany({ + data: { viewCount: 111, duration: 111, rating: 111 }, + }) + ).resolves.toMatchObject({ count: 2 }); + await expect(db.ratedVideo.findUnique({ where: { id: video.id } })).resolves.toMatchObject({ duration: 111 }); + await expect(db.ratedVideo.findUnique({ where: { id: otherVideo.id } })).resolves.toMatchObject({ + duration: 111, + }); + + // set discriminator + await expect(db.ratedVideo.updateMany({ data: { assetType: 'Image' } })).rejects.toThrow('is a discriminator'); + await expect(db.ratedVideo.updateMany({ data: { videoType: 'RatedVideo' } })).rejects.toThrow( + 'is a discriminator' + ); + }); + + it('upsert', async () => { + const { db, videoWithOwner: video, user } = await setup(); + + await expect( + db.asset.upsert({ + where: { id: video.id }, + create: { id: video.id, viewCount: 1 }, + update: { viewCount: 2 }, + }) + ).rejects.toThrow('is a delegate'); + + // update + await expect( + db.ratedVideo.upsert({ + where: { id: video.id }, + create: { + viewCount: 1, + duration: 300, + url: 'xyz', + rating: 100, + owner: { connect: { id: user.id } }, + }, + update: { duration: 200 }, + }) + ).resolves.toMatchObject({ + id: video.id, + duration: 200, + }); + + // create + const created = await db.ratedVideo.upsert({ + where: { id: video.id + 1 }, + create: { viewCount: 1, duration: 300, url: 'xyz', rating: 100, owner: { connect: { id: user.id } } }, + update: { duration: 200 }, + }); + expect(created.id).not.toEqual(video.id); + expect(created.duration).toBe(300); + }); + + it('delete', async () => { + let { db, user, video: ratedVideo } = await setup(); + + let deleted = await db.ratedVideo.delete({ + where: { id: ratedVideo.id }, + select: { rating: true, owner: true }, + }); + expect(deleted).toMatchObject({ rating: 100 }); + expect(deleted.owner).toMatchObject(user); + await expect(db.ratedVideo.findUnique({ where: { id: ratedVideo.id } })).resolves.toBeNull(); + await expect(db.video.findUnique({ where: { id: ratedVideo.id } })).resolves.toBeNull(); + await expect(db.asset.findUnique({ where: { id: ratedVideo.id } })).resolves.toBeNull(); + + // delete with base + ratedVideo = await db.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, viewCount: 1, duration: 100, url: 'xyz', rating: 100 }, + }); + const video = await db.video.findUnique({ where: { id: ratedVideo.id } }); + deleted = await db.video.delete({ where: { id: ratedVideo.id }, include: { owner: true } }); + expect(deleted).toMatchObject(video); + expect(deleted.owner).toMatchObject(user); + await expect(db.ratedVideo.findUnique({ where: { id: ratedVideo.id } })).resolves.toBeNull(); + await expect(db.video.findUnique({ where: { id: ratedVideo.id } })).resolves.toBeNull(); + await expect(db.asset.findUnique({ where: { id: ratedVideo.id } })).resolves.toBeNull(); + + // delete with concrete + ratedVideo = await db.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, viewCount: 1, duration: 100, url: 'xyz', rating: 100 }, + }); + let asset = await db.asset.findUnique({ where: { id: ratedVideo.id } }); + deleted = await db.video.delete({ where: { id: ratedVideo.id }, include: { owner: true } }); + expect(deleted).toMatchObject(asset); + expect(deleted.owner).toMatchObject(user); + await expect(db.ratedVideo.findUnique({ where: { id: ratedVideo.id } })).resolves.toBeNull(); + await expect(db.video.findUnique({ where: { id: ratedVideo.id } })).resolves.toBeNull(); + await expect(db.asset.findUnique({ where: { id: ratedVideo.id } })).resolves.toBeNull(); + + // delete with combined condition + ratedVideo = await db.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, viewCount: 1, duration: 100, url: 'xyz', rating: 100 }, + }); + asset = await db.asset.findUnique({ where: { id: ratedVideo.id } }); + deleted = await db.video.delete({ where: { id: ratedVideo.id, viewCount: 1 } }); + expect(deleted).toMatchObject(asset); + await expect(db.ratedVideo.findUnique({ where: { id: ratedVideo.id } })).resolves.toBeNull(); + await expect(db.video.findUnique({ where: { id: ratedVideo.id } })).resolves.toBeNull(); + await expect(db.asset.findUnique({ where: { id: ratedVideo.id } })).resolves.toBeNull(); + }); + + it('deleteMany', async () => { + const { enhance } = await loadSchema(schema, { logPrismaQuery: true, enhancements: ['delegate'] }); + const db = enhance(); + + const user = await db.user.create({ data: { id: 1 } }); + + // no where + let video1 = await db.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, viewCount: 1, duration: 100, url: 'xyz', rating: 100 }, + }); + let video2 = await db.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, viewCount: 1, duration: 100, url: 'xyz', rating: 100 }, + }); + await expect(db.ratedVideo.deleteMany()).resolves.toMatchObject({ count: 2 }); + await expect(db.ratedVideo.findUnique({ where: { id: video1.id } })).resolves.toBeNull(); + await expect(db.video.findUnique({ where: { id: video1.id } })).resolves.toBeNull(); + await expect(db.asset.findUnique({ where: { id: video1.id } })).resolves.toBeNull(); + await expect(db.ratedVideo.findUnique({ where: { id: video2.id } })).resolves.toBeNull(); + await expect(db.video.findUnique({ where: { id: video2.id } })).resolves.toBeNull(); + await expect(db.asset.findUnique({ where: { id: video2.id } })).resolves.toBeNull(); + await expect(db.ratedVideo.count()).resolves.toBe(0); + + // with base + video1 = await db.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, viewCount: 1, duration: 100, url: 'abc', rating: 100 }, + }); + video2 = await db.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, viewCount: 2, duration: 200, url: 'xyz', rating: 200 }, + }); + await expect(db.asset.deleteMany({ where: { viewCount: 1 } })).resolves.toMatchObject({ count: 1 }); + await expect(db.asset.count()).resolves.toBe(1); + await db.asset.deleteMany(); + + // where current level + video1 = await db.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, viewCount: 1, duration: 100, url: 'abc', rating: 100 }, + }); + video2 = await db.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, viewCount: 2, duration: 200, url: 'xyz', rating: 200 }, + }); + await expect(db.ratedVideo.deleteMany({ where: { rating: 100 } })).resolves.toMatchObject({ count: 1 }); + await expect(db.ratedVideo.count()).resolves.toBe(1); + await db.ratedVideo.deleteMany(); + + // where mixed with base level + video1 = await db.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, viewCount: 1, duration: 100, url: 'abc', rating: 100 }, + }); + video2 = await db.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, viewCount: 2, duration: 200, url: 'xyz', rating: 200 }, + }); + await expect(db.ratedVideo.deleteMany({ where: { viewCount: 1, duration: 100 } })).resolves.toMatchObject({ + count: 1, + }); + await expect(db.ratedVideo.count()).resolves.toBe(1); + await db.ratedVideo.deleteMany(); + + // delete not found + video1 = await db.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, viewCount: 1, duration: 100, url: 'abc', rating: 100 }, + }); + video2 = await db.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, viewCount: 2, duration: 200, url: 'xyz', rating: 200 }, + }); + await expect(db.ratedVideo.deleteMany({ where: { viewCount: 2, duration: 100 } })).resolves.toMatchObject({ + count: 0, + }); + await expect(db.ratedVideo.count()).resolves.toBe(2); + }); + + it('aggregate', async () => { + const { db } = await setup(); + + const aggregate = await db.ratedVideo.aggregate({ + _count: true, + _sum: { rating: true }, + where: { viewCount: { gt: 0 }, rating: { gt: 10 } }, + orderBy: { + duration: 'desc', + }, + }); + expect(aggregate).toMatchObject({ _count: 1, _sum: { rating: 100 } }); + + expect(() => db.ratedVideo.aggregate({ _count: true, _sum: { rating: true, viewCount: true } })).toThrow( + 'aggregate with fields from base type is not supported yet' + ); + }); + + it('count', async () => { + const { db } = await setup(); + + let count = await db.ratedVideo.count(); + expect(count).toBe(1); + + count = await db.ratedVideo.count({ + select: { _all: true, rating: true }, + where: { viewCount: { gt: 0 }, rating: { gt: 10 } }, + }); + expect(count).toMatchObject({ _all: 1, rating: 1 }); + + expect(() => db.ratedVideo.count({ select: { rating: true, viewCount: true } })).toThrow( + 'count with fields from base type is not supported yet' + ); + }); + + it('groupBy', async () => { + const { db, video } = await setup(); + + let group = await db.ratedVideo.groupBy({ by: ['rating'] }); + expect(group).toHaveLength(1); + expect(group[0]).toMatchObject({ rating: video.rating }); + + group = await db.ratedVideo.groupBy({ + by: ['id', 'rating'], + where: { viewCount: { gt: 0 }, rating: { gt: 10 } }, + }); + expect(group).toHaveLength(1); + expect(group[0]).toMatchObject({ id: video.id, rating: video.rating }); + + group = await db.ratedVideo.groupBy({ + by: ['id'], + _sum: { rating: true }, + }); + expect(group).toHaveLength(1); + expect(group[0]).toMatchObject({ id: video.id, _sum: { rating: video.rating } }); + + group = await db.ratedVideo.groupBy({ + by: ['id'], + _sum: { rating: true }, + having: { rating: { _sum: { gt: video.rating } } }, + }); + expect(group).toHaveLength(0); + + expect(() => db.ratedVideo.groupBy({ by: 'viewCount' })).toThrow( + 'groupBy with fields from base type is not supported yet' + ); + expect(() => db.ratedVideo.groupBy({ having: { rating: { gt: 0 }, viewCount: { gt: 0 } } })).toThrow( + 'groupBy with fields from base type is not supported yet' + ); + }); +}); diff --git a/tests/integration/tests/enhancements/with-omit/with-omit.test.ts b/tests/integration/tests/enhancements/with-omit/with-omit.test.ts index 61d44b440..f7fcc7266 100644 --- a/tests/integration/tests/enhancements/with-omit/with-omit.test.ts +++ b/tests/integration/tests/enhancements/with-omit/with-omit.test.ts @@ -1,4 +1,3 @@ -import { withOmit } from '@zenstackhq/runtime'; import { loadSchema } from '@zenstackhq/testtools'; import path from 'path'; @@ -33,9 +32,9 @@ describe('Omit test', () => { `; it('omit tests', async () => { - const { withOmit } = await loadSchema(model); + const { enhance } = await loadSchema(model); - const db = withOmit(); + const db = enhance(); const r = await db.user.create({ include: { profile: true }, data: { @@ -79,9 +78,12 @@ describe('Omit test', () => { }); it('customization', async () => { - const { prisma } = await loadSchema(model, { getPrismaOnly: true, output: './zen' }); + const { prisma, enhance } = await loadSchema(model, { + output: './zen', + enhancements: ['omit'], + }); - const db = withOmit(prisma, { loadPath: './zen' }); + const db = enhance(prisma, { loadPath: './zen' }); const r = await db.user.create({ include: { profile: true }, data: { @@ -93,7 +95,7 @@ describe('Omit test', () => { expect(r.password).toBeUndefined(); expect(r.profile.image).toBeUndefined(); - const db1 = withOmit(prisma, { modelMeta: require(path.resolve('./zen/model-meta')).default }); + const db1 = enhance(prisma, { modelMeta: require(path.resolve('./zen/model-meta')).default }); const r1 = await db1.user.create({ include: { profile: true }, data: { @@ -105,4 +107,53 @@ describe('Omit test', () => { expect(r1.password).toBeUndefined(); expect(r1.profile.image).toBeUndefined(); }); + + it('to-many', async () => { + const { enhance } = await loadSchema( + ` + model User { + id String @id @default(cuid()) + posts Post[] + + @@allow('all', true) + } + + model Post { + id String @id @default(cuid()) + user User @relation(fields: [userId], references: [id]) + userId String + images Image[] + + @@allow('all', true) + } + + model Image { + id String @id @default(cuid()) + post Post @relation(fields: [postId], references: [id]) + postId String + url String @omit + + @@allow('all', true) + } + `, + { enhancements: ['omit'] } + ); + + const db = enhance(); + const r = await db.user.create({ + include: { posts: { include: { images: true } } }, + data: { + posts: { + create: [ + { images: { create: { url: 'img1' } } }, + { images: { create: [{ url: 'img2' }, { url: 'img3' }] } }, + ], + }, + }, + }); + + expect(r.posts[0].images[0].url).toBeUndefined(); + expect(r.posts[1].images[0].url).toBeUndefined(); + expect(r.posts[1].images[1].url).toBeUndefined(); + }); }); diff --git a/tests/integration/tests/enhancements/with-password/with-password.test.ts b/tests/integration/tests/enhancements/with-password/with-password.test.ts index 62e30636b..37b23ecde 100644 --- a/tests/integration/tests/enhancements/with-password/with-password.test.ts +++ b/tests/integration/tests/enhancements/with-password/with-password.test.ts @@ -1,4 +1,3 @@ -import { withPassword } from '@zenstackhq/runtime'; import { loadSchema } from '@zenstackhq/testtools'; import { compareSync } from 'bcryptjs'; import path from 'path'; @@ -23,9 +22,9 @@ describe('Password test', () => { }`; it('password tests', async () => { - const { withPassword } = await loadSchema(model); + const { enhance } = await loadSchema(model); - const db = withPassword(); + const db = enhance(); const r = await db.user.create({ data: { id: '1', @@ -42,26 +41,4 @@ describe('Password test', () => { }); expect(compareSync('abc456', r1.password)).toBeTruthy(); }); - - it('customization', async () => { - const { prisma } = await loadSchema(model, { getPrismaOnly: true, output: './zen' }); - - const db = withPassword(prisma, { loadPath: './zen' }); - const r = await db.user.create({ - data: { - id: '1', - password: 'abc123', - }, - }); - expect(compareSync('abc123', r.password)).toBeTruthy(); - - const db1 = withPassword(prisma, { modelMeta: require(path.resolve('./zen/model-meta')).default }); - const r1 = await db1.user.create({ - data: { - id: '2', - password: 'abc123', - }, - }); - expect(compareSync('abc123', r1.password)).toBeTruthy(); - }); }); diff --git a/tests/integration/tests/enhancements/with-policy/auth.test.ts b/tests/integration/tests/enhancements/with-policy/auth.test.ts index 8f095f677..f5b4e2f4f 100644 --- a/tests/integration/tests/enhancements/with-policy/auth.test.ts +++ b/tests/integration/tests/enhancements/with-policy/auth.test.ts @@ -13,7 +13,7 @@ describe('With Policy: auth() test', () => { }); it('undefined user with string id simple', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model User { id String @id @default(uuid()) @@ -29,15 +29,15 @@ describe('With Policy: auth() test', () => { ` ); - const db = withPolicy(); + const db = enhance(); await expect(db.post.create({ data: { title: 'abc' } })).toBeRejectedByPolicy(); - const authDb = withPolicy({ id: 'user1' }); + const authDb = enhance({ id: 'user1' }); await expect(authDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); }); it('undefined user with string id more', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model User { id String @id @default(uuid()) @@ -53,15 +53,15 @@ describe('With Policy: auth() test', () => { ` ); - const db = withPolicy(); + const db = enhance(); await expect(db.post.create({ data: { title: 'abc' } })).toBeRejectedByPolicy(); - const authDb = withPolicy({ id: 'user1' }); + const authDb = enhance({ id: 'user1' }); await expect(authDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); }); it('undefined user with int id', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model User { id Int @id @default(autoincrement()) @@ -77,15 +77,15 @@ describe('With Policy: auth() test', () => { ` ); - const db = withPolicy(); + const db = enhance(); await expect(db.post.create({ data: { title: 'abc' } })).toBeRejectedByPolicy(); - const authDb = withPolicy({ id: 'user1' }); + const authDb = enhance({ id: 'user1' }); await expect(authDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); }); it('undefined user compared with field', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model User { id String @id @default(uuid()) @@ -106,21 +106,21 @@ describe('With Policy: auth() test', () => { ` ); - const db = withPolicy(); + const db = enhance(); await expect(db.user.create({ data: { id: 'user1' } })).toResolveTruthy(); await expect(db.post.create({ data: { id: '1', title: 'abc', authorId: 'user1' } })).toResolveTruthy(); - const authDb = withPolicy(); + const authDb = enhance(); await expect(authDb.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toBeRejectedByPolicy(); - expect(() => withPolicy({ id: null })).toThrow(/Invalid user context/); + expect(() => enhance({ id: null })).toThrow(/Invalid user context/); - const authDb2 = withPolicy({ id: 'user1' }); + const authDb2 = enhance({ id: 'user1' }); await expect(authDb2.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toResolveTruthy(); }); it('undefined user compared with field more', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model User { id String @id @default(uuid()) @@ -141,18 +141,18 @@ describe('With Policy: auth() test', () => { ` ); - const db = withPolicy(); + const db = enhance(); await expect(db.user.create({ data: { id: 'user1' } })).toResolveTruthy(); await expect(db.post.create({ data: { id: '1', title: 'abc', authorId: 'user1' } })).toResolveTruthy(); await expect(db.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toBeRejectedByPolicy(); - const authDb2 = withPolicy({ id: 'user1' }); + const authDb2 = enhance({ id: 'user1' }); await expect(authDb2.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toResolveTruthy(); }); it('undefined user non-id field', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model User { id String @id @default(uuid()) @@ -174,20 +174,20 @@ describe('With Policy: auth() test', () => { ` ); - const db = withPolicy(); + const db = enhance(); await expect(db.user.create({ data: { id: 'user1', role: 'USER' } })).toResolveTruthy(); await expect(db.post.create({ data: { id: '1', title: 'abc', authorId: 'user1' } })).toResolveTruthy(); await expect(db.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toBeRejectedByPolicy(); - const authDb = withPolicy({ id: 'user1', role: 'USER' }); + const authDb = enhance({ id: 'user1', role: 'USER' }); await expect(authDb.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toBeRejectedByPolicy(); - const authDb1 = withPolicy({ id: 'user2', role: 'ADMIN' }); + const authDb1 = enhance({ id: 'user2', role: 'ADMIN' }); await expect(authDb1.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toResolveTruthy(); }); it('non User auth model', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model Foo { id String @id @default(uuid()) @@ -206,15 +206,15 @@ describe('With Policy: auth() test', () => { ` ); - const userDb = withPolicy({ id: 'user1', role: 'USER' }); + const userDb = enhance({ id: 'user1', role: 'USER' }); await expect(userDb.post.create({ data: { title: 'abc' } })).toBeRejectedByPolicy(); - const adminDb = withPolicy({ id: 'user1', role: 'ADMIN' }); + const adminDb = enhance({ id: 'user1', role: 'ADMIN' }); await expect(adminDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); }); it('User model ignored', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model User { id String @id @default(uuid()) @@ -233,15 +233,15 @@ describe('With Policy: auth() test', () => { ` ); - const userDb = withPolicy({ id: 'user1', role: 'USER' }); + const userDb = enhance({ id: 'user1', role: 'USER' }); await expect(userDb.post.create({ data: { title: 'abc' } })).toBeRejectedByPolicy(); - const adminDb = withPolicy({ id: 'user1', role: 'ADMIN' }); + const adminDb = enhance({ id: 'user1', role: 'ADMIN' }); await expect(adminDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); }); it('Auth model ignored', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model Foo { id String @id @default(uuid()) @@ -261,10 +261,10 @@ describe('With Policy: auth() test', () => { ` ); - const userDb = withPolicy({ id: 'user1', role: 'USER' }); + const userDb = enhance({ id: 'user1', role: 'USER' }); await expect(userDb.post.create({ data: { title: 'abc' } })).toBeRejectedByPolicy(); - const adminDb = withPolicy({ id: 'user1', role: 'ADMIN' }); + const adminDb = enhance({ id: 'user1', role: 'ADMIN' }); await expect(adminDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); }); @@ -363,4 +363,146 @@ describe('With Policy: auth() test', () => { enhance({ id: '1', posts: [{ id: '1', published: true, comments: [] }] }).post.create(createPayload) ).toResolveTruthy(); }); + + it('Default auth() on literal fields', async () => { + const { enhance } = await loadSchema( + ` + model User { + id String @id + name String + score Int + + } + + model Post { + id String @id @default(uuid()) + title String + score Int? @default(auth().score) + authorName String? @default(auth().name) + + @@allow('all', true) + } + ` + ); + + const userDb = enhance({ id: '1', name: 'user1', score: 10 }); + await expect(userDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); + await expect(userDb.post.findMany()).resolves.toHaveLength(1); + await expect(userDb.post.count({ where: { authorName: 'user1', score: 10 } })).resolves.toBe(1); + }); + + it('Default auth() data should not override passed args', async () => { + const { enhance } = await loadSchema( + ` + model User { + id String @id + name String + + } + + model Post { + id String @id @default(uuid()) + authorName String? @default(auth().name) + + @@allow('all', true) + } + ` + ); + + const userContextName = 'user1'; + const overrideName = 'no-default-auth-name'; + const userDb = enhance({ id: '1', name: userContextName }); + await expect(userDb.post.create({ data: { authorName: overrideName } })).toResolveTruthy(); + await expect(userDb.post.count({ where: { authorName: overrideName } })).resolves.toBe(1); + }); + + it('Default auth() with foreign key', async () => { + const { enhance, modelMeta } = await loadSchema( + ` + model User { + id String @id + posts Post[] + + @@allow('all', true) + + } + + model Post { + id String @id @default(uuid()) + title String + author User @relation(fields: [authorId], references: [id]) + authorId String @default(auth().id) + + @@allow('all', true) + } + ` + ); + + const db = enhance({ id: 'userId-1' }); + await expect(db.user.create({ data: { id: 'userId-1' } })).toResolveTruthy(); + await expect(db.post.create({ data: { title: 'abc' } })).resolves.toMatchObject({ authorId: 'userId-1' }); + }); + + it('Default auth() with nested user context value', async () => { + const { enhance } = await loadSchema( + ` + model User { + id String @id + profile Profile? + posts Post[] + + @@allow('all', true) + } + + model Profile { + id String @id @default(uuid()) + image Image? + user User @relation(fields: [userId], references: [id]) + userId String @unique + } + + model Image { + id String @id @default(uuid()) + url String + profile Profile @relation(fields: [profileId], references: [id]) + profileId String @unique + } + + model Post { + id String @id @default(uuid()) + title String + defaultImageUrl String @default(auth().profile.image.url) + author User @relation(fields: [authorId], references: [id]) + authorId String + + @@allow('all', true) + } + ` + ); + const url = 'https://zenstack.dev'; + const db = enhance({ id: 'userId-1', profile: { image: { url } } }); + + // top-level create + await expect(db.user.create({ data: { id: 'userId-1' } })).toResolveTruthy(); + await expect( + db.post.create({ data: { title: 'abc', author: { connect: { id: 'userId-1' } } } }) + ).resolves.toMatchObject({ defaultImageUrl: url }); + + // nested create + let result = await db.user.create({ + data: { + id: 'userId-2', + posts: { + create: [{ title: 'p1' }, { title: 'p2' }], + }, + }, + include: { posts: true }, + }); + expect(result.posts).toEqual( + expect.arrayContaining([ + expect.objectContaining({ title: 'p1', defaultImageUrl: url }), + expect.objectContaining({ title: 'p2', defaultImageUrl: url }), + ]) + ); + }); }); diff --git a/tests/integration/tests/enhancements/with-policy/client-extensions.test.ts b/tests/integration/tests/enhancements/with-policy/client-extensions.test.ts index cadb42767..13f05aa51 100644 --- a/tests/integration/tests/enhancements/with-policy/client-extensions.test.ts +++ b/tests/integration/tests/enhancements/with-policy/client-extensions.test.ts @@ -1,4 +1,3 @@ -import { enhance } from '@zenstackhq/runtime'; import { loadSchema } from '@zenstackhq/testtools'; import path from 'path'; @@ -14,7 +13,7 @@ describe('With Policy: client extensions', () => { }); it('all model new method', async () => { - const { prisma, Prisma } = await loadSchema( + const { prisma, enhanceRaw, prismaModule } = await loadSchema( ` model Model { id String @id @default(uuid()) @@ -29,13 +28,13 @@ describe('With Policy: client extensions', () => { await prisma.model.create({ data: { value: 1 } }); await prisma.model.create({ data: { value: 2 } }); - const ext = Prisma.defineExtension((_prisma: any) => { + const ext = prismaModule.defineExtension((_prisma: any) => { return _prisma.$extends({ name: 'prisma-extension-getAll', model: { $allModels: { async getAll(this: T, args?: any) { - const context = Prisma.getExtensionContext(this); + const context = prismaModule.getExtensionContext(this); const r = await (context as any).findMany(args); console.log('getAll result:', r); return r; @@ -46,7 +45,7 @@ describe('With Policy: client extensions', () => { }); const xprisma = prisma.$extends(ext); - const db = enhance(xprisma); + const db = enhanceRaw(xprisma); await expect(db.model.getAll()).resolves.toHaveLength(2); // FIXME: extending an enhanced client doesn't work for this case @@ -55,7 +54,7 @@ describe('With Policy: client extensions', () => { }); it('one model new method', async () => { - const { prisma, Prisma } = await loadSchema( + const { prisma, enhanceRaw, prismaModule } = await loadSchema( ` model Model { id String @id @default(uuid()) @@ -70,13 +69,13 @@ describe('With Policy: client extensions', () => { await prisma.model.create({ data: { value: 1 } }); await prisma.model.create({ data: { value: 2 } }); - const ext = Prisma.defineExtension((_prisma: any) => { + const ext = prismaModule.defineExtension((_prisma: any) => { return _prisma.$extends({ name: 'prisma-extension-getAll', model: { model: { async getAll(this: T, args?: any) { - const context = Prisma.getExtensionContext(this); + const context = prismaModule.getExtensionContext(this); const r = await (context as any).findMany(args); return r; }, @@ -86,12 +85,12 @@ describe('With Policy: client extensions', () => { }); const xprisma = prisma.$extends(ext); - const db = enhance(xprisma); + const db = enhanceRaw(xprisma); await expect(db.model.getAll()).resolves.toHaveLength(2); }); it('add client method', async () => { - const { prisma, Prisma } = await loadSchema( + const { prisma, prismaModule, enhanceRaw } = await loadSchema( ` model Model { id String @id @default(uuid()) @@ -104,7 +103,7 @@ describe('With Policy: client extensions', () => { let logged = false; - const ext = Prisma.defineExtension((_prisma: any) => { + const ext = prismaModule.defineExtension((_prisma: any) => { return _prisma.$extends({ name: 'prisma-extension-log', client: { @@ -122,7 +121,7 @@ describe('With Policy: client extensions', () => { }); it('query override one model', async () => { - const { prisma, Prisma } = await loadSchema( + const { prisma, prismaModule, enhanceRaw } = await loadSchema( ` model Model { id String @id @default(uuid()) @@ -138,7 +137,7 @@ describe('With Policy: client extensions', () => { await prisma.model.create({ data: { x: 1, y: 200 } }); await prisma.model.create({ data: { x: 2, y: 300 } }); - const ext = Prisma.defineExtension((_prisma: any) => { + const ext = prismaModule.defineExtension((_prisma: any) => { return _prisma.$extends({ name: 'prisma-extension-queryOverride', query: { @@ -154,12 +153,12 @@ describe('With Policy: client extensions', () => { }); const xprisma = prisma.$extends(ext); - const db = enhance(xprisma); + const db = enhanceRaw(xprisma); await expect(db.model.findMany()).resolves.toHaveLength(1); }); it('query override all models', async () => { - const { prisma, Prisma } = await loadSchema( + const { prisma, prismaModule, enhanceRaw } = await loadSchema( ` model Model { id String @id @default(uuid()) @@ -175,7 +174,7 @@ describe('With Policy: client extensions', () => { await prisma.model.create({ data: { x: 1, y: 200 } }); await prisma.model.create({ data: { x: 2, y: 300 } }); - const ext = Prisma.defineExtension((_prisma: any) => { + const ext = prismaModule.defineExtension((_prisma: any) => { return _prisma.$extends({ name: 'prisma-extension-queryOverride', query: { @@ -192,12 +191,12 @@ describe('With Policy: client extensions', () => { }); const xprisma = prisma.$extends(ext); - const db = enhance(xprisma); + const db = enhanceRaw(xprisma); await expect(db.model.findMany()).resolves.toHaveLength(1); }); it('query override all operations', async () => { - const { prisma, Prisma } = await loadSchema( + const { prisma, prismaModule, enhanceRaw } = await loadSchema( ` model Model { id String @id @default(uuid()) @@ -213,7 +212,7 @@ describe('With Policy: client extensions', () => { await prisma.model.create({ data: { x: 1, y: 200 } }); await prisma.model.create({ data: { x: 2, y: 300 } }); - const ext = Prisma.defineExtension((_prisma: any) => { + const ext = prismaModule.defineExtension((_prisma: any) => { return _prisma.$extends({ name: 'prisma-extension-queryOverride', query: { @@ -230,12 +229,12 @@ describe('With Policy: client extensions', () => { }); const xprisma = prisma.$extends(ext); - const db = enhance(xprisma); + const db = enhanceRaw(xprisma); await expect(db.model.findMany()).resolves.toHaveLength(1); }); it('query override everything', async () => { - const { prisma, Prisma } = await loadSchema( + const { prisma, prismaModule, enhanceRaw } = await loadSchema( ` model Model { id String @id @default(uuid()) @@ -251,7 +250,7 @@ describe('With Policy: client extensions', () => { await prisma.model.create({ data: { x: 1, y: 200 } }); await prisma.model.create({ data: { x: 2, y: 300 } }); - const ext = Prisma.defineExtension((_prisma: any) => { + const ext = prismaModule.defineExtension((_prisma: any) => { return _prisma.$extends({ name: 'prisma-extension-queryOverride', query: { @@ -266,12 +265,12 @@ describe('With Policy: client extensions', () => { }); const xprisma = prisma.$extends(ext); - const db = enhance(xprisma); + const db = enhanceRaw(xprisma); await expect(db.model.findMany()).resolves.toHaveLength(1); }); it('result mutation', async () => { - const { prisma, Prisma } = await loadSchema( + const { prisma, prismaModule, enhanceRaw } = await loadSchema( ` model Model { id String @id @default(uuid()) @@ -285,7 +284,7 @@ describe('With Policy: client extensions', () => { await prisma.model.create({ data: { value: 0 } }); await prisma.model.create({ data: { value: 1 } }); - const ext = Prisma.defineExtension((_prisma: any) => { + const ext = prismaModule.defineExtension((_prisma: any) => { return _prisma.$extends({ name: 'prisma-extension-resultMutation', query: { @@ -303,14 +302,14 @@ describe('With Policy: client extensions', () => { }); const xprisma = prisma.$extends(ext); - const db = enhance(xprisma); + const db = enhanceRaw(xprisma); const r = await db.model.findMany(); expect(r).toHaveLength(1); expect(r).toEqual(expect.arrayContaining([expect.objectContaining({ value: 2 })])); }); it('result custom fields', async () => { - const { prisma, Prisma } = await loadSchema( + const { prisma, prismaModule, enhanceRaw } = await loadSchema( ` model Model { id String @id @default(uuid()) @@ -324,7 +323,7 @@ describe('With Policy: client extensions', () => { await prisma.model.create({ data: { value: 0 } }); await prisma.model.create({ data: { value: 1 } }); - const ext = Prisma.defineExtension((_prisma: any) => { + const ext = prismaModule.defineExtension((_prisma: any) => { return _prisma.$extends({ name: 'prisma-extension-resultNewFields', result: { @@ -341,7 +340,7 @@ describe('With Policy: client extensions', () => { }); const xprisma = prisma.$extends(ext); - const db = enhance(xprisma); + const db = enhanceRaw(xprisma); const r = await db.model.findMany(); expect(r).toHaveLength(1); expect(r).toEqual(expect.arrayContaining([expect.objectContaining({ doubleValue: 2 })])); diff --git a/tests/integration/tests/enhancements/with-policy/connect-disconnect.test.ts b/tests/integration/tests/enhancements/with-policy/connect-disconnect.test.ts index 99ae6d626..7bc4a9ed9 100644 --- a/tests/integration/tests/enhancements/with-policy/connect-disconnect.test.ts +++ b/tests/integration/tests/enhancements/with-policy/connect-disconnect.test.ts @@ -47,9 +47,9 @@ describe('With Policy: connect-disconnect', () => { `; it('simple to-many', async () => { - const { withPolicy, prisma } = await loadSchema(modelToMany); + const { enhance, prisma } = await loadSchema(modelToMany); - const db = withPolicy(); + const db = enhance(); // m1-1 -> m2-1 await db.m2.create({ data: { id: 'm2-1', value: 1, deleted: false } }); @@ -164,9 +164,9 @@ describe('With Policy: connect-disconnect', () => { }); it('nested to-many', async () => { - const { withPolicy } = await loadSchema(modelToMany); + const { enhance } = await loadSchema(modelToMany); - const db = withPolicy(); + const db = enhance(); await db.m3.create({ data: { id: 'm3-1', value: 1, deleted: false } }); await expect( @@ -219,9 +219,9 @@ describe('With Policy: connect-disconnect', () => { `; it('to-one', async () => { - const { withPolicy, prisma } = await loadSchema(modelToOne); + const { enhance, prisma } = await loadSchema(modelToOne); - const db = withPolicy(); + const db = enhance(); await db.m2.create({ data: { id: 'm2-1', value: 1, deleted: false } }); await db.m1.create({ @@ -314,9 +314,9 @@ describe('With Policy: connect-disconnect', () => { `; it('implicit many-to-many', async () => { - const { withPolicy, prisma } = await loadSchema(modelImplicitManyToMany); + const { enhance, prisma } = await loadSchema(modelImplicitManyToMany); - const db = withPolicy(); + const db = enhance(); // await prisma.m1.create({ data: { id: 'm1-1', value: 1 } }); // await prisma.m2.create({ data: { id: 'm2-1', value: 1 } }); @@ -379,9 +379,9 @@ describe('With Policy: connect-disconnect', () => { `; it('explicit many-to-many', async () => { - const { withPolicy, prisma } = await loadSchema(modelExplicitManyToMany); + const { enhance, prisma } = await loadSchema(modelExplicitManyToMany); - const db = withPolicy(); + const db = enhance(); await prisma.m1.create({ data: { id: 'm1-1', value: 1 } }); await prisma.m2.create({ data: { id: 'm2-1', value: 1 } }); diff --git a/tests/integration/tests/enhancements/with-policy/deep-nested.test.ts b/tests/integration/tests/enhancements/with-policy/deep-nested.test.ts index ee8f16467..26022aa6b 100644 --- a/tests/integration/tests/enhancements/with-policy/deep-nested.test.ts +++ b/tests/integration/tests/enhancements/with-policy/deep-nested.test.ts @@ -69,7 +69,7 @@ describe('With Policy:deep nested', () => { beforeEach(async () => { const params = await loadSchema(model); - db = params.withPolicy(); + db = params.enhance(); prisma = params.prisma; }); diff --git a/tests/integration/tests/enhancements/with-policy/empty-policy.test.ts b/tests/integration/tests/enhancements/with-policy/empty-policy.test.ts index 4a1a4d0c5..ee0b61850 100644 --- a/tests/integration/tests/enhancements/with-policy/empty-policy.test.ts +++ b/tests/integration/tests/enhancements/with-policy/empty-policy.test.ts @@ -13,7 +13,7 @@ describe('With Policy:empty policy', () => { }); it('direct operations', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model Model { id String @id @default(uuid()) @@ -22,7 +22,7 @@ describe('With Policy:empty policy', () => { ` ); - const db = withPolicy(); + const db = enhance(); await prisma.model.create({ data: { id: '1', value: 0 } }); await expect(db.model.create({ data: {} })).toBeRejectedByPolicy(); @@ -57,7 +57,7 @@ describe('With Policy:empty policy', () => { }); it('to-many write', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model M1 { id String @id @default(uuid()) @@ -74,7 +74,7 @@ describe('With Policy:empty policy', () => { ` ); - const db = withPolicy(); + const db = enhance(); await expect( db.m1.create({ @@ -88,7 +88,7 @@ describe('With Policy:empty policy', () => { }); it('to-one write', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model M1 { id String @id @default(uuid()) @@ -105,7 +105,7 @@ describe('With Policy:empty policy', () => { ` ); - const db = withPolicy(); + const db = enhance(); await expect( db.m1.create({ diff --git a/tests/integration/tests/enhancements/with-policy/field-comparison.test.ts b/tests/integration/tests/enhancements/with-policy/field-comparison.test.ts index 4f014d2f2..f130b2c94 100644 --- a/tests/integration/tests/enhancements/with-policy/field-comparison.test.ts +++ b/tests/integration/tests/enhancements/with-policy/field-comparison.test.ts @@ -3,7 +3,7 @@ import path from 'path'; const DB_NAME = 'field-comparison'; -describe('WithPolicy: field comparison tests', () => { +describe('Policy: field comparison tests', () => { let origDir: string; let dbUrl: string; let prisma: any; @@ -41,7 +41,7 @@ describe('WithPolicy: field comparison tests', () => { ); prisma = r.prisma; - const db = r.withPolicy(); + const db = r.enhance(); await expect(db.model.create({ data: { x: 1, y: 2 } })).toBeRejectedByPolicy(); await expect(db.model.create({ data: { x: 2, y: 1 } })).toResolveTruthy(); }); @@ -62,7 +62,7 @@ describe('WithPolicy: field comparison tests', () => { ); prisma = r.prisma; - const db = r.withPolicy(); + const db = r.enhance(); await expect(db.model.create({ data: { x: 1, y: 2 } })).toBeRejectedByPolicy(); await expect(db.model.create({ data: { x: 2, y: 1 } })).toResolveTruthy(); }); @@ -83,7 +83,7 @@ describe('WithPolicy: field comparison tests', () => { ); prisma = r.prisma; - const db = r.withPolicy(); + const db = r.enhance(); await expect(db.model.create({ data: { x: 'a', y: ['b', 'c'] } })).toBeRejectedByPolicy(); await expect(db.model.create({ data: { x: 'a', y: ['a', 'c'] } })).toResolveTruthy(); }); @@ -104,7 +104,7 @@ describe('WithPolicy: field comparison tests', () => { ); prisma = r.prisma; - const db = r.withPolicy(); + const db = r.enhance(); await expect(db.model.create({ data: { x: 'a', y: ['b', 'c'] } })).toBeRejectedByPolicy(); await expect(db.model.create({ data: { x: 'a', y: ['a', 'c'] } })).toResolveTruthy(); }); diff --git a/tests/integration/tests/enhancements/with-policy/field-level-policy.test.ts b/tests/integration/tests/enhancements/with-policy/field-level-policy.test.ts index ee89c58e7..ebaf2d858 100644 --- a/tests/integration/tests/enhancements/with-policy/field-level-policy.test.ts +++ b/tests/integration/tests/enhancements/with-policy/field-level-policy.test.ts @@ -1,7 +1,7 @@ import { loadSchema } from '@zenstackhq/testtools'; import path from 'path'; -describe('With Policy: field-level policy', () => { +describe('Policy: field-level policy', () => { let origDir: string; beforeAll(async () => { @@ -13,7 +13,7 @@ describe('With Policy: field-level policy', () => { }); it('read simple', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model User { id Int @id @default(autoincrement()) @@ -37,7 +37,7 @@ describe('With Policy: field-level policy', () => { await prisma.user.create({ data: { id: 1, admin: true } }); - const db = withPolicy(); + const db = enhance(); let r; // y is unreadable @@ -103,7 +103,7 @@ describe('With Policy: field-level policy', () => { }); it('read override', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model User { id Int @id @default(autoincrement()) @@ -128,7 +128,7 @@ describe('With Policy: field-level policy', () => { await prisma.user.create({ data: { id: 1, admin: true } }); - const db = withPolicy(); + const db = enhance(); // created but can't read back await expect( @@ -181,7 +181,7 @@ describe('With Policy: field-level policy', () => { }); it('read filter with auth', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model User { id Int @id @default(autoincrement()) @@ -205,7 +205,7 @@ describe('With Policy: field-level policy', () => { await prisma.user.create({ data: { id: 1, admin: true } }); - let db = withPolicy({ id: 1, admin: false }); + let db = enhance({ id: 1, admin: false }); let r; // y is unreadable @@ -246,7 +246,7 @@ describe('With Policy: field-level policy', () => { expect(r.y).toBeUndefined(); // y is readable - db = withPolicy({ id: 1, admin: true }); + db = enhance({ id: 1, admin: true }); r = await db.model.create({ data: { id: 2, @@ -281,7 +281,7 @@ describe('With Policy: field-level policy', () => { }); it('read filter with relation', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model User { id Int @id @default(autoincrement()) @@ -306,7 +306,7 @@ describe('With Policy: field-level policy', () => { await prisma.user.create({ data: { id: 1, admin: false } }); await prisma.user.create({ data: { id: 2, admin: true } }); - const db = withPolicy(); + const db = enhance(); let r; // y is unreadable @@ -381,7 +381,7 @@ describe('With Policy: field-level policy', () => { }); it('read coverage', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model Model { id Int @id @default(autoincrement()) @@ -393,7 +393,7 @@ describe('With Policy: field-level policy', () => { ` ); - const db = withPolicy(); + const db = enhance(); let r; // y is unreadable @@ -430,7 +430,7 @@ describe('With Policy: field-level policy', () => { }); it('read relation', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model User { id Int @id @default(autoincrement()) @@ -472,7 +472,7 @@ describe('With Policy: field-level policy', () => { }, }); - const db = withPolicy(); + const db = enhance(); // read to-many relation let r = await db.user.findUnique({ @@ -498,7 +498,7 @@ describe('With Policy: field-level policy', () => { }); it('update simple', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model User { id Int @id @default(autoincrement()) @@ -523,7 +523,7 @@ describe('With Policy: field-level policy', () => { await prisma.user.create({ data: { id: 1 }, }); - const db = withPolicy(); + const db = enhance(); await db.model.create({ data: { id: 1, x: 0, y: 0, ownerId: 1 }, @@ -569,7 +569,7 @@ describe('With Policy: field-level policy', () => { }); it('update with override', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model Model { id Int @id @default(autoincrement()) @@ -583,7 +583,7 @@ describe('With Policy: field-level policy', () => { ` ); - const db = withPolicy(); + const db = enhance(); await db.model.create({ data: { id: 1, x: 0, y: 0, z: 0 }, @@ -648,7 +648,7 @@ describe('With Policy: field-level policy', () => { }); it('update filter with relation', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model User { id Int @id @default(autoincrement()) @@ -676,7 +676,7 @@ describe('With Policy: field-level policy', () => { await prisma.user.create({ data: { id: 2, admin: true }, }); - const db = withPolicy(); + const db = enhance(); await db.model.create({ data: { id: 1, x: 0, y: 0, ownerId: 1 }, @@ -706,7 +706,7 @@ describe('With Policy: field-level policy', () => { }); it('update with nested to-many relation', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model User { id Int @id @default(autoincrement()) @@ -734,7 +734,7 @@ describe('With Policy: field-level policy', () => { await prisma.user.create({ data: { id: 2, admin: true, models: { create: { id: 2, x: 0, y: 0 } } }, }); - const db = withPolicy(); + const db = enhance(); await expect( db.user.update({ @@ -758,7 +758,7 @@ describe('With Policy: field-level policy', () => { }); it('update with nested to-one relation', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model User { id Int @id @default(autoincrement()) @@ -786,7 +786,7 @@ describe('With Policy: field-level policy', () => { await prisma.user.create({ data: { id: 2, admin: true, model: { create: { id: 2, x: 0, y: 0 } } }, }); - const db = withPolicy(); + const db = enhance(); await expect( db.user.update({ @@ -828,7 +828,7 @@ describe('With Policy: field-level policy', () => { }); it('update with connect to-many relation', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model User { id Int @id @default(autoincrement()) @@ -854,7 +854,7 @@ describe('With Policy: field-level policy', () => { await prisma.model.create({ data: { id: 1, value: 0 } }); await prisma.model.create({ data: { id: 2, value: 1 } }); - const db = withPolicy(); + const db = enhance(); await expect( db.model.update({ @@ -922,7 +922,7 @@ describe('With Policy: field-level policy', () => { }); it('update with connect to-one relation', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model User { id Int @id @default(autoincrement()) @@ -948,7 +948,7 @@ describe('With Policy: field-level policy', () => { await prisma.model.create({ data: { id: 1, value: 0 } }); await prisma.model.create({ data: { id: 2, value: 1 } }); - const db = withPolicy(); + const db = enhance(); await expect( db.model.update({ @@ -1010,7 +1010,7 @@ describe('With Policy: field-level policy', () => { }); it('updateMany simple', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model User { id Int @id @default(autoincrement()) @@ -1042,7 +1042,7 @@ describe('With Policy: field-level policy', () => { }, }, }); - const db = withPolicy(); + const db = enhance(); await expect(db.model.updateMany({ data: { y: 2 } })).resolves.toEqual({ count: 1 }); await expect(db.model.findUnique({ where: { id: 1 } })).resolves.toEqual( @@ -1054,7 +1054,7 @@ describe('With Policy: field-level policy', () => { }); it('updateMany override', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model Model { id Int @id @default(autoincrement()) @@ -1067,7 +1067,7 @@ describe('With Policy: field-level policy', () => { ` ); - const db = withPolicy(); + const db = enhance(); await db.model.create({ data: { id: 1, x: 0, y: 0 } }); await db.model.create({ data: { id: 2, x: 1, y: 0 } }); @@ -1084,7 +1084,7 @@ describe('With Policy: field-level policy', () => { }); it('updateMany nested', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model User { id Int @id @default(autoincrement()) @@ -1116,7 +1116,7 @@ describe('With Policy: field-level policy', () => { }, }, }); - const db = withPolicy(); + const db = enhance(); await expect( db.user.update({ where: { id: 1 }, data: { models: { updateMany: { data: { y: 2 } } } } }) @@ -1144,7 +1144,7 @@ describe('With Policy: field-level policy', () => { }); it('this expression', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model User { id Int @id @@ -1157,24 +1157,24 @@ describe('With Policy: field-level policy', () => { await prisma.user.create({ data: { id: 1, username: 'test' } }); // admin - let r = await withPolicy({ id: 1, admin: true }).user.findFirst(); + let r = await enhance({ id: 1, admin: true }).user.findFirst(); expect(r.username).toEqual('test'); // owner - r = await withPolicy({ id: 1 }).user.findFirst(); + r = await enhance({ id: 1 }).user.findFirst(); expect(r.username).toEqual('test'); // anonymous - r = await withPolicy().user.findFirst(); + r = await enhance().user.findFirst(); expect(r.username).toBeUndefined(); // non-owner - r = await withPolicy({ id: 2 }).user.findFirst(); + r = await enhance({ id: 2 }).user.findFirst(); expect(r.username).toBeUndefined(); }); it('collection predicate', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model User { id Int @id @default(autoincrement()) @@ -1206,7 +1206,7 @@ describe('With Policy: field-level policy', () => { ` ); - const db = withPolicy(); + const db = enhance(); await prisma.user.create({ data: { @@ -1269,7 +1269,7 @@ describe('With Policy: field-level policy', () => { }); it('deny only without field access', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model User { id Int @id @default(autoincrement()) @@ -1285,14 +1285,14 @@ describe('With Policy: field-level policy', () => { }); await expect( - withPolicy({ id: 1, role: 'ADMIN' }).user.update({ + enhance({ id: 1, role: 'ADMIN' }).user.update({ where: { id: user.id }, data: { role: 'ADMIN' }, }) ).toResolveTruthy(); await expect( - withPolicy({ id: 1, role: 'USER' }).user.update({ + enhance({ id: 1, role: 'USER' }).user.update({ where: { id: user.id }, data: { role: 'ADMIN' }, }) @@ -1300,7 +1300,7 @@ describe('With Policy: field-level policy', () => { }); it('deny only with field access', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model User { id Int @id @default(autoincrement()) @@ -1317,14 +1317,14 @@ describe('With Policy: field-level policy', () => { }); await expect( - withPolicy({ id: 1, role: 'ADMIN' }).user.update({ + enhance({ id: 1, role: 'ADMIN' }).user.update({ where: { id: user1.id }, data: { role: 'ADMIN' }, }) ).toResolveTruthy(); await expect( - withPolicy({ id: 1, role: 'USER' }).user.update({ + enhance({ id: 1, role: 'USER' }).user.update({ where: { id: user1.id }, data: { role: 'ADMIN' }, }) @@ -1335,7 +1335,7 @@ describe('With Policy: field-level policy', () => { }); await expect( - withPolicy({ id: 1, role: 'ADMIN' }).user.update({ + enhance({ id: 1, role: 'ADMIN' }).user.update({ where: { id: user2.id }, data: { role: 'ADMIN' }, }) diff --git a/tests/integration/tests/enhancements/with-policy/field-validation.test.ts b/tests/integration/tests/enhancements/with-policy/field-validation.test.ts index 8727f1561..16f56dddd 100644 --- a/tests/integration/tests/enhancements/with-policy/field-validation.test.ts +++ b/tests/integration/tests/enhancements/with-policy/field-validation.test.ts @@ -5,7 +5,7 @@ describe('With Policy: field validation', () => { let db: FullDbClientContract; beforeAll(async () => { - const { withPolicy, prisma: _prisma } = await loadSchema( + const { enhance, prisma: _prisma } = await loadSchema( ` model User { id String @id @default(cuid()) @@ -35,6 +35,8 @@ describe('With Policy: field validation', () => { text3 String @length(min: 3) text4 String @length(max: 5) text5 String? @endsWith('xyz') + text6 String? @trim @lower + text7 String? @upper @@allow('all', true) } @@ -49,7 +51,7 @@ describe('With Policy: field validation', () => { } ` ); - db = withPolicy(); + db = enhance(); }); beforeEach(() => { @@ -495,4 +497,61 @@ describe('With Policy: field validation', () => { }) ).toResolveTruthy(); }); + + it('string transformation', async () => { + await db.user.create({ + data: { + id: '1', + password: 'abc123!@#', + email: 'who@myorg.com', + handle: 'user1', + }, + }); + + await expect( + db.userData.create({ + data: { + userId: '1', + a: 1, + b: 0, + c: -1, + d: 0, + text1: 'abc123', + text2: 'def', + text3: 'aaa', + text4: 'abcab', + text6: ' AbC ', + text7: 'abc', + }, + }) + ).resolves.toMatchObject({ text6: 'abc', text7: 'ABC' }); + + await expect( + db.user.create({ + data: { + id: '2', + password: 'abc123!@#', + email: 'who@myorg.com', + handle: 'user2', + userData: { + create: { + a: 1, + b: 0, + c: -1, + d: 0, + text1: 'abc123', + text2: 'def', + text3: 'aaa', + text4: 'abcab', + text6: ' AbC ', + text7: 'abc', + }, + }, + }, + include: { userData: true }, + }) + ).resolves.toMatchObject({ + userData: expect.objectContaining({ text6: 'abc', text7: 'ABC' }), + }); + }); }); diff --git a/tests/integration/tests/enhancements/with-policy/fluent-api.test.ts b/tests/integration/tests/enhancements/with-policy/fluent-api.test.ts index 264c5da28..6c27aab1c 100644 --- a/tests/integration/tests/enhancements/with-policy/fluent-api.test.ts +++ b/tests/integration/tests/enhancements/with-policy/fluent-api.test.ts @@ -13,7 +13,7 @@ describe('With Policy: fluent API', () => { }); it('fluent api', async () => { - const { withPolicy, prisma } = await loadSchema( + const { enhance, prisma } = await loadSchema( ` model User { id Int @id @@ -58,7 +58,7 @@ model Post { }, }); - const db = withPolicy({ id: 1 }); + const db = enhance({ id: 1 }); // check policies await expect(db.user.findUnique({ where: { id: 1 } }).posts()).resolves.toHaveLength(2); diff --git a/tests/integration/tests/enhancements/with-policy/multi-field-unique.test.ts b/tests/integration/tests/enhancements/with-policy/multi-field-unique.test.ts index 3dcc07850..f0eeb1a8a 100644 --- a/tests/integration/tests/enhancements/with-policy/multi-field-unique.test.ts +++ b/tests/integration/tests/enhancements/with-policy/multi-field-unique.test.ts @@ -13,7 +13,7 @@ describe('With Policy: multi-field unique', () => { }); it('toplevel crud test unnamed constraint', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model Model { id String @id @default(uuid()) @@ -28,7 +28,7 @@ describe('With Policy: multi-field unique', () => { ` ); - const db = withPolicy(); + const db = enhance(); await expect(db.model.create({ data: { a: 'a1', b: 'b1', x: 1 } })).toResolveTruthy(); await expect(db.model.create({ data: { a: 'a1', b: 'b1', x: 2 } })).toBeRejectedWithCode('P2002'); @@ -43,7 +43,7 @@ describe('With Policy: multi-field unique', () => { }); it('toplevel crud test named constraint', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model Model { id String @id @default(uuid()) @@ -58,7 +58,7 @@ describe('With Policy: multi-field unique', () => { ` ); - const db = withPolicy(); + const db = enhance(); await expect(db.model.create({ data: { a: 'a1', b: 'b1', x: 1 } })).toResolveTruthy(); await expect(db.model.findUnique({ where: { myconstraint: { a: 'a1', b: 'b1' } } })).toResolveTruthy(); @@ -73,7 +73,7 @@ describe('With Policy: multi-field unique', () => { }); it('nested crud test', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model M1 { id String @id @default(uuid()) @@ -95,7 +95,7 @@ describe('With Policy: multi-field unique', () => { ` ); - const db = withPolicy(); + const db = enhance(); await expect(db.m1.create({ data: { id: '1', m2: { create: { a: 'a1', b: 'b1', x: 1 } } } })).toResolveTruthy(); await expect( diff --git a/tests/integration/tests/enhancements/with-policy/multi-id-fields.test.ts b/tests/integration/tests/enhancements/with-policy/multi-id-fields.test.ts index f48cdba45..227dc5a27 100644 --- a/tests/integration/tests/enhancements/with-policy/multi-id-fields.test.ts +++ b/tests/integration/tests/enhancements/with-policy/multi-id-fields.test.ts @@ -13,7 +13,7 @@ describe('With Policy: multiple id fields', () => { }); it('multi-id fields', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model A { x String @@ -43,7 +43,7 @@ describe('With Policy: multiple id fields', () => { ` ); - const db = withPolicy(); + const db = enhance(); await expect(db.a.create({ data: { x: '1', y: 1, value: 0 } })).toBeRejectedByPolicy(); await expect(db.a.create({ data: { x: '1', y: 2, value: 1 } })).toResolveTruthy(); @@ -70,7 +70,7 @@ describe('With Policy: multiple id fields', () => { }); it('multi-id auth', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model User { x String @@ -124,7 +124,7 @@ describe('With Policy: multiple id fields', () => { await prisma.user.create({ data: { x: '1', y: '1' } }); await prisma.user.create({ data: { x: '1', y: '2' } }); - const anonDb = withPolicy(); + const anonDb = enhance(); await expect( anonDb.m.create({ data: { owner: { connect: { x_y: { x: '1', y: '2' } } } } }) @@ -139,7 +139,7 @@ describe('With Policy: multiple id fields', () => { anonDb.n.create({ data: { owner: { connect: { x_y: { x: '1', y: '1' } } } } }) ).toBeRejectedByPolicy(); - const db = withPolicy({ x: '1', y: '1' }); + const db = enhance({ x: '1', y: '1' }); await expect(db.m.create({ data: { owner: { connect: { x_y: { x: '1', y: '2' } } } } })).toBeRejectedByPolicy(); await expect(db.m.create({ data: { owner: { connect: { x_y: { x: '1', y: '1' } } } } })).toResolveTruthy(); @@ -149,13 +149,13 @@ describe('With Policy: multiple id fields', () => { await expect(db.p.create({ data: { owner: { connect: { x_y: { x: '1', y: '2' } } } } })).toResolveTruthy(); await expect( - withPolicy(undefined).q.create({ data: { owner: { connect: { x_y: { x: '1', y: '1' } } } } }) + enhance(undefined).q.create({ data: { owner: { connect: { x_y: { x: '1', y: '1' } } } } }) ).toBeRejectedByPolicy(); await expect(db.q.create({ data: { owner: { connect: { x_y: { x: '1', y: '2' } } } } })).toResolveTruthy(); }); it('multi-id to-one nested write', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model A { x Int @@ -177,7 +177,7 @@ describe('With Policy: multiple id fields', () => { } ` ); - const db = withPolicy(); + const db = enhance(); await expect( db.b.create({ data: { @@ -205,7 +205,7 @@ describe('With Policy: multiple id fields', () => { }); it('multi-id to-many nested write', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model A { x Int @@ -237,7 +237,7 @@ describe('With Policy: multiple id fields', () => { } ` ); - const db = withPolicy(); + const db = enhance(); await expect( db.b.create({ data: { diff --git a/tests/integration/tests/enhancements/with-policy/nested-to-many.test.ts b/tests/integration/tests/enhancements/with-policy/nested-to-many.test.ts index b112aeeb1..777af1118 100644 --- a/tests/integration/tests/enhancements/with-policy/nested-to-many.test.ts +++ b/tests/integration/tests/enhancements/with-policy/nested-to-many.test.ts @@ -13,7 +13,7 @@ describe('With Policy:nested to-many', () => { }); it('read filtering', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model M1 { id String @id @default(uuid()) @@ -34,7 +34,7 @@ describe('With Policy:nested to-many', () => { ` ); - const db = withPolicy(); + const db = enhance(); let read = await db.m1.create({ include: { m2: true }, @@ -62,7 +62,7 @@ describe('With Policy:nested to-many', () => { }); it('read condition hoisting', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model M1 { id String @id @default(uuid()) @@ -108,7 +108,7 @@ describe('With Policy:nested to-many', () => { ` ); - const db = withPolicy(); + const db = enhance(); await db.m1.create({ include: { m2: true }, @@ -144,7 +144,7 @@ describe('With Policy:nested to-many', () => { }); it('create simple', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model M1 { id String @id @default(uuid()) @@ -165,7 +165,7 @@ describe('With Policy:nested to-many', () => { ` ); - const db = withPolicy(); + const db = enhance(); // single create denied await expect( @@ -211,7 +211,7 @@ describe('With Policy:nested to-many', () => { }); it('update simple', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model M1 { id String @id @default(uuid()) @@ -233,7 +233,7 @@ describe('With Policy:nested to-many', () => { ` ); - const db = withPolicy(); + const db = enhance(); await db.m1.create({ data: { @@ -285,7 +285,7 @@ describe('With Policy:nested to-many', () => { }); it('update with create from one to many', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model M1 { id String @id @default(uuid()) @@ -307,7 +307,7 @@ describe('With Policy:nested to-many', () => { ` ); - const db = withPolicy(); + const db = enhance(); await db.m1.create({ data: { @@ -342,7 +342,7 @@ describe('With Policy:nested to-many', () => { }); it('update with create from many to one', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model M1 { id String @id @default(uuid()) @@ -364,7 +364,7 @@ describe('With Policy:nested to-many', () => { ` ); - const db = withPolicy(); + const db = enhance(); await db.m2.create({ data: { id: '1' } }); @@ -392,7 +392,7 @@ describe('With Policy:nested to-many', () => { }); it('update with delete', async () => { - const { withPolicy, prisma } = await loadSchema( + const { enhance, prisma } = await loadSchema( ` model M1 { id String @id @default(uuid()) @@ -415,7 +415,7 @@ describe('With Policy:nested to-many', () => { ` ); - const db = withPolicy(); + const db = enhance(); await db.m1.create({ data: { @@ -496,7 +496,7 @@ describe('With Policy:nested to-many', () => { }); it('create with nested read', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model M1 { id String @id @default(uuid()) @@ -530,7 +530,7 @@ describe('With Policy:nested to-many', () => { ` ); - const db = withPolicy(); + const db = enhance(); await expect( db.m1.create({ @@ -589,7 +589,7 @@ describe('With Policy:nested to-many', () => { }); it('update with nested read', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model M1 { id String @id @default(uuid()) @@ -621,7 +621,7 @@ describe('With Policy:nested to-many', () => { ` ); - const db = withPolicy(); + const db = enhance(); await db.m1.create({ data: { id: '1', diff --git a/tests/integration/tests/enhancements/with-policy/nested-to-one.test.ts b/tests/integration/tests/enhancements/with-policy/nested-to-one.test.ts index 2e14b6d02..4b30c095f 100644 --- a/tests/integration/tests/enhancements/with-policy/nested-to-one.test.ts +++ b/tests/integration/tests/enhancements/with-policy/nested-to-one.test.ts @@ -13,7 +13,7 @@ describe('With Policy:nested to-one', () => { }); it('read filtering for optional relation', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model M1 { id String @id @default(uuid()) @@ -34,7 +34,7 @@ describe('With Policy:nested to-one', () => { ` ); - const db = withPolicy(); + const db = enhance(); let read = await db.m1.create({ include: { m2: true }, @@ -60,7 +60,7 @@ describe('With Policy:nested to-one', () => { }); it('read rejection for non-optional relation', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model M1 { id String @id @default(uuid()) @@ -91,7 +91,7 @@ describe('With Policy:nested to-one', () => { }, }); - const db = withPolicy(); + const db = enhance(); await expect(db.m2.findUnique({ where: { id: '1' }, include: { m1: true } })).toResolveFalsy(); await expect(db.m2.findMany({ include: { m1: true } })).resolves.toHaveLength(0); @@ -100,7 +100,7 @@ describe('With Policy:nested to-one', () => { }); it('read condition hoisting', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model M1 { id String @id @default(uuid()) @@ -134,7 +134,7 @@ describe('With Policy:nested to-one', () => { ` ); - const db = withPolicy(); + const db = enhance(); await db.m1.create({ include: { m2: true }, @@ -153,7 +153,7 @@ describe('With Policy:nested to-one', () => { }); it('create and update tests', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model M1 { id String @id @default(uuid()) @@ -175,7 +175,7 @@ describe('With Policy:nested to-one', () => { ` ); - const db = withPolicy(); + const db = enhance(); // create denied await expect( @@ -213,7 +213,7 @@ describe('With Policy:nested to-one', () => { }); it('nested create', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model M1 { id String @id @default(uuid()) @@ -236,7 +236,7 @@ describe('With Policy:nested to-one', () => { { logPrismaQuery: true } ); - const db = withPolicy(); + const db = enhance(); await db.m1.create({ data: { @@ -269,7 +269,7 @@ describe('With Policy:nested to-one', () => { }); it('nested delete', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model M1 { id String @id @default(uuid()) @@ -292,7 +292,7 @@ describe('With Policy:nested to-one', () => { ` ); - const db = withPolicy(); + const db = enhance(); await db.m1.create({ data: { @@ -335,7 +335,7 @@ describe('With Policy:nested to-one', () => { }); it('nested relation delete', async () => { - const { withPolicy, prisma } = await loadSchema( + const { enhance, prisma } = await loadSchema( ` model User { id String @id @default(uuid()) @@ -356,7 +356,7 @@ describe('With Policy:nested to-one', () => { ` ); - await withPolicy({ id: 'user1' }).m1.create({ + await enhance({ id: 'user1' }).m1.create({ data: { id: 'm1', value: 1, @@ -364,7 +364,7 @@ describe('With Policy:nested to-one', () => { }); await expect( - withPolicy({ id: 'user2' }).user.create({ + enhance({ id: 'user2' }).user.create({ data: { id: 'user2', m1: { @@ -375,7 +375,7 @@ describe('With Policy:nested to-one', () => { ).toResolveTruthy(); await expect( - withPolicy({ id: 'user2' }).user.update({ + enhance({ id: 'user2' }).user.update({ where: { id: 'user2' }, data: { m1: { delete: true }, @@ -384,7 +384,7 @@ describe('With Policy:nested to-one', () => { ).toBeRejectedByPolicy(); await expect( - withPolicy({ id: 'user1' }).user.create({ + enhance({ id: 'user1' }).user.create({ data: { id: 'user1', m1: { @@ -395,7 +395,7 @@ describe('With Policy:nested to-one', () => { ).toResolveTruthy(); await expect( - withPolicy({ id: 'user1' }).user.update({ + enhance({ id: 'user1' }).user.update({ where: { id: 'user1' }, data: { m1: { delete: true }, diff --git a/tests/integration/tests/enhancements/with-policy/options.test.ts b/tests/integration/tests/enhancements/with-policy/options.test.ts index 2c661ceb4..55c5458f4 100644 --- a/tests/integration/tests/enhancements/with-policy/options.test.ts +++ b/tests/integration/tests/enhancements/with-policy/options.test.ts @@ -1,20 +1,9 @@ -import { withPolicy } from '@zenstackhq/runtime'; import { loadSchema } from '@zenstackhq/testtools'; import path from 'path'; describe('Password test', () => { - let origDir: string; - - beforeAll(async () => { - origDir = path.resolve('.'); - }); - - afterEach(async () => { - process.chdir(origDir); - }); - it('load path', async () => { - const { prisma } = await loadSchema( + const { prisma, projectDir } = await loadSchema( ` model Foo { id String @id @default(cuid()) @@ -25,7 +14,8 @@ describe('Password test', () => { { getPrismaOnly: true, output: './zen' } ); - const db = withPolicy(prisma, undefined, { loadPath: './zen' }); + const enhance = require(path.join(projectDir, 'zen/enhance')).enhance; + const db = enhance(prisma, { loadPath: './zen' }); await expect( db.foo.create({ data: { x: 0 }, @@ -34,7 +24,7 @@ describe('Password test', () => { }); it('overrides', async () => { - const { prisma } = await loadSchema( + const { prisma, projectDir } = await loadSchema( ` model Foo { id String @id @default(cuid()) @@ -45,9 +35,10 @@ describe('Password test', () => { { getPrismaOnly: true, output: './zen' } ); - const db = withPolicy(prisma, undefined, { - modelMeta: require(path.resolve('./zen/model-meta')).default, - policy: require(path.resolve('./zen/policy')).default, + const enhance = require(path.join(projectDir, 'zen/enhance')).enhance; + const db = enhance(prisma, { + modelMeta: require(path.join(projectDir, 'zen/model-meta')).default, + policy: require(path.resolve(projectDir, 'zen/policy')).default, }); await expect( db.foo.create({ diff --git a/tests/integration/tests/enhancements/with-policy/permissions-checker.test.ts b/tests/integration/tests/enhancements/with-policy/permissions-checker.test.ts new file mode 100644 index 000000000..e055fc035 --- /dev/null +++ b/tests/integration/tests/enhancements/with-policy/permissions-checker.test.ts @@ -0,0 +1,107 @@ +import { loadSchema } from '@zenstackhq/testtools'; +import path from 'path'; + +describe('With Policy: permissions checker test', () => { + let origDir: string; + + beforeAll(async () => { + origDir = path.resolve('.'); + }); + + afterEach(() => { + process.chdir(origDir); + }); + + it('`check` method on enhanced prisma', async () => { + const { enhance, enhanceRaw, prisma } = await loadSchema( + ` + model User { + id String @id + age Int + email String @unique + role String @default("user") + posts Post[] + comments Comment[] + + // @@allow('all', true) + @@allow('create,read', age > 18 && age < 60) + + @@deny('update', age > 18 && age < 60) + + @@deny('delete', true) + @@allow('delete', true) + + } + + model Post { + id String @id @default(uuid()) + title String + rating Int + published Boolean @default(false) + authorId String @default("userId-1") + author User @relation(fields: [authorId], references: [id]) + comments Comment[] + + + @@allow('create,read', auth() == author && title == "Title" && rating > 1) + @@deny('read', !published || published == false) + + @@deny('update', rating < 10) + @@allow('update', rating > 5) + + @@deny('delete', auth() == null) + } + + model Comment { + id String @id @default(uuid()) + content String + postId String + post Post @relation(fields: [postId], references: [id]) + authorId String + author User @relation(fields: [authorId], references: [id]) + + @@deny('read', auth().age > 18 && auth().age < 60) + @@allow('create', auth().role == "editor") + @@allow('delete', auth().id == author.id) + } + ` + ); + + const authDb = enhance({ id: 'userId-1' }); + const db = enhanceRaw(prisma, {}); + + // check user + await expect(db.user.check('read', {})).toResolveTruthy(); + await expect(authDb.user.check('read', {})).toResolveTruthy(); + await expect(authDb.user.check('read', { age: { lt: 10 } })).toResolveFalsy(); + await expect(authDb.user.check('update', {})).toResolveTruthy(); + await expect(authDb.user.check('delete', {})).toResolveFalsy(); + + // check post + await expect(db.post.check('read', {})).toResolveFalsy(); + await expect(authDb.post.check('read', {})).toResolveTruthy(); + await expect(authDb.post.check('read', { author: { id: 'userId-1' } })).toResolveTruthy(); + await expect(authDb.post.check('read', { author: { id: 'invalid' } })).toResolveFalsy(); + await expect(authDb.post.check('read', { authorId: 'userId-1' })).toResolveTruthy(); + await expect(authDb.post.check('read', { authorId: 'invalid' })).toResolveFalsy(); + await expect(authDb.post.check('read', { title: 'Title' })).toResolveTruthy(); + await expect(authDb.post.check('read', { title: 'invalid' })).toResolveFalsy(); + await expect(authDb.post.check('read', { rating: 2 })).toResolveTruthy(); + await expect(authDb.post.check('read', { rating: 0 })).toResolveFalsy(); + await expect(authDb.post.check('read', { rating: { gt: 8 } })).toResolveTruthy(); + await expect(authDb.post.check('read', { rating: { lt: 1 } })).toResolveFalsy(); + await expect(authDb.post.check('create', {})).toResolveTruthy(); + await expect(authDb.post.check('update', {})).toResolveTruthy(); + await expect(authDb.post.check('update', { rating: { lt: 1 } })).toResolveFalsy(); + await expect(authDb.post.check('update', { rating: { gt: 10 } })).toResolveTruthy(); + await expect(authDb.post.check('update', { rating: 8 })).toResolveFalsy(); + await expect(db.post.check('delete', {})).toResolveFalsy(); + await expect(authDb.post.check('delete', {})).toResolveTruthy(); + await expect(authDb.post.check('read', { published: true })).toResolveTruthy(); + await expect(authDb.post.check('read', { published: false })).toResolveFalsy(); + + await expect(db.comment.check('delete', {})).toResolveFalsy(); + await expect(authDb.comment.check('delete', {})).toResolveTruthy(); + await expect(authDb.comment.check('delete', { author: { id: 'invalid' } })).toResolveFalsy(); + }); +}); diff --git a/tests/integration/tests/enhancements/with-policy/petstore-sample.test.ts b/tests/integration/tests/enhancements/with-policy/petstore-sample.test.ts index 9c251faf5..691c6176a 100644 --- a/tests/integration/tests/enhancements/with-policy/petstore-sample.test.ts +++ b/tests/integration/tests/enhancements/with-policy/petstore-sample.test.ts @@ -7,11 +7,11 @@ describe('Pet Store Policy Tests', () => { let prisma: FullDbClientContract; beforeAll(async () => { - const { withPolicy, prisma: _prisma } = await loadSchemaFromFile( + const { enhance, prisma: _prisma } = await loadSchemaFromFile( path.join(__dirname, '../../schema/petstore.zmodel'), { addPrelude: false } ); - getDb = withPolicy; + getDb = enhance; prisma = _prisma; }); diff --git a/tests/integration/tests/enhancements/with-policy/post-update.test.ts b/tests/integration/tests/enhancements/with-policy/post-update.test.ts index c40d338a3..e2d7e0156 100644 --- a/tests/integration/tests/enhancements/with-policy/post-update.test.ts +++ b/tests/integration/tests/enhancements/with-policy/post-update.test.ts @@ -13,7 +13,7 @@ describe('With Policy: post update', () => { }); it('simple allow', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model Model { id String @id @default(uuid()) @@ -25,7 +25,7 @@ describe('With Policy: post update', () => { ` ); - const db = withPolicy(); + const db = enhance(); await expect(db.model.create({ data: { id: '1', value: 0 } })).toResolveTruthy(); await expect(db.model.update({ where: { id: '1' }, data: { value: 1 } })).toBeRejectedByPolicy(); @@ -33,7 +33,7 @@ describe('With Policy: post update', () => { }); it('simple deny', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model Model { id String @id @default(uuid()) @@ -45,7 +45,7 @@ describe('With Policy: post update', () => { ` ); - const db = withPolicy(); + const db = enhance(); await expect(db.model.create({ data: { id: '1', value: 0 } })).toResolveTruthy(); await expect(db.model.update({ where: { id: '1' }, data: { value: 1 } })).toBeRejectedByPolicy(); @@ -53,7 +53,7 @@ describe('With Policy: post update', () => { }); it('mixed pre and post', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model Model { id String @id @default(uuid()) @@ -65,7 +65,7 @@ describe('With Policy: post update', () => { ` ); - const db = withPolicy(); + const db = enhance(); await expect(db.model.create({ data: { id: '1', value: 0 } })).toResolveTruthy(); await expect(db.model.update({ where: { id: '1' }, data: { value: 1 } })).toBeRejectedByPolicy(); @@ -76,7 +76,7 @@ describe('With Policy: post update', () => { }); it('functions pre-update', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model Model { id String @id @default(uuid()) @@ -89,7 +89,7 @@ describe('With Policy: post update', () => { ` ); - const db = withPolicy(); + const db = enhance(); await prisma.model.create({ data: { id: '1', value: 'good', x: 1 } }); await expect(db.model.update({ where: { id: '1' }, data: { value: 'hello' } })).toBeRejectedByPolicy(); @@ -100,7 +100,7 @@ describe('With Policy: post update', () => { }); it('functions post-update', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model Model { id String @id @default(uuid()) @@ -114,7 +114,7 @@ describe('With Policy: post update', () => { { logPrismaQuery: true } ); - const db = withPolicy(); + const db = enhance(); await prisma.model.create({ data: { id: '1', value: 'good', x: 1 } }); await expect(db.model.update({ where: { id: '1' }, data: { value: 'nice' } })).toBeRejectedByPolicy(); @@ -124,7 +124,7 @@ describe('With Policy: post update', () => { }); it('collection predicate pre-update', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model M1 { id String @id @default(uuid()) @@ -145,7 +145,7 @@ describe('With Policy: post update', () => { ` ); - const db = withPolicy(); + const db = enhance(); await prisma.m1.create({ data: { @@ -181,7 +181,7 @@ describe('With Policy: post update', () => { }); it('collection predicate post-update', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model M1 { id String @id @default(uuid()) @@ -202,7 +202,7 @@ describe('With Policy: post update', () => { ` ); - const db = withPolicy(); + const db = enhance(); await prisma.m1.create({ data: { @@ -238,7 +238,7 @@ describe('With Policy: post update', () => { }); it('nested to-many', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model M1 { id String @id @default(uuid()) @@ -258,7 +258,7 @@ describe('With Policy: post update', () => { ` ); - const db = withPolicy(); + const db = enhance(); await expect( db.m1.create({ @@ -297,7 +297,7 @@ describe('With Policy: post update', () => { }); it('nested to-one', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model M1 { id String @id @default(uuid()) @@ -317,7 +317,7 @@ describe('With Policy: post update', () => { ` ); - const db = withPolicy(); + const db = enhance(); await expect( db.m1.create({ @@ -350,7 +350,7 @@ describe('With Policy: post update', () => { }); it('nested select', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model M1 { id String @id @default(uuid()) @@ -370,7 +370,7 @@ describe('With Policy: post update', () => { ` ); - const db = withPolicy(); + const db = enhance(); await expect( db.m1.create({ @@ -401,7 +401,7 @@ describe('With Policy: post update', () => { }); it('deep nesting', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model M1 { id String @id @default(uuid()) @@ -432,7 +432,7 @@ describe('With Policy: post update', () => { ` ); - const db = withPolicy(); + const db = enhance(); await expect( db.m1.create({ diff --git a/tests/integration/tests/enhancements/with-policy/query-reduction.test.ts b/tests/integration/tests/enhancements/with-policy/query-reduction.test.ts index 1654fba96..264119453 100644 --- a/tests/integration/tests/enhancements/with-policy/query-reduction.test.ts +++ b/tests/integration/tests/enhancements/with-policy/query-reduction.test.ts @@ -13,7 +13,7 @@ describe('With Policy: query reduction', () => { }); it('test query reduction', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model User { id Int @id @default(autoincrement()) @@ -65,8 +65,8 @@ describe('With Policy: query reduction', () => { }, }); - const dbUser1 = withPolicy({ id: 1 }); - const dbUser2 = withPolicy({ id: 2 }); + const dbUser1 = enhance({ id: 1 }); + const dbUser2 = enhance({ id: 2 }); await expect( dbUser1.user.findMany({ diff --git a/tests/integration/tests/enhancements/with-policy/refactor.test.ts b/tests/integration/tests/enhancements/with-policy/refactor.test.ts index 126c038fa..0cd490f6c 100644 --- a/tests/integration/tests/enhancements/with-policy/refactor.test.ts +++ b/tests/integration/tests/enhancements/with-policy/refactor.test.ts @@ -21,7 +21,7 @@ describe('With Policy: refactor tests', () => { beforeEach(async () => { dbUrl = await createPostgresDb(DB_NAME); - const { prisma: _prisma, withPolicy } = await loadSchemaFromFile( + const { prisma: _prisma, enhance } = await loadSchemaFromFile( path.join(__dirname, '../../schema/refactor-pg.zmodel'), { provider: 'postgresql', @@ -29,7 +29,7 @@ describe('With Policy: refactor tests', () => { logPrismaQuery: true, } ); - getDb = withPolicy; + getDb = enhance; prisma = _prisma; anonDb = getDb(); user1Db = getDb({ id: 1 }); diff --git a/tests/integration/tests/enhancements/with-policy/relation-many-to-many-filter.test.ts b/tests/integration/tests/enhancements/with-policy/relation-many-to-many-filter.test.ts index fe0c686db..e7ddb043e 100644 --- a/tests/integration/tests/enhancements/with-policy/relation-many-to-many-filter.test.ts +++ b/tests/integration/tests/enhancements/with-policy/relation-many-to-many-filter.test.ts @@ -35,9 +35,9 @@ describe('With Policy: relation many-to-many filter', () => { `; it('some filter', async () => { - const { withPolicy } = await loadSchema(model); + const { enhance } = await loadSchema(model); - const db = withPolicy(); + const db = enhance(); await db.m1.create({ data: { @@ -128,9 +128,9 @@ describe('With Policy: relation many-to-many filter', () => { }); it('none filter', async () => { - const { withPolicy } = await loadSchema(model); + const { enhance } = await loadSchema(model); - const db = withPolicy(); + const db = enhance(); await db.m1.create({ data: { @@ -211,9 +211,9 @@ describe('With Policy: relation many-to-many filter', () => { }); it('every filter', async () => { - const { withPolicy } = await loadSchema(model); + const { enhance } = await loadSchema(model); - const db = withPolicy(); + const db = enhance(); await db.m1.create({ data: { diff --git a/tests/integration/tests/enhancements/with-policy/relation-one-to-many-filter.test.ts b/tests/integration/tests/enhancements/with-policy/relation-one-to-many-filter.test.ts index 3737bbf4c..1a1c40406 100644 --- a/tests/integration/tests/enhancements/with-policy/relation-one-to-many-filter.test.ts +++ b/tests/integration/tests/enhancements/with-policy/relation-one-to-many-filter.test.ts @@ -45,9 +45,9 @@ describe('With Policy: relation one-to-many filter', () => { `; it('some filter', async () => { - const { withPolicy } = await loadSchema(model); + const { enhance } = await loadSchema(model); - const db = withPolicy(); + const db = enhance(); // m1 with m2 and m3 await db.m1.create({ @@ -163,9 +163,9 @@ describe('With Policy: relation one-to-many filter', () => { }); it('none filter', async () => { - const { withPolicy } = await loadSchema(model); + const { enhance } = await loadSchema(model); - const db = withPolicy(); + const db = enhance(); // m1 with m2 and m3 await db.m1.create({ @@ -281,9 +281,9 @@ describe('With Policy: relation one-to-many filter', () => { }); it('every filter', async () => { - const { withPolicy } = await loadSchema(model); + const { enhance } = await loadSchema(model); - const db = withPolicy(); + const db = enhance(); // m1 with m2 and m3 await db.m1.create({ @@ -399,9 +399,9 @@ describe('With Policy: relation one-to-many filter', () => { }); it('_count filter', async () => { - const { withPolicy } = await loadSchema(model); + const { enhance } = await loadSchema(model); - const db = withPolicy(); + const db = enhance(); // m1 with m2 and m3 await db.m1.create({ diff --git a/tests/integration/tests/enhancements/with-policy/relation-one-to-one-filter.test.ts b/tests/integration/tests/enhancements/with-policy/relation-one-to-one-filter.test.ts index 7c26bc854..d076e18e5 100644 --- a/tests/integration/tests/enhancements/with-policy/relation-one-to-one-filter.test.ts +++ b/tests/integration/tests/enhancements/with-policy/relation-one-to-one-filter.test.ts @@ -45,9 +45,9 @@ describe('With Policy: relation one-to-one filter', () => { `; it('is filter', async () => { - const { withPolicy } = await loadSchema(model); + const { enhance } = await loadSchema(model); - const db = withPolicy(); + const db = enhance(); // m1 with m2 and m3 await db.m1.create({ @@ -152,9 +152,9 @@ describe('With Policy: relation one-to-one filter', () => { }); it('isNot filter', async () => { - const { withPolicy } = await loadSchema(model); + const { enhance } = await loadSchema(model); - const db = withPolicy(); + const db = enhance(); // m1 with m2 and m3 await db.m1.create({ @@ -261,9 +261,9 @@ describe('With Policy: relation one-to-one filter', () => { }); it('direct object filter', async () => { - const { withPolicy } = await loadSchema(model); + const { enhance } = await loadSchema(model); - const db = withPolicy(); + const db = enhance(); // m1 with m2 and m3 await db.m1.create({ diff --git a/tests/integration/tests/enhancements/with-policy/self-relation.test.ts b/tests/integration/tests/enhancements/with-policy/self-relation.test.ts index dc7cb96ca..525d30043 100644 --- a/tests/integration/tests/enhancements/with-policy/self-relation.test.ts +++ b/tests/integration/tests/enhancements/with-policy/self-relation.test.ts @@ -13,7 +13,7 @@ describe('With Policy: self relations', () => { }); it('one-to-one', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model User { id Int @id @default(autoincrement()) @@ -28,7 +28,7 @@ describe('With Policy: self relations', () => { ` ); - const db = withPolicy(); + const db = enhance(); // create denied await expect( @@ -90,7 +90,7 @@ describe('With Policy: self relations', () => { }); it('one-to-many', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model User { id Int @id @default(autoincrement()) @@ -105,7 +105,7 @@ describe('With Policy: self relations', () => { ` ); - const db = withPolicy(); + const db = enhance(); // create denied await expect( @@ -157,7 +157,7 @@ describe('With Policy: self relations', () => { }); it('many-to-many', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model User { id Int @id @default(autoincrement()) @@ -171,7 +171,7 @@ describe('With Policy: self relations', () => { ` ); - const db = withPolicy(); + const db = enhance(); // create denied await expect( diff --git a/tests/integration/tests/enhancements/with-policy/subscription.test.ts b/tests/integration/tests/enhancements/with-policy/subscription.test.ts index 2befdd42a..a4dccf807 100644 --- a/tests/integration/tests/enhancements/with-policy/subscription.test.ts +++ b/tests/integration/tests/enhancements/with-policy/subscription.test.ts @@ -17,7 +17,7 @@ describe.skip('With Policy: subscription test', () => { }); it('subscribe auth check', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model User { id Int @id @default(autoincrement()) @@ -42,11 +42,11 @@ describe.skip('With Policy: subscription test', () => { const rawSub = await prisma.model.subscribe(); - const anonDb = withPolicy(); + const anonDb = enhance(); console.log('Anonymous db subscribing'); const anonSub = await anonDb.model.subscribe(); - const authDb = withPolicy({ id: 1 }); + const authDb = enhance({ id: 1 }); console.log('Auth db subscribing'); const authSub = await authDb.model.subscribe(); @@ -75,7 +75,7 @@ describe.skip('With Policy: subscription test', () => { }); it('subscribe model check', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model Model { id Int @id @default(autoincrement()) @@ -96,7 +96,7 @@ describe.skip('With Policy: subscription test', () => { const rawSub = await prisma.model.subscribe(); - const enhanced = withPolicy(); + const enhanced = enhance(); console.log('Auth db subscribing'); const enhancedSub = await enhanced.model.subscribe(); @@ -130,7 +130,7 @@ describe.skip('With Policy: subscription test', () => { }); it('subscribe partial', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model Model { id Int @id @default(autoincrement()) @@ -151,7 +151,7 @@ describe.skip('With Policy: subscription test', () => { const rawSub = await prisma.model.subscribe({ create: {} }); - const enhanced = withPolicy(); + const enhanced = enhance(); console.log('Auth db subscribing'); const enhancedSub = await enhanced.model.subscribe({ create: {} }); @@ -185,7 +185,7 @@ describe.skip('With Policy: subscription test', () => { }); it('subscribe mixed model check', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model Model { id Int @id @default(autoincrement()) @@ -210,7 +210,7 @@ describe.skip('With Policy: subscription test', () => { delete: { before: { name: { contains: 'world' } } }, }); - const enhanced = withPolicy(); + const enhanced = enhance(); console.log('Auth db subscribing'); const enhancedSub = await enhanced.model.subscribe({ create: { after: { name: { contains: 'world' } } }, diff --git a/tests/integration/tests/enhancements/with-policy/todo-sample.test.ts b/tests/integration/tests/enhancements/with-policy/todo-sample.test.ts index 2b7dd416b..fe26dd561 100644 --- a/tests/integration/tests/enhancements/with-policy/todo-sample.test.ts +++ b/tests/integration/tests/enhancements/with-policy/todo-sample.test.ts @@ -7,11 +7,11 @@ describe('Todo Policy Tests', () => { let prisma: FullDbClientContract; beforeAll(async () => { - const { withPolicy, prisma: _prisma } = await loadSchemaFromFile( + const { enhance, prisma: _prisma } = await loadSchemaFromFile( path.join(__dirname, '../../schema/todo.zmodel'), { addPrelude: false } ); - getDb = withPolicy; + getDb = enhance; prisma = _prisma; }); diff --git a/tests/integration/tests/enhancements/with-policy/toplevel-operations.test.ts b/tests/integration/tests/enhancements/with-policy/toplevel-operations.test.ts index 99179e015..61f25dc25 100644 --- a/tests/integration/tests/enhancements/with-policy/toplevel-operations.test.ts +++ b/tests/integration/tests/enhancements/with-policy/toplevel-operations.test.ts @@ -13,7 +13,7 @@ describe('With Policy: toplevel operations', () => { }); it('read tests', async () => { - const { withPolicy, prisma } = await loadSchema( + const { enhance, prisma } = await loadSchema( ` model Model { id String @id @default(uuid()) @@ -25,7 +25,7 @@ describe('With Policy: toplevel operations', () => { ` ); - const db = withPolicy(); + const db = enhance(); await expect( db.model.create({ @@ -62,7 +62,7 @@ describe('With Policy: toplevel operations', () => { }); it('write tests', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model Model { id String @id @default(uuid()) @@ -75,7 +75,7 @@ describe('With Policy: toplevel operations', () => { ` ); - const db = withPolicy(); + const db = enhance(); // create denied await expect( @@ -148,7 +148,7 @@ describe('With Policy: toplevel operations', () => { }); it('delete tests', async () => { - const { withPolicy, prisma } = await loadSchema( + const { enhance, prisma } = await loadSchema( ` model Model { id String @id @default(uuid()) @@ -161,7 +161,7 @@ describe('With Policy: toplevel operations', () => { ` ); - const db = withPolicy(); + const db = enhance(); await expect(db.model.delete({ where: { id: '1' } })).toBeNotFound(); diff --git a/tests/integration/tests/enhancements/with-policy/unique-as-id.test.ts b/tests/integration/tests/enhancements/with-policy/unique-as-id.test.ts index e4d399204..a7ec74fa5 100644 --- a/tests/integration/tests/enhancements/with-policy/unique-as-id.test.ts +++ b/tests/integration/tests/enhancements/with-policy/unique-as-id.test.ts @@ -13,7 +13,7 @@ describe('With Policy: unique as id', () => { }); it('unique fields', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model A { x String @unique @@ -38,7 +38,7 @@ describe('With Policy: unique as id', () => { ` ); - const db = withPolicy(); + const db = enhance(); await expect(db.a.create({ data: { x: '1', y: 1, value: 0 } })).toBeRejectedByPolicy(); await expect(db.a.create({ data: { x: '1', y: 2, value: 1 } })).toResolveTruthy(); @@ -64,7 +64,7 @@ describe('With Policy: unique as id', () => { }); it('unique fields mixed with id', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model A { id Int @id @default(autoincrement()) @@ -91,7 +91,7 @@ describe('With Policy: unique as id', () => { ` ); - const db = withPolicy(); + const db = enhance(); await expect(db.a.create({ data: { x: '1', y: 1, value: 0 } })).toBeRejectedByPolicy(); await expect(db.a.create({ data: { x: '1', y: 2, value: 1 } })).toResolveTruthy(); @@ -117,7 +117,7 @@ describe('With Policy: unique as id', () => { }); it('model-level unique fields', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model A { x String @@ -147,7 +147,7 @@ describe('With Policy: unique as id', () => { ` ); - const db = withPolicy(); + const db = enhance(); await expect(db.a.create({ data: { x: '1', y: 1, value: 0 } })).toBeRejectedByPolicy(); await expect(db.a.create({ data: { x: '1', y: 2, value: 1 } })).toResolveTruthy(); diff --git a/tests/integration/tests/enhancements/with-policy/view.test.ts b/tests/integration/tests/enhancements/with-policy/view.test.ts index f5abe6439..3c541d2b0 100644 --- a/tests/integration/tests/enhancements/with-policy/view.test.ts +++ b/tests/integration/tests/enhancements/with-policy/view.test.ts @@ -13,7 +13,7 @@ describe('View Policy Test', () => { }); it('view policy', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` datasource db { provider = "sqlite" @@ -91,7 +91,7 @@ describe('View Policy Test', () => { }, }); - const db = withPolicy(); + const db = enhance(); await expect(prisma.userInfo.findMany()).resolves.toHaveLength(2); await expect(db.userInfo.findMany()).resolves.toHaveLength(1); diff --git a/tests/integration/tests/frameworks/nextjs/test-project/package.json b/tests/integration/tests/frameworks/nextjs/test-project/package.json index 1461849f1..7b93ec340 100644 --- a/tests/integration/tests/frameworks/nextjs/test-project/package.json +++ b/tests/integration/tests/frameworks/nextjs/test-project/package.json @@ -9,7 +9,7 @@ "lint": "next lint" }, "dependencies": { - "@prisma/client": "^4.8.0", + "@prisma/client": "^5.0.0", "@types/node": "18.11.18", "@types/react": "18.0.27", "@types/react-dom": "18.0.10", @@ -22,6 +22,6 @@ "zod": "^3.22.4" }, "devDependencies": { - "prisma": "^4.8.0" + "prisma": "^5.0.0" } } diff --git a/tests/integration/tests/frameworks/trpc/generation.test.ts b/tests/integration/tests/frameworks/trpc/generation.test.ts index 5e15d9943..a58f5965d 100644 --- a/tests/integration/tests/frameworks/trpc/generation.test.ts +++ b/tests/integration/tests/frameworks/trpc/generation.test.ts @@ -21,6 +21,7 @@ describe('tRPC Routers Generation Tests', () => { `${path.join(__dirname, '../../../../../.build/zenstackhq-sdk-' + ver + '.tgz')}`, `${path.join(__dirname, '../../../../../.build/zenstackhq-runtime-' + ver + '.tgz')}`, `${path.join(__dirname, '../../../../../.build/zenstackhq-trpc-' + ver + '.tgz')}`, + `${path.join(__dirname, '../../../../../.build/zenstackhq-server-' + ver + '.tgz')}`, ]; const deps = depPkgs.join(' '); @@ -35,7 +36,9 @@ describe('tRPC Routers Generation Tests', () => { process.chdir(testDir); run('npm install'); run('npm install ' + deps); - run('npx zenstack generate --no-dependency-check --schema ./todo.zmodel', { NODE_PATH: 'node_modules' }); + run('npx zenstack generate --no-dependency-check --schema ./todo.zmodel', { + NODE_PATH: 'node_modules', + }); run('npm run build', { NODE_PATH: 'node_modules' }); }); }); diff --git a/tests/integration/tests/frameworks/trpc/test-project/package.json b/tests/integration/tests/frameworks/trpc/test-project/package.json index f27687e63..8445cc451 100644 --- a/tests/integration/tests/frameworks/trpc/test-project/package.json +++ b/tests/integration/tests/frameworks/trpc/test-project/package.json @@ -9,7 +9,7 @@ "lint": "next lint" }, "dependencies": { - "@prisma/client": "^4.8.0", + "@prisma/client": "^5.0.0", "@tanstack/react-query": "^4.22.4", "@trpc/client": "^10.34.0", "@trpc/next": "^10.34.0", @@ -26,6 +26,6 @@ "zod": "^3.22.4" }, "devDependencies": { - "prisma": "^4.8.0" + "prisma": "^5.0.0" } } diff --git a/tests/integration/tests/frameworks/trpc/test-project/todo.zmodel b/tests/integration/tests/frameworks/trpc/test-project/todo.zmodel index 6840f8978..92363c825 100644 --- a/tests/integration/tests/frameworks/trpc/test-project/todo.zmodel +++ b/tests/integration/tests/frameworks/trpc/test-project/todo.zmodel @@ -7,16 +7,6 @@ generator js { provider = 'prisma-client-js' } -plugin meta { - provider = '@core/model-meta' - output = '.zenstack' -} - -plugin policy { - provider = '@core/access-policy' - output = '.zenstack' -} - plugin trpc { provider = '@zenstackhq/trpc' output = 'server/routers/generated' diff --git a/tests/integration/tests/misc/stacktrace.test.ts b/tests/integration/tests/misc/stacktrace.test.ts index 6573ed088..f652c5514 100644 --- a/tests/integration/tests/misc/stacktrace.test.ts +++ b/tests/integration/tests/misc/stacktrace.test.ts @@ -13,7 +13,7 @@ describe('Stack trace tests', () => { }); it('stack trace', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model Model { id String @id @default(uuid()) @@ -21,7 +21,7 @@ describe('Stack trace tests', () => { ` ); - const db = withPolicy(); + const db = enhance(); let error: Error | undefined = undefined; try { @@ -31,7 +31,7 @@ describe('Stack trace tests', () => { } expect(error?.stack).toContain( - "Error calling enhanced Prisma method `create`: denied by policy: model entities failed 'create' check" + "Error calling enhanced Prisma method `model.create`: denied by policy: model entities failed 'create' check" ); expect(error?.stack).toContain(`misc/stacktrace.test.ts`); expect((error as any).internalStack).toBeTruthy(); diff --git a/tests/integration/tests/regression/issue-665.test.ts b/tests/integration/tests/regression/issue-665.test.ts index 8bd9f717b..b6552fd2b 100644 --- a/tests/integration/tests/regression/issue-665.test.ts +++ b/tests/integration/tests/regression/issue-665.test.ts @@ -2,7 +2,7 @@ import { loadSchema } from '@zenstackhq/testtools'; describe('Regression: issue 665', () => { it('regression', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model User { id Int @id @default(autoincrement()) @@ -20,19 +20,19 @@ describe('Regression: issue 665', () => { await prisma.user.create({ data: { id: 1, username: 'test', password: 'test', admin: true } }); // admin - let r = await withPolicy({ id: 1, admin: true }).user.findFirst(); + let r = await enhance({ id: 1, admin: true }).user.findFirst(); expect(r.username).toEqual('test'); // owner - r = await withPolicy({ id: 1 }).user.findFirst(); + r = await enhance({ id: 1 }).user.findFirst(); expect(r.username).toEqual('test'); // anonymous - r = await withPolicy().user.findFirst(); + r = await enhance().user.findFirst(); expect(r.username).toBeUndefined(); // non-owner - r = await withPolicy({ id: 2 }).user.findFirst(); + r = await enhance({ id: 2 }).user.findFirst(); expect(r.username).toBeUndefined(); }); }); diff --git a/tests/integration/tests/regression/issue-925.test.ts b/tests/integration/tests/regression/issue-925.test.ts index 34b1ac434..b19d9d615 100644 --- a/tests/integration/tests/regression/issue-925.test.ts +++ b/tests/integration/tests/regression/issue-925.test.ts @@ -1,7 +1,7 @@ -import { loadModelWithError } from '@zenstackhq/testtools'; +import { loadModel, loadModelWithError } from '@zenstackhq/testtools'; describe('Regression: issue 925', () => { - it('member reference from this', async () => { + it('member reference without using this', async () => { await expect( loadModelWithError( ` @@ -10,7 +10,7 @@ describe('Regression: issue 925', () => { company Company[] test Int - @@allow('read', auth().company?[staff?[companyId == this.test]]) + @@allow('read', auth().company?[staff?[companyId == test]]) } model Company { @@ -32,19 +32,18 @@ describe('Regression: issue 925', () => { } ` ) - ).resolves.toContain("Could not resolve reference to DataModelField named 'test'."); + ).resolves.toContain("Could not resolve reference to ReferenceTarget named 'test'."); }); - it('simple reference', async () => { - await expect( - loadModelWithError( - ` + it('reference with this', async () => { + await loadModel( + ` model User { id Int @id @default(autoincrement()) company Company[] test Int - @@allow('read', auth().company?[staff?[companyId == test]]) + @@allow('read', auth().company?[staff?[companyId == this.test]]) } model Company { @@ -65,7 +64,6 @@ describe('Regression: issue 925', () => { @@allow('read', true) } ` - ) - ).resolves.toContain("Could not resolve reference to ReferenceTarget named 'test'."); + ); }); }); diff --git a/tests/integration/tests/regression/issue-965.test.ts b/tests/integration/tests/regression/issue-965.test.ts new file mode 100644 index 000000000..79bd92075 --- /dev/null +++ b/tests/integration/tests/regression/issue-965.test.ts @@ -0,0 +1,53 @@ +import { loadModel, loadModelWithError } from '@zenstackhq/testtools'; + +describe('Regression: issue 965', () => { + it('regression1', async () => { + await loadModel(` + abstract model Base { + id String @id @default(cuid()) + } + + abstract model A { + URL String? @url + } + + abstract model B { + anotherURL String? @url + } + + abstract model C { + oneMoreURL String? @url + } + + model D extends Base, A, B { + } + + model E extends Base, B, C { + }`); + }); + + it('regression2', async () => { + await expect( + loadModelWithError(` + abstract model A { + URL String? @url + } + + abstract model B { + anotherURL String? @url + } + + abstract model C { + oneMoreURL String? @url + } + + model D extends A, B { + } + + model E extends B, C { + }`) + ).resolves.toContain( + 'Model must have at least one unique criteria. Either mark a single field with `@id`, `@unique` or add a multi field criterion with `@@id([])` or `@@unique([])` to the model.' + ); + }); +}); diff --git a/tests/integration/tests/regression/issue-971.test.ts b/tests/integration/tests/regression/issue-971.test.ts new file mode 100644 index 000000000..40990aa6a --- /dev/null +++ b/tests/integration/tests/regression/issue-971.test.ts @@ -0,0 +1,23 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('Regression: issue 971', () => { + it('regression', async () => { + await loadSchema( + ` + abstract model Level1 { + id String @id @default(cuid()) + URL String? + @@validate(URL != null, "URL must be provided") // works + } + abstract model Level2 extends Level1 { + @@validate(URL != null, "URL must be provided") // works + } + abstract model Level3 extends Level2 { + @@validate(URL != null, "URL must be provided") // doesn't work + } + model Foo extends Level3 { + } + ` + ); + }); +}); diff --git a/tests/integration/tests/regression/issue-992.test.ts b/tests/integration/tests/regression/issue-992.test.ts new file mode 100644 index 000000000..40a1aac47 --- /dev/null +++ b/tests/integration/tests/regression/issue-992.test.ts @@ -0,0 +1,45 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('Regression: issue 992', () => { + it('regression', async () => { + const { enhance, prisma } = await loadSchema( + ` + model Product { + id String @id @default(cuid()) + category Category @relation(fields: [categoryId], references: [id]) + categoryId String + + deleted Int @default(0) @omit + @@deny('read', deleted != 0) + @@allow('all', true) + } + + model Category { + id String @id @default(cuid()) + products Product[] + @@allow('all', true) + } + ` + ); + + await prisma.category.create({ + data: { + products: { + create: [ + { + deleted: 0, + }, + { + deleted: 0, + }, + ], + }, + }, + }); + + const db = enhance(); + const category = await db.category.findFirst({ include: { products: true } }); + expect(category.products[0].deleted).toBeUndefined(); + expect(category.products[1].deleted).toBeUndefined(); + }); +}); diff --git a/tests/integration/tests/regression/issues.test.ts b/tests/integration/tests/regression/issues.test.ts index 8353f8bad..7c2ca94cd 100644 --- a/tests/integration/tests/regression/issues.test.ts +++ b/tests/integration/tests/regression/issues.test.ts @@ -13,7 +13,7 @@ describe('GitHub issues regression', () => { }); it('issue 389', async () => { - const { withPolicy } = await loadSchema(` + const { enhance } = await loadSchema(` model model { id String @id @default(uuid()) value Int @@ -21,7 +21,7 @@ describe('GitHub issues regression', () => { @@allow('create', value > 0) } `); - const db = withPolicy(); + const db = enhance(); await expect(db.model.create({ data: { value: 0 } })).toBeRejectedByPolicy(); await expect(db.model.create({ data: { value: 1 } })).toResolveTruthy(); }); @@ -88,7 +88,7 @@ describe('GitHub issues regression', () => { }); it('select with _count', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model User { id String @id @unique @default(uuid()) @@ -117,7 +117,7 @@ describe('GitHub issues regression', () => { }, }); - const db = withPolicy(); + const db = enhance(); const r = await db.user.findFirst({ select: { _count: { select: { posts: true } } } }); expect(r).toMatchObject({ _count: { posts: 2 } }); }); @@ -150,7 +150,7 @@ describe('GitHub issues regression', () => { }); it('issue 552', async () => { - const { withPolicy, prisma } = await loadSchema( + const { enhance, prisma } = await loadSchema( ` model Tenant { id Int @id @default(autoincrement()) @@ -240,7 +240,7 @@ describe('GitHub issues regression', () => { }, }); - const db = withPolicy({ id: 1, is_super_admin: true }); + const db = enhance({ id: 1, is_super_admin: true }); await db.userTenant.update({ where: { user_id_tenant_id: { @@ -259,7 +259,7 @@ describe('GitHub issues regression', () => { }); it('issue 609', async () => { - const { withPolicy, prisma } = await loadSchema( + const { enhance, prisma } = await loadSchema( ` model User { id String @id @default(cuid()) @@ -300,7 +300,7 @@ describe('GitHub issues regression', () => { }); // connecting a child comment from a different user to a parent comment should succeed - const db = withPolicy({ id: '2' }); + const db = enhance({ id: '2' }); await expect( db.comment.create({ data: { @@ -313,7 +313,7 @@ describe('GitHub issues regression', () => { }); it('issue 624', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model User { id String @id @default(uuid()) @@ -327,9 +327,9 @@ model User { // can be created by anyone, even not logged in @@allow('create', true) // can be read by users in the same organization - @@allow('read', orgs?[members?[auth() == this]]) + @@allow('read', orgs?[members?[auth().id == id]]) // full access by oneself - @@allow('all', auth() == this) + @@allow('all', auth().id == id) } model Organization { @@ -343,7 +343,7 @@ model Organization { // everyone can create a organization @@allow('create', true) // any user in the organization can read the organization - @@allow('read', members?[auth() == this]) + @@allow('read', members?[auth().id == id]) } abstract model organizationBaseEntity { @@ -359,15 +359,15 @@ abstract model organizationBaseEntity { groups Group[] // when create, owner must be set to current user, and user must be in the organization - @@allow('create', owner == auth() && org.members?[this == auth()]) + @@allow('create', owner == auth() && org.members?[id == auth().id]) // only the owner can update it and is not allowed to change the owner - @@allow('update', owner == auth() && org.members?[this == auth()] && future().owner == owner) + @@allow('update', owner == auth() && org.members?[id == auth().id] && future().owner == owner) // allow owner to read @@allow('read', owner == auth()) // allow shared group members to read it - @@allow('read', groups?[users?[this == auth()]]) + @@allow('read', groups?[users?[id == auth().id]]) // allow organization to access if public - @@allow('read', isPublic && org.members?[this == auth()]) + @@allow('read', isPublic && org.members?[id == auth().id]) // can not be read if deleted @@deny('all', isDeleted == true) } @@ -394,7 +394,7 @@ model Group { orgId String // group is shared by organization - @@allow('all', org.members?[auth() == this]) + @@allow('all', org.members?[auth().id == id]) } ` ); @@ -476,7 +476,7 @@ model Group { console.log(`Created user with id: ${user.id}`); } - const db = withPolicy({ id: 'robin@prisma.io' }); + const db = enhance({ id: 'robin@prisma.io' }); await expect( db.post.findMany({ where: {}, @@ -507,7 +507,7 @@ model Group { }); it('issue 627', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model User { id String @id @default(uuid()) @@ -541,7 +541,7 @@ model Equipment extends BaseEntityWithTenant { }, }); - const db = withPolicy({ id: 'tenant-1' }); + const db = enhance({ id: 'tenant-1' }); await expect( db.equipment.create({ data: { @@ -586,7 +586,7 @@ model TwoEnumsOneModelTest { }); it('issue 634', async () => { - const { prisma, withPolicy } = await loadSchema( + const { prisma, enhance } = await loadSchema( ` model User { id String @id @default(uuid()) @@ -616,7 +616,7 @@ model Organization { // everyone can create a organization @@allow('create', true) // any user in the organization can read the organization - @@allow('read', members?[auth() == this]) + @@allow('read', members?[auth().id == id]) } abstract model organizationBaseEntity { @@ -632,15 +632,15 @@ abstract model organizationBaseEntity { groups Group[] // when create, owner must be set to current user, and user must be in the organization - @@allow('create', owner == auth() && org.members?[this == auth()]) + @@allow('create', owner == auth() && org.members?[id == auth().id]) // only the owner can update it and is not allowed to change the owner - @@allow('update', owner == auth() && org.members?[this == auth()] && future().owner == owner) + @@allow('update', owner == auth() && org.members?[id == auth().id] && future().owner == owner) // allow owner to read @@allow('read', owner == auth()) // allow shared group members to read it - @@allow('read', groups?[users?[this == auth()]]) + @@allow('read', groups?[users?[id == auth().id]]) // allow organization to access if public - @@allow('read', isPublic && org.members?[this == auth()]) + @@allow('read', isPublic && org.members?[id == auth().id]) // can not be read if deleted @@deny('all', isDeleted == true) } @@ -667,7 +667,7 @@ model Group { orgId String // group is shared by organization - @@allow('all', org.members?[auth() == this]) + @@allow('all', org.members?[auth().id == id]) } ` ); @@ -749,7 +749,7 @@ model Group { console.log(`Created user with id: ${user.id}`); } - const db = withPolicy({ id: 'robin@prisma.io' }); + const db = enhance({ id: 'robin@prisma.io' }); await expect( db.comment.findMany({ where: { diff --git a/tests/integration/tests/schema/petstore.zmodel b/tests/integration/tests/schema/petstore.zmodel index 77ec1e643..42a279550 100644 --- a/tests/integration/tests/schema/petstore.zmodel +++ b/tests/integration/tests/schema/petstore.zmodel @@ -5,7 +5,6 @@ datasource db { generator js { provider = 'prisma-client-js' - previewFeatures = ['clientExtensions'] } plugin zod { diff --git a/tests/integration/tests/schema/todo.zmodel b/tests/integration/tests/schema/todo.zmodel index 733391bd1..c3a84707e 100644 --- a/tests/integration/tests/schema/todo.zmodel +++ b/tests/integration/tests/schema/todo.zmodel @@ -9,7 +9,6 @@ datasource db { generator js { provider = 'prisma-client-js' - previewFeatures = ['clientExtensions'] } plugin zod {