2121import logging
2222import os
2323from pathlib import Path
24+ from typing import NamedTuple
2425
2526import torch
2627from PIL import Image
3132logging .basicConfig (level = logging .INFO , format = "%(asctime)s - %(levelname)s - %(message)s" )
3233logger = logging .getLogger (__name__ )
3334
34-
35- def load_t2i_generation_config (model_dir : str ) -> tuple [int , int , int ]:
36- """Load T2I token ranges from t2i_generation_config.json."""
37- cfg_path = Path (model_dir ) / "t2i_generation_config.json"
38- if not cfg_path .exists ():
39- raise FileNotFoundError (f"Config not found: { cfg_path } " )
40-
41- with cfg_path .open ("r" , encoding = "utf-8" ) as f :
42- cfg = json .load (f )
43-
44- return (
45- int (cfg ["eol_token_id" ]),
46- int (cfg ["visual_token_start_id" ]),
47- int (cfg ["visual_token_end_id" ]),
35+ # ---------------------------------------------------------------------------
36+ # Constants
37+ # ---------------------------------------------------------------------------
38+ _PATCH_SIZE = 16 # AR image grid patch size (pixels per token)
39+
40+
41+ class T2IGenConfig (NamedTuple ):
42+ eol_token_id : int
43+ visual_token_start_id : int
44+ visual_token_end_id : int
45+ top_k : int # AR sampling top-k (covers the full visual generation vocabulary)
46+ # Qwen2.5-VL special vision tokens: <|image_pad|>, <|video_pad|>, <|vision_start|>, <|vision_end|>
47+ visual_ids : list [int ]
48+
49+
50+ def load_t2i_generation_config (model_dir : str ) -> T2IGenConfig :
51+ """Load T2I token IDs from t2i_generation_config.json and config.json."""
52+ model_path = Path (model_dir )
53+
54+ gen_cfg_path = model_path / "t2i_generation_config.json"
55+ if not gen_cfg_path .exists ():
56+ raise FileNotFoundError (f"Config not found: { gen_cfg_path } " )
57+ with gen_cfg_path .open (encoding = "utf-8" ) as f :
58+ gen_cfg = json .load (f )
59+
60+ model_cfg_path = model_path / "config.json"
61+ if not model_cfg_path .exists ():
62+ raise FileNotFoundError (f"Config not found: { model_cfg_path } " )
63+ with model_cfg_path .open (encoding = "utf-8" ) as f :
64+ llm_cfg = json .load (f ).get ("llm_config" , {})
65+
66+ return T2IGenConfig (
67+ eol_token_id = int (gen_cfg ["eol_token_id" ]),
68+ visual_token_start_id = int (gen_cfg ["visual_token_start_id" ]),
69+ visual_token_end_id = int (gen_cfg ["visual_token_end_id" ]),
70+ top_k = int (gen_cfg ["top_k" ]),
71+ visual_ids = [
72+ int (llm_cfg ["image_token_id" ]),
73+ int (llm_cfg ["video_token_id" ]),
74+ int (llm_cfg ["vision_start_token_id" ]),
75+ int (llm_cfg ["vision_end_token_id" ]),
76+ ],
4877 )
4978
5079
5180def parse_args () -> argparse .Namespace :
5281 p = argparse .ArgumentParser (description = "Run MammothModa2 T2I (AR -> DiT) with vLLM-Omni." )
53- p .add_argument (
54- "--model" ,
55- type = str ,
56- required = True ,
57- help = "Path to the model directory." ,
58- )
59- p .add_argument (
60- "--stage-config" ,
61- type = str ,
62- required = True ,
63- help = "Path to the multi-stage YAML configuration." ,
64- )
65- p .add_argument (
66- "--prompt" ,
67- type = str ,
68- action = "append" ,
69- default = None ,
82+ p .add_argument ("--model" , type = str , required = True , help = "Path to the model directory." )
83+ p .add_argument ("--stage-config" , type = str , required = True ,help = "Path to the multi-stage YAML configuration." )
84+ p .add_argument ("--prompt" , type = str , action = "append" , default = None ,
7085 help = (
7186 "Text prompt for image generation. Can be provided multiple times "
72- "to generate multiple images with shared height/width/CFG settings."
73- ),
74- )
75- p .add_argument (
76- "--height" ,
77- type = int ,
78- default = 1024 ,
79- help = "Output image height (must be a multiple of 16)." ,
80- )
81- p .add_argument (
82- "--width" ,
83- type = int ,
84- default = 1024 ,
85- help = "Output image width (must be a multiple of 16)." ,
86- )
87- p .add_argument (
88- "--num-inference-steps" ,
89- type = int ,
90- default = 50 ,
91- help = "Number of diffusion steps for the DiT stage." ,
92- )
93- p .add_argument (
94- "--text-guidance-scale" ,
95- type = float ,
96- default = 9.0 ,
97- help = "Classifier-Free Guidance (CFG) scale for DiT." ,
98- )
99- p .add_argument (
100- "--cfg-range" ,
101- type = float ,
102- nargs = 2 ,
103- default = (0.0 , 1.0 ),
104- help = "Relative step range [start, end] where CFG is active." ,
87+ "to generate multiple images with shared height/width/CFG settings." ),
10588 )
89+ p .add_argument ("--height" , type = int , default = 1024 , help = "Output image height (must be a multiple of 16)." )
90+ p .add_argument ("--width" , type = int , default = 1024 , help = "Output image width (must be a multiple of 16)." )
91+ p .add_argument ("--num-inference-steps" , type = int , default = 50 , help = "Number of diffusion steps for the DiT stage." )
92+ p .add_argument ("--text-guidance-scale" , type = float , default = 9.0 , help = "Classifier-Free Guidance (CFG) scale for DiT." )
93+ p .add_argument ("--cfg-range" , type = float , nargs = 2 , default = (0.0 , 1.0 ), help = "Relative step range [start, end] where CFG is active." ,)
10694 p .add_argument ("--out" , type = str , default = "output.png" , help = "Path to save the generated image." )
10795 p .add_argument ("--trust-remote-code" , action = "store_true" , help = "Trust remote code when loading the model." )
10896 args = p .parse_args ()
@@ -122,140 +110,109 @@ def tensor_to_pil(image: torch.Tensor) -> Image.Image:
122110 return Image .fromarray (image )
123111
124112
113+ def _format_prompt (user_prompt : str , ar_width : int , ar_height : int ) -> str :
114+ """Build the AR-stage prompt string including the image grid header."""
115+ return (
116+ "<|im_start|>system\n You are a helpful image generator.<|im_end|>\n "
117+ f"<|im_start|>user\n { user_prompt } <|im_end|>\n "
118+ "<|im_start|>assistant\n "
119+ f"<|image start|>{ ar_width } *{ ar_height } <|image token|>"
120+ )
121+
122+
123+ def _collect_images (outputs : list ) -> list [torch .Tensor ]:
124+ """Extract all image tensors produced by the final (DiT) stage."""
125+ images : list [torch .Tensor ] = []
126+ for out in outputs :
127+ ro_list = getattr (out , "request_output" , out )
128+ if not isinstance (ro_list , list ):
129+ ro_list = [ro_list ]
130+ for ro_item in ro_list :
131+ for completion in (getattr (ro_item , "outputs" , None ) or []):
132+ mm = getattr (completion , "multimodal_output" , None )
133+ if not isinstance (mm , dict ) or "image" not in mm :
134+ raise RuntimeError (f"Missing image in multimodal output: { mm } " )
135+ payload = mm ["image" ]
136+ for tensor in (payload if isinstance (payload , list ) else [payload ]):
137+ if not isinstance (tensor , torch .Tensor ):
138+ raise TypeError (f"Expected image tensor, got { type (tensor )} " )
139+ images .append (tensor )
140+ return images
141+
142+
143+ def _save_images (images : list [torch .Tensor ], out_path : str ) -> list [str ]:
144+ """Save image tensors to disk.
145+
146+ Single image: written to *out_path* exactly.
147+ Multiple images: suffixed as ``<base>_0<ext>``, ``<base>_1<ext>``, …
148+ """
149+ if not images :
150+ raise RuntimeError ("No images to save." )
151+ base , ext = os .path .splitext (out_path )
152+ ext = ext or ".png"
153+ paths = []
154+ for i , tensor in enumerate (images ):
155+ path = out_path if len (images ) == 1 else f"{ base } _{ i } { ext } "
156+ tensor_to_pil (tensor ).save (path )
157+ paths .append (path )
158+ return paths
159+
160+
125161def main () -> None :
126162 args = parse_args ()
127163 os .makedirs (os .path .dirname (args .out ) or "." , exist_ok = True )
128164
129165 if args .height <= 0 or args .width <= 0 :
130166 raise ValueError (f"Height and width must be positive, got { args .height } x{ args .width } " )
131- if args .height % 16 != 0 or args .width % 16 != 0 :
132- raise ValueError (f"Height and width must be multiples of 16 , got { args .height } x{ args .width } " )
167+ if args .height % _PATCH_SIZE != 0 or args .width % _PATCH_SIZE != 0 :
168+ raise ValueError (f"Height and width must be multiples of { _PATCH_SIZE } , got { args .height } x{ args .width } " )
133169
134- ar_height = args .height // 16
135- ar_width = args .width // 16
136-
137- eol_token_id , visual_start , visual_end = load_t2i_generation_config (args .model )
170+ ar_height = args .height // _PATCH_SIZE
171+ ar_width = args .width // _PATCH_SIZE
172+ gen_cfg = load_t2i_generation_config (args .model )
138173 expected_grid_tokens = ar_height * (ar_width + 1 )
139174
140- def _format_prompt (user_prompt : str ) -> str :
141- return (
142- "<|im_start|>system\n You are a helpful image generator.<|im_end|>\n "
143- f"<|im_start|>user\n { user_prompt } <|im_end|>\n "
144- "<|im_start|>assistant\n "
145- f"<|image start|>{ ar_width } *{ ar_height } <|image token|>"
146- )
147-
148175 logger .info ("Initializing Omni pipeline..." )
149176 omni = Omni (model = args .model , stage_configs_path = args .stage_config , trust_remote_code = args .trust_remote_code )
150-
151177 try :
152178 ar_sampling = SamplingParams (
153179 temperature = 1.0 ,
154180 top_p = 1.0 ,
155- top_k = 2048 ,
156- # +1 for generating hidden state of eoi
157- max_tokens = max (1 , expected_grid_tokens + 1 ),
181+ top_k = gen_cfg .top_k ,
182+ max_tokens = max (1 , expected_grid_tokens + 1 ), # +1 for hidden state of eoi
158183 detokenize = False ,
159184 )
160-
161185 dit_sampling = SamplingParams (
162- temperature = 0.0 ,
163- top_p = 1.0 ,
164- top_k = - 1 ,
165- max_tokens = 1 ,
166- detokenize = False ,
186+ temperature = 0.0 , top_p = 1.0 , top_k = - 1 , max_tokens = 1 , detokenize = False ,
167187 )
168188
169- logger .info ("Starting generation..." )
170- shared_additional_information = {
189+ additional_information = {
171190 "omni_task" : ["t2i" ],
172- "ar_width" : [ar_width ],
173- "ar_height" : [ar_height ],
174- "eol_token_id" : [eol_token_id ],
175- "visual_token_start_id" : [visual_start ],
176- "visual_token_end_id" : [visual_end ],
177- "image_height" : [args .height ],
178- "image_width" : [args .width ],
191+ "ar_width" : [ar_width ], "ar_height" : [ar_height ],
192+ "eol_token_id" : [gen_cfg .eol_token_id ],
193+ "visual_token_start_id" : [gen_cfg .visual_token_start_id ],
194+ "visual_token_end_id" : [gen_cfg .visual_token_end_id ],
195+ "image_height" : [args .height ], "image_width" : [args .width ],
179196 "num_inference_steps" : [args .num_inference_steps ],
180197 "text_guidance_scale" : [args .text_guidance_scale ],
181198 "cfg_range" : [args .cfg_range [0 ], args .cfg_range [1 ]],
182- # ["<|image_pad|>", "<|video_pad|>", "<|vision_start|>", "<|vision_end|>"]
183- "visual_ids" : [151655 , 151656 , 151652 , 151653 ,]
199+ "visual_ids" : gen_cfg .visual_ids ,
184200 }
185201 inputs = [
186202 {
187- "prompt" : _format_prompt (p ),
188- "additional_information" : dict (shared_additional_information ),
203+ "prompt" : _format_prompt (p , ar_width , ar_height ),
204+ "additional_information" : dict (additional_information ),
189205 }
190206 for p in args .prompt
191207 ]
192208
193- # NOTE: omni.generate() returns a Generator[OmniRequestOutput, None, None].
194- # Consume it to actually run the pipeline and obtain final outputs .
209+ logger . info ( "Starting generation..." )
210+ # omni.generate() returns a Generator; consume it to run the full pipeline .
195211 outputs = list (omni .generate (inputs , [ar_sampling , dit_sampling ]))
196212
197213 logger .info ("Post-processing and saving image(s)..." )
198- out_base , out_ext = os .path .splitext (args .out )
199- saved_paths : list [str ] = []
200-
201- # Flatten to (image_tensor, suffix) list so we can decide filenames.
202- images_to_save : list [tuple [torch .Tensor , str ]] = []
203- for out_idx , out in enumerate (outputs ):
204- ro = getattr (out , "request_output" , out )
205- ro_list = ro if isinstance (ro , list ) else [ro ]
206- if not ro_list :
207- raise RuntimeError ("Empty request_output from final stage." )
208-
209- req_id = getattr (out , "request_id" , None )
210- req_suffix = f"_{ req_id } " if isinstance (req_id , str ) and req_id else f"_{ out_idx } "
211-
212- for sample_idx , ro_item in enumerate (ro_list ):
213- completion_outputs = getattr (ro_item , "outputs" , None )
214- if not isinstance (completion_outputs , list ) or not completion_outputs :
215- raise RuntimeError (f"Unexpected RequestOutput.outputs: { type (completion_outputs )} { completion_outputs } " )
216-
217- for completion_idx , completion in enumerate (completion_outputs ):
218- mm = getattr (completion , "multimodal_output" , None )
219- if not isinstance (mm , dict ) or "image" not in mm :
220- raise RuntimeError (
221- "Unexpected completion multimodal output: "
222- f"{ type (mm )} { mm } , completion={ completion } "
223- )
224-
225- img_payload = mm ["image" ]
226- img_list = img_payload if isinstance (img_payload , list ) else [img_payload ]
227- for img_idx , img_tensor in enumerate (img_list ):
228- if not isinstance (img_tensor , torch .Tensor ):
229- raise TypeError (f"Expected image tensor, got { type (img_tensor )} " )
230- suffix_parts = [req_suffix ]
231- if len (ro_list ) > 1 :
232- suffix_parts .append (f"_s{ sample_idx } " )
233- if len (completion_outputs ) > 1 :
234- suffix_parts .append (f"_c{ completion_idx } " )
235- if len (img_list ) > 1 :
236- suffix_parts .append (f"_i{ img_idx } " )
237- images_to_save .append ((img_tensor , "" .join (suffix_parts )))
238-
239- # If there's only one image, respect `--out` exactly.
240- if len (images_to_save ) == 1 :
241- img_tensor , _ = images_to_save [0 ]
242- pil = tensor_to_pil (img_tensor )
243- pil .save (args .out )
244- saved_paths .append (args .out )
245- else :
246- if not out_ext :
247- out_ext = ".png"
248- for img_tensor , suffix in images_to_save :
249- out_path = f"{ out_base } { suffix } { out_ext } "
250- pil = tensor_to_pil (img_tensor )
251- pil .save (out_path )
252- saved_paths .append (out_path )
253-
254- for p in saved_paths :
255- logger .info (f"Successfully saved generated image to: { p } " )
256-
257- except Exception as e :
258- logger .exception (f"An error occurred during generation: { e } " )
214+ for path in _save_images (_collect_images (outputs ), args .out ):
215+ logger .info (f"Saved: { path } " )
259216 finally :
260217 omni .close ()
261218
0 commit comments