diff --git a/TODO.md b/TODO.md index 4a861b4f..d913fdbb 100644 --- a/TODO.md +++ b/TODO.md @@ -3,13 +3,13 @@ - [ ] CLI - [x] generate - [x] migrate - - [ ] db + - [x] db - [x] push - - [ ] seed + - [x] seed - [x] info - [x] init - [x] validate - - [ ] format + - [x] format - [ ] repl - [x] plugin mechanism - [x] built-in plugins diff --git a/packages/language/src/validators/function-invocation-validator.ts b/packages/language/src/validators/function-invocation-validator.ts index f4552743..a2ff34fd 100644 --- a/packages/language/src/validators/function-invocation-validator.ts +++ b/packages/language/src/validators/function-invocation-validator.ts @@ -179,6 +179,17 @@ export default class FunctionInvocationValidator implements AstValidator { @@ -184,10 +184,9 @@ export class ZModelScopeProvider extends DefaultScopeProvider { const globalScope = this.getGlobalScope(referenceType, context); const collection = collectionPredicate.left; - // TODO: generalize it + // TODO: full support of typedef member access // // typedef's fields are only added to the scope if the access starts with `auth().` - // const allowTypeDefScope = isAuthOrAuthMemberAccess(collection); - const allowTypeDefScope = false; + const allowTypeDefScope = isAuthOrAuthMemberAccess(collection); return match(collection) .when(isReferenceExpr, (expr) => { diff --git a/packages/orm/src/utils/schema-utils.ts b/packages/orm/src/utils/schema-utils.ts index 8c0824d4..cd5ce553 100644 --- a/packages/orm/src/utils/schema-utils.ts +++ b/packages/orm/src/utils/schema-utils.ts @@ -12,9 +12,11 @@ import type { UnaryExpression, } from '../schema'; +export type VisitResult = void | { abort: true }; + export class ExpressionVisitor { - visit(expr: Expression): void { - match(expr) + visit(expr: Expression): VisitResult { + return match(expr) .with({ kind: 'literal' }, (e) => this.visitLiteral(e)) .with({ kind: 'array' }, (e) => this.visitArray(e)) .with({ kind: 'field' }, (e) => this.visitField(e)) @@ -27,32 +29,68 @@ export class ExpressionVisitor { .exhaustive(); } - protected visitLiteral(_e: LiteralExpression) {} + protected visitLiteral(_e: LiteralExpression): VisitResult {} - protected visitArray(e: ArrayExpression) { - e.items.forEach((item) => this.visit(item)); + protected visitArray(e: ArrayExpression): VisitResult { + for (const item of e.items) { + const result = this.visit(item); + if (result?.abort) { + return result; + } + } } - protected visitField(_e: FieldExpression) {} + protected visitField(_e: FieldExpression): VisitResult {} - protected visitMember(e: MemberExpression) { - this.visit(e.receiver); + protected visitMember(e: MemberExpression): VisitResult { + return this.visit(e.receiver); } - protected visitBinary(e: BinaryExpression) { - this.visit(e.left); - this.visit(e.right); + protected visitBinary(e: BinaryExpression): VisitResult { + const l = this.visit(e.left); + if (l?.abort) { + return l; + } else { + return this.visit(e.right); + } } - protected visitUnary(e: UnaryExpression) { - this.visit(e.operand); + protected visitUnary(e: UnaryExpression): VisitResult { + return this.visit(e.operand); } - protected visitCall(e: CallExpression) { - e.args?.forEach((arg) => this.visit(arg)); + protected visitCall(e: CallExpression): VisitResult { + for (const arg of e.args ?? []) { + const r = this.visit(arg); + if (r?.abort) { + return r; + } + } } - protected visitThis(_e: ThisExpression) {} + protected visitThis(_e: ThisExpression): VisitResult {} + + protected visitNull(_e: NullExpression): VisitResult {} +} - protected visitNull(_e: NullExpression) {} +export class MatchingExpressionVisitor extends ExpressionVisitor { + private found = false; + + constructor(private predicate: (expr: Expression) => boolean) { + super(); + } + + find(expr: Expression) { + this.visit(expr); + return this.found; + } + + override visit(expr: Expression) { + if (this.predicate(expr)) { + this.found = true; + return { abort: true } as const; + } else { + return super.visit(expr); + } + } } diff --git a/packages/plugins/policy/src/expression-evaluator.ts b/packages/plugins/policy/src/expression-evaluator.ts index a09c87d1..e1d4e8e4 100644 --- a/packages/plugins/policy/src/expression-evaluator.ts +++ b/packages/plugins/policy/src/expression-evaluator.ts @@ -79,6 +79,11 @@ export class ExpressionEvaluator { const left = this.evaluate(expr.left, context); const right = this.evaluate(expr.right, context); + if (!['==', '!='].includes(expr.op) && (left === null || right === null)) { + // non-equality comparison with null always yields null (follow SQL logic) + return null; + } + return match(expr.op) .with('==', () => left === right) .with('!=', () => left !== right) @@ -102,7 +107,7 @@ export class ExpressionEvaluator { const left = this.evaluate(expr.left, context); if (!left) { - return false; + return null; } invariant(Array.isArray(left), 'expected array'); diff --git a/packages/plugins/policy/src/expression-transformer.ts b/packages/plugins/policy/src/expression-transformer.ts index 8036eb16..0ea84a97 100644 --- a/packages/plugins/policy/src/expression-transformer.ts +++ b/packages/plugins/policy/src/expression-transformer.ts @@ -1,5 +1,12 @@ import { invariant } from '@zenstackhq/common-helpers'; -import { getCrudDialect, QueryUtils, type BaseCrudDialect, type ClientContract, type CRUD_EXT } from '@zenstackhq/orm'; +import { + getCrudDialect, + QueryUtils, + SchemaUtils, + type BaseCrudDialect, + type ClientContract, + type CRUD_EXT, +} from '@zenstackhq/orm'; import type { BinaryExpression, BinaryOperator, @@ -40,6 +47,7 @@ import { } from 'kysely'; import { match } from 'ts-pattern'; import { ExpressionEvaluator } from './expression-evaluator'; +import { CollectionPredicateOperator } from './types'; import { conjunction, createUnsupportedError, @@ -50,12 +58,49 @@ import { trueNode, } from './utils'; +/** + * Context for transforming a policy expression + */ export type ExpressionTransformerContext = { - model: string; + /** + * The current model or type name fields should be resolved against + */ + modelOrType: string; + + /** + * The alias name that should be used to address a model + */ alias?: string; + + /** + * The CRUD operation + */ operation: CRUD_EXT; + + /** + * In case of transforming a collection predicate's LHS, the compiled RHS filter expression + */ memberFilter?: OperationNode; + + /** + * In case of transforming a collection predicate's LHS, the field name to select as the predicate result + */ memberSelect?: SelectionNode; + + /** + * The value object that fields should be evaluated against + */ + contextValue?: Record; + + /** + * The model or type name that `this` keyword refers to + */ + thisType: string; + + /** + * The table alias name used to compile `this` keyword + */ + thisAlias?: string; }; // a registry of expression handlers marked with @expr @@ -122,7 +167,13 @@ export class ExpressionTransformer { @expr('field') private _field(expr: FieldExpression, context: ExpressionTransformerContext) { - const fieldDef = QueryUtils.requireField(this.schema, context.model, expr.field); + if (context.contextValue) { + // if we're transforming against a value object, fields should be evaluated directly + const fieldDef = QueryUtils.requireField(this.schema, context.modelOrType, expr.field); + return this.transformValue(context.contextValue[expr.field], fieldDef.type as BuiltinType); + } + + const fieldDef = QueryUtils.requireField(this.schema, context.modelOrType, expr.field); if (!fieldDef.relation) { return this.createColumnRef(expr.field, context); } else { @@ -202,35 +253,45 @@ export class ExpressionTransformer { } private transformNullCheck(expr: OperationNode, operator: BinaryOperator) { - invariant(operator === '==' || operator === '!=', 'operator must be "==" or "!=" for null comparison'); - if (ValueNode.is(expr)) { - if (expr.value === null) { - return operator === '==' ? trueNode(this.dialect) : falseNode(this.dialect); + if (operator === '==' || operator === '!=') { + // equality checks against null + if (ValueNode.is(expr)) { + if (expr.value === null) { + return operator === '==' ? trueNode(this.dialect) : falseNode(this.dialect); + } else { + return operator === '==' ? falseNode(this.dialect) : trueNode(this.dialect); + } } else { - return operator === '==' ? falseNode(this.dialect) : trueNode(this.dialect); + return operator === '==' + ? BinaryOperationNode.create(expr, OperatorNode.create('is'), ValueNode.createImmediate(null)) + : BinaryOperationNode.create(expr, OperatorNode.create('is not'), ValueNode.createImmediate(null)); } } else { - return operator === '==' - ? BinaryOperationNode.create(expr, OperatorNode.create('is'), ValueNode.createImmediate(null)) - : BinaryOperationNode.create(expr, OperatorNode.create('is not'), ValueNode.createImmediate(null)); + // otherwise any comparison with null is null + return ValueNode.createImmediate(null); } } private normalizeBinaryOperationOperands(expr: BinaryExpression, context: ExpressionTransformerContext) { + if (context.contextValue) { + // no normalization needed if evaluating against a value object + return { normalizedLeft: expr.left, normalizedRight: expr.right }; + } + // if relation fields are used directly in comparison, it can only be compared with null, // so we normalize the args with the id field (use the first id field if multiple) let normalizedLeft: Expression = expr.left; - if (this.isRelationField(expr.left, context.model)) { + if (this.isRelationField(expr.left, context.modelOrType)) { invariant(ExpressionUtils.isNull(expr.right), 'only null comparison is supported for relation field'); - const leftRelDef = this.getFieldDefFromFieldRef(expr.left, context.model); + const leftRelDef = this.getFieldDefFromFieldRef(expr.left, context.modelOrType); invariant(leftRelDef, 'failed to get relation field definition'); const idFields = QueryUtils.requireIdFields(this.schema, leftRelDef.type); normalizedLeft = this.makeOrAppendMember(normalizedLeft, idFields[0]!); } let normalizedRight: Expression = expr.right; - if (this.isRelationField(expr.right, context.model)) { + if (this.isRelationField(expr.right, context.modelOrType)) { invariant(ExpressionUtils.isNull(expr.left), 'only null comparison is supported for relation field'); - const rightRelDef = this.getFieldDefFromFieldRef(expr.right, context.model); + const rightRelDef = this.getFieldDefFromFieldRef(expr.right, context.modelOrType); invariant(rightRelDef, 'failed to get relation field definition'); const idFields = QueryUtils.requireIdFields(this.schema, rightRelDef.type); normalizedRight = this.makeOrAppendMember(normalizedRight, idFields[0]!); @@ -239,22 +300,35 @@ export class ExpressionTransformer { } private transformCollectionPredicate(expr: BinaryExpression, context: ExpressionTransformerContext) { - invariant(expr.op === '?' || expr.op === '!' || expr.op === '^', 'expected "?" or "!" or "^" operator'); + this.ensureCollectionPredicateOperator(expr.op); - if (this.isAuthCall(expr.left) || this.isAuthMember(expr.left)) { - const value = new ExpressionEvaluator().evaluate(expr, { - auth: this.auth, - }); - return this.transformValue(value, 'Boolean'); + if (this.isAuthMember(expr.left) || context.contextValue) { + invariant( + ExpressionUtils.isMember(expr.left) || ExpressionUtils.isField(expr.left), + 'expected member or field expression', + ); + + // LHS of the expression is evaluated as a value + const evaluator = new ExpressionEvaluator(); + const receiver = evaluator.evaluate(expr.left, { thisValue: context.contextValue, auth: this.auth }); + + // get LHS's type + const baseType = this.isAuthMember(expr.left) ? this.authType : context.modelOrType; + const memberType = this.getMemberType(baseType, expr.left); + + // transform the entire expression with a value LHS and the correct context type + return this.transformValueCollectionPredicate(receiver, expr, { ...context, modelOrType: memberType }); } + // otherwise, transform the expression with relation joins + invariant( ExpressionUtils.isField(expr.left) || ExpressionUtils.isMember(expr.left), 'left operand must be field or member access', ); let newContextModel: string; - const fieldDef = this.getFieldDefFromFieldRef(expr.left, context.model); + const fieldDef = this.getFieldDefFromFieldRef(expr.left, context.modelOrType); if (fieldDef) { invariant(fieldDef.relation, `field is not a relation: ${JSON.stringify(expr.left)}`); newContextModel = fieldDef.type; @@ -263,7 +337,7 @@ export class ExpressionTransformer { ExpressionUtils.isMember(expr.left) && ExpressionUtils.isField(expr.left.receiver), 'left operand must be member access with field receiver', ); - const fieldDef = QueryUtils.requireField(this.schema, context.model, expr.left.receiver.field); + const fieldDef = QueryUtils.requireField(this.schema, context.modelOrType, expr.left.receiver.field); newContextModel = fieldDef.type; for (const member of expr.left.members) { const memberDef = QueryUtils.requireField(this.schema, newContextModel, member); @@ -273,7 +347,7 @@ export class ExpressionTransformer { let predicateFilter = this.transform(expr.right, { ...context, - model: newContextModel, + modelOrType: newContextModel, alias: undefined, }); @@ -296,10 +370,80 @@ export class ExpressionTransformer { }); } + private ensureCollectionPredicateOperator(op: BinaryOperator): asserts op is CollectionPredicateOperator { + invariant(CollectionPredicateOperator.includes(op as any), 'expected "?" or "!" or "^" operator'); + } + + private transformValueCollectionPredicate( + receiver: any, + expr: BinaryExpression, + context: ExpressionTransformerContext, + ) { + if (!receiver) { + return ValueNode.createImmediate(null); + } + + this.ensureCollectionPredicateOperator(expr.op); + + const visitor = new SchemaUtils.MatchingExpressionVisitor((e) => ExpressionUtils.isThis(e)); + if (!visitor.find(expr.right)) { + // right side only refers to the value tree, evaluate directly as an optimization + const value = new ExpressionEvaluator().evaluate(expr, { + auth: this.auth, + thisValue: context.contextValue, + }); + return this.transformValue(value, 'Boolean'); + } else { + // right side refers to `this`, need expand into a real filter + // e.g.: `auth().profiles?[age == this.age], where `this` refer to the containing model + invariant(Array.isArray(receiver), 'array value is expected'); + + // for each LHS element, transform RHS + // e.g.: `auth().profiles[age == this.age]`, each `auth().profiles` element (which is a value) + // is used to build an expression for the RHS `age == this.age` + // the transformation happens recursively for nested collection predicates + const components = receiver.map((item) => + this.transform(expr.right, { + operation: context.operation, + thisType: context.thisType, + thisAlias: context.thisAlias, + modelOrType: context.modelOrType, + contextValue: item, + }), + ); + + // compose the components based on the operator + return ( + match(expr.op) + // some + .with('?', () => disjunction(this.dialect, components)) + // every + .with('!', () => conjunction(this.dialect, components)) + // none + .with('^', () => logicalNot(this.dialect, disjunction(this.dialect, components))) + .exhaustive() + ); + } + } + + private getMemberType(receiverType: string, expr: MemberExpression | FieldExpression) { + if (ExpressionUtils.isField(expr)) { + const fieldDef = QueryUtils.requireField(this.schema, receiverType, expr.field); + return fieldDef.type; + } else { + let currType = receiverType; + for (const member of expr.members) { + const fieldDef = QueryUtils.requireField(this.schema, currType, member); + currType = fieldDef.type; + } + return currType; + } + } + private transformAuthBinary(expr: BinaryExpression, context: ExpressionTransformerContext) { if (expr.op !== '==' && expr.op !== '!=') { throw createUnsupportedError( - `Unsupported operator for \`auth()\` in policy of model "${context.model}": ${expr.op}`, + `Unsupported operator for \`auth()\` in policy of model "${context.modelOrType}": ${expr.op}`, ); } @@ -319,7 +463,7 @@ export class ExpressionTransformer { const authModel = QueryUtils.getModel(this.schema, this.authType); if (!authModel) { throw createUnsupportedError( - `Unsupported use of \`auth()\` in policy of model "${context.model}", comparing with \`auth()\` is only possible when auth type is a model`, + `Unsupported use of \`auth()\` in policy of model "${context.modelOrType}", comparing with \`auth()\` is only possible when auth type is a model`, ); } @@ -358,7 +502,13 @@ export class ExpressionTransformer { } else if (value === false) { return falseNode(this.dialect); } else { - return ValueNode.create(this.dialect.transformPrimitive(value, type, false) ?? null); + const transformed = this.dialect.transformPrimitive(value, type, false) ?? null; + if (!Array.isArray(transformed)) { + // simple primitives can be immediate values + return ValueNode.createImmediate(transformed); + } else { + return ValueNode.create(transformed); + } } } @@ -396,8 +546,8 @@ export class ExpressionTransformer { { client: this.client, dialect: this.dialect, - model: context.model as GetModels, - modelAlias: context.alias ?? context.model, + model: context.modelOrType as GetModels, + modelAlias: context.alias ?? context.modelOrType, operation: context.operation, }, ); @@ -476,10 +626,16 @@ export class ExpressionTransformer { if (ExpressionUtils.isThis(expr.receiver)) { if (expr.members.length === 1) { // `this.relation` case, equivalent to field access - return this._field(ExpressionUtils.field(expr.members[0]!), context); + return this._field(ExpressionUtils.field(expr.members[0]!), { + ...context, + alias: context.thisAlias, + modelOrType: context.thisType, + thisType: context.thisType, + contextValue: undefined, + }); } else { // transform the first segment into a relation access, then continue with the rest of the members - const firstMemberFieldDef = QueryUtils.requireField(this.schema, context.model, expr.members[0]!); + const firstMemberFieldDef = QueryUtils.requireField(this.schema, context.thisType, expr.members[0]!); receiver = this.transformRelationAccess(expr.members[0]!, firstMemberFieldDef.type, restContext); members = expr.members.slice(1); } @@ -491,11 +647,11 @@ export class ExpressionTransformer { let startType: string; if (ExpressionUtils.isField(expr.receiver)) { - const receiverField = QueryUtils.requireField(this.schema, context.model, expr.receiver.field); + const receiverField = QueryUtils.requireField(this.schema, context.modelOrType, expr.receiver.field); startType = receiverField.type; } else { - // "this." case, start type is the model of the context - startType = context.model; + // "this." case + startType = context.thisType; } // traverse forward to collect member types @@ -516,7 +672,7 @@ export class ExpressionTransformer { if (fieldDef.relation) { const relation = this.transformRelationAccess(member, fieldDef.type, { ...restContext, - model: fromModel, + modelOrType: fromModel, alias: undefined, }); @@ -554,14 +710,24 @@ export class ExpressionTransformer { return ValueNode.createImmediate(null); } - if (expr.members.length !== 1) { - throw new Error(`Only single member access is supported`); - } + invariant(expr.members.length > 0, 'member expression must have at least one member'); - const field = expr.members[0]!; - const fieldDef = QueryUtils.requireField(this.schema, receiverType, field); - const fieldValue = receiver[field] ?? null; - return this.transformValue(fieldValue, fieldDef.type as BuiltinType); + let curr: any = receiver; + let currType = receiverType; + for (let i = 0; i < expr.members.length; i++) { + const field = expr.members[i]!; + curr = curr?.[field]; + if (curr === undefined) { + curr = ValueNode.createImmediate(null); + break; + } + currType = QueryUtils.requireField(this.schema, currType, field).type; + if (i === expr.members.length - 1) { + // last segment (which is the value), make sure it's transformed + curr = this.transformValue(curr, currType as BuiltinType); + } + } + return curr; } private transformRelationAccess( @@ -569,12 +735,12 @@ export class ExpressionTransformer { relationModel: string, context: ExpressionTransformerContext, ): SelectQueryNode { - const m2m = QueryUtils.getManyToManyRelation(this.schema, context.model, field); + const m2m = QueryUtils.getManyToManyRelation(this.schema, context.modelOrType, field); if (m2m) { return this.transformManyToManyRelationAccess(m2m, context); } - const fromModel = context.model; + const fromModel = context.modelOrType; const relationFieldDef = QueryUtils.requireField(this.schema, fromModel, field); const { keyPairs, ownedByModel } = QueryUtils.getRelationForeignKeyFieldPairs(this.schema, fromModel, field); @@ -641,7 +807,7 @@ export class ExpressionTransformer { .onRef( `${m2m.joinTable}.${m2m.parentFkName}`, '=', - `${context.alias ?? context.model}.${m2m.parentPKName}`, + `${context.alias ?? context.modelOrType}.${m2m.parentPKName}`, ), ); return relationQuery.toOperationNode(); @@ -651,7 +817,7 @@ export class ExpressionTransformer { // if field comes from a delegate base model, we need to use the join alias // of that base model - const tableName = context.alias ?? context.model; + const tableName = context.alias ?? context.modelOrType; // "create" policies evaluate table from "VALUES" node so no join from delegate bases are // created and thus we should directly use the model table name @@ -659,12 +825,12 @@ export class ExpressionTransformer { return ReferenceNode.create(ColumnNode.create(column), TableNode.create(tableName)); } - const fieldDef = QueryUtils.requireField(this.schema, context.model, column); - if (!fieldDef.originModel || fieldDef.originModel === context.model) { + const fieldDef = QueryUtils.requireField(this.schema, context.modelOrType, column); + if (!fieldDef.originModel || fieldDef.originModel === context.modelOrType) { return ReferenceNode.create(ColumnNode.create(column), TableNode.create(tableName)); } - return this.buildDelegateBaseFieldSelect(context.model, tableName, column, fieldDef.originModel); + return this.buildDelegateBaseFieldSelect(context.modelOrType, tableName, column, fieldDef.originModel); } private buildDelegateBaseFieldSelect(model: string, modelAlias: string, field: string, baseModel: string) { @@ -723,13 +889,13 @@ export class ExpressionTransformer { private getFieldDefFromFieldRef(expr: Expression, model: string): FieldDef | undefined { if (ExpressionUtils.isField(expr)) { - return QueryUtils.requireField(this.schema, model, expr.field); + return QueryUtils.getField(this.schema, model, expr.field); } else if ( ExpressionUtils.isMember(expr) && expr.members.length === 1 && ExpressionUtils.isThis(expr.receiver) ) { - return QueryUtils.requireField(this.schema, model, expr.members[0]!); + return QueryUtils.getField(this.schema, model, expr.members[0]!); } else { return undefined; } diff --git a/packages/plugins/policy/src/policy-handler.ts b/packages/plugins/policy/src/policy-handler.ts index e1c24d6c..eeb4a8b8 100644 --- a/packages/plugins/policy/src/policy-handler.ts +++ b/packages/plugins/policy/src/policy-handler.ts @@ -881,7 +881,9 @@ export class PolicyHandler extends OperationNodeTransf private compilePolicyCondition(model: string, alias: string | undefined, operation: CRUD_EXT, policy: Policy) { return new ExpressionTransformer(this.client).transform(policy.condition, { - model, + modelOrType: model, + thisType: model, // type name for `this`, never changed during the entire transformation + thisAlias: alias, // alias for `this`, never changed during the entire transformation alias, operation, }); diff --git a/packages/plugins/policy/src/types.ts b/packages/plugins/policy/src/types.ts index dc042aa9..8f2f635b 100644 --- a/packages/plugins/policy/src/types.ts +++ b/packages/plugins/policy/src/types.ts @@ -19,3 +19,13 @@ export type Policy = { operations: readonly PolicyOperation[]; condition: Expression; }; + +/** + * Operators allowed for collection predicate expressions. + */ +export const CollectionPredicateOperator = ['?', '!', '^'] as const; + +/** + * Operators allowed for collection predicate expressions. + */ +export type CollectionPredicateOperator = (typeof CollectionPredicateOperator)[number]; diff --git a/packages/sdk/src/model-utils.ts b/packages/sdk/src/model-utils.ts index c9102c22..2473f284 100644 --- a/packages/sdk/src/model-utils.ts +++ b/packages/sdk/src/model-utils.ts @@ -106,7 +106,7 @@ export function getAuthDecl(model: Model) { (d) => (isDataModel(d) || isTypeDef(d)) && d.attributes.some((attr) => attr.decl.$refText === '@@auth'), ); if (!found) { - found = model.declarations.find((d) => isDataModel(d) && d.name === 'User'); + found = model.declarations.find((d) => (isDataModel(d) || isTypeDef(d)) && d.name === 'User'); } return found; } diff --git a/tests/e2e/apps/rally/rally.test.ts b/tests/e2e/apps/rally/rally.test.ts index e14b8798..5b204275 100644 --- a/tests/e2e/apps/rally/rally.test.ts +++ b/tests/e2e/apps/rally/rally.test.ts @@ -17,7 +17,6 @@ describe('Rally app tests', () => { destination: 'models', }, ], - debug: true, dataSourceExtensions: ['citext'], usePrismaPush: true, }); diff --git a/tests/e2e/orm/policy/auth-access.test.ts b/tests/e2e/orm/policy/auth-access.test.ts new file mode 100644 index 00000000..b994324f --- /dev/null +++ b/tests/e2e/orm/policy/auth-access.test.ts @@ -0,0 +1,436 @@ +import { createPolicyTestClient } from '@zenstackhq/testtools'; +import { describe, expect, it } from 'vitest'; + +describe('Auth access tests', () => { + it('works with simple auth model', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + age Int +} + +model Foo { + id Int @id + name String + @@allow('all', auth().age > 18) +} +`, + ); + + await db.$unuseAll().foo.create({ data: { id: 1, name: 'Test' } }); + await expect(db.foo.findFirst()).toResolveNull(); + await expect(db.$setAuth({ age: 15 }).foo.findFirst()).toResolveNull(); + await expect(db.$setAuth({ age: 20 }).foo.findFirst()).toResolveTruthy(); + }); + + it('works with simple auth type', async () => { + const db = await createPolicyTestClient( + ` +type User { + age Int + @@auth +} + +model Foo { + id Int @id + name String + @@allow('all', auth().age > 18) +} +`, + ); + + await db.$unuseAll().foo.create({ data: { id: 1, name: 'Test' } }); + await expect(db.foo.findFirst()).toResolveNull(); + await expect(db.$setAuth({ age: 15 }).foo.findFirst()).toResolveNull(); + await expect(db.$setAuth({ age: 20 }).foo.findFirst()).toResolveTruthy(); + }); + + it('works with deep model value access', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + profile Profile? +} + +model Profile { + id Int @id + user User @relation(fields: [userId], references: [id]) + userId Int @unique + age Int +} + +model Foo { + id Int @id + name String + @@allow('all', auth().profile.age > 18) +}`, + ); + await db.$unuseAll().foo.create({ data: { id: 1, name: 'Test' } }); + await expect(db.foo.findFirst()).toResolveNull(); + await expect(db.$setAuth({ id: 1 }).foo.findFirst()).toResolveNull(); + await expect(db.$setAuth({ id: 1, profile: { age: 15 } }).foo.findFirst()).toResolveNull(); + await expect(db.$setAuth({ profile: { age: 20 } }).foo.findFirst()).toResolveTruthy(); + }); + + it('works with deep type value access', async () => { + const db = await createPolicyTestClient( + ` +type User { + profile Profile? + @@auth +} + +type Profile { + age Int +} + +model Foo { + id Int @id + name String + @@allow('all', auth().profile.age > 18) +}`, + ); + await db.$unuseAll().foo.create({ data: { id: 1, name: 'Test' } }); + await expect(db.foo.findFirst()).toResolveNull(); + await expect(db.$setAuth({}).foo.findFirst()).toResolveNull(); + await expect(db.$setAuth({ profile: { age: 15 } }).foo.findFirst()).toResolveNull(); + await expect(db.$setAuth({ profile: { age: 20 } }).foo.findFirst()).toResolveTruthy(); + }); + + it('works with shallow auth model simple collection predicates', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + profiles Profile[] +} + +model Profile { + id Int @id + user User @relation(fields: [userId], references: [id]) + userId Int + age Int +} + +model Foo { + id Int @id + name String + @@allow('all', auth().profiles?[age > 18]) +} +`, + ); + + await db.$unuseAll().foo.create({ data: { id: 1, name: 'Test' } }); + await expect(db.foo.findFirst()).toResolveNull(); + await expect(db.$setAuth({ id: 1 }).foo.findFirst()).toResolveNull(); + await expect(db.$setAuth({ id: 1, profiles: [{ age: 15 }] }).foo.findFirst()).toResolveNull(); + await expect(db.$setAuth({ profiles: [{ age: 20 }] }).foo.findFirst()).toResolveTruthy(); + await expect(db.$setAuth({ profiles: [{ age: 15 }, { age: 20 }] }).foo.findFirst()).toResolveTruthy(); + }); + + it('works with shallow auth model collection predicates involving fields - some', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + profiles Profile[] +} + +model Profile { + id Int @id + user User @relation(fields: [userId], references: [id]) + userId Int + age Int +} + +model Foo { + id Int @id + name String + requiredAge Int + @@allow('all', auth().profiles?[age >= this.requiredAge]) +} +`, + ); + + await db.$unuseAll().foo.create({ data: { id: 1, name: 'Test', requiredAge: 18 } }); + await expect(db.foo.findFirst()).toResolveNull(); + await expect(db.$setAuth({ id: 1 }).foo.findFirst()).toResolveNull(); + await expect(db.$setAuth({ id: 1, profiles: [] }).foo.findFirst()).toResolveNull(); + await expect(db.$setAuth({ id: 1, profiles: [{ age: 15 }] }).foo.findFirst()).toResolveNull(); + await expect(db.$setAuth({ profiles: [{ age: 20 }] }).foo.findFirst()).toResolveTruthy(); + await expect(db.$setAuth({ profiles: [{ age: 15 }, { age: 20 }] }).foo.findFirst()).toResolveTruthy(); + }); + + it('works with deep auth model simple collection predicates', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + profiles Profile[] +} + +model Profile { + id Int @id + user User @relation(fields: [userId], references: [id]) + userId Int + records ProfileRecord[] +} + +model ProfileRecord { + id Int @id + age Int + profile Profile @relation(fields: [profileId], references: [id]) + profileId Int +} + +model Foo { + id Int @id + name String + @@allow('all', auth().profiles?[records?[age > 18]]) +} +`, + ); + + await db.$unuseAll().foo.create({ data: { id: 1, name: 'Test' } }); + await expect(db.foo.findFirst()).toResolveNull(); + await expect(db.$setAuth({ id: 1 }).foo.findFirst()).toResolveNull(); + await expect(db.$setAuth({ id: 1, profiles: [] }).foo.findFirst()).toResolveNull(); + await expect(db.$setAuth({ id: 1, profiles: [{ records: [] }] }).foo.findFirst()).toResolveNull(); + await expect(db.$setAuth({ id: 1, profiles: [{ records: [{ age: 15 }] }] }).foo.findFirst()).toResolveNull(); + await expect(db.$setAuth({ profiles: [{ records: [{ age: 20 }] }] }).foo.findFirst()).toResolveTruthy(); + await expect( + db.$setAuth({ profiles: [{ records: [{ age: 15 }] }, { records: [{ age: 20 }] }] }).foo.findFirst(), + ).toResolveTruthy(); + await expect( + db.$setAuth({ profiles: [{ records: [{ age: 15 }, { age: 20 }] }] }).foo.findFirst(), + ).toResolveTruthy(); + }); + + it('works with shallow auth type collection predicates involving fields - some', async () => { + const db = await createPolicyTestClient( + ` +type User { + profiles Profile[] +} + +type Profile { + age Int +} + +model Foo { + id Int @id + name String + requiredAge Int + @@allow('all', auth().profiles?[age >= this.requiredAge]) +} +`, + ); + + await db.$unuseAll().foo.create({ data: { id: 1, name: 'Test', requiredAge: 18 } }); + await expect(db.foo.findFirst()).toResolveNull(); + await expect(db.$setAuth({ id: 1 }).foo.findFirst()).toResolveNull(); + await expect(db.$setAuth({ id: 1, profiles: [] }).foo.findFirst()).toResolveNull(); + await expect(db.$setAuth({ id: 1, profiles: [{ age: 15 }] }).foo.findFirst()).toResolveNull(); + await expect(db.$setAuth({ profiles: [{ age: 20 }] }).foo.findFirst()).toResolveTruthy(); + await expect(db.$setAuth({ profiles: [{ age: 15 }, { age: 20 }] }).foo.findFirst()).toResolveTruthy(); + }); + + it('works with shallow auth model collection predicates involving fields - every', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + profiles Profile[] +} + +model Profile { + id Int @id + user User @relation(fields: [userId], references: [id]) + userId Int + age Int +} + +model Foo { + id Int @id + name String + requiredAge Int + @@allow('all', auth().profiles![age >= this.requiredAge]) +} +`, + ); + + await db.$unuseAll().foo.create({ data: { id: 1, name: 'Test', requiredAge: 18 } }); + await expect(db.foo.findFirst()).toResolveNull(); + await expect(db.$setAuth({ id: 1 }).foo.findFirst()).toResolveNull(); + await expect(db.$setAuth({ id: 1, profiles: [] }).foo.findFirst()).toResolveTruthy(); + await expect(db.$setAuth({ id: 1, profiles: [{ age: 15 }] }).foo.findFirst()).toResolveNull(); + await expect(db.$setAuth({ profiles: [{ age: 18 }, { age: 20 }] }).foo.findFirst()).toResolveTruthy(); + await expect(db.$setAuth({ profiles: [{ age: 15 }, { age: 20 }] }).foo.findFirst()).toResolveFalsy(); + }); + + it('works with shallow auth model collection predicates involving fields - none', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + profiles Profile[] +} + +model Profile { + id Int @id + user User @relation(fields: [userId], references: [id]) + userId Int + age Int +} + +model Foo { + id Int @id + name String + requiredAge Int + @@allow('all', auth().profiles^[age >= this.requiredAge]) +} +`, + ); + + await db.$unuseAll().foo.create({ data: { id: 1, name: 'Test', requiredAge: 18 } }); + await expect(db.foo.findFirst()).toResolveNull(); + await expect(db.$setAuth({ id: 1 }).foo.findFirst()).toResolveNull(); + await expect(db.$setAuth({ id: 1, profiles: [] }).foo.findFirst()).toResolveTruthy(); + await expect(db.$setAuth({ id: 1, profiles: [{ age: 15 }] }).foo.findFirst()).toResolveTruthy(); + await expect(db.$setAuth({ profiles: [{ age: 15 }, { age: 18 }] }).foo.findFirst()).toResolveNull(); + }); + + it('works with deep auth model collection predicates involving fields', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + roles Role[] +} + +model Role { + id Int @id + user User @relation(fields: [userId], references: [id]) + userId Int + permissions Permission[] +} + +model Permission { + id Int @id + role Role @relation(fields: [roleId], references: [id]) + roleId Int + canReadTypes String[] +} + +model Post { + id Int @id + type String + @@allow('all', auth().roles?[permissions![this.type in canReadTypes]] ) +} +`, + { provider: 'postgresql' }, + ); + + await db.$unuseAll().post.create({ data: { id: 1, type: 'News' } }); + + await expect(db.post.findFirst()).toResolveNull(); + await expect(db.$setAuth({ id: 1 }).post.findFirst()).toResolveNull(); + await expect(db.$setAuth({ id: 1, roles: [] }).post.findFirst()).toResolveNull(); + await expect(db.$setAuth({ id: 1, roles: [{ permissions: [] }] }).post.findFirst()).toResolveTruthy(); + await expect( + db.$setAuth({ id: 1, roles: [{ permissions: [{ canReadTypes: [] }] }] }).post.findFirst(), + ).toResolveNull(); + await expect( + db.$setAuth({ id: 1, roles: [{ permissions: [{ canReadTypes: ['News'] }] }] }).post.findFirst(), + ).toResolveTruthy(); + await expect( + db.$setAuth({ roles: [{ permissions: [{ canReadTypes: ['Blog'] }] }] }).post.findFirst(), + ).toResolveNull(); + await expect( + db.$setAuth({ roles: [{ permissions: [{ canReadTypes: ['Blog', 'News'] }] }] }).post.findFirst(), + ).toResolveTruthy(); + await expect( + db + .$setAuth({ roles: [{ permissions: [{ canReadTypes: ['Blog'] }, { canReadTypes: ['News'] }] }] }) + .post.findFirst(), + ).toResolveNull(); + await expect( + db + .$setAuth({ + roles: [ + { permissions: [{ canReadTypes: ['Blog'] }] }, + { permissions: [{ canReadTypes: ['News', 'Story'] }, { canReadTypes: ['Weather'] }] }, + ], + }) + .post.findFirst(), + ).toResolveNull(); + await expect( + db + .$setAuth({ + roles: [{ permissions: [{ canReadTypes: ['Blog', 'News'] }, { canReadTypes: ['News'] }] }], + }) + .post.findFirst(), + ).toResolveTruthy(); + }); + + it('works with regression1', async () => { + const schema = ` +model User { + id Int @id @default(autoincrement()) + permissions Permission[] +} + +model Permission { + id Int @id @default(autoincrement()) + name String + canUpdateChannelById Int[] + user User @relation(fields: [userId], references: [id]) + userId Int +} + +model Channel { + id Int @id @default(autoincrement()) + name String + + @@allow('create,read', true) + @@allow('update', auth().permissions?[this.id in canUpdateChannelById]) +} +`; + + const db = await createPolicyTestClient(schema, { provider: 'postgresql' }); + + await db.channel.create({ data: { id: 1, name: 'general' } }); + + await expect(db.channel.update({ where: { id: 1 }, data: { name: 'general-updated' } })).toBeRejectedNotFound(); + + const userDb1 = db.$setAuth({ + id: 3, + permissions: [ + { + id: 3, + name: 'update-general', + canUpdateChannelById: [2], + }, + ], + }); + await expect( + userDb1.channel.update({ where: { id: 1 }, data: { name: 'general-updated' } }), + ).toBeRejectedNotFound(); + + const userDb2 = db.$setAuth({ + id: 3, + permissions: [ + { + id: 3, + name: 'update-general', + canUpdateChannelById: [1], + }, + ], + }); + await expect( + userDb2.channel.update({ where: { id: 1 }, data: { name: 'general-updated' } }), + ).resolves.toBeTruthy(); + }); +});