Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/layers/attention/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

namespace llm {

class AttentionImpl : public llm::nn::Module {
class AttentionImpl : public Module {
public:
AttentionImpl(int64_t n_heads,
int64_t n_kv_heads,
Expand Down
16 changes: 6 additions & 10 deletions src/layers/embedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,14 @@ namespace llm {
// This module is often used to store word embeddings and retrieve them using
// indices.

class EmbeddingImpl : public llm::nn::Module {
class EmbeddingImpl : public Module {
public:
EmbeddingImpl(int64_t num_embeddings,
int64_t embedding_dim,
const torch::TensorOptions& options) {
// register the weight parameter
weight_ = register_parameter(
"weight",
torch::empty({num_embeddings, embedding_dim}, options),
/*requires_grad=*/false);
"weight", torch::empty({num_embeddings, embedding_dim}, options));
}

// The input to the module is a list of indices, and the output is the
Expand Down Expand Up @@ -68,7 +66,7 @@ class EmbeddingImpl : public llm::nn::Module {
LLM_MODULE(Embedding);

// Embedding parallelized in the embedding dimension.
class ParallelEmbeddingImpl : public llm::nn::Module {
class ParallelEmbeddingImpl : public Module {
public:
ParallelEmbeddingImpl(int64_t num_embeddings,
int64_t embedding_dim,
Expand All @@ -84,8 +82,7 @@ class ParallelEmbeddingImpl : public llm::nn::Module {
// register the weight parameter
weight_ = register_parameter(
"weight",
torch::empty({num_embeddings, embedding_dim_per_partition}, options),
/*requires_grad=*/false);
torch::empty({num_embeddings, embedding_dim_per_partition}, options));
}

// The input to the module is a list of indices, and the output is the
Expand Down Expand Up @@ -139,7 +136,7 @@ class ParallelEmbeddingImpl : public llm::nn::Module {
LLM_MODULE(ParallelEmbedding);

// Embedding parallelized in the vocabulary dimension
class VocabParallelEmbeddingImpl : public llm::nn::Module {
class VocabParallelEmbeddingImpl : public Module {
public:
VocabParallelEmbeddingImpl(int64_t num_embeddings,
int64_t embedding_dim,
Expand All @@ -154,8 +151,7 @@ class VocabParallelEmbeddingImpl : public llm::nn::Module {
// register the weight parameter
weight_ = register_parameter(
"weight",
torch::empty({num_embeddings_per_partition, embedding_dim}, options),
/*requires_grad=*/false);
torch::empty({num_embeddings_per_partition, embedding_dim}, options));
}

// The input to the module is a list of indices, and the output is the
Expand Down
2 changes: 1 addition & 1 deletion src/layers/fused_linear.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

namespace llm {

class FusedColumnParallelLinearImpl : public llm::nn::Module {
class FusedColumnParallelLinearImpl : public Module {
public:
FusedColumnParallelLinearImpl(int64_t in_features,
const std::vector<int64_t>& out_features,
Expand Down
14 changes: 7 additions & 7 deletions src/layers/linear.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ using TensorTransform = std::function<torch::Tensor(const torch::Tensor&)>;
// an interface for parallel linear layer.
// all linear classes should inherit from this class and implement the forward
// function.
class ParallelLinearImpl : public llm::nn::Module {
class ParallelLinearImpl : public Module {
public:
~ParallelLinearImpl() override = default;

Expand All @@ -39,10 +39,10 @@ class ParallelLinearImpl : public llm::nn::Module {
}
};

class ColumnParallelLinear : public llm::nn::ModuleHolder<ParallelLinearImpl> {
class ColumnParallelLinear : public ModuleHolder<ParallelLinearImpl> {
public:
using llm::nn::ModuleHolder<ParallelLinearImpl>::ModuleHolder;
using Impl __attribute__((__unused__)) = ParallelLinearImpl;
using ModuleHolder<ParallelLinearImpl>::ModuleHolder;
using Impl [[maybe_unused]] = ParallelLinearImpl;

// construct a rotary positional embedding.
// chose right implementation based on the args.
Expand All @@ -62,10 +62,10 @@ class ColumnParallelLinear : public llm::nn::ModuleHolder<ParallelLinearImpl> {
const torch::TensorOptions& options);
};

class RowParallelLinear : public llm::nn::ModuleHolder<ParallelLinearImpl> {
class RowParallelLinear : public ModuleHolder<ParallelLinearImpl> {
public:
using llm::nn::ModuleHolder<ParallelLinearImpl>::ModuleHolder;
using Impl __attribute__((__unused__)) = ParallelLinearImpl;
using ModuleHolder<ParallelLinearImpl>::ModuleHolder;
using Impl [[maybe_unused]] = ParallelLinearImpl;

// construct a rotary positional embedding.
// chose right implementation based on the args.
Expand Down
16 changes: 5 additions & 11 deletions src/layers/linear_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,11 @@ ColumnParallelLinearImpl::ColumnParallelLinearImpl(
// we allocate the transpose.
weight_ = register_parameter(
"weight",
torch::empty({out_features_per_partition, in_features}, options),
/*requires_grad=*/false);
torch::empty({out_features_per_partition, in_features}, options));

if (bias) {
bias_ =
register_parameter("bias",
torch::empty({out_features_per_partition}, options),
/*requires_grad=*/false);
bias_ = register_parameter(
"bias", torch::empty({out_features_per_partition}, options));
}
}

Expand Down Expand Up @@ -104,13 +101,10 @@ RowParallelLinearImpl::RowParallelLinearImpl(
// Allocate the transpose since linear performs XA^T.
weight_ = register_parameter(
"weight",
torch::empty({out_features, in_features_per_partition}, options),
/*requires_grad=*/false);
torch::empty({out_features, in_features_per_partition}, options));

if (bias) {
bias_ = register_parameter("bias",
torch::empty({out_features}, options),
/*requires_grad=*/false);
bias_ = register_parameter("bias", torch::empty({out_features}, options));
}
}

Expand Down
30 changes: 11 additions & 19 deletions src/layers/normalization.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ inline torch::Tensor layer_norm(torch::Tensor input,
// apply layer normalization over a mini-batch of inputs as described in
// the paper `Layer Normalization`: https://arxiv.org/abs/1607.06450
// x = ((x - mean(x)) / sqrt(std(x) + eps)) * weight + bias
class LayerNormImpl : public llm::nn::Module {
class LayerNormImpl : public Module {
public:
// dim: the dim over which the mean and std are calculated separately.
// eps: a value added to the denominator for numerical stability.
Expand All @@ -75,13 +75,11 @@ class LayerNormImpl : public llm::nn::Module {
const torch::TensorOptions& options)
: eps_(eps) {
normalized_shape_ = {dim};
weight_ = register_parameter("weight",
torch::empty(normalized_shape_, options),
/*requires_grad=*/false);
weight_ =
register_parameter("weight", torch::empty(normalized_shape_, options));
if (bias) {
bias_ = register_parameter("bias",
torch::zeros(normalized_shape_, options),
/*requires_grad=*/false);
bias_ =
register_parameter("bias", torch::zeros(normalized_shape_, options));
}
}

Expand Down Expand Up @@ -145,13 +143,11 @@ class LayerNormImpl : public llm::nn::Module {
LLM_MODULE(LayerNorm);

// Root mean square normalization
class RMSNormImpl : public llm::nn::Module {
class RMSNormImpl : public Module {
public:
RMSNormImpl(int64_t dim, float eps, const torch::TensorOptions& options)
: eps_(eps) {
weight_ = register_parameter("weight",
torch::empty({dim}, options),
/*requires_grad=*/false);
weight_ = register_parameter("weight", torch::empty({dim}, options));
}

torch::Tensor forward(const torch::Tensor& input) {
Expand Down Expand Up @@ -195,13 +191,11 @@ class RMSNormImpl : public llm::nn::Module {
};
LLM_MODULE(RMSNorm);

class GemmaRMSNormImpl : public llm::nn::Module {
class GemmaRMSNormImpl : public Module {
public:
GemmaRMSNormImpl(int64_t dim, float eps, const torch::TensorOptions& options)
: eps_(eps) {
weight_ = register_parameter("weight",
torch::empty({dim}, options),
/*requires_grad=*/false);
weight_ = register_parameter("weight", torch::empty({dim}, options));
}

torch::Tensor forward(const torch::Tensor& input) {
Expand Down Expand Up @@ -246,15 +240,13 @@ class GemmaRMSNormImpl : public llm::nn::Module {
LLM_MODULE(GemmaRMSNorm);

// Root mean square normalization
class RMSNormResidualImpl : public llm::nn::Module {
class RMSNormResidualImpl : public Module {
public:
RMSNormResidualImpl(int64_t dim,
float eps,
const torch::TensorOptions& options)
: eps_(eps) {
weight_ = register_parameter("weight",
torch::empty({dim}, options),
/*requires_grad=*/false);
weight_ = register_parameter("weight", torch::empty({dim}, options));
}

torch::Tensor forward(const torch::Tensor& input, torch::Tensor& residual) {
Expand Down
4 changes: 2 additions & 2 deletions src/layers/pos_embedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ class RotaryEmbeddingImpl : public torch::nn::Module {

// RotaryEmbedding is a wrapper class that chooses the right rotary positional
// embedding implementation based on the args.
// Similar to TORCH_MODULE(RotaryEmbedding) except of the explicit constructor.
// Similar to LLM_MODULE(RotaryEmbedding) except of the explicit constructor.
class RotaryEmbedding : public torch::nn::ModuleHolder<RotaryEmbeddingImpl> {
public:
using torch::nn::ModuleHolder<RotaryEmbeddingImpl>::ModuleHolder;
using Impl __attribute__((__unused__)) = RotaryEmbeddingImpl;
using Impl [[maybe_unused]] = RotaryEmbeddingImpl;

// construct a rotary positional embedding.
// chose right implementation based on the args.
Expand Down
4 changes: 3 additions & 1 deletion src/layers/qkv_linear.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
#include "fused_linear.h"
#include "model_loader/state_dict.h"
#include "model_parallel/parallel_args.h"
#include "module/module.h"
#include "module/module_holder.h"
#include "quantization/quant_args.h"

namespace llm {

// a thin wrapper to handle state_dict loading for QKV with
// support of MQA/GQA
class QKVColumnParallelLinearImpl : public llm::nn::Module {
class QKVColumnParallelLinearImpl : public Module {
public:
QKVColumnParallelLinearImpl(int64_t hidden_size,
int64_t n_heads,
Expand Down
16 changes: 8 additions & 8 deletions src/models/aquila.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
#include "models/parameters.h"
#include "module/module.h"
#include "module/module_holder.h"
#include "module/modulelist.h"
#include "module/module_list.h"

// Aquila model compatible with huggingface weights
namespace llm::hf {

class AquilaMLPImpl : public llm::nn::Module {
class AquilaMLPImpl : public Module {
public:
AquilaMLPImpl(const ModelArgs& args,
const QuantArgs& quant_args,
Expand Down Expand Up @@ -82,7 +82,7 @@ class AquilaMLPImpl : public llm::nn::Module {
};
LLM_MODULE(AquilaMLP);

class AquilaAttentionImpl : public llm::nn::Module {
class AquilaAttentionImpl : public Module {
public:
AquilaAttentionImpl(const ModelArgs& args,
const QuantArgs& quant_args,
Expand Down Expand Up @@ -169,7 +169,7 @@ class AquilaAttentionImpl : public llm::nn::Module {
};
LLM_MODULE(AquilaAttention);

class AquilaDecoderLayerImpl : public llm::nn::Module {
class AquilaDecoderLayerImpl : public Module {
public:
AquilaDecoderLayerImpl(const ModelArgs& args,
const QuantArgs& quant_args,
Expand Down Expand Up @@ -229,7 +229,7 @@ class AquilaDecoderLayerImpl : public llm::nn::Module {
};
LLM_MODULE(AquilaDecoderLayer);

class AquilaModelImpl : public llm::nn::Module {
class AquilaModelImpl : public Module {
public:
AquilaModelImpl(const ModelArgs& args,
const QuantArgs& quant_args,
Expand All @@ -244,7 +244,7 @@ class AquilaModelImpl : public llm::nn::Module {
handler_ = AttentionHandler::create_handler_with_rope(
args, /*interleaved=*/false, options);

blocks_ = register_module("layers", llm::nn::ModuleList());
blocks_ = register_module("layers", ModuleList());
layers_.reserve(args.n_layers());
for (int32_t i = 0; i < args.n_layers(); i++) {
auto block = AquilaDecoderLayer(
Expand Down Expand Up @@ -298,15 +298,15 @@ class AquilaModelImpl : public llm::nn::Module {
// attention handler
std::unique_ptr<AttentionHandler> handler_{nullptr};

llm::nn::ModuleList blocks_{nullptr};
ModuleList blocks_{nullptr};
// hold same data but different type as blocks_ to avoid type cast
std::vector<AquilaDecoderLayer> layers_;

RMSNorm norm_{nullptr};
};
LLM_MODULE(AquilaModel);

class AquilaForCausalLMImpl : public llm::nn::Module {
class AquilaForCausalLMImpl : public Module {
public:
AquilaForCausalLMImpl(const ModelArgs& args,
const QuantArgs& quant_args,
Expand Down
16 changes: 8 additions & 8 deletions src/models/baichuan.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#include "models/parameters.h"
#include "module/module.h"
#include "module/module_holder.h"
#include "module/modulelist.h"
#include "module/module_list.h"

// Baichuan model compatible with huggingface weights

Expand All @@ -31,7 +31,7 @@ enum class BaichuanType : uint8_t {
Baichuan2_13B,
};

class BaichuanMLPImpl : public llm::nn::Module {
class BaichuanMLPImpl : public Module {
public:
BaichuanMLPImpl(const ModelArgs& args,
const QuantArgs& quant_args,
Expand Down Expand Up @@ -92,7 +92,7 @@ class BaichuanMLPImpl : public llm::nn::Module {
};
LLM_MODULE(BaichuanMLP);

class BaichuanAttentionImpl : public llm::nn::Module {
class BaichuanAttentionImpl : public Module {
public:
BaichuanAttentionImpl(const ModelArgs& args,
const QuantArgs& quant_args,
Expand Down Expand Up @@ -179,7 +179,7 @@ class BaichuanAttentionImpl : public llm::nn::Module {
};
LLM_MODULE(BaichuanAttention);

class BaichuanDecoderLayerImpl : public llm::nn::Module {
class BaichuanDecoderLayerImpl : public Module {
public:
BaichuanDecoderLayerImpl(const ModelArgs& args,
const QuantArgs& quant_args,
Expand Down Expand Up @@ -247,7 +247,7 @@ class BaichuanDecoderLayerImpl : public llm::nn::Module {
};
LLM_MODULE(BaichuanDecoderLayer);

class BaichuanModelImpl : public llm::nn::Module {
class BaichuanModelImpl : public Module {
public:
BaichuanModelImpl(const ModelArgs& args,
const QuantArgs& quant_args,
Expand All @@ -271,7 +271,7 @@ class BaichuanModelImpl : public llm::nn::Module {
args, alibi_slopes, options);
}

blocks_ = register_module("layers", llm::nn::ModuleList());
blocks_ = register_module("layers", ModuleList());
layers_.reserve(args.n_layers());
for (int32_t i = 0; i < args.n_layers(); i++) {
auto block = BaichuanDecoderLayer(args,
Expand Down Expand Up @@ -366,7 +366,7 @@ class BaichuanModelImpl : public llm::nn::Module {
std::unique_ptr<AttentionHandler> handler_{nullptr};

// parameter members, must be registered
llm::nn::ModuleList blocks_{nullptr};
ModuleList blocks_{nullptr};
// hold same data but different type as blocks_ to avoid type cast
std::vector<BaichuanDecoderLayer> layers_;

Expand All @@ -375,7 +375,7 @@ class BaichuanModelImpl : public llm::nn::Module {
};
LLM_MODULE(BaichuanModel);

class BaichuanForCausalLMImpl : public llm::nn::Module {
class BaichuanForCausalLMImpl : public Module {
public:
BaichuanForCausalLMImpl(const ModelArgs& args,
const QuantArgs& quant_args,
Expand Down
Loading