Skip to content

Commit 450c8d6

Browse files
committed
Addressing Review Comments 2
Signed-off-by: Dipankar Sarkar <dipankar@qti.qualcomm.com>
1 parent 9328991 commit 450c8d6

File tree

3 files changed

+12
-112
lines changed

3 files changed

+12
-112
lines changed

QEfficient/generation/embedding_handler.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -252,29 +252,10 @@ def prepare_vlm_inputs(self, image_url: str, query: str, prefill_seq_len: int) -
252252

253253
# Process image and text
254254
inputs = self._processor(images=image, text=prompt, return_tensors="pt")
255-
if (
256-
hasattr(self._qeff_model.model.config, "model_type")
257-
and self._qeff_model.model.config.model_type == "qwen2_5_vl"
258-
):
259-
inputs = self._qeff_model.model.prepare_inputs_for_generation(
260-
inputs=inputs, prefill_seq_len=prefill_seq_len, batch_size=inputs["input_ids"].shape[0]
261-
)
262-
263-
if (
264-
hasattr(self._qeff_model.model.config, "model_type")
265-
and self._qeff_model.model.config.model_type == "qwen3_vl_moe"
266-
):
255+
if (hasattr(self._qeff_model.model.config, "model_type")and self._qeff_model.model.config.model_type in {"qwen2_5_vl", "qwen3_vl_moe", "qwen3_vl"}):
267256
inputs = self._qeff_model.model.prepare_inputs_for_generation(
268-
inputs=inputs, prefill_seq_len=prefill_seq_len, batch_size=inputs["input_ids"].shape[0]
269-
)
270-
271-
if (
272-
hasattr(self._qeff_model.model.config, "model_type")
273-
and self._qeff_model.model.config.model_type == "qwen3_vl"
274-
):
275-
inputs = self._qeff_model.model.prepare_inputs_for_generation(
276-
inputs=inputs, prefill_seq_len=prefill_seq_len, batch_size=inputs["input_ids"].shape[0]
277-
)
257+
inputs=inputs, prefill_seq_len=prefill_seq_len, batch_size=inputs["input_ids"].shape[0]
258+
)
278259

279260
# Convert to float32 if needed
280261
if "pixel_values" in inputs:
@@ -426,7 +407,7 @@ def setup_vision_buffers(self):
426407
buffers = {}
427408
for output_name, shape in shapes.items():
428409
# Create placeholder with appropriate dtype
429-
if "vision_embeds" or "deepstack_features" in output_name:
410+
if "vision_embeds" in output_name or "deepstack_features" in output_name:
430411
buffers[output_name] = np.zeros(shape, dtype=np.float16)
431412
else:
432413
buffers[output_name] = np.zeros(shape, dtype=np.float32)

QEfficient/generation/vlm_generation.py

Lines changed: 7 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -146,15 +146,7 @@ def __init__(
146146
)
147147

148148
# Vision-specific initialization
149-
self.is_qwen2_5_vl = (
150-
hasattr(qeff_model.model.config, "model_type") and qeff_model.model.config.model_type == "qwen2_5_vl"
151-
)
152-
self.is_qwen3_vl_moe = (
153-
hasattr(qeff_model.model.config, "model_type") and qeff_model.model.config.model_type == "qwen3_vl_moe"
154-
)
155-
self.is_qwen3_vl = (
156-
hasattr(qeff_model.model.config, "model_type") and qeff_model.model.config.model_type == "qwen3_vl"
157-
)
149+
self.is_qwen_vl = (hasattr(qeff_model.model.config, "model_type")and qeff_model.model.config.model_type in {"qwen2_5_vl", "qwen3_vl_moe", "qwen3_vl"})
158150
self.qeff_model = qeff_model
159151
self.processor = processor
160152
self.tokenizer = tokenizer
@@ -262,37 +254,12 @@ def run_prefill_for_all_inputs(self, prompt_queue, generation_len):
262254
outputs, position_ids, generation_len = self.run_prefill(
263255
next_prompt, generation_len, decode_batch_id=np.array(decode_batch_id, dtype=np.int64).reshape(1, 1)
264256
)
265-
if self.is_qwen2_5_vl:
266-
_ = self.update_decode_inputs_qwen2_5_vl(outputs, position_ids, generation_len, decode_batch_id)
267-
elif self.is_qwen3_vl_moe:
268-
_ = self.update_decode_inputs_qwen3_vl_moe(outputs, position_ids, generation_len, decode_batch_id)
269-
elif self.is_qwen3_vl:
270-
_ = self.update_decode_inputs_qwen3_vl_moe(outputs, position_ids, generation_len, decode_batch_id)
257+
if self.is_qwen_vl:
258+
_ = self.update_decode_inputs_qwen_vl(outputs, position_ids, generation_len, decode_batch_id)
271259
else:
272260
_ = self.update_decode_input(outputs, position_ids, generation_len, decode_batch_id)
273261

274-
def update_decode_inputs_qwen2_5_vl(self, outputs, position_ids, generation_len, decode_batch_id=None):
275-
"""
276-
Updates the decode input with the generated values.
277-
Args:
278-
outputs (dict): The outputs of the model.
279-
position_ids (array): The position IDs.
280-
generation_len (int): The generation length.
281-
decode_batch_id (int, optional): The decode batch ID. If None, all values are updated. Defaults to None.
282-
283-
Returns:
284-
next_token_id (array): The next token ID.
285-
"""
286-
next_token_id = self._fetch_next_token_id(outputs)
287-
288-
# Store the generated values.
289-
self.decode_input_ids[decode_batch_id or slice(None)] = next_token_id
290-
self.decode_pos_ids[:, decode_batch_id] = position_ids.squeeze(1)
291-
self.generated_ids[decode_batch_id or slice(None), 0] = next_token_id.squeeze(1)
292-
self.generation_len[decode_batch_id or slice(None)] = generation_len
293-
return next_token_id
294-
295-
def update_decode_inputs_qwen3_vl_moe(self, outputs, position_ids, generation_len, decode_batch_id=None):
262+
def update_decode_inputs_qwen_vl(self, outputs, position_ids, generation_len, decode_batch_id=None):
296263
"""
297264
Updates the decode input with the generated values.
298265
Args:
@@ -313,26 +280,6 @@ def update_decode_inputs_qwen3_vl_moe(self, outputs, position_ids, generation_le
313280
self.generation_len[decode_batch_id or slice(None)] = generation_len
314281
return next_token_id
315282

316-
def update_decode_inputs_qwen3_vl(self, outputs, position_ids, generation_len, decode_batch_id=None):
317-
"""
318-
Updates the decode input with the generated values.
319-
Args:
320-
outputs (dict): The outputs of the model.
321-
position_ids (array): The position IDs.
322-
generation_len (int): The generation length.
323-
decode_batch_id (int, optional): The decode batch ID. If None, all values are updated. Defaults to None.
324-
325-
Returns:
326-
next_token_id (array): The next token ID.
327-
"""
328-
next_token_id = self._fetch_next_token_id(outputs)
329-
330-
# Store the generated values.
331-
self.decode_input_ids[decode_batch_id or slice(None)] = next_token_id
332-
self.decode_pos_ids[:, decode_batch_id] = position_ids.squeeze(1)
333-
self.generated_ids[decode_batch_id or slice(None), 0] = next_token_id.squeeze(1)
334-
self.generation_len[decode_batch_id or slice(None)] = generation_len
335-
return next_token_id
336283

337284
def _execute_chunked_prefill(
338285
self,
@@ -632,11 +579,7 @@ def _generate_continuous_batching(self, vision_prompts, generation_len, stream,
632579
max_gen_length = self._ctx_len if not generation_len else max(self._ctx_len, generation_len)
633580

634581
self.initialize_decode_inputs(num_prompts, execution_batch_size, max_gen_length)
635-
if self.is_qwen2_5_vl:
636-
self.decode_pos_ids = np.zeros((4, execution_batch_size, 1), np.int64)
637-
if self.is_qwen3_vl_moe:
638-
self.decode_pos_ids = np.zeros((4, execution_batch_size, 1), np.int64)
639-
if self.is_qwen3_vl:
582+
if self.is_qwen_vl:
640583
self.decode_pos_ids = np.zeros((4, execution_batch_size, 1), np.int64)
641584
# Create prompt queue
642585
prompt_queue = deque(vision_prompts)
@@ -744,16 +687,8 @@ def run_prefill_for_all_inputs_with_cached_vision(self, prompt_queue, generation
744687
generation_len_final = self._fetch_generation_len(generation_len, max_gen_len)
745688

746689
# Update decode inputs
747-
if self.is_qwen2_5_vl:
748-
self.update_decode_inputs_qwen2_5_vl(
749-
outputs, position_ids_decode, generation_len_final, decode_batch_id
750-
)
751-
elif self.is_qwen3_vl_moe:
752-
self.update_decode_inputs_qwen3_vl_moe(
753-
outputs, position_ids_decode, generation_len_final, decode_batch_id
754-
)
755-
elif self.is_qwen3_vl:
756-
self.update_decode_inputs_qwen3_vl(
690+
if self.is_qwen_vl:
691+
self.update_decode_inputs_qwen_vl(
757692
outputs, position_ids_decode, generation_len_final, decode_batch_id
758693
)
759694
else:

QEfficient/transformers/cache_utils.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,6 @@ def update(
192192
A tuple containing the updated key and value states.
193193
"""
194194
# Update the cache
195-
# if not self.is_initialized:
196195

197196
if self.keys is None:
198197
self.keys = key_states
@@ -336,17 +335,13 @@ def __init__(
336335
layer_class_to_replicate=QEffDynamicLayer,
337336
offloading=offloading,
338337
offload_only_non_sliding=offload_only_non_sliding,
339-
# args=args,
340-
# kwargs=kwargs,
341338
)
342339
else:
343340
Cache.__init__(
344341
self,
345342
layers=layers,
346343
offloading=offloading,
347344
offload_only_non_sliding=offload_only_non_sliding,
348-
# args=args,
349-
# kwargs=kwargs,
350345
)
351346

352347
if ddp_cache_data is not None:
@@ -434,18 +429,7 @@ def update3D(
434429
self.append_new_layers(layer_idx)
435430
return self.layers[layer_idx].update3D(key_states, value_states, cache_kwargs)
436431

437-
# def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
438-
# """Returns the sequence length of the cached states. A layer index can be optionally passed."""
439-
# # TODO: deprecate this function in favor of `cache_position`
440-
# breakpoint()
441-
# is_empty_layer = (
442-
# len(self.key_cache) == 0 # no cache in any layer
443-
# or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
444-
# or len(self.key_cache[layer_idx]) == 0 # the layer has no cache
445-
# )
446-
# layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
447-
# return layer_seq_length
448-
432+
449433

450434
class QEffEncoderDecoderCache(EncoderDecoderCache):
451435
"""

0 commit comments

Comments
 (0)