@@ -1923,14 +1923,16 @@ def _batch_mm_kwargs_from_scheduler(
19231923
19241924 return mm_kwargs , mm_hashes_pos
19251925
1926- def _execute_mm_encoder (self , scheduler_output : "SchedulerOutput" ):
1926+ def _execute_mm_encoder (
1927+ self , scheduler_output : "SchedulerOutput"
1928+ ) -> list [torch .Tensor ]:
19271929 # Batch the multi-modal inputs using the helper method.
19281930 mm_kwargs , mm_hashes_pos = self ._batch_mm_kwargs_from_scheduler (
19291931 scheduler_output
19301932 )
19311933
19321934 if not mm_kwargs :
1933- return
1935+ return []
19341936
19351937 # Batch mm inputs as much as we can: if a request in the batch has
19361938 # multiple modalities or a different modality than the previous one,
@@ -2007,6 +2009,8 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
20072009 logger .debug ("Finish execute for mm hash %s" , mm_hash )
20082010 self .maybe_save_ec_to_connector (self .encoder_cache , mm_hash )
20092011
2012+ return encoder_outputs
2013+
20102014 def _gather_mm_embeddings (
20112015 self ,
20122016 scheduler_output : "SchedulerOutput" ,
@@ -2095,38 +2099,6 @@ def _gather_mm_embeddings(
20952099
20962100 return mm_embeds , is_mm_embed
20972101
2098- def _extract_encoder_inputs (
2099- self ,
2100- scheduler_output : "SchedulerOutput" ,
2101- ) -> dict [str , torch .Tensor ]:
2102- """Extract encoder inputs for encoder-decoder models.
2103-
2104- This method extracts multimodal input features from scheduled encoder
2105- inputs and formats them for the encoder-decoder model forward pass.
2106- """
2107- # Batch the multi-modal inputs using the helper method.
2108- mm_kwargs , _ = self ._batch_mm_kwargs_from_scheduler (scheduler_output )
2109-
2110- if not mm_kwargs :
2111- return {}
2112-
2113- # Group MM kwargs by modality and extract features
2114- model = cast (SupportsMultiModal , self .model )
2115- encoder_features = {}
2116- for _ , _ , mm_kwargs_group in group_mm_kwargs_by_modality (
2117- mm_kwargs ,
2118- device = self .device ,
2119- pin_memory = self .pin_memory ,
2120- merge_by_field_config = model .merge_by_field_config ,
2121- multimodal_cpu_fields = model .multimodal_cpu_fields ,
2122- ):
2123- # Add the grouped features to encoder_features dict
2124- # This allows the model to receive them as kwargs (e.g.,
2125- # input_features=...)
2126- encoder_features .update (mm_kwargs_group )
2127-
2128- return encoder_features
2129-
21302102 def get_model (self ) -> nn .Module :
21312103 # get raw model out of the cudagraph wrapper.
21322104 if isinstance (self .model , (CUDAGraphWrapper , UBatchWrapper )):
@@ -2416,8 +2388,13 @@ def _preprocess(
24162388 self .model_config .is_encoder_decoder
24172389 and scheduler_output .scheduled_encoder_inputs
24182390 ):
2419- encoder_inputs = self ._extract_encoder_inputs (scheduler_output )
2420- model_kwargs .update (encoder_inputs )
2391+ # Run the encoder, just like we do with other multimodal inputs.
2392+ # For an encoder-decoder model, our processing here is a bit
2393+ # simpler, because the outputs are just passed to the decoder.
2394+ # We are not doing any prompt replacement. We also will only
2395+ # ever have a single encoder input.
2396+ encoder_outputs = self ._execute_mm_encoder (scheduler_output )
2397+ model_kwargs .update ({"encoder_outputs" : encoder_outputs })
24212398
24222399 return (
24232400 input_ids ,
0 commit comments