3131 RowParallelLinear )
3232from vllm .model_executor .layers .logits_processor import LogitsProcessor
3333from vllm .model_executor .layers .quantization import QuantizationConfig
34+ from vllm .model_executor .layers .quantization .compressed_tensors .utils import (
35+ get_compressed_tensors_cache_scale )
3436from vllm .model_executor .layers .rotary_embedding import get_rope
3537from vllm .model_executor .layers .sampler import SamplerOutput , get_sampler
3638from vllm .model_executor .layers .vocab_parallel_embedding import (
3739 VocabParallelEmbedding )
38- from vllm .model_executor .model_loader .weight_utils import default_weight_loader
40+ from vllm .model_executor .model_loader .weight_utils import (
41+ default_weight_loader , maybe_remap_kv_scale_name )
3942from vllm .model_executor .sampling_metadata import SamplingMetadata
4043from vllm .sequence import IntermediateTensors
4144
@@ -326,6 +329,15 @@ def load_weights(self, weights: Iterable[Tuple[str,
326329 params_dict = dict (self .named_parameters ())
327330 loaded_params : Set [str ] = set ()
328331 for name , loaded_weight in weights :
332+ if scale_name := get_compressed_tensors_cache_scale (name ):
333+ # Loading kv cache scales for compressed-tensors quantization
334+ param = params_dict [scale_name ]
335+ weight_loader = getattr (param , "weight_loader" ,
336+ default_weight_loader )
337+ loaded_weight = loaded_weight [0 ]
338+ weight_loader (param , loaded_weight )
339+ loaded_params .add (scale_name )
340+ continue
329341 for (param_name , shard_name , shard_id ) in stacked_params_mapping :
330342 if shard_name not in name :
331343 continue
@@ -343,6 +355,10 @@ def load_weights(self, weights: Iterable[Tuple[str,
343355 # Skip loading extra bias for GPTQ models.
344356 if name .endswith (".bias" ) and name not in params_dict :
345357 continue
358+ # Remapping the name of FP8 kv-scale.
359+ name = maybe_remap_kv_scale_name (name , params_dict )
360+ if name is None :
361+ continue
346362 if is_pp_missing_parameter (name , self ):
347363 continue
348364 param = params_dict [name ]
0 commit comments