Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion NAM/activations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ nam::activations::ActivationHardTanh _HARD_TANH = nam::activations::ActivationHa
nam::activations::ActivationReLU _RELU = nam::activations::ActivationReLU();
nam::activations::ActivationLeakyReLU _LEAKY_RELU =
nam::activations::ActivationLeakyReLU(0.01); // FIXME does not parameterize LeakyReLU
nam::activations::ActivationPReLU _PRELU = nam::activations::ActivationPReLU(0.01); // Same as leaky ReLU by default
nam::activations::ActivationSigmoid _SIGMOID = nam::activations::ActivationSigmoid();
nam::activations::ActivationSwish _SWISH = nam::activations::ActivationSwish();
nam::activations::ActivationHardSwish _HARD_SWISH = nam::activations::ActivationHardSwish();
Expand All @@ -16,7 +17,8 @@ bool nam::activations::Activation::using_fast_tanh = false;
std::unordered_map<std::string, nam::activations::Activation*> nam::activations::Activation::_activations = {
{"Tanh", &_TANH}, {"Hardtanh", &_HARD_TANH}, {"Fasttanh", &_FAST_TANH},
{"ReLU", &_RELU}, {"LeakyReLU", &_LEAKY_RELU}, {"Sigmoid", &_SIGMOID},
{"SiLU", &_SWISH}, {"Hardswish", &_HARD_SWISH}, {"LeakyHardtanh", &_LEAKY_HARD_TANH}};
{"SiLU", &_SWISH}, {"Hardswish", &_HARD_SWISH}, {"LeakyHardtanh", &_LEAKY_HARD_TANH},
{"PReLU", &_PRELU}};

nam::activations::Activation* tanh_bak = nullptr;
nam::activations::Activation* sigmoid_bak = nullptr;
Expand Down
37 changes: 37 additions & 0 deletions NAM/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,43 @@ class ActivationLeakyReLU : public Activation
float negative_slope = 0.01;
};

class ActivationPReLU : public Activation
{
public:
ActivationPReLU() = default;
ActivationPReLU(float ns)
{
negative_slopes.clear();
negative_slopes.push_back(ns);
}
ActivationPReLU(std::vector<float> ns) { negative_slopes = ns; }

void apply(Eigen::MatrixXf& matrix) override
{
// Matrix is organized as (channels, time_steps)
int n_channels = negative_slopes.size();
int actual_channels = matrix.rows();

// NOTE: check not done during runtime on release builds
// model loader should make sure dimensions match
assert(actual_channels == n_channels);

// Apply each negative slope to its corresponding channel
for (int channel = 0; channel < std::min(n_channels, actual_channels); channel++)
{
// Apply the negative slope to all time steps in this channel
for (int time_step = 0; time_step < matrix.rows(); time_step++)
{
matrix(channel, time_step) = leaky_relu(matrix(channel, time_step), negative_slopes[channel]);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Feels inefficient since there's no vectorized ops being used, but we'll see.

}
}
}

private:
std::vector<float> negative_slopes;
};


class ActivationSigmoid : public Activation
{
public:
Expand Down
5 changes: 5 additions & 0 deletions tools/run_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ int main()
test_activations::TestLeakyReLU::test_get_by_init();
test_activations::TestLeakyReLU::test_get_by_str();

test_activations::TestPReLU::test_core_function();
test_activations::TestPReLU::test_per_channel_behavior();
// This is enforced by an assert so it doesn't need to be tested
//test_activations::TestPReLU::test_wrong_number_of_channels();

test_dsp::test_construct();
test_dsp::test_get_input_level();
test_dsp::test_get_output_level();
Expand Down
76 changes: 76 additions & 0 deletions tools/test/test_activations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <cassert>
#include <string>
#include <vector>
#include <cmath>

#include "NAM/activations.h"

Expand Down Expand Up @@ -119,4 +120,79 @@ class TestLeakyReLU
}
};
};
class TestPReLU
{
public:
static void test_core_function()
{
// Test the basic leaky_relu function that PReLU uses
auto TestCase = [](float input, float slope, float expectedOutput) {
float actualOutput = nam::activations::leaky_relu(input, slope);
assert(actualOutput == expectedOutput);
};

// A few snapshot tests
TestCase(0.0f, 0.01f, 0.0f);
TestCase(1.0f, 0.01f, 1.0f);
TestCase(-1.0f, 0.01f, -0.01f);
TestCase(-1.0f, 0.05f, -0.05f); // Different slope
}

static void test_per_channel_behavior()
{
// Test that different slopes are applied to different channels
Eigen::MatrixXf data(2, 3); // 2 channels, 3 time steps

// Initialize with some test data
data << -1.0f, 0.5f, 1.0f, -2.0f, -0.5f, 0.0f;

// Create PReLU with different slopes for each channel
std::vector<float> slopes = {0.01f, 0.05f}; // slope 0.01 for channel 0, 0.05 for channel 1
nam::activations::ActivationPReLU prelu(slopes);

// Apply the activation
prelu.apply(data);

// Verify the results
// Channel 0 (slope = 0.01):
assert(fabs(data(0, 0) - (-0.01f)) < 1e-6); // -1.0 * 0.01 = -0.01
assert(fabs(data(0, 1) - 0.5f) < 1e-6); // 0.5 (positive, unchanged)
assert(fabs(data(0, 2) - 1.0f) < 1e-6); // 1.0 (positive, unchanged)

// Channel 1 (slope = 0.05):
assert(fabs(data(1, 0) - (-0.10f)) < 1e-6); // -2.0 * 0.05 = -0.10
assert(fabs(data(1, 1) - (-0.025f)) < 1e-6); // -0.5 * 0.05 = -0.025
assert(fabs(data(1, 2) - 0.0f) < 1e-6); // 0.0 (unchanged)
}

static void test_wrong_number_of_channels()
{
// Test that we fail when we have more channels than slopes
Eigen::MatrixXf data(3, 2); // 3 channels, 2 time steps

// Initialize with test data
data << -1.0f, -1.0f, -1.0f, 1.0f, 1.0f, 1.0f;

// Create PReLU with only 2 slopes for 3 channels
std::vector<float> slopes = {0.01f, 0.05f};
nam::activations::ActivationPReLU prelu(slopes);

// Apply the activation
bool caught = false;
try
{
prelu.apply(data);
}
catch (const std::runtime_error& e)
{
caught = true;
}
catch (...)
{
}

assert(caught);
}
};

}; // namespace test_activations