Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/cli/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
"@zenstackhq/testtools": "workspace:*",
"@zenstackhq/typescript-config": "workspace:*",
"@zenstackhq/vitest-config": "workspace:*",
"better-sqlite3": "^12.2.0",
"better-sqlite3": "catalog:",
"tmp": "catalog:"
}
}
4 changes: 4 additions & 0 deletions packages/plugins/policy/eslint.config.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import config from '@zenstackhq/eslint-config/base.js';

/** @type {import("eslint").Linter.Config} */
export default config;
50 changes: 50 additions & 0 deletions packages/plugins/policy/package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
{
"name": "@zenstackhq/plugin-policy",
"version": "3.0.0-beta.8",
"description": "ZenStack Policy Plugin",
"type": "module",
"scripts": {
"build": "tsc --noEmit && tsup-node",
"watch": "tsup-node --watch",
"lint": "eslint src --ext ts",
"pack": "pnpm pack"
},
"keywords": [],
"author": "ZenStack Team",
"license": "MIT",
"files": [
"dist"
],
"exports": {
".": {
"import": {
"types": "./dist/index.d.ts",
"default": "./dist/index.js"
},
"require": {
"types": "./dist/index.d.cts",
"default": "./dist/index.cjs"
}
},
"./package.json": {
"import": "./package.json",
"require": "./package.json"
}
},
"dependencies": {
"@zenstackhq/common-helpers": "workspace:*",
"@zenstackhq/sdk": "workspace:*",
"@zenstackhq/runtime": "workspace:*",
"ts-pattern": "catalog:"
},
"peerDependencies": {
"kysely": "catalog:"
},
"devDependencies": {
"@types/better-sqlite3": "^7.6.13",
"@types/pg": "^8.0.0",
"@zenstackhq/eslint-config": "workspace:*",
"@zenstackhq/typescript-config": "workspace:*",
"@zenstackhq/vitest-config": "workspace:*"
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import type { ColumnNode, OperationNode } from 'kysely';
import { DefaultOperationNodeVisitor } from '../../utils/default-operation-node-visitor';
import { DefaultOperationNodeVisitor } from '@zenstackhq/sdk';

/**
* Collects all column names from a query.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import {
type LiteralExpression,
type MemberExpression,
type UnaryExpression,
} from '../../schema';
} from '@zenstackhq/runtime/schema';

type ExpressionEvaluatorContext = {
auth?: any;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,31 @@
import { invariant } from '@zenstackhq/common-helpers';
import {
getCrudDialect,
InternalError,
QueryError,
QueryUtils,
type BaseCrudDialect,
type ClientContract,
type CRUD_EXT,
} from '@zenstackhq/runtime';
import type {
BinaryExpression,
BinaryOperator,
BuiltinType,
FieldDef,
GetModels,
LiteralExpression,
MemberExpression,
UnaryExpression,
} from '@zenstackhq/runtime/schema';
import {
ExpressionUtils,
type ArrayExpression,
type CallExpression,
type Expression,
type FieldExpression,
type SchemaDef,
} from '@zenstackhq/runtime/schema';
import {
AliasNode,
BinaryOperationNode,
Expand All @@ -20,35 +47,6 @@ import {
type OperationNode,
} from 'kysely';
import { match } from 'ts-pattern';
import type { ClientContract, CRUD_EXT } from '../../client/contract';
import { getCrudDialect } from '../../client/crud/dialects';
import type { BaseCrudDialect } from '../../client/crud/dialects/base-dialect';
import { InternalError, QueryError } from '../../client/errors';
import {
getManyToManyRelation,
getModel,
getRelationForeignKeyFieldPairs,
requireField,
requireIdFields,
} from '../../client/query-utils';
import type {
BinaryExpression,
BinaryOperator,
BuiltinType,
FieldDef,
GetModels,
LiteralExpression,
MemberExpression,
UnaryExpression,
} from '../../schema';
import {
ExpressionUtils,
type ArrayExpression,
type CallExpression,
type Expression,
type FieldExpression,
type SchemaDef,
} from '../../schema';
import { ExpressionEvaluator } from './expression-evaluator';
import { conjunction, disjunction, falseNode, isBeforeInvocation, logicalNot, trueNode } from './utils';

Expand Down Expand Up @@ -124,7 +122,7 @@ export class ExpressionTransformer<Schema extends SchemaDef> {

@expr('field')
private _field(expr: FieldExpression, context: ExpressionTransformerContext<Schema>) {
const fieldDef = requireField(this.schema, context.model, expr.field);
const fieldDef = QueryUtils.requireField(this.schema, context.model, expr.field);
if (!fieldDef.relation) {
return this.createColumnRef(expr.field, context);
} else {
Expand Down Expand Up @@ -226,15 +224,15 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
invariant(ExpressionUtils.isNull(expr.right), 'only null comparison is supported for relation field');
const leftRelDef = this.getFieldDefFromFieldRef(expr.left, context.model);
invariant(leftRelDef, 'failed to get relation field definition');
const idFields = requireIdFields(this.schema, leftRelDef.type);
const idFields = QueryUtils.requireIdFields(this.schema, leftRelDef.type);
normalizedLeft = this.makeOrAppendMember(normalizedLeft, idFields[0]!);
}
let normalizedRight: Expression = expr.right;
if (this.isRelationField(expr.right, context.model)) {
invariant(ExpressionUtils.isNull(expr.left), 'only null comparison is supported for relation field');
const rightRelDef = this.getFieldDefFromFieldRef(expr.right, context.model);
invariant(rightRelDef, 'failed to get relation field definition');
const idFields = requireIdFields(this.schema, rightRelDef.type);
const idFields = QueryUtils.requireIdFields(this.schema, rightRelDef.type);
normalizedRight = this.makeOrAppendMember(normalizedRight, idFields[0]!);
}
return { normalizedLeft, normalizedRight };
Expand Down Expand Up @@ -265,10 +263,10 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
ExpressionUtils.isMember(expr.left) && ExpressionUtils.isField(expr.left.receiver),
'left operand must be member access with field receiver',
);
const fieldDef = requireField(this.schema, context.model, expr.left.receiver.field);
const fieldDef = QueryUtils.requireField(this.schema, context.model, expr.left.receiver.field);
newContextModel = fieldDef.type;
for (const member of expr.left.members) {
const memberDef = requireField(this.schema, newContextModel, member);
const memberDef = QueryUtils.requireField(this.schema, newContextModel, member);
newContextModel = memberDef.type;
}
}
Expand Down Expand Up @@ -318,7 +316,7 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
if (ExpressionUtils.isNull(other)) {
return this.transformValue(expr.op === '==' ? !this.auth : !!this.auth, 'Boolean');
} else {
const authModel = getModel(this.schema, this.authType);
const authModel = QueryUtils.getModel(this.schema, this.authType);
if (!authModel) {
throw new QueryError(
`Unsupported use of \`auth()\` in policy of model "${context.model}", comparing with \`auth()\` is only possible when auth type is a model`,
Expand Down Expand Up @@ -481,7 +479,7 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
return this._field(ExpressionUtils.field(expr.members[0]!), context);
} else {
// transform the first segment into a relation access, then continue with the rest of the members
const firstMemberFieldDef = requireField(this.schema, context.model, expr.members[0]!);
const firstMemberFieldDef = QueryUtils.requireField(this.schema, context.model, expr.members[0]!);
receiver = this.transformRelationAccess(expr.members[0]!, firstMemberFieldDef.type, restContext);
members = expr.members.slice(1);
}
Expand All @@ -493,7 +491,7 @@ export class ExpressionTransformer<Schema extends SchemaDef> {

let startType: string;
if (ExpressionUtils.isField(expr.receiver)) {
const receiverField = requireField(this.schema, context.model, expr.receiver.field);
const receiverField = QueryUtils.requireField(this.schema, context.model, expr.receiver.field);
startType = receiverField.type;
} else {
// "this." case, start type is the model of the context
Expand All @@ -504,7 +502,7 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
const memberFields: { fromModel: string; fieldDef: FieldDef }[] = [];
let currType = startType;
for (const member of members) {
const fieldDef = requireField(this.schema, currType, member);
const fieldDef = QueryUtils.requireField(this.schema, currType, member);
memberFields.push({ fieldDef, fromModel: currType });
currType = fieldDef.type;
}
Expand Down Expand Up @@ -561,7 +559,7 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
}

const field = expr.members[0]!;
const fieldDef = requireField(this.schema, receiverType, field);
const fieldDef = QueryUtils.requireField(this.schema, receiverType, field);
const fieldValue = receiver[field] ?? null;
return this.transformValue(fieldValue, fieldDef.type as BuiltinType);
}
Expand All @@ -571,13 +569,13 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
relationModel: string,
context: ExpressionTransformerContext<Schema>,
): SelectQueryNode {
const m2m = getManyToManyRelation(this.schema, context.model, field);
const m2m = QueryUtils.getManyToManyRelation(this.schema, context.model, field);
if (m2m) {
return this.transformManyToManyRelationAccess(m2m, context);
}

const fromModel = context.model;
const { keyPairs, ownedByModel } = getRelationForeignKeyFieldPairs(this.schema, fromModel, field);
const { keyPairs, ownedByModel } = QueryUtils.getRelationForeignKeyFieldPairs(this.schema, fromModel, field);

let condition: OperationNode;
if (ownedByModel) {
Expand Down Expand Up @@ -614,7 +612,7 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
}

private transformManyToManyRelationAccess(
m2m: NonNullable<ReturnType<typeof getManyToManyRelation>>,
m2m: NonNullable<ReturnType<typeof QueryUtils.getManyToManyRelation>>,
context: ExpressionTransformerContext<Schema>,
) {
const eb = expressionBuilder<any, any>();
Expand Down Expand Up @@ -672,13 +670,13 @@ export class ExpressionTransformer<Schema extends SchemaDef> {

private getFieldDefFromFieldRef(expr: Expression, model: GetModels<Schema>): FieldDef | undefined {
if (ExpressionUtils.isField(expr)) {
return requireField(this.schema, model, expr.field);
return QueryUtils.requireField(this.schema, model, expr.field);
} else if (
ExpressionUtils.isMember(expr) &&
expr.members.length === 1 &&
ExpressionUtils.isThis(expr.receiver)
) {
return requireField(this.schema, model, expr.members[0]!);
return QueryUtils.requireField(this.schema, model, expr.members[0]!);
} else {
return undefined;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import { invariant } from '@zenstackhq/common-helpers';
import type { ZModelFunction, ZModelFunctionContext } from '@zenstackhq/runtime';
import { CRUD, QueryUtils } from '@zenstackhq/runtime';
import { ExpressionWrapper, ValueNode, type Expression, type ExpressionBuilder } from 'kysely';
import { CRUD } from '../../client/contract';
import { extractFieldName } from '../../client/kysely-utils';
import type { ZModelFunction, ZModelFunctionContext } from '../../client/options';
import { buildJoinPairs, requireField } from '../../client/query-utils';
import { PolicyHandler } from './policy-handler';

/**
Expand Down Expand Up @@ -31,9 +29,9 @@ export const check: ZModelFunction<any> = (
}

// first argument must be a field reference
const fieldName = extractFieldName(arg1Node);
const fieldName = QueryUtils.extractFieldName(arg1Node);
invariant(fieldName, 'Failed to extract field name from the first argument of "check" function');
const fieldDef = requireField(client.$schema, model, fieldName);
const fieldDef = QueryUtils.requireField(client.$schema, model, fieldName);
invariant(fieldDef.relation, `Field "${fieldName}" is not a relation field in model "${model}"`);
invariant(!fieldDef.array, `Field "${fieldName}" is a to-many relation, which is not supported by "check"`);
const relationModel = fieldDef.type;
Expand All @@ -43,7 +41,7 @@ export const check: ZModelFunction<any> = (
const policyHandler = new PolicyHandler(client);

// join with parent model
const joinPairs = buildJoinPairs(client.$schema, model, modelAlias, fieldName, relationModel);
const joinPairs = QueryUtils.buildJoinPairs(client.$schema, model, modelAlias, fieldName, relationModel);
const joinCondition =
joinPairs.length === 1
? eb(eb.ref(joinPairs[0]![0]), '=', eb.ref(joinPairs[0]![1]))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
export * from './errors';
export * from './plugin';
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { type OnKyselyQueryArgs, type RuntimePlugin } from '../../client/plugin';
import type { SchemaDef } from '../../schema';
import { type OnKyselyQueryArgs, type RuntimePlugin } from '@zenstackhq/runtime';
import type { SchemaDef } from '@zenstackhq/runtime/schema';
import { check } from './functions';
import { PolicyHandler } from './policy-handler';

Expand Down
Loading