Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fastapi_cloudauth/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .auth0 import Auth0, Auth0CurrentUser
from .cognito import Cognito, CognitoCurrentUser
from .firebase import FirebaseCurrentUser
from .scope import AdvancedScope
3 changes: 2 additions & 1 deletion fastapi_cloudauth/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ScopedJWKsVerifier,
Verifier,
)
from fastapi_cloudauth.scope import AdvancedScope


class CloudAuth(ABC):
Expand Down Expand Up @@ -176,7 +177,7 @@ def __init__(
self,
jwks: JWKS,
user_info: Optional[Type[BaseModel]] = None,
scope_name: Optional[str] = None,
scope_name: Union[Optional[str], Optional[AdvancedScope]] = None,
scope_key: Optional[str] = None,
auto_error: bool = True,
):
Expand Down
15 changes: 15 additions & 0 deletions fastapi_cloudauth/scope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pydantic
from typing import List

VALID_COMPERATOR = ["any", "all"]


class AdvancedScope(pydantic.BaseModel):
comperator: str
scopes: List[str]

@pydantic.validator("comperator")
def valid_coperator(cls, value):
if value not in VALID_COMPERATOR:
raise ValueError("Coperator mus be one of '{VALID_COMPERATOR}'")
return value
13 changes: 11 additions & 2 deletions fastapi_cloudauth/verification.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union

import requests
from fastapi import HTTPException
Expand All @@ -17,6 +17,7 @@
NOT_VERIFIED,
SCOPE_NOT_MATCHED,
)
from fastapi_cloudauth.scope import AdvancedScope


class Verifier(ABC):
Expand Down Expand Up @@ -135,7 +136,7 @@ class ScopedJWKsVerifier(JWKsVerifier):
def __init__(
self,
jwks: JWKS,
scope_name: Optional[str] = None,
scope_name: Union[Optional[str], Optional[AdvancedScope]] = None,
scope_key: Optional[str] = None,
auto_error: bool = True,
*args: Any,
Expand All @@ -159,6 +160,14 @@ def _verify_scope(self, http_auth: HTTPAuthorizationCredentials) -> bool:
scopes = claims.get(self.scope_key)
if isinstance(scopes, str):
scopes = {scope.strip() for scope in scopes.split()}
if isinstance(self.scope_name, AdvancedScope) and scopes:
match = set(self.scope_name.scopes) & set(scopes)
if self.scope_name.comperator == "any" and len(match) > 0:
return True
elif self.scope_name.comperator == "all" and len(match) >= len(self.scope_name.scopes):
return True
else:
return False
if scopes is None or self.scope_name not in scopes:
if self.auto_error:
raise HTTPException(
Expand Down
6 changes: 6 additions & 0 deletions tests/test_cloudauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ def test_invalid_scope(self):
self.failure_case("/scope/", self.ACCESS_TOKEN, detail=SCOPE_NOT_MATCHED)
self.success_case("/scope/no-error/", self.ACCESS_TOKEN)

def test_valid_scope_advanced(self):
self.success_case("/scope_advanced/", self.SCOPE_ACCESS_TOKEN)

def test_invalid_scope_advanced(self):
self.failure_case("/scope_advanced/", self.ACCESS_TOKEN, detail=SCOPE_NOT_MATCHED)

def test_valid_token_extraction(self):
self.userinfo_success_case("/access/user", self.ACCESS_TOKEN)

Expand Down
5 changes: 5 additions & 0 deletions tests/test_cognito.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pydantic.main import BaseModel

from fastapi_cloudauth import Cognito, CognitoCurrentUser
from fastapi_cloudauth import AdvancedScope
from fastapi_cloudauth.cognito import CognitoClaims
from tests.helpers import BaseTestCloudAuth, decode_token

Expand Down Expand Up @@ -172,6 +173,10 @@ async def invalid_access_user_no_error(
async def secure_scope() -> bool:
pass

@app.get("/scope_advanced/", dependencies=[Depends(auth.scope(AdvancedScope("all", [self.scope])))])
async def secure_scope_advanced() -> bool:
pass

@app.get("/scope/no-error/")
async def secure_scope_no_error(
payload=Depends(auth_no_error.scope(self.scope)),
Expand Down