Skip to content

Commit 40b4284

Browse files
authored
[Bugfix] Handle process_weights_after_loading for QKVCrossParallelLinear (#15328)
Signed-off-by: Isotr0py <[email protected]>
1 parent 4ebc0b9 commit 40b4284

File tree

3 files changed

+33
-6
lines changed

3 files changed

+33
-6
lines changed

vllm/model_executor/layers/linear.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1353,6 +1353,7 @@ def __init__(self,
13531353
prefix=f"{prefix}.kv_proj_encoder")
13541354

13551355
# `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
1356+
self.q_size = self.q_proj_decoder.output_size_per_partition
13561357
self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size
13571358

13581359
if bias:
@@ -1364,20 +1365,31 @@ def __init__(self,
13641365
else:
13651366
self.bias = None
13661367

1368+
def process_weights_after_loading(self):
1369+
for layer in self.proj.values():
1370+
if self.quant_method is not None:
1371+
self.quant_method.process_weights_after_loading(layer)
1372+
13671373
@property
13681374
def q_proj_decoder(self) -> ColumnParallelLinear:
13691375
layer = self.proj["q_proj_decoder"]
13701376
for name, param in self.named_parameters():
1371-
target_param = getattr(layer, name)
1372-
self.sync_weight_attrs(param, target_param, mode="q_proj_decoder")
1377+
target_param = getattr(layer, name, None)
1378+
if target_param is not None:
1379+
self.sync_weight_attrs(param,
1380+
target_param,
1381+
mode="q_proj_decoder")
13731382
return layer
13741383

13751384
@property
13761385
def kv_proj_encoder(self) -> QKVParallelLinear:
13771386
layer = self.proj["kv_proj_encoder"]
13781387
for name, param in self.named_parameters():
1379-
target_param = getattr(layer, name)
1380-
self.sync_weight_attrs(param, target_param, mode="kv_proj_encoder")
1388+
target_param = getattr(layer, name, None)
1389+
if target_param is not None:
1390+
self.sync_weight_attrs(param,
1391+
target_param,
1392+
mode="kv_proj_encoder")
13811393
return layer
13821394

13831395
def sync_weight_attrs(
@@ -1466,11 +1478,14 @@ def weight_loader(self,
14661478
if loaded_shard_id == "q" else self.kv_proj_encoder)
14671479
target_param = self.select_proj_params(layer, param)
14681480
shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
1469-
layer.weight_loader(target_param, loaded_weight, *shard_id_args)
1481+
if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED:
1482+
layer.weight_loader_v2(target_param, loaded_weight, *shard_id_args)
1483+
else:
1484+
layer.weight_loader(target_param, loaded_weight, *shard_id_args)
14701485

14711486
def extra_repr(self) -> str:
14721487
s = f"in_features={self.input_size}"
1473-
s += f", q_size={self.q_proj_decoder.output_size_per_partition}"
1488+
s += f", q_size={self.q_size}"
14741489
s += f", kv_size={self.kv_size}"
14751490
s += f", bias={self.bias is not None}"
14761491
s += f", tp_size={get_tensor_model_parallel_world_size()}"

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ def create_weights(
254254
weight_loader=weight_loader,
255255
)
256256
scale[:] = torch.finfo(torch.float32).min
257+
set_weight_attrs(scale, {"scale_type": "weight_scale"})
257258
layer.register_parameter("weight_scale", scale)
258259
else:
259260
assert self.quant_config.activation_scheme == "dynamic"
@@ -268,6 +269,7 @@ def create_weights(
268269
weight_loader=weight_loader,
269270
)
270271
scale[:] = torch.finfo(torch.float32).min
272+
set_weight_attrs(scale, {"scale_type": "weight_scale"})
271273
# The weight_scale_inv name is intentional for deepseekv3
272274
layer.register_parameter("weight_scale_inv", scale)
273275

@@ -278,6 +280,7 @@ def create_weights(
278280
weight_loader=weight_loader)
279281

280282
scale[:] = torch.finfo(torch.float32).min
283+
set_weight_attrs(scale, {"scale_type": "input_scale"})
281284
layer.register_parameter("input_scale", scale)
282285
else:
283286
layer.register_parameter("input_scale", None)

vllm/model_executor/model_loader/loader.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,15 @@
3333
get_tensor_model_parallel_world_size)
3434
from vllm.envs import VLLM_USE_MODELSCOPE
3535
from vllm.logger import init_logger
36+
# yapf conflicts with isort for this block
37+
# yapf: disable
3638
from vllm.model_executor.layers.linear import (LinearBase,
3739
MergedColumnParallelLinear,
40+
QKVCrossParallelLinear,
3841
QKVParallelLinear,
3942
ReplicatedLinear,
4043
RowParallelLinear)
44+
# yapf: enable
4145
from vllm.model_executor.layers.quantization.base_config import (
4246
QuantizeMethodBase)
4347
from vllm.model_executor.model_loader.tensorizer import (
@@ -160,6 +164,11 @@ def _initialize_model(
160164
def _process_weights_after_loading(model: nn.Module, model_config: ModelConfig,
161165
target_device: torch.device) -> None:
162166
for _, module in model.named_modules():
167+
if isinstance(module, QKVCrossParallelLinear):
168+
# NOTE(Isotr0py): special case for cross QKV layer because
169+
# q and kv proj aren't registered as submodules intentionally
170+
module.process_weights_after_loading()
171+
continue
163172
quant_method = getattr(module, "quant_method", None)
164173
if isinstance(quant_method, QuantizeMethodBase):
165174
# When quant methods need to process weights after loading

0 commit comments

Comments
 (0)