Skip to content

Commit ed2eaf5

Browse files
princeprideusberkeley
authored andcommitted
add mrope and change vit padding token numbers
Signed-off-by: princepride <wangzhipeng628@gmail.com>
1 parent 571eaa6 commit ed2eaf5

File tree

2 files changed

+192
-15
lines changed

2 files changed

+192
-15
lines changed

examples/offline_inference/hunyuan_image3/image_to_text.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,14 @@
1111
"""
1212
The tencent/HunyuanImage-3.0-Instruct base model is built on the Hunyuan v1 architecture, specifically the tencent/Hunyuan-A13B-Instruct model. It utilizes two tokenizer delimiter templates:
1313
14-
1) Pretrained template, which is suitable for text-to-text scenarios:
15-
"<|startoftext|>You are a study abroad planning consultant.<|extra_4|>I don't need general comprehensive rankings. Please list the world's top ten universities for computer science based on the 2025 U.S. News subject rankings.<|extra_0|>\n"
14+
1) Pretrained template (default for gen_text mode), which concatenates system, image
15+
tokens, and user question WITHOUT role delimiters:
16+
"<|startoftext|>{system_prompt}{image_tokens}{user_question}"
1617
17-
2) Instruct template, which is designed for image-to-text scenarios:
18-
"<bos>You are an assistant for recognizing pictures, outputting text.\n\nUser: <img> Describe the content of the picture.\n\nAssistant: "
18+
Example (before image token expansion):
19+
"<|startoftext|>You are an assistant that understands images and outputs text.<img>Describe the content of the picture."
20+
21+
2) Instruct template, which uses explicit role prefixes and separators.
1922
"""
2023

2124

@@ -35,8 +38,8 @@ def parse_args() -> argparse.Namespace:
3538
parser.add_argument(
3639
"--prompt",
3740
type=str,
38-
default="<bos>You are an assistant for recognizing pictures, outputting text.\n\nUser: <img> Describe the content of the picture.\n\nAssistant: ",
39-
help="Text prompt for the model.",
41+
default="<|startoftext|>You are an assistant that understands images and outputs text.<img>Identify the animal in this image and describe this animal's characteristics in the image.",
42+
help="Pretrain template prompt: <|startoftext|>{system}<img>{question}",
4043
)
4144
return parser.parse_args()
4245

vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py

Lines changed: 183 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from PIL import Image
1414
from torch import nn
1515
from torchvision import transforms
16-
from transformers import Siglip2ImageProcessorFast
16+
from transformers import PretrainedConfig, Siglip2ImageProcessorFast
1717
from transformers.feature_extraction_utils import BatchFeature
1818
from transformers.image_utils import ImageInput
1919
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
@@ -44,6 +44,7 @@
4444
from vllm.model_executor.models.interfaces import (
4545
MultiModalEmbeddings,
4646
SupportsLoRA,
47+
SupportsMRoPE,
4748
SupportsMultiModal,
4849
SupportsPP,
4950
_require_is_multimodal,
@@ -58,6 +59,7 @@
5859
from vllm.multimodal import MULTIMODAL_REGISTRY
5960
from vllm.multimodal.inputs import (
6061
MultiModalDataDict,
62+
MultiModalFeatureSpec,
6163
MultiModalFieldConfig,
6264
MultiModalKwargsItems,
6365
)
@@ -73,6 +75,7 @@
7375
PromptUpdateDetails,
7476
)
7577
from vllm.sequence import IntermediateTensors
78+
from vllm.transformers_utils.tokenizer import get_tokenizer
7679
from vllm.utils.tensor_schema import TensorSchema
7780
from vllm.v1.sample.metadata import SamplingMetadata
7881

@@ -1060,7 +1063,8 @@ def get_replacement_image(item_idx: int) -> PromptUpdateDetails:
10601063

10611064
timestep_token_num = 1
10621065
vae_token_num = _vae_token_grid_hw[0] * _vae_token_grid_hw[1]
1063-
vit_token_num = _vit_token_grid_hw[0] * _vit_token_grid_hw[1]
1066+
hf_config = self.info.get_hf_config()
1067+
vit_token_num = hf_config.vit_processor.get("max_num_patches", 729)
10641068

10651069
base_size_token_id = tokenizer.convert_tokens_to_ids(f"<img_size_{_base_size}>")
10661070
if base_size_token_id is None:
@@ -1092,7 +1096,7 @@ def get_replacement_image(item_idx: int) -> PromptUpdateDetails:
10921096
info=HunyuanImage3ProcessingInfo,
10931097
dummy_inputs=HunyuanImage3DummyInputsBuilder,
10941098
)
1095-
class HunyuanImage3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP):
1099+
class HunyuanImage3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE):
10961100
HunyuanImage3Inputs: TypeAlias = HunyuanImage3PixelInputs
10971101

10981102
packed_modules_mapping = {
@@ -1113,9 +1117,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
11131117
config = vllm_config.model_config.hf_config
11141118
quant_config = vllm_config.quant_config
11151119

1116-
# Change the rope_type from 'custom' to 'default' in the AR stage
1120+
# Use mRoPE to preserve 2D positional encoding for image tokens.
11171121
if isinstance(config.rope_parameters, dict):
11181122
config.rope_parameters["rope_type"] = "default"
1123+
head_dim = getattr(
1124+
config,
1125+
"head_dim",
1126+
getattr(config, "attention_head_dim", config.hidden_size // config.num_attention_heads),
1127+
)
1128+
config.rope_parameters["mrope_section"] = [0, head_dim // 4, head_dim // 4]
11191129

11201130
self.config = config
11211131
self.quant_config = quant_config
@@ -1159,6 +1169,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
11591169
# Used to embed timestep information into the input sequence.
11601170
self.timestep_emb = TimestepEmbedder(hidden_size=config.hidden_size)
11611171

1172+
tokenizer = get_tokenizer(vllm_config.model_config.tokenizer)
1173+
self._mrope_img_token_id = tokenizer.convert_tokens_to_ids("<img>")
1174+
self._mrope_boi_token_id = tokenizer.convert_tokens_to_ids("<boi>")
1175+
self._mrope_eoi_token_id = tokenizer.convert_tokens_to_ids("<eoi>")
1176+
self._mrope_joint_img_sep_token_id = tokenizer.convert_tokens_to_ids("<joint_img_sep>")
1177+
self._mrope_max_num_patches = config.vit_processor.get("max_num_patches", 729)
1178+
11621179
def _parse_and_validate_image_input(
11631180
self,
11641181
**kwargs: dict[str, Any],
@@ -1323,11 +1340,6 @@ def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
13231340
# 3. ViT image embeddings
13241341
vit_embed = vit_embeddings[img_idx]
13251342

1326-
# Slice vit_embed to valid tokens
1327-
h, w = vit_spatial_shapes[img_idx]
1328-
valid_tokens = int(h * w)
1329-
vit_embed = vit_embed[:valid_tokens]
1330-
13311343
stacked_embed = torch.cat([timestep_emb, vae_token_embed, vit_embed], dim=0)
13321344
combined_embeddings.append(stacked_embed)
13331345

@@ -1408,3 +1420,165 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
14081420

14091421
def get_language_model(self) -> torch.nn.Module:
14101422
return self.model
1423+
1424+
def get_mrope_input_positions(
1425+
self,
1426+
input_tokens: list[int],
1427+
mm_features: list[MultiModalFeatureSpec] | None = None,
1428+
*,
1429+
hf_config: PretrainedConfig,
1430+
image_grid_thw: list[list[int]] | torch.Tensor,
1431+
video_grid_thw: list[list[int]] | torch.Tensor,
1432+
second_per_grid_ts: list[float] | None = None,
1433+
context_len: int = 0,
1434+
seq_len: int | None = None,
1435+
audio_feature_lengths: torch.Tensor | None = None,
1436+
use_audio_in_video: bool = False,
1437+
) -> tuple[torch.Tensor, int]:
1438+
"""Compute mRoPE positions for HunyuanImage-3.
1439+
1440+
Maps the original model's build_2d_rope logic into vLLM's 3-dim
1441+
mRoPE position tensor [3, seq_len] where dim-0 is temporal (unused,
1442+
kept equal to 1D), dim-1 is height, and dim-2 is width.
1443+
1444+
For text tokens and auxiliary image tokens (timestep, ViT):
1445+
All three dims get the same flat 1D position id.
1446+
For VAE image tokens:
1447+
dim-0 (T): flat 1D position id at the image start
1448+
dim-1 (H): 2D y-position using build_2d_rope centering
1449+
dim-2 (W): 2D x-position using build_2d_rope centering
1450+
"""
1451+
1452+
# Extract per-image VAE grid dims from mm_features
1453+
vae_grids: list[tuple[int, int]] = []
1454+
if mm_features is not None:
1455+
for mm_feature in mm_features:
1456+
mm_item = mm_feature.data
1457+
if mm_item is None:
1458+
continue
1459+
mm_input = mm_item.get_data()
1460+
vae_hw = mm_input.get("vae_token_grid_hw")
1461+
if vae_hw is not None:
1462+
grid = vae_hw.tolist()
1463+
vae_grids.append((int(grid[0]), int(grid[1])))
1464+
1465+
# Identify image token ids (cached from __init__)
1466+
img_token_id = self._mrope_img_token_id
1467+
boi_token_id = self._mrope_boi_token_id
1468+
eoi_token_id = self._mrope_eoi_token_id
1469+
joint_img_sep_token_id = self._mrope_joint_img_sep_token_id
1470+
max_num_patches = self._mrope_max_num_patches
1471+
1472+
# Build position arrays
1473+
t_pos = [] # temporal (same as 1D for this model)
1474+
h_pos = [] # height
1475+
w_pos = [] # width
1476+
1477+
pos = 0 # current 1D position counter
1478+
image_idx = 0
1479+
i = 0
1480+
n = len(input_tokens)
1481+
1482+
while i < n:
1483+
tok = input_tokens[i]
1484+
1485+
if tok == boi_token_id:
1486+
# Found start of image block.
1487+
# Structure: <boi> <size> <ratio> <img>*timestep <img>*vae
1488+
# <joint_img_sep> <img>*vit <eoi>
1489+
# Assign <boi> a flat position
1490+
t_pos.append(pos)
1491+
h_pos.append(pos)
1492+
w_pos.append(pos)
1493+
pos += 1
1494+
i += 1
1495+
1496+
# <size> token
1497+
if i < n:
1498+
t_pos.append(pos)
1499+
h_pos.append(pos)
1500+
w_pos.append(pos)
1501+
pos += 1
1502+
i += 1
1503+
1504+
# <ratio> token
1505+
if i < n:
1506+
t_pos.append(pos)
1507+
h_pos.append(pos)
1508+
w_pos.append(pos)
1509+
pos += 1
1510+
i += 1
1511+
1512+
# Timestep token (1 <img> token)
1513+
if i < n and input_tokens[i] == img_token_id:
1514+
t_pos.append(pos)
1515+
h_pos.append(pos)
1516+
w_pos.append(pos)
1517+
pos += 1
1518+
i += 1
1519+
1520+
# VAE tokens: get grid dims
1521+
if image_idx < len(vae_grids):
1522+
vae_h, vae_w = vae_grids[image_idx]
1523+
else:
1524+
vae_h, vae_w = 0, 0
1525+
image_idx += 1
1526+
1527+
# Assign 2D positions to VAE tokens using build_2d_rope
1528+
# centering logic
1529+
L = pos # position at start of VAE region
1530+
wh = vae_w * vae_h
1531+
beta_y = L + (wh - vae_h) / 2
1532+
beta_x = L + (wh - vae_w) / 2
1533+
1534+
for row in range(vae_h):
1535+
for col in range(vae_w):
1536+
if i < n and input_tokens[i] == img_token_id:
1537+
t_pos.append(L) # temporal stays flat
1538+
h_pos.append(int(beta_y + row))
1539+
w_pos.append(int(beta_x + col))
1540+
i += 1
1541+
1542+
pos = L + wh # advance past VAE region
1543+
1544+
# <joint_img_sep> token
1545+
if i < n and input_tokens[i] == joint_img_sep_token_id:
1546+
t_pos.append(pos)
1547+
h_pos.append(pos)
1548+
w_pos.append(pos)
1549+
pos += 1
1550+
i += 1
1551+
1552+
# ViT tokens (max_num_patches <img> tokens) — flat 1D
1553+
vit_consumed = 0
1554+
while i < n and input_tokens[i] == img_token_id and vit_consumed < max_num_patches:
1555+
t_pos.append(pos)
1556+
h_pos.append(pos)
1557+
w_pos.append(pos)
1558+
pos += 1
1559+
i += 1
1560+
vit_consumed += 1
1561+
1562+
# <eoi> token
1563+
if i < n and input_tokens[i] == eoi_token_id:
1564+
t_pos.append(pos)
1565+
h_pos.append(pos)
1566+
w_pos.append(pos)
1567+
pos += 1
1568+
i += 1
1569+
1570+
else:
1571+
# Regular text token — flat 1D position
1572+
t_pos.append(pos)
1573+
h_pos.append(pos)
1574+
w_pos.append(pos)
1575+
pos += 1
1576+
i += 1
1577+
1578+
llm_positions = torch.tensor([t_pos, h_pos, w_pos], dtype=torch.long)
1579+
mrope_position_delta = llm_positions.max() + 1 - len(input_tokens)
1580+
1581+
if seq_len is not None:
1582+
llm_positions = llm_positions[:, context_len:seq_len]
1583+
1584+
return llm_positions, mrope_position_delta

0 commit comments

Comments
 (0)