Skip to content

Commit 12bf10c

Browse files
hyuknvideodanchik
authored andcommitted
[TRTLLM-9615][feat] Support synchronization through PP ranks in the distributed tuning system (NVIDIA#10011)
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> Signed-off-by: Daniil Kulko <kulkodaniil@gmail.com>
1 parent 88c4955 commit 12bf10c

File tree

3 files changed

+64
-6
lines changed

3 files changed

+64
-6
lines changed

tensorrt_llm/_torch/autotuner.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
from tensorrt_llm.logger import logger
2323
from tensorrt_llm.mapping import Mapping
2424

25+
# Unique tag to avoid collisions with other comms
26+
PP_COMM_TAG_AUTOTUNING = 30000
27+
2528

2629
class DistributedTuningStrategy(enum.Enum):
2730
"""
@@ -358,7 +361,7 @@ class AutoTunerProfilingCache:
358361
"""
359362

360363
def __init__(self):
361-
self.cache = {}
364+
self.cache: Dict[Tuple, Tuple] = dict()
362365

363366
# Cache metadata for local storage and validation
364367
self.lib_version = tensorrt_llm.__version__
@@ -430,7 +433,7 @@ def get_cache_key(
430433
),
431434
)
432435

433-
def merge_cache_data(self, cache_data: Dict[str, Any]):
436+
def merge_cache_data(self, cache_data: Dict[Tuple, Tuple]):
434437
self.cache.update(cache_data)
435438

436439
def get_specific_custom_op(self, custom_op: str) -> Dict[Tuple, Tuple]:
@@ -615,7 +618,10 @@ def __init__(self, warmup=2, repeat=10, stream_delay_micro_secs=1000):
615618
self._last_capture: Optional['AutoTuner.TacticsCapture'] = None
616619

617620
# Dsitributed tuning state
621+
self._map_op_to_distributed_strategy: Dict[
622+
str, DistributedTuningStrategy] = {}
618623
self._dist: Optional[Distributed] = None
624+
self._has_received_cache: bool = False
619625
self.mapping: Mapping = Mapping()
620626

621627
@classmethod
@@ -624,9 +630,6 @@ def get(cls):
624630
cls._instance = AutoTuner()
625631
return cls._instance
626632

627-
def set_mapping(self, mapping: Mapping = None):
628-
self.mapping = mapping
629-
630633
class TacticsCapture:
631634
"""Object returned by capture() that can be iterated to get all tactic combinations.
632635
@@ -797,10 +800,18 @@ def choose_one(
797800
if self.is_tuning_mode and is_cache_hit:
798801
return (runners[best_runner_id], best_tactic)
799802

803+
# PP rank does not have cache hit, so we try to receive the cache from the previous rank
804+
# Notice that only under tuning mode, pp_recv will be called
805+
self.cache_pp_recv()
806+
800807
assert len(runners) > 0, "At least one runner is required"
801808
assert all([isinstance(r, TunableRunner) for r in runners]), \
802809
"All Given runners must be subclass of TunableRunner"
803810

811+
# Record the distributed tuning strategy for the custom_op
812+
self._map_op_to_distributed_strategy[
813+
custom_op] = tuning_config.distributed_tuning_strategy
814+
804815
tuning_start_time = time.perf_counter()
805816
profiles = self._optimization_profiles(tuning_config, inputs)
806817

@@ -1507,3 +1518,32 @@ def _should_current_rank_tune(self,
15071518
f"[AutoTuner] Unknown distributed tuning strategy: {strategy}, falling back to independent"
15081519
)
15091520
return True
1521+
1522+
def cache_pp_recv(self):
1523+
if self.mapping.has_pp() and not self.mapping.is_first_pp_rank(
1524+
) and not self._has_received_cache:
1525+
self._debug_logger(
1526+
f"[AutoTuner] Receiving cache data from previous pp rank {self.mapping.prev_pp_rank()}"
1527+
)
1528+
profiling_cache = self._dist.recv_object(
1529+
src=self.mapping.prev_pp_rank(),
1530+
tag=PP_COMM_TAG_AUTOTUNING,
1531+
)
1532+
# Guarantee that only receive cache once during a single warm-up run
1533+
# Notice that this flag should be reset after each warm-up run because isend is always called
1534+
self._has_received_cache = True
1535+
self.profiling_cache.merge_cache_data(profiling_cache)
1536+
1537+
def cache_pp_send(self):
1538+
if self.mapping.has_pp() and not self.mapping.is_last_pp_rank():
1539+
self._debug_logger(
1540+
f"[AutoTuner] Sending cache data to next pp rank {self.mapping.next_pp_rank()}"
1541+
)
1542+
self._dist.isend_object(
1543+
self.profiling_cache.cache,
1544+
dest=self.mapping.next_pp_rank(),
1545+
tag=PP_COMM_TAG_AUTOTUNING,
1546+
).wait()
1547+
1548+
def clean_pp_flag(self):
1549+
self._has_received_cache = False

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,14 @@ def _(
693693

694694
class NVFP4GemmUnifiedRunner(TunableRunner):
695695
runner_dict = dict()
696+
tuning_config = TuningConfig(
697+
dynamic_tensor_specs=(DynamicTensorSpec(
698+
0, 0, get_last_power_of_2_num_tokens_buckets,
699+
last_positive_power_of_2), ),
700+
constraint_specs=(ConstraintSpec(2, 0, fp4_scale_infer_shape), ),
701+
# nested tuning should always be independent
702+
distributed_tuning_strategy=DistributedTuningStrategy.INDEPENDENT,
703+
)
696704

697705
def __init__(self, to_userbuffers: bool, output_dtype: torch.dtype,
698706
allowed_backends: List[str]):
@@ -943,7 +951,7 @@ def nvfp4_gemm(
943951
_, best_tactic = tuner.choose_one(
944952
"trtllm::nvfp4_gemm::gemm",
945953
[runner],
946-
FP4GemmRunner.
954+
NVFP4GemmUnifiedRunner.
947955
tuning_config, # All runners use the same tuning_config
948956
[act_fp4, weight, act_sf, weight_scale, alpha],
949957
)

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,9 +667,19 @@ def _run_autotuner_warmup(self, resource_manager: ResourceManager):
667667
if self.is_draft_model and isinstance(
668668
spec_resource_manager, Eagle3ResourceManager):
669669
spec_resource_manager.is_first_draft = True
670+
670671
self.forward(batch,
671672
new_tensors_device=None,
672673
resource_manager=resource_manager)
674+
675+
# pp_recv in AutoTuner choose_one will never be called if there is no tuning op during the forward pass.
676+
# So we need to make an extra call to consume the previous rank's pp_send to guarantee that the previous rank's pp_send is released.
677+
AutoTuner.get().cache_pp_recv()
678+
# Send the cache after the tuning process to the next PP rank
679+
AutoTuner.get().cache_pp_send()
680+
# Clean the pp flag to avoid deadlock with synchronous send/recv
681+
AutoTuner.get().clean_pp_flag()
682+
673683
torch.cuda.synchronize()
674684

675685
logger.info(

0 commit comments

Comments
 (0)