Skip to content

Commit b4cc8f2

Browse files
committed
update qlinear
1 parent 2972034 commit b4cc8f2

File tree

2 files changed

+50
-19
lines changed

2 files changed

+50
-19
lines changed

src/quantization/qlinear_impl.cpp

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -117,37 +117,55 @@ ColumnParallelQLinearImpl::ColumnParallelQLinearImpl(
117117
quant_args.group_size() > 0 ? quant_args.group_size() : in_features;
118118
CHECK(qweight_pack_dim == 0 || qweight_pack_dim == 1)
119119
<< "qweight_pack_dim must be 0 or 1";
120-
const int64_t world_size = parallel_args.world_size();
120+
const auto rank = parallel_args_.rank();
121+
const auto world_size = parallel_args_.world_size();
121122
CHECK(out_features % world_size == 0)
122123
<< "out_features " << out_features << " not divisible by world_size "
123124
<< world_size;
124125
const int64_t out_features_per_partition = out_features / world_size;
125126
const int64_t pack_factor = 32 / bits;
126127

127128
if (qweight_pack_dim == 0) {
128-
qweight_ = register_parameter(
129+
qweight_ = register_sharded_parameter(
129130
"qweight",
131+
/*dim=*/1,
132+
rank,
133+
world_size,
130134
torch::empty({in_features / pack_factor, out_features_per_partition},
131135
options.dtype(torch::kInt32)));
132136
} else {
133-
qweight_ = register_parameter(
137+
qweight_ = register_sharded_parameter(
134138
"qweight",
139+
/*dim=*/1,
140+
rank,
141+
world_size,
135142
torch::empty({in_features, out_features_per_partition / pack_factor},
136143
options.dtype(torch::kInt32)));
137144
}
138-
qzeros_ = register_parameter(
145+
qzeros_ = register_sharded_parameter(
139146
"qzeros",
147+
/*dim=*/1,
148+
rank,
149+
world_size,
140150
torch::empty({round_up(in_features, group_size),
141151
out_features_per_partition / pack_factor},
142152
options.dtype(torch::kInt32)));
143153

144-
scales_ = register_parameter("scales",
145-
torch::empty({round_up(in_features, group_size),
146-
out_features_per_partition},
147-
options));
154+
scales_ = register_sharded_parameter(
155+
"scales",
156+
/*dim=*/1,
157+
rank,
158+
world_size,
159+
torch::empty(
160+
{round_up(in_features, group_size), out_features_per_partition},
161+
options));
148162
if (bias) {
149-
bias_ = register_parameter(
150-
"bias", torch::empty({out_features_per_partition}, options));
163+
bias_ = register_sharded_parameter(
164+
"bias",
165+
/*dim=*/0,
166+
rank,
167+
world_size,
168+
torch::empty({out_features_per_partition}, options));
151169
}
152170
}
153171

@@ -226,7 +244,8 @@ RowParallelQLinearImpl::RowParallelQLinearImpl(
226244
const auto bits = quant_args.bits();
227245
CHECK(qweight_pack_dim == 0 || qweight_pack_dim == 1)
228246
<< "qweight_pack_dim must be 0 or 1";
229-
const int64_t world_size = parallel_args.world_size();
247+
const auto rank = parallel_args_.rank();
248+
const auto world_size = parallel_args_.world_size();
230249
CHECK(in_features % world_size == 0)
231250
<< "in_features " << in_features << " not divisible by world_size "
232251
<< world_size;
@@ -236,24 +255,36 @@ RowParallelQLinearImpl::RowParallelQLinearImpl(
236255
quant_args.group_size() > 0 ? quant_args.group_size() : in_features;
237256

238257
if (qweight_pack_dim == 0) {
239-
qweight_ = register_parameter(
258+
qweight_ = register_sharded_parameter(
240259
"qweight",
260+
/*dim=*/0,
261+
rank,
262+
world_size,
241263
torch::empty({in_features_per_partition / pack_factor, out_features},
242264
options.dtype(torch::kInt32)));
243265
} else {
244-
qweight_ = register_parameter(
266+
qweight_ = register_sharded_parameter(
245267
"qweight",
268+
/*dim=*/0,
269+
rank,
270+
world_size,
246271
torch::empty({in_features_per_partition, out_features / pack_factor},
247272
options.dtype(torch::kInt32)));
248273
}
249-
qzeros_ = register_parameter(
274+
qzeros_ = register_sharded_parameter(
250275
"qzeros",
276+
/*dim=*/0,
277+
rank,
278+
world_size,
251279
torch::empty({round_up(in_features_per_partition, group_size),
252280
out_features / pack_factor},
253281
options.dtype(torch::kInt32)));
254282

255-
scales_ = register_parameter(
283+
scales_ = register_sharded_parameter(
256284
"scales",
285+
/*dim=*/0,
286+
rank,
287+
world_size,
257288
torch::empty(
258289
{round_up(in_features_per_partition, group_size), out_features},
259290
options));

src/quantization/qlinear_impl_test.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ TEST(QlinearTest, ColumnParallelQuantLinear) {
4646
/*bits=*/4);
4747
weights = weights.to(torch::kCUDA);
4848

49-
qlinear.load_state_dict(*state_dict);
50-
qlinear.verify_loaded_weights();
49+
qlinear.load(*state_dict);
50+
EXPECT_TRUE(qlinear.verify());
5151

5252
auto input = torch::rand({40960, in_features}, options);
5353
auto output = qlinear.forward(input);
@@ -83,8 +83,8 @@ TEST(QlinearTest, RowParallelQuantLinear) {
8383
/*bits=*/4);
8484
weights = weights.to(torch::kCUDA);
8585

86-
qlinear.load_state_dict(*state_dict);
87-
qlinear.verify_loaded_weights();
86+
qlinear.load(*state_dict);
87+
EXPECT_TRUE(qlinear.verify());
8888

8989
auto input = torch::rand({40960, in_features}, options);
9090
auto output = qlinear.forward(input);

0 commit comments

Comments
 (0)