@@ -89,23 +89,34 @@ def create_weights(self, layer: torch.nn.Module,
8989 requires_grad = False )
9090
9191 layer .register_parameter ("weight" , weight )
92- set_weight_attrs (weight , {"input_dim" : 1 , "output_dim" : 0 })
93-
94- set_weight_attrs (weight , {"weight_loader" : weight_loader })
95-
92+ set_weight_attrs (weight , {
93+ "weight_loader" : weight_loader ,
94+ "input_dim" : 1 ,
95+ "output_dim" : 0 ,
96+ })
9697 layer .register_parameter ("input_scale" , input_scale )
97- set_weight_attrs (input_scale , {"weight_loader" : weight_loader })
98+ set_weight_attrs (input_scale , {
99+ "weight_loader" : weight_loader ,
100+ "ignore_warning" : True ,
101+ })
98102 layer .register_parameter ("input_zero_point" , input_zero_point )
99- set_weight_attrs (input_zero_point , {"weight_loader" : weight_loader })
103+ set_weight_attrs (input_zero_point , {
104+ "weight_loader" : weight_loader ,
105+ "ignore_warning" : True ,
106+ })
100107 layer .register_parameter ("weight_scale" , weight_scale )
101- set_weight_attrs (weight_scale , {"weight_loader" : weight_loader })
102108 set_weight_attrs (
103109 weight_scale , {
110+ "weight_loader" : weight_loader ,
104111 "shard_splitter" : self .scales_shard_splitter ,
105- "logical_widths" : output_partition_sizes
112+ "logical_widths" : output_partition_sizes ,
113+ "ignore_warning" : True ,
106114 })
107115 layer .register_parameter ("weight_zero_point" , weight_zero_point )
108- set_weight_attrs (weight_zero_point , {"weight_loader" : weight_loader })
116+ set_weight_attrs (weight_zero_point , {
117+ "weight_loader" : weight_loader ,
118+ "ignore_warning" : True
119+ })
109120
110121 def apply_weights (self , layer : torch .nn .Module , x : torch .Tensor ):
111122 weight = layer .weight
0 commit comments