Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,8 @@
- [ ] Validation
- [ ] Access Policy
- [ ] Short-circuit pre-create check for scalar-field only policies
- [ ] Inject "replace into"
- [ ] Inject "on conflict do update"
- [ ] Inject "insert into select from"
- [x] Inject "on conflict do update"
- [x] `check` function
- [x] Migration
- [ ] Databases
- [x] SQLite
Expand Down
1 change: 1 addition & 0 deletions packages/common-helpers/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ export * from './param-case';
export * from './sleep';
export * from './tiny-invariant';
export * from './upper-case-first';
export * from './zip';
11 changes: 11 additions & 0 deletions packages/common-helpers/src/zip.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
/**
* Zips two arrays into an array of tuples.
*/
export function zip<T, U>(arr1: T[], arr2: U[]): Array<[T, U]> {
const length = Math.min(arr1.length, arr2.length);
const result: Array<[T, U]> = [];
for (let i = 0; i < length; i++) {
result.push([arr1[i]!, arr2[i]!]);
}
return result;
}
32 changes: 17 additions & 15 deletions packages/runtime/src/client/crud-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ export type OmitInput<Schema extends SchemaDef, Model extends GetModels<Schema>>

export type SelectIncludeOmit<Schema extends SchemaDef, Model extends GetModels<Schema>, AllowCount extends boolean> = {
select?: SelectInput<Schema, Model, AllowCount, boolean>;
include?: IncludeInput<Schema, Model>;
include?: IncludeInput<Schema, Model, AllowCount>;
omit?: OmitInput<Schema, Model>;
};

Expand All @@ -463,14 +463,7 @@ export type SelectInput<
AllowRelation extends boolean = true,
> = {
[Key in NonRelationFields<Schema, Model>]?: boolean;
} & (AllowRelation extends true ? IncludeInput<Schema, Model> : {}) & // relation fields
// relation count
(AllowCount extends true
? // _count is only allowed if the model has to-many relations
HasToManyRelations<Schema, Model> extends true
? { _count?: SelectCount<Schema, Model> }
: {}
: {});
} & (AllowRelation extends true ? IncludeInput<Schema, Model, AllowCount> : {});

type SelectCount<Schema extends SchemaDef, Model extends GetModels<Schema>> =
| boolean
Expand All @@ -484,7 +477,11 @@ type SelectCount<Schema extends SchemaDef, Model extends GetModels<Schema>> =
};
};

export type IncludeInput<Schema extends SchemaDef, Model extends GetModels<Schema>> = {
export type IncludeInput<
Schema extends SchemaDef,
Model extends GetModels<Schema>,
AllowCount extends boolean = true,
> = {
[Key in RelationFields<Schema, Model>]?:
| boolean
| FindArgs<
Expand All @@ -498,7 +495,12 @@ export type IncludeInput<Schema extends SchemaDef, Model extends GetModels<Schem
? true
: false
>;
};
} & (AllowCount extends true
? // _count is only allowed if the model has to-many relations
HasToManyRelations<Schema, Model> extends true
? { _count?: SelectCount<Schema, Model> }
: {}
: {});

export type Subset<T, U> = {
[key in keyof T]: key extends keyof U ? T[key] : never;
Expand Down Expand Up @@ -674,7 +676,7 @@ export type FindUniqueArgs<Schema extends SchemaDef, Model extends GetModels<Sch

export type CreateArgs<Schema extends SchemaDef, Model extends GetModels<Schema>> = {
data: CreateInput<Schema, Model>;
select?: SelectInput<Schema, Model, true>;
select?: SelectInput<Schema, Model>;
include?: IncludeInput<Schema, Model>;
omit?: OmitInput<Schema, Model>;
};
Expand Down Expand Up @@ -813,7 +815,7 @@ type NestedCreateManyInput<
export type UpdateArgs<Schema extends SchemaDef, Model extends GetModels<Schema>> = {
data: UpdateInput<Schema, Model>;
where: WhereUniqueInput<Schema, Model>;
select?: SelectInput<Schema, Model, true>;
select?: SelectInput<Schema, Model>;
include?: IncludeInput<Schema, Model>;
omit?: OmitInput<Schema, Model>;
};
Expand Down Expand Up @@ -841,7 +843,7 @@ export type UpsertArgs<Schema extends SchemaDef, Model extends GetModels<Schema>
create: CreateInput<Schema, Model>;
update: UpdateInput<Schema, Model>;
where: WhereUniqueInput<Schema, Model>;
select?: SelectInput<Schema, Model, true>;
select?: SelectInput<Schema, Model>;
include?: IncludeInput<Schema, Model>;
omit?: OmitInput<Schema, Model>;
};
Expand Down Expand Up @@ -958,7 +960,7 @@ type ToOneRelationUpdateInput<

export type DeleteArgs<Schema extends SchemaDef, Model extends GetModels<Schema>> = {
where: WhereUniqueInput<Schema, Model>;
select?: SelectInput<Schema, Model, true>;
select?: SelectInput<Schema, Model>;
include?: IncludeInput<Schema, Model>;
omit?: OmitInput<Schema, Model>;
};
Expand Down
27 changes: 21 additions & 6 deletions packages/runtime/src/client/crud/dialects/base-dialect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1048,14 +1048,29 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
for (const [field, value] of Object.entries(selections.select)) {
const fieldDef = requireField(this.schema, model, field);
const fieldModel = fieldDef.type;
const joinPairs = buildJoinPairs(this.schema, model, parentAlias, field, fieldModel);

// build a nested query to count the number of records in the relation
let fieldCountQuery = eb.selectFrom(fieldModel).select(eb.fn.countAll().as(`_count$${field}`));
let fieldCountQuery: SelectQueryBuilder<any, any, any>;

// join conditions
for (const [left, right] of joinPairs) {
fieldCountQuery = fieldCountQuery.whereRef(left, '=', right);
const m2m = getManyToManyRelation(this.schema, model, field);
if (m2m) {
// many-to-many relation, count the join table
fieldCountQuery = eb
.selectFrom(fieldModel)
.innerJoin(m2m.joinTable, (join) =>
join
.onRef(`${m2m.joinTable}.${m2m.otherFkName}`, '=', `${fieldModel}.${m2m.otherPKName}`)
.onRef(`${m2m.joinTable}.${m2m.parentFkName}`, '=', `${parentAlias}.${m2m.parentPKName}`),
)
.select(eb.fn.countAll().as(`_count$${field}`));
} else {
// build a nested query to count the number of records in the relation
fieldCountQuery = eb.selectFrom(fieldModel).select(eb.fn.countAll().as(`_count$${field}`));

// join conditions
const joinPairs = buildJoinPairs(this.schema, model, parentAlias, field, fieldModel);
for (const [left, right] of joinPairs) {
fieldCountQuery = fieldCountQuery.whereRef(left, '=', right);
}
}

// merge _count filter
Expand Down
2 changes: 1 addition & 1 deletion packages/runtime/src/client/crud/operations/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
entity: rightEntity,
},
].sort((a, b) =>
// the implement m2m join table's "A", "B" fk fields' order is determined
// the implicit m2m join table's "A", "B" fk fields' order is determined
// by model name's sort order, and when identical (for self-relations),
// field name's sort order
a.model !== b.model ? a.model.localeCompare(b.model) : a.field.localeCompare(b.field),
Expand Down
30 changes: 25 additions & 5 deletions packages/runtime/src/client/crud/validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@ import Decimal from 'decimal.js';
import stableStringify from 'json-stable-stringify';
import { match, P } from 'ts-pattern';
import { z, ZodType } from 'zod';
import { type BuiltinType, type EnumDef, type FieldDef, type GetModels, type SchemaDef } from '../../schema';
import {
type BuiltinType,
type EnumDef,
type FieldDef,
type GetModels,
type ModelDef,
type SchemaDef,
} from '../../schema';
import { enumerate } from '../../utils/enumerate';
import { extractFields } from '../../utils/object-utils';
import { formatError } from '../../utils/zod-utils';
Expand Down Expand Up @@ -595,10 +602,18 @@ export class InputValidator<Schema extends SchemaDef> {
}
}

const toManyRelations = Object.values(modelDef.fields).filter((def) => def.relation && def.array);
const _countSchema = this.makeCountSelectionSchema(modelDef);
if (_countSchema) {
fields['_count'] = _countSchema;
}

return z.strictObject(fields);
}

private makeCountSelectionSchema(modelDef: ModelDef) {
const toManyRelations = Object.values(modelDef.fields).filter((def) => def.relation && def.array);
if (toManyRelations.length > 0) {
fields['_count'] = z
return z
.union([
z.literal(true),
z.strictObject({
Expand All @@ -621,9 +636,9 @@ export class InputValidator<Schema extends SchemaDef> {
}),
])
.optional();
} else {
return undefined;
}

return z.strictObject(fields);
}

private makeRelationSelectIncludeSchema(fieldDef: FieldDef) {
Expand Down Expand Up @@ -677,6 +692,11 @@ export class InputValidator<Schema extends SchemaDef> {
}
}

const _countSchema = this.makeCountSelectionSchema(modelDef);
if (_countSchema) {
fields['_count'] = _countSchema;
}

return z.strictObject(fields);
}

Expand Down
8 changes: 8 additions & 0 deletions packages/runtime/src/client/query-utils.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { invariant } from '@zenstackhq/common-helpers';
import type { Expression, ExpressionBuilder, ExpressionWrapper } from 'kysely';
import { match } from 'ts-pattern';
import { ExpressionUtils, type FieldDef, type GetModels, type ModelDef, type SchemaDef } from '../schema';
Expand Down Expand Up @@ -259,11 +260,18 @@ export function getManyToManyRelation(schema: SchemaDef, model: string, field: s
orderedFK = sortedFieldNames[0] === field ? ['A', 'B'] : ['B', 'A'];
}

const modelIdFields = requireIdFields(schema, model);
invariant(modelIdFields.length === 1, 'Only single-field ID is supported for many-to-many relation');
const otherIdFields = requireIdFields(schema, fieldDef.type);
invariant(otherIdFields.length === 1, 'Only single-field ID is supported for many-to-many relation');

return {
parentFkName: orderedFK[0],
parentPKName: modelIdFields[0]!,
otherModel: fieldDef.type,
otherField: fieldDef.relation.opposite,
otherFkName: orderedFK[1],
otherPKName: otherIdFields[0]!,
joinTable: fieldDef.relation.name
? `_${fieldDef.relation.name}`
: `_${sortedModelNames[0]}To${sortedModelNames[1]}`,
Expand Down
45 changes: 42 additions & 3 deletions packages/runtime/src/plugins/policy/expression-transformer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@ import type { ClientContract, CRUD } from '../../client/contract';
import { getCrudDialect } from '../../client/crud/dialects';
import type { BaseCrudDialect } from '../../client/crud/dialects/base-dialect';
import { InternalError, QueryError } from '../../client/errors';
import { getModel, getRelationForeignKeyFieldPairs, requireField, requireIdFields } from '../../client/query-utils';
import {
getManyToManyRelation,
getModel,
getRelationForeignKeyFieldPairs,
requireField,
requireIdFields,
} from '../../client/query-utils';
import type {
BinaryExpression,
BinaryOperator,
Expand All @@ -44,7 +50,7 @@ import {
type SchemaDef,
} from '../../schema';
import { ExpressionEvaluator } from './expression-evaluator';
import { conjunction, disjunction, logicalNot, trueNode } from './utils';
import { conjunction, disjunction, falseNode, logicalNot, trueNode } from './utils';

export type ExpressionTransformerContext<Schema extends SchemaDef> = {
model: GetModels<Schema>;
Expand Down Expand Up @@ -335,7 +341,13 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
}

private transformValue(value: unknown, type: BuiltinType) {
return ValueNode.create(this.dialect.transformPrimitive(value, type, false) ?? null);
if (value === true) {
return trueNode(this.dialect);
} else if (value === false) {
return falseNode(this.dialect);
} else {
return ValueNode.create(this.dialect.transformPrimitive(value, type, false) ?? null);
}
}

@expr('unary')
Expand Down Expand Up @@ -537,6 +549,11 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
relationModel: string,
context: ExpressionTransformerContext<Schema>,
): SelectQueryNode {
const m2m = getManyToManyRelation(this.schema, context.model, field);
if (m2m) {
return this.transformManyToManyRelationAccess(m2m, context);
}

const fromModel = context.model;
const { keyPairs, ownedByModel } = getRelationForeignKeyFieldPairs(this.schema, fromModel, field);

Expand Down Expand Up @@ -574,6 +591,28 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
};
}

private transformManyToManyRelationAccess(
m2m: NonNullable<ReturnType<typeof getManyToManyRelation>>,
context: ExpressionTransformerContext<Schema>,
) {
const eb = expressionBuilder<any, any>();
const relationQuery = eb
.selectFrom(m2m.otherModel)
// inner join with join table and additionally filter by the parent model
.innerJoin(m2m.joinTable, (join) =>
join
// relation model pk to join table fk
.onRef(`${m2m.otherModel}.${m2m.otherPKName}`, '=', `${m2m.joinTable}.${m2m.otherFkName}`)
// parent model pk to join table fk
.onRef(
`${m2m.joinTable}.${m2m.parentFkName}`,
'=',
`${context.alias ?? context.model}.${m2m.parentPKName}`,
),
);
return relationQuery.toOperationNode();
}

private createColumnRef(column: string, context: ExpressionTransformerContext<Schema>): ReferenceNode {
return ReferenceNode.create(ColumnNode.create(column), TableNode.create(context.alias ?? context.model));
}
Expand Down
Loading