-
Notifications
You must be signed in to change notification settings - Fork 618
[Profiler] Add Nsight Systems support for serving #1098
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d7c91aa
c88f6c5
5ee0b19
fc4d9eb
214df9c
7dc426d
9c71ca7
33e451f
b0c7853
e518553
4d5104d
076924c
27805c2
8a1e42d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,12 +10,14 @@ | |
| import multiprocessing as mp | ||
| import os | ||
| from contextlib import AbstractContextManager, nullcontext | ||
| from typing import Any | ||
|
|
||
| import torch | ||
| import zmq | ||
| from vllm.config import VllmConfig | ||
| from vllm.distributed.device_communicators.shm_broadcast import MessageQueue | ||
| from vllm.logger import init_logger | ||
| from vllm.profiler.wrapper import CudaProfilerWrapper, TorchProfilerWrapper | ||
| from vllm.utils.mem_utils import GiB_bytes | ||
|
|
||
| from vllm_omni.diffusion.data import ( | ||
|
|
@@ -29,7 +31,6 @@ | |
| ) | ||
| from vllm_omni.diffusion.forward_context import set_forward_context | ||
| from vllm_omni.diffusion.lora.manager import DiffusionLoRAManager | ||
| from vllm_omni.diffusion.profiler import CurrentProfiler | ||
| from vllm_omni.diffusion.request import OmniDiffusionRequest | ||
| from vllm_omni.diffusion.worker.diffusion_model_runner import DiffusionModelRunner | ||
| from vllm_omni.lora.request import LoRARequest | ||
|
|
@@ -65,6 +66,7 @@ def __init__( | |
| self.model_runner: DiffusionModelRunner | None = None | ||
| self._sleep_saved_buffers: dict[str, torch.Tensor] = {} | ||
| self.lora_manager: DiffusionLoRAManager | None = None | ||
| self.profiler: Any | None = None | ||
| self.init_device() | ||
|
|
||
| def init_device(self) -> None: | ||
|
|
@@ -89,6 +91,20 @@ def init_device(self) -> None: | |
| vllm_config.parallel_config.data_parallel_size = self.od_config.parallel_config.data_parallel_size | ||
| self.vllm_config = vllm_config | ||
|
|
||
| # Initialize profiler based on profiler_config (follows vLLM pattern) | ||
| profiler_config = vllm_config.profiler_config | ||
| if profiler_config.profiler == "torch": | ||
| worker_name = f"diffusion-rank-{self.rank}" | ||
| self.profiler = TorchProfilerWrapper( | ||
| profiler_config, | ||
| worker_name=worker_name, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing self.profiler = TorchProfilerWrapper(
profiler_config,
worker_name=worker_name,
local_rank=self.local_rank,
activities=["CPU", "CUDA"], # <-- add this
)Without it, the torch profiler may not capture CUDA kernels, which defeats the purpose of nsys integration. |
||
| local_rank=self.local_rank, | ||
| ) | ||
| elif profiler_config.profiler == "cuda": | ||
| self.profiler = CudaProfilerWrapper(profiler_config) | ||
| else: | ||
| self.profiler = None | ||
|
|
||
| # Initialize distributed environment | ||
| with set_forward_context(vllm_config=vllm_config, omni_diffusion_config=self.od_config): | ||
| init_distributed_environment(world_size=world_size, rank=rank) | ||
|
|
@@ -129,15 +145,27 @@ def generate(self, request: OmniDiffusionRequest) -> DiffusionOutput: | |
| """Generate output for the given requests.""" | ||
| return self.execute_model(request, self.od_config) | ||
|
|
||
| @classmethod | ||
| def start_profile(cls, trace_path_template: str) -> str: | ||
| """Start profiling for this GPU worker.""" | ||
| return CurrentProfiler.start(trace_path_template) | ||
| def start_profile(self, trace_path_template: str = "") -> str: | ||
| """Start profiling for this GPU worker. | ||
|
|
||
| Uses vLLM's profiler wrappers based on profiler_config: | ||
| - 'torch': TorchProfilerWrapper for detailed CPU/CUDA traces | ||
| - 'cuda': CudaProfilerWrapper for nsys integration | ||
|
|
||
| @classmethod | ||
| def stop_profile(cls) -> dict | None: | ||
| Set VLLM_TORCH_CUDA_PROFILE=1 for nsys/cuda profiler, or | ||
| VLLM_TORCH_PROFILER_DIR for torch profiler. | ||
| """ | ||
| if self.profiler is not None: | ||
| self.profiler.start() | ||
| logger.info("Diffusion worker %s: profiler started", self.rank) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| return trace_path_template | ||
|
|
||
| def stop_profile(self) -> dict | None: | ||
| """Stop profiling and return the result dictionary.""" | ||
| return CurrentProfiler.stop() | ||
| if self.profiler is not None: | ||
| self.profiler.stop() | ||
| logger.info("Diffusion worker %s: profiler stopped", self.rank) | ||
| return None | ||
|
|
||
| def execute_model(self, req: OmniDiffusionRequest, od_config: OmniDiffusionConfig) -> DiffusionOutput: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
def stop_profile(self) -> dict | None:
if self.profiler is not None:
return self.profiler.stop()
return None |
||
| """Execute a forward pass by delegating to the model runner.""" | ||
|
|
@@ -149,7 +177,15 @@ def execute_model(self, req: OmniDiffusionRequest, od_config: OmniDiffusionConfi | |
| if req.sampling_params.lora_request is not None: | ||
| raise | ||
| logger.warning("LoRA activation skipped: %s", exc) | ||
| return self.model_runner.execute_model(req) | ||
| profiler_context = ( | ||
| self.profiler.annotate_context_manager("diffusion_forward") if self.profiler is not None else nullcontext() | ||
| ) | ||
| with profiler_context: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good use of |
||
| output = self.model_runner.execute_model(req) | ||
| if self.profiler is not None: | ||
| # Drive delayed start/auto-stop behavior to match vLLM's profiler wrapper. | ||
| self.profiler.step() | ||
| return output | ||
|
|
||
| def load_weights(self, weights) -> set[str]: | ||
| """Load weights by delegating to the model runner.""" | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since workers always return
Noneright now, the entirefor rank, res in enumerate(results)loop body is effectively dead code (theif res is None: continueskips everything). This will become useful after fixing the worker'sstop_profile()to return the wrapper's result.