Skip to content

Commit 69bb9d6

Browse files
committed
Add optional noise embeddings during quantization
1 parent 5857ea9 commit 69bb9d6

File tree

5 files changed

+34
-8
lines changed

5 files changed

+34
-8
lines changed

exllamav2/architecture.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,9 @@ class Params:
241241
# Tensors are transposed in original model weights
242242
self.orig_weights_transposed = False
243243

244+
# Add noise rows to calibration while quantizing
245+
self.standard_calib_noise = None
246+
244247
# Mistral
245248

246249
if arch_string == "MistralForCausalLM":

exllamav2/conversion/convert_exl2.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,8 @@ def save_job():
240240
else:
241241

242242
print(f" -- Tokenizing samples (measurement)...")
243-
tokenize(job, save_job, tokenizer, measure = True)
243+
noise_rows = config.arch.standard_calib_noise
244+
tokenize(job, save_job, tokenizer, measure = True, noise_rows = noise_rows)
244245
job["progress"] = "initial_embeddings"
245246
save_job()
246247

@@ -285,7 +286,8 @@ def save_job():
285286
if progress == "tokens_cal":
286287

287288
print(f" -- Tokenizing samples...")
288-
tokenize(job, save_job, tokenizer)
289+
noise_rows = config.arch.standard_calib_noise
290+
tokenize(job, save_job, tokenizer, noise_rows = noise_rows)
289291
job["progress"] = "embeddings"
290292
save_job()
291293

exllamav2/conversion/measure.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def embeddings(job, save_fn, model, measure = False):
8080

8181
module.load()
8282
input_ids[input_ids >= module.native_vocab_size] = 0
83-
hidden_state = module.forward(input_ids)
83+
hidden_state = module.forward(input_ids, negative_ids_noise = True)
8484
module.unload()
8585

8686
embeddings_dict = { f"row.{i:05}": hidden_state[i:i+1, :, :].contiguous() for i in range(hidden_state.shape[0]) }

exllamav2/conversion/tokenize.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def get_tokens(num_rows, length, filename, tokenizer):
3636
return all_tokens
3737

3838

39-
def tokenize(job, save_fn, tokenizer, measure = False):
39+
def tokenize(job, save_fn, tokenizer, measure = False, noise_rows = None):
4040

4141
print_stage(job, "Tokenizing (1)" if measure else "Tokenizing (2)", 0, 1)
4242

@@ -47,7 +47,7 @@ def tokenize(job, save_fn, tokenizer, measure = False):
4747
length = job["measurement_length"] if measure else job["length"]
4848
cal_tokens = get_tokens(rows, length, cal_ds, tokenizer)
4949
else:
50-
cal_tokens = get_standard_calibration(job, measure, tokenizer)
50+
cal_tokens = get_standard_calibration(job, measure, tokenizer, noise_rows)
5151
if measure:
5252
job["measurement_rows"] = cal_tokens.shape[0]
5353
else:
@@ -61,7 +61,7 @@ def tokenize(job, save_fn, tokenizer, measure = False):
6161
print_stage(job, "Tokenizing (1)" if measure else "Tokenizing (2)", 1, 1)
6262

6363

64-
def get_standard_calibration(job, measure, tokenizer):
64+
def get_standard_calibration(job, measure, tokenizer, noise_rows = None):
6565

6666
data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "standard_cal_data")
6767
file_c4 =os.path.join(data_dir, "c4.utf8")
@@ -80,6 +80,10 @@ def get_standard_calibration(job, measure, tokenizer):
8080
rows_multilingual_s = 1 if measure else 5
8181
rows_technical = 2 if measure else 10
8282
rows_random = 2
83+
if noise_rows is not None:
84+
rows_noise = noise_rows[0] if measure else noise_rows[1]
85+
else:
86+
rows_noise = 0
8387

8488
ctx = min(2048, job["measurement_length"] if measure else job["length"])
8589

@@ -189,6 +193,11 @@ def get_standard_calibration(job, measure, tokenizer):
189193
for i in range(rows_technical):
190194
rows.append(tokenized_rows[i:i+1])
191195

196+
# Noise: 30 rows
197+
198+
for i in range(rows_noise):
199+
rows.append(torch.neg(torch.ones_like(rows[-1])))
200+
192201
# for idx, r in enumerate(rows):
193202
# print("------------------------------------------------------------------------------")
194203
# print(idx)

exllamav2/embedding.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,10 @@ def forward(
109109

110110
cfg = self.model.config
111111

112-
# If input IDs contain negative values, assume they are padding tokens from a model with not pad_token_id
113-
# defined
112+
# If input IDs contain negative values, assume they are padding tokens from a model with no pad_token_id
113+
# defined or noise values for quantizing
114114

115+
input_ids = hidden_states
115116
hidden_states = hidden_states.clamp(min = 0)
116117

117118
# Apply indexed embeddings
@@ -185,6 +186,17 @@ def forward(
185186
if self.archparams.normalize_embeddings:
186187
hidden_states *= cfg.hidden_size ** 0.5
187188

189+
# Negative tokens during quantization are noise tokens
190+
191+
if kwargs.get("negative_ids_noise"):
192+
mask = (input_ids < 0).unsqueeze(-1)
193+
unmasked_values = hidden_states[~mask.expand_as(hidden_states)].float()
194+
mean, std = unmasked_values.mean(), unmasked_values.std()
195+
noise = torch.randn_like(hidden_states, dtype = torch.float)
196+
noise = noise * std + mean
197+
noise = noise.half()
198+
hidden_states = torch.where(mask, noise, hidden_states)
199+
188200
# Move to pinned temp buffer for TP
189201

190202
if self.is_tp:

0 commit comments

Comments
 (0)