Skip to content

Commit 39fe587

Browse files
committed
Compatible with vllm-omni 0.16.0
Signed-off-by: Chen Yang <2082464740@qq.com>
1 parent 1e86404 commit 39fe587

File tree

2 files changed

+74
-9
lines changed

2 files changed

+74
-9
lines changed

vllm_omni/diffusion/diffusion_engine.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,22 @@ def __init__(self, od_config: OmniDiffusionConfig):
6767
raise e
6868

6969
def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]:
70+
# 记录扩散引擎整体执行开始时间
71+
diffusion_engine_start_time = time.time()
72+
7073
# Apply pre-processing if available
74+
preprocess_time = 0.0
7175
if self.pre_process_func is not None:
7276
preprocess_start_time = time.time()
7377
request = self.pre_process_func(request)
7478
preprocess_time = time.time() - preprocess_start_time
7579
logger.info(f"Pre-processing completed in {preprocess_time:.4f} seconds")
7680

81+
# 执行扩散推理并统计核心耗时
82+
exec_start_time = time.time()
7783
output = self.add_req_and_wait_for_response(request)
84+
exec_total_time = time.time() - exec_start_time
85+
7886
if output.error:
7987
raise Exception(f"{output.error}")
8088
logger.info("Generation completed successfully.")
@@ -102,15 +110,28 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]:
102110
if not isinstance(outputs, list):
103111
outputs = [outputs] if outputs is not None else []
104112

113+
105114
metrics = {
115+
"preprocess_time_ms": round(preprocess_time * 1000,2),
116+
"diffusion_engine_exec_time_ms": round((time.time() - diffusion_engine_start_time) * 1000,2),
117+
"dit_time_ms": round(exec_total_time * 1000,2),
118+
"postprocess_time_ms": round(postprocess_time * 1000,2),
106119
"image_num": int(request.sampling_params.num_outputs_per_prompt),
107120
"resolution": int(request.sampling_params.resolution),
108-
"postprocess_time_ms": postprocess_time * 1000,
121+
"denoise_time_per_step_ms": 0.0,
122+
"vae_time_ms": 0.0,
109123
}
110-
if self.pre_process_func is not None:
111-
metrics["preprocessing_time_ms"] = preprocess_time * 1000
112124

113-
# Handle single request or multiple requests
125+
126+
dit_time_seconds = metrics["dit_time_ms"] / 1000
127+
num_steps = request.sampling_params.num_inference_steps
128+
129+
if num_steps > 0:
130+
total_denoise_time = dit_time_seconds
131+
metrics["denoise_time_per_step_ms"] = round((total_denoise_time / num_steps) * 1000,2)
132+
133+
metrics["vae_time_ms"] = round(dit_time_seconds * 1000,2)
134+
114135
if len(request.prompts) == 1:
115136
# Single request: return single OmniRequestOutput
116137
prompt = request.prompts[0]
@@ -177,7 +198,7 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]:
177198
)
178199
)
179200

180-
return results
201+
return results
181202

182203
@staticmethod
183204
def make_engine(config: OmniDiffusionConfig) -> "DiffusionEngine":
@@ -378,4 +399,4 @@ def close(self) -> None:
378399
def abort(self, request_id: str | Iterable[str]) -> None:
379400
# TODO implement it
380401
logger.warning("DiffusionEngine abort is not implemented yet")
381-
pass
402+
pass

vllm_omni/outputs.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,19 +235,63 @@ def __repr__(self) -> str:
235235
"""Custom repr to properly show image count instead of image objects."""
236236
# For images, show count instead of full list
237237
images_repr = f"[{len(self.images)} PIL Images]" if self.images else "[]"
238-
239238
# Build repr string
239+
240+
def _repr_nested(obj) -> str:
241+
if isinstance(obj, list):
242+
return "[" + ", ".join(_repr_nested(x) for x in obj) + "]"
243+
if isinstance(obj, OmniRequestOutput):
244+
return obj._repr_multiline(indent=" ")
245+
return repr(obj)
246+
240247
parts = [
241248
f"request_id={self.request_id!r}",
242249
f"finished={self.finished}",
243250
f"stage_id={self.stage_id}",
244251
f"final_output_type={self.final_output_type!r}",
245-
f"request_output={self.request_output}",
252+
f"request_output={_repr_nested(self.request_output)}",
246253
f"images={images_repr}",
247254
f"prompt={self.prompt!r}",
248255
f"latents={self.latents}",
249256
f"metrics={self.metrics}",
250257
f"multimodal_output={self._multimodal_output}",
251258
]
252-
253259
return f"OmniRequestOutput({', '.join(parts)})"
260+
261+
262+
def _repr_multiline(self, indent: str = "") -> str:
263+
"""Helper to produce multi-line, indented repr for nested logging."""
264+
images_repr = f"[{len(self.images)} PIL Images]" if self.images else "[]"
265+
266+
def _repr_nested(obj, ind: str) -> str:
267+
if isinstance(obj, list):
268+
inner = ",\n".join(_repr_nested(x, ind + " ") for x in obj)
269+
return "[\n" + inner + "\n" + ind + "]"
270+
if isinstance(obj, OmniRequestOutput):
271+
return obj._repr_multiline(indent=ind + " ")
272+
return repr(obj)
273+
274+
# Format metrics with each key-value pair on a separate line
275+
if self.metrics:
276+
metrics_indent = indent + " "
277+
metrics_lines = f",\n{metrics_indent}".join(
278+
f"{k!r}: {v!r}" for k, v in self.metrics.items()
279+
)
280+
metrics_repr = f"{{\n{metrics_indent}{metrics_lines}\n{indent} }}"
281+
else:
282+
metrics_repr = "{}"
283+
284+
lines = [
285+
f"{indent}OmniRequestOutput(",
286+
f"{indent} request_id={self.request_id!r},",
287+
f"{indent} finished={self.finished},",
288+
f"{indent} stage_id={self.stage_id},",
289+
f"{indent} final_output_type={self.final_output_type!r},",
290+
f"{indent} request_output={_repr_nested(self.request_output, indent + ' ')},",
291+
f"{indent} images={images_repr},",
292+
f"{indent} prompt={self.prompt!r},",
293+
f"{indent} latents={self.latents},",
294+
f"{indent} metrics={metrics_repr},",
295+
f"{indent})",
296+
]
297+
return "\n".join(lines)

0 commit comments

Comments
 (0)