Skip to content

Commit 60a1192

Browse files
authored
feat: port torch::nn::module as llm::nn::module (#500)
1 parent 0f4d591 commit 60a1192

35 files changed

+2039
-274
lines changed

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ add_subdirectory(common)
33
add_subdirectory(handlers)
44
add_subdirectory(kernels)
55
add_subdirectory(tokenizer)
6+
add_subdirectory(module)
67
add_subdirectory(layers)
78
add_subdirectory(quantization)
89
add_subdirectory(models)

src/layers/CMakeLists.txt

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ include(cc_test)
33

44
cc_library(
55
NAME
6-
linear
6+
linear
77
HDRS
88
linear.h
99
qkv_linear.h
@@ -21,35 +21,37 @@ cc_library(
2121
:model_parallel
2222
:quantization
2323
:kernels
24+
:module
2425
glog::glog
2526
gflags::gflags
2627
torch
2728
)
2829

2930
cc_library(
30-
NAME
31+
NAME
3132
pos_embedding
32-
HDRS
33+
HDRS
3334
pos_embedding.h
34-
SRCS
35+
SRCS
3536
pos_embedding.cpp
3637
DEPS
3738
:state_dict
3839
:memory
3940
:kernels
41+
:module
4042
glog::glog
4143
gflags::gflags
4244
torch
4345
)
4446

4547
cc_library(
46-
NAME
48+
NAME
4749
layers
48-
HDRS
50+
HDRS
4951
normalization.h
5052
embedding.h
5153
activation.h
52-
SRCS
54+
SRCS
5355
activation.cpp
5456
DEPS
5557
:state_dict
@@ -58,6 +60,7 @@ cc_library(
5860
:pos_embedding
5961
:attention
6062
:kernels
63+
:module
6164
glog::glog
6265
gflags::gflags
6366
torch
@@ -80,4 +83,4 @@ cc_test(
8083
)
8184

8285
add_subdirectory(attention)
83-
add_subdirectory(moe)
86+
add_subdirectory(moe)

src/layers/attention/attention.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
#include "layers/attention/handler.h"
66
#include "memory/kv_cache.h"
77
#include "models/parameters.h"
8+
#include "module/module.h"
9+
#include "module/module_holder.h"
810

911
namespace llm {
1012

11-
class AttentionImpl : public torch::nn::Module {
13+
class AttentionImpl : public llm::nn::Module {
1214
public:
1315
AttentionImpl(int64_t n_heads,
1416
int64_t n_kv_heads,
@@ -38,6 +40,6 @@ class AttentionImpl : public torch::nn::Module {
3840
// sliding window for self-attention, -1 means no sliding window
3941
int32_t sliding_window_ = -1;
4042
};
41-
TORCH_MODULE(Attention);
43+
LLM_MODULE(Attention);
4244

4345
} // namespace llm

src/layers/embedding.h

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,16 @@
77

88
#include "model_loader/state_dict.h"
99
#include "model_parallel/model_parallel.h"
10+
#include "module/module.h"
11+
#include "module/module_holder.h"
1012

1113
namespace llm {
1214

1315
// A simple lookup table that stores embeddings of a fixed dictionary and size.
1416
// This module is often used to store word embeddings and retrieve them using
1517
// indices.
1618

17-
class EmbeddingImpl : public torch::nn::Module {
19+
class EmbeddingImpl : public llm::nn::Module {
1820
public:
1921
EmbeddingImpl(int64_t num_embeddings,
2022
int64_t embedding_dim,
@@ -63,10 +65,10 @@ class EmbeddingImpl : public torch::nn::Module {
6365
// whether the weight is loaded
6466
bool is_loaded_ = false;
6567
};
66-
TORCH_MODULE(Embedding);
68+
LLM_MODULE(Embedding);
6769

6870
// Embedding parallelized in the embedding dimension.
69-
class ParallelEmbeddingImpl : public torch::nn::Module {
71+
class ParallelEmbeddingImpl : public llm::nn::Module {
7072
public:
7173
ParallelEmbeddingImpl(int64_t num_embeddings,
7274
int64_t embedding_dim,
@@ -134,10 +136,10 @@ class ParallelEmbeddingImpl : public torch::nn::Module {
134136
// parallel args
135137
ParallelArgs parallel_args_;
136138
};
137-
TORCH_MODULE(ParallelEmbedding);
139+
LLM_MODULE(ParallelEmbedding);
138140

139141
// Embedding parallelized in the vocabulary dimension
140-
class VocabParallelEmbeddingImpl : public torch::nn::Module {
142+
class VocabParallelEmbeddingImpl : public llm::nn::Module {
141143
public:
142144
VocabParallelEmbeddingImpl(int64_t num_embeddings,
143145
int64_t embedding_dim,
@@ -152,8 +154,7 @@ class VocabParallelEmbeddingImpl : public torch::nn::Module {
152154
// register the weight parameter
153155
weight_ = register_parameter(
154156
"weight",
155-
torch::empty({num_embeddings_per_partition, embedding_dim},
156-
options),
157+
torch::empty({num_embeddings_per_partition, embedding_dim}, options),
157158
/*requires_grad=*/false);
158159
}
159160

@@ -218,5 +219,5 @@ class VocabParallelEmbeddingImpl : public torch::nn::Module {
218219
int64_t start_index_ = 0;
219220
int64_t end_index_ = 0;
220221
};
221-
TORCH_MODULE(VocabParallelEmbedding);
222+
LLM_MODULE(VocabParallelEmbedding);
222223
} // namespace llm

src/layers/fused_linear.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
#include "linear.h"
77
#include "model_loader/state_dict.h"
88
#include "model_parallel/parallel_args.h"
9+
#include "module/module.h"
10+
#include "module/module_holder.h"
911
#include "quantization/quant_args.h"
1012

1113
namespace llm {
1214

13-
class FusedColumnParallelLinearImpl : public torch::nn::Module {
15+
class FusedColumnParallelLinearImpl : public llm::nn::Module {
1416
public:
1517
FusedColumnParallelLinearImpl(int64_t in_features,
1618
const std::vector<int64_t>& out_features,
@@ -44,6 +46,6 @@ class FusedColumnParallelLinearImpl : public torch::nn::Module {
4446
// whether the linear layer is fused
4547
bool fused_ = false;
4648
};
47-
TORCH_MODULE(FusedColumnParallelLinear);
49+
LLM_MODULE(FusedColumnParallelLinear);
4850

4951
} // namespace llm

src/layers/linear.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
#include "model_loader/state_dict.h"
77
#include "model_parallel/parallel_args.h"
8+
#include "module/module.h"
9+
#include "module/module_holder.h"
810
#include "quantization/quant_args.h"
911

1012
namespace llm {
@@ -14,7 +16,7 @@ using TensorTransform = std::function<torch::Tensor(const torch::Tensor&)>;
1416
// an interface for parallel linear layer.
1517
// all linear classes should inherit from this class and implement the forward
1618
// function.
17-
class ParallelLinearImpl : public torch::nn::Module {
19+
class ParallelLinearImpl : public llm::nn::Module {
1820
public:
1921
~ParallelLinearImpl() override = default;
2022

@@ -37,10 +39,9 @@ class ParallelLinearImpl : public torch::nn::Module {
3739
}
3840
};
3941

40-
class ColumnParallelLinear
41-
: public torch::nn::ModuleHolder<ParallelLinearImpl> {
42+
class ColumnParallelLinear : public llm::nn::ModuleHolder<ParallelLinearImpl> {
4243
public:
43-
using torch::nn::ModuleHolder<ParallelLinearImpl>::ModuleHolder;
44+
using llm::nn::ModuleHolder<ParallelLinearImpl>::ModuleHolder;
4445
using Impl __attribute__((__unused__)) = ParallelLinearImpl;
4546

4647
// construct a rotary positional embedding.
@@ -61,9 +62,9 @@ class ColumnParallelLinear
6162
const torch::TensorOptions& options);
6263
};
6364

64-
class RowParallelLinear : public torch::nn::ModuleHolder<ParallelLinearImpl> {
65+
class RowParallelLinear : public llm::nn::ModuleHolder<ParallelLinearImpl> {
6566
public:
66-
using torch::nn::ModuleHolder<ParallelLinearImpl>::ModuleHolder;
67+
using llm::nn::ModuleHolder<ParallelLinearImpl>::ModuleHolder;
6768
using Impl __attribute__((__unused__)) = ParallelLinearImpl;
6869

6970
// construct a rotary positional embedding.

src/layers/normalization.h

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
#include "kernels/layernorm_kernels.h"
99
#include "model_loader/state_dict.h"
10+
#include "module/module.h"
11+
#include "module/module_holder.h"
1012

1113
DECLARE_bool(disable_custom_kernels);
1214
namespace llm {
@@ -63,7 +65,7 @@ inline torch::Tensor layer_norm(torch::Tensor input,
6365
// apply layer normalization over a mini-batch of inputs as described in
6466
// the paper `Layer Normalization`: https://arxiv.org/abs/1607.06450
6567
// x = ((x - mean(x)) / sqrt(std(x) + eps)) * weight + bias
66-
class LayerNormImpl : public torch::nn::Module {
68+
class LayerNormImpl : public llm::nn::Module {
6769
public:
6870
// dim: the dim over which the mean and std are calculated separately.
6971
// eps: a value added to the denominator for numerical stability.
@@ -140,10 +142,10 @@ class LayerNormImpl : public torch::nn::Module {
140142
float eps_;
141143
std::vector<int64_t> normalized_shape_;
142144
};
143-
TORCH_MODULE(LayerNorm);
145+
LLM_MODULE(LayerNorm);
144146

145147
// Root mean square normalization
146-
class RMSNormImpl : public torch::nn::Module {
148+
class RMSNormImpl : public llm::nn::Module {
147149
public:
148150
RMSNormImpl(int64_t dim, float eps, const torch::TensorOptions& options)
149151
: eps_(eps) {
@@ -191,9 +193,9 @@ class RMSNormImpl : public torch::nn::Module {
191193
// configs
192194
float eps_;
193195
};
194-
TORCH_MODULE(RMSNorm);
196+
LLM_MODULE(RMSNorm);
195197

196-
class GemmaRMSNormImpl : public torch::nn::Module {
198+
class GemmaRMSNormImpl : public llm::nn::Module {
197199
public:
198200
GemmaRMSNormImpl(int64_t dim, float eps, const torch::TensorOptions& options)
199201
: eps_(eps) {
@@ -241,10 +243,10 @@ class GemmaRMSNormImpl : public torch::nn::Module {
241243
// configs
242244
float eps_;
243245
};
244-
TORCH_MODULE(GemmaRMSNorm);
246+
LLM_MODULE(GemmaRMSNorm);
245247

246248
// Root mean square normalization
247-
class RMSNormResidualImpl : public torch::nn::Module {
249+
class RMSNormResidualImpl : public llm::nn::Module {
248250
public:
249251
RMSNormResidualImpl(int64_t dim,
250252
float eps,
@@ -304,6 +306,6 @@ class RMSNormResidualImpl : public torch::nn::Module {
304306
// configs
305307
float eps_;
306308
};
307-
TORCH_MODULE(RMSNormResidual);
309+
LLM_MODULE(RMSNormResidual);
308310

309311
} // namespace llm

src/layers/qkv_linear.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace llm {
1212

1313
// a thin wrapper to handle state_dict loading for QKV with
1414
// support of MQA/GQA
15-
class QKVColumnParallelLinearImpl : public torch::nn::Module {
15+
class QKVColumnParallelLinearImpl : public llm::nn::Module {
1616
public:
1717
QKVColumnParallelLinearImpl(int64_t hidden_size,
1818
int64_t n_heads,
@@ -47,6 +47,6 @@ class QKVColumnParallelLinearImpl : public torch::nn::Module {
4747

4848
int64_t head_dim_ = 0;
4949
};
50-
TORCH_MODULE(QKVColumnParallelLinear);
50+
LLM_MODULE(QKVColumnParallelLinear);
5151

5252
} // namespace llm

src/models/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ include(cc_library)
22
include(cc_test)
33

44
cc_library(
5-
NAME
5+
NAME
66
models
77
HDRS
88
model_args.h
@@ -18,7 +18,7 @@ cc_library(
1818
:quantization
1919
:memory
2020
:chat_template
21+
:module
2122
glog::glog
2223
torch
2324
)
24-

0 commit comments

Comments
 (0)