|
7 | 7 |
|
8 | 8 | #include "kernels/layernorm_kernels.h" |
9 | 9 | #include "model_loader/state_dict.h" |
| 10 | +#include "module/module.h" |
| 11 | +#include "module/module_holder.h" |
10 | 12 |
|
11 | 13 | DECLARE_bool(disable_custom_kernels); |
12 | 14 | namespace llm { |
@@ -63,7 +65,7 @@ inline torch::Tensor layer_norm(torch::Tensor input, |
63 | 65 | // apply layer normalization over a mini-batch of inputs as described in |
64 | 66 | // the paper `Layer Normalization`: https://arxiv.org/abs/1607.06450 |
65 | 67 | // x = ((x - mean(x)) / sqrt(std(x) + eps)) * weight + bias |
66 | | -class LayerNormImpl : public torch::nn::Module { |
| 68 | +class LayerNormImpl : public llm::nn::Module { |
67 | 69 | public: |
68 | 70 | // dim: the dim over which the mean and std are calculated separately. |
69 | 71 | // eps: a value added to the denominator for numerical stability. |
@@ -140,10 +142,10 @@ class LayerNormImpl : public torch::nn::Module { |
140 | 142 | float eps_; |
141 | 143 | std::vector<int64_t> normalized_shape_; |
142 | 144 | }; |
143 | | -TORCH_MODULE(LayerNorm); |
| 145 | +LLM_MODULE(LayerNorm); |
144 | 146 |
|
145 | 147 | // Root mean square normalization |
146 | | -class RMSNormImpl : public torch::nn::Module { |
| 148 | +class RMSNormImpl : public llm::nn::Module { |
147 | 149 | public: |
148 | 150 | RMSNormImpl(int64_t dim, float eps, const torch::TensorOptions& options) |
149 | 151 | : eps_(eps) { |
@@ -191,9 +193,9 @@ class RMSNormImpl : public torch::nn::Module { |
191 | 193 | // configs |
192 | 194 | float eps_; |
193 | 195 | }; |
194 | | -TORCH_MODULE(RMSNorm); |
| 196 | +LLM_MODULE(RMSNorm); |
195 | 197 |
|
196 | | -class GemmaRMSNormImpl : public torch::nn::Module { |
| 198 | +class GemmaRMSNormImpl : public llm::nn::Module { |
197 | 199 | public: |
198 | 200 | GemmaRMSNormImpl(int64_t dim, float eps, const torch::TensorOptions& options) |
199 | 201 | : eps_(eps) { |
@@ -241,10 +243,10 @@ class GemmaRMSNormImpl : public torch::nn::Module { |
241 | 243 | // configs |
242 | 244 | float eps_; |
243 | 245 | }; |
244 | | -TORCH_MODULE(GemmaRMSNorm); |
| 246 | +LLM_MODULE(GemmaRMSNorm); |
245 | 247 |
|
246 | 248 | // Root mean square normalization |
247 | | -class RMSNormResidualImpl : public torch::nn::Module { |
| 249 | +class RMSNormResidualImpl : public llm::nn::Module { |
248 | 250 | public: |
249 | 251 | RMSNormResidualImpl(int64_t dim, |
250 | 252 | float eps, |
@@ -304,6 +306,6 @@ class RMSNormResidualImpl : public torch::nn::Module { |
304 | 306 | // configs |
305 | 307 | float eps_; |
306 | 308 | }; |
307 | | -TORCH_MODULE(RMSNormResidual); |
| 309 | +LLM_MODULE(RMSNormResidual); |
308 | 310 |
|
309 | 311 | } // namespace llm |
0 commit comments