diff --git a/strawberry_django/permissions.py b/strawberry_django/permissions.py index e97595a1..c34e13b4 100644 --- a/strawberry_django/permissions.py +++ b/strawberry_django/permissions.py @@ -53,6 +53,92 @@ _M = TypeVar("_M", bound=Model) +# Borrowed from the DRF project +class OperationHolderMixin: + def __and__(self, other): + return OperandHolder(AND, self, other) + + def __or__(self, other): + return OperandHolder(OR, self, other) + + def __rand__(self, other): + return OperandHolder(AND, other, self) + + def __ror__(self, other): + return OperandHolder(OR, other, self) + + def __invert__(self): + return SingleOperandHolder(NOT, self) + + +class SingleOperandHolder(OperationHolderMixin): + def __init__(self, operator_class, op1_class): + self.operator_class = operator_class + self.op1_class = op1_class + + def __call__(self, *args, **kwargs): + op1 = self.op1_class(*args, **kwargs) + return self.operator_class(op1) + + +class OperandHolder(OperationHolderMixin): + def __init__(self, operator_class, op1_class, op2_class): + self.operator_class = operator_class + self.op1_class = op1_class + self.op2_class = op2_class + + def __call__(self, *args, **kwargs): + op1 = self.op1_class(*args, **kwargs) + op2 = self.op2_class(*args, **kwargs) + return self.operator_class(op1, op2) + + def __eq__(self, other): + return ( + isinstance(other, OperandHolder) + and self.operator_class == other.operator_class + and self.op1_class == other.op1_class + and self.op2_class == other.op2_class + ) + + def __hash__(self): + return hash((self.operator_class, self.op1_class, self.op2_class)) + + +class AND: + def __init__(self, op1, op2): + self.op1 = op1 + self.op2 = op2 + + def has_permission(self, user: UserType) -> bool: + return self.op1.has_permission(user=user) and self.op2.has_permission(user=user) + + +class OR: + def __init__(self, op1, op2): + self.op1 = op1 + self.op2 = op2 + + def has_permission(self, user: UserType) -> bool: + return self.op1.has_permission(user=user) or self.op2.has_permission(user=user) + + +class NOT: + def __init__(self, op1): + self.op1 = op1 + + def has_permission(self, user: UserType): + return not self.op1.has_permission(user=user) + + +class BasePermissionMetaclass(OperationHolderMixin, type): + pass + + +class BasePermission(metaclass=BasePermissionMetaclass): + def has_permission(self, user: UserType) -> bool: + return True + + @functools.lru_cache def _get_user_or_anonymous_getter() -> Optional[Callable[[UserType], UserType]]: try: @@ -930,3 +1016,33 @@ class HasRetvalPerm(HasPerm): "Will check if the user has any/all permissions for the resolved " "value of this field before returning it.", ) + + +class HasPermissionClasses(DjangoPermissionExtension): + def __init__( + self, + *args, + permission_classes: Iterable[ + type[BasePermission | SingleOperandHolder | OperandHolder] + ], + **kwargs, + ): + super().__init__(*args, **kwargs) + self.permission_classes = permission_classes + + def resolve_for_user( # pragma: no cover + self, + resolver: Callable, + user: UserType | None, + *, + info: Info, + source: Any, + ) -> AwaitableOrValue[Any]: + if not user: + raise DjangoNoPermission + + permissions = [permission() for permission in self.permission_classes] + for permission in permissions: + if not permission.has_permission(user=user): + raise DjangoNoPermission + return resolver()