Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions packages/runtime/src/client/crud/operations/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,8 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {

if (!ownedByModel) {
// assign fks from parent
const parentFkFields = this.buildFkAssignments(
const parentFkFields = await this.buildFkAssignments(
kysely,
fromRelation.model,
fromRelation.field,
fromRelation.ids,
Expand Down Expand Up @@ -433,7 +434,12 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
return { baseEntity, remainingFields };
}

private buildFkAssignments(model: string, relationField: string, entity: any) {
private async buildFkAssignments(
kysely: ToKysely<Schema>,
model: GetModels<Schema>,
relationField: string,
entity: any,
) {
const parentFkFields: any = {};

invariant(relationField, 'parentField must be defined if parentModel is defined');
Expand All @@ -443,7 +449,18 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {

for (const pair of keyPairs) {
if (!(pair.pk in entity)) {
throw new QueryError(`Field "${pair.pk}" not found in parent created data`);
// the relation may be using a non-id field as fk, so we read in-place
// to fetch that field
const extraRead = await this.readUnique(kysely, model, {
where: entity,
select: { [pair.pk]: true },
} as any);
if (!extraRead) {
throw new QueryError(`Field "${pair.pk}" not found in parent created data`);
} else {
// update the parent entity
Object.assign(entity, extraRead);
}
}
Object.assign(parentFkFields, {
[pair.fk]: (entity as any)[pair.pk],
Expand Down Expand Up @@ -1411,7 +1428,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
...(enumerate(value) as { where: any; data: any }[]).map((item) => {
let where;
let data;
if ('where' in item) {
if ('data' in item && typeof item.data === 'object') {
where = item.where;
data = item.data;
} else {
Expand Down
8 changes: 6 additions & 2 deletions packages/runtime/src/client/crud/operations/create.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { match } from 'ts-pattern';
import { RejectedByPolicyError } from '../../../plugins/policy/errors';
import { RejectedByPolicyError, RejectedByPolicyReason } from '../../../plugins/policy/errors';
import type { GetModels, SchemaDef } from '../../../schema';
import type { CreateArgs, CreateManyAndReturnArgs, CreateManyArgs, WhereInput } from '../../crud-types';
import { getIdValues } from '../../query-utils';
Expand Down Expand Up @@ -40,7 +40,11 @@ export class CreateOperationHandler<Schema extends SchemaDef> extends BaseOperat
});

if (!result && this.hasPolicyEnabled) {
throw new RejectedByPolicyError(this.model, `result is not allowed to be read back`);
throw new RejectedByPolicyError(
this.model,
RejectedByPolicyReason.CANNOT_READ_BACK,
`result is not allowed to be read back`,
);
}

return result;
Expand Down
12 changes: 9 additions & 3 deletions packages/runtime/src/client/crud/operations/delete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import type { SchemaDef } from '../../../schema';
import type { DeleteArgs, DeleteManyArgs } from '../../crud-types';
import { NotFoundError } from '../../errors';
import { BaseOperationHandler } from './base';
import { RejectedByPolicyError, RejectedByPolicyReason } from '../../../plugins/policy';

export class DeleteOperationHandler<Schema extends SchemaDef> extends BaseOperationHandler<Schema> {
async handle(operation: 'delete' | 'deleteMany', args: unknown | undefined) {
Expand All @@ -24,9 +25,6 @@ export class DeleteOperationHandler<Schema extends SchemaDef> extends BaseOperat
omit: args.omit,
where: args.where,
});
if (!existing) {
throw new NotFoundError(this.model);
}

// TODO: avoid using transaction for simple delete
await this.safeTransaction(async (tx) => {
Expand All @@ -36,6 +34,14 @@ export class DeleteOperationHandler<Schema extends SchemaDef> extends BaseOperat
}
});

if (!existing && this.hasPolicyEnabled) {
throw new RejectedByPolicyError(
this.model,
RejectedByPolicyReason.CANNOT_READ_BACK,
'result is not allowed to be read back',
);
}

return existing;
}

Expand Down
31 changes: 26 additions & 5 deletions packages/runtime/src/client/crud/operations/update.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { match } from 'ts-pattern';
import { RejectedByPolicyError } from '../../../plugins/policy/errors';
import { RejectedByPolicyError, RejectedByPolicyReason } from '../../../plugins/policy/errors';
import type { GetModels, SchemaDef } from '../../../schema';
import type { UpdateArgs, UpdateManyAndReturnArgs, UpdateManyArgs, UpsertArgs, WhereInput } from '../../crud-types';
import { getIdValues } from '../../query-utils';
Expand Down Expand Up @@ -48,7 +48,11 @@ export class UpdateOperationHandler<Schema extends SchemaDef> extends BaseOperat
// update succeeded but result cannot be read back
if (this.hasPolicyEnabled) {
// if access policy is enabled, we assume it's due to read violation (not guaranteed though)
throw new RejectedByPolicyError(this.model, 'result is not allowed to be read back');
throw new RejectedByPolicyError(
this.model,
RejectedByPolicyReason.CANNOT_READ_BACK,
'result is not allowed to be read back',
);
} else {
// this can happen if the entity is cascade deleted during the update, return null to
// be consistent with Prisma even though it doesn't comply with the method signature
Expand All @@ -71,16 +75,29 @@ export class UpdateOperationHandler<Schema extends SchemaDef> extends BaseOperat
return [];
}

return this.safeTransaction(async (tx) => {
const { readBackResult, updateResult } = await this.safeTransaction(async (tx) => {
const updateResult = await this.updateMany(tx, this.model, args.where, args.data, args.limit, true);
return this.read(tx, this.model, {
const readBackResult = await this.read(tx, this.model, {
select: args.select,
omit: args.omit,
where: {
OR: updateResult.map((item) => getIdValues(this.schema, this.model, item) as any),
} as any, // TODO: fix type
});

return { readBackResult, updateResult };
});

if (readBackResult.length < updateResult.length && this.hasPolicyEnabled) {
// some of the updated entities cannot be read back
throw new RejectedByPolicyError(
this.model,
RejectedByPolicyReason.CANNOT_READ_BACK,
'result is not allowed to be read back',
);
}

return readBackResult;
}

private async runUpsert(args: UpsertArgs<Schema, GetModels<Schema>>) {
Expand Down Expand Up @@ -113,7 +130,11 @@ export class UpdateOperationHandler<Schema extends SchemaDef> extends BaseOperat
});

if (!result && this.hasPolicyEnabled) {
throw new RejectedByPolicyError(this.model, 'result is not allowed to be read back');
throw new RejectedByPolicyError(
this.model,
RejectedByPolicyReason.CANNOT_READ_BACK,
'result is not allowed to be read back',
);
}

return result;
Expand Down
35 changes: 22 additions & 13 deletions packages/runtime/src/client/crud/validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { invariant } from '@zenstackhq/common-helpers';
import Decimal from 'decimal.js';
import stableStringify from 'json-stable-stringify';
import { match, P } from 'ts-pattern';
import { z, ZodType } from 'zod';
import { z, ZodSchema, ZodType } from 'zod';
import {
type BuiltinType,
type EnumDef,
Expand Down Expand Up @@ -764,13 +764,15 @@ export class InputValidator<Schema extends SchemaDef> {

private makeCreateSchema(model: string) {
const dataSchema = this.makeCreateDataSchema(model, false);
const schema = z.strictObject({
let schema: ZodSchema = z.strictObject({
data: dataSchema,
select: this.makeSelectSchema(model).optional(),
include: this.makeIncludeSchema(model).optional(),
omit: this.makeOmitSchema(model).optional(),
});
return this.refineForSelectIncludeMutuallyExclusive(schema);
schema = this.refineForSelectIncludeMutuallyExclusive(schema);
schema = this.refineForSelectOmitMutuallyExclusive(schema);
return schema;
}

private makeCreateManySchema(model: string) {
Expand Down Expand Up @@ -934,15 +936,15 @@ export class InputValidator<Schema extends SchemaDef> {
fields['update'] = array
? this.orArray(
z.strictObject({
where: this.makeWhereSchema(fieldType, true),
where: this.makeWhereSchema(fieldType, true).optional(),
data: this.makeUpdateDataSchema(fieldType, withoutFields),
}),
true,
).optional()
: z
.union([
z.strictObject({
where: this.makeWhereSchema(fieldType, true),
where: this.makeWhereSchema(fieldType, true).optional(),
data: this.makeUpdateDataSchema(fieldType, withoutFields),
}),
this.makeUpdateDataSchema(fieldType, withoutFields),
Expand Down Expand Up @@ -1026,14 +1028,16 @@ export class InputValidator<Schema extends SchemaDef> {
// #region Update

private makeUpdateSchema(model: string) {
const schema = z.strictObject({
let schema: ZodSchema = z.strictObject({
where: this.makeWhereSchema(model, true),
data: this.makeUpdateDataSchema(model),
select: this.makeSelectSchema(model).optional(),
include: this.makeIncludeSchema(model).optional(),
omit: this.makeOmitSchema(model).optional(),
});
return this.refineForSelectIncludeMutuallyExclusive(schema);
schema = this.refineForSelectIncludeMutuallyExclusive(schema);
schema = this.refineForSelectOmitMutuallyExclusive(schema);
return schema;
}

private makeUpdateManySchema(model: string) {
Expand All @@ -1046,23 +1050,26 @@ export class InputValidator<Schema extends SchemaDef> {

private makeUpdateManyAndReturnSchema(model: string) {
const base = this.makeUpdateManySchema(model);
const result = base.extend({
let schema: ZodSchema = base.extend({
select: this.makeSelectSchema(model).optional(),
omit: this.makeOmitSchema(model).optional(),
});
return this.refineForSelectOmitMutuallyExclusive(result);
schema = this.refineForSelectOmitMutuallyExclusive(schema);
return schema;
}

private makeUpsertSchema(model: string) {
const schema = z.strictObject({
let schema: ZodSchema = z.strictObject({
where: this.makeWhereSchema(model, true),
create: this.makeCreateDataSchema(model, false),
update: this.makeUpdateDataSchema(model),
select: this.makeSelectSchema(model).optional(),
include: this.makeIncludeSchema(model).optional(),
omit: this.makeOmitSchema(model).optional(),
});
return this.refineForSelectIncludeMutuallyExclusive(schema);
schema = this.refineForSelectIncludeMutuallyExclusive(schema);
schema = this.refineForSelectOmitMutuallyExclusive(schema);
return schema;
}

private makeUpdateDataSchema(model: string, withoutFields: string[] = [], withoutRelationFields = false) {
Expand Down Expand Up @@ -1166,12 +1173,14 @@ export class InputValidator<Schema extends SchemaDef> {
// #region Delete

private makeDeleteSchema(model: GetModels<Schema>) {
const schema = z.strictObject({
let schema: ZodSchema = z.strictObject({
where: this.makeWhereSchema(model, true),
select: this.makeSelectSchema(model).optional(),
include: this.makeIncludeSchema(model).optional(),
});
return this.refineForSelectIncludeMutuallyExclusive(schema);
schema = this.refineForSelectIncludeMutuallyExclusive(schema);
schema = this.refineForSelectOmitMutuallyExclusive(schema);
return schema;
}

private makeDeleteManySchema(model: GetModels<Schema>) {
Expand Down
7 changes: 4 additions & 3 deletions packages/runtime/src/client/helpers/schema-db-pusher.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,19 @@ export class SchemaDbPusher<Schema extends SchemaDef> {
}

// sort models so that target of fk constraints are created first
const sortedModels = this.sortModels(this.schema.models);
const models = Object.values(this.schema.models).filter((m) => !m.isView);
const sortedModels = this.sortModels(models);
for (const modelDef of sortedModels) {
const createTable = this.createModelTable(tx, modelDef);
await createTable.execute();
}
});
}

private sortModels(models: Record<string, ModelDef>): ModelDef[] {
private sortModels(models: ModelDef[]): ModelDef[] {
const graph: [ModelDef, ModelDef | undefined][] = [];

for (const model of Object.values(models)) {
for (const model of models) {
let added = false;

if (model.baseModel) {
Expand Down
25 changes: 23 additions & 2 deletions packages/runtime/src/plugins/policy/errors.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,32 @@
/**
* Reason code for policy rejection.
*/
export enum RejectedByPolicyReason {
/**
* Rejected because the operation is not allowed by policy.
*/
NO_ACCESS = 'no-access',

/**
* Rejected because the result cannot be read back after mutation due to policy.
*/
CANNOT_READ_BACK = 'cannot-read-back',

/**
* Other reasons.
*/
OTHER = 'other',
}

/**
* Error thrown when an operation is rejected by access policy.
*/
export class RejectedByPolicyError extends Error {
constructor(
public readonly model: string | undefined,
public readonly reason?: string,
public readonly reason: RejectedByPolicyReason = RejectedByPolicyReason.NO_ACCESS,
message?: string,
) {
super(reason ?? `Operation rejected by policy${model ? ': ' + model : ''}`);
super(message ?? `Operation rejected by policy${model ? ': ' + model : ''}`);
}
}
16 changes: 13 additions & 3 deletions packages/runtime/src/plugins/policy/policy-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import type { ProceedKyselyQueryFunction } from '../../client/plugin';
import { getManyToManyRelation, requireField, requireIdFields, requireModel } from '../../client/query-utils';
import { ExpressionUtils, type BuiltinType, type Expression, type GetModels, type SchemaDef } from '../../schema';
import { ColumnCollector } from './column-collector';
import { RejectedByPolicyError } from './errors';
import { RejectedByPolicyError, RejectedByPolicyReason } from './errors';
import { ExpressionTransformer } from './expression-transformer';
import type { Policy, PolicyOperation } from './types';
import { buildIsFalse, conjunction, disjunction, falseNode, getTableName } from './utils';
Expand All @@ -66,7 +66,11 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
) {
if (!this.isCrudQueryNode(node)) {
// non-CRUD queries are not allowed
throw new RejectedByPolicyError(undefined, 'non-CRUD queries are not allowed');
throw new RejectedByPolicyError(
undefined,
RejectedByPolicyReason.OTHER,
'non-CRUD queries are not allowed',
);
}

if (!this.isMutationQueryNode(node)) {
Expand Down Expand Up @@ -106,7 +110,11 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
} else {
const readBackResult = await this.processReadBack(node, result, proceed);
if (readBackResult.rows.length !== result.rows.length) {
throw new RejectedByPolicyError(mutationModel, 'result is not allowed to be read back');
throw new RejectedByPolicyError(
mutationModel,
RejectedByPolicyReason.CANNOT_READ_BACK,
'result is not allowed to be read back',
);
}
return readBackResult;
}
Expand Down Expand Up @@ -335,12 +343,14 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
if (!result.rows[0]?.$conditionA) {
throw new RejectedByPolicyError(
m2m.firstModel as GetModels<Schema>,
RejectedByPolicyReason.CANNOT_READ_BACK,
`many-to-many relation participant model "${m2m.firstModel}" not updatable`,
);
}
if (!result.rows[0]?.$conditionB) {
throw new RejectedByPolicyError(
m2m.secondModel as GetModels<Schema>,
RejectedByPolicyReason.NO_ACCESS,
`many-to-many relation participant model "${m2m.secondModel}" not updatable`,
);
}
Expand Down
Loading