Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 60 additions & 7 deletions packages/plugins/policy/src/expression-transformer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -575,20 +575,34 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
}

const fromModel = context.model;
const relationFieldDef = QueryUtils.requireField(this.schema, fromModel, field);
const { keyPairs, ownedByModel } = QueryUtils.getRelationForeignKeyFieldPairs(this.schema, fromModel, field);

let condition: OperationNode;
if (ownedByModel) {
// `fromModel` owns the fk

condition = conjunction(
this.dialect,
keyPairs.map(({ fk, pk }) =>
BinaryOperationNode.create(
ReferenceNode.create(ColumnNode.create(fk), TableNode.create(context.alias ?? fromModel)),
keyPairs.map(({ fk, pk }) => {
let fkRef: OperationNode = ReferenceNode.create(
ColumnNode.create(fk),
TableNode.create(context.alias ?? fromModel),
);
if (relationFieldDef.originModel && relationFieldDef.originModel !== fromModel) {
fkRef = this.buildDelegateBaseFieldSelect(
fromModel,
context.alias ?? fromModel,
fk,
relationFieldDef.originModel,
);
}
return BinaryOperationNode.create(
fkRef,
OperatorNode.create('='),
ReferenceNode.create(ColumnNode.create(pk), TableNode.create(relationModel)),
),
),
);
}),
);
} else {
// `relationModel` owns the fk
Expand Down Expand Up @@ -633,8 +647,47 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
return relationQuery.toOperationNode();
}

private createColumnRef(column: string, context: ExpressionTransformerContext<Schema>): ReferenceNode {
return ReferenceNode.create(ColumnNode.create(column), TableNode.create(context.alias ?? context.model));
private createColumnRef(column: string, context: ExpressionTransformerContext<Schema>) {
// if field comes from a delegate base model, we need to use the join alias
// of that base model

const tableName = context.alias ?? context.model;

// "create" policies evaluate table from "VALUES" node so no join from delegate bases are
// created and thus we should directly use the model table name
if (context.operation === 'create') {
return ReferenceNode.create(ColumnNode.create(column), TableNode.create(tableName));
}

const fieldDef = QueryUtils.requireField(this.schema, context.model, column);
if (!fieldDef.originModel || fieldDef.originModel === context.model) {
return ReferenceNode.create(ColumnNode.create(column), TableNode.create(tableName));
}

return this.buildDelegateBaseFieldSelect(context.model, tableName, column, fieldDef.originModel);
}

private buildDelegateBaseFieldSelect(model: string, modelAlias: string, field: string, baseModel: string) {
const idFields = QueryUtils.requireIdFields(this.client.$schema, model);
return {
kind: 'SelectQueryNode',
from: FromNode.create([TableNode.create(baseModel)]),
selections: [
SelectionNode.create(ReferenceNode.create(ColumnNode.create(field), TableNode.create(baseModel))),
],
where: WhereNode.create(
conjunction(
this.dialect,
idFields.map((idField) =>
BinaryOperationNode.create(
ReferenceNode.create(ColumnNode.create(idField), TableNode.create(baseModel)),
OperatorNode.create('='),
ReferenceNode.create(ColumnNode.create(idField), TableNode.create(modelAlias)),
),
),
),
),
} satisfies SelectQueryNode;
}

private isAuthCall(value: unknown): value is CallExpression {
Expand Down
58 changes: 50 additions & 8 deletions packages/plugins/policy/src/functions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,60 @@ export const check: ZModelFunction<any> = (
invariant(!fieldDef.array, `Field "${fieldName}" is a to-many relation, which is not supported by "check"`);
const relationModel = fieldDef.type;

const op = arg2Node ? (arg2Node.value as CRUD) : operation;
// build the join condition between the current model and the related model
const joinConditions: Expression<any>[] = [];
const fkInfo = QueryUtils.getRelationForeignKeyFieldPairs(client.$schema, model, fieldName);
const idFields = QueryUtils.requireIdFields(client.$schema, model);

const policyHandler = new PolicyHandler(client);
// helper to build a base model select for delegate models
const buildBaseSelect = (baseModel: string, field: string): Expression<any> => {
return eb
.selectFrom(baseModel)
.select(field)
.where(
eb.and(
idFields.map((idField) =>
eb(eb.ref(`${fieldDef.originModel}.${idField}`), '=', eb.ref(`${modelAlias}.${idField}`)),
),
),
);
};

if (fkInfo.ownedByModel) {
// model owns the relation
joinConditions.push(
...fkInfo.keyPairs.map(({ fk, pk }) => {
let fkRef: Expression<any>;
if (fieldDef.originModel && fieldDef.originModel !== model) {
// relation is actually defined in a delegate base model, select from there
fkRef = buildBaseSelect(fieldDef.originModel, fk);
} else {
fkRef = eb.ref(`${modelAlias}.${fk}`);
}
return eb(fkRef, '=', eb.ref(`${relationModel}.${pk}`));
}),
);
} else {
// related model owns the relation
joinConditions.push(
...fkInfo.keyPairs.map(({ fk, pk }) => {
let pkRef: Expression<any>;
if (fieldDef.originModel && fieldDef.originModel !== model) {
// relation is actually defined in a delegate base model, select from there
pkRef = buildBaseSelect(fieldDef.originModel, pk);
} else {
pkRef = eb.ref(`${modelAlias}.${pk}`);
}
return eb(pkRef, '=', eb.ref(`${relationModel}.${fk}`));
}),
);
}

// join with parent model
const joinPairs = QueryUtils.buildJoinPairs(client.$schema, model, modelAlias, fieldName, relationModel);
const joinCondition =
joinPairs.length === 1
? eb(eb.ref(joinPairs[0]![0]), '=', eb.ref(joinPairs[0]![1]))
: eb.and(joinPairs.map(([left, right]) => eb(eb.ref(left), '=', eb.ref(right))));
const joinCondition = joinConditions.length === 1 ? joinConditions[0]! : eb.and(joinConditions);

// policy condition of the related model
const policyHandler = new PolicyHandler(client);
const op = arg2Node ? (arg2Node.value as CRUD) : operation;
const policyCondition = policyHandler.buildPolicyFilter(relationModel, undefined, op);

// build the final nested select that evaluates the policy condition
Expand Down
5 changes: 5 additions & 0 deletions packages/plugins/policy/src/policy-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,10 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
// #region overrides

protected override transformSelectQuery(node: SelectQueryNode) {
if (!node.from) {
return super.transformSelectQuery(node);
}

let whereNode = this.transformNode(node.where);

// get combined policy filter for all froms, and merge into where clause
Expand Down Expand Up @@ -327,6 +331,7 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf

// build a nested query with policy filter applied
const filter = this.buildPolicyFilter(table.model, table.alias, 'read');

const nestedSelect: SelectQueryNode = {
kind: 'SelectQueryNode',
from: FromNode.create([node.table]),
Expand Down
3 changes: 1 addition & 2 deletions packages/testtools/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"description": "ZenStack Test Tools",
"type": "module",
"scripts": {
"build": "tsc --noEmit && tsup-node && copyfiles -f ./src/types.d.ts ./dist",
"build": "tsc --noEmit && tsup-node",
"watch": "tsup-node --watch",
"lint": "eslint src --ext ts",
"pack": "pnpm pack"
Expand Down Expand Up @@ -53,7 +53,6 @@
"@types/pg": "^8.11.11",
"@zenstackhq/eslint-config": "workspace:*",
"@zenstackhq/typescript-config": "workspace:*",
"copyfiles": "^2.4.1",
"typescript": "catalog:"
}
}
5 changes: 4 additions & 1 deletion packages/testtools/tsup.config.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import fs from 'fs';
import { defineConfig } from 'tsup';

export default defineConfig({
Expand All @@ -7,7 +8,9 @@ export default defineConfig({
outDir: 'dist',
splitting: false,
sourcemap: true,
clean: true,
dts: true,
format: ['cjs', 'esm'],
async onSuccess() {
fs.cpSync('src/types.d.ts', 'dist/types.d.ts', { force: true });
},
});
Loading