Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_subdirectory(common)
add_subdirectory(handlers)
add_subdirectory(kernels)
add_subdirectory(tokenizer)
add_subdirectory(module)
add_subdirectory(layers)
add_subdirectory(quantization)
add_subdirectory(models)
Expand Down
19 changes: 11 additions & 8 deletions src/layers/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ include(cc_test)

cc_library(
NAME
linear
linear
HDRS
linear.h
qkv_linear.h
Expand All @@ -21,35 +21,37 @@ cc_library(
:model_parallel
:quantization
:kernels
:module
glog::glog
gflags::gflags
torch
)

cc_library(
NAME
NAME
pos_embedding
HDRS
HDRS
pos_embedding.h
SRCS
SRCS
pos_embedding.cpp
DEPS
:state_dict
:memory
:kernels
:module
glog::glog
gflags::gflags
torch
)

cc_library(
NAME
NAME
layers
HDRS
HDRS
normalization.h
embedding.h
activation.h
SRCS
SRCS
activation.cpp
DEPS
:state_dict
Expand All @@ -58,6 +60,7 @@ cc_library(
:pos_embedding
:attention
:kernels
:module
glog::glog
gflags::gflags
torch
Expand All @@ -80,4 +83,4 @@ cc_test(
)

add_subdirectory(attention)
add_subdirectory(moe)
add_subdirectory(moe)
6 changes: 4 additions & 2 deletions src/layers/attention/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
#include "layers/attention/handler.h"
#include "memory/kv_cache.h"
#include "models/parameters.h"
#include "module/module.h"
#include "module/module_holder.h"

namespace llm {

class AttentionImpl : public torch::nn::Module {
class AttentionImpl : public llm::nn::Module {
public:
AttentionImpl(int64_t n_heads,
int64_t n_kv_heads,
Expand Down Expand Up @@ -38,6 +40,6 @@ class AttentionImpl : public torch::nn::Module {
// sliding window for self-attention, -1 means no sliding window
int32_t sliding_window_ = -1;
};
TORCH_MODULE(Attention);
LLM_MODULE(Attention);

} // namespace llm
17 changes: 9 additions & 8 deletions src/layers/embedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@

#include "model_loader/state_dict.h"
#include "model_parallel/model_parallel.h"
#include "module/module.h"
#include "module/module_holder.h"

namespace llm {

// A simple lookup table that stores embeddings of a fixed dictionary and size.
// This module is often used to store word embeddings and retrieve them using
// indices.

class EmbeddingImpl : public torch::nn::Module {
class EmbeddingImpl : public llm::nn::Module {
public:
EmbeddingImpl(int64_t num_embeddings,
int64_t embedding_dim,
Expand Down Expand Up @@ -63,10 +65,10 @@ class EmbeddingImpl : public torch::nn::Module {
// whether the weight is loaded
bool is_loaded_ = false;
};
TORCH_MODULE(Embedding);
LLM_MODULE(Embedding);

// Embedding parallelized in the embedding dimension.
class ParallelEmbeddingImpl : public torch::nn::Module {
class ParallelEmbeddingImpl : public llm::nn::Module {
public:
ParallelEmbeddingImpl(int64_t num_embeddings,
int64_t embedding_dim,
Expand Down Expand Up @@ -134,10 +136,10 @@ class ParallelEmbeddingImpl : public torch::nn::Module {
// parallel args
ParallelArgs parallel_args_;
};
TORCH_MODULE(ParallelEmbedding);
LLM_MODULE(ParallelEmbedding);

// Embedding parallelized in the vocabulary dimension
class VocabParallelEmbeddingImpl : public torch::nn::Module {
class VocabParallelEmbeddingImpl : public llm::nn::Module {
public:
VocabParallelEmbeddingImpl(int64_t num_embeddings,
int64_t embedding_dim,
Expand All @@ -152,8 +154,7 @@ class VocabParallelEmbeddingImpl : public torch::nn::Module {
// register the weight parameter
weight_ = register_parameter(
"weight",
torch::empty({num_embeddings_per_partition, embedding_dim},
options),
torch::empty({num_embeddings_per_partition, embedding_dim}, options),
/*requires_grad=*/false);
}

Expand Down Expand Up @@ -218,5 +219,5 @@ class VocabParallelEmbeddingImpl : public torch::nn::Module {
int64_t start_index_ = 0;
int64_t end_index_ = 0;
};
TORCH_MODULE(VocabParallelEmbedding);
LLM_MODULE(VocabParallelEmbedding);
} // namespace llm
6 changes: 4 additions & 2 deletions src/layers/fused_linear.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
#include "linear.h"
#include "model_loader/state_dict.h"
#include "model_parallel/parallel_args.h"
#include "module/module.h"
#include "module/module_holder.h"
#include "quantization/quant_args.h"

namespace llm {

class FusedColumnParallelLinearImpl : public torch::nn::Module {
class FusedColumnParallelLinearImpl : public llm::nn::Module {
public:
FusedColumnParallelLinearImpl(int64_t in_features,
const std::vector<int64_t>& out_features,
Expand Down Expand Up @@ -44,6 +46,6 @@ class FusedColumnParallelLinearImpl : public torch::nn::Module {
// whether the linear layer is fused
bool fused_ = false;
};
TORCH_MODULE(FusedColumnParallelLinear);
LLM_MODULE(FusedColumnParallelLinear);

} // namespace llm
13 changes: 7 additions & 6 deletions src/layers/linear.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

#include "model_loader/state_dict.h"
#include "model_parallel/parallel_args.h"
#include "module/module.h"
#include "module/module_holder.h"
#include "quantization/quant_args.h"

namespace llm {
Expand All @@ -14,7 +16,7 @@ using TensorTransform = std::function<torch::Tensor(const torch::Tensor&)>;
// an interface for parallel linear layer.
// all linear classes should inherit from this class and implement the forward
// function.
class ParallelLinearImpl : public torch::nn::Module {
class ParallelLinearImpl : public llm::nn::Module {
public:
~ParallelLinearImpl() override = default;

Expand All @@ -37,10 +39,9 @@ class ParallelLinearImpl : public torch::nn::Module {
}
};

class ColumnParallelLinear
: public torch::nn::ModuleHolder<ParallelLinearImpl> {
class ColumnParallelLinear : public llm::nn::ModuleHolder<ParallelLinearImpl> {
public:
using torch::nn::ModuleHolder<ParallelLinearImpl>::ModuleHolder;
using llm::nn::ModuleHolder<ParallelLinearImpl>::ModuleHolder;
using Impl __attribute__((__unused__)) = ParallelLinearImpl;

// construct a rotary positional embedding.
Expand All @@ -61,9 +62,9 @@ class ColumnParallelLinear
const torch::TensorOptions& options);
};

class RowParallelLinear : public torch::nn::ModuleHolder<ParallelLinearImpl> {
class RowParallelLinear : public llm::nn::ModuleHolder<ParallelLinearImpl> {
public:
using torch::nn::ModuleHolder<ParallelLinearImpl>::ModuleHolder;
using llm::nn::ModuleHolder<ParallelLinearImpl>::ModuleHolder;
using Impl __attribute__((__unused__)) = ParallelLinearImpl;

// construct a rotary positional embedding.
Expand Down
18 changes: 10 additions & 8 deletions src/layers/normalization.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

#include "kernels/layernorm_kernels.h"
#include "model_loader/state_dict.h"
#include "module/module.h"
#include "module/module_holder.h"

DECLARE_bool(disable_custom_kernels);
namespace llm {
Expand Down Expand Up @@ -63,7 +65,7 @@ inline torch::Tensor layer_norm(torch::Tensor input,
// apply layer normalization over a mini-batch of inputs as described in
// the paper `Layer Normalization`: https://arxiv.org/abs/1607.06450
// x = ((x - mean(x)) / sqrt(std(x) + eps)) * weight + bias
class LayerNormImpl : public torch::nn::Module {
class LayerNormImpl : public llm::nn::Module {
public:
// dim: the dim over which the mean and std are calculated separately.
// eps: a value added to the denominator for numerical stability.
Expand Down Expand Up @@ -140,10 +142,10 @@ class LayerNormImpl : public torch::nn::Module {
float eps_;
std::vector<int64_t> normalized_shape_;
};
TORCH_MODULE(LayerNorm);
LLM_MODULE(LayerNorm);

// Root mean square normalization
class RMSNormImpl : public torch::nn::Module {
class RMSNormImpl : public llm::nn::Module {
public:
RMSNormImpl(int64_t dim, float eps, const torch::TensorOptions& options)
: eps_(eps) {
Expand Down Expand Up @@ -191,9 +193,9 @@ class RMSNormImpl : public torch::nn::Module {
// configs
float eps_;
};
TORCH_MODULE(RMSNorm);
LLM_MODULE(RMSNorm);

class GemmaRMSNormImpl : public torch::nn::Module {
class GemmaRMSNormImpl : public llm::nn::Module {
public:
GemmaRMSNormImpl(int64_t dim, float eps, const torch::TensorOptions& options)
: eps_(eps) {
Expand Down Expand Up @@ -241,10 +243,10 @@ class GemmaRMSNormImpl : public torch::nn::Module {
// configs
float eps_;
};
TORCH_MODULE(GemmaRMSNorm);
LLM_MODULE(GemmaRMSNorm);

// Root mean square normalization
class RMSNormResidualImpl : public torch::nn::Module {
class RMSNormResidualImpl : public llm::nn::Module {
public:
RMSNormResidualImpl(int64_t dim,
float eps,
Expand Down Expand Up @@ -304,6 +306,6 @@ class RMSNormResidualImpl : public torch::nn::Module {
// configs
float eps_;
};
TORCH_MODULE(RMSNormResidual);
LLM_MODULE(RMSNormResidual);

} // namespace llm
4 changes: 2 additions & 2 deletions src/layers/qkv_linear.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace llm {

// a thin wrapper to handle state_dict loading for QKV with
// support of MQA/GQA
class QKVColumnParallelLinearImpl : public torch::nn::Module {
class QKVColumnParallelLinearImpl : public llm::nn::Module {
public:
QKVColumnParallelLinearImpl(int64_t hidden_size,
int64_t n_heads,
Expand Down Expand Up @@ -47,6 +47,6 @@ class QKVColumnParallelLinearImpl : public torch::nn::Module {

int64_t head_dim_ = 0;
};
TORCH_MODULE(QKVColumnParallelLinear);
LLM_MODULE(QKVColumnParallelLinear);

} // namespace llm
4 changes: 2 additions & 2 deletions src/models/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ include(cc_library)
include(cc_test)

cc_library(
NAME
NAME
models
HDRS
model_args.h
Expand All @@ -18,7 +18,7 @@ cc_library(
:quantization
:memory
:chat_template
:module
glog::glog
torch
)

Loading