diff --git a/packages/language/src/validators/expression-validator.ts b/packages/language/src/validators/expression-validator.ts index cf74db06..f8dc4930 100644 --- a/packages/language/src/validators/expression-validator.ts +++ b/packages/language/src/validators/expression-validator.ts @@ -108,9 +108,12 @@ export default class ExpressionValidator implements AstValidator { supportedShapes = ['Boolean', 'Any']; } + const leftResolvedDecl = expr.left.$resolvedType?.decl; + const rightResolvedDecl = expr.right.$resolvedType?.decl; + if ( - typeof expr.left.$resolvedType?.decl !== 'string' || - !supportedShapes.includes(expr.left.$resolvedType.decl) + leftResolvedDecl && + (typeof leftResolvedDecl !== 'string' || !supportedShapes.includes(leftResolvedDecl)) ) { accept('error', `invalid operand type for "${expr.operator}" operator`, { node: expr.left, @@ -118,8 +121,8 @@ export default class ExpressionValidator implements AstValidator { return; } if ( - typeof expr.right.$resolvedType?.decl !== 'string' || - !supportedShapes.includes(expr.right.$resolvedType.decl) + rightResolvedDecl && + (typeof rightResolvedDecl !== 'string' || !supportedShapes.includes(rightResolvedDecl)) ) { accept('error', `invalid operand type for "${expr.operator}" operator`, { node: expr.right, @@ -128,14 +131,11 @@ export default class ExpressionValidator implements AstValidator { } // DateTime comparison is only allowed between two DateTime values - if (expr.left.$resolvedType.decl === 'DateTime' && expr.right.$resolvedType.decl !== 'DateTime') { + if (leftResolvedDecl === 'DateTime' && rightResolvedDecl && rightResolvedDecl !== 'DateTime') { accept('error', 'incompatible operand types', { node: expr, }); - } else if ( - expr.right.$resolvedType.decl === 'DateTime' && - expr.left.$resolvedType.decl !== 'DateTime' - ) { + } else if (rightResolvedDecl === 'DateTime' && leftResolvedDecl && leftResolvedDecl !== 'DateTime') { accept('error', 'incompatible operand types', { node: expr, }); diff --git a/packages/runtime/src/client/crud/dialects/base-dialect.ts b/packages/runtime/src/client/crud/dialects/base-dialect.ts index 9f314bf9..602a3a4f 100644 --- a/packages/runtime/src/client/crud/dialects/base-dialect.ts +++ b/packages/runtime/src/client/crud/dialects/base-dialect.ts @@ -24,7 +24,6 @@ import { ensureArray, flattenCompoundUniqueFilters, getDelegateDescendantModels, - getIdFields, getManyToManyRelation, getRelationForeignKeyFieldPairs, isEnum, @@ -32,6 +31,7 @@ import { isRelationField, makeDefaultOrderBy, requireField, + requireIdFields, requireModel, } from '../../query-utils'; @@ -366,10 +366,14 @@ export abstract class BaseCrudDialect { const m2m = getManyToManyRelation(this.schema, model, field); if (m2m) { // many-to-many relation - const modelIdField = getIdFields(this.schema, model)[0]!; - const relationIdField = getIdFields(this.schema, relationModel)[0]!; + + const modelIdFields = requireIdFields(this.schema, model); + invariant(modelIdFields.length === 1, 'many-to-many relation must have exactly one id field'); + const relationIdFields = requireIdFields(this.schema, relationModel); + invariant(relationIdFields.length === 1, 'many-to-many relation must have exactly one id field'); + return eb( - sql.ref(`${relationFilterSelectAlias}.${relationIdField}`), + sql.ref(`${relationFilterSelectAlias}.${relationIdFields[0]}`), 'in', eb .selectFrom(m2m.joinTable) @@ -377,7 +381,7 @@ export abstract class BaseCrudDialect { .whereRef( sql.ref(`${m2m.joinTable}.${m2m.parentFkName}`), '=', - sql.ref(`${modelAlias}.${modelIdField}`), + sql.ref(`${modelAlias}.${modelIdFields[0]}`), ), ); } else { @@ -1012,7 +1016,7 @@ export abstract class BaseCrudDialect { otherModelAlias: string, query: SelectQueryBuilder, ) { - const idFields = getIdFields(this.schema, thisModel); + const idFields = requireIdFields(this.schema, thisModel); query = query.leftJoin(otherModelAlias, (qb) => { for (const idField of idFields) { qb = qb.onRef(`${thisModelAlias}.${idField}`, '=', `${otherModelAlias}.${idField}`); diff --git a/packages/runtime/src/client/crud/dialects/postgresql.ts b/packages/runtime/src/client/crud/dialects/postgresql.ts index a71e987d..07606133 100644 --- a/packages/runtime/src/client/crud/dialects/postgresql.ts +++ b/packages/runtime/src/client/crud/dialects/postgresql.ts @@ -14,10 +14,10 @@ import type { FindArgs } from '../../crud-types'; import { buildJoinPairs, getDelegateDescendantModels, - getIdFields, getManyToManyRelation, isRelationField, requireField, + requireIdFields, requireModel, } from '../../query-utils'; import { BaseCrudDialect } from './base-dialect'; @@ -157,8 +157,8 @@ export class PostgresCrudDialect extends BaseCrudDiale const m2m = getManyToManyRelation(this.schema, model, relationField); if (m2m) { // many-to-many relation - const parentIds = getIdFields(this.schema, model); - const relationIds = getIdFields(this.schema, relationModel); + const parentIds = requireIdFields(this.schema, model); + const relationIds = requireIdFields(this.schema, relationModel); invariant(parentIds.length === 1, 'many-to-many relation must have exactly one id field'); invariant(relationIds.length === 1, 'many-to-many relation must have exactly one id field'); query = query.where((eb) => diff --git a/packages/runtime/src/client/crud/dialects/sqlite.ts b/packages/runtime/src/client/crud/dialects/sqlite.ts index 69de608d..5f4515ed 100644 --- a/packages/runtime/src/client/crud/dialects/sqlite.ts +++ b/packages/runtime/src/client/crud/dialects/sqlite.ts @@ -14,10 +14,10 @@ import { DELEGATE_JOINED_FIELD_PREFIX } from '../../constants'; import type { FindArgs } from '../../crud-types'; import { getDelegateDescendantModels, - getIdFields, getManyToManyRelation, getRelationForeignKeyFieldPairs, requireField, + requireIdFields, requireModel, } from '../../query-utils'; import { BaseCrudDialect } from './base-dialect'; @@ -213,8 +213,8 @@ export class SqliteCrudDialect extends BaseCrudDialect const m2m = getManyToManyRelation(this.schema, model, relationField); if (m2m) { // many-to-many relation - const parentIds = getIdFields(this.schema, model); - const relationIds = getIdFields(this.schema, relationModel); + const parentIds = requireIdFields(this.schema, model); + const relationIds = requireIdFields(this.schema, relationModel); invariant(parentIds.length === 1, 'many-to-many relation must have exactly one id field'); invariant(relationIds.length === 1, 'many-to-many relation must have exactly one id field'); selectModelQuery = selectModelQuery.where((eb) => diff --git a/packages/runtime/src/client/crud/operations/base.ts b/packages/runtime/src/client/crud/operations/base.ts index 6d5cae1b..a4ce1c52 100644 --- a/packages/runtime/src/client/crud/operations/base.ts +++ b/packages/runtime/src/client/crud/operations/base.ts @@ -31,7 +31,6 @@ import { flattenCompoundUniqueFilters, getDiscriminatorField, getField, - getIdFields, getIdValues, getManyToManyRelation, getModel, @@ -40,6 +39,7 @@ import { isRelationField, isScalarField, requireField, + requireIdFields, requireModel, } from '../../query-utils'; import { getCrudDialect } from '../dialects'; @@ -132,7 +132,7 @@ export abstract class BaseOperationHandler { model: GetModels, filter: any, ): Promise { - const idFields = getIdFields(this.schema, model); + const idFields = requireIdFields(this.schema, model); const _filter = flattenCompoundUniqueFilters(this.schema, model, filter); const query = kysely .selectFrom(model) @@ -344,7 +344,7 @@ export abstract class BaseOperationHandler { } const updatedData = this.fillGeneratedAndDefaultValues(modelDef, createFields); - const idFields = getIdFields(this.schema, model); + const idFields = requireIdFields(this.schema, model); const query = kysely .insertInto(model) .$if(Object.keys(updatedData).length === 0, (qb) => qb.defaultValues()) @@ -481,8 +481,8 @@ export abstract class BaseOperationHandler { a.model !== b.model ? a.model.localeCompare(b.model) : a.field.localeCompare(b.field), ); - const firstIds = getIdFields(this.schema, sortedRecords[0]!.model); - const secondIds = getIdFields(this.schema, sortedRecords[1]!.model); + const firstIds = requireIdFields(this.schema, sortedRecords[0]!.model); + const secondIds = requireIdFields(this.schema, sortedRecords[1]!.model); invariant(firstIds.length === 1, 'many-to-many relation must have exactly one id field'); invariant(secondIds.length === 1, 'many-to-many relation must have exactly one id field'); @@ -771,7 +771,7 @@ export abstract class BaseOperationHandler { const result = await this.executeQuery(kysely, query, 'createMany'); return { count: Number(result.numAffectedRows) } as Result; } else { - const idFields = getIdFields(this.schema, model); + const idFields = requireIdFields(this.schema, model); const result = await query.returning(idFields as any).execute(); return result as Result; } @@ -1039,7 +1039,7 @@ export abstract class BaseOperationHandler { // nothing to update, return the filter so that the caller can identify the entity return combinedWhere; } else { - const idFields = getIdFields(this.schema, model); + const idFields = requireIdFields(this.schema, model); const query = kysely .updateTable(model) .where((eb) => this.dialect.buildFilter(eb, model, model, combinedWhere)) @@ -1104,7 +1104,7 @@ export abstract class BaseOperationHandler { if (!filter || typeof filter !== 'object') { return false; } - const idFields = getIdFields(this.schema, model); + const idFields = requireIdFields(this.schema, model); return idFields.length === Object.keys(filter).length && idFields.every((field) => field in filter); } @@ -1297,7 +1297,7 @@ export abstract class BaseOperationHandler { const result = await this.executeQuery(kysely, query, 'update'); return { count: Number(result.numAffectedRows) } as Result; } else { - const idFields = getIdFields(this.schema, model); + const idFields = requireIdFields(this.schema, model); const result = await query.returning(idFields as any).execute(); return result as Result; } @@ -1336,7 +1336,7 @@ export abstract class BaseOperationHandler { } private buildIdFieldRefs(kysely: ToKysely, model: GetModels) { - const idFields = getIdFields(this.schema, model); + const idFields = requireIdFields(this.schema, model); return idFields.map((f) => kysely.dynamic.ref(`${model}.${f}`)); } @@ -2097,7 +2097,7 @@ export abstract class BaseOperationHandler { // reused the filter if it's a complete id filter (without extra fields) // otherwise, read the entity by the filter private getEntityIds(kysely: ToKysely, model: GetModels, uniqueFilter: any) { - const idFields: string[] = getIdFields(this.schema, model); + const idFields: string[] = requireIdFields(this.schema, model); if ( // all id fields are provided idFields.every((f) => f in uniqueFilter && uniqueFilter[f] !== undefined) && diff --git a/packages/runtime/src/client/query-utils.ts b/packages/runtime/src/client/query-utils.ts index fdce2aaf..b574ed05 100644 --- a/packages/runtime/src/client/query-utils.ts +++ b/packages/runtime/src/client/query-utils.ts @@ -55,8 +55,8 @@ export function requireField(schema: SchemaDef, modelOrType: string, field: stri } export function getIdFields(schema: SchemaDef, model: GetModels) { - const modelDef = requireModel(schema, model); - return modelDef?.idFields as GetModels[]; + const modelDef = getModel(schema, model); + return modelDef?.idFields; } export function requireIdFields(schema: SchemaDef, model: string) { @@ -231,7 +231,7 @@ export function buildJoinPairs( } export function makeDefaultOrderBy(schema: SchemaDef, model: string) { - const idFields = getIdFields(schema, model); + const idFields = requireIdFields(schema, model); return idFields.map((f) => ({ [f]: 'asc' }) as OrderBy, true, false>); } @@ -318,7 +318,7 @@ export function safeJSONStringify(value: unknown) { } export function extractIdFields(entity: any, schema: SchemaDef, model: string) { - const idFields = getIdFields(schema, model); + const idFields = requireIdFields(schema, model); return extractFields(entity, idFields); } diff --git a/packages/runtime/src/plugins/policy/expression-transformer.ts b/packages/runtime/src/plugins/policy/expression-transformer.ts index 2c39cbb7..04df8cca 100644 --- a/packages/runtime/src/plugins/policy/expression-transformer.ts +++ b/packages/runtime/src/plugins/policy/expression-transformer.ts @@ -25,7 +25,7 @@ import { getCrudDialect } from '../../client/crud/dialects'; import type { BaseCrudDialect } from '../../client/crud/dialects/base-dialect'; import { InternalError, QueryError } from '../../client/errors'; import type { ClientOptions } from '../../client/options'; -import { getIdFields, getModel, getRelationForeignKeyFieldPairs, requireField } from '../../client/query-utils'; +import { getModel, getRelationForeignKeyFieldPairs, requireField, requireIdFields } from '../../client/query-utils'; import type { BinaryExpression, BinaryOperator, @@ -196,16 +196,18 @@ export class ExpressionTransformer { } private normalizeBinaryOperationOperands(expr: BinaryExpression, context: ExpressionTransformerContext) { + // if relation fields are used directly in comparison, it can only be compared with null, + // so we normalize the args with the id field (use the first id field if multiple) let normalizedLeft: Expression = expr.left; if (this.isRelationField(expr.left, context.model)) { invariant(ExpressionUtils.isNull(expr.right), 'only null comparison is supported for relation field'); - const idFields = getIdFields(this.schema, context.model); + const idFields = requireIdFields(this.schema, context.model); normalizedLeft = this.makeOrAppendMember(normalizedLeft, idFields[0]!); } let normalizedRight: Expression = expr.right; if (this.isRelationField(expr.right, context.model)) { invariant(ExpressionUtils.isNull(expr.left), 'only null comparison is supported for relation field'); - const idFields = getIdFields(this.schema, context.model); + const idFields = requireIdFields(this.schema, context.model); normalizedRight = this.makeOrAppendMember(normalizedRight, idFields[0]!); } return { normalizedLeft, normalizedRight }; diff --git a/packages/runtime/src/plugins/policy/policy-handler.ts b/packages/runtime/src/plugins/policy/policy-handler.ts index 6d018980..029ec58c 100644 --- a/packages/runtime/src/plugins/policy/policy-handler.ts +++ b/packages/runtime/src/plugins/policy/policy-handler.ts @@ -8,6 +8,7 @@ import { FunctionNode, IdentifierNode, InsertQueryNode, + JoinNode, OperationNodeTransformer, OperatorNode, ParensNode, @@ -31,9 +32,9 @@ import type { ClientContract } from '../../client'; import type { CRUD } from '../../client/contract'; import { getCrudDialect } from '../../client/crud/dialects'; import type { BaseCrudDialect } from '../../client/crud/dialects/base-dialect'; -import { InternalError } from '../../client/errors'; +import { InternalError, QueryError } from '../../client/errors'; import type { ProceedKyselyQueryFunction } from '../../client/plugin'; -import { getIdFields, requireField, requireModel } from '../../client/query-utils'; +import { requireField, requireIdFields, requireModel } from '../../client/query-utils'; import { ExpressionUtils, type BuiltinType, type Expression, type GetModels, type SchemaDef } from '../../schema'; import { ColumnCollector } from './column-collector'; import { RejectedByPolicyError } from './errors'; @@ -72,7 +73,7 @@ export class PolicyHandler extends OperationNodeTransf } let mutationRequiresTransaction = false; - const mutationModel = this.getMutationModel(node); + const { mutationModel } = this.getMutationModel(node); if (InsertQueryNode.is(node)) { // reject create if unconditional deny @@ -138,18 +139,15 @@ export class PolicyHandler extends OperationNodeTransf // #region overrides protected override transformSelectQuery(node: SelectQueryNode) { - let whereNode = node.where; + let whereNode = this.transformNode(node.where); - node.from?.froms.forEach((from) => { - const extractResult = this.extractTableName(from); - if (extractResult) { - const { model, alias } = extractResult; - const filter = this.buildPolicyFilter(model, alias, 'read'); - whereNode = WhereNode.create( - whereNode?.where ? conjunction(this.dialect, [whereNode.where, filter]) : filter, - ); - } - }); + // get combined policy filter for all froms, and merge into where clause + const policyFilter = this.createPolicyFilterForFrom(node.from); + if (policyFilter) { + whereNode = WhereNode.create( + whereNode?.where ? conjunction(this.dialect, [whereNode.where, policyFilter]) : policyFilter, + ); + } const baseResult = super.transformSelectQuery({ ...node, @@ -162,6 +160,27 @@ export class PolicyHandler extends OperationNodeTransf }; } + protected override transformJoin(node: JoinNode) { + const table = this.extractTableName(node.table); + if (!table) { + // unable to extract table name, can be a subquery, which will be handled when nested transformation happens + return super.transformJoin(node); + } + + // build a nested query with policy filter applied + const filter = this.buildPolicyFilter(table.model, table.alias, 'read'); + const nestedSelect: SelectQueryNode = { + kind: 'SelectQueryNode', + from: FromNode.create([node.table]), + selections: [SelectionNode.createSelectAll()], + where: WhereNode.create(filter), + }; + return { + ...node, + table: AliasNode.create(ParensNode.create(nestedSelect), IdentifierNode.create(table.alias ?? table.model)), + }; + } + protected override transformInsertQuery(node: InsertQueryNode) { // pre-insert check is done in `handle()` @@ -169,8 +188,8 @@ export class PolicyHandler extends OperationNodeTransf if (onConflict?.updates) { // for "on conflict do update", we need to apply policy filter to the "where" clause - const mutationModel = this.getMutationModel(node); - const filter = this.buildPolicyFilter(mutationModel, undefined, 'update'); + const { mutationModel, alias } = this.getMutationModel(node); + const filter = this.buildPolicyFilter(mutationModel, alias, 'update'); if (onConflict.updateWhere) { onConflict = { ...onConflict, @@ -197,7 +216,8 @@ export class PolicyHandler extends OperationNodeTransf return result; } else { // only return ID fields, that's enough for reading back the inserted row - const idFields = getIdFields(this.client.$schema, this.getMutationModel(node)); + const { mutationModel } = this.getMutationModel(node); + const idFields = requireIdFields(this.client.$schema, mutationModel); return { ...result, returning: ReturningNode.create( @@ -209,8 +229,17 @@ export class PolicyHandler extends OperationNodeTransf protected override transformUpdateQuery(node: UpdateQueryNode) { const result = super.transformUpdateQuery(node); - const mutationModel = this.getMutationModel(node); - const filter = this.buildPolicyFilter(mutationModel, undefined, 'update'); + const { mutationModel, alias } = this.getMutationModel(node); + let filter = this.buildPolicyFilter(mutationModel, alias, 'update'); + + if (node.from) { + // for update with from (join), we need to merge join tables' policy filters to the "where" clause + const joinFilter = this.createPolicyFilterForFrom(node.from); + if (joinFilter) { + filter = conjunction(this.dialect, [filter, joinFilter]); + } + } + return { ...result, where: WhereNode.create(result.where ? conjunction(this.dialect, [result.where.where, filter]) : filter), @@ -219,8 +248,17 @@ export class PolicyHandler extends OperationNodeTransf protected override transformDeleteQuery(node: DeleteQueryNode) { const result = super.transformDeleteQuery(node); - const mutationModel = this.getMutationModel(node); - const filter = this.buildPolicyFilter(mutationModel, undefined, 'delete'); + const { mutationModel, alias } = this.getMutationModel(node); + let filter = this.buildPolicyFilter(mutationModel, alias, 'delete'); + + if (node.using) { + // for delete with using (join), we need to merge join tables' policy filters to the "where" clause + const joinFilter = this.createPolicyFilterForTables(node.using.tables); + if (joinFilter) { + filter = conjunction(this.dialect, [filter, joinFilter]); + } + } + return { ...result, where: WhereNode.create(result.where ? conjunction(this.dialect, [result.where.where, filter]) : filter), @@ -235,19 +273,20 @@ export class PolicyHandler extends OperationNodeTransf if (!node.returning) { return true; } - const idFields = getIdFields(this.client.$schema, this.getMutationModel(node)); + const { mutationModel } = this.getMutationModel(node); + const idFields = requireIdFields(this.client.$schema, mutationModel); const collector = new ColumnCollector(); const selectedColumns = collector.collect(node.returning); return selectedColumns.every((c) => idFields.includes(c)); } private async enforcePreCreatePolicy(node: InsertQueryNode, proceed: ProceedKyselyQueryFunction) { - const model = this.getMutationModel(node); + const { mutationModel } = this.getMutationModel(node); const fields = node.columns?.map((c) => c.column.name) ?? []; - const valueRows = node.values ? this.unwrapCreateValueRows(node.values, model, fields) : [[]]; + const valueRows = node.values ? this.unwrapCreateValueRows(node.values, mutationModel, fields) : [[]]; for (const values of valueRows) { await this.enforcePreCreatePolicyForOne( - model, + mutationModel, fields, values.map((v) => v.node), proceed, @@ -394,17 +433,13 @@ export class PolicyHandler extends OperationNodeTransf } // do a select (with policy) in place of returning - const table = this.getMutationModel(node); - if (!table) { - throw new InternalError(`Unable to get table name for query node: ${node}`); - } - - const idConditions = this.buildIdConditions(table, result.rows); - const policyFilter = this.buildPolicyFilter(table, undefined, 'read'); + const { mutationModel } = this.getMutationModel(node); + const idConditions = this.buildIdConditions(mutationModel, result.rows); + const policyFilter = this.buildPolicyFilter(mutationModel, undefined, 'read'); const select: SelectQueryNode = { kind: 'SelectQueryNode', - from: FromNode.create([TableNode.create(table)]), + from: FromNode.create([TableNode.create(mutationModel)]), where: WhereNode.create(conjunction(this.dialect, [idConditions, policyFilter])), selections: node.returning.selections, }; @@ -413,7 +448,7 @@ export class PolicyHandler extends OperationNodeTransf } private buildIdConditions(table: string, rows: any[]): OperationNode { - const idFields = getIdFields(this.client.$schema, table); + const idFields = requireIdFields(this.client.$schema, table); return disjunction( this.dialect, rows.map((row) => @@ -433,13 +468,23 @@ export class PolicyHandler extends OperationNodeTransf private getMutationModel(node: InsertQueryNode | UpdateQueryNode | DeleteQueryNode) { const r = match(node) - .when(InsertQueryNode.is, (node) => getTableName(node.into) as GetModels) - .when(UpdateQueryNode.is, (node) => getTableName(node.table) as GetModels) + .when(InsertQueryNode.is, (node) => ({ + mutationModel: getTableName(node.into) as GetModels, + alias: undefined, + })) + .when(UpdateQueryNode.is, (node) => { + if (!node.table) { + throw new QueryError('Update query must have a table'); + } + const r = this.extractTableName(node.table); + return r ? { mutationModel: r.model, alias: r.alias } : undefined; + }) .when(DeleteQueryNode.is, (node) => { if (node.from.froms.length !== 1) { - throw new InternalError('Only one from table is supported for delete'); + throw new QueryError('Only one from table is supported for delete'); } - return getTableName(node.from.froms[0]) as GetModels; + const r = this.extractTableName(node.from.froms[0]!); + return r ? { mutationModel: r.model, alias: r.alias } : undefined; }) .exhaustive(); if (!r) { @@ -466,11 +511,11 @@ export class PolicyHandler extends OperationNodeTransf const allows = policies .filter((policy) => policy.kind === 'allow') - .map((policy) => this.transformPolicyCondition(model, alias, operation, policy)); + .map((policy) => this.compilePolicyCondition(model, alias, operation, policy)); const denies = policies .filter((policy) => policy.kind === 'deny') - .map((policy) => this.transformPolicyCondition(model, alias, operation, policy)); + .map((policy) => this.compilePolicyCondition(model, alias, operation, policy)); let combinedPolicy: OperationNode; @@ -494,18 +539,18 @@ export class PolicyHandler extends OperationNodeTransf return combinedPolicy; } - private extractTableName(from: OperationNode): { model: GetModels; alias?: string } | undefined { - if (TableNode.is(from)) { - return { model: from.table.identifier.name as GetModels }; + private extractTableName(node: OperationNode): { model: GetModels; alias?: string } | undefined { + if (TableNode.is(node)) { + return { model: node.table.identifier.name as GetModels }; } - if (AliasNode.is(from)) { - const inner = this.extractTableName(from.node); + if (AliasNode.is(node)) { + const inner = this.extractTableName(node.node); if (!inner) { return undefined; } return { model: inner.model, - alias: IdentifierNode.is(from.alias) ? from.alias.name : undefined, + alias: IdentifierNode.is(node.alias) ? node.alias.name : undefined, }; } else { // this can happen for subqueries, which will be handled when nested @@ -514,7 +559,26 @@ export class PolicyHandler extends OperationNodeTransf } } - private transformPolicyCondition( + private createPolicyFilterForFrom(node: FromNode | undefined) { + if (!node) { + return undefined; + } + return this.createPolicyFilterForTables(node.froms); + } + + private createPolicyFilterForTables(tables: readonly OperationNode[]) { + return tables.reduce((acc, table) => { + const extractResult = this.extractTableName(table); + if (extractResult) { + const { model, alias } = extractResult; + const filter = this.buildPolicyFilter(model, alias, 'read'); + return acc ? conjunction(this.dialect, [acc, filter]) : filter; + } + return acc; + }, undefined); + } + + private compilePolicyCondition( model: GetModels, alias: string | undefined, operation: CRUD, diff --git a/packages/runtime/test/policy/crud/read.test.ts b/packages/runtime/test/policy/crud/read.test.ts index a0e42815..467abea6 100644 --- a/packages/runtime/test/policy/crud/read.test.ts +++ b/packages/runtime/test/policy/crud/read.test.ts @@ -244,5 +244,372 @@ model Bar { await db.foo.update({ where: { id: 1 }, data: { bars: { create: { id: 2, y: 1 } } } }); await expect(db.foo.findMany()).resolves.toHaveLength(1); }); + + it('works with counting relations', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + bars Bar[] + @@allow('all', true) +} + +model Bar { + id Int @id + y Int + foo Foo? @relation(fields: [fooId], references: [id]) + fooId Int? + @@allow('create,update', true) + @@allow('read', y > 0) +} +`, + ); + + await db.$unuseAll().foo.create({ + data: { + id: 1, + bars: { + create: [ + { id: 1, y: 0 }, + { id: 2, y: 1 }, + ], + }, + }, + }); + await expect( + db.foo.findFirst({ where: { id: 1 }, select: { _count: { select: { bars: true } } } }), + ).resolves.toMatchObject({ _count: { bars: 1 } }); + }); + }); + + describe('Count tests', () => { + it('works with top-level count', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + name String + @@allow('create', true) + @@allow('read', x > 0) +} +`, + ); + + await db.$unuseAll().foo.create({ data: { id: 1, x: 0, name: 'Foo1' } }); + await db.$unuseAll().foo.create({ data: { id: 2, x: 0, name: 'Foo2' } }); + await expect(db.foo.count()).resolves.toBe(0); + await expect(db.foo.count({ select: { _all: true, name: true } })).resolves.toEqual({ _all: 0, name: 0 }); + + await db.$unuseAll().foo.update({ where: { id: 1 }, data: { x: 1 } }); + await expect(db.foo.count()).resolves.toBe(1); + await expect(db.foo.count({ select: { _all: true, name: true } })).resolves.toEqual({ _all: 1, name: 1 }); + }); + }); + + describe('Aggregate tests', () => { + it('respects read policies', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + @@allow('create', true) + @@allow('read', x > 0) +} +`, + ); + + await db.$unuseAll().foo.create({ data: { id: 1, x: 0 } }); + await db.$unuseAll().foo.create({ data: { id: 2, x: 1 } }); + await db.$unuseAll().foo.create({ data: { id: 3, x: 3 } }); + + await expect( + db.foo.aggregate({ + _count: true, + _sum: { x: true }, + _avg: { x: true }, + _min: { x: true }, + _max: { x: true }, + }), + ).resolves.toEqual({ + _count: 2, + _sum: { x: 4 }, + _avg: { x: 2 }, + _min: { x: 1 }, + _max: { x: 3 }, + }); + }); + }); + + describe('GroupBy tests', () => { + it('respects read policies', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + y Int + @@allow('create', true) + @@allow('read', x > 0) +} +`, + ); + + await db.$unuseAll().foo.create({ data: { id: 1, x: 0, y: 1 } }); + await db.$unuseAll().foo.create({ data: { id: 2, x: 1, y: 1 } }); + await db.$unuseAll().foo.create({ data: { id: 3, x: 3, y: 2 } }); + await db.$unuseAll().foo.create({ data: { id: 4, x: 5, y: 2 } }); + + await expect( + db.foo.groupBy({ + by: ['y'], + _count: { _all: true }, + _sum: { x: true }, + _avg: { x: true }, + _min: { x: true }, + _max: { x: true }, + orderBy: { y: 'asc' }, + }), + ).resolves.toEqual([ + { + y: 1, + _count: { _all: 1 }, + _sum: { x: 1 }, + _avg: { x: 1 }, + _min: { x: 1 }, + _max: { x: 1 }, + }, + { + y: 2, + _count: { _all: 2 }, + _sum: { x: 8 }, + _avg: { x: 4 }, + _min: { x: 3 }, + _max: { x: 5 }, + }, + ]); + }); + }); + + describe('Query builder tests', () => { + it('works with simple selects', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + @@allow('create', true) + @@allow('read', x > 0) +} +`, + ); + + await db.$unuseAll().foo.create({ data: { id: 1, x: 0 } }); + await db.$unuseAll().foo.create({ data: { id: 2, x: 1 } }); + + await expect(db.$qb.selectFrom('Foo').selectAll().execute()).resolves.toHaveLength(1); + await expect(db.$qb.selectFrom('Foo as f').selectAll().execute()).resolves.toHaveLength(1); + await expect(db.$qb.selectFrom('Foo').selectAll().execute()).resolves.toHaveLength(1); + await expect(db.$qb.selectFrom('Foo').where('id', '=', 1).selectAll().execute()).resolves.toHaveLength(0); + + // nested query + await expect( + db.$qb + .selectFrom((eb: any) => eb.selectFrom('Foo').selectAll().as('f')) + .selectAll() + .execute(), + ).resolves.toHaveLength(1); + await expect( + db.$qb + .selectFrom((eb: any) => eb.selectFrom('Foo').selectAll().as('f')) + .selectAll() + .where('f.id', '=', 1) + .execute(), + ).resolves.toHaveLength(0); + }); + + it('works with joins', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + bars Bar[] + @@allow('create', true) + @@allow('read', x > 0) +} + +model Bar { + id Int @id + y Int + foo Foo? @relation(fields: [fooId], references: [id]) + fooId Int? + @@allow('create', true) + @@allow('read', y > 0) +} +`, + ); + + await db.$unuseAll().foo.create({ + data: { + id: 1, + x: 1, + bars: { + create: [ + { id: 1, y: 0 }, + { id: 2, y: 1 }, + ], + }, + }, + }); + await db.$unuseAll().foo.create({ + data: { + id: 2, + x: 0, + bars: { + create: { id: 3, y: 1 }, + }, + }, + }); + + // direct join + await expect( + db.$qb.selectFrom('Foo').innerJoin('Bar', 'Bar.fooId', 'Foo.id').select(['Foo.id', 'x', 'y']).execute(), + ).resolves.toEqual([expect.objectContaining({ id: 1, x: 1, y: 1 })]); + + // through alias + await expect( + db.$qb + .selectFrom('Foo as f') + .innerJoin( + (eb: any) => eb.selectFrom('Bar').selectAll().as('b'), + (join: any) => join.onRef('b.fooId', '=', 'f.id'), + ) + .select(['f.id', 'x', 'y']) + .execute(), + ).resolves.toEqual([expect.objectContaining({ id: 1, x: 1, y: 1 })]); + }); + + it('works with implicit cross join', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + @@allow('create', true) + @@allow('read', x > 0) +} + +model Bar { + id Int @id + y Int + @@allow('create', true) + @@allow('read', y > 0) +} +`, + { provider: 'postgresql', dbName: 'policy-test-implicit-cross-join' }, + ); + + await db.$unuseAll().foo.createMany({ + data: [ + { id: 1, x: 1 }, + { id: 2, x: 0 }, + ], + }); + await db.$unuseAll().bar.createMany({ + data: [ + { id: 1, y: 1 }, + { id: 2, y: 0 }, + ], + }); + + await expect( + db.$qb.selectFrom(['Foo', 'Bar']).select(['Foo.id as fooId', 'Bar.id as barId', 'x', 'y']).execute(), + ).resolves.toEqual([ + { + fooId: 1, + barId: 1, + x: 1, + y: 1, + }, + ]); + }); + + it('works with update from', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + @@allow('all', true) +} + +model Bar { + id Int @id + y Int + @@allow('read', y > 0) +} +`, + ); + + await db.$unuseAll().foo.create({ data: { id: 1, x: 1 } }); + await db.$unuseAll().bar.create({ data: { id: 1, y: 0 } }); + + // update with from, only one row is visible + await expect( + db.$qb + .updateTable('Foo') + .from('Bar as bar') + .whereRef('Foo.id', '=', 'bar.id') + .set((eb: any) => ({ x: eb.ref('bar.y') })) + .executeTakeFirst(), + ).resolves.toMatchObject({ numUpdatedRows: 0n }); + await expect(db.foo.findFirst()).resolves.toMatchObject({ x: 1 }); + + await db.$unuseAll().bar.update({ where: { id: 1 }, data: { y: 2 } }); + await expect( + db.$qb + .updateTable('Foo') + .from('Bar as bar') + .whereRef('Foo.id', '=', 'bar.id') + .set((eb: any) => ({ x: eb.ref('bar.y') })) + .executeTakeFirst(), + ).resolves.toMatchObject({ numUpdatedRows: 1n }); + await expect(db.foo.findFirst()).resolves.toMatchObject({ x: 2 }); + }); + + it('works with delete using', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + @@allow('all', true) +} + +model Bar { + id Int @id + y Int + @@allow('read', y > 0) +} +`, + { provider: 'postgresql', dbName: 'policy-test-delete-using' }, + ); + + await db.$unuseAll().foo.create({ data: { id: 1, x: 1 } }); + await db.$unuseAll().bar.create({ data: { id: 1, y: 0 } }); + + await expect( + db.$qb.deleteFrom('Foo').using('Bar as bar').whereRef('Foo.id', '=', 'bar.id').executeTakeFirst(), + ).resolves.toMatchObject({ numDeletedRows: 0n }); + await expect(db.foo.findFirst()).resolves.toBeTruthy(); + + await db.$unuseAll().bar.update({ where: { id: 1 }, data: { y: 2 } }); + await expect( + db.$qb.deleteFrom('Foo').using('Bar as bar').whereRef('Foo.id', '=', 'bar.id').executeTakeFirst(), + ).resolves.toMatchObject({ numDeletedRows: 1n }); + await expect(db.foo.findFirst()).resolves.toBeNull(); + }); }); });