@@ -266,6 +266,8 @@ def event_loop_normal(self) -> None:
266266 self .model_weights_signal = np .zeros ([1 ], dtype = np .int32 )
267267 attention_dp_cached_prefill_tasks = []
268268 attention_dp_wait_prefill_iters = 0
269+ is_first_chunk_dict = {}
270+
269271 while True :
270272 if local_rank == 0 :
271273 if self .model_weights_status .value [0 ] != 0 :
@@ -350,6 +352,15 @@ def event_loop_normal(self) -> None:
350352 for request in req_dicts :
351353 if request .task_type .value == RequestType .PREFILL .value :
352354 tmp_need_cached_prefills .append (request )
355+ if request .request_id not in is_first_chunk_dict :
356+ is_first_chunk_dict [request .request_id ] = True
357+ request .is_first_chunk = True
358+ else :
359+ is_first_chunk_dict [request .request_id ] = False
360+ request .is_first_chunk = False
361+ else :
362+ is_first_chunk_dict [request .request_id ] = False
363+ request .is_first_chunk = False
353364 if tmp_need_cached_prefills :
354365 attention_dp_cached_prefill_tasks .append (tmp_need_cached_prefills )
355366 for request in tmp_need_cached_prefills :
@@ -366,12 +377,20 @@ def event_loop_normal(self) -> None:
366377 req_dicts .extend (attention_dp_cached_prefill_tasks .pop (0 ))
367378 attention_dp_wait_prefill_iters = 0
368379 else :
369- # wait until all ranks have prefill tasks or reached timeout
370- attention_dp_wait_prefill_iters += 1
371- if attention_dp_wait_prefill_iters > self .fd_config .attention_dp_time_out_iters :
372- if len (attention_dp_cached_prefill_tasks ) > 0 :
373- req_dicts .extend (attention_dp_cached_prefill_tasks .pop (0 ))
380+ if len (attention_dp_cached_prefill_tasks ) > 0 :
381+ for task in attention_dp_cached_prefill_tasks [0 ]:
382+ if not task .is_first_chunk :
383+ exist_non_first_chunk = True
384+ if exist_non_first_chunk :
385+ req_dicts .extend (attention_dp_cached_prefill_tasks .pop (0 ))
374386 attention_dp_wait_prefill_iters = 0
387+ else :
388+ # wait until all ranks have prefill tasks or reached timeout
389+ attention_dp_wait_prefill_iters += 1
390+ if attention_dp_wait_prefill_iters > self .fd_config .attention_dp_time_out_iters :
391+ if len (attention_dp_cached_prefill_tasks ) > 0 :
392+ req_dicts .extend (attention_dp_cached_prefill_tasks .pop (0 ))
393+ attention_dp_wait_prefill_iters = 0
375394
376395 if len (req_dicts ) > 0 :
377396 req_ids = [req .request_id for req in req_dicts ]
0 commit comments