diff --git a/mock_tests/test_collection.py b/mock_tests/test_collection.py index 4236a35cb..53273faeb 100644 --- a/mock_tests/test_collection.py +++ b/mock_tests/test_collection.py @@ -37,6 +37,7 @@ WeaviateStartUpError, BackupCanceledError, InsufficientPermissionsError, + UnexpectedStatusCodeError, ) ACCESS_TOKEN = "HELLO!IamAnAccessToken" @@ -54,9 +55,15 @@ def test_insufficient_permissions( port=MOCK_PORT, host=MOCK_IP, grpc_port=MOCK_PORT_GRPC, skip_init_checks=True ) collection = client.collections.get("Test") - with pytest.raises(InsufficientPermissionsError) as e: + + with pytest.raises(InsufficientPermissionsError) as e1: + collection.config.get() + assert "this is an error" in e1.value.message + + with pytest.raises(UnexpectedStatusCodeError) as e2: collection.config.get() - assert "this is an error" in e.value.message + assert e2.value.status_code == 403 + weaviate_mock.check_assertions() diff --git a/requirements-devel.txt b/requirements-devel.txt index 1e197763f..9530acfda 100644 --- a/requirements-devel.txt +++ b/requirements-devel.txt @@ -39,7 +39,6 @@ mypy==1.13.0 mypy-extensions==1.0.0 tomli==2.2.1 types-protobuf==5.28.3.20241203 -types-requests==2.32.0.20241016 types-urllib3==1.26.25.14 typing_extensions==4.12.2 diff --git a/weaviate/connect/authentication.py b/weaviate/connect/authentication.py deleted file mode 100644 index f7ed3da4d..000000000 --- a/weaviate/connect/authentication.py +++ /dev/null @@ -1,154 +0,0 @@ -from __future__ import annotations - -from typing import Dict, Generic, List, Type, TypeVar, Union, cast -from typing import TYPE_CHECKING - -import requests -from authlib.integrations.httpx_client import OAuth2Client # type: ignore -from authlib.integrations.requests_client import OAuth2Session # type: ignore - -from weaviate.auth import ( - AuthCredentials, - AuthClientPassword, - AuthBearerToken, - AuthClientCredentials, -) -from weaviate.exceptions import MissingScopeError, AuthenticationFailedError -from ..util import _decode_json_response_dict -from ..warnings import _Warnings - -if TYPE_CHECKING: - from .base import _ConnectionBase - -AUTH_DEFAULT_TIMEOUT = 5 -OIDC_CONFIG = Dict[str, Union[str, List[str]]] - - -T = TypeVar("T", bound=Union[OAuth2Client, OAuth2Session]) - - -class _Auth(Generic[T]): - def __init__( - self, - session_type: Type[T], - oidc_config: OIDC_CONFIG, - credentials: AuthCredentials, - connection: _ConnectionBase, - ) -> None: - self._credentials: AuthCredentials = credentials - self._connection = connection - self.__session_type: Type[T] = session_type - config_url = oidc_config["href"] - client_id = oidc_config["clientId"] - assert isinstance(config_url, str) and isinstance(client_id, str) - self._open_id_config_url: str = config_url - self._client_id: str = client_id - self._default_scopes: List[str] = [] - if "scopes" in oidc_config: - default_scopes = oidc_config["scopes"] - assert isinstance(default_scopes, list) - self._default_scopes = default_scopes - - self._token_endpoint: str = self._get_token_endpoint() - self._validate(oidc_config) - - def _validate(self, oidc_config: OIDC_CONFIG) -> None: - if isinstance(self._credentials, AuthClientPassword): - if self._token_endpoint.startswith("https://login.microsoftonline.com"): - raise AuthenticationFailedError( - """Microsoft/azure does not recommend to authenticate using username and password and this method is - not supported by the python client.""" - ) - - # The grant_types_supported field is optional and does not have to be present in the response - if ( - "grant_types_supported" in oidc_config - and "password" not in oidc_config["grant_types_supported"] - ): - raise AuthenticationFailedError( - """The grant_types supported by the third-party authentication service are insufficient. Please add - the 'password' grant type.""" - ) - - def _get_token_endpoint(self) -> str: - response_auth = requests.get( - self._open_id_config_url, proxies=self._connection.get_proxies() - ) - response_auth_json = _decode_json_response_dict(response_auth, "Get token endpoint") - assert response_auth_json is not None - token_endpoint = response_auth_json["token_endpoint"] - assert isinstance(token_endpoint, str) - return token_endpoint - - def get_auth_session(self) -> T: - if isinstance(self._credentials, AuthBearerToken): - sessions = self._get_session_auth_bearer_token(self._credentials) - elif isinstance(self._credentials, AuthClientCredentials): - sessions = self._get_session_client_credential(self._credentials) - else: - assert isinstance(self._credentials, AuthClientPassword) - sessions = self._get_session_user_pw(self._credentials) - - return sessions - - def _get_session_auth_bearer_token(self, config: AuthBearerToken) -> T: - token: Dict[str, Union[str, int]] = {"access_token": config.access_token} - if config.expires_in is not None: - token["expires_in"] = config.expires_in - if config.refresh_token is not None: - token["refresh_token"] = config.refresh_token - - if "refresh_token" not in token: - _Warnings.auth_no_refresh_token(config.expires_in) - - return cast( - T, - self.__session_type( - token=token, - token_endpoint=self._token_endpoint, - client_id=self._client_id, - default_timeout=AUTH_DEFAULT_TIMEOUT, - ), - ) - - def _get_session_user_pw(self, config: AuthClientPassword) -> T: - scope: List[str] = self._default_scopes.copy() - scope.extend(config.scope_list) - session = self.__session_type( - client_id=self._client_id, - token_endpoint=self._token_endpoint, - grant_type="password", - scope=scope, - default_timeout=AUTH_DEFAULT_TIMEOUT, - ) - token = session.fetch_token(username=config.username, password=config.password) - if "refresh_token" not in token: - _Warnings.auth_no_refresh_token(token["expires_in"]) - - return cast(T, session) - - def _get_session_client_credential(self, config: AuthClientCredentials) -> T: - scope: List[str] = self._default_scopes.copy() - - if config.scope_list is not None: - scope.extend(config.scope_list) - if len(scope) == 0: - # hardcode commonly used scopes - if self._token_endpoint.startswith("https://login.microsoftonline.com"): - scope = [self._client_id + "/.default"] - else: - raise MissingScopeError - - session = self.__session_type( - client_id=self._client_id, - client_secret=config.client_secret, - token_endpoint_auth_method="client_secret_post", - scope=scope, - token_endpoint=self._token_endpoint, - grant_type="client_credentials", - token={"access_token": None, "expires_in": -100}, - default_timeout=AUTH_DEFAULT_TIMEOUT, - ) - # explicitly fetch tokens. Otherwise, authlib will do it in the background and we might have race-conditions - session.fetch_token() - return cast(T, session) diff --git a/weaviate/embedded.py b/weaviate/embedded.py index 191dc496c..03768f82e 100644 --- a/weaviate/embedded.py +++ b/weaviate/embedded.py @@ -15,7 +15,7 @@ from pathlib import Path from typing import Dict, Optional, Tuple -import requests +import httpx import validators from weaviate import exceptions @@ -88,9 +88,7 @@ def __init__(self, options: EmbeddedOptions) -> None: self._parsed_weaviate_version = version_tag self._set_download_url_from_version_tag(version_tag) elif self.options.version == "latest": - response = requests.get( - "https://api.github.com/repos/weaviate/weaviate/releases/latest" - ) + response = httpx.get("https://api.github.com/repos/weaviate/weaviate/releases/latest") latest = _decode_json_response_dict(response, "get tag of latest weaviate release") assert latest is not None self._set_download_url_from_version_tag(latest["tag_name"]) diff --git a/weaviate/exceptions.py b/weaviate/exceptions.py index cfab86042..c04f8990b 100644 --- a/weaviate/exceptions.py +++ b/weaviate/exceptions.py @@ -5,8 +5,8 @@ from json.decoder import JSONDecodeError from typing import Union, Tuple +from grpc.aio import AioRpcError # type: ignore import httpx -import requests ERROR_CODE_EXPLANATION = { 413: """Payload Too Large. Try to decrease the batch size or increase the maximum request size on your weaviate @@ -40,7 +40,7 @@ class UnexpectedStatusCodeError(WeaviateBaseError): not handled in the client implementation and suggests an error. """ - def __init__(self, message: str, response: Union[httpx.Response, requests.Response]): + def __init__(self, message: str, response: Union[httpx.Response, AioRpcError]): """ Is raised in case the status code returned from Weaviate is not handled in the client implementation and suggests an error. @@ -55,21 +55,27 @@ def __init__(self, message: str, response: Union[httpx.Response, requests.Respon `response`: The request response of which the status code was unexpected. """ - self._status_code: int = response.status_code - # Set error message - - try: - body = response.json() - except (requests.exceptions.JSONDecodeError, httpx.DecodingError, JSONDecodeError): - body = None - - msg = ( - message - + f"! Unexpected status code: {response.status_code}, with response body: {body}." - ) - if response.status_code in ERROR_CODE_EXPLANATION: - msg += " " + ERROR_CODE_EXPLANATION[response.status_code] - + if isinstance(response, httpx.Response): + self._status_code: int = response.status_code + # Set error message + + try: + body = response.json() + except (httpx.DecodingError, JSONDecodeError): + body = None + + msg = ( + message + + f"! Unexpected status code: {response.status_code}, with response body: {body}." + ) + if response.status_code in ERROR_CODE_EXPLANATION: + msg += " " + ERROR_CODE_EXPLANATION[response.status_code] + elif isinstance(response, AioRpcError): + self._status_code = int(response.code().value[0]) + msg = ( + message + + f"! Unexpected status code: {response.code().value[1]}, with response body: {response.details()}." + ) super().__init__(msg) @property @@ -81,7 +87,7 @@ def status_code(self) -> int: class ResponseCannotBeDecodedError(WeaviateBaseError): - def __init__(self, location: str, response: Union[httpx.Response, requests.Response]): + def __init__(self, location: str, response: httpx.Response): """Raised when a weaviate response cannot be decoded to json Arguments: @@ -360,10 +366,8 @@ def __init__(self, message: str, count: int) -> None: super().__init__(msg) -class InsufficientPermissionsError(WeaviateBaseError): +class InsufficientPermissionsError(UnexpectedStatusCodeError): """Is raised when a request to Weaviate fails due to insufficient permissions.""" - def __init__(self, res: httpx.Response) -> None: - err = res.json()["error"][0]["message"] - msg = f"""The request to Weaviate failed due to insufficient permissions. Details: {err}""" - super().__init__(msg) + def __init__(self, res: Union[httpx.Response, AioRpcError]) -> None: + super().__init__("forbidden", res) diff --git a/weaviate/util.py b/weaviate/util.py index 47b40ba40..bbf9743c9 100644 --- a/weaviate/util.py +++ b/weaviate/util.py @@ -12,9 +12,7 @@ from typing import Union, Sequence, Any, Optional, List, Dict, Generator, Tuple, cast import httpx -import requests import validators -from requests.exceptions import JSONDecodeError from weaviate.exceptions import ( SchemaValidationError, @@ -817,9 +815,7 @@ def _to_beacons(uuids: UUIDS, to_class: str = "") -> List[Dict[str, str]]: return [{"beacon": f"weaviate://localhost/{to_class}{uuid_to}"} for uuid_to in uuids] -def _decode_json_response_dict( - response: Union[httpx.Response, requests.Response], location: str -) -> Optional[Dict[str, Any]]: +def _decode_json_response_dict(response: httpx.Response, location: str) -> Optional[Dict[str, Any]]: if response is None: return None @@ -827,14 +823,14 @@ def _decode_json_response_dict( try: json_response = cast(Dict[str, Any], response.json()) return json_response - except JSONDecodeError: + except httpx.DecodingError: raise ResponseCannotBeDecodedError(location, response) raise UnexpectedStatusCodeError(location, response) def _decode_json_response_list( - response: Union[httpx.Response, requests.Response], location: str + response: httpx.Response, location: str ) -> Optional[List[Dict[str, Any]]]: if response is None: return None @@ -843,7 +839,7 @@ def _decode_json_response_list( try: json_response = response.json() return cast(list, json_response) - except JSONDecodeError: + except httpx.DecodingError: raise ResponseCannotBeDecodedError(location, response) raise UnexpectedStatusCodeError(location, response)