@@ -146,15 +146,7 @@ def __init__(
146146 )
147147
148148 # Vision-specific initialization
149- self .is_qwen2_5_vl = (
150- hasattr (qeff_model .model .config , "model_type" ) and qeff_model .model .config .model_type == "qwen2_5_vl"
151- )
152- self .is_qwen3_vl_moe = (
153- hasattr (qeff_model .model .config , "model_type" ) and qeff_model .model .config .model_type == "qwen3_vl_moe"
154- )
155- self .is_qwen3_vl = (
156- hasattr (qeff_model .model .config , "model_type" ) and qeff_model .model .config .model_type == "qwen3_vl"
157- )
149+ self .is_qwen_vl = (hasattr (qeff_model .model .config , "model_type" )and qeff_model .model .config .model_type in {"qwen2_5_vl" , "qwen3_vl_moe" , "qwen3_vl" })
158150 self .qeff_model = qeff_model
159151 self .processor = processor
160152 self .tokenizer = tokenizer
@@ -262,37 +254,12 @@ def run_prefill_for_all_inputs(self, prompt_queue, generation_len):
262254 outputs , position_ids , generation_len = self .run_prefill (
263255 next_prompt , generation_len , decode_batch_id = np .array (decode_batch_id , dtype = np .int64 ).reshape (1 , 1 )
264256 )
265- if self .is_qwen2_5_vl :
266- _ = self .update_decode_inputs_qwen2_5_vl (outputs , position_ids , generation_len , decode_batch_id )
267- elif self .is_qwen3_vl_moe :
268- _ = self .update_decode_inputs_qwen3_vl_moe (outputs , position_ids , generation_len , decode_batch_id )
269- elif self .is_qwen3_vl :
270- _ = self .update_decode_inputs_qwen3_vl_moe (outputs , position_ids , generation_len , decode_batch_id )
257+ if self .is_qwen_vl :
258+ _ = self .update_decode_inputs_qwen_vl (outputs , position_ids , generation_len , decode_batch_id )
271259 else :
272260 _ = self .update_decode_input (outputs , position_ids , generation_len , decode_batch_id )
273261
274- def update_decode_inputs_qwen2_5_vl (self , outputs , position_ids , generation_len , decode_batch_id = None ):
275- """
276- Updates the decode input with the generated values.
277- Args:
278- outputs (dict): The outputs of the model.
279- position_ids (array): The position IDs.
280- generation_len (int): The generation length.
281- decode_batch_id (int, optional): The decode batch ID. If None, all values are updated. Defaults to None.
282-
283- Returns:
284- next_token_id (array): The next token ID.
285- """
286- next_token_id = self ._fetch_next_token_id (outputs )
287-
288- # Store the generated values.
289- self .decode_input_ids [decode_batch_id or slice (None )] = next_token_id
290- self .decode_pos_ids [:, decode_batch_id ] = position_ids .squeeze (1 )
291- self .generated_ids [decode_batch_id or slice (None ), 0 ] = next_token_id .squeeze (1 )
292- self .generation_len [decode_batch_id or slice (None )] = generation_len
293- return next_token_id
294-
295- def update_decode_inputs_qwen3_vl_moe (self , outputs , position_ids , generation_len , decode_batch_id = None ):
262+ def update_decode_inputs_qwen_vl (self , outputs , position_ids , generation_len , decode_batch_id = None ):
296263 """
297264 Updates the decode input with the generated values.
298265 Args:
@@ -313,26 +280,6 @@ def update_decode_inputs_qwen3_vl_moe(self, outputs, position_ids, generation_le
313280 self .generation_len [decode_batch_id or slice (None )] = generation_len
314281 return next_token_id
315282
316- def update_decode_inputs_qwen3_vl (self , outputs , position_ids , generation_len , decode_batch_id = None ):
317- """
318- Updates the decode input with the generated values.
319- Args:
320- outputs (dict): The outputs of the model.
321- position_ids (array): The position IDs.
322- generation_len (int): The generation length.
323- decode_batch_id (int, optional): The decode batch ID. If None, all values are updated. Defaults to None.
324-
325- Returns:
326- next_token_id (array): The next token ID.
327- """
328- next_token_id = self ._fetch_next_token_id (outputs )
329-
330- # Store the generated values.
331- self .decode_input_ids [decode_batch_id or slice (None )] = next_token_id
332- self .decode_pos_ids [:, decode_batch_id ] = position_ids .squeeze (1 )
333- self .generated_ids [decode_batch_id or slice (None ), 0 ] = next_token_id .squeeze (1 )
334- self .generation_len [decode_batch_id or slice (None )] = generation_len
335- return next_token_id
336283
337284 def _execute_chunked_prefill (
338285 self ,
@@ -632,11 +579,7 @@ def _generate_continuous_batching(self, vision_prompts, generation_len, stream,
632579 max_gen_length = self ._ctx_len if not generation_len else max (self ._ctx_len , generation_len )
633580
634581 self .initialize_decode_inputs (num_prompts , execution_batch_size , max_gen_length )
635- if self .is_qwen2_5_vl :
636- self .decode_pos_ids = np .zeros ((4 , execution_batch_size , 1 ), np .int64 )
637- if self .is_qwen3_vl_moe :
638- self .decode_pos_ids = np .zeros ((4 , execution_batch_size , 1 ), np .int64 )
639- if self .is_qwen3_vl :
582+ if self .is_qwen_vl :
640583 self .decode_pos_ids = np .zeros ((4 , execution_batch_size , 1 ), np .int64 )
641584 # Create prompt queue
642585 prompt_queue = deque (vision_prompts )
@@ -744,16 +687,8 @@ def run_prefill_for_all_inputs_with_cached_vision(self, prompt_queue, generation
744687 generation_len_final = self ._fetch_generation_len (generation_len , max_gen_len )
745688
746689 # Update decode inputs
747- if self .is_qwen2_5_vl :
748- self .update_decode_inputs_qwen2_5_vl (
749- outputs , position_ids_decode , generation_len_final , decode_batch_id
750- )
751- elif self .is_qwen3_vl_moe :
752- self .update_decode_inputs_qwen3_vl_moe (
753- outputs , position_ids_decode , generation_len_final , decode_batch_id
754- )
755- elif self .is_qwen3_vl :
756- self .update_decode_inputs_qwen3_vl (
690+ if self .is_qwen_vl :
691+ self .update_decode_inputs_qwen_vl (
757692 outputs , position_ids_decode , generation_len_final , decode_batch_id
758693 )
759694 else :
0 commit comments