17
17
import re
18
18
import threading
19
19
import webbrowser
20
- from typing import Callable , List , Optional
20
+ from typing import Any , Callable , Dict , List , Optional , Tuple
21
21
from urllib .parse import urlparse
22
22
23
- from requests import Request
23
+ from requests import PreparedRequest , Request , Response , Session
24
24
from requests .auth import AuthBase , extract_cookies_to_jar
25
25
from requests .utils import parse_dict_header
26
26
32
32
33
33
class Authentication (metaclass = abc .ABCMeta ):
34
34
@abc .abstractmethod
35
- def set_http_session (self , http_session ) :
35
+ def set_http_session (self , http_session : Session ) -> Session :
36
36
pass
37
37
38
- def get_exceptions (self ):
38
+ def get_exceptions (self ) -> Tuple [ Any , ...] :
39
39
return tuple ()
40
40
41
41
42
42
class KerberosAuthentication (Authentication ):
43
43
def __init__ (
44
44
self ,
45
45
config : Optional [str ] = None ,
46
- service_name : str = None ,
46
+ service_name : Optional [ str ] = None ,
47
47
mutual_authentication : bool = False ,
48
48
force_preemptive : bool = False ,
49
49
hostname_override : Optional [str ] = None ,
@@ -62,7 +62,7 @@ def __init__(
62
62
self ._delegate = delegate
63
63
self ._ca_bundle = ca_bundle
64
64
65
- def set_http_session (self , http_session ) :
65
+ def set_http_session (self , http_session : Session ) -> Session :
66
66
try :
67
67
import requests_kerberos
68
68
except ImportError :
@@ -84,15 +84,15 @@ def set_http_session(self, http_session):
84
84
http_session .verify = self ._ca_bundle
85
85
return http_session
86
86
87
- def get_exceptions (self ):
87
+ def get_exceptions (self ) -> Tuple [ Any , ...] :
88
88
try :
89
89
from requests_kerberos .exceptions import KerberosExchangeError
90
90
91
- return ( KerberosExchangeError ,)
91
+ return KerberosExchangeError ,
92
92
except ImportError :
93
93
raise RuntimeError ("unable to import requests_kerberos" )
94
94
95
- def __eq__ (self , other ) :
95
+ def __eq__ (self , other : object ) -> bool :
96
96
if not isinstance (other , KerberosAuthentication ):
97
97
return False
98
98
return (self ._config == other ._config
@@ -107,11 +107,11 @@ def __eq__(self, other):
107
107
108
108
109
109
class BasicAuthentication (Authentication ):
110
- def __init__ (self , username , password ):
110
+ def __init__ (self , username : str , password : str ):
111
111
self ._username = username
112
112
self ._password = password
113
113
114
- def set_http_session (self , http_session ) :
114
+ def set_http_session (self , http_session : Session ) -> Session :
115
115
try :
116
116
import requests .auth
117
117
except ImportError :
@@ -120,10 +120,10 @@ def set_http_session(self, http_session):
120
120
http_session .auth = requests .auth .HTTPBasicAuth (self ._username , self ._password )
121
121
return http_session
122
122
123
- def get_exceptions (self ):
123
+ def get_exceptions (self ) -> Tuple [ Any , ...] :
124
124
return ()
125
125
126
- def __eq__ (self , other ) :
126
+ def __eq__ (self , other : object ) -> bool :
127
127
if not isinstance (other , BasicAuthentication ):
128
128
return False
129
129
return self ._username == other ._username and self ._password == other ._password
@@ -134,27 +134,27 @@ class _BearerAuth(AuthBase):
134
134
Custom implementation of Authentication class for bearer token
135
135
"""
136
136
137
- def __init__ (self , token ):
137
+ def __init__ (self , token : str ):
138
138
self .token = token
139
139
140
- def __call__ (self , r ) :
140
+ def __call__ (self , r : PreparedRequest ) -> PreparedRequest :
141
141
r .headers ["Authorization" ] = "Bearer " + self .token
142
142
return r
143
143
144
144
145
145
class JWTAuthentication (Authentication ):
146
146
147
- def __init__ (self , token ):
147
+ def __init__ (self , token : str ):
148
148
self .token = token
149
149
150
- def set_http_session (self , http_session ) :
150
+ def set_http_session (self , http_session : Session ) -> Session :
151
151
http_session .auth = _BearerAuth (self .token )
152
152
return http_session
153
153
154
- def get_exceptions (self ):
154
+ def get_exceptions (self ) -> Tuple [ Any , ...] :
155
155
return ()
156
156
157
- def __eq__ (self , other ) :
157
+ def __eq__ (self , other : object ) -> bool :
158
158
if not isinstance (other , JWTAuthentication ):
159
159
return False
160
160
return self .token == other .token
@@ -197,7 +197,7 @@ class CompositeRedirectHandler(RedirectHandler):
197
197
def __init__ (self , handlers : List [Callable [[str ], None ]]):
198
198
self .handlers = handlers
199
199
200
- def __call__ (self , url : str ):
200
+ def __call__ (self , url : str ) -> None :
201
201
for handler in self .handlers :
202
202
handler (url )
203
203
@@ -208,11 +208,11 @@ class _OAuth2TokenCache(metaclass=abc.ABCMeta):
208
208
"""
209
209
210
210
@abc .abstractmethod
211
- def get_token_from_cache (self , host : str ) -> Optional [str ]:
211
+ def get_token_from_cache (self , host : Optional [ str ] ) -> Optional [str ]:
212
212
pass
213
213
214
214
@abc .abstractmethod
215
- def store_token_to_cache (self , host : str , token : str ) -> None :
215
+ def store_token_to_cache (self , host : Optional [ str ] , token : str ) -> None :
216
216
pass
217
217
218
218
@@ -221,13 +221,13 @@ class _OAuth2TokenInMemoryCache(_OAuth2TokenCache):
221
221
In-memory token cache implementation. The token is stored per host, so multiple clients can share the same cache.
222
222
"""
223
223
224
- def __init__ (self ):
225
- self ._cache = {}
224
+ def __init__ (self ) -> None :
225
+ self ._cache : Dict [ Optional [ str ], str ] = {}
226
226
227
- def get_token_from_cache (self , host : str ) -> Optional [str ]:
227
+ def get_token_from_cache (self , host : Optional [ str ] ) -> Optional [str ]:
228
228
return self ._cache .get (host )
229
229
230
- def store_token_to_cache (self , host : str , token : str ) -> None :
230
+ def store_token_to_cache (self , host : Optional [ str ] , token : str ) -> None :
231
231
self ._cache [host ] = token
232
232
233
233
@@ -236,26 +236,26 @@ class _OAuth2KeyRingTokenCache(_OAuth2TokenCache):
236
236
Keyring Token Cache implementation
237
237
"""
238
238
239
- def __init__ (self ):
239
+ def __init__ (self ) -> None :
240
240
super ().__init__ ()
241
241
try :
242
242
self ._keyring = importlib .import_module ("keyring" )
243
243
except ImportError :
244
- self ._keyring = None
244
+ self ._keyring = None # type: ignore
245
245
logger .info ("keyring module not found. OAuth2 token will not be stored in keyring." )
246
246
247
247
def is_keyring_available (self ) -> bool :
248
248
return self ._keyring is not None
249
249
250
- def get_token_from_cache (self , host : str ) -> Optional [str ]:
250
+ def get_token_from_cache (self , host : Optional [ str ] ) -> Optional [str ]:
251
251
try :
252
252
return self ._keyring .get_password (host , "token" )
253
253
except self ._keyring .errors .NoKeyringError as e :
254
254
raise trino .exceptions .NotSupportedError ("Although keyring module is installed no backend has been "
255
255
"detected, check https://pypi.org/project/keyring/ for more "
256
256
"information." ) from e
257
257
258
- def store_token_to_cache (self , host : str , token : str ) -> None :
258
+ def store_token_to_cache (self , host : Optional [ str ] , token : str ) -> None :
259
259
try :
260
260
# keyring is installed, so we can store the token for reuse within multiple threads
261
261
self ._keyring .set_password (host , "token" , token )
@@ -280,18 +280,18 @@ def __init__(self, redirect_auth_url_handler: Callable[[str], None]):
280
280
self ._inside_oauth_attempt_lock = threading .Lock ()
281
281
self ._inside_oauth_attempt_blocker = threading .Event ()
282
282
283
- def __call__ (self , r ) :
283
+ def __call__ (self , r : PreparedRequest ) -> PreparedRequest :
284
284
host = self ._determine_host (r .url )
285
285
token = self ._get_token_from_cache (host )
286
286
287
287
if token is not None :
288
288
r .headers ['Authorization' ] = "Bearer " + token
289
289
290
- r .register_hook ('response' , self ._authenticate )
290
+ r .register_hook ('response' , self ._authenticate ) # type: ignore
291
291
292
292
return r
293
293
294
- def _authenticate (self , response , ** kwargs ) :
294
+ def _authenticate (self , response : Response , ** kwargs : Any ) -> Optional [ Response ] :
295
295
if not 400 <= response .status_code < 500 :
296
296
return response
297
297
@@ -310,7 +310,7 @@ def _authenticate(self, response, **kwargs):
310
310
311
311
return self ._retry_request (response , ** kwargs )
312
312
313
- def _attempt_oauth (self , response , ** kwargs ) :
313
+ def _attempt_oauth (self , response : Response , ** kwargs : Any ) -> None :
314
314
# we have to handle the authentication, may be token the token expired, or it wasn't there at all
315
315
auth_info = response .headers .get ('WWW-Authenticate' )
316
316
if not auth_info :
@@ -319,7 +319,8 @@ def _attempt_oauth(self, response, **kwargs):
319
319
if not _OAuth2TokenBearer ._BEARER_PREFIX .search (auth_info ):
320
320
raise exceptions .TrinoAuthError (f"Error: header info didn't match { auth_info } " )
321
321
322
- auth_info_headers = parse_dict_header (_OAuth2TokenBearer ._BEARER_PREFIX .sub ("" , auth_info , count = 1 ))
322
+ auth_info_headers = parse_dict_header (
323
+ _OAuth2TokenBearer ._BEARER_PREFIX .sub ("" , auth_info , count = 1 )) # type: ignore
323
324
324
325
auth_server = auth_info_headers .get ('x_redirect_server' )
325
326
token_server = auth_info_headers .get ('x_token_server' )
@@ -341,23 +342,26 @@ def _attempt_oauth(self, response, **kwargs):
341
342
host = self ._determine_host (request .url )
342
343
self ._store_token_to_cache (host , token )
343
344
344
- def _retry_request (self , response , ** kwargs ) :
345
+ def _retry_request (self , response : Response , ** kwargs : Any ) -> Optional [ Response ] :
345
346
request = response .request .copy ()
346
- extract_cookies_to_jar (request ._cookies , response .request , response .raw )
347
- request .prepare_cookies (request ._cookies )
347
+ extract_cookies_to_jar (request ._cookies , response .request , response .raw ) # type: ignore
348
+ request .prepare_cookies (request ._cookies ) # type: ignore
348
349
349
350
host = self ._determine_host (response .request .url )
350
- request .headers ['Authorization' ] = "Bearer " + self ._get_token_from_cache (host )
351
- retry_response = response .connection .send (request , ** kwargs )
351
+ token = self ._get_token_from_cache (host )
352
+ if token is not None :
353
+ request .headers ['Authorization' ] = "Bearer " + token
354
+ retry_response = response .connection .send (request , ** kwargs ) # type: ignore
352
355
retry_response .history .append (response )
353
356
retry_response .request = request
354
357
return retry_response
355
358
356
- def _get_token (self , token_server , response , ** kwargs ) :
359
+ def _get_token (self , token_server : str , response : Response , ** kwargs : Any ) -> str :
357
360
attempts = 0
358
361
while attempts < self .MAX_OAUTH_ATTEMPTS :
359
362
attempts += 1
360
- with response .connection .send (Request (method = 'GET' , url = token_server ).prepare (), ** kwargs ) as response :
363
+ with response .connection .send (Request ( # type: ignore
364
+ method = 'GET' , url = token_server ).prepare (), ** kwargs ) as response :
361
365
if response .status_code == 200 :
362
366
token_response = json .loads (response .text )
363
367
token = token_response .get ('token' )
@@ -377,53 +381,53 @@ def _get_token(self, token_server, response, **kwargs):
377
381
378
382
raise exceptions .TrinoAuthError ("Exceeded max attempts while getting the token" )
379
383
380
- def _get_token_from_cache (self , host : str ) -> Optional [str ]:
384
+ def _get_token_from_cache (self , host : Optional [ str ] ) -> Optional [str ]:
381
385
with self ._token_lock :
382
386
return self ._token_cache .get_token_from_cache (host )
383
387
384
- def _store_token_to_cache (self , host : str , token : str ) -> None :
388
+ def _store_token_to_cache (self , host : Optional [ str ] , token : str ) -> None :
385
389
with self ._token_lock :
386
390
self ._token_cache .store_token_to_cache (host , token )
387
391
388
392
@staticmethod
389
- def _determine_host (url ) -> Optional [str ]:
393
+ def _determine_host (url : Optional [str ]) -> Any :
390
394
return urlparse (url ).hostname
391
395
392
396
393
397
class OAuth2Authentication (Authentication ):
394
- def __init__ (self , redirect_auth_url_handler = CompositeRedirectHandler ([
398
+ def __init__ (self , redirect_auth_url_handler : CompositeRedirectHandler = CompositeRedirectHandler ([
395
399
WebBrowserRedirectHandler (),
396
400
ConsoleRedirectHandler ()
397
401
])):
398
402
self ._redirect_auth_url = redirect_auth_url_handler
399
403
self ._bearer = _OAuth2TokenBearer (self ._redirect_auth_url )
400
404
401
- def set_http_session (self , http_session ) :
405
+ def set_http_session (self , http_session : Session ) -> Session :
402
406
http_session .auth = self ._bearer
403
407
return http_session
404
408
405
- def get_exceptions (self ):
409
+ def get_exceptions (self ) -> Tuple [ Any , ...] :
406
410
return ()
407
411
408
- def __eq__ (self , other ) :
412
+ def __eq__ (self , other : object ) -> bool :
409
413
if not isinstance (other , OAuth2Authentication ):
410
414
return False
411
415
return self ._redirect_auth_url == other ._redirect_auth_url
412
416
413
417
414
418
class CertificateAuthentication (Authentication ):
415
- def __init__ (self , cert , key ):
419
+ def __init__ (self , cert : str , key : str ):
416
420
self ._cert = cert
417
421
self ._key = key
418
422
419
- def set_http_session (self , http_session ) :
423
+ def set_http_session (self , http_session : Session ) -> Session :
420
424
http_session .cert = (self ._cert , self ._key )
421
425
return http_session
422
426
423
- def get_exceptions (self ):
427
+ def get_exceptions (self ) -> Tuple [ Any , ...] :
424
428
return ()
425
429
426
- def __eq__ (self , other ) :
430
+ def __eq__ (self , other : object ) -> bool :
427
431
if not isinstance (other , CertificateAuthentication ):
428
432
return False
429
433
return self ._cert == other ._cert and self ._key == other ._key
0 commit comments