diff --git a/packages/plugins/tanstack-query/src/generator.ts b/packages/plugins/tanstack-query/src/generator.ts index c45a32517..c484f9180 100644 --- a/packages/plugins/tanstack-query/src/generator.ts +++ b/packages/plugins/tanstack-query/src/generator.ts @@ -6,6 +6,7 @@ import { ensureEmptyDir, generateModelMeta, getDataModels, + getPrismaClientGenerator, isDelegateModel, requireOption, resolvePath, @@ -52,7 +53,6 @@ export async function generate(model: Model, options: PluginOptions, dmmf: DMMF. `Invalid value for "portable" option: ${options.portable}, a boolean value is expected` ); } - const portable = options.portable ?? false; await generateModelMeta(project, models, typeDefs, { output: path.join(outDir, '__model_meta.ts'), @@ -70,8 +70,13 @@ export async function generate(model: Model, options: PluginOptions, dmmf: DMMF. generateModelHooks(target, version, project, outDir, dataModel, mapping, options); }); - if (portable) { - generateBundledTypes(project, outDir, options); + if (options.portable) { + const gen = getPrismaClientGenerator(model); + if (gen?.isNewGenerator) { + warnings.push(`The "portable" option is not supported with the "prisma-client" generator and is ignored.`); + } else { + generateBundledTypes(project, outDir, options); + } } await saveProject(project); diff --git a/packages/schema/src/cli/actions/generate.ts b/packages/schema/src/cli/actions/generate.ts index 229a9ddd8..2f5099312 100644 --- a/packages/schema/src/cli/actions/generate.ts +++ b/packages/schema/src/cli/actions/generate.ts @@ -1,4 +1,4 @@ -import { PluginError } from '@zenstackhq/sdk'; +import { getPrismaClientGenerator, PluginError } from '@zenstackhq/sdk'; import { isPlugin } from '@zenstackhq/sdk/ast'; import colors from 'colors'; import path from 'path'; @@ -70,6 +70,18 @@ async function runPlugins(options: Options) { const model = await loadDocument(schema); + const gen = getPrismaClientGenerator(model); + if (gen?.isNewGenerator && !options.output) { + console.error( + colors.red( + 'When using the "prisma-client" generator, you must provide an explicit output path with the "--output" CLI parameter.' + ) + ); + throw new CliError( + 'When using with the "prisma-client" generator, you must provide an explicit output path with the "--output" CLI parameter.' + ); + } + for (const name of [...(options.withPlugins ?? []), ...(options.withoutPlugins ?? [])]) { const pluginDecl = model.declarations.find((d) => isPlugin(d) && d.name === name); if (!pluginDecl) { diff --git a/packages/schema/src/cli/plugin-runner.ts b/packages/schema/src/cli/plugin-runner.ts index 54e13d41e..7c9ffdd66 100644 --- a/packages/schema/src/cli/plugin-runner.ts +++ b/packages/schema/src/cli/plugin-runner.ts @@ -112,10 +112,7 @@ export class PluginRunner { const otherPlugins = plugins.filter((p) => !p.options.preprocessor); // calculate all plugins (including core plugins implicitly enabled) - const { corePlugins, userPlugins } = this.calculateAllPlugins( - runnerOptions, - otherPlugins, - ); + const { corePlugins, userPlugins } = this.calculateAllPlugins(runnerOptions, otherPlugins); const allPlugins = [...corePlugins, ...userPlugins]; // check dependencies @@ -448,7 +445,7 @@ export class PluginRunner { } async function compileProject(project: Project, runnerOptions: PluginRunnerOptions) { - if (runnerOptions.compile !== false) { + if (!runnerOptions.output && runnerOptions.compile !== false) { // emit await emitProject(project); } else { diff --git a/packages/schema/src/plugins/enhancer/enhance/index.ts b/packages/schema/src/plugins/enhancer/enhance/index.ts index f8db9c063..06853c8e8 100644 --- a/packages/schema/src/plugins/enhancer/enhance/index.ts +++ b/packages/schema/src/plugins/enhancer/enhance/index.ts @@ -7,7 +7,7 @@ import { getDataModelAndTypeDefs, getDataModels, getForeignKeyFields, - getLiteral, + getPrismaClientGenerator, getRelationField, hasAttribute, isDelegateModel, @@ -22,7 +22,6 @@ import { ReferenceExpr, isArrayExpr, isDataModel, - isGeneratorDecl, isTypeDef, type Model, } from '@zenstackhq/sdk/ast'; @@ -56,7 +55,7 @@ import { generateTypeDefType } from './model-typedef-generator'; // information of delegate models and their sub models type DelegateInfo = [DataModel, DataModel[]][]; -const LOGICAL_CLIENT_GENERATION_PATH = './.logical-prisma-client'; +const LOGICAL_CLIENT_GENERATION_PATH = './logical-prisma-client'; export class EnhancerGenerator { // regex for matching "ModelCreateXXXInput" and "ModelUncheckedCreateXXXInput" type @@ -114,6 +113,9 @@ export class EnhancerGenerator { if (this.needsLogicalClient) { prismaTypesFixed = true; resultPrismaTypeImport = LOGICAL_CLIENT_GENERATION_PATH; + if (this.isNewPrismaClientGenerator) { + resultPrismaTypeImport += '/client'; + } const result = await this.generateLogicalPrisma(); dmmf = result.dmmf; } @@ -440,23 +442,14 @@ export type Enhanced = } private getPrismaClientGeneratorName(model: Model) { - for (const generator of model.declarations.filter(isGeneratorDecl)) { - if ( - generator.fields.some( - (f) => f.name === 'provider' && getLiteral(f.value) === 'prisma-client-js' - ) - ) { - return generator.name; - } + const gen = getPrismaClientGenerator(model); + if (!gen) { + throw new PluginError(name, `Cannot find "prisma-client-js" or "prisma-client" generator in the schema`); } - throw new PluginError(name, `Cannot find prisma-client-js generator in the schema`); + return gen.name; } private async processClientTypes(prismaClientDir: string) { - // make necessary updates to the generated `index.d.ts` file and overwrite it - const project = new Project(); - const sf = project.addSourceFileAtPath(path.join(prismaClientDir, 'index.d.ts')); - // build a map of delegate models and their sub models const delegateInfo: DelegateInfo = []; this.model.declarations @@ -468,6 +461,16 @@ export type Enhanced = } }); + if (this.isNewPrismaClientGenerator) { + await this.processClientTypesNewPrismaGenerator(prismaClientDir, delegateInfo); + } else { + await this.processClientTypesLegacyPrismaGenerator(prismaClientDir, delegateInfo); + } + } + private async processClientTypesLegacyPrismaGenerator(prismaClientDir: string, delegateInfo: DelegateInfo) { + const project = new Project(); + const sf = project.addSourceFileAtPath(path.join(prismaClientDir, 'index.d.ts')); + // transform index.d.ts and write it into a new file (better perf than in-line editing) const sfNew = project.createSourceFile(path.join(prismaClientDir, 'index-fixed.d.ts'), undefined, { overwrite: true, @@ -484,6 +487,36 @@ export type Enhanced = await sfNew.save(); } + private async processClientTypesNewPrismaGenerator(prismaClientDir: string, delegateInfo: DelegateInfo) { + const project = new Project(); + + for (const d of this.model.declarations.filter(isDataModel)) { + const fileName = `${prismaClientDir}/models/${d.name}.ts`; + const sf = project.addSourceFileAtPath(fileName); + const sfNew = project.createSourceFile(`${prismaClientDir}/models/${d.name}-fixed.ts`, undefined, { + overwrite: true, + }); + + const syntaxList = sf.getChildren()[0]; + if (!Node.isSyntaxList(syntaxList)) { + throw new PluginError(name, `Unexpected syntax list structure in ${fileName}`); + } + + syntaxList.getChildren().forEach((node) => { + if (Node.isInterfaceDeclaration(node)) { + sfNew.addInterface(this.transformInterface(node, delegateInfo)); + } else if (Node.isTypeAliasDeclaration(node)) { + sfNew.addTypeAlias(this.transformTypeAlias(node, delegateInfo)); + } else { + sfNew.addStatements(node.getText()); + } + }); + + await sfNew.move(sf.getFilePath(), { overwrite: true }); + await sfNew.save(); + } + } + private transformPrismaTypes(sf: SourceFile, sfNew: SourceFile, delegateInfo: DelegateInfo) { // copy toplevel imports sfNew.addImportDeclarations(sf.getImportDeclarations().map((n) => n.getStructure())); @@ -639,7 +672,7 @@ export type Enhanced = source = `${payloadRecord[1] .map( (concrete) => - `($${concrete.name}Payload & { scalars: { ${discriminatorDecl.name}: '${concrete.name}' } })` + `(Prisma.$${concrete.name}Payload & { scalars: { ${discriminatorDecl.name}: '${concrete.name}' } })` ) .join(' | ')}`; } @@ -916,4 +949,9 @@ export type Enhanced = private trimEmptyLines(source: string): string { return source.replace(/^\s*[\r\n]/gm, ''); } + + private get isNewPrismaClientGenerator() { + const gen = getPrismaClientGenerator(this.model); + return !!gen?.isNewGenerator; + } } diff --git a/packages/schema/src/plugins/prisma/index.ts b/packages/schema/src/plugins/prisma/index.ts index aba67c90a..85832b266 100644 --- a/packages/schema/src/plugins/prisma/index.ts +++ b/packages/schema/src/plugins/prisma/index.ts @@ -2,11 +2,10 @@ import { PluginError, type PluginFunction, type PluginOptions, - getLiteral, + getPrismaClientGenerator, normalizedRelative, resolvePath, } from '@zenstackhq/sdk'; -import { GeneratorDecl, isGeneratorDecl } from '@zenstackhq/sdk/ast'; import { getDMMF } from '@zenstackhq/sdk/prisma'; import colors from 'colors'; import fs from 'fs'; @@ -58,13 +57,9 @@ const run: PluginFunction = async (model, options, _dmmf, _globalOptions) => { } // extract user-provided prisma client output path - const generator = model.declarations.find( - (d): d is GeneratorDecl => - isGeneratorDecl(d) && - d.fields.some((f) => f.name === 'provider' && getLiteral(f.value) === 'prisma-client-js') - ); - const clientOutputField = generator?.fields.find((f) => f.name === 'output'); - const clientOutput = getLiteral(clientOutputField?.value); + const gen = getPrismaClientGenerator(model); + const clientOutput = gen?.output; + const newGenerator = !!gen?.isNewGenerator; if (clientOutput) { if (path.isAbsolute(clientOutput)) { @@ -81,6 +76,11 @@ const run: PluginFunction = async (model, options, _dmmf, _globalOptions) => { clientOutputDir = prismaClientPath; } + if (newGenerator) { + // "prisma-client" generator requires an extra "/client" import suffix + prismaClientPath = `${prismaClientPath}/client`; + } + // get PrismaClient dts path if (clientOutput) { @@ -89,7 +89,7 @@ const run: PluginFunction = async (model, options, _dmmf, _globalOptions) => { prismaClientDtsPath = path.resolve(path.dirname(options.schemaPath), clientOutputDir, 'index.d.ts'); } - if (!prismaClientDtsPath || !fs.existsSync(prismaClientDtsPath)) { + if (!newGenerator && (!prismaClientDtsPath || !fs.existsSync(prismaClientDtsPath))) { // if the file does not exist, try node module resolution try { // the resolution is relative to the schema path by default diff --git a/packages/schema/src/plugins/prisma/schema-generator.ts b/packages/schema/src/plugins/prisma/schema-generator.ts index 177edbd0e..1bb661647 100644 --- a/packages/schema/src/plugins/prisma/schema-generator.ts +++ b/packages/schema/src/plugins/prisma/schema-generator.ts @@ -234,7 +234,10 @@ export class PrismaSchemaGenerator { // deal with configuring PrismaClient preview features const provider = generator.fields.find((f) => f.name === 'provider'); - if (provider?.text === JSON.stringify('prisma-client-js')) { + if ( + provider?.text === JSON.stringify('prisma-client-js') || + provider?.text === JSON.stringify('prisma-client') + ) { const prismaVersion = getPrismaVersion(); if (prismaVersion) { const previewFeatures = JSON.parse( diff --git a/packages/sdk/src/utils.ts b/packages/sdk/src/utils.ts index 93118b5f5..1904ffdd6 100644 --- a/packages/sdk/src/utils.ts +++ b/packages/sdk/src/utils.ts @@ -466,7 +466,11 @@ export function getPreviewFeatures(model: Model) { const jsGenerator = model.declarations.find( (d) => isGeneratorDecl(d) && - d.fields.some((f) => f.name === 'provider' && getLiteral(f.value) === 'prisma-client-js') + d.fields.some( + (f) => + (f.name === 'provider' && getLiteral(f.value) === 'prisma-client-js') || + getLiteral(f.value) === 'prisma-client' + ) ) as GeneratorDecl | undefined; if (jsGenerator) { @@ -683,3 +687,28 @@ export function getRelationName(field: DataModelField) { } return getAttributeArgLiteral(relAttr, 'name'); } + +export function getPrismaClientGenerator(model: Model) { + const decl = model.declarations.find( + (d): d is GeneratorDecl => + isGeneratorDecl(d) && + d.fields.some( + (f) => + f.name === 'provider' && + (getLiteral(f.value) === 'prisma-client-js' || + getLiteral(f.value) === 'prisma-client') + ) + ); + if (!decl) { + return undefined; + } + + const provider = getLiteral(decl.fields.find((f) => f.name === 'provider')?.value); + return { + name: decl.name, + output: getLiteral(decl.fields.find((f) => f.name === 'output')?.value), + previewFeatures: getLiteralArray(decl.fields.find((f) => f.name === 'previewFeatures')?.value), + provider, + isNewGenerator: provider === 'prisma-client', + }; +} diff --git a/packages/testtools/src/schema.ts b/packages/testtools/src/schema.ts index fabd180e2..4eea10bd2 100644 --- a/packages/testtools/src/schema.ts +++ b/packages/testtools/src/schema.ts @@ -278,27 +278,6 @@ export async function loadSchema(schema: string, options?: SchemaLoadOptions) { fs.cpSync(dep, path.join(projectDir, 'node_modules', pkgJson.name), { recursive: true, force: true }); }); - const prismaLoadPath = options?.prismaLoadPath - ? path.isAbsolute(options.prismaLoadPath) - ? options.prismaLoadPath - : path.join(projectDir, options.prismaLoadPath) - : path.join(projectDir, 'node_modules/.prisma/client'); - const prismaModule = require(prismaLoadPath); - const PrismaClient = prismaModule.PrismaClient; - - let clientOptions: object = { log: ['info', 'warn', 'error'] }; - if (options?.prismaClientOptions) { - clientOptions = { ...clientOptions, ...options.prismaClientOptions }; - } - let prisma = new PrismaClient(clientOptions); - // https://github.com/prisma/prisma/issues/18292 - prisma[Symbol.for('nodejs.util.inspect.custom')] = 'PrismaClient'; - - if (opt.pulseApiKey) { - const withPulse = loadModule('@prisma/extension-pulse/node', projectDir).withPulse; - prisma = prisma.$extends(withPulse({ apiKey: opt.pulseApiKey })); - } - opt.extraSourceFiles?.forEach(({ name, content }) => { fs.writeFileSync(path.join(projectDir, name), content); }); @@ -325,6 +304,27 @@ export async function loadSchema(schema: string, options?: SchemaLoadOptions) { run('npx tsc --project tsconfig.json'); } + const prismaLoadPath = options?.prismaLoadPath + ? path.isAbsolute(options.prismaLoadPath) + ? options.prismaLoadPath + : path.join(projectDir, options.prismaLoadPath) + : path.join(projectDir, 'node_modules/.prisma/client'); + const prismaModule = require(prismaLoadPath); + const PrismaClient = prismaModule.PrismaClient; + + let clientOptions: object = { log: ['info', 'warn', 'error'] }; + if (options?.prismaClientOptions) { + clientOptions = { ...clientOptions, ...options.prismaClientOptions }; + } + let prisma = new PrismaClient(clientOptions); + // https://github.com/prisma/prisma/issues/18292 + prisma[Symbol.for('nodejs.util.inspect.custom')] = 'PrismaClient'; + + if (opt.pulseApiKey) { + const withPulse = loadModule('@prisma/extension-pulse/node', projectDir).withPulse; + prisma = prisma.$extends(withPulse({ apiKey: opt.pulseApiKey })); + } + if (options?.getPrismaOnly) { return { prisma, diff --git a/tests/integration/tests/cli/generate.test.ts b/tests/integration/tests/cli/generate.test.ts index c4aed8e51..857c4fed1 100644 --- a/tests/integration/tests/cli/generate.test.ts +++ b/tests/integration/tests/cli/generate.test.ts @@ -68,8 +68,10 @@ model Post { const program = createProgram(); await program.parseAsync(['generate', '--no-dependency-check', '-o', 'out'], { from: 'user' }); expect(fs.existsSync('./node_modules/.zenstack')).toBeFalsy(); - expect(fs.existsSync('./out/policy.js')).toBeTruthy(); - expect(fs.existsSync('./out/model-meta.js')).toBeTruthy(); + expect(fs.existsSync('./out/policy.ts')).toBeTruthy(); + expect(fs.existsSync('./out/model-meta.ts')).toBeTruthy(); + expect(fs.existsSync('./out/policy.js')).toBeFalsy(); + expect(fs.existsSync('./out/model-meta.js')).toBeFalsy(); expect(fs.existsSync('./out/zod')).toBeTruthy(); }); @@ -83,8 +85,10 @@ model Post { from: 'user', }); expect(fs.existsSync('./node_modules/.zenstack')).toBeFalsy(); - expect(fs.existsSync('./out/policy.js')).toBeTruthy(); - expect(fs.existsSync('./out/model-meta.js')).toBeTruthy(); + expect(fs.existsSync('./out/policy.ts')).toBeTruthy(); + expect(fs.existsSync('./out/model-meta.ts')).toBeTruthy(); + expect(fs.existsSync('./out/policy.js')).toBeFalsy(); + expect(fs.existsSync('./out/model-meta.js')).toBeFalsy(); expect(fs.existsSync('./out/zod')).toBeTruthy(); }); diff --git a/tests/integration/tests/misc/prisma-client-generator.test.ts b/tests/integration/tests/misc/prisma-client-generator.test.ts new file mode 100644 index 000000000..fbb2813c9 --- /dev/null +++ b/tests/integration/tests/misc/prisma-client-generator.test.ts @@ -0,0 +1,124 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('New prisma-client generator tests', () => { + it('works with `auth` in `@default`', async () => { + const { enhance, prisma } = await loadSchema( + ` + datasource db { + provider = "sqlite" + url = "file:./dev.db" + } + + generator client { + provider = "prisma-client" + output = "./prisma-generated" + moduleFormat = "cjs" + } + + model User { + id Int @id + posts Post[] + @@allow('all', true) + } + + model Post { + id Int @id + title String + author User @relation(fields: [authorId], references: [id]) + authorId Int @default(auth().id) + @@allow('all', true) + } + `, + { + addPrelude: false, + output: './zenstack', + compile: true, + prismaLoadPath: './prisma/prisma-generated/client', + extraSourceFiles: [ + { + name: 'main.ts', + content: ` +import { PrismaClient } from './prisma/prisma-generated/client'; +import { enhance } from './zenstack/enhance'; + +const prisma = new PrismaClient(); +const db = enhance(prisma); + +async function main() { + const post = await db.post.create({ data: { id: 1, title: 'Hello World' } }); + console.log(post.authorId); +} + +main(); +`, + }, + ], + } + ); + + const user = await prisma.user.create({ data: { id: 1 } }); + const db = enhance({ id: user.id }); + await expect(db.post.create({ data: { id: 1, title: 'Hello World' } })).resolves.toMatchObject({ + authorId: user.id, + }); + }); + + it('works with delegate models', async () => { + const { enhance } = await loadSchema( + ` + datasource db { + provider = "sqlite" + url = "file:./dev.db" + } + + generator client { + provider = "prisma-client" + output = "./prisma-generated" + moduleFormat = "cjs" + } + + model Asset { + id Int @id + name String + type String + @@delegate(type) + } + + model Post extends Asset { + title String + } + `, + { + enhancements: ['delegate'], + addPrelude: false, + output: './zenstack', + compile: true, + prismaLoadPath: './prisma/prisma-generated/client', + extraSourceFiles: [ + { + name: 'main.ts', + content: ` +import { PrismaClient } from './prisma/prisma-generated/client'; +import { enhance } from './zenstack/enhance'; + +const prisma = new PrismaClient(); +const db = enhance(prisma); + +async function main() { + const post = await db.post.create({ data: { id: 1, name: 'Test Post', title: 'Hello World' } }); + console.log(post.type, post.name, post.title); +} + +main(); +`, + }, + ], + } + ); + + const db = enhance(); + await expect( + db.post.create({ data: { id: 1, name: 'Test Post', title: 'Hello World' } }) + ).resolves.toMatchObject({ id: 1, name: 'Test Post', type: 'Post', title: 'Hello World' }); + }); +});