Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions NAM/activations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ nam::activations::ActivationTanh _TANH = nam::activations::ActivationTanh();
nam::activations::ActivationFastTanh _FAST_TANH = nam::activations::ActivationFastTanh();
nam::activations::ActivationHardTanh _HARD_TANH = nam::activations::ActivationHardTanh();
nam::activations::ActivationReLU _RELU = nam::activations::ActivationReLU();
nam::activations::ActivationLeakyReLU _LEAKY_RELU = nam::activations::ActivationLeakyReLU(0.01); //FIXME does not parameterize LeakyReLU
nam::activations::ActivationLeakyReLU _LEAKY_RELU =
nam::activations::ActivationLeakyReLU(0.01); // FIXME does not parameterize LeakyReLU
nam::activations::ActivationPReLU _PRELU = nam::activations::ActivationPReLU(0.01); // Same as leaky ReLU by default
nam::activations::ActivationSigmoid _SIGMOID = nam::activations::ActivationSigmoid();
nam::activations::ActivationSwish _SWISH = nam::activations::ActivationSwish();
nam::activations::ActivationHardSwish _HARD_SWISH = nam::activations::ActivationHardSwish();
Expand All @@ -13,9 +15,10 @@ nam::activations::ActivationLeakyHardTanh _LEAKY_HARD_TANH = nam::activations::A
bool nam::activations::Activation::using_fast_tanh = false;

std::unordered_map<std::string, nam::activations::Activation*> nam::activations::Activation::_activations = {
{"Tanh", &_TANH}, {"Hardtanh", &_HARD_TANH}, {"Fasttanh", &_FAST_TANH},
{"ReLU", &_RELU}, {"LeakyReLU", &_LEAKY_RELU}, {"Sigmoid", &_SIGMOID},
{"SiLU", &_SWISH}, {"Hardswish", &_HARD_SWISH}, {"LeakyHardtanh", &_LEAKY_HARD_TANH}};
{"Tanh", &_TANH}, {"Hardtanh", &_HARD_TANH}, {"Fasttanh", &_FAST_TANH},
{"ReLU", &_RELU}, {"LeakyReLU", &_LEAKY_RELU}, {"Sigmoid", &_SIGMOID},
{"SiLU", &_SWISH}, {"Hardswish", &_HARD_SWISH}, {"LeakyHardtanh", &_LEAKY_HARD_TANH},
{"PReLU", &_PRELU}};

nam::activations::Activation* tanh_bak = nullptr;
nam::activations::Activation* sigmoid_bak = nullptr;
Expand Down Expand Up @@ -52,13 +55,18 @@ void nam::activations::Activation::disable_fast_tanh()
void nam::activations::Activation::enable_lut(std::string function_name, float min, float max, std::size_t n_points)
{
std::function<float(float)> fn;
if (function_name == "Tanh"){
if (function_name == "Tanh")
{
fn = [](float x) { return std::tanh(x); };
tanh_bak = _activations["Tanh"];
} else if (function_name == "Sigmoid") {
}
else if (function_name == "Sigmoid")
{
fn = sigmoid;
sigmoid_bak = _activations["Sigmoid"];
} else {
}
else
{
throw std::runtime_error("Tried to enable LUT for a function other than Tanh or Sigmoid");
}
FastLUTActivation lut_activation(min, max, n_points, fn);
Expand All @@ -67,12 +75,16 @@ void nam::activations::Activation::enable_lut(std::string function_name, float m

void nam::activations::Activation::disable_lut(std::string function_name)
{
if (function_name == "Tanh"){
if (function_name == "Tanh")
{
_activations["Tanh"] = tanh_bak;
} else if (function_name == "Sigmoid") {
}
else if (function_name == "Sigmoid")
{
_activations["Sigmoid"] = sigmoid_bak;
} else {
}
else
{
throw std::runtime_error("Tried to disable LUT for a function other than Tanh or Sigmoid");
}
}

148 changes: 102 additions & 46 deletions NAM/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,16 @@ inline float hard_tanh(float x)

inline float leaky_hardtanh(float x, float min_val, float max_val, float min_slope, float max_slope)
{
if (x < min_val) {
if (x < min_val)
{
return (x - min_val) * min_slope + min_val;
} else if (x > max_val) {
}
else if (x > max_val)
{
return (x - max_val) * max_slope + max_val;
} else {
}
else
{
return x;
}
}
Expand All @@ -50,7 +55,7 @@ inline float fast_sigmoid(const float x)
{
return 0.5f * (fast_tanh(x * 0.5f) + 1.0f);
}

inline float leaky_relu(float x, float negative_slope)
{
return x > 0.0f ? x : negative_slope * x;
Expand All @@ -68,12 +73,17 @@ inline float swish(float x)

inline float hardswish(float x)
{
if (x <= -3.0) {
if (x <= -3.0)
{
return 0;
} else if (x >= 3.0) {
}
else if (x >= 3.0)
{
return x;
} else {
return x * (x + 3.0)/6.0;
}
else
{
return x * (x + 3.0) / 6.0;
}
}

Expand Down Expand Up @@ -129,7 +139,8 @@ class ActivationLeakyHardTanh : public Activation
{
public:
ActivationLeakyHardTanh() = default;
ActivationLeakyHardTanh(float min_val_, float max_val_, float min_slope_, float max_slope_) {
ActivationLeakyHardTanh(float min_val_, float max_val_, float min_slope_, float max_slope_)
{
min_val = min_val_;
max_val = max_val_;
min_slope = min_slope_;
Expand All @@ -142,6 +153,7 @@ class ActivationLeakyHardTanh : public Activation
data[pos] = leaky_hardtanh(data[pos], min_val, max_val, min_slope, max_slope);
}
}

private:
float min_val = -1.0;
float max_val = 1.0;
Expand Down Expand Up @@ -177,20 +189,56 @@ class ActivationLeakyReLU : public Activation
{
public:
ActivationLeakyReLU() = default;
ActivationLeakyReLU(float ns) {
negative_slope = ns;
}
ActivationLeakyReLU(float ns) { negative_slope = ns; }
void apply(float* data, long size) override
{
for (long pos = 0; pos < size; pos++)
{
data[pos] = leaky_relu(data[pos], negative_slope);
}
}

private:
float negative_slope = 0.01;
};

class ActivationPReLU : public Activation
{
public:
ActivationPReLU() = default;
ActivationPReLU(float ns)
{
negative_slopes.clear();
negative_slopes.push_back(ns);
}
ActivationPReLU(std::vector<float> ns) { negative_slopes = ns; }

void apply(Eigen::MatrixXf& matrix) override
{
// Matrix is organized as (channels, time_steps)
int n_channels = negative_slopes.size();
int actual_channels = matrix.rows();

// NOTE: check not done during runtime on release builds
// model loader should make sure dimensions match
assert(actual_channels == n_channels);

// Apply each negative slope to its corresponding channel
for (int channel = 0; channel < std::min(n_channels, actual_channels); channel++)
{
// Apply the negative slope to all time steps in this channel
for (int time_step = 0; time_step < matrix.rows(); time_step++)
{
matrix(channel, time_step) = leaky_relu(matrix(channel, time_step), negative_slopes[channel]);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Feels inefficient since there's no vectorized ops being used, but we'll see.

}
}
}

private:
std::vector<float> negative_slopes;
};


class ActivationSigmoid : public Activation
{
public:
Expand Down Expand Up @@ -230,46 +278,54 @@ class ActivationHardSwish : public Activation
class FastLUTActivation : public Activation
{
public:
FastLUTActivation(float min_x, float max_x, std::size_t size, std::function<float(float)> f)
: min_x_(min_x), max_x_(max_x), size_(size) {

step_ = (max_x - min_x) / (size - 1);
inv_step_ = 1.0f / step_;
table_.reserve(size);

for (std::size_t i = 0; i < size; ++i) {
table_.push_back(f(min_x + i * step_));
}
}
FastLUTActivation(float min_x, float max_x, std::size_t size, std::function<float(float)> f)
: min_x_(min_x)
, max_x_(max_x)
, size_(size)
{

step_ = (max_x - min_x) / (size - 1);
inv_step_ = 1.0f / step_;
table_.reserve(size);

// Fast lookup with linear interpolation
inline float lookup(float x) const {
// Clamp input to range
x = std::clamp(x, min_x_, max_x_);

// Calculate float index
float f_idx = (x - min_x_) * inv_step_;
std::size_t i = static_cast<std::size_t>(f_idx);

// Handle edge case at max_x_
if (i >= size_ - 1) return table_.back();

// Linear interpolation: y = y0 + (y1 - y0) * fractional_part
float frac = f_idx - static_cast<float>(i);
return table_[i] + (table_[i + 1] - table_[i]) * frac;
for (std::size_t i = 0; i < size; ++i)
{
table_.push_back(f(min_x + i * step_));
}
}

// Fast lookup with linear interpolation
inline float lookup(float x) const
{
// Clamp input to range
x = std::clamp(x, min_x_, max_x_);

// Calculate float index
float f_idx = (x - min_x_) * inv_step_;
std::size_t i = static_cast<std::size_t>(f_idx);

// Handle edge case at max_x_
if (i >= size_ - 1)
return table_.back();

// Vector application (Batch processing)
void apply(std::vector<float>& data) const {
for (float& val : data) {
val = lookup(val);
}
// Linear interpolation: y = y0 + (y1 - y0) * fractional_part
float frac = f_idx - static_cast<float>(i);
return table_[i] + (table_[i + 1] - table_[i]) * frac;
}

// Vector application (Batch processing)
void apply(std::vector<float>& data) const
{
for (float& val : data)
{
val = lookup(val);
}
}

private:
float min_x_, max_x_, step_, inv_step_;
size_t size_;
std::vector<float> table_;
float min_x_, max_x_, step_, inv_step_;
size_t size_;
std::vector<float> table_;
};

}; // namespace activations
Expand Down
5 changes: 5 additions & 0 deletions tools/run_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ int main()
test_activations::TestLeakyReLU::test_get_by_init();
test_activations::TestLeakyReLU::test_get_by_str();

test_activations::TestPReLU::test_core_function();
test_activations::TestPReLU::test_per_channel_behavior();
// This is enforced by an assert so it doesn't need to be tested
//test_activations::TestPReLU::test_wrong_number_of_channels();

test_dsp::test_construct();
test_dsp::test_get_input_level();
test_dsp::test_get_output_level();
Expand Down
76 changes: 76 additions & 0 deletions tools/test/test_activations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <cassert>
#include <string>
#include <vector>
#include <cmath>

#include "NAM/activations.h"

Expand Down Expand Up @@ -119,4 +120,79 @@ class TestLeakyReLU
}
};
};
class TestPReLU
{
public:
static void test_core_function()
{
// Test the basic leaky_relu function that PReLU uses
auto TestCase = [](float input, float slope, float expectedOutput) {
float actualOutput = nam::activations::leaky_relu(input, slope);
assert(actualOutput == expectedOutput);
};

// A few snapshot tests
TestCase(0.0f, 0.01f, 0.0f);
TestCase(1.0f, 0.01f, 1.0f);
TestCase(-1.0f, 0.01f, -0.01f);
TestCase(-1.0f, 0.05f, -0.05f); // Different slope
}

static void test_per_channel_behavior()
{
// Test that different slopes are applied to different channels
Eigen::MatrixXf data(2, 3); // 2 channels, 3 time steps

// Initialize with some test data
data << -1.0f, 0.5f, 1.0f, -2.0f, -0.5f, 0.0f;

// Create PReLU with different slopes for each channel
std::vector<float> slopes = {0.01f, 0.05f}; // slope 0.01 for channel 0, 0.05 for channel 1
nam::activations::ActivationPReLU prelu(slopes);

// Apply the activation
prelu.apply(data);

// Verify the results
// Channel 0 (slope = 0.01):
assert(fabs(data(0, 0) - (-0.01f)) < 1e-6); // -1.0 * 0.01 = -0.01
assert(fabs(data(0, 1) - 0.5f) < 1e-6); // 0.5 (positive, unchanged)
assert(fabs(data(0, 2) - 1.0f) < 1e-6); // 1.0 (positive, unchanged)

// Channel 1 (slope = 0.05):
assert(fabs(data(1, 0) - (-0.10f)) < 1e-6); // -2.0 * 0.05 = -0.10
assert(fabs(data(1, 1) - (-0.025f)) < 1e-6); // -0.5 * 0.05 = -0.025
assert(fabs(data(1, 2) - 0.0f) < 1e-6); // 0.0 (unchanged)
}

static void test_wrong_number_of_channels()
{
// Test that we fail when we have more channels than slopes
Eigen::MatrixXf data(3, 2); // 3 channels, 2 time steps

// Initialize with test data
data << -1.0f, -1.0f, -1.0f, 1.0f, 1.0f, 1.0f;

// Create PReLU with only 2 slopes for 3 channels
std::vector<float> slopes = {0.01f, 0.05f};
nam::activations::ActivationPReLU prelu(slopes);

// Apply the activation
bool caught = false;
try
{
prelu.apply(data);
}
catch (const std::runtime_error& e)
{
caught = true;
}
catch (...)
{
}

assert(caught);
}
};

}; // namespace test_activations
Loading