Skip to content

Commit c55656c

Browse files
committed
Fix system RAM consumption while quantizing, fixes #692
1 parent c86f62c commit c55656c

File tree

1 file changed

+32
-8
lines changed

1 file changed

+32
-8
lines changed

exllamav2/embedding.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -186,16 +186,40 @@ def forward(
186186
if self.archparams.normalize_embeddings:
187187
hidden_states *= cfg.hidden_size ** 0.5
188188

189-
# Negative tokens during quantization are noise tokens
189+
# Rows with negative tokens during quantization are noise tokens
190190

191191
if kwargs.get("negative_ids_noise"):
192-
mask = (input_ids < 0).unsqueeze(-1)
193-
unmasked_values = hidden_states[~mask.expand_as(hidden_states)].float()
194-
mean, std = unmasked_values.mean(), unmasked_values.std()
195-
noise = torch.randn_like(hidden_states, dtype = torch.float)
196-
noise = noise * std + mean
197-
noise = noise.half()
198-
hidden_states = torch.where(mask, noise, hidden_states)
192+
193+
n = 0
194+
mean = torch.tensor([0.0], dtype = torch.float, device = hidden_states.device)
195+
M2 = torch.tensor([0.0], dtype = torch.float, device = hidden_states.device)
196+
197+
for i in range(input_ids.shape[0]):
198+
if input_ids[i][0] < 0:
199+
continue
200+
201+
er = hidden_states[i].float()
202+
n += er.numel()
203+
delta = er - mean
204+
mean += delta.sum() / n
205+
delta2 = er - mean
206+
M2 += (delta * delta2).sum()
207+
del er
208+
del delta
209+
del delta2
210+
211+
if n > 1:
212+
std = torch.sqrt(M2 / (n - 1))
213+
214+
for i in range(input_ids.shape[0]):
215+
if input_ids[i][0] >= 0:
216+
continue
217+
218+
er = hidden_states[i]
219+
noise = torch.randn(er.size(), dtype = torch.float, device = hidden_states.device) * std + mean
220+
er.copy_(noise.half())
221+
del er
222+
del noise
199223

200224
# Move to pinned temp buffer for TP
201225

0 commit comments

Comments
 (0)