diff --git a/integration/test_batch_v4.py b/integration/test_batch_v4.py index 7788ef03d..cebd4594c 100644 --- a/integration/test_batch_v4.py +++ b/integration/test_batch_v4.py @@ -1,4 +1,3 @@ -import asyncio import concurrent.futures import uuid from dataclasses import dataclass @@ -819,11 +818,34 @@ def test_references_with_to_uuids(client_factory: ClientFactory) -> None: client.collections.delete(["target", "source"]) +def test_ingest_one_hundred_thousand_data_objects( + client_factory: ClientFactory, +) -> None: + client, name = client_factory() + if client._connection._weaviate_version.is_lower_than(1, 34, 0): + pytest.skip("Server-side batching not supported in Weaviate < 1.34.0") + nr_objects = 100000 + import time + + start = time.time() + results = client.collections.use(name).data.ingest( + {"name": "test" + str(i)} for i in range(nr_objects) + ) + end = time.time() + print(f"Time taken to add {nr_objects} objects: {end - start} seconds") + assert len(results.errors) == 0 + assert len(results.all_responses) == nr_objects + assert len(results.uuids) == nr_objects + assert len(client.collections.use(name)) == nr_objects + assert results.has_errors is False + assert len(results.errors) == 0, [obj.message for obj in results.errors.values()] + client.collections.delete(name) + + @pytest.mark.asyncio -async def test_add_ten_thousand_data_objects_async( +async def test_ingest_one_hundred_thousand_data_objects_async( async_client_factory: AsyncClientFactory, ) -> None: - """Test adding ten thousand data objects.""" client, name = await async_client_factory() if client._connection._weaviate_version.is_lower_than(1, 34, 0): pytest.skip("Server-side batching not supported in Weaviate < 1.34.0") @@ -831,26 +853,15 @@ async def test_add_ten_thousand_data_objects_async( import time start = time.time() - async with client.batch.experimental(concurrency=1) as batch: - async for i in arange(nr_objects): - await batch.add_object( - collection=name, - properties={"name": "test" + str(i)}, - ) + results = await client.collections.use(name).data.ingest( + {"name": "test" + str(i)} for i in range(nr_objects) + ) end = time.time() print(f"Time taken to add {nr_objects} objects: {end - start} seconds") - assert len(client.batch.results.objs.errors) == 0 - assert len(client.batch.results.objs.all_responses) == nr_objects - assert len(client.batch.results.objs.uuids) == nr_objects + assert len(results.errors) == 0 + assert len(results.all_responses) == nr_objects + assert len(results.uuids) == nr_objects assert await client.collections.use(name).length() == nr_objects - assert client.batch.results.objs.has_errors is False - assert len(client.batch.failed_objects) == 0, [ - obj.message for obj in client.batch.failed_objects - ] + assert results.has_errors is False + assert len(results.errors) == 0, [obj.message for obj in results.errors.values()] await client.collections.delete(name) - - -async def arange(count): - for i in range(count): - yield i - await asyncio.sleep(0.0) diff --git a/weaviate/client.py b/weaviate/client.py index d7f9080f4..8cf856c51 100644 --- a/weaviate/client.py +++ b/weaviate/client.py @@ -10,7 +10,7 @@ from .auth import AuthCredentials from .backup import _Backup, _BackupAsync from .cluster import _Cluster, _ClusterAsync -from .collections.batch.client import _BatchClientWrapper, _BatchClientWrapperAsync +from .collections.batch.client import _BatchClientWrapper from .collections.collections import _Collections, _CollectionsAsync from .config import AdditionalConfig from .connect import executor @@ -76,7 +76,6 @@ def __init__( ) self.alias = _AliasAsync(self._connection) self.backup = _BackupAsync(self._connection) - self.batch = _BatchClientWrapperAsync(self._connection) self.cluster = _ClusterAsync(self._connection) self.collections = _CollectionsAsync(self._connection) self.debug = _DebugAsync(self._connection) diff --git a/weaviate/client.pyi b/weaviate/client.pyi index 9b32af15f..205a34b4e 100644 --- a/weaviate/client.pyi +++ b/weaviate/client.pyi @@ -18,7 +18,7 @@ from weaviate.users.sync import _Users from .backup import _Backup, _BackupAsync from .cluster import _Cluster, _ClusterAsync -from .collections.batch.client import _BatchClientWrapper, _BatchClientWrapperAsync +from .collections.batch.client import _BatchClientWrapper from .debug import _Debug, _DebugAsync from .rbac import _Roles, _RolesAsync from .types import NUMBER @@ -29,7 +29,6 @@ class WeaviateAsyncClient(_WeaviateClientExecutor[ConnectionAsync]): _connection: ConnectionAsync alias: _AliasAsync backup: _BackupAsync - batch: _BatchClientWrapperAsync collections: _CollectionsAsync cluster: _ClusterAsync debug: _DebugAsync diff --git a/weaviate/collections/batch/async_.py b/weaviate/collections/batch/async_.py index f5f5758d3..555df84fa 100644 --- a/weaviate/collections/batch/async_.py +++ b/weaviate/collections/batch/async_.py @@ -3,6 +3,8 @@ import uuid as uuid_package from typing import ( AsyncGenerator, + Awaitable, + Callable, List, Optional, Set, @@ -15,8 +17,6 @@ ObjectsBatchRequest, ReferencesBatchRequest, _BatchDataWrapper, - _BatchMode, - _ServerSideBatching, ) from weaviate.collections.batch.grpc_batch import _BatchGRPC from weaviate.collections.classes.batch import ( @@ -36,6 +36,7 @@ from weaviate.collections.classes.types import WeaviateProperties from weaviate.connect.v4 import ConnectionAsync from weaviate.exceptions import ( + WeaviateBatchStreamError, WeaviateBatchValidationError, WeaviateGRPCUnavailableError, WeaviateStartUpError, @@ -57,7 +58,6 @@ def __init__( connection: ConnectionAsync, consistency_level: Optional[ConsistencyLevel], results: _BatchDataWrapper, - batch_mode: _BatchMode, objects: Optional[ObjectsBatchRequest[batch_pb2.BatchObject]] = None, references: Optional[ReferencesBatchRequest] = None, ) -> None: @@ -95,8 +95,6 @@ def __init__( self.__stop = False - self.__batch_mode = batch_mode - @property def number_errors(self) -> int: """Return the number of errors in the batch.""" @@ -104,12 +102,50 @@ def number_errors(self) -> int: self.__results_for_wrapper.failed_references ) + async def __wrap(self, fn: Callable[[], Awaitable[None]]): + try: + await fn() + except Exception as e: + socket_hung_up = False + if isinstance(e, WeaviateBatchStreamError) and ( + "Socket closed" in e.message or "context canceled" in e.message + ): + socket_hung_up = True + else: + logger.error(e) + logger.error(type(e)) + self.__bg_thread_exception = e + if socket_hung_up: + # this happens during ungraceful shutdown of the coordinator + # lets restart the stream and add the cached objects again + logger.warning("Stream closed unexpectedly, restarting...") + await self.__reconnect() + # server sets this whenever it restarts, gracefully or unexpectedly, so need to clear it now + self.__is_shutting_down.clear() + with self.__objs_cache_lock: + logger.warning( + f"Re-adding {len(self.__objs_cache)} cached objects to the batch" + ) + await self.__batch_objects.aprepend( + [ + self.__batch_grpc.grpc_object(o._to_internal()) + for o in self.__objs_cache.values() + ] + ) + with self.__refs_cache_lock: + await self.__batch_references.aprepend( + [ + self.__batch_grpc.grpc_reference(o._to_internal()) + for o in self.__refs_cache.values() + ] + ) + # start a new fn with a newly reconnected channel + return await fn() + async def _start(self): - assert isinstance(self.__batch_mode, _ServerSideBatching), ( - "Only server-side batching is supported in this mode" - ) return _BgTasks( - send=asyncio.create_task(self.__send()), recv=asyncio.create_task(self.__recv()) + send=asyncio.create_task(self.__wrap(self.__send)), + recv=asyncio.create_task(self.__wrap(self.__recv)), ) async def _shutdown(self) -> None: @@ -332,74 +368,6 @@ async def __reconnect(self, retry: int = 0) -> None: logger.error("Failed to reconnect after 5 attempts") self.__bg_thread_exception = e - # def __start_bg_threads(self) -> _BgThreads: - # """Create a background thread that periodically checks how congested the batch queue is.""" - # self.__shut_background_thread_down = threading.Event() - - # def batch_send_wrapper() -> None: - # try: - # self.__batch_send() - # logger.warning("exited batch send thread") - # except Exception as e: - # logger.error(e) - # self.__bg_thread_exception = e - - # def batch_recv_wrapper() -> None: - # socket_hung_up = False - # try: - # self.__batch_recv() - # logger.warning("exited batch receive thread") - # except Exception as e: - # if isinstance(e, WeaviateBatchStreamError) and ( - # "Socket closed" in e.message or "context canceled" in e.message - # ): - # socket_hung_up = True - # else: - # logger.error(e) - # logger.error(type(e)) - # self.__bg_thread_exception = e - # if socket_hung_up: - # # this happens during ungraceful shutdown of the coordinator - # # lets restart the stream and add the cached objects again - # logger.warning("Stream closed unexpectedly, restarting...") - # self.__reconnect() - # # server sets this whenever it restarts, gracefully or unexpectedly, so need to clear it now - # self.__is_shutting_down.clear() - # with self.__objs_cache_lock: - # logger.warning( - # f"Re-adding {len(self.__objs_cache)} cached objects to the batch" - # ) - # self.__batch_objects.prepend( - # [ - # self.__batch_grpc.grpc_object(o._to_internal()) - # for o in self.__objs_cache.values() - # ] - # ) - # with self.__refs_cache_lock: - # self.__batch_references.prepend( - # [ - # self.__batch_grpc.grpc_reference(o._to_internal()) - # for o in self.__refs_cache.values() - # ] - # ) - # # start a new stream with a newly reconnected channel - # return batch_recv_wrapper() - - # threads = _BgThreads( - # send=threading.Thread( - # target=batch_send_wrapper, - # daemon=True, - # name="BgBatchSend", - # ), - # recv=threading.Thread( - # target=batch_recv_wrapper, - # daemon=True, - # name="BgBatchRecv", - # ), - # ) - # threads.start_recv() - # return threads - async def flush(self) -> None: """Flush the batch queue and wait for all requests to be finished.""" # bg thread is sending objs+refs automatically, so simply wait for everything to be done diff --git a/weaviate/collections/batch/batch_wrapper.py b/weaviate/collections/batch/batch_wrapper.py index f8e40395c..928513545 100644 --- a/weaviate/collections/batch/batch_wrapper.py +++ b/weaviate/collections/batch/batch_wrapper.py @@ -1,4 +1,3 @@ -import asyncio import time from typing import Any, Generic, List, Optional, Protocol, TypeVar, Union, cast @@ -8,9 +7,7 @@ _BatchDataWrapper, _BatchMode, _ClusterBatch, - _ClusterBatchAsync, _DynamicBatching, - _ServerSideBatching, ) from weaviate.collections.batch.sync import _BatchBaseSync from weaviate.collections.classes.batch import ( @@ -24,7 +21,7 @@ from weaviate.collections.classes.tenants import Tenant from weaviate.collections.classes.types import Properties, WeaviateProperties from weaviate.connect import executor -from weaviate.connect.v4 import ConnectionAsync, ConnectionSync +from weaviate.connect.v4 import ConnectionSync from weaviate.logger import logger from weaviate.types import UUID, VECTORS from weaviate.util import _capitalize_first_letter, _decode_json_response_list @@ -131,109 +128,6 @@ def results(self) -> BatchResult: return self._batch_data.results -class _BatchWrapperAsync: - def __init__( - self, - connection: ConnectionAsync, - consistency_level: Optional[ConsistencyLevel], - ): - self._connection = connection - self._consistency_level = consistency_level - self._current_batch: Optional[_BatchBaseAsync] = None - # config options - self._batch_mode: _BatchMode = _ServerSideBatching(1) - - self._batch_data = _BatchDataWrapper() - self._cluster = _ClusterBatchAsync(connection) - - async def __is_ready( - self, max_count: int, shards: Optional[List[Shard]], backoff_count: int = 0 - ) -> bool: - try: - readinesses = await asyncio.gather( - *[ - self.__get_shards_readiness(shard) - for shard in shards or self._batch_data.imported_shards - ] - ) - return all(all(readiness) for readiness in readinesses) - except Exception as e: - logger.warning( - f"Error while getting class shards statuses: {e}, trying again with 2**n={2**backoff_count}s exponential backoff with n={backoff_count}" - ) - if backoff_count >= max_count: - raise e - await asyncio.sleep(2**backoff_count) - return await self.__is_ready(max_count, shards, backoff_count + 1) - - async def wait_for_vector_indexing( - self, shards: Optional[List[Shard]] = None, how_many_failures: int = 5 - ) -> None: - """Wait for the all the vectors of the batch imported objects to be indexed. - - Upon network error, it will retry to get the shards' status for `how_many_failures` times - with exponential backoff (2**n seconds with n=0,1,2,...,how_many_failures). - - Args: - shards: The shards to check the status of. If `None` it will check the status of all the shards of the imported objects in the batch. - how_many_failures: How many times to try to get the shards' status before raising an exception. Default 5. - """ - if shards is not None and not isinstance(shards, list): - raise TypeError(f"'shards' must be of type List[Shard]. Given type: {type(shards)}.") - if shards is not None and not isinstance(shards[0], Shard): - raise TypeError(f"'shards' must be of type List[Shard]. Given type: {type(shards)}.") - - waiting_count = 0 - while not await self.__is_ready(how_many_failures, shards): - if waiting_count % 20 == 0: # print every 5s - logger.debug("Waiting for async indexing to finish...") - await asyncio.sleep(0.25) - waiting_count += 1 - logger.debug("Async indexing finished!") - - async def __get_shards_readiness(self, shard: Shard) -> List[bool]: - path = f"/schema/{_capitalize_first_letter(shard.collection)}/shards{'' if shard.tenant is None else f'?tenant={shard.tenant}'}" - response = await executor.aresult(self._connection.get(path=path)) - - res = _decode_json_response_list(response, "Get shards' status") - assert res is not None - return [ - (cast(str, shard.get("status")) == "READY") - & (cast(int, shard.get("vectorQueueSize")) == 0) - for shard in res - ] - - async def _get_shards_readiness(self, shard: Shard) -> List[bool]: - return await self.__get_shards_readiness(shard) - - @property - def failed_objects(self) -> List[ErrorObject]: - """Get all failed objects from the batch manager. - - Returns: - A list of all the failed objects from the batch. - """ - return self._batch_data.failed_objects - - @property - def failed_references(self) -> List[ErrorReference]: - """Get all failed references from the batch manager. - - Returns: - A list of all the failed references from the batch. - """ - return self._batch_data.failed_references - - @property - def results(self) -> BatchResult: - """Get the results of the batch operation. - - Returns: - The results of the batch operation. - """ - return self._batch_data.results - - class BatchClientProtocol(Protocol): def add_object( self, @@ -311,83 +205,6 @@ def number_errors(self) -> int: ... -class BatchClientProtocolAsync(Protocol): - async def add_object( - self, - collection: str, - properties: Optional[WeaviateProperties] = None, - references: Optional[ReferenceInputs] = None, - uuid: Optional[UUID] = None, - vector: Optional[VECTORS] = None, - tenant: Optional[Union[str, Tenant]] = None, - ) -> UUID: - """Add one object to this batch. - - NOTE: If the UUID of one of the objects already exists then the existing object will be - replaced by the new object. - - Args: - collection: The name of the collection this object belongs to. - properties: The data properties of the object to be added as a dictionary. - references: The references of the object to be added as a dictionary. - uuid: The UUID of the object as an uuid.UUID object or str. It can be a Weaviate beacon or Weaviate href. - If it is None an UUIDv4 will generated, by default None - vector: The embedding of the object. Can be used when a collection does not have a vectorization module or the given - vector was generated using the _identical_ vectorization module that is configured for the class. In this - case this vector takes precedence. - Supported types are: - - for single vectors: `list`, 'numpy.ndarray`, `torch.Tensor` and `tf.Tensor`, by default None. - - for named vectors: Dict[str, *list above*], where the string is the name of the vector. - tenant: The tenant name or Tenant object to be used for this request. - - Returns: - The UUID of the added object. If one was not provided a UUIDv4 will be auto-generated for you and returned here. - - Raises: - WeaviateBatchValidationError: If the provided options are in the format required by Weaviate. - """ - ... - - async def add_reference( - self, - from_uuid: UUID, - from_collection: str, - from_property: str, - to: ReferenceInput, - tenant: Optional[Union[str, Tenant]] = None, - ) -> None: - """Add one reference to this batch. - - Args: - from_uuid: The UUID of the object, as an uuid.UUID object or str, that should reference another object. - from_collection: The name of the collection that should reference another object. - from_property: The name of the property that contains the reference. - to: The UUID of the referenced object, as an uuid.UUID object or str, that is actually referenced. - For multi-target references use wvc.Reference.to_multi_target(). - tenant: The tenant name or Tenant object to be used for this request. - - Raises: - WeaviateBatchValidationError: If the provided options are in the format required by Weaviate. - """ - ... - - async def flush(self) -> None: - """Flush the current batch. - - This will send all the objects and references in the current batch to Weaviate. - """ - ... - - @property - def number_errors(self) -> int: - """Get the number of errors in the current batch. - - Returns: - The number of errors in the current batch. - """ - ... - - class BatchCollectionProtocol(Generic[Properties], Protocol[Properties]): def add_object( self, @@ -444,7 +261,7 @@ def number_errors(self) -> int: ... -T = TypeVar("T", bound=Union[_BatchBase, _BatchBaseSync]) +T = TypeVar("T", bound=Union[_BatchBase, _BatchBaseSync, _BatchBaseAsync]) P = TypeVar("P", bound=Union[BatchClientProtocol, BatchCollectionProtocol[Properties]]) @@ -460,82 +277,20 @@ def __enter__(self) -> P: return self.__current_batch # pyright: ignore[reportReturnType] -class BatchClientAsync(_BatchBaseAsync): - async def add_object( - self, - collection: str, - properties: Optional[WeaviateProperties] = None, - references: Optional[ReferenceInputs] = None, - uuid: Optional[UUID] = None, - vector: Optional[VECTORS] = None, - tenant: Optional[Union[str, Tenant]] = None, - ) -> UUID: - """Add one object to this batch. - - NOTE: If the UUID of one of the objects already exists then the existing object will be - replaced by the new object. - - Args: - collection: The name of the collection this object belongs to. - properties: The data properties of the object to be added as a dictionary. - references: The references of the object to be added as a dictionary. - uuid: The UUID of the object as an uuid.UUID object or str. It can be a Weaviate beacon or Weaviate href. - If it is None an UUIDv4 will generated, by default None - vector: The embedding of the object. Can be used when a collection does not have a vectorization module or the given - vector was generated using the _identical_ vectorization module that is configured for the class. In this - case this vector takes precedence. - Supported types are: - - for single vectors: `list`, 'numpy.ndarray`, `torch.Tensor` and `tf.Tensor`, by default None. - - for named vectors: Dict[str, *list above*], where the string is the name of the vector. - tenant: The tenant name or Tenant object to be used for this request. - - Returns: - The UUID of the added object. If one was not provided a UUIDv4 will be auto-generated for you and returned here. - - Raises: - WeaviateBatchValidationError: If the provided options are in the format required by Weaviate. - """ - return await super()._add_object( - collection=collection, - properties=properties, - references=references, - uuid=uuid, - vector=vector, - tenant=tenant.name if isinstance(tenant, Tenant) else tenant, - ) - - async def add_reference( - self, - from_uuid: UUID, - from_collection: str, - from_property: str, - to: ReferenceInput, - tenant: Optional[Union[str, Tenant]] = None, - ) -> None: - """Add one reference to this batch. +class _ContextManagerSync: + def __init__(self, current_batch: _BatchBaseSync): + self.__current_batch = current_batch - Args: - from_uuid: The UUID of the object, as an uuid.UUID object or str, that should reference another object. - from_collection: The name of the collection that should reference another object. - from_property: The name of the property that contains the reference. - to: The UUID of the referenced object, as an uuid.UUID object or str, that is actually referenced. - For multi-target references use wvc.Reference.to_multi_target(). - tenant: The tenant name or Tenant object to be used for this request. + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.__current_batch._shutdown() - Raises: - WeaviateBatchValidationError: If the provided options are in the format required by Weaviate. - """ - await super()._add_reference( - from_object_uuid=from_uuid, - from_object_collection=from_collection, - from_property_name=from_property, - to=to, - tenant=tenant.name if isinstance(tenant, Tenant) else tenant, - ) + def __enter__(self) -> _BatchBaseSync: + self.__bg_tasks = self.__current_batch._start() + return self.__current_batch -class _ContextManagerWrapperAsync: - def __init__(self, current_batch: BatchClientAsync): +class _ContextManagerAsync: + def __init__(self, current_batch: _BatchBaseAsync): self.__current_batch = current_batch async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: @@ -543,6 +298,6 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: await self.__bg_tasks.send await self.__bg_tasks.recv - async def __aenter__(self) -> BatchClientAsync: + async def __aenter__(self) -> _BatchBaseAsync: self.__bg_tasks = await self.__current_batch._start() return self.__current_batch diff --git a/weaviate/collections/batch/client.py b/weaviate/collections/batch/client.py index 0aaddd718..902d393ae 100644 --- a/weaviate/collections/batch/client.py +++ b/weaviate/collections/batch/client.py @@ -10,20 +10,17 @@ _ServerSideBatching, ) from weaviate.collections.batch.batch_wrapper import ( - BatchClientAsync, BatchClientProtocol, _BatchMode, _BatchWrapper, - _BatchWrapperAsync, _ContextManagerWrapper, - _ContextManagerWrapperAsync, ) from weaviate.collections.batch.sync import _BatchBaseSync from weaviate.collections.classes.config import ConsistencyLevel, Vectorizers from weaviate.collections.classes.internal import ReferenceInput, ReferenceInputs from weaviate.collections.classes.tenants import Tenant from weaviate.collections.classes.types import WeaviateProperties -from weaviate.connect.v4 import ConnectionAsync, ConnectionSync +from weaviate.connect.v4 import ConnectionSync from weaviate.exceptions import UnexpectedStatusCodeError, WeaviateUnsupportedFeatureError from weaviate.types import UUID, VECTORS @@ -184,7 +181,6 @@ def add_reference( ClientBatchingContextManager = _ContextManagerWrapper[ Union[BatchClient, BatchClientSync], BatchClientProtocol ] -AsyncClientBatchingContextManager = _ContextManagerWrapperAsync class _BatchClientWrapper(_BatchWrapper): @@ -316,46 +312,3 @@ def experimental( ) self._consistency_level = consistency_level return self.__create_batch_and_reset(_BatchClientSync) - - -class _BatchClientWrapperAsync(_BatchWrapperAsync): - def __init__( - self, - connection: ConnectionAsync, - ): - super().__init__(connection, None) - self._vectorizer_batching: Optional[bool] = None - - def __create_batch_and_reset(self): - self._batch_data = _BatchDataWrapper() # clear old data - return _ContextManagerWrapperAsync( - BatchClientAsync( - connection=self._connection, - consistency_level=self._consistency_level, - results=self._batch_data, - batch_mode=self._batch_mode, - ) - ) - - def experimental( - self, - *, - concurrency: Optional[int] = None, - consistency_level: Optional[ConsistencyLevel] = None, - ) -> AsyncClientBatchingContextManager: - """Configure the batching context manager using the experimental server-side batching mode. - - When you exit the context manager, the final batch will be sent automatically. - """ - if self._connection._weaviate_version.is_lower_than(1, 34, 0): - raise WeaviateUnsupportedFeatureError( - "Server-side batching", str(self._connection._weaviate_version), "1.34.0" - ) - self._batch_mode = _ServerSideBatching( - # concurrency=concurrency - # if concurrency is not None - # else len(self._cluster.get_nodes_status()) - concurrency=1, # hard-code until client-side multi-threading is fixed - ) - self._consistency_level = consistency_level - return self.__create_batch_and_reset() diff --git a/weaviate/collections/batch/sync.py b/weaviate/collections/batch/sync.py index d87b5db02..030d85c7e 100644 --- a/weaviate/collections/batch/sync.py +++ b/weaviate/collections/batch/sync.py @@ -14,7 +14,6 @@ _BatchMode, _BgThreads, _ClusterBatch, - _ServerSideBatching, ) from weaviate.collections.batch.grpc_batch import _BatchGRPC from weaviate.collections.classes.batch import ( @@ -51,9 +50,9 @@ def __init__( connection: ConnectionSync, consistency_level: Optional[ConsistencyLevel], results: _BatchDataWrapper, - batch_mode: _BatchMode, - executor: ThreadPoolExecutor, - vectorizer_batching: bool, + batch_mode: Optional[_BatchMode] = None, + executor: Optional[ThreadPoolExecutor] = None, + vectorizer_batching: bool = False, objects: Optional[ObjectsBatchRequest[batch_pb2.BatchObject]] = None, references: Optional[ReferencesBatchRequest] = None, ) -> None: @@ -96,8 +95,6 @@ def __init__( self.__stop = False - self.__batch_mode = batch_mode - self.__total = 0 @property @@ -118,12 +115,7 @@ def __any_threads_alive(self) -> bool: ) def _start(self) -> None: - assert isinstance(self.__batch_mode, _ServerSideBatching), ( - "Only server-side batching is supported in this mode" - ) - self.__bg_threads = [ - self.__start_bg_threads() for _ in range(self.__batch_mode.concurrency) - ] + self.__bg_threads = [self.__start_bg_threads() for _ in range(1)] logger.warning( f"Provisioned {len(self.__bg_threads)} stream(s) to the server for batch processing" ) diff --git a/weaviate/collections/data/async_.pyi b/weaviate/collections/data/async_.pyi index 15108447a..8dd092bf3 100644 --- a/weaviate/collections/data/async_.pyi +++ b/weaviate/collections/data/async_.pyi @@ -1,7 +1,6 @@ import uuid as uuid_package -from typing import Generic, List, Literal, Optional, Sequence, Union, overload +from typing import Generic, Iterable, List, Literal, Optional, Sequence, Union, overload -from weaviate.collections.batch.collection import _BatchCollectionWrapper from weaviate.collections.batch.grpc_batch import _BatchGRPC from weaviate.collections.batch.grpc_batch_delete import _BatchDeleteGRPC from weaviate.collections.batch.rest import _BatchREST @@ -30,7 +29,6 @@ class _DataCollectionAsync( __batch_delete: _BatchDeleteGRPC __batch_grpc: _BatchGRPC __batch_rest: _BatchREST - __batch: _BatchCollectionWrapper[Properties] async def insert( self, @@ -81,3 +79,6 @@ class _DataCollectionAsync( async def delete_many( self, where: _Filters, *, verbose: bool = False, dry_run: bool = False ) -> Union[DeleteManyReturn[List[DeleteManyObject]], DeleteManyReturn[None]]: ... + async def ingest( + self, objs: Iterable[Union[Properties, DataObject[Properties, Optional[ReferenceInputs]]]] + ) -> BatchObjectReturn: ... diff --git a/weaviate/collections/data/executor.py b/weaviate/collections/data/executor.py index 8d6d12d40..2a1a2e296 100644 --- a/weaviate/collections/data/executor.py +++ b/weaviate/collections/data/executor.py @@ -5,6 +5,7 @@ Any, Dict, Generic, + Iterable, List, Literal, Mapping, @@ -19,10 +20,13 @@ from httpx import Response -from weaviate.collections.batch.collection import _BatchCollectionWrapper +from weaviate.collections.batch.async_ import _BatchBaseAsync +from weaviate.collections.batch.base import _BatchDataWrapper +from weaviate.collections.batch.batch_wrapper import _ContextManagerAsync, _ContextManagerSync from weaviate.collections.batch.grpc_batch import _BatchGRPC from weaviate.collections.batch.grpc_batch_delete import _BatchDeleteGRPC from weaviate.collections.batch.rest import _BatchREST +from weaviate.collections.batch.sync import _BatchBaseSync from weaviate.collections.classes.batch import ( BatchObjectReturn, BatchReferenceReturn, @@ -61,7 +65,6 @@ class _DataCollectionExecutor(Generic[ConnectionType, Properties]): __batch_delete: _BatchDeleteGRPC __batch_grpc: _BatchGRPC __batch_rest: _BatchREST - __batch: _BatchCollectionWrapper[Properties] def __init__( self, @@ -704,3 +707,82 @@ def __parse_vector(self, obj: Dict[str, Any], vector: VECTORS) -> Dict[str, Any] else: obj["vector"] = _get_vector_v4(vector) return obj + + def ingest( + self, objs: Iterable[Union[Properties, DataObject[Properties, Optional[ReferenceInputs]]]] + ) -> executor.Result[BatchObjectReturn]: + """Ingest multiple objects into the collection in batches. The batching is handled automatically for you by Weaviate. + + This is different from `insert_many` which sends all objects in a single batch request. Use this method when you want to insert a large number of objects without worrying about batch sizes + and whether they will fit into the maximum allowed batch size of your Weaviate instance. In addition, use this instead of `client.batch.dynamic()` or `collection.batch.dynamic()` for a more + performant dynamic batching algorithm that utilizes server-side batching. + + Args: + objs: An iterable of objects to insert. This can be either a sequence of `Properties` or `DataObject[Properties, ReferenceInputs]` + If you didn't set `data_model` then `Properties` will be `Data[str, Any]` in which case you can insert simple dictionaries here. + """ + if isinstance(self._connection, ConnectionAsync): + con = self._connection + + async def execute() -> BatchObjectReturn: + results = _BatchDataWrapper() + ctx = _ContextManagerAsync( + _BatchBaseAsync( + connection=con, + results=results, + consistency_level=self._consistency_level, + ) + ) + async with ctx as batch: + for obj in objs: + if isinstance(obj, DataObject): + await batch._add_object( + collection=self.name, + properties=cast(dict, obj.properties), + references=obj.references, + uuid=obj.uuid, + vector=obj.vector, + tenant=self._tenant, + ) + else: + await batch._add_object( + collection=self.name, + properties=cast(dict, obj), + references=None, + uuid=None, + vector=None, + tenant=self._tenant, + ) + return results.results.objs + + return execute() + + results = _BatchDataWrapper() + ctx = _ContextManagerSync( + _BatchBaseSync( + connection=self._connection, + results=results, + consistency_level=self._consistency_level, + ) + ) + with ctx as batch: + for obj in objs: + if isinstance(obj, DataObject): + batch._add_object( + collection=self.name, + properties=cast(dict, obj.properties), + references=obj.references, + uuid=obj.uuid, + vector=obj.vector, + tenant=self._tenant, + ) + else: + batch._add_object( + collection=self.name, + properties=cast(dict, obj), + references=None, + uuid=None, + vector=None, + tenant=self._tenant, + ) + return results.results.objs diff --git a/weaviate/collections/data/sync.pyi b/weaviate/collections/data/sync.pyi index eda3da21a..ab1eb3f39 100644 --- a/weaviate/collections/data/sync.pyi +++ b/weaviate/collections/data/sync.pyi @@ -1,7 +1,6 @@ import uuid as uuid_package -from typing import Generic, List, Literal, Optional, Sequence, Union, overload +from typing import Generic, Iterable, List, Literal, Optional, Sequence, Union, overload -from weaviate.collections.batch.collection import _BatchCollectionWrapper from weaviate.collections.batch.grpc_batch import _BatchGRPC from weaviate.collections.batch.grpc_batch_delete import _BatchDeleteGRPC from weaviate.collections.batch.rest import _BatchREST @@ -28,7 +27,6 @@ class _DataCollection(Generic[Properties,], _DataCollectionExecutor[ConnectionSy __batch_delete: _BatchDeleteGRPC __batch_grpc: _BatchGRPC __batch_rest: _BatchREST - __batch: _BatchCollectionWrapper[Properties] def insert( self, @@ -79,3 +77,6 @@ class _DataCollection(Generic[Properties,], _DataCollectionExecutor[ConnectionSy def delete_many( self, where: _Filters, *, verbose: bool = False, dry_run: bool = False ) -> Union[DeleteManyReturn[List[DeleteManyObject]], DeleteManyReturn[None]]: ... + def ingest( + self, objs: Iterable[Union[Properties, DataObject[Properties, Optional[ReferenceInputs]]]] + ) -> BatchObjectReturn: ...