@@ -134,7 +134,6 @@ def __init__(self,
134
134
# PUT or PUT_ASYNC
135
135
# tensor_id: torch.Tensor
136
136
self .send_queue : deque [SendQueueItem ] = deque ()
137
- self .send_request_id_to_tensor_ids : dict [str , set [str ]] = {}
138
137
if self .send_type == "PUT_ASYNC" :
139
138
self ._send_thread = threading .Thread (target = self .send_async ,
140
139
daemon = True )
@@ -143,6 +142,7 @@ def __init__(self,
143
142
# tensor_id: torch.Tensor/(addr, dtype, shape)
144
143
self .recv_store : dict [str , Any ] = {}
145
144
self .recv_request_id_to_tensor_ids : dict [str , set [str ]] = {}
145
+ self .send_request_id_to_tensor_ids : dict [str , set [str ]] = {}
146
146
self .socks : dict [str , Any ] = {} # remote_address: client socket
147
147
self .comms : dict [str , Any ] = {} # remote_address: (ncclComm_t, rank)
148
148
@@ -223,18 +223,26 @@ def send_tensor(
223
223
# GET
224
224
with self .send_store_cv :
225
225
tensor_size = tensor .element_size () * tensor .numel ()
226
+ if tensor_size > self .buffer_size_threshold :
227
+ logger .warning (
228
+ "❗[GET]tensor_id:%s, tensor_size:%d, is greater than"
229
+ "buffer size threshold :%d, skip send to %s, rank:%d" ,
230
+ tensor_id , tensor_size , self .buffer_size_threshold ,
231
+ remote_address , self .rank )
232
+ return False
226
233
while (self .buffer_size + tensor_size
227
234
> self .buffer_size_threshold ):
228
- oldest_tenser_id = next (iter (self .send_store ))
229
- oldest_tenser = self .send_store .pop (oldest_tenser_id )
230
- oldest_tenser_size = oldest_tenser .element_size (
231
- ) * oldest_tenser .numel ()
232
- self .buffer_size -= oldest_tenser_size
233
- logger .info (
235
+ assert len (self .send_store ) > 0
236
+ oldest_tensor_id = next (iter (self .send_store ))
237
+ oldest_tensor = self .send_store .pop (oldest_tensor_id )
238
+ oldest_tensor_size = oldest_tensor .element_size (
239
+ ) * oldest_tensor .numel ()
240
+ self .buffer_size -= oldest_tensor_size
241
+ logger .debug (
234
242
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
235
- " buffer_size:%d, oldest_tenser_size :%d, rank:%d" ,
243
+ " buffer_size:%d, oldest_tensor_size :%d, rank:%d" ,
236
244
remote_address , tensor_id , tensor_size , self .buffer_size ,
237
- oldest_tenser_size , self .rank )
245
+ oldest_tensor_size , self .rank )
238
246
239
247
self .send_store [tensor_id ] = tensor
240
248
self .buffer_size += tensor_size
0 commit comments