Skip to content

Commit 75723c7

Browse files
authored
feat(policy): support read filtering for update with "from" and delete with "using" (#253)
* feat(policy): support read filtering for update with "from" and delete with "using" * addressing pr comments * more robust alias handling * addressing pr comments
1 parent 50e92e0 commit 75723c7

File tree

9 files changed

+523
-86
lines changed

9 files changed

+523
-86
lines changed

packages/language/src/validators/expression-validator.ts

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -108,18 +108,21 @@ export default class ExpressionValidator implements AstValidator<Expression> {
108108
supportedShapes = ['Boolean', 'Any'];
109109
}
110110

111+
const leftResolvedDecl = expr.left.$resolvedType?.decl;
112+
const rightResolvedDecl = expr.right.$resolvedType?.decl;
113+
111114
if (
112-
typeof expr.left.$resolvedType?.decl !== 'string' ||
113-
!supportedShapes.includes(expr.left.$resolvedType.decl)
115+
leftResolvedDecl &&
116+
(typeof leftResolvedDecl !== 'string' || !supportedShapes.includes(leftResolvedDecl))
114117
) {
115118
accept('error', `invalid operand type for "${expr.operator}" operator`, {
116119
node: expr.left,
117120
});
118121
return;
119122
}
120123
if (
121-
typeof expr.right.$resolvedType?.decl !== 'string' ||
122-
!supportedShapes.includes(expr.right.$resolvedType.decl)
124+
rightResolvedDecl &&
125+
(typeof rightResolvedDecl !== 'string' || !supportedShapes.includes(rightResolvedDecl))
123126
) {
124127
accept('error', `invalid operand type for "${expr.operator}" operator`, {
125128
node: expr.right,
@@ -128,14 +131,11 @@ export default class ExpressionValidator implements AstValidator<Expression> {
128131
}
129132

130133
// DateTime comparison is only allowed between two DateTime values
131-
if (expr.left.$resolvedType.decl === 'DateTime' && expr.right.$resolvedType.decl !== 'DateTime') {
134+
if (leftResolvedDecl === 'DateTime' && rightResolvedDecl && rightResolvedDecl !== 'DateTime') {
132135
accept('error', 'incompatible operand types', {
133136
node: expr,
134137
});
135-
} else if (
136-
expr.right.$resolvedType.decl === 'DateTime' &&
137-
expr.left.$resolvedType.decl !== 'DateTime'
138-
) {
138+
} else if (rightResolvedDecl === 'DateTime' && leftResolvedDecl && leftResolvedDecl !== 'DateTime') {
139139
accept('error', 'incompatible operand types', {
140140
node: expr,
141141
});

packages/runtime/src/client/crud/dialects/base-dialect.ts

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,14 @@ import {
2424
ensureArray,
2525
flattenCompoundUniqueFilters,
2626
getDelegateDescendantModels,
27-
getIdFields,
2827
getManyToManyRelation,
2928
getRelationForeignKeyFieldPairs,
3029
isEnum,
3130
isInheritedField,
3231
isRelationField,
3332
makeDefaultOrderBy,
3433
requireField,
34+
requireIdFields,
3535
requireModel,
3636
} from '../../query-utils';
3737

@@ -366,18 +366,22 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
366366
const m2m = getManyToManyRelation(this.schema, model, field);
367367
if (m2m) {
368368
// many-to-many relation
369-
const modelIdField = getIdFields(this.schema, model)[0]!;
370-
const relationIdField = getIdFields(this.schema, relationModel)[0]!;
369+
370+
const modelIdFields = requireIdFields(this.schema, model);
371+
invariant(modelIdFields.length === 1, 'many-to-many relation must have exactly one id field');
372+
const relationIdFields = requireIdFields(this.schema, relationModel);
373+
invariant(relationIdFields.length === 1, 'many-to-many relation must have exactly one id field');
374+
371375
return eb(
372-
sql.ref(`${relationFilterSelectAlias}.${relationIdField}`),
376+
sql.ref(`${relationFilterSelectAlias}.${relationIdFields[0]}`),
373377
'in',
374378
eb
375379
.selectFrom(m2m.joinTable)
376380
.select(`${m2m.joinTable}.${m2m.otherFkName}`)
377381
.whereRef(
378382
sql.ref(`${m2m.joinTable}.${m2m.parentFkName}`),
379383
'=',
380-
sql.ref(`${modelAlias}.${modelIdField}`),
384+
sql.ref(`${modelAlias}.${modelIdFields[0]}`),
381385
),
382386
);
383387
} else {
@@ -1012,7 +1016,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
10121016
otherModelAlias: string,
10131017
query: SelectQueryBuilder<any, any, any>,
10141018
) {
1015-
const idFields = getIdFields(this.schema, thisModel);
1019+
const idFields = requireIdFields(this.schema, thisModel);
10161020
query = query.leftJoin(otherModelAlias, (qb) => {
10171021
for (const idField of idFields) {
10181022
qb = qb.onRef(`${thisModelAlias}.${idField}`, '=', `${otherModelAlias}.${idField}`);

packages/runtime/src/client/crud/dialects/postgresql.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ import type { FindArgs } from '../../crud-types';
1414
import {
1515
buildJoinPairs,
1616
getDelegateDescendantModels,
17-
getIdFields,
1817
getManyToManyRelation,
1918
isRelationField,
2019
requireField,
20+
requireIdFields,
2121
requireModel,
2222
} from '../../query-utils';
2323
import { BaseCrudDialect } from './base-dialect';
@@ -157,8 +157,8 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
157157
const m2m = getManyToManyRelation(this.schema, model, relationField);
158158
if (m2m) {
159159
// many-to-many relation
160-
const parentIds = getIdFields(this.schema, model);
161-
const relationIds = getIdFields(this.schema, relationModel);
160+
const parentIds = requireIdFields(this.schema, model);
161+
const relationIds = requireIdFields(this.schema, relationModel);
162162
invariant(parentIds.length === 1, 'many-to-many relation must have exactly one id field');
163163
invariant(relationIds.length === 1, 'many-to-many relation must have exactly one id field');
164164
query = query.where((eb) =>

packages/runtime/src/client/crud/dialects/sqlite.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ import { DELEGATE_JOINED_FIELD_PREFIX } from '../../constants';
1414
import type { FindArgs } from '../../crud-types';
1515
import {
1616
getDelegateDescendantModels,
17-
getIdFields,
1817
getManyToManyRelation,
1918
getRelationForeignKeyFieldPairs,
2019
requireField,
20+
requireIdFields,
2121
requireModel,
2222
} from '../../query-utils';
2323
import { BaseCrudDialect } from './base-dialect';
@@ -213,8 +213,8 @@ export class SqliteCrudDialect<Schema extends SchemaDef> extends BaseCrudDialect
213213
const m2m = getManyToManyRelation(this.schema, model, relationField);
214214
if (m2m) {
215215
// many-to-many relation
216-
const parentIds = getIdFields(this.schema, model);
217-
const relationIds = getIdFields(this.schema, relationModel);
216+
const parentIds = requireIdFields(this.schema, model);
217+
const relationIds = requireIdFields(this.schema, relationModel);
218218
invariant(parentIds.length === 1, 'many-to-many relation must have exactly one id field');
219219
invariant(relationIds.length === 1, 'many-to-many relation must have exactly one id field');
220220
selectModelQuery = selectModelQuery.where((eb) =>

packages/runtime/src/client/crud/operations/base.ts

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ import {
3131
flattenCompoundUniqueFilters,
3232
getDiscriminatorField,
3333
getField,
34-
getIdFields,
3534
getIdValues,
3635
getManyToManyRelation,
3736
getModel,
@@ -40,6 +39,7 @@ import {
4039
isRelationField,
4140
isScalarField,
4241
requireField,
42+
requireIdFields,
4343
requireModel,
4444
} from '../../query-utils';
4545
import { getCrudDialect } from '../dialects';
@@ -132,7 +132,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
132132
model: GetModels<Schema>,
133133
filter: any,
134134
): Promise<unknown | undefined> {
135-
const idFields = getIdFields(this.schema, model);
135+
const idFields = requireIdFields(this.schema, model);
136136
const _filter = flattenCompoundUniqueFilters(this.schema, model, filter);
137137
const query = kysely
138138
.selectFrom(model)
@@ -344,7 +344,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
344344
}
345345

346346
const updatedData = this.fillGeneratedAndDefaultValues(modelDef, createFields);
347-
const idFields = getIdFields(this.schema, model);
347+
const idFields = requireIdFields(this.schema, model);
348348
const query = kysely
349349
.insertInto(model)
350350
.$if(Object.keys(updatedData).length === 0, (qb) => qb.defaultValues())
@@ -481,8 +481,8 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
481481
a.model !== b.model ? a.model.localeCompare(b.model) : a.field.localeCompare(b.field),
482482
);
483483

484-
const firstIds = getIdFields(this.schema, sortedRecords[0]!.model);
485-
const secondIds = getIdFields(this.schema, sortedRecords[1]!.model);
484+
const firstIds = requireIdFields(this.schema, sortedRecords[0]!.model);
485+
const secondIds = requireIdFields(this.schema, sortedRecords[1]!.model);
486486
invariant(firstIds.length === 1, 'many-to-many relation must have exactly one id field');
487487
invariant(secondIds.length === 1, 'many-to-many relation must have exactly one id field');
488488

@@ -771,7 +771,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
771771
const result = await this.executeQuery(kysely, query, 'createMany');
772772
return { count: Number(result.numAffectedRows) } as Result;
773773
} else {
774-
const idFields = getIdFields(this.schema, model);
774+
const idFields = requireIdFields(this.schema, model);
775775
const result = await query.returning(idFields as any).execute();
776776
return result as Result;
777777
}
@@ -1039,7 +1039,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
10391039
// nothing to update, return the filter so that the caller can identify the entity
10401040
return combinedWhere;
10411041
} else {
1042-
const idFields = getIdFields(this.schema, model);
1042+
const idFields = requireIdFields(this.schema, model);
10431043
const query = kysely
10441044
.updateTable(model)
10451045
.where((eb) => this.dialect.buildFilter(eb, model, model, combinedWhere))
@@ -1104,7 +1104,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
11041104
if (!filter || typeof filter !== 'object') {
11051105
return false;
11061106
}
1107-
const idFields = getIdFields(this.schema, model);
1107+
const idFields = requireIdFields(this.schema, model);
11081108
return idFields.length === Object.keys(filter).length && idFields.every((field) => field in filter);
11091109
}
11101110

@@ -1297,7 +1297,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
12971297
const result = await this.executeQuery(kysely, query, 'update');
12981298
return { count: Number(result.numAffectedRows) } as Result;
12991299
} else {
1300-
const idFields = getIdFields(this.schema, model);
1300+
const idFields = requireIdFields(this.schema, model);
13011301
const result = await query.returning(idFields as any).execute();
13021302
return result as Result;
13031303
}
@@ -1336,7 +1336,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
13361336
}
13371337

13381338
private buildIdFieldRefs(kysely: ToKysely<Schema>, model: GetModels<Schema>) {
1339-
const idFields = getIdFields(this.schema, model);
1339+
const idFields = requireIdFields(this.schema, model);
13401340
return idFields.map((f) => kysely.dynamic.ref(`${model}.${f}`));
13411341
}
13421342

@@ -2097,7 +2097,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
20972097
// reused the filter if it's a complete id filter (without extra fields)
20982098
// otherwise, read the entity by the filter
20992099
private getEntityIds(kysely: ToKysely<Schema>, model: GetModels<Schema>, uniqueFilter: any) {
2100-
const idFields: string[] = getIdFields(this.schema, model);
2100+
const idFields: string[] = requireIdFields(this.schema, model);
21012101
if (
21022102
// all id fields are provided
21032103
idFields.every((f) => f in uniqueFilter && uniqueFilter[f] !== undefined) &&

packages/runtime/src/client/query-utils.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ export function requireField(schema: SchemaDef, modelOrType: string, field: stri
5555
}
5656

5757
export function getIdFields<Schema extends SchemaDef>(schema: SchemaDef, model: GetModels<Schema>) {
58-
const modelDef = requireModel(schema, model);
59-
return modelDef?.idFields as GetModels<Schema>[];
58+
const modelDef = getModel(schema, model);
59+
return modelDef?.idFields;
6060
}
6161

6262
export function requireIdFields(schema: SchemaDef, model: string) {
@@ -231,7 +231,7 @@ export function buildJoinPairs(
231231
}
232232

233233
export function makeDefaultOrderBy<Schema extends SchemaDef>(schema: SchemaDef, model: string) {
234-
const idFields = getIdFields(schema, model);
234+
const idFields = requireIdFields(schema, model);
235235
return idFields.map((f) => ({ [f]: 'asc' }) as OrderBy<Schema, GetModels<Schema>, true, false>);
236236
}
237237

@@ -318,7 +318,7 @@ export function safeJSONStringify(value: unknown) {
318318
}
319319

320320
export function extractIdFields(entity: any, schema: SchemaDef, model: string) {
321-
const idFields = getIdFields(schema, model);
321+
const idFields = requireIdFields(schema, model);
322322
return extractFields(entity, idFields);
323323
}
324324

packages/runtime/src/plugins/policy/expression-transformer.ts

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import { getCrudDialect } from '../../client/crud/dialects';
2525
import type { BaseCrudDialect } from '../../client/crud/dialects/base-dialect';
2626
import { InternalError, QueryError } from '../../client/errors';
2727
import type { ClientOptions } from '../../client/options';
28-
import { getIdFields, getModel, getRelationForeignKeyFieldPairs, requireField } from '../../client/query-utils';
28+
import { getModel, getRelationForeignKeyFieldPairs, requireField, requireIdFields } from '../../client/query-utils';
2929
import type {
3030
BinaryExpression,
3131
BinaryOperator,
@@ -196,16 +196,18 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
196196
}
197197

198198
private normalizeBinaryOperationOperands(expr: BinaryExpression, context: ExpressionTransformerContext<Schema>) {
199+
// if relation fields are used directly in comparison, it can only be compared with null,
200+
// so we normalize the args with the id field (use the first id field if multiple)
199201
let normalizedLeft: Expression = expr.left;
200202
if (this.isRelationField(expr.left, context.model)) {
201203
invariant(ExpressionUtils.isNull(expr.right), 'only null comparison is supported for relation field');
202-
const idFields = getIdFields(this.schema, context.model);
204+
const idFields = requireIdFields(this.schema, context.model);
203205
normalizedLeft = this.makeOrAppendMember(normalizedLeft, idFields[0]!);
204206
}
205207
let normalizedRight: Expression = expr.right;
206208
if (this.isRelationField(expr.right, context.model)) {
207209
invariant(ExpressionUtils.isNull(expr.left), 'only null comparison is supported for relation field');
208-
const idFields = getIdFields(this.schema, context.model);
210+
const idFields = requireIdFields(this.schema, context.model);
209211
normalizedRight = this.makeOrAppendMember(normalizedRight, idFields[0]!);
210212
}
211213
return { normalizedLeft, normalizedRight };

0 commit comments

Comments
 (0)