diff --git a/packages/language/src/validators/datamodel-validator.ts b/packages/language/src/validators/datamodel-validator.ts index cbcbf896..40d74dbf 100644 --- a/packages/language/src/validators/datamodel-validator.ts +++ b/packages/language/src/validators/datamodel-validator.ts @@ -443,7 +443,7 @@ export default class DataModelValidator implements AstValidator { invariant(model.baseModel.ref, 'baseModel must be resolved'); // check if the base model is a delegate model - if (!isDelegateModel(model.baseModel.ref)) { + if (!isDelegateModel(model.baseModel.ref!)) { accept('error', `Model ${model.baseModel.$refText} cannot be extended because it's not a delegate model`, { node: model, property: 'baseModel', diff --git a/packages/runtime/src/client/crud-types.ts b/packages/runtime/src/client/crud-types.ts index 728082d7..fd1918df 100644 --- a/packages/runtime/src/client/crud-types.ts +++ b/packages/runtime/src/client/crud-types.ts @@ -437,14 +437,6 @@ export type SelectIncludeOmit; }; -type Distinct> = { - distinct?: OrArray>; -}; - -type Cursor> = { - cursor?: WhereUniqueInput; -}; - export type SelectInput< Schema extends SchemaDef, Model extends GetModels, @@ -621,25 +613,34 @@ type OppositeRelationAndFK< //#region Find args +type FilterArgs> = { + where?: WhereInput; +}; + +type SortAndTakeArgs> = { + skip?: number; + take?: number; + orderBy?: OrArray>; + cursor?: WhereUniqueInput; +}; + export type FindArgs< Schema extends SchemaDef, Model extends GetModels, Collection extends boolean, AllowFilter extends boolean = true, -> = (Collection extends true - ? { - skip?: number; - take?: number; - orderBy?: OrArray>; - } & Distinct & - Cursor - : {}) & - (AllowFilter extends true - ? { - where?: WhereInput; - } - : {}) & - SelectIncludeOmit; +> = + ProviderSupportsDistinct extends true + ? (Collection extends true + ? SortAndTakeArgs & { + distinct?: OrArray>; + } + : {}) & + (AllowFilter extends true ? FilterArgs : {}) & + SelectIncludeOmit + : (Collection extends true ? SortAndTakeArgs : {}) & + (AllowFilter extends true ? FilterArgs : {}) & + SelectIncludeOmit; export type FindManyArgs> = FindArgs; export type FindFirstArgs> = FindArgs; @@ -1259,6 +1260,12 @@ type HasToManyRelations = Schema['provider'] extends 'postgresql' ? true : false; +type ProviderSupportsCaseSensitivity = Schema['provider']['type'] extends 'postgresql' + ? true + : false; + +type ProviderSupportsDistinct = Schema['provider']['type'] extends 'postgresql' + ? true + : false; // #endregion diff --git a/packages/runtime/src/client/crud/dialects/base.ts b/packages/runtime/src/client/crud/dialects/base.ts index 2c1738cd..34d3ffd3 100644 --- a/packages/runtime/src/client/crud/dialects/base.ts +++ b/packages/runtime/src/client/crud/dialects/base.ts @@ -21,6 +21,7 @@ import { aggregate, buildFieldRef, buildJoinPairs, + ensureArray, flattenCompoundUniqueFilters, getDelegateDescendantModels, getIdFields, @@ -58,6 +59,54 @@ export abstract class BaseCrudDialect { return result; } + buildFilterSortTake( + model: GetModels, + args: FindArgs, true>, + query: SelectQueryBuilder, + ) { + let result = query; + + // where + if (args.where) { + result = result.where((eb) => this.buildFilter(eb, model, model, args?.where)); + } + + // skip && take + let negateOrderBy = false; + const skip = args.skip; + let take = args.take; + if (take !== undefined && take < 0) { + negateOrderBy = true; + take = -take; + } + result = this.buildSkipTake(result, skip, take); + + // orderBy + result = this.buildOrderBy( + result, + model, + model, + args.orderBy, + skip !== undefined || take !== undefined, + negateOrderBy, + ); + + // distinct + if ('distinct' in args && (args as any).distinct) { + const distinct = ensureArray((args as any).distinct) as string[]; + if (this.supportsDistinctOn) { + result = result.distinctOn(distinct.map((f) => sql.ref(`${model}.${f}`))); + } else { + throw new QueryError(`"distinct" is not supported by "${this.schema.provider.type}" provider`); + } + } + + if (args.cursor) { + result = this.buildCursorFilter(model, result, args.cursor, args.orderBy, negateOrderBy); + } + return result; + } + buildFilter( eb: ExpressionBuilder, model: string, @@ -117,6 +166,47 @@ export abstract class BaseCrudDialect { return result; } + private buildCursorFilter( + model: string, + query: SelectQueryBuilder, + cursor: FindArgs, true>['cursor'], + orderBy: FindArgs, true>['orderBy'], + negateOrderBy: boolean, + ) { + const _orderBy = orderBy ?? makeDefaultOrderBy(this.schema, model); + + const orderByItems = ensureArray(_orderBy).flatMap((obj) => Object.entries(obj)); + + const eb = expressionBuilder(); + const cursorFilter = this.buildFilter(eb, model, model, cursor); + + let result = query; + const filters: ExpressionWrapper[] = []; + + for (let i = orderByItems.length - 1; i >= 0; i--) { + const andFilters: ExpressionWrapper[] = []; + + for (let j = 0; j <= i; j++) { + const [field, order] = orderByItems[j]!; + const _order = negateOrderBy ? (order === 'asc' ? 'desc' : 'asc') : order; + const op = j === i ? (_order === 'asc' ? '>=' : '<=') : '='; + andFilters.push( + eb( + eb.ref(`${model}.${field}`), + op, + eb.selectFrom(model).select(`${model}.${field}`).where(cursorFilter), + ), + ); + } + + filters.push(eb.and(andFilters)); + } + + result = result.where((eb) => eb.or(filters)); + + return result; + } + private isLogicalCombinator(key: string): key is (typeof LOGICAL_COMBINATORS)[number] { return LOGICAL_COMBINATORS.includes(key as any); } @@ -722,7 +812,7 @@ export abstract class BaseCrudDialect { // aggregations if (['_count', '_avg', '_sum', '_min', '_max'].includes(field)) { invariant(value && typeof value === 'object', `invalid orderBy value for field "${field}"`); - for (const [k, v] of Object.entries(value)) { + for (const [k, v] of Object.entries(value)) { invariant(v === 'asc' || v === 'desc', `invalid orderBy value for field "${field}"`); result = result.orderBy( (eb) => diff --git a/packages/runtime/src/client/crud/dialects/postgresql.ts b/packages/runtime/src/client/crud/dialects/postgresql.ts index 08d07950..6d10fc12 100644 --- a/packages/runtime/src/client/crud/dialects/postgresql.ts +++ b/packages/runtime/src/client/crud/dialects/postgresql.ts @@ -91,31 +91,7 @@ export class PostgresCrudDialect extends BaseCrudDiale ); if (payload && typeof payload === 'object') { - if (payload.where) { - subQuery = subQuery.where((eb) => - this.buildFilter(eb, relationModel, relationModel, payload.where), - ); - } - - // skip & take - const skip = payload.skip; - let take = payload.take; - let negateOrderBy = false; - if (take !== undefined && take < 0) { - negateOrderBy = true; - take = -take; - } - subQuery = this.buildSkipTake(subQuery, skip, take); - - // orderBy - subQuery = this.buildOrderBy( - subQuery, - relationModel, - relationModel, - payload.orderBy, - skip !== undefined || take !== undefined, - negateOrderBy, - ); + subQuery = this.buildFilterSortTake(relationModel, payload, subQuery); } // add join conditions diff --git a/packages/runtime/src/client/crud/dialects/sqlite.ts b/packages/runtime/src/client/crud/dialects/sqlite.ts index 3a2a4868..747337fe 100644 --- a/packages/runtime/src/client/crud/dialects/sqlite.ts +++ b/packages/runtime/src/client/crud/dialects/sqlite.ts @@ -85,31 +85,8 @@ export class SqliteCrudDialect extends BaseCrudDialect ); if (payload && typeof payload === 'object') { - if (payload.where) { - subQuery = subQuery.where((eb) => - this.buildFilter(eb, relationModel, relationModel, payload.where), - ); - } - - // skip & take - const skip = payload.skip; - let take = payload.take; - let negateOrderBy = false; - if (take !== undefined && take < 0) { - negateOrderBy = true; - take = -take; - } - subQuery = this.buildSkipTake(subQuery, skip, take); - - // orderBy - subQuery = this.buildOrderBy( - subQuery, - relationModel, - relationModel, - payload.orderBy, - skip !== undefined || take !== undefined, - negateOrderBy, - ); + // take care of where, orderBy, skip, take, cursor, and distinct + subQuery = this.buildFilterSortTake(relationModel, payload, subQuery); } // join conditions diff --git a/packages/runtime/src/client/crud/operations/base.ts b/packages/runtime/src/client/crud/operations/base.ts index 9765ea59..19fca142 100644 --- a/packages/runtime/src/client/crud/operations/base.ts +++ b/packages/runtime/src/client/crud/operations/base.ts @@ -3,7 +3,6 @@ import { invariant, isPlainObject } from '@zenstackhq/common-helpers'; import { DeleteResult, expressionBuilder, - ExpressionWrapper, sql, UpdateResult, type Compilable, @@ -25,7 +24,7 @@ import { enumerate } from '../../../utils/enumerate'; import { extractFields, fieldsToSelectObject } from '../../../utils/object-utils'; import { NUMERIC_FIELD_TYPES } from '../../constants'; import type { CRUD } from '../../contract'; -import type { FindArgs, SelectIncludeOmit, SortOrder, WhereInput } from '../../crud-types'; +import type { FindArgs, SelectIncludeOmit, WhereInput } from '../../crud-types'; import { InternalError, NotFoundError, QueryError } from '../../errors'; import type { ToKysely } from '../../query-builder'; import { @@ -42,10 +41,8 @@ import { isForeignKeyField, isRelationField, isScalarField, - makeDefaultOrderBy, requireField, requireModel, - safeJSONStringify, } from '../../query-utils'; import { getCrudDialect } from '../dialects'; import type { BaseCrudDialect } from '../dialects/base'; @@ -150,48 +147,8 @@ export abstract class BaseOperationHandler { // table let query = this.dialect.buildSelectModel(expressionBuilder(), model); - // where - if (args?.where) { - query = query.where((eb) => this.dialect.buildFilter(eb, model, model, args?.where)); - } - - // skip && take - let negateOrderBy = false; - const skip = args?.skip; - let take = args?.take; - if (take !== undefined && take < 0) { - negateOrderBy = true; - take = -take; - } - query = this.dialect.buildSkipTake(query, skip, take); - - // orderBy - query = this.dialect.buildOrderBy( - query, - model, - model, - args?.orderBy, - skip !== undefined || take !== undefined, - negateOrderBy, - ); - - // distinct - let inMemoryDistinct: string[] | undefined = undefined; - if (args?.distinct) { - const distinct = ensureArray(args.distinct) as string[]; - if (this.dialect.supportsDistinctOn) { - query = query.distinctOn(distinct.map((f) => sql.ref(`${model}.${f}`))); - } else { - // in-memory distinct after fetching all results - inMemoryDistinct = distinct; - - // make sure distinct fields are selected - query = distinct.reduce( - (acc, field) => - acc.select((eb) => this.dialect.fieldRef(model, field, eb).as(`$distinct$${field}`)), - query, - ); - } + if (args) { + query = this.dialect.buildFilterSortTake(model, args, query); } // select @@ -209,10 +166,6 @@ export abstract class BaseOperationHandler { query = this.buildFieldSelection(model, query, args.include, model); } - if (args?.cursor) { - query = this.buildCursorFilter(model, query, args.cursor, args.orderBy, negateOrderBy); - } - query = query.modifyEnd(this.makeContextComment({ model, operation: 'read' })); let result: any[] = []; @@ -229,26 +182,6 @@ export abstract class BaseOperationHandler { throw new QueryError(message, err); } - if (inMemoryDistinct) { - const distinctResult: Record[] = []; - const seen = new Set(); - for (const r of result as any[]) { - const key = safeJSONStringify(inMemoryDistinct.map((f) => r[`$distinct$${f}`]))!; - if (!seen.has(key)) { - distinctResult.push(r); - seen.add(key); - } - } - result = distinctResult; - - // clean up distinct utility fields - for (const r of result) { - Object.keys(r) - .filter((k) => k.startsWith('$distinct$')) - .forEach((k) => delete r[k]); - } - } - return result; } @@ -314,49 +247,6 @@ export abstract class BaseOperationHandler { return query.select((eb) => this.dialect.buildCountJson(model, eb, parentAlias, payload).as('_count')); } - private buildCursorFilter( - model: string, - query: SelectQueryBuilder, - cursor: FindArgs, true>['cursor'], - orderBy: FindArgs, true>['orderBy'], - negateOrderBy: boolean, - ) { - if (!orderBy) { - orderBy = makeDefaultOrderBy(this.schema, model); - } - - const orderByItems = ensureArray(orderBy).flatMap((obj) => Object.entries(obj)); - - const eb = expressionBuilder(); - const cursorFilter = this.dialect.buildFilter(eb, model, model, cursor); - - let result = query; - const filters: ExpressionWrapper[] = []; - - for (let i = orderByItems.length - 1; i >= 0; i--) { - const andFilters: ExpressionWrapper[] = []; - - for (let j = 0; j <= i; j++) { - const [field, order] = orderByItems[j]!; - const _order = negateOrderBy ? (order === 'asc' ? 'desc' : 'asc') : order; - const op = j === i ? (_order === 'asc' ? '>=' : '<=') : '='; - andFilters.push( - eb( - eb.ref(`${model}.${field}`), - op, - eb.selectFrom(model).select(`${model}.${field}`).where(cursorFilter), - ), - ); - } - - filters.push(eb.and(andFilters)); - } - - result = result.where((eb) => eb.or(filters)); - - return result; - } - protected async create( kysely: ToKysely, model: GetModels, diff --git a/packages/runtime/test/client-api/find.test.ts b/packages/runtime/test/client-api/find.test.ts index 70cc400c..adea8d3b 100644 --- a/packages/runtime/test/client-api/find.test.ts +++ b/packages/runtime/test/client-api/find.test.ts @@ -7,7 +7,7 @@ import { createPosts, createUser } from './utils'; const PG_DB_NAME = 'client-api-find-tests'; -describe.each(createClientSpecs(PG_DB_NAME))('Client find tests for $provider', ({ createClient }) => { +describe.each(createClientSpecs(PG_DB_NAME))('Client find tests for $provider', ({ createClient, provider }) => { let client: ClientContract; beforeEach(async () => { @@ -241,10 +241,11 @@ describe.each(createClientSpecs(PG_DB_NAME))('Client find tests for $provider', }); it('works with distinct', async () => { - await createUser(client, 'u1@test.com', { + const user1 = await createUser(client, 'u1@test.com', { name: 'Admin1', role: 'ADMIN', profile: { create: { bio: 'Bio1' } }, + posts: { create: [{ title: 'Post1' }, { title: 'Post1' }, { title: 'Post2' }] }, }); await createUser(client, 'u3@test.com', { name: 'User', @@ -260,8 +261,13 @@ describe.each(createClientSpecs(PG_DB_NAME))('Client find tests for $provider', role: 'USER', }); + if (provider === 'sqlite') { + await expect(client.user.findMany({ distinct: ['role'] } as any)).rejects.toThrow('not supported'); + return; + } + // single field distinct - let r: any = await client.user.findMany({ distinct: ['role'] }); + let r: any = await client.user.findMany({ distinct: ['role'] } as any); expect(r).toHaveLength(2); expect(r).toEqual( expect.arrayContaining([ @@ -270,8 +276,15 @@ describe.each(createClientSpecs(PG_DB_NAME))('Client find tests for $provider', ]), ); + // distinct in relation + r = await client.user.findUnique({ + where: { id: user1.id }, + include: { posts: { distinct: ['title'] } as any }, + }); + expect(r.posts).toHaveLength(2); + // distinct with include - r = await client.user.findMany({ distinct: ['role'], include: { profile: true } }); + r = await client.user.findMany({ distinct: ['role'], include: { profile: true } } as any); expect(r).toHaveLength(2); expect(r).toEqual( expect.arrayContaining([ @@ -281,14 +294,14 @@ describe.each(createClientSpecs(PG_DB_NAME))('Client find tests for $provider', ); // distinct with select - r = await client.user.findMany({ distinct: ['role'], select: { email: true } }); + r = await client.user.findMany({ distinct: ['role'], select: { email: true } } as any); expect(r).toHaveLength(2); expect(r).toEqual(expect.arrayContaining([{ email: expect.any(String) }, { email: expect.any(String) }])); // multiple fields distinct r = await client.user.findMany({ distinct: ['role', 'name'], - }); + } as any); expect(r).toHaveLength(3); expect(r).toEqual( expect.arrayContaining([ @@ -658,26 +671,28 @@ describe.each(createClientSpecs(PG_DB_NAME))('Client find tests for $provider', posts: [expect.objectContaining({ title: 'Post1' })], }); - await expect( - client.user.findUnique({ - where: { id: user.id }, - select: { - posts: { orderBy: { title: 'asc' }, skip: 1, take: 1, distinct: ['title'] }, - }, - }), - ).resolves.toMatchObject({ - posts: [expect.objectContaining({ title: 'Post2' })], - }); - await expect( - client.user.findUnique({ - where: { id: user.id }, - include: { - posts: { orderBy: { title: 'asc' }, skip: 1, take: 1, distinct: ['title'] }, - }, - }), - ).resolves.toMatchObject({ - posts: [expect.objectContaining({ title: 'Post2' })], - }); + if (provider === 'postgresql') { + await expect( + client.user.findUnique({ + where: { id: user.id }, + select: { + posts: { orderBy: { title: 'asc' }, skip: 1, take: 1, distinct: ['title'] } as any, + }, + }), + ).resolves.toMatchObject({ + posts: [expect.objectContaining({ title: 'Post2' })], + }); + await expect( + client.user.findUnique({ + where: { id: user.id }, + include: { + posts: { orderBy: { title: 'asc' }, skip: 1, take: 1, distinct: ['title'] } as any, + }, + }), + ).resolves.toMatchObject({ + posts: [expect.objectContaining({ title: 'Post2' })], + }); + } await expect( client.post.findFirst({ @@ -895,6 +910,19 @@ describe.each(createClientSpecs(PG_DB_NAME))('Client find tests for $provider', }); expect(u.posts[0]).toMatchObject(post1); + // cursor + u = await client.user.findUniqueOrThrow({ + where: { id: user.id }, + include: { + posts: { + orderBy: { title: 'asc' }, + cursor: { id: post2.id }, + }, + }, + }); + expect(u.posts).toHaveLength(1); + expect(u.posts?.[0]?.id).toBe(post2.id); + // skip and take u = await client.user.findUniqueOrThrow({ where: { id: user.id }, diff --git a/packages/runtime/test/schemas/typing/typecheck.ts b/packages/runtime/test/schemas/typing/typecheck.ts index fe35c9a1..e90b0f82 100644 --- a/packages/runtime/test/schemas/typing/typecheck.ts +++ b/packages/runtime/test/schemas/typing/typecheck.ts @@ -84,7 +84,6 @@ async function find() { email: 'asc', name: 'desc', }, - distinct: ['name'], cursor: { id: 1 }, });