Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 27 additions & 11 deletions packages/runtime/src/client/crud/dialects/base-dialect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import { InternalError, QueryError } from '../../errors';
import type { ClientOptions } from '../../options';
import {
aggregate,
buildFieldRef,
buildJoinPairs,
ensureArray,
flattenCompoundUniqueFilters,
Expand Down Expand Up @@ -931,15 +930,13 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
field: string,
): SelectQueryBuilder<any, any, any> {
const fieldDef = requireField(this.schema, model, field);
if (fieldDef.computed) {
// TODO: computed field from delegate base?
return query.select(() => this.fieldRef(model, field, modelAlias).as(field));
} else if (!fieldDef.originModel) {
// regular field
return query.select(this.eb.ref(`${modelAlias}.${field}`).as(field));
} else {
return this.buildSelectField(query, fieldDef.originModel, fieldDef.originModel, field);
}

// if field is defined on a delegate base, the base model is joined with its
// model name from outer query, so we should use it directly as the alias
const fieldModel = fieldDef.originModel ?? model;
const alias = fieldDef.originModel ?? modelAlias;

return query.select(() => this.fieldRef(fieldModel, field, alias).as(field));
}

buildDelegateJoin(
Expand Down Expand Up @@ -1071,7 +1068,26 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
}

fieldRef(model: string, field: string, modelAlias?: string, inlineComputedField = true) {
return buildFieldRef(this.schema, model, field, this.options, this.eb, modelAlias, inlineComputedField);
const fieldDef = requireField(this.schema, model, field);

if (!fieldDef.computed) {
// regular field
return this.eb.ref(modelAlias ? `${modelAlias}.${field}` : field);
} else {
// computed field
if (!inlineComputedField) {
return this.eb.ref(modelAlias ? `${modelAlias}.${field}` : field);
}
let computer: Function | undefined;
if ('computedFields' in this.options) {
const computedFields = this.options.computedFields as Record<string, any>;
computer = computedFields?.[fieldDef.originModel ?? model]?.[field];
}
if (!computer) {
throw new QueryError(`Computed field "${field}" implementation not provided for model "${model}"`);
}
return computer(this.eb, { modelAlias });
}
}

protected canJoinWithoutNestedSelect(
Expand Down
30 changes: 0 additions & 30 deletions packages/runtime/src/client/query-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import {
TableNode,
type Expression,
type ExpressionBuilder,
type ExpressionWrapper,
type OperationNode,
} from 'kysely';
import { match } from 'ts-pattern';
Expand All @@ -15,7 +14,6 @@ import { extractFields } from '../utils/object-utils';
import type { AGGREGATE_OPERATORS } from './constants';
import type { OrderBy } from './crud-types';
import { InternalError, QueryError } from './errors';
import type { ClientOptions } from './options';

export function hasModel(schema: SchemaDef, model: string) {
return Object.keys(schema.models)
Expand Down Expand Up @@ -180,34 +178,6 @@ export function getIdValues(schema: SchemaDef, model: string, data: any): Record
return idFields.reduce((acc, field) => ({ ...acc, [field]: data[field] }), {});
}

export function buildFieldRef<Schema extends SchemaDef>(
schema: Schema,
model: string,
field: string,
options: ClientOptions<Schema>,
eb: ExpressionBuilder<any, any>,
modelAlias?: string,
inlineComputedField = true,
): ExpressionWrapper<any, any, unknown> {
const fieldDef = requireField(schema, model, field);
if (!fieldDef.computed) {
return eb.ref(modelAlias ? `${modelAlias}.${field}` : field);
} else {
if (!inlineComputedField) {
return eb.ref(modelAlias ? `${modelAlias}.${field}` : field);
}
let computer: Function | undefined;
if ('computedFields' in options) {
const computedFields = options.computedFields as Record<string, any>;
computer = computedFields?.[model]?.[field];
}
if (!computer) {
throw new QueryError(`Computed field "${field}" implementation not provided for model "${model}"`);
}
return computer(eb, { modelAlias });
}
}

export function fieldHasDefaultValue(fieldDef: FieldDef) {
return fieldDef.default !== undefined || fieldDef.updatedAt;
}
Expand Down
3 changes: 2 additions & 1 deletion samples/blog/zenstack/schema.zmodel
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ enum Role {
}

plugin policy {
// due to pnpm layout we can't directly use package name here
// due to pnpm layout we can't directly use package name here,
// don't do this in your code and use "@zenstackhq/plugin-policy" instead
provider = '../node_modules/@zenstackhq/plugin-policy/dist/index.js'
}

Expand Down
56 changes: 44 additions & 12 deletions tests/e2e/orm/client-api/computed-fields.test.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
import { createTestClient } from '@zenstackhq/testtools';
import { afterEach, describe, expect, it } from 'vitest';
import { describe, expect, it } from 'vitest';

describe('Computed fields tests', () => {
let db: any;

afterEach(async () => {
await db?.$disconnect();
});

it('works with non-optional fields', async () => {
db = await createTestClient(
const db = await createTestClient(
`
model User {
id Int @id @default(autoincrement())
Expand Down Expand Up @@ -97,7 +91,7 @@ model User {
});

it('is typed correctly for non-optional fields', async () => {
db = await createTestClient(
await createTestClient(
`
model User {
id Int @id @default(autoincrement())
Expand Down Expand Up @@ -137,7 +131,7 @@ main();
});

it('works with optional fields', async () => {
db = await createTestClient(
const db = await createTestClient(
`
model User {
id Int @id @default(autoincrement())
Expand All @@ -164,7 +158,7 @@ model User {
});

it('is typed correctly for optional fields', async () => {
db = await createTestClient(
await createTestClient(
`
model User {
id Int @id @default(autoincrement())
Expand Down Expand Up @@ -203,7 +197,7 @@ main();
});

it('works with read from a relation', async () => {
db = await createTestClient(
const db = await createTestClient(
`
model User {
id Int @id @default(autoincrement())
Expand Down Expand Up @@ -240,4 +234,42 @@ model Post {
author: expect.objectContaining({ postCount: 1 }),
});
});

it('allows sub models to use computed fields from delegate base', async () => {
const db = await createTestClient(
`
model Content {
id Int @id @default(autoincrement())
title String
isNews Boolean @computed
contentType String
@@delegate(contentType)
}

model Post extends Content {
body String
}
`,
{
computedFields: {
Content: {
isNews: (eb: any) => eb('title', 'like', '%news%'),
},
},
} as any,
);

const posts = await db.post.createManyAndReturn({
data: [
{ id: 1, title: 'latest news', body: 'some news content' },
{ id: 2, title: 'random post', body: 'some other content' },
],
});
expect(posts).toEqual(
expect.arrayContaining([
expect.objectContaining({ id: 1, isNews: true }),
expect.objectContaining({ id: 2, isNews: false }),
]),
);
});
});