@@ -453,25 +453,30 @@ def __init__(
453
453
context_dim : int ,
454
454
quant_config : Optional [QuantizationConfig ] = None ,
455
455
bias : bool = False ,
456
+ prefix : str = "" ,
456
457
) -> None :
457
458
super ().__init__ ()
458
459
self .hidden_size = d_model
459
460
self .proj = ColumnParallelLinear (self .hidden_size ,
460
461
self .hidden_size ,
461
462
bias = bias ,
462
- gather_output = True )
463
+ gather_output = True ,
464
+ quant_config = quant_config ,
465
+ prefix = f"{ prefix } .proj" )
463
466
self .post_projection_norm = nn .LayerNorm (self .hidden_size )
464
467
self .gate_up_proj = MergedColumnParallelLinear (
465
468
input_size = self .hidden_size ,
466
469
output_sizes = [context_dim ] * 2 ,
467
470
bias = bias ,
468
471
quant_config = quant_config ,
472
+ prefix = f"{ prefix } .gate_up_proj" ,
469
473
)
470
474
self .down_proj = RowParallelLinear (
471
475
context_dim ,
472
476
self .hidden_size ,
473
477
bias = bias ,
474
478
quant_config = quant_config ,
479
+ prefix = f"{ prefix } .down_proj" ,
475
480
)
476
481
self .act_fn = SiluAndMul ()
477
482
self .extra_activation_func = nn .GELU ()
@@ -661,6 +666,7 @@ def __init__(
661
666
context_dim = vision_config .intermediate_size ,
662
667
quant_config = quant_config ,
663
668
bias = False ,
669
+ prefix = f"{ prefix } .merger" ,
664
670
)
665
671
self .embeddings = Glm4vVisionEmbeddings (vision_config )
666
672
0 commit comments