|
29 | 29 | from typing import Dict, List, Union |
30 | 30 |
|
31 | 31 | import triton_python_backend_utils as pb_utils |
| 32 | +from vllm.config import VllmConfig |
32 | 33 | from vllm.engine.metrics import StatLoggerBase as VllmStatLoggerBase |
33 | 34 | from vllm.engine.metrics import Stats as VllmStats |
34 | 35 | from vllm.engine.metrics import SupportsMetricsInfo, build_1_2_5_buckets |
@@ -163,11 +164,13 @@ def __init__(self, labels: List[str], max_model_len: int): |
163 | 164 | class VllmStatLogger(VllmStatLoggerBase): |
164 | 165 | """StatLogger is used as an adapter between vLLM stats collector and Triton metrics provider.""" |
165 | 166 |
|
166 | | - def __init__(self, labels: Dict, max_model_len: int, log_logger) -> None: |
| 167 | + def __init__(self, labels: Dict, vllm_config: VllmConfig, log_logger) -> None: |
167 | 168 | # Tracked stats over current local logging interval. |
168 | 169 | # local_interval not used here. It's for vLLM logs to stdout. |
169 | | - super().__init__(local_interval=0) |
170 | | - self.metrics = TritonMetrics(labels, max_model_len) |
| 170 | + super().__init__(local_interval=0, vllm_config=vllm_config) |
| 171 | + self.metrics = TritonMetrics( |
| 172 | + labels=labels, max_model_len=vllm_config.model_config.max_model_len |
| 173 | + ) |
171 | 174 | self.log_logger = log_logger |
172 | 175 |
|
173 | 176 | # Starting the metrics thread. It allows vLLM to keep making progress |
|
0 commit comments