Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions NAM/activations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ nam::activations::ActivationTanh _TANH = nam::activations::ActivationTanh();
nam::activations::ActivationFastTanh _FAST_TANH = nam::activations::ActivationFastTanh();
nam::activations::ActivationHardTanh _HARD_TANH = nam::activations::ActivationHardTanh();
nam::activations::ActivationReLU _RELU = nam::activations::ActivationReLU();
nam::activations::ActivationLeakyReLU _LEAKY_RELU = nam::activations::ActivationLeakyReLU(0.01); //FIXME does not parameterize LeakyReLU
nam::activations::ActivationLeakyReLU _LEAKY_RELU =
nam::activations::ActivationLeakyReLU(0.01); // FIXME does not parameterize LeakyReLU
nam::activations::ActivationSigmoid _SIGMOID = nam::activations::ActivationSigmoid();
nam::activations::ActivationSwish _SWISH = nam::activations::ActivationSwish();
nam::activations::ActivationHardSwish _HARD_SWISH = nam::activations::ActivationHardSwish();
Expand All @@ -13,8 +14,8 @@ nam::activations::ActivationLeakyHardTanh _LEAKY_HARD_TANH = nam::activations::A
bool nam::activations::Activation::using_fast_tanh = false;

std::unordered_map<std::string, nam::activations::Activation*> nam::activations::Activation::_activations = {
{"Tanh", &_TANH}, {"Hardtanh", &_HARD_TANH}, {"Fasttanh", &_FAST_TANH},
{"ReLU", &_RELU}, {"LeakyReLU", &_LEAKY_RELU}, {"Sigmoid", &_SIGMOID},
{"Tanh", &_TANH}, {"Hardtanh", &_HARD_TANH}, {"Fasttanh", &_FAST_TANH},
{"ReLU", &_RELU}, {"LeakyReLU", &_LEAKY_RELU}, {"Sigmoid", &_SIGMOID},
{"SiLU", &_SWISH}, {"Hardswish", &_HARD_SWISH}, {"LeakyHardtanh", &_LEAKY_HARD_TANH}};

nam::activations::Activation* tanh_bak = nullptr;
Expand Down Expand Up @@ -52,13 +53,18 @@ void nam::activations::Activation::disable_fast_tanh()
void nam::activations::Activation::enable_lut(std::string function_name, float min, float max, std::size_t n_points)
{
std::function<float(float)> fn;
if (function_name == "Tanh"){
if (function_name == "Tanh")
{
fn = [](float x) { return std::tanh(x); };
tanh_bak = _activations["Tanh"];
} else if (function_name == "Sigmoid") {
}
else if (function_name == "Sigmoid")
{
fn = sigmoid;
sigmoid_bak = _activations["Sigmoid"];
} else {
}
else
{
throw std::runtime_error("Tried to enable LUT for a function other than Tanh or Sigmoid");
}
FastLUTActivation lut_activation(min, max, n_points, fn);
Expand All @@ -67,12 +73,16 @@ void nam::activations::Activation::enable_lut(std::string function_name, float m

void nam::activations::Activation::disable_lut(std::string function_name)
{
if (function_name == "Tanh"){
if (function_name == "Tanh")
{
_activations["Tanh"] = tanh_bak;
} else if (function_name == "Sigmoid") {
}
else if (function_name == "Sigmoid")
{
_activations["Sigmoid"] = sigmoid_bak;
} else {
}
else
{
throw std::runtime_error("Tried to disable LUT for a function other than Tanh or Sigmoid");
}
}

111 changes: 65 additions & 46 deletions NAM/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,16 @@ inline float hard_tanh(float x)

inline float leaky_hardtanh(float x, float min_val, float max_val, float min_slope, float max_slope)
{
if (x < min_val) {
if (x < min_val)
{
return (x - min_val) * min_slope + min_val;
} else if (x > max_val) {
}
else if (x > max_val)
{
return (x - max_val) * max_slope + max_val;
} else {
}
else
{
return x;
}
}
Expand All @@ -50,7 +55,7 @@ inline float fast_sigmoid(const float x)
{
return 0.5f * (fast_tanh(x * 0.5f) + 1.0f);
}

inline float leaky_relu(float x, float negative_slope)
{
return x > 0.0f ? x : negative_slope * x;
Expand All @@ -68,12 +73,17 @@ inline float swish(float x)

inline float hardswish(float x)
{
if (x <= -3.0) {
if (x <= -3.0)
{
return 0;
} else if (x >= 3.0) {
}
else if (x >= 3.0)
{
return x;
} else {
return x * (x + 3.0)/6.0;
}
else
{
return x * (x + 3.0) / 6.0;
}
}

Expand Down Expand Up @@ -129,7 +139,8 @@ class ActivationLeakyHardTanh : public Activation
{
public:
ActivationLeakyHardTanh() = default;
ActivationLeakyHardTanh(float min_val_, float max_val_, float min_slope_, float max_slope_) {
ActivationLeakyHardTanh(float min_val_, float max_val_, float min_slope_, float max_slope_)
{
min_val = min_val_;
max_val = max_val_;
min_slope = min_slope_;
Expand All @@ -142,6 +153,7 @@ class ActivationLeakyHardTanh : public Activation
data[pos] = leaky_hardtanh(data[pos], min_val, max_val, min_slope, max_slope);
}
}

private:
float min_val = -1.0;
float max_val = 1.0;
Expand Down Expand Up @@ -177,16 +189,15 @@ class ActivationLeakyReLU : public Activation
{
public:
ActivationLeakyReLU() = default;
ActivationLeakyReLU(float ns) {
negative_slope = ns;
}
ActivationLeakyReLU(float ns) { negative_slope = ns; }
void apply(float* data, long size) override
{
for (long pos = 0; pos < size; pos++)
{
data[pos] = leaky_relu(data[pos], negative_slope);
}
}

private:
float negative_slope = 0.01;
};
Expand Down Expand Up @@ -230,46 +241,54 @@ class ActivationHardSwish : public Activation
class FastLUTActivation : public Activation
{
public:
FastLUTActivation(float min_x, float max_x, std::size_t size, std::function<float(float)> f)
: min_x_(min_x), max_x_(max_x), size_(size) {

step_ = (max_x - min_x) / (size - 1);
inv_step_ = 1.0f / step_;
table_.reserve(size);

for (std::size_t i = 0; i < size; ++i) {
table_.push_back(f(min_x + i * step_));
}
}
FastLUTActivation(float min_x, float max_x, std::size_t size, std::function<float(float)> f)
: min_x_(min_x)
, max_x_(max_x)
, size_(size)
{

// Fast lookup with linear interpolation
inline float lookup(float x) const {
// Clamp input to range
x = std::clamp(x, min_x_, max_x_);

// Calculate float index
float f_idx = (x - min_x_) * inv_step_;
std::size_t i = static_cast<std::size_t>(f_idx);

// Handle edge case at max_x_
if (i >= size_ - 1) return table_.back();

// Linear interpolation: y = y0 + (y1 - y0) * fractional_part
float frac = f_idx - static_cast<float>(i);
return table_[i] + (table_[i + 1] - table_[i]) * frac;
step_ = (max_x - min_x) / (size - 1);
inv_step_ = 1.0f / step_;
table_.reserve(size);

for (std::size_t i = 0; i < size; ++i)
{
table_.push_back(f(min_x + i * step_));
}
}

// Fast lookup with linear interpolation
inline float lookup(float x) const
{
// Clamp input to range
x = std::clamp(x, min_x_, max_x_);

// Vector application (Batch processing)
void apply(std::vector<float>& data) const {
for (float& val : data) {
val = lookup(val);
}
// Calculate float index
float f_idx = (x - min_x_) * inv_step_;
std::size_t i = static_cast<std::size_t>(f_idx);

// Handle edge case at max_x_
if (i >= size_ - 1)
return table_.back();

// Linear interpolation: y = y0 + (y1 - y0) * fractional_part
float frac = f_idx - static_cast<float>(i);
return table_[i] + (table_[i + 1] - table_[i]) * frac;
}

// Vector application (Batch processing)
void apply(std::vector<float>& data) const
{
for (float& val : data)
{
val = lookup(val);
}
}

private:
float min_x_, max_x_, step_, inv_step_;
size_t size_;
std::vector<float> table_;
float min_x_, max_x_, step_, inv_step_;
size_t size_;
std::vector<float> table_;
};

}; // namespace activations
Expand Down
Loading