Skip to content

Commit 9e74bb3

Browse files
committed
Add logic to handle acks msg capping in-flight stream usage
1 parent e9d49f8 commit 9e74bb3

File tree

16 files changed

+544
-428
lines changed

16 files changed

+544
-428
lines changed

.github/workflows/main.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ env:
2626
WEAVIATE_132: 1.32.16
2727
WEAVIATE_133: 1.33.4
2828
WEAVIATE_134: 1.34.0
29+
WEAVIATE_135: 1.35.0-dev-8d38bb2.amd64
2930

3031
jobs:
3132
lint-and-format:
@@ -304,7 +305,8 @@ jobs:
304305
$WEAVIATE_131,
305306
$WEAVIATE_132,
306307
$WEAVIATE_133,
307-
$WEAVIATE_134
308+
$WEAVIATE_134,
309+
$WEAVIATE_135
308310
]
309311
steps:
310312
- name: Checkout

profiling/conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ def _factory(
6161
headers=headers,
6262
additional_config=AdditionalConfig(timeout=(60, 120)), # for image tests
6363
)
64+
# client_fixture = weaviate.connect_to_weaviate_cloud(
65+
# cluster_url="flnyoj61teuw1mxfwf1fsa.c0.europe-west3.gcp.weaviate.cloud",
66+
# auth_credentials=weaviate.auth.Auth.api_key("QnVtdnlnM2RYeUh3NVlFNF82V3pqVEtoYnloMlo0MHV2R2hYMU9BUFFsR3cvUUlkUG9CTFRiQXNjam1nPV92MjAw"),
67+
# )
6468
client_fixture.collections.delete(name_fixture)
6569
if integration_config is not None:
6670
client_fixture.integrations.configure(integration_config)

profiling/test_sphere.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_sphere(collection_factory: CollectionFactory) -> None:
2828
start = time.time()
2929

3030
import_objects = 1000000
31-
with collection.batch.dynamic() as batch:
31+
with collection.batch.experimental() as batch:
3232
with open(sphere_file) as jsonl_file:
3333
for i, jsonl in enumerate(jsonl_file):
3434
if i == import_objects or batch.number_errors > 10:
@@ -46,7 +46,7 @@ def test_sphere(collection_factory: CollectionFactory) -> None:
4646
vector=json_parsed["vector"],
4747
)
4848
if i % 1000 == 0:
49-
print(f"Imported {len(collection)} objects")
49+
print(f"Imported {len(collection)} objects after processing {i} lines")
5050
assert len(collection.batch.failed_objects) == 0
5151
assert len(collection) == import_objects
5252
print(f"Imported {import_objects} objects in {time.time() - start}")

weaviate/collections/batch/base.py

Lines changed: 40 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def prepend(self, item: List[TBatchInput]) -> None:
9595
self._lock.release()
9696

9797

98-
Ref = TypeVar("Ref", bound=Union[_BatchReference, batch_pb2.BatchReference])
98+
Ref = TypeVar("Ref", bound=BatchReference)
9999

100100

101101
class ReferencesBatchRequest(BatchRequest[Ref, BatchReferenceReturn]):
@@ -111,8 +111,9 @@ def pop_items(self, pop_amount: int, uuid_lookup: Set[str]) -> List[Ref]:
111111
i = 0
112112
self._lock.acquire()
113113
while len(ret) < pop_amount and len(self._items) > 0 and i < len(self._items):
114-
if self._items[i].from_uuid not in uuid_lookup and (
115-
self._items[i].to_uuid is None or self._items[i].to_uuid not in uuid_lookup
114+
if self._items[i].from_object_uuid not in uuid_lookup and (
115+
self._items[i].to_object_uuid is None
116+
or self._items[i].to_object_uuid not in uuid_lookup
116117
):
117118
ret.append(self._items.pop(i))
118119
else:
@@ -132,7 +133,7 @@ def head(self) -> Optional[Ref]:
132133
return item
133134

134135

135-
Obj = TypeVar("Obj", bound=Union[_BatchObject, batch_pb2.BatchObject])
136+
Obj = TypeVar("Obj", bound=BatchObject)
136137

137138

138139
class ObjectsBatchRequest(Generic[Obj], BatchRequest[Obj, BatchObjectReturn]):
@@ -843,11 +844,11 @@ def __init__(
843844
batch_mode: _BatchMode,
844845
executor: ThreadPoolExecutor,
845846
vectorizer_batching: bool,
846-
objects: Optional[ObjectsBatchRequest[batch_pb2.BatchObject]] = None,
847-
references: Optional[ReferencesBatchRequest] = None,
847+
objects: Optional[ObjectsBatchRequest[BatchObject]] = None,
848+
references: Optional[ReferencesBatchRequest[BatchReference]] = None,
848849
) -> None:
849-
self.__batch_objects = objects or ObjectsBatchRequest[batch_pb2.BatchObject]()
850-
self.__batch_references = references or ReferencesBatchRequest[batch_pb2.BatchReference]()
850+
self.__batch_objects = objects or ObjectsBatchRequest[BatchObject]()
851+
self.__batch_references = references or ReferencesBatchRequest[BatchReference]()
851852

852853
self.__connection = connection
853854
self.__consistency_level: ConsistencyLevel = consistency_level or ConsistencyLevel.QUORUM
@@ -879,6 +880,10 @@ def __init__(
879880
self.__objs_cache: dict[str, BatchObject] = {}
880881
self.__refs_cache: dict[str, BatchReference] = {}
881882

883+
self.__acks_lock = threading.Lock()
884+
self.__inflight_objs: set[str] = set()
885+
self.__inflight_refs: set[str] = set()
886+
882887
# maxsize=1 so that __batch_send does not run faster than generator for __batch_recv
883888
# thereby using too much buffer in case of server-side shutdown
884889
self.__reqs: Queue[Optional[batch_pb2.BatchStreamRequest]] = Queue(maxsize=1)
@@ -1005,10 +1010,13 @@ def __batch_send(self) -> None:
10051010
return
10061011
time.sleep(refresh_time)
10071012

1013+
def __beacon(self, ref: batch_pb2.BatchReference) -> str:
1014+
return f"weaviate://localhost/{ref.from_collection}{f'#{ref.tenant}' if ref.tenant != '' else ''}/{ref.from_uuid}#{ref.name}->/{ref.to_collection}/{ref.to_uuid}"
1015+
10081016
def __generate_stream_requests(
10091017
self,
1010-
objs: List[batch_pb2.BatchObject],
1011-
refs: List[batch_pb2.BatchReference],
1018+
objects: List[BatchObject],
1019+
references: List[BatchReference],
10121020
) -> Generator[batch_pb2.BatchStreamRequest, None, None]:
10131021
per_object_overhead = 4 # extra overhead bytes per object in the request
10141022

@@ -1018,7 +1026,8 @@ def request_maker():
10181026
request = request_maker()
10191027
total_size = request.ByteSize()
10201028

1021-
for obj in objs:
1029+
for object_ in objects:
1030+
obj = self.__batch_grpc.grpc_object(object_._to_internal())
10221031
obj_size = obj.ByteSize() + per_object_overhead
10231032

10241033
if total_size + obj_size >= self.__batch_grpc.grpc_max_msg_size:
@@ -1028,8 +1037,12 @@ def request_maker():
10281037

10291038
request.data.objects.values.append(obj)
10301039
total_size += obj_size
1040+
if self.__connection._weaviate_version.is_at_least(1, 35, 0):
1041+
with self.__acks_lock:
1042+
self.__inflight_objs.add(obj.uuid)
10311043

1032-
for ref in refs:
1044+
for reference in references:
1045+
ref = self.__batch_grpc.grpc_reference(reference._to_internal())
10331046
ref_size = ref.ByteSize() + per_object_overhead
10341047

10351048
if total_size + ref_size >= self.__batch_grpc.grpc_max_msg_size:
@@ -1039,6 +1052,9 @@ def request_maker():
10391052

10401053
request.data.references.values.append(ref)
10411054
total_size += ref_size
1055+
if self.__connection._weaviate_version.is_at_least(1, 35, 0):
1056+
with self.__acks_lock:
1057+
self.__inflight_refs.add(reference._to_beacon())
10421058

10431059
if len(request.data.objects.values) > 0 or len(request.data.references.values) > 0:
10441060
yield request
@@ -1091,6 +1107,10 @@ def __batch_recv(self) -> None:
10911107
logger.warning(
10921108
f"Updated batch size to {self.__batch_size} as per server request"
10931109
)
1110+
if message.HasField("acks"):
1111+
with self.__acks_lock:
1112+
self.__inflight_objs.difference_update(message.acks.uuids)
1113+
self.__inflight_refs.difference_update(message.acks.beacons)
10941114
if message.HasField("results"):
10951115
result_objs = BatchObjectReturn()
10961116
result_refs = BatchReferenceReturn()
@@ -1241,19 +1261,9 @@ def batch_recv_wrapper() -> None:
12411261
logger.warning(
12421262
f"Re-adding {len(self.__objs_cache)} cached objects to the batch"
12431263
)
1244-
self.__batch_objects.prepend(
1245-
[
1246-
self.__batch_grpc.grpc_object(o._to_internal())
1247-
for o in self.__objs_cache.values()
1248-
]
1249-
)
1264+
self.__batch_objects.prepend(list(self.__objs_cache.values()))
12501265
with self.__refs_cache_lock:
1251-
self.__batch_references.prepend(
1252-
[
1253-
self.__batch_grpc.grpc_reference(o._to_internal())
1254-
for o in self.__refs_cache.values()
1255-
]
1256-
)
1266+
self.__batch_references.prepend(list(self.__refs_cache.values()))
12571267
# start a new stream with a newly reconnected channel
12581268
return batch_recv_wrapper()
12591269

@@ -1307,14 +1317,14 @@ def _add_object(
13071317
uuid = str(batch_object.uuid)
13081318
with self.__uuid_lookup_lock:
13091319
self.__uuid_lookup.add(uuid)
1310-
self.__batch_objects.add(self.__batch_grpc.grpc_object(batch_object._to_internal()))
1320+
self.__batch_objects.add(batch_object)
13111321
with self.__objs_cache_lock:
13121322
self.__objs_cache[uuid] = batch_object
13131323
self.__objs_count += 1
13141324

13151325
# block if queue gets too long or weaviate is overloaded - reading files is faster them sending them so we do
13161326
# not need a long queue
1317-
while len(self.__batch_objects) >= self.__batch_size * 2:
1327+
while len(self.__inflight_objs) >= self.__batch_size * 2:
13181328
self.__check_bg_threads_alive()
13191329
time.sleep(0.01)
13201330

@@ -1352,12 +1362,13 @@ def _add_reference(
13521362
)
13531363
except ValidationError as e:
13541364
raise WeaviateBatchValidationError(repr(e))
1355-
self.__batch_references.add(
1356-
self.__batch_grpc.grpc_reference(batch_reference._to_internal())
1357-
)
1365+
self.__batch_references.add(batch_reference)
13581366
with self.__refs_cache_lock:
13591367
self.__refs_cache[batch_reference._to_beacon()] = batch_reference
13601368
self.__refs_count += 1
1369+
while len(self.__inflight_refs) >= self.__batch_size * 2:
1370+
self.__check_bg_threads_alive()
1371+
time.sleep(0.01)
13611372

13621373
def __check_bg_threads_alive(self) -> None:
13631374
if self.__any_threads_alive():

0 commit comments

Comments
 (0)