From 5babf7a5ab89022d0e4967dd0dbfcf6eda757842 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Mon, 21 Jul 2025 23:41:37 +0800 Subject: [PATCH] feat: strongly typed JSON fields --- packages/language/res/stdlib.zmodel | 2 +- packages/runtime/src/client/crud-types.ts | 62 ++--- .../src/client/crud/dialects/sqlite.ts | 19 +- .../src/client/crud/operations/base.ts | 3 +- packages/runtime/src/client/crud/validator.ts | 59 ++++- packages/runtime/src/client/query-builder.ts | 4 +- .../runtime/src/client/result-processor.ts | 21 +- .../test/client-api/typed-json-fields.test.ts | 222 ++++++++++++++++++ packages/runtime/test/typing/schema.ts | 8 +- .../runtime/test/typing/typing-test.zmodel | 15 +- packages/runtime/test/typing/verify-typing.ts | 19 ++ packages/sdk/src/schema/schema.ts | 8 +- 12 files changed, 376 insertions(+), 66 deletions(-) create mode 100644 packages/runtime/test/client-api/typed-json-fields.test.ts diff --git a/packages/language/res/stdlib.zmodel b/packages/language/res/stdlib.zmodel index 52f300bb..8f91957f 100644 --- a/packages/language/res/stdlib.zmodel +++ b/packages/language/res/stdlib.zmodel @@ -701,7 +701,7 @@ function raw(value: String): Any { /** * Marks a field to be strong-typed JSON. */ -attribute @json() @@@targetField([TypeDefField]) @@@deprecated('The "@json" attribute is not needed anymore. ZenStack will automatically use JSON to store typed fields.') +attribute @json() @@@targetField([TypeDefField]) /** * Marks a field to be computed. diff --git a/packages/runtime/src/client/crud-types.ts b/packages/runtime/src/client/crud-types.ts index c14cf401..821dac06 100644 --- a/packages/runtime/src/client/crud-types.ts +++ b/packages/runtime/src/client/crud-types.ts @@ -4,7 +4,6 @@ import type { FieldDef, FieldHasDefault, FieldIsArray, - FieldIsOptional, FieldIsRelation, FieldIsRelationArray, FieldType, @@ -19,12 +18,14 @@ import type { GetTypeDefField, GetTypeDefFields, GetTypeDefs, + ModelFieldIsOptional, NonRelationFields, RelationFields, RelationFieldType, RelationInfo, ScalarFields, SchemaDef, + TypeDefFieldIsOptional, } from '../schema'; import type { AtLeast, @@ -86,21 +87,21 @@ type ModelSelectResult Schema, RelationFieldType, Pick, - FieldIsOptional, + ModelFieldIsOptional, FieldIsArray > : ModelResult< Schema, RelationFieldType, Pick, - FieldIsOptional, + ModelFieldIsOptional, FieldIsArray > : DefaultModelResult< Schema, RelationFieldType, Omit, - FieldIsOptional, + ModelFieldIsOptional, FieldIsArray > : never; @@ -143,14 +144,14 @@ export type ModelResult< Schema, RelationFieldType, I[Key], - FieldIsOptional, + ModelFieldIsOptional, FieldIsArray > : DefaultModelResult< Schema, RelationFieldType, undefined, - FieldIsOptional, + ModelFieldIsOptional, FieldIsArray >; } @@ -169,9 +170,17 @@ export type SimplifiedModelResult< Array = false, > = Simplify>; -export type TypeDefResult> = { - [Key in GetTypeDefFields]: MapTypeDefFieldType; -}; +export type TypeDefResult> = Optional< + { + [Key in GetTypeDefFields]: MapTypeDefFieldType; + }, + // optionality + keyof { + [Key in GetTypeDefFields as TypeDefFieldIsOptional extends true + ? Key + : never]: Key; + } +>; export type BatchResult = { count: number }; @@ -193,11 +202,11 @@ export type WhereInput< RelationFilter : // enum GetModelFieldType extends GetEnums - ? EnumFilter, FieldIsOptional> + ? EnumFilter, ModelFieldIsOptional> : FieldIsArray extends true ? ArrayFilter> : // primitive - PrimitiveFilter, FieldIsOptional>; + PrimitiveFilter, ModelFieldIsOptional>; } & { $expr?: (eb: ExpressionBuilder, Model>) => OperandExpression; } & { @@ -290,7 +299,7 @@ export type OrderBy< WithRelation extends boolean, WithAggregation extends boolean, > = { - [Key in NonRelationFields]?: FieldIsOptional extends true + [Key in NonRelationFields]?: ModelFieldIsOptional extends true ? | SortOrder | { @@ -391,7 +400,7 @@ export type IncludeInput extends true ? true - : FieldIsOptional extends true + : ModelFieldIsOptional extends true ? true : false >; @@ -427,14 +436,14 @@ type ToOneRelationFilter< WhereInput> & { is?: NullableIf< WhereInput>, - FieldIsOptional + ModelFieldIsOptional >; isNot?: NullableIf< WhereInput>, - FieldIsOptional + ModelFieldIsOptional >; }, - FieldIsOptional + ModelFieldIsOptional >; type RelationFilter< @@ -460,23 +469,20 @@ type MapTypeDefFieldType< Schema extends SchemaDef, TypeDef extends GetTypeDefs, Field extends GetTypeDefFields, -> = - GetTypeDefField['type'] extends GetTypeDefs - ? WrapType< - TypeDefResult['type']>, - GetTypeDefField['optional'], - GetTypeDefField['array'] - > - : MapFieldDefType>; +> = MapFieldDefType>; type MapFieldDefType> = WrapType< - T['type'] extends GetEnums ? keyof GetEnum : MapBaseType, + T['type'] extends GetEnums + ? keyof GetEnum + : T['type'] extends GetTypeDefs + ? TypeDefResult & Record + : MapBaseType, T['optional'], T['array'] >; type OptionalFieldsForCreate> = keyof { - [Key in GetModelFields as FieldIsOptional extends true + [Key in GetModelFields as ModelFieldIsOptional extends true ? Key : FieldHasDefault extends true ? Key @@ -752,7 +758,7 @@ type ScalarUpdatePayload< | MapModelFieldType | (Field extends NumericFields ? { - set?: NullableIf>; + set?: NullableIf>; increment?: number; decrement?: number; multiply?: number; @@ -820,7 +826,7 @@ type ToOneRelationUpdateInput< connectOrCreate?: ConnectOrCreateInput; update?: NestedUpdateInput; upsert?: NestedUpsertInput; -} & (FieldIsOptional extends true +} & (ModelFieldIsOptional extends true ? { disconnect?: DisconnectInput; delete?: NestedDeleteInput; diff --git a/packages/runtime/src/client/crud/dialects/sqlite.ts b/packages/runtime/src/client/crud/dialects/sqlite.ts index 7fa67905..c27cd7de 100644 --- a/packages/runtime/src/client/crud/dialects/sqlite.ts +++ b/packages/runtime/src/client/crud/dialects/sqlite.ts @@ -34,13 +34,18 @@ export class SqliteCrudDialect extends BaseCrudDialect if (Array.isArray(value)) { return value.map((v) => this.transformPrimitive(v, type, false)); } else { - return match(type) - .with('Boolean', () => (value ? 1 : 0)) - .with('DateTime', () => (value instanceof Date ? value.toISOString() : value)) - .with('Decimal', () => (value as Decimal).toString()) - .with('Bytes', () => Buffer.from(value as Uint8Array)) - .with('Json', () => JSON.stringify(value)) - .otherwise(() => value); + if (this.schema.typeDefs && type in this.schema.typeDefs) { + // typed JSON field + return JSON.stringify(value); + } else { + return match(type) + .with('Boolean', () => (value ? 1 : 0)) + .with('DateTime', () => (value instanceof Date ? value.toISOString() : value)) + .with('Decimal', () => (value as Decimal).toString()) + .with('Bytes', () => Buffer.from(value as Uint8Array)) + .with('Json', () => JSON.stringify(value)) + .otherwise(() => value); + } } } diff --git a/packages/runtime/src/client/crud/operations/base.ts b/packages/runtime/src/client/crud/operations/base.ts index 32775ef8..64e4efee 100644 --- a/packages/runtime/src/client/crud/operations/base.ts +++ b/packages/runtime/src/client/crud/operations/base.ts @@ -493,7 +493,8 @@ export abstract class BaseOperationHandler { const idFields = getIdFields(this.schema, model); const query = kysely .insertInto(model) - .values(updatedData) + .$if(Object.keys(updatedData).length === 0, (qb) => qb.defaultValues()) + .$if(Object.keys(updatedData).length > 0, (qb) => qb.values(updatedData)) .returning(idFields as any) .modifyEnd( this.makeContextComment({ diff --git a/packages/runtime/src/client/crud/validator.ts b/packages/runtime/src/client/crud/validator.ts index 00dc4f2c..cad8e953 100644 --- a/packages/runtime/src/client/crud/validator.ts +++ b/packages/runtime/src/client/crud/validator.ts @@ -218,16 +218,46 @@ export class InputValidator { } private makePrimitiveSchema(type: string) { - return match(type) - .with('String', () => z.string()) - .with('Int', () => z.number()) - .with('Float', () => z.number()) - .with('Boolean', () => z.boolean()) - .with('BigInt', () => z.union([z.number(), z.bigint()])) - .with('Decimal', () => z.union([z.number(), z.instanceof(Decimal), z.string()])) - .with('DateTime', () => z.union([z.date(), z.string().datetime()])) - .with('Bytes', () => z.instanceof(Uint8Array)) - .otherwise(() => z.unknown()); + if (this.schema.typeDefs && type in this.schema.typeDefs) { + return this.makeTypeDefSchema(type); + } else { + return match(type) + .with('String', () => z.string()) + .with('Int', () => z.number()) + .with('Float', () => z.number()) + .with('Boolean', () => z.boolean()) + .with('BigInt', () => z.union([z.number(), z.bigint()])) + .with('Decimal', () => z.union([z.number(), z.instanceof(Decimal), z.string()])) + .with('DateTime', () => z.union([z.date(), z.string().datetime()])) + .with('Bytes', () => z.instanceof(Uint8Array)) + .otherwise(() => z.unknown()); + } + } + + private makeTypeDefSchema(type: string): z.ZodType { + const key = `$typedef-${type}`; + let schema = this.schemaCache.get(key); + if (schema) { + return schema; + } + const typeDef = this.schema.typeDefs?.[type]; + invariant(typeDef, `Type definition "${type}" not found in schema`); + schema = z.looseObject( + Object.fromEntries( + Object.entries(typeDef.fields).map(([field, def]) => { + let fieldSchema = this.makePrimitiveSchema(def.type); + if (def.array) { + fieldSchema = fieldSchema.array(); + } + if (def.optional) { + fieldSchema = fieldSchema.optional(); + } + return [field, fieldSchema]; + }), + ), + ); + this.schemaCache.set(key, schema); + return schema; } private makeWhereSchema(model: string, unique: boolean, withoutRelationFields = false): ZodType { @@ -396,6 +426,10 @@ export class InputValidator { } private makePrimitiveFilterSchema(type: BuiltinType, optional: boolean) { + if (this.schema.typeDefs && type in this.schema.typeDefs) { + // typed JSON field + return this.makeTypeDefFilterSchema(type, optional); + } return ( match(type) .with('String', () => this.makeStringFilterSchema(optional)) @@ -412,6 +446,11 @@ export class InputValidator { ); } + private makeTypeDefFilterSchema(_type: string, _optional: boolean) { + // TODO: strong typed JSON filtering + return z.never(); + } + private makeDateTimeFilterSchema(optional: boolean): ZodType { return this.makeCommonPrimitiveFilterSchema(z.union([z.string().datetime(), z.date()]), optional, () => z.lazy(() => this.makeDateTimeFilterSchema(optional)), diff --git a/packages/runtime/src/client/query-builder.ts b/packages/runtime/src/client/query-builder.ts index 64b76ea0..19997017 100644 --- a/packages/runtime/src/client/query-builder.ts +++ b/packages/runtime/src/client/query-builder.ts @@ -2,11 +2,11 @@ import type Decimal from 'decimal.js'; import type { Generated, Kysely } from 'kysely'; import type { FieldHasDefault, - FieldIsOptional, ForeignKeyFields, GetModelFields, GetModelFieldType, GetModels, + ModelFieldIsOptional, ScalarFields, SchemaDef, } from '../schema'; @@ -45,7 +45,7 @@ type MapType< Schema extends SchemaDef, Model extends GetModels, Field extends GetModelFields, -> = WrapNull>, FieldIsOptional>; +> = WrapNull>, ModelFieldIsOptional>; type toKyselyFieldType< Schema extends SchemaDef, diff --git a/packages/runtime/src/client/result-processor.ts b/packages/runtime/src/client/result-processor.ts index 25a2a4df..a43e4648 100644 --- a/packages/runtime/src/client/result-processor.ts +++ b/packages/runtime/src/client/result-processor.ts @@ -84,14 +84,19 @@ export class ResultProcessor { } private transformScalar(value: unknown, type: BuiltinType) { - return match(type) - .with('Boolean', () => this.transformBoolean(value)) - .with('DateTime', () => this.transformDate(value)) - .with('Bytes', () => this.transformBytes(value)) - .with('Decimal', () => this.transformDecimal(value)) - .with('BigInt', () => this.transformBigInt(value)) - .with('Json', () => this.transformJson(value)) - .otherwise(() => value); + if (this.schema.typeDefs && type in this.schema.typeDefs) { + // typed JSON field + return this.transformJson(value); + } else { + return match(type) + .with('Boolean', () => this.transformBoolean(value)) + .with('DateTime', () => this.transformDate(value)) + .with('Bytes', () => this.transformBytes(value)) + .with('Decimal', () => this.transformDecimal(value)) + .with('BigInt', () => this.transformBigInt(value)) + .with('Json', () => this.transformJson(value)) + .otherwise(() => value); + } } private transformDecimal(value: unknown) { diff --git a/packages/runtime/test/client-api/typed-json-fields.test.ts b/packages/runtime/test/client-api/typed-json-fields.test.ts new file mode 100644 index 00000000..fdf01f81 --- /dev/null +++ b/packages/runtime/test/client-api/typed-json-fields.test.ts @@ -0,0 +1,222 @@ +import { afterEach, beforeEach, describe, expect, it } from 'vitest'; +import { createTestClient } from '../utils'; + +const PG_DB_NAME = 'client-api-typed-json-fields-tests'; + +describe.each([{ provider: 'sqlite' as const }, { provider: 'postgresql' as const }])( + 'Typed JSON fields', + ({ provider }) => { + const schema = ` +type Identity { + providers IdentityProvider[] +} + +type IdentityProvider { + id String + name String? +} + +model User { + id Int @id @default(autoincrement()) + identity Identity? @json +} + `; + + let client: any; + + beforeEach(async () => { + client = await createTestClient(schema, { + usePrismaPush: true, + provider, + dbName: provider === 'postgresql' ? PG_DB_NAME : undefined, + log: ['query'], + }); + }); + + afterEach(async () => { + await client?.$disconnect(); + }); + + it('works with create', async () => { + await expect( + client.user.create({ + data: {}, + }), + ).resolves.toMatchObject({ + identity: null, + }); + + await expect( + client.user.create({ + data: { + identity: { + providers: [ + { + id: '123', + name: 'Google', + }, + ], + }, + }, + }), + ).resolves.toMatchObject({ + identity: { + providers: [ + { + id: '123', + name: 'Google', + }, + ], + }, + }); + + await expect( + client.user.create({ + data: { + identity: { + providers: [ + { + id: '123', + }, + ], + }, + }, + }), + ).resolves.toMatchObject({ + identity: { + providers: [ + { + id: '123', + }, + ], + }, + }); + + await expect( + client.user.create({ + data: { + identity: { + providers: [ + { + id: '123', + foo: 1, + }, + ], + }, + }, + }), + ).resolves.toMatchObject({ + identity: { + providers: [ + { + id: '123', + foo: 1, + }, + ], + }, + }); + + await expect( + client.user.create({ + data: { + identity: { + providers: [ + { + name: 'Google', + }, + ], + }, + }, + }), + ).rejects.toThrow('Invalid input'); + }); + + it('works with find', async () => { + await expect( + client.user.create({ + data: { id: 1 }, + }), + ).toResolveTruthy(); + await expect(client.user.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ + identity: null, + }); + + await expect( + client.user.create({ + data: { + id: 2, + identity: { + providers: [ + { + id: '123', + name: 'Google', + }, + ], + }, + }, + }), + ).toResolveTruthy(); + + await expect(client.user.findUnique({ where: { id: 2 } })).resolves.toMatchObject({ + identity: { + providers: [ + { + id: '123', + name: 'Google', + }, + ], + }, + }); + }); + + it('works with update', async () => { + await expect( + client.user.create({ + data: { id: 1 }, + }), + ).toResolveTruthy(); + + await expect( + client.user.update({ + where: { id: 1 }, + data: { + identity: { + providers: [ + { + id: '123', + name: 'Google', + foo: 1, + }, + ], + }, + }, + }), + ).resolves.toMatchObject({ + identity: { + providers: [ + { + id: '123', + name: 'Google', + foo: 1, + }, + ], + }, + }); + + await expect( + client.user.update({ + where: { id: 1 }, + data: { + identity: { + providers: [ + { + name: 'GitHub', + }, + ], + }, + }, + }), + ).rejects.toThrow('Invalid input'); + }); + }, +); diff --git a/packages/runtime/test/typing/schema.ts b/packages/runtime/test/typing/schema.ts index d529dbf3..49bf584e 100644 --- a/packages/runtime/test/typing/schema.ts +++ b/packages/runtime/test/typing/schema.ts @@ -56,6 +56,11 @@ export const schema = { type: "Int", attributes: [{ name: "@computed" }], computed: true + }, + identity: { + type: "Identity", + optional: true, + attributes: [{ name: "@json" }] } }, idFields: ["id"], @@ -261,7 +266,8 @@ export const schema = { type: "String" }, name: { - type: "String" + type: "String", + optional: true } } } diff --git a/packages/runtime/test/typing/typing-test.zmodel b/packages/runtime/test/typing/typing-test.zmodel index 2aa9aa67..2cb789d7 100644 --- a/packages/runtime/test/typing/typing-test.zmodel +++ b/packages/runtime/test/typing/typing-test.zmodel @@ -14,19 +14,20 @@ type Identity { type IdentityProvider { id String - name String + name String? } model User { - id Int @id @default(autoincrement()) - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt + id Int @id @default(autoincrement()) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt name String - email String @unique - role Role @default(USER) + email String @unique + role Role @default(USER) posts Post[] profile Profile? - postCount Int @computed + postCount Int @computed + identity Identity? @json } model Post { diff --git a/packages/runtime/test/typing/verify-typing.ts b/packages/runtime/test/typing/verify-typing.ts index e9410031..e815758f 100644 --- a/packages/runtime/test/typing/verify-typing.ts +++ b/packages/runtime/test/typing/verify-typing.ts @@ -40,6 +40,7 @@ async function find() { }); console.log(user1?.name); console.log(user1?.postCount); + console.log(user1?.identity?.providers[0]?.name); const users = await client.user.findMany({ include: { posts: true }, @@ -206,6 +207,24 @@ async function create() { userId: 1, }, }, + identity: { + providers: [ + { + id: '123', + name: 'GitHub', + // undeclared fields are allowed + otherField: 123, + }, + { + id: '234', + // name is optional + }, + // @ts-expect-error id is required + { + name: 'Google', + }, + ], + }, }, }); diff --git a/packages/sdk/src/schema/schema.ts b/packages/sdk/src/schema/schema.ts index 7ef99516..208024a8 100644 --- a/packages/sdk/src/schema/schema.ts +++ b/packages/sdk/src/schema/schema.ts @@ -187,12 +187,18 @@ export type RelationFieldType< ? GetModelField['type'] : never; -export type FieldIsOptional< +export type ModelFieldIsOptional< Schema extends SchemaDef, Model extends GetModels, Field extends GetModelFields, > = GetModelField['optional'] extends true ? true : false; +export type TypeDefFieldIsOptional< + Schema extends SchemaDef, + TypeDef extends GetTypeDefs, + Field extends GetTypeDefFields, +> = GetTypeDefField['optional'] extends true ? true : false; + export type FieldIsRelation< Schema extends SchemaDef, Model extends GetModels,