Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions packages/language/src/validators/expression-validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -108,18 +108,21 @@ export default class ExpressionValidator implements AstValidator<Expression> {
supportedShapes = ['Boolean', 'Any'];
}

const leftResolvedDecl = expr.left.$resolvedType?.decl;
const rightResolvedDecl = expr.right.$resolvedType?.decl;

if (
typeof expr.left.$resolvedType?.decl !== 'string' ||
!supportedShapes.includes(expr.left.$resolvedType.decl)
leftResolvedDecl &&
(typeof leftResolvedDecl !== 'string' || !supportedShapes.includes(leftResolvedDecl))
) {
accept('error', `invalid operand type for "${expr.operator}" operator`, {
node: expr.left,
});
return;
}
if (
typeof expr.right.$resolvedType?.decl !== 'string' ||
!supportedShapes.includes(expr.right.$resolvedType.decl)
rightResolvedDecl &&
(typeof rightResolvedDecl !== 'string' || !supportedShapes.includes(rightResolvedDecl))
) {
accept('error', `invalid operand type for "${expr.operator}" operator`, {
node: expr.right,
Expand All @@ -128,14 +131,11 @@ export default class ExpressionValidator implements AstValidator<Expression> {
}

// DateTime comparison is only allowed between two DateTime values
if (expr.left.$resolvedType.decl === 'DateTime' && expr.right.$resolvedType.decl !== 'DateTime') {
if (leftResolvedDecl === 'DateTime' && rightResolvedDecl && rightResolvedDecl !== 'DateTime') {
accept('error', 'incompatible operand types', {
node: expr,
});
} else if (
expr.right.$resolvedType.decl === 'DateTime' &&
expr.left.$resolvedType.decl !== 'DateTime'
) {
} else if (rightResolvedDecl === 'DateTime' && leftResolvedDecl && leftResolvedDecl !== 'DateTime') {
accept('error', 'incompatible operand types', {
node: expr,
});
Expand Down
16 changes: 10 additions & 6 deletions packages/runtime/src/client/crud/dialects/base-dialect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ import {
ensureArray,
flattenCompoundUniqueFilters,
getDelegateDescendantModels,
getIdFields,
getManyToManyRelation,
getRelationForeignKeyFieldPairs,
isEnum,
isInheritedField,
isRelationField,
makeDefaultOrderBy,
requireField,
requireIdFields,
requireModel,
} from '../../query-utils';

Expand Down Expand Up @@ -366,18 +366,22 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
const m2m = getManyToManyRelation(this.schema, model, field);
if (m2m) {
// many-to-many relation
const modelIdField = getIdFields(this.schema, model)[0]!;
const relationIdField = getIdFields(this.schema, relationModel)[0]!;

const modelIdFields = requireIdFields(this.schema, model);
invariant(modelIdFields.length === 1, 'many-to-many relation must have exactly one id field');
const relationIdFields = requireIdFields(this.schema, relationModel);
invariant(relationIdFields.length === 1, 'many-to-many relation must have exactly one id field');

return eb(
sql.ref(`${relationFilterSelectAlias}.${relationIdField}`),
sql.ref(`${relationFilterSelectAlias}.${relationIdFields[0]}`),
'in',
eb
.selectFrom(m2m.joinTable)
.select(`${m2m.joinTable}.${m2m.otherFkName}`)
.whereRef(
sql.ref(`${m2m.joinTable}.${m2m.parentFkName}`),
'=',
sql.ref(`${modelAlias}.${modelIdField}`),
sql.ref(`${modelAlias}.${modelIdFields[0]}`),
),
);
} else {
Expand Down Expand Up @@ -1012,7 +1016,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
otherModelAlias: string,
query: SelectQueryBuilder<any, any, any>,
) {
const idFields = getIdFields(this.schema, thisModel);
const idFields = requireIdFields(this.schema, thisModel);
query = query.leftJoin(otherModelAlias, (qb) => {
for (const idField of idFields) {
qb = qb.onRef(`${thisModelAlias}.${idField}`, '=', `${otherModelAlias}.${idField}`);
Expand Down
6 changes: 3 additions & 3 deletions packages/runtime/src/client/crud/dialects/postgresql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ import type { FindArgs } from '../../crud-types';
import {
buildJoinPairs,
getDelegateDescendantModels,
getIdFields,
getManyToManyRelation,
isRelationField,
requireField,
requireIdFields,
requireModel,
} from '../../query-utils';
import { BaseCrudDialect } from './base-dialect';
Expand Down Expand Up @@ -157,8 +157,8 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
const m2m = getManyToManyRelation(this.schema, model, relationField);
if (m2m) {
// many-to-many relation
const parentIds = getIdFields(this.schema, model);
const relationIds = getIdFields(this.schema, relationModel);
const parentIds = requireIdFields(this.schema, model);
const relationIds = requireIdFields(this.schema, relationModel);
invariant(parentIds.length === 1, 'many-to-many relation must have exactly one id field');
invariant(relationIds.length === 1, 'many-to-many relation must have exactly one id field');
query = query.where((eb) =>
Expand Down
6 changes: 3 additions & 3 deletions packages/runtime/src/client/crud/dialects/sqlite.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ import { DELEGATE_JOINED_FIELD_PREFIX } from '../../constants';
import type { FindArgs } from '../../crud-types';
import {
getDelegateDescendantModels,
getIdFields,
getManyToManyRelation,
getRelationForeignKeyFieldPairs,
requireField,
requireIdFields,
requireModel,
} from '../../query-utils';
import { BaseCrudDialect } from './base-dialect';
Expand Down Expand Up @@ -213,8 +213,8 @@ export class SqliteCrudDialect<Schema extends SchemaDef> extends BaseCrudDialect
const m2m = getManyToManyRelation(this.schema, model, relationField);
if (m2m) {
// many-to-many relation
const parentIds = getIdFields(this.schema, model);
const relationIds = getIdFields(this.schema, relationModel);
const parentIds = requireIdFields(this.schema, model);
const relationIds = requireIdFields(this.schema, relationModel);
invariant(parentIds.length === 1, 'many-to-many relation must have exactly one id field');
invariant(relationIds.length === 1, 'many-to-many relation must have exactly one id field');
selectModelQuery = selectModelQuery.where((eb) =>
Expand Down
22 changes: 11 additions & 11 deletions packages/runtime/src/client/crud/operations/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import {
flattenCompoundUniqueFilters,
getDiscriminatorField,
getField,
getIdFields,
getIdValues,
getManyToManyRelation,
getModel,
Expand All @@ -40,6 +39,7 @@ import {
isRelationField,
isScalarField,
requireField,
requireIdFields,
requireModel,
} from '../../query-utils';
import { getCrudDialect } from '../dialects';
Expand Down Expand Up @@ -132,7 +132,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
model: GetModels<Schema>,
filter: any,
): Promise<unknown | undefined> {
const idFields = getIdFields(this.schema, model);
const idFields = requireIdFields(this.schema, model);
const _filter = flattenCompoundUniqueFilters(this.schema, model, filter);
const query = kysely
.selectFrom(model)
Expand Down Expand Up @@ -344,7 +344,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
}

const updatedData = this.fillGeneratedAndDefaultValues(modelDef, createFields);
const idFields = getIdFields(this.schema, model);
const idFields = requireIdFields(this.schema, model);
const query = kysely
.insertInto(model)
.$if(Object.keys(updatedData).length === 0, (qb) => qb.defaultValues())
Expand Down Expand Up @@ -481,8 +481,8 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
a.model !== b.model ? a.model.localeCompare(b.model) : a.field.localeCompare(b.field),
);

const firstIds = getIdFields(this.schema, sortedRecords[0]!.model);
const secondIds = getIdFields(this.schema, sortedRecords[1]!.model);
const firstIds = requireIdFields(this.schema, sortedRecords[0]!.model);
const secondIds = requireIdFields(this.schema, sortedRecords[1]!.model);
invariant(firstIds.length === 1, 'many-to-many relation must have exactly one id field');
invariant(secondIds.length === 1, 'many-to-many relation must have exactly one id field');

Expand Down Expand Up @@ -771,7 +771,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
const result = await this.executeQuery(kysely, query, 'createMany');
return { count: Number(result.numAffectedRows) } as Result;
} else {
const idFields = getIdFields(this.schema, model);
const idFields = requireIdFields(this.schema, model);
const result = await query.returning(idFields as any).execute();
return result as Result;
}
Expand Down Expand Up @@ -1039,7 +1039,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
// nothing to update, return the filter so that the caller can identify the entity
return combinedWhere;
} else {
const idFields = getIdFields(this.schema, model);
const idFields = requireIdFields(this.schema, model);
const query = kysely
.updateTable(model)
.where((eb) => this.dialect.buildFilter(eb, model, model, combinedWhere))
Expand Down Expand Up @@ -1104,7 +1104,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
if (!filter || typeof filter !== 'object') {
return false;
}
const idFields = getIdFields(this.schema, model);
const idFields = requireIdFields(this.schema, model);
return idFields.length === Object.keys(filter).length && idFields.every((field) => field in filter);
}

Expand Down Expand Up @@ -1297,7 +1297,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
const result = await this.executeQuery(kysely, query, 'update');
return { count: Number(result.numAffectedRows) } as Result;
} else {
const idFields = getIdFields(this.schema, model);
const idFields = requireIdFields(this.schema, model);
const result = await query.returning(idFields as any).execute();
return result as Result;
}
Expand Down Expand Up @@ -1336,7 +1336,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
}

private buildIdFieldRefs(kysely: ToKysely<Schema>, model: GetModels<Schema>) {
const idFields = getIdFields(this.schema, model);
const idFields = requireIdFields(this.schema, model);
return idFields.map((f) => kysely.dynamic.ref(`${model}.${f}`));
}

Expand Down Expand Up @@ -2097,7 +2097,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
// reused the filter if it's a complete id filter (without extra fields)
// otherwise, read the entity by the filter
private getEntityIds(kysely: ToKysely<Schema>, model: GetModels<Schema>, uniqueFilter: any) {
const idFields: string[] = getIdFields(this.schema, model);
const idFields: string[] = requireIdFields(this.schema, model);
if (
// all id fields are provided
idFields.every((f) => f in uniqueFilter && uniqueFilter[f] !== undefined) &&
Expand Down
8 changes: 4 additions & 4 deletions packages/runtime/src/client/query-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ export function requireField(schema: SchemaDef, modelOrType: string, field: stri
}

export function getIdFields<Schema extends SchemaDef>(schema: SchemaDef, model: GetModels<Schema>) {
const modelDef = requireModel(schema, model);
return modelDef?.idFields as GetModels<Schema>[];
const modelDef = getModel(schema, model);
return modelDef?.idFields;
}

export function requireIdFields(schema: SchemaDef, model: string) {
Expand Down Expand Up @@ -231,7 +231,7 @@ export function buildJoinPairs(
}

export function makeDefaultOrderBy<Schema extends SchemaDef>(schema: SchemaDef, model: string) {
const idFields = getIdFields(schema, model);
const idFields = requireIdFields(schema, model);
return idFields.map((f) => ({ [f]: 'asc' }) as OrderBy<Schema, GetModels<Schema>, true, false>);
}

Expand Down Expand Up @@ -318,7 +318,7 @@ export function safeJSONStringify(value: unknown) {
}

export function extractIdFields(entity: any, schema: SchemaDef, model: string) {
const idFields = getIdFields(schema, model);
const idFields = requireIdFields(schema, model);
return extractFields(entity, idFields);
}

Expand Down
8 changes: 5 additions & 3 deletions packages/runtime/src/plugins/policy/expression-transformer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import { getCrudDialect } from '../../client/crud/dialects';
import type { BaseCrudDialect } from '../../client/crud/dialects/base-dialect';
import { InternalError, QueryError } from '../../client/errors';
import type { ClientOptions } from '../../client/options';
import { getIdFields, getModel, getRelationForeignKeyFieldPairs, requireField } from '../../client/query-utils';
import { getModel, getRelationForeignKeyFieldPairs, requireField, requireIdFields } from '../../client/query-utils';
import type {
BinaryExpression,
BinaryOperator,
Expand Down Expand Up @@ -196,16 +196,18 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
}

private normalizeBinaryOperationOperands(expr: BinaryExpression, context: ExpressionTransformerContext<Schema>) {
// if relation fields are used directly in comparison, it can only be compared with null,
// so we normalize the args with the id field (use the first id field if multiple)
let normalizedLeft: Expression = expr.left;
if (this.isRelationField(expr.left, context.model)) {
invariant(ExpressionUtils.isNull(expr.right), 'only null comparison is supported for relation field');
const idFields = getIdFields(this.schema, context.model);
const idFields = requireIdFields(this.schema, context.model);
normalizedLeft = this.makeOrAppendMember(normalizedLeft, idFields[0]!);
}
let normalizedRight: Expression = expr.right;
if (this.isRelationField(expr.right, context.model)) {
invariant(ExpressionUtils.isNull(expr.left), 'only null comparison is supported for relation field');
const idFields = getIdFields(this.schema, context.model);
const idFields = requireIdFields(this.schema, context.model);
normalizedRight = this.makeOrAppendMember(normalizedRight, idFields[0]!);
}
return { normalizedLeft, normalizedRight };
Expand Down
Loading