Skip to content

Commit 60fb4f3

Browse files
authored
[Bugfix] Add kv cache scales to gemma2.py (#11269)
1 parent 63afbe9 commit 60fb4f3

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

vllm/model_executor/models/gemma2.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,14 @@
3131
RowParallelLinear)
3232
from vllm.model_executor.layers.logits_processor import LogitsProcessor
3333
from vllm.model_executor.layers.quantization import QuantizationConfig
34+
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
35+
get_compressed_tensors_cache_scale)
3436
from vllm.model_executor.layers.rotary_embedding import get_rope
3537
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
3638
from vllm.model_executor.layers.vocab_parallel_embedding import (
3739
VocabParallelEmbedding)
38-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40+
from vllm.model_executor.model_loader.weight_utils import (
41+
default_weight_loader, maybe_remap_kv_scale_name)
3942
from vllm.model_executor.sampling_metadata import SamplingMetadata
4043
from vllm.sequence import IntermediateTensors
4144

@@ -326,6 +329,15 @@ def load_weights(self, weights: Iterable[Tuple[str,
326329
params_dict = dict(self.named_parameters())
327330
loaded_params: Set[str] = set()
328331
for name, loaded_weight in weights:
332+
if scale_name := get_compressed_tensors_cache_scale(name):
333+
# Loading kv cache scales for compressed-tensors quantization
334+
param = params_dict[scale_name]
335+
weight_loader = getattr(param, "weight_loader",
336+
default_weight_loader)
337+
loaded_weight = loaded_weight[0]
338+
weight_loader(param, loaded_weight)
339+
loaded_params.add(scale_name)
340+
continue
329341
for (param_name, shard_name, shard_id) in stacked_params_mapping:
330342
if shard_name not in name:
331343
continue
@@ -343,6 +355,10 @@ def load_weights(self, weights: Iterable[Tuple[str,
343355
# Skip loading extra bias for GPTQ models.
344356
if name.endswith(".bias") and name not in params_dict:
345357
continue
358+
# Remapping the name of FP8 kv-scale.
359+
name = maybe_remap_kv_scale_name(name, params_dict)
360+
if name is None:
361+
continue
346362
if is_pp_missing_parameter(name, self):
347363
continue
348364
param = params_dict[name]

0 commit comments

Comments
 (0)