@@ -117,37 +117,55 @@ ColumnParallelQLinearImpl::ColumnParallelQLinearImpl(
117117 quant_args.group_size () > 0 ? quant_args.group_size () : in_features;
118118 CHECK (qweight_pack_dim == 0 || qweight_pack_dim == 1 )
119119 << " qweight_pack_dim must be 0 or 1" ;
120- const int64_t world_size = parallel_args.world_size ();
120+ const auto rank = parallel_args_.rank ();
121+ const auto world_size = parallel_args_.world_size ();
121122 CHECK (out_features % world_size == 0 )
122123 << " out_features " << out_features << " not divisible by world_size "
123124 << world_size;
124125 const int64_t out_features_per_partition = out_features / world_size;
125126 const int64_t pack_factor = 32 / bits;
126127
127128 if (qweight_pack_dim == 0 ) {
128- qweight_ = register_parameter (
129+ qweight_ = register_sharded_parameter (
129130 " qweight" ,
131+ /* dim=*/ 1 ,
132+ rank,
133+ world_size,
130134 torch::empty ({in_features / pack_factor, out_features_per_partition},
131135 options.dtype (torch::kInt32 )));
132136 } else {
133- qweight_ = register_parameter (
137+ qweight_ = register_sharded_parameter (
134138 " qweight" ,
139+ /* dim=*/ 1 ,
140+ rank,
141+ world_size,
135142 torch::empty ({in_features, out_features_per_partition / pack_factor},
136143 options.dtype (torch::kInt32 )));
137144 }
138- qzeros_ = register_parameter (
145+ qzeros_ = register_sharded_parameter (
139146 " qzeros" ,
147+ /* dim=*/ 1 ,
148+ rank,
149+ world_size,
140150 torch::empty ({round_up (in_features, group_size),
141151 out_features_per_partition / pack_factor},
142152 options.dtype (torch::kInt32 )));
143153
144- scales_ = register_parameter (" scales" ,
145- torch::empty ({round_up (in_features, group_size),
146- out_features_per_partition},
147- options));
154+ scales_ = register_sharded_parameter (
155+ " scales" ,
156+ /* dim=*/ 1 ,
157+ rank,
158+ world_size,
159+ torch::empty (
160+ {round_up (in_features, group_size), out_features_per_partition},
161+ options));
148162 if (bias) {
149- bias_ = register_parameter (
150- " bias" , torch::empty ({out_features_per_partition}, options));
163+ bias_ = register_sharded_parameter (
164+ " bias" ,
165+ /* dim=*/ 0 ,
166+ rank,
167+ world_size,
168+ torch::empty ({out_features_per_partition}, options));
151169 }
152170}
153171
@@ -226,7 +244,8 @@ RowParallelQLinearImpl::RowParallelQLinearImpl(
226244 const auto bits = quant_args.bits ();
227245 CHECK (qweight_pack_dim == 0 || qweight_pack_dim == 1 )
228246 << " qweight_pack_dim must be 0 or 1" ;
229- const int64_t world_size = parallel_args.world_size ();
247+ const auto rank = parallel_args_.rank ();
248+ const auto world_size = parallel_args_.world_size ();
230249 CHECK (in_features % world_size == 0 )
231250 << " in_features " << in_features << " not divisible by world_size "
232251 << world_size;
@@ -236,24 +255,36 @@ RowParallelQLinearImpl::RowParallelQLinearImpl(
236255 quant_args.group_size () > 0 ? quant_args.group_size () : in_features;
237256
238257 if (qweight_pack_dim == 0 ) {
239- qweight_ = register_parameter (
258+ qweight_ = register_sharded_parameter (
240259 " qweight" ,
260+ /* dim=*/ 0 ,
261+ rank,
262+ world_size,
241263 torch::empty ({in_features_per_partition / pack_factor, out_features},
242264 options.dtype (torch::kInt32 )));
243265 } else {
244- qweight_ = register_parameter (
266+ qweight_ = register_sharded_parameter (
245267 " qweight" ,
268+ /* dim=*/ 0 ,
269+ rank,
270+ world_size,
246271 torch::empty ({in_features_per_partition, out_features / pack_factor},
247272 options.dtype (torch::kInt32 )));
248273 }
249- qzeros_ = register_parameter (
274+ qzeros_ = register_sharded_parameter (
250275 " qzeros" ,
276+ /* dim=*/ 0 ,
277+ rank,
278+ world_size,
251279 torch::empty ({round_up (in_features_per_partition, group_size),
252280 out_features / pack_factor},
253281 options.dtype (torch::kInt32 )));
254282
255- scales_ = register_parameter (
283+ scales_ = register_sharded_parameter (
256284 " scales" ,
285+ /* dim=*/ 0 ,
286+ rank,
287+ world_size,
257288 torch::empty (
258289 {round_up (in_features_per_partition, group_size), out_features},
259290 options));
0 commit comments