Skip to content

Commit 4bd2239

Browse files
authored
Adding the support of dense models distilled from moe models with the same architecture (#728)
In this PR, we are adding the support of meta-llama/Llama-Guard-4-12B which is a dense model distilled form llama4 scout moe model. The changes in pytorch_transforms.py file can be applied to any dense model distilled from a moe model with supported architecture in QEfficient. Signed-off-by: Vahid Janfaza <vjanfaza@qti.qualcomm.com>
1 parent facae5f commit 4bd2239

File tree

2 files changed

+25
-14
lines changed

2 files changed

+25
-14
lines changed

QEfficient/base/pytorch_transforms.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,16 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
152152
# ---- build the textual prefix once per layer ----------
153153
if is_gpt_oss:
154154
prefix = f"model.layers.{layer_idx}.mlp.experts."
155-
experts = model_tmp.model.layers[layer_idx].mlp.experts
155+
# experts = model_tmp.model.layers[layer_idx].mlp.experts
156+
ff = model_tmp.model.layers[layer_idx].mlp
156157
else:
157158
prefix = f"model.layers.{layer_idx}.feed_forward.experts."
158-
experts = model_tmp.model.layers[layer_idx].feed_forward.experts
159+
# experts = model_tmp.model.layers[layer_idx].feed_forward.experts
160+
ff = model_tmp.model.layers[layer_idx].feed_forward
161+
162+
if not hasattr(ff, "experts"):
163+
continue
164+
experts = ff.experts
159165

160166
fused_key = prefix + "gate_up_proj"
161167
gate_key = prefix + "gate_proj"

QEfficient/transformers/models/llama4/modeling_llama4.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ def forward(
504504

505505
if past_key_value is not None:
506506
chunk_position_ids = position_ids
507-
if self.use_rope:
507+
if self.use_rope and self.config.attention_chunk_size:
508508
chunk_position_ids = torch.where(
509509
chunk_position_ids != -1, chunk_position_ids % self.config.attention_chunk_size, chunk_position_ids
510510
)
@@ -663,10 +663,16 @@ def forward(
663663
causal_mask = _create_causal_mask(
664664
position_ids=position_ids, target_length=past_key_values.layers[3].keys.shape[-2]
665665
)
666-
chunk_position_ids = torch.where(
667-
position_ids != -1, position_ids % self.config.attention_chunk_size, position_ids
668-
)
669-
target_length = min(past_key_values.layers[0].keys.shape[-2], torch.tensor(self.config.attention_chunk_size))
666+
if self.config.attention_chunk_size:
667+
chunk_position_ids = torch.where(
668+
position_ids != -1, position_ids % self.config.attention_chunk_size, position_ids
669+
)
670+
target_length = min(
671+
past_key_values.layers[0].keys.shape[-2], torch.tensor(self.config.attention_chunk_size)
672+
)
673+
else:
674+
chunk_position_ids = position_ids
675+
target_length = past_key_values.layers[0].keys.shape[-2]
670676
chunk_causal_mask = _create_causal_mask(position_ids=chunk_position_ids, target_length=target_length)
671677
causal_mask_mapping = {
672678
"full_attention": causal_mask,
@@ -798,7 +804,7 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len):
798804
is_chunked_attention = torch.tensor(
799805
[bool((i + 1) % 4) for i in range(config.num_hidden_layers)], dtype=torch.bool
800806
)
801-
attention_chunk_size = getattr(config, "attention_chunk_size", seq_len)
807+
attention_chunk_size = getattr(config, "attention_chunk_size", None) or seq_len
802808
global_cache_shape = [batch_size, n_heads, seq_len, d_head]
803809
chunked_cache_shape = [
804810
batch_size,
@@ -967,13 +973,12 @@ def get_specializations(
967973

968974
prefill_seq_len = prefill_seq_len if prefill_seq_len else 32
969975
ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN
976+
attention_chunk_size = getattr(
977+
getattr(getattr(self, "config", None), "text_config", None), "attention_chunk_size", None
978+
)
970979
chunk_ctx_len = min(
971980
ctx_len,
972-
(
973-
self.config.text_config.attention_chunk_size
974-
if hasattr(self, "config")
975-
else constants.LLAMA4_ATTENTION_CHUNK_SIZE
976-
),
981+
(attention_chunk_size if attention_chunk_size is not None else constants.LLAMA4_ATTENTION_CHUNK_SIZE),
977982
)
978983
if (
979984
prefill_seq_len > constants.LLAMA4_MAX_POSITION_EMBEDDINGS
@@ -1158,7 +1163,7 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len):
11581163
is_chunked_attention = torch.tensor(
11591164
[bool((i + 1) % 4) for i in range(config.num_hidden_layers)], dtype=torch.bool
11601165
)
1161-
attention_chunk_size = getattr(config, "attention_chunk_size", seq_len)
1166+
attention_chunk_size = getattr(config, "attention_chunk_size", None) or seq_len
11621167
global_cache_shape = [batch_size, n_heads, seq_len, d_head]
11631168
chunked_cache_shape = [
11641169
batch_size,

0 commit comments

Comments
 (0)