@@ -267,7 +267,7 @@ def event_loop_normal(self) -> None:
267267 attention_dp_cached_prefill_tasks = []
268268 attention_dp_wait_prefill_iters = 0
269269 is_first_chunk_dict = {}
270-
270+ idx_request_id_dict = {}
271271 while True :
272272 if local_rank == 0 :
273273 if self .model_weights_status .value [0 ] != 0 :
@@ -350,6 +350,15 @@ def event_loop_normal(self) -> None:
350350 tmp_need_cached_prefills = []
351351 if len (req_dicts ) > 0 :
352352 for request in req_dicts :
353+ if request .idx not in idx_request_id_dict :
354+ idx_request_id_dict [request .idx ] = request .request_id
355+ else :
356+ if (
357+ idx_request_id_dict [request .idx ] != request .request_id
358+ ): # next request in this slot, delete data to prevent memory leak
359+ if idx_request_id_dict [request .idx ] in is_first_chunk_dict :
360+ del is_first_chunk_dict [idx_request_id_dict [request .idx ]]
361+ idx_request_id_dict [request .idx ] = request .request_id
353362 if request .task_type .value == RequestType .PREFILL .value :
354363 tmp_need_cached_prefills .append (request )
355364 if request .request_id not in is_first_chunk_dict :
@@ -371,6 +380,7 @@ def event_loop_normal(self) -> None:
371380 only_prefill_batch_list = []
372381 paddle .distributed .all_gather_object (only_prefill_batch_list , exist_prefill )
373382 if_only_prefill = all (only_prefill_batch_list )
383+ exist_non_first_chunk = False
374384 if if_only_prefill : # all ranks have prefill tasks
375385 # add a prefill task to current step
376386 if len (attention_dp_cached_prefill_tasks ) > 0 :
0 commit comments