Skip to content

Commit 8317857

Browse files
committed
refactor example file
Signed-off-by: HonestDeng <2958906959@qq.com>
1 parent cd95963 commit 8317857

File tree

2 files changed

+127
-200
lines changed

2 files changed

+127
-200
lines changed

examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_image_summarize.py

Lines changed: 6 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -25,42 +25,12 @@
2525

2626
def parse_args() -> argparse.Namespace:
2727
parser = argparse.ArgumentParser(description="MammothModa2 image summarization (offline, AR only).")
28-
parser.add_argument(
29-
"--model",
30-
type=str,
31-
required=True,
32-
help="Path to model directory or model id.",
33-
)
34-
parser.add_argument(
35-
"--stage-config",
36-
type=str,
37-
required=True,
38-
help="Path to stage config yaml (single-stage AR->text).",
39-
)
40-
parser.add_argument(
41-
"--image",
42-
type=str,
43-
required=True,
44-
help="Path to input image.",
45-
)
46-
parser.add_argument(
47-
"--question",
48-
type=str,
49-
default=DEFAULT_QUESTION,
50-
help="Question/instruction for the model.",
51-
)
52-
parser.add_argument(
53-
"--system",
54-
type=str,
55-
default=DEFAULT_SYSTEM,
56-
help="System prompt.",
57-
)
58-
parser.add_argument(
59-
"--max-tokens",
60-
type=int,
61-
default=512,
62-
help="Max new tokens to generate.",
63-
)
28+
parser.add_argument("--model", type=str, required=True, help="Path to model directory or model id.")
29+
parser.add_argument("--stage-config", type=str, required=True, help="Path to stage config yaml (single-stage AR->text).")
30+
parser.add_argument("--image", type=str, required=True, help="Path to input image.")
31+
parser.add_argument("--question", type=str, default=DEFAULT_QUESTION, help="Question/instruction for the model.")
32+
parser.add_argument("--system", type=str, default=DEFAULT_SYSTEM, help="System prompt.")
33+
parser.add_argument("--max-tokens", type=int, default=512, help="Max new tokens to generate.",)
6434
parser.add_argument("--temperature", type=float, default=0.2)
6535
parser.add_argument("--top-p", type=float, default=0.9)
6636
parser.add_argument("--seed", type=int, default=42)

examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_t2i.py

Lines changed: 121 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import logging
2222
import os
2323
from pathlib import Path
24+
from typing import NamedTuple
2425

2526
import torch
2627
from PIL import Image
@@ -31,78 +32,65 @@
3132
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
3233
logger = logging.getLogger(__name__)
3334

34-
35-
def load_t2i_generation_config(model_dir: str) -> tuple[int, int, int]:
36-
"""Load T2I token ranges from t2i_generation_config.json."""
37-
cfg_path = Path(model_dir) / "t2i_generation_config.json"
38-
if not cfg_path.exists():
39-
raise FileNotFoundError(f"Config not found: {cfg_path}")
40-
41-
with cfg_path.open("r", encoding="utf-8") as f:
42-
cfg = json.load(f)
43-
44-
return (
45-
int(cfg["eol_token_id"]),
46-
int(cfg["visual_token_start_id"]),
47-
int(cfg["visual_token_end_id"]),
35+
# ---------------------------------------------------------------------------
36+
# Constants
37+
# ---------------------------------------------------------------------------
38+
_PATCH_SIZE = 16 # AR image grid patch size (pixels per token)
39+
40+
41+
class T2IGenConfig(NamedTuple):
42+
eol_token_id: int
43+
visual_token_start_id: int
44+
visual_token_end_id: int
45+
top_k: int # AR sampling top-k (covers the full visual generation vocabulary)
46+
# Qwen2.5-VL special vision tokens: <|image_pad|>, <|video_pad|>, <|vision_start|>, <|vision_end|>
47+
visual_ids: list[int]
48+
49+
50+
def load_t2i_generation_config(model_dir: str) -> T2IGenConfig:
51+
"""Load T2I token IDs from t2i_generation_config.json and config.json."""
52+
model_path = Path(model_dir)
53+
54+
gen_cfg_path = model_path / "t2i_generation_config.json"
55+
if not gen_cfg_path.exists():
56+
raise FileNotFoundError(f"Config not found: {gen_cfg_path}")
57+
with gen_cfg_path.open(encoding="utf-8") as f:
58+
gen_cfg = json.load(f)
59+
60+
model_cfg_path = model_path / "config.json"
61+
if not model_cfg_path.exists():
62+
raise FileNotFoundError(f"Config not found: {model_cfg_path}")
63+
with model_cfg_path.open(encoding="utf-8") as f:
64+
llm_cfg = json.load(f).get("llm_config", {})
65+
66+
return T2IGenConfig(
67+
eol_token_id=int(gen_cfg["eol_token_id"]),
68+
visual_token_start_id=int(gen_cfg["visual_token_start_id"]),
69+
visual_token_end_id=int(gen_cfg["visual_token_end_id"]),
70+
top_k=int(gen_cfg["top_k"]),
71+
visual_ids=[
72+
int(llm_cfg["image_token_id"]),
73+
int(llm_cfg["video_token_id"]),
74+
int(llm_cfg["vision_start_token_id"]),
75+
int(llm_cfg["vision_end_token_id"]),
76+
],
4877
)
4978

5079

5180
def parse_args() -> argparse.Namespace:
5281
p = argparse.ArgumentParser(description="Run MammothModa2 T2I (AR -> DiT) with vLLM-Omni.")
53-
p.add_argument(
54-
"--model",
55-
type=str,
56-
required=True,
57-
help="Path to the model directory.",
58-
)
59-
p.add_argument(
60-
"--stage-config",
61-
type=str,
62-
required=True,
63-
help="Path to the multi-stage YAML configuration.",
64-
)
65-
p.add_argument(
66-
"--prompt",
67-
type=str,
68-
action="append",
69-
default=None,
82+
p.add_argument("--model", type=str, required=True, help="Path to the model directory.")
83+
p.add_argument("--stage-config", type=str, required=True,help="Path to the multi-stage YAML configuration.")
84+
p.add_argument("--prompt", type=str, action="append", default=None,
7085
help=(
7186
"Text prompt for image generation. Can be provided multiple times "
72-
"to generate multiple images with shared height/width/CFG settings."
73-
),
74-
)
75-
p.add_argument(
76-
"--height",
77-
type=int,
78-
default=1024,
79-
help="Output image height (must be a multiple of 16).",
80-
)
81-
p.add_argument(
82-
"--width",
83-
type=int,
84-
default=1024,
85-
help="Output image width (must be a multiple of 16).",
86-
)
87-
p.add_argument(
88-
"--num-inference-steps",
89-
type=int,
90-
default=50,
91-
help="Number of diffusion steps for the DiT stage.",
92-
)
93-
p.add_argument(
94-
"--text-guidance-scale",
95-
type=float,
96-
default=9.0,
97-
help="Classifier-Free Guidance (CFG) scale for DiT.",
98-
)
99-
p.add_argument(
100-
"--cfg-range",
101-
type=float,
102-
nargs=2,
103-
default=(0.0, 1.0),
104-
help="Relative step range [start, end] where CFG is active.",
87+
"to generate multiple images with shared height/width/CFG settings."),
10588
)
89+
p.add_argument("--height", type=int, default=1024, help="Output image height (must be a multiple of 16).")
90+
p.add_argument("--width", type=int, default=1024, help="Output image width (must be a multiple of 16).")
91+
p.add_argument("--num-inference-steps", type=int, default=50, help="Number of diffusion steps for the DiT stage.")
92+
p.add_argument("--text-guidance-scale", type=float, default=9.0, help="Classifier-Free Guidance (CFG) scale for DiT.")
93+
p.add_argument("--cfg-range", type=float, nargs=2, default=(0.0, 1.0), help="Relative step range [start, end] where CFG is active.",)
10694
p.add_argument("--out", type=str, default="output.png", help="Path to save the generated image.")
10795
p.add_argument("--trust-remote-code", action="store_true", help="Trust remote code when loading the model.")
10896
args = p.parse_args()
@@ -122,140 +110,109 @@ def tensor_to_pil(image: torch.Tensor) -> Image.Image:
122110
return Image.fromarray(image)
123111

124112

113+
def _format_prompt(user_prompt: str, ar_width: int, ar_height: int) -> str:
114+
"""Build the AR-stage prompt string including the image grid header."""
115+
return (
116+
"<|im_start|>system\nYou are a helpful image generator.<|im_end|>\n"
117+
f"<|im_start|>user\n{user_prompt}<|im_end|>\n"
118+
"<|im_start|>assistant\n"
119+
f"<|image start|>{ar_width}*{ar_height}<|image token|>"
120+
)
121+
122+
123+
def _collect_images(outputs: list) -> list[torch.Tensor]:
124+
"""Extract all image tensors produced by the final (DiT) stage."""
125+
images: list[torch.Tensor] = []
126+
for out in outputs:
127+
ro_list = getattr(out, "request_output", out)
128+
if not isinstance(ro_list, list):
129+
ro_list = [ro_list]
130+
for ro_item in ro_list:
131+
for completion in (getattr(ro_item, "outputs", None) or []):
132+
mm = getattr(completion, "multimodal_output", None)
133+
if not isinstance(mm, dict) or "image" not in mm:
134+
raise RuntimeError(f"Missing image in multimodal output: {mm}")
135+
payload = mm["image"]
136+
for tensor in (payload if isinstance(payload, list) else [payload]):
137+
if not isinstance(tensor, torch.Tensor):
138+
raise TypeError(f"Expected image tensor, got {type(tensor)}")
139+
images.append(tensor)
140+
return images
141+
142+
143+
def _save_images(images: list[torch.Tensor], out_path: str) -> list[str]:
144+
"""Save image tensors to disk.
145+
146+
Single image: written to *out_path* exactly.
147+
Multiple images: suffixed as ``<base>_0<ext>``, ``<base>_1<ext>``, …
148+
"""
149+
if not images:
150+
raise RuntimeError("No images to save.")
151+
base, ext = os.path.splitext(out_path)
152+
ext = ext or ".png"
153+
paths = []
154+
for i, tensor in enumerate(images):
155+
path = out_path if len(images) == 1 else f"{base}_{i}{ext}"
156+
tensor_to_pil(tensor).save(path)
157+
paths.append(path)
158+
return paths
159+
160+
125161
def main() -> None:
126162
args = parse_args()
127163
os.makedirs(os.path.dirname(args.out) or ".", exist_ok=True)
128164

129165
if args.height <= 0 or args.width <= 0:
130166
raise ValueError(f"Height and width must be positive, got {args.height}x{args.width}")
131-
if args.height % 16 != 0 or args.width % 16 != 0:
132-
raise ValueError(f"Height and width must be multiples of 16, got {args.height}x{args.width}")
167+
if args.height % _PATCH_SIZE != 0 or args.width % _PATCH_SIZE != 0:
168+
raise ValueError(f"Height and width must be multiples of {_PATCH_SIZE}, got {args.height}x{args.width}")
133169

134-
ar_height = args.height // 16
135-
ar_width = args.width // 16
136-
137-
eol_token_id, visual_start, visual_end = load_t2i_generation_config(args.model)
170+
ar_height = args.height // _PATCH_SIZE
171+
ar_width = args.width // _PATCH_SIZE
172+
gen_cfg = load_t2i_generation_config(args.model)
138173
expected_grid_tokens = ar_height * (ar_width + 1)
139174

140-
def _format_prompt(user_prompt: str) -> str:
141-
return (
142-
"<|im_start|>system\nYou are a helpful image generator.<|im_end|>\n"
143-
f"<|im_start|>user\n{user_prompt}<|im_end|>\n"
144-
"<|im_start|>assistant\n"
145-
f"<|image start|>{ar_width}*{ar_height}<|image token|>"
146-
)
147-
148175
logger.info("Initializing Omni pipeline...")
149176
omni = Omni(model=args.model, stage_configs_path=args.stage_config, trust_remote_code=args.trust_remote_code)
150-
151177
try:
152178
ar_sampling = SamplingParams(
153179
temperature=1.0,
154180
top_p=1.0,
155-
top_k=2048,
156-
# +1 for generating hidden state of eoi
157-
max_tokens=max(1, expected_grid_tokens + 1),
181+
top_k=gen_cfg.top_k,
182+
max_tokens=max(1, expected_grid_tokens + 1), # +1 for hidden state of eoi
158183
detokenize=False,
159184
)
160-
161185
dit_sampling = SamplingParams(
162-
temperature=0.0,
163-
top_p=1.0,
164-
top_k=-1,
165-
max_tokens=1,
166-
detokenize=False,
186+
temperature=0.0, top_p=1.0, top_k=-1, max_tokens=1, detokenize=False,
167187
)
168188

169-
logger.info("Starting generation...")
170-
shared_additional_information = {
189+
additional_information = {
171190
"omni_task": ["t2i"],
172-
"ar_width": [ar_width],
173-
"ar_height": [ar_height],
174-
"eol_token_id": [eol_token_id],
175-
"visual_token_start_id": [visual_start],
176-
"visual_token_end_id": [visual_end],
177-
"image_height": [args.height],
178-
"image_width": [args.width],
191+
"ar_width": [ar_width], "ar_height": [ar_height],
192+
"eol_token_id": [gen_cfg.eol_token_id],
193+
"visual_token_start_id": [gen_cfg.visual_token_start_id],
194+
"visual_token_end_id": [gen_cfg.visual_token_end_id],
195+
"image_height": [args.height], "image_width": [args.width],
179196
"num_inference_steps": [args.num_inference_steps],
180197
"text_guidance_scale": [args.text_guidance_scale],
181198
"cfg_range": [args.cfg_range[0], args.cfg_range[1]],
182-
# ["<|image_pad|>", "<|video_pad|>", "<|vision_start|>", "<|vision_end|>"]
183-
"visual_ids": [151655, 151656, 151652, 151653,]
199+
"visual_ids": gen_cfg.visual_ids,
184200
}
185201
inputs = [
186202
{
187-
"prompt": _format_prompt(p),
188-
"additional_information": dict(shared_additional_information),
203+
"prompt": _format_prompt(p, ar_width, ar_height),
204+
"additional_information": dict(additional_information),
189205
}
190206
for p in args.prompt
191207
]
192208

193-
# NOTE: omni.generate() returns a Generator[OmniRequestOutput, None, None].
194-
# Consume it to actually run the pipeline and obtain final outputs.
209+
logger.info("Starting generation...")
210+
# omni.generate() returns a Generator; consume it to run the full pipeline.
195211
outputs = list(omni.generate(inputs, [ar_sampling, dit_sampling]))
196212

197213
logger.info("Post-processing and saving image(s)...")
198-
out_base, out_ext = os.path.splitext(args.out)
199-
saved_paths: list[str] = []
200-
201-
# Flatten to (image_tensor, suffix) list so we can decide filenames.
202-
images_to_save: list[tuple[torch.Tensor, str]] = []
203-
for out_idx, out in enumerate(outputs):
204-
ro = getattr(out, "request_output", out)
205-
ro_list = ro if isinstance(ro, list) else [ro]
206-
if not ro_list:
207-
raise RuntimeError("Empty request_output from final stage.")
208-
209-
req_id = getattr(out, "request_id", None)
210-
req_suffix = f"_{req_id}" if isinstance(req_id, str) and req_id else f"_{out_idx}"
211-
212-
for sample_idx, ro_item in enumerate(ro_list):
213-
completion_outputs = getattr(ro_item, "outputs", None)
214-
if not isinstance(completion_outputs, list) or not completion_outputs:
215-
raise RuntimeError(f"Unexpected RequestOutput.outputs: {type(completion_outputs)} {completion_outputs}")
216-
217-
for completion_idx, completion in enumerate(completion_outputs):
218-
mm = getattr(completion, "multimodal_output", None)
219-
if not isinstance(mm, dict) or "image" not in mm:
220-
raise RuntimeError(
221-
"Unexpected completion multimodal output: "
222-
f"{type(mm)} {mm}, completion={completion}"
223-
)
224-
225-
img_payload = mm["image"]
226-
img_list = img_payload if isinstance(img_payload, list) else [img_payload]
227-
for img_idx, img_tensor in enumerate(img_list):
228-
if not isinstance(img_tensor, torch.Tensor):
229-
raise TypeError(f"Expected image tensor, got {type(img_tensor)}")
230-
suffix_parts = [req_suffix]
231-
if len(ro_list) > 1:
232-
suffix_parts.append(f"_s{sample_idx}")
233-
if len(completion_outputs) > 1:
234-
suffix_parts.append(f"_c{completion_idx}")
235-
if len(img_list) > 1:
236-
suffix_parts.append(f"_i{img_idx}")
237-
images_to_save.append((img_tensor, "".join(suffix_parts)))
238-
239-
# If there's only one image, respect `--out` exactly.
240-
if len(images_to_save) == 1:
241-
img_tensor, _ = images_to_save[0]
242-
pil = tensor_to_pil(img_tensor)
243-
pil.save(args.out)
244-
saved_paths.append(args.out)
245-
else:
246-
if not out_ext:
247-
out_ext = ".png"
248-
for img_tensor, suffix in images_to_save:
249-
out_path = f"{out_base}{suffix}{out_ext}"
250-
pil = tensor_to_pil(img_tensor)
251-
pil.save(out_path)
252-
saved_paths.append(out_path)
253-
254-
for p in saved_paths:
255-
logger.info(f"Successfully saved generated image to: {p}")
256-
257-
except Exception as e:
258-
logger.exception(f"An error occurred during generation: {e}")
214+
for path in _save_images(_collect_images(outputs), args.out):
215+
logger.info(f"Saved: {path}")
259216
finally:
260217
omni.close()
261218

0 commit comments

Comments
 (0)