Skip to content

Commit 8cbcd1c

Browse files
committed
add unittests
1 parent 2500606 commit 8cbcd1c

File tree

4 files changed

+204
-138
lines changed

4 files changed

+204
-138
lines changed

src/layers/linear/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ cc_test(
2929
NAME
3030
linear_test
3131
SRCS
32-
linear_test.cpp
32+
parallel_linear_test.cpp
33+
multi_parallel_linear_test.cpp
3334
qkv_parallel_linear_test.cpp
3435
DEPS
3536
:linear
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
#include "multi_parallel_linear.h"
2+
3+
#include <c10/core/Device.h>
4+
#include <c10/core/ScalarType.h>
5+
#include <glog/logging.h>
6+
#include <gtest/gtest.h>
7+
#include <torch/torch.h>
8+
9+
#include <cstddef>
10+
#include <torch/csrc/distributed/c10d/FileStore.hpp>
11+
#include <torch/csrc/distributed/c10d/HashStore.hpp>
12+
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
13+
14+
#include "model_loader/state_dict.h"
15+
16+
namespace llm {
17+
18+
TEST(MultiParallelLinearTest, FusedColumnParallelLinear) {
19+
// test load state dict for linear
20+
const int64_t in_features = 10;
21+
const int64_t out_features = 40;
22+
23+
torch::Device device(torch::kCPU);
24+
torch::ScalarType dtype(torch::kFloat);
25+
const auto options = torch::dtype(dtype).device(device);
26+
27+
std::vector<int64_t> out_features_vec = {
28+
out_features, out_features, out_features};
29+
std::vector<std::string> prefixes = {"query.", "key.", "value."};
30+
31+
std::unordered_map<std::string, torch::Tensor> state_dict_data;
32+
// Allocate transposed weight matrix
33+
state_dict_data["query.weight"] = torch::randn({out_features, in_features});
34+
state_dict_data["key.weight"] = torch::randn({out_features, in_features});
35+
state_dict_data["value.weight"] = torch::randn({out_features, in_features});
36+
37+
// weight is not sharded
38+
StateDict state_dict(state_dict_data);
39+
40+
// test load weight
41+
{
42+
ParallelArgs parallel_args(0, 1, nullptr);
43+
FusedColumnParallelLinearImpl linear(in_features,
44+
out_features_vec,
45+
prefixes,
46+
/*bias=*/false,
47+
/*gather_output=*/false,
48+
parallel_args,
49+
options);
50+
// test load fused weight
51+
EXPECT_EQ(linear.load(state_dict), 3);
52+
53+
for (const auto& prefix : prefixes) {
54+
auto named_parameters = linear.named_parameters(/*recurse=*/false);
55+
const auto key = detail::join_name(prefix, "weight");
56+
ASSERT_TRUE(named_parameters.contains(key));
57+
58+
const auto& loaded_weight = named_parameters[key];
59+
EXPECT_EQ(loaded_weight.sizes(),
60+
torch::IntArrayRef({out_features, in_features}));
61+
EXPECT_TRUE(torch::equal(loaded_weight, state_dict_data[key]));
62+
}
63+
64+
// verify the fused weight
65+
const auto loaded_fused_weight = linear.weight();
66+
const auto desired_fused_weight =
67+
torch::cat({state_dict_data["query.weight"],
68+
state_dict_data["key.weight"],
69+
state_dict_data["value.weight"]},
70+
/*dim=*/0);
71+
EXPECT_TRUE(torch::equal(loaded_fused_weight, desired_fused_weight));
72+
}
73+
74+
// test load weight with 4 shards
75+
const int32_t num_shards = 4;
76+
for (int32_t shard_id = 0; shard_id < num_shards; ++shard_id) {
77+
ParallelArgs parallel_args_0(shard_id, num_shards, nullptr);
78+
FusedColumnParallelLinearImpl linear(in_features,
79+
out_features_vec,
80+
prefixes,
81+
/*bias=*/false,
82+
/*gather_output=*/false,
83+
parallel_args_0,
84+
options);
85+
EXPECT_EQ(linear.load(state_dict), 3);
86+
87+
auto named_parameters = linear.named_parameters(/*recurse=*/false);
88+
89+
// check size for each prefix
90+
for (const auto& prefix : prefixes) {
91+
const auto key = detail::join_name(prefix, "weight");
92+
ASSERT_TRUE(named_parameters.contains(key));
93+
94+
const auto& loaded_weight = named_parameters[key];
95+
EXPECT_EQ(loaded_weight.sizes(),
96+
torch::IntArrayRef({out_features / num_shards, in_features}));
97+
EXPECT_TRUE(torch::equal(
98+
loaded_weight, state_dict_data[key].chunk(num_shards, 0)[shard_id]));
99+
}
100+
101+
// shard weight then cat
102+
auto sharded_query_weight =
103+
state_dict_data["query.weight"].chunk(num_shards, 0)[shard_id];
104+
auto sharded_key_weight =
105+
state_dict_data["key.weight"].chunk(num_shards, 0)[shard_id];
106+
auto sharded_value_weight =
107+
state_dict_data["value.weight"].chunk(num_shards, 0)[shard_id];
108+
109+
// verify the fused weight
110+
const auto loaded_fused_weight = linear.weight();
111+
auto desired_fused_weight = torch::cat(
112+
{sharded_query_weight, sharded_key_weight, sharded_value_weight},
113+
/*dim=*/0);
114+
115+
EXPECT_TRUE(torch::equal(loaded_fused_weight, desired_fused_weight));
116+
}
117+
}
118+
119+
TEST(MultiParallelLinearTest, GroupedColumnParallelLinear) {
120+
const int64_t in_features = 10;
121+
const int64_t out_features = 40;
122+
std::vector<int64_t> out_features_vec = {
123+
out_features, out_features, out_features};
124+
std::vector<std::string> prefixes = {"query.", "key.", "value."};
125+
126+
torch::Device device(torch::kCPU);
127+
torch::ScalarType dtype(torch::kFloat);
128+
const auto options = torch::dtype(dtype).device(device);
129+
130+
std::unordered_map<std::string, torch::Tensor> state_dict_data;
131+
// Allocate transposed weight matrix
132+
state_dict_data["query.weight"] = torch::randn({out_features, in_features});
133+
state_dict_data["key.weight"] = torch::randn({out_features, in_features});
134+
state_dict_data["value.weight"] = torch::randn({out_features, in_features});
135+
// weight is not sharded
136+
StateDict state_dict(state_dict_data);
137+
138+
// test load weight
139+
{
140+
ParallelArgs parallel_args(0, 1, nullptr);
141+
GroupedColumnParallelLinearImpl linear(in_features,
142+
out_features_vec,
143+
prefixes,
144+
/*bias=*/false,
145+
/*gather_output=*/false,
146+
parallel_args,
147+
options);
148+
// test load grouped weight
149+
EXPECT_EQ(linear.load(state_dict), 3);
150+
151+
auto named_parameters = linear.named_parameters(/*recurse=*/true);
152+
for (size_t i = 0; i < prefixes.size(); ++i) {
153+
const auto prefix = "linear_" + std::to_string(i) + "." + prefixes[i];
154+
const auto key = detail::join_name(prefix, "weight");
155+
ASSERT_TRUE(named_parameters.contains(key));
156+
157+
const auto& loaded_weight = named_parameters[key];
158+
159+
const auto sd_key = detail::join_name(prefixes[i], "weight");
160+
161+
EXPECT_EQ(loaded_weight.sizes(),
162+
torch::IntArrayRef({out_features, in_features}));
163+
EXPECT_TRUE(torch::equal(loaded_weight, state_dict_data[sd_key]));
164+
}
165+
}
166+
167+
// test load weight with 4 shards
168+
const int32_t num_shards = 4;
169+
for (int32_t shard_id = 0; shard_id < num_shards; ++shard_id) {
170+
ParallelArgs parallel_args(shard_id, num_shards, nullptr);
171+
GroupedColumnParallelLinearImpl linear(in_features,
172+
out_features_vec,
173+
prefixes,
174+
/*bias=*/false,
175+
/*gather_output=*/false,
176+
parallel_args,
177+
options);
178+
EXPECT_EQ(linear.load(state_dict), 3);
179+
auto named_parameters = linear.named_parameters(/*recurse=*/true);
180+
// check size for each prefix
181+
for (size_t i = 0; i < prefixes.size(); ++i) {
182+
const auto prefix = "linear_" + std::to_string(i) + "." + prefixes[i];
183+
const auto key = detail::join_name(prefix, "weight");
184+
ASSERT_TRUE(named_parameters.contains(key));
185+
186+
const auto& loaded_weight = named_parameters[key];
187+
EXPECT_EQ(loaded_weight.sizes(),
188+
torch::IntArrayRef({out_features / num_shards, in_features}));
189+
const auto sd_key = detail::join_name(prefixes[i], "weight");
190+
EXPECT_TRUE(
191+
torch::equal(loaded_weight,
192+
state_dict_data[sd_key].chunk(num_shards, 0)[shard_id]));
193+
}
194+
}
195+
}
196+
197+
} // namespace llm

src/layers/linear/parallel_linear.cpp

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -215,37 +215,6 @@ std::shared_ptr<ParallelLinearImpl> create_row_parallel_linear(
215215
parallel_args,
216216
options);
217217
}
218-
219-
// std::shared_ptr<MultiParallelLinearImpl> create_multi_column_parallel_linear(
220-
// int64_t in_features,
221-
// const std::vector<int64_t>& out_features,
222-
// const std::vector<std::string>& prefixes,
223-
// bool bias,
224-
// bool gather_output,
225-
// const QuantArgs& quant_args,
226-
// const ParallelArgs& parallel_args,
227-
// const torch::TensorOptions& options) {
228-
// // check if the linear layers can be fused
229-
// const bool fused = quant_args.can_be_fused();
230-
// std::shared_ptr<MultiParallelLinearImpl> impl;
231-
// if (fused) {
232-
// return std::make_shared<FusedColumnParallelLinearImpl>(in_features,
233-
// out_features,
234-
// prefixes,
235-
// bias,
236-
// gather_output,
237-
// parallel_args,
238-
// options);
239-
// }
240-
241-
// return std::make_shared<GroupedColumnParallelLinearImpl>(in_features,
242-
// out_features,
243-
// prefixes,
244-
// bias,
245-
// gather_output,
246-
// parallel_args,
247-
// options);
248-
// }
249218
} // namespace
250219

251220
// Linear layer with column parallelism.
Lines changed: 5 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
1+
#include "parallel_linear.h"
2+
13
#include <c10/core/Device.h>
24
#include <c10/core/ScalarType.h>
35
#include <glog/logging.h>
46
#include <gtest/gtest.h>
57
#include <torch/torch.h>
68

9+
#include <cstddef>
710
#include <torch/csrc/distributed/c10d/FileStore.hpp>
811
#include <torch/csrc/distributed/c10d/HashStore.hpp>
912
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
1013

1114
#include "model_loader/state_dict.h"
12-
#include "multi_parallel_linear.h"
13-
#include "parallel_linear.h"
1415

1516
namespace llm {
1617

17-
TEST(LinearTest, RowParallelLoadWeight) {
18+
TEST(ParallelLinearTest, RowParallelLinear) {
1819
// test load state dict for row parallel linear
1920
const int64_t in_features = 10;
2021
const int64_t out_features = 20;
@@ -79,7 +80,7 @@ TEST(LinearTest, RowParallelLoadWeight) {
7980
}
8081
}
8182

82-
TEST(LinearTest, ColumnParallelLoadWeight) {
83+
TEST(ParallelLinearTest, ColumnParallelLinear) {
8384
// test load state dict for linear
8485
const int64_t in_features = 10;
8586
const int64_t out_features = 20;
@@ -136,106 +137,4 @@ TEST(LinearTest, ColumnParallelLoadWeight) {
136137
}
137138
}
138139

139-
TEST(LinearTest, ColumnParallelLoadFusedWeight) {
140-
// test load state dict for linear
141-
const int64_t in_features = 10;
142-
const int64_t out_features = 40;
143-
144-
torch::Device device(torch::kCPU);
145-
torch::ScalarType dtype(torch::kFloat);
146-
const auto options = torch::dtype(dtype).device(device);
147-
148-
std::vector<int64_t> out_features_vec = {
149-
out_features, out_features, out_features};
150-
std::vector<std::string> prefixes = {"query.", "key.", "value."};
151-
152-
std::unordered_map<std::string, torch::Tensor> state_dict_data;
153-
// Allocate transposed weight matrix
154-
state_dict_data["query.weight"] = torch::randn({out_features, in_features});
155-
state_dict_data["key.weight"] = torch::randn({out_features, in_features});
156-
state_dict_data["value.weight"] = torch::randn({out_features, in_features});
157-
158-
// weight is not sharded
159-
StateDict state_dict(state_dict_data);
160-
161-
// test load weight
162-
{
163-
ParallelArgs parallel_args(0, 1, nullptr);
164-
FusedColumnParallelLinearImpl linear(in_features,
165-
out_features_vec,
166-
prefixes,
167-
/*bias=*/false,
168-
/*gather_output=*/false,
169-
parallel_args,
170-
options);
171-
// test load fused weight
172-
EXPECT_EQ(linear.load(state_dict), 3);
173-
174-
for (const auto& prefix : prefixes) {
175-
auto named_parameters = linear.named_parameters(/*recurse=*/false);
176-
const auto key = detail::join_name(prefix, "weight");
177-
ASSERT_TRUE(named_parameters.contains(key));
178-
179-
const auto& loaded_weight = named_parameters[key];
180-
EXPECT_EQ(loaded_weight.sizes(),
181-
torch::IntArrayRef({out_features, in_features}));
182-
EXPECT_TRUE(torch::equal(loaded_weight, state_dict_data[key]));
183-
}
184-
185-
// verify the fused weight
186-
const auto loaded_fused_weight = linear.weight();
187-
const auto desired_fused_weight =
188-
torch::cat({state_dict_data["query.weight"],
189-
state_dict_data["key.weight"],
190-
state_dict_data["value.weight"]},
191-
/*dim=*/0);
192-
EXPECT_TRUE(torch::equal(loaded_fused_weight, desired_fused_weight));
193-
}
194-
195-
// test load weight with 4 shards
196-
const int32_t num_shards = 4;
197-
for (int32_t shard_id = 0; shard_id < num_shards; ++shard_id) {
198-
ParallelArgs parallel_args_0(shard_id, num_shards, nullptr);
199-
FusedColumnParallelLinearImpl linear(in_features,
200-
out_features_vec,
201-
prefixes,
202-
/*bias=*/false,
203-
/*gather_output=*/false,
204-
parallel_args_0,
205-
options);
206-
EXPECT_EQ(linear.load(state_dict), 3);
207-
208-
auto named_parameters = linear.named_parameters(/*recurse=*/false);
209-
210-
// check size for each prefix
211-
for (const auto& prefix : prefixes) {
212-
auto named_parameters = linear.named_parameters(/*recurse=*/false);
213-
const auto key = detail::join_name(prefix, "weight");
214-
ASSERT_TRUE(named_parameters.contains(key));
215-
216-
const auto& loaded_weight = named_parameters[key];
217-
EXPECT_EQ(loaded_weight.sizes(),
218-
torch::IntArrayRef({out_features / num_shards, in_features}));
219-
EXPECT_TRUE(torch::equal(
220-
loaded_weight, state_dict_data[key].chunk(num_shards, 0)[shard_id]));
221-
}
222-
223-
// shard weight then cat
224-
auto sharded_query_weight =
225-
state_dict_data["query.weight"].chunk(num_shards, 0)[shard_id];
226-
auto sharded_key_weight =
227-
state_dict_data["key.weight"].chunk(num_shards, 0)[shard_id];
228-
auto sharded_value_weight =
229-
state_dict_data["value.weight"].chunk(num_shards, 0)[shard_id];
230-
231-
// verify the fused weight
232-
const auto loaded_fused_weight = linear.weight();
233-
auto desired_fused_weight = torch::cat(
234-
{sharded_query_weight, sharded_key_weight, sharded_value_weight},
235-
/*dim=*/0);
236-
237-
EXPECT_TRUE(torch::equal(loaded_fused_weight, desired_fused_weight));
238-
}
239-
}
240-
241140
} // namespace llm

0 commit comments

Comments
 (0)