Skip to content

Commit 906781b

Browse files
authored
[https://nvbugs/5948539][fix] Fix disagg gen-only benchmark (NVIDIA#12091)
Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>
1 parent be57adb commit 906781b

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1636,25 +1636,30 @@ def _prepare_and_schedule_batch(self):
16361636
self._check_disagg_gen_transfer_status()
16371637
self._check_kv_transfer_timeout()
16381638

1639-
# In gen-only benchmark mode with disaggregated serving, keep fetching
1640-
# until all real requests have arrived before adding ADP dummies.
1641-
# This ensures the benchmark starts with the exact number of real
1642-
# requests specified, since dummies only get added after this loop.
1639+
# In benchmark disagg mode, fetch requests in batches to avoid
1640+
# blocking the CTX→GEN KV cache pipeline. With ADP, fetch tp_size
1641+
# requests per batch (one per rank) for even distribution; without
1642+
# ADP, fetch 1 request per batch.
16431643
if not self.is_warmup and self.benchmark_req_queues_size > 0 \
16441644
and self.kv_cache_transceiver \
16451645
and self.num_fetch_requests < self.benchmark_req_queues_size:
1646+
batch_size = min(
1647+
self.dist.tp_size if self.enable_attention_dp else 1,
1648+
self.benchmark_req_queues_size)
1649+
fill_target = min(self.num_fetch_requests + batch_size,
1650+
self.benchmark_req_queues_size)
16461651
if self.dist.rank == 0:
16471652
logger.info(f"Starting benchmark fill loop, "
16481653
f"num_fetch_requests={self.num_fetch_requests}/"
1649-
f"{self.benchmark_req_queues_size}, "
1654+
f"{fill_target}, "
16501655
f"len(active_requests)={len(self.active_requests)}")
1651-
while self.num_fetch_requests < self.benchmark_req_queues_size:
1656+
while self.num_fetch_requests < fill_target:
16521657
iter_requests = self._fetch_and_activate_new_requests()
16531658
if self.should_stop_processing:
16541659
return None, None
16551660
new_requests += iter_requests
16561661
self.hang_detector.checkpoint()
1657-
if self.num_fetch_requests < self.benchmark_req_queues_size:
1662+
if self.num_fetch_requests < fill_target:
16581663
time.sleep(1)
16591664

16601665
iter_stats = None

0 commit comments

Comments
 (0)