Skip to content

Commit 8ef6048

Browse files
CSWYF3634076skyloevil
authored andcommitted
[BugFix][Model] Fix Ernie4.5-VL hanging on long inputs (vllm-project#24074)
Signed-off-by: wangyafeng <[email protected]>
1 parent e46c334 commit 8ef6048

File tree

2 files changed

+18
-7
lines changed

2 files changed

+18
-7
lines changed

vllm/model_executor/models/ernie45_vl.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,6 @@
6666

6767
logger = init_logger(__name__)
6868

69-
_MAX_FRAMES_PER_VIDEO = 16
70-
7169
# === Vision Transformer === #
7270

7371

@@ -839,6 +837,15 @@ def get_image_processor(self, **kwargs: object):
839837
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
840838
return {"image": None, "video": None}
841839

840+
def get_mm_max_tokens_per_item(
841+
self,
842+
seq_len: int,
843+
mm_counts: Mapping[str, int],
844+
) -> Mapping[str, int]:
845+
max_image_tokens = self.get_max_image_tokens()
846+
max_video_tokens = self.get_max_video_tokens(seq_len, mm_counts)
847+
return {"image": max_image_tokens, "video": max_video_tokens}
848+
842849
def _get_vision_info(
843850
self,
844851
*,
@@ -964,8 +971,7 @@ def get_num_frames_with_most_features(
964971
max_image_tokens = self.get_max_image_tokens() * max_images
965972
max_total_frames = self._get_max_video_frames(seq_len -
966973
max_image_tokens)
967-
max_frames_per_video = min(max_total_frames // max(max_videos, 1),
968-
_MAX_FRAMES_PER_VIDEO)
974+
max_frames_per_video = max_total_frames // max(max_videos, 1)
969975

970976
return max(max_frames_per_video, 2)
971977

vllm/model_executor/models/ernie45_vl_moe.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,13 @@ def forward(
287287
if self.has_shared_experts:
288288
shared_output = self.shared_experts(hidden_states)
289289

290-
if visual_token_mask is not None and visual_token_mask.any():
291-
# assert visual_token_mask.shape[0] != hidden_states.shape[0]
290+
if visual_token_mask is not None and visual_token_mask.all():
291+
# only vision modal input
292+
router_logits, _ = self.vision_experts_gate(hidden_states)
293+
final_hidden_states = self.vision_experts(
294+
hidden_states=hidden_states, router_logits=router_logits)
295+
elif visual_token_mask is not None and visual_token_mask.any():
296+
# text and vision modals input
292297
visual_token_mask = visual_token_mask.repeat(
293298
1, self.hidden_size).bool()
294299
text_token_mask = ~visual_token_mask
@@ -310,7 +315,7 @@ def forward(
310315
hidden_states=vision_hidden_states,
311316
router_logits=vision_router_logits).flatten()
312317
else:
313-
# text modal input processing directly
318+
# only text modal input
314319
text_router_logits, _ = self.text_experts_gate(hidden_states)
315320

316321
final_hidden_states = self.text_experts(

0 commit comments

Comments
 (0)