diff --git a/packages/runtime/src/client/crud/dialects/postgresql.ts b/packages/runtime/src/client/crud/dialects/postgresql.ts index 44869b51..c73b4bb1 100644 --- a/packages/runtime/src/client/crud/dialects/postgresql.ts +++ b/packages/runtime/src/client/crud/dialects/postgresql.ts @@ -15,6 +15,7 @@ import { buildJoinPairs, getIdFields, getManyToManyRelation, + isRelationField, requireField, requireModel, } from '../../query-utils'; @@ -216,10 +217,15 @@ export class PostgresCrudDialect extends BaseCrudDiale objArgs.push( ...Object.entries(payload.select) .filter(([, value]) => value) - .map(([field]) => [ - sql.lit(field), - buildFieldRef(this.schema, relationModel, field, this.options, eb), - ]) + .map(([field]) => { + const fieldDef = requireField(this.schema, relationModel, field); + const fieldValue = fieldDef.relation + ? // reference the synthesized JSON field + eb.ref(`${parentName}$${relationField}$${field}.$j`) + : // reference a plain field + buildFieldRef(this.schema, relationModel, field, this.options, eb); + return [sql.lit(field), fieldValue]; + }) .flatMap((v) => v), ); } @@ -229,7 +235,11 @@ export class PostgresCrudDialect extends BaseCrudDiale objArgs.push( ...Object.entries(payload.include) .filter(([, value]) => value) - .map(([field]) => [sql.lit(field), eb.ref(`${parentName}$${relationField}$${field}.$j`)]) + .map(([field]) => [ + sql.lit(field), + // reference the synthesized JSON field + eb.ref(`${parentName}$${relationField}$${field}.$j`), + ]) .flatMap((v) => v), ); } @@ -237,19 +247,29 @@ export class PostgresCrudDialect extends BaseCrudDiale } private buildRelationJoins( - model: string, + relationModel: string, relationField: string, qb: SelectQueryBuilder, payload: true | FindArgs, true>, parentName: string, ) { let result = qb; - if (typeof payload === 'object' && payload.include && typeof payload.include === 'object') { - Object.entries(payload.include) - .filter(([, value]) => value) - .forEach(([field, value]) => { - result = this.buildRelationJSON(model, result, field, `${parentName}$${relationField}`, value); - }); + if (typeof payload === 'object') { + const selectInclude = payload.include ?? payload.select; + if (selectInclude && typeof selectInclude === 'object') { + Object.entries(selectInclude) + .filter(([, value]) => value) + .filter(([field]) => isRelationField(this.schema, relationModel, field)) + .forEach(([field, value]) => { + result = this.buildRelationJSON( + relationModel, + result, + field, + `${parentName}$${relationField}`, + value, + ); + }); + } } return result; } diff --git a/packages/runtime/src/client/crud/operations/aggregate.ts b/packages/runtime/src/client/crud/operations/aggregate.ts index 05392250..03b1ae6d 100644 --- a/packages/runtime/src/client/crud/operations/aggregate.ts +++ b/packages/runtime/src/client/crud/operations/aggregate.ts @@ -6,7 +6,11 @@ import { BaseOperationHandler } from './base'; export class AggregateOperationHandler extends BaseOperationHandler { async handle(_operation: 'aggregate', args: unknown | undefined) { - const validatedArgs = this.inputValidator.validateAggregateArgs(this.model, args); + // normalize args to strip `undefined` fields + const normalizeArgs = this.normalizeArgs(args); + + // parse args + const parsedArgs = this.inputValidator.validateAggregateArgs(this.model, normalizeArgs); let query = this.kysely.selectFrom((eb) => { // nested query for filtering and pagination @@ -15,11 +19,11 @@ export class AggregateOperationHandler extends BaseOpe let subQuery = eb .selectFrom(this.model) .selectAll(this.model as any) // TODO: check typing - .where((eb1) => this.dialect.buildFilter(eb1, this.model, this.model, validatedArgs?.where)); + .where((eb1) => this.dialect.buildFilter(eb1, this.model, this.model, parsedArgs?.where)); // skip & take - const skip = validatedArgs?.skip; - let take = validatedArgs?.take; + const skip = parsedArgs?.skip; + let take = parsedArgs?.take; let negateOrderBy = false; if (take !== undefined && take < 0) { negateOrderBy = true; @@ -32,7 +36,7 @@ export class AggregateOperationHandler extends BaseOpe subQuery, this.model, this.model, - validatedArgs.orderBy, + parsedArgs.orderBy, skip !== undefined || take !== undefined, negateOrderBy, ); @@ -41,7 +45,7 @@ export class AggregateOperationHandler extends BaseOpe }); // aggregations - for (const [key, value] of Object.entries(validatedArgs)) { + for (const [key, value] of Object.entries(parsedArgs)) { switch (key) { case '_count': { if (value === true) { diff --git a/packages/runtime/src/client/crud/operations/base.ts b/packages/runtime/src/client/crud/operations/base.ts index 26a75376..d436feb2 100644 --- a/packages/runtime/src/client/crud/operations/base.ts +++ b/packages/runtime/src/client/crud/operations/base.ts @@ -1,12 +1,11 @@ import { createId } from '@paralleldrive/cuid2'; -import { invariant } from '@zenstackhq/common-helpers'; +import { invariant, isPlainObject } from '@zenstackhq/common-helpers'; import { DeleteResult, expressionBuilder, ExpressionWrapper, sql, UpdateResult, - type ExpressionBuilder, type Expression as KyselyExpression, type SelectQueryBuilder, } from 'kysely'; @@ -292,33 +291,29 @@ export abstract class BaseOperationHandler { for (const [field, value] of Object.entries(selections.select)) { const fieldDef = requireField(this.schema, model, field); const fieldModel = fieldDef.type; - const jointTable = `${parentAlias}$${field}$count`; - const joinPairs = buildJoinPairs(this.schema, model, parentAlias, field, jointTable); - - query = query.leftJoin( - (eb) => { - let result = eb.selectFrom(fieldModel).selectAll(); - if ( - value && - typeof value === 'object' && - 'where' in value && - value.where && - typeof value.where === 'object' - ) { - const filter = this.dialect.buildFilter(eb, fieldModel, fieldModel, value.where); - result = result.where(filter); - } - return result.as(jointTable); - }, - (join) => { - for (const [left, right] of joinPairs) { - join = join.onRef(left, '=', right); - } - return join; - }, - ); + const joinPairs = buildJoinPairs(this.schema, model, parentAlias, field, fieldModel); + + // build a nested query to count the number of records in the relation + let fieldCountQuery = eb.selectFrom(fieldModel).select(eb.fn.countAll().as(`_count$${field}`)); + + // join conditions + for (const [left, right] of joinPairs) { + fieldCountQuery = fieldCountQuery.whereRef(left, '=', right); + } + + // merge _count filter + if ( + value && + typeof value === 'object' && + 'where' in value && + value.where && + typeof value.where === 'object' + ) { + const filter = this.dialect.buildFilter(eb, fieldModel, fieldModel, value.where); + fieldCountQuery = fieldCountQuery.where(filter); + } - jsonObject[field] = this.countIdDistinct(eb, fieldDef.type, jointTable); + jsonObject[field] = fieldCountQuery; } query = query.select((eb) => this.dialect.buildJsonObject(eb, jsonObject).as('_count')); @@ -326,11 +321,6 @@ export abstract class BaseOperationHandler { return query; } - private countIdDistinct(eb: ExpressionBuilder, model: string, table: string) { - const idFields = getIdFields(this.schema, model); - return eb.fn.count(sql.join(idFields.map((f) => sql.ref(`${table}.${f}`)))).distinct(); - } - private buildSelectAllScalarFields( model: string, query: SelectQueryBuilder, @@ -479,7 +469,7 @@ export abstract class BaseOperationHandler { } else { const subM2M = getManyToManyRelation(this.schema, model, field); if (!subM2M && fieldDef.relation?.fields && fieldDef.relation?.references) { - const fkValues = await this.processOwnedRelation(kysely, fieldDef, value); + const fkValues = await this.processOwnedRelationForCreate(kysely, fieldDef, value); for (let i = 0; i < fieldDef.relation.fields.length; i++) { createFields[fieldDef.relation.fields[i]!] = fkValues[fieldDef.relation.references[i]!]; } @@ -519,7 +509,7 @@ export abstract class BaseOperationHandler { if (Object.keys(postCreateRelations).length > 0) { // process nested creates that need to happen after the current entity is created const relationPromises = Object.entries(postCreateRelations).map(([field, subPayload]) => { - return this.processNoneOwnedRelation(kysely, model, field, subPayload, createdEntity); + return this.processNoneOwnedRelationForCreate(kysely, model, field, subPayload, createdEntity); }); // await relation creation @@ -633,7 +623,7 @@ export abstract class BaseOperationHandler { .execute(); } - private async processOwnedRelation(kysely: ToKysely, relationField: FieldDef, payload: any) { + private async processOwnedRelationForCreate(kysely: ToKysely, relationField: FieldDef, payload: any) { if (!payload) { return; } @@ -696,7 +686,7 @@ export abstract class BaseOperationHandler { return result; } - private processNoneOwnedRelation( + private processNoneOwnedRelationForCreate( kysely: ToKysely, contextModel: GetModels, relationFieldName: string, @@ -706,6 +696,11 @@ export abstract class BaseOperationHandler { const relationFieldDef = this.requireField(contextModel, relationFieldName); const relationModel = relationFieldDef.type as GetModels; const tasks: Promise[] = []; + const fromRelationContext = { + model: contextModel, + field: relationFieldName, + ids: parentEntity, + }; for (const [action, subPayload] of Object.entries(payload)) { if (!subPayload) { @@ -716,11 +711,21 @@ export abstract class BaseOperationHandler { // create with a parent entity tasks.push( ...enumerate(subPayload).map((item) => - this.create(kysely, relationModel, item, { - model: contextModel, - field: relationFieldName, - ids: parentEntity, - }), + this.create(kysely, relationModel, item, fromRelationContext), + ), + ); + break; + } + + case 'createMany': { + invariant(relationFieldDef.array, 'relation must be an array for createMany'); + tasks.push( + this.createMany( + kysely, + relationModel, + subPayload as { data: any; skipDuplicates: boolean }, + false, + fromRelationContext, ), ); break; @@ -776,6 +781,11 @@ export abstract class BaseOperationHandler { returnData: ReturnData, fromRelation?: FromRelationContext, ): Promise { + if (!input.data || (Array.isArray(input.data) && input.data.length === 0)) { + // nothing todo + return returnData ? ([] as Result) : ({ count: 0 } as Result); + } + const modelDef = this.requireModel(model); let relationKeyPairs: { fk: string; pk: string }[] = []; @@ -1916,4 +1926,28 @@ export abstract class BaseOperationHandler { where: uniqueFilter, }); } + + /** + * Normalize input args to strip `undefined` fields + */ + protected normalizeArgs(args: unknown) { + if (!args) { + return; + } + const newArgs = clone(args); + this.doNormalizeArgs(newArgs); + return newArgs; + } + + private doNormalizeArgs(args: unknown) { + if (args && typeof args === 'object') { + for (const [key, value] of Object.entries(args)) { + if (value === undefined) { + delete args[key as keyof typeof args]; + } else if (value && isPlainObject(value)) { + this.doNormalizeArgs(value); + } + } + } + } } diff --git a/packages/runtime/src/client/crud/operations/count.ts b/packages/runtime/src/client/crud/operations/count.ts index f454762b..9a8cc315 100644 --- a/packages/runtime/src/client/crud/operations/count.ts +++ b/packages/runtime/src/client/crud/operations/count.ts @@ -4,22 +4,26 @@ import { BaseOperationHandler } from './base'; export class CountOperationHandler extends BaseOperationHandler { async handle(_operation: 'count', args: unknown | undefined) { - const validatedArgs = this.inputValidator.validateCountArgs(this.model, args); + // normalize args to strip `undefined` fields + const normalizeArgs = this.normalizeArgs(args); + + // parse args + const parsedArgs = this.inputValidator.validateCountArgs(this.model, normalizeArgs); let query = this.kysely.selectFrom((eb) => { // nested query for filtering and pagination let subQuery = eb .selectFrom(this.model) .selectAll() - .where((eb1) => this.dialect.buildFilter(eb1, this.model, this.model, validatedArgs?.where)); - subQuery = this.dialect.buildSkipTake(subQuery, validatedArgs?.skip, validatedArgs?.take); + .where((eb1) => this.dialect.buildFilter(eb1, this.model, this.model, parsedArgs?.where)); + subQuery = this.dialect.buildSkipTake(subQuery, parsedArgs?.skip, parsedArgs?.take); return subQuery.as('$sub'); }); - if (validatedArgs?.select && typeof validatedArgs.select === 'object') { + if (parsedArgs?.select && typeof parsedArgs.select === 'object') { // count with field selection query = query.select((eb) => - Object.keys(validatedArgs.select!).map((key) => + Object.keys(parsedArgs.select!).map((key) => key === '_all' ? eb.cast(eb.fn.countAll(), 'integer').as('_all') : eb.cast(eb.fn.count(sql.ref(`$sub.${key}`)), 'integer').as(key), diff --git a/packages/runtime/src/client/crud/operations/create.ts b/packages/runtime/src/client/crud/operations/create.ts index 1b6c9288..a908346b 100644 --- a/packages/runtime/src/client/crud/operations/create.ts +++ b/packages/runtime/src/client/crud/operations/create.ts @@ -7,14 +7,17 @@ import { BaseOperationHandler } from './base'; export class CreateOperationHandler extends BaseOperationHandler { async handle(operation: 'create' | 'createMany' | 'createManyAndReturn', args: unknown | undefined) { + // normalize args to strip `undefined` fields + const normalizeArgs = this.normalizeArgs(args); + return match(operation) - .with('create', () => this.runCreate(this.inputValidator.validateCreateArgs(this.model, args))) + .with('create', () => this.runCreate(this.inputValidator.validateCreateArgs(this.model, normalizeArgs))) .with('createMany', () => { - return this.runCreateMany(this.inputValidator.validateCreateManyArgs(this.model, args)); + return this.runCreateMany(this.inputValidator.validateCreateManyArgs(this.model, normalizeArgs)); }) .with('createManyAndReturn', () => { return this.runCreateManyAndReturn( - this.inputValidator.validateCreateManyAndReturnArgs(this.model, args), + this.inputValidator.validateCreateManyAndReturnArgs(this.model, normalizeArgs), ); }) .exhaustive(); diff --git a/packages/runtime/src/client/crud/operations/delete.ts b/packages/runtime/src/client/crud/operations/delete.ts index 6933b1a8..7ee821c6 100644 --- a/packages/runtime/src/client/crud/operations/delete.ts +++ b/packages/runtime/src/client/crud/operations/delete.ts @@ -6,9 +6,14 @@ import { BaseOperationHandler } from './base'; export class DeleteOperationHandler extends BaseOperationHandler { async handle(operation: 'delete' | 'deleteMany', args: unknown | undefined) { + // normalize args to strip `undefined` fields + const normalizeArgs = this.normalizeArgs(args); + return match(operation) - .with('delete', () => this.runDelete(this.inputValidator.validateDeleteArgs(this.model, args))) - .with('deleteMany', () => this.runDeleteMany(this.inputValidator.validateDeleteManyArgs(this.model, args))) + .with('delete', () => this.runDelete(this.inputValidator.validateDeleteArgs(this.model, normalizeArgs))) + .with('deleteMany', () => + this.runDeleteMany(this.inputValidator.validateDeleteManyArgs(this.model, normalizeArgs)), + ) .exhaustive(); } diff --git a/packages/runtime/src/client/crud/operations/find.ts b/packages/runtime/src/client/crud/operations/find.ts index 8a868fad..7834e58b 100644 --- a/packages/runtime/src/client/crud/operations/find.ts +++ b/packages/runtime/src/client/crud/operations/find.ts @@ -4,10 +4,13 @@ import { BaseOperationHandler, type CrudOperation } from './base'; export class FindOperationHandler extends BaseOperationHandler { async handle(operation: CrudOperation, args: unknown, validateArgs = true): Promise { + // normalize args to strip `undefined` fields + const normalizeArgs = this.normalizeArgs(args); + // parse args const parsedArgs = validateArgs - ? this.inputValidator.validateFindArgs(this.model, operation === 'findUnique', args) - : args; + ? this.inputValidator.validateFindArgs(this.model, operation === 'findUnique', normalizeArgs) + : normalizeArgs; // run query const result = await this.read( diff --git a/packages/runtime/src/client/crud/operations/group-by.ts b/packages/runtime/src/client/crud/operations/group-by.ts index c59b1f7f..f1630c82 100644 --- a/packages/runtime/src/client/crud/operations/group-by.ts +++ b/packages/runtime/src/client/crud/operations/group-by.ts @@ -6,7 +6,11 @@ import { BaseOperationHandler } from './base'; export class GroupByeOperationHandler extends BaseOperationHandler { async handle(_operation: 'groupBy', args: unknown | undefined) { - const validatedArgs = this.inputValidator.validateGroupByArgs(this.model, args); + // normalize args to strip `undefined` fields + const normalizeArgs = this.normalizeArgs(args); + + // parse args + const parsedArgs = this.inputValidator.validateGroupByArgs(this.model, normalizeArgs); let query = this.kysely.selectFrom((eb) => { // nested query for filtering and pagination @@ -15,11 +19,11 @@ export class GroupByeOperationHandler extends BaseOper let subQuery = eb .selectFrom(this.model) .selectAll() - .where((eb1) => this.dialect.buildFilter(eb1, this.model, this.model, validatedArgs?.where)); + .where((eb1) => this.dialect.buildFilter(eb1, this.model, this.model, parsedArgs?.where)); // skip & take - const skip = validatedArgs?.skip; - let take = validatedArgs?.take; + const skip = parsedArgs?.skip; + let take = parsedArgs?.take; let negateOrderBy = false; if (take !== undefined && take < 0) { negateOrderBy = true; @@ -40,17 +44,17 @@ export class GroupByeOperationHandler extends BaseOper return subQuery.as('$sub'); }); - const bys = typeof validatedArgs.by === 'string' ? [validatedArgs.by] : (validatedArgs.by as string[]); + const bys = typeof parsedArgs.by === 'string' ? [parsedArgs.by] : (parsedArgs.by as string[]); query = query.groupBy(bys as any); // orderBy - if (validatedArgs.orderBy) { - query = this.dialect.buildOrderBy(query, this.model, '$sub', validatedArgs.orderBy, false, false); + if (parsedArgs.orderBy) { + query = this.dialect.buildOrderBy(query, this.model, '$sub', parsedArgs.orderBy, false, false); } - if (validatedArgs.having) { - query = query.having((eb1) => this.dialect.buildFilter(eb1, this.model, '$sub', validatedArgs.having)); + if (parsedArgs.having) { + query = query.having((eb1) => this.dialect.buildFilter(eb1, this.model, '$sub', parsedArgs.having)); } // select all by fields @@ -59,7 +63,7 @@ export class GroupByeOperationHandler extends BaseOper } // aggregations - for (const [key, value] of Object.entries(validatedArgs)) { + for (const [key, value] of Object.entries(parsedArgs)) { switch (key) { case '_count': { if (value === true) { diff --git a/packages/runtime/src/client/crud/operations/update.ts b/packages/runtime/src/client/crud/operations/update.ts index 4771b071..577646ef 100644 --- a/packages/runtime/src/client/crud/operations/update.ts +++ b/packages/runtime/src/client/crud/operations/update.ts @@ -7,13 +7,20 @@ import { BaseOperationHandler } from './base'; export class UpdateOperationHandler extends BaseOperationHandler { async handle(operation: 'update' | 'updateMany' | 'updateManyAndReturn' | 'upsert', args: unknown) { + // normalize args to strip `undefined` fields + const normalizeArgs = this.normalizeArgs(args); + return match(operation) - .with('update', () => this.runUpdate(this.inputValidator.validateUpdateArgs(this.model, args))) - .with('updateMany', () => this.runUpdateMany(this.inputValidator.validateUpdateManyArgs(this.model, args))) + .with('update', () => this.runUpdate(this.inputValidator.validateUpdateArgs(this.model, normalizeArgs))) + .with('updateMany', () => + this.runUpdateMany(this.inputValidator.validateUpdateManyArgs(this.model, normalizeArgs)), + ) .with('updateManyAndReturn', () => - this.runUpdateManyAndReturn(this.inputValidator.validateUpdateManyAndReturnArgs(this.model, args)), + this.runUpdateManyAndReturn( + this.inputValidator.validateUpdateManyAndReturnArgs(this.model, normalizeArgs), + ), ) - .with('upsert', () => this.runUpsert(this.inputValidator.validateUpsertArgs(this.model, args))) + .with('upsert', () => this.runUpsert(this.inputValidator.validateUpsertArgs(this.model, normalizeArgs))) .exhaustive(); } diff --git a/packages/runtime/src/client/crud/validator.ts b/packages/runtime/src/client/crud/validator.ts index 46a7d7d1..754d53d4 100644 --- a/packages/runtime/src/client/crud/validator.ts +++ b/packages/runtime/src/client/crud/validator.ts @@ -21,6 +21,7 @@ import { } from '../crud-types'; import { InternalError, QueryError } from '../errors'; import { fieldHasDefaultValue, getEnum, getModel, getUniqueFields, requireField, requireModel } from '../query-utils'; +import { invariant } from '@zenstackhq/common-helpers'; type GetSchemaFunc = (model: GetModels, options: Options) => ZodType; @@ -298,10 +299,26 @@ export class InputValidator { fields[uniqueField.name] = z .object( Object.fromEntries( - Object.entries(uniqueField.defs).map(([key, def]) => [ - key, - this.makePrimitiveFilterSchema(def.type as BuiltinType, !!def.optional), - ]), + Object.entries(uniqueField.defs).map(([key, def]) => { + invariant(!def.relation, 'unique field cannot be a relation'); + let fieldSchema: ZodType; + const enumDef = getEnum(this.schema, def.type); + if (enumDef) { + // enum + if (Object.keys(enumDef).length > 0) { + fieldSchema = this.makeEnumFilterSchema(enumDef, !!def.optional); + } else { + fieldSchema = z.never(); + } + } else { + // regular field + fieldSchema = this.makePrimitiveFilterSchema( + def.type as BuiltinType, + !!def.optional, + ); + } + return [key, fieldSchema]; + }), ), ) .optional(); @@ -796,10 +813,7 @@ export class InputValidator { } } - return z - .object(fields) - .strict() - .refine((v) => Object.keys(v).length > 0, 'At least one action is required'); + return z.object(fields).strict(); } private makeSetDataSchema(model: string, canBeArray: boolean) { diff --git a/packages/runtime/src/client/result-processor.ts b/packages/runtime/src/client/result-processor.ts index 8c6e9df4..25a2a4df 100644 --- a/packages/runtime/src/client/result-processor.ts +++ b/packages/runtime/src/client/result-processor.ts @@ -135,6 +135,10 @@ export class ResultProcessor { } private fixReversedResult(data: any, model: GetModels, args: any) { + if (!data) { + return; + } + if (Array.isArray(data) && typeof args === 'object' && args && args.take !== undefined && args.take < 0) { data.reverse(); } @@ -150,7 +154,7 @@ export class ResultProcessor { continue; } const fieldDef = getField(this.schema, model, field); - if (!fieldDef?.relation) { + if (!fieldDef || !fieldDef.relation || !fieldDef.array) { continue; } this.fixReversedResult(row[field], fieldDef.type as GetModels, value); diff --git a/packages/runtime/test/client-api/create.test.ts b/packages/runtime/test/client-api/create.test.ts index 4004d7cc..51e0c06d 100644 --- a/packages/runtime/test/client-api/create.test.ts +++ b/packages/runtime/test/client-api/create.test.ts @@ -1,6 +1,5 @@ import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import type { ClientContract } from '../../src/client'; -import { QueryError } from '../../src/client/errors'; import { schema } from '../test-schema'; import { createClientSpecs } from './client-specs'; @@ -289,21 +288,4 @@ describe.each(createClientSpecs(PG_DB_NAME))('Client create tests', ({ createCli expect(u3.posts).toHaveLength(3); expect(u3.posts.map((p) => p.title)).toEqual(expect.arrayContaining(['Post1', 'Post2', 'Post4'])); }); - - it('rejects empty relation payload', async () => { - await expect( - client.post.create({ - data: { title: 'Post1', author: {} }, - }), - ).rejects.toThrow('At least one action is required'); - - await expect( - client.user.create({ - data: { - email: 'u1@test.com', - posts: {}, - }, - }), - ).rejects.toThrow(QueryError); - }); }); diff --git a/packages/runtime/test/client-api/find.test.ts b/packages/runtime/test/client-api/find.test.ts index 6d70c769..8bc60821 100644 --- a/packages/runtime/test/client-api/find.test.ts +++ b/packages/runtime/test/client-api/find.test.ts @@ -645,6 +645,56 @@ describe.each(createClientSpecs(PG_DB_NAME))('Client find tests for $provider', email: 'u1@test.com', createdAt: expect.any(Date), }); + + const r2 = await client.user.findUnique({ + where: { id: user.id }, + select: { + id: true, + posts: { + select: { + id: true, + author: { + select: { email: true }, + }, + }, + }, + }, + }); + expect(r2).toMatchObject({ + id: user.id, + posts: expect.arrayContaining([ + expect.objectContaining({ + id: expect.any(String), + author: { + email: 'u1@test.com', + }, + }), + ]), + }); + + const r3 = await client.user.findUnique({ + where: { id: user.id }, + include: { + posts: { + include: { + author: { + select: { email: true }, + }, + }, + }, + }, + }); + expect(r3).toMatchObject({ + id: user.id, + posts: expect.arrayContaining([ + expect.objectContaining({ + id: expect.any(String), + author: { + email: 'u1@test.com', + }, + }), + ]), + }); }); it('allows field omission', async () => { @@ -775,7 +825,7 @@ describe.each(createClientSpecs(PG_DB_NAME))('Client find tests for $provider', await expect( client.user.findUnique({ where: { id: user1.id }, - select: { _count: true }, + select: { id: true, _count: true }, }), ).resolves.toMatchObject({ _count: { posts: 2 }, @@ -784,12 +834,21 @@ describe.each(createClientSpecs(PG_DB_NAME))('Client find tests for $provider', await expect( client.user.findUnique({ where: { id: user1.id }, - select: { _count: { select: { posts: true } } }, + select: { id: true, _count: { select: { posts: true } } }, }), ).resolves.toMatchObject({ _count: { posts: 2 }, }); + await expect( + client.user.findUnique({ + where: { id: user1.id }, + select: { id: true, _count: { select: { posts: { where: { published: true } } } } }, + }), + ).resolves.toMatchObject({ + _count: { posts: 1 }, + }); + await expect( client.user.findUnique({ where: { id: user1.id }, diff --git a/packages/runtime/test/client-api/undefined-values.test.ts b/packages/runtime/test/client-api/undefined-values.test.ts new file mode 100644 index 00000000..20b9aa27 --- /dev/null +++ b/packages/runtime/test/client-api/undefined-values.test.ts @@ -0,0 +1,45 @@ +import { afterEach, beforeEach, describe, expect, it } from 'vitest'; +import type { ClientContract } from '../../src/client'; +import { schema } from '../test-schema'; +import { createClientSpecs } from './client-specs'; +import { createUser } from './utils'; + +const PG_DB_NAME = 'client-api-undefined-values-tests'; + +describe.each(createClientSpecs(PG_DB_NAME, true))( + 'Client undefined values tests for $provider', + ({ createClient }) => { + let client: ClientContract; + + beforeEach(async () => { + client = await createClient(); + }); + + afterEach(async () => { + await client?.$disconnect(); + }); + + it('works with toplevel undefined args', async () => { + await expect(client.user.findMany(undefined)).toResolveTruthy(); + }); + + it('ignored with undefined filter values', async () => { + const user = await createUser(client, 'u1@test.com'); + await expect( + client.user.findFirst({ + where: { + id: undefined, + }, + }), + ).resolves.toMatchObject(user); + + await expect( + client.user.findFirst({ + where: { + email: undefined, + }, + }), + ).resolves.toMatchObject(user); + }); + }, +); diff --git a/packages/runtime/tsconfig.json b/packages/runtime/tsconfig.json index 7b9efb7a..2125902f 100644 --- a/packages/runtime/tsconfig.json +++ b/packages/runtime/tsconfig.json @@ -3,5 +3,5 @@ "compilerOptions": { "outDir": "dist" }, - "include": ["src/**/*.ts"] + "include": ["src/**/*", "test/**/*"] }