Skip to content

Commit cca2d2c

Browse files
authored
[Core] Align whisper closer to other multimodal models (#27292)
Signed-off-by: Russell Bryant <[email protected]>
1 parent aab0102 commit cca2d2c

File tree

2 files changed

+21
-41
lines changed

2 files changed

+21
-41
lines changed

vllm/model_executor/models/whisper.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -599,15 +599,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
599599

600600
def forward(
601601
self,
602-
input_features: torch.Tensor | list[torch.Tensor] | None,
603602
input_ids: torch.Tensor | None,
604603
positions: torch.Tensor,
604+
encoder_outputs: list[torch.Tensor],
605605
) -> torch.Tensor:
606-
encoder_outputs = self.get_encoder_outputs(input_features)
606+
assert len(encoder_outputs) in (0, 1)
607+
enc_states = encoder_outputs[0] if len(encoder_outputs) == 1 else None
607608
decoder_outputs = self.decoder(
608609
input_ids=input_ids,
609610
positions=positions,
610-
encoder_hidden_states=encoder_outputs,
611+
encoder_hidden_states=enc_states,
611612
)
612613
return decoder_outputs
613614

@@ -894,13 +895,15 @@ def forward(
894895
self,
895896
input_ids: torch.Tensor,
896897
positions: torch.Tensor,
898+
encoder_outputs: list[torch.Tensor] | None = None,
897899
**kwargs,
898900
) -> torch.Tensor:
899-
audio_input = self._parse_and_validate_audio_input(**kwargs)
901+
if encoder_outputs is None:
902+
encoder_outputs = []
900903
decoder_outputs = self.model(
901-
input_features=audio_input["input_features"],
902904
input_ids=input_ids,
903905
positions=positions,
906+
encoder_outputs=encoder_outputs,
904907
)
905908
return decoder_outputs
906909

vllm/v1/worker/gpu_model_runner.py

Lines changed: 13 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1923,14 +1923,16 @@ def _batch_mm_kwargs_from_scheduler(
19231923

19241924
return mm_kwargs, mm_hashes_pos
19251925

1926-
def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
1926+
def _execute_mm_encoder(
1927+
self, scheduler_output: "SchedulerOutput"
1928+
) -> list[torch.Tensor]:
19271929
# Batch the multi-modal inputs using the helper method.
19281930
mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler(
19291931
scheduler_output
19301932
)
19311933

19321934
if not mm_kwargs:
1933-
return
1935+
return []
19341936

19351937
# Batch mm inputs as much as we can: if a request in the batch has
19361938
# multiple modalities or a different modality than the previous one,
@@ -2007,6 +2009,8 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
20072009
logger.debug("Finish execute for mm hash %s", mm_hash)
20082010
self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash)
20092011

2012+
return encoder_outputs
2013+
20102014
def _gather_mm_embeddings(
20112015
self,
20122016
scheduler_output: "SchedulerOutput",
@@ -2095,38 +2099,6 @@ def _gather_mm_embeddings(
20952099

20962100
return mm_embeds, is_mm_embed
20972101

2098-
def _extract_encoder_inputs(
2099-
self,
2100-
scheduler_output: "SchedulerOutput",
2101-
) -> dict[str, torch.Tensor]:
2102-
"""Extract encoder inputs for encoder-decoder models.
2103-
2104-
This method extracts multimodal input features from scheduled encoder
2105-
inputs and formats them for the encoder-decoder model forward pass.
2106-
"""
2107-
# Batch the multi-modal inputs using the helper method.
2108-
mm_kwargs, _ = self._batch_mm_kwargs_from_scheduler(scheduler_output)
2109-
2110-
if not mm_kwargs:
2111-
return {}
2112-
2113-
# Group MM kwargs by modality and extract features
2114-
model = cast(SupportsMultiModal, self.model)
2115-
encoder_features = {}
2116-
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
2117-
mm_kwargs,
2118-
device=self.device,
2119-
pin_memory=self.pin_memory,
2120-
merge_by_field_config=model.merge_by_field_config,
2121-
multimodal_cpu_fields=model.multimodal_cpu_fields,
2122-
):
2123-
# Add the grouped features to encoder_features dict
2124-
# This allows the model to receive them as kwargs (e.g.,
2125-
# input_features=...)
2126-
encoder_features.update(mm_kwargs_group)
2127-
2128-
return encoder_features
2129-
21302102
def get_model(self) -> nn.Module:
21312103
# get raw model out of the cudagraph wrapper.
21322104
if isinstance(self.model, (CUDAGraphWrapper, UBatchWrapper)):
@@ -2416,8 +2388,13 @@ def _preprocess(
24162388
self.model_config.is_encoder_decoder
24172389
and scheduler_output.scheduled_encoder_inputs
24182390
):
2419-
encoder_inputs = self._extract_encoder_inputs(scheduler_output)
2420-
model_kwargs.update(encoder_inputs)
2391+
# Run the encoder, just like we do with other multimodal inputs.
2392+
# For an encoder-decoder model, our processing here is a bit
2393+
# simpler, because the outputs are just passed to the decoder.
2394+
# We are not doing any prompt replacement. We also will only
2395+
# ever have a single encoder input.
2396+
encoder_outputs = self._execute_mm_encoder(scheduler_output)
2397+
model_kwargs.update({"encoder_outputs": encoder_outputs})
24212398

24222399
return (
24232400
input_ids,

0 commit comments

Comments
 (0)