Skip to content

Commit bb24592

Browse files
wangxiongtsywang96
andauthored
[Qwen3-Omni] fixed _get_feat_extract_output_lengths function (#31007)
Signed-off-by: Xiong Wang <[email protected]> Signed-off-by: Roger Wang <[email protected]> Co-authored-by: Roger Wang <[email protected]>
1 parent 369f47a commit bb24592

File tree

1 file changed

+8
-12
lines changed

1 file changed

+8
-12
lines changed

vllm/model_executor/models/qwen3_omni_moe_thinker.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
118118
output_lengths = (
119119
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
120120
)
121-
return feat_lengths, output_lengths
121+
return output_lengths
122122

123123

124124
class Qwen3_VisionPatchEmbed(nn.Module):
@@ -921,13 +921,11 @@ def _get_prompt_updates(
921921
if audio_feature_lengths is None and feature_attention_mask is None:
922922
audio_output_lengths = []
923923
elif audio_feature_lengths is not None:
924-
_, audio_output_lens = _get_feat_extract_output_lengths(
925-
audio_feature_lengths
926-
)
924+
audio_output_lens = _get_feat_extract_output_lengths(audio_feature_lengths)
927925
audio_output_lengths = audio_output_lens.tolist()
928926
elif feature_attention_mask is not None:
929927
assert isinstance(feature_attention_mask, torch.Tensor)
930-
_, audio_output_lens = _get_feat_extract_output_lengths(
928+
audio_output_lens = _get_feat_extract_output_lengths(
931929
feature_attention_mask.sum(-1)
932930
)
933931
audio_output_lengths = audio_output_lens.tolist()
@@ -1111,18 +1109,16 @@ def _process_audio_input(
11111109
audio_input: Qwen2_5OmniAudioFeatureInputs,
11121110
audio_hashes: list[str] | None = None,
11131111
cached_audio_features: torch.Tensor | None = None,
1114-
) -> torch.Tensor:
1112+
) -> tuple[torch.Tensor, ...]:
11151113
input_features = audio_input["input_features"]
11161114
audio_feature_lengths = audio_input["audio_feature_lengths"]
11171115

1118-
audio_feat_lengths, audio_output_lengths = _get_feat_extract_output_lengths(
1119-
audio_feature_lengths
1120-
)
1116+
audio_output_lengths = _get_feat_extract_output_lengths(audio_feature_lengths)
11211117

11221118
audio_outputs = self.audio_tower(
11231119
input_features.to(self.audio_tower.dtype),
11241120
feature_lens=audio_feature_lengths,
1125-
aftercnn_lens=audio_feat_lengths,
1121+
aftercnn_lens=audio_output_lengths,
11261122
)
11271123
audio_features = audio_outputs.last_hidden_state
11281124
return audio_features.split(audio_output_lengths.tolist())
@@ -1579,7 +1575,7 @@ def get_mrope_input_positions(
15791575
+ st_idx
15801576
)
15811577
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
1582-
_, audio_len = _get_feat_extract_output_lengths(
1578+
audio_len = _get_feat_extract_output_lengths(
15831579
audio_feature_lengths[audio_idx]
15841580
)
15851581
llm_pos_ids = (
@@ -1700,7 +1696,7 @@ def get_mrope_input_positions(
17001696
llm_pos_ids_list.append(bos_block)
17011697
llm_pos_ids_list.append(bos_block)
17021698
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
1703-
_, audio_len = _get_feat_extract_output_lengths(
1699+
audio_len = _get_feat_extract_output_lengths(
17041700
audio_feature_lengths[audio_idx]
17051701
)
17061702
audio_llm_pos_ids = (

0 commit comments

Comments
 (0)