Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

# What's ZenStack

> Read full documentation at 👉🏻 https://zenstack.dev/v3.

ZenStack is a TypeScript database toolkit for developing full-stack or backend Node.js/Bun applications. It provides a unified data modeling and access solution with the following features:

- A modern schema-first ORM that's compatible with [Prisma](https://github.com/prisma/prisma)'s schema and API
Expand Down
38 changes: 30 additions & 8 deletions packages/orm/src/client/crud-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ import type {
} from '../schema';
import type {
AtLeast,
JsonNullValues,
JsonValue,
MapBaseType,
NonEmptyArray,
NullableIf,
Expand All @@ -44,6 +46,7 @@ import type {
WrapType,
XOR,
} from '../utils/type-utils';
import type { DbNull, JsonNull } from './null-values';
import type { ClientOptions } from './options';
import type { ToKyselySchema } from './query-builder';

Expand Down Expand Up @@ -359,7 +362,7 @@ type PrimitiveFilter<T extends string, Nullable extends boolean, WithAggregation
: T extends 'Bytes'
? BytesFilter<Nullable, WithAggregations>
: T extends 'Json'
? 'Not implemented yet' // TODO: Json filter
? JsonFilter
: never;

type CommonPrimitiveFilter<
Expand Down Expand Up @@ -452,6 +455,11 @@ export type BooleanFilter<Nullable extends boolean, WithAggregations extends boo
}
: {}));

export type JsonFilter = {
equals?: JsonValue | JsonNullValues;
not?: JsonValue | JsonNullValues;
};

export type SortOrder = 'asc' | 'desc';
export type NullsOrder = 'first' | 'last';

Expand Down Expand Up @@ -772,20 +780,34 @@ type CreateScalarPayload<Schema extends SchemaDef, Model extends GetModels<Schem
}
>;

// For unknown reason toplevel `Simplify` can't simplify this type, so we added an extra layer
// to make it work
type ScalarCreatePayload<
Schema extends SchemaDef,
Model extends GetModels<Schema>,
Field extends ScalarFields<Schema, Model, false>,
> = Simplify<
| MapModelFieldType<Schema, Model, Field>
> =
| ScalarFieldMutationPayload<Schema, Model, Field>
| (FieldIsArray<Schema, Model, Field> extends true
? {
set?: MapModelFieldType<Schema, Model, Field>;
}
: never)
>;
: never);

type ScalarFieldMutationPayload<
Schema extends SchemaDef,
Model extends GetModels<Schema>,
Field extends GetModelFields<Schema, Model>,
> =
IsJsonField<Schema, Model, Field> extends true
? ModelFieldIsOptional<Schema, Model, Field> extends true
? JsonValue | JsonNull | DbNull
: JsonValue | JsonNull
: MapModelFieldType<Schema, Model, Field>;

type IsJsonField<
Schema extends SchemaDef,
Model extends GetModels<Schema>,
Field extends GetModelFields<Schema, Model>,
> = GetModelFieldType<Schema, Model, Field> extends 'Json' ? true : false;

type CreateFKPayload<Schema extends SchemaDef, Model extends GetModels<Schema>> = OptionalWrap<
Schema,
Expand Down Expand Up @@ -932,7 +954,7 @@ type ScalarUpdatePayload<
Model extends GetModels<Schema>,
Field extends NonRelationFields<Schema, Model>,
> =
| MapModelFieldType<Schema, Model, Field>
| ScalarFieldMutationPayload<Schema, Model, Field>
| (Field extends NumericFields<Schema, Model>
? {
set?: NullableIf<number, ModelFieldIsOptional<Schema, Model, Field>>;
Expand Down
63 changes: 45 additions & 18 deletions packages/orm/src/client/crud/dialects/base-dialect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import type {
StringFilter,
} from '../../crud-types';
import { createConfigError, createInvalidInputError, createNotSupportedError } from '../../errors';
import { AnyNullClass, DbNullClass, JsonNullClass } from '../../null-values';
import type { ClientOptions } from '../../options';
import {
aggregate,
Expand Down Expand Up @@ -499,24 +500,50 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
return this.buildEnumFilter(fieldRef, fieldDef, payload);
}

return (
match(fieldDef.type as BuiltinType)
.with('String', () => this.buildStringFilter(fieldRef, payload))
.with(P.union('Int', 'Float', 'Decimal', 'BigInt'), (type) =>
this.buildNumberFilter(fieldRef, type, payload),
)
.with('Boolean', () => this.buildBooleanFilter(fieldRef, payload))
.with('DateTime', () => this.buildDateTimeFilter(fieldRef, payload))
.with('Bytes', () => this.buildBytesFilter(fieldRef, payload))
// TODO: JSON filters
.with('Json', () => {
throw createNotSupportedError('JSON filters are not supported yet');
})
.with('Unsupported', () => {
throw createInvalidInputError(`Unsupported field cannot be used in filters`);
})
.exhaustive()
);
return match(fieldDef.type as BuiltinType)
.with('String', () => this.buildStringFilter(fieldRef, payload))
.with(P.union('Int', 'Float', 'Decimal', 'BigInt'), (type) =>
this.buildNumberFilter(fieldRef, type, payload),
)
.with('Boolean', () => this.buildBooleanFilter(fieldRef, payload))
.with('DateTime', () => this.buildDateTimeFilter(fieldRef, payload))
.with('Bytes', () => this.buildBytesFilter(fieldRef, payload))
.with('Json', () => this.buildJsonFilter(fieldRef, payload))
.with('Unsupported', () => {
throw createInvalidInputError(`Unsupported field cannot be used in filters`);
})
.exhaustive();
}

private buildJsonFilter(lhs: Expression<any>, payload: any): any {
const clauses: Expression<SqlBool>[] = [];
invariant(payload && typeof payload === 'object', 'Json filter payload must be an object');
for (const [key, value] of Object.entries(payload)) {
switch (key) {
case 'equals': {
clauses.push(this.buildJsonValueFilterClause(lhs, value));
break;
}
case 'not': {
clauses.push(this.eb.not(this.buildJsonValueFilterClause(lhs, value)));
break;
}
}
}
return this.and(...clauses);
}

private buildJsonValueFilterClause(lhs: Expression<any>, value: unknown) {
if (value instanceof DbNullClass) {
return this.eb(lhs, 'is', null);
} else if (value instanceof JsonNullClass) {
return this.eb.and([this.eb(lhs, '=', 'null'), this.eb(lhs, 'is not', null)]);
} else if (value instanceof AnyNullClass) {
// AnyNull matches both DB NULL and JSON null
return this.eb.or([this.eb(lhs, 'is', null), this.eb(lhs, '=', 'null')]);
} else {
return this.buildLiteralFilter(lhs, 'Json', value);
}
}

private buildLiteralFilter(lhs: Expression<any>, type: BuiltinType, rhs: unknown) {
Expand Down
11 changes: 10 additions & 1 deletion packages/orm/src/client/crud/dialects/postgresql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import type { BuiltinType, FieldDef, GetModels, SchemaDef } from '../../../schem
import { DELEGATE_JOINED_FIELD_PREFIX } from '../../constants';
import type { FindArgs } from '../../crud-types';
import { createInternalError } from '../../errors';
import { AnyNullClass, DbNullClass, JsonNullClass } from '../../null-values';
import type { ClientOptions } from '../../options';
import {
buildJoinPairs,
Expand All @@ -25,7 +26,6 @@ import {
requireModel,
} from '../../query-utils';
import { BaseCrudDialect } from './base-dialect';

export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDialect<Schema> {
private isoDateSchema = z.iso.datetime({ local: true, offset: true });

Expand All @@ -42,6 +42,15 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
return value;
}

// Handle special null classes for JSON fields
if (value instanceof JsonNullClass) {
return 'null';
} else if (value instanceof DbNullClass) {
return null;
} else if (value instanceof AnyNullClass) {
invariant(false, 'should not reach here: AnyNull is not a valid input value');
}

if (Array.isArray(value)) {
if (type === 'Json' && !forArrayField) {
// node-pg incorrectly handles array values passed to non-array JSON fields,
Expand Down
10 changes: 10 additions & 0 deletions packages/orm/src/client/crud/dialects/sqlite.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import type { BuiltinType, FieldDef, GetModels, SchemaDef } from '../../../schem
import { DELEGATE_JOINED_FIELD_PREFIX } from '../../constants';
import type { FindArgs } from '../../crud-types';
import { createInternalError } from '../../errors';
import { AnyNullClass, DbNullClass, JsonNullClass } from '../../null-values';
import {
getDelegateDescendantModels,
getManyToManyRelation,
Expand All @@ -33,6 +34,15 @@ export class SqliteCrudDialect<Schema extends SchemaDef> extends BaseCrudDialect
return value;
}

// Handle special null classes for JSON fields
if (value instanceof JsonNullClass) {
return 'null';
} else if (value instanceof DbNullClass) {
return null;
} else if (value instanceof AnyNullClass) {
invariant(false, 'should not reach here: AnyNull is not a valid input value');
}

if (Array.isArray(value)) {
return value.map((v) => this.transformPrimitive(v, type, false));
} else {
Expand Down
79 changes: 59 additions & 20 deletions packages/orm/src/client/crud/validator/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import {
type UpsertArgs,
} from '../../crud-types';
import { createInternalError, createInvalidInputError } from '../../errors';
import { AnyNullClass, DbNullClass, JsonNullClass } from '../../null-values';
import {
fieldHasDefaultValue,
getDiscriminatorField,
Expand Down Expand Up @@ -328,8 +329,9 @@ export class InputValidator<Schema extends SchemaDef> {
addDecimalValidation(z.string(), attributes, this.extraValidationsEnabled),
]);
})
.with('DateTime', () => z.union([z.date(), z.string().datetime()]))
.with('DateTime', () => z.union([z.date(), z.iso.datetime()]))
.with('Bytes', () => z.instanceof(Uint8Array))
.with('Json', () => this.makeJsonValueSchema(false, false))
.otherwise(() => z.unknown());
}
}
Expand Down Expand Up @@ -553,20 +555,47 @@ export class InputValidator<Schema extends SchemaDef> {
// typed JSON field
return this.makeTypeDefFilterSchema(type, optional);
}
return (
match(type)
.with('String', () => this.makeStringFilterSchema(optional, withAggregations))
.with(P.union('Int', 'Float', 'Decimal', 'BigInt'), (type) =>
this.makeNumberFilterSchema(this.makePrimitiveSchema(type), optional, withAggregations),
)
.with('Boolean', () => this.makeBooleanFilterSchema(optional, withAggregations))
.with('DateTime', () => this.makeDateTimeFilterSchema(optional, withAggregations))
.with('Bytes', () => this.makeBytesFilterSchema(optional, withAggregations))
// TODO: JSON filters
.with('Json', () => z.any())
.with('Unsupported', () => z.never())
.exhaustive()
);
return match(type)
.with('String', () => this.makeStringFilterSchema(optional, withAggregations))
.with(P.union('Int', 'Float', 'Decimal', 'BigInt'), (type) =>
this.makeNumberFilterSchema(this.makePrimitiveSchema(type), optional, withAggregations),
)
.with('Boolean', () => this.makeBooleanFilterSchema(optional, withAggregations))
.with('DateTime', () => this.makeDateTimeFilterSchema(optional, withAggregations))
.with('Bytes', () => this.makeBytesFilterSchema(optional, withAggregations))
.with('Json', () => this.makeJsonFilterSchema(optional))
.with('Unsupported', () => z.never())
.exhaustive();
}

private makeJsonValueSchema(nullable: boolean, forFilter: boolean): z.ZodType {
const options: z.ZodType[] = [z.string(), z.number(), z.boolean(), z.instanceof(JsonNullClass)];

if (nullable) {
options.push(z.instanceof(DbNullClass));
}

if (forFilter) {
options.push(z.instanceof(AnyNullClass));
}

const schema = z.union([
...options,
z.lazy(() => this.makeJsonValueSchema(false, false).array()),
z.record(
z.string(),
z.lazy(() => this.makeJsonValueSchema(false, false)),
),
]);
return this.nullableIf(schema, nullable);
}

private makeJsonFilterSchema(optional: boolean) {
const valueSchema = this.makeJsonValueSchema(optional, true);
return z.object({
equals: valueSchema.optional(),
not: valueSchema.optional(),
});
}

private makeTypeDefFilterSchema(_type: string, _optional: boolean) {
Expand All @@ -576,7 +605,7 @@ export class InputValidator<Schema extends SchemaDef> {

private makeDateTimeFilterSchema(optional: boolean, withAggregations: boolean): ZodType {
return this.makeCommonPrimitiveFilterSchema(
z.union([z.string().datetime(), z.date()]),
z.union([z.iso.datetime(), z.date()]),
optional,
() => z.lazy(() => this.makeDateTimeFilterSchema(optional, withAggregations)),
withAggregations ? ['_count', '_min', '_max'] : undefined,
Expand Down Expand Up @@ -977,7 +1006,7 @@ export class InputValidator<Schema extends SchemaDef> {
uncheckedVariantFields[field] = fieldSchema;
}
} else {
let fieldSchema: ZodType = this.makePrimitiveSchema(fieldDef.type, fieldDef.attributes);
let fieldSchema = this.makePrimitiveSchema(fieldDef.type, fieldDef.attributes);

if (fieldDef.array) {
fieldSchema = addListValidation(fieldSchema.array(), fieldDef.attributes);
Expand All @@ -996,7 +1025,12 @@ export class InputValidator<Schema extends SchemaDef> {
}

if (fieldDef.optional) {
fieldSchema = fieldSchema.nullable();
if (fieldDef.type === 'Json') {
// DbNull for Json fields
fieldSchema = z.union([fieldSchema, z.instanceof(DbNullClass)]);
} else {
fieldSchema = fieldSchema.nullable();
}
}

uncheckedVariantFields[field] = fieldSchema;
Expand Down Expand Up @@ -1242,7 +1276,7 @@ export class InputValidator<Schema extends SchemaDef> {
uncheckedVariantFields[field] = fieldSchema;
}
} else {
let fieldSchema: ZodType = this.makePrimitiveSchema(fieldDef.type, fieldDef.attributes);
let fieldSchema = this.makePrimitiveSchema(fieldDef.type, fieldDef.attributes);

if (this.isNumericField(fieldDef)) {
fieldSchema = z.union([
Expand Down Expand Up @@ -1276,7 +1310,12 @@ export class InputValidator<Schema extends SchemaDef> {
}

if (fieldDef.optional) {
fieldSchema = fieldSchema.nullable();
if (fieldDef.type === 'Json') {
// DbNull for Json fields
fieldSchema = z.union([fieldSchema, z.instanceof(DbNullClass)]);
} else {
fieldSchema = fieldSchema.nullable();
}
}

// all fields are optional in update
Expand Down
1 change: 1 addition & 0 deletions packages/orm/src/client/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ export type * from './crud-types';
export { getCrudDialect } from './crud/dialects';
export { BaseCrudDialect } from './crud/dialects/base-dialect';
export { ORMError, ORMErrorReason, RejectedByPolicyReason } from './errors';
export { AnyNull, DbNull, JsonNull } from './null-values';
export * from './options';
export * from './plugin';
export type { ZenStackPromise } from './promise';
Expand Down
17 changes: 17 additions & 0 deletions packages/orm/src/client/null-values.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
export class DbNullClass {
__brand = 'DbNull' as const;
}
export const DbNull = new DbNullClass();
export type DbNull = typeof DbNull;

export class JsonNullClass {
__brand = 'JsonNull' as const;
}
export const JsonNull = new JsonNullClass();
export type JsonNull = typeof JsonNull;

export class AnyNullClass {
__brand = 'AnyNull' as const;
}
export const AnyNull = new AnyNullClass();
export type AnyNull = typeof AnyNull;
Loading