Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
- [ ] Short-circuit pre-create check for scalar-field only policies
- [x] Inject "on conflict do update"
- [x] `check` function
- [ ] Accessing tables not in the schema
- [x] Migration
- [ ] Databases
- [x] SQLite
Expand Down
8 changes: 4 additions & 4 deletions packages/language/res/stdlib.zmodel
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ attribute @@@deprecated(_ message: String)
* @param operation: comma-separated list of "create", "read", "update", "delete". Use "all" to denote all operations.
* @param condition: a boolean expression that controls if the operation should be allowed.
*/
attribute @@allow(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'delete'", "'all'"]), _ condition: Boolean)
attribute @@allow(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'post-update'","'delete'", "'all'"]), _ condition: Boolean)

/**
* Defines an access policy that allows the annotated field to be read or updated.
Expand All @@ -684,7 +684,7 @@ attribute @allow(_ operation: String @@@completionHint(["'create'", "'read'", "'
* @param operation: comma-separated list of "create", "read", "update", "delete". Use "all" to denote all operations.
* @param condition: a boolean expression that controls if the operation should be denied.
*/
attribute @@deny(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'delete'", "'all'"]), _ condition: Boolean)
attribute @@deny(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'post-update'","'delete'", "'all'"]), _ condition: Boolean)

/**
* Defines an access policy that denies the annotated field to be read or updated.
Expand All @@ -705,8 +705,8 @@ function check(field: Any, operation: String?): Boolean {
} @@@expressionContext([AccessPolicy])

/**
* Gets entities value before an update. Only valid when used in a "update" policy rule.
* Gets entity's value before an update. Only valid when used in a "update" policy rule.
*/
function future(): Any {
function before(): Any {
} @@@expressionContext([AccessPolicy])

8 changes: 2 additions & 6 deletions packages/language/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,6 @@ export function isRelationshipField(field: DataField) {
return isDataModel(field.type.reference?.ref);
}

export function isFutureExpr(node: AstNode) {
return isInvocationExpr(node) && node.function.ref?.name === 'future' && isFromStdlib(node.function.ref);
}

export function isDelegateModel(node: AstNode) {
return isDataModel(node) && hasAttribute(node, '@@delegate');
}
Expand Down Expand Up @@ -450,8 +446,8 @@ export function getAuthDecl(decls: (DataModel | TypeDef)[]) {
return authModel;
}

export function isFutureInvocation(node: AstNode) {
return isInvocationExpr(node) && node.function.ref?.name === 'future' && isFromStdlib(node.function.ref);
export function isBeforeInvocation(node: AstNode) {
return isInvocationExpr(node) && node.function.ref?.name === 'before' && isFromStdlib(node.function.ref);
}

export function isCollectionPredicate(node: AstNode): node is BinaryExpr {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ import {
getAllAttributes,
getStringLiteral,
isAuthOrAuthMemberAccess,
isBeforeInvocation,
isCollectionPredicate,
isDataFieldReference,
isDelegateModel,
isFutureExpr,
isRelationshipField,
mapBuiltinTypeToExpressionType,
resolved,
Expand Down Expand Up @@ -166,7 +166,7 @@ export default class AttributeApplicationValidator implements AstValidator<Attri
});
return;
}
this.validatePolicyKinds(kind, ['create', 'read', 'update', 'delete', 'all'], attr, accept);
this.validatePolicyKinds(kind, ['create', 'read', 'update', 'post-update', 'delete', 'all'], attr, accept);

if ((kind === 'create' || kind === 'all') && attr.args[1]?.value) {
// "create" rules cannot access non-owned relations because the entity does not exist yet, so
Expand Down Expand Up @@ -251,8 +251,8 @@ export default class AttributeApplicationValidator implements AstValidator<Attri
const kindItems = this.validatePolicyKinds(kind, ['read', 'update', 'all'], attr, accept);

const expr = attr.args[1]?.value;
if (expr && AstUtils.streamAst(expr).some((node) => isFutureExpr(node))) {
accept('error', `"future()" is not allowed in field-level policy rules`, { node: expr });
if (expr && AstUtils.streamAst(expr).some((node) => isBeforeInvocation(node))) {
accept('error', `"before()" is not allowed in field-level policy rules`, { node: expr });
}

// 'update' rules are not allowed for relation fields
Expand Down
11 changes: 11 additions & 0 deletions packages/language/src/validators/expression-validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@ import {
isNullExpr,
isReferenceExpr,
isThisExpr,
MemberAccessExpr,
type ExpressionType,
} from '../generated/ast';

import {
findUpAst,
isAuthInvocation,
isAuthOrAuthMemberAccess,
isBeforeInvocation,
isDataFieldReference,
isEnumFieldReference,
typeAssignable,
Expand Down Expand Up @@ -59,12 +61,21 @@ export default class ExpressionValidator implements AstValidator<Expression> {

// extra validations by expression type
switch (expr.$type) {
case 'MemberAccessExpr':
this.validateMemberAccessExpr(expr, accept);
break;
case 'BinaryExpr':
this.validateBinaryExpr(expr, accept);
break;
}
}

private validateMemberAccessExpr(expr: MemberAccessExpr, accept: ValidationAcceptor) {
if (isBeforeInvocation(expr.operand) && isDataModel(expr.$resolvedType?.decl)) {
accept('error', 'relation fields cannot be accessed from `before()`', { node: expr });
}
}

private validateBinaryExpr(expr: BinaryExpr, accept: ValidationAcceptor) {
switch (expr.operator) {
case 'in': {
Expand Down
6 changes: 3 additions & 3 deletions packages/language/src/zmodel-linker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ import {
getAuthDecl,
getContainingDataModel,
isAuthInvocation,
isFutureExpr,
isBeforeInvocation,
isMemberContainer,
mapBuiltinTypeToExpressionType,
} from './utils';
Expand Down Expand Up @@ -292,8 +292,8 @@ export class ZModelLinker extends DefaultLinker {
if (authDecl) {
node.$resolvedType = { decl: authDecl, nullable: true };
}
} else if (isFutureExpr(node)) {
// future() function is resolved to current model
} else if (isBeforeInvocation(node)) {
// before() function is resolved to current model
node.$resolvedType = { decl: getContainingDataModel(node) };
} else {
this.resolveToDeclaredType(node, funcDecl.returnType);
Expand Down
6 changes: 3 additions & 3 deletions packages/language/src/zmodel-scope.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import {
getRecursiveBases,
isAuthInvocation,
isCollectionPredicate,
isFutureInvocation,
isBeforeInvocation,
resolveImportUri,
} from './utils';

Expand Down Expand Up @@ -170,8 +170,8 @@ export class ZModelScopeProvider extends DefaultScopeProvider {
return this.createScopeForAuth(node, globalScope);
}

if (isFutureInvocation(operand)) {
// resolve `future()` to the containing model
if (isBeforeInvocation(operand)) {
// resolve `before()` to the containing model
return this.createScopeForContainingModel(node, globalScope);
}
return EMPTY_SCOPE;
Expand Down
10 changes: 10 additions & 0 deletions packages/runtime/src/client/contract.ts
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,21 @@ export interface ClientConstructor {
*/
export type CRUD = 'create' | 'read' | 'update' | 'delete';

/**
* Extended CRUD operations including 'post-update'.
*/
export type CRUD_EXT = CRUD | 'post-update';

/**
* CRUD operations.
*/
export const CRUD = ['create', 'read', 'update', 'delete'] as const;

/**
* Extended CRUD operations including 'post-update'.
*/
export const CRUD_EXT = [...CRUD, 'post-update'] as const;

//#region Model operations

export type AllModelOperations<Schema extends SchemaDef, Model extends GetModels<Schema>> = {
Expand Down
5 changes: 3 additions & 2 deletions packages/runtime/src/client/crud/operations/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1296,8 +1296,9 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
return { count: Number(result.numAffectedRows) } as Result;
} else {
const idFields = requireIdFields(this.schema, model);
const result = await query.returning(idFields as any).execute();
return result as Result;
const finalQuery = query.returning(idFields as any);
const result = await this.executeQuery(kysely, finalQuery, 'update');
return result.rows as Result;
}
}

Expand Down
4 changes: 2 additions & 2 deletions packages/runtime/src/client/options.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import type { Dialect, Expression, ExpressionBuilder, KyselyConfig } from 'kysely';
import type { GetModel, GetModels, ProcedureDef, SchemaDef } from '../schema';
import type { PrependParameter } from '../utils/type-utils';
import type { ClientContract, CRUD, ProcedureFunc } from './contract';
import type { ClientContract, CRUD_EXT, ProcedureFunc } from './contract';
import type { BaseCrudDialect } from './crud/dialects/base-dialect';
import type { RuntimePlugin } from './plugin';
import type { ToKyselySchema } from './query-builder';
Expand Down Expand Up @@ -30,7 +30,7 @@ export type ZModelFunctionContext<Schema extends SchemaDef> = {
/**
* The CRUD operation being performed
*/
operation: CRUD;
operation: CRUD_EXT;
};

export type ZModelFunction<Schema extends SchemaDef> = (
Expand Down
19 changes: 13 additions & 6 deletions packages/runtime/src/plugins/policy/expression-transformer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import {
type OperationNode,
} from 'kysely';
import { match } from 'ts-pattern';
import type { ClientContract, CRUD } from '../../client/contract';
import type { ClientContract, CRUD_EXT } from '../../client/contract';
import { getCrudDialect } from '../../client/crud/dialects';
import type { BaseCrudDialect } from '../../client/crud/dialects/base-dialect';
import { InternalError, QueryError } from '../../client/errors';
Expand Down Expand Up @@ -50,13 +50,12 @@ import {
type SchemaDef,
} from '../../schema';
import { ExpressionEvaluator } from './expression-evaluator';
import { conjunction, disjunction, falseNode, logicalNot, trueNode } from './utils';
import { conjunction, disjunction, falseNode, isBeforeInvocation, logicalNot, trueNode } from './utils';

export type ExpressionTransformerContext<Schema extends SchemaDef> = {
model: GetModels<Schema>;
alias?: string;
operation: CRUD;
auth?: any;
operation: CRUD_EXT;
memberFilter?: OperationNode;
memberSelect?: SelectionNode;
};
Expand Down Expand Up @@ -439,7 +438,7 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
}

if (this.isAuthMember(arg)) {
const valNode = this.valueMemberAccess(context.auth, arg as MemberExpression, this.authType);
const valNode = this.valueMemberAccess(this.auth, arg as MemberExpression, this.authType);
return valNode ? eb.val(valNode.value) : eb.val(null);
}

Expand All @@ -453,11 +452,19 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
@expr('member')
// @ts-ignore
private _member(expr: MemberExpression, context: ExpressionTransformerContext<Schema>) {
// auth() member access
// `auth()` member access
if (this.isAuthCall(expr.receiver)) {
return this.valueMemberAccess(this.auth, expr, this.authType);
}

// `before()` member access
if (isBeforeInvocation(expr.receiver)) {
// policy handler creates a join table named `$before` using entity value before update,
// we can directly reference the column from there
invariant(expr.members.length === 1, 'before() can only be followed by a scalar field access');
return ReferenceNode.create(ColumnNode.create(expr.members[0]!), TableNode.create('$before'));
}

invariant(
ExpressionUtils.isField(expr.receiver) || ExpressionUtils.isThis(expr.receiver),
'expect receiver to be field expression or "this"',
Expand Down
4 changes: 2 additions & 2 deletions packages/runtime/src/plugins/policy/plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ export class PolicyPlugin<Schema extends SchemaDef> implements RuntimePlugin<Sch
};
}

onKyselyQuery({ query, client, proceed /*, transaction*/ }: OnKyselyQueryArgs<Schema>) {
onKyselyQuery({ query, client, proceed }: OnKyselyQueryArgs<Schema>) {
const handler = new PolicyHandler<Schema>(client);
return handler.handle(query, proceed /*, transaction*/);
return handler.handle(query, proceed);
}
}
Loading