diff --git a/trino/auth.py b/trino/auth.py index 99c5d2d6..c2155fd1 100644 --- a/trino/auth.py +++ b/trino/auth.py @@ -33,7 +33,7 @@ from requests.auth import extract_cookies_to_jar import trino.logging -from trino.client import exceptions +from trino import exceptions from trino.constants import HEADER_USER from trino.constants import MAX_NT_PASSWORD_SIZE diff --git a/trino/client.py b/trino/client.py index da5e4047..12f3e8d4 100644 --- a/trino/client.py +++ b/trino/client.py @@ -55,12 +55,19 @@ from zoneinfo import ZoneInfo import requests +from requests import Response +from requests import Session +from requests.structures import CaseInsensitiveDict from tzlocal import get_localzone_name # type: ignore import trino.logging from trino import constants from trino import exceptions from trino._version import __version__ +from trino.auth import Authentication +from trino.exceptions import TrinoExternalError +from trino.exceptions import TrinoQueryError +from trino.exceptions import TrinoUserError from trino.mapper import RowMapper from trino.mapper import RowMapperFactory @@ -271,11 +278,11 @@ def __setstate__(self, state): self._object_lock = threading.Lock() -def get_header_values(headers, header): +def get_header_values(headers: CaseInsensitiveDict[str], header: str) -> List[str]: return [val.strip() for val in headers[header].split(",")] -def get_session_property_values(headers, header): +def get_session_property_values(headers: CaseInsensitiveDict[str], header: str) -> List[Tuple[str, str]]: kvs = get_header_values(headers, header) return [ (k.strip(), urllib.parse.unquote_plus(v.strip())) @@ -283,7 +290,7 @@ def get_session_property_values(headers, header): ] -def get_prepared_statement_values(headers, header): +def get_prepared_statement_values(headers: CaseInsensitiveDict[str], header: str) -> List[Tuple[str, str]]: kvs = get_header_values(headers, header) return [ (k.strip(), urllib.parse.unquote_plus(v.strip())) @@ -291,7 +298,7 @@ def get_prepared_statement_values(headers, header): ] -def get_roles_values(headers, header): +def get_roles_values(headers: CaseInsensitiveDict[str], header: str) -> List[Tuple[str, str]]: kvs = get_header_values(headers, header) return [ (k.strip(), urllib.parse.unquote_plus(v.strip())) @@ -414,9 +421,9 @@ def __init__( host: str, port: int, client_session: ClientSession, - http_session: Any = None, - http_scheme: str = None, - auth: Optional[Any] = constants.DEFAULT_AUTH, + http_session: Optional[Session] = None, + http_scheme: Optional[str] = None, + auth: Optional[Authentication] = constants.DEFAULT_AUTH, max_attempts: int = MAX_ATTEMPTS, request_timeout: Union[float, Tuple[float, float]] = constants.DEFAULT_REQUEST_TIMEOUT, handle_retry=_RetryWithExponentialBackoff(), @@ -454,16 +461,16 @@ def __init__( self.max_attempts = max_attempts @property - def transaction_id(self): + def transaction_id(self) -> Optional[str]: return self._client_session.transaction_id @transaction_id.setter - def transaction_id(self, value): + def transaction_id(self, value: Optional[str]) -> None: self._client_session.transaction_id = value @property - def http_headers(self) -> Dict[str, str]: - headers = requests.structures.CaseInsensitiveDict() + def http_headers(self) -> CaseInsensitiveDict[str]: + headers: CaseInsensitiveDict[str] = CaseInsensitiveDict() headers[constants.HEADER_CATALOG] = self._client_session.catalog headers[constants.HEADER_SCHEMA] = self._client_session.schema @@ -525,7 +532,7 @@ def max_attempts(self) -> int: return self._max_attempts @max_attempts.setter - def max_attempts(self, value) -> None: + def max_attempts(self, value: int) -> None: self._max_attempts = value if value == 1: # No retry self._get = self._http_session.get @@ -547,7 +554,7 @@ def max_attempts(self, value) -> None: self._post = with_retry(self._http_session.post) self._delete = with_retry(self._http_session.delete) - def get_url(self, path) -> str: + def get_url(self, path: str) -> str: return "{protocol}://{host}:{port}{path}".format( protocol=self._http_scheme, host=self._host, port=self._port, path=path ) @@ -560,7 +567,7 @@ def statement_url(self) -> str: def next_uri(self) -> Optional[str]: return self._next_uri - def post(self, sql: str, additional_http_headers: Optional[Dict[str, Any]] = None): + def post(self, sql: str, additional_http_headers: Optional[Dict[str, Any]] = None) -> Response: data = sql.encode("utf-8") # Deep copy of the http_headers dict since they may be modified for this # request by the provided additional_http_headers @@ -578,7 +585,7 @@ def post(self, sql: str, additional_http_headers: Optional[Dict[str, Any]] = Non ) return http_response - def get(self, url: str): + def get(self, url: str) -> Response: return self._get( url, headers=self.http_headers, @@ -586,10 +593,11 @@ def get(self, url: str): proxies=PROXIES, ) - def delete(self, url): + def delete(self, url: str) -> Response: return self._delete(url, timeout=self._request_timeout, proxies=PROXIES) - def _process_error(self, error, query_id): + @staticmethod + def _process_error(error, query_id: Optional[str]) -> Union[TrinoExternalError, TrinoQueryError, TrinoUserError]: error_type = error["errorType"] if error_type == "EXTERNAL": raise exceptions.TrinoExternalError(error, query_id) @@ -598,7 +606,8 @@ def _process_error(self, error, query_id): return exceptions.TrinoQueryError(error, query_id) - def raise_response_error(self, http_response): + @staticmethod + def raise_response_error(http_response: Response) -> None: if http_response.status_code == 502: raise exceptions.Http502Error("error 502: bad gateway") @@ -615,7 +624,7 @@ def raise_response_error(self, http_response): ) ) - def process(self, http_response) -> TrinoStatus: + def process(self, http_response: Response) -> TrinoStatus: if not http_response.ok: self.raise_response_error(http_response) @@ -682,7 +691,8 @@ def process(self, http_response) -> TrinoStatus: columns=response.get("columns"), ) - def _verify_extra_credential(self, header): + @staticmethod + def _verify_extra_credential(header: Tuple[str, str]) -> None: """ Verifies that key has ASCII only and non-whitespace characters. """