1+ from abc import ABC , abstractmethod
2+
3+ import jwt
4+ from datetime import datetime , timezone
5+
6+ from redis .auth .err import InvalidTokenSchemaErr
7+
8+
9+ class TokenInterface (ABC ):
10+ @abstractmethod
11+ def is_expired (self ) -> bool :
12+ pass
13+
14+ @abstractmethod
15+ def ttl (self ) -> float :
16+ pass
17+
18+ @abstractmethod
19+ def try_get (self , key : str ) -> str :
20+ pass
21+
22+ @abstractmethod
23+ def get_value (self ) -> str :
24+ pass
25+
26+ @abstractmethod
27+ def get_expires_at_ms (self ) -> float :
28+ pass
29+
30+ @abstractmethod
31+ def get_received_at_ms (self ) -> float :
32+ pass
33+
34+
35+ class TokenResponse :
36+ def __init__ (self , token : TokenInterface ):
37+ self ._token = token
38+
39+ def get_token (self ) -> TokenInterface :
40+ return self ._token
41+
42+ def get_ttl_ms (self ) -> float :
43+ return self ._token .get_expires_at_ms () - self ._token .get_received_at_ms ()
44+
45+
46+ class SimpleToken (TokenInterface ):
47+ def __init__ (self , value : str , expires_at_ms : float , received_at_ms : float , claims : dict ) -> None :
48+ self .value = value
49+ self .expires_at = expires_at_ms
50+ self .received_at = received_at_ms
51+ self .claims = claims
52+
53+ def ttl (self ) -> float :
54+ if self .expires_at == - 1 :
55+ return - 1
56+
57+ return self .expires_at - (datetime .now (timezone .utc ).timestamp () * 1000 )
58+
59+ def is_expired (self ) -> bool :
60+ if self .expires_at == - 1 :
61+ return False
62+
63+ return self .ttl () <= 0
64+
65+ def try_get (self , key : str ) -> str :
66+ return self .claims .get (key )
67+
68+ def get_value (self ) -> str :
69+ return self .value
70+
71+ def get_expires_at_ms (self ) -> float :
72+ return self .expires_at
73+
74+ def get_received_at_ms (self ) -> float :
75+ return self .received_at
76+
77+
78+ class JWToken (TokenInterface ):
79+
80+ REQUIRED_FIELDS = {'exp' }
81+
82+ def __init__ (self , token : str ):
83+ self ._value = token
84+ self ._decoded = jwt .decode (
85+ self ._value ,
86+ options = {"verify_signature" : False },
87+ algorithms = [jwt .get_unverified_header (self ._value ).get ('alg' )]
88+ )
89+ self ._validate_token ()
90+
91+ def is_expired (self ) -> bool :
92+ exp = self ._decoded ['exp' ]
93+ if exp == - 1 :
94+ return False
95+
96+ return self ._decoded ['exp' ] * 1000 <= datetime .now (timezone .utc ).timestamp () * 1000
97+
98+ def ttl (self ) -> float :
99+ exp = self ._decoded ['exp' ]
100+ if exp == - 1 :
101+ return - 1
102+
103+ return self ._decoded ['exp' ] * 1000 - datetime .now (timezone .utc ).timestamp () * 1000
104+
105+ def try_get (self , key : str ) -> str :
106+ return self ._decoded .get (key )
107+
108+ def get_value (self ) -> str :
109+ return self ._value
110+
111+ def get_expires_at_ms (self ) -> float :
112+ return float (self ._decoded ['exp' ] * 1000 )
113+
114+ def get_received_at_ms (self ) -> float :
115+ return datetime .now (timezone .utc ).timestamp () * 1000
116+
117+ def _validate_token (self ):
118+ actual_fields = {x for x in self ._decoded .keys ()}
119+
120+ if len (self .REQUIRED_FIELDS - actual_fields ) != 0 :
121+ raise InvalidTokenSchemaErr (self .REQUIRED_FIELDS - actual_fields )
0 commit comments