@@ -196,63 +196,49 @@ def add_req_and_wait_for_response(self, request: OmniDiffusionRequest):
196196
197197 def start_profile (self , trace_filename : str | None = None ) -> None :
198198 """
199- Start torch profiling on all diffusion workers.
199+ Start profiling on all diffusion workers.
200200
201- Creates a directory (if needed) and sets up a base filename template
202- for per-rank profiler traces (typically saved as <template>_rank<N>.json).
203-
204- Args:
205- trace_filename: Optional base filename (without extension or rank suffix).
206- If None, generates one using current timestamp.
201+ Profiling is configured via vLLM's profiler config/environment variables:
202+ - PyTorch profiler: VLLM_TORCH_PROFILER_DIR
203+ - Nsight Systems (cuda profiler): VLLM_TORCH_CUDA_PROFILE=1
207204 """
208- if trace_filename is None :
209- trace_filename = f"stage_0_diffusion_{ int (time .time ())} _rank"
210-
211- trace_dir = os .environ .get ("VLLM_TORCH_PROFILER_DIR" , "./profiles" )
212-
213- # Expand ~ and ~user, then make absolute (robust against cwd changes)
214- trace_dir = os .path .expanduser (trace_dir )
215- trace_dir = os .path .abspath (trace_dir )
216-
217- try :
218- os .makedirs (trace_dir , exist_ok = True )
219- except OSError as exc :
220- logger .error (f"Failed to create profiler directory { trace_dir } : { exc } " )
221- raise
222-
223- # Build final template path (without rank or extension — torch.profiler appends those)
224- full_template = os .path .join (trace_dir , trace_filename )
225-
226- expected_pattern = f"{ full_template } *.json"
227- logger .info (f"Starting diffusion profiling → { expected_pattern } " )
205+ if trace_filename :
206+ logger .debug (
207+ "Diffusion profiling uses vLLM profiler config; trace_filename is ignored (%s)." ,
208+ trace_filename ,
209+ )
228210
229- # Also log the absolute directory once (useful in multi-node or containers)
230- logger .debug (f"Profiler output directory: { trace_dir } " )
211+ trace_dir = os .environ .get ("VLLM_TORCH_PROFILER_DIR" )
212+ if trace_dir :
213+ trace_dir = os .path .abspath (os .path .expanduser (trace_dir ))
214+ try :
215+ os .makedirs (trace_dir , exist_ok = True )
216+ except OSError as exc :
217+ logger .error ("Failed to create profiler directory %s: %s" , trace_dir , exc )
218+ raise
219+ logger .info ("Starting diffusion profiling. Torch traces will be written under %s" , trace_dir )
220+ else :
221+ logger .info ("Starting diffusion profiling." )
231222
232223 # Propagate to all workers
233224 try :
234- self .collective_rpc (method = "start_profile" , args = ( full_template ,) )
225+ self .collective_rpc (method = "start_profile" )
235226 except Exception as e :
236227 logger .error ("Failed to start profiling on workers" , exc_info = True )
237228 raise RuntimeError (f"Could not start profiler: { e } " ) from e
238229
239230 def stop_profile (self ) -> dict :
240231 """
241- Stop profiling on all workers and collect the final trace/table paths.
242-
243- The worker (torch_profiler.py) now handles trace export, compression to .gz,
244- and deletion of the original .json file. This method only collects and
245- reports the paths returned by the workers.
232+ Stop profiling on all workers and best-effort collect any legacy outputs.
246233
247- Returns:
248- dict with keys:
249- - "traces": list of final trace file paths (usually .json.gz)
250- - "tables": list of table strings (one per rank)
234+ vLLM's profiler wrappers write traces directly to disk and do not return
235+ per-rank file paths. This method preserves backward compatibility by
236+ aggregating any dict-like results if present.
251237 """
252- logger .info ("Stopping diffusion profiling and collecting results ..." )
238+ logger .info ("Stopping diffusion profiling..." )
253239
254240 try :
255- # Give worker enough time — export + compression + table can be slow
241+ # Give workers enough time — trace flushing can be slow
256242 results = self .collective_rpc (method = "stop_profile" , timeout = 60000 )
257243 except Exception :
258244 logger .error ("Failed to stop profiling on workers" , exc_info = True )
@@ -262,54 +248,46 @@ def stop_profile(self) -> dict:
262248 successful_traces = 0
263249
264250 if not results :
265- logger .warning ("No profiling results returned from any rank" )
251+ logger .info ("No profiling results returned from any rank. " )
266252 return output_files
267253
268254 for rank , res in enumerate (results ):
255+ if res is None :
256+ # vLLM profiler wrappers return no per-rank payloads.
257+ continue
269258 if not isinstance (res , dict ):
270- logger .warning (f "Rank { rank } : invalid result format (got { type (res )} )" )
259+ logger .warning ("Rank %s : invalid result format (got %s)" , rank , type (res ))
271260 continue
272261
273- # 1. Trace file — should be .json.gz if compression succeeded
274- trace_path = res .get ("trace" )
262+ trace_path = res .get ("trace" ) or res .get ("traces" )
275263 if trace_path :
276- # We trust the worker — it created/compressed the file
277- logger .info (f"[Rank { rank } ] Final trace: { trace_path } " )
278- output_files ["traces" ].append (trace_path )
279- successful_traces += 1
264+ if isinstance (trace_path , str ):
265+ output_files ["traces" ].append (trace_path )
266+ elif isinstance (trace_path , list ):
267+ output_files ["traces" ].extend (trace_path )
268+ successful_traces = len (output_files ["traces" ])
280269
281- # Optional: warn if path looks suspicious (e.g. still .json)
282- if not trace_path .endswith ((".json.gz" , ".json" )):
283- logger .warning (f"Rank { rank } : unusual trace path extension: { trace_path } " )
284-
285- # 2. Table file — plain text
286- table = res .get ("table" )
270+ table = res .get ("table" ) or res .get ("tables" )
287271 if table :
288- output_files ["tables" ].append (table )
272+ if isinstance (table , str ):
273+ output_files ["tables" ].append (table )
274+ elif isinstance (table , list ):
275+ output_files ["tables" ].extend (table )
289276
290- # Final summary logging
291- num_ranks = len (results )
292277 if successful_traces > 0 :
293- final_paths_str = ", " .join (output_files ["traces" ][:3 ])
294- if len (output_files ["traces" ]) > 3 :
295- final_paths_str += f" ... (+{ len (output_files ['traces' ]) - 3 } more)"
296-
297278 logger .info (
298- f "Profiling stopped. Collected { successful_traces } trace file(s) "
299- f"from { num_ranks } rank(s). "
300- f"Final trace paths: { final_paths_str } "
279+ "Profiling stopped. Collected %s trace file(s) from %s rank(s)." ,
280+ successful_traces ,
281+ len ( results ),
301282 )
302- elif output_files [ "traces" ] :
283+ else :
303284 logger .info (
304- f"Profiling stopped but no traces were successfully collected. "
305- f"Reported paths: { ', ' .join (output_files ['traces' ][:3 ])} "
306- f"{ ' ...' if len (output_files ['traces' ]) > 3 else '' } "
285+ "Profiling stopped. Traces are written by the active profiler "
286+ "(PyTorch: VLLM_TORCH_PROFILER_DIR, nsys: -o output)."
307287 )
308- else :
309- logger .info ("Profiling stopped — no trace files were collected from any rank." )
310288
311289 if output_files ["tables" ]:
312- logger .debug (f "Collected { len ( output_files [ 'tables' ]) } profiling table(s)" )
290+ logger .debug ("Collected %s profiling table(s)" , len ( output_files [ "tables" ]) )
313291
314292 return output_files
315293
0 commit comments