Skip to content

Commit f2c53ef

Browse files
committed
Remove (experimental) Q-cache calibration feature
1 parent a029bcd commit f2c53ef

File tree

7 files changed

+25
-197
lines changed

7 files changed

+25
-197
lines changed

exllamav2/cache.py

Lines changed: 0 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -460,14 +460,6 @@ def __init__(
460460
devs = self.model.get_cache_devices() if self.fixed_device is None else [self.fixed_device]
461461
for device in devs: self.touch_device(device)
462462

463-
# Calibration mode
464-
465-
self.calibrated = False
466-
self.calibrating = False
467-
self.calibration_rows = [0] * cfg.num_hidden_layers
468-
self.calibration_k = {}
469-
self.calibration_v = {}
470-
471463

472464
def touch_device(self, device):
473465

@@ -516,15 +508,9 @@ def get_kv_state(
516508
block_table if block_table is not None else none_tensor,
517509
# none_tensor,
518510
# none_tensor
519-
self.calibration_k[layer_idx] if self.calibrated else none_tensor,
520-
self.calibration_v[layer_idx] if self.calibrated else none_tensor,
521511
self.wbits
522512
)
523513

524-
# if self.calibrated:
525-
# temp_key_state *= self.calibration_k[layer_idx]
526-
# temp_value_state *= self.calibration_v[layer_idx]
527-
528514
return temp_key_state, temp_value_state
529515

530516

@@ -551,10 +537,6 @@ def store_kv_state(
551537
device = self.model.cache_map.get(layer_idx, self.fixed_device)
552538
temp_key_state, temp_value_state = self.temp_tensors[device]
553539

554-
# if self.calibrated:
555-
# temp_key_state /= self.calibration_k[layer_idx]
556-
# temp_value_state /= self.calibration_v[layer_idx]
557-
558540
ext_c.fp16_to_q_kv(
559541
temp_key_state,
560542
self.key_states[layer_idx],
@@ -570,40 +552,9 @@ def store_kv_state(
570552
block_table if block_table is not None else none_tensor,
571553
# none_tensor,
572554
# none_tensor
573-
self.calibration_k[layer_idx] if self.calibrated else none_tensor,
574-
self.calibration_v[layer_idx] if self.calibrated else none_tensor,
575555
self.wbits
576556
)
577557

578-
# Collect calibration data
579-
580-
if self.calibrating:
581-
582-
cfg = self.model.config
583-
584-
if layer_idx not in self.calibration_k:
585-
self.calibration_k[layer_idx] = torch.zeros(
586-
(cfg.num_key_value_heads, cfg.head_dim,),
587-
dtype = torch.float,
588-
device = temp_key_state.device
589-
)
590-
self.calibration_v[layer_idx] = torch.zeros(
591-
(cfg.num_key_value_heads, cfg.head_dim,),
592-
dtype = torch.float,
593-
device = temp_key_state.device
594-
)
595-
596-
b, l, h, d = temp_key_state.shape
597-
cal_k = self.calibration_k[layer_idx]
598-
cal_v = self.calibration_v[layer_idx]
599-
cal_k_input = temp_key_state[:, offset:offset+width, :, :].view(b * width, h * d)
600-
cal_v_input = temp_value_state[:, offset:offset+width, :, :].view(b * width, h * d)
601-
cal_k_sum = torch.norm(cal_k_input, p = 1, dim = 0, dtype = torch.float)
602-
cal_v_sum = torch.norm(cal_v_input, p = 1, dim = 0, dtype = torch.float)
603-
cal_k.add_(cal_k_sum.view(h, d))
604-
cal_v.add_(cal_v_sum.view(h, d))
605-
self.calibration_rows[layer_idx] += width
606-
607558

608559
def footprint(self) -> list[int]:
609560

@@ -623,57 +574,13 @@ def footprint(self) -> list[int]:
623574

624575

625576
def clone(self) -> ExLlamaV2Cache_Q4:
626-
627577
new = ExLlamaV2Cache_Q4(self.model, self.batch_size, self.max_seq_len, self)
628578
return new
629579

630-
631580
def all_tensors(self):
632581
return self.key_states + self.value_states + self.key_scales + self.value_scales
633582

634583

635-
def calibrate(self,
636-
tokenizer: ExLlamaV2Tokenizer,
637-
num_batches = 8,
638-
num_samples_per_batch = 256
639-
):
640-
"""
641-
Unfinished
642-
"""
643-
644-
assert self.max_seq_len >= num_samples_per_batch, \
645-
f"Cache max_seq_len must be at least {num_samples_per_batch} to calibrate."
646-
647-
self.calibrating = True
648-
torch.manual_seed(123)
649-
650-
for _ in range(num_batches):
651-
652-
input_ids = torch.randint(
653-
low = 0,
654-
high = tokenizer.get_vocab_size() - 1,
655-
size = (1, num_samples_per_batch),
656-
dtype = torch.long
657-
)
658-
659-
self.reset()
660-
self.model.forward(input_ids, preprocess_only = True, cache = self)
661-
662-
self.calibrating = False
663-
664-
for i in range(self.model.config.num_hidden_layers):
665-
cal_k = self.calibration_k[i] / self.calibration_rows[i] # self.calibration_k[i].mean()
666-
cal_v = self.calibration_v[i] / self.calibration_rows[i] # self.calibration_v[i].mean()
667-
cal_k = cal_k ** (1/8)
668-
cal_v = cal_v ** (1/8)
669-
cal_k = cal_k.half() * (-1)
670-
cal_v = cal_v.half() * (-1)
671-
self.calibration_k[i] = cal_k
672-
self.calibration_v[i] = cal_v
673-
self.calibrating = False
674-
# self.calibrated = True
675-
676-
677584
class ExLlamaV2Cache_Q4(ExLlamaV2Cache_Q):
678585

679586
def __init__(

0 commit comments

Comments
 (0)