@@ -57,6 +57,26 @@ def clear_kv_transfer():
5757 ensure_kv_transfer_shutdown ()
5858
5959
60+ def get_default_xfer_telemetry (xferDurationS : float = 1 ,
61+ postDurationS : float = 1 ,
62+ totalBytes : int = 1 ,
63+ descCount : int = 1 ) -> dict :
64+
65+ class AttributeDict (dict ):
66+ __slots__ = ()
67+ __getattr__ = dict .__getitem__
68+ __setattr__ = dict .__setitem__ # type: ignore[assignment]
69+
70+ # We can't instantiate nixlXferTelemetry because it's read only and
71+ # ray env does not have NIXL, so we must fake it
72+ return AttributeDict (
73+ xferDuration = xferDurationS * 1e6 , # in us
74+ postDuration = postDurationS * 1e6 , # in us
75+ totalBytes = totalBytes ,
76+ descCount = descCount ,
77+ )
78+
79+
6080class FakeNixlWrapper :
6181 """Mock implementation of NixlWrapper for testing.
6282
@@ -132,6 +152,9 @@ def make_prepped_xfer(self,
132152 def transfer (self , handle : int ) -> str :
133153 return "PROC"
134154
155+ def get_xfer_telemetry (self , handle : int ) -> dict :
156+ return get_default_xfer_telemetry ()
157+
135158 ############################################################
136159 # Follow are for changing the behavior during testing.
137160 ############################################################
@@ -169,6 +192,11 @@ def _make_fake_nixl_pkg():
169192 with open (os .path .join (pkg_root , "__init__.py" ), "w" ) as f :
170193 f .write (stub )
171194
195+ # Mock nixlXferTelemetry class
196+ pkg_root2 = os .path .join (td , "nixl" , "_bindings" )
197+ os .makedirs (pkg_root2 , exist_ok = True )
198+ with open (os .path .join (pkg_root2 , "__init__.py" ), "w" ) as f :
199+ f .write ("class nixlXferTelemetry: pass" )
172200 # touch parent package
173201 open (os .path .join (td , "nixl" , "__init__.py" ), "w" ).close ()
174202 yield td
@@ -575,7 +603,7 @@ def test_kv_connector_stats(dist_init):
575603
576604 # Verify stats values are recorded
577605 assert not stats_after_transfer .is_empty ()
578- assert stats_after_transfer .data [ " num_successful_transfers" ] == 1
606+ assert stats_after_transfer .num_successful_transfers == 1
579607
580608 # Verify stats are reset after retrieval
581609 stats_after_reset = connector .get_kv_connector_stats ()
@@ -599,16 +627,21 @@ def test_kv_connector_stats_aggregation():
599627
600628 # Record different transfers on each worker
601629 # Worker 1: 2 transfers
602- worker1_stats .record_transfer ()
603- worker1_stats .record_transfer ()
630+ stats = get_default_xfer_telemetry ()
631+ worker1_stats .record_transfer (stats )
632+ worker1_stats .record_transfer (stats )
604633
605634 # Worker 2: 1 transfer
606- worker2_stats .record_transfer ()
635+ worker2_stats .record_transfer (stats )
607636
608637 # Worker 3: 3 transfers
609- worker3_stats .record_transfer ()
610- worker3_stats .record_transfer ()
611- worker3_stats .record_transfer ()
638+ stats = get_default_xfer_telemetry (xferDurationS = 2 ,
639+ postDurationS = 2 ,
640+ totalBytes = 2 ,
641+ descCount = 2 )
642+ worker3_stats .record_transfer (stats )
643+ worker3_stats .record_transfer (stats )
644+ worker3_stats .record_transfer (stats )
612645
613646 # Create ModelRunnerOutput instances for each worker
614647 worker_outputs = []
@@ -636,7 +669,12 @@ def test_kv_connector_stats_aggregation():
636669 aggregated_output .kv_connector_output .kv_connector_stats
637670 assert isinstance (kv_connector_stats , NixlKVConnectorStats )
638671 # Number of total transfers across all workers.
639- assert kv_connector_stats .data ["num_successful_transfers" ] == 6
672+ assert kv_connector_stats .num_successful_transfers == 6
673+ # Logging proc, call reduce() to get CLI-friendly stats.
674+ cli_stats = kv_connector_stats .reduce ()
675+ assert cli_stats ["Avg xfer time (ms)" ] == 1500.0
676+ assert cli_stats ["Avg post time (ms)" ] == 1500.0
677+ assert cli_stats ["Avg number of descriptors" ] == 1.5
640678
641679
642680def test_multi_kv_connector_stats_aggregation ():
@@ -649,6 +687,7 @@ def test_multi_kv_connector_stats_aggregation():
649687
650688 from dataclasses import dataclass
651689
690+ # Mock a KVConnectorStats class for testing aggregation over connectors.
652691 @dataclass
653692 class FooKVConnectorStats (KVConnectorStats ):
654693
@@ -676,7 +715,7 @@ def make_multi_stats(nixl_count: int,
676715 if nixl_count > 0 :
677716 nixl_stats = NixlKVConnectorStats ()
678717 for _ in range (nixl_count ):
679- nixl_stats .record_transfer ()
718+ nixl_stats .record_transfer (get_default_xfer_telemetry () )
680719 data ["NixlConnector" ] = nixl_stats
681720 if foo_count > 0 :
682721 foo_stats = FooKVConnectorStats ()
@@ -712,8 +751,10 @@ def make_multi_stats(nixl_count: int,
712751 assert isinstance (kv_connector_stats , MultiKVConnectorStats )
713752
714753 # Validate per-connector totals across workers
715- assert kv_connector_stats ["NixlConnector" ].data [
716- "num_successful_transfers" ] == 5
754+ assert isinstance (kv_connector_stats ["NixlConnector" ],
755+ NixlKVConnectorStats )
756+ assert kv_connector_stats ["NixlConnector" ].num_successful_transfers == 5
757+ assert isinstance (kv_connector_stats ["FooConnector" ], FooKVConnectorStats )
717758 assert kv_connector_stats ["FooConnector" ].data ["num_foo_transfers" ] == 6
718759
719760
@@ -755,6 +796,8 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
755796 "working_dir" : working_dir , # ship fake nixl package
756797 "env_vars" : {
757798 "VLLM_NIXL_ABORT_REQUEST_TIMEOUT" : str (timeout ),
799+ # TODO: for ray to carry over, remove once we set
800+ "NIXL_TELEMETRY_ENABLE" : "1" ,
758801 },
759802 }
760803 ray .init (runtime_env = runtime_env )
0 commit comments