Skip to content

Commit c6deeb1

Browse files
committed
run precommit
Signed-off-by: HonestDeng <2958906959@qq.com>
1 parent 1eb106c commit c6deeb1

File tree

5 files changed

+96
-49
lines changed

5 files changed

+96
-49
lines changed

examples/offline_inference/mammothmodal2_preview/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,6 @@ python run_mammothmoda2_t2i.py \
3232
python run_mammothmoda2_image_summarize.py \
3333
--model ./MammothModa2-Preview \
3434
--stage-config ./mammoth_moda2_image_summarize.yaml \
35-
--question "Summerize this image." \
35+
--question "Summarize this image." \
3636
--image ./image.png
3737
```

examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_image_summarize.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,18 @@
2626
def parse_args() -> argparse.Namespace:
2727
parser = argparse.ArgumentParser(description="MammothModa2 image summarization (offline, AR only).")
2828
parser.add_argument("--model", type=str, required=True, help="Path to model directory or model id.")
29-
parser.add_argument("--stage-config", type=str, required=True, help="Path to stage config yaml (single-stage AR->text).")
29+
parser.add_argument(
30+
"--stage-config", type=str, required=True, help="Path to stage config yaml (single-stage AR->text)."
31+
)
3032
parser.add_argument("--image", type=str, required=True, help="Path to input image.")
3133
parser.add_argument("--question", type=str, default=DEFAULT_QUESTION, help="Question/instruction for the model.")
3234
parser.add_argument("--system", type=str, default=DEFAULT_SYSTEM, help="System prompt.")
33-
parser.add_argument("--max-tokens", type=int, default=512, help="Max new tokens to generate.",)
35+
parser.add_argument(
36+
"--max-tokens",
37+
type=int,
38+
default=512,
39+
help="Max new tokens to generate.",
40+
)
3441
parser.add_argument("--temperature", type=float, default=0.2)
3542
parser.add_argument("--top-p", type=float, default=0.9)
3643
parser.add_argument("--seed", type=int, default=42)

examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_t2i.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,14 @@
3535
# ---------------------------------------------------------------------------
3636
# Constants
3737
# ---------------------------------------------------------------------------
38-
_PATCH_SIZE = 16 # AR image grid patch size (pixels per token)
38+
_PATCH_SIZE = 16 # AR image grid patch size (pixels per token)
3939

4040

4141
class T2IGenConfig(NamedTuple):
4242
eol_token_id: int
4343
visual_token_start_id: int
4444
visual_token_end_id: int
45-
top_k: int # AR sampling top-k (covers the full visual generation vocabulary)
45+
top_k: int # AR sampling top-k (covers the full visual generation vocabulary)
4646
# Qwen2.5-VL special vision tokens: <|image_pad|>, <|video_pad|>, <|vision_start|>, <|vision_end|>
4747
visual_ids: list[int]
4848

@@ -80,17 +80,30 @@ def load_t2i_generation_config(model_dir: str) -> T2IGenConfig:
8080
def parse_args() -> argparse.Namespace:
8181
p = argparse.ArgumentParser(description="Run MammothModa2 T2I (AR -> DiT) with vLLM-Omni.")
8282
p.add_argument("--model", type=str, required=True, help="Path to the model directory.")
83-
p.add_argument("--stage-config", type=str, required=True,help="Path to the multi-stage YAML configuration.")
84-
p.add_argument("--prompt", type=str, action="append", default=None,
83+
p.add_argument("--stage-config", type=str, required=True, help="Path to the multi-stage YAML configuration.")
84+
p.add_argument(
85+
"--prompt",
86+
type=str,
87+
action="append",
88+
default=None,
8589
help=(
8690
"Text prompt for image generation. Can be provided multiple times "
87-
"to generate multiple images with shared height/width/CFG settings."),
91+
"to generate multiple images with shared height/width/CFG settings."
92+
),
8893
)
8994
p.add_argument("--height", type=int, default=1024, help="Output image height (must be a multiple of 16).")
9095
p.add_argument("--width", type=int, default=1024, help="Output image width (must be a multiple of 16).")
9196
p.add_argument("--num-inference-steps", type=int, default=50, help="Number of diffusion steps for the DiT stage.")
92-
p.add_argument("--text-guidance-scale", type=float, default=9.0, help="Classifier-Free Guidance (CFG) scale for DiT.")
93-
p.add_argument("--cfg-range", type=float, nargs=2, default=(0.0, 1.0), help="Relative step range [start, end] where CFG is active.",)
97+
p.add_argument(
98+
"--text-guidance-scale", type=float, default=9.0, help="Classifier-Free Guidance (CFG) scale for DiT."
99+
)
100+
p.add_argument(
101+
"--cfg-range",
102+
type=float,
103+
nargs=2,
104+
default=(0.0, 1.0),
105+
help="Relative step range [start, end] where CFG is active.",
106+
)
94107
p.add_argument("--out", type=str, default="output.png", help="Path to save the generated image.")
95108
p.add_argument("--trust-remote-code", action="store_true", help="Trust remote code when loading the model.")
96109
args = p.parse_args()
@@ -128,12 +141,12 @@ def _collect_images(outputs: list) -> list[torch.Tensor]:
128141
if not isinstance(ro_list, list):
129142
ro_list = [ro_list]
130143
for ro_item in ro_list:
131-
for completion in (getattr(ro_item, "outputs", None) or []):
144+
for completion in getattr(ro_item, "outputs", None) or []:
132145
mm = getattr(completion, "multimodal_output", None)
133146
if not isinstance(mm, dict) or "image" not in mm:
134147
raise RuntimeError(f"Missing image in multimodal output: {mm}")
135148
payload = mm["image"]
136-
for tensor in (payload if isinstance(payload, list) else [payload]):
149+
for tensor in payload if isinstance(payload, list) else [payload]:
137150
if not isinstance(tensor, torch.Tensor):
138151
raise TypeError(f"Expected image tensor, got {type(tensor)}")
139152
images.append(tensor)
@@ -183,16 +196,22 @@ def main() -> None:
183196
detokenize=False,
184197
)
185198
dit_sampling = SamplingParams(
186-
temperature=0.0, top_p=1.0, top_k=-1, max_tokens=1, detokenize=False,
199+
temperature=0.0,
200+
top_p=1.0,
201+
top_k=-1,
202+
max_tokens=1,
203+
detokenize=False,
187204
)
188205

189206
additional_information = {
190207
"omni_task": ["t2i"],
191-
"ar_width": [ar_width], "ar_height": [ar_height],
208+
"ar_width": [ar_width],
209+
"ar_height": [ar_height],
192210
"eol_token_id": [gen_cfg.eol_token_id],
193211
"visual_token_start_id": [gen_cfg.visual_token_start_id],
194212
"visual_token_end_id": [gen_cfg.visual_token_end_id],
195-
"image_height": [args.height], "image_width": [args.width],
213+
"image_height": [args.height],
214+
"image_width": [args.width],
196215
"num_inference_steps": [args.num_inference_steps],
197216
"text_guidance_scale": [args.text_guidance_scale],
198217
"cfg_range": [args.cfg_range[0], args.cfg_range[1]],

tests/e2e/offline_inference/test_mammoth_moda2.py

Lines changed: 53 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,15 @@
3232
# Token ID constants (Qwen2.5-VL base tokenizer + MammothModa2 gen vocab)
3333
# ---------------------------------------------------------------------------
3434
# Qwen2.5-VL special vision token IDs (match Mammothmoda2Qwen2_5_VLTextConfig defaults)
35-
_IMAGE_TOKEN_ID = 151655 # "<|image_pad|>"
36-
_VIDEO_TOKEN_ID = 151656 # "<|video_pad|>"
37-
_VISION_START_TOKEN_ID = 151652 # "<|vision_start|>"
38-
_VISION_END_TOKEN_ID = 151653 # "<|vision_end|>"
35+
_IMAGE_TOKEN_ID = 151655 # "<|image_pad|>"
36+
_VIDEO_TOKEN_ID = 151656 # "<|video_pad|>"
37+
_VISION_START_TOKEN_ID = 151652 # "<|vision_start|>"
38+
_VISION_END_TOKEN_ID = 151653 # "<|vision_end|>"
3939
# MammothModa2 generation vocab (from t2i_generation_config.json)
40-
_BASE_VOCAB_SIZE = 152064 # Qwen2.5 base vocab size; also used as eol_token_id
41-
_VISUAL_TOKEN_START_ID = 152072 # first visual generation token
42-
_VISUAL_TOKEN_END_ID = 168456 # last visual generation token
43-
_GEN_VOCAB_SIZE = 32800 # size of the visual generation vocabulary
40+
_BASE_VOCAB_SIZE = 152064 # Qwen2.5 base vocab size; also used as eol_token_id
41+
_VISUAL_TOKEN_START_ID = 152072 # first visual generation token
42+
_VISUAL_TOKEN_END_ID = 168456 # last visual generation token
43+
_GEN_VOCAB_SIZE = 32800 # size of the visual generation vocabulary
4444
# AR stage image grid: each token covers _AR_PATCH_SIZE x _AR_PATCH_SIZE pixels
4545
_AR_PATCH_SIZE = 16
4646
# AR sampling top-k covers the full visual generation vocabulary
@@ -49,7 +49,9 @@
4949
_EXAMPLE_DIR = Path(__file__).resolve().parents[3] / "examples" / "offline_inference" / "mammothmodal2_preview"
5050
MODEL_PATH = os.environ.get("MAMMOTHMODA2_MODEL_PATH", str(_EXAMPLE_DIR / "MammothModa2-Preview"))
5151
T2I_STAGE_CONFIG = os.environ.get("MAMMOTHMODA2_T2I_STAGE_CONFIG", str(_EXAMPLE_DIR / "mammoth_moda2_t2i.yaml"))
52-
SUMMARIZE_STAGE_CONFIG = os.environ.get("MAMMOTHMODA2_SUMMARIZE_STAGE_CONFIG", str(_EXAMPLE_DIR / "mammoth_moda2_image_summarize.yaml"))
52+
SUMMARIZE_STAGE_CONFIG = os.environ.get(
53+
"MAMMOTHMODA2_SUMMARIZE_STAGE_CONFIG", str(_EXAMPLE_DIR / "mammoth_moda2_image_summarize.yaml")
54+
)
5355

5456

5557
def _load_t2i_gen_config(model_dir: str) -> dict:
@@ -78,7 +80,9 @@ class TestConfigParsing:
7880
def test_autoconfig_registration(self):
7981
"""AutoConfig should resolve 'mammothmoda2' model_type."""
8082
from transformers import AutoConfig
83+
8184
from vllm_omni.model_executor.models.mammoth_moda2.config import Mammothmoda2Config # noqa: F401
85+
8286
cfg = AutoConfig.for_model(
8387
model_type="mammothmoda2",
8488
llm_config={"model_type": "mammothmoda2_qwen2_5_vl"},
@@ -88,24 +92,33 @@ def test_autoconfig_registration(self):
8892
def test_dual_vocab_size_computation(self):
8993
"""With extra_gen_vocab=True: vocab_size == gen_vocab_start_index + gen_vocab_size."""
9094
from vllm_omni.model_executor.models.mammoth_moda2.config import Mammothmoda2Qwen2_5_VLTextConfig
91-
tc = Mammothmoda2Qwen2_5_VLTextConfig(vocab_size=_BASE_VOCAB_SIZE, extra_gen_vocab=True, gen_vocab_size=_GEN_VOCAB_SIZE)
95+
96+
tc = Mammothmoda2Qwen2_5_VLTextConfig(
97+
vocab_size=_BASE_VOCAB_SIZE, extra_gen_vocab=True, gen_vocab_size=_GEN_VOCAB_SIZE
98+
)
9299
assert tc.gen_vocab_start_index == _BASE_VOCAB_SIZE
93100
assert tc.vocab_size == _BASE_VOCAB_SIZE + _GEN_VOCAB_SIZE
94101

95102
def test_proxy_properties(self):
96103
"""Top-level config should proxy token IDs from llm_config."""
97104
from vllm_omni.model_executor.models.mammoth_moda2.config import Mammothmoda2Config
98-
cfg = Mammothmoda2Config(llm_config={
99-
"model_type": "mammothmoda2_qwen2_5_vl",
100-
"image_token_id": _IMAGE_TOKEN_ID, "video_token_id": _VIDEO_TOKEN_ID,
101-
"vision_start_token_id": _VISION_START_TOKEN_ID, "vision_end_token_id": _VISION_END_TOKEN_ID,
102-
})
105+
106+
cfg = Mammothmoda2Config(
107+
llm_config={
108+
"model_type": "mammothmoda2_qwen2_5_vl",
109+
"image_token_id": _IMAGE_TOKEN_ID,
110+
"video_token_id": _VIDEO_TOKEN_ID,
111+
"vision_start_token_id": _VISION_START_TOKEN_ID,
112+
"vision_end_token_id": _VISION_END_TOKEN_ID,
113+
}
114+
)
103115
assert cfg.image_token_id == _IMAGE_TOKEN_ID
104116
assert cfg.video_token_id == _VIDEO_TOKEN_ID
105117

106118
def test_missing_llm_config_raises(self):
107119
"""Proxy property access with llm_config=None should raise AttributeError."""
108120
from vllm_omni.model_executor.models.mammoth_moda2.config import Mammothmoda2Config
121+
109122
with pytest.raises(AttributeError, match="llm_config is None"):
110123
_ = Mammothmoda2Config(llm_config=None).image_token_id
111124

@@ -153,18 +166,22 @@ def _stage(ar_outputs: list) -> list:
153166
def _p(image_height: int = 512, image_width: int = 512, visual_ids: list[int] | None = None, **kw) -> dict:
154167
if visual_ids is None:
155168
visual_ids = [_IMAGE_TOKEN_ID, _VIDEO_TOKEN_ID, _VISION_START_TOKEN_ID, _VISION_END_TOKEN_ID]
156-
return {"additional_information": {
157-
"omni_task": ["t2i"],
158-
"ar_width": [image_width // _AR_PATCH_SIZE], "ar_height": [image_height // _AR_PATCH_SIZE],
159-
"eol_token_id": [kw.get("eol_token_id", _BASE_VOCAB_SIZE)],
160-
"visual_token_start_id": [kw.get("visual_token_start_id", _VISUAL_TOKEN_START_ID)],
161-
"visual_token_end_id": [kw.get("visual_token_end_id", _VISUAL_TOKEN_END_ID)],
162-
"image_height": [image_height], "image_width": [image_width],
163-
"num_inference_steps": [kw.get("num_inference_steps", 50)],
164-
"text_guidance_scale": [kw.get("text_guidance_scale", 9.0)],
165-
"cfg_range": list(kw.get("cfg_range", [0.0, 1.0])),
166-
"visual_ids": visual_ids,
167-
}}
169+
return {
170+
"additional_information": {
171+
"omni_task": ["t2i"],
172+
"ar_width": [image_width // _AR_PATCH_SIZE],
173+
"ar_height": [image_height // _AR_PATCH_SIZE],
174+
"eol_token_id": [kw.get("eol_token_id", _BASE_VOCAB_SIZE)],
175+
"visual_token_start_id": [kw.get("visual_token_start_id", _VISUAL_TOKEN_START_ID)],
176+
"visual_token_end_id": [kw.get("visual_token_end_id", _VISUAL_TOKEN_END_ID)],
177+
"image_height": [image_height],
178+
"image_width": [image_width],
179+
"num_inference_steps": [kw.get("num_inference_steps", 50)],
180+
"text_guidance_scale": [kw.get("text_guidance_scale", 9.0)],
181+
"cfg_range": list(kw.get("cfg_range", [0.0, 1.0])),
182+
"visual_ids": visual_ids,
183+
}
184+
}
168185

169186

170187
class TestAR2DitProcessor:
@@ -184,7 +201,9 @@ def test_basic_output_structure(self):
184201
def test_embed_shapes_and_dtype(self):
185202
"""text/image condition embeds must be 2D float32 with correct leading dim."""
186203
n_gen = 30
187-
ar_out = _mock_ar(list(range(15)), list(range(_VISUAL_TOKEN_START_ID, _VISUAL_TOKEN_START_ID + n_gen)) + [0], hidden_dim=128)
204+
ar_out = _mock_ar(
205+
list(range(15)), list(range(_VISUAL_TOKEN_START_ID, _VISUAL_TOKEN_START_ID + n_gen)) + [0], hidden_dim=128
206+
)
188207
info = ar2dit(_stage([ar_out]), engine_input_source=[0], prompts=[_p()])[0]["additional_information"]
189208
assert info["image_prompt_embeds"].shape == (n_gen, 128)
190209
assert info["text_prompt_embeds"].dtype == torch.float32
@@ -227,8 +246,9 @@ def test_visual_ids_excluded_from_text_embeds(self):
227246
ar.prompt_token_ids = prompt_ids
228247
ar.outputs = [c]
229248
ar.request_id = "req-visual"
230-
info = ar2dit(_stage([ar]), engine_input_source=[0],
231-
prompts=[_p(visual_ids=visual_ids)])[0]["additional_information"]
249+
info = ar2dit(_stage([ar]), engine_input_source=[0], prompts=[_p(visual_ids=visual_ids)])[0][
250+
"additional_information"
251+
]
232252
assert info["text_prompt_embeds"].shape[0] == 3
233253

234254

@@ -331,9 +351,7 @@ def test_mammothmoda2_t2i_e2e():
331351
f"Expected image tensor, got {type(img_tensor)}"
332352
)
333353
# DiT output is (C, H*2, W*2) or (1, C, H*2, W*2)
334-
assert img_tensor.ndim in (3, 4), (
335-
f"Expected 3D or 4D image tensor, got {img_tensor.ndim}D"
336-
)
354+
assert img_tensor.ndim in (3, 4), f"Expected 3D or 4D image tensor, got {img_tensor.ndim}D"
337355
found_image = True
338356

339357
assert found_image, "No image tensor found in pipeline output"
@@ -350,6 +368,7 @@ class TestStageConfigValidation:
350368
def test_t2i_config_two_stages(self):
351369
"""T2I YAML must define exactly 2 stages (AR->latent, DiT->image) with correct wiring."""
352370
import yaml
371+
353372
if not Path(T2I_STAGE_CONFIG).exists():
354373
pytest.skip(f"Not found: {T2I_STAGE_CONFIG}")
355374
with open(T2I_STAGE_CONFIG) as f:
@@ -366,6 +385,7 @@ def test_t2i_config_two_stages(self):
366385
def test_summarize_config_single_ar_stage(self):
367386
"""Image-summarisation YAML must be a single AR stage outputting text."""
368387
import yaml
388+
369389
if not Path(SUMMARIZE_STAGE_CONFIG).exists():
370390
pytest.skip(f"Not found: {SUMMARIZE_STAGE_CONFIG}")
371391
with open(SUMMARIZE_STAGE_CONFIG) as f:

vllm_omni/model_executor/models/mammoth_moda2/mammoth_moda2_ar.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from vllm_omni.model_executor.models.output_templates import OmniOutput
5050

5151

52-
def moe_enable(moe_type, layer_type, layer_idx):
52+
def moe_enable(moe_type, layer_type, layer_idx) -> bool:
5353
"""Determine if MoE should be enabled for a specific layer type and index.
5454
5555
Args:
@@ -69,6 +69,7 @@ def moe_enable(moe_type, layer_type, layer_idx):
6969
assert moe_type in ["none", "attention", "ffn", "ffn_attention"]
7070
return layer_type in moe_type and start <= layer_idx < end
7171

72+
7273
def moe_forward(
7374
hidden_states: torch.Tensor,
7475
und_expert: Callable[[torch.Tensor], torch.Tensor],

0 commit comments

Comments
 (0)