@@ -460,14 +460,6 @@ def __init__(
460
460
devs = self .model .get_cache_devices () if self .fixed_device is None else [self .fixed_device ]
461
461
for device in devs : self .touch_device (device )
462
462
463
- # Calibration mode
464
-
465
- self .calibrated = False
466
- self .calibrating = False
467
- self .calibration_rows = [0 ] * cfg .num_hidden_layers
468
- self .calibration_k = {}
469
- self .calibration_v = {}
470
-
471
463
472
464
def touch_device (self , device ):
473
465
@@ -516,15 +508,9 @@ def get_kv_state(
516
508
block_table if block_table is not None else none_tensor ,
517
509
# none_tensor,
518
510
# none_tensor
519
- self .calibration_k [layer_idx ] if self .calibrated else none_tensor ,
520
- self .calibration_v [layer_idx ] if self .calibrated else none_tensor ,
521
511
self .wbits
522
512
)
523
513
524
- # if self.calibrated:
525
- # temp_key_state *= self.calibration_k[layer_idx]
526
- # temp_value_state *= self.calibration_v[layer_idx]
527
-
528
514
return temp_key_state , temp_value_state
529
515
530
516
@@ -551,10 +537,6 @@ def store_kv_state(
551
537
device = self .model .cache_map .get (layer_idx , self .fixed_device )
552
538
temp_key_state , temp_value_state = self .temp_tensors [device ]
553
539
554
- # if self.calibrated:
555
- # temp_key_state /= self.calibration_k[layer_idx]
556
- # temp_value_state /= self.calibration_v[layer_idx]
557
-
558
540
ext_c .fp16_to_q_kv (
559
541
temp_key_state ,
560
542
self .key_states [layer_idx ],
@@ -570,40 +552,9 @@ def store_kv_state(
570
552
block_table if block_table is not None else none_tensor ,
571
553
# none_tensor,
572
554
# none_tensor
573
- self .calibration_k [layer_idx ] if self .calibrated else none_tensor ,
574
- self .calibration_v [layer_idx ] if self .calibrated else none_tensor ,
575
555
self .wbits
576
556
)
577
557
578
- # Collect calibration data
579
-
580
- if self .calibrating :
581
-
582
- cfg = self .model .config
583
-
584
- if layer_idx not in self .calibration_k :
585
- self .calibration_k [layer_idx ] = torch .zeros (
586
- (cfg .num_key_value_heads , cfg .head_dim ,),
587
- dtype = torch .float ,
588
- device = temp_key_state .device
589
- )
590
- self .calibration_v [layer_idx ] = torch .zeros (
591
- (cfg .num_key_value_heads , cfg .head_dim ,),
592
- dtype = torch .float ,
593
- device = temp_key_state .device
594
- )
595
-
596
- b , l , h , d = temp_key_state .shape
597
- cal_k = self .calibration_k [layer_idx ]
598
- cal_v = self .calibration_v [layer_idx ]
599
- cal_k_input = temp_key_state [:, offset :offset + width , :, :].view (b * width , h * d )
600
- cal_v_input = temp_value_state [:, offset :offset + width , :, :].view (b * width , h * d )
601
- cal_k_sum = torch .norm (cal_k_input , p = 1 , dim = 0 , dtype = torch .float )
602
- cal_v_sum = torch .norm (cal_v_input , p = 1 , dim = 0 , dtype = torch .float )
603
- cal_k .add_ (cal_k_sum .view (h , d ))
604
- cal_v .add_ (cal_v_sum .view (h , d ))
605
- self .calibration_rows [layer_idx ] += width
606
-
607
558
608
559
def footprint (self ) -> list [int ]:
609
560
@@ -623,57 +574,13 @@ def footprint(self) -> list[int]:
623
574
624
575
625
576
def clone (self ) -> ExLlamaV2Cache_Q4 :
626
-
627
577
new = ExLlamaV2Cache_Q4 (self .model , self .batch_size , self .max_seq_len , self )
628
578
return new
629
579
630
-
631
580
def all_tensors (self ):
632
581
return self .key_states + self .value_states + self .key_scales + self .value_scales
633
582
634
583
635
- def calibrate (self ,
636
- tokenizer : ExLlamaV2Tokenizer ,
637
- num_batches = 8 ,
638
- num_samples_per_batch = 256
639
- ):
640
- """
641
- Unfinished
642
- """
643
-
644
- assert self .max_seq_len >= num_samples_per_batch , \
645
- f"Cache max_seq_len must be at least { num_samples_per_batch } to calibrate."
646
-
647
- self .calibrating = True
648
- torch .manual_seed (123 )
649
-
650
- for _ in range (num_batches ):
651
-
652
- input_ids = torch .randint (
653
- low = 0 ,
654
- high = tokenizer .get_vocab_size () - 1 ,
655
- size = (1 , num_samples_per_batch ),
656
- dtype = torch .long
657
- )
658
-
659
- self .reset ()
660
- self .model .forward (input_ids , preprocess_only = True , cache = self )
661
-
662
- self .calibrating = False
663
-
664
- for i in range (self .model .config .num_hidden_layers ):
665
- cal_k = self .calibration_k [i ] / self .calibration_rows [i ] # self.calibration_k[i].mean()
666
- cal_v = self .calibration_v [i ] / self .calibration_rows [i ] # self.calibration_v[i].mean()
667
- cal_k = cal_k ** (1 / 8 )
668
- cal_v = cal_v ** (1 / 8 )
669
- cal_k = cal_k .half () * (- 1 )
670
- cal_v = cal_v .half () * (- 1 )
671
- self .calibration_k [i ] = cal_k
672
- self .calibration_v [i ] = cal_v
673
- self .calibrating = False
674
- # self.calibrated = True
675
-
676
-
677
584
class ExLlamaV2Cache_Q4 (ExLlamaV2Cache_Q ):
678
585
679
586
def __init__ (
0 commit comments