88from typing import List , Optional , Tuple , Type , Union
99
1010import torch
11- import torch .nn .functional as F
1211from torch import nn
1312from transformers .cache_utils import Cache , StaticCache
1413from transformers .modeling_attn_mask_utils import AttentionMaskConverter
1514from transformers .modeling_outputs import BaseModelOutputWithPast , MoeCausalLMOutputWithPast , MoeModelOutputWithPast
1615from transformers .models .granitemoe .modeling_granitemoe import (
1716 GraniteMoeAttention ,
1817 GraniteMoeConfig ,
18+ GraniteMoeDecoderLayer ,
1919 GraniteMoeForCausalLM ,
2020 GraniteMoeModel ,
2121 GraniteMoeMoE ,
2222 GraniteMoeParallelExperts ,
2323 GraniteMoeRotaryEmbedding ,
2424 GraniteMoeTopKGating ,
25- load_balancing_loss_func ,
26- logger ,
2725 repeat_kv ,
2826 rotate_half ,
2927)
@@ -198,6 +196,88 @@ def eager_attention_forward(
198196 return attn_output , attn_weights
199197
200198
199+ class QEffGraniteMoeDecoderLayer (GraniteMoeDecoderLayer ):
200+ """
201+ Copied from GraniteForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/granite/modeling_granite.py
202+ The only differences are:
203+ - add new args batch idx for the CB models although its not supported yet.
204+ """
205+
206+ def forward (
207+ self ,
208+ hidden_states : torch .Tensor ,
209+ attention_mask : Optional [torch .Tensor ] = None ,
210+ position_ids : Optional [torch .LongTensor ] = None ,
211+ past_key_value : Optional [Cache ] = None ,
212+ output_attentions : Optional [bool ] = False ,
213+ use_cache : Optional [bool ] = False ,
214+ cache_position : Optional [torch .LongTensor ] = None ,
215+ output_router_logits : Optional [bool ] = False ,
216+ position_embeddings : Optional [tuple [torch .Tensor , torch .Tensor ]] = None , # necessary, but kept here for BC
217+ ** kwargs ,
218+ ) -> tuple [torch .FloatTensor , Optional [tuple [torch .FloatTensor , torch .FloatTensor ]]]:
219+ """
220+ Args:
221+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
222+ attention_mask (`torch.FloatTensor`, *optional*):
223+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
224+ query_sequence_length, key_sequence_length)` if default attention is used.
225+ output_attentions (`bool`, *optional*):
226+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
227+ returned tensors for more detail.
228+ use_cache (`bool`, *optional*):
229+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
230+ (see `past_key_values`).
231+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
232+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
233+ Indices depicting the position of the input sequence tokens in the sequence
234+ output_router_logits (`bool`, *optional*):
235+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
236+ should not be returned during inference.
237+ position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
238+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
239+ with `head_dim` being the embedding dimension of each attention head.
240+ kwargs (`dict`, *optional*):
241+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
242+ into the model
243+ """
244+ residual = hidden_states
245+
246+ hidden_states = self .input_layernorm (hidden_states )
247+
248+ # Self Attention
249+ hidden_states , self_attn_weights = self .self_attn (
250+ hidden_states = hidden_states ,
251+ attention_mask = attention_mask ,
252+ position_ids = position_ids ,
253+ past_key_value = past_key_value ,
254+ output_attentions = output_attentions ,
255+ use_cache = use_cache ,
256+ cache_position = cache_position ,
257+ position_embeddings = position_embeddings ,
258+ ** kwargs ,
259+ )
260+
261+ hidden_states = residual + hidden_states * self .residual_multiplier
262+
263+ # Fully Connected
264+ residual = hidden_states
265+ hidden_states = self .post_attention_layernorm (hidden_states )
266+ hidden_states , router_logits = self .block_sparse_moe (hidden_states )
267+
268+ hidden_states = residual + hidden_states * self .residual_multiplier
269+
270+ outputs = (hidden_states ,)
271+
272+ if output_attentions :
273+ outputs += (self_attn_weights ,)
274+
275+ if output_router_logits :
276+ outputs += (router_logits ,)
277+
278+ return outputs
279+
280+
201281class QEffGraniteMoeModel (GraniteMoeModel ):
202282 """Copied from GraniteMoeModel: https://github.com/huggingface/transformers/blob/main/src/transformers/models/granitemoe/modeling_granitemoe.py
203283 The only differences are:
@@ -227,39 +307,19 @@ def forward(
227307 output_hidden_states if output_hidden_states is not None else self .config .output_hidden_states
228308 )
229309 use_cache = use_cache if use_cache is not None else self .config .use_cache
230- return_dict = return_dict if return_dict is not None else self .config .use_return_dict
231310
232311 if (input_ids is None ) ^ (inputs_embeds is not None ):
233312 raise ValueError ("You must specify exactly one of input_ids or inputs_embeds" )
234313
235- if self .gradient_checkpointing and self .training and use_cache :
236- logger .warning_once (
237- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
238- )
239- use_cache = False
240-
241314 if inputs_embeds is None :
242315 inputs_embeds = self .embed_tokens (input_ids )
243316
244317 inputs_embeds = inputs_embeds * self .embedding_multiplier # main diff with Llama
245318
246- # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
247- # if not isinstance(past_key_values, (type(None), Cache)):
248- # raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
249-
250- # if use_cache and past_key_values is None:
251- # past_key_values = QEffDynamicCache()
252-
319+ return_legacy_cache = False
253320 if use_cache and not isinstance (past_key_values , Cache ):
254- if past_key_values is None :
255- past_key_values = QEffDynamicCache ()
256- else :
257- past_key_values = QEffDynamicCache .from_legacy_cache (past_key_values )
258- logger .warning_once (
259- "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
260- "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
261- "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
262- )
321+ return_legacy_cache = True
322+ past_key_values = QEffDynamicCache .from_legacy_cache (past_key_values )
263323
264324 if cache_position is None :
265325 past_seen_tokens = past_key_values .get_seq_length () if past_key_values is not None else 0
@@ -321,18 +381,15 @@ def forward(
321381 if output_hidden_states :
322382 all_hidden_states += (hidden_states ,)
323383
324- if not return_dict :
325- return tuple (
326- v for v in [hidden_states , past_key_values , all_hidden_states , all_self_attns ] if v is not None
327- )
384+ if return_legacy_cache :
385+ past_key_values = past_key_values .to_legacy_cache ()
328386
329- output = MoeModelOutputWithPast (
387+ return MoeModelOutputWithPast (
330388 last_hidden_state = hidden_states ,
331389 past_key_values = past_key_values ,
332390 hidden_states = all_hidden_states ,
333391 attentions = all_self_attns ,
334392 )
335- return output if return_dict else output .to_tuple ()
336393
337394 def _update_causal_mask (
338395 self ,
@@ -435,7 +492,13 @@ def forward(self, hidden_states):
435492 logits = self .layer (hidden_states ).float ()
436493 top_k_logits , top_k_indices = torch .topk (logits , self .top_k , dim = 1 ) # [num_tokens, top_k]
437494 top_k_gates = torch .softmax (top_k_logits , dim = 1 ).type_as (hidden_states ) # [num_tokens, top_k]
438- expert_mask = F .one_hot (top_k_indices , num_classes = self .num_experts ).permute (2 , 1 , 0 )
495+
496+ B , K = top_k_indices .shape
497+ E = int (self .num_experts )
498+ flat = top_k_indices .reshape (- 1 )
499+ mask = torch .zeros ((B * K , E ), dtype = torch .int64 , device = top_k_indices .device )
500+ mask [torch .arange (B * K , device = flat .device ), flat ] = 1
501+ expert_mask = mask .view (B , K , E ).permute (2 , 1 , 0 )
439502 return top_k_gates , expert_mask , logits , self .num_experts
440503
441504
@@ -511,14 +574,9 @@ def forward(
511574 comp_ctx_lengths : Optional [torch .LongTensor ] = None ,
512575 batch_index : Optional [torch .LongTensor ] = None ,
513576 inputs_embeds : Optional [torch .FloatTensor ] = None ,
514- labels : Optional [torch .LongTensor ] = None ,
515577 use_cache : Optional [bool ] = None ,
516- output_attentions : Optional [bool ] = None ,
517578 output_hidden_states : Optional [bool ] = None ,
518- output_router_logits : Optional [bool ] = None ,
519- return_dict : Optional [bool ] = None ,
520579 cache_position : Optional [torch .LongTensor ] = None ,
521- logits_to_keep : Union [int , torch .Tensor ] = 0 ,
522580 ** kwargs ,
523581 ) -> Union [Tuple , MoeCausalLMOutputWithPast ]:
524582 r"""
@@ -551,11 +609,9 @@ def forward(
551609 >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
552610 "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
553611 ```"""
554- output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
555612 output_hidden_states = (
556613 output_hidden_states if output_hidden_states is not None else self .config .output_hidden_states
557614 )
558- return_dict = return_dict if return_dict is not None else self .config .use_return_dict
559615
560616 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
561617 outputs = self .model (
@@ -567,57 +623,21 @@ def forward(
567623 batch_index = batch_index ,
568624 inputs_embeds = inputs_embeds ,
569625 use_cache = use_cache ,
570- output_attentions = output_attentions ,
571626 output_hidden_states = output_hidden_states ,
572- return_dict = return_dict ,
573627 cache_position = cache_position ,
574628 ** kwargs ,
575629 )
576630
577- hidden_states = outputs [0 ]
578631 # Cast to INT32 to avoid issue while running in ONNXRT
579632 logit_index = position_ids .to (torch .int32 ).argmax (1 , keepdim = True )
580- hidden_states = outputs [0 ][torch .arange (position_ids .shape [0 ]).view (- 1 , 1 ), logit_index ]
581-
582- slice_indices = slice (- logits_to_keep , None ) if isinstance (logits_to_keep , int ) else logits_to_keep
583- logits = self .lm_head (hidden_states [:, slice_indices , :])
584- logits = logits / self .config .logits_scaling
585-
586- loss = None
587- if labels is not None :
588- # Upcast to float if we need to compute the loss to avoid potential precision issues
589- logits = logits .float ()
590- # Flatten the tokens
591- loss = self .loss_function (
592- logits ,
593- labels ,
594- vocab_size = self .config .vocab_size ,
595- ** kwargs ,
596- )
597-
598- aux_loss = None
599- if output_router_logits :
600- aux_loss = load_balancing_loss_func (
601- outputs .router_logits if return_dict else outputs [- 1 ],
602- self .num_experts ,
603- self .num_experts_per_tok ,
604- attention_mask ,
605- )
606- if labels is not None :
607- loss += self .router_aux_loss_coef * aux_loss .to (loss .device ) # make sure to reside in the same device
608-
609- if not return_dict :
610- output = (logits ,) + outputs [1 :]
611- if output_router_logits :
612- output = (aux_loss ,) + output
613- return (loss ,) + output if loss is not None else output
633+ hidden_states = outputs .last_hidden_state [torch .arange (position_ids .shape [0 ]).view (- 1 , 1 ), logit_index ]
634+ logits = self .lm_head (hidden_states ).float ()
635+ # logits = logits / self.config.logits_scaling
614636
615637 return MoeCausalLMOutputWithPast (
616- loss = loss ,
617- aux_loss = aux_loss ,
638+ loss = None ,
618639 logits = logits ,
619640 past_key_values = outputs .past_key_values ,
620641 hidden_states = outputs .hidden_states ,
621642 attentions = outputs .attentions ,
622- router_logits = outputs .router_logits ,
623643 )
0 commit comments