@@ -1353,6 +1353,7 @@ def __init__(self,
1353
1353
prefix = f"{ prefix } .kv_proj_encoder" )
1354
1354
1355
1355
# `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
1356
+ self .q_size = self .q_proj_decoder .output_size_per_partition
1356
1357
self .kv_size = self .kv_proj_encoder .num_kv_heads * head_size
1357
1358
1358
1359
if bias :
@@ -1364,20 +1365,31 @@ def __init__(self,
1364
1365
else :
1365
1366
self .bias = None
1366
1367
1368
+ def process_weights_after_loading (self ):
1369
+ for layer in self .proj .values ():
1370
+ if self .quant_method is not None :
1371
+ self .quant_method .process_weights_after_loading (layer )
1372
+
1367
1373
@property
1368
1374
def q_proj_decoder (self ) -> ColumnParallelLinear :
1369
1375
layer = self .proj ["q_proj_decoder" ]
1370
1376
for name , param in self .named_parameters ():
1371
- target_param = getattr (layer , name )
1372
- self .sync_weight_attrs (param , target_param , mode = "q_proj_decoder" )
1377
+ target_param = getattr (layer , name , None )
1378
+ if target_param is not None :
1379
+ self .sync_weight_attrs (param ,
1380
+ target_param ,
1381
+ mode = "q_proj_decoder" )
1373
1382
return layer
1374
1383
1375
1384
@property
1376
1385
def kv_proj_encoder (self ) -> QKVParallelLinear :
1377
1386
layer = self .proj ["kv_proj_encoder" ]
1378
1387
for name , param in self .named_parameters ():
1379
- target_param = getattr (layer , name )
1380
- self .sync_weight_attrs (param , target_param , mode = "kv_proj_encoder" )
1388
+ target_param = getattr (layer , name , None )
1389
+ if target_param is not None :
1390
+ self .sync_weight_attrs (param ,
1391
+ target_param ,
1392
+ mode = "kv_proj_encoder" )
1381
1393
return layer
1382
1394
1383
1395
def sync_weight_attrs (
@@ -1466,11 +1478,14 @@ def weight_loader(self,
1466
1478
if loaded_shard_id == "q" else self .kv_proj_encoder )
1467
1479
target_param = self .select_proj_params (layer , param )
1468
1480
shard_id_args = (loaded_shard_id , ) if loaded_shard_id != "q" else ()
1469
- layer .weight_loader (target_param , loaded_weight , * shard_id_args )
1481
+ if self .quant_method .__class__ .__name__ in WEIGHT_LOADER_V2_SUPPORTED :
1482
+ layer .weight_loader_v2 (target_param , loaded_weight , * shard_id_args )
1483
+ else :
1484
+ layer .weight_loader (target_param , loaded_weight , * shard_id_args )
1470
1485
1471
1486
def extra_repr (self ) -> str :
1472
1487
s = f"in_features={ self .input_size } "
1473
- s += f", q_size={ self .q_proj_decoder . output_size_per_partition } "
1488
+ s += f", q_size={ self .q_size } "
1474
1489
s += f", kv_size={ self .kv_size } "
1475
1490
s += f", bias={ self .bias is not None } "
1476
1491
s += f", tp_size={ get_tensor_model_parallel_world_size ()} "
0 commit comments