Skip to content

Commit 32fc374

Browse files
committed
Added core functionality
1 parent 6dae71b commit 32fc374

File tree

8 files changed

+1139
-0
lines changed

8 files changed

+1139
-0
lines changed

redis/auth/__init__.py

Whitespace-only changes.

redis/auth/err.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from typing import Iterable
2+
3+
4+
class RequestTokenErr(Exception):
5+
"""
6+
Represents an exception during token request.
7+
"""
8+
def __init__(self, *args):
9+
super().__init__(*args)
10+
11+
12+
class InvalidTokenSchemaErr(Exception):
13+
"""
14+
Represents an exception related to invalid token schema.
15+
"""
16+
def __init__(self, missing_fields: Iterable[str] = []):
17+
super().__init__(
18+
"Unexpected token schema. Following fields are missing: " + ", ".join(missing_fields)
19+
)
20+
21+
22+
class TokenRenewalErr(Exception):
23+
"""
24+
Represents an exception during token renewal process.
25+
"""
26+
def __init__(self, *args):
27+
super().__init__(*args)

redis/auth/idp.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from abc import ABC, abstractmethod
2+
from redis.auth.token import TokenInterface
3+
4+
"""
5+
This interface is the facade of an identity provider
6+
"""
7+
8+
9+
class IdentityProviderInterface(ABC):
10+
"""
11+
Receive a token from the identity provider. Receiving a token only works when being authenticated.
12+
"""
13+
14+
@abstractmethod
15+
def request_token(self, force_refresh=False) -> TokenInterface:
16+
pass
17+
18+
19+
class IdentityProviderConfigInterface(ABC):
20+
"""
21+
Configuration class that provides a configured identity provider.
22+
"""
23+
@abstractmethod
24+
def get_provider(self) -> IdentityProviderInterface:
25+
pass

redis/auth/token.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
from abc import ABC, abstractmethod
2+
3+
import jwt
4+
from datetime import datetime, timezone
5+
6+
from redis.auth.err import InvalidTokenSchemaErr
7+
8+
9+
class TokenInterface(ABC):
10+
@abstractmethod
11+
def is_expired(self) -> bool:
12+
pass
13+
14+
@abstractmethod
15+
def ttl(self) -> float:
16+
pass
17+
18+
@abstractmethod
19+
def try_get(self, key: str) -> str:
20+
pass
21+
22+
@abstractmethod
23+
def get_value(self) -> str:
24+
pass
25+
26+
@abstractmethod
27+
def get_expires_at_ms(self) -> float:
28+
pass
29+
30+
@abstractmethod
31+
def get_received_at_ms(self) -> float:
32+
pass
33+
34+
35+
class TokenResponse:
36+
def __init__(self, token: TokenInterface):
37+
self._token = token
38+
39+
def get_token(self) -> TokenInterface:
40+
return self._token
41+
42+
def get_ttl_ms(self) -> float:
43+
return self._token.get_expires_at_ms() - self._token.get_received_at_ms()
44+
45+
46+
class SimpleToken(TokenInterface):
47+
def __init__(self, value: str, expires_at_ms: float, received_at_ms: float, claims: dict) -> None:
48+
self.value = value
49+
self.expires_at = expires_at_ms
50+
self.received_at = received_at_ms
51+
self.claims = claims
52+
53+
def ttl(self) -> float:
54+
if self.expires_at == -1:
55+
return -1
56+
57+
return self.expires_at - (datetime.now(timezone.utc).timestamp() * 1000)
58+
59+
def is_expired(self) -> bool:
60+
if self.expires_at == -1:
61+
return False
62+
63+
return self.ttl() <= 0
64+
65+
def try_get(self, key: str) -> str:
66+
return self.claims.get(key)
67+
68+
def get_value(self) -> str:
69+
return self.value
70+
71+
def get_expires_at_ms(self) -> float:
72+
return self.expires_at
73+
74+
def get_received_at_ms(self) -> float:
75+
return self.received_at
76+
77+
78+
class JWToken(TokenInterface):
79+
80+
REQUIRED_FIELDS = {'exp'}
81+
82+
def __init__(self, token: str):
83+
self._value = token
84+
self._decoded = jwt.decode(
85+
self._value,
86+
options={"verify_signature": False},
87+
algorithms=[jwt.get_unverified_header(self._value).get('alg')]
88+
)
89+
self._validate_token()
90+
91+
def is_expired(self) -> bool:
92+
exp = self._decoded['exp']
93+
if exp == -1:
94+
return False
95+
96+
return self._decoded['exp'] * 1000 <= datetime.now(timezone.utc).timestamp() * 1000
97+
98+
def ttl(self) -> float:
99+
exp = self._decoded['exp']
100+
if exp == -1:
101+
return -1
102+
103+
return self._decoded['exp'] * 1000 - datetime.now(timezone.utc).timestamp() * 1000
104+
105+
def try_get(self, key: str) -> str:
106+
return self._decoded.get(key)
107+
108+
def get_value(self) -> str:
109+
return self._value
110+
111+
def get_expires_at_ms(self) -> float:
112+
return float(self._decoded['exp'] * 1000)
113+
114+
def get_received_at_ms(self) -> float:
115+
return datetime.now(timezone.utc).timestamp() * 1000
116+
117+
def _validate_token(self):
118+
actual_fields = {x for x in self._decoded.keys()}
119+
120+
if len(self.REQUIRED_FIELDS - actual_fields) != 0:
121+
raise InvalidTokenSchemaErr(self.REQUIRED_FIELDS - actual_fields)

0 commit comments

Comments
 (0)