Skip to content

Commit fc42332

Browse files
Added One hot fix for MOE model with subfunction (quic#777)
Signed-off-by: Abhishek Kumar Singh <sabhis@qti.qualcomm.com>
1 parent e8e5c43 commit fc42332

File tree

4 files changed

+112
-85
lines changed

4 files changed

+112
-85
lines changed

QEfficient/transformers/models/granitemoe/modeling_granitemoe.py

Lines changed: 100 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,20 @@
88
from typing import List, Optional, Tuple, Type, Union
99

1010
import torch
11-
import torch.nn.functional as F
1211
from torch import nn
1312
from transformers.cache_utils import Cache, StaticCache
1413
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
1514
from transformers.modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast
1615
from transformers.models.granitemoe.modeling_granitemoe import (
1716
GraniteMoeAttention,
1817
GraniteMoeConfig,
18+
GraniteMoeDecoderLayer,
1919
GraniteMoeForCausalLM,
2020
GraniteMoeModel,
2121
GraniteMoeMoE,
2222
GraniteMoeParallelExperts,
2323
GraniteMoeRotaryEmbedding,
2424
GraniteMoeTopKGating,
25-
load_balancing_loss_func,
26-
logger,
2725
repeat_kv,
2826
rotate_half,
2927
)
@@ -198,6 +196,88 @@ def eager_attention_forward(
198196
return attn_output, attn_weights
199197

200198

199+
class QEffGraniteMoeDecoderLayer(GraniteMoeDecoderLayer):
200+
"""
201+
Copied from GraniteForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/granite/modeling_granite.py
202+
The only differences are:
203+
- add new args batch idx for the CB models although its not supported yet.
204+
"""
205+
206+
def forward(
207+
self,
208+
hidden_states: torch.Tensor,
209+
attention_mask: Optional[torch.Tensor] = None,
210+
position_ids: Optional[torch.LongTensor] = None,
211+
past_key_value: Optional[Cache] = None,
212+
output_attentions: Optional[bool] = False,
213+
use_cache: Optional[bool] = False,
214+
cache_position: Optional[torch.LongTensor] = None,
215+
output_router_logits: Optional[bool] = False,
216+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
217+
**kwargs,
218+
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
219+
"""
220+
Args:
221+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
222+
attention_mask (`torch.FloatTensor`, *optional*):
223+
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
224+
query_sequence_length, key_sequence_length)` if default attention is used.
225+
output_attentions (`bool`, *optional*):
226+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
227+
returned tensors for more detail.
228+
use_cache (`bool`, *optional*):
229+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
230+
(see `past_key_values`).
231+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
232+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
233+
Indices depicting the position of the input sequence tokens in the sequence
234+
output_router_logits (`bool`, *optional*):
235+
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
236+
should not be returned during inference.
237+
position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
238+
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
239+
with `head_dim` being the embedding dimension of each attention head.
240+
kwargs (`dict`, *optional*):
241+
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
242+
into the model
243+
"""
244+
residual = hidden_states
245+
246+
hidden_states = self.input_layernorm(hidden_states)
247+
248+
# Self Attention
249+
hidden_states, self_attn_weights = self.self_attn(
250+
hidden_states=hidden_states,
251+
attention_mask=attention_mask,
252+
position_ids=position_ids,
253+
past_key_value=past_key_value,
254+
output_attentions=output_attentions,
255+
use_cache=use_cache,
256+
cache_position=cache_position,
257+
position_embeddings=position_embeddings,
258+
**kwargs,
259+
)
260+
261+
hidden_states = residual + hidden_states * self.residual_multiplier
262+
263+
# Fully Connected
264+
residual = hidden_states
265+
hidden_states = self.post_attention_layernorm(hidden_states)
266+
hidden_states, router_logits = self.block_sparse_moe(hidden_states)
267+
268+
hidden_states = residual + hidden_states * self.residual_multiplier
269+
270+
outputs = (hidden_states,)
271+
272+
if output_attentions:
273+
outputs += (self_attn_weights,)
274+
275+
if output_router_logits:
276+
outputs += (router_logits,)
277+
278+
return outputs
279+
280+
201281
class QEffGraniteMoeModel(GraniteMoeModel):
202282
"""Copied from GraniteMoeModel: https://github.com/huggingface/transformers/blob/main/src/transformers/models/granitemoe/modeling_granitemoe.py
203283
The only differences are:
@@ -227,39 +307,19 @@ def forward(
227307
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
228308
)
229309
use_cache = use_cache if use_cache is not None else self.config.use_cache
230-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
231310

232311
if (input_ids is None) ^ (inputs_embeds is not None):
233312
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
234313

235-
if self.gradient_checkpointing and self.training and use_cache:
236-
logger.warning_once(
237-
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
238-
)
239-
use_cache = False
240-
241314
if inputs_embeds is None:
242315
inputs_embeds = self.embed_tokens(input_ids)
243316

244317
inputs_embeds = inputs_embeds * self.embedding_multiplier # main diff with Llama
245318

246-
# TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
247-
# if not isinstance(past_key_values, (type(None), Cache)):
248-
# raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
249-
250-
# if use_cache and past_key_values is None:
251-
# past_key_values = QEffDynamicCache()
252-
319+
return_legacy_cache = False
253320
if use_cache and not isinstance(past_key_values, Cache):
254-
if past_key_values is None:
255-
past_key_values = QEffDynamicCache()
256-
else:
257-
past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values)
258-
logger.warning_once(
259-
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
260-
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
261-
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
262-
)
321+
return_legacy_cache = True
322+
past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values)
263323

264324
if cache_position is None:
265325
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
@@ -321,18 +381,15 @@ def forward(
321381
if output_hidden_states:
322382
all_hidden_states += (hidden_states,)
323383

324-
if not return_dict:
325-
return tuple(
326-
v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None
327-
)
384+
if return_legacy_cache:
385+
past_key_values = past_key_values.to_legacy_cache()
328386

329-
output = MoeModelOutputWithPast(
387+
return MoeModelOutputWithPast(
330388
last_hidden_state=hidden_states,
331389
past_key_values=past_key_values,
332390
hidden_states=all_hidden_states,
333391
attentions=all_self_attns,
334392
)
335-
return output if return_dict else output.to_tuple()
336393

337394
def _update_causal_mask(
338395
self,
@@ -435,7 +492,13 @@ def forward(self, hidden_states):
435492
logits = self.layer(hidden_states).float()
436493
top_k_logits, top_k_indices = torch.topk(logits, self.top_k, dim=1) # [num_tokens, top_k]
437494
top_k_gates = torch.softmax(top_k_logits, dim=1).type_as(hidden_states) # [num_tokens, top_k]
438-
expert_mask = F.one_hot(top_k_indices, num_classes=self.num_experts).permute(2, 1, 0)
495+
496+
B, K = top_k_indices.shape
497+
E = int(self.num_experts)
498+
flat = top_k_indices.reshape(-1)
499+
mask = torch.zeros((B * K, E), dtype=torch.int64, device=top_k_indices.device)
500+
mask[torch.arange(B * K, device=flat.device), flat] = 1
501+
expert_mask = mask.view(B, K, E).permute(2, 1, 0)
439502
return top_k_gates, expert_mask, logits, self.num_experts
440503

441504

@@ -511,14 +574,9 @@ def forward(
511574
comp_ctx_lengths: Optional[torch.LongTensor] = None,
512575
batch_index: Optional[torch.LongTensor] = None,
513576
inputs_embeds: Optional[torch.FloatTensor] = None,
514-
labels: Optional[torch.LongTensor] = None,
515577
use_cache: Optional[bool] = None,
516-
output_attentions: Optional[bool] = None,
517578
output_hidden_states: Optional[bool] = None,
518-
output_router_logits: Optional[bool] = None,
519-
return_dict: Optional[bool] = None,
520579
cache_position: Optional[torch.LongTensor] = None,
521-
logits_to_keep: Union[int, torch.Tensor] = 0,
522580
**kwargs,
523581
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
524582
r"""
@@ -551,11 +609,9 @@ def forward(
551609
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
552610
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
553611
```"""
554-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
555612
output_hidden_states = (
556613
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
557614
)
558-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
559615

560616
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
561617
outputs = self.model(
@@ -567,57 +623,21 @@ def forward(
567623
batch_index=batch_index,
568624
inputs_embeds=inputs_embeds,
569625
use_cache=use_cache,
570-
output_attentions=output_attentions,
571626
output_hidden_states=output_hidden_states,
572-
return_dict=return_dict,
573627
cache_position=cache_position,
574628
**kwargs,
575629
)
576630

577-
hidden_states = outputs[0]
578631
# Cast to INT32 to avoid issue while running in ONNXRT
579632
logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True)
580-
hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index]
581-
582-
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
583-
logits = self.lm_head(hidden_states[:, slice_indices, :])
584-
logits = logits / self.config.logits_scaling
585-
586-
loss = None
587-
if labels is not None:
588-
# Upcast to float if we need to compute the loss to avoid potential precision issues
589-
logits = logits.float()
590-
# Flatten the tokens
591-
loss = self.loss_function(
592-
logits,
593-
labels,
594-
vocab_size=self.config.vocab_size,
595-
**kwargs,
596-
)
597-
598-
aux_loss = None
599-
if output_router_logits:
600-
aux_loss = load_balancing_loss_func(
601-
outputs.router_logits if return_dict else outputs[-1],
602-
self.num_experts,
603-
self.num_experts_per_tok,
604-
attention_mask,
605-
)
606-
if labels is not None:
607-
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
608-
609-
if not return_dict:
610-
output = (logits,) + outputs[1:]
611-
if output_router_logits:
612-
output = (aux_loss,) + output
613-
return (loss,) + output if loss is not None else output
633+
hidden_states = outputs.last_hidden_state[torch.arange(position_ids.shape[0]).view(-1, 1), logit_index]
634+
logits = self.lm_head(hidden_states).float()
635+
# logits = logits / self.config.logits_scaling
614636

615637
return MoeCausalLMOutputWithPast(
616-
loss=loss,
617-
aux_loss=aux_loss,
638+
loss=None,
618639
logits=logits,
619640
past_key_values=outputs.past_key_values,
620641
hidden_states=outputs.hidden_states,
621642
attentions=outputs.attentions,
622-
router_logits=outputs.router_logits,
623643
)

QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,14 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
219219

220220
# One hot encode the selected experts to create an expert mask
221221
# this will be used to easily index which expert is going to be sollicitated
222-
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
222+
# selected_experts: [B, K]
223+
B, K = selected_experts.shape
224+
E = int(self.num_experts)
225+
flat = selected_experts.reshape(-1)
226+
mask = torch.zeros((B * K, E), dtype=torch.int64)
227+
mask[torch.arange(B * K), flat] = 1
228+
mask_bke = mask.view(B, K, E)
229+
expert_mask = mask_bke.permute(2, 1, 0)
223230

224231
# Loop over all available experts in the model and perform the computation on each expert
225232
for expert_idx in range(self.num_experts):

QEfficient/utils/torch_patches.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
import torch.onnx.utils as onnx_utils
1212
from torch import _C
1313

14-
from QEfficient.utils.logging_utils import logger
15-
1614
# Store original references before patching
1715
_original_setup_trace_module_map = onnx_utils._setup_trace_module_map
1816
_original_get_module_attributes = getattr(onnx_utils, "_get_module_attributes", None)
@@ -43,7 +41,8 @@ def _track_module_attributes_forward_hook(module, input, output):
4341
onnx_attrs = {} # HACK: to reduce export time # TODO: study behaviour across models
4442
_C._jit_pass_onnx_track_scope_attributes(graph, onnx_attrs)
4543
except Exception:
46-
logger.warning("Failed to track ONNX scope attributes, Skipping this step.")
44+
# Silently skip: scope-attribute tracking is best-effort and not required for export.
45+
pass
4746

4847
for m in model.modules():
4948
m.register_forward_hook(_track_module_attributes_forward_hook)

tests/transformers/models/test_subfunction.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
("gptj", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}),
2424
("llama", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}),
2525
("mistral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}),
26-
# ("mixtral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}),
26+
("mixtral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}),
2727
("mpt", 256, 2, 4, 128, 512, 127, {}),
2828
("phi", 256, 2, 4, 128, 512, 127, {}),
2929
("phi3", 256, 2, 4, 128, 512, 127, {"pad_token_id": 0}),
@@ -34,6 +34,7 @@
3434
("olmo2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}),
3535
("gpt_oss", 256, 3, 4, 128, 512, 127, {"num_key_value_heads": 2}),
3636
("qwen3_moe", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}),
37+
("granitemoe", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}),
3738
]
3839

3940
configs = [

0 commit comments

Comments
 (0)