Skip to content

Commit b1d5f91

Browse files
hovaescohashhar
authored andcommitted
Enable mypy checks for auth.py
1 parent 46d46c6 commit b1d5f91

File tree

2 files changed

+58
-54
lines changed

2 files changed

+58
-54
lines changed

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,5 @@ ignore_missing_imports = true
1919
no_implicit_optional = true
2020
warn_unused_ignores = true
2121

22-
[mypy-tests.*,trino.auth,trino.client,trino.dbapi,trino.sqlalchemy.*]
22+
[mypy-tests.*,trino.client,trino.dbapi,trino.sqlalchemy.*]
2323
ignore_errors = true

trino/auth.py

Lines changed: 57 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
import re
1818
import threading
1919
import webbrowser
20-
from typing import Callable, List, Optional
20+
from typing import Any, Callable, Dict, List, Optional, Tuple
2121
from urllib.parse import urlparse
2222

23-
from requests import Request
23+
from requests import PreparedRequest, Request, Response, Session
2424
from requests.auth import AuthBase, extract_cookies_to_jar
2525
from requests.utils import parse_dict_header
2626

@@ -32,18 +32,18 @@
3232

3333
class Authentication(metaclass=abc.ABCMeta):
3434
@abc.abstractmethod
35-
def set_http_session(self, http_session):
35+
def set_http_session(self, http_session: Session) -> Session:
3636
pass
3737

38-
def get_exceptions(self):
38+
def get_exceptions(self) -> Tuple[Any, ...]:
3939
return tuple()
4040

4141

4242
class KerberosAuthentication(Authentication):
4343
def __init__(
4444
self,
4545
config: Optional[str] = None,
46-
service_name: str = None,
46+
service_name: Optional[str] = None,
4747
mutual_authentication: bool = False,
4848
force_preemptive: bool = False,
4949
hostname_override: Optional[str] = None,
@@ -62,7 +62,7 @@ def __init__(
6262
self._delegate = delegate
6363
self._ca_bundle = ca_bundle
6464

65-
def set_http_session(self, http_session):
65+
def set_http_session(self, http_session: Session) -> Session:
6666
try:
6767
import requests_kerberos
6868
except ImportError:
@@ -84,15 +84,15 @@ def set_http_session(self, http_session):
8484
http_session.verify = self._ca_bundle
8585
return http_session
8686

87-
def get_exceptions(self):
87+
def get_exceptions(self) -> Tuple[Any, ...]:
8888
try:
8989
from requests_kerberos.exceptions import KerberosExchangeError
9090

91-
return (KerberosExchangeError,)
91+
return KerberosExchangeError,
9292
except ImportError:
9393
raise RuntimeError("unable to import requests_kerberos")
9494

95-
def __eq__(self, other):
95+
def __eq__(self, other: object) -> bool:
9696
if not isinstance(other, KerberosAuthentication):
9797
return False
9898
return (self._config == other._config
@@ -107,11 +107,11 @@ def __eq__(self, other):
107107

108108

109109
class BasicAuthentication(Authentication):
110-
def __init__(self, username, password):
110+
def __init__(self, username: str, password: str):
111111
self._username = username
112112
self._password = password
113113

114-
def set_http_session(self, http_session):
114+
def set_http_session(self, http_session: Session) -> Session:
115115
try:
116116
import requests.auth
117117
except ImportError:
@@ -120,10 +120,10 @@ def set_http_session(self, http_session):
120120
http_session.auth = requests.auth.HTTPBasicAuth(self._username, self._password)
121121
return http_session
122122

123-
def get_exceptions(self):
123+
def get_exceptions(self) -> Tuple[Any, ...]:
124124
return ()
125125

126-
def __eq__(self, other):
126+
def __eq__(self, other: object) -> bool:
127127
if not isinstance(other, BasicAuthentication):
128128
return False
129129
return self._username == other._username and self._password == other._password
@@ -134,27 +134,27 @@ class _BearerAuth(AuthBase):
134134
Custom implementation of Authentication class for bearer token
135135
"""
136136

137-
def __init__(self, token):
137+
def __init__(self, token: str):
138138
self.token = token
139139

140-
def __call__(self, r):
140+
def __call__(self, r: PreparedRequest) -> PreparedRequest:
141141
r.headers["Authorization"] = "Bearer " + self.token
142142
return r
143143

144144

145145
class JWTAuthentication(Authentication):
146146

147-
def __init__(self, token):
147+
def __init__(self, token: str):
148148
self.token = token
149149

150-
def set_http_session(self, http_session):
150+
def set_http_session(self, http_session: Session) -> Session:
151151
http_session.auth = _BearerAuth(self.token)
152152
return http_session
153153

154-
def get_exceptions(self):
154+
def get_exceptions(self) -> Tuple[Any, ...]:
155155
return ()
156156

157-
def __eq__(self, other):
157+
def __eq__(self, other: object) -> bool:
158158
if not isinstance(other, JWTAuthentication):
159159
return False
160160
return self.token == other.token
@@ -197,7 +197,7 @@ class CompositeRedirectHandler(RedirectHandler):
197197
def __init__(self, handlers: List[Callable[[str], None]]):
198198
self.handlers = handlers
199199

200-
def __call__(self, url: str):
200+
def __call__(self, url: str) -> None:
201201
for handler in self.handlers:
202202
handler(url)
203203

@@ -208,11 +208,11 @@ class _OAuth2TokenCache(metaclass=abc.ABCMeta):
208208
"""
209209

210210
@abc.abstractmethod
211-
def get_token_from_cache(self, host: str) -> Optional[str]:
211+
def get_token_from_cache(self, host: Optional[str]) -> Optional[str]:
212212
pass
213213

214214
@abc.abstractmethod
215-
def store_token_to_cache(self, host: str, token: str) -> None:
215+
def store_token_to_cache(self, host: Optional[str], token: str) -> None:
216216
pass
217217

218218

@@ -221,13 +221,13 @@ class _OAuth2TokenInMemoryCache(_OAuth2TokenCache):
221221
In-memory token cache implementation. The token is stored per host, so multiple clients can share the same cache.
222222
"""
223223

224-
def __init__(self):
225-
self._cache = {}
224+
def __init__(self) -> None:
225+
self._cache: Dict[Optional[str], str] = {}
226226

227-
def get_token_from_cache(self, host: str) -> Optional[str]:
227+
def get_token_from_cache(self, host: Optional[str]) -> Optional[str]:
228228
return self._cache.get(host)
229229

230-
def store_token_to_cache(self, host: str, token: str) -> None:
230+
def store_token_to_cache(self, host: Optional[str], token: str) -> None:
231231
self._cache[host] = token
232232

233233

@@ -236,26 +236,26 @@ class _OAuth2KeyRingTokenCache(_OAuth2TokenCache):
236236
Keyring Token Cache implementation
237237
"""
238238

239-
def __init__(self):
239+
def __init__(self) -> None:
240240
super().__init__()
241241
try:
242242
self._keyring = importlib.import_module("keyring")
243243
except ImportError:
244-
self._keyring = None
244+
self._keyring = None # type: ignore
245245
logger.info("keyring module not found. OAuth2 token will not be stored in keyring.")
246246

247247
def is_keyring_available(self) -> bool:
248248
return self._keyring is not None
249249

250-
def get_token_from_cache(self, host: str) -> Optional[str]:
250+
def get_token_from_cache(self, host: Optional[str]) -> Optional[str]:
251251
try:
252252
return self._keyring.get_password(host, "token")
253253
except self._keyring.errors.NoKeyringError as e:
254254
raise trino.exceptions.NotSupportedError("Although keyring module is installed no backend has been "
255255
"detected, check https://pypi.org/project/keyring/ for more "
256256
"information.") from e
257257

258-
def store_token_to_cache(self, host: str, token: str) -> None:
258+
def store_token_to_cache(self, host: Optional[str], token: str) -> None:
259259
try:
260260
# keyring is installed, so we can store the token for reuse within multiple threads
261261
self._keyring.set_password(host, "token", token)
@@ -280,18 +280,18 @@ def __init__(self, redirect_auth_url_handler: Callable[[str], None]):
280280
self._inside_oauth_attempt_lock = threading.Lock()
281281
self._inside_oauth_attempt_blocker = threading.Event()
282282

283-
def __call__(self, r):
283+
def __call__(self, r: PreparedRequest) -> PreparedRequest:
284284
host = self._determine_host(r.url)
285285
token = self._get_token_from_cache(host)
286286

287287
if token is not None:
288288
r.headers['Authorization'] = "Bearer " + token
289289

290-
r.register_hook('response', self._authenticate)
290+
r.register_hook('response', self._authenticate) # type: ignore
291291

292292
return r
293293

294-
def _authenticate(self, response, **kwargs):
294+
def _authenticate(self, response: Response, **kwargs: Any) -> Optional[Response]:
295295
if not 400 <= response.status_code < 500:
296296
return response
297297

@@ -310,7 +310,7 @@ def _authenticate(self, response, **kwargs):
310310

311311
return self._retry_request(response, **kwargs)
312312

313-
def _attempt_oauth(self, response, **kwargs):
313+
def _attempt_oauth(self, response: Response, **kwargs: Any) -> None:
314314
# we have to handle the authentication, may be token the token expired, or it wasn't there at all
315315
auth_info = response.headers.get('WWW-Authenticate')
316316
if not auth_info:
@@ -319,7 +319,8 @@ def _attempt_oauth(self, response, **kwargs):
319319
if not _OAuth2TokenBearer._BEARER_PREFIX.search(auth_info):
320320
raise exceptions.TrinoAuthError(f"Error: header info didn't match {auth_info}")
321321

322-
auth_info_headers = parse_dict_header(_OAuth2TokenBearer._BEARER_PREFIX.sub("", auth_info, count=1))
322+
auth_info_headers = parse_dict_header(
323+
_OAuth2TokenBearer._BEARER_PREFIX.sub("", auth_info, count=1)) # type: ignore
323324

324325
auth_server = auth_info_headers.get('x_redirect_server')
325326
token_server = auth_info_headers.get('x_token_server')
@@ -341,23 +342,26 @@ def _attempt_oauth(self, response, **kwargs):
341342
host = self._determine_host(request.url)
342343
self._store_token_to_cache(host, token)
343344

344-
def _retry_request(self, response, **kwargs):
345+
def _retry_request(self, response: Response, **kwargs: Any) -> Optional[Response]:
345346
request = response.request.copy()
346-
extract_cookies_to_jar(request._cookies, response.request, response.raw)
347-
request.prepare_cookies(request._cookies)
347+
extract_cookies_to_jar(request._cookies, response.request, response.raw) # type: ignore
348+
request.prepare_cookies(request._cookies) # type: ignore
348349

349350
host = self._determine_host(response.request.url)
350-
request.headers['Authorization'] = "Bearer " + self._get_token_from_cache(host)
351-
retry_response = response.connection.send(request, **kwargs)
351+
token = self._get_token_from_cache(host)
352+
if token is not None:
353+
request.headers['Authorization'] = "Bearer " + token
354+
retry_response = response.connection.send(request, **kwargs) # type: ignore
352355
retry_response.history.append(response)
353356
retry_response.request = request
354357
return retry_response
355358

356-
def _get_token(self, token_server, response, **kwargs):
359+
def _get_token(self, token_server: str, response: Response, **kwargs: Any) -> str:
357360
attempts = 0
358361
while attempts < self.MAX_OAUTH_ATTEMPTS:
359362
attempts += 1
360-
with response.connection.send(Request(method='GET', url=token_server).prepare(), **kwargs) as response:
363+
with response.connection.send(Request( # type: ignore
364+
method='GET', url=token_server).prepare(), **kwargs) as response:
361365
if response.status_code == 200:
362366
token_response = json.loads(response.text)
363367
token = token_response.get('token')
@@ -377,53 +381,53 @@ def _get_token(self, token_server, response, **kwargs):
377381

378382
raise exceptions.TrinoAuthError("Exceeded max attempts while getting the token")
379383

380-
def _get_token_from_cache(self, host: str) -> Optional[str]:
384+
def _get_token_from_cache(self, host: Optional[str]) -> Optional[str]:
381385
with self._token_lock:
382386
return self._token_cache.get_token_from_cache(host)
383387

384-
def _store_token_to_cache(self, host: str, token: str) -> None:
388+
def _store_token_to_cache(self, host: Optional[str], token: str) -> None:
385389
with self._token_lock:
386390
self._token_cache.store_token_to_cache(host, token)
387391

388392
@staticmethod
389-
def _determine_host(url) -> Optional[str]:
393+
def _determine_host(url: Optional[str]) -> Any:
390394
return urlparse(url).hostname
391395

392396

393397
class OAuth2Authentication(Authentication):
394-
def __init__(self, redirect_auth_url_handler=CompositeRedirectHandler([
398+
def __init__(self, redirect_auth_url_handler: CompositeRedirectHandler = CompositeRedirectHandler([
395399
WebBrowserRedirectHandler(),
396400
ConsoleRedirectHandler()
397401
])):
398402
self._redirect_auth_url = redirect_auth_url_handler
399403
self._bearer = _OAuth2TokenBearer(self._redirect_auth_url)
400404

401-
def set_http_session(self, http_session):
405+
def set_http_session(self, http_session: Session) -> Session:
402406
http_session.auth = self._bearer
403407
return http_session
404408

405-
def get_exceptions(self):
409+
def get_exceptions(self) -> Tuple[Any, ...]:
406410
return ()
407411

408-
def __eq__(self, other):
412+
def __eq__(self, other: object) -> bool:
409413
if not isinstance(other, OAuth2Authentication):
410414
return False
411415
return self._redirect_auth_url == other._redirect_auth_url
412416

413417

414418
class CertificateAuthentication(Authentication):
415-
def __init__(self, cert, key):
419+
def __init__(self, cert: str, key: str):
416420
self._cert = cert
417421
self._key = key
418422

419-
def set_http_session(self, http_session):
423+
def set_http_session(self, http_session: Session) -> Session:
420424
http_session.cert = (self._cert, self._key)
421425
return http_session
422426

423-
def get_exceptions(self):
427+
def get_exceptions(self) -> Tuple[Any, ...]:
424428
return ()
425429

426-
def __eq__(self, other):
430+
def __eq__(self, other: object) -> bool:
427431
if not isinstance(other, CertificateAuthentication):
428432
return False
429433
return self._cert == other._cert and self._key == other._key

0 commit comments

Comments
 (0)