@@ -186,16 +186,40 @@ def forward(
186
186
if self .archparams .normalize_embeddings :
187
187
hidden_states *= cfg .hidden_size ** 0.5
188
188
189
- # Negative tokens during quantization are noise tokens
189
+ # Rows with negative tokens during quantization are noise tokens
190
190
191
191
if kwargs .get ("negative_ids_noise" ):
192
- mask = (input_ids < 0 ).unsqueeze (- 1 )
193
- unmasked_values = hidden_states [~ mask .expand_as (hidden_states )].float ()
194
- mean , std = unmasked_values .mean (), unmasked_values .std ()
195
- noise = torch .randn_like (hidden_states , dtype = torch .float )
196
- noise = noise * std + mean
197
- noise = noise .half ()
198
- hidden_states = torch .where (mask , noise , hidden_states )
192
+
193
+ n = 0
194
+ mean = torch .tensor ([0.0 ], dtype = torch .float , device = hidden_states .device )
195
+ M2 = torch .tensor ([0.0 ], dtype = torch .float , device = hidden_states .device )
196
+
197
+ for i in range (input_ids .shape [0 ]):
198
+ if input_ids [i ][0 ] < 0 :
199
+ continue
200
+
201
+ er = hidden_states [i ].float ()
202
+ n += er .numel ()
203
+ delta = er - mean
204
+ mean += delta .sum () / n
205
+ delta2 = er - mean
206
+ M2 += (delta * delta2 ).sum ()
207
+ del er
208
+ del delta
209
+ del delta2
210
+
211
+ if n > 1 :
212
+ std = torch .sqrt (M2 / (n - 1 ))
213
+
214
+ for i in range (input_ids .shape [0 ]):
215
+ if input_ids [i ][0 ] >= 0 :
216
+ continue
217
+
218
+ er = hidden_states [i ]
219
+ noise = torch .randn (er .size (), dtype = torch .float , device = hidden_states .device ) * std + mean
220
+ er .copy_ (noise .half ())
221
+ del er
222
+ del noise
199
223
200
224
# Move to pinned temp buffer for TP
201
225
0 commit comments