diff --git a/packages/runtime/src/client/crud/dialects/base.ts b/packages/runtime/src/client/crud/dialects/base.ts index a38cb479..e3d6e3d0 100644 --- a/packages/runtime/src/client/crud/dialects/base.ts +++ b/packages/runtime/src/client/crud/dialects/base.ts @@ -1,6 +1,6 @@ import { invariant, isPlainObject } from '@zenstackhq/common-helpers'; import type { Expression, ExpressionBuilder, ExpressionWrapper, SqlBool, ValueNode } from 'kysely'; -import { sql, type SelectQueryBuilder } from 'kysely'; +import { expressionBuilder, sql, type SelectQueryBuilder } from 'kysely'; import { match, P } from 'ts-pattern'; import type { BuiltinType, DataSourceProviderType, FieldDef, GetModels, SchemaDef } from '../../../schema'; import { enumerate } from '../../../utils/enumerate'; @@ -95,11 +95,9 @@ export abstract class BaseCrudDialect { result = this.and(eb, result, this.buildRelationFilter(eb, model, modelAlias, key, fieldDef, payload)); } else { // if the field is from a base model, build a reference from that model - const fieldRef = buildFieldRef( - this.schema, + const fieldRef = this.fieldRef( fieldDef.originModel ?? model, key, - this.options, eb, fieldDef.originModel ?? modelAlias, ); @@ -727,7 +725,8 @@ export abstract class BaseCrudDialect { for (const [k, v] of Object.entries(value)) { invariant(v === 'asc' || v === 'desc', `invalid orderBy value for field "${field}"`); result = result.orderBy( - (eb) => aggregate(eb, sql.ref(`${modelAlias}.${k}`), field as AGGREGATE_OPERATORS), + (eb) => + aggregate(eb, this.fieldRef(model, k, eb, modelAlias), field as AGGREGATE_OPERATORS), sql.raw(this.negateSort(v, negated)), ); } @@ -740,7 +739,7 @@ export abstract class BaseCrudDialect { for (const [k, v] of Object.entries(value)) { invariant(v === 'asc' || v === 'desc', `invalid orderBy value for field "${field}"`); result = result.orderBy( - (eb) => eb.fn.count(sql.ref(k)), + (eb) => eb.fn.count(this.fieldRef(model, k, eb, modelAlias)), sql.raw(this.negateSort(v, negated)), ); } @@ -753,8 +752,9 @@ export abstract class BaseCrudDialect { const fieldDef = requireField(this.schema, model, field); if (!fieldDef.relation) { + const fieldRef = this.fieldRef(model, field, expressionBuilder(), modelAlias); if (value === 'asc' || value === 'desc') { - result = result.orderBy(sql.ref(`${modelAlias}.${field}`), this.negateSort(value, negated)); + result = result.orderBy(fieldRef, this.negateSort(value, negated)); } else if ( value && typeof value === 'object' && @@ -764,7 +764,7 @@ export abstract class BaseCrudDialect { (value.nulls === 'first' || value.nulls === 'last') ) { result = result.orderBy( - sql.ref(`${modelAlias}.${field}`), + fieldRef, sql.raw(`${this.negateSort(value.sort, negated)} nulls ${value.nulls}`), ); } @@ -865,7 +865,7 @@ export abstract class BaseCrudDialect { const fieldDef = requireField(this.schema, model, field); if (fieldDef.computed) { // TODO: computed field from delegate base? - return query.select((eb) => buildFieldRef(this.schema, model, field, this.options, eb).as(field)); + return query.select((eb) => this.fieldRef(model, field, eb, modelAlias).as(field)); } else if (!fieldDef.originModel) { // regular field return query.select(sql.ref(`${modelAlias}.${field}`).as(field)); @@ -993,6 +993,10 @@ export abstract class BaseCrudDialect { return eb.not(this.and(eb, ...args)); } + fieldRef(model: string, field: string, eb: ExpressionBuilder, modelAlias?: string) { + return buildFieldRef(this.schema, model, field, this.options, eb, modelAlias); + } + // #endregion // #region abstract methods diff --git a/packages/runtime/src/client/crud/dialects/postgresql.ts b/packages/runtime/src/client/crud/dialects/postgresql.ts index 5cb9c5de..65ff3988 100644 --- a/packages/runtime/src/client/crud/dialects/postgresql.ts +++ b/packages/runtime/src/client/crud/dialects/postgresql.ts @@ -12,7 +12,6 @@ import type { BuiltinType, FieldDef, GetModels, SchemaDef } from '../../../schem import { DELEGATE_JOINED_FIELD_PREFIX } from '../../constants'; import type { FindArgs } from '../../crud-types'; import { - buildFieldRef, buildJoinPairs, getDelegateDescendantModels, getIdFields, @@ -227,10 +226,7 @@ export class PostgresCrudDialect extends BaseCrudDiale ...Object.entries(relationModelDef.fields) .filter(([, value]) => !value.relation) .filter(([name]) => !(typeof payload === 'object' && (payload.omit as any)?.[name] === true)) - .map(([field]) => [ - sql.lit(field), - buildFieldRef(this.schema, relationModel, field, this.options, eb), - ]) + .map(([field]) => [sql.lit(field), this.fieldRef(relationModel, field, eb)]) .flatMap((v) => v), ); } else if (payload.select) { @@ -253,7 +249,7 @@ export class PostgresCrudDialect extends BaseCrudDiale ? // reference the synthesized JSON field eb.ref(`${parentAlias}$${relationField}$${field}.$j`) : // reference a plain field - buildFieldRef(this.schema, relationModel, field, this.options, eb); + this.fieldRef(relationModel, field, eb); return [sql.lit(field), fieldValue]; } }) diff --git a/packages/runtime/src/client/crud/dialects/sqlite.ts b/packages/runtime/src/client/crud/dialects/sqlite.ts index 9277af48..127a13b4 100644 --- a/packages/runtime/src/client/crud/dialects/sqlite.ts +++ b/packages/runtime/src/client/crud/dialects/sqlite.ts @@ -13,7 +13,6 @@ import type { BuiltinType, GetModels, SchemaDef } from '../../../schema'; import { DELEGATE_JOINED_FIELD_PREFIX } from '../../constants'; import type { FindArgs } from '../../crud-types'; import { - buildFieldRef, getDelegateDescendantModels, getIdFields, getManyToManyRelation, @@ -171,10 +170,7 @@ export class SqliteCrudDialect extends BaseCrudDialect ...Object.entries(relationModelDef.fields) .filter(([, value]) => !value.relation) .filter(([name]) => !(typeof payload === 'object' && (payload.omit as any)?.[name] === true)) - .map(([field]) => [ - sql.lit(field), - buildFieldRef(this.schema, relationModel, field, this.options, eb), - ]) + .map(([field]) => [sql.lit(field), this.fieldRef(relationModel, field, eb)]) .flatMap((v) => v), ); } else if (payload.select) { @@ -203,10 +199,7 @@ export class SqliteCrudDialect extends BaseCrudDialect ); return [sql.lit(field), subJson]; } else { - return [ - sql.lit(field), - buildFieldRef(this.schema, relationModel, field, this.options, eb) as ArgsType, - ]; + return [sql.lit(field), this.fieldRef(relationModel, field, eb) as ArgsType]; } } }) diff --git a/packages/runtime/src/client/crud/operations/base.ts b/packages/runtime/src/client/crud/operations/base.ts index cc79057f..237b3157 100644 --- a/packages/runtime/src/client/crud/operations/base.ts +++ b/packages/runtime/src/client/crud/operations/base.ts @@ -29,7 +29,6 @@ import type { FindArgs, SelectIncludeOmit, SortOrder, WhereInput } from '../../c import { InternalError, NotFoundError, QueryError } from '../../errors'; import type { ToKysely } from '../../query-builder'; import { - buildFieldRef, ensureArray, extractIdFields, flattenCompoundUniqueFilters, @@ -187,9 +186,7 @@ export abstract class BaseOperationHandler { // make sure distinct fields are selected query = distinct.reduce( (acc, field) => - acc.select((eb) => - buildFieldRef(this.schema, model, field, this.options, eb).as(`$distinct$${field}`), - ), + acc.select((eb) => this.dialect.fieldRef(model, field, eb).as(`$distinct$${field}`)), query, ); } @@ -1267,7 +1264,7 @@ export abstract class BaseOperationHandler { const key = Object.keys(payload)[0]; const value = this.dialect.transformPrimitive(payload[key!], fieldDef.type as BuiltinType, false); const eb = expressionBuilder(); - const fieldRef = buildFieldRef(this.schema, model, field, this.options, eb); + const fieldRef = this.dialect.fieldRef(model, field, eb); return match(key) .with('set', () => value) @@ -1290,7 +1287,7 @@ export abstract class BaseOperationHandler { const key = Object.keys(payload)[0]; const value = this.dialect.transformPrimitive(payload[key!], fieldDef.type as BuiltinType, true); const eb = expressionBuilder(); - const fieldRef = buildFieldRef(this.schema, model, field, this.options, eb); + const fieldRef = this.dialect.fieldRef(model, field, eb); return match(key) .with('set', () => value) diff --git a/packages/runtime/src/client/crud/operations/group-by.ts b/packages/runtime/src/client/crud/operations/group-by.ts index cdf99b8c..14bb77b5 100644 --- a/packages/runtime/src/client/crud/operations/group-by.ts +++ b/packages/runtime/src/client/crud/operations/group-by.ts @@ -1,7 +1,7 @@ -import { sql } from 'kysely'; +import { expressionBuilder } from 'kysely'; import { match } from 'ts-pattern'; import type { SchemaDef } from '../../../schema'; -import { getField } from '../../query-utils'; +import { aggregate, getField } from '../../query-utils'; import { BaseOperationHandler } from './base'; export class GroupByOperationHandler extends BaseOperationHandler { @@ -44,9 +44,11 @@ export class GroupByOperationHandler extends BaseOpera return subQuery.as('$sub'); }); + const fieldRef = (field: string) => this.dialect.fieldRef(this.model, field, expressionBuilder(), '$sub'); + // groupBy const bys = typeof parsedArgs.by === 'string' ? [parsedArgs.by] : (parsedArgs.by as string[]); - query = query.groupBy(bys.map((by) => sql.ref(`$sub.${by}`))); + query = query.groupBy(bys.map((by) => fieldRef(by))); // orderBy if (parsedArgs.orderBy) { @@ -59,7 +61,7 @@ export class GroupByOperationHandler extends BaseOpera // select all by fields for (const by of bys) { - query = query.select(() => sql.ref(`$sub.${by}`).as(by)); + query = query.select(() => fieldRef(by).as(by)); } // aggregations @@ -77,7 +79,7 @@ export class GroupByOperationHandler extends BaseOpera ); } else { query = query.select((eb) => - eb.cast(eb.fn.count(sql.ref(`$sub.${field}`)), 'integer').as(`${key}.${field}`), + eb.cast(eb.fn.count(fieldRef(field)), 'integer').as(`${key}.${field}`), ); } } @@ -92,15 +94,7 @@ export class GroupByOperationHandler extends BaseOpera case '_min': { Object.entries(value).forEach(([field, val]) => { if (val === true) { - query = query.select((eb) => { - const fn = match(key) - .with('_sum', () => eb.fn.sum) - .with('_avg', () => eb.fn.avg) - .with('_max', () => eb.fn.max) - .with('_min', () => eb.fn.min) - .exhaustive(); - return fn(sql.ref(`$sub.${field}`)).as(`${key}.${field}`); - }); + query = query.select((eb) => aggregate(eb, fieldRef(field), key).as(`${key}.${field}`)); } }); break; diff --git a/packages/runtime/src/client/query-utils.ts b/packages/runtime/src/client/query-utils.ts index 1cad9bb1..d2c7b649 100644 --- a/packages/runtime/src/client/query-utils.ts +++ b/packages/runtime/src/client/query-utils.ts @@ -318,11 +318,7 @@ export function getDelegateDescendantModels( return [...collected]; } -export function aggregate( - eb: ExpressionBuilder, - expr: Expression, - op: AGGREGATE_OPERATORS, -): Expression { +export function aggregate(eb: ExpressionBuilder, expr: Expression, op: AGGREGATE_OPERATORS) { return match(op) .with('_count', () => eb.fn.count(expr)) .with('_sum', () => eb.fn.sum(expr)) diff --git a/packages/runtime/test/client-api/computed-fields.test.ts b/packages/runtime/test/client-api/computed-fields.test.ts index 6a28c4d3..353f495f 100644 --- a/packages/runtime/test/client-api/computed-fields.test.ts +++ b/packages/runtime/test/client-api/computed-fields.test.ts @@ -36,6 +36,58 @@ model User { ).resolves.toMatchObject({ upperName: 'ALEX', }); + + await expect( + db.user.findFirst({ + where: { upperName: 'ALEX' }, + }), + ).resolves.toMatchObject({ + upperName: 'ALEX', + }); + + await expect( + db.user.findFirst({ + where: { upperName: 'Alex' }, + }), + ).toResolveNull(); + + await expect( + db.user.findFirst({ + orderBy: { upperName: 'desc' }, + }), + ).resolves.toMatchObject({ + upperName: 'ALEX', + }); + + await expect( + db.user.findFirst({ + orderBy: { upperName: 'desc' }, + take: -1, + }), + ).resolves.toMatchObject({ + upperName: 'ALEX', + }); + + await expect( + db.user.aggregate({ + _count: { upperName: true }, + }), + ).resolves.toMatchObject({ + _count: { upperName: 1 }, + }); + + await expect( + db.user.groupBy({ + by: ['upperName'], + _count: { upperName: true }, + _max: { upperName: true }, + }), + ).resolves.toEqual([ + expect.objectContaining({ + _count: { upperName: 1 }, + _max: { upperName: 'ALEX' }, + }), + ]); }); it('is typed correctly for non-optional fields', async () => {