Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
125 commits
Select commit Hold shift + click to select a range
cdaef14
register MammothModa2 model in registry.py
HonestDeng Dec 16, 2025
f95173e
add code skeleton
HonestDeng Dec 16, 2025
0c0b611
add skeleton of ar and dit stage
HonestDeng Dec 16, 2025
59ba5a1
constructs ar model
HonestDeng Dec 17, 2025
fb513ce
capture hidden states using hook
HonestDeng Dec 17, 2025
7baa5e5
add input processors
HonestDeng Dec 17, 2025
4f25a05
implement DiT stage
HonestDeng Dec 17, 2025
a68cdc0
remove code of capturing history hidden state
HonestDeng Dec 17, 2025
0e007c0
delete redundant code
HonestDeng Dec 17, 2025
b6c8802
implement MammothModa2ARForConditionalGeneration using qwen2
HonestDeng Dec 17, 2025
a3e28ad
delete useless entry
HonestDeng Dec 17, 2025
20a8a87
Fix MammothModa2 processor/tokenizer in spawn workers
HonestDeng Dec 18, 2025
7a40266
Fix AutoConfig mapping for Mammoth VL subconfigs
HonestDeng Dec 18, 2025
890ff4c
Load config.json successfully
HonestDeng Dec 18, 2025
0d535f6
Add minimal Mammoth text token step debug script
HonestDeng Dec 18, 2025
7371f98
Make Mammoth token-step script fail fast on missing vLLM platform
HonestDeng Dec 18, 2025
e653884
Handle OmniOutput in Mammoth compute_logits
HonestDeng Dec 18, 2025
8eab22b
Fix MammothModa2 wrapper load_weights prefix and AR LM compat
HonestDeng Dec 18, 2025
e3b7a7b
Handle vLLM passing input_ids=None in Mammoth LM
HonestDeng Dec 18, 2025
392d683
Use omni AR worker in Mammoth token-step; fix logits and OmniOutput
HonestDeng Dec 19, 2025
299fe59
Expose VL token ids on Mammothmoda2Config for mrope
HonestDeng Dec 19, 2025
7fd44f9
Add MammothModa2 Omni pipeline runner and text decode
HonestDeng Dec 19, 2025
c889d5d
Add image input support to MammothModa2 Omni example
HonestDeng Dec 19, 2025
2a8081b
Add MammothModa2 unified entry + t2i pipeline scaffold
HonestDeng Dec 20, 2025
2ea2b78
Limit MammothModa2 AR max_model_len to reduce KV cache
HonestDeng Dec 20, 2025
0c52878
Fix MammothModa2 MoE helper for 2D hidden_states
HonestDeng Dec 20, 2025
0f56070
Now we can generate image, but still bugs exist
HonestDeng Dec 21, 2025
a614fd9
insert eol token
HonestDeng Dec 21, 2025
df6c532
mammoth_moda2: build DiT condition from AR hidden states
HonestDeng Dec 22, 2025
4025a12
mammoth_moda2: wire condition into DiT stage
HonestDeng Dec 22, 2025
f4b2a2a
generation_runner: pass runtime additional_information to models
HonestDeng Dec 22, 2025
aa35552
mammoth_moda2: align gen token ids to available hidden states
HonestDeng Dec 22, 2025
559538a
mammoth_moda2: keep additional_information serializable
HonestDeng Dec 22, 2025
a6e524a
mammoth_moda2: fix DiT conditioning and RoPE freqs
HonestDeng Dec 22, 2025
000c8d6
examples: simplify MammothModa2 default prompt
HonestDeng Dec 22, 2025
677d671
transfer height and weight params
HonestDeng Dec 22, 2025
ca9e6a9
delete useless logic
HonestDeng Dec 22, 2025
887a10d
delete backward-compatible codes
HonestDeng Dec 22, 2025
d86b20a
mammoth_moda2: align ar2dit masks with upstream
HonestDeng Dec 22, 2025
fd21182
mammoth_moda2: add DiT CFG params and guidance
HonestDeng Dec 23, 2025
fac1191
examples: derive ar grid from image size
HonestDeng Dec 23, 2025
386b33c
delete backward-compatible code
HonestDeng Dec 23, 2025
73715eb
delete useless arguments
HonestDeng Dec 23, 2025
d3632c3
construct dummy run params
HonestDeng Dec 23, 2025
06e7b6c
delete useless code
HonestDeng Dec 23, 2025
cc7c2f8
move hard-code from runner
HonestDeng Dec 23, 2025
1e670d7
simplify code
HonestDeng Dec 23, 2025
41f96f8
generate eoi token
HonestDeng Dec 23, 2025
cc4f945
simplify code in ar2dit
HonestDeng Dec 23, 2025
fe238e9
delete useless file
HonestDeng Dec 23, 2025
c750217
recover arg_utils.py
HonestDeng Dec 24, 2025
27b5ce3
merge main branch
HonestDeng Dec 24, 2025
2f73e5c
Fix multimodal hooks and mrope handling
HonestDeng Dec 24, 2025
37e0950
delete Chinese comment
HonestDeng Dec 24, 2025
30761cb
simplify code
HonestDeng Dec 25, 2025
7369cc7
delete _build_dummy_mm_embeddings function
HonestDeng Dec 25, 2025
5f1d9b8
change Chinese comments to English
HonestDeng Dec 25, 2025
d81375e
refactor example
HonestDeng Dec 25, 2025
3e38344
delete useless file and rename file
HonestDeng Dec 25, 2025
1e2d343
delete useless ocnfig file
HonestDeng Dec 26, 2025
752b2a3
delete Chinese comment
HonestDeng Dec 26, 2025
f8b5849
examples: support multi-prompt t2i outputs
HonestDeng Dec 26, 2025
0b71f18
Merge upstream/main
HonestDeng Dec 26, 2025
85e6f66
fix bug in calling _build_model_kwargs_extra
HonestDeng Dec 26, 2025
dbd18a9
examples: add MammothModa2 image summary
HonestDeng Dec 26, 2025
397ae64
avoid sampling gen token
HonestDeng Dec 26, 2025
0aef6b6
merge main brach
HonestDeng Dec 27, 2025
79022c9
compute generated_len in runner
HonestDeng Dec 27, 2025
52573d9
run pre-commit
HonestDeng Dec 27, 2025
f3882a6
rename mammothmoda2_dit to mammothmoda2_dit_layer
HonestDeng Dec 27, 2025
31ccc84
revert unrelated change
HonestDeng Dec 27, 2025
e23f8de
revert change
HonestDeng Dec 27, 2025
5bc6ee4
[Model] Support stable diffusion3 (#439)
iwzbi Dec 27, 2025
527d2b6
model forward call logic fixes (#495)
divyanshsinghvi Dec 27, 2025
5d7c6bb
[Bugfix][NPU] Add _model_forward for ModelRunner (#505)
gcanlin Dec 27, 2025
b276b85
[Core] remove deparated code from PR391 (#502)
yinpeiqi Dec 28, 2025
063b6f2
Restore gpu_model_runner.py to upstream/main
HonestDeng Dec 29, 2025
a16cf21
remove redundant code
HonestDeng Dec 29, 2025
05feb63
remove useless code in transport and embedding
HonestDeng Dec 29, 2025
4efe564
remove useless code in TimeEmbedding
HonestDeng Dec 29, 2025
f335680
remove useless code in RMSNorm
HonestDeng Dec 29, 2025
489826f
remove useless code in diffusion_transformer.py
HonestDeng Dec 29, 2025
6943c66
mammoth_moda2: simplify DiT transformer for inference
HonestDeng Dec 29, 2025
27d2c26
merge small file to bigger one
HonestDeng Dec 29, 2025
b5b66a4
delete Chinese comment
HonestDeng Dec 29, 2025
073a6fd
change file name
HonestDeng Dec 29, 2025
0e9011d
delete useless file
HonestDeng Dec 29, 2025
287765b
rename file
HonestDeng Dec 29, 2025
a9de380
polish code
HonestDeng Dec 29, 2025
e126d6c
run pre-commit
HonestDeng Dec 30, 2025
707820d
run mkdocs
HonestDeng Dec 30, 2025
b0d2085
merge main branch
HonestDeng Dec 30, 2025
7301833
Merge branch 'main' into add-mammoth-moda2-support
hsliuustc0106 Jan 1, 2026
c0fd514
add docs supported models and examples
HonestDeng Jan 6, 2026
b911284
Adjust code position
HonestDeng Jan 16, 2026
f3f41b8
merge main branch and fix conflicts
HonestDeng Jan 16, 2026
9ad285a
Merge branch 'main' into add-mammoth-moda2-support
hsliuustc0106 Jan 17, 2026
2561d0d
merge main branch
HonestDeng Jan 23, 2026
7cf6c2c
update vllm api to 0.14.0
HonestDeng Jan 23, 2026
003c3b6
mv dit code to diffusion folder
HonestDeng Jan 23, 2026
1dff925
remove redundant definition of _process_text
HonestDeng Jan 23, 2026
147ed39
Merge remote-tracking branch 'upstream/main' into add-mammoth-moda2-s…
HonestDeng Feb 25, 2026
d6be90e
align with main branch
HonestDeng Feb 26, 2026
3c259cf
remove magic number; fix hardcoded token; add doc-string; split long …
HonestDeng Feb 28, 2026
a101834
add test file
HonestDeng Mar 1, 2026
960a2a9
replace magic number with constant var
HonestDeng Mar 1, 2026
cd95963
delete Chinese comment
HonestDeng Mar 1, 2026
8317857
refactor example file
HonestDeng Mar 1, 2026
6ec6a23
fix typo; useless code
HonestDeng Mar 1, 2026
1eb106c
delete dead code
HonestDeng Mar 1, 2026
c6deeb1
run precommit
HonestDeng Mar 1, 2026
982e321
Merge remote-tracking branch 'upstream/main' into add-mammoth-moda2-s…
HonestDeng Mar 1, 2026
dd47583
fix PP bug
HonestDeng Mar 1, 2026
7cbca14
Merge branch 'main' into add-mammoth-moda2-support
princepride Mar 2, 2026
eca6400
remove image
HonestDeng Mar 2, 2026
4aabccc
mv stage config file to stage_configs
HonestDeng Mar 2, 2026
e9e095f
mv Mammothmoda2Config to transformers_utils/configs
HonestDeng Mar 2, 2026
559d9b0
eagerly import Mammothmoda2Config
HonestDeng Mar 2, 2026
029b5c2
combine mammoth_moda2_ar.py and mammoth_moda2.py
HonestDeng Mar 2, 2026
2ea219a
rm mammothmoda2_dit_layer
HonestDeng Mar 2, 2026
6dec4ec
Validate before the reshape operation
HonestDeng Mar 2, 2026
9f5af84
remove Chinese comment
HonestDeng Mar 2, 2026
77e8d8c
add comment for unused param
HonestDeng Mar 2, 2026
5329e53
raise NotImplementedError for num_reqs > 1 and run precommit
HonestDeng Mar 2, 2026
78b23f9
open file safely
HonestDeng Mar 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ th {
|`LongcatImagePipeline` | LongCat-Image | `meituan-longcat/LongCat-Image` |
|`LongCatImageEditPipeline` | LongCat-Image-Edit | `meituan-longcat/LongCat-Image-Edit` |
|`StableDiffusion3Pipeline` | Stable-Diffusion-3 | `stabilityai/stable-diffusion-3.5-medium` |
|`MammothModa2ForConditionalGeneration` | MammothModa2-Preview | `bytedance-research/MammothModa2-Preview` |
|`Flux2KleinPipeline` | FLUX.2-klein | `black-forest-labs/FLUX.2-klein-4B`, `black-forest-labs/FLUX.2-klein-9B` |
|`FluxPipeline` | FLUX.1-dev | `black-forest-labs/FLUX.1-dev` |
|`OmniGen2Pipeline` | OmniGen2 | `OmniGen2/OmniGen2` |
Expand Down
32 changes: 32 additions & 0 deletions examples/offline_inference/mammothmodal2_preview/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# MammothModa2-Preview

## Run examples (MammothModa2-Preview)

Download model
```bash
hf bytedance-research/MammothModa2-Preview --local-dir ./MammothModa2-Preview
```

### Text-to-Image (T2I)

```bash
python examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_t2i.py \
--model ./MammothModa2-Preview \
--stage-config ./vllm_omni/model_executor/stage_configs/mammoth_moda2.yaml \
--prompt "A stylish woman riding a motorcycle in NYC, movie poster style" \
--height 1024 \
--width 1024 \
--num-inference-steps 50 \
--text-guidance-scale 4.0 \
--out output.png
```

### Image Summary

```bash
python examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_image_summarize.py \
--model ./MammothModa2-Preview \
--stage-config ./vllm_omni/model_executor/stage_configs/mammoth_moda2_ar.yaml \
--question "Summarize this image." \
--image ./image.png
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""
Offline inference example: MammothModa2 image summarization (single AR stage).

Example:
uv run python examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_image_summarize.py \
--model path/to/MammothModa2-Preview \
--stage-config vllm_omni/model_executor/stage_configs/mammoth_moda2_ar.yaml \
--image /path/to/input.jpg \
--question "Please summarize the content of this image."
"""

from __future__ import annotations

import argparse
import os

from PIL import Image
from vllm import SamplingParams
from vllm.multimodal.image import convert_image_mode

from vllm_omni import Omni

DEFAULT_SYSTEM = "You are a helpful assistant."
DEFAULT_QUESTION = "Please summarize the content of this image."


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="MammothModa2 image summarization (offline, AR only).")
parser.add_argument("--model", type=str, required=True, help="Path to model directory or model id.")
parser.add_argument(
"--stage-config", type=str, required=True, help="Path to stage config yaml (single-stage AR->text)."
)
parser.add_argument("--image", type=str, required=True, help="Path to input image.")
parser.add_argument("--question", type=str, default=DEFAULT_QUESTION, help="Question/instruction for the model.")
parser.add_argument("--system", type=str, default=DEFAULT_SYSTEM, help="System prompt.")
parser.add_argument(
"--max-tokens",
type=int,
default=512,
help="Max new tokens to generate.",
)
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--top-p", type=float, default=0.9)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--trust-remote-code", action="store_true")
return parser.parse_args()


def build_prompt(system: str, question: str) -> str:
return (
f"<|im_start|>system\n{system}<|im_end|>\n"
"<|im_start|>user\n"
"<|vision_start|><|image_pad|><|vision_end|>"
f"{question}<|im_end|>\n"
"<|im_start|>assistant\n"
)


def main() -> None:
args = parse_args()

if not os.path.exists(args.image):
raise FileNotFoundError(f"Image file not found: {args.image}")

pil_image = Image.open(args.image)
image_data = convert_image_mode(pil_image, "RGB")
prompt = build_prompt(args.system, args.question)

omni = Omni(
model=args.model,
stage_configs_path=args.stage_config,
trust_remote_code=args.trust_remote_code,
)
try:
sp = SamplingParams(
temperature=float(args.temperature),
top_p=float(args.top_p),
top_k=-1,
max_tokens=int(args.max_tokens),
seed=int(args.seed),
detokenize=True,
)
# NOTE: omni.generate() returns a Generator[OmniRequestOutput, None, None].
# Consume it inside the try block so the worker isn't closed early.
outputs = list(
omni.generate(
[
{
"prompt": prompt,
"multi_modal_data": {"image": image_data},
"additional_information": {"omni_task": ["chat"]},
}
],
[sp],
)
)
finally:
omni.close()

lines: list[str] = []
for stage_outputs in outputs:
req_outputs = getattr(stage_outputs, "request_output", stage_outputs)
req_outputs = req_outputs if isinstance(req_outputs, list) else [req_outputs]
for ro in req_outputs:
text = ro.outputs[0].text if getattr(ro, "outputs", None) else str(ro)
lines.append(f"request_id: {getattr(ro, 'request_id', 'unknown')}\n")
lines.append("answer:\n")
lines.append(text.strip() + "\n")
lines.append("\n")

print("\n".join(lines))


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
"""
Offline inference example for MammothModa2 Text-to-Image (T2I) generation.
This script uses the vllm_omni.Omni pipeline with a multi-stage configuration.

Workflow:
1. Stage 0 (AR): Generates visual tokens and their corresponding hidden states.
2. Stage 1 (DiT): Consumes the hidden states as conditions to perform diffusion
and VAE decoding to produce the final image.

Example Usage:
uv run python examples/offline_inference/run_mammothmoda2_t2i.py \
--model path/to/MammothModa2-Preview \
--stage-config vllm_omni/model_executor/stage_configs/mammoth_moda2.yaml \
--prompt "A stylish woman riding a motorcycle in NYC, movie poster style" \
--out output.png
"""

from __future__ import annotations

import argparse
import json
import logging
import os
from pathlib import Path
from typing import NamedTuple

import torch
from PIL import Image
from vllm.sampling_params import SamplingParams

from vllm_omni import Omni

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
_PATCH_SIZE = 16 # AR image grid patch size (pixels per token)


class T2IGenConfig(NamedTuple):
eol_token_id: int
visual_token_start_id: int
visual_token_end_id: int
top_k: int # AR sampling top-k (covers the full visual generation vocabulary)
# Qwen2.5-VL special vision tokens: <|image_pad|>, <|video_pad|>, <|vision_start|>, <|vision_end|>
visual_ids: list[int]


def load_t2i_generation_config(model_dir: str) -> T2IGenConfig:
"""Load T2I token IDs from t2i_generation_config.json and config.json."""
model_path = Path(model_dir)

gen_cfg_path = model_path / "t2i_generation_config.json"
if not gen_cfg_path.exists():
raise FileNotFoundError(f"Config not found: {gen_cfg_path}")
with gen_cfg_path.open(encoding="utf-8") as f:
gen_cfg = json.load(f)

model_cfg_path = model_path / "config.json"
if not model_cfg_path.exists():
raise FileNotFoundError(f"Config not found: {model_cfg_path}")
with model_cfg_path.open(encoding="utf-8") as f:
llm_cfg = json.load(f).get("llm_config", {})

return T2IGenConfig(
eol_token_id=int(gen_cfg["eol_token_id"]),
visual_token_start_id=int(gen_cfg["visual_token_start_id"]),
visual_token_end_id=int(gen_cfg["visual_token_end_id"]),
top_k=int(gen_cfg["top_k"]),
visual_ids=[
int(llm_cfg["image_token_id"]),
int(llm_cfg["video_token_id"]),
int(llm_cfg["vision_start_token_id"]),
int(llm_cfg["vision_end_token_id"]),
],
)


def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Run MammothModa2 T2I (AR -> DiT) with vLLM-Omni.")
p.add_argument("--model", type=str, required=True, help="Path to the model directory.")
p.add_argument("--stage-config", type=str, required=True, help="Path to the multi-stage YAML configuration.")
p.add_argument(
"--prompt",
type=str,
action="append",
default=None,
help=(
"Text prompt for image generation. Can be provided multiple times "
"to generate multiple images with shared height/width/CFG settings."
),
)
p.add_argument("--height", type=int, default=1024, help="Output image height (must be a multiple of 16).")
p.add_argument("--width", type=int, default=1024, help="Output image width (must be a multiple of 16).")
p.add_argument("--num-inference-steps", type=int, default=50, help="Number of diffusion steps for the DiT stage.")
p.add_argument(
"--text-guidance-scale", type=float, default=9.0, help="Classifier-Free Guidance (CFG) scale for DiT."
)
p.add_argument(
"--cfg-range",
type=float,
nargs=2,
default=(0.0, 1.0),
help="Relative step range [start, end] where CFG is active.",
)
p.add_argument("--out", type=str, default="output.png", help="Path to save the generated image.")
p.add_argument("--trust-remote-code", action="store_true", help="Trust remote code when loading the model.")
args = p.parse_args()
if not args.prompt:
args.prompt = ["A stylish woman with sunglasses riding a motorcycle in NYC."]
return args


def tensor_to_pil(image: torch.Tensor) -> Image.Image:
"""Convert a normalized torch tensor [-1, 1] to a PIL Image."""
if image.ndim == 4:
image = image[0]
image = image.detach().to("cpu")
image = (image / 2 + 0.5).clamp(0, 1)
image = (image * 255).to(torch.uint8)
image = image.permute(1, 2, 0).contiguous().numpy()
return Image.fromarray(image)


def _format_prompt(user_prompt: str, ar_width: int, ar_height: int) -> str:
"""Build the AR-stage prompt string including the image grid header."""
return (
"<|im_start|>system\nYou are a helpful image generator.<|im_end|>\n"
f"<|im_start|>user\n{user_prompt}<|im_end|>\n"
"<|im_start|>assistant\n"
f"<|image start|>{ar_width}*{ar_height}<|image token|>"
)


def _collect_images(outputs: list) -> list[torch.Tensor]:
"""Extract all image tensors produced by the final (DiT) stage."""
images: list[torch.Tensor] = []
for out in outputs:
ro_list = getattr(out, "request_output", out)
if not isinstance(ro_list, list):
ro_list = [ro_list]
for ro_item in ro_list:
for completion in getattr(ro_item, "outputs", None) or []:
mm = getattr(completion, "multimodal_output", None)
if not isinstance(mm, dict) or "image" not in mm:
raise RuntimeError(f"Missing image in multimodal output: {mm}")
payload = mm["image"]
for tensor in payload if isinstance(payload, list) else [payload]:
if not isinstance(tensor, torch.Tensor):
raise TypeError(f"Expected image tensor, got {type(tensor)}")
images.append(tensor)
return images


def _save_images(images: list[torch.Tensor], out_path: str) -> list[str]:
"""Save image tensors to disk.

Single image: written to *out_path* exactly.
Multiple images: suffixed as ``<base>_0<ext>``, ``<base>_1<ext>``, …
"""
if not images:
raise RuntimeError("No images to save.")
base, ext = os.path.splitext(out_path)
ext = ext or ".png"
paths = []
for i, tensor in enumerate(images):
path = out_path if len(images) == 1 else f"{base}_{i}{ext}"
tensor_to_pil(tensor).save(path)
paths.append(path)
return paths


def main() -> None:
args = parse_args()
os.makedirs(os.path.dirname(args.out) or ".", exist_ok=True)

if args.height <= 0 or args.width <= 0:
raise ValueError(f"Height and width must be positive, got {args.height}x{args.width}")
if args.height % _PATCH_SIZE != 0 or args.width % _PATCH_SIZE != 0:
raise ValueError(f"Height and width must be multiples of {_PATCH_SIZE}, got {args.height}x{args.width}")

ar_height = args.height // _PATCH_SIZE
ar_width = args.width // _PATCH_SIZE
gen_cfg = load_t2i_generation_config(args.model)
expected_grid_tokens = ar_height * (ar_width + 1)

logger.info("Initializing Omni pipeline...")
omni = Omni(model=args.model, stage_configs_path=args.stage_config, trust_remote_code=args.trust_remote_code)
try:
ar_sampling = SamplingParams(
temperature=1.0,
top_p=1.0,
top_k=gen_cfg.top_k,
max_tokens=max(1, expected_grid_tokens + 1), # +1 for hidden state of eoi
detokenize=False,
)
dit_sampling = SamplingParams(
temperature=0.0,
top_p=1.0,
top_k=-1,
max_tokens=1,
detokenize=False,
)

additional_information = {
"omni_task": ["t2i"],
"ar_width": [ar_width],
"ar_height": [ar_height],
"eol_token_id": [gen_cfg.eol_token_id],
"visual_token_start_id": [gen_cfg.visual_token_start_id],
"visual_token_end_id": [gen_cfg.visual_token_end_id],
"image_height": [args.height],
"image_width": [args.width],
"num_inference_steps": [args.num_inference_steps],
"text_guidance_scale": [args.text_guidance_scale],
"cfg_range": [args.cfg_range[0], args.cfg_range[1]],
"visual_ids": gen_cfg.visual_ids,
}
inputs = [
{
"prompt": _format_prompt(p, ar_width, ar_height),
"additional_information": dict(additional_information),
}
for p in args.prompt
]

logger.info("Starting generation...")
# omni.generate() returns a Generator; consume it to run the full pipeline.
outputs = list(omni.generate(inputs, [ar_sampling, dit_sampling]))

logger.info("Post-processing and saving image(s)...")
for path in _save_images(_collect_images(outputs), args.out):
logger.info(f"Saved: {path}")
finally:
omni.close()


if __name__ == "__main__":
main()
Loading