Skip to content

Commit 48f3090

Browse files
authored
[NIXL][Misc] Expose metrics from NIXL for logging to CLI (#25388)
Signed-off-by: NickLucche <[email protected]>
1 parent 0e93ac0 commit 48f3090

File tree

4 files changed

+127
-28
lines changed

4 files changed

+127
-28
lines changed

requirements/kv_connectors.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
lmcache
2-
nixl >= 0.5.1 # Required for disaggregated prefill
2+
nixl >= 0.6.0 # Required for disaggregated prefill

tests/v1/kv_connector/unit/test_nixl_connector.py

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,26 @@ def clear_kv_transfer():
5757
ensure_kv_transfer_shutdown()
5858

5959

60+
def get_default_xfer_telemetry(xferDurationS: float = 1,
61+
postDurationS: float = 1,
62+
totalBytes: int = 1,
63+
descCount: int = 1) -> dict:
64+
65+
class AttributeDict(dict):
66+
__slots__ = ()
67+
__getattr__ = dict.__getitem__
68+
__setattr__ = dict.__setitem__ # type: ignore[assignment]
69+
70+
# We can't instantiate nixlXferTelemetry because it's read only and
71+
# ray env does not have NIXL, so we must fake it
72+
return AttributeDict(
73+
xferDuration=xferDurationS * 1e6, # in us
74+
postDuration=postDurationS * 1e6, # in us
75+
totalBytes=totalBytes,
76+
descCount=descCount,
77+
)
78+
79+
6080
class FakeNixlWrapper:
6181
"""Mock implementation of NixlWrapper for testing.
6282
@@ -132,6 +152,9 @@ def make_prepped_xfer(self,
132152
def transfer(self, handle: int) -> str:
133153
return "PROC"
134154

155+
def get_xfer_telemetry(self, handle: int) -> dict:
156+
return get_default_xfer_telemetry()
157+
135158
############################################################
136159
# Follow are for changing the behavior during testing.
137160
############################################################
@@ -169,6 +192,11 @@ def _make_fake_nixl_pkg():
169192
with open(os.path.join(pkg_root, "__init__.py"), "w") as f:
170193
f.write(stub)
171194

195+
# Mock nixlXferTelemetry class
196+
pkg_root2 = os.path.join(td, "nixl", "_bindings")
197+
os.makedirs(pkg_root2, exist_ok=True)
198+
with open(os.path.join(pkg_root2, "__init__.py"), "w") as f:
199+
f.write("class nixlXferTelemetry: pass")
172200
# touch parent package
173201
open(os.path.join(td, "nixl", "__init__.py"), "w").close()
174202
yield td
@@ -575,7 +603,7 @@ def test_kv_connector_stats(dist_init):
575603

576604
# Verify stats values are recorded
577605
assert not stats_after_transfer.is_empty()
578-
assert stats_after_transfer.data["num_successful_transfers"] == 1
606+
assert stats_after_transfer.num_successful_transfers == 1
579607

580608
# Verify stats are reset after retrieval
581609
stats_after_reset = connector.get_kv_connector_stats()
@@ -599,16 +627,21 @@ def test_kv_connector_stats_aggregation():
599627

600628
# Record different transfers on each worker
601629
# Worker 1: 2 transfers
602-
worker1_stats.record_transfer()
603-
worker1_stats.record_transfer()
630+
stats = get_default_xfer_telemetry()
631+
worker1_stats.record_transfer(stats)
632+
worker1_stats.record_transfer(stats)
604633

605634
# Worker 2: 1 transfer
606-
worker2_stats.record_transfer()
635+
worker2_stats.record_transfer(stats)
607636

608637
# Worker 3: 3 transfers
609-
worker3_stats.record_transfer()
610-
worker3_stats.record_transfer()
611-
worker3_stats.record_transfer()
638+
stats = get_default_xfer_telemetry(xferDurationS=2,
639+
postDurationS=2,
640+
totalBytes=2,
641+
descCount=2)
642+
worker3_stats.record_transfer(stats)
643+
worker3_stats.record_transfer(stats)
644+
worker3_stats.record_transfer(stats)
612645

613646
# Create ModelRunnerOutput instances for each worker
614647
worker_outputs = []
@@ -636,7 +669,12 @@ def test_kv_connector_stats_aggregation():
636669
aggregated_output.kv_connector_output.kv_connector_stats
637670
assert isinstance(kv_connector_stats, NixlKVConnectorStats)
638671
# Number of total transfers across all workers.
639-
assert kv_connector_stats.data["num_successful_transfers"] == 6
672+
assert kv_connector_stats.num_successful_transfers == 6
673+
# Logging proc, call reduce() to get CLI-friendly stats.
674+
cli_stats = kv_connector_stats.reduce()
675+
assert cli_stats["Avg xfer time (ms)"] == 1500.0
676+
assert cli_stats["Avg post time (ms)"] == 1500.0
677+
assert cli_stats["Avg number of descriptors"] == 1.5
640678

641679

642680
def test_multi_kv_connector_stats_aggregation():
@@ -649,6 +687,7 @@ def test_multi_kv_connector_stats_aggregation():
649687

650688
from dataclasses import dataclass
651689

690+
# Mock a KVConnectorStats class for testing aggregation over connectors.
652691
@dataclass
653692
class FooKVConnectorStats(KVConnectorStats):
654693

@@ -676,7 +715,7 @@ def make_multi_stats(nixl_count: int,
676715
if nixl_count > 0:
677716
nixl_stats = NixlKVConnectorStats()
678717
for _ in range(nixl_count):
679-
nixl_stats.record_transfer()
718+
nixl_stats.record_transfer(get_default_xfer_telemetry())
680719
data["NixlConnector"] = nixl_stats
681720
if foo_count > 0:
682721
foo_stats = FooKVConnectorStats()
@@ -712,8 +751,10 @@ def make_multi_stats(nixl_count: int,
712751
assert isinstance(kv_connector_stats, MultiKVConnectorStats)
713752

714753
# Validate per-connector totals across workers
715-
assert kv_connector_stats["NixlConnector"].data[
716-
"num_successful_transfers"] == 5
754+
assert isinstance(kv_connector_stats["NixlConnector"],
755+
NixlKVConnectorStats)
756+
assert kv_connector_stats["NixlConnector"].num_successful_transfers == 5
757+
assert isinstance(kv_connector_stats["FooConnector"], FooKVConnectorStats)
717758
assert kv_connector_stats["FooConnector"].data["num_foo_transfers"] == 6
718759

719760

@@ -755,6 +796,8 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
755796
"working_dir": working_dir, # ship fake nixl package
756797
"env_vars": {
757798
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT": str(timeout),
799+
# TODO: for ray to carry over, remove once we set
800+
"NIXL_TELEMETRY_ENABLE": "1",
758801
},
759802
}
760803
ray.init(runtime_env=runtime_env)

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 69 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import copy
55
import logging
66
import math
7+
import os
78
import queue
89
import threading
910
import time
@@ -54,10 +55,12 @@
5455
# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
5556
try:
5657
from nixl._api import nixl_agent as NixlWrapper
58+
from nixl._bindings import nixlXferTelemetry
5759
logger.info("NIXL is available")
5860
except ImportError:
5961
logger.warning("NIXL is not available")
6062
NixlWrapper = None
63+
nixlXferTelemetry = None
6164

6265
try:
6366
from nixl._api import nixl_agent_config
@@ -476,6 +479,9 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
476479
self.nixl_backends = \
477480
vllm_config.kv_transfer_config.get_from_extra_config(
478481
"backends", ["UCX"])
482+
# TODO temporary, once nixl allows for telemetry flag in config
483+
# (next release), we can remove this env var.
484+
os.environ["NIXL_TELEMETRY_ENABLE"] = "1"
479485
# Agent.
480486
non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"]
481487
if nixl_agent_config is None:
@@ -1175,9 +1181,10 @@ def _pop_done_transfers(
11751181
for handle, _xfer_stime in handles:
11761182
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
11771183
if xfer_state == "DONE":
1184+
# Get telemetry from NIXL
1185+
res = self.nixl_wrapper.get_xfer_telemetry(handle)
1186+
self.xfer_stats.record_transfer(res)
11781187
self.nixl_wrapper.release_xfer_handle(handle)
1179-
# TODO (NickLucche) Get from NIXL telemetry once integrated
1180-
self.xfer_stats.record_transfer()
11811188
elif xfer_state == "PROC":
11821189
in_progress = True
11831190
continue
@@ -1449,32 +1456,81 @@ class NixlKVConnectorStats(KVConnectorStats):
14491456
"""Container for transfer performance metrics"""
14501457

14511458
def __post_init__(self):
1452-
if "num_successful_transfers" not in self.data:
1453-
self.data["num_successful_transfers"] = 0
1459+
if not self.data:
1460+
# Empty container init, no data is passed in.
1461+
self.reset()
14541462

14551463
def reset(self):
1456-
self.data = {"num_successful_transfers": 0}
1464+
# Must be serializable
1465+
self.data: dict[str, list[float]] = {
1466+
"transfer_duration": [],
1467+
"post_duration": [],
1468+
"bytes_transferred": [],
1469+
"num_descriptors": [],
1470+
}
14571471

1458-
def record_transfer(self):
1459-
# TODO: record actual transfer stats when available
1460-
self.data["num_successful_transfers"] += 1
1472+
def record_transfer(self, res: nixlXferTelemetry):
1473+
# Keep metrics units consistent with rest of the code: time us->s
1474+
self.data["transfer_duration"].append(res.xferDuration / 1e6)
1475+
self.data["post_duration"].append(res.postDuration / 1e6)
1476+
self.data["bytes_transferred"].append(res.totalBytes)
1477+
self.data["num_descriptors"].append(res.descCount)
14611478

14621479
def clone_and_reset(self) -> "NixlKVConnectorStats":
14631480
old = copy.copy(self)
14641481
self.reset()
14651482
return old
14661483

14671484
def is_empty(self) -> bool:
1468-
return self.data["num_successful_transfers"] == 0
1485+
return self.num_successful_transfers == 0
14691486

14701487
def aggregate(self, other: KVConnectorStats) -> KVConnectorStats:
14711488
if not other.is_empty():
1472-
self.data["num_successful_transfers"] += other.data[
1473-
"num_successful_transfers"]
1489+
for k, v in other.data.items():
1490+
accumulator = self.data[k]
1491+
assert isinstance(accumulator, list)
1492+
accumulator.extend(v)
14741493
return self
14751494

14761495
def reduce(self) -> dict[str, Union[int, float]]:
1477-
# TODO: reduce stats to a single value, calculate latency/throughput
1496+
# Compute compact representative stats suitable for CLI logging
1497+
if self.is_empty():
1498+
return {
1499+
"Num successful transfers": 0,
1500+
"Avg xfer time (ms)": 0,
1501+
"P90 xfer time (ms)": 0,
1502+
"Avg post time (ms)": 0,
1503+
"P90 post time (ms)": 0,
1504+
"Avg MB per transfer": 0,
1505+
"Throughput (MB/s)": 0,
1506+
"Avg number of descriptors": 0,
1507+
}
1508+
1509+
xfer_time = np.asarray(self.data["transfer_duration"])
1510+
post_time = np.asarray(self.data["post_duration"])
1511+
# Convert to MB for CLI logging.
1512+
mb = np.asarray(self.data["bytes_transferred"]) / 2**20
1513+
descs = np.asarray(self.data["num_descriptors"], dtype=np.uint32)
1514+
n = len(descs)
1515+
assert n == self.num_successful_transfers
1516+
1517+
total_mb = mb.sum()
1518+
avg_mb = total_mb / n
1519+
1520+
total_time_seconds = xfer_time.sum()
1521+
throughput_mb_s = total_mb / total_time_seconds
1522+
14781523
return {
1479-
"num_successful_transfers": self.data["num_successful_transfers"]
1524+
"Num successful transfers": n,
1525+
"Avg xfer time (ms)": round(xfer_time.mean() * 1e3, 3),
1526+
"P90 xfer time (ms)": round(np.percentile(xfer_time, 90) * 1e3, 3),
1527+
"Avg post time (ms)": round(post_time.mean() * 1e3, 3),
1528+
"P90 post time (ms)": round(np.percentile(post_time, 90) * 1e3, 3),
1529+
"Avg MB per transfer": round(avg_mb, 3),
1530+
"Throughput (MB/s)": round(throughput_mb_s, 3),
1531+
"Avg number of descriptors": round(descs.mean(), 1),
14801532
}
1533+
1534+
@property
1535+
def num_successful_transfers(self) -> int:
1536+
return len(self.data["transfer_duration"])

vllm/v1/metrics/loggers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
6262
self.prefix_caching_metrics = PrefixCachingMetrics()
6363
self.spec_decoding_logging = SpecDecodingLogging()
6464
kv_tranfer_config = self.vllm_config.kv_transfer_config
65-
self.kv_transfer_logging = KVConnectorLogging(kv_tranfer_config)
65+
self.kv_connector_logging = KVConnectorLogging(kv_tranfer_config)
6666
self.last_prompt_throughput: float = 0.0
6767
self.last_generation_throughput: float = 0.0
6868

@@ -101,7 +101,7 @@ def record(self,
101101
self.spec_decoding_logging.observe(
102102
scheduler_stats.spec_decoding_stats)
103103
if kv_connector_stats := scheduler_stats.kv_connector_stats:
104-
self.kv_transfer_logging.observe(kv_connector_stats)
104+
self.kv_connector_logging.observe(kv_connector_stats)
105105
self.last_scheduler_stats = scheduler_stats
106106

107107
def log(self):
@@ -140,7 +140,7 @@ def log(self):
140140
self.prefix_caching_metrics.hit_rate * 100,
141141
)
142142
self.spec_decoding_logging.log(log_fn=log_fn)
143-
self.kv_transfer_logging.log(log_fn=log_fn)
143+
self.kv_connector_logging.log(log_fn=log_fn)
144144

145145
def log_engine_initialized(self):
146146
if self.vllm_config.cache_config.num_gpu_blocks:

0 commit comments

Comments
 (0)