@@ -95,7 +95,7 @@ def prepend(self, item: List[TBatchInput]) -> None:
9595 self ._lock .release ()
9696
9797
98- Ref = TypeVar ("Ref" , bound = Union [ _BatchReference , batch_pb2 . BatchReference ] )
98+ Ref = TypeVar ("Ref" , bound = BatchReference )
9999
100100
101101class ReferencesBatchRequest (BatchRequest [Ref , BatchReferenceReturn ]):
@@ -111,8 +111,9 @@ def pop_items(self, pop_amount: int, uuid_lookup: Set[str]) -> List[Ref]:
111111 i = 0
112112 self ._lock .acquire ()
113113 while len (ret ) < pop_amount and len (self ._items ) > 0 and i < len (self ._items ):
114- if self ._items [i ].from_uuid not in uuid_lookup and (
115- self ._items [i ].to_uuid is None or self ._items [i ].to_uuid not in uuid_lookup
114+ if self ._items [i ].from_object_uuid not in uuid_lookup and (
115+ self ._items [i ].to_object_uuid is None
116+ or self ._items [i ].to_object_uuid not in uuid_lookup
116117 ):
117118 ret .append (self ._items .pop (i ))
118119 else :
@@ -132,7 +133,7 @@ def head(self) -> Optional[Ref]:
132133 return item
133134
134135
135- Obj = TypeVar ("Obj" , bound = Union [ _BatchObject , batch_pb2 . BatchObject ] )
136+ Obj = TypeVar ("Obj" , bound = BatchObject )
136137
137138
138139class ObjectsBatchRequest (Generic [Obj ], BatchRequest [Obj , BatchObjectReturn ]):
@@ -843,11 +844,11 @@ def __init__(
843844 batch_mode : _BatchMode ,
844845 executor : ThreadPoolExecutor ,
845846 vectorizer_batching : bool ,
846- objects : Optional [ObjectsBatchRequest [batch_pb2 . BatchObject ]] = None ,
847- references : Optional [ReferencesBatchRequest ] = None ,
847+ objects : Optional [ObjectsBatchRequest [BatchObject ]] = None ,
848+ references : Optional [ReferencesBatchRequest [ BatchReference ] ] = None ,
848849 ) -> None :
849- self .__batch_objects = objects or ObjectsBatchRequest [batch_pb2 . BatchObject ]()
850- self .__batch_references = references or ReferencesBatchRequest [batch_pb2 . BatchReference ]()
850+ self .__batch_objects = objects or ObjectsBatchRequest [BatchObject ]()
851+ self .__batch_references = references or ReferencesBatchRequest [BatchReference ]()
851852
852853 self .__connection = connection
853854 self .__consistency_level : ConsistencyLevel = consistency_level or ConsistencyLevel .QUORUM
@@ -879,6 +880,10 @@ def __init__(
879880 self .__objs_cache : dict [str , BatchObject ] = {}
880881 self .__refs_cache : dict [str , BatchReference ] = {}
881882
883+ self .__acks_lock = threading .Lock ()
884+ self .__inflight_objs : set [str ] = set ()
885+ self .__inflight_refs : set [str ] = set ()
886+
882887 # maxsize=1 so that __batch_send does not run faster than generator for __batch_recv
883888 # thereby using too much buffer in case of server-side shutdown
884889 self .__reqs : Queue [Optional [batch_pb2 .BatchStreamRequest ]] = Queue (maxsize = 1 )
@@ -1005,10 +1010,13 @@ def __batch_send(self) -> None:
10051010 return
10061011 time .sleep (refresh_time )
10071012
1013+ def __beacon (self , ref : batch_pb2 .BatchReference ) -> str :
1014+ return f"weaviate://localhost/{ ref .from_collection } { f'#{ ref .tenant } ' if ref .tenant != '' else '' } /{ ref .from_uuid } #{ ref .name } ->/{ ref .to_collection } /{ ref .to_uuid } "
1015+
10081016 def __generate_stream_requests (
10091017 self ,
1010- objs : List [batch_pb2 . BatchObject ],
1011- refs : List [batch_pb2 . BatchReference ],
1018+ objects : List [BatchObject ],
1019+ references : List [BatchReference ],
10121020 ) -> Generator [batch_pb2 .BatchStreamRequest , None , None ]:
10131021 per_object_overhead = 4 # extra overhead bytes per object in the request
10141022
@@ -1018,7 +1026,8 @@ def request_maker():
10181026 request = request_maker ()
10191027 total_size = request .ByteSize ()
10201028
1021- for obj in objs :
1029+ for object_ in objects :
1030+ obj = self .__batch_grpc .grpc_object (object_ ._to_internal ())
10221031 obj_size = obj .ByteSize () + per_object_overhead
10231032
10241033 if total_size + obj_size >= self .__batch_grpc .grpc_max_msg_size :
@@ -1028,8 +1037,12 @@ def request_maker():
10281037
10291038 request .data .objects .values .append (obj )
10301039 total_size += obj_size
1040+ if self .__connection ._weaviate_version .is_at_least (1 , 35 , 0 ):
1041+ with self .__acks_lock :
1042+ self .__inflight_objs .add (obj .uuid )
10311043
1032- for ref in refs :
1044+ for reference in references :
1045+ ref = self .__batch_grpc .grpc_reference (reference ._to_internal ())
10331046 ref_size = ref .ByteSize () + per_object_overhead
10341047
10351048 if total_size + ref_size >= self .__batch_grpc .grpc_max_msg_size :
@@ -1039,6 +1052,9 @@ def request_maker():
10391052
10401053 request .data .references .values .append (ref )
10411054 total_size += ref_size
1055+ if self .__connection ._weaviate_version .is_at_least (1 , 35 , 0 ):
1056+ with self .__acks_lock :
1057+ self .__inflight_refs .add (reference ._to_beacon ())
10421058
10431059 if len (request .data .objects .values ) > 0 or len (request .data .references .values ) > 0 :
10441060 yield request
@@ -1091,6 +1107,10 @@ def __batch_recv(self) -> None:
10911107 logger .warning (
10921108 f"Updated batch size to { self .__batch_size } as per server request"
10931109 )
1110+ if message .HasField ("acks" ):
1111+ with self .__acks_lock :
1112+ self .__inflight_objs .difference_update (message .acks .uuids )
1113+ self .__inflight_refs .difference_update (message .acks .beacons )
10941114 if message .HasField ("results" ):
10951115 result_objs = BatchObjectReturn ()
10961116 result_refs = BatchReferenceReturn ()
@@ -1241,19 +1261,9 @@ def batch_recv_wrapper() -> None:
12411261 logger .warning (
12421262 f"Re-adding { len (self .__objs_cache )} cached objects to the batch"
12431263 )
1244- self .__batch_objects .prepend (
1245- [
1246- self .__batch_grpc .grpc_object (o ._to_internal ())
1247- for o in self .__objs_cache .values ()
1248- ]
1249- )
1264+ self .__batch_objects .prepend (list (self .__objs_cache .values ()))
12501265 with self .__refs_cache_lock :
1251- self .__batch_references .prepend (
1252- [
1253- self .__batch_grpc .grpc_reference (o ._to_internal ())
1254- for o in self .__refs_cache .values ()
1255- ]
1256- )
1266+ self .__batch_references .prepend (list (self .__refs_cache .values ()))
12571267 # start a new stream with a newly reconnected channel
12581268 return batch_recv_wrapper ()
12591269
@@ -1307,14 +1317,14 @@ def _add_object(
13071317 uuid = str (batch_object .uuid )
13081318 with self .__uuid_lookup_lock :
13091319 self .__uuid_lookup .add (uuid )
1310- self .__batch_objects .add (self . __batch_grpc . grpc_object ( batch_object . _to_internal ()) )
1320+ self .__batch_objects .add (batch_object )
13111321 with self .__objs_cache_lock :
13121322 self .__objs_cache [uuid ] = batch_object
13131323 self .__objs_count += 1
13141324
13151325 # block if queue gets too long or weaviate is overloaded - reading files is faster them sending them so we do
13161326 # not need a long queue
1317- while len (self .__batch_objects ) >= self .__batch_size * 2 :
1327+ while len (self .__inflight_objs ) >= self .__batch_size * 2 :
13181328 self .__check_bg_threads_alive ()
13191329 time .sleep (0.01 )
13201330
@@ -1352,12 +1362,13 @@ def _add_reference(
13521362 )
13531363 except ValidationError as e :
13541364 raise WeaviateBatchValidationError (repr (e ))
1355- self .__batch_references .add (
1356- self .__batch_grpc .grpc_reference (batch_reference ._to_internal ())
1357- )
1365+ self .__batch_references .add (batch_reference )
13581366 with self .__refs_cache_lock :
13591367 self .__refs_cache [batch_reference ._to_beacon ()] = batch_reference
13601368 self .__refs_count += 1
1369+ while len (self .__inflight_refs ) >= self .__batch_size * 2 :
1370+ self .__check_bg_threads_alive ()
1371+ time .sleep (0.01 )
13611372
13621373 def __check_bg_threads_alive (self ) -> None :
13631374 if self .__any_threads_alive ():
0 commit comments