diff --git a/packages/server/src/api/rest/index.ts b/packages/server/src/api/rest/index.ts index 32367df44..4bf2dcfe3 100644 --- a/packages/server/src/api/rest/index.ts +++ b/packages/server/src/api/rest/index.ts @@ -22,19 +22,6 @@ import { LoggerConfig, Response } from '../../types'; import { APIHandlerBase, RequestContext } from '../base'; import { logWarning, registerCustomSerializers } from '../utils'; -const urlPatterns = { - // collection operations - collection: new UrlPattern('/:type'), - // single resource operations - single: new UrlPattern('/:type/:id'), - // related entity fetching - fetchRelationship: new UrlPattern('/:type/:id/:relationship'), - // relationship operations - relationship: new UrlPattern('/:type/:id/relationships/:relationship'), -}; - -export const idDivider = '_'; - /** * Request handler options */ @@ -52,6 +39,19 @@ export type Options = { * Defaults to 100. Set to Infinity to disable pagination. */ pageSize?: number; + + /** + * The divider used to separate compound ID fields in the URL. + * Defaults to '_'. + */ + idDivider?: string; + + /** + * The charset used for URL segment values. Defaults to `a-zA-Z0-9-_~ %`. You can change it if your entity's ID values + * allow different characters. Specifically, if your models use compound IDs and the idDivider is set to a different value, + * it should be included in the charset. + */ + urlSegmentCharset?: string; }; type RelationshipInfo = { @@ -93,6 +93,8 @@ const FilterOperations = [ type FilterOperationType = (typeof FilterOperations)[number] | undefined; +const prismaIdDivider = '_'; + registerCustomSerializers(); /** @@ -210,8 +212,30 @@ class RequestHandler extends APIHandlerBase { // all known types and their metadata private typeMap: Record; + // divider used to separate compound ID fields + private idDivider; + + private urlPatterns; + constructor(private readonly options: Options) { super(); + this.idDivider = options.idDivider ?? prismaIdDivider; + const segmentCharset = options.urlSegmentCharset ?? 'a-zA-Z0-9-_~ %'; + this.urlPatterns = this.buildUrlPatterns(this.idDivider, segmentCharset); + } + + buildUrlPatterns(idDivider: string, urlSegmentNameCharset: string) { + const options = { segmentValueCharset: urlSegmentNameCharset }; + return { + // collection operations + collection: new UrlPattern('/:type', options), + // single resource operations + single: new UrlPattern('/:type/:id', options), + // related entity fetching + fetchRelationship: new UrlPattern('/:type/:id/:relationship', options), + // relationship operations + relationship: new UrlPattern('/:type/:id/relationships/:relationship', options), + }; } async handleRequest({ @@ -245,19 +269,19 @@ class RequestHandler extends APIHandlerBase { try { switch (method) { case 'GET': { - let match = urlPatterns.single.match(path); + let match = this.urlPatterns.single.match(path); if (match) { // single resource read return await this.processSingleRead(prisma, match.type, match.id, query); } - match = urlPatterns.fetchRelationship.match(path); + match = this.urlPatterns.fetchRelationship.match(path); if (match) { // fetch related resource(s) return await this.processFetchRelated(prisma, match.type, match.id, match.relationship, query); } - match = urlPatterns.relationship.match(path); + match = this.urlPatterns.relationship.match(path); if (match) { // read relationship return await this.processReadRelationship( @@ -269,7 +293,7 @@ class RequestHandler extends APIHandlerBase { ); } - match = urlPatterns.collection.match(path); + match = this.urlPatterns.collection.match(path); if (match) { // collection read return await this.processCollectionRead(prisma, match.type, query); @@ -283,13 +307,13 @@ class RequestHandler extends APIHandlerBase { return this.makeError('invalidPayload'); } - let match = urlPatterns.collection.match(path); + let match = this.urlPatterns.collection.match(path); if (match) { // resource creation return await this.processCreate(prisma, match.type, query, requestBody, modelMeta, zodSchemas); } - match = urlPatterns.relationship.match(path); + match = this.urlPatterns.relationship.match(path); if (match) { // relationship creation (collection relationship only) return await this.processRelationshipCRUD( @@ -313,7 +337,7 @@ class RequestHandler extends APIHandlerBase { return this.makeError('invalidPayload'); } - let match = urlPatterns.single.match(path); + let match = this.urlPatterns.single.match(path); if (match) { // resource update return await this.processUpdate( @@ -327,7 +351,7 @@ class RequestHandler extends APIHandlerBase { ); } - match = urlPatterns.relationship.match(path); + match = this.urlPatterns.relationship.match(path); if (match) { // relationship update return await this.processRelationshipCRUD( @@ -345,13 +369,13 @@ class RequestHandler extends APIHandlerBase { } case 'DELETE': { - let match = urlPatterns.single.match(path); + let match = this.urlPatterns.single.match(path); if (match) { // resource deletion return await this.processDelete(prisma, match.type, match.id); } - match = urlPatterns.relationship.match(path); + match = this.urlPatterns.relationship.match(path); if (match) { // relationship deletion (collection relationship only) return await this.processRelationshipCRUD( @@ -391,7 +415,7 @@ class RequestHandler extends APIHandlerBase { return this.makeUnsupportedModelError(type); } - const args: any = { where: this.makeIdFilter(typeInfo.idFields, resourceId) }; + const args: any = { where: this.makePrismaIdFilter(typeInfo.idFields, resourceId) }; // include IDs of relation fields so that they can be serialized this.includeRelationshipIds(type, args, 'include'); @@ -456,7 +480,7 @@ class RequestHandler extends APIHandlerBase { select = select ?? { [relationship]: true }; const args: any = { - where: this.makeIdFilter(typeInfo.idFields, resourceId), + where: this.makePrismaIdFilter(typeInfo.idFields, resourceId), select, }; @@ -514,7 +538,7 @@ class RequestHandler extends APIHandlerBase { } const args: any = { - where: this.makeIdFilter(typeInfo.idFields, resourceId), + where: this.makePrismaIdFilter(typeInfo.idFields, resourceId), select: this.makeIdSelect(typeInfo.idFields), }; @@ -753,7 +777,7 @@ class RequestHandler extends APIHandlerBase { if (relationInfo.isCollection) { createPayload.data[key] = { connect: enumerate(data.data).map((item: any) => ({ - [this.makeIdKey(relationInfo.idFields)]: item.id, + [this.makePrismaIdKey(relationInfo.idFields)]: item.id, })), }; } else { @@ -762,7 +786,7 @@ class RequestHandler extends APIHandlerBase { } createPayload.data[key] = { connect: { - [this.makeIdKey(relationInfo.idFields)]: data.data.id, + [this.makePrismaIdKey(relationInfo.idFields)]: data.data.id, }, }; } @@ -770,7 +794,7 @@ class RequestHandler extends APIHandlerBase { // make sure ID fields are included for result serialization createPayload.include = { ...createPayload.include, - [key]: { select: { [this.makeIdKey(relationInfo.idFields)]: true } }, + [key]: { select: { [this.makePrismaIdKey(relationInfo.idFields)]: true } }, }; } } @@ -807,7 +831,7 @@ class RequestHandler extends APIHandlerBase { } const updateArgs: any = { - where: this.makeIdFilter(typeInfo.idFields, resourceId), + where: this.makePrismaIdFilter(typeInfo.idFields, resourceId), select: { ...typeInfo.idFields.reduce((acc, field) => ({ ...acc, [field.name]: true }), {}), [relationship]: { select: this.makeIdSelect(relationInfo.idFields) }, @@ -842,7 +866,7 @@ class RequestHandler extends APIHandlerBase { updateArgs.data = { [relationship]: { connect: { - [this.makeIdKey(relationInfo.idFields)]: parsed.data.data.id, + [this.makePrismaIdKey(relationInfo.idFields)]: parsed.data.data.id, }, }, }; @@ -866,7 +890,7 @@ class RequestHandler extends APIHandlerBase { updateArgs.data = { [relationship]: { [relationVerb]: enumerate(parsed.data.data).map((item: any) => - this.makeIdFilter(relationInfo.idFields, item.id) + this.makePrismaIdFilter(relationInfo.idFields, item.id) ), }, }; @@ -907,7 +931,7 @@ class RequestHandler extends APIHandlerBase { } const updatePayload: any = { - where: this.makeIdFilter(typeInfo.idFields, resourceId), + where: this.makePrismaIdFilter(typeInfo.idFields, resourceId), data: { ...attributes }, }; @@ -926,7 +950,7 @@ class RequestHandler extends APIHandlerBase { if (relationInfo.isCollection) { updatePayload.data[key] = { set: enumerate(data.data).map((item: any) => ({ - [this.makeIdKey(relationInfo.idFields)]: item.id, + [this.makePrismaIdKey(relationInfo.idFields)]: item.id, })), }; } else { @@ -935,13 +959,13 @@ class RequestHandler extends APIHandlerBase { } updatePayload.data[key] = { set: { - [this.makeIdKey(relationInfo.idFields)]: data.data.id, + [this.makePrismaIdKey(relationInfo.idFields)]: data.data.id, }, }; } updatePayload.include = { ...updatePayload.include, - [key]: { select: { [this.makeIdKey(relationInfo.idFields)]: true } }, + [key]: { select: { [this.makePrismaIdKey(relationInfo.idFields)]: true } }, }; } } @@ -960,7 +984,7 @@ class RequestHandler extends APIHandlerBase { } await prisma[type].delete({ - where: this.makeIdFilter(typeInfo.idFields, resourceId), + where: this.makePrismaIdFilter(typeInfo.idFields, resourceId), }); return { status: 204, @@ -1110,7 +1134,7 @@ class RequestHandler extends APIHandlerBase { if (ids.length === 0) { return undefined; } else { - return data[ids.map((id) => id.name).join(idDivider)]; + return data[this.makeIdKey(ids)]; } } @@ -1206,15 +1230,16 @@ class RequestHandler extends APIHandlerBase { return r.toString(); } - private makeIdFilter(idFields: FieldInfo[], resourceId: string) { + private makePrismaIdFilter(idFields: FieldInfo[], resourceId: string) { if (idFields.length === 1) { return { [idFields[0].name]: this.coerce(idFields[0].type, resourceId) }; } else { return { - [idFields.map((idf) => idf.name).join(idDivider)]: idFields.reduce( + // TODO: support `@@id` with custom name + [idFields.map((idf) => idf.name).join(prismaIdDivider)]: idFields.reduce( (acc, curr, idx) => ({ ...acc, - [curr.name]: this.coerce(curr.type, resourceId.split(idDivider)[idx]), + [curr.name]: this.coerce(curr.type, resourceId.split(this.idDivider)[idx]), }), {} ), @@ -1230,11 +1255,16 @@ class RequestHandler extends APIHandlerBase { } private makeIdKey(idFields: FieldInfo[]) { - return idFields.map((idf) => idf.name).join(idDivider); + return idFields.map((idf) => idf.name).join(this.idDivider); + } + + private makePrismaIdKey(idFields: FieldInfo[]) { + // TODO: support `@@id` with custom name + return idFields.map((idf) => idf.name).join(prismaIdDivider); } private makeCompoundId(idFields: FieldInfo[], item: any) { - return idFields.map((idf) => item[idf.name]).join(idDivider); + return idFields.map((idf) => item[idf.name]).join(this.idDivider); } private includeRelationshipIds(model: string, args: any, mode: 'select' | 'include') { @@ -1557,11 +1587,11 @@ class RequestHandler extends APIHandlerBase { const values = value.split(',').filter((i) => i); const filterValue = values.length > 1 - ? { OR: values.map((v) => this.makeIdFilter(info.idFields, v)) } - : this.makeIdFilter(info.idFields, value); + ? { OR: values.map((v) => this.makePrismaIdFilter(info.idFields, v)) } + : this.makePrismaIdFilter(info.idFields, value); return { some: filterValue }; } else { - return { is: this.makeIdFilter(info.idFields, value) }; + return { is: this.makePrismaIdFilter(info.idFields, value) }; } } else { const coerced = this.coerce(fieldInfo.type, value); diff --git a/packages/server/tests/api/rest.test.ts b/packages/server/tests/api/rest.test.ts index fa7d0cfb8..640dcbebe 100644 --- a/packages/server/tests/api/rest.test.ts +++ b/packages/server/tests/api/rest.test.ts @@ -2,10 +2,12 @@ /// import { CrudFailureReason, type ModelMeta } from '@zenstackhq/runtime'; -import { loadSchema, run } from '@zenstackhq/testtools'; +import { createPostgresDb, dropPostgresDb, loadSchema, run } from '@zenstackhq/testtools'; import { Decimal } from 'decimal.js'; import SuperJSON from 'superjson'; -import makeHandler, { idDivider } from '../../src/api/rest'; +import makeHandler from '../../src/api/rest'; + +const idDivider = '_'; describe('REST server tests', () => { let prisma: any; @@ -2519,4 +2521,121 @@ describe('REST server tests', () => { expect(Buffer.isBuffer(included.attributes.bytes)).toBeTruthy(); }); }); + + describe('REST server tests - compound id with custom separator', () => { + const schema = ` + enum Role { + COMMON_USER + ADMIN_USER + } + + model User { + email String + role Role + enabled Boolean @default(true) + + @@id([email, role]) + } + `; + const idDivider = ':'; + const dbName = 'restful-compound-id-custom-separator'; + + beforeAll(async () => { + const params = await loadSchema(schema, { + provider: 'postgresql', + dbUrl: await createPostgresDb(dbName), + }); + + prisma = params.prisma; + zodSchemas = params.zodSchemas; + modelMeta = params.modelMeta; + + const _handler = makeHandler({ + endpoint: 'http://localhost/api', + pageSize: 5, + idDivider, + urlSegmentCharset: 'a-zA-Z0-9-_~ %@.:', + }); + handler = (args) => + _handler({ ...args, zodSchemas, modelMeta, url: new URL(`http://localhost/${args.path}`) }); + }); + + afterAll(async () => { + if (prisma) { + await prisma.$disconnect(); + } + await dropPostgresDb(dbName); + }); + + it('POST', async () => { + const r = await handler({ + method: 'post', + path: '/user', + query: {}, + requestBody: { + data: { + type: 'user', + attributes: { email: 'user1@abc.com', role: 'COMMON_USER' }, + }, + }, + prisma, + }); + + expect(r.status).toBe(201); + }); + + it('GET', async () => { + await prisma.user.create({ + data: { email: 'user1@abc.com', role: 'COMMON_USER' }, + }); + + const r = await handler({ + method: 'get', + path: '/user', + query: {}, + prisma, + }); + + expect(r.status).toBe(200); + expect(r.body.data).toHaveLength(1); + }); + + it('GET single', async () => { + await prisma.user.create({ + data: { email: 'user1@abc.com', role: 'COMMON_USER' }, + }); + + const r = await handler({ + method: 'get', + path: '/user/user1@abc.com:COMMON_USER', + query: {}, + prisma, + }); + + expect(r.status).toBe(200); + expect(r.body.data.attributes.email).toBe('user1@abc.com'); + }); + + it('PUT', async () => { + await prisma.user.create({ + data: { email: 'user1@abc.com', role: 'COMMON_USER' }, + }); + + const r = await handler({ + method: 'put', + path: '/user/user1@abc.com:COMMON_USER', + query: {}, + requestBody: { + data: { + type: 'user', + attributes: { enabled: false }, + }, + }, + prisma, + }); + + expect(r.status).toBe(200); + expect(r.body.data.attributes.enabled).toBe(false); + }); + }); }); diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index a0da7754f..6771d2899 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -706,7 +706,7 @@ importers: version: 10.3.9(@nestjs/common@10.3.9(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/core@10.3.9) '@nestjs/testing': specifier: ^10.3.7 - version: 10.3.9(@nestjs/common@10.3.9(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/core@10.3.9(@nestjs/common@10.3.9(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/platform-express@10.3.9)(encoding@0.1.13)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/platform-express@10.3.9(@nestjs/common@10.3.9(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/core@10.3.9)) + version: 10.3.9(@nestjs/common@10.3.9(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/core@10.3.9)(@nestjs/platform-express@10.3.9) '@sveltejs/kit': specifier: 1.21.0 version: 1.21.0(svelte@4.2.18)(vite@5.3.2(@types/node@20.14.9)(terser@5.31.1)) @@ -3986,7 +3986,7 @@ packages: engines: {node: '>= 14'} concat-map@0.0.1: - resolution: {integrity: sha1-2Klr13/Wjfd5OnMDajug1UBdR3s=} + resolution: {integrity: sha512-/Srv4dswyQNBfohGpz9o6Yb3Gz3SrUDqBH5rTuhGR7ahtlbYKnVxw2bCFMRljaA7EXHaXZ8wsHdodFvbkhKmqg==} concat-stream@1.6.2: resolution: {integrity: sha512-27HBghJxjiZtIk3Ycvn/4kbJk/1uZuJFfuPEns6LaEvpvG1f0hTea8lilrouyo9mVc2GWdcEZ8OLoGmSADlrCw==} @@ -4424,7 +4424,7 @@ packages: resolution: {integrity: sha512-nagl3RYrbNv6kQkeJIpt6NJZy8twLB/2vtz6yN9Z4vRKHN4/QZJIEbqohALSgwKdnksuY3k5Addp5lg8sVoVcQ==} ee-first@1.1.1: - resolution: {integrity: sha1-WQxhFWsK4vTwJVcyoViyZrxWsh0=} + resolution: {integrity: sha512-WMwm9LhRUo+WUaRN+vRuETqG89IgZphVSNkdFgeb6sS/E4OrDIN7t48CAewSHXc6C8lefD8KKfr5vY61brQlow==} electron-to-chromium@1.4.814: resolution: {integrity: sha512-GVulpHjFu1Y9ZvikvbArHmAhZXtm3wHlpjTMcXNGKl4IQ4jMQjlnz8yMQYYqdLHKi/jEL2+CBC2akWVCoIGUdw==} @@ -6100,7 +6100,7 @@ packages: resolution: {integrity: sha512-/sKlQJCBYVY9Ers9hqzKou4H6V5UWc/M59TH2dvkt+84itfnq7uFOMLpOiOS4ujvHP4etln18fmIxA5R5fll0g==} media-typer@0.3.0: - resolution: {integrity: sha1-hxDXrwqmJvj/+hzgAWhUUmMlV0g=} + resolution: {integrity: sha512-dq+qelQ9akHpcOl/gUVRTxVIOkAJ1wR3QAvb4RsVjS8oVoFjDGTc679wJYmUmknUF5HwMLOgb5O+a3KxfWapPQ==} engines: {node: '>= 0.6'} merge-descriptors@1.0.1: @@ -8260,7 +8260,7 @@ packages: resolution: {integrity: sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw==} utils-merge@1.0.1: - resolution: {integrity: sha1-n5VxD1CiZ5R7LMwSR0HBAoQn5xM=} + resolution: {integrity: sha512-pMZTvIkT1d+TFGvDOqodOclx0QWkkgi6Tdoa8gC8ffGAAqz9pzPTZWAybbsHHoED/ztMtkv/VoYTYyShUn81hA==} engines: {node: '>= 0.4.0'} uuid@10.0.0: @@ -10080,7 +10080,7 @@ snapshots: transitivePeerDependencies: - supports-color - '@nestjs/testing@10.3.9(@nestjs/common@10.3.9(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/core@10.3.9(@nestjs/common@10.3.9(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/platform-express@10.3.9)(encoding@0.1.13)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/platform-express@10.3.9(@nestjs/common@10.3.9(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/core@10.3.9))': + '@nestjs/testing@10.3.9(@nestjs/common@10.3.9(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/core@10.3.9)(@nestjs/platform-express@10.3.9)': dependencies: '@nestjs/common': 10.3.9(reflect-metadata@0.2.2)(rxjs@7.8.1) '@nestjs/core': 10.3.9(@nestjs/common@10.3.9(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/platform-express@10.3.9)(encoding@0.1.13)(reflect-metadata@0.2.2)(rxjs@7.8.1)