@@ -65,7 +65,7 @@ inline torch::Tensor layer_norm(torch::Tensor input,
6565// apply layer normalization over a mini-batch of inputs as described in
6666// the paper `Layer Normalization`: https://arxiv.org/abs/1607.06450
6767// x = ((x - mean(x)) / sqrt(std(x) + eps)) * weight + bias
68- class LayerNormImpl : public llm ::nn:: Module {
68+ class LayerNormImpl : public Module {
6969 public:
7070 // dim: the dim over which the mean and std are calculated separately.
7171 // eps: a value added to the denominator for numerical stability.
@@ -75,13 +75,11 @@ class LayerNormImpl : public llm::nn::Module {
7575 const torch::TensorOptions& options)
7676 : eps_(eps) {
7777 normalized_shape_ = {dim};
78- weight_ = register_parameter (" weight" ,
79- torch::empty (normalized_shape_, options),
80- /* requires_grad=*/ false );
78+ weight_ =
79+ register_parameter (" weight" , torch::empty (normalized_shape_, options));
8180 if (bias) {
82- bias_ = register_parameter (" bias" ,
83- torch::zeros (normalized_shape_, options),
84- /* requires_grad=*/ false );
81+ bias_ =
82+ register_parameter (" bias" , torch::zeros (normalized_shape_, options));
8583 }
8684 }
8785
@@ -145,13 +143,11 @@ class LayerNormImpl : public llm::nn::Module {
145143LLM_MODULE (LayerNorm);
146144
147145// Root mean square normalization
148- class RMSNormImpl : public llm ::nn:: Module {
146+ class RMSNormImpl : public Module {
149147 public:
150148 RMSNormImpl (int64_t dim, float eps, const torch::TensorOptions& options)
151149 : eps_(eps) {
152- weight_ = register_parameter (" weight" ,
153- torch::empty ({dim}, options),
154- /* requires_grad=*/ false );
150+ weight_ = register_parameter (" weight" , torch::empty ({dim}, options));
155151 }
156152
157153 torch::Tensor forward (const torch::Tensor& input) {
@@ -195,13 +191,11 @@ class RMSNormImpl : public llm::nn::Module {
195191};
196192LLM_MODULE (RMSNorm);
197193
198- class GemmaRMSNormImpl : public llm ::nn:: Module {
194+ class GemmaRMSNormImpl : public Module {
199195 public:
200196 GemmaRMSNormImpl (int64_t dim, float eps, const torch::TensorOptions& options)
201197 : eps_(eps) {
202- weight_ = register_parameter (" weight" ,
203- torch::empty ({dim}, options),
204- /* requires_grad=*/ false );
198+ weight_ = register_parameter (" weight" , torch::empty ({dim}, options));
205199 }
206200
207201 torch::Tensor forward (const torch::Tensor& input) {
@@ -246,15 +240,13 @@ class GemmaRMSNormImpl : public llm::nn::Module {
246240LLM_MODULE (GemmaRMSNorm);
247241
248242// Root mean square normalization
249- class RMSNormResidualImpl : public llm ::nn:: Module {
243+ class RMSNormResidualImpl : public Module {
250244 public:
251245 RMSNormResidualImpl (int64_t dim,
252246 float eps,
253247 const torch::TensorOptions& options)
254248 : eps_(eps) {
255- weight_ = register_parameter (" weight" ,
256- torch::empty ({dim}, options),
257- /* requires_grad=*/ false );
249+ weight_ = register_parameter (" weight" , torch::empty ({dim}, options));
258250 }
259251
260252 torch::Tensor forward (const torch::Tensor& input, torch::Tensor& residual) {
0 commit comments