diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index e275fc22..e7548482 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -3,6 +3,7 @@ add_subdirectory(common) add_subdirectory(handlers) add_subdirectory(kernels) add_subdirectory(tokenizer) +add_subdirectory(module) add_subdirectory(layers) add_subdirectory(quantization) add_subdirectory(models) diff --git a/src/layers/CMakeLists.txt b/src/layers/CMakeLists.txt index de607d96..f178ce47 100644 --- a/src/layers/CMakeLists.txt +++ b/src/layers/CMakeLists.txt @@ -3,7 +3,7 @@ include(cc_test) cc_library( NAME - linear + linear HDRS linear.h qkv_linear.h @@ -21,35 +21,37 @@ cc_library( :model_parallel :quantization :kernels + :module glog::glog gflags::gflags torch ) cc_library( - NAME + NAME pos_embedding - HDRS + HDRS pos_embedding.h - SRCS + SRCS pos_embedding.cpp DEPS :state_dict :memory :kernels + :module glog::glog gflags::gflags torch ) cc_library( - NAME + NAME layers - HDRS + HDRS normalization.h embedding.h activation.h - SRCS + SRCS activation.cpp DEPS :state_dict @@ -58,6 +60,7 @@ cc_library( :pos_embedding :attention :kernels + :module glog::glog gflags::gflags torch @@ -80,4 +83,4 @@ cc_test( ) add_subdirectory(attention) -add_subdirectory(moe) \ No newline at end of file +add_subdirectory(moe) diff --git a/src/layers/attention/attention.h b/src/layers/attention/attention.h index bc453626..54773d22 100644 --- a/src/layers/attention/attention.h +++ b/src/layers/attention/attention.h @@ -5,10 +5,12 @@ #include "layers/attention/handler.h" #include "memory/kv_cache.h" #include "models/parameters.h" +#include "module/module.h" +#include "module/module_holder.h" namespace llm { -class AttentionImpl : public torch::nn::Module { +class AttentionImpl : public llm::nn::Module { public: AttentionImpl(int64_t n_heads, int64_t n_kv_heads, @@ -38,6 +40,6 @@ class AttentionImpl : public torch::nn::Module { // sliding window for self-attention, -1 means no sliding window int32_t sliding_window_ = -1; }; -TORCH_MODULE(Attention); +LLM_MODULE(Attention); } // namespace llm diff --git a/src/layers/embedding.h b/src/layers/embedding.h index 12abb6d2..b20fef70 100644 --- a/src/layers/embedding.h +++ b/src/layers/embedding.h @@ -7,6 +7,8 @@ #include "model_loader/state_dict.h" #include "model_parallel/model_parallel.h" +#include "module/module.h" +#include "module/module_holder.h" namespace llm { @@ -14,7 +16,7 @@ namespace llm { // This module is often used to store word embeddings and retrieve them using // indices. -class EmbeddingImpl : public torch::nn::Module { +class EmbeddingImpl : public llm::nn::Module { public: EmbeddingImpl(int64_t num_embeddings, int64_t embedding_dim, @@ -63,10 +65,10 @@ class EmbeddingImpl : public torch::nn::Module { // whether the weight is loaded bool is_loaded_ = false; }; -TORCH_MODULE(Embedding); +LLM_MODULE(Embedding); // Embedding parallelized in the embedding dimension. -class ParallelEmbeddingImpl : public torch::nn::Module { +class ParallelEmbeddingImpl : public llm::nn::Module { public: ParallelEmbeddingImpl(int64_t num_embeddings, int64_t embedding_dim, @@ -134,10 +136,10 @@ class ParallelEmbeddingImpl : public torch::nn::Module { // parallel args ParallelArgs parallel_args_; }; -TORCH_MODULE(ParallelEmbedding); +LLM_MODULE(ParallelEmbedding); // Embedding parallelized in the vocabulary dimension -class VocabParallelEmbeddingImpl : public torch::nn::Module { +class VocabParallelEmbeddingImpl : public llm::nn::Module { public: VocabParallelEmbeddingImpl(int64_t num_embeddings, int64_t embedding_dim, @@ -152,8 +154,7 @@ class VocabParallelEmbeddingImpl : public torch::nn::Module { // register the weight parameter weight_ = register_parameter( "weight", - torch::empty({num_embeddings_per_partition, embedding_dim}, - options), + torch::empty({num_embeddings_per_partition, embedding_dim}, options), /*requires_grad=*/false); } @@ -218,5 +219,5 @@ class VocabParallelEmbeddingImpl : public torch::nn::Module { int64_t start_index_ = 0; int64_t end_index_ = 0; }; -TORCH_MODULE(VocabParallelEmbedding); +LLM_MODULE(VocabParallelEmbedding); } // namespace llm diff --git a/src/layers/fused_linear.h b/src/layers/fused_linear.h index 36136028..b63d8555 100644 --- a/src/layers/fused_linear.h +++ b/src/layers/fused_linear.h @@ -6,11 +6,13 @@ #include "linear.h" #include "model_loader/state_dict.h" #include "model_parallel/parallel_args.h" +#include "module/module.h" +#include "module/module_holder.h" #include "quantization/quant_args.h" namespace llm { -class FusedColumnParallelLinearImpl : public torch::nn::Module { +class FusedColumnParallelLinearImpl : public llm::nn::Module { public: FusedColumnParallelLinearImpl(int64_t in_features, const std::vector& out_features, @@ -44,6 +46,6 @@ class FusedColumnParallelLinearImpl : public torch::nn::Module { // whether the linear layer is fused bool fused_ = false; }; -TORCH_MODULE(FusedColumnParallelLinear); +LLM_MODULE(FusedColumnParallelLinear); } // namespace llm diff --git a/src/layers/linear.h b/src/layers/linear.h index 61079563..b65b6d4a 100644 --- a/src/layers/linear.h +++ b/src/layers/linear.h @@ -5,6 +5,8 @@ #include "model_loader/state_dict.h" #include "model_parallel/parallel_args.h" +#include "module/module.h" +#include "module/module_holder.h" #include "quantization/quant_args.h" namespace llm { @@ -14,7 +16,7 @@ using TensorTransform = std::function; // an interface for parallel linear layer. // all linear classes should inherit from this class and implement the forward // function. -class ParallelLinearImpl : public torch::nn::Module { +class ParallelLinearImpl : public llm::nn::Module { public: ~ParallelLinearImpl() override = default; @@ -37,10 +39,9 @@ class ParallelLinearImpl : public torch::nn::Module { } }; -class ColumnParallelLinear - : public torch::nn::ModuleHolder { +class ColumnParallelLinear : public llm::nn::ModuleHolder { public: - using torch::nn::ModuleHolder::ModuleHolder; + using llm::nn::ModuleHolder::ModuleHolder; using Impl __attribute__((__unused__)) = ParallelLinearImpl; // construct a rotary positional embedding. @@ -61,9 +62,9 @@ class ColumnParallelLinear const torch::TensorOptions& options); }; -class RowParallelLinear : public torch::nn::ModuleHolder { +class RowParallelLinear : public llm::nn::ModuleHolder { public: - using torch::nn::ModuleHolder::ModuleHolder; + using llm::nn::ModuleHolder::ModuleHolder; using Impl __attribute__((__unused__)) = ParallelLinearImpl; // construct a rotary positional embedding. diff --git a/src/layers/normalization.h b/src/layers/normalization.h index 1add1fb6..748e96fe 100644 --- a/src/layers/normalization.h +++ b/src/layers/normalization.h @@ -7,6 +7,8 @@ #include "kernels/layernorm_kernels.h" #include "model_loader/state_dict.h" +#include "module/module.h" +#include "module/module_holder.h" DECLARE_bool(disable_custom_kernels); namespace llm { @@ -63,7 +65,7 @@ inline torch::Tensor layer_norm(torch::Tensor input, // apply layer normalization over a mini-batch of inputs as described in // the paper `Layer Normalization`: https://arxiv.org/abs/1607.06450 // x = ((x - mean(x)) / sqrt(std(x) + eps)) * weight + bias -class LayerNormImpl : public torch::nn::Module { +class LayerNormImpl : public llm::nn::Module { public: // dim: the dim over which the mean and std are calculated separately. // eps: a value added to the denominator for numerical stability. @@ -140,10 +142,10 @@ class LayerNormImpl : public torch::nn::Module { float eps_; std::vector normalized_shape_; }; -TORCH_MODULE(LayerNorm); +LLM_MODULE(LayerNorm); // Root mean square normalization -class RMSNormImpl : public torch::nn::Module { +class RMSNormImpl : public llm::nn::Module { public: RMSNormImpl(int64_t dim, float eps, const torch::TensorOptions& options) : eps_(eps) { @@ -191,9 +193,9 @@ class RMSNormImpl : public torch::nn::Module { // configs float eps_; }; -TORCH_MODULE(RMSNorm); +LLM_MODULE(RMSNorm); -class GemmaRMSNormImpl : public torch::nn::Module { +class GemmaRMSNormImpl : public llm::nn::Module { public: GemmaRMSNormImpl(int64_t dim, float eps, const torch::TensorOptions& options) : eps_(eps) { @@ -241,10 +243,10 @@ class GemmaRMSNormImpl : public torch::nn::Module { // configs float eps_; }; -TORCH_MODULE(GemmaRMSNorm); +LLM_MODULE(GemmaRMSNorm); // Root mean square normalization -class RMSNormResidualImpl : public torch::nn::Module { +class RMSNormResidualImpl : public llm::nn::Module { public: RMSNormResidualImpl(int64_t dim, float eps, @@ -304,6 +306,6 @@ class RMSNormResidualImpl : public torch::nn::Module { // configs float eps_; }; -TORCH_MODULE(RMSNormResidual); +LLM_MODULE(RMSNormResidual); } // namespace llm diff --git a/src/layers/qkv_linear.h b/src/layers/qkv_linear.h index 538676e0..1dab1c4b 100644 --- a/src/layers/qkv_linear.h +++ b/src/layers/qkv_linear.h @@ -12,7 +12,7 @@ namespace llm { // a thin wrapper to handle state_dict loading for QKV with // support of MQA/GQA -class QKVColumnParallelLinearImpl : public torch::nn::Module { +class QKVColumnParallelLinearImpl : public llm::nn::Module { public: QKVColumnParallelLinearImpl(int64_t hidden_size, int64_t n_heads, @@ -47,6 +47,6 @@ class QKVColumnParallelLinearImpl : public torch::nn::Module { int64_t head_dim_ = 0; }; -TORCH_MODULE(QKVColumnParallelLinear); +LLM_MODULE(QKVColumnParallelLinear); } // namespace llm diff --git a/src/models/CMakeLists.txt b/src/models/CMakeLists.txt index 582ff68a..47797106 100644 --- a/src/models/CMakeLists.txt +++ b/src/models/CMakeLists.txt @@ -2,7 +2,7 @@ include(cc_library) include(cc_test) cc_library( - NAME + NAME models HDRS model_args.h @@ -18,7 +18,7 @@ cc_library( :quantization :memory :chat_template + :module glog::glog torch ) - diff --git a/src/models/aquila.h b/src/models/aquila.h index 538d96f3..682456b6 100644 --- a/src/models/aquila.h +++ b/src/models/aquila.h @@ -14,11 +14,14 @@ #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" +#include "module/module.h" +#include "module/module_holder.h" +#include "module/modulelist.h" // Aquila model compatible with huggingface weights namespace llm::hf { -class AquilaMLPImpl : public torch::nn::Module { +class AquilaMLPImpl : public llm::nn::Module { public: AquilaMLPImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -77,9 +80,9 @@ class AquilaMLPImpl : public torch::nn::Module { // activation function ActFunc act_func_{nullptr}; }; -TORCH_MODULE(AquilaMLP); +LLM_MODULE(AquilaMLP); -class AquilaAttentionImpl : public torch::nn::Module { +class AquilaAttentionImpl : public llm::nn::Module { public: AquilaAttentionImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -164,9 +167,9 @@ class AquilaAttentionImpl : public torch::nn::Module { // size for q, k, v std::vector qkv_sizes_; }; -TORCH_MODULE(AquilaAttention); +LLM_MODULE(AquilaAttention); -class AquilaDecoderLayerImpl : public torch::nn::Module { +class AquilaDecoderLayerImpl : public llm::nn::Module { public: AquilaDecoderLayerImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -224,9 +227,9 @@ class AquilaDecoderLayerImpl : public torch::nn::Module { RMSNorm post_attention_layernorm_{nullptr}; }; -TORCH_MODULE(AquilaDecoderLayer); +LLM_MODULE(AquilaDecoderLayer); -class AquilaModelImpl : public torch::nn::Module { +class AquilaModelImpl : public llm::nn::Module { public: AquilaModelImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -241,7 +244,7 @@ class AquilaModelImpl : public torch::nn::Module { handler_ = AttentionHandler::create_handler_with_rope( args, /*interleaved=*/false, options); - blocks_ = register_module("layers", torch::nn::ModuleList()); + blocks_ = register_module("layers", llm::nn::ModuleList()); layers_.reserve(args.n_layers()); for (int32_t i = 0; i < args.n_layers(); i++) { auto block = AquilaDecoderLayer( @@ -295,15 +298,15 @@ class AquilaModelImpl : public torch::nn::Module { // attention handler std::unique_ptr handler_{nullptr}; - torch::nn::ModuleList blocks_{nullptr}; + llm::nn::ModuleList blocks_{nullptr}; // hold same data but different type as blocks_ to avoid type cast std::vector layers_; RMSNorm norm_{nullptr}; }; -TORCH_MODULE(AquilaModel); +LLM_MODULE(AquilaModel); -class AquilaForCausalLMImpl : public torch::nn::Module { +class AquilaForCausalLMImpl : public llm::nn::Module { public: AquilaForCausalLMImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -362,7 +365,7 @@ class AquilaForCausalLMImpl : public torch::nn::Module { ColumnParallelLinear lm_head_{nullptr}; }; -TORCH_MODULE(AquilaForCausalLM); +LLM_MODULE(AquilaForCausalLM); class AquilaChatTemplate final : public CodedChatTemplate { public: diff --git a/src/models/baichuan.h b/src/models/baichuan.h index 13e30ce1..f59700a5 100644 --- a/src/models/baichuan.h +++ b/src/models/baichuan.h @@ -16,6 +16,9 @@ #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" +#include "module/module.h" +#include "module/module_holder.h" +#include "module/modulelist.h" // Baichuan model compatible with huggingface weights @@ -28,7 +31,7 @@ enum class BaichuanType : uint8_t { Baichuan2_13B, }; -class BaichuanMLPImpl : public torch::nn::Module { +class BaichuanMLPImpl : public llm::nn::Module { public: BaichuanMLPImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -87,9 +90,9 @@ class BaichuanMLPImpl : public torch::nn::Module { // activation function ActFunc act_func_{nullptr}; }; -TORCH_MODULE(BaichuanMLP); +LLM_MODULE(BaichuanMLP); -class BaichuanAttentionImpl : public torch::nn::Module { +class BaichuanAttentionImpl : public llm::nn::Module { public: BaichuanAttentionImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -174,9 +177,9 @@ class BaichuanAttentionImpl : public torch::nn::Module { // size for local q, k, v std::vector qkv_sizes_; }; -TORCH_MODULE(BaichuanAttention); +LLM_MODULE(BaichuanAttention); -class BaichuanDecoderLayerImpl : public torch::nn::Module { +class BaichuanDecoderLayerImpl : public llm::nn::Module { public: BaichuanDecoderLayerImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -242,9 +245,9 @@ class BaichuanDecoderLayerImpl : public torch::nn::Module { RMSNormResidual input_layernorm_{nullptr}; RMSNormResidual post_attention_layernorm_{nullptr}; }; -TORCH_MODULE(BaichuanDecoderLayer); +LLM_MODULE(BaichuanDecoderLayer); -class BaichuanModelImpl : public torch::nn::Module { +class BaichuanModelImpl : public llm::nn::Module { public: BaichuanModelImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -268,7 +271,7 @@ class BaichuanModelImpl : public torch::nn::Module { args, alibi_slopes, options); } - blocks_ = register_module("layers", torch::nn::ModuleList()); + blocks_ = register_module("layers", llm::nn::ModuleList()); layers_.reserve(args.n_layers()); for (int32_t i = 0; i < args.n_layers(); i++) { auto block = BaichuanDecoderLayer(args, @@ -363,16 +366,16 @@ class BaichuanModelImpl : public torch::nn::Module { std::unique_ptr handler_{nullptr}; // parameter members, must be registered - torch::nn::ModuleList blocks_{nullptr}; + llm::nn::ModuleList blocks_{nullptr}; // hold same data but different type as blocks_ to avoid type cast std::vector layers_; // final layer norm RMSNormResidual norm_{nullptr}; }; -TORCH_MODULE(BaichuanModel); +LLM_MODULE(BaichuanModel); -class BaichuanForCausalLMImpl : public torch::nn::Module { +class BaichuanForCausalLMImpl : public llm::nn::Module { public: BaichuanForCausalLMImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -460,7 +463,7 @@ class BaichuanForCausalLMImpl : public torch::nn::Module { ColumnParallelLinear lm_head_{nullptr}; }; -TORCH_MODULE(BaichuanForCausalLM); +LLM_MODULE(BaichuanForCausalLM); class BaichuanChatTemplate final : public CodedChatTemplate { public: diff --git a/src/models/bloom.h b/src/models/bloom.h index d5d61774..cd80e03a 100644 --- a/src/models/bloom.h +++ b/src/models/bloom.h @@ -13,12 +13,15 @@ #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" +#include "module/module.h" +#include "module/module_holder.h" +#include "module/modulelist.h" // bloom model compatible with huggingface weights namespace llm::hf { -class BloomMLPImpl : public torch::nn::Module { +class BloomMLPImpl : public llm::nn::Module { public: BloomMLPImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -74,9 +77,9 @@ class BloomMLPImpl : public torch::nn::Module { ActFunc act_{nullptr}; }; -TORCH_MODULE(BloomMLP); +LLM_MODULE(BloomMLP); -class BloomAttentionImpl : public torch::nn::Module { +class BloomAttentionImpl : public llm::nn::Module { public: BloomAttentionImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -168,9 +171,9 @@ class BloomAttentionImpl : public torch::nn::Module { int64_t hidden_size_ = 0; int64_t head_dim_ = 0; }; -TORCH_MODULE(BloomAttention); +LLM_MODULE(BloomAttention); -class BloomBlockImpl : public torch::nn::Module { +class BloomBlockImpl : public llm::nn::Module { public: BloomBlockImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -240,9 +243,9 @@ class BloomBlockImpl : public torch::nn::Module { bool residual_post_layernorm_ = false; }; -TORCH_MODULE(BloomBlock); +LLM_MODULE(BloomBlock); -class BloomModelImpl : public torch::nn::Module { +class BloomModelImpl : public llm::nn::Module { public: BloomModelImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -265,7 +268,7 @@ class BloomModelImpl : public torch::nn::Module { handler_ = AttentionHandler::create_handler_with_alibi( args, alibi_slopes, options); - blocks_ = register_module("h", torch::nn::ModuleList()); + blocks_ = register_module("h", llm::nn::ModuleList()); layers_.reserve(args.n_layers()); for (int32_t i = 0; i < args.n_layers(); i++) { auto block = @@ -359,16 +362,16 @@ class BloomModelImpl : public torch::nn::Module { // attention handler std::unique_ptr handler_{nullptr}; - torch::nn::ModuleList blocks_{nullptr}; + llm::nn::ModuleList blocks_{nullptr}; // hold same data but different type as blocks_ to avoid type cast std::vector layers_; // final layer norm LayerNorm ln_f_{nullptr}; }; -TORCH_MODULE(BloomModel); +LLM_MODULE(BloomModel); -class BloomForCausalLMImpl : public torch::nn::Module { +class BloomForCausalLMImpl : public llm::nn::Module { public: BloomForCausalLMImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -428,7 +431,7 @@ class BloomForCausalLMImpl : public torch::nn::Module { ColumnParallelLinear lm_head_{nullptr}; }; -TORCH_MODULE(BloomForCausalLM); +LLM_MODULE(BloomForCausalLM); // register the model to make it available REGISTER_CAUSAL_MODEL(bloom, BloomForCausalLM); diff --git a/src/models/causal_lm.h b/src/models/causal_lm.h index 14f66d67..02a54e29 100644 --- a/src/models/causal_lm.h +++ b/src/models/causal_lm.h @@ -15,9 +15,9 @@ namespace llm { // An interface for causal language models that can hold different models. -class CausalLM : public torch::nn::Module { +class CausalLM { public: - ~CausalLM() override = default; + virtual ~CausalLM() = default; // tokens: [num_tokens] // positions: [num_tokens] token pos in the sequence diff --git a/src/models/chatglm.h b/src/models/chatglm.h index aa59bba9..e4d6ce6b 100644 --- a/src/models/chatglm.h +++ b/src/models/chatglm.h @@ -14,13 +14,16 @@ #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" +#include "module/module.h" +#include "module/module_holder.h" +#include "module/modulelist.h" #include "tokenizer/tokenizer_args.h" // ChatGLM model compatible with huggingface weights namespace llm::hf { -class ChatGLMMLPImpl : public torch::nn::Module { +class ChatGLMMLPImpl : public llm::nn::Module { public: ChatGLMMLPImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -77,9 +80,9 @@ class ChatGLMMLPImpl : public torch::nn::Module { // calculate act(x) * y ActFunc act_with_mul_{nullptr}; }; -TORCH_MODULE(ChatGLMMLP); +LLM_MODULE(ChatGLMMLP); -class ChatGLMAttentionImpl : public torch::nn::Module { +class ChatGLMAttentionImpl : public llm::nn::Module { public: ChatGLMAttentionImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -162,9 +165,9 @@ class ChatGLMAttentionImpl : public torch::nn::Module { // size for local q, k, v std::vector qkv_sizes_; }; -TORCH_MODULE(ChatGLMAttention); +LLM_MODULE(ChatGLMAttention); -class ChatGLMBlockImpl : public torch::nn::Module { +class ChatGLMBlockImpl : public llm::nn::Module { public: ChatGLMBlockImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -273,9 +276,9 @@ class ChatGLMBlockImpl : public torch::nn::Module { bool residual_post_layernorm_ = false; bool use_rms_norm_ = false; }; -TORCH_MODULE(ChatGLMBlock); +LLM_MODULE(ChatGLMBlock); -class ChatGLMModelImpl : public torch::nn::Module { +class ChatGLMModelImpl : public llm::nn::Module { public: ChatGLMModelImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -287,7 +290,7 @@ class ChatGLMModelImpl : public torch::nn::Module { args, /*interleaved=*/true, options); // register submodules - blocks_ = register_module("layers", torch::nn::ModuleList()); + blocks_ = register_module("layers", llm::nn::ModuleList()); layers_.reserve(args.n_layers()); for (int32_t i = 0; i < args.n_layers(); i++) { auto block = ChatGLMBlock( @@ -367,7 +370,7 @@ class ChatGLMModelImpl : public torch::nn::Module { std::unique_ptr handler_{nullptr}; // parameter members, must be registered - torch::nn::ModuleList blocks_{nullptr}; + llm::nn::ModuleList blocks_{nullptr}; // hold same data but different type as blocks_ to avoid type cast std::vector layers_; @@ -378,9 +381,9 @@ class ChatGLMModelImpl : public torch::nn::Module { bool post_layernorm_ = false; bool use_rms_norm_ = false; }; -TORCH_MODULE(ChatGLMModel); +LLM_MODULE(ChatGLMModel); -class ChatGLMForCausalLMImpl : public torch::nn::Module { +class ChatGLMForCausalLMImpl : public llm::nn::Module { public: ChatGLMForCausalLMImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -450,7 +453,7 @@ class ChatGLMForCausalLMImpl : public torch::nn::Module { ColumnParallelLinear output_layer_{nullptr}; }; -TORCH_MODULE(ChatGLMForCausalLM); +LLM_MODULE(ChatGLMForCausalLM); class ChatGLMChatTemplate final : public CodedChatTemplate { public: diff --git a/src/models/gemma.h b/src/models/gemma.h index 111ce9d9..dc5babea 100644 --- a/src/models/gemma.h +++ b/src/models/gemma.h @@ -17,11 +17,14 @@ #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" +#include "module/module.h" +#include "module/module_holder.h" +#include "module/modulelist.h" // Gemma model compatible with huggingface weight namespace llm::hf { -class GemmaMLPImpl : public torch::nn::Module { +class GemmaMLPImpl : public llm::nn::Module { public: GemmaMLPImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -80,9 +83,9 @@ class GemmaMLPImpl : public torch::nn::Module { // activation function ActFunc act_func_{nullptr}; }; -TORCH_MODULE(GemmaMLP); +LLM_MODULE(GemmaMLP); -class GemmaAttentionImpl : public torch::nn::Module { +class GemmaAttentionImpl : public llm::nn::Module { public: GemmaAttentionImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -160,9 +163,9 @@ class GemmaAttentionImpl : public torch::nn::Module { // module members without parameters Attention atten_{nullptr}; }; -TORCH_MODULE(GemmaAttention); +LLM_MODULE(GemmaAttention); -class GemmaDecoderLayerImpl : public torch::nn::Module { +class GemmaDecoderLayerImpl : public llm::nn::Module { public: GemmaDecoderLayerImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -229,9 +232,9 @@ class GemmaDecoderLayerImpl : public torch::nn::Module { GemmaRMSNorm post_attention_layernorm_{nullptr}; }; -TORCH_MODULE(GemmaDecoderLayer); +LLM_MODULE(GemmaDecoderLayer); -class GemmaModelImpl : public torch::nn::Module { +class GemmaModelImpl : public llm::nn::Module { public: GemmaModelImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -258,7 +261,7 @@ class GemmaModelImpl : public torch::nn::Module { handler_ = AttentionHandler::create_handler_with_rope( args, /*interleaved=*/false, options); - blocks_ = register_module("layers", torch::nn::ModuleList()); + blocks_ = register_module("layers", llm::nn::ModuleList()); layers_.reserve(args.n_layers()); for (int32_t i = 0; i < args.n_layers(); i++) { auto block = GemmaDecoderLayer( @@ -317,13 +320,13 @@ class GemmaModelImpl : public torch::nn::Module { // attention handler std::unique_ptr handler_{nullptr}; - torch::nn::ModuleList blocks_{nullptr}; + llm::nn::ModuleList blocks_{nullptr}; // hold same data but different type as blocks_ to avoid type cast std::vector layers_; }; -TORCH_MODULE(GemmaModel); +LLM_MODULE(GemmaModel); -class GemmaForCausalLMImpl : public torch::nn::Module { +class GemmaForCausalLMImpl : public llm::nn::Module { public: GemmaForCausalLMImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -386,7 +389,7 @@ class GemmaForCausalLMImpl : public torch::nn::Module { ColumnParallelLinear lm_head_{nullptr}; }; -TORCH_MODULE(GemmaForCausalLM); +LLM_MODULE(GemmaForCausalLM); class GemmaChatTemplate final : public CodedChatTemplate { public: @@ -451,4 +454,4 @@ REGISTER_MODEL_ARGS(gemma, [&] { }); }); -} // namespace llm::hf \ No newline at end of file +} // namespace llm::hf diff --git a/src/models/gemma2.h b/src/models/gemma2.h index c1f1766e..c68140a1 100644 --- a/src/models/gemma2.h +++ b/src/models/gemma2.h @@ -17,11 +17,14 @@ #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" +#include "module/module.h" +#include "module/module_holder.h" +#include "module/modulelist.h" // Gemma2 model compatible with huggingface weight namespace llm::hf { -class Gemma2MLPImpl : public torch::nn::Module { +class Gemma2MLPImpl : public llm::nn::Module { public: Gemma2MLPImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -80,9 +83,9 @@ class Gemma2MLPImpl : public torch::nn::Module { // activation function ActFunc act_func_{nullptr}; }; -TORCH_MODULE(Gemma2MLP); +LLM_MODULE(Gemma2MLP); -class Gemma2AttentionImpl : public torch::nn::Module { +class Gemma2AttentionImpl : public llm::nn::Module { public: Gemma2AttentionImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -168,9 +171,9 @@ class Gemma2AttentionImpl : public torch::nn::Module { // size for q, k, v std::vector qkv_sizes_; }; -TORCH_MODULE(Gemma2Attention); +LLM_MODULE(Gemma2Attention); -class Gemma2DecoderLayerImpl : public torch::nn::Module { +class Gemma2DecoderLayerImpl : public llm::nn::Module { public: Gemma2DecoderLayerImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -260,9 +263,9 @@ class Gemma2DecoderLayerImpl : public torch::nn::Module { GemmaRMSNorm post_feedforward_layernorm_{nullptr}; }; -TORCH_MODULE(Gemma2DecoderLayer); +LLM_MODULE(Gemma2DecoderLayer); -class Gemma2ModelImpl : public torch::nn::Module { +class Gemma2ModelImpl : public llm::nn::Module { public: Gemma2ModelImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -287,7 +290,7 @@ class Gemma2ModelImpl : public torch::nn::Module { handler_ = AttentionHandler::create_handler_with_rope( args, /*interleaved=*/false, options); - blocks_ = register_module("layers", torch::nn::ModuleList()); + blocks_ = register_module("layers", llm::nn::ModuleList()); layers_.reserve(args.n_layers()); for (int32_t i = 0; i < args.n_layers(); i++) { // Attention Type: [LOCAL_SLIDING, Global, LOCAL_SLIDING, Global, ...] @@ -353,13 +356,13 @@ class Gemma2ModelImpl : public torch::nn::Module { // attention handler std::unique_ptr handler_{nullptr}; - torch::nn::ModuleList blocks_{nullptr}; + llm::nn::ModuleList blocks_{nullptr}; // hold same data but different type as blocks_ to avoid type cast std::vector layers_; }; -TORCH_MODULE(Gemma2Model); +LLM_MODULE(Gemma2Model); -class Gemma2ForCausalLMImpl : public torch::nn::Module { +class Gemma2ForCausalLMImpl : public llm::nn::Module { public: Gemma2ForCausalLMImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -427,7 +430,7 @@ class Gemma2ForCausalLMImpl : public torch::nn::Module { float final_logit_soft_cap_{0.0f}; }; -TORCH_MODULE(Gemma2ForCausalLM); +LLM_MODULE(Gemma2ForCausalLM); class Gemma2ChatTemplate final : public CodedChatTemplate { public: @@ -487,4 +490,4 @@ REGISTER_MODEL_ARGS(gemma2, [&] { }); }); -} // namespace llm::hf \ No newline at end of file +} // namespace llm::hf diff --git a/src/models/gpt2.h b/src/models/gpt2.h index 35b33112..bb910d5d 100644 --- a/src/models/gpt2.h +++ b/src/models/gpt2.h @@ -13,12 +13,15 @@ #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" +#include "module/module.h" +#include "module/module_holder.h" +#include "module/modulelist.h" // gpt2 model compatible with huggingface weights namespace llm::hf { -class GPT2MLPImpl : public torch::nn::Module { +class GPT2MLPImpl : public llm::nn::Module { public: GPT2MLPImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -80,9 +83,9 @@ class GPT2MLPImpl : public torch::nn::Module { ActFunc act_{nullptr}; }; -TORCH_MODULE(GPT2MLP); +LLM_MODULE(GPT2MLP); -class GPT2AttentionImpl : public torch::nn::Module { +class GPT2AttentionImpl : public llm::nn::Module { public: GPT2AttentionImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -168,9 +171,9 @@ class GPT2AttentionImpl : public torch::nn::Module { int64_t hidden_size_ = 0; int64_t head_dim_ = 0; }; -TORCH_MODULE(GPT2Attention); +LLM_MODULE(GPT2Attention); -class GPT2BlockImpl : public torch::nn::Module { +class GPT2BlockImpl : public llm::nn::Module { public: GPT2BlockImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -230,9 +233,9 @@ class GPT2BlockImpl : public torch::nn::Module { LayerNorm ln_2_{nullptr}; }; -TORCH_MODULE(GPT2Block); +LLM_MODULE(GPT2Block); -class GPT2ModelImpl : public torch::nn::Module { +class GPT2ModelImpl : public llm::nn::Module { public: GPT2ModelImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -249,7 +252,7 @@ class GPT2ModelImpl : public torch::nn::Module { handler_ = AttentionHandler::create_handler(args, options); - blocks_ = register_module("h", torch::nn::ModuleList()); + blocks_ = register_module("h", llm::nn::ModuleList()); layers_.reserve(args.n_layers()); for (int32_t i = 0; i < args.n_layers(); i++) { auto block = @@ -310,15 +313,15 @@ class GPT2ModelImpl : public torch::nn::Module { // attention handler std::unique_ptr handler_{nullptr}; - torch::nn::ModuleList blocks_{nullptr}; + llm::nn::ModuleList blocks_{nullptr}; // hold same data but different type as blocks_ to avoid type cast std::vector layers_; LayerNorm ln_f_{nullptr}; }; -TORCH_MODULE(GPT2Model); +LLM_MODULE(GPT2Model); -class GPT2ForCausalLMImpl : public torch::nn::Module { +class GPT2ForCausalLMImpl : public llm::nn::Module { public: GPT2ForCausalLMImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -378,7 +381,7 @@ class GPT2ForCausalLMImpl : public torch::nn::Module { ColumnParallelLinear lm_head_{nullptr}; }; -TORCH_MODULE(GPT2ForCausalLM); +LLM_MODULE(GPT2ForCausalLM); // register the model to make it available REGISTER_CAUSAL_MODEL(gpt2, GPT2ForCausalLM); diff --git a/src/models/gpt_j.h b/src/models/gpt_j.h index c00a0576..3f889d07 100644 --- a/src/models/gpt_j.h +++ b/src/models/gpt_j.h @@ -12,12 +12,14 @@ #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" - +#include "module/module.h" +#include "module/module_holder.h" +#include "module/modulelist.h" // GPTJ model compatible with huggingface weights namespace llm::hf { -class GPTJMLPImpl : public torch::nn::Module { +class GPTJMLPImpl : public llm::nn::Module { public: GPTJMLPImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -69,9 +71,9 @@ class GPTJMLPImpl : public torch::nn::Module { ActFunc act_{nullptr}; }; -TORCH_MODULE(GPTJMLP); +LLM_MODULE(GPTJMLP); -class GPTJAttentionImpl : public torch::nn::Module { +class GPTJAttentionImpl : public llm::nn::Module { public: GPTJAttentionImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -141,9 +143,9 @@ class GPTJAttentionImpl : public torch::nn::Module { // module members without parameters Attention atten_{nullptr}; }; -TORCH_MODULE(GPTJAttention); +LLM_MODULE(GPTJAttention); -class GPTJBlockImpl : public torch::nn::Module { +class GPTJBlockImpl : public llm::nn::Module { public: GPTJBlockImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -196,9 +198,9 @@ class GPTJBlockImpl : public torch::nn::Module { LayerNorm ln_1_{nullptr}; }; -TORCH_MODULE(GPTJBlock); +LLM_MODULE(GPTJBlock); -class GPTJModelImpl : public torch::nn::Module { +class GPTJModelImpl : public llm::nn::Module { public: GPTJModelImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -213,7 +215,7 @@ class GPTJModelImpl : public torch::nn::Module { handler_ = AttentionHandler::create_handler_with_rope( args, /*interleaved=*/true, options); - blocks_ = register_module("h", torch::nn::ModuleList()); + blocks_ = register_module("h", llm::nn::ModuleList()); layers_.reserve(args.n_layers()); for (int32_t i = 0; i < args.n_layers(); i++) { auto block = @@ -271,15 +273,15 @@ class GPTJModelImpl : public torch::nn::Module { // attention handler std::unique_ptr handler_{nullptr}; - torch::nn::ModuleList blocks_{nullptr}; + llm::nn::ModuleList blocks_{nullptr}; // hold same data but different type as blocks_ to avoid type cast std::vector layers_; LayerNorm ln_f_{nullptr}; }; -TORCH_MODULE(GPTJModel); +LLM_MODULE(GPTJModel); -class GPTJForCausalLMImpl : public torch::nn::Module { +class GPTJForCausalLMImpl : public llm::nn::Module { public: GPTJForCausalLMImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -338,7 +340,7 @@ class GPTJForCausalLMImpl : public torch::nn::Module { ColumnParallelLinear lm_head_{nullptr}; }; -TORCH_MODULE(GPTJForCausalLM); +LLM_MODULE(GPTJForCausalLM); // register the model to make it available REGISTER_CAUSAL_MODEL(gptj, GPTJForCausalLM); diff --git a/src/models/gpt_neox.h b/src/models/gpt_neox.h index b86acb25..e870fd3f 100644 --- a/src/models/gpt_neox.h +++ b/src/models/gpt_neox.h @@ -12,12 +12,14 @@ #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" - +#include "module/module.h" +#include "module/module_holder.h" +#include "module/modulelist.h" // gpt-neox model compatible with huggingface weights namespace llm::hf { -class GPTNeoXMLPImpl : public torch::nn::Module { +class GPTNeoXMLPImpl : public llm::nn::Module { public: GPTNeoXMLPImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -73,9 +75,9 @@ class GPTNeoXMLPImpl : public torch::nn::Module { ActFunc act_{nullptr}; }; -TORCH_MODULE(GPTNeoXMLP); +LLM_MODULE(GPTNeoXMLP); -class GPTNeoXAttentionImpl : public torch::nn::Module { +class GPTNeoXAttentionImpl : public llm::nn::Module { public: GPTNeoXAttentionImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -167,9 +169,9 @@ class GPTNeoXAttentionImpl : public torch::nn::Module { int64_t hidden_size_ = 0; int64_t head_dim_ = 0; }; -TORCH_MODULE(GPTNeoXAttention); +LLM_MODULE(GPTNeoXAttention); -class GPTNeoXLayerImpl : public torch::nn::Module { +class GPTNeoXLayerImpl : public llm::nn::Module { public: GPTNeoXLayerImpl(uint32_t layer_id, const ModelArgs& args, @@ -244,9 +246,9 @@ class GPTNeoXLayerImpl : public torch::nn::Module { bool use_parallel_residual_; }; -TORCH_MODULE(GPTNeoXLayer); +LLM_MODULE(GPTNeoXLayer); -class GPTNeoXModelImpl : public torch::nn::Module { +class GPTNeoXModelImpl : public llm::nn::Module { public: GPTNeoXModelImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -261,7 +263,7 @@ class GPTNeoXModelImpl : public torch::nn::Module { handler_ = AttentionHandler::create_handler_with_rope( args, /*interleaved=*/false, options); - blocks_ = register_module("layers", torch::nn::ModuleList()); + blocks_ = register_module("layers", llm::nn::ModuleList()); layers_.reserve(args.n_layers()); for (int32_t i = 0; i < args.n_layers(); i++) { auto block = GPTNeoXLayer( @@ -319,15 +321,15 @@ class GPTNeoXModelImpl : public torch::nn::Module { // attention handler std::unique_ptr handler_{nullptr}; - torch::nn::ModuleList blocks_{nullptr}; + llm::nn::ModuleList blocks_{nullptr}; // hold same data but different type as blocks_ to avoid type cast std::vector layers_; LayerNorm final_layer_norm_{nullptr}; }; -TORCH_MODULE(GPTNeoXModel); +LLM_MODULE(GPTNeoXModel); -class GPTNeoXForCausalLMImpl : public torch::nn::Module { +class GPTNeoXForCausalLMImpl : public llm::nn::Module { public: GPTNeoXForCausalLMImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -386,7 +388,7 @@ class GPTNeoXForCausalLMImpl : public torch::nn::Module { ColumnParallelLinear embed_out_{nullptr}; }; -TORCH_MODULE(GPTNeoXForCausalLM); +LLM_MODULE(GPTNeoXForCausalLM); // register the model to make it available REGISTER_CAUSAL_MODEL(gpt_neox, GPTNeoXForCausalLM); diff --git a/src/models/internlm.h b/src/models/internlm.h index a9aba260..8958cf98 100644 --- a/src/models/internlm.h +++ b/src/models/internlm.h @@ -14,11 +14,13 @@ #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" - +#include "module/module.h" +#include "module/module_holder.h" +#include "module/modulelist.h" // Internlm model compatible with huggingface weights namespace llm::hf { -class InternlmMLPImpl : public torch::nn::Module { +class InternlmMLPImpl : public llm::nn::Module { public: InternlmMLPImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -77,9 +79,9 @@ class InternlmMLPImpl : public torch::nn::Module { // calculate act(x) * y ActFunc act_func_{nullptr}; }; -TORCH_MODULE(InternlmMLP); +LLM_MODULE(InternlmMLP); -class InternlmAttentionImpl : public torch::nn::Module { +class InternlmAttentionImpl : public llm::nn::Module { public: InternlmAttentionImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -151,9 +153,9 @@ class InternlmAttentionImpl : public torch::nn::Module { // module members without parameters Attention atten_{nullptr}; }; -TORCH_MODULE(InternlmAttention); +LLM_MODULE(InternlmAttention); -class InternlmDecoderLayerImpl : public torch::nn::Module { +class InternlmDecoderLayerImpl : public llm::nn::Module { public: InternlmDecoderLayerImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -211,9 +213,9 @@ class InternlmDecoderLayerImpl : public torch::nn::Module { RMSNorm post_attention_layernorm_{nullptr}; }; -TORCH_MODULE(InternlmDecoderLayer); +LLM_MODULE(InternlmDecoderLayer); -class InternlmModelImpl : public torch::nn::Module { +class InternlmModelImpl : public llm::nn::Module { public: InternlmModelImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -228,7 +230,7 @@ class InternlmModelImpl : public torch::nn::Module { handler_ = AttentionHandler::create_handler_with_rope( args, /*interleaved=*/false, options); - blocks_ = register_module("layers", torch::nn::ModuleList()); + blocks_ = register_module("layers", llm::nn::ModuleList()); layers_.reserve(args.n_layers()); for (int32_t i = 0; i < args.n_layers(); i++) { auto block = InternlmDecoderLayer( @@ -283,15 +285,15 @@ class InternlmModelImpl : public torch::nn::Module { // attention handler std::unique_ptr handler_{nullptr}; - torch::nn::ModuleList blocks_{nullptr}; + llm::nn::ModuleList blocks_{nullptr}; // hold same data but different type as blocks_ to avoid type cast std::vector layers_; RMSNorm norm_{nullptr}; }; -TORCH_MODULE(InternlmModel); +LLM_MODULE(InternlmModel); -class InternlmForCausalLMImpl : public torch::nn::Module { +class InternlmForCausalLMImpl : public llm::nn::Module { public: InternlmForCausalLMImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -350,7 +352,7 @@ class InternlmForCausalLMImpl : public torch::nn::Module { ColumnParallelLinear lm_head_{nullptr}; }; -TORCH_MODULE(InternlmForCausalLM); +LLM_MODULE(InternlmForCausalLM); class InternlmDialog final : public CodedChatTemplate { public: diff --git a/src/models/llama.h b/src/models/llama.h index 0a41910b..959ba552 100644 --- a/src/models/llama.h +++ b/src/models/llama.h @@ -16,11 +16,13 @@ #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" - +#include "module/module.h" +#include "module/module_holder.h" +#include "module/modulelist.h" // llama2 model compatible with huggingface weights namespace llm::hf { -class LlamaMLPImpl : public torch::nn::Module { +class LlamaMLPImpl : public llm::nn::Module { public: LlamaMLPImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -80,9 +82,9 @@ class LlamaMLPImpl : public torch::nn::Module { // activation function ActFunc act_func_{nullptr}; }; -TORCH_MODULE(LlamaMLP); +LLM_MODULE(LlamaMLP); -class LlamaAttentionImpl : public torch::nn::Module { +class LlamaAttentionImpl : public llm::nn::Module { public: LlamaAttentionImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -163,9 +165,9 @@ class LlamaAttentionImpl : public torch::nn::Module { // size for q, k, v std::vector qkv_sizes_; }; -TORCH_MODULE(LlamaAttention); +LLM_MODULE(LlamaAttention); -class LlamaDecoderLayerImpl : public torch::nn::Module { +class LlamaDecoderLayerImpl : public llm::nn::Module { public: LlamaDecoderLayerImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -223,9 +225,9 @@ class LlamaDecoderLayerImpl : public torch::nn::Module { RMSNorm post_attention_layernorm_{nullptr}; }; -TORCH_MODULE(LlamaDecoderLayer); +LLM_MODULE(LlamaDecoderLayer); -class LlamaModelImpl : public torch::nn::Module { +class LlamaModelImpl : public llm::nn::Module { public: LlamaModelImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -240,7 +242,7 @@ class LlamaModelImpl : public torch::nn::Module { handler_ = AttentionHandler::create_handler_with_rope( args, /*interleaved=*/false, options); - blocks_ = register_module("layers", torch::nn::ModuleList()); + blocks_ = register_module("layers", llm::nn::ModuleList()); layers_.reserve(args.n_layers()); for (int32_t i = 0; i < args.n_layers(); i++) { auto block = LlamaDecoderLayer( @@ -295,15 +297,15 @@ class LlamaModelImpl : public torch::nn::Module { // attention handler std::unique_ptr handler_{nullptr}; - torch::nn::ModuleList blocks_{nullptr}; + llm::nn::ModuleList blocks_{nullptr}; // hold same data but different type as blocks_ to avoid type cast std::vector layers_; RMSNorm norm_{nullptr}; }; -TORCH_MODULE(LlamaModel); +LLM_MODULE(LlamaModel); -class LlamaForCausalLMImpl : public torch::nn::Module { +class LlamaForCausalLMImpl : public llm::nn::Module { public: LlamaForCausalLMImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -362,7 +364,7 @@ class LlamaForCausalLMImpl : public torch::nn::Module { ColumnParallelLinear lm_head_{nullptr}; }; -TORCH_MODULE(LlamaForCausalLM); +LLM_MODULE(LlamaForCausalLM); class YiChatTemplate final : public CodedChatTemplate { public: diff --git a/src/models/mistral.h b/src/models/mistral.h index 77b1194b..120fcd4a 100644 --- a/src/models/mistral.h +++ b/src/models/mistral.h @@ -13,11 +13,13 @@ #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" - +#include "module/module.h" +#include "module/module_holder.h" +#include "module/modulelist.h" // Mistral model compatible with huggingface weights namespace llm::hf { -class MistralMLPImpl : public torch::nn::Module { +class MistralMLPImpl : public llm::nn::Module { public: MistralMLPImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -75,9 +77,9 @@ class MistralMLPImpl : public torch::nn::Module { ActFunc act_func_{nullptr}; }; -TORCH_MODULE(MistralMLP); +LLM_MODULE(MistralMLP); -class MistralAttentionImpl : public torch::nn::Module { +class MistralAttentionImpl : public llm::nn::Module { public: MistralAttentionImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -154,9 +156,9 @@ class MistralAttentionImpl : public torch::nn::Module { // module members without parameters Attention atten_{nullptr}; }; -TORCH_MODULE(MistralAttention); +LLM_MODULE(MistralAttention); -class MistralDecoderLayerImpl : public torch::nn::Module { +class MistralDecoderLayerImpl : public llm::nn::Module { public: MistralDecoderLayerImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -214,9 +216,9 @@ class MistralDecoderLayerImpl : public torch::nn::Module { RMSNorm post_attention_layernorm_{nullptr}; }; -TORCH_MODULE(MistralDecoderLayer); +LLM_MODULE(MistralDecoderLayer); -class MistralModelImpl : public torch::nn::Module { +class MistralModelImpl : public llm::nn::Module { public: MistralModelImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -231,7 +233,7 @@ class MistralModelImpl : public torch::nn::Module { handler_ = AttentionHandler::create_handler_with_rope( args, /*interleaved=*/false, options); - blocks_ = register_module("layers", torch::nn::ModuleList()); + blocks_ = register_module("layers", llm::nn::ModuleList()); layers_.reserve(args.n_layers()); for (int32_t i = 0; i < args.n_layers(); i++) { auto block = MistralDecoderLayer( @@ -286,15 +288,15 @@ class MistralModelImpl : public torch::nn::Module { // attention handler std::unique_ptr handler_{nullptr}; - torch::nn::ModuleList blocks_{nullptr}; + llm::nn::ModuleList blocks_{nullptr}; // hold same data but different type as blocks_ to avoid type cast std::vector layers_; RMSNorm norm_{nullptr}; }; -TORCH_MODULE(MistralModel); +LLM_MODULE(MistralModel); -class MistralForCausalLMImpl : public torch::nn::Module { +class MistralForCausalLMImpl : public llm::nn::Module { public: MistralForCausalLMImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -353,7 +355,7 @@ class MistralForCausalLMImpl : public torch::nn::Module { ColumnParallelLinear lm_head_{nullptr}; }; -TORCH_MODULE(MistralForCausalLM); +LLM_MODULE(MistralForCausalLM); class MistralChatTemplate final : public CodedChatTemplate { public: diff --git a/src/models/models.h b/src/models/models.h index 5695892c..e569821e 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -1,19 +1,19 @@ #pragma once // list all registered models here -#include "aquila.h" // IWYU pragma: keep -#include "baichuan.h" // IWYU pargma: keep -#include "bloom.h" // IWYU pragma: keep -#include "chatglm.h" // IWYU pragma: keep -#include "gemma.h" // IWYU pragma: keep -#include "gemma2.h" // IWYU pragma: keep -#include "gpt2.h" // IWYU pragma: keep -#include "gpt_j.h" // IWYU pragma: keep -#include "gpt_neox.h" // IWYU pragma: keep -#include "internlm.h" // IWYU pragma: keep -#include "llama.h" // IWYU pragma: keep -#include "mistral.h" // IWYU pragma: keep -#include "mpt.h" // IWYU pragma: keep -#include "phi.h" // IWYU pragma: keep -#include "qwen.h" // IWYU pragma: keep -#include "qwen2.h" // IWYU pragma: keep +#include "aquila.h" // IWYU pragma: keep +#include "baichuan.h" // IWYU pargma: keep +#include "bloom.h" // IWYU pragma: keep +#include "chatglm.h" // IWYU pragma: keep +#include "gemma.h" // IWYU pragma: keep +#include "gemma2.h" // IWYU pragma: keep +#include "gpt2.h" // IWYU pragma: keep +#include "gpt_j.h" // IWYU pragma: keep +#include "gpt_neox.h" // IWYU pragma: keep +#include "internlm.h" // IWYU pragma: keep +#include "llama.h" // IWYU pragma: keep +#include "mistral.h" // IWYU pragma: keep +#include "mpt.h" // IWYU pragma: keep +#include "phi.h" // IWYU pragma: keep +#include "qwen.h" // IWYU pragma: keep +#include "qwen2.h" // IWYU pragma: keep diff --git a/src/models/mpt.h b/src/models/mpt.h index adf91401..86249921 100644 --- a/src/models/mpt.h +++ b/src/models/mpt.h @@ -15,11 +15,13 @@ #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" - +#include "module/module.h" +#include "module/module_holder.h" +#include "module/modulelist.h" // mpt model compatible with huggingface weights namespace llm::hf { -class MPTMLPImpl : public torch::nn::Module { +class MPTMLPImpl : public llm::nn::Module { public: MPTMLPImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -74,9 +76,9 @@ class MPTMLPImpl : public torch::nn::Module { ActFunc act_{nullptr}; }; -TORCH_MODULE(MPTMLP); +LLM_MODULE(MPTMLP); -class MPTAttentionImpl : public torch::nn::Module { +class MPTAttentionImpl : public llm::nn::Module { public: MPTAttentionImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -230,9 +232,9 @@ class MPTAttentionImpl : public torch::nn::Module { int64_t hidden_size_ = 0; int64_t head_dim_ = 0; }; -TORCH_MODULE(MPTAttention); +LLM_MODULE(MPTAttention); -class MPTBlockImpl : public torch::nn::Module { +class MPTBlockImpl : public llm::nn::Module { public: MPTBlockImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -290,9 +292,9 @@ class MPTBlockImpl : public torch::nn::Module { LayerNorm norm_2_{nullptr}; }; -TORCH_MODULE(MPTBlock); +LLM_MODULE(MPTBlock); -class MPTModelImpl : public torch::nn::Module { +class MPTModelImpl : public llm::nn::Module { public: MPTModelImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -310,7 +312,7 @@ class MPTModelImpl : public torch::nn::Module { handler_ = AttentionHandler::create_handler_with_alibi( args, alibi_slopes, options); - blocks_ = register_module("blocks", torch::nn::ModuleList()); + blocks_ = register_module("blocks", llm::nn::ModuleList()); layers_.reserve(args.n_layers()); for (int32_t i = 0; i < args.n_layers(); i++) { auto block = @@ -388,15 +390,15 @@ class MPTModelImpl : public torch::nn::Module { // attention handler std::unique_ptr handler_{nullptr}; - torch::nn::ModuleList blocks_{nullptr}; + llm::nn::ModuleList blocks_{nullptr}; // hold same data but different type as blocks_ to avoid type cast std::vector layers_; LayerNorm norm_f_{nullptr}; }; -TORCH_MODULE(MPTModel); +LLM_MODULE(MPTModel); -class MPTForCausalLMImpl : public torch::nn::Module { +class MPTForCausalLMImpl : public llm::nn::Module { public: MPTForCausalLMImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -456,7 +458,7 @@ class MPTForCausalLMImpl : public torch::nn::Module { ColumnParallelLinear lm_head_{nullptr}; }; -TORCH_MODULE(MPTForCausalLM); +LLM_MODULE(MPTForCausalLM); class MPTChatTemplate final : public CodedChatTemplate { public: diff --git a/src/models/phi.h b/src/models/phi.h index 433f65a4..ac9ba726 100644 --- a/src/models/phi.h +++ b/src/models/phi.h @@ -12,12 +12,14 @@ #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" - +#include "module/module.h" +#include "module/module_holder.h" +#include "module/modulelist.h" // Phi model compatible with huggingface weights namespace llm::hf { -class PhiMLPImpl : public torch::nn::Module { +class PhiMLPImpl : public llm::nn::Module { public: PhiMLPImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -69,9 +71,9 @@ class PhiMLPImpl : public torch::nn::Module { ActFunc act_{nullptr}; }; -TORCH_MODULE(PhiMLP); +LLM_MODULE(PhiMLP); -class PhiAttentionImpl : public torch::nn::Module { +class PhiAttentionImpl : public llm::nn::Module { public: PhiAttentionImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -156,9 +158,9 @@ class PhiAttentionImpl : public torch::nn::Module { // size for q, k, v std::vector qkv_sizes_; }; -TORCH_MODULE(PhiAttention); +LLM_MODULE(PhiAttention); -class PhiBlockImpl : public torch::nn::Module { +class PhiBlockImpl : public llm::nn::Module { public: PhiBlockImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -211,9 +213,9 @@ class PhiBlockImpl : public torch::nn::Module { LayerNorm ln_{nullptr}; }; -TORCH_MODULE(PhiBlock); +LLM_MODULE(PhiBlock); -class PhiModelImpl : public torch::nn::Module { +class PhiModelImpl : public llm::nn::Module { public: PhiModelImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -228,7 +230,7 @@ class PhiModelImpl : public torch::nn::Module { handler_ = AttentionHandler::create_handler_with_rope( args, /*interleaved=*/false, options); - blocks_ = register_module("h", torch::nn::ModuleList()); + blocks_ = register_module("h", llm::nn::ModuleList()); layers_.reserve(args.n_layers()); for (int32_t i = 0; i < args.n_layers(); i++) { auto block = @@ -279,13 +281,13 @@ class PhiModelImpl : public torch::nn::Module { // attention handler std::unique_ptr handler_{nullptr}; - torch::nn::ModuleList blocks_{nullptr}; + llm::nn::ModuleList blocks_{nullptr}; // hold same data but different type as blocks_ to avoid type cast std::vector layers_; }; -TORCH_MODULE(PhiModel); +LLM_MODULE(PhiModel); -class PhiLMHeadImpl : public torch::nn::Module { +class PhiLMHeadImpl : public llm::nn::Module { public: PhiLMHeadImpl(const ModelArgs& args, const ParallelArgs& parallel_args, @@ -325,9 +327,9 @@ class PhiLMHeadImpl : public torch::nn::Module { ColumnParallelLinear linear_{nullptr}; }; -TORCH_MODULE(PhiLMHead); +LLM_MODULE(PhiLMHead); -class PhiForCausalLMImpl : public torch::nn::Module { +class PhiForCausalLMImpl : public llm::nn::Module { public: PhiForCausalLMImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -381,7 +383,7 @@ class PhiForCausalLMImpl : public torch::nn::Module { PhiLMHead lm_head_{nullptr}; }; -TORCH_MODULE(PhiForCausalLM); +LLM_MODULE(PhiForCausalLM); // clang-format off // register the model to make it available diff --git a/src/models/qwen.h b/src/models/qwen.h index 841f74a9..3884eb8c 100644 --- a/src/models/qwen.h +++ b/src/models/qwen.h @@ -16,12 +16,14 @@ #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" - +#include "module/module.h" +#include "module/module_holder.h" +#include "module/modulelist.h" // QWen model compatible with huggingface weights // adopted from https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py namespace llm::hf { -class QWenMLPImpl : public torch::nn::Module { +class QWenMLPImpl : public llm::nn::Module { public: QWenMLPImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -79,9 +81,9 @@ class QWenMLPImpl : public torch::nn::Module { ActFunc act_{nullptr}; }; -TORCH_MODULE(QWenMLP); +LLM_MODULE(QWenMLP); -class QWenAttentionImpl : public torch::nn::Module { +class QWenAttentionImpl : public llm::nn::Module { public: QWenAttentionImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -152,9 +154,9 @@ class QWenAttentionImpl : public torch::nn::Module { // module members without parameters Attention atten_{nullptr}; }; -TORCH_MODULE(QWenAttention); +LLM_MODULE(QWenAttention); -class QWenBlockImpl : public torch::nn::Module { +class QWenBlockImpl : public llm::nn::Module { public: QWenBlockImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -207,9 +209,9 @@ class QWenBlockImpl : public torch::nn::Module { RMSNorm ln_2_{nullptr}; }; -TORCH_MODULE(QWenBlock); +LLM_MODULE(QWenBlock); -class QWenModelImpl : public torch::nn::Module { +class QWenModelImpl : public llm::nn::Module { public: QWenModelImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -224,7 +226,7 @@ class QWenModelImpl : public torch::nn::Module { handler_ = AttentionHandler::create_handler_with_rope( args, /*interleaved=*/false, options); - blocks_ = register_module("layers", torch::nn::ModuleList()); + blocks_ = register_module("layers", llm::nn::ModuleList()); layers_.reserve(args.n_layers()); for (int32_t i = 0; i < args.n_layers(); i++) { auto block = @@ -279,15 +281,15 @@ class QWenModelImpl : public torch::nn::Module { // attention handler std::unique_ptr handler_{nullptr}; - torch::nn::ModuleList blocks_{nullptr}; + llm::nn::ModuleList blocks_{nullptr}; // hold same data but different type as blocks_ to avoid type cast std::vector layers_; RMSNorm ln_f_{nullptr}; }; -TORCH_MODULE(QWenModel); +LLM_MODULE(QWenModel); -class QWenForCausalLMImpl : public torch::nn::Module { +class QWenForCausalLMImpl : public llm::nn::Module { public: QWenForCausalLMImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -346,7 +348,7 @@ class QWenForCausalLMImpl : public torch::nn::Module { ColumnParallelLinear lm_head_{nullptr}; }; -TORCH_MODULE(QWenForCausalLM); +LLM_MODULE(QWenForCausalLM); class QwenChatTemplate final : public CodedChatTemplate { public: @@ -431,4 +433,4 @@ REGISTER_TOKENIZER_ARGS(qwen, [&] { SET_ARG(pattern, pattern); }); -} // namespace llm::hf \ No newline at end of file +} // namespace llm::hf diff --git a/src/models/qwen2.h b/src/models/qwen2.h index aaa6280e..df64ccd1 100644 --- a/src/models/qwen2.h +++ b/src/models/qwen2.h @@ -17,13 +17,15 @@ #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" - +#include "module/module.h" +#include "module/module_holder.h" +#include "module/modulelist.h" // QWen2 model compatible with huggingface weights // ref to: // https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/qwen2/modeling_qwen2.py namespace llm::hf { -class QWen2MLPImpl : public torch::nn::Module { +class QWen2MLPImpl : public llm::nn::Module { public: QWen2MLPImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -82,9 +84,9 @@ class QWen2MLPImpl : public torch::nn::Module { // activation function ActFunc act_func_{nullptr}; }; -TORCH_MODULE(QWen2MLP); +LLM_MODULE(QWen2MLP); -class QWen2AttentionImpl : public torch::nn::Module { +class QWen2AttentionImpl : public llm::nn::Module { public: QWen2AttentionImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -166,9 +168,9 @@ class QWen2AttentionImpl : public torch::nn::Module { // module members without parameters Attention atten_{nullptr}; }; -TORCH_MODULE(QWen2Attention); +LLM_MODULE(QWen2Attention); -class QWen2DecoderLayerImpl : public torch::nn::Module { +class QWen2DecoderLayerImpl : public llm::nn::Module { public: QWen2DecoderLayerImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -234,9 +236,9 @@ class QWen2DecoderLayerImpl : public torch::nn::Module { RMSNormResidual input_layernorm_{nullptr}; RMSNormResidual post_attention_layernorm_{nullptr}; }; -TORCH_MODULE(QWen2DecoderLayer); +LLM_MODULE(QWen2DecoderLayer); -class QWen2ModelImpl : public torch::nn::Module { +class QWen2ModelImpl : public llm::nn::Module { public: QWen2ModelImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -251,7 +253,7 @@ class QWen2ModelImpl : public torch::nn::Module { handler_ = AttentionHandler::create_handler_with_rope( args, /*interleaved=*/false, options); - blocks_ = register_module("layers", torch::nn::ModuleList()); + blocks_ = register_module("layers", llm::nn::ModuleList()); layers_.reserve(args.n_layers()); for (int32_t i = 0; i < args.n_layers(); i++) { int32_t sliding_window = -1; @@ -316,15 +318,15 @@ class QWen2ModelImpl : public torch::nn::Module { // attention handler std::unique_ptr handler_{nullptr}; - torch::nn::ModuleList blocks_{nullptr}; + llm::nn::ModuleList blocks_{nullptr}; // hold same data but different type as blocks_ to avoid type cast std::vector layers_; RMSNormResidual norm_{nullptr}; }; -TORCH_MODULE(QWen2Model); +LLM_MODULE(QWen2Model); -class QWen2ForCausalLMImpl : public torch::nn::Module { +class QWen2ForCausalLMImpl : public llm::nn::Module { public: QWen2ForCausalLMImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -383,7 +385,7 @@ class QWen2ForCausalLMImpl : public torch::nn::Module { ColumnParallelLinear lm_head_{nullptr}; }; -TORCH_MODULE(QWen2ForCausalLM); +LLM_MODULE(QWen2ForCausalLM); class QWen2ChatTemplate final : public CodedChatTemplate { public: @@ -447,4 +449,4 @@ REGISTER_MODEL_ARGS(qwen2, [&] { SET_ARG(stop_token_ids, std::unordered_set({151644, 151645})); }); -} // namespace llm::hf \ No newline at end of file +} // namespace llm::hf diff --git a/src/models/simple_model.h b/src/models/simple_model.h index 2570910c..ae727cff 100644 --- a/src/models/simple_model.h +++ b/src/models/simple_model.h @@ -11,11 +11,13 @@ #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" - +#include "module/module.h" +#include "module/module_holder.h" +#include "module/modulelist.h" // simple model for test namespace llm { -class SimpleMLPImpl : public torch::nn::Module { +class SimpleMLPImpl : public llm::nn::Module { public: SimpleMLPImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -69,9 +71,9 @@ class SimpleMLPImpl : public torch::nn::Module { ActFunc act_func_{nullptr}; }; -TORCH_MODULE(SimpleMLP); +LLM_MODULE(SimpleMLP); -class SimpleDecoderLayerImpl : public torch::nn::Module { +class SimpleDecoderLayerImpl : public llm::nn::Module { public: SimpleDecoderLayerImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -100,9 +102,9 @@ class SimpleDecoderLayerImpl : public torch::nn::Module { SimpleMLP mlp_{nullptr}; }; -TORCH_MODULE(SimpleDecoderLayer); +LLM_MODULE(SimpleDecoderLayer); -class SimpleModelImpl : public torch::nn::Module { +class SimpleModelImpl : public llm::nn::Module { public: SimpleModelImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -113,7 +115,7 @@ class SimpleModelImpl : public torch::nn::Module { ParallelEmbedding( args.vocab_size(), args.hidden_size(), parallel_args, options)); - blocks_ = register_module("layers", torch::nn::ModuleList()); + blocks_ = register_module("layers", llm::nn::ModuleList()); layers_.reserve(args.n_layers()); for (int32_t i = 0; i < args.n_layers(); i++) { auto block = SimpleDecoderLayer(args, quant_args, parallel_args, options); @@ -152,12 +154,12 @@ class SimpleModelImpl : public torch::nn::Module { private: ParallelEmbedding embed_tokens_{nullptr}; - torch::nn::ModuleList blocks_{nullptr}; + llm::nn::ModuleList blocks_{nullptr}; std::vector layers_; }; -TORCH_MODULE(SimpleModel); +LLM_MODULE(SimpleModel); -class SimpleForCausalLMImpl : public torch::nn::Module { +class SimpleForCausalLMImpl : public llm::nn::Module { public: SimpleForCausalLMImpl(const ModelArgs& args, const QuantArgs& quant_args, @@ -190,7 +192,7 @@ class SimpleForCausalLMImpl : public torch::nn::Module { private: SimpleModel model_{nullptr}; }; -TORCH_MODULE(SimpleForCausalLM); +LLM_MODULE(SimpleForCausalLM); REGISTER_CAUSAL_MODEL(simple, SimpleForCausalLM); REGISTER_MODEL_ARGS(simple, [&] { diff --git a/src/module/CMakeLists.txt b/src/module/CMakeLists.txt new file mode 100644 index 00000000..e06006f6 --- /dev/null +++ b/src/module/CMakeLists.txt @@ -0,0 +1,15 @@ +include(cc_library) + +cc_library( + NAME + module + HDRS + module.h + modulelist.h + module_holder.h + SRCS + module.cpp + DEPS + glog::glog + torch +) diff --git a/src/module/cloneable.h b/src/module/cloneable.h new file mode 100644 index 00000000..51821a99 --- /dev/null +++ b/src/module/cloneable.h @@ -0,0 +1,90 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +#include "module.h" + +namespace llm::nn { +/// The `clone()` method in the base `Module` class does not have knowledge of +/// the concrete runtime type of its subclasses. Therefore, `clone()` must +/// either be called from within the subclass, or from a base class that has +/// knowledge of the concrete type. `Cloneable` uses the CRTP to gain +/// knowledge of the subclass' static type and provide an implementation of the +/// `clone()` method. We do not want to use this pattern in the base class, +/// because then storing a module would always require templatizing it. +template +// NOLINTNEXTLINE(bugprone-exception-escape) +class Cloneable : public llm::nn::Module { + public: + using Module::Module; + + /// `reset()` must perform initialization of all members with reference + /// semantics, most importantly parameters, buffers and submodules. + virtual void reset() = 0; + + /// Performs a recursive "deep copy" of the `Module`, such that all parameters + /// and submodules in the cloned module are different from those in the + /// original module. + std::shared_ptr clone( + const std::optional& device = std::nullopt) const override { + NoGradGuard no_grad; + + const auto& self = static_cast(*this); + auto copy = std::make_shared(self); + copy->parameters_.clear(); + copy->buffers_.clear(); + copy->children_.clear(); + copy->reset(); + TORCH_CHECK(copy->parameters_.size() == parameters_.size(), + "The cloned module does not have the same number of " + "parameters as the original module after calling reset(). " + "Are you sure you called register_parameter() inside reset() " + "and not the constructor?"); + for (const auto& parameter : named_parameters(/*recurse=*/false)) { + auto& tensor = *parameter; + auto data = device && tensor.device() != *device ? tensor.to(*device) + : tensor.clone(); + copy->parameters_[parameter.key()].set_data(data); + } + TORCH_CHECK(copy->buffers_.size() == buffers_.size(), + "The cloned module does not have the same number of " + "buffers as the original module after calling reset(). " + "Are you sure you called register_buffer() inside reset() " + "and not the constructor?"); + for (const auto& buffer : named_buffers(/*recurse=*/false)) { + auto& tensor = *buffer; + auto data = device && tensor.device() != *device ? tensor.to(*device) + : tensor.clone(); + copy->buffers_[buffer.key()].set_data(data); + } + TORCH_CHECK(copy->children_.size() == children_.size(), + "The cloned module does not have the same number of " + "child modules as the original module after calling reset(). " + "Are you sure you called register_module() inside reset() " + "and not the constructor?"); + for (const auto& child : children_) { + copy->children_[child.key()]->clone_(*child.value(), device); + } + return copy; + } + + private: + void clone_(Module& other, const std::optional& device) final { + // Here we are *pretty* certain that `other's` type is `Derived` (because it + // was registered under the same name as `this`), but you never know what + // crazy things `reset()` does, so `dynamic_cast` just to be safe. + auto clone = std::dynamic_pointer_cast(other.clone(device)); + TORCH_CHECK(clone != nullptr, + "Attempted to clone submodule, but it is of a " + "different type than the submodule it was to be cloned into"); + static_cast(*this) = *clone; + } +}; + +} // namespace llm::nn diff --git a/src/module/module.cpp b/src/module/module.cpp new file mode 100644 index 00000000..c6a332e4 --- /dev/null +++ b/src/module/module.cpp @@ -0,0 +1,397 @@ +#include "module.h" + +#include +#include +#include + +#include +#include +#include + +namespace llm::nn { +namespace { +/// Joins names hierarchically: "name_prefix.name" if `name_prefix` is +/// non-empty, else just "name". +std::string join_name(const std::string& name_prefix, const std::string& name) { + size_t total_size = name.size(); + if (!name_prefix.empty()) { + total_size += name_prefix.size() + 1; + } + std::string full_name; + full_name.reserve(total_size); + if (!name_prefix.empty()) { + full_name += name_prefix; + full_name.push_back('.'); + } + full_name += name; + return full_name; +} +} // namespace + +Module::Module() + : parameters_("Parameter"), buffers_("Buffer"), children_("Submodule") {} + +Module::Module(std::string name) : Module() { + // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer) + name_ = std::move(name); +} + +const std::string& Module::name() const noexcept { + // If the name optional is empty at this point, we grab the name of the + // dynamic type via RTTI. Note that we cannot do this in the constructor, + // because in the constructor of a base class `this` always refers to the base + // type. Inheritance effectively does not work in constructors. Also this note + // from http://en.cppreference.com/w/cpp/language/typeid: + // If typeid is used on an object under construction or destruction (in a + // destructor or in a constructor, including constructor's initializer list + // or default member initializers), then the std::type_info object referred + // to by this typeid represents the class that is being constructed or + // destroyed even if it is not the most-derived class. + if (!name_.has_value()) { + name_ = c10::demangle(typeid(*this).name()); +#if defined(_WIN32) + // Windows adds "struct" or "class" as a prefix. + if (name_->find("struct ") == 0) { + name_->erase(name_->begin(), name_->begin() + 7); + } else if (name_->find("class ") == 0) { + name_->erase(name_->begin(), name_->begin() + 6); + } +#endif // defined(_WIN32) + } + return *name_; +} + +std::shared_ptr Module::clone( + const std::optional& device) const { + TORCH_CHECK( + false, + "clone() has not been implemented for ", + name(), + ". Subclass torch::nn::Cloneable<", + name(), + "> instead of torch::nn::Module to inherit the ability to clone."); +} + +void Module::apply(const ModuleApplyFunction& function) { + function(*this); + apply_to_submodules( + [&function](const std::string&, const std::shared_ptr& module) { + function(*module); + }); +} + +void Module::apply(const ConstModuleApplyFunction& function) const { + function(*this); + apply_to_submodules( + [&function](const std::string&, const std::shared_ptr& module) { + function(*module); + }); +} + +void Module::apply(const NamedModuleApplyFunction& function, + const std::string& name_prefix) { + function(/*name=*/name_prefix, *this); + apply_to_submodules( + [&function](const std::string& name, + const std::shared_ptr& module) { + function(name, *module); + }, + name_prefix); +} + +void Module::apply(const ConstNamedModuleApplyFunction& function, + const std::string& name_prefix) const { + function(/*name=*/name_prefix, *this); + apply_to_submodules( + [&function](const std::string& name, + const std::shared_ptr& module) { + function(name, *module); + }, + name_prefix); +} + +void Module::apply(const ModulePointerApplyFunction& function) const { + function(shared_from_this_checked()); + apply_to_submodules( + [&function](const std::string&, const std::shared_ptr& module) { + function(module); + }); +} + +void Module::apply(const NamedModulePointerApplyFunction& function, + const std::string& name_prefix) const { + function( + /*name=*/name_prefix, shared_from_this_checked()); + apply_to_submodules(function, name_prefix); +} + +std::vector Module::parameters(bool recurse) const { + return named_parameters(recurse).values(); +} + +OrderedDict Module::named_parameters(bool recurse) const { + OrderedDict result; + if (!recurse) { + for (const auto& parameter : parameters_) { + if (parameter.value().defined()) { + result.insert(parameter.key(), parameter.value()); + } + } + } else { + apply([&result](const std::string& name, const Module& module) { + for (const auto& parameter : module.named_parameters(/*recurse=*/false)) { + TORCH_INTERNAL_ASSERT(parameter.value().defined()); + result.insert(join_name(name, parameter.key()), parameter.value()); + } + }); + } + return result; +} + +std::vector Module::buffers(bool recurse) const { + return named_buffers(recurse).values(); +} + +OrderedDict Module::named_buffers(bool recurse) const { + OrderedDict result; + if (!recurse) { + for (const auto& buffer : buffers_) { + if (buffer.value().defined()) { + result.insert(buffer.key(), buffer.value()); + } + } + } else { + apply([&result](const std::string& name, const Module& module) { + for (const auto& buffer : module.named_buffers(/*recurse=*/false)) { + TORCH_INTERNAL_ASSERT(buffer.value().defined()); + result.insert(join_name(name, buffer.key()), buffer.value()); + } + }); + } + return result; +} + +std::vector> Module::modules(bool include_self) const { + std::vector> result; + if (include_self) { + apply([&result](const std::shared_ptr& module) { + result.push_back(module); + }); + } else { + apply_to_submodules( + [&result](const std::string&, const std::shared_ptr& module) { + result.push_back(module); + }); + } + return result; +} + +OrderedDict> Module::named_modules( + const std::string& name_prefix, + bool include_self) const { + OrderedDict> result; + if (include_self) { + apply( + [&result](const std::string& key, + const std::shared_ptr& module) { + result.insert(key, module); + }, + name_prefix); + } else { + apply_to_submodules( + [&result](const std::string& key, + const std::shared_ptr& module) { + result.insert(key, module); + }, + name_prefix); + } + return result; +} + +std::vector> Module::children() const { + return children_.values(); +} + +OrderedDict> Module::named_children() + const { + return children_; +} + +void Module::train(bool on) { + for (auto& child : children_) { + child.value()->train(on); + } + is_training_ = on; +} + +void Module::eval() { train(/*on=*/false); } + +void Module::to(torch::Device device, torch::Dtype dtype, bool non_blocking) { + to_impl(device, dtype, non_blocking); +} + +void Module::to(torch::Dtype dtype, bool non_blocking) { + to_impl(dtype, non_blocking); +} + +void Module::to(torch::Device device, bool non_blocking) { + to_impl(device, non_blocking); +} + +bool Module::is_training() const noexcept { return is_training_; } + +void Module::zero_grad(bool set_to_none) { + for (auto& child : children_) { + child.value()->zero_grad(set_to_none); + } + for (auto& parameter : named_parameters(/*recurse=*/false)) { + auto& grad = parameter->mutable_grad(); + if (grad.defined()) { + grad = grad.detach(); + + if (set_to_none) + grad.reset(); + else + grad.zero_(); + } + } +} + +void Module::save(serialize::OutputArchive& archive) const { + for (const auto& parameter : named_parameters(/*recurse=*/false)) { + archive.write(parameter.key(), parameter.value()); + } + for (const auto& buffer : named_buffers(/*recurse=*/false)) { + archive.write(buffer.key(), buffer.value(), /*is_buffer=*/true); + } + for (const auto& child : children_) { + if (child.value()->is_serializable()) { + serialize::OutputArchive child_archive(archive.compilation_unit()); + child.value()->save(child_archive); + archive.write(child.key(), child_archive); + } + } +} + +void Module::load(serialize::InputArchive& archive) { + for (auto& parameter : named_parameters(/*recurse=*/false)) { + archive.read(parameter.key(), parameter.value()); + } + for (auto& buffer : named_buffers(/*recurse=*/false)) { + archive.read(buffer.key(), buffer.value(), /*is_buffer=*/true); + } + for (const auto& child : children_) { + if (child.value()->is_serializable()) { + serialize::InputArchive child_archive; + archive.read(child.key(), child_archive); + child.value()->load(child_archive); + } + } +} + +bool Module::is_serializable() const { return true; } + +Tensor& Module::register_parameter(std::string name, + Tensor tensor, + bool requires_grad) { + TORCH_CHECK(!name.empty(), "Parameter name must not be empty"); + TORCH_CHECK(name.find('.') == std::string::npos, + "Parameter name must not contain a dot (got '", + name, + "')"); + if (!tensor.defined()) { + if (requires_grad) { + TORCH_WARN("An undefined tensor cannot require grad. ", + "Ignoring the `requires_grad=true` function parameter."); + } + } else { + tensor.set_requires_grad(requires_grad); + } + return parameters_.insert(std::move(name), std::move(tensor)); +} + +Tensor& Module::register_buffer(std::string name, Tensor tensor) { + TORCH_CHECK(!name.empty(), "Buffer name must not be empty"); + TORCH_CHECK(name.find('.') == std::string::npos, + "Buffer name must not contain a dot (got '", + name, + "')"); + return buffers_.insert(std::move(name), std::move(tensor)); +} + +void Module::unregister_module(const std::string& name) { + TORCH_CHECK(children_.contains(name), + "No Module with name `", + name, + "` is registered"); + children_.erase(name); +} + +void Module::pretty_print(std::ostream& stream) const { stream << name(); } + +void Module::pretty_print_recursive(std::ostream& stream, + const std::string& indentation) const { + pretty_print(stream); + if (!children_.is_empty()) { + stream << "(\n"; + const std::string next_indentation = indentation + " "; + for (const auto& child : children_) { + stream << next_indentation << "(" << child.key() << "): "; + child.value()->pretty_print_recursive(stream, next_indentation); + stream << '\n'; + } + stream << indentation << ")"; + } +} + +void Module::clone_(Module& other, const std::optional& device) {} + +void Module::apply_to_submodules( + const NamedModulePointerApplyFunction& function, + const std::string& name_prefix) const { + for (const auto& child : children_) { + auto qualified_name = join_name(name_prefix, child.key()); + function(qualified_name, child.value()); + child.value()->apply_to_submodules(function, qualified_name); + } +} + +std::shared_ptr Module::shared_from_this_checked() const { + std::shared_ptr ptr; + try { + ptr = shared_from_this(); + } catch (const std::bad_weak_ptr&) { + TORCH_CHECK( + false, + "It looks like you attempted to retrieve your top-level module " + "as a shared_ptr, but it is not stored in a shared_ptr. " + "Use std::make_shared<", + name(), + "> instead of creating your module on " + "the stack, or alternatively do not try to access your top-level " + "module at all by passing /*include_self=*/false " + "to modules() or named_modules()"); + } + return std::const_pointer_cast(ptr); +} + +std::ostream& operator<<(std::ostream& stream, const nn::Module& module) { + module.pretty_print_recursive(stream, ""); + return stream; +} + +serialize::OutputArchive& operator<<( + serialize::OutputArchive& archive, + const std::shared_ptr& module) { + TORCH_CHECK(module != nullptr, "Cannot serialize empty module"); + module->save(archive); + return archive; +} + +serialize::InputArchive& operator>>(serialize::InputArchive& archive, + const std::shared_ptr& module) { + TORCH_CHECK(module != nullptr, "Cannot deserialize empty module"); + module->load(archive); + return archive; +} +} // namespace llm::nn diff --git a/src/module/module.h b/src/module/module.h new file mode 100644 index 00000000..afcf813e --- /dev/null +++ b/src/module/module.h @@ -0,0 +1,692 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "module_holder.h" + +namespace llm::nn { +using namespace torch; + +/// The base class for all modules in PyTorch. +/// +/// \rst +/// .. note:: +/// The design and implementation of this class is largely based on the Python +/// API. You may want to consult the python documentation for +/// :py:class:`pytorch:torch.nn.Module` for further clarification on certain +/// methods or behavior. +/// \endrst +/// +/// A `Module` is an abstraction over the implementation of some function or +/// algorithm, possibly associated with some persistent data. A `Module` may +/// contain further `Module`s ("submodules"), each with their own +/// implementation, persistent data and further submodules. `Module`s can thus +/// be said to form a recursive tree structure. A `Module` is registered as a +/// submodule to another `Module` by calling `register_module()`, typically from +/// within a parent module's constructor. +/// +/// A distinction is made between three kinds of persistent data that may be +/// associated with a `Module`: +/// +/// 1. *Parameters*: tensors that record gradients, typically weights updated +/// during the backward step (e.g. the `weight` of a `Linear` module), +/// 2. *Buffers*: tensors that do not record gradients, typically updated during +/// the forward step, such as running statistics (e.g. `mean` and `variance` +/// in the `BatchNorm` module), +/// 3. Any additional state, not necessarily tensors, required for the +/// implementation or configuration of a `Module`. +/// +/// The first two kinds of state are special in that they may be registered +/// with the `Module` system to allow convenient access and batch configuration. +/// For example, registered parameters in any `Module` may be iterated over via +/// the `parameters()` accessor. Further, changing the data type of a `Module`'s +/// registered parameters can be done conveniently via `Module::to()`, e.g. +/// `module->to(torch::kCUDA)` to move all parameters to GPU memory. Lastly, +/// registered parameters and buffers are handled specially during a `clone()` +/// operation, which performs a deepcopy of a cloneable `Module` hierarchy. +/// +/// Parameters are registered with a `Module` via `register_parameter`. Buffers +/// are registered separately via `register_buffer`. These methods are part of +/// the public API of `Module` and are typically invoked from within a +/// concrete `Module`s constructor. +class Module : public std::enable_shared_from_this { + public: + using ModuleApplyFunction = std::function; + using ConstModuleApplyFunction = std::function; + using NamedModuleApplyFunction = + std::function; + using ConstNamedModuleApplyFunction = + std::function; + using ModulePointerApplyFunction = + std::function&)>; + using NamedModulePointerApplyFunction = + std::function&)>; + + /// Tells the base `Module` about the name of the submodule. + explicit Module(std::string name); + + /// Constructs the module without immediate knowledge of the submodule's name. + /// The name of the submodule is inferred via RTTI (if possible) the first + /// time `.name()` is invoked. + Module(); + Module(const Module&) = default; + Module& operator=(const Module&) = default; + Module(Module&&) noexcept = default; + Module& operator=(Module&&) noexcept = default; + + virtual ~Module() = default; + + /// Returns the name of the `Module`. + /// + /// A `Module` has an associated `name`, which is a string representation of + /// the kind of concrete `Module` it represents, such as `"Linear"` for the + /// `Linear` module. Under most circumstances, this name is automatically + /// inferred via runtime type information (RTTI). In the unusual circumstance + /// that you have this feature disabled, you may want to manually name your + /// `Module`s by passing the string name to the `Module` base class' + /// constructor. + const std::string& name() const noexcept; + + /// Performs a recursive deep copy of the module and all its registered + /// parameters, buffers and submodules. + /// + /// Optionally, this method sets the current device + /// to the one supplied before cloning. If no device is given, each + /// parameter and buffer will be moved to the device of its source. + /// + /// \rst + /// .. attention:: + /// Attempting to call the `clone()` method inherited from the base `Module` + /// class (the one documented here) will fail. To inherit an actual + /// implementation of `clone()`, you must subclass `Cloneable`. `Cloneable` + /// is templatized on the concrete module type, and can thus properly copy a + /// `Module`. This method is provided on the base class' API solely for an + /// easier-to-use polymorphic interface. + /// \endrst + virtual std::shared_ptr clone( + const std::optional& device = std::nullopt) const; + + /// Applies the `function` to the `Module` and recursively to every submodule. + /// The function must accept a `Module&`. + /// + /// \rst + /// .. code-block:: cpp + /// MyModule module; + /// module->apply([](nn::Module& module) { + /// std::cout << module.name() << std::endl; + /// }); + /// \endrst + void apply(const ModuleApplyFunction& function); + + /// Applies the `function` to the `Module` and recursively to every submodule. + /// The function must accept a `const Module&`. + /// + /// \rst + /// .. code-block:: cpp + /// MyModule module; + /// module->apply([](const nn::Module& module) { + /// std::cout << module.name() << std::endl; + /// }); + /// \endrst + void apply(const ConstModuleApplyFunction& function) const; + + /// Applies the `function` to the `Module` and recursively to every submodule. + /// The function must accept a `const std::string&` for the key of the module, + /// and a `Module&`. The key of the module itself is the empty string. If + /// `name_prefix` is given, it is prepended to every key as + /// `.` (and just `name_prefix` for the module itself). + /// + /// \rst + /// .. code-block:: cpp + /// MyModule module; + /// module->apply([](const std::string& key, nn::Module& module) { + /// std::cout << key << ": " << module.name() << std::endl; + /// }); + /// \endrst + void apply(const NamedModuleApplyFunction& function, + const std::string& name_prefix = std::string()); + + /// Applies the `function` to the `Module` and recursively to every submodule. + /// The function must accept a `const std::string&` for the key of the module, + /// and a `const Module&`. The key of the module itself is the empty string. + /// If `name_prefix` is given, it is prepended to every key as + /// `.` (and just `name_prefix` for the module itself). + /// + /// \rst + /// .. code-block:: cpp + /// MyModule module; + /// module->apply([](const std::string& key, const nn::Module& module) { + /// std::cout << key << ": " << module.name() << std::endl; + /// }); + /// \endrst + void apply(const ConstNamedModuleApplyFunction& function, + const std::string& name_prefix = std::string()) const; + + /// Applies the `function` to the `Module` and recursively to every submodule. + /// The function must accept a `const std::shared_ptr&`. + /// + /// \rst + /// .. code-block:: cpp + /// MyModule module; + /// module->apply([](const std::shared_ptr& module) { + /// std::cout << module->name() << std::endl; + /// }); + /// \endrst + void apply(const ModulePointerApplyFunction& function) const; + + /// Applies the `function` to the `Module` and recursively to every submodule. + /// The function must accept a `const std::string&` for the key of the module, + /// and a `const std::shared_ptr&`. The key of the module itself is + /// the empty string. If `name_prefix` is given, it is prepended to every key + /// as + /// `.` (and just `name_prefix` for the module itself). + /// + /// \rst + /// .. code-block:: cpp + /// MyModule module; + /// module->apply([](const std::string& key, + /// const std::shared_ptr& module) { + /// std::cout << key << ": " << module->name() << std::endl; + /// }); + /// \endrst + void apply(const NamedModulePointerApplyFunction& function, + const std::string& name_prefix = std::string()) const; + + /// Returns the parameters of this `Module` and if `recurse` is true, also + /// recursively of every submodule. + std::vector parameters(bool recurse = true) const; + + /// Returns an `OrderedDict` with the parameters of this `Module` along with + /// their keys, and if `recurse` is true also recursively of every submodule. + OrderedDict named_parameters(bool recurse = true) const; + + /// Returns the buffers of this `Module` and if `recurse` is true, also + /// recursively of every submodule. + std::vector buffers(bool recurse = true) const; + + /// Returns an `OrderedDict` with the buffers of this `Module` along with + /// their keys, and if `recurse` is true also recursively of every submodule. + OrderedDict named_buffers(bool recurse = true) const; + + /// Returns the submodules of this `Module` (the entire submodule hierarchy) + /// and if `include_self` is true, also inserts a `shared_ptr` to this module + /// in the first position. + /// + /// \rst + /// .. warning:: + /// Only pass `include_self` as `true` if this `Module` is stored in a + /// `shared_ptr`! Otherwise an exception will be thrown. You may still call + /// this method with `include_self` set to false if your `Module` is not + /// stored in a `shared_ptr`. + /// \endrst + std::vector> modules(bool include_self = true) const; + + /// Returns an `OrderedDict` of the submodules of this `Module` (the entire + /// submodule hierarchy) and their keys, and if `include_self` is true, also + /// inserts a `shared_ptr` to this module in the first position. If + /// `name_prefix` is given, it is prepended to every key as + /// `.` (and just `name_prefix` for the module itself). + /// + /// \rst + /// .. warning:: + /// Only pass `include_self` as `true` if this `Module` is stored in a + /// `shared_ptr`! Otherwise an exception will be thrown. You may still call + /// this method with `include_self` set to false if your `Module` is not + /// stored in a `shared_ptr`. + /// \endrst + OrderedDict> named_modules( + const std::string& name_prefix = std::string(), + bool include_self = true) const; + + /// Returns the direct submodules of this `Module`. + std::vector> children() const; + + /// Returns an `OrderedDict` of the direct submodules of this `Module` and + /// their keys. + OrderedDict> named_children() const; + + /// Enables "training" mode. + virtual void train(bool on = true); + + /// Calls train(false) to enable "eval" mode. + /// Do not override this method, override `train()` instead. + void eval(); + + /// True if the module is in training mode. + /// + /// Every `Module` has a boolean associated with it that determines whether + /// the `Module` is currently in *training* mode (set via `.train()`) or in + /// *evaluation* (inference) mode (set via `.eval()`). This property is + /// exposed via `is_training()`, and may be used by the implementation of a + /// concrete module to modify its runtime behavior. See the `BatchNorm` or + /// `Dropout` modules for examples of `Module`s that use different code paths + /// depending on this property. + virtual bool is_training() const noexcept; + + /// Recursively casts all parameters to the given `dtype` and `device`. + /// + /// If `non_blocking` is true and the source is in pinned memory and + /// destination is on the GPU or vice versa, the copy is performed + /// asynchronously with respect to the host. Otherwise, the argument has no + /// effect. + virtual void to(torch::Device device, + torch::Dtype dtype, + bool non_blocking = false); + + /// Recursively casts all parameters to the given dtype. + /// + /// If `non_blocking` is true and the source is in pinned memory and + /// destination is on the GPU or vice versa, the copy is performed + /// asynchronously with respect to the host. Otherwise, the argument has no + /// effect. + virtual void to(torch::Dtype dtype, bool non_blocking = false); + + /// Recursively moves all parameters to the given device. + /// + /// If `non_blocking` is true and the source is in pinned memory and + /// destination is on the GPU or vice versa, the copy is performed + /// asynchronously with respect to the host. Otherwise, the argument has no + /// effect. + virtual void to(torch::Device device, bool non_blocking = false); + + /// Recursively zeros out the `grad` value of each registered parameter. + virtual void zero_grad(bool set_to_none = true); + + /// Attempts to cast this `Module` to the given `ModuleType`. + /// + /// This method is useful when calling `apply()`. + /// \rst + /// .. code-block:: cpp + /// + /// void initialize_weights(nn::Module& module) { + /// torch::NoGradGuard no_grad; + /// if (auto* linear = module.as()) { + /// linear->weight.normal_(0.0, 0.02); + /// } + /// } + /// + /// MyModule module; + /// module->apply(initialize_weights); + /// \endrst + template + typename ModuleType::ContainedType* as() noexcept; + + /// Attempts to cast this `Module` to the given `ModuleType`. + /// + /// This method is useful when calling `apply()`. + /// \rst + /// .. code-block:: cpp + /// void initialize_weights(nn::Module& module) { + /// torch::NoGradGuard no_grad; + /// if (auto* linear = module.as()) { + /// linear->weight.normal_(0.0, 0.02); + /// } + /// } + /// + /// MyModule module; + /// module->apply(initialize_weights); + /// \endrst + template + const typename ModuleType::ContainedType* as() const noexcept; + + /// Attempts to cast this `Module` to the given `ModuleType`. + /// + /// This method is useful when calling `apply()`. + /// \rst + /// .. code-block:: cpp + /// + /// void initialize_weights(nn::Module& module) { + /// torch::NoGradGuard no_grad; + /// if (auto* linear = module.as()) { + /// linear->weight.normal_(0.0, 0.02); + /// } + /// } + /// + /// MyModule module; + /// module.apply(initialize_weights); + /// \endrst + template > + ModuleType* as() noexcept; + + /// Attempts to cast this `Module` to the given `ModuleType`. + /// + /// This method is useful when calling `apply()`. + /// \rst + /// .. code-block:: cpp + /// + /// void initialize_weights(nn::Module& module) { + /// torch::NoGradGuard no_grad; + /// if (auto* linear = module.as()) { + /// linear->weight.normal_(0.0, 0.02); + /// } + /// } + /// + /// MyModule module; + /// module.apply(initialize_weights); + /// \endrst + template > + const ModuleType* as() const noexcept; + + /// Serializes the `Module` into the given `OutputArchive`. + /// + /// If the `Module` contains unserializable submodules (e.g. + /// `nn::Functional`), those submodules are skipped when serializing. + virtual void save(serialize::OutputArchive& archive) const; + + /// Deserializes the `Module` from the given `InputArchive`. + /// + /// If the `Module` contains unserializable submodules (e.g. + /// `nn::Functional`), we don't check the existence of those submodules in the + /// `InputArchive` when deserializing. + virtual void load(serialize::InputArchive& archive); + + /// Streams a pretty representation of the `Module` into the given `stream`. + /// By default, this representation will be the name of the module (taken from + /// `name()`), followed by a recursive pretty print of all of the `Module`'s + /// submodules. + /// + /// Override this method to change the pretty print. The input + /// `stream` should be returned from the method, to allow easy chaining. + virtual void pretty_print(std::ostream& stream) const; + + /// Returns whether the `Module` is serializable. + virtual bool is_serializable() const; + + /// Registers a parameter with this `Module`. + /// + /// A parameter should be any gradient-recording tensor used in the + /// implementation of your `Module`. Registering it makes it available to + /// methods such as `parameters()`, `clone()` or `to().` + /// + /// Note that registering an undefined Tensor (e.g. + /// `module.register_parameter("param", Tensor())`) is allowed, and is + /// equivalent to `module.register_parameter("param", None)` in Python API. + /// + /// \rst + /// .. code-block:: cpp + /// + /// MyModule::MyModule() { + /// weight_ = register_parameter("weight", torch::randn({A, B})); + /// } + /// \endrst + Tensor& register_parameter(std::string name, + Tensor tensor, + bool requires_grad = true); + + /// Registers a buffer with this `Module`. + /// + /// A buffer is intended to be state in your module that does not record + /// gradients, such as running statistics. Registering it makes it available + /// to methods such as `buffers()`, `clone()` or `to(). + /// + /// \rst + /// .. code-block:: cpp + /// + /// MyModule::MyModule() { + /// mean_ = register_buffer("mean", torch::empty({num_features_})); + /// } + /// \endrst + Tensor& register_buffer(std::string name, Tensor tensor); + + /// Registers a submodule with this `Module`. + /// + /// Registering a module makes it available to methods such as `modules()`, + /// `clone()` or `to()`. + /// + /// \rst + /// .. code-block:: cpp + /// + /// MyModule::MyModule() { + /// submodule_ = register_module("linear", torch::nn::Linear(3, 4)); + /// } + /// \endrst + template + std::shared_ptr register_module( + std::string name, + std::shared_ptr module); + + /// Registers a submodule with this `Module`. + /// + /// This method deals with `ModuleHolder`s. + /// + /// Registering a module makes it available to methods such as `modules()`, + /// `clone()` or `to()`. + /// + /// \rst + /// .. code-block:: cpp + /// + /// MyModule::MyModule() { + /// submodule_ = register_module("linear", torch::nn::Linear(3, 4)); + /// } + /// \endrst + template + std::shared_ptr register_module( + std::string name, + ModuleHolder module_holder); + + /// Replaces a registered submodule with this `Module`. + /// + /// This takes care of the registration, if you used submodule members, you + /// should + // assign the submodule as well, i.e. use as + /// module->submodule_ = module->replace_module("linear", + /// torch::nn::Linear(3, 4)); + /// It only works when a module of the name is already registered. + /// + /// This is useful for replacing a module after initialization, e.g. + /// for finetuning. + template + std::shared_ptr replace_module( + const std::string& name, + std::shared_ptr module); + + /// Replaces a registered submodule with this `Module`. + /// This method deals with `ModuleHolder`s. + /// + /// This takes care of the registration, if you used submodule members, you + /// should + // assign the submodule as well, i.e. use as + /// module->submodule_ = module->replace_module("linear", linear_holder); + /// It only works when a module of the name is already registered. + /// + /// This is useful for replacing a module after initialization, e.g. + /// for finetuning. + template + std::shared_ptr replace_module( + const std::string& name, + ModuleHolder module_holder); + + /// Unregisters a submodule from this `Module`. If there is no such module + /// with `name` an exception is thrown. + void unregister_module(const std::string& name); + + protected: + /// The following three functions allow a module with default arguments in its + /// forward method to be used in a Sequential module. + /// You should NEVER override these functions manually. Instead, you should + /// use the `FORWARD_HAS_DEFAULT_ARGS` macro. + virtual bool _forward_has_default_args() { return false; } + + virtual unsigned int _forward_num_required_args() { + TORCH_CHECK( + false, + "torch::nn::Module subclass that has default arguments in `forward` " + "method ", + "must override `_forward_num_required_args` method. Please use ", + "`FORWARD_HAS_DEFAULT_ARGS` macro to do so."); + } + + // virtual std::vector _forward_populate_default_args( + // std::vector&& arguments) { + // TORCH_CHECK( + // false, + // "torch::nn::Module subclass that has default arguments in `forward` " + // "method ", + // "must override `_forward_populate_default_args` method. Please use ", + // "`FORWARD_HAS_DEFAULT_ARGS` macro to do so."); + // } + + /// The registered parameters of this `Module`. + /// Inorder to access parameters_ in ParameterDict and ParameterList + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + OrderedDict parameters_; + + private: + // Friend classes. + + template + friend class Cloneable; + + template + friend struct AnyModuleHolder; + + /// Pretty prints the given `Module` into the `ostream`. + TORCH_API friend std::ostream& operator<<(std::ostream& stream, + const nn::Module& module); + + // data parallel using this method to configure gradient edges during the + // replicate step. + template + friend void replicate_grad_edges( + const std::shared_ptr& module, + const std::vector>& replicas, + const std::vector& devices); + + // Private methods. + + /// Used in the implementation of `Cloneable`. + virtual void clone_(Module& other, const std::optional& device); + + /// The implementation of the various `to()` methods. + template + void to_impl(Ts&&... ts); + + /// Implements pretty printing the module hierarchy. + void pretty_print_recursive(std::ostream& stream, + const std::string& indentation) const; + + /// Applies the `function` to every submodule recursively, starting at this + /// `Module`'s children (thus not including the module itself). + void apply_to_submodules( + const NamedModulePointerApplyFunction& function, + const std::string& name_prefix = std::string()) const; + + /// Returns a shared_ptr to `this` in a safe (checked) way. + std::shared_ptr shared_from_this_checked() const; + + /// The registered buffers of this `Module`. + OrderedDict buffers_; + + /// The registered (direct) submodules of this `Module`. + OrderedDict> children_; + + /// The module's name (e.g. "LSTM"). + mutable std::optional name_; + + /// Whether the module is in training mode. + bool is_training_{true}; +}; + +/// Serialize a `Module` pointer into an `OutputArchive`. +TORCH_API serialize::OutputArchive& operator<<( + serialize::OutputArchive& archive, + const std::shared_ptr& module); + +/// Deserializes a `Module` from an `InputArchive`. +TORCH_API serialize::InputArchive& operator>>( + serialize::InputArchive& archive, + const std::shared_ptr& module); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ nn::Module ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template +typename ModuleType::ContainedType* Module::as() noexcept { + // Use the contained type of the `ModuleHolder`, e.g. `LinearImpl` for + // `Linear`, since `LinearImpl` inherits `nn::Module`. + return as(); +} + +template +const typename ModuleType::ContainedType* Module::as() const noexcept { + // Use the contained type of the `ModuleHolder`, e.g. `LinearImpl` for + // `Linear`, since `LinearImpl` inherits `nn::Module`. + return as(); +} + +template +ModuleType* Module::as() noexcept { + return dynamic_cast(this); +} + +template +const ModuleType* Module::as() const noexcept { + return dynamic_cast(this); +} + +template +std::shared_ptr Module::register_module( + std::string name, + std::shared_ptr module) { + TORCH_CHECK(!name.empty(), "Submodule name must not be empty"); + TORCH_CHECK(name.find('.') == std::string::npos, + "Submodule name must not contain a dot (got '", + name, + "')"); + auto& base_module = children_.insert(std::move(name), std::move(module)); + return std::dynamic_pointer_cast(base_module); +} + +template +std::shared_ptr Module::register_module( + std::string name, + ModuleHolder module_holder) { + return register_module(std::move(name), module_holder.ptr()); +} + +template +std::shared_ptr Module::replace_module( + const std::string& name, + std::shared_ptr module) { + auto& base_module = (children_[name] = std::move(module)); + return std::dynamic_pointer_cast(base_module); +} + +template +std::shared_ptr Module::replace_module( + const std::string& name, + ModuleHolder module_holder) { + return replace_module(name, module_holder.ptr()); +} + +template +void Module::to_impl(Ts&&... ts) { + // First call `to()` on every child module. + for (auto& child : children_) { + child.value()->to(ts...); + } + // Then move every parameter to the new dtype/device. + for (auto& parameter : named_parameters(/*recurse=*/false)) { + parameter->set_data(parameter->to(ts...)); + } + // Then move every buffer to the new dtype/device. + for (auto& buffer : named_buffers(/*recurse=*/false)) { + buffer->set_data(buffer->to(ts...)); + } +} + +} // namespace llm::nn diff --git a/src/module/module_holder.h b/src/module/module_holder.h new file mode 100644 index 00000000..36689aa0 --- /dev/null +++ b/src/module/module_holder.h @@ -0,0 +1,184 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace llm { +namespace detail { +// Dump all the template metaprogramming in this file. +#include "pimpl-inl.h" +} // namespace detail + +namespace nn { +using namespace torch; + +/// A `ModuleHolder` is essentially a wrapper around `std::shared_ptr` where +/// `M` is an `nn::Module` subclass, with convenient constructors defined for +/// the kind of constructions we want to allow for our modules. +template +class ModuleHolder : detail::ModuleHolderIndicator { + protected: + /// The module pointer this class wraps. + /// NOTE: Must be placed at the top of the class so that we can use it with + /// trailing return types below. + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + std::shared_ptr impl_; + + public: + using ContainedType = Contained; + + /// Default constructs the contained module if if has a default constructor, + /// else produces a static error. + /// + /// NOTE: This uses the behavior of template + /// classes in C++ that constructors (or any methods) are only compiled when + /// actually used. + ModuleHolder() : impl_(default_construct()) { + static_assert( + std::is_default_constructible_v, + "You are trying to default construct a module which has " + "no default constructor. Use = nullptr to give it the empty state " + "(e.g. `Linear linear = nullptr;` instead of `Linear linear;`)."); + } + + /// Constructs the `ModuleHolder` with an empty contained value. Access to + /// the underlying module is not permitted and will throw an exception, until + /// a value is assigned. + /* implicit */ ModuleHolder(std::nullptr_t) : impl_(nullptr) {} + + /// Constructs the `ModuleHolder` with a contained module, forwarding all + /// arguments to its constructor. + template ::value && + (sizeof...(Tail) == 0))>> + explicit ModuleHolder(Head&& head, Tail&&... tail) + : impl_(new Contained(std::forward(head), + std::forward(tail)...)) {} + + /// Constructs the `ModuleHolder` from a pointer to the contained type. + /// Example: `Linear(std::make_shared(...))`. + /* implicit */ ModuleHolder(std::shared_ptr module) + : impl_(std::move(module)) {} + + /// Returns true if the `ModuleHolder` contains a module, or false if it is + /// `nullptr`. + explicit operator bool() const noexcept { return !is_empty(); } + + /// Forwards to the contained module. + Contained* operator->() { return get(); } + + /// Forwards to the contained module. + const Contained* operator->() const { return get(); } + + /// Returns a reference to the contained module. + Contained& operator*() { return *get(); } + + /// Returns a const reference to the contained module. + const Contained& operator*() const { return *get(); } + + /// Returns a shared pointer to the underlying module. + const std::shared_ptr& ptr() const { + TORCH_CHECK(!is_empty(), "Accessing empty ModuleHolder"); + return impl_; + } + + /// Returns a pointer to the underlying module. + Contained* get() { + TORCH_CHECK(!is_empty(), "Accessing empty ModuleHolder"); + return impl_.get(); + } + + /// Returns a const pointer to the underlying module. + const Contained* get() const { + TORCH_CHECK(!is_empty(), "Accessing empty ModuleHolder"); + return impl_.get(); + } + + /// Calls the `forward()` method of the contained module. + template + auto operator()(Args&&... args) + -> detail::return_type_of_forward_t { + // This will not compile if the module does not have a `forward()` method + // (as expected). + // NOTE: `std::forward` is qualified to prevent VS2017 emitting + // error C2872: 'std': ambiguous symbol + return impl_->forward(::std::forward(args)...); + } + + /// Forwards to the subscript operator of the contained module. + /// NOTE: std::forward is qualified to prevent VS2017 emitting + /// error C2872: 'std': ambiguous symbol + template + decltype(auto) operator[](Arg&& arg) { + return (*impl_)[::std::forward(arg)]; + } + + /// Returns true if the `ModuleHolder` does not contain a module. + bool is_empty() const noexcept { return impl_ == nullptr; } + + private: + template + std::shared_ptr default_construct() { + if constexpr (std::is_default_constructible_v) { + return std::make_shared(); + } else { + return nullptr; + } + } +}; + +/// Pretty prints the given `Module` into the `ostream`. +template +std::ostream& operator<<(std::ostream& stream, + const nn::ModuleHolder& module) { + return stream << *module; +} + +/// Serializes a `ModuleHolder` into an `OutputArchive`. +template +serialize::OutputArchive& operator<<( + serialize::OutputArchive& archive, + const nn::ModuleHolder& module) { + return archive << module.ptr(); +} + +/// Deserializes a `ModuleHolder` from an `InputArchive`. +template +serialize::InputArchive& operator>>(serialize::InputArchive& archive, + nn::ModuleHolder& module) { + return archive >> module.ptr(); +} + +} // namespace nn +} // namespace llm + +// Workaround for CUDA 10.2 and below not allowing attribute unused on +// using declarations. +#ifdef __CUDACC__ +#define UNUSED_EXCEPT_CUDA +#else +#define UNUSED_EXCEPT_CUDA [[maybe_unused]] +#endif + +/// Defines a class `Name` which inherits from `nn::ModuleHolder` to provide a +/// wrapper over a `std::shared_ptr`. +/// `Impl` is a type alias for `ImplType` which provides a way to call static +/// method of `ImplType`. +#define LLM_MODULE_IMPL(Name, ImplType) \ + class Name : public llm::nn::ModuleHolder { /* NOLINT */ \ + public: \ + using llm::nn::ModuleHolder::ModuleHolder; \ + using Impl TORCH_UNUSED_EXCEPT_CUDA = ImplType; \ + } + +/// Like `TORCH_MODULE_IMPL`, but defaults the `ImplType` name to `Impl`. +#define LLM_MODULE(Name) LLM_MODULE_IMPL(Name, Name##Impl) diff --git a/src/module/modulelist.h b/src/module/modulelist.h new file mode 100644 index 00000000..e96f4f97 --- /dev/null +++ b/src/module/modulelist.h @@ -0,0 +1,258 @@ +#pragma once + +#include +#include +#include + +#include +#include + +#include "cloneable.h" +#include "module.h" +#include "module_holder.h" + +namespace llm::nn { + +/// A list of `Module`s that registers its elements. +/// +/// \rst +/// .. code-block:: cpp +/// +/// torch::nn::ModuleList mlist( +/// torch::nn::Linear(3, 4), +/// torch::nn::BatchNorm1d(4), +/// torch::nn::Dropout(0.5) +/// ); +/// +/// for (const auto &module : *mlist) { +/// module->pretty_print(std::cout); +/// } +/// +/// \endrst +/// +/// Why should you use `ModuleList` instead of a simple `std::vector`? The value +/// a `ModuleList` provides over manually calling a sequence of modules is that +/// it allows treating the whole container *as a single module*, such that +/// performing a transformation on the `ModuleList` applies to each of the +/// modules it stores (which are each a registered submodule of the +/// `ModuleList`). For example, calling +/// `.to(torch::kCUDA)` on a `ModuleList` will move each module in the list to +/// CUDA memory. For example: +/// +/// \rst +/// .. code-block:: cpp +/// +/// torch::nn::ModuleList mlist( +/// torch::nn::Linear(3, 4), +/// torch::nn::BatchNorm1d(4), +/// torch::nn::Dropout(0.5) +/// ); +/// +/// // Convert all modules to CUDA. +/// mlist->to(torch::kCUDA); +/// +/// \endrst +/// +/// Finally, `ModuleList` provides a lightweight container API, such as allowing +/// iteration over submodules, positional access, adding a new module after +/// construction via `push_back`, as well as joining two `ModuleList`s via +/// `extend`. +class ModuleListImpl : public Cloneable { + public: + using Iterator = std::vector>::iterator; + using ConstIterator = std::vector>::const_iterator; + + ModuleListImpl() = default; + + /// Constructs the `ModuleList` from a variadic list of modules. + template + explicit ModuleListImpl(Modules&&... modules) { + modules_.reserve(sizeof...(Modules)); + push_back_var(std::forward(modules)...); + } + + /// Special cloning function for `ModuleList` because it does not use + /// `reset()`. + std::shared_ptr clone( + const std::optional& device = std::nullopt) const override { + auto clone = std::make_shared(); + for (const auto& module : modules_) { + clone->push_back(module->clone(device)); + } + return clone; + } + + /// `reset()` is empty for `ModuleList`, since it does not have parameters of + /// its own. + void reset() override {} + + /// Pretty prints the `ModuleList` module into the given `stream`. + void pretty_print(std::ostream& stream) const override { + stream << "torch::nn::ModuleList"; + } + + void push_back(std::shared_ptr module) { + modules_.push_back(std::move(module)); + const auto index = modules_.size() - 1; + register_module(std::to_string(index), modules_[index]); + } + + /// Adds a new `Module` to the `ModuleList` container, moving or copying + /// it into a `shared_ptr` internally. This method allows passing value types, + /// and letting the container deal with the boxing. + template > + void push_back(M&& module) { + using Type = std::remove_reference_t; + push_back(std::make_shared(std::forward(module))); + } + + /// Unwraps the contained module of a `ModuleHolder` and adds it to the + /// `ModuleList`. + template + void push_back(const ModuleHolder& module_holder) { + push_back(module_holder.ptr()); + } + + /// Iterates over the container and calls `push_back()` on each value. + template + void extend(const Container& container) { + for (const auto& module : container) { + push_back(module); + } + } + + /// Returns an iterator to the start of the `ModuleList`. + Iterator begin() { return modules_.begin(); } + + /// Returns a const iterator to the start of the `ModuleList`. + ConstIterator begin() const { return modules_.begin(); } + + /// Returns an iterator to the end of the `ModuleList`. + Iterator end() { return modules_.end(); } + + /// Returns a const iterator to the end of the `ModuleList`. + ConstIterator end() const { return modules_.end(); } + + /// Attempts to return the module at the given index as the requested type. + /// Throws an exception if the index is out of bounds or the types do not + /// match. + template + T& at(size_t index) { + static_assert(torch::detail::is_module::value, + "Can only call ModuleList::at with an nn::Module type"); + TORCH_CHECK(index < size(), "Index out of range"); + auto module = modules_[index]->as(); + TORCH_CHECK(module, + "Unable to cast module[", + index, + "] to ", + c10::demangle(typeid(T).name())); + return *module; + } + + /// Attempts to return the module at the given index as the requested type. + /// Throws an exception if the index is out of bounds or the types do not + /// match. + template + const T& at(size_t index) const { + static_assert(torch::detail::is_module::value, + "Can only call ModuleList::at with an nn::Module type"); + TORCH_CHECK(index < size(), "Index out of range"); + const auto module = modules_[index]->as(); + TORCH_CHECK(module, + "Unable to cast module[", + index, + "] to ", + c10::demangle(typeid(T).name())); + return *module; + } + + /// Attempts to return a `std::shared_ptr` whose dynamic type is that of the + /// underlying module at the given index. Throws an exception if the index is + /// out of bounds. + std::shared_ptr ptr(size_t index) const { + TORCH_CHECK(index < size(), "Index out of range"); + return modules_[index]; + } + + /// Attempts to return a `std::shared_ptr` whose type is the one provided. + /// Throws an exception if the index is out of bounds or the types do not + /// match. + template + std::shared_ptr ptr(size_t index) const { + static_assert(torch::detail::is_module::value, + "Can only call ModuleList::ptr with an nn::Module type"); + TORCH_CHECK(index < size(), "Index out of range"); + return std::dynamic_pointer_cast(modules_[index]); + } + + /// Like `ptr(index)`. + std::shared_ptr operator[](size_t index) const { + // This is the only method we can call without a type. + return ptr(index); + } + + /// The current size of the `ModuleList` container. + size_t size() const noexcept { return modules_.size(); } + + /// True if there are no modules in the `ModuleList`. + bool is_empty() const noexcept { return size() == 0; } + + void insert(size_t index, std::shared_ptr module) { + TORCH_CHECK(index <= size(), "Index out of range"); + + if (index == size()) + push_back(std::move(module)); + else { + modules_.insert(modules_.begin() + Iterator::difference_type(index), + std::move(module)); + + for (const auto i : c10::irange(index, size() - 1)) { + (void)i; // Suppress unused variable warning + replace_module(std::to_string(index), modules_[index]); + } + register_module(std::to_string(size() - 1), modules_.back()); + } + } + + /// Unwraps the contained module of a `ModuleHolder` and inserts it in the + /// `ModuleList`. + template + void insert(size_t index, const ModuleHolder& module_holder) { + insert(index, module_holder.ptr()); + } + + /// inserts a new `Module` to the `ModuleList` container, moving or copying + /// it into a `shared_ptr` internally. This method allows passing value types, + /// and letting the container deal with the boxing. + template > + void insert(size_t index, M&& module) { + using Type = std::remove_reference_t; + insert(index, std::make_shared(std::forward(module))); + } + + private: + template + void push_back_var(Head&& head, Tail&&... tail) { + push_back(std::forward(head)); + // Recursively calls this method, until the parameter pack only thas this + // entry left. Then calls `push_back()` a final time (above). + push_back_var(std::forward(tail)...); + } + + /// The base case, when the list of modules is empty. + void push_back_var() {} + + // Box the AnyModules to give ModuleList reference semantics, like the rest of + // the API. Note that this is not required otherwise, this could just be a + // `vector`. + std::vector> modules_; +}; + +/// A `ModuleHolder` subclass for `ModuleListImpl`. +/// See the documentation for `ModuleListImpl` class to learn what methods it +/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +LLM_MODULE(ModuleList); + +} // namespace llm::nn diff --git a/src/module/pimpl-inl.h b/src/module/pimpl-inl.h new file mode 100644 index 00000000..c5d91f34 --- /dev/null +++ b/src/module/pimpl-inl.h @@ -0,0 +1,76 @@ +// This class exists only to do SFINAE on abstract types `T` that are really +// `ModuleHolder`, because there's no good way to say that `T` is a +// `ModuleHolder` over some unknown type `ModuleType`. With this, you can do +// `enable_if_t>`. +struct ModuleHolderIndicator {}; + +// A type trait that is true for types that are `ModuleHolder`s. +template +using is_module_holder = + std::is_base_of>; + +template +using disable_if_module_holder_t = + std::enable_if_t::value>; + +// A collection of templates that answer the question whether a type `T` is a +// `ModuleHolder`, and if so whether its contained type is of type `C`. This is +// tricky because it is hard to short circuit in template metaprogramming. A +// naive and incorrect solution to this problem would be something like +// `disable_if::value && typename T::ContainedType == C>`. +// This would disable all types that are not `ModuleHolder`s, because even +// though the `is_module_holder::value` may be `false` for such types the +// `T::ContainedType` access would be ill-formed and thus fail the whole +// expression by the rules of SFINAE. Instead we have to use template +// specialization to statically branch on the first condition +// (`is_module_holder`) and are only then allowed to query +// `T::ContainedType` in the branch for which the condition was true. + +// Base template. +template +struct is_module_holder_of_impl; + +// False branch. `T` is not a `ModuleHolder` and thus not a `ModuleHolder` with +// contained type `C`. +template +struct is_module_holder_of_impl : std::false_type {}; + +// True branch. `T` is a `ModuleHolder` and thus we can legit access its +// `ContainedType` and compare it against `C`. +template +struct is_module_holder_of_impl + : std::is_same {}; + +// Helper template. +template +struct is_module_holder_of + : is_module_holder_of_impl::value, + std::decay_t, + std::decay_t> {}; + +// A collection of templates that allow deducing the return type of the +// `forward()` method, but only if a module actually has a `forward()` method, +// and otherwise deduces to the type `void`. + +template +struct return_type_of_forward_impl; + +template +struct return_type_of_forward_impl { + using type = decltype(::std::declval().forward(::std::declval()...)); +}; + +template +struct return_type_of_forward_impl { + using type = void; +}; + +template +using return_type_of_forward = + return_type_of_forward_impl::value, + C, + Args...>; + +template +using return_type_of_forward_t = + typename return_type_of_forward::type;