Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/cli/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
"@zenstackhq/testtools": "workspace:*",
"@zenstackhq/typescript-config": "workspace:*",
"@zenstackhq/vitest-config": "workspace:*",
"better-sqlite3": "^12.2.0",
"better-sqlite3": "catalog:",
"tmp": "catalog:"
}
}
50 changes: 50 additions & 0 deletions packages/plugins/policy/package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
{
"name": "@zenstackhq/plugin-policy",
"version": "3.0.0-beta.8",
"description": "ZenStack Policy Plugin",
"type": "module",
"scripts": {
"build": "tsc --noEmit && tsup-node",
"watch": "tsup-node --watch",
"lint": "eslint src --ext ts",
"pack": "pnpm pack"
},
"keywords": [],
"author": "ZenStack Team",
"license": "MIT",
"files": [
"dist"
],
"exports": {
".": {
"import": {
"types": "./dist/index.d.ts",
"default": "./dist/index.js"
},
"require": {
"types": "./dist/index.d.cts",
"default": "./dist/index.cjs"
}
},
"./package.json": {
"import": "./package.json",
"require": "./package.json"
}
},
"dependencies": {
"@zenstackhq/common-helpers": "workspace:*",
"@zenstackhq/sdk": "workspace:*",
"@zenstackhq/runtime": "workspace:*",
"ts-pattern": "catalog:"
},
"peerDependencies": {
"kysely": "catalog:"
},
"devDependencies": {
"@types/better-sqlite3": "^7.6.13",
"@types/pg": "^8.0.0",
"@zenstackhq/eslint-config": "workspace:*",
"@zenstackhq/typescript-config": "workspace:*",
"@zenstackhq/vitest-config": "workspace:*"
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import type { ColumnNode, OperationNode } from 'kysely';
import { DefaultOperationNodeVisitor } from '../../utils/default-operation-node-visitor';
import { DefaultOperationNodeVisitor } from '@zenstackhq/sdk';

/**
* Collects all column names from a query.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import {
type LiteralExpression,
type MemberExpression,
type UnaryExpression,
} from '../../schema';
} from '@zenstackhq/runtime/schema';

type ExpressionEvaluatorContext = {
auth?: any;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,31 @@
import { invariant } from '@zenstackhq/common-helpers';
import {
getCrudDialect,
InternalError,
QueryError,
QueryUtils,
type BaseCrudDialect,
type ClientContract,
type CRUD_EXT,
} from '@zenstackhq/runtime';
import type {
BinaryExpression,
BinaryOperator,
BuiltinType,
FieldDef,
GetModels,
LiteralExpression,
MemberExpression,
UnaryExpression,
} from '@zenstackhq/runtime/schema';
import {
ExpressionUtils,
type ArrayExpression,
type CallExpression,
type Expression,
type FieldExpression,
type SchemaDef,
} from '@zenstackhq/runtime/schema';
import {
AliasNode,
BinaryOperationNode,
Expand All @@ -20,35 +47,6 @@ import {
type OperationNode,
} from 'kysely';
import { match } from 'ts-pattern';
import type { ClientContract, CRUD_EXT } from '../../client/contract';
import { getCrudDialect } from '../../client/crud/dialects';
import type { BaseCrudDialect } from '../../client/crud/dialects/base-dialect';
import { InternalError, QueryError } from '../../client/errors';
import {
getManyToManyRelation,
getModel,
getRelationForeignKeyFieldPairs,
requireField,
requireIdFields,
} from '../../client/query-utils';
import type {
BinaryExpression,
BinaryOperator,
BuiltinType,
FieldDef,
GetModels,
LiteralExpression,
MemberExpression,
UnaryExpression,
} from '../../schema';
import {
ExpressionUtils,
type ArrayExpression,
type CallExpression,
type Expression,
type FieldExpression,
type SchemaDef,
} from '../../schema';
import { ExpressionEvaluator } from './expression-evaluator';
import { conjunction, disjunction, falseNode, isBeforeInvocation, logicalNot, trueNode } from './utils';

Expand Down Expand Up @@ -124,7 +122,7 @@ export class ExpressionTransformer<Schema extends SchemaDef> {

@expr('field')
private _field(expr: FieldExpression, context: ExpressionTransformerContext<Schema>) {
const fieldDef = requireField(this.schema, context.model, expr.field);
const fieldDef = QueryUtils.requireField(this.schema, context.model, expr.field);
if (!fieldDef.relation) {
return this.createColumnRef(expr.field, context);
} else {
Expand Down Expand Up @@ -226,15 +224,15 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
invariant(ExpressionUtils.isNull(expr.right), 'only null comparison is supported for relation field');
const leftRelDef = this.getFieldDefFromFieldRef(expr.left, context.model);
invariant(leftRelDef, 'failed to get relation field definition');
const idFields = requireIdFields(this.schema, leftRelDef.type);
const idFields = QueryUtils.requireIdFields(this.schema, leftRelDef.type);
normalizedLeft = this.makeOrAppendMember(normalizedLeft, idFields[0]!);
}
let normalizedRight: Expression = expr.right;
if (this.isRelationField(expr.right, context.model)) {
invariant(ExpressionUtils.isNull(expr.left), 'only null comparison is supported for relation field');
const rightRelDef = this.getFieldDefFromFieldRef(expr.right, context.model);
invariant(rightRelDef, 'failed to get relation field definition');
const idFields = requireIdFields(this.schema, rightRelDef.type);
const idFields = QueryUtils.requireIdFields(this.schema, rightRelDef.type);
normalizedRight = this.makeOrAppendMember(normalizedRight, idFields[0]!);
}
return { normalizedLeft, normalizedRight };
Expand Down Expand Up @@ -265,10 +263,10 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
ExpressionUtils.isMember(expr.left) && ExpressionUtils.isField(expr.left.receiver),
'left operand must be member access with field receiver',
);
const fieldDef = requireField(this.schema, context.model, expr.left.receiver.field);
const fieldDef = QueryUtils.requireField(this.schema, context.model, expr.left.receiver.field);
newContextModel = fieldDef.type;
for (const member of expr.left.members) {
const memberDef = requireField(this.schema, newContextModel, member);
const memberDef = QueryUtils.requireField(this.schema, newContextModel, member);
newContextModel = memberDef.type;
}
}
Expand Down Expand Up @@ -318,7 +316,7 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
if (ExpressionUtils.isNull(other)) {
return this.transformValue(expr.op === '==' ? !this.auth : !!this.auth, 'Boolean');
} else {
const authModel = getModel(this.schema, this.authType);
const authModel = QueryUtils.getModel(this.schema, this.authType);
if (!authModel) {
throw new QueryError(
`Unsupported use of \`auth()\` in policy of model "${context.model}", comparing with \`auth()\` is only possible when auth type is a model`,
Expand Down Expand Up @@ -481,7 +479,7 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
return this._field(ExpressionUtils.field(expr.members[0]!), context);
} else {
// transform the first segment into a relation access, then continue with the rest of the members
const firstMemberFieldDef = requireField(this.schema, context.model, expr.members[0]!);
const firstMemberFieldDef = QueryUtils.requireField(this.schema, context.model, expr.members[0]!);
receiver = this.transformRelationAccess(expr.members[0]!, firstMemberFieldDef.type, restContext);
members = expr.members.slice(1);
}
Expand All @@ -493,7 +491,7 @@ export class ExpressionTransformer<Schema extends SchemaDef> {

let startType: string;
if (ExpressionUtils.isField(expr.receiver)) {
const receiverField = requireField(this.schema, context.model, expr.receiver.field);
const receiverField = QueryUtils.requireField(this.schema, context.model, expr.receiver.field);
startType = receiverField.type;
} else {
// "this." case, start type is the model of the context
Expand All @@ -504,7 +502,7 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
const memberFields: { fromModel: string; fieldDef: FieldDef }[] = [];
let currType = startType;
for (const member of members) {
const fieldDef = requireField(this.schema, currType, member);
const fieldDef = QueryUtils.requireField(this.schema, currType, member);
memberFields.push({ fieldDef, fromModel: currType });
currType = fieldDef.type;
}
Expand Down Expand Up @@ -561,7 +559,7 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
}

const field = expr.members[0]!;
const fieldDef = requireField(this.schema, receiverType, field);
const fieldDef = QueryUtils.requireField(this.schema, receiverType, field);
const fieldValue = receiver[field] ?? null;
return this.transformValue(fieldValue, fieldDef.type as BuiltinType);
}
Expand All @@ -571,13 +569,13 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
relationModel: string,
context: ExpressionTransformerContext<Schema>,
): SelectQueryNode {
const m2m = getManyToManyRelation(this.schema, context.model, field);
const m2m = QueryUtils.getManyToManyRelation(this.schema, context.model, field);
if (m2m) {
return this.transformManyToManyRelationAccess(m2m, context);
}

const fromModel = context.model;
const { keyPairs, ownedByModel } = getRelationForeignKeyFieldPairs(this.schema, fromModel, field);
const { keyPairs, ownedByModel } = QueryUtils.getRelationForeignKeyFieldPairs(this.schema, fromModel, field);

let condition: OperationNode;
if (ownedByModel) {
Expand Down Expand Up @@ -614,7 +612,7 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
}

private transformManyToManyRelationAccess(
m2m: NonNullable<ReturnType<typeof getManyToManyRelation>>,
m2m: NonNullable<ReturnType<typeof QueryUtils.getManyToManyRelation>>,
context: ExpressionTransformerContext<Schema>,
) {
const eb = expressionBuilder<any, any>();
Expand Down Expand Up @@ -672,13 +670,13 @@ export class ExpressionTransformer<Schema extends SchemaDef> {

private getFieldDefFromFieldRef(expr: Expression, model: GetModels<Schema>): FieldDef | undefined {
if (ExpressionUtils.isField(expr)) {
return requireField(this.schema, model, expr.field);
return QueryUtils.requireField(this.schema, model, expr.field);
} else if (
ExpressionUtils.isMember(expr) &&
expr.members.length === 1 &&
ExpressionUtils.isThis(expr.receiver)
) {
return requireField(this.schema, model, expr.members[0]!);
return QueryUtils.requireField(this.schema, model, expr.members[0]!);
} else {
return undefined;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import { invariant } from '@zenstackhq/common-helpers';
import type { ZModelFunction, ZModelFunctionContext } from '@zenstackhq/runtime';
import { CRUD, QueryUtils } from '@zenstackhq/runtime';
import { ExpressionWrapper, ValueNode, type Expression, type ExpressionBuilder } from 'kysely';
import { CRUD } from '../../client/contract';
import { extractFieldName } from '../../client/kysely-utils';
import type { ZModelFunction, ZModelFunctionContext } from '../../client/options';
import { buildJoinPairs, requireField } from '../../client/query-utils';
import { PolicyHandler } from './policy-handler';

/**
Expand Down Expand Up @@ -31,9 +29,9 @@ export const check: ZModelFunction<any> = (
}

// first argument must be a field reference
const fieldName = extractFieldName(arg1Node);
const fieldName = QueryUtils.extractFieldName(arg1Node);
invariant(fieldName, 'Failed to extract field name from the first argument of "check" function');
const fieldDef = requireField(client.$schema, model, fieldName);
const fieldDef = QueryUtils.requireField(client.$schema, model, fieldName);
invariant(fieldDef.relation, `Field "${fieldName}" is not a relation field in model "${model}"`);
invariant(!fieldDef.array, `Field "${fieldName}" is a to-many relation, which is not supported by "check"`);
const relationModel = fieldDef.type;
Expand All @@ -43,7 +41,7 @@ export const check: ZModelFunction<any> = (
const policyHandler = new PolicyHandler(client);

// join with parent model
const joinPairs = buildJoinPairs(client.$schema, model, modelAlias, fieldName, relationModel);
const joinPairs = QueryUtils.buildJoinPairs(client.$schema, model, modelAlias, fieldName, relationModel);
const joinCondition =
joinPairs.length === 1
? eb(eb.ref(joinPairs[0]![0]), '=', eb.ref(joinPairs[0]![1]))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
export * from './errors';
export * from './plugin';
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { type OnKyselyQueryArgs, type RuntimePlugin } from '../../client/plugin';
import type { SchemaDef } from '../../schema';
import { type OnKyselyQueryArgs, type RuntimePlugin } from '@zenstackhq/runtime';
import type { SchemaDef } from '@zenstackhq/runtime/schema';
import { check } from './functions';
import { PolicyHandler } from './policy-handler';

Expand Down
Loading
Loading