@@ -31,7 +31,10 @@ def __init__(self, quantization_args: QuantizationArgs):
3131
3232 @torch .no_grad ()
3333 def forward (
34- self , observed : Tensor , g_idx : Optional [Tensor ] = None
34+ self ,
35+ observed : Tensor ,
36+ g_idx : Optional [Tensor ] = None ,
37+ base_name : Optional [str ] = None ,
3538 ) -> Tuple [FloatTensor , IntTensor ]:
3639 """
3740 maps directly to get_qparams
@@ -40,8 +43,9 @@ def forward(
4043 :param g_idx: optional mapping from column index to group index
4144 :return: tuple of scale and zero point based on last observed value
4245 """
46+ # breakpoint()
4347 self .record_observed_tokens (observed )
44- return self .get_qparams (observed = observed , g_idx = g_idx )
48+ return self .get_qparams (observed = observed , g_idx = g_idx , base_name = base_name )
4549
4650 def calculate_qparams (
4751 self ,
@@ -66,6 +70,7 @@ def get_qparams(
6670 self ,
6771 observed : Optional [Tensor ] = None ,
6872 g_idx : Optional [Tensor ] = None ,
73+ base_name : Optional [str ] = None ,
6974 ) -> Tuple [FloatTensor , IntTensor ]:
7075 """
7176 Convenience function to wrap overwritten calculate_qparams
@@ -123,26 +128,24 @@ def get_qparams(
123128 self ._zero_point [:, group_index ] = zero_point .squeeze (1 )
124129
125130 elif self .quantization_args .strategy == QuantizationStrategy .CHANNEL :
126- # assume observed is transposed, because its the output, hence use dim 0
127- # we pass in [1, 8, 2048, 128] for k_states
128- # normally per channel: (output_dim, 1) and you have as many scales as the output_dim
129- # we want 8 - num_k_head_scales? or
130- #breakpoint()
131-
132- # weight --> get scales along the first dimension (output dim is first dim)
133- # weight shape (output_dim, input_dim)
134- # self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0)
135- # output when applied to the weight: (output_dim, 1)
136-
137-
138- # for outputs:
139- self ._scale , self ._zero_point = self .get_qparams_along_dim (observed , 2 )
140- self ._scale = self ._scale .squeeze (1 )
141- self ._zero_point = self ._zero_point .squeeze (1 )
142- # why is the output of self._scale: [1, 1, 1]
143-
144-
145-
131+ if base_name == "output" :
132+ # the last dimension is the hidden dimension
133+ # shape of [1,1, num_key_value_heads * head_dim]
134+ scale , zero_point = self .get_qparams_along_dim (
135+ observed , observed .ndim - 1
136+ )
137+ self ._scale = (
138+ scale .squeeze ()
139+ ) # shape of [num_key_value_heads * head_dim]
140+ self ._zero_point = (
141+ zero_point .squeeze ()
142+ ) # shape of [num_key_value_heads * head_dim]
143+ else :
144+ # weight or input
145+ # assume observed is transposed, because its the output, hence use dim 0
146+ self ._scale , self ._zero_point = self .get_qparams_along_dim (
147+ observed , 0
148+ )
146149
147150 elif self .quantization_args .strategy == QuantizationStrategy .TOKEN :
148151 # use dim 1, assume the obsersed.shape = [batch, token, hidden]
0 commit comments