diff --git a/src/layers/attention/attention.h b/src/layers/attention/attention.h index 54773d22..64c6bb5e 100644 --- a/src/layers/attention/attention.h +++ b/src/layers/attention/attention.h @@ -10,7 +10,7 @@ namespace llm { -class AttentionImpl : public llm::nn::Module { +class AttentionImpl : public Module { public: AttentionImpl(int64_t n_heads, int64_t n_kv_heads, diff --git a/src/layers/embedding.h b/src/layers/embedding.h index b20fef70..9abe231d 100644 --- a/src/layers/embedding.h +++ b/src/layers/embedding.h @@ -16,16 +16,14 @@ namespace llm { // This module is often used to store word embeddings and retrieve them using // indices. -class EmbeddingImpl : public llm::nn::Module { +class EmbeddingImpl : public Module { public: EmbeddingImpl(int64_t num_embeddings, int64_t embedding_dim, const torch::TensorOptions& options) { // register the weight parameter weight_ = register_parameter( - "weight", - torch::empty({num_embeddings, embedding_dim}, options), - /*requires_grad=*/false); + "weight", torch::empty({num_embeddings, embedding_dim}, options)); } // The input to the module is a list of indices, and the output is the @@ -68,7 +66,7 @@ class EmbeddingImpl : public llm::nn::Module { LLM_MODULE(Embedding); // Embedding parallelized in the embedding dimension. -class ParallelEmbeddingImpl : public llm::nn::Module { +class ParallelEmbeddingImpl : public Module { public: ParallelEmbeddingImpl(int64_t num_embeddings, int64_t embedding_dim, @@ -84,8 +82,7 @@ class ParallelEmbeddingImpl : public llm::nn::Module { // register the weight parameter weight_ = register_parameter( "weight", - torch::empty({num_embeddings, embedding_dim_per_partition}, options), - /*requires_grad=*/false); + torch::empty({num_embeddings, embedding_dim_per_partition}, options)); } // The input to the module is a list of indices, and the output is the @@ -139,7 +136,7 @@ class ParallelEmbeddingImpl : public llm::nn::Module { LLM_MODULE(ParallelEmbedding); // Embedding parallelized in the vocabulary dimension -class VocabParallelEmbeddingImpl : public llm::nn::Module { +class VocabParallelEmbeddingImpl : public Module { public: VocabParallelEmbeddingImpl(int64_t num_embeddings, int64_t embedding_dim, @@ -154,8 +151,7 @@ class VocabParallelEmbeddingImpl : public llm::nn::Module { // register the weight parameter weight_ = register_parameter( "weight", - torch::empty({num_embeddings_per_partition, embedding_dim}, options), - /*requires_grad=*/false); + torch::empty({num_embeddings_per_partition, embedding_dim}, options)); } // The input to the module is a list of indices, and the output is the diff --git a/src/layers/fused_linear.h b/src/layers/fused_linear.h index b63d8555..191922e6 100644 --- a/src/layers/fused_linear.h +++ b/src/layers/fused_linear.h @@ -12,7 +12,7 @@ namespace llm { -class FusedColumnParallelLinearImpl : public llm::nn::Module { +class FusedColumnParallelLinearImpl : public Module { public: FusedColumnParallelLinearImpl(int64_t in_features, const std::vector& out_features, diff --git a/src/layers/linear.h b/src/layers/linear.h index b65b6d4a..99da103b 100644 --- a/src/layers/linear.h +++ b/src/layers/linear.h @@ -16,7 +16,7 @@ using TensorTransform = std::function; // an interface for parallel linear layer. // all linear classes should inherit from this class and implement the forward // function. -class ParallelLinearImpl : public llm::nn::Module { +class ParallelLinearImpl : public Module { public: ~ParallelLinearImpl() override = default; @@ -39,10 +39,10 @@ class ParallelLinearImpl : public llm::nn::Module { } }; -class ColumnParallelLinear : public llm::nn::ModuleHolder { +class ColumnParallelLinear : public ModuleHolder { public: - using llm::nn::ModuleHolder::ModuleHolder; - using Impl __attribute__((__unused__)) = ParallelLinearImpl; + using ModuleHolder::ModuleHolder; + using Impl [[maybe_unused]] = ParallelLinearImpl; // construct a rotary positional embedding. // chose right implementation based on the args. @@ -62,10 +62,10 @@ class ColumnParallelLinear : public llm::nn::ModuleHolder { const torch::TensorOptions& options); }; -class RowParallelLinear : public llm::nn::ModuleHolder { +class RowParallelLinear : public ModuleHolder { public: - using llm::nn::ModuleHolder::ModuleHolder; - using Impl __attribute__((__unused__)) = ParallelLinearImpl; + using ModuleHolder::ModuleHolder; + using Impl [[maybe_unused]] = ParallelLinearImpl; // construct a rotary positional embedding. // chose right implementation based on the args. diff --git a/src/layers/linear_impl.cpp b/src/layers/linear_impl.cpp index 4138b347..3f301868 100644 --- a/src/layers/linear_impl.cpp +++ b/src/layers/linear_impl.cpp @@ -28,14 +28,11 @@ ColumnParallelLinearImpl::ColumnParallelLinearImpl( // we allocate the transpose. weight_ = register_parameter( "weight", - torch::empty({out_features_per_partition, in_features}, options), - /*requires_grad=*/false); + torch::empty({out_features_per_partition, in_features}, options)); if (bias) { - bias_ = - register_parameter("bias", - torch::empty({out_features_per_partition}, options), - /*requires_grad=*/false); + bias_ = register_parameter( + "bias", torch::empty({out_features_per_partition}, options)); } } @@ -104,13 +101,10 @@ RowParallelLinearImpl::RowParallelLinearImpl( // Allocate the transpose since linear performs XA^T. weight_ = register_parameter( "weight", - torch::empty({out_features, in_features_per_partition}, options), - /*requires_grad=*/false); + torch::empty({out_features, in_features_per_partition}, options)); if (bias) { - bias_ = register_parameter("bias", - torch::empty({out_features}, options), - /*requires_grad=*/false); + bias_ = register_parameter("bias", torch::empty({out_features}, options)); } } diff --git a/src/layers/normalization.h b/src/layers/normalization.h index 748e96fe..8996ce78 100644 --- a/src/layers/normalization.h +++ b/src/layers/normalization.h @@ -65,7 +65,7 @@ inline torch::Tensor layer_norm(torch::Tensor input, // apply layer normalization over a mini-batch of inputs as described in // the paper `Layer Normalization`: https://arxiv.org/abs/1607.06450 // x = ((x - mean(x)) / sqrt(std(x) + eps)) * weight + bias -class LayerNormImpl : public llm::nn::Module { +class LayerNormImpl : public Module { public: // dim: the dim over which the mean and std are calculated separately. // eps: a value added to the denominator for numerical stability. @@ -75,13 +75,11 @@ class LayerNormImpl : public llm::nn::Module { const torch::TensorOptions& options) : eps_(eps) { normalized_shape_ = {dim}; - weight_ = register_parameter("weight", - torch::empty(normalized_shape_, options), - /*requires_grad=*/false); + weight_ = + register_parameter("weight", torch::empty(normalized_shape_, options)); if (bias) { - bias_ = register_parameter("bias", - torch::zeros(normalized_shape_, options), - /*requires_grad=*/false); + bias_ = + register_parameter("bias", torch::zeros(normalized_shape_, options)); } } @@ -145,13 +143,11 @@ class LayerNormImpl : public llm::nn::Module { LLM_MODULE(LayerNorm); // Root mean square normalization -class RMSNormImpl : public llm::nn::Module { +class RMSNormImpl : public Module { public: RMSNormImpl(int64_t dim, float eps, const torch::TensorOptions& options) : eps_(eps) { - weight_ = register_parameter("weight", - torch::empty({dim}, options), - /*requires_grad=*/false); + weight_ = register_parameter("weight", torch::empty({dim}, options)); } torch::Tensor forward(const torch::Tensor& input) { @@ -195,13 +191,11 @@ class RMSNormImpl : public llm::nn::Module { }; LLM_MODULE(RMSNorm); -class GemmaRMSNormImpl : public llm::nn::Module { +class GemmaRMSNormImpl : public Module { public: GemmaRMSNormImpl(int64_t dim, float eps, const torch::TensorOptions& options) : eps_(eps) { - weight_ = register_parameter("weight", - torch::empty({dim}, options), - /*requires_grad=*/false); + weight_ = register_parameter("weight", torch::empty({dim}, options)); } torch::Tensor forward(const torch::Tensor& input) { @@ -246,15 +240,13 @@ class GemmaRMSNormImpl : public llm::nn::Module { LLM_MODULE(GemmaRMSNorm); // Root mean square normalization -class RMSNormResidualImpl : public llm::nn::Module { +class RMSNormResidualImpl : public Module { public: RMSNormResidualImpl(int64_t dim, float eps, const torch::TensorOptions& options) : eps_(eps) { - weight_ = register_parameter("weight", - torch::empty({dim}, options), - /*requires_grad=*/false); + weight_ = register_parameter("weight", torch::empty({dim}, options)); } torch::Tensor forward(const torch::Tensor& input, torch::Tensor& residual) { diff --git a/src/layers/pos_embedding.h b/src/layers/pos_embedding.h index dc3b9504..5d5b3bba 100644 --- a/src/layers/pos_embedding.h +++ b/src/layers/pos_embedding.h @@ -41,11 +41,11 @@ class RotaryEmbeddingImpl : public torch::nn::Module { // RotaryEmbedding is a wrapper class that chooses the right rotary positional // embedding implementation based on the args. -// Similar to TORCH_MODULE(RotaryEmbedding) except of the explicit constructor. +// Similar to LLM_MODULE(RotaryEmbedding) except of the explicit constructor. class RotaryEmbedding : public torch::nn::ModuleHolder { public: using torch::nn::ModuleHolder::ModuleHolder; - using Impl __attribute__((__unused__)) = RotaryEmbeddingImpl; + using Impl [[maybe_unused]] = RotaryEmbeddingImpl; // construct a rotary positional embedding. // chose right implementation based on the args. diff --git a/src/layers/qkv_linear.h b/src/layers/qkv_linear.h index 1dab1c4b..2e4e2e39 100644 --- a/src/layers/qkv_linear.h +++ b/src/layers/qkv_linear.h @@ -6,13 +6,15 @@ #include "fused_linear.h" #include "model_loader/state_dict.h" #include "model_parallel/parallel_args.h" +#include "module/module.h" +#include "module/module_holder.h" #include "quantization/quant_args.h" namespace llm { // a thin wrapper to handle state_dict loading for QKV with // support of MQA/GQA -class QKVColumnParallelLinearImpl : public llm::nn::Module { +class QKVColumnParallelLinearImpl : public Module { public: QKVColumnParallelLinearImpl(int64_t hidden_size, int64_t n_heads, diff --git a/src/models/aquila.h b/src/models/aquila.h index 682456b6..59e7b3e4 100644 --- a/src/models/aquila.h +++ b/src/models/aquila.h @@ -16,12 +16,12 @@ #include "models/parameters.h" #include "module/module.h" #include "module/module_holder.h" -#include "module/modulelist.h" +#include "module/module_list.h" // Aquila model compatible with huggingface weights namespace llm::hf { -class AquilaMLPImpl : public llm::nn::Module { +class AquilaMLPImpl : public Module { public: AquilaMLPImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -82,7 +82,7 @@ class AquilaMLPImpl : public llm::nn::Module { }; LLM_MODULE(AquilaMLP); -class AquilaAttentionImpl : public llm::nn::Module { +class AquilaAttentionImpl : public Module { public: AquilaAttentionImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -169,7 +169,7 @@ class AquilaAttentionImpl : public llm::nn::Module { }; LLM_MODULE(AquilaAttention); -class AquilaDecoderLayerImpl : public llm::nn::Module { +class AquilaDecoderLayerImpl : public Module { public: AquilaDecoderLayerImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -229,7 +229,7 @@ class AquilaDecoderLayerImpl : public llm::nn::Module { }; LLM_MODULE(AquilaDecoderLayer); -class AquilaModelImpl : public llm::nn::Module { +class AquilaModelImpl : public Module { public: AquilaModelImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -244,7 +244,7 @@ class AquilaModelImpl : public llm::nn::Module { handler_ = AttentionHandler::create_handler_with_rope( args, /*interleaved=*/false, options); - blocks_ = register_module("layers", llm::nn::ModuleList()); + blocks_ = register_module("layers", ModuleList()); layers_.reserve(args.n_layers()); for (int32_t i = 0; i < args.n_layers(); i++) { auto block = AquilaDecoderLayer( @@ -298,7 +298,7 @@ class AquilaModelImpl : public llm::nn::Module { // attention handler std::unique_ptr handler_{nullptr}; - llm::nn::ModuleList blocks_{nullptr}; + ModuleList blocks_{nullptr}; // hold same data but different type as blocks_ to avoid type cast std::vector layers_; @@ -306,7 +306,7 @@ class AquilaModelImpl : public llm::nn::Module { }; LLM_MODULE(AquilaModel); -class AquilaForCausalLMImpl : public llm::nn::Module { +class AquilaForCausalLMImpl : public Module { public: AquilaForCausalLMImpl(const ModelArgs& args, const QuantArgs& quant_args, diff --git a/src/models/baichuan.h b/src/models/baichuan.h index f59700a5..db77fcdc 100644 --- a/src/models/baichuan.h +++ b/src/models/baichuan.h @@ -18,7 +18,7 @@ #include "models/parameters.h" #include "module/module.h" #include "module/module_holder.h" -#include "module/modulelist.h" +#include "module/module_list.h" // Baichuan model compatible with huggingface weights @@ -31,7 +31,7 @@ enum class BaichuanType : uint8_t { Baichuan2_13B, }; -class BaichuanMLPImpl : public llm::nn::Module { +class BaichuanMLPImpl : public Module { public: BaichuanMLPImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -92,7 +92,7 @@ class BaichuanMLPImpl : public llm::nn::Module { }; LLM_MODULE(BaichuanMLP); -class BaichuanAttentionImpl : public llm::nn::Module { +class BaichuanAttentionImpl : public Module { public: BaichuanAttentionImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -179,7 +179,7 @@ class BaichuanAttentionImpl : public llm::nn::Module { }; LLM_MODULE(BaichuanAttention); -class BaichuanDecoderLayerImpl : public llm::nn::Module { +class BaichuanDecoderLayerImpl : public Module { public: BaichuanDecoderLayerImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -247,7 +247,7 @@ class BaichuanDecoderLayerImpl : public llm::nn::Module { }; LLM_MODULE(BaichuanDecoderLayer); -class BaichuanModelImpl : public llm::nn::Module { +class BaichuanModelImpl : public Module { public: BaichuanModelImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -271,7 +271,7 @@ class BaichuanModelImpl : public llm::nn::Module { args, alibi_slopes, options); } - blocks_ = register_module("layers", llm::nn::ModuleList()); + blocks_ = register_module("layers", ModuleList()); layers_.reserve(args.n_layers()); for (int32_t i = 0; i < args.n_layers(); i++) { auto block = BaichuanDecoderLayer(args, @@ -366,7 +366,7 @@ class BaichuanModelImpl : public llm::nn::Module { std::unique_ptr handler_{nullptr}; // parameter members, must be registered - llm::nn::ModuleList blocks_{nullptr}; + ModuleList blocks_{nullptr}; // hold same data but different type as blocks_ to avoid type cast std::vector layers_; @@ -375,7 +375,7 @@ class BaichuanModelImpl : public llm::nn::Module { }; LLM_MODULE(BaichuanModel); -class BaichuanForCausalLMImpl : public llm::nn::Module { +class BaichuanForCausalLMImpl : public Module { public: BaichuanForCausalLMImpl(const ModelArgs& args, const QuantArgs& quant_args, diff --git a/src/models/bloom.h b/src/models/bloom.h index cd80e03a..0499993b 100644 --- a/src/models/bloom.h +++ b/src/models/bloom.h @@ -15,13 +15,13 @@ #include "models/parameters.h" #include "module/module.h" #include "module/module_holder.h" -#include "module/modulelist.h" +#include "module/module_list.h" // bloom model compatible with huggingface weights namespace llm::hf { -class BloomMLPImpl : public llm::nn::Module { +class BloomMLPImpl : public Module { public: BloomMLPImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -79,7 +79,7 @@ class BloomMLPImpl : public llm::nn::Module { }; LLM_MODULE(BloomMLP); -class BloomAttentionImpl : public llm::nn::Module { +class BloomAttentionImpl : public Module { public: BloomAttentionImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -173,7 +173,7 @@ class BloomAttentionImpl : public llm::nn::Module { }; LLM_MODULE(BloomAttention); -class BloomBlockImpl : public llm::nn::Module { +class BloomBlockImpl : public Module { public: BloomBlockImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -245,7 +245,7 @@ class BloomBlockImpl : public llm::nn::Module { }; LLM_MODULE(BloomBlock); -class BloomModelImpl : public llm::nn::Module { +class BloomModelImpl : public Module { public: BloomModelImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -268,7 +268,7 @@ class BloomModelImpl : public llm::nn::Module { handler_ = AttentionHandler::create_handler_with_alibi( args, alibi_slopes, options); - blocks_ = register_module("h", llm::nn::ModuleList()); + blocks_ = register_module("h", ModuleList()); layers_.reserve(args.n_layers()); for (int32_t i = 0; i < args.n_layers(); i++) { auto block = @@ -362,7 +362,7 @@ class BloomModelImpl : public llm::nn::Module { // attention handler std::unique_ptr handler_{nullptr}; - llm::nn::ModuleList blocks_{nullptr}; + ModuleList blocks_{nullptr}; // hold same data but different type as blocks_ to avoid type cast std::vector layers_; @@ -371,7 +371,7 @@ class BloomModelImpl : public llm::nn::Module { }; LLM_MODULE(BloomModel); -class BloomForCausalLMImpl : public llm::nn::Module { +class BloomForCausalLMImpl : public Module { public: BloomForCausalLMImpl(const ModelArgs& args, const QuantArgs& quant_args, diff --git a/src/models/chatglm.h b/src/models/chatglm.h index e4d6ce6b..cd8c71b3 100644 --- a/src/models/chatglm.h +++ b/src/models/chatglm.h @@ -16,14 +16,14 @@ #include "models/parameters.h" #include "module/module.h" #include "module/module_holder.h" -#include "module/modulelist.h" +#include "module/module_list.h" #include "tokenizer/tokenizer_args.h" // ChatGLM model compatible with huggingface weights namespace llm::hf { -class ChatGLMMLPImpl : public llm::nn::Module { +class ChatGLMMLPImpl : public Module { public: ChatGLMMLPImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -82,7 +82,7 @@ class ChatGLMMLPImpl : public llm::nn::Module { }; LLM_MODULE(ChatGLMMLP); -class ChatGLMAttentionImpl : public llm::nn::Module { +class ChatGLMAttentionImpl : public Module { public: ChatGLMAttentionImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -167,7 +167,7 @@ class ChatGLMAttentionImpl : public llm::nn::Module { }; LLM_MODULE(ChatGLMAttention); -class ChatGLMBlockImpl : public llm::nn::Module { +class ChatGLMBlockImpl : public Module { public: ChatGLMBlockImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -278,7 +278,7 @@ class ChatGLMBlockImpl : public llm::nn::Module { }; LLM_MODULE(ChatGLMBlock); -class ChatGLMModelImpl : public llm::nn::Module { +class ChatGLMModelImpl : public Module { public: ChatGLMModelImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -290,7 +290,7 @@ class ChatGLMModelImpl : public llm::nn::Module { args, /*interleaved=*/true, options); // register submodules - blocks_ = register_module("layers", llm::nn::ModuleList()); + blocks_ = register_module("layers", ModuleList()); layers_.reserve(args.n_layers()); for (int32_t i = 0; i < args.n_layers(); i++) { auto block = ChatGLMBlock( @@ -370,7 +370,7 @@ class ChatGLMModelImpl : public llm::nn::Module { std::unique_ptr handler_{nullptr}; // parameter members, must be registered - llm::nn::ModuleList blocks_{nullptr}; + ModuleList blocks_{nullptr}; // hold same data but different type as blocks_ to avoid type cast std::vector layers_; @@ -383,7 +383,7 @@ class ChatGLMModelImpl : public llm::nn::Module { }; LLM_MODULE(ChatGLMModel); -class ChatGLMForCausalLMImpl : public llm::nn::Module { +class ChatGLMForCausalLMImpl : public Module { public: ChatGLMForCausalLMImpl(const ModelArgs& args, const QuantArgs& quant_args, diff --git a/src/models/gemma.h b/src/models/gemma.h index dc5babea..dcb2c692 100644 --- a/src/models/gemma.h +++ b/src/models/gemma.h @@ -19,12 +19,12 @@ #include "models/parameters.h" #include "module/module.h" #include "module/module_holder.h" -#include "module/modulelist.h" +#include "module/module_list.h" // Gemma model compatible with huggingface weight namespace llm::hf { -class GemmaMLPImpl : public llm::nn::Module { +class GemmaMLPImpl : public Module { public: GemmaMLPImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -85,7 +85,7 @@ class GemmaMLPImpl : public llm::nn::Module { }; LLM_MODULE(GemmaMLP); -class GemmaAttentionImpl : public llm::nn::Module { +class GemmaAttentionImpl : public Module { public: GemmaAttentionImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -165,7 +165,7 @@ class GemmaAttentionImpl : public llm::nn::Module { }; LLM_MODULE(GemmaAttention); -class GemmaDecoderLayerImpl : public llm::nn::Module { +class GemmaDecoderLayerImpl : public Module { public: GemmaDecoderLayerImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -234,7 +234,7 @@ class GemmaDecoderLayerImpl : public llm::nn::Module { }; LLM_MODULE(GemmaDecoderLayer); -class GemmaModelImpl : public llm::nn::Module { +class GemmaModelImpl : public Module { public: GemmaModelImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -261,7 +261,7 @@ class GemmaModelImpl : public llm::nn::Module { handler_ = AttentionHandler::create_handler_with_rope( args, /*interleaved=*/false, options); - blocks_ = register_module("layers", llm::nn::ModuleList()); + blocks_ = register_module("layers", ModuleList()); layers_.reserve(args.n_layers()); for (int32_t i = 0; i < args.n_layers(); i++) { auto block = GemmaDecoderLayer( @@ -320,13 +320,13 @@ class GemmaModelImpl : public llm::nn::Module { // attention handler std::unique_ptr handler_{nullptr}; - llm::nn::ModuleList blocks_{nullptr}; + ModuleList blocks_{nullptr}; // hold same data but different type as blocks_ to avoid type cast std::vector layers_; }; LLM_MODULE(GemmaModel); -class GemmaForCausalLMImpl : public llm::nn::Module { +class GemmaForCausalLMImpl : public Module { public: GemmaForCausalLMImpl(const ModelArgs& args, const QuantArgs& quant_args, diff --git a/src/models/gemma2.h b/src/models/gemma2.h index c68140a1..6282af6e 100644 --- a/src/models/gemma2.h +++ b/src/models/gemma2.h @@ -19,12 +19,12 @@ #include "models/parameters.h" #include "module/module.h" #include "module/module_holder.h" -#include "module/modulelist.h" +#include "module/module_list.h" // Gemma2 model compatible with huggingface weight namespace llm::hf { -class Gemma2MLPImpl : public llm::nn::Module { +class Gemma2MLPImpl : public Module { public: Gemma2MLPImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -85,7 +85,7 @@ class Gemma2MLPImpl : public llm::nn::Module { }; LLM_MODULE(Gemma2MLP); -class Gemma2AttentionImpl : public llm::nn::Module { +class Gemma2AttentionImpl : public Module { public: Gemma2AttentionImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -173,7 +173,7 @@ class Gemma2AttentionImpl : public llm::nn::Module { }; LLM_MODULE(Gemma2Attention); -class Gemma2DecoderLayerImpl : public llm::nn::Module { +class Gemma2DecoderLayerImpl : public Module { public: Gemma2DecoderLayerImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -265,7 +265,7 @@ class Gemma2DecoderLayerImpl : public llm::nn::Module { }; LLM_MODULE(Gemma2DecoderLayer); -class Gemma2ModelImpl : public llm::nn::Module { +class Gemma2ModelImpl : public Module { public: Gemma2ModelImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -290,7 +290,7 @@ class Gemma2ModelImpl : public llm::nn::Module { handler_ = AttentionHandler::create_handler_with_rope( args, /*interleaved=*/false, options); - blocks_ = register_module("layers", llm::nn::ModuleList()); + blocks_ = register_module("layers", ModuleList()); layers_.reserve(args.n_layers()); for (int32_t i = 0; i < args.n_layers(); i++) { // Attention Type: [LOCAL_SLIDING, Global, LOCAL_SLIDING, Global, ...] @@ -356,13 +356,13 @@ class Gemma2ModelImpl : public llm::nn::Module { // attention handler std::unique_ptr handler_{nullptr}; - llm::nn::ModuleList blocks_{nullptr}; + ModuleList blocks_{nullptr}; // hold same data but different type as blocks_ to avoid type cast std::vector layers_; }; LLM_MODULE(Gemma2Model); -class Gemma2ForCausalLMImpl : public llm::nn::Module { +class Gemma2ForCausalLMImpl : public Module { public: Gemma2ForCausalLMImpl(const ModelArgs& args, const QuantArgs& quant_args, diff --git a/src/models/gpt2.h b/src/models/gpt2.h index bb910d5d..afde1084 100644 --- a/src/models/gpt2.h +++ b/src/models/gpt2.h @@ -15,13 +15,13 @@ #include "models/parameters.h" #include "module/module.h" #include "module/module_holder.h" -#include "module/modulelist.h" +#include "module/module_list.h" // gpt2 model compatible with huggingface weights namespace llm::hf { -class GPT2MLPImpl : public llm::nn::Module { +class GPT2MLPImpl : public Module { public: GPT2MLPImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -85,7 +85,7 @@ class GPT2MLPImpl : public llm::nn::Module { }; LLM_MODULE(GPT2MLP); -class GPT2AttentionImpl : public llm::nn::Module { +class GPT2AttentionImpl : public Module { public: GPT2AttentionImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -173,7 +173,7 @@ class GPT2AttentionImpl : public llm::nn::Module { }; LLM_MODULE(GPT2Attention); -class GPT2BlockImpl : public llm::nn::Module { +class GPT2BlockImpl : public Module { public: GPT2BlockImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -235,7 +235,7 @@ class GPT2BlockImpl : public llm::nn::Module { }; LLM_MODULE(GPT2Block); -class GPT2ModelImpl : public llm::nn::Module { +class GPT2ModelImpl : public Module { public: GPT2ModelImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -252,7 +252,7 @@ class GPT2ModelImpl : public llm::nn::Module { handler_ = AttentionHandler::create_handler(args, options); - blocks_ = register_module("h", llm::nn::ModuleList()); + blocks_ = register_module("h", ModuleList()); layers_.reserve(args.n_layers()); for (int32_t i = 0; i < args.n_layers(); i++) { auto block = @@ -313,7 +313,7 @@ class GPT2ModelImpl : public llm::nn::Module { // attention handler std::unique_ptr handler_{nullptr}; - llm::nn::ModuleList blocks_{nullptr}; + ModuleList blocks_{nullptr}; // hold same data but different type as blocks_ to avoid type cast std::vector layers_; @@ -321,7 +321,7 @@ class GPT2ModelImpl : public llm::nn::Module { }; LLM_MODULE(GPT2Model); -class GPT2ForCausalLMImpl : public llm::nn::Module { +class GPT2ForCausalLMImpl : public Module { public: GPT2ForCausalLMImpl(const ModelArgs& args, const QuantArgs& quant_args, diff --git a/src/models/gpt_j.h b/src/models/gpt_j.h index 3f889d07..3905860b 100644 --- a/src/models/gpt_j.h +++ b/src/models/gpt_j.h @@ -14,12 +14,12 @@ #include "models/parameters.h" #include "module/module.h" #include "module/module_holder.h" -#include "module/modulelist.h" +#include "module/module_list.h" // GPTJ model compatible with huggingface weights namespace llm::hf { -class GPTJMLPImpl : public llm::nn::Module { +class GPTJMLPImpl : public Module { public: GPTJMLPImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -73,7 +73,7 @@ class GPTJMLPImpl : public llm::nn::Module { }; LLM_MODULE(GPTJMLP); -class GPTJAttentionImpl : public llm::nn::Module { +class GPTJAttentionImpl : public Module { public: GPTJAttentionImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -145,7 +145,7 @@ class GPTJAttentionImpl : public llm::nn::Module { }; LLM_MODULE(GPTJAttention); -class GPTJBlockImpl : public llm::nn::Module { +class GPTJBlockImpl : public Module { public: GPTJBlockImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -200,7 +200,7 @@ class GPTJBlockImpl : public llm::nn::Module { }; LLM_MODULE(GPTJBlock); -class GPTJModelImpl : public llm::nn::Module { +class GPTJModelImpl : public Module { public: GPTJModelImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -215,7 +215,7 @@ class GPTJModelImpl : public llm::nn::Module { handler_ = AttentionHandler::create_handler_with_rope( args, /*interleaved=*/true, options); - blocks_ = register_module("h", llm::nn::ModuleList()); + blocks_ = register_module("h", ModuleList()); layers_.reserve(args.n_layers()); for (int32_t i = 0; i < args.n_layers(); i++) { auto block = @@ -273,7 +273,7 @@ class GPTJModelImpl : public llm::nn::Module { // attention handler std::unique_ptr handler_{nullptr}; - llm::nn::ModuleList blocks_{nullptr}; + ModuleList blocks_{nullptr}; // hold same data but different type as blocks_ to avoid type cast std::vector layers_; @@ -281,7 +281,7 @@ class GPTJModelImpl : public llm::nn::Module { }; LLM_MODULE(GPTJModel); -class GPTJForCausalLMImpl : public llm::nn::Module { +class GPTJForCausalLMImpl : public Module { public: GPTJForCausalLMImpl(const ModelArgs& args, const QuantArgs& quant_args, diff --git a/src/models/gpt_neox.h b/src/models/gpt_neox.h index e870fd3f..65a2a1a3 100644 --- a/src/models/gpt_neox.h +++ b/src/models/gpt_neox.h @@ -14,12 +14,12 @@ #include "models/parameters.h" #include "module/module.h" #include "module/module_holder.h" -#include "module/modulelist.h" +#include "module/module_list.h" // gpt-neox model compatible with huggingface weights namespace llm::hf { -class GPTNeoXMLPImpl : public llm::nn::Module { +class GPTNeoXMLPImpl : public Module { public: GPTNeoXMLPImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -77,7 +77,7 @@ class GPTNeoXMLPImpl : public llm::nn::Module { }; LLM_MODULE(GPTNeoXMLP); -class GPTNeoXAttentionImpl : public llm::nn::Module { +class GPTNeoXAttentionImpl : public Module { public: GPTNeoXAttentionImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -171,7 +171,7 @@ class GPTNeoXAttentionImpl : public llm::nn::Module { }; LLM_MODULE(GPTNeoXAttention); -class GPTNeoXLayerImpl : public llm::nn::Module { +class GPTNeoXLayerImpl : public Module { public: GPTNeoXLayerImpl(uint32_t layer_id, const ModelArgs& args, @@ -248,7 +248,7 @@ class GPTNeoXLayerImpl : public llm::nn::Module { }; LLM_MODULE(GPTNeoXLayer); -class GPTNeoXModelImpl : public llm::nn::Module { +class GPTNeoXModelImpl : public Module { public: GPTNeoXModelImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -263,7 +263,7 @@ class GPTNeoXModelImpl : public llm::nn::Module { handler_ = AttentionHandler::create_handler_with_rope( args, /*interleaved=*/false, options); - blocks_ = register_module("layers", llm::nn::ModuleList()); + blocks_ = register_module("layers", ModuleList()); layers_.reserve(args.n_layers()); for (int32_t i = 0; i < args.n_layers(); i++) { auto block = GPTNeoXLayer( @@ -321,7 +321,7 @@ class GPTNeoXModelImpl : public llm::nn::Module { // attention handler std::unique_ptr handler_{nullptr}; - llm::nn::ModuleList blocks_{nullptr}; + ModuleList blocks_{nullptr}; // hold same data but different type as blocks_ to avoid type cast std::vector layers_; @@ -329,7 +329,7 @@ class GPTNeoXModelImpl : public llm::nn::Module { }; LLM_MODULE(GPTNeoXModel); -class GPTNeoXForCausalLMImpl : public llm::nn::Module { +class GPTNeoXForCausalLMImpl : public Module { public: GPTNeoXForCausalLMImpl(const ModelArgs& args, const QuantArgs& quant_args, diff --git a/src/models/internlm.h b/src/models/internlm.h index 8958cf98..8f100c32 100644 --- a/src/models/internlm.h +++ b/src/models/internlm.h @@ -16,11 +16,11 @@ #include "models/parameters.h" #include "module/module.h" #include "module/module_holder.h" -#include "module/modulelist.h" +#include "module/module_list.h" // Internlm model compatible with huggingface weights namespace llm::hf { -class InternlmMLPImpl : public llm::nn::Module { +class InternlmMLPImpl : public Module { public: InternlmMLPImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -81,7 +81,7 @@ class InternlmMLPImpl : public llm::nn::Module { }; LLM_MODULE(InternlmMLP); -class InternlmAttentionImpl : public llm::nn::Module { +class InternlmAttentionImpl : public Module { public: InternlmAttentionImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -155,7 +155,7 @@ class InternlmAttentionImpl : public llm::nn::Module { }; LLM_MODULE(InternlmAttention); -class InternlmDecoderLayerImpl : public llm::nn::Module { +class InternlmDecoderLayerImpl : public Module { public: InternlmDecoderLayerImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -215,7 +215,7 @@ class InternlmDecoderLayerImpl : public llm::nn::Module { }; LLM_MODULE(InternlmDecoderLayer); -class InternlmModelImpl : public llm::nn::Module { +class InternlmModelImpl : public Module { public: InternlmModelImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -230,7 +230,7 @@ class InternlmModelImpl : public llm::nn::Module { handler_ = AttentionHandler::create_handler_with_rope( args, /*interleaved=*/false, options); - blocks_ = register_module("layers", llm::nn::ModuleList()); + blocks_ = register_module("layers", ModuleList()); layers_.reserve(args.n_layers()); for (int32_t i = 0; i < args.n_layers(); i++) { auto block = InternlmDecoderLayer( @@ -285,7 +285,7 @@ class InternlmModelImpl : public llm::nn::Module { // attention handler std::unique_ptr handler_{nullptr}; - llm::nn::ModuleList blocks_{nullptr}; + ModuleList blocks_{nullptr}; // hold same data but different type as blocks_ to avoid type cast std::vector layers_; @@ -293,7 +293,7 @@ class InternlmModelImpl : public llm::nn::Module { }; LLM_MODULE(InternlmModel); -class InternlmForCausalLMImpl : public llm::nn::Module { +class InternlmForCausalLMImpl : public Module { public: InternlmForCausalLMImpl(const ModelArgs& args, const QuantArgs& quant_args, diff --git a/src/models/llama.h b/src/models/llama.h index 959ba552..120ff48d 100644 --- a/src/models/llama.h +++ b/src/models/llama.h @@ -18,11 +18,11 @@ #include "models/parameters.h" #include "module/module.h" #include "module/module_holder.h" -#include "module/modulelist.h" +#include "module/module_list.h" // llama2 model compatible with huggingface weights namespace llm::hf { -class LlamaMLPImpl : public llm::nn::Module { +class LlamaMLPImpl : public Module { public: LlamaMLPImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -84,7 +84,7 @@ class LlamaMLPImpl : public llm::nn::Module { }; LLM_MODULE(LlamaMLP); -class LlamaAttentionImpl : public llm::nn::Module { +class LlamaAttentionImpl : public Module { public: LlamaAttentionImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -167,7 +167,7 @@ class LlamaAttentionImpl : public llm::nn::Module { }; LLM_MODULE(LlamaAttention); -class LlamaDecoderLayerImpl : public llm::nn::Module { +class LlamaDecoderLayerImpl : public Module { public: LlamaDecoderLayerImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -227,7 +227,7 @@ class LlamaDecoderLayerImpl : public llm::nn::Module { }; LLM_MODULE(LlamaDecoderLayer); -class LlamaModelImpl : public llm::nn::Module { +class LlamaModelImpl : public Module { public: LlamaModelImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -242,7 +242,7 @@ class LlamaModelImpl : public llm::nn::Module { handler_ = AttentionHandler::create_handler_with_rope( args, /*interleaved=*/false, options); - blocks_ = register_module("layers", llm::nn::ModuleList()); + blocks_ = register_module("layers", ModuleList()); layers_.reserve(args.n_layers()); for (int32_t i = 0; i < args.n_layers(); i++) { auto block = LlamaDecoderLayer( @@ -297,7 +297,7 @@ class LlamaModelImpl : public llm::nn::Module { // attention handler std::unique_ptr handler_{nullptr}; - llm::nn::ModuleList blocks_{nullptr}; + ModuleList blocks_{nullptr}; // hold same data but different type as blocks_ to avoid type cast std::vector layers_; @@ -305,7 +305,7 @@ class LlamaModelImpl : public llm::nn::Module { }; LLM_MODULE(LlamaModel); -class LlamaForCausalLMImpl : public llm::nn::Module { +class LlamaForCausalLMImpl : public Module { public: LlamaForCausalLMImpl(const ModelArgs& args, const QuantArgs& quant_args, diff --git a/src/models/mistral.h b/src/models/mistral.h index 120fcd4a..9b8e8ad0 100644 --- a/src/models/mistral.h +++ b/src/models/mistral.h @@ -15,11 +15,11 @@ #include "models/parameters.h" #include "module/module.h" #include "module/module_holder.h" -#include "module/modulelist.h" +#include "module/module_list.h" // Mistral model compatible with huggingface weights namespace llm::hf { -class MistralMLPImpl : public llm::nn::Module { +class MistralMLPImpl : public Module { public: MistralMLPImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -79,7 +79,7 @@ class MistralMLPImpl : public llm::nn::Module { }; LLM_MODULE(MistralMLP); -class MistralAttentionImpl : public llm::nn::Module { +class MistralAttentionImpl : public Module { public: MistralAttentionImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -158,7 +158,7 @@ class MistralAttentionImpl : public llm::nn::Module { }; LLM_MODULE(MistralAttention); -class MistralDecoderLayerImpl : public llm::nn::Module { +class MistralDecoderLayerImpl : public Module { public: MistralDecoderLayerImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -218,7 +218,7 @@ class MistralDecoderLayerImpl : public llm::nn::Module { }; LLM_MODULE(MistralDecoderLayer); -class MistralModelImpl : public llm::nn::Module { +class MistralModelImpl : public Module { public: MistralModelImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -233,7 +233,7 @@ class MistralModelImpl : public llm::nn::Module { handler_ = AttentionHandler::create_handler_with_rope( args, /*interleaved=*/false, options); - blocks_ = register_module("layers", llm::nn::ModuleList()); + blocks_ = register_module("layers", ModuleList()); layers_.reserve(args.n_layers()); for (int32_t i = 0; i < args.n_layers(); i++) { auto block = MistralDecoderLayer( @@ -288,7 +288,7 @@ class MistralModelImpl : public llm::nn::Module { // attention handler std::unique_ptr handler_{nullptr}; - llm::nn::ModuleList blocks_{nullptr}; + ModuleList blocks_{nullptr}; // hold same data but different type as blocks_ to avoid type cast std::vector layers_; @@ -296,7 +296,7 @@ class MistralModelImpl : public llm::nn::Module { }; LLM_MODULE(MistralModel); -class MistralForCausalLMImpl : public llm::nn::Module { +class MistralForCausalLMImpl : public Module { public: MistralForCausalLMImpl(const ModelArgs& args, const QuantArgs& quant_args, diff --git a/src/models/model_registry.h b/src/models/model_registry.h index b472fd50..a15f7138 100644 --- a/src/models/model_registry.h +++ b/src/models/model_registry.h @@ -88,7 +88,6 @@ class ModelRegistry { const ParallelArgs& parallel_args, \ const torch::TensorOptions& options) { \ ModelClass model(args, quant_args, parallel_args, options); \ - model->eval(); \ return std::make_unique>( \ std::move(model), options); \ }); \ diff --git a/src/models/mpt.h b/src/models/mpt.h index 86249921..b435b828 100644 --- a/src/models/mpt.h +++ b/src/models/mpt.h @@ -17,11 +17,11 @@ #include "models/parameters.h" #include "module/module.h" #include "module/module_holder.h" -#include "module/modulelist.h" +#include "module/module_list.h" // mpt model compatible with huggingface weights namespace llm::hf { -class MPTMLPImpl : public llm::nn::Module { +class MPTMLPImpl : public Module { public: MPTMLPImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -78,7 +78,7 @@ class MPTMLPImpl : public llm::nn::Module { }; LLM_MODULE(MPTMLP); -class MPTAttentionImpl : public llm::nn::Module { +class MPTAttentionImpl : public Module { public: MPTAttentionImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -234,7 +234,7 @@ class MPTAttentionImpl : public llm::nn::Module { }; LLM_MODULE(MPTAttention); -class MPTBlockImpl : public llm::nn::Module { +class MPTBlockImpl : public Module { public: MPTBlockImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -294,7 +294,7 @@ class MPTBlockImpl : public llm::nn::Module { }; LLM_MODULE(MPTBlock); -class MPTModelImpl : public llm::nn::Module { +class MPTModelImpl : public Module { public: MPTModelImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -312,7 +312,7 @@ class MPTModelImpl : public llm::nn::Module { handler_ = AttentionHandler::create_handler_with_alibi( args, alibi_slopes, options); - blocks_ = register_module("blocks", llm::nn::ModuleList()); + blocks_ = register_module("blocks", ModuleList()); layers_.reserve(args.n_layers()); for (int32_t i = 0; i < args.n_layers(); i++) { auto block = @@ -390,7 +390,7 @@ class MPTModelImpl : public llm::nn::Module { // attention handler std::unique_ptr handler_{nullptr}; - llm::nn::ModuleList blocks_{nullptr}; + ModuleList blocks_{nullptr}; // hold same data but different type as blocks_ to avoid type cast std::vector layers_; @@ -398,7 +398,7 @@ class MPTModelImpl : public llm::nn::Module { }; LLM_MODULE(MPTModel); -class MPTForCausalLMImpl : public llm::nn::Module { +class MPTForCausalLMImpl : public Module { public: MPTForCausalLMImpl(const ModelArgs& args, const QuantArgs& quant_args, diff --git a/src/models/phi.h b/src/models/phi.h index ac9ba726..99e8dc07 100644 --- a/src/models/phi.h +++ b/src/models/phi.h @@ -14,12 +14,12 @@ #include "models/parameters.h" #include "module/module.h" #include "module/module_holder.h" -#include "module/modulelist.h" +#include "module/module_list.h" // Phi model compatible with huggingface weights namespace llm::hf { -class PhiMLPImpl : public llm::nn::Module { +class PhiMLPImpl : public Module { public: PhiMLPImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -73,7 +73,7 @@ class PhiMLPImpl : public llm::nn::Module { }; LLM_MODULE(PhiMLP); -class PhiAttentionImpl : public llm::nn::Module { +class PhiAttentionImpl : public Module { public: PhiAttentionImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -160,7 +160,7 @@ class PhiAttentionImpl : public llm::nn::Module { }; LLM_MODULE(PhiAttention); -class PhiBlockImpl : public llm::nn::Module { +class PhiBlockImpl : public Module { public: PhiBlockImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -215,7 +215,7 @@ class PhiBlockImpl : public llm::nn::Module { }; LLM_MODULE(PhiBlock); -class PhiModelImpl : public llm::nn::Module { +class PhiModelImpl : public Module { public: PhiModelImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -230,7 +230,7 @@ class PhiModelImpl : public llm::nn::Module { handler_ = AttentionHandler::create_handler_with_rope( args, /*interleaved=*/false, options); - blocks_ = register_module("h", llm::nn::ModuleList()); + blocks_ = register_module("h", ModuleList()); layers_.reserve(args.n_layers()); for (int32_t i = 0; i < args.n_layers(); i++) { auto block = @@ -281,13 +281,13 @@ class PhiModelImpl : public llm::nn::Module { // attention handler std::unique_ptr handler_{nullptr}; - llm::nn::ModuleList blocks_{nullptr}; + ModuleList blocks_{nullptr}; // hold same data but different type as blocks_ to avoid type cast std::vector layers_; }; LLM_MODULE(PhiModel); -class PhiLMHeadImpl : public llm::nn::Module { +class PhiLMHeadImpl : public Module { public: PhiLMHeadImpl(const ModelArgs& args, const ParallelArgs& parallel_args, @@ -329,7 +329,7 @@ class PhiLMHeadImpl : public llm::nn::Module { }; LLM_MODULE(PhiLMHead); -class PhiForCausalLMImpl : public llm::nn::Module { +class PhiForCausalLMImpl : public Module { public: PhiForCausalLMImpl(const ModelArgs& args, const QuantArgs& quant_args, diff --git a/src/models/qwen.h b/src/models/qwen.h index 3884eb8c..d1867e4e 100644 --- a/src/models/qwen.h +++ b/src/models/qwen.h @@ -18,12 +18,12 @@ #include "models/parameters.h" #include "module/module.h" #include "module/module_holder.h" -#include "module/modulelist.h" +#include "module/module_list.h" // QWen model compatible with huggingface weights // adopted from https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py namespace llm::hf { -class QWenMLPImpl : public llm::nn::Module { +class QWenMLPImpl : public Module { public: QWenMLPImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -83,7 +83,7 @@ class QWenMLPImpl : public llm::nn::Module { }; LLM_MODULE(QWenMLP); -class QWenAttentionImpl : public llm::nn::Module { +class QWenAttentionImpl : public Module { public: QWenAttentionImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -156,7 +156,7 @@ class QWenAttentionImpl : public llm::nn::Module { }; LLM_MODULE(QWenAttention); -class QWenBlockImpl : public llm::nn::Module { +class QWenBlockImpl : public Module { public: QWenBlockImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -211,7 +211,7 @@ class QWenBlockImpl : public llm::nn::Module { }; LLM_MODULE(QWenBlock); -class QWenModelImpl : public llm::nn::Module { +class QWenModelImpl : public Module { public: QWenModelImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -226,7 +226,7 @@ class QWenModelImpl : public llm::nn::Module { handler_ = AttentionHandler::create_handler_with_rope( args, /*interleaved=*/false, options); - blocks_ = register_module("layers", llm::nn::ModuleList()); + blocks_ = register_module("layers", ModuleList()); layers_.reserve(args.n_layers()); for (int32_t i = 0; i < args.n_layers(); i++) { auto block = @@ -281,7 +281,7 @@ class QWenModelImpl : public llm::nn::Module { // attention handler std::unique_ptr handler_{nullptr}; - llm::nn::ModuleList blocks_{nullptr}; + ModuleList blocks_{nullptr}; // hold same data but different type as blocks_ to avoid type cast std::vector layers_; @@ -289,7 +289,7 @@ class QWenModelImpl : public llm::nn::Module { }; LLM_MODULE(QWenModel); -class QWenForCausalLMImpl : public llm::nn::Module { +class QWenForCausalLMImpl : public Module { public: QWenForCausalLMImpl(const ModelArgs& args, const QuantArgs& quant_args, diff --git a/src/models/qwen2.h b/src/models/qwen2.h index df64ccd1..d93de88a 100644 --- a/src/models/qwen2.h +++ b/src/models/qwen2.h @@ -19,13 +19,13 @@ #include "models/parameters.h" #include "module/module.h" #include "module/module_holder.h" -#include "module/modulelist.h" +#include "module/module_list.h" // QWen2 model compatible with huggingface weights // ref to: // https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/qwen2/modeling_qwen2.py namespace llm::hf { -class QWen2MLPImpl : public llm::nn::Module { +class QWen2MLPImpl : public Module { public: QWen2MLPImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -86,7 +86,7 @@ class QWen2MLPImpl : public llm::nn::Module { }; LLM_MODULE(QWen2MLP); -class QWen2AttentionImpl : public llm::nn::Module { +class QWen2AttentionImpl : public Module { public: QWen2AttentionImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -170,7 +170,7 @@ class QWen2AttentionImpl : public llm::nn::Module { }; LLM_MODULE(QWen2Attention); -class QWen2DecoderLayerImpl : public llm::nn::Module { +class QWen2DecoderLayerImpl : public Module { public: QWen2DecoderLayerImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -238,7 +238,7 @@ class QWen2DecoderLayerImpl : public llm::nn::Module { }; LLM_MODULE(QWen2DecoderLayer); -class QWen2ModelImpl : public llm::nn::Module { +class QWen2ModelImpl : public Module { public: QWen2ModelImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -253,7 +253,7 @@ class QWen2ModelImpl : public llm::nn::Module { handler_ = AttentionHandler::create_handler_with_rope( args, /*interleaved=*/false, options); - blocks_ = register_module("layers", llm::nn::ModuleList()); + blocks_ = register_module("layers", ModuleList()); layers_.reserve(args.n_layers()); for (int32_t i = 0; i < args.n_layers(); i++) { int32_t sliding_window = -1; @@ -318,7 +318,7 @@ class QWen2ModelImpl : public llm::nn::Module { // attention handler std::unique_ptr handler_{nullptr}; - llm::nn::ModuleList blocks_{nullptr}; + ModuleList blocks_{nullptr}; // hold same data but different type as blocks_ to avoid type cast std::vector layers_; @@ -326,7 +326,7 @@ class QWen2ModelImpl : public llm::nn::Module { }; LLM_MODULE(QWen2Model); -class QWen2ForCausalLMImpl : public llm::nn::Module { +class QWen2ForCausalLMImpl : public Module { public: QWen2ForCausalLMImpl(const ModelArgs& args, const QuantArgs& quant_args, diff --git a/src/models/simple_model.h b/src/models/simple_model.h index ae727cff..dc22d3a0 100644 --- a/src/models/simple_model.h +++ b/src/models/simple_model.h @@ -13,11 +13,11 @@ #include "models/parameters.h" #include "module/module.h" #include "module/module_holder.h" -#include "module/modulelist.h" +#include "module/module_list.h" // simple model for test namespace llm { -class SimpleMLPImpl : public llm::nn::Module { +class SimpleMLPImpl : public Module { public: SimpleMLPImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -73,7 +73,7 @@ class SimpleMLPImpl : public llm::nn::Module { LLM_MODULE(SimpleMLP); -class SimpleDecoderLayerImpl : public llm::nn::Module { +class SimpleDecoderLayerImpl : public Module { public: SimpleDecoderLayerImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -104,7 +104,7 @@ class SimpleDecoderLayerImpl : public llm::nn::Module { LLM_MODULE(SimpleDecoderLayer); -class SimpleModelImpl : public llm::nn::Module { +class SimpleModelImpl : public Module { public: SimpleModelImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -115,7 +115,7 @@ class SimpleModelImpl : public llm::nn::Module { ParallelEmbedding( args.vocab_size(), args.hidden_size(), parallel_args, options)); - blocks_ = register_module("layers", llm::nn::ModuleList()); + blocks_ = register_module("layers", ModuleList()); layers_.reserve(args.n_layers()); for (int32_t i = 0; i < args.n_layers(); i++) { auto block = SimpleDecoderLayer(args, quant_args, parallel_args, options); @@ -154,12 +154,12 @@ class SimpleModelImpl : public llm::nn::Module { private: ParallelEmbedding embed_tokens_{nullptr}; - llm::nn::ModuleList blocks_{nullptr}; + ModuleList blocks_{nullptr}; std::vector layers_; }; LLM_MODULE(SimpleModel); -class SimpleForCausalLMImpl : public llm::nn::Module { +class SimpleForCausalLMImpl : public Module { public: SimpleForCausalLMImpl(const ModelArgs& args, const QuantArgs& quant_args, diff --git a/src/module/CMakeLists.txt b/src/module/CMakeLists.txt index e06006f6..ceaab617 100644 --- a/src/module/CMakeLists.txt +++ b/src/module/CMakeLists.txt @@ -5,7 +5,7 @@ cc_library( module HDRS module.h - modulelist.h + module_list.h module_holder.h SRCS module.cpp diff --git a/src/module/cloneable.h b/src/module/cloneable.h deleted file mode 100644 index 51821a99..00000000 --- a/src/module/cloneable.h +++ /dev/null @@ -1,90 +0,0 @@ -#pragma once - -#include -#include -#include -#include - -#include -#include - -#include "module.h" - -namespace llm::nn { -/// The `clone()` method in the base `Module` class does not have knowledge of -/// the concrete runtime type of its subclasses. Therefore, `clone()` must -/// either be called from within the subclass, or from a base class that has -/// knowledge of the concrete type. `Cloneable` uses the CRTP to gain -/// knowledge of the subclass' static type and provide an implementation of the -/// `clone()` method. We do not want to use this pattern in the base class, -/// because then storing a module would always require templatizing it. -template -// NOLINTNEXTLINE(bugprone-exception-escape) -class Cloneable : public llm::nn::Module { - public: - using Module::Module; - - /// `reset()` must perform initialization of all members with reference - /// semantics, most importantly parameters, buffers and submodules. - virtual void reset() = 0; - - /// Performs a recursive "deep copy" of the `Module`, such that all parameters - /// and submodules in the cloned module are different from those in the - /// original module. - std::shared_ptr clone( - const std::optional& device = std::nullopt) const override { - NoGradGuard no_grad; - - const auto& self = static_cast(*this); - auto copy = std::make_shared(self); - copy->parameters_.clear(); - copy->buffers_.clear(); - copy->children_.clear(); - copy->reset(); - TORCH_CHECK(copy->parameters_.size() == parameters_.size(), - "The cloned module does not have the same number of " - "parameters as the original module after calling reset(). " - "Are you sure you called register_parameter() inside reset() " - "and not the constructor?"); - for (const auto& parameter : named_parameters(/*recurse=*/false)) { - auto& tensor = *parameter; - auto data = device && tensor.device() != *device ? tensor.to(*device) - : tensor.clone(); - copy->parameters_[parameter.key()].set_data(data); - } - TORCH_CHECK(copy->buffers_.size() == buffers_.size(), - "The cloned module does not have the same number of " - "buffers as the original module after calling reset(). " - "Are you sure you called register_buffer() inside reset() " - "and not the constructor?"); - for (const auto& buffer : named_buffers(/*recurse=*/false)) { - auto& tensor = *buffer; - auto data = device && tensor.device() != *device ? tensor.to(*device) - : tensor.clone(); - copy->buffers_[buffer.key()].set_data(data); - } - TORCH_CHECK(copy->children_.size() == children_.size(), - "The cloned module does not have the same number of " - "child modules as the original module after calling reset(). " - "Are you sure you called register_module() inside reset() " - "and not the constructor?"); - for (const auto& child : children_) { - copy->children_[child.key()]->clone_(*child.value(), device); - } - return copy; - } - - private: - void clone_(Module& other, const std::optional& device) final { - // Here we are *pretty* certain that `other's` type is `Derived` (because it - // was registered under the same name as `this`), but you never know what - // crazy things `reset()` does, so `dynamic_cast` just to be safe. - auto clone = std::dynamic_pointer_cast(other.clone(device)); - TORCH_CHECK(clone != nullptr, - "Attempted to clone submodule, but it is of a " - "different type than the submodule it was to be cloned into"); - static_cast(*this) = *clone; - } -}; - -} // namespace llm::nn diff --git a/src/module/module.cpp b/src/module/module.cpp index c6a332e4..45f2546c 100644 --- a/src/module/module.cpp +++ b/src/module/module.cpp @@ -1,14 +1,10 @@ #include "module.h" -#include -#include -#include - -#include -#include #include -namespace llm::nn { +namespace llm { +using namespace torch; + namespace { /// Joins names hierarchically: "name_prefix.name" if `name_prefix` is /// non-empty, else just "name". @@ -38,93 +34,13 @@ Module::Module(std::string name) : Module() { const std::string& Module::name() const noexcept { // If the name optional is empty at this point, we grab the name of the - // dynamic type via RTTI. Note that we cannot do this in the constructor, - // because in the constructor of a base class `this` always refers to the base - // type. Inheritance effectively does not work in constructors. Also this note - // from http://en.cppreference.com/w/cpp/language/typeid: - // If typeid is used on an object under construction or destruction (in a - // destructor or in a constructor, including constructor's initializer list - // or default member initializers), then the std::type_info object referred - // to by this typeid represents the class that is being constructed or - // destroyed even if it is not the most-derived class. + // dynamic type via RTTI. if (!name_.has_value()) { name_ = c10::demangle(typeid(*this).name()); -#if defined(_WIN32) - // Windows adds "struct" or "class" as a prefix. - if (name_->find("struct ") == 0) { - name_->erase(name_->begin(), name_->begin() + 7); - } else if (name_->find("class ") == 0) { - name_->erase(name_->begin(), name_->begin() + 6); - } -#endif // defined(_WIN32) } return *name_; } -std::shared_ptr Module::clone( - const std::optional& device) const { - TORCH_CHECK( - false, - "clone() has not been implemented for ", - name(), - ". Subclass torch::nn::Cloneable<", - name(), - "> instead of torch::nn::Module to inherit the ability to clone."); -} - -void Module::apply(const ModuleApplyFunction& function) { - function(*this); - apply_to_submodules( - [&function](const std::string&, const std::shared_ptr& module) { - function(*module); - }); -} - -void Module::apply(const ConstModuleApplyFunction& function) const { - function(*this); - apply_to_submodules( - [&function](const std::string&, const std::shared_ptr& module) { - function(*module); - }); -} - -void Module::apply(const NamedModuleApplyFunction& function, - const std::string& name_prefix) { - function(/*name=*/name_prefix, *this); - apply_to_submodules( - [&function](const std::string& name, - const std::shared_ptr& module) { - function(name, *module); - }, - name_prefix); -} - -void Module::apply(const ConstNamedModuleApplyFunction& function, - const std::string& name_prefix) const { - function(/*name=*/name_prefix, *this); - apply_to_submodules( - [&function](const std::string& name, - const std::shared_ptr& module) { - function(name, *module); - }, - name_prefix); -} - -void Module::apply(const ModulePointerApplyFunction& function) const { - function(shared_from_this_checked()); - apply_to_submodules( - [&function](const std::string&, const std::shared_ptr& module) { - function(module); - }); -} - -void Module::apply(const NamedModulePointerApplyFunction& function, - const std::string& name_prefix) const { - function( - /*name=*/name_prefix, shared_from_this_checked()); - apply_to_submodules(function, name_prefix); -} - std::vector Module::parameters(bool recurse) const { return named_parameters(recurse).values(); } @@ -217,96 +133,78 @@ OrderedDict> Module::named_children() return children_; } -void Module::train(bool on) { - for (auto& child : children_) { - child.value()->train(on); - } - is_training_ = on; +void Module::apply(const ModuleApplyFunction& function) { + function(*this); + apply_to_submodules( + [&function](const std::string&, const std::shared_ptr& module) { + function(*module); + }); } -void Module::eval() { train(/*on=*/false); } - -void Module::to(torch::Device device, torch::Dtype dtype, bool non_blocking) { - to_impl(device, dtype, non_blocking); +void Module::apply(const ConstModuleApplyFunction& function) const { + function(*this); + apply_to_submodules( + [&function](const std::string&, const std::shared_ptr& module) { + function(*module); + }); } -void Module::to(torch::Dtype dtype, bool non_blocking) { - to_impl(dtype, non_blocking); +void Module::apply(const NamedModuleApplyFunction& function, + const std::string& name_prefix) { + function(/*name=*/name_prefix, *this); + apply_to_submodules( + [&function](const std::string& name, + const std::shared_ptr& module) { + function(name, *module); + }, + name_prefix); } -void Module::to(torch::Device device, bool non_blocking) { - to_impl(device, non_blocking); +void Module::apply(const ConstNamedModuleApplyFunction& function, + const std::string& name_prefix) const { + function(/*name=*/name_prefix, *this); + apply_to_submodules( + [&function](const std::string& name, + const std::shared_ptr& module) { + function(name, *module); + }, + name_prefix); } -bool Module::is_training() const noexcept { return is_training_; } - -void Module::zero_grad(bool set_to_none) { - for (auto& child : children_) { - child.value()->zero_grad(set_to_none); - } - for (auto& parameter : named_parameters(/*recurse=*/false)) { - auto& grad = parameter->mutable_grad(); - if (grad.defined()) { - grad = grad.detach(); +void Module::apply(const ModulePointerApplyFunction& function) const { + function(shared_from_this_checked()); + apply_to_submodules( + [&function](const std::string&, const std::shared_ptr& module) { + function(module); + }); +} - if (set_to_none) - grad.reset(); - else - grad.zero_(); - } - } +void Module::apply(const NamedModulePointerApplyFunction& function, + const std::string& name_prefix) const { + function( + /*name=*/name_prefix, shared_from_this_checked()); + apply_to_submodules(function, name_prefix); } -void Module::save(serialize::OutputArchive& archive) const { - for (const auto& parameter : named_parameters(/*recurse=*/false)) { - archive.write(parameter.key(), parameter.value()); - } - for (const auto& buffer : named_buffers(/*recurse=*/false)) { - archive.write(buffer.key(), buffer.value(), /*is_buffer=*/true); - } - for (const auto& child : children_) { - if (child.value()->is_serializable()) { - serialize::OutputArchive child_archive(archive.compilation_unit()); - child.value()->save(child_archive); - archive.write(child.key(), child_archive); - } - } +void Module::to(torch::Device device, torch::Dtype dtype, bool non_blocking) { + to_impl(device, dtype, non_blocking); } -void Module::load(serialize::InputArchive& archive) { - for (auto& parameter : named_parameters(/*recurse=*/false)) { - archive.read(parameter.key(), parameter.value()); - } - for (auto& buffer : named_buffers(/*recurse=*/false)) { - archive.read(buffer.key(), buffer.value(), /*is_buffer=*/true); - } - for (const auto& child : children_) { - if (child.value()->is_serializable()) { - serialize::InputArchive child_archive; - archive.read(child.key(), child_archive); - child.value()->load(child_archive); - } - } +void Module::to(torch::Dtype dtype, bool non_blocking) { + to_impl(dtype, non_blocking); } -bool Module::is_serializable() const { return true; } +void Module::to(torch::Device device, bool non_blocking) { + to_impl(device, non_blocking); +} -Tensor& Module::register_parameter(std::string name, - Tensor tensor, - bool requires_grad) { +Tensor& Module::register_parameter(std::string name, Tensor tensor) { TORCH_CHECK(!name.empty(), "Parameter name must not be empty"); TORCH_CHECK(name.find('.') == std::string::npos, "Parameter name must not contain a dot (got '", name, "')"); - if (!tensor.defined()) { - if (requires_grad) { - TORCH_WARN("An undefined tensor cannot require grad. ", - "Ignoring the `requires_grad=true` function parameter."); - } - } else { - tensor.set_requires_grad(requires_grad); - } + tensor.set_requires_grad(false); return parameters_.insert(std::move(name), std::move(tensor)); } @@ -344,8 +242,7 @@ void Module::pretty_print_recursive(std::ostream& stream, } } -void Module::clone_(Module& other, const std::optional& device) {} - +// NOLINTNEXTLINE(misc-no-recursion) void Module::apply_to_submodules( const NamedModulePointerApplyFunction& function, const std::string& name_prefix) const { @@ -375,23 +272,8 @@ std::shared_ptr Module::shared_from_this_checked() const { return std::const_pointer_cast(ptr); } -std::ostream& operator<<(std::ostream& stream, const nn::Module& module) { +std::ostream& operator<<(std::ostream& stream, const Module& module) { module.pretty_print_recursive(stream, ""); return stream; } - -serialize::OutputArchive& operator<<( - serialize::OutputArchive& archive, - const std::shared_ptr& module) { - TORCH_CHECK(module != nullptr, "Cannot serialize empty module"); - module->save(archive); - return archive; -} - -serialize::InputArchive& operator>>(serialize::InputArchive& archive, - const std::shared_ptr& module) { - TORCH_CHECK(module != nullptr, "Cannot deserialize empty module"); - module->load(archive); - return archive; -} -} // namespace llm::nn +} // namespace llm diff --git a/src/module/module.h b/src/module/module.h index afcf813e..69a63109 100644 --- a/src/module/module.h +++ b/src/module/module.h @@ -1,80 +1,28 @@ #pragma once #include -#include -#include -#include #include -#include #include #include #include -#include #include #include -#include #include "module_holder.h" -namespace llm::nn { -using namespace torch; +namespace llm { -/// The base class for all modules in PyTorch. -/// -/// \rst -/// .. note:: -/// The design and implementation of this class is largely based on the Python -/// API. You may want to consult the python documentation for -/// :py:class:`pytorch:torch.nn.Module` for further clarification on certain -/// methods or behavior. -/// \endrst +/// The base class for all modules. /// /// A `Module` is an abstraction over the implementation of some function or -/// algorithm, possibly associated with some persistent data. A `Module` may -/// contain further `Module`s ("submodules"), each with their own -/// implementation, persistent data and further submodules. `Module`s can thus -/// be said to form a recursive tree structure. A `Module` is registered as a +/// algorithm. A `Module` may contain further `Module`s ("submodules"), each +/// with their own implementation and further submodules. `Module`s can thus be +/// said to form a recursive tree structure. A `Module` is registered as a /// submodule to another `Module` by calling `register_module()`, typically from /// within a parent module's constructor. -/// -/// A distinction is made between three kinds of persistent data that may be -/// associated with a `Module`: -/// -/// 1. *Parameters*: tensors that record gradients, typically weights updated -/// during the backward step (e.g. the `weight` of a `Linear` module), -/// 2. *Buffers*: tensors that do not record gradients, typically updated during -/// the forward step, such as running statistics (e.g. `mean` and `variance` -/// in the `BatchNorm` module), -/// 3. Any additional state, not necessarily tensors, required for the -/// implementation or configuration of a `Module`. -/// -/// The first two kinds of state are special in that they may be registered -/// with the `Module` system to allow convenient access and batch configuration. -/// For example, registered parameters in any `Module` may be iterated over via -/// the `parameters()` accessor. Further, changing the data type of a `Module`'s -/// registered parameters can be done conveniently via `Module::to()`, e.g. -/// `module->to(torch::kCUDA)` to move all parameters to GPU memory. Lastly, -/// registered parameters and buffers are handled specially during a `clone()` -/// operation, which performs a deepcopy of a cloneable `Module` hierarchy. -/// -/// Parameters are registered with a `Module` via `register_parameter`. Buffers -/// are registered separately via `register_buffer`. These methods are part of -/// the public API of `Module` and are typically invoked from within a -/// concrete `Module`s constructor. class Module : public std::enable_shared_from_this { public: - using ModuleApplyFunction = std::function; - using ConstModuleApplyFunction = std::function; - using NamedModuleApplyFunction = - std::function; - using ConstNamedModuleApplyFunction = - std::function; - using ModulePointerApplyFunction = - std::function&)>; - using NamedModulePointerApplyFunction = - std::function&)>; - /// Tells the base `Module` about the name of the submodule. explicit Module(std::string name); @@ -82,6 +30,8 @@ class Module : public std::enable_shared_from_this { /// The name of the submodule is inferred via RTTI (if possible) the first /// time `.name()` is invoked. Module(); + + // default copy/move constructors and assignment operators Module(const Module&) = default; Module& operator=(const Module&) = default; Module(Module&&) noexcept = default; @@ -90,148 +40,29 @@ class Module : public std::enable_shared_from_this { virtual ~Module() = default; /// Returns the name of the `Module`. - /// - /// A `Module` has an associated `name`, which is a string representation of - /// the kind of concrete `Module` it represents, such as `"Linear"` for the - /// `Linear` module. Under most circumstances, this name is automatically - /// inferred via runtime type information (RTTI). In the unusual circumstance - /// that you have this feature disabled, you may want to manually name your - /// `Module`s by passing the string name to the `Module` base class' - /// constructor. const std::string& name() const noexcept; - /// Performs a recursive deep copy of the module and all its registered - /// parameters, buffers and submodules. - /// - /// Optionally, this method sets the current device - /// to the one supplied before cloning. If no device is given, each - /// parameter and buffer will be moved to the device of its source. - /// - /// \rst - /// .. attention:: - /// Attempting to call the `clone()` method inherited from the base `Module` - /// class (the one documented here) will fail. To inherit an actual - /// implementation of `clone()`, you must subclass `Cloneable`. `Cloneable` - /// is templatized on the concrete module type, and can thus properly copy a - /// `Module`. This method is provided on the base class' API solely for an - /// easier-to-use polymorphic interface. - /// \endrst - virtual std::shared_ptr clone( - const std::optional& device = std::nullopt) const; - - /// Applies the `function` to the `Module` and recursively to every submodule. - /// The function must accept a `Module&`. - /// - /// \rst - /// .. code-block:: cpp - /// MyModule module; - /// module->apply([](nn::Module& module) { - /// std::cout << module.name() << std::endl; - /// }); - /// \endrst - void apply(const ModuleApplyFunction& function); - - /// Applies the `function` to the `Module` and recursively to every submodule. - /// The function must accept a `const Module&`. - /// - /// \rst - /// .. code-block:: cpp - /// MyModule module; - /// module->apply([](const nn::Module& module) { - /// std::cout << module.name() << std::endl; - /// }); - /// \endrst - void apply(const ConstModuleApplyFunction& function) const; - - /// Applies the `function` to the `Module` and recursively to every submodule. - /// The function must accept a `const std::string&` for the key of the module, - /// and a `Module&`. The key of the module itself is the empty string. If - /// `name_prefix` is given, it is prepended to every key as - /// `.` (and just `name_prefix` for the module itself). - /// - /// \rst - /// .. code-block:: cpp - /// MyModule module; - /// module->apply([](const std::string& key, nn::Module& module) { - /// std::cout << key << ": " << module.name() << std::endl; - /// }); - /// \endrst - void apply(const NamedModuleApplyFunction& function, - const std::string& name_prefix = std::string()); - - /// Applies the `function` to the `Module` and recursively to every submodule. - /// The function must accept a `const std::string&` for the key of the module, - /// and a `const Module&`. The key of the module itself is the empty string. - /// If `name_prefix` is given, it is prepended to every key as - /// `.` (and just `name_prefix` for the module itself). - /// - /// \rst - /// .. code-block:: cpp - /// MyModule module; - /// module->apply([](const std::string& key, const nn::Module& module) { - /// std::cout << key << ": " << module.name() << std::endl; - /// }); - /// \endrst - void apply(const ConstNamedModuleApplyFunction& function, - const std::string& name_prefix = std::string()) const; - - /// Applies the `function` to the `Module` and recursively to every submodule. - /// The function must accept a `const std::shared_ptr&`. - /// - /// \rst - /// .. code-block:: cpp - /// MyModule module; - /// module->apply([](const std::shared_ptr& module) { - /// std::cout << module->name() << std::endl; - /// }); - /// \endrst - void apply(const ModulePointerApplyFunction& function) const; - - /// Applies the `function` to the `Module` and recursively to every submodule. - /// The function must accept a `const std::string&` for the key of the module, - /// and a `const std::shared_ptr&`. The key of the module itself is - /// the empty string. If `name_prefix` is given, it is prepended to every key - /// as - /// `.` (and just `name_prefix` for the module itself). - /// - /// \rst - /// .. code-block:: cpp - /// MyModule module; - /// module->apply([](const std::string& key, - /// const std::shared_ptr& module) { - /// std::cout << key << ": " << module->name() << std::endl; - /// }); - /// \endrst - void apply(const NamedModulePointerApplyFunction& function, - const std::string& name_prefix = std::string()) const; - /// Returns the parameters of this `Module` and if `recurse` is true, also /// recursively of every submodule. - std::vector parameters(bool recurse = true) const; + std::vector parameters(bool recurse = true) const; /// Returns an `OrderedDict` with the parameters of this `Module` along with /// their keys, and if `recurse` is true also recursively of every submodule. - OrderedDict named_parameters(bool recurse = true) const; + torch::OrderedDict named_parameters( + bool recurse = true) const; /// Returns the buffers of this `Module` and if `recurse` is true, also /// recursively of every submodule. - std::vector buffers(bool recurse = true) const; + std::vector buffers(bool recurse = true) const; /// Returns an `OrderedDict` with the buffers of this `Module` along with /// their keys, and if `recurse` is true also recursively of every submodule. - OrderedDict named_buffers(bool recurse = true) const; + torch::OrderedDict named_buffers( + bool recurse = true) const; /// Returns the submodules of this `Module` (the entire submodule hierarchy) /// and if `include_self` is true, also inserts a `shared_ptr` to this module /// in the first position. - /// - /// \rst - /// .. warning:: - /// Only pass `include_self` as `true` if this `Module` is stored in a - /// `shared_ptr`! Otherwise an exception will be thrown. You may still call - /// this method with `include_self` set to false if your `Module` is not - /// stored in a `shared_ptr`. - /// \endrst std::vector> modules(bool include_self = true) const; /// Returns an `OrderedDict` of the submodules of this `Module` (the entire @@ -239,15 +70,7 @@ class Module : public std::enable_shared_from_this { /// inserts a `shared_ptr` to this module in the first position. If /// `name_prefix` is given, it is prepended to every key as /// `.` (and just `name_prefix` for the module itself). - /// - /// \rst - /// .. warning:: - /// Only pass `include_self` as `true` if this `Module` is stored in a - /// `shared_ptr`! Otherwise an exception will be thrown. You may still call - /// this method with `include_self` set to false if your `Module` is not - /// stored in a `shared_ptr`. - /// \endrst - OrderedDict> named_modules( + torch::OrderedDict> named_modules( const std::string& name_prefix = std::string(), bool include_self = true) const; @@ -256,25 +79,34 @@ class Module : public std::enable_shared_from_this { /// Returns an `OrderedDict` of the direct submodules of this `Module` and /// their keys. - OrderedDict> named_children() const; + torch::OrderedDict> named_children() + const; - /// Enables "training" mode. - virtual void train(bool on = true); + /// Applies the `function` to the `Module` and recursively to every submodule. + using ModuleApplyFunction = std::function; + void apply(const ModuleApplyFunction& function); - /// Calls train(false) to enable "eval" mode. - /// Do not override this method, override `train()` instead. - void eval(); + using ConstModuleApplyFunction = std::function; + void apply(const ConstModuleApplyFunction& function) const; - /// True if the module is in training mode. - /// - /// Every `Module` has a boolean associated with it that determines whether - /// the `Module` is currently in *training* mode (set via `.train()`) or in - /// *evaluation* (inference) mode (set via `.eval()`). This property is - /// exposed via `is_training()`, and may be used by the implementation of a - /// concrete module to modify its runtime behavior. See the `BatchNorm` or - /// `Dropout` modules for examples of `Module`s that use different code paths - /// depending on this property. - virtual bool is_training() const noexcept; + using NamedModuleApplyFunction = + std::function; + void apply(const NamedModuleApplyFunction& function, + const std::string& name_prefix = std::string()); + + using ConstNamedModuleApplyFunction = + std::function; + void apply(const ConstNamedModuleApplyFunction& function, + const std::string& name_prefix = std::string()) const; + + using ModulePointerApplyFunction = + std::function&)>; + void apply(const ModulePointerApplyFunction& function) const; + + using NamedModulePointerApplyFunction = + std::function&)>; + void apply(const NamedModulePointerApplyFunction& function, + const std::string& name_prefix = std::string()) const; /// Recursively casts all parameters to the given `dtype` and `device`. /// @@ -287,114 +119,26 @@ class Module : public std::enable_shared_from_this { bool non_blocking = false); /// Recursively casts all parameters to the given dtype. - /// - /// If `non_blocking` is true and the source is in pinned memory and - /// destination is on the GPU or vice versa, the copy is performed - /// asynchronously with respect to the host. Otherwise, the argument has no - /// effect. virtual void to(torch::Dtype dtype, bool non_blocking = false); /// Recursively moves all parameters to the given device. - /// - /// If `non_blocking` is true and the source is in pinned memory and - /// destination is on the GPU or vice versa, the copy is performed - /// asynchronously with respect to the host. Otherwise, the argument has no - /// effect. virtual void to(torch::Device device, bool non_blocking = false); - /// Recursively zeros out the `grad` value of each registered parameter. - virtual void zero_grad(bool set_to_none = true); - /// Attempts to cast this `Module` to the given `ModuleType`. - /// - /// This method is useful when calling `apply()`. - /// \rst - /// .. code-block:: cpp - /// - /// void initialize_weights(nn::Module& module) { - /// torch::NoGradGuard no_grad; - /// if (auto* linear = module.as()) { - /// linear->weight.normal_(0.0, 0.02); - /// } - /// } - /// - /// MyModule module; - /// module->apply(initialize_weights); - /// \endrst template typename ModuleType::ContainedType* as() noexcept; - /// Attempts to cast this `Module` to the given `ModuleType`. - /// - /// This method is useful when calling `apply()`. - /// \rst - /// .. code-block:: cpp - /// void initialize_weights(nn::Module& module) { - /// torch::NoGradGuard no_grad; - /// if (auto* linear = module.as()) { - /// linear->weight.normal_(0.0, 0.02); - /// } - /// } - /// - /// MyModule module; - /// module->apply(initialize_weights); - /// \endrst template const typename ModuleType::ContainedType* as() const noexcept; - /// Attempts to cast this `Module` to the given `ModuleType`. - /// - /// This method is useful when calling `apply()`. - /// \rst - /// .. code-block:: cpp - /// - /// void initialize_weights(nn::Module& module) { - /// torch::NoGradGuard no_grad; - /// if (auto* linear = module.as()) { - /// linear->weight.normal_(0.0, 0.02); - /// } - /// } - /// - /// MyModule module; - /// module.apply(initialize_weights); - /// \endrst template > + typename = detail::disable_if_module_holder_t> ModuleType* as() noexcept; - /// Attempts to cast this `Module` to the given `ModuleType`. - /// - /// This method is useful when calling `apply()`. - /// \rst - /// .. code-block:: cpp - /// - /// void initialize_weights(nn::Module& module) { - /// torch::NoGradGuard no_grad; - /// if (auto* linear = module.as()) { - /// linear->weight.normal_(0.0, 0.02); - /// } - /// } - /// - /// MyModule module; - /// module.apply(initialize_weights); - /// \endrst template > + typename = detail::disable_if_module_holder_t> const ModuleType* as() const noexcept; - /// Serializes the `Module` into the given `OutputArchive`. - /// - /// If the `Module` contains unserializable submodules (e.g. - /// `nn::Functional`), those submodules are skipped when serializing. - virtual void save(serialize::OutputArchive& archive) const; - - /// Deserializes the `Module` from the given `InputArchive`. - /// - /// If the `Module` contains unserializable submodules (e.g. - /// `nn::Functional`), we don't check the existence of those submodules in the - /// `InputArchive` when deserializing. - virtual void load(serialize::InputArchive& archive); - /// Streams a pretty representation of the `Module` into the given `stream`. /// By default, this representation will be the name of the module (taken from /// `name()`), followed by a recursive pretty print of all of the `Module`'s @@ -404,108 +148,29 @@ class Module : public std::enable_shared_from_this { /// `stream` should be returned from the method, to allow easy chaining. virtual void pretty_print(std::ostream& stream) const; - /// Returns whether the `Module` is serializable. - virtual bool is_serializable() const; - /// Registers a parameter with this `Module`. - /// - /// A parameter should be any gradient-recording tensor used in the - /// implementation of your `Module`. Registering it makes it available to - /// methods such as `parameters()`, `clone()` or `to().` - /// - /// Note that registering an undefined Tensor (e.g. - /// `module.register_parameter("param", Tensor())`) is allowed, and is - /// equivalent to `module.register_parameter("param", None)` in Python API. - /// - /// \rst - /// .. code-block:: cpp - /// - /// MyModule::MyModule() { - /// weight_ = register_parameter("weight", torch::randn({A, B})); - /// } - /// \endrst - Tensor& register_parameter(std::string name, - Tensor tensor, - bool requires_grad = true); + torch::Tensor& register_parameter(std::string name, torch::Tensor tensor); /// Registers a buffer with this `Module`. - /// - /// A buffer is intended to be state in your module that does not record - /// gradients, such as running statistics. Registering it makes it available - /// to methods such as `buffers()`, `clone()` or `to(). - /// - /// \rst - /// .. code-block:: cpp - /// - /// MyModule::MyModule() { - /// mean_ = register_buffer("mean", torch::empty({num_features_})); - /// } - /// \endrst - Tensor& register_buffer(std::string name, Tensor tensor); + torch::Tensor& register_buffer(std::string name, torch::Tensor tensor); /// Registers a submodule with this `Module`. - /// - /// Registering a module makes it available to methods such as `modules()`, - /// `clone()` or `to()`. - /// - /// \rst - /// .. code-block:: cpp - /// - /// MyModule::MyModule() { - /// submodule_ = register_module("linear", torch::nn::Linear(3, 4)); - /// } - /// \endrst template std::shared_ptr register_module( std::string name, std::shared_ptr module); - /// Registers a submodule with this `Module`. - /// - /// This method deals with `ModuleHolder`s. - /// - /// Registering a module makes it available to methods such as `modules()`, - /// `clone()` or `to()`. - /// - /// \rst - /// .. code-block:: cpp - /// - /// MyModule::MyModule() { - /// submodule_ = register_module("linear", torch::nn::Linear(3, 4)); - /// } - /// \endrst template std::shared_ptr register_module( std::string name, ModuleHolder module_holder); /// Replaces a registered submodule with this `Module`. - /// - /// This takes care of the registration, if you used submodule members, you - /// should - // assign the submodule as well, i.e. use as - /// module->submodule_ = module->replace_module("linear", - /// torch::nn::Linear(3, 4)); - /// It only works when a module of the name is already registered. - /// - /// This is useful for replacing a module after initialization, e.g. - /// for finetuning. template std::shared_ptr replace_module( const std::string& name, std::shared_ptr module); - /// Replaces a registered submodule with this `Module`. - /// This method deals with `ModuleHolder`s. - /// - /// This takes care of the registration, if you used submodule members, you - /// should - // assign the submodule as well, i.e. use as - /// module->submodule_ = module->replace_module("linear", linear_holder); - /// It only works when a module of the name is already registered. - /// - /// This is useful for replacing a module after initialization, e.g. - /// for finetuning. template std::shared_ptr replace_module( const std::string& name, @@ -516,62 +181,17 @@ class Module : public std::enable_shared_from_this { void unregister_module(const std::string& name); protected: - /// The following three functions allow a module with default arguments in its - /// forward method to be used in a Sequential module. - /// You should NEVER override these functions manually. Instead, you should - /// use the `FORWARD_HAS_DEFAULT_ARGS` macro. - virtual bool _forward_has_default_args() { return false; } - - virtual unsigned int _forward_num_required_args() { - TORCH_CHECK( - false, - "torch::nn::Module subclass that has default arguments in `forward` " - "method ", - "must override `_forward_num_required_args` method. Please use ", - "`FORWARD_HAS_DEFAULT_ARGS` macro to do so."); - } - - // virtual std::vector _forward_populate_default_args( - // std::vector&& arguments) { - // TORCH_CHECK( - // false, - // "torch::nn::Module subclass that has default arguments in `forward` " - // "method ", - // "must override `_forward_populate_default_args` method. Please use ", - // "`FORWARD_HAS_DEFAULT_ARGS` macro to do so."); - // } - /// The registered parameters of this `Module`. /// Inorder to access parameters_ in ParameterDict and ParameterList // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - OrderedDict parameters_; + torch::OrderedDict parameters_; private: - // Friend classes. - - template - friend class Cloneable; - - template - friend struct AnyModuleHolder; - /// Pretty prints the given `Module` into the `ostream`. - TORCH_API friend std::ostream& operator<<(std::ostream& stream, - const nn::Module& module); - - // data parallel using this method to configure gradient edges during the - // replicate step. - template - friend void replicate_grad_edges( - const std::shared_ptr& module, - const std::vector>& replicas, - const std::vector& devices); + friend std::ostream& operator<<(std::ostream& stream, const Module& module); // Private methods. - /// Used in the implementation of `Cloneable`. - virtual void clone_(Module& other, const std::optional& device); - /// The implementation of the various `to()` methods. template void to_impl(Ts&&... ts); @@ -590,41 +210,28 @@ class Module : public std::enable_shared_from_this { std::shared_ptr shared_from_this_checked() const; /// The registered buffers of this `Module`. - OrderedDict buffers_; + torch::OrderedDict buffers_; /// The registered (direct) submodules of this `Module`. - OrderedDict> children_; + torch::OrderedDict> children_; /// The module's name (e.g. "LSTM"). mutable std::optional name_; - - /// Whether the module is in training mode. - bool is_training_{true}; }; -/// Serialize a `Module` pointer into an `OutputArchive`. -TORCH_API serialize::OutputArchive& operator<<( - serialize::OutputArchive& archive, - const std::shared_ptr& module); - -/// Deserializes a `Module` from an `InputArchive`. -TORCH_API serialize::InputArchive& operator>>( - serialize::InputArchive& archive, - const std::shared_ptr& module); - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ nn::Module ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Module ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template typename ModuleType::ContainedType* Module::as() noexcept { // Use the contained type of the `ModuleHolder`, e.g. `LinearImpl` for - // `Linear`, since `LinearImpl` inherits `nn::Module`. + // `Linear`, since `LinearImpl` inherits `Module`. return as(); } template const typename ModuleType::ContainedType* Module::as() const noexcept { // Use the contained type of the `ModuleHolder`, e.g. `LinearImpl` for - // `Linear`, since `LinearImpl` inherits `nn::Module`. + // `Linear`, since `LinearImpl` inherits `Module`. return as(); } @@ -689,4 +296,4 @@ void Module::to_impl(Ts&&... ts) { } } -} // namespace llm::nn +} // namespace llm diff --git a/src/module/module_holder.h b/src/module/module_holder.h index 36689aa0..abe131a5 100644 --- a/src/module/module_holder.h +++ b/src/module/module_holder.h @@ -1,33 +1,98 @@ #pragma once -#include -#include -#include -#include -#include - #include #include #include namespace llm { namespace detail { -// Dump all the template metaprogramming in this file. -#include "pimpl-inl.h" -} // namespace detail +struct ModuleHolderIndicator {}; + +// A type trait that is true for types that are `ModuleHolder`s. +template +using is_module_holder = + std::is_base_of>; + +template +using disable_if_module_holder_t = + std::enable_if_t::value>; + +// A collection of templates that answer the question whether a type `T` is a +// `ModuleHolder`, and if so whether its contained type is of type `C`. + +// Base template. +template +struct is_module_holder_of_impl; + +// False branch. `T` is not a `ModuleHolder` and thus not a `ModuleHolder` with +// contained type `C`. +template +struct is_module_holder_of_impl : std::false_type {}; + +// True branch. `T` is a `ModuleHolder` and thus we can legit access its +// `ContainedType` and compare it against `C`. +template +struct is_module_holder_of_impl + : std::is_same {}; + +// Helper template. +template +struct is_module_holder_of + : is_module_holder_of_impl::value, + std::decay_t, + std::decay_t> {}; + +/// Detects if a type T has a forward() method. +template +struct has_forward { + // Declare two types with differing size. + using yes = int8_t; + using no = int16_t; + + template + static yes test(decltype(&U::forward)); + template + static no test(...); + + // Finally we test statically whether the size of the type returned by the + // selected overload is the size of the `yes` type. + static constexpr bool value = (sizeof(test(nullptr)) == sizeof(yes)); +}; + +// A collection of templates that allow deducing the return type of the +// `forward()` method, but only if a module actually has a `forward()` method, +// and otherwise deduces to the type `void`. + +template +struct return_type_of_forward_impl; + +template +struct return_type_of_forward_impl { + using type = decltype(::std::declval().forward(::std::declval()...)); +}; + +template +struct return_type_of_forward_impl { + using type = void; +}; + +template +using return_type_of_forward = + return_type_of_forward_impl::value, C, Args...>; -namespace nn { -using namespace torch; +template +using return_type_of_forward_t = + typename return_type_of_forward::type; + +} // namespace detail /// A `ModuleHolder` is essentially a wrapper around `std::shared_ptr` where -/// `M` is an `nn::Module` subclass, with convenient constructors defined for +/// `M` is an `Module` subclass, with convenient constructors defined for /// the kind of constructions we want to allow for our modules. template class ModuleHolder : detail::ModuleHolderIndicator { protected: /// The module pointer this class wraps. - /// NOTE: Must be placed at the top of the class so that we can use it with - /// trailing return types below. // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) std::shared_ptr impl_; @@ -36,10 +101,6 @@ class ModuleHolder : detail::ModuleHolderIndicator { /// Default constructs the contained module if if has a default constructor, /// else produces a static error. - /// - /// NOTE: This uses the behavior of template - /// classes in C++ that constructors (or any methods) are only compiled when - /// actually used. ModuleHolder() : impl_(default_construct()) { static_assert( std::is_default_constructible_v, @@ -139,46 +200,22 @@ class ModuleHolder : detail::ModuleHolderIndicator { /// Pretty prints the given `Module` into the `ostream`. template std::ostream& operator<<(std::ostream& stream, - const nn::ModuleHolder& module) { + const ModuleHolder& module) { return stream << *module; } -/// Serializes a `ModuleHolder` into an `OutputArchive`. -template -serialize::OutputArchive& operator<<( - serialize::OutputArchive& archive, - const nn::ModuleHolder& module) { - return archive << module.ptr(); -} - -/// Deserializes a `ModuleHolder` from an `InputArchive`. -template -serialize::InputArchive& operator>>(serialize::InputArchive& archive, - nn::ModuleHolder& module) { - return archive >> module.ptr(); -} - -} // namespace nn } // namespace llm -// Workaround for CUDA 10.2 and below not allowing attribute unused on -// using declarations. -#ifdef __CUDACC__ -#define UNUSED_EXCEPT_CUDA -#else -#define UNUSED_EXCEPT_CUDA [[maybe_unused]] -#endif - -/// Defines a class `Name` which inherits from `nn::ModuleHolder` to provide a +/// Defines a class `Name` which inherits from `ModuleHolder` to provide a /// wrapper over a `std::shared_ptr`. /// `Impl` is a type alias for `ImplType` which provides a way to call static /// method of `ImplType`. -#define LLM_MODULE_IMPL(Name, ImplType) \ - class Name : public llm::nn::ModuleHolder { /* NOLINT */ \ - public: \ - using llm::nn::ModuleHolder::ModuleHolder; \ - using Impl TORCH_UNUSED_EXCEPT_CUDA = ImplType; \ +#define LLM_MODULE_IMPL(Name, ImplType) \ + class Name : public ModuleHolder { /* NOLINT */ \ + public: \ + using ModuleHolder::ModuleHolder; \ + using Impl [[maybe_unused]] = ImplType; \ } -/// Like `TORCH_MODULE_IMPL`, but defaults the `ImplType` name to `Impl`. +/// Like `LLM_MODULE_IMPL`, but defaults the `ImplType` name to `Impl`. #define LLM_MODULE(Name) LLM_MODULE_IMPL(Name, Name##Impl) diff --git a/src/module/modulelist.h b/src/module/module_list.h similarity index 77% rename from src/module/modulelist.h rename to src/module/module_list.h index e96f4f97..3ae3cb45 100644 --- a/src/module/modulelist.h +++ b/src/module/module_list.h @@ -7,57 +7,12 @@ #include #include -#include "cloneable.h" #include "module.h" #include "module_holder.h" -namespace llm::nn { - +namespace llm { /// A list of `Module`s that registers its elements. -/// -/// \rst -/// .. code-block:: cpp -/// -/// torch::nn::ModuleList mlist( -/// torch::nn::Linear(3, 4), -/// torch::nn::BatchNorm1d(4), -/// torch::nn::Dropout(0.5) -/// ); -/// -/// for (const auto &module : *mlist) { -/// module->pretty_print(std::cout); -/// } -/// -/// \endrst -/// -/// Why should you use `ModuleList` instead of a simple `std::vector`? The value -/// a `ModuleList` provides over manually calling a sequence of modules is that -/// it allows treating the whole container *as a single module*, such that -/// performing a transformation on the `ModuleList` applies to each of the -/// modules it stores (which are each a registered submodule of the -/// `ModuleList`). For example, calling -/// `.to(torch::kCUDA)` on a `ModuleList` will move each module in the list to -/// CUDA memory. For example: -/// -/// \rst -/// .. code-block:: cpp -/// -/// torch::nn::ModuleList mlist( -/// torch::nn::Linear(3, 4), -/// torch::nn::BatchNorm1d(4), -/// torch::nn::Dropout(0.5) -/// ); -/// -/// // Convert all modules to CUDA. -/// mlist->to(torch::kCUDA); -/// -/// \endrst -/// -/// Finally, `ModuleList` provides a lightweight container API, such as allowing -/// iteration over submodules, positional access, adding a new module after -/// construction via `push_back`, as well as joining two `ModuleList`s via -/// `extend`. -class ModuleListImpl : public Cloneable { +class ModuleListImpl : public Module { public: using Iterator = std::vector>::iterator; using ConstIterator = std::vector>::const_iterator; @@ -71,24 +26,9 @@ class ModuleListImpl : public Cloneable { push_back_var(std::forward(modules)...); } - /// Special cloning function for `ModuleList` because it does not use - /// `reset()`. - std::shared_ptr clone( - const std::optional& device = std::nullopt) const override { - auto clone = std::make_shared(); - for (const auto& module : modules_) { - clone->push_back(module->clone(device)); - } - return clone; - } - - /// `reset()` is empty for `ModuleList`, since it does not have parameters of - /// its own. - void reset() override {} - /// Pretty prints the `ModuleList` module into the given `stream`. void pretty_print(std::ostream& stream) const override { - stream << "torch::nn::ModuleList"; + stream << "ModuleList"; } void push_back(std::shared_ptr module) { @@ -255,4 +195,4 @@ class ModuleListImpl : public Cloneable { /// module storage semantics. LLM_MODULE(ModuleList); -} // namespace llm::nn +} // namespace llm diff --git a/src/module/pimpl-inl.h b/src/module/pimpl-inl.h deleted file mode 100644 index c5d91f34..00000000 --- a/src/module/pimpl-inl.h +++ /dev/null @@ -1,76 +0,0 @@ -// This class exists only to do SFINAE on abstract types `T` that are really -// `ModuleHolder`, because there's no good way to say that `T` is a -// `ModuleHolder` over some unknown type `ModuleType`. With this, you can do -// `enable_if_t>`. -struct ModuleHolderIndicator {}; - -// A type trait that is true for types that are `ModuleHolder`s. -template -using is_module_holder = - std::is_base_of>; - -template -using disable_if_module_holder_t = - std::enable_if_t::value>; - -// A collection of templates that answer the question whether a type `T` is a -// `ModuleHolder`, and if so whether its contained type is of type `C`. This is -// tricky because it is hard to short circuit in template metaprogramming. A -// naive and incorrect solution to this problem would be something like -// `disable_if::value && typename T::ContainedType == C>`. -// This would disable all types that are not `ModuleHolder`s, because even -// though the `is_module_holder::value` may be `false` for such types the -// `T::ContainedType` access would be ill-formed and thus fail the whole -// expression by the rules of SFINAE. Instead we have to use template -// specialization to statically branch on the first condition -// (`is_module_holder`) and are only then allowed to query -// `T::ContainedType` in the branch for which the condition was true. - -// Base template. -template -struct is_module_holder_of_impl; - -// False branch. `T` is not a `ModuleHolder` and thus not a `ModuleHolder` with -// contained type `C`. -template -struct is_module_holder_of_impl : std::false_type {}; - -// True branch. `T` is a `ModuleHolder` and thus we can legit access its -// `ContainedType` and compare it against `C`. -template -struct is_module_holder_of_impl - : std::is_same {}; - -// Helper template. -template -struct is_module_holder_of - : is_module_holder_of_impl::value, - std::decay_t, - std::decay_t> {}; - -// A collection of templates that allow deducing the return type of the -// `forward()` method, but only if a module actually has a `forward()` method, -// and otherwise deduces to the type `void`. - -template -struct return_type_of_forward_impl; - -template -struct return_type_of_forward_impl { - using type = decltype(::std::declval().forward(::std::declval()...)); -}; - -template -struct return_type_of_forward_impl { - using type = void; -}; - -template -using return_type_of_forward = - return_type_of_forward_impl::value, - C, - Args...>; - -template -using return_type_of_forward_t = - typename return_type_of_forward::type; diff --git a/src/quantization/qlinear_impl.cpp b/src/quantization/qlinear_impl.cpp index aa3de6fd..695d0dad 100644 --- a/src/quantization/qlinear_impl.cpp +++ b/src/quantization/qlinear_impl.cpp @@ -128,32 +128,26 @@ ColumnParallelQLinearImpl::ColumnParallelQLinearImpl( qweight_ = register_parameter( "qweight", torch::empty({in_features / pack_factor, out_features_per_partition}, - options.dtype(torch::kInt32)), - /*requires_grad=*/false); + options.dtype(torch::kInt32))); } else { qweight_ = register_parameter( "qweight", torch::empty({in_features, out_features_per_partition / pack_factor}, - options.dtype(torch::kInt32)), - /*requires_grad=*/false); + options.dtype(torch::kInt32))); } qzeros_ = register_parameter( "qzeros", torch::empty({round_up(in_features, group_size), out_features_per_partition / pack_factor}, - options.dtype(torch::kInt32)), - /*requires_grad=*/false); + options.dtype(torch::kInt32))); scales_ = register_parameter("scales", torch::empty({round_up(in_features, group_size), out_features_per_partition}, - options), - /*requires_grad=*/false); + options)); if (bias) { - bias_ = - register_parameter("bias", - torch::empty({out_features_per_partition}, options), - /*requires_grad=*/false); + bias_ = register_parameter( + "bias", torch::empty({out_features_per_partition}, options)); } } @@ -245,33 +239,27 @@ RowParallelQLinearImpl::RowParallelQLinearImpl( qweight_ = register_parameter( "qweight", torch::empty({in_features_per_partition / pack_factor, out_features}, - options.dtype(torch::kInt32)), - /*requires_grad=*/false); + options.dtype(torch::kInt32))); } else { qweight_ = register_parameter( "qweight", torch::empty({in_features_per_partition, out_features / pack_factor}, - options.dtype(torch::kInt32)), - /*requires_grad=*/false); + options.dtype(torch::kInt32))); } qzeros_ = register_parameter( "qzeros", torch::empty({round_up(in_features_per_partition, group_size), out_features / pack_factor}, - options.dtype(torch::kInt32)), - /*requires_grad=*/false); + options.dtype(torch::kInt32))); scales_ = register_parameter( "scales", torch::empty( {round_up(in_features_per_partition, group_size), out_features}, - options), - /*requires_grad=*/false); + options)); if (bias) { - bias_ = register_parameter("bias", - torch::empty({out_features}, options), - /*requires_grad=*/false); + bias_ = register_parameter("bias", torch::empty({out_features}, options)); } }