Skip to content

Commit 6ddb3bd

Browse files
authored
feat: clean up module for inference use only (#501)
1 parent 60a1192 commit 6ddb3bd

34 files changed

+384
-1113
lines changed

src/layers/attention/attention.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
namespace llm {
1212

13-
class AttentionImpl : public llm::nn::Module {
13+
class AttentionImpl : public Module {
1414
public:
1515
AttentionImpl(int64_t n_heads,
1616
int64_t n_kv_heads,

src/layers/embedding.h

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,14 @@ namespace llm {
1616
// This module is often used to store word embeddings and retrieve them using
1717
// indices.
1818

19-
class EmbeddingImpl : public llm::nn::Module {
19+
class EmbeddingImpl : public Module {
2020
public:
2121
EmbeddingImpl(int64_t num_embeddings,
2222
int64_t embedding_dim,
2323
const torch::TensorOptions& options) {
2424
// register the weight parameter
2525
weight_ = register_parameter(
26-
"weight",
27-
torch::empty({num_embeddings, embedding_dim}, options),
28-
/*requires_grad=*/false);
26+
"weight", torch::empty({num_embeddings, embedding_dim}, options));
2927
}
3028

3129
// The input to the module is a list of indices, and the output is the
@@ -68,7 +66,7 @@ class EmbeddingImpl : public llm::nn::Module {
6866
LLM_MODULE(Embedding);
6967

7068
// Embedding parallelized in the embedding dimension.
71-
class ParallelEmbeddingImpl : public llm::nn::Module {
69+
class ParallelEmbeddingImpl : public Module {
7270
public:
7371
ParallelEmbeddingImpl(int64_t num_embeddings,
7472
int64_t embedding_dim,
@@ -84,8 +82,7 @@ class ParallelEmbeddingImpl : public llm::nn::Module {
8482
// register the weight parameter
8583
weight_ = register_parameter(
8684
"weight",
87-
torch::empty({num_embeddings, embedding_dim_per_partition}, options),
88-
/*requires_grad=*/false);
85+
torch::empty({num_embeddings, embedding_dim_per_partition}, options));
8986
}
9087

9188
// The input to the module is a list of indices, and the output is the
@@ -139,7 +136,7 @@ class ParallelEmbeddingImpl : public llm::nn::Module {
139136
LLM_MODULE(ParallelEmbedding);
140137

141138
// Embedding parallelized in the vocabulary dimension
142-
class VocabParallelEmbeddingImpl : public llm::nn::Module {
139+
class VocabParallelEmbeddingImpl : public Module {
143140
public:
144141
VocabParallelEmbeddingImpl(int64_t num_embeddings,
145142
int64_t embedding_dim,
@@ -154,8 +151,7 @@ class VocabParallelEmbeddingImpl : public llm::nn::Module {
154151
// register the weight parameter
155152
weight_ = register_parameter(
156153
"weight",
157-
torch::empty({num_embeddings_per_partition, embedding_dim}, options),
158-
/*requires_grad=*/false);
154+
torch::empty({num_embeddings_per_partition, embedding_dim}, options));
159155
}
160156

161157
// The input to the module is a list of indices, and the output is the

src/layers/fused_linear.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
namespace llm {
1414

15-
class FusedColumnParallelLinearImpl : public llm::nn::Module {
15+
class FusedColumnParallelLinearImpl : public Module {
1616
public:
1717
FusedColumnParallelLinearImpl(int64_t in_features,
1818
const std::vector<int64_t>& out_features,

src/layers/linear.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ using TensorTransform = std::function<torch::Tensor(const torch::Tensor&)>;
1616
// an interface for parallel linear layer.
1717
// all linear classes should inherit from this class and implement the forward
1818
// function.
19-
class ParallelLinearImpl : public llm::nn::Module {
19+
class ParallelLinearImpl : public Module {
2020
public:
2121
~ParallelLinearImpl() override = default;
2222

@@ -39,10 +39,10 @@ class ParallelLinearImpl : public llm::nn::Module {
3939
}
4040
};
4141

42-
class ColumnParallelLinear : public llm::nn::ModuleHolder<ParallelLinearImpl> {
42+
class ColumnParallelLinear : public ModuleHolder<ParallelLinearImpl> {
4343
public:
44-
using llm::nn::ModuleHolder<ParallelLinearImpl>::ModuleHolder;
45-
using Impl __attribute__((__unused__)) = ParallelLinearImpl;
44+
using ModuleHolder<ParallelLinearImpl>::ModuleHolder;
45+
using Impl [[maybe_unused]] = ParallelLinearImpl;
4646

4747
// construct a rotary positional embedding.
4848
// chose right implementation based on the args.
@@ -62,10 +62,10 @@ class ColumnParallelLinear : public llm::nn::ModuleHolder<ParallelLinearImpl> {
6262
const torch::TensorOptions& options);
6363
};
6464

65-
class RowParallelLinear : public llm::nn::ModuleHolder<ParallelLinearImpl> {
65+
class RowParallelLinear : public ModuleHolder<ParallelLinearImpl> {
6666
public:
67-
using llm::nn::ModuleHolder<ParallelLinearImpl>::ModuleHolder;
68-
using Impl __attribute__((__unused__)) = ParallelLinearImpl;
67+
using ModuleHolder<ParallelLinearImpl>::ModuleHolder;
68+
using Impl [[maybe_unused]] = ParallelLinearImpl;
6969

7070
// construct a rotary positional embedding.
7171
// chose right implementation based on the args.

src/layers/linear_impl.cpp

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,11 @@ ColumnParallelLinearImpl::ColumnParallelLinearImpl(
2828
// we allocate the transpose.
2929
weight_ = register_parameter(
3030
"weight",
31-
torch::empty({out_features_per_partition, in_features}, options),
32-
/*requires_grad=*/false);
31+
torch::empty({out_features_per_partition, in_features}, options));
3332

3433
if (bias) {
35-
bias_ =
36-
register_parameter("bias",
37-
torch::empty({out_features_per_partition}, options),
38-
/*requires_grad=*/false);
34+
bias_ = register_parameter(
35+
"bias", torch::empty({out_features_per_partition}, options));
3936
}
4037
}
4138

@@ -104,13 +101,10 @@ RowParallelLinearImpl::RowParallelLinearImpl(
104101
// Allocate the transpose since linear performs XA^T.
105102
weight_ = register_parameter(
106103
"weight",
107-
torch::empty({out_features, in_features_per_partition}, options),
108-
/*requires_grad=*/false);
104+
torch::empty({out_features, in_features_per_partition}, options));
109105

110106
if (bias) {
111-
bias_ = register_parameter("bias",
112-
torch::empty({out_features}, options),
113-
/*requires_grad=*/false);
107+
bias_ = register_parameter("bias", torch::empty({out_features}, options));
114108
}
115109
}
116110

src/layers/normalization.h

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ inline torch::Tensor layer_norm(torch::Tensor input,
6565
// apply layer normalization over a mini-batch of inputs as described in
6666
// the paper `Layer Normalization`: https://arxiv.org/abs/1607.06450
6767
// x = ((x - mean(x)) / sqrt(std(x) + eps)) * weight + bias
68-
class LayerNormImpl : public llm::nn::Module {
68+
class LayerNormImpl : public Module {
6969
public:
7070
// dim: the dim over which the mean and std are calculated separately.
7171
// eps: a value added to the denominator for numerical stability.
@@ -75,13 +75,11 @@ class LayerNormImpl : public llm::nn::Module {
7575
const torch::TensorOptions& options)
7676
: eps_(eps) {
7777
normalized_shape_ = {dim};
78-
weight_ = register_parameter("weight",
79-
torch::empty(normalized_shape_, options),
80-
/*requires_grad=*/false);
78+
weight_ =
79+
register_parameter("weight", torch::empty(normalized_shape_, options));
8180
if (bias) {
82-
bias_ = register_parameter("bias",
83-
torch::zeros(normalized_shape_, options),
84-
/*requires_grad=*/false);
81+
bias_ =
82+
register_parameter("bias", torch::zeros(normalized_shape_, options));
8583
}
8684
}
8785

@@ -145,13 +143,11 @@ class LayerNormImpl : public llm::nn::Module {
145143
LLM_MODULE(LayerNorm);
146144

147145
// Root mean square normalization
148-
class RMSNormImpl : public llm::nn::Module {
146+
class RMSNormImpl : public Module {
149147
public:
150148
RMSNormImpl(int64_t dim, float eps, const torch::TensorOptions& options)
151149
: eps_(eps) {
152-
weight_ = register_parameter("weight",
153-
torch::empty({dim}, options),
154-
/*requires_grad=*/false);
150+
weight_ = register_parameter("weight", torch::empty({dim}, options));
155151
}
156152

157153
torch::Tensor forward(const torch::Tensor& input) {
@@ -195,13 +191,11 @@ class RMSNormImpl : public llm::nn::Module {
195191
};
196192
LLM_MODULE(RMSNorm);
197193

198-
class GemmaRMSNormImpl : public llm::nn::Module {
194+
class GemmaRMSNormImpl : public Module {
199195
public:
200196
GemmaRMSNormImpl(int64_t dim, float eps, const torch::TensorOptions& options)
201197
: eps_(eps) {
202-
weight_ = register_parameter("weight",
203-
torch::empty({dim}, options),
204-
/*requires_grad=*/false);
198+
weight_ = register_parameter("weight", torch::empty({dim}, options));
205199
}
206200

207201
torch::Tensor forward(const torch::Tensor& input) {
@@ -246,15 +240,13 @@ class GemmaRMSNormImpl : public llm::nn::Module {
246240
LLM_MODULE(GemmaRMSNorm);
247241

248242
// Root mean square normalization
249-
class RMSNormResidualImpl : public llm::nn::Module {
243+
class RMSNormResidualImpl : public Module {
250244
public:
251245
RMSNormResidualImpl(int64_t dim,
252246
float eps,
253247
const torch::TensorOptions& options)
254248
: eps_(eps) {
255-
weight_ = register_parameter("weight",
256-
torch::empty({dim}, options),
257-
/*requires_grad=*/false);
249+
weight_ = register_parameter("weight", torch::empty({dim}, options));
258250
}
259251

260252
torch::Tensor forward(const torch::Tensor& input, torch::Tensor& residual) {

src/layers/pos_embedding.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,11 @@ class RotaryEmbeddingImpl : public torch::nn::Module {
4141

4242
// RotaryEmbedding is a wrapper class that chooses the right rotary positional
4343
// embedding implementation based on the args.
44-
// Similar to TORCH_MODULE(RotaryEmbedding) except of the explicit constructor.
44+
// Similar to LLM_MODULE(RotaryEmbedding) except of the explicit constructor.
4545
class RotaryEmbedding : public torch::nn::ModuleHolder<RotaryEmbeddingImpl> {
4646
public:
4747
using torch::nn::ModuleHolder<RotaryEmbeddingImpl>::ModuleHolder;
48-
using Impl __attribute__((__unused__)) = RotaryEmbeddingImpl;
48+
using Impl [[maybe_unused]] = RotaryEmbeddingImpl;
4949

5050
// construct a rotary positional embedding.
5151
// chose right implementation based on the args.

src/layers/qkv_linear.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
#include "fused_linear.h"
77
#include "model_loader/state_dict.h"
88
#include "model_parallel/parallel_args.h"
9+
#include "module/module.h"
10+
#include "module/module_holder.h"
911
#include "quantization/quant_args.h"
1012

1113
namespace llm {
1214

1315
// a thin wrapper to handle state_dict loading for QKV with
1416
// support of MQA/GQA
15-
class QKVColumnParallelLinearImpl : public llm::nn::Module {
17+
class QKVColumnParallelLinearImpl : public Module {
1618
public:
1719
QKVColumnParallelLinearImpl(int64_t hidden_size,
1820
int64_t n_heads,

src/models/aquila.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616
#include "models/parameters.h"
1717
#include "module/module.h"
1818
#include "module/module_holder.h"
19-
#include "module/modulelist.h"
19+
#include "module/module_list.h"
2020

2121
// Aquila model compatible with huggingface weights
2222
namespace llm::hf {
2323

24-
class AquilaMLPImpl : public llm::nn::Module {
24+
class AquilaMLPImpl : public Module {
2525
public:
2626
AquilaMLPImpl(const ModelArgs& args,
2727
const QuantArgs& quant_args,
@@ -82,7 +82,7 @@ class AquilaMLPImpl : public llm::nn::Module {
8282
};
8383
LLM_MODULE(AquilaMLP);
8484

85-
class AquilaAttentionImpl : public llm::nn::Module {
85+
class AquilaAttentionImpl : public Module {
8686
public:
8787
AquilaAttentionImpl(const ModelArgs& args,
8888
const QuantArgs& quant_args,
@@ -169,7 +169,7 @@ class AquilaAttentionImpl : public llm::nn::Module {
169169
};
170170
LLM_MODULE(AquilaAttention);
171171

172-
class AquilaDecoderLayerImpl : public llm::nn::Module {
172+
class AquilaDecoderLayerImpl : public Module {
173173
public:
174174
AquilaDecoderLayerImpl(const ModelArgs& args,
175175
const QuantArgs& quant_args,
@@ -229,7 +229,7 @@ class AquilaDecoderLayerImpl : public llm::nn::Module {
229229
};
230230
LLM_MODULE(AquilaDecoderLayer);
231231

232-
class AquilaModelImpl : public llm::nn::Module {
232+
class AquilaModelImpl : public Module {
233233
public:
234234
AquilaModelImpl(const ModelArgs& args,
235235
const QuantArgs& quant_args,
@@ -244,7 +244,7 @@ class AquilaModelImpl : public llm::nn::Module {
244244
handler_ = AttentionHandler::create_handler_with_rope(
245245
args, /*interleaved=*/false, options);
246246

247-
blocks_ = register_module("layers", llm::nn::ModuleList());
247+
blocks_ = register_module("layers", ModuleList());
248248
layers_.reserve(args.n_layers());
249249
for (int32_t i = 0; i < args.n_layers(); i++) {
250250
auto block = AquilaDecoderLayer(
@@ -298,15 +298,15 @@ class AquilaModelImpl : public llm::nn::Module {
298298
// attention handler
299299
std::unique_ptr<AttentionHandler> handler_{nullptr};
300300

301-
llm::nn::ModuleList blocks_{nullptr};
301+
ModuleList blocks_{nullptr};
302302
// hold same data but different type as blocks_ to avoid type cast
303303
std::vector<AquilaDecoderLayer> layers_;
304304

305305
RMSNorm norm_{nullptr};
306306
};
307307
LLM_MODULE(AquilaModel);
308308

309-
class AquilaForCausalLMImpl : public llm::nn::Module {
309+
class AquilaForCausalLMImpl : public Module {
310310
public:
311311
AquilaForCausalLMImpl(const ModelArgs& args,
312312
const QuantArgs& quant_args,

src/models/baichuan.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
#include "models/parameters.h"
1919
#include "module/module.h"
2020
#include "module/module_holder.h"
21-
#include "module/modulelist.h"
21+
#include "module/module_list.h"
2222

2323
// Baichuan model compatible with huggingface weights
2424

@@ -31,7 +31,7 @@ enum class BaichuanType : uint8_t {
3131
Baichuan2_13B,
3232
};
3333

34-
class BaichuanMLPImpl : public llm::nn::Module {
34+
class BaichuanMLPImpl : public Module {
3535
public:
3636
BaichuanMLPImpl(const ModelArgs& args,
3737
const QuantArgs& quant_args,
@@ -92,7 +92,7 @@ class BaichuanMLPImpl : public llm::nn::Module {
9292
};
9393
LLM_MODULE(BaichuanMLP);
9494

95-
class BaichuanAttentionImpl : public llm::nn::Module {
95+
class BaichuanAttentionImpl : public Module {
9696
public:
9797
BaichuanAttentionImpl(const ModelArgs& args,
9898
const QuantArgs& quant_args,
@@ -179,7 +179,7 @@ class BaichuanAttentionImpl : public llm::nn::Module {
179179
};
180180
LLM_MODULE(BaichuanAttention);
181181

182-
class BaichuanDecoderLayerImpl : public llm::nn::Module {
182+
class BaichuanDecoderLayerImpl : public Module {
183183
public:
184184
BaichuanDecoderLayerImpl(const ModelArgs& args,
185185
const QuantArgs& quant_args,
@@ -247,7 +247,7 @@ class BaichuanDecoderLayerImpl : public llm::nn::Module {
247247
};
248248
LLM_MODULE(BaichuanDecoderLayer);
249249

250-
class BaichuanModelImpl : public llm::nn::Module {
250+
class BaichuanModelImpl : public Module {
251251
public:
252252
BaichuanModelImpl(const ModelArgs& args,
253253
const QuantArgs& quant_args,
@@ -271,7 +271,7 @@ class BaichuanModelImpl : public llm::nn::Module {
271271
args, alibi_slopes, options);
272272
}
273273

274-
blocks_ = register_module("layers", llm::nn::ModuleList());
274+
blocks_ = register_module("layers", ModuleList());
275275
layers_.reserve(args.n_layers());
276276
for (int32_t i = 0; i < args.n_layers(); i++) {
277277
auto block = BaichuanDecoderLayer(args,
@@ -366,7 +366,7 @@ class BaichuanModelImpl : public llm::nn::Module {
366366
std::unique_ptr<AttentionHandler> handler_{nullptr};
367367

368368
// parameter members, must be registered
369-
llm::nn::ModuleList blocks_{nullptr};
369+
ModuleList blocks_{nullptr};
370370
// hold same data but different type as blocks_ to avoid type cast
371371
std::vector<BaichuanDecoderLayer> layers_;
372372

@@ -375,7 +375,7 @@ class BaichuanModelImpl : public llm::nn::Module {
375375
};
376376
LLM_MODULE(BaichuanModel);
377377

378-
class BaichuanForCausalLMImpl : public llm::nn::Module {
378+
class BaichuanForCausalLMImpl : public Module {
379379
public:
380380
BaichuanForCausalLMImpl(const ModelArgs& args,
381381
const QuantArgs& quant_args,

0 commit comments

Comments
 (0)