Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion NAM/activations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ nam::activations::ActivationFastTanh _FAST_TANH = nam::activations::ActivationFa
nam::activations::ActivationHardTanh _HARD_TANH = nam::activations::ActivationHardTanh();
nam::activations::ActivationReLU _RELU = nam::activations::ActivationReLU();
nam::activations::ActivationLeakyReLU _LEAKY_RELU = nam::activations::ActivationLeakyReLU(0.01); //FIXME does not parameterize LeakyReLU
nam::activations::ActivationPReLU _PRELU = nam::activations::ActivationPReLU(0.01); //Same as leaky ReLU by default
nam::activations::ActivationSigmoid _SIGMOID = nam::activations::ActivationSigmoid();
nam::activations::ActivationSwish _SWISH = nam::activations::ActivationSwish();
nam::activations::ActivationHardSwish _HARD_SWISH = nam::activations::ActivationHardSwish();
Expand All @@ -15,7 +16,8 @@ bool nam::activations::Activation::using_fast_tanh = false;
std::unordered_map<std::string, nam::activations::Activation*> nam::activations::Activation::_activations = {
{"Tanh", &_TANH}, {"Hardtanh", &_HARD_TANH}, {"Fasttanh", &_FAST_TANH},
{"ReLU", &_RELU}, {"LeakyReLU", &_LEAKY_RELU}, {"Sigmoid", &_SIGMOID},
{"SiLU", &_SWISH}, {"Hardswish", &_HARD_SWISH}, {"LeakyHardtanh", &_LEAKY_HARD_TANH}};
{"SiLU", &_SWISH}, {"Hardswish", &_HARD_SWISH}, {"LeakyHardtanh", &_LEAKY_HARD_TANH},
{"PReLU", &_PRELU}};

nam::activations::Activation* tanh_bak = nullptr;
nam::activations::Activation* sigmoid_bak = nullptr;
Expand Down
53 changes: 53 additions & 0 deletions NAM/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,59 @@ class ActivationLeakyReLU : public Activation
float negative_slope = 0.01;
};

class ActivationPReLU : public Activation
{
public:
ActivationPReLU() = default;
ActivationPReLU(float ns) {
negative_slopes.clear();
negative_slopes.push_back(ns);
}
ActivationPReLU(std::vector<float> ns) {
negative_slopes = ns;
}

void apply(Eigen::MatrixXf& matrix) override
{
// Matrix is organized as (channels, time_steps)
int n_channels = negative_slopes.size();
int actual_channels = matrix.rows();

if (actual_channels > n_channels)
{
throw std::runtime_error("Number of channels in PReLU activation different from input matrix");
}

// Apply each negative slope to its corresponding channel
for (int channel = 0; channel < std::min(n_channels, actual_channels); channel++)
{
// Apply the negative slope to all time steps in this channel
for (int time_step = 0; time_step < matrix.rows(); time_step++)
{
matrix(channel, time_step) = leaky_relu(matrix(channel, time_step), negative_slopes[channel]);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Feels inefficient since there's no vectorized ops being used, but we'll see.

}
}
}

void apply(float* data, long size) override
{
// Fallback that operates like leaky_relu, should not be used as it's a waste of a vector for one element
if (!negative_slopes.empty())
{
float slope = negative_slopes[0]; // Use first slope as fallback
for (long pos = 0; pos < size; pos++)
{
data[pos] = leaky_relu(data[pos], slope);
}
} else {
throw std::runtime_error("negative_slopes not initialized");
}
}
private:
std::vector<float> negative_slopes;
};


class ActivationSigmoid : public Activation
{
public:
Expand Down
4 changes: 4 additions & 0 deletions tools/run_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ int main()
test_activations::TestLeakyReLU::test_get_by_init();
test_activations::TestLeakyReLU::test_get_by_str();

test_activations::TestPReLU::test_core_function();
test_activations::TestPReLU::test_per_channel_behavior();
test_activations::TestPReLU::test_wrong_number_of_channels();

test_dsp::test_construct();
test_dsp::test_get_input_level();
test_dsp::test_get_output_level();
Expand Down
77 changes: 77 additions & 0 deletions tools/test/test_activations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <cassert>
#include <string>
#include <vector>
#include <cmath>

#include "NAM/activations.h"

Expand Down Expand Up @@ -119,4 +120,80 @@ class TestLeakyReLU
}
};
};
class TestPReLU
{
public:
static void test_core_function()
{
// Test the basic leaky_relu function that PReLU uses
auto TestCase = [](float input, float slope, float expectedOutput) {
float actualOutput = nam::activations::leaky_relu(input, slope);
assert(actualOutput == expectedOutput);
};

// A few snapshot tests
TestCase(0.0f, 0.01f, 0.0f);
TestCase(1.0f, 0.01f, 1.0f);
TestCase(-1.0f, 0.01f, -0.01f);
TestCase(-1.0f, 0.05f, -0.05f); // Different slope
}

static void test_per_channel_behavior()
{
// Test that different slopes are applied to different channels
Eigen::MatrixXf data(2, 3); // 2 channels, 3 time steps

// Initialize with some test data
data << -1.0f, 0.5f, 1.0f,
-2.0f, -0.5f, 0.0f;

// Create PReLU with different slopes for each channel
std::vector<float> slopes = {0.01f, 0.05f}; // slope 0.01 for channel 0, 0.05 for channel 1
nam::activations::ActivationPReLU prelu(slopes);

// Apply the activation
prelu.apply(data);

// Verify the results
// Channel 0 (slope = 0.01):
assert(fabs(data(0, 0) - (-0.01f)) < 1e-6); // -1.0 * 0.01 = -0.01
assert(fabs(data(0, 1) - 0.5f) < 1e-6); // 0.5 (positive, unchanged)
assert(fabs(data(0, 2) - 1.0f) < 1e-6); // 1.0 (positive, unchanged)

// Channel 1 (slope = 0.05):
assert(fabs(data(1, 0) - (-0.10f)) < 1e-6); // -2.0 * 0.05 = -0.10
assert(fabs(data(1, 1) - (-0.025f)) < 1e-6); // -0.5 * 0.05 = -0.025
assert(fabs(data(1, 2) - 0.0f) < 1e-6); // 0.0 (unchanged)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strange that these would pass if I'm not mistaken...

}

static void test_wrong_number_of_channels()
{
// Test that we fail when we have more channels than slopes
Eigen::MatrixXf data(3, 2); // 3 channels, 2 time steps

// Initialize with test data
data << -1.0f, -1.0f,
-1.0f, 1.0f,
1.0f, 1.0f;

// Create PReLU with only 2 slopes for 3 channels
std::vector<float> slopes = {0.01f, 0.05f};
nam::activations::ActivationPReLU prelu(slopes);

// Apply the activation
bool caught = false;
try
{
prelu.apply(data);
} catch (const std::runtime_error& e) {
caught = true;
} catch (...) {

}

assert(caught == true);
}

};

}; // namespace test_activations