Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions packages/runtime/src/client/crud/dialects/base-dialect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
if ('distinct' in args && (args as any).distinct) {
const distinct = ensureArray((args as any).distinct) as string[];
if (this.supportsDistinctOn) {
result = result.distinctOn(distinct.map((f) => sql.ref(`${modelAlias}.${f}`)));
result = result.distinctOn(distinct.map((f) => this.eb.ref(`${modelAlias}.${f}`)));
} else {
throw new QueryError(`"distinct" is not supported by "${this.schema.provider.type}" provider`);
}
Expand Down Expand Up @@ -248,7 +248,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {

if (ownedByModel && !fieldDef.originModel) {
// can be short-circuited to FK null check
return this.and(...keyPairs.map(({ fk }) => this.eb(sql.ref(`${modelAlias}.${fk}`), 'is', null)));
return this.and(...keyPairs.map(({ fk }) => this.eb(this.eb.ref(`${modelAlias}.${fk}`), 'is', null)));
} else {
// translate it to `{ is: null }` filter
return this.buildToOneRelationFilter(model, modelAlias, field, fieldDef, { is: null });
Expand All @@ -268,7 +268,9 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {

const joinSelect = this.eb
.selectFrom(`${fieldDef.type} as ${joinAlias}`)
.where(() => this.and(...joinPairs.map(([left, right]) => this.eb(sql.ref(left), '=', sql.ref(right)))))
.where(() =>
this.and(...joinPairs.map(([left, right]) => this.eb(this.eb.ref(left), '=', this.eb.ref(right)))),
)
.select(() => this.eb.fn.count(this.eb.lit(1)).as(filterResultField));

const conditions: Expression<SqlBool>[] = [];
Expand Down Expand Up @@ -331,7 +333,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
) {
// null check needs to be converted to fk "is null" checks
if (payload === null) {
return this.eb(sql.ref(`${modelAlias}.${field}`), 'is', null);
return this.eb(this.eb.ref(`${modelAlias}.${field}`), 'is', null);
}

const relationModel = fieldDef.type;
Expand All @@ -351,15 +353,15 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
invariant(relationIdFields.length === 1, 'many-to-many relation must have exactly one id field');

return eb(
sql.ref(`${relationFilterSelectAlias}.${relationIdFields[0]}`),
this.eb.ref(`${relationFilterSelectAlias}.${relationIdFields[0]}`),
'in',
eb
.selectFrom(m2m.joinTable)
.select(`${m2m.joinTable}.${m2m.otherFkName}`)
.whereRef(
sql.ref(`${m2m.joinTable}.${m2m.parentFkName}`),
this.eb.ref(`${m2m.joinTable}.${m2m.parentFkName}`),
'=',
sql.ref(`${modelAlias}.${modelIdFields[0]}`),
this.eb.ref(`${modelAlias}.${modelIdFields[0]}`),
),
);
} else {
Expand All @@ -370,12 +372,20 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
if (relationKeyPairs.ownedByModel) {
result = this.and(
result,
eb(sql.ref(`${modelAlias}.${fk}`), '=', sql.ref(`${relationFilterSelectAlias}.${pk}`)),
eb(
this.eb.ref(`${modelAlias}.${fk}`),
'=',
this.eb.ref(`${relationFilterSelectAlias}.${pk}`),
),
);
} else {
result = this.and(
result,
eb(sql.ref(`${modelAlias}.${pk}`), '=', sql.ref(`${relationFilterSelectAlias}.${fk}`)),
eb(
this.eb.ref(`${modelAlias}.${pk}`),
'=',
this.eb.ref(`${relationFilterSelectAlias}.${fk}`),
),
);
}
}
Expand Down Expand Up @@ -833,7 +843,9 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
const joinPairs = buildJoinPairs(this.schema, model, modelAlias, field, subQueryAlias);
subQuery = subQuery.where(() =>
this.and(
...joinPairs.map(([left, right]) => eb(sql.ref(left), '=', sql.ref(right))),
...joinPairs.map(([left, right]) =>
eb(this.eb.ref(left), '=', this.eb.ref(right)),
),
),
);
subQuery = subQuery.select(() => eb.fn.count(eb.lit(1)).as('_count'));
Expand All @@ -845,7 +857,9 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
result = result.leftJoin(relationModel, (join) => {
const joinPairs = buildJoinPairs(this.schema, model, modelAlias, field, relationModel);
return join.on((eb) =>
this.and(...joinPairs.map(([left, right]) => eb(sql.ref(left), '=', sql.ref(right)))),
this.and(
...joinPairs.map(([left, right]) => eb(this.eb.ref(left), '=', this.eb.ref(right))),
),
);
});
result = this.buildOrderBy(result, fieldDef.type, relationModel, value, false, negated);
Expand Down Expand Up @@ -934,7 +948,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
return query.select(() => this.fieldRef(model, field, modelAlias).as(field));
} else if (!fieldDef.originModel) {
// regular field
return query.select(sql.ref(`${modelAlias}.${field}`).as(field));
return query.select(this.eb.ref(`${modelAlias}.${field}`).as(field));
} else {
return this.buildSelectField(query, fieldDef.originModel, fieldDef.originModel, field);
}
Expand Down
2 changes: 1 addition & 1 deletion packages/runtime/src/client/crud/dialects/postgresql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
} else {
const joinPairs = buildJoinPairs(this.schema, model, parentAlias, relationField, relationModelAlias);
query = query.where((eb) =>
this.and(...joinPairs.map(([left, right]) => eb(sql.ref(left), '=', sql.ref(right)))),
this.and(...joinPairs.map(([left, right]) => eb(this.eb.ref(left), '=', this.eb.ref(right)))),
);
}
return query;
Expand Down
7 changes: 4 additions & 3 deletions packages/runtime/src/client/crud/operations/aggregate.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import { sql } from 'kysely';
import { match } from 'ts-pattern';
import type { SchemaDef } from '../../../schema';
import { getField } from '../../query-utils';
Expand Down Expand Up @@ -80,7 +79,9 @@ export class AggregateOperationHandler<Schema extends SchemaDef> extends BaseOpe
);
} else {
query = query.select((eb) =>
eb.cast(eb.fn.count(sql.ref(`$sub.${field}`)), 'integer').as(`${key}.${field}`),
eb
.cast(eb.fn.count(eb.ref(`$sub.${field}` as any)), 'integer')
.as(`${key}.${field}`),
);
}
}
Expand All @@ -102,7 +103,7 @@ export class AggregateOperationHandler<Schema extends SchemaDef> extends BaseOpe
.with('_max', () => eb.fn.max)
.with('_min', () => eb.fn.min)
.exhaustive();
return fn(sql.ref(`$sub.${field}`)).as(`${key}.${field}`);
return fn(eb.ref(`$sub.${field}` as any)).as(`${key}.${field}`);
});
}
});
Expand Down
4 changes: 3 additions & 1 deletion packages/runtime/src/client/crud/operations/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1540,7 +1540,9 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
if (!relationFieldDef.array) {
const query = kysely
.updateTable(model)
.where((eb) => eb.and(keyPairs.map(({ fk, pk }) => eb(sql.ref(fk), '=', fromRelation.ids[pk]))))
.where((eb) =>
eb.and(keyPairs.map(({ fk, pk }) => eb(eb.ref(fk as any), '=', fromRelation.ids[pk]))),
)
.set(keyPairs.reduce((acc, { fk }) => ({ ...acc, [fk]: null }), {} as any))
.modifyEnd(
this.makeContextComment({
Expand Down
3 changes: 1 addition & 2 deletions packages/runtime/src/client/crud/operations/count.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import { sql } from 'kysely';
import type { SchemaDef } from '../../../schema';
import { BaseOperationHandler } from './base';

Expand Down Expand Up @@ -40,7 +39,7 @@ export class CountOperationHandler<Schema extends SchemaDef> extends BaseOperati
Object.keys(parsedArgs.select!).map((key) =>
key === '_all'
? eb.cast(eb.fn.countAll(), 'integer').as('_all')
: eb.cast(eb.fn.count(sql.ref(`${subQueryName}.${key}`)), 'integer').as(key),
: eb.cast(eb.fn.count(eb.ref(`${subQueryName}.${key}` as any)), 'integer').as(key),
),
);
const result = await this.executeQuery(this.kysely, query, 'count');
Expand Down
16 changes: 10 additions & 6 deletions packages/runtime/src/client/executor/name-mapper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ export class QueryNameMapper extends OperationNodeTransformer {
this.modelToTableMap.set(modelName, mappedName);
}

for (const [fieldName, fieldDef] of Object.entries(modelDef.fields)) {
for (const fieldDef of this.getModelFields(modelDef)) {
const mappedName = this.getMappedName(fieldDef);
if (mappedName) {
this.fieldToColumnMap.set(`${modelName}.${fieldName}`, mappedName);
this.fieldToColumnMap.set(`${modelName}.${fieldDef.name}`, mappedName);
}
}
}
Expand Down Expand Up @@ -72,11 +72,14 @@ export class QueryNameMapper extends OperationNodeTransformer {
on: this.transformNode(join.on),
}))
: undefined;
const selections = this.processSelectQuerySelections(node);
const baseResult = super.transformSelectQuery(node);

return {
...super.transformSelectQuery(node),
...baseResult,
from: FromNode.create(processedFroms.map((f) => f.node)),
joins,
selections: this.processSelectQuerySelections(node),
selections,
};
});
}
Expand Down Expand Up @@ -132,7 +135,8 @@ export class QueryNameMapper extends OperationNodeTransformer {
mappedTableName ? TableNode.create(mappedTableName) : undefined,
);
} else {
return super.transformReference(node);
// no name mapping needed
return node;
}
}

Expand Down Expand Up @@ -270,7 +274,7 @@ export class QueryNameMapper extends OperationNodeTransformer {
if (!modelDef) {
continue;
}
if (modelDef.fields[name]) {
if (this.getModelFields(modelDef).some((f) => f.name === name)) {
return scope;
}
}
Expand Down
3 changes: 1 addition & 2 deletions tests/e2e/orm/client-api/computed-fields.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import { sql } from '@zenstackhq/runtime/helpers';
import { createTestClient } from '@zenstackhq/testtools';
import { afterEach, describe, expect, it } from 'vitest';

Expand Down Expand Up @@ -226,7 +225,7 @@ model Post {
postCount: (eb: any, context: { modelAlias: string }) =>
eb
.selectFrom('Post')
.whereRef('Post.authorId', '=', sql.ref(`${context.modelAlias}.id`))
.whereRef('Post.authorId', '=', eb.ref(`${context.modelAlias}.id`))
.select(() => eb.fn.countAll().as('count')),
},
},
Expand Down
3 changes: 1 addition & 2 deletions tests/regression/test/v2-migrated/issue-1235.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { createPolicyTestClient, testLogger } from '@zenstackhq/testtools';
import { createPolicyTestClient } from '@zenstackhq/testtools';
import { describe, expect, it } from 'vitest';

describe('Regression for issue 1235', () => {
Expand All @@ -11,7 +11,6 @@ model Post {
@@allow('all', true)
}
`,
{ log: testLogger },
);

const post = await db.post.create({ data: {} });
Expand Down
35 changes: 35 additions & 0 deletions tests/regression/test/v2-migrated/issue-1506.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import { createPolicyTestClient } from '@zenstackhq/testtools';
import { it } from 'vitest';

it('verifies issue 1506', async () => {
await createPolicyTestClient(
`
model A {
id Int @id @default(autoincrement())
value Int
b B @relation(fields: [bId], references: [id])
bId Int @unique
@@allow('read', true)
}
model B {
id Int @id @default(autoincrement())
value Int
a A?
c C @relation(fields: [cId], references: [id])
cId Int @unique
@@allow('read', value > c.value)
}
model C {
id Int @id @default(autoincrement())
value Int
b B?
@@allow('read', true)
}
`,
);
});
25 changes: 25 additions & 0 deletions tests/regression/test/v2-migrated/issue-1507.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import { createPolicyTestClient } from '@zenstackhq/testtools';
import { expect, it } from 'vitest';

it('verifies issue 1507', async () => {
const db = await createPolicyTestClient(
`
model User {
id Int @id @default(autoincrement())
age Int
}

model Profile {
id Int @id @default(autoincrement())
age Int

@@allow('read', auth().age == age)
}
`,
);

await db.$unuseAll().profile.create({ data: { age: 18 } });
await db.$unuseAll().profile.create({ data: { age: 20 } });
await expect(db.$setAuth({ id: 1, age: 18 }).profile.findMany()).resolves.toHaveLength(1);
await expect(db.$setAuth({ id: 1, age: 18 }).profile.count()).resolves.toBe(1);
});
30 changes: 30 additions & 0 deletions tests/regression/test/v2-migrated/issue-1518.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import { createTestClient } from '@zenstackhq/testtools';
import { it } from 'vitest';

it('verifies issue 1518', async () => {
const db = await createTestClient(
`
model Activity {
id String @id @default(uuid())
title String
type String
@@delegate(type)
@@allow('all', true)
}

model TaskActivity extends Activity {
description String
@@map("task_activity")
@@allow('all', true)
}
`,
);

await db.taskActivity.create({
data: {
id: '00000000-0000-0000-0000-111111111111',
title: 'Test Activity',
description: 'Description of task',
},
});
});
Loading