1313from PIL import Image
1414from torch import nn
1515from torchvision import transforms
16- from transformers import Siglip2ImageProcessorFast
16+ from transformers import PretrainedConfig , Siglip2ImageProcessorFast
1717from transformers .feature_extraction_utils import BatchFeature
1818from transformers .image_utils import ImageInput
1919from transformers .tokenization_utils_base import PreTokenizedInput , TextInput
4444from vllm .model_executor .models .interfaces import (
4545 MultiModalEmbeddings ,
4646 SupportsLoRA ,
47+ SupportsMRoPE ,
4748 SupportsMultiModal ,
4849 SupportsPP ,
4950 _require_is_multimodal ,
5859from vllm .multimodal import MULTIMODAL_REGISTRY
5960from vllm .multimodal .inputs import (
6061 MultiModalDataDict ,
62+ MultiModalFeatureSpec ,
6163 MultiModalFieldConfig ,
6264 MultiModalKwargsItems ,
6365)
7375 PromptUpdateDetails ,
7476)
7577from vllm .sequence import IntermediateTensors
78+ from vllm .transformers_utils .tokenizer import get_tokenizer
7679from vllm .utils .tensor_schema import TensorSchema
7780from vllm .v1 .sample .metadata import SamplingMetadata
7881
@@ -1060,7 +1063,8 @@ def get_replacement_image(item_idx: int) -> PromptUpdateDetails:
10601063
10611064 timestep_token_num = 1
10621065 vae_token_num = _vae_token_grid_hw [0 ] * _vae_token_grid_hw [1 ]
1063- vit_token_num = _vit_token_grid_hw [0 ] * _vit_token_grid_hw [1 ]
1066+ hf_config = self .info .get_hf_config ()
1067+ vit_token_num = hf_config .vit_processor .get ("max_num_patches" , 729 )
10641068
10651069 base_size_token_id = tokenizer .convert_tokens_to_ids (f"<img_size_{ _base_size } >" )
10661070 if base_size_token_id is None :
@@ -1092,7 +1096,7 @@ def get_replacement_image(item_idx: int) -> PromptUpdateDetails:
10921096 info = HunyuanImage3ProcessingInfo ,
10931097 dummy_inputs = HunyuanImage3DummyInputsBuilder ,
10941098)
1095- class HunyuanImage3ForConditionalGeneration (nn .Module , SupportsMultiModal , SupportsLoRA , SupportsPP ):
1099+ class HunyuanImage3ForConditionalGeneration (nn .Module , SupportsMultiModal , SupportsLoRA , SupportsPP , SupportsMRoPE ):
10961100 HunyuanImage3Inputs : TypeAlias = HunyuanImage3PixelInputs
10971101
10981102 packed_modules_mapping = {
@@ -1113,9 +1117,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
11131117 config = vllm_config .model_config .hf_config
11141118 quant_config = vllm_config .quant_config
11151119
1116- # Change the rope_type from 'custom' to 'default' in the AR stage
1120+ # Use mRoPE to preserve 2D positional encoding for image tokens.
11171121 if isinstance (config .rope_parameters , dict ):
11181122 config .rope_parameters ["rope_type" ] = "default"
1123+ head_dim = getattr (
1124+ config ,
1125+ "head_dim" ,
1126+ getattr (config , "attention_head_dim" , config .hidden_size // config .num_attention_heads ),
1127+ )
1128+ config .rope_parameters ["mrope_section" ] = [0 , head_dim // 4 , head_dim // 4 ]
11191129
11201130 self .config = config
11211131 self .quant_config = quant_config
@@ -1159,6 +1169,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
11591169 # Used to embed timestep information into the input sequence.
11601170 self .timestep_emb = TimestepEmbedder (hidden_size = config .hidden_size )
11611171
1172+ tokenizer = get_tokenizer (vllm_config .model_config .tokenizer )
1173+ self ._mrope_img_token_id = tokenizer .convert_tokens_to_ids ("<img>" )
1174+ self ._mrope_boi_token_id = tokenizer .convert_tokens_to_ids ("<boi>" )
1175+ self ._mrope_eoi_token_id = tokenizer .convert_tokens_to_ids ("<eoi>" )
1176+ self ._mrope_joint_img_sep_token_id = tokenizer .convert_tokens_to_ids ("<joint_img_sep>" )
1177+ self ._mrope_max_num_patches = config .vit_processor .get ("max_num_patches" , 729 )
1178+
11621179 def _parse_and_validate_image_input (
11631180 self ,
11641181 ** kwargs : dict [str , Any ],
@@ -1323,11 +1340,6 @@ def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
13231340 # 3. ViT image embeddings
13241341 vit_embed = vit_embeddings [img_idx ]
13251342
1326- # Slice vit_embed to valid tokens
1327- h , w = vit_spatial_shapes [img_idx ]
1328- valid_tokens = int (h * w )
1329- vit_embed = vit_embed [:valid_tokens ]
1330-
13311343 stacked_embed = torch .cat ([timestep_emb , vae_token_embed , vit_embed ], dim = 0 )
13321344 combined_embeddings .append (stacked_embed )
13331345
@@ -1408,3 +1420,165 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
14081420
14091421 def get_language_model (self ) -> torch .nn .Module :
14101422 return self .model
1423+
1424+ def get_mrope_input_positions (
1425+ self ,
1426+ input_tokens : list [int ],
1427+ mm_features : list [MultiModalFeatureSpec ] | None = None ,
1428+ * ,
1429+ hf_config : PretrainedConfig ,
1430+ image_grid_thw : list [list [int ]] | torch .Tensor ,
1431+ video_grid_thw : list [list [int ]] | torch .Tensor ,
1432+ second_per_grid_ts : list [float ] | None = None ,
1433+ context_len : int = 0 ,
1434+ seq_len : int | None = None ,
1435+ audio_feature_lengths : torch .Tensor | None = None ,
1436+ use_audio_in_video : bool = False ,
1437+ ) -> tuple [torch .Tensor , int ]:
1438+ """Compute mRoPE positions for HunyuanImage-3.
1439+
1440+ Maps the original model's build_2d_rope logic into vLLM's 3-dim
1441+ mRoPE position tensor [3, seq_len] where dim-0 is temporal (unused,
1442+ kept equal to 1D), dim-1 is height, and dim-2 is width.
1443+
1444+ For text tokens and auxiliary image tokens (timestep, ViT):
1445+ All three dims get the same flat 1D position id.
1446+ For VAE image tokens:
1447+ dim-0 (T): flat 1D position id at the image start
1448+ dim-1 (H): 2D y-position using build_2d_rope centering
1449+ dim-2 (W): 2D x-position using build_2d_rope centering
1450+ """
1451+
1452+ # Extract per-image VAE grid dims from mm_features
1453+ vae_grids : list [tuple [int , int ]] = []
1454+ if mm_features is not None :
1455+ for mm_feature in mm_features :
1456+ mm_item = mm_feature .data
1457+ if mm_item is None :
1458+ continue
1459+ mm_input = mm_item .get_data ()
1460+ vae_hw = mm_input .get ("vae_token_grid_hw" )
1461+ if vae_hw is not None :
1462+ grid = vae_hw .tolist ()
1463+ vae_grids .append ((int (grid [0 ]), int (grid [1 ])))
1464+
1465+ # Identify image token ids (cached from __init__)
1466+ img_token_id = self ._mrope_img_token_id
1467+ boi_token_id = self ._mrope_boi_token_id
1468+ eoi_token_id = self ._mrope_eoi_token_id
1469+ joint_img_sep_token_id = self ._mrope_joint_img_sep_token_id
1470+ max_num_patches = self ._mrope_max_num_patches
1471+
1472+ # Build position arrays
1473+ t_pos = [] # temporal (same as 1D for this model)
1474+ h_pos = [] # height
1475+ w_pos = [] # width
1476+
1477+ pos = 0 # current 1D position counter
1478+ image_idx = 0
1479+ i = 0
1480+ n = len (input_tokens )
1481+
1482+ while i < n :
1483+ tok = input_tokens [i ]
1484+
1485+ if tok == boi_token_id :
1486+ # Found start of image block.
1487+ # Structure: <boi> <size> <ratio> <img>*timestep <img>*vae
1488+ # <joint_img_sep> <img>*vit <eoi>
1489+ # Assign <boi> a flat position
1490+ t_pos .append (pos )
1491+ h_pos .append (pos )
1492+ w_pos .append (pos )
1493+ pos += 1
1494+ i += 1
1495+
1496+ # <size> token
1497+ if i < n :
1498+ t_pos .append (pos )
1499+ h_pos .append (pos )
1500+ w_pos .append (pos )
1501+ pos += 1
1502+ i += 1
1503+
1504+ # <ratio> token
1505+ if i < n :
1506+ t_pos .append (pos )
1507+ h_pos .append (pos )
1508+ w_pos .append (pos )
1509+ pos += 1
1510+ i += 1
1511+
1512+ # Timestep token (1 <img> token)
1513+ if i < n and input_tokens [i ] == img_token_id :
1514+ t_pos .append (pos )
1515+ h_pos .append (pos )
1516+ w_pos .append (pos )
1517+ pos += 1
1518+ i += 1
1519+
1520+ # VAE tokens: get grid dims
1521+ if image_idx < len (vae_grids ):
1522+ vae_h , vae_w = vae_grids [image_idx ]
1523+ else :
1524+ vae_h , vae_w = 0 , 0
1525+ image_idx += 1
1526+
1527+ # Assign 2D positions to VAE tokens using build_2d_rope
1528+ # centering logic
1529+ L = pos # position at start of VAE region
1530+ wh = vae_w * vae_h
1531+ beta_y = L + (wh - vae_h ) / 2
1532+ beta_x = L + (wh - vae_w ) / 2
1533+
1534+ for row in range (vae_h ):
1535+ for col in range (vae_w ):
1536+ if i < n and input_tokens [i ] == img_token_id :
1537+ t_pos .append (L ) # temporal stays flat
1538+ h_pos .append (int (beta_y + row ))
1539+ w_pos .append (int (beta_x + col ))
1540+ i += 1
1541+
1542+ pos = L + wh # advance past VAE region
1543+
1544+ # <joint_img_sep> token
1545+ if i < n and input_tokens [i ] == joint_img_sep_token_id :
1546+ t_pos .append (pos )
1547+ h_pos .append (pos )
1548+ w_pos .append (pos )
1549+ pos += 1
1550+ i += 1
1551+
1552+ # ViT tokens (max_num_patches <img> tokens) — flat 1D
1553+ vit_consumed = 0
1554+ while i < n and input_tokens [i ] == img_token_id and vit_consumed < max_num_patches :
1555+ t_pos .append (pos )
1556+ h_pos .append (pos )
1557+ w_pos .append (pos )
1558+ pos += 1
1559+ i += 1
1560+ vit_consumed += 1
1561+
1562+ # <eoi> token
1563+ if i < n and input_tokens [i ] == eoi_token_id :
1564+ t_pos .append (pos )
1565+ h_pos .append (pos )
1566+ w_pos .append (pos )
1567+ pos += 1
1568+ i += 1
1569+
1570+ else :
1571+ # Regular text token — flat 1D position
1572+ t_pos .append (pos )
1573+ h_pos .append (pos )
1574+ w_pos .append (pos )
1575+ pos += 1
1576+ i += 1
1577+
1578+ llm_positions = torch .tensor ([t_pos , h_pos , w_pos ], dtype = torch .long )
1579+ mrope_position_delta = llm_positions .max () + 1 - len (input_tokens )
1580+
1581+ if seq_len is not None :
1582+ llm_positions = llm_positions [:, context_len :seq_len ]
1583+
1584+ return llm_positions , mrope_position_delta
0 commit comments