Skip to content

Commit dc95ccf

Browse files
committed
Code Cleaning Done 1
Signed-off-by: Dipankar Sarkar <dipankar@qti.qualcomm.com>
1 parent e205446 commit dc95ccf

File tree

5 files changed

+121
-58
lines changed

5 files changed

+121
-58
lines changed

QEfficient/transformers/cache_utils.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,6 @@ def write_only(self, key_states, value_states, cache_kwargs):
157157
self.keys = key_states
158158
self.values = value_states
159159
else:
160-
# breakpoint()
161160
position_ids = cache_kwargs.get("position_ids")
162161
batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs
163162

@@ -192,9 +191,8 @@ def update(
192191
Return:
193192
A tuple containing the updated key and value states.
194193
"""
195-
# breakpoint()
194+
196195
# Update the cache
197-
# if not self.is_initialized:
198196

199197
if self.keys is None:
200198
self.keys = key_states
@@ -327,11 +325,10 @@ def __init__(
327325
**kwargs,
328326
):
329327
# Remove layer_classes if present to avoid duplicate argument
330-
# breakpoint()
328+
331329
kwargs.pop("layers", None)
332330
from transformers.cache_utils import Cache # Import here to avoid circular import
333331

334-
# breakpoint()
335332
layers = []
336333
# If a config is passed, use it to infer the layer types and initialize accordingly
337334
if len(layers) == 0:
@@ -373,7 +370,7 @@ def read_only(self, layer_idx, cache_kwargs):
373370
Return:
374371
A tuple containing the updated key and value states.
375372
"""
376-
# breakpoint()
373+
377374
return self.layers[layer_idx].read_only(cache_kwargs)
378375

379376
def read_only_blockedKV(self, start_index, end_index, layer_idx, cache_kwargs):
@@ -439,18 +436,6 @@ def update3D(
439436
self.append_new_layers(layer_idx)
440437
return self.layers[layer_idx].update3D(key_states, value_states, cache_kwargs)
441438

442-
# def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
443-
# """Returns the sequence length of the cached states. A layer index can be optionally passed."""
444-
# # TODO: deprecate this function in favor of `cache_position`
445-
# breakpoint()
446-
# is_empty_layer = (
447-
# len(self.key_cache) == 0 # no cache in any layer
448-
# or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
449-
# or len(self.key_cache[layer_idx]) == 0 # the layer has no cache
450-
# )
451-
# layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
452-
# return layer_seq_length
453-
454439

455440
class QEffEncoderDecoderCache(EncoderDecoderCache):
456441
"""

QEfficient/transformers/models/pytorch_transforms.py

Lines changed: 104 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,11 @@
113113
MistralModel,
114114
MistralRMSNorm,
115115
)
116+
from transformers.models.mistral3.modeling_mistral3 import (
117+
Mistral3ForConditionalGeneration,
118+
Mistral3Model,
119+
Mistral3RMSNorm,
120+
)
116121
from transformers.models.mixtral.modeling_mixtral import (
117122
MixtralAttention,
118123
MixtralDecoderLayer,
@@ -135,6 +140,13 @@
135140
MllamaVisionModel,
136141
)
137142
from transformers.models.mpt.modeling_mpt import MptAttention, MptBlock, MptForCausalLM, MptModel
143+
from transformers.models.olmo2.modeling_olmo2 import (
144+
Olmo2Attention,
145+
Olmo2DecoderLayer,
146+
Olmo2ForCausalLM,
147+
Olmo2Model,
148+
Olmo2RMSNorm,
149+
)
138150
from transformers.models.phi.modeling_phi import PhiAttention, PhiDecoderLayer, PhiForCausalLM, PhiModel
139151
from transformers.models.phi3.modeling_phi3 import (
140152
Phi3Attention,
@@ -143,6 +155,7 @@
143155
Phi3Model,
144156
Phi3RMSNorm,
145157
)
158+
from transformers.models.pixtral.modeling_pixtral import PixtralRMSNorm, PixtralVisionModel
146159
from transformers.models.qwen2.modeling_qwen2 import (
147160
Qwen2Attention,
148161
Qwen2DecoderLayer,
@@ -155,6 +168,7 @@
155168
Qwen2_5_VLAttention,
156169
Qwen2_5_VLDecoderLayer,
157170
Qwen2_5_VLForConditionalGeneration,
171+
Qwen2_5_VLModel,
158172
Qwen2_5_VLTextModel,
159173
Qwen2_5_VLVisionAttention,
160174
)
@@ -168,6 +182,15 @@
168182
Qwen3Model,
169183
Qwen3RMSNorm,
170184
)
185+
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
186+
Qwen3MoeAttention,
187+
Qwen3MoeDecoderLayer,
188+
Qwen3MoeForCausalLM,
189+
Qwen3MoeModel,
190+
Qwen3MoeRMSNorm,
191+
Qwen3MoeRotaryEmbedding,
192+
Qwen3MoeSparseMoeBlock,
193+
)
171194
from transformers.models.qwen3_vl.modeling_qwen3_vl import (
172195
Qwen3VLForConditionalGeneration,
173196
Qwen3VLModel,
@@ -327,6 +350,11 @@
327350
QEffMistralForCausalLM,
328351
QEffMistralModel,
329352
)
353+
from QEfficient.transformers.models.mistral3.modeling_mistral3 import (
354+
QEffMistral3ForConditionalGeneration,
355+
QEffMistral3Model,
356+
QEffPixtralVisionModel,
357+
)
330358
from QEfficient.transformers.models.mixtral_moe.modeling_mixtral import (
331359
QEffMixtralAttention,
332360
QeffMixtralDecoderLayer,
@@ -347,12 +375,25 @@
347375
QEffMllamaTextSelfAttention,
348376
QEffMllamaVisionModel,
349377
)
378+
from QEfficient.transformers.models.molmo.modeling_molmo import (
379+
QEffMolmo,
380+
QEffMolmoBlock,
381+
QEffMolmoModel,
382+
QEffMolmoSequentialBlock,
383+
QEffMultiHeadDotProductAttention,
384+
)
350385
from QEfficient.transformers.models.mpt.modeling_mpt import (
351386
QEffMptAttention,
352387
QEffMptBlock,
353388
QEffMptForCausalLM,
354389
QEFfMptModel,
355390
)
391+
from QEfficient.transformers.models.olmo2.modeling_olmo2 import (
392+
QEffOlmo2Attention,
393+
QEffOlmo2DecoderLayer,
394+
QEffOlmo2ForCausalLM,
395+
QEffOlmo2Model,
396+
)
356397
from QEfficient.transformers.models.phi.modeling_phi import (
357398
QEffPhiAttention,
358399
QEffPhiDecoderLayer,
@@ -375,9 +416,10 @@
375416
QEffQwen2_5_VisionTransformerPretrainedModel,
376417
QEffQwen2_5_VLAttention,
377418
QEffQwen2_5_VLDecoderLayer,
419+
QEffQwen2_5_VLModel,
378420
QEffQwen2_5_VLTextModel,
379-
# QEffQwen2_5_VLModel,
380421
QEffQwen2_5_VLVisionAttention,
422+
QEffQwen_2_5_vl_DecoderWrapper,
381423
QEffQwen_2_5_vl_ForConditionalGeneration,
382424
)
383425
from QEfficient.transformers.models.qwen3.modeling_qwen3 import (
@@ -386,13 +428,20 @@
386428
QEffQwen3ForCausalLM,
387429
QEffQwen3Model,
388430
)
431+
from QEfficient.transformers.models.qwen3_moe.modeling_qwen3_moe import (
432+
QEffQwen3MoeAttention,
433+
QEffQwen3MoeDecoderLayer,
434+
QEffQwen3MoeForCausalLM,
435+
QEffQwen3MoeModel,
436+
QEffQwen3MoeRotaryEmbedding,
437+
QEffQwen3MoeSparseMoeBlock,
438+
)
389439
from QEfficient.transformers.models.qwen3_vl.modeling_qwen3_vl import (
390440
QEffQwen3VLForConditionalGeneration,
391441
QEffQwen3VLModel,
392442
QEffQwen3VLTextAttention,
393443
QEffQwen3VLTextDecoderLayer,
394444
QEffQwen3VLTextModel,
395-
# QEffQwen3VLTextRotaryEmbedding,
396445
QEffQwen3VLVisionAttention,
397446
QEffQwen3VLVisionModel,
398447
)
@@ -430,16 +479,20 @@ class CustomOpsTransform(ModuleMappingTransform):
430479
LlamaRMSNorm: CustomRMSNormAIC,
431480
Llama4TextRMSNorm: CustomRMSNormAIC,
432481
MistralRMSNorm: CustomRMSNormAIC,
482+
Mistral3RMSNorm: CustomRMSNormAIC,
433483
MixtralRMSNorm: CustomRMSNormAIC,
434484
Phi3RMSNorm: CustomRMSNormAIC,
435485
Qwen2RMSNorm: CustomRMSNormAIC,
436486
Qwen3RMSNorm: CustomRMSNormAIC,
437487
Qwen2_5RMSNorm: CustomRMSNormAIC,
438488
MllamaTextRMSNorm: CustomRMSNormAIC,
439489
GraniteRMSNorm: CustomRMSNormAIC,
490+
PixtralRMSNorm: CustomRMSNormAIC,
440491
GraniteMoeRMSNorm: CustomRMSNormAIC,
441-
Qwen3VLTextRMSNorm: CustomRMSNormAIC,
492+
Qwen3MoeRMSNorm: CustomRMSNormAIC,
442493
Gemma3RMSNorm: QEffGemma3CustomRMSNormAIC,
494+
Qwen3VLTextRMSNorm: CustomRMSNormAIC,
495+
Olmo2RMSNorm: CustomRMSNormAIC,
443496
}
444497

445498

@@ -492,12 +545,12 @@ class KVCacheTransform(ModuleMappingTransform):
492545
GemmaModel: QEffGemmaModel,
493546
GemmaForCausalLM: QEffGemmaForCausalLM,
494547
# Qwen3Moe
495-
# Qwen3MoeForCausalLM: QEffQwen3MoeForCausalLM,
496-
# Qwen3MoeModel: QEffQwen3MoeModel,
497-
# Qwen3MoeDecoderLayer: QEffQwen3MoeDecoderLayer,
498-
# Qwen3MoeAttention: QEffQwen3MoeAttention,
499-
# Qwen3MoeRotaryEmbedding: QEffQwen3MoeRotaryEmbedding,
500-
# Qwen3MoeSparseMoeBlock: QEffQwen3MoeSparseMoeBlock,
548+
Qwen3MoeForCausalLM: QEffQwen3MoeForCausalLM,
549+
Qwen3MoeModel: QEffQwen3MoeModel,
550+
Qwen3MoeDecoderLayer: QEffQwen3MoeDecoderLayer,
551+
Qwen3MoeAttention: QEffQwen3MoeAttention,
552+
Qwen3MoeRotaryEmbedding: QEffQwen3MoeRotaryEmbedding,
553+
Qwen3MoeSparseMoeBlock: QEffQwen3MoeSparseMoeBlock,
501554
# Gemma2
502555
Gemma2Attention: QEffGemma2Attention,
503556
Gemma2DecoderLayer: QEffGemma2DecoderLayer,
@@ -545,6 +598,9 @@ class KVCacheTransform(ModuleMappingTransform):
545598
MistralDecoderLayer: QEffMistralDecoderLayer,
546599
MistralModel: QEffMistralModel,
547600
MistralForCausalLM: QEffMistralForCausalLM,
601+
# Mistral3
602+
Mistral3ForConditionalGeneration: QEffMistral3ForConditionalGeneration,
603+
Mistral3Model: QEffMistral3Model,
548604
# Mixtral
549605
MixtralAttention: QEffMixtralAttention,
550606
MixtralSparseMoeBlock: QEffMixtralSparseMoeBlock,
@@ -566,36 +622,34 @@ class KVCacheTransform(ModuleMappingTransform):
566622
PhiDecoderLayer: QEffPhiDecoderLayer,
567623
PhiModel: QEffPhiModel,
568624
PhiForCausalLM: QEffPhiForCausalLM,
625+
# Pixtral
626+
PixtralVisionModel: QEffPixtralVisionModel,
569627
# Qwen2
570628
Qwen2Attention: QEffQwen2Attention,
571629
Qwen2DecoderLayer: QEffQwen2DecoderLayer,
572630
Qwen2Model: QEffQwen2Model,
573631
Qwen2ForCausalLM: QEffQwen2ForCausalLM,
574-
# Qwen2.5 VL
575-
Qwen2_5_VLForConditionalGeneration: QEffQwen_2_5_vl_ForConditionalGeneration,
576-
# Qwen2_5_VLModel: QEffQwen2_5_VLModel,
577-
Qwen2_5_VLTextModel: QEffQwen2_5_VLTextModel,
578632
# Qwen3
579633
Qwen3Attention: QEffQwen3Attention,
580634
Qwen3DecoderLayer: QEffQwen3DecoderLayer,
581635
Qwen3Model: QEffQwen3Model,
582636
Qwen3ForCausalLM: QEffQwen3ForCausalLM,
583637
# Qwen2.5 VL
584-
# Qwen2_5_VLForConditionalGeneration: QEffQwen_2_5_vl_ForConditionalGeneration,
585-
# Qwen2_5_VLModel: QEffQwen2_5_VLModel,
638+
Qwen2_5_VLForConditionalGeneration: QEffQwen_2_5_vl_ForConditionalGeneration,
639+
Qwen2_5_VLModel: QEffQwen2_5_VLModel,
586640
Qwen2_5_VLAttention: QEffQwen2_5_VLAttention,
587641
Qwen2_5_VLDecoderLayer: QEffQwen2_5_VLDecoderLayer,
588642
Qwen2_5_VisionTransformerPretrainedModel: QEffQwen2_5_VisionTransformerPretrainedModel,
589643
Qwen2_5_VLVisionAttention: QEffQwen2_5_VLVisionAttention,
590-
# Qwen3vl
644+
Qwen2_5_VLTextModel: QEffQwen2_5_VLTextModel,
645+
# Qwen3 VL
591646
Qwen3VLForConditionalGeneration: QEffQwen3VLForConditionalGeneration,
592647
Qwen3VLModel: QEffQwen3VLModel,
593648
Qwen3VLTextAttention: QEffQwen3VLTextAttention,
594649
Qwen3VLTextDecoderLayer: QEffQwen3VLTextDecoderLayer,
595650
Qwen3VLVisionAttention: QEffQwen3VLVisionAttention,
596651
Qwen3VLVisionModel: QEffQwen3VLVisionModel,
597652
Qwen3VLTextModel: QEffQwen3VLTextModel,
598-
# Qwen3VLTextRotaryEmbedding: QEffQwen3VLTextRotaryEmbedding, # reusing decoder layer for rotary embedding as they are tightly coupled in forward pass
599653
# Starcoder2
600654
Starcoder2Attention: QEffStarcoder2Attention,
601655
Starcoder2DecoderLayer: QEFFStarcoder2DecoderLayer,
@@ -606,6 +660,11 @@ class KVCacheTransform(ModuleMappingTransform):
606660
GPTBigCodeBlock: QEffGPTBigCodeBlock,
607661
GPTBigCodeModel: QEffGPTBigCodeModel,
608662
GPTBigCodeForCausalLM: QEffGPTBigCodeForCausalLM,
663+
# Olmo2
664+
Olmo2Attention: QEffOlmo2Attention,
665+
Olmo2DecoderLayer: QEffOlmo2DecoderLayer,
666+
Olmo2Model: QEffOlmo2Model,
667+
Olmo2ForCausalLM: QEffOlmo2ForCausalLM,
609668
# Whisper encoder and decoder layers
610669
WhisperPositionalEmbedding: QEffWhisperPositionalEmbedding,
611670
WhisperAttention: QEffWhisperAttention,
@@ -675,7 +734,7 @@ class SpDTransform:
675734
# Llama
676735
QEffLlamaForCausalLM,
677736
QEffQwen2ForCausalLM,
678-
# QEffQwen3ForCausalLM,
737+
QEffQwen3ForCausalLM,
679738
}
680739

681740
@classmethod
@@ -741,7 +800,7 @@ class SamplerTransform:
741800
QEffMptForCausalLM,
742801
QEffPhi3ForCausalLM,
743802
QEffQwen2ForCausalLM,
744-
# QEffQwen_2_5_vl_DecoderWrapper,
803+
QEffQwen_2_5_vl_DecoderWrapper,
745804
}
746805

747806
@classmethod
@@ -787,6 +846,32 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform):
787846
"get_qeff_language_decoder": QEffInternVLModel.get_qeff_language_decoder,
788847
},
789848
"InternVisionEmbeddings": {"forward": QEffInternVisionEmbeddings.forward},
849+
# Mapping for Molmo
850+
"MolmoForCausalLM": {
851+
"forward": QEffMolmoModel.forward,
852+
"get_qeff_vision_encoder": QEffMolmoModel.get_qeff_vision_encoder,
853+
"get_qeff_language_decoder": QEffMolmoModel.get_qeff_language_decoder,
854+
"get_specializations": QEffMolmoModel.get_specializations,
855+
"get_onnx_dynamic_axes": QEffMolmoModel.get_onnx_dynamic_axes,
856+
"get_output_names": QEffMolmoModel.get_output_names,
857+
"get_dummy_inputs": QEffMolmoModel.get_dummy_inputs,
858+
"get_inputs_info": QEffMolmoModel.get_inputs_info,
859+
},
860+
"RMSLayerNorm": {"forward": CustomRMSNormAIC.forward},
861+
# "MolmoForCausalLM": {"forward": QEffMolmoForCausalLM.forward},
862+
"Molmo": {"forward": QEffMolmo.forward},
863+
"MolmoSequentialBlock": {
864+
"forward": QEffMolmoSequentialBlock.forward,
865+
"attention": QEffMolmoBlock.attention,
866+
"__qeff_init__": QEffMolmoBlock.__qeff_init__,
867+
},
868+
"MolmoBlock": {
869+
"attention": QEffMolmoBlock.attention,
870+
"__qeff_init__": QEffMolmoBlock.__qeff_init__,
871+
},
872+
"MultiHeadDotProductAttention": {
873+
"forward": QEffMultiHeadDotProductAttention.forward,
874+
},
790875
# Mapping for grok1 model
791876
"Grok1ModelForCausalLM": {"forward": QEffGrok1ModelForCausalLM.forward},
792877
"Grok1Model": {

QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,6 @@ def eager_attention_forward(
309309

310310

311311
class QEffQwen3VLTextAttention(Qwen3VLTextAttention):
312-
313312
def forward(
314313
self,
315314
hidden_states: torch.Tensor,
@@ -1024,7 +1023,6 @@ def prepare_inputs_for_generation(self, inputs, prefill_seq_len=128, batch_size=
10241023
inputs["position_ids"], pad=(0, padded_len - input_ids_length), mode="constant", value=-1
10251024
)
10261025

1027-
10281026
return inputs
10291027

10301028
def get_inputs_info(self):

0 commit comments

Comments
 (0)