@@ -1353,6 +1353,7 @@ def __init__(self,
13531353 prefix = f"{ prefix } .kv_proj_encoder" )
13541354
13551355 # `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
1356+ self .q_size = self .q_proj_decoder .output_size_per_partition
13561357 self .kv_size = self .kv_proj_encoder .num_kv_heads * head_size
13571358
13581359 if bias :
@@ -1364,20 +1365,31 @@ def __init__(self,
13641365 else :
13651366 self .bias = None
13661367
1368+ def process_weights_after_loading (self ):
1369+ for layer in self .proj .values ():
1370+ if self .quant_method is not None :
1371+ self .quant_method .process_weights_after_loading (layer )
1372+
13671373 @property
13681374 def q_proj_decoder (self ) -> ColumnParallelLinear :
13691375 layer = self .proj ["q_proj_decoder" ]
13701376 for name , param in self .named_parameters ():
1371- target_param = getattr (layer , name )
1372- self .sync_weight_attrs (param , target_param , mode = "q_proj_decoder" )
1377+ target_param = getattr (layer , name , None )
1378+ if target_param is not None :
1379+ self .sync_weight_attrs (param ,
1380+ target_param ,
1381+ mode = "q_proj_decoder" )
13731382 return layer
13741383
13751384 @property
13761385 def kv_proj_encoder (self ) -> QKVParallelLinear :
13771386 layer = self .proj ["kv_proj_encoder" ]
13781387 for name , param in self .named_parameters ():
1379- target_param = getattr (layer , name )
1380- self .sync_weight_attrs (param , target_param , mode = "kv_proj_encoder" )
1388+ target_param = getattr (layer , name , None )
1389+ if target_param is not None :
1390+ self .sync_weight_attrs (param ,
1391+ target_param ,
1392+ mode = "kv_proj_encoder" )
13811393 return layer
13821394
13831395 def sync_weight_attrs (
@@ -1466,11 +1478,14 @@ def weight_loader(self,
14661478 if loaded_shard_id == "q" else self .kv_proj_encoder )
14671479 target_param = self .select_proj_params (layer , param )
14681480 shard_id_args = (loaded_shard_id , ) if loaded_shard_id != "q" else ()
1469- layer .weight_loader (target_param , loaded_weight , * shard_id_args )
1481+ if self .quant_method .__class__ .__name__ in WEIGHT_LOADER_V2_SUPPORTED :
1482+ layer .weight_loader_v2 (target_param , loaded_weight , * shard_id_args )
1483+ else :
1484+ layer .weight_loader (target_param , loaded_weight , * shard_id_args )
14701485
14711486 def extra_repr (self ) -> str :
14721487 s = f"in_features={ self .input_size } "
1473- s += f", q_size={ self .q_proj_decoder . output_size_per_partition } "
1488+ s += f", q_size={ self .q_size } "
14741489 s += f", kv_size={ self .kv_size } "
14751490 s += f", bias={ self .bias is not None } "
14761491 s += f", tp_size={ get_tensor_model_parallel_world_size ()} "
0 commit comments