Skip to content

Commit b94ae73

Browse files
committed
Code Cleaning Done 1
Signed-off-by: Dipankar Sarkar <dipankar@qti.qualcomm.com>
1 parent 330444a commit b94ae73

File tree

5 files changed

+121
-58
lines changed

5 files changed

+121
-58
lines changed

QEfficient/transformers/cache_utils.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,6 @@ def write_only(self, key_states, value_states, cache_kwargs):
157157
self.keys = key_states
158158
self.values = value_states
159159
else:
160-
# breakpoint()
161160
position_ids = cache_kwargs.get("position_ids")
162161
batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs
163162

@@ -192,9 +191,8 @@ def update(
192191
Return:
193192
A tuple containing the updated key and value states.
194193
"""
195-
# breakpoint()
194+
196195
# Update the cache
197-
# if not self.is_initialized:
198196

199197
if self.keys is None:
200198
self.keys = key_states
@@ -327,11 +325,10 @@ def __init__(
327325
**kwargs,
328326
):
329327
# Remove layer_classes if present to avoid duplicate argument
330-
# breakpoint()
328+
331329
kwargs.pop("layers", None)
332330
from transformers.cache_utils import Cache # Import here to avoid circular import
333331

334-
# breakpoint()
335332
layers = []
336333
# If a config is passed, use it to infer the layer types and initialize accordingly
337334
if len(layers) == 0:
@@ -373,7 +370,7 @@ def read_only(self, layer_idx, cache_kwargs):
373370
Return:
374371
A tuple containing the updated key and value states.
375372
"""
376-
# breakpoint()
373+
377374
return self.layers[layer_idx].read_only(cache_kwargs)
378375

379376
def read_only_blockedKV(self, start_index, end_index, layer_idx, cache_kwargs):
@@ -439,18 +436,6 @@ def update3D(
439436
self.append_new_layers(layer_idx)
440437
return self.layers[layer_idx].update3D(key_states, value_states, cache_kwargs)
441438

442-
# def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
443-
# """Returns the sequence length of the cached states. A layer index can be optionally passed."""
444-
# # TODO: deprecate this function in favor of `cache_position`
445-
# breakpoint()
446-
# is_empty_layer = (
447-
# len(self.key_cache) == 0 # no cache in any layer
448-
# or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
449-
# or len(self.key_cache[layer_idx]) == 0 # the layer has no cache
450-
# )
451-
# layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
452-
# return layer_seq_length
453-
454439

455440
class QEffEncoderDecoderCache(EncoderDecoderCache):
456441
"""

QEfficient/transformers/models/pytorch_transforms.py

Lines changed: 104 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,11 @@
116116
MistralModel,
117117
MistralRMSNorm,
118118
)
119+
from transformers.models.mistral3.modeling_mistral3 import (
120+
Mistral3ForConditionalGeneration,
121+
Mistral3Model,
122+
Mistral3RMSNorm,
123+
)
119124
from transformers.models.mixtral.modeling_mixtral import (
120125
MixtralAttention,
121126
MixtralDecoderLayer,
@@ -138,6 +143,13 @@
138143
MllamaVisionModel,
139144
)
140145
from transformers.models.mpt.modeling_mpt import MptAttention, MptBlock, MptForCausalLM, MptModel
146+
from transformers.models.olmo2.modeling_olmo2 import (
147+
Olmo2Attention,
148+
Olmo2DecoderLayer,
149+
Olmo2ForCausalLM,
150+
Olmo2Model,
151+
Olmo2RMSNorm,
152+
)
141153
from transformers.models.phi.modeling_phi import PhiAttention, PhiDecoderLayer, PhiForCausalLM, PhiModel
142154
from transformers.models.phi3.modeling_phi3 import (
143155
Phi3Attention,
@@ -146,6 +158,7 @@
146158
Phi3Model,
147159
Phi3RMSNorm,
148160
)
161+
from transformers.models.pixtral.modeling_pixtral import PixtralRMSNorm, PixtralVisionModel
149162
from transformers.models.qwen2.modeling_qwen2 import (
150163
Qwen2Attention,
151164
Qwen2DecoderLayer,
@@ -158,6 +171,7 @@
158171
Qwen2_5_VLAttention,
159172
Qwen2_5_VLDecoderLayer,
160173
Qwen2_5_VLForConditionalGeneration,
174+
Qwen2_5_VLModel,
161175
Qwen2_5_VLTextModel,
162176
Qwen2_5_VLVisionAttention,
163177
)
@@ -171,6 +185,15 @@
171185
Qwen3Model,
172186
Qwen3RMSNorm,
173187
)
188+
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
189+
Qwen3MoeAttention,
190+
Qwen3MoeDecoderLayer,
191+
Qwen3MoeForCausalLM,
192+
Qwen3MoeModel,
193+
Qwen3MoeRMSNorm,
194+
Qwen3MoeRotaryEmbedding,
195+
Qwen3MoeSparseMoeBlock,
196+
)
174197
from transformers.models.qwen3_vl.modeling_qwen3_vl import (
175198
Qwen3VLForConditionalGeneration,
176199
Qwen3VLModel,
@@ -333,6 +356,11 @@
333356
QEffMistralForCausalLM,
334357
QEffMistralModel,
335358
)
359+
from QEfficient.transformers.models.mistral3.modeling_mistral3 import (
360+
QEffMistral3ForConditionalGeneration,
361+
QEffMistral3Model,
362+
QEffPixtralVisionModel,
363+
)
336364
from QEfficient.transformers.models.mixtral_moe.modeling_mixtral import (
337365
QEffMixtralAttention,
338366
QeffMixtralDecoderLayer,
@@ -353,12 +381,25 @@
353381
QEffMllamaTextSelfAttention,
354382
QEffMllamaVisionModel,
355383
)
384+
from QEfficient.transformers.models.molmo.modeling_molmo import (
385+
QEffMolmo,
386+
QEffMolmoBlock,
387+
QEffMolmoModel,
388+
QEffMolmoSequentialBlock,
389+
QEffMultiHeadDotProductAttention,
390+
)
356391
from QEfficient.transformers.models.mpt.modeling_mpt import (
357392
QEffMptAttention,
358393
QEffMptBlock,
359394
QEffMptForCausalLM,
360395
QEFfMptModel,
361396
)
397+
from QEfficient.transformers.models.olmo2.modeling_olmo2 import (
398+
QEffOlmo2Attention,
399+
QEffOlmo2DecoderLayer,
400+
QEffOlmo2ForCausalLM,
401+
QEffOlmo2Model,
402+
)
362403
from QEfficient.transformers.models.phi.modeling_phi import (
363404
QEffPhiAttention,
364405
QEffPhiDecoderLayer,
@@ -381,9 +422,10 @@
381422
QEffQwen2_5_VisionTransformerPretrainedModel,
382423
QEffQwen2_5_VLAttention,
383424
QEffQwen2_5_VLDecoderLayer,
425+
QEffQwen2_5_VLModel,
384426
QEffQwen2_5_VLTextModel,
385-
# QEffQwen2_5_VLModel,
386427
QEffQwen2_5_VLVisionAttention,
428+
QEffQwen_2_5_vl_DecoderWrapper,
387429
QEffQwen_2_5_vl_ForConditionalGeneration,
388430
)
389431
from QEfficient.transformers.models.qwen3.modeling_qwen3 import (
@@ -392,13 +434,20 @@
392434
QEffQwen3ForCausalLM,
393435
QEffQwen3Model,
394436
)
437+
from QEfficient.transformers.models.qwen3_moe.modeling_qwen3_moe import (
438+
QEffQwen3MoeAttention,
439+
QEffQwen3MoeDecoderLayer,
440+
QEffQwen3MoeForCausalLM,
441+
QEffQwen3MoeModel,
442+
QEffQwen3MoeRotaryEmbedding,
443+
QEffQwen3MoeSparseMoeBlock,
444+
)
395445
from QEfficient.transformers.models.qwen3_vl.modeling_qwen3_vl import (
396446
QEffQwen3VLForConditionalGeneration,
397447
QEffQwen3VLModel,
398448
QEffQwen3VLTextAttention,
399449
QEffQwen3VLTextDecoderLayer,
400450
QEffQwen3VLTextModel,
401-
# QEffQwen3VLTextRotaryEmbedding,
402451
QEffQwen3VLVisionAttention,
403452
QEffQwen3VLVisionModel,
404453
)
@@ -436,16 +485,20 @@ class CustomOpsTransform(ModuleMappingTransform):
436485
LlamaRMSNorm: CustomRMSNormAIC,
437486
Llama4TextRMSNorm: CustomRMSNormAIC,
438487
MistralRMSNorm: CustomRMSNormAIC,
488+
Mistral3RMSNorm: CustomRMSNormAIC,
439489
MixtralRMSNorm: CustomRMSNormAIC,
440490
Phi3RMSNorm: CustomRMSNormAIC,
441491
Qwen2RMSNorm: CustomRMSNormAIC,
442492
Qwen3RMSNorm: CustomRMSNormAIC,
443493
Qwen2_5RMSNorm: CustomRMSNormAIC,
444494
MllamaTextRMSNorm: CustomRMSNormAIC,
445495
GraniteRMSNorm: CustomRMSNormAIC,
496+
PixtralRMSNorm: CustomRMSNormAIC,
446497
GraniteMoeRMSNorm: CustomRMSNormAIC,
447-
Qwen3VLTextRMSNorm: CustomRMSNormAIC,
498+
Qwen3MoeRMSNorm: CustomRMSNormAIC,
448499
Gemma3RMSNorm: QEffGemma3CustomRMSNormAIC,
500+
Qwen3VLTextRMSNorm: CustomRMSNormAIC,
501+
Olmo2RMSNorm: CustomRMSNormAIC,
449502
}
450503

451504

@@ -498,12 +551,12 @@ class KVCacheTransform(ModuleMappingTransform):
498551
GemmaModel: QEffGemmaModel,
499552
GemmaForCausalLM: QEffGemmaForCausalLM,
500553
# Qwen3Moe
501-
# Qwen3MoeForCausalLM: QEffQwen3MoeForCausalLM,
502-
# Qwen3MoeModel: QEffQwen3MoeModel,
503-
# Qwen3MoeDecoderLayer: QEffQwen3MoeDecoderLayer,
504-
# Qwen3MoeAttention: QEffQwen3MoeAttention,
505-
# Qwen3MoeRotaryEmbedding: QEffQwen3MoeRotaryEmbedding,
506-
# Qwen3MoeSparseMoeBlock: QEffQwen3MoeSparseMoeBlock,
554+
Qwen3MoeForCausalLM: QEffQwen3MoeForCausalLM,
555+
Qwen3MoeModel: QEffQwen3MoeModel,
556+
Qwen3MoeDecoderLayer: QEffQwen3MoeDecoderLayer,
557+
Qwen3MoeAttention: QEffQwen3MoeAttention,
558+
Qwen3MoeRotaryEmbedding: QEffQwen3MoeRotaryEmbedding,
559+
Qwen3MoeSparseMoeBlock: QEffQwen3MoeSparseMoeBlock,
507560
# Gemma2
508561
Gemma2Attention: QEffGemma2Attention,
509562
Gemma2DecoderLayer: QEffGemma2DecoderLayer,
@@ -551,6 +604,9 @@ class KVCacheTransform(ModuleMappingTransform):
551604
MistralDecoderLayer: QEffMistralDecoderLayer,
552605
MistralModel: QEffMistralModel,
553606
MistralForCausalLM: QEffMistralForCausalLM,
607+
# Mistral3
608+
Mistral3ForConditionalGeneration: QEffMistral3ForConditionalGeneration,
609+
Mistral3Model: QEffMistral3Model,
554610
# Mixtral
555611
MixtralAttention: QEffMixtralAttention,
556612
MixtralSparseMoeBlock: QEffMixtralSparseMoeBlock,
@@ -572,36 +628,34 @@ class KVCacheTransform(ModuleMappingTransform):
572628
PhiDecoderLayer: QEffPhiDecoderLayer,
573629
PhiModel: QEffPhiModel,
574630
PhiForCausalLM: QEffPhiForCausalLM,
631+
# Pixtral
632+
PixtralVisionModel: QEffPixtralVisionModel,
575633
# Qwen2
576634
Qwen2Attention: QEffQwen2Attention,
577635
Qwen2DecoderLayer: QEffQwen2DecoderLayer,
578636
Qwen2Model: QEffQwen2Model,
579637
Qwen2ForCausalLM: QEffQwen2ForCausalLM,
580-
# Qwen2.5 VL
581-
Qwen2_5_VLForConditionalGeneration: QEffQwen_2_5_vl_ForConditionalGeneration,
582-
# Qwen2_5_VLModel: QEffQwen2_5_VLModel,
583-
Qwen2_5_VLTextModel: QEffQwen2_5_VLTextModel,
584638
# Qwen3
585639
Qwen3Attention: QEffQwen3Attention,
586640
Qwen3DecoderLayer: QEffQwen3DecoderLayer,
587641
Qwen3Model: QEffQwen3Model,
588642
Qwen3ForCausalLM: QEffQwen3ForCausalLM,
589643
# Qwen2.5 VL
590-
# Qwen2_5_VLForConditionalGeneration: QEffQwen_2_5_vl_ForConditionalGeneration,
591-
# Qwen2_5_VLModel: QEffQwen2_5_VLModel,
644+
Qwen2_5_VLForConditionalGeneration: QEffQwen_2_5_vl_ForConditionalGeneration,
645+
Qwen2_5_VLModel: QEffQwen2_5_VLModel,
592646
Qwen2_5_VLAttention: QEffQwen2_5_VLAttention,
593647
Qwen2_5_VLDecoderLayer: QEffQwen2_5_VLDecoderLayer,
594648
Qwen2_5_VisionTransformerPretrainedModel: QEffQwen2_5_VisionTransformerPretrainedModel,
595649
Qwen2_5_VLVisionAttention: QEffQwen2_5_VLVisionAttention,
596-
# Qwen3vl
650+
Qwen2_5_VLTextModel: QEffQwen2_5_VLTextModel,
651+
# Qwen3 VL
597652
Qwen3VLForConditionalGeneration: QEffQwen3VLForConditionalGeneration,
598653
Qwen3VLModel: QEffQwen3VLModel,
599654
Qwen3VLTextAttention: QEffQwen3VLTextAttention,
600655
Qwen3VLTextDecoderLayer: QEffQwen3VLTextDecoderLayer,
601656
Qwen3VLVisionAttention: QEffQwen3VLVisionAttention,
602657
Qwen3VLVisionModel: QEffQwen3VLVisionModel,
603658
Qwen3VLTextModel: QEffQwen3VLTextModel,
604-
# Qwen3VLTextRotaryEmbedding: QEffQwen3VLTextRotaryEmbedding, # reusing decoder layer for rotary embedding as they are tightly coupled in forward pass
605659
# Starcoder2
606660
Starcoder2Attention: QEffStarcoder2Attention,
607661
Starcoder2DecoderLayer: QEFFStarcoder2DecoderLayer,
@@ -612,6 +666,11 @@ class KVCacheTransform(ModuleMappingTransform):
612666
GPTBigCodeBlock: QEffGPTBigCodeBlock,
613667
GPTBigCodeModel: QEffGPTBigCodeModel,
614668
GPTBigCodeForCausalLM: QEffGPTBigCodeForCausalLM,
669+
# Olmo2
670+
Olmo2Attention: QEffOlmo2Attention,
671+
Olmo2DecoderLayer: QEffOlmo2DecoderLayer,
672+
Olmo2Model: QEffOlmo2Model,
673+
Olmo2ForCausalLM: QEffOlmo2ForCausalLM,
615674
# Whisper encoder and decoder layers
616675
WhisperPositionalEmbedding: QEffWhisperPositionalEmbedding,
617676
WhisperAttention: QEffWhisperAttention,
@@ -681,7 +740,7 @@ class SpDTransform:
681740
# Llama
682741
QEffLlamaForCausalLM,
683742
QEffQwen2ForCausalLM,
684-
# QEffQwen3ForCausalLM,
743+
QEffQwen3ForCausalLM,
685744
}
686745

687746
@classmethod
@@ -747,7 +806,7 @@ class SamplerTransform:
747806
QEffMptForCausalLM,
748807
QEffPhi3ForCausalLM,
749808
QEffQwen2ForCausalLM,
750-
# QEffQwen_2_5_vl_DecoderWrapper,
809+
QEffQwen_2_5_vl_DecoderWrapper,
751810
}
752811

753812
@classmethod
@@ -793,6 +852,32 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform):
793852
"get_qeff_language_decoder": QEffInternVLModel.get_qeff_language_decoder,
794853
},
795854
"InternVisionEmbeddings": {"forward": QEffInternVisionEmbeddings.forward},
855+
# Mapping for Molmo
856+
"MolmoForCausalLM": {
857+
"forward": QEffMolmoModel.forward,
858+
"get_qeff_vision_encoder": QEffMolmoModel.get_qeff_vision_encoder,
859+
"get_qeff_language_decoder": QEffMolmoModel.get_qeff_language_decoder,
860+
"get_specializations": QEffMolmoModel.get_specializations,
861+
"get_onnx_dynamic_axes": QEffMolmoModel.get_onnx_dynamic_axes,
862+
"get_output_names": QEffMolmoModel.get_output_names,
863+
"get_dummy_inputs": QEffMolmoModel.get_dummy_inputs,
864+
"get_inputs_info": QEffMolmoModel.get_inputs_info,
865+
},
866+
"RMSLayerNorm": {"forward": CustomRMSNormAIC.forward},
867+
# "MolmoForCausalLM": {"forward": QEffMolmoForCausalLM.forward},
868+
"Molmo": {"forward": QEffMolmo.forward},
869+
"MolmoSequentialBlock": {
870+
"forward": QEffMolmoSequentialBlock.forward,
871+
"attention": QEffMolmoBlock.attention,
872+
"__qeff_init__": QEffMolmoBlock.__qeff_init__,
873+
},
874+
"MolmoBlock": {
875+
"attention": QEffMolmoBlock.attention,
876+
"__qeff_init__": QEffMolmoBlock.__qeff_init__,
877+
},
878+
"MultiHeadDotProductAttention": {
879+
"forward": QEffMultiHeadDotProductAttention.forward,
880+
},
796881
# Mapping for grok1 model
797882
"Grok1ModelForCausalLM": {"forward": QEffGrok1ModelForCausalLM.forward},
798883
"Grok1Model": {

QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,6 @@ def eager_attention_forward(
309309

310310

311311
class QEffQwen3VLTextAttention(Qwen3VLTextAttention):
312-
313312
def forward(
314313
self,
315314
hidden_states: torch.Tensor,
@@ -1024,7 +1023,6 @@ def prepare_inputs_for_generation(self, inputs, prefill_seq_len=128, batch_size=
10241023
inputs["position_ids"], pad=(0, padded_len - input_ids_length), mode="constant", value=-1
10251024
)
10261025

1027-
10281026
return inputs
10291027

10301028
def get_inputs_info(self):

0 commit comments

Comments
 (0)