diff --git a/mock_tests/conftest.py b/mock_tests/conftest.py index ec2735d30..23436fa4a 100644 --- a/mock_tests/conftest.py +++ b/mock_tests/conftest.py @@ -11,6 +11,7 @@ import weaviate from mock_tests.mock_data import mock_class +from weaviate.config import AdditionalConfig, RetryConfig from weaviate.connect.base import ConnectionParams, ProtocolParams from weaviate.proto.v1 import ( batch_delete_pb2, @@ -139,6 +140,20 @@ def weaviate_client( client.close() +@pytest.fixture(scope="function") +def weaviate_client_retry_timeout( + weaviate_mock: HTTPServer, start_grpc_server: grpc.Server +) -> Generator[weaviate.WeaviateClient, None, None]: + client = weaviate.connect_to_local( + port=MOCK_PORT, + host=MOCK_IP, + grpc_port=MOCK_PORT_GRPC, + additional_config=AdditionalConfig(retry=RetryConfig(timeout_ms=500)), + ) + yield client + client.close() + + @pytest.fixture(scope="function") def weaviate_timeouts_client( weaviate_timeouts_mock: HTTPServer, start_grpc_server: grpc.Server @@ -148,7 +163,8 @@ def weaviate_timeouts_client( port=MOCK_PORT, grpc_port=MOCK_PORT_GRPC, additional_config=weaviate.classes.init.AdditionalConfig( - timeout=weaviate.classes.init.Timeout(query=0.5, insert=1.5) + timeout=weaviate.classes.init.Timeout(query=0.5, insert=1.5), + retry=weaviate.config.RetryConfig(request_retry_count=5, request_retry_backoff_ms=0), ), ) yield client @@ -253,6 +269,40 @@ def BatchObjects( class MockRetriesWeaviateService(weaviate_pb2_grpc.WeaviateServicer): search_count = 0 tenants_count = 0 + delete_count = 0 + batch_count = 0 + + def BatchObjects( + self, request: batch_pb2.BatchObjectsRequest, context: grpc.ServicerContext + ) -> batch_pb2.BatchObjectsReply: + if self.batch_count == 0: + self.batch_count += 1 + context.set_code(grpc.StatusCode.ABORTED) + context.set_details("Aborted") + return batch_pb2.BatchObjectsReply() + if self.batch_count == 1: + self.batch_count += 1 + context.set_code(grpc.StatusCode.CANCELLED) + context.set_details("Cancelled") + return batch_pb2.BatchObjectsReply() + return batch_pb2.BatchObjectsReply( + errors=[], + ) + + def BatchDelete( + self, request: batch_delete_pb2.BatchDeleteRequest, context: grpc.ServicerContext + ) -> batch_delete_pb2.BatchDeleteReply: + if self.delete_count == 0: + self.delete_count += 1 + context.set_code(grpc.StatusCode.DEADLINE_EXCEEDED) + context.set_details("Deadline Exceeded") + return batch_delete_pb2.BatchDeleteReply() + if self.delete_count == 1: + self.delete_count += 1 + context.set_code(grpc.StatusCode.UNAVAILABLE) + context.set_details("Service is unavailable") + return batch_delete_pb2.BatchDeleteReply() + return batch_delete_pb2.BatchDeleteReply(matches=1, failed=0, successful=1, objects=[]) def Search( self, request: search_get_pb2.SearchRequest, context: grpc.ServicerContext @@ -310,6 +360,15 @@ def retries( return weaviate_client.collections.use("RetriesCollection"), service +@pytest.fixture(scope="function") +def no_retries( + weaviate_client_retry_timeout: weaviate.WeaviateClient, start_grpc_server: grpc.Server +) -> tuple[weaviate.collections.Collection, MockRetriesWeaviateService]: + service = MockRetriesWeaviateService() + weaviate_pb2_grpc.add_WeaviateServicer_to_server(service, start_grpc_server) + return weaviate_client_retry_timeout.collections.use("RetriesCollection"), service + + class MockForbiddenWeaviateService(weaviate_pb2_grpc.WeaviateServicer): def Search( self, request: search_get_pb2.SearchRequest, context: grpc.ServicerContext diff --git a/mock_tests/test_collection.py b/mock_tests/test_collection.py index 38cd0b088..a784aedee 100644 --- a/mock_tests/test_collection.py +++ b/mock_tests/test_collection.py @@ -29,6 +29,7 @@ VectorIndexType, Vectorizers, ) +from weaviate.collections.classes.filters import Filter from weaviate.connect.base import ConnectionParams, ProtocolParams from weaviate.connect.integrations import _IntegrationConfig from weaviate.exceptions import ( @@ -36,6 +37,9 @@ InsufficientPermissionsError, UnexpectedStatusCodeError, WeaviateStartUpError, + WeaviateQueryError, + WeaviateBatchError, + WeaviateDeleteManyError, ) ACCESS_TOKEN = "HELLO!IamAnAccessToken" @@ -372,26 +376,43 @@ def test_grpc_retry_logic( collection = retries[0] service = retries[1] - with pytest.raises(weaviate.exceptions.WeaviateQueryError): - # checks first call correctly handles INTERNAL error - collection.query.fetch_objects() - # should perform one retry and then succeed subsequently objs = collection.query.fetch_objects().objects assert len(objs) == 1 assert objs[0].properties["name"] == "test" assert service.search_count == 2 - with pytest.raises(weaviate.exceptions.WeaviateTenantGetError): - # checks first call correctly handles error that isn't UNAVAILABLE - collection.tenants.get() - # should perform one retry and then succeed subsequently tenants = list(collection.tenants.get().values()) assert len(tenants) == 1 assert tenants[0].name == "tenant1" assert service.tenants_count == 2 + # Should perform two retry and then succeed subsequently + collection.data.insert_many(objects=[{"Hello": "World"}]) + + # should perform two retries and then succeed subsequently + deleted = collection.data.delete_many(where=Filter.by_id().equal(objs[0].uuid)) + assert deleted.matches == 1 + + +def test_grpc_retry_timeout_logic( + no_retries: tuple[weaviate.collections.Collection, MockRetriesWeaviateService], +) -> None: + collection, _ = no_retries[0], no_retries[1] + + # timeout after 1 retry + with pytest.raises(WeaviateQueryError): + collection.query.fetch_objects().objects + + # timeout after 1 retry + with pytest.raises(WeaviateBatchError): + collection.data.insert_many(objects=[{"Hello": "World"}]) + + # timeout after 1 retry + with pytest.raises(WeaviateDeleteManyError): + collection.data.delete_many(where=Filter.by_property("Hello").equal("World")) + def test_grpc_forbidden_exception(forbidden: weaviate.collections.Collection) -> None: with pytest.raises(weaviate.exceptions.InsufficientPermissionsError): diff --git a/mock_tests/test_timeouts.py b/mock_tests/test_timeouts.py index 5f5a51b57..4ad94d3b9 100644 --- a/mock_tests/test_timeouts.py +++ b/mock_tests/test_timeouts.py @@ -1,7 +1,7 @@ import pytest import weaviate -from weaviate.exceptions import WeaviateQueryError, WeaviateTimeoutError +from weaviate.exceptions import WeaviateQueryError, WeaviateTimeoutError, WeaviateBatchError def test_timeout_rest_query(timeouts_collection: weaviate.collections.Collection): @@ -21,6 +21,6 @@ def test_timeout_grpc_query(timeouts_collection: weaviate.collections.Collection def test_timeout_grpc_insert(timeouts_collection: weaviate.collections.Collection): - with pytest.raises(WeaviateQueryError) as recwarn: + with pytest.raises(WeaviateBatchError) as recwarn: timeouts_collection.data.insert_many([{"what": "ever"}]) assert "DEADLINE_EXCEEDED" in str(recwarn) diff --git a/weaviate/client_executor.py b/weaviate/client_executor.py index 6d1d0c59b..82cbf04a6 100644 --- a/weaviate/client_executor.py +++ b/weaviate/client_executor.py @@ -76,6 +76,7 @@ def __init__( additional_headers=additional_headers, embedded_db=embedded_db, connection_config=config.connection, + retry_config=config.retry, proxies=config.proxies, trust_env=config.trust_env, skip_init_checks=skip_init_checks, diff --git a/weaviate/config.py b/weaviate/config.py index bc0525531..93852c044 100644 --- a/weaviate/config.py +++ b/weaviate/config.py @@ -66,6 +66,26 @@ class Proxies(BaseModel): grpc: Optional[str] = Field(default=None) +@dataclass +class RetryConfig: + request_retry_count: int = 20 + request_retry_backoff_ms: int = 100 + timeout_ms: int = 30000 + + def __post_init__(self) -> None: + if not isinstance(self.request_retry_count, int): + raise TypeError( + f"request_retry_count must be {int}, received {type(self.request_retry_count)}" + ) + if not isinstance(self.request_retry_backoff_ms, int): + raise TypeError( + f"request_retry_backoff_ms must be {int}, received {type(self.request_retry_backoff_ms)}" + ) + + if not isinstance(self.timeout_ms, int): + raise TypeError(f"timeout_ms must be {int}, received {type(self.timeout_ms)}") + + class AdditionalConfig(BaseModel): """Use this class to specify the connection and proxy settings for your client when connecting to Weaviate. @@ -80,6 +100,7 @@ class AdditionalConfig(BaseModel): connection: ConnectionConfig = Field(default_factory=ConnectionConfig) proxies: Union[str, Proxies, None] = Field(default=None) timeout_: Union[Tuple[int, int], Timeout] = Field(default_factory=Timeout, alias="timeout") + retry: RetryConfig = Field(default_factory=RetryConfig) trust_env: bool = Field(default=False) @property diff --git a/weaviate/connect/v4.py b/weaviate/connect/v4.py index 3734f650a..d50d91a6c 100644 --- a/weaviate/connect/v4.py +++ b/weaviate/connect/v4.py @@ -1,5 +1,6 @@ from __future__ import annotations +import datetime import time from copy import copy from dataclasses import dataclass, field @@ -50,7 +51,7 @@ from weaviate import __version__ as client_version from weaviate.auth import AuthApiKey, AuthClientCredentials, AuthCredentials -from weaviate.config import ConnectionConfig, Proxies +from weaviate.config import ConnectionConfig, Proxies, RetryConfig from weaviate.config import Timeout as TimeoutConfig from weaviate.connect import executor from weaviate.connect.authentication import _Auth @@ -130,6 +131,7 @@ def __init__( trust_env: bool, additional_headers: Optional[Dict[str, Any]], connection_config: ConnectionConfig, + retry_config: RetryConfig, embedded_db: Optional[EmbeddedV4] = None, skip_init_checks: bool = False, ): @@ -143,6 +145,7 @@ def __init__( self._grpc_stub: Optional[weaviate_pb2_grpc.WeaviateStub] = None self._grpc_channel: Union[AsyncChannel, SyncChannel, None] = None self.timeout_config = timeout_config + self.retry_config = retry_config self.__connection_config = connection_config self.__trust_env = trust_env self._weaviate_version = _ServerVersion.from_string("") @@ -323,7 +326,11 @@ def _ping_grpc(self, colour: executor.Colour) -> Union[None, Awaitable[None]]: assert self._grpc_channel is not None try: - res = self._grpc_channel.unary_unary( + res = _Retry(self.retry_config).with_exponential_backoff( + 0, + datetime.datetime.now(), + "", + self._grpc_channel.unary_unary, "/grpc.health.v1.Health/Check", request_serializer=health_weaviate_pb2.WeaviateHealthCheckRequest.SerializeToString, response_deserializer=health_weaviate_pb2.WeaviateHealthCheckResponse.FromString, @@ -985,8 +992,9 @@ def wait_for_weaviate(self, startup_period: int) -> None: def grpc_search(self, request: search_get_pb2.SearchRequest) -> search_get_pb2.SearchReply: try: assert self.grpc_stub is not None - res = _Retry(4).with_exponential_backoff( + res = _Retry(self.retry_config).with_exponential_backoff( 0, + datetime.datetime.now(), f"Searching in collection {request.collection}", self.grpc_stub.Search, request, @@ -1010,8 +1018,15 @@ def grpc_batch_objects( ) -> Dict[int, str]: try: assert self.grpc_stub is not None - res = _Retry(max_retries).with_exponential_backoff( + res = _Retry( + RetryConfig( + request_retry_count=int(max_retries), + request_retry_backoff_ms=self.retry_config.request_retry_backoff_ms, + timeout_ms=self.retry_config.timeout_ms, + ) + ).with_exponential_backoff( count=0, + start_time=datetime.datetime.now(), error="Batch objects", f=self.grpc_stub.BatchObjects, request=request, @@ -1024,6 +1039,8 @@ def grpc_batch_objects( for err in res.errors: objects[err.index] = err.error return objects + except WeaviateRetryError as e: + raise WeaviateBatchError(str(e)) from e except RpcError as e: error = cast(Call, e) if error.code() == StatusCode.PERMISSION_DENIED: @@ -1053,14 +1070,21 @@ def grpc_batch_delete( ) -> batch_delete_pb2.BatchDeleteReply: try: assert self.grpc_stub is not None + res = _Retry(self.retry_config).with_exponential_backoff( + 0, + datetime.datetime.now(), + "Batch Delete", + self.grpc_stub.BatchDelete, + request, + metadata=self.grpc_headers(), + timeout=self.timeout_config.insert, + ) return cast( batch_delete_pb2.BatchDeleteReply, - self.grpc_stub.BatchDelete( - request, - metadata=self.grpc_headers(), - timeout=self.timeout_config.insert, - ), + res, ) + except WeaviateRetryError as e: + raise WeaviateDeleteManyError(str(e)) from e except RpcError as e: error = cast(Call, e) if error.code() == StatusCode.PERMISSION_DENIED: @@ -1072,8 +1096,9 @@ def grpc_tenants_get( ) -> tenants_pb2.TenantsGetReply: try: assert self.grpc_stub is not None - res = _Retry().with_exponential_backoff( + res = _Retry(self.retry_config).with_exponential_backoff( 0, + datetime.datetime.now(), f"Get tenants for collection {request.collection}", self.grpc_stub.TenantsGet, request, @@ -1093,8 +1118,9 @@ def grpc_aggregate( ) -> aggregate_pb2.AggregateReply: try: assert self.grpc_stub is not None - res = _Retry(4).with_exponential_backoff( + res = _Retry(self.retry_config).with_exponential_backoff( 0, + datetime.datetime.now(), f"Searching in collection {request.collection}", self.grpc_stub.Aggregate, request, @@ -1189,8 +1215,9 @@ async def grpc_search( ) -> search_get_pb2.SearchReply: try: assert self.grpc_stub is not None - res = await _Retry(4).awith_exponential_backoff( + res = await _Retry(self.retry_config).awith_exponential_backoff( 0, + datetime.datetime.now(), f"Searching in collection {request.collection}", self.grpc_stub.Search, request, @@ -1213,8 +1240,15 @@ async def grpc_batch_objects( ) -> Dict[int, str]: try: assert self.grpc_stub is not None - res = await _Retry(max_retries).awith_exponential_backoff( + res = await _Retry( + RetryConfig( + request_retry_count=int(max_retries), + request_retry_backoff_ms=self.retry_config.request_retry_backoff_ms, + timeout_ms=self.retry_config.timeout_ms, + ) + ).awith_exponential_backoff( count=0, + start_time=datetime.datetime.now(), error="Batch objects", f=self.grpc_stub.BatchObjects, request=request, @@ -1227,6 +1261,8 @@ async def grpc_batch_objects( for err in res.errors: objects[err.index] = err.error return objects + except WeaviateRetryError as e: + raise WeaviateBatchError(str(e)) from e except AioRpcError as e: if e.code().name == PERMISSION_DENIED: raise InsufficientPermissionsError(e) @@ -1242,6 +1278,8 @@ async def grpc_batch_delete( metadata=self.grpc_headers(), timeout=self.timeout_config.insert, ) + except WeaviateRetryError as e: + raise WeaviateDeleteManyError(str(e)) from e except AioRpcError as e: if e.code().name == PERMISSION_DENIED: raise InsufficientPermissionsError(e) @@ -1252,14 +1290,17 @@ async def grpc_tenants_get( ) -> tenants_pb2.TenantsGetReply: try: assert self.grpc_stub is not None - res = await _Retry().awith_exponential_backoff( + res = await _Retry(self.retry_config).awith_exponential_backoff( 0, + datetime.datetime.now(), f"Get tenants for collection {request.collection}", self.grpc_stub.TenantsGet, request, metadata=self.grpc_headers(), timeout=self.timeout_config.query, ) + except WeaviateRetryError as e: + raise WeaviateTenantGetError(str(e)) from e except AioRpcError as e: if e.code().name == PERMISSION_DENIED: raise InsufficientPermissionsError(e) @@ -1272,8 +1313,9 @@ async def grpc_aggregate( ) -> aggregate_pb2.AggregateReply: try: assert self.grpc_stub is not None - res = await _Retry(4).awith_exponential_backoff( + res = await _Retry(self.retry_config).awith_exponential_backoff( 0, + datetime.datetime.now(), f"Searching in collection {request.collection}", self.grpc_stub.Aggregate, request, diff --git a/weaviate/retry.py b/weaviate/retry.py index 23a419d4a..383100cd0 100644 --- a/weaviate/retry.py +++ b/weaviate/retry.py @@ -1,4 +1,5 @@ import asyncio +import datetime import time from typing import Awaitable, Callable, cast @@ -6,6 +7,7 @@ from grpc.aio import AioRpcError # type: ignore from typing_extensions import ParamSpec, TypeVar +from weaviate.config import RetryConfig from weaviate.exceptions import WeaviateRetryError from weaviate.logger import logger @@ -14,12 +16,27 @@ class _Retry: - def __init__(self, n: float = 4) -> None: - self.n = n + def __init__(self, retry_config: RetryConfig) -> None: + self.config = retry_config + + def is_retriable(self, e: Exception) -> bool: + if isinstance(e, AioRpcError) or isinstance(e, RpcError): + err = cast(Call, e) + return err.code() in [ + StatusCode.UNAVAILABLE, + StatusCode.NOT_FOUND, + StatusCode.DEADLINE_EXCEEDED, + StatusCode.ABORTED, + StatusCode.INTERNAL, + StatusCode.CANCELLED, + StatusCode.ABORTED, + ] + return False async def awith_exponential_backoff( self, count: int, + start_time: datetime.datetime, error: str, f: Callable[P, Awaitable[T]], *args: P.args, @@ -28,19 +45,26 @@ async def awith_exponential_backoff( try: return await f(*args, **kwargs) except AioRpcError as e: - if e.code() != StatusCode.UNAVAILABLE: + if not self.is_retriable(e): raise e + if ( + (datetime.datetime.now() - start_time).total_seconds() * 1000 + ) > self.config.timeout_ms: + raise WeaviateRetryError(str(e), count) from e + if count > self.config.request_retry_count: + raise WeaviateRetryError(str(e), count) from e logger.info( f"{error} received exception: {e}. Retrying with exponential backoff in {2**count} seconds" ) - await asyncio.sleep(2**count) - if count > self.n: - raise WeaviateRetryError(str(e), count) from e - return await self.awith_exponential_backoff(count + 1, error, f, *args, **kwargs) + await asyncio.sleep((self.config.request_retry_backoff_ms / 1000.0) ** count) + return await self.awith_exponential_backoff( + count + 1, start_time, error, f, *args, **kwargs + ) def with_exponential_backoff( self, count: int, + start_time: datetime.datetime, error: str, f: Callable[P, T], *args: P.args, @@ -48,14 +72,17 @@ def with_exponential_backoff( ) -> T: try: return f(*args, **kwargs) - except RpcError as e: - err = cast(Call, e) - if err.code() != StatusCode.UNAVAILABLE: + except Exception as e: + if not self.is_retriable(e): raise e + if ( + (datetime.datetime.now() - start_time).total_seconds() * 1000 + ) > self.config.timeout_ms: + raise WeaviateRetryError(str(e), count) from e + if count > self.config.request_retry_count: + raise WeaviateRetryError(str(e), count) from e logger.info( f"{error} received exception: {e}. Retrying with exponential backoff in {2**count} seconds" ) - time.sleep(2**count) - if count > self.n: - raise WeaviateRetryError(str(e), count) from e - return self.with_exponential_backoff(count + 1, error, f, *args, **kwargs) + time.sleep((self.config.request_retry_backoff_ms / 1000.0) ** count) + return self.with_exponential_backoff(count + 1, start_time, error, f, *args, **kwargs)