Skip to content

Commit 94e04bc

Browse files
authored
feat: add state dict load and verify for module (#502)
Issues to fix: - [ ] qwen with multiple gpus generates garbage. - [ ] qwen2 generates '!!!!!' - [ ] phi2 out-of-date
1 parent 6ddb3bd commit 94e04bc

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+645
-1026
lines changed

README.md

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,12 @@ ScaleLLM
3232

3333
<div align="left">
3434

35-
[ScaleLLM](#) is a cutting-edge inference system engineered for large language models (LLMs), designed to meet the demands of production environments. It extends its support to a wide range of popular open-source models, including [Llama3.1](https://github.com/meta-llama/llama3), [Gemma2](https://github.com/google-deepmind/gemma), Bloom, GPT-NeoX, and more.
35+
[ScaleLLM](#) is a cutting-edge inference system engineered for large language models (LLMs), designed to meet the demands of production environments. It extends its support to a wide range of popular open-source models, including [Llama3.1](https://github.com/meta-llama/llama3), [Gemma2](https://github.com/google-deepmind/gemma), [Phi](https://huggingface.co/microsoft/phi-2), and more.
3636

3737
ScaleLLM is currently undergoing active development. We are fully committed to consistently enhancing its efficiency while also incorporating additional features. Feel free to explore our [**_Roadmap_**](https://github.com/vectorch-ai/ScaleLLM/issues/84) for more details.
3838

3939
## News:
40+
* [01/2025] - Optimized inhouse Attention kernels
4041
* [06/2024] - ScaleLLM is now available on [PyPI](https://pypi.org/project/scalellm/). You can install it using `pip install scalellm`.
4142
* [03/2024] - [Advanced features](#advanced-features) support for ✅ [CUDA graph](#cuda-graph), ✅ [prefix cache](#prefix-cache), ✅ [chunked prefill](#chunked-prefill) and ✅ [speculative decoding](#speculative-decoding).
4243
* [11/2023] - [First release](https://github.com/vectorch-ai/ScaleLLM/releases/tag/v0.0.1) with support for popular [open-source models](#supported-models).
@@ -274,21 +275,11 @@ Quantization is a crucial process for reducing the memory footprint of models. S
274275

275276
| Models | Tensor Parallel | Quantization | Chat API | HF models examples |
276277
| :--------: | :-------------: | :----------: | :------: | :---------------------------:|
277-
| Aquila | Yes | Yes | Yes | [BAAI/Aquila-7B](https://huggingface.co/BAAI/Aquila-7B), [BAAI/AquilaChat-7B](https://huggingface.co/BAAI/AquilaChat-7B) |
278-
| Bloom | Yes | Yes | No | [bigscience/bloom](https://huggingface.co/bigscience/bloom) |
279-
| Baichuan | Yes | Yes | Yes | [baichuan-inc/Baichuan2-7B-Chat](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat) |
280-
| ChatGLM4/3 | Yes | Yes | Yes | [THUDM/chatglm3-6b](https://huggingface.co/THUDM/chatglm3-6b) |
281278
| Gemma2 | Yes | Yes | Yes | [google/gemma-2-2b](https://huggingface.co/google/gemma-2-2b) |
282-
| GPT_j | Yes | Yes | No | [EleutherAI/gpt-j-6b](https://huggingface.co/EleutherAI/gpt-j-6b) |
283-
| GPT_NeoX | Yes | Yes | No | [EleutherAI/gpt-neox-20b](https://huggingface.co/EleutherAI/gpt-neox-20b) |
284279
| GPT2 | Yes | Yes | No | [gpt2](https://huggingface.co/gpt2)|
285-
| InternLM | Yes | Yes | Yes | [internlm/internlm-7b](https://huggingface.co/internlm/internlm-7b) |
286280
| Llama3/2 | Yes | Yes | Yes | [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct), [meta-llama/Meta-Llama-3.1-8B](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B) |
287-
| Mistral | Yes | Yes | Yes | [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) |
288-
| MPT | Yes | Yes | Yes | [mosaicml/mpt-30b](https://huggingface.co/mosaicml/mpt-30b) |
289281
| Phi2 | Yes | Yes | No | [microsoft/phi-2](https://huggingface.co/microsoft/phi-2) |
290282
| Qwen2 | Yes | Yes | Yes | [Qwen/Qwen-72B-Chat](https://huggingface.co/Qwen/Qwen-72B-Chat) |
291-
| Yi | Yes | Yes | Yes |[01-ai/Yi-6B](https://huggingface.co/01-ai/Yi-6B), [01-ai/Yi-34B-Chat-4bits](https://huggingface.co/01-ai/Yi-34B-Chat-4bits), [01-ai/Yi-6B-200K](https://huggingface.co/01-ai/Yi-6B-200K) |
292283

293284
If your model is not included in the supported list, we are more than willing to assist you. Please feel free to create a request for adding a new model on [GitHub Issues](https://github.com/vectorch-ai/ScaleLLM/issues).
294285

src/layers/embedding.h

Lines changed: 14 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -33,26 +33,6 @@ class EmbeddingImpl : public Module {
3333
return F::embedding(input, weight_);
3434
}
3535

36-
// load the weight from the checkpoint
37-
void load_state_dict(const StateDict& state_dict) {
38-
const auto weight = state_dict.get_tensor("weight");
39-
if (weight.defined()) {
40-
CHECK_EQ(weight_.sizes(), weight.sizes())
41-
<< "weight size mismatch for " << name();
42-
weight_.copy_(weight);
43-
is_loaded_ = true;
44-
}
45-
}
46-
47-
// whether the weight is loaded
48-
void verify_loaded_weights(const std::string& prefix) const {
49-
CHECK(is_loaded_) << "weight is not loaded for " << prefix + "weight";
50-
}
51-
52-
void pretty_print(std::ostream& stream) const override {
53-
stream << name() << " " << weight_.sizes() << " " << weight_.device();
54-
}
55-
5636
// return the weight (for testing)
5737
torch::Tensor weight() const { return weight_; }
5838

@@ -73,15 +53,19 @@ class ParallelEmbeddingImpl : public Module {
7353
const ParallelArgs& parallel_args,
7454
const torch::TensorOptions& options)
7555
: parallel_args_(parallel_args) {
56+
const auto rank = parallel_args_.rank();
7657
const auto world_size = parallel_args_.world_size();
7758
CHECK(embedding_dim % world_size == 0)
7859
<< "out_features " << embedding_dim << " not divisible by world_size "
7960
<< world_size;
8061
const int64_t embedding_dim_per_partition = embedding_dim / world_size;
8162

8263
// register the weight parameter
83-
weight_ = register_parameter(
64+
weight_ = register_sharded_parameter(
8465
"weight",
66+
/*dim=*/1,
67+
rank,
68+
world_size,
8569
torch::empty({num_embeddings, embedding_dim_per_partition}, options));
8670
}
8771

@@ -96,30 +80,6 @@ class ParallelEmbeddingImpl : public Module {
9680
return output;
9781
}
9882

99-
// load the weight from the checkpoint
100-
void load_state_dict(const StateDict& state_dict) {
101-
const auto weight = state_dict.get_sharded_tensor(
102-
"weight",
103-
/*dim=*/1,
104-
/*rank=*/parallel_args_.rank(),
105-
/*world_size=*/parallel_args_.world_size());
106-
if (weight.defined()) {
107-
CHECK_EQ(weight_.sizes(), weight.sizes())
108-
<< "weight size mismatch for " << name();
109-
weight_.copy_(weight);
110-
is_loaded_ = true;
111-
}
112-
}
113-
114-
// whether the weight is loaded
115-
void verify_loaded_weights(const std::string& prefix) const {
116-
CHECK(is_loaded_) << "weight is not loaded for " << prefix + "weight";
117-
}
118-
119-
void pretty_print(std::ostream& stream) const override {
120-
stream << name() << " " << weight_.sizes() << " " << weight_.device();
121-
}
122-
12383
// return the weight (for testing)
12484
torch::Tensor weight() const { return weight_; }
12585

@@ -143,14 +103,19 @@ class VocabParallelEmbeddingImpl : public Module {
143103
const ParallelArgs& parallel_args,
144104
const torch::TensorOptions& options)
145105
: parallel_args_(parallel_args) {
146-
const int64_t num_embeddings_per_partition =
147-
num_embeddings / parallel_args_.world_size();
148-
start_index_ = num_embeddings_per_partition * parallel_args_.rank();
106+
const auto rank = parallel_args_.rank();
107+
const auto world_size = parallel_args_.world_size();
108+
109+
const int64_t num_embeddings_per_partition = num_embeddings / world_size;
110+
start_index_ = num_embeddings_per_partition * rank;
149111
end_index_ = start_index_ + num_embeddings_per_partition;
150112

151113
// register the weight parameter
152-
weight_ = register_parameter(
114+
weight_ = register_sharded_parameter(
153115
"weight",
116+
/*dim=*/0,
117+
rank,
118+
world_size,
154119
torch::empty({num_embeddings_per_partition, embedding_dim}, options));
155120
}
156121

@@ -174,30 +139,6 @@ class VocabParallelEmbeddingImpl : public Module {
174139
return reduce_from_model_parallel_region(output, parallel_args_);
175140
}
176141

177-
// load the weight from the checkpoint
178-
void load_state_dict(const StateDict& state_dict) {
179-
const auto weight = state_dict.get_sharded_tensor(
180-
"weight",
181-
/*dim=*/0,
182-
/*rank=*/parallel_args_.rank(),
183-
/*world_size=*/parallel_args_.world_size());
184-
if (weight.defined()) {
185-
CHECK_EQ(weight_.sizes(), weight.sizes())
186-
<< "weight size mismatch for " << name();
187-
weight_.copy_(weight);
188-
is_loaded_ = true;
189-
}
190-
}
191-
192-
// whether the weight is loaded
193-
void verify_loaded_weights(const std::string& prefix = "") const {
194-
CHECK(is_loaded_) << "weight is not loaded for " << prefix + "weight";
195-
}
196-
197-
void pretty_print(std::ostream& stream) const override {
198-
stream << name() << " " << weight_.sizes() << " " << weight_.device();
199-
}
200-
201142
// return the weight (for testing)
202143
torch::Tensor weight() const { return weight_; }
203144

src/layers/fused_linear.cpp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@ namespace llm {
1313
FusedColumnParallelLinearImpl::FusedColumnParallelLinearImpl(
1414
int64_t in_features,
1515
const std::vector<int64_t>& out_features_vec,
16+
const std::vector<std::string>& prefixes,
1617
bool bias,
1718
bool gather_output,
1819
const QuantArgs& quant_args,
1920
const ParallelArgs& parallel_args,
2021
const torch::TensorOptions& options) {
22+
prefixes_ = prefixes;
2123
// check if the linear layers can be fused
2224
fused_ = quant_args.can_be_fused();
2325
if (fused_) {
@@ -72,28 +74,29 @@ std::vector<torch::Tensor> FusedColumnParallelLinearImpl::forward(
7274
return outputs;
7375
}
7476

75-
void FusedColumnParallelLinearImpl::load_state_dict(
76-
const StateDict& state_dict,
77-
const std::vector<std::string>& prefixes) {
77+
size_t FusedColumnParallelLinearImpl::load(const StateDict& state_dict,
78+
const std::string&) {
7879
if (fused_) {
79-
fused_linear_->load_state_dict(state_dict, prefixes);
80+
fused_linear_->load_state_dict(state_dict, prefixes_);
8081
} else {
81-
CHECK_EQ(parallel_linears_.size(), prefixes.size());
82+
CHECK_EQ(parallel_linears_.size(), prefixes_.size());
8283
for (size_t i = 0; i < parallel_linears_.size(); ++i) {
83-
parallel_linears_[i]->load_state_dict(state_dict.select(prefixes[i]));
84+
parallel_linears_[i]->load_state_dict(state_dict.select(prefixes_[i]));
8485
}
8586
}
87+
return 0;
8688
}
8789

88-
void FusedColumnParallelLinearImpl::verify_loaded_weights(
89-
const std::string& prefix) const {
90+
bool FusedColumnParallelLinearImpl::verify(
91+
const std::string& name_prefix) const {
9092
if (fused_) {
91-
fused_linear_->verify_loaded_weights(prefix);
93+
fused_linear_->verify_loaded_weights(name_prefix);
9294
} else {
9395
for (const auto& parallel_linear : parallel_linears_) {
94-
parallel_linear->verify_loaded_weights(prefix);
96+
parallel_linear->verify_loaded_weights(name_prefix);
9597
}
9698
}
99+
return true;
97100
}
98101

99102
} // namespace llm

src/layers/fused_linear.h

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class FusedColumnParallelLinearImpl : public Module {
1616
public:
1717
FusedColumnParallelLinearImpl(int64_t in_features,
1818
const std::vector<int64_t>& out_features,
19+
const std::vector<std::string>& prefixes,
1920
bool bias,
2021
bool gather_output,
2122
const QuantArgs& quant_args,
@@ -24,11 +25,13 @@ class FusedColumnParallelLinearImpl : public Module {
2425

2526
std::vector<torch::Tensor> forward(torch::Tensor input);
2627

27-
// load_state_dict for fused weights
28-
void load_state_dict(const StateDict& state_dict,
29-
const std::vector<std::string>& prefixes);
28+
// load weights from the checkpoint, override this method if necessary
29+
// returns the number of loaded parameters
30+
size_t load(const StateDict& state_dict,
31+
const std::string& name_prefix = std::string()) override;
3032

31-
void verify_loaded_weights(const std::string& prefix = "") const;
33+
// verify whether the weights are loaded, override this method if necessary
34+
bool verify(const std::string& name_prefix = std::string()) const override;
3235

3336
// whether the linear layer is fused
3437
bool fused() const { return fused_; }
@@ -43,6 +46,8 @@ class FusedColumnParallelLinearImpl : public Module {
4346
// sizes for each split
4447
std::vector<int64_t> split_sizes_;
4548

49+
std::vector<std::string> prefixes_;
50+
4651
// whether the linear layer is fused
4752
bool fused_ = false;
4853
};

src/layers/linear_impl.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ ColumnParallelLinearImpl::ColumnParallelLinearImpl(
1818
const ParallelArgs& parallel_args,
1919
const torch::TensorOptions& options)
2020
: gather_output_(gather_output), parallel_args_(parallel_args) {
21+
const auto rank = parallel_args_.rank();
2122
const auto world_size = parallel_args_.world_size();
2223
CHECK(out_features % world_size == 0)
2324
<< "out_features " << out_features << " not divisible by world_size "
@@ -26,13 +27,20 @@ ColumnParallelLinearImpl::ColumnParallelLinearImpl(
2627

2728
// Note: torch.nn.functional.linear performs XA^T + b and as a result
2829
// we allocate the transpose.
29-
weight_ = register_parameter(
30+
weight_ = register_sharded_parameter(
3031
"weight",
32+
/*dim=*/0,
33+
rank,
34+
world_size,
3135
torch::empty({out_features_per_partition, in_features}, options));
3236

3337
if (bias) {
34-
bias_ = register_parameter(
35-
"bias", torch::empty({out_features_per_partition}, options));
38+
bias_ = register_sharded_parameter(
39+
"bias",
40+
/*dim=*/0,
41+
rank,
42+
world_size,
43+
torch::empty({out_features_per_partition}, options));
3644
}
3745
}
3846

@@ -93,14 +101,18 @@ RowParallelLinearImpl::RowParallelLinearImpl(
93101
const torch::TensorOptions& options)
94102
: input_is_parallelized_(input_is_parallelized),
95103
parallel_args_(parallel_args) {
104+
const auto rank = parallel_args_.rank();
96105
const auto world_size = parallel_args_.world_size();
97106
CHECK(in_features % world_size == 0)
98107
<< "in_features " << in_features << " not divisible by world_size "
99108
<< world_size;
100109
const int64_t in_features_per_partition = in_features / world_size;
101110
// Allocate the transpose since linear performs XA^T.
102-
weight_ = register_parameter(
111+
weight_ = register_sharded_parameter(
103112
"weight",
113+
/*dim=*/1,
114+
rank,
115+
world_size,
104116
torch::empty({out_features, in_features_per_partition}, options));
105117

106118
if (bias) {

src/layers/linear_impl.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,6 @@ class ColumnParallelLinearImpl : public ParallelLinearImpl {
4242
<< "bias is not loaded for " << prefix + "bias";
4343
}
4444

45-
void pretty_print(std::ostream& stream) const override {
46-
stream << name() << " " << weight_.sizes() << " " << weight_.device();
47-
}
48-
4945
// return the weight (for testing)
5046
torch::Tensor weight() const { return weight_; }
5147

@@ -95,10 +91,6 @@ class RowParallelLinearImpl : public ParallelLinearImpl {
9591
<< "bias is not loaded for " << prefix + "bias";
9692
}
9793

98-
void pretty_print(std::ostream& stream) const override {
99-
stream << name() << " " << weight_.sizes() << " " << weight_.device();
100-
}
101-
10294
// return the weight (for testing)
10395
torch::Tensor weight() const { return weight_; }
10496

0 commit comments

Comments
 (0)