@@ -624,6 +624,7 @@ def __init__(
624624 self ,
625625 config : Optional [config_mllama .MllamaTextConfig ] = None ,
626626 layer_idx : Optional [int ] = None ,
627+ quant_config : Optional [QuantizationConfig ] = None ,
627628 ):
628629 super ().__init__ ()
629630 self .config = config
@@ -648,12 +649,14 @@ def __init__(
648649 self .num_heads ,
649650 self .num_key_value_heads ,
650651 bias = False ,
652+ quant_config = quant_config ,
651653 )
652654 self .o_proj = RowParallelLinear (
653655 self .num_heads * self .head_dim ,
654656 self .hidden_size ,
655657 bias = False ,
656658 input_is_parallel = True ,
659+ quant_config = quant_config ,
657660 )
658661 # vllm.model_executor.layers.layernorm.RMSNorm has precision issue,
659662 # use huggingface's instead
@@ -708,13 +711,15 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
708711 """Cross-attention transformer block with tanh-gated attention
709712 and feedforward."""
710713
711- def __init__ (self , config : config_mllama .MllamaTextConfig , layer_idx : int ) \
714+ def __init__ (self , config : config_mllama .MllamaTextConfig , layer_idx : int ,
715+ quant_config : Optional [QuantizationConfig ]) \
712716 -> None :
713717 super ().__init__ ()
714718 self .layer_idx = layer_idx
715719 self .cross_attn = MllamaTextCrossAttention (
716720 config = config ,
717721 layer_idx = layer_idx ,
722+ quant_config = quant_config ,
718723 )
719724
720725 self .input_layernorm = RMSNorm (config .hidden_size ,
@@ -725,6 +730,7 @@ def __init__(self, config: config_mllama.MllamaTextConfig, layer_idx: int) \
725730 hidden_size = config .hidden_size ,
726731 intermediate_size = config .intermediate_size ,
727732 hidden_act = config .hidden_act ,
733+ quant_config = quant_config ,
728734 )
729735 self .post_attention_layernorm = RMSNorm (config .hidden_size ,
730736 eps = config .rms_norm_eps )
@@ -780,7 +786,8 @@ def __init__(self, config: config_mllama.MllamaTextConfig,
780786 for layer_idx in range (config .num_hidden_layers ):
781787 if layer_idx in self .cross_attention_layers :
782788 layers .append (
783- MllamaCrossAttentionDecoderLayer (config , layer_idx ))
789+ MllamaCrossAttentionDecoderLayer (
790+ config , layer_idx , quant_config = quant_config ))
784791 else :
785792 # TODO: force LlamaDecoderLayer to config.attention_bias=False
786793 layers .append (
0 commit comments