Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions packages/runtime/src/client/crud/dialects/base.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { invariant, isPlainObject } from '@zenstackhq/common-helpers';
import type { Expression, ExpressionBuilder, ExpressionWrapper, SqlBool, ValueNode } from 'kysely';
import { sql, type SelectQueryBuilder } from 'kysely';
import { expressionBuilder, sql, type SelectQueryBuilder } from 'kysely';
import { match, P } from 'ts-pattern';
import type { BuiltinType, DataSourceProviderType, FieldDef, GetModels, SchemaDef } from '../../../schema';
import { enumerate } from '../../../utils/enumerate';
Expand Down Expand Up @@ -95,11 +95,9 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
result = this.and(eb, result, this.buildRelationFilter(eb, model, modelAlias, key, fieldDef, payload));
} else {
// if the field is from a base model, build a reference from that model
const fieldRef = buildFieldRef(
this.schema,
const fieldRef = this.fieldRef(
fieldDef.originModel ?? model,
key,
this.options,
eb,
fieldDef.originModel ?? modelAlias,
);
Expand Down Expand Up @@ -727,7 +725,8 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
for (const [k, v] of Object.entries<string>(value)) {
invariant(v === 'asc' || v === 'desc', `invalid orderBy value for field "${field}"`);
result = result.orderBy(
(eb) => aggregate(eb, sql.ref(`${modelAlias}.${k}`), field as AGGREGATE_OPERATORS),
(eb) =>
aggregate(eb, this.fieldRef(model, k, eb, modelAlias), field as AGGREGATE_OPERATORS),
sql.raw(this.negateSort(v, negated)),
);
}
Expand All @@ -740,7 +739,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
for (const [k, v] of Object.entries<string>(value)) {
invariant(v === 'asc' || v === 'desc', `invalid orderBy value for field "${field}"`);
result = result.orderBy(
(eb) => eb.fn.count(sql.ref(k)),
(eb) => eb.fn.count(this.fieldRef(model, k, eb, modelAlias)),
sql.raw(this.negateSort(v, negated)),
);
}
Expand All @@ -753,8 +752,9 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
const fieldDef = requireField(this.schema, model, field);

if (!fieldDef.relation) {
const fieldRef = this.fieldRef(model, field, expressionBuilder(), modelAlias);
if (value === 'asc' || value === 'desc') {
result = result.orderBy(sql.ref(`${modelAlias}.${field}`), this.negateSort(value, negated));
result = result.orderBy(fieldRef, this.negateSort(value, negated));
} else if (
value &&
typeof value === 'object' &&
Expand All @@ -764,7 +764,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
(value.nulls === 'first' || value.nulls === 'last')
) {
result = result.orderBy(
sql.ref(`${modelAlias}.${field}`),
fieldRef,
sql.raw(`${this.negateSort(value.sort, negated)} nulls ${value.nulls}`),
);
}
Expand Down Expand Up @@ -865,7 +865,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
const fieldDef = requireField(this.schema, model, field);
if (fieldDef.computed) {
// TODO: computed field from delegate base?
return query.select((eb) => buildFieldRef(this.schema, model, field, this.options, eb).as(field));
return query.select((eb) => this.fieldRef(model, field, eb, modelAlias).as(field));
} else if (!fieldDef.originModel) {
// regular field
return query.select(sql.ref(`${modelAlias}.${field}`).as(field));
Expand Down Expand Up @@ -993,6 +993,10 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
return eb.not(this.and(eb, ...args));
}

fieldRef(model: string, field: string, eb: ExpressionBuilder<any, any>, modelAlias?: string) {
return buildFieldRef(this.schema, model, field, this.options, eb, modelAlias);
}

// #endregion

// #region abstract methods
Expand Down
8 changes: 2 additions & 6 deletions packages/runtime/src/client/crud/dialects/postgresql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import type { BuiltinType, FieldDef, GetModels, SchemaDef } from '../../../schem
import { DELEGATE_JOINED_FIELD_PREFIX } from '../../constants';
import type { FindArgs } from '../../crud-types';
import {
buildFieldRef,
buildJoinPairs,
getDelegateDescendantModels,
getIdFields,
Expand Down Expand Up @@ -227,10 +226,7 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
...Object.entries(relationModelDef.fields)
.filter(([, value]) => !value.relation)
.filter(([name]) => !(typeof payload === 'object' && (payload.omit as any)?.[name] === true))
.map(([field]) => [
sql.lit(field),
buildFieldRef(this.schema, relationModel, field, this.options, eb),
])
.map(([field]) => [sql.lit(field), this.fieldRef(relationModel, field, eb)])
.flatMap((v) => v),
);
} else if (payload.select) {
Expand All @@ -253,7 +249,7 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
? // reference the synthesized JSON field
eb.ref(`${parentAlias}$${relationField}$${field}.$j`)
: // reference a plain field
buildFieldRef(this.schema, relationModel, field, this.options, eb);
this.fieldRef(relationModel, field, eb);
return [sql.lit(field), fieldValue];
}
})
Expand Down
11 changes: 2 additions & 9 deletions packages/runtime/src/client/crud/dialects/sqlite.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import type { BuiltinType, GetModels, SchemaDef } from '../../../schema';
import { DELEGATE_JOINED_FIELD_PREFIX } from '../../constants';
import type { FindArgs } from '../../crud-types';
import {
buildFieldRef,
getDelegateDescendantModels,
getIdFields,
getManyToManyRelation,
Expand Down Expand Up @@ -171,10 +170,7 @@ export class SqliteCrudDialect<Schema extends SchemaDef> extends BaseCrudDialect
...Object.entries(relationModelDef.fields)
.filter(([, value]) => !value.relation)
.filter(([name]) => !(typeof payload === 'object' && (payload.omit as any)?.[name] === true))
.map(([field]) => [
sql.lit(field),
buildFieldRef(this.schema, relationModel, field, this.options, eb),
])
.map(([field]) => [sql.lit(field), this.fieldRef(relationModel, field, eb)])
.flatMap((v) => v),
);
} else if (payload.select) {
Expand Down Expand Up @@ -203,10 +199,7 @@ export class SqliteCrudDialect<Schema extends SchemaDef> extends BaseCrudDialect
);
return [sql.lit(field), subJson];
} else {
return [
sql.lit(field),
buildFieldRef(this.schema, relationModel, field, this.options, eb) as ArgsType,
];
return [sql.lit(field), this.fieldRef(relationModel, field, eb) as ArgsType];
}
}
})
Expand Down
9 changes: 3 additions & 6 deletions packages/runtime/src/client/crud/operations/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import type { FindArgs, SelectIncludeOmit, SortOrder, WhereInput } from '../../c
import { InternalError, NotFoundError, QueryError } from '../../errors';
import type { ToKysely } from '../../query-builder';
import {
buildFieldRef,
ensureArray,
extractIdFields,
flattenCompoundUniqueFilters,
Expand Down Expand Up @@ -187,9 +186,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
// make sure distinct fields are selected
query = distinct.reduce(
(acc, field) =>
acc.select((eb) =>
buildFieldRef(this.schema, model, field, this.options, eb).as(`$distinct$${field}`),
),
acc.select((eb) => this.dialect.fieldRef(model, field, eb).as(`$distinct$${field}`)),
query,
);
}
Expand Down Expand Up @@ -1267,7 +1264,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
const key = Object.keys(payload)[0];
const value = this.dialect.transformPrimitive(payload[key!], fieldDef.type as BuiltinType, false);
const eb = expressionBuilder<any, any>();
const fieldRef = buildFieldRef(this.schema, model, field, this.options, eb);
const fieldRef = this.dialect.fieldRef(model, field, eb);

return match(key)
.with('set', () => value)
Expand All @@ -1290,7 +1287,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
const key = Object.keys(payload)[0];
const value = this.dialect.transformPrimitive(payload[key!], fieldDef.type as BuiltinType, true);
const eb = expressionBuilder<any, any>();
const fieldRef = buildFieldRef(this.schema, model, field, this.options, eb);
const fieldRef = this.dialect.fieldRef(model, field, eb);

return match(key)
.with('set', () => value)
Expand Down
22 changes: 8 additions & 14 deletions packages/runtime/src/client/crud/operations/group-by.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { sql } from 'kysely';
import { expressionBuilder } from 'kysely';
import { match } from 'ts-pattern';
import type { SchemaDef } from '../../../schema';
import { getField } from '../../query-utils';
import { aggregate, getField } from '../../query-utils';
import { BaseOperationHandler } from './base';

export class GroupByOperationHandler<Schema extends SchemaDef> extends BaseOperationHandler<Schema> {
Expand Down Expand Up @@ -44,9 +44,11 @@ export class GroupByOperationHandler<Schema extends SchemaDef> extends BaseOpera
return subQuery.as('$sub');
});

const fieldRef = (field: string) => this.dialect.fieldRef(this.model, field, expressionBuilder(), '$sub');

// groupBy
const bys = typeof parsedArgs.by === 'string' ? [parsedArgs.by] : (parsedArgs.by as string[]);
query = query.groupBy(bys.map((by) => sql.ref(`$sub.${by}`)));
query = query.groupBy(bys.map((by) => fieldRef(by)));

// orderBy
if (parsedArgs.orderBy) {
Expand All @@ -59,7 +61,7 @@ export class GroupByOperationHandler<Schema extends SchemaDef> extends BaseOpera

// select all by fields
for (const by of bys) {
query = query.select(() => sql.ref(`$sub.${by}`).as(by));
query = query.select(() => fieldRef(by).as(by));
}

// aggregations
Expand All @@ -77,7 +79,7 @@ export class GroupByOperationHandler<Schema extends SchemaDef> extends BaseOpera
);
} else {
query = query.select((eb) =>
eb.cast(eb.fn.count(sql.ref(`$sub.${field}`)), 'integer').as(`${key}.${field}`),
eb.cast(eb.fn.count(fieldRef(field)), 'integer').as(`${key}.${field}`),
);
}
}
Expand All @@ -92,15 +94,7 @@ export class GroupByOperationHandler<Schema extends SchemaDef> extends BaseOpera
case '_min': {
Object.entries(value).forEach(([field, val]) => {
if (val === true) {
query = query.select((eb) => {
const fn = match(key)
.with('_sum', () => eb.fn.sum)
.with('_avg', () => eb.fn.avg)
.with('_max', () => eb.fn.max)
.with('_min', () => eb.fn.min)
.exhaustive();
return fn(sql.ref(`$sub.${field}`)).as(`${key}.${field}`);
});
query = query.select((eb) => aggregate(eb, fieldRef(field), key).as(`${key}.${field}`));
}
});
break;
Expand Down
6 changes: 1 addition & 5 deletions packages/runtime/src/client/query-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -318,11 +318,7 @@ export function getDelegateDescendantModels(
return [...collected];
}

export function aggregate(
eb: ExpressionBuilder<any, any>,
expr: Expression<any>,
op: AGGREGATE_OPERATORS,
): Expression<any> {
export function aggregate(eb: ExpressionBuilder<any, any>, expr: Expression<any>, op: AGGREGATE_OPERATORS) {
return match(op)
.with('_count', () => eb.fn.count(expr))
.with('_sum', () => eb.fn.sum(expr))
Expand Down
52 changes: 52 additions & 0 deletions packages/runtime/test/client-api/computed-fields.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,58 @@ model User {
).resolves.toMatchObject({
upperName: 'ALEX',
});

await expect(
db.user.findFirst({
where: { upperName: 'ALEX' },
}),
).resolves.toMatchObject({
upperName: 'ALEX',
});

await expect(
db.user.findFirst({
where: { upperName: 'Alex' },
}),
).toResolveNull();

await expect(
db.user.findFirst({
orderBy: { upperName: 'desc' },
}),
).resolves.toMatchObject({
upperName: 'ALEX',
});

await expect(
db.user.findFirst({
orderBy: { upperName: 'desc' },
take: -1,
}),
).resolves.toMatchObject({
upperName: 'ALEX',
});

await expect(
db.user.aggregate({
_count: { upperName: true },
}),
).resolves.toMatchObject({
_count: { upperName: 1 },
});

await expect(
db.user.groupBy({
by: ['upperName'],
_count: { upperName: true },
_max: { upperName: true },
}),
).resolves.toEqual([
expect.objectContaining({
_count: { upperName: 1 },
_max: { upperName: 'ALEX' },
}),
]);
});

it('is typed correctly for non-optional fields', async () => {
Expand Down