Skip to content

Commit 0ffa4ea

Browse files
Fix for Qwen 2.5 VL with subfunction (quic#733)
Signed-off-by: Abhishek Kumar Singh <sabhis@qti.qualcomm.com>
1 parent 3a8e5e9 commit 0ffa4ea

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

QEfficient/transformers/models/pytorch_transforms.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -912,9 +912,16 @@ def get_decoder_layer_classes_for_export(model: nn.Module) -> set:
912912

913913
# Filter to only include classes that are actually used in the current model
914914
model_decoder_classes = set()
915-
for module in model.modules():
916-
if module.__class__ in decoder_layer_classes:
917-
model_decoder_classes.add(module.__class__)
915+
model_class_name = model.__class__.__name__
916+
if "EncoderWrapper" in model_class_name:
917+
model_decoder_classes.update(
918+
module.__class__ for module in model.modules() if "Qwen2_5_VLVisionBlock" in module.__class__.__name__
919+
)
920+
return model_decoder_classes
921+
922+
model_decoder_classes.update(
923+
module.__class__ for module in model.modules() if module.__class__ in decoder_layer_classes
924+
)
918925

919926
return model_decoder_classes
920927

QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,10 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section, unsqu
7474
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
7575
"""
7676

77-
mrope_section = mrope_section * 2
7877
cos = cos[position_ids]
7978
sin = sin[position_ids]
80-
81-
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim)
82-
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim)
79+
cos = torch.cat([cos[0, ..., 0:32], cos[1, ..., 32:80], cos[2, ..., 80:128]], dim=-1).unsqueeze(unsqueeze_dim)
80+
sin = torch.cat([sin[0, ..., 0:32], sin[1, ..., 32:80], sin[2, ..., 80:128]], dim=-1).unsqueeze(unsqueeze_dim)
8381

8482
q_embed = (q * cos) + (rotate_half(q) * sin)
8583
k_embed = (k * cos) + (rotate_half(k) * sin)

0 commit comments

Comments
 (0)