Skip to content

Commit 6ea03b1

Browse files
jfsantosJoão Felipe Santos
andauthored
Adding activation functions and fast LUT implementation (#177)
* Added a few new activation functions. LeakyReLU is broken because of how activations classes are pre-instantiated in activations.cpp. * Fixed initialization issues with LeakyReLU, will require changes to code currently using the default constructor. * Added tests to fast_lut * Addressing comments * Refactored FastLUT into an activation class, added functions to enable/disable LUT for tanh and sigmoid --------- Co-authored-by: João Felipe Santos <[email protected]>
1 parent e5cc355 commit 6ea03b1

File tree

7 files changed

+221
-12
lines changed

7 files changed

+221
-12
lines changed

NAM/activations.cpp

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,21 @@ nam::activations::ActivationTanh _TANH = nam::activations::ActivationTanh();
44
nam::activations::ActivationFastTanh _FAST_TANH = nam::activations::ActivationFastTanh();
55
nam::activations::ActivationHardTanh _HARD_TANH = nam::activations::ActivationHardTanh();
66
nam::activations::ActivationReLU _RELU = nam::activations::ActivationReLU();
7-
nam::activations::ActivationLeakyReLU _LEAKY_RELU = nam::activations::ActivationLeakyReLU();
7+
nam::activations::ActivationLeakyReLU _LEAKY_RELU = nam::activations::ActivationLeakyReLU(0.01); //FIXME does not parameterize LeakyReLU
88
nam::activations::ActivationSigmoid _SIGMOID = nam::activations::ActivationSigmoid();
9+
nam::activations::ActivationSwish _SWISH = nam::activations::ActivationSwish();
10+
nam::activations::ActivationHardSwish _HARD_SWISH = nam::activations::ActivationHardSwish();
11+
nam::activations::ActivationLeakyHardTanh _LEAKY_HARD_TANH = nam::activations::ActivationLeakyHardTanh();
912

1013
bool nam::activations::Activation::using_fast_tanh = false;
1114

1215
std::unordered_map<std::string, nam::activations::Activation*> nam::activations::Activation::_activations = {
1316
{"Tanh", &_TANH}, {"Hardtanh", &_HARD_TANH}, {"Fasttanh", &_FAST_TANH},
14-
{"ReLU", &_RELU}, {"LeakyReLU", &_LEAKY_RELU}, {"Sigmoid", &_SIGMOID}};
17+
{"ReLU", &_RELU}, {"LeakyReLU", &_LEAKY_RELU}, {"Sigmoid", &_SIGMOID},
18+
{"SiLU", &_SWISH}, {"Hardswish", &_HARD_SWISH}, {"LeakyHardtanh", &_LEAKY_HARD_TANH}};
1519

1620
nam::activations::Activation* tanh_bak = nullptr;
21+
nam::activations::Activation* sigmoid_bak = nullptr;
1722

1823
nam::activations::Activation* nam::activations::Activation::get_activation(const std::string name)
1924
{
@@ -43,3 +48,31 @@ void nam::activations::Activation::disable_fast_tanh()
4348
_activations["Tanh"] = tanh_bak;
4449
}
4550
}
51+
52+
void nam::activations::Activation::enable_lut(std::string function_name, float min, float max, std::size_t n_points)
53+
{
54+
std::function<float(float)> fn;
55+
if (function_name == "Tanh"){
56+
fn = [](float x) { return std::tanh(x); };
57+
tanh_bak = _activations["Tanh"];
58+
} else if (function_name == "Sigmoid") {
59+
fn = sigmoid;
60+
sigmoid_bak = _activations["Sigmoid"];
61+
} else {
62+
throw std::runtime_error("Tried to enable LUT for a function other than Tanh or Sigmoid");
63+
}
64+
FastLUTActivation lut_activation(min, max, n_points, fn);
65+
_activations[function_name] = &lut_activation;
66+
}
67+
68+
void nam::activations::Activation::disable_lut(std::string function_name)
69+
{
70+
if (function_name == "Tanh"){
71+
_activations["Tanh"] = tanh_bak;
72+
} else if (function_name == "Sigmoid") {
73+
_activations["Sigmoid"] = sigmoid_bak;
74+
} else {
75+
throw std::runtime_error("Tried to disable LUT for a function other than Tanh or Sigmoid");
76+
}
77+
}
78+

NAM/activations.h

Lines changed: 138 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <cmath> // expf
55
#include <unordered_map>
66
#include <Eigen/Dense>
7+
#include <functional>
78

89
namespace nam
910
{
@@ -25,6 +26,17 @@ inline float hard_tanh(float x)
2526
return t > 1 ? 1 : t;
2627
}
2728

29+
inline float leaky_hardtanh(float x, float min_val, float max_val, float min_slope, float max_slope)
30+
{
31+
if (x < min_val) {
32+
return (x - min_val) * min_slope + min_val;
33+
} else if (x > max_val) {
34+
return (x - max_val) * max_slope + max_val;
35+
} else {
36+
return x;
37+
}
38+
}
39+
2840
inline float fast_tanh(const float x)
2941
{
3042
const float ax = fabsf(x);
@@ -38,14 +50,32 @@ inline float fast_sigmoid(const float x)
3850
{
3951
return 0.5f * (fast_tanh(x * 0.5f) + 1.0f);
4052
}
41-
42-
// Assumes PyTorch default of 0.01 for negative slope. This may change to be
43-
// configurable in the future.
44-
inline float leaky_relu(float x)
53+
54+
inline float leaky_relu(float x, float negative_slope)
4555
{
46-
const float negative_slope = 0.01;
4756
return x > 0.0f ? x : negative_slope * x;
4857
}
58+
inline float leaky_relu(float x)
59+
{
60+
return leaky_relu(x, 0.01);
61+
}
62+
63+
64+
inline float swish(float x)
65+
{
66+
return x * sigmoid(x);
67+
}
68+
69+
inline float hardswish(float x)
70+
{
71+
if (x <= -3.0) {
72+
return 0;
73+
} else if (x >= 3.0) {
74+
return x;
75+
} else {
76+
return x * (x + 3.0)/6.0;
77+
}
78+
}
4979

5080
class Activation
5181
{
@@ -64,6 +94,8 @@ class Activation
6494
static void enable_fast_tanh();
6595
static void disable_fast_tanh();
6696
static bool using_fast_tanh;
97+
static void enable_lut(std::string function_name, float min, float max, std::size_t n_points);
98+
static void disable_lut(std::string function_name);
6799

68100
protected:
69101
static std::unordered_map<std::string, Activation*> _activations;
@@ -93,6 +125,30 @@ class ActivationHardTanh : public Activation
93125
}
94126
};
95127

128+
class ActivationLeakyHardTanh : public Activation
129+
{
130+
public:
131+
ActivationLeakyHardTanh() = default;
132+
ActivationLeakyHardTanh(float min_val_, float max_val_, float min_slope_, float max_slope_) {
133+
min_val = min_val_;
134+
max_val = max_val_;
135+
min_slope = min_slope_;
136+
max_slope = max_slope_;
137+
}
138+
void apply(float* data, long size) override
139+
{
140+
for (long pos = 0; pos < size; pos++)
141+
{
142+
data[pos] = leaky_hardtanh(data[pos], min_val, max_val, min_slope, max_slope);
143+
}
144+
}
145+
private:
146+
float min_val = -1.0;
147+
float max_val = 1.0;
148+
float min_slope = 0.01;
149+
float max_slope = 0.01;
150+
};
151+
96152
class ActivationFastTanh : public Activation
97153
{
98154
public:
@@ -120,13 +176,19 @@ class ActivationReLU : public Activation
120176
class ActivationLeakyReLU : public Activation
121177
{
122178
public:
179+
ActivationLeakyReLU() = default;
180+
ActivationLeakyReLU(float ns) {
181+
negative_slope = ns;
182+
}
123183
void apply(float* data, long size) override
124184
{
125185
for (long pos = 0; pos < size; pos++)
126186
{
127-
data[pos] = leaky_relu(data[pos]);
187+
data[pos] = leaky_relu(data[pos], negative_slope);
128188
}
129189
}
190+
private:
191+
float negative_slope = 0.01;
130192
};
131193

132194
class ActivationSigmoid : public Activation
@@ -140,5 +202,75 @@ class ActivationSigmoid : public Activation
140202
}
141203
}
142204
};
205+
206+
class ActivationSwish : public Activation
207+
{
208+
public:
209+
void apply(float* data, long size) override
210+
{
211+
for (long pos = 0; pos < size; pos++)
212+
{
213+
data[pos] = swish(data[pos]);
214+
}
215+
}
216+
};
217+
218+
class ActivationHardSwish : public Activation
219+
{
220+
public:
221+
void apply(float* data, long size) override
222+
{
223+
for (long pos = 0; pos < size; pos++)
224+
{
225+
data[pos] = hardswish(data[pos]);
226+
}
227+
}
228+
};
229+
230+
class FastLUTActivation : public Activation
231+
{
232+
public:
233+
FastLUTActivation(float min_x, float max_x, std::size_t size, std::function<float(float)> f)
234+
: min_x_(min_x), max_x_(max_x), size_(size) {
235+
236+
step_ = (max_x - min_x) / (size - 1);
237+
inv_step_ = 1.0f / step_;
238+
table_.reserve(size);
239+
240+
for (std::size_t i = 0; i < size; ++i) {
241+
table_.push_back(f(min_x + i * step_));
242+
}
243+
}
244+
245+
// Fast lookup with linear interpolation
246+
inline float lookup(float x) const {
247+
// Clamp input to range
248+
x = std::clamp(x, min_x_, max_x_);
249+
250+
// Calculate float index
251+
float f_idx = (x - min_x_) * inv_step_;
252+
std::size_t i = static_cast<std::size_t>(f_idx);
253+
254+
// Handle edge case at max_x_
255+
if (i >= size_ - 1) return table_.back();
256+
257+
// Linear interpolation: y = y0 + (y1 - y0) * fractional_part
258+
float frac = f_idx - static_cast<float>(i);
259+
return table_[i] + (table_[i + 1] - table_[i]) * frac;
260+
}
261+
262+
// Vector application (Batch processing)
263+
void apply(std::vector<float>& data) const {
264+
for (float& val : data) {
265+
val = lookup(val);
266+
}
267+
}
268+
269+
private:
270+
float min_x_, max_x_, step_, inv_step_;
271+
size_t size_;
272+
std::vector<float> table_;
273+
};
274+
143275
}; // namespace activations
144276
}; // namespace nam

NAM/wavenet.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ void nam::wavenet::_Layer::process_(const Eigen::MatrixXf& input, const Eigen::M
5252
for (int i = 0; i < num_frames; i++)
5353
{
5454
this->_activation->apply(this->_z.block(0, i, channels, 1));
55+
// TODO Need to support other activation functions here instead of hardcoded sigmoid
5556
activations::Activation::get_activation("Sigmoid")->apply(this->_z.block(channels, i, channels, 1));
5657
}
5758
this->_z.block(0, 0, channels, num_frames).array() *= this->_z.block(channels, 0, channels, num_frames).array();

NAM/wavenet.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class _Layer
2828
: _conv(channels, gated ? 2 * channels : channels, kernel_size, true, dilation)
2929
, _input_mixin(condition_size, gated ? 2 * channels : channels, false)
3030
, _1x1(channels, channels, true)
31-
, _activation(activations::Activation::get_activation(activation))
31+
, _activation(activations::Activation::get_activation(activation)) // needs to support activations with parameters
3232
, _gated(gated) {};
3333
// Resize all arrays to be able to process `maxBufferSize` frames.
3434
void SetMaxBufferSize(const int maxBufferSize);

tools/run_tests.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "test/test_dsp.cpp"
77
#include "test/test_get_dsp.cpp"
88
#include "test/test_wavenet.cpp"
9+
#include "test/test_fast_lut.cpp"
910

1011
int main()
1112
{
@@ -35,6 +36,9 @@ int main()
3536

3637
test_wavenet::test_gated();
3738

39+
test_lut::TestFastLUT::test_sigmoid();
40+
test_lut::TestFastLUT::test_tanh();
41+
3842
std::cout << "Success!" << std::endl;
3943
return 0;
40-
}
44+
}

tools/test/test_activations.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class TestLeakyReLU
7373
static void test_core_function()
7474
{
7575
auto TestCase = [](float input, float expectedOutput) {
76-
float actualOutput = nam::activations::leaky_relu(input);
76+
float actualOutput = nam::activations::leaky_relu(input, 0.01);
7777
assert(actualOutput == expectedOutput);
7878
};
7979
// A few snapshot tests
@@ -84,7 +84,7 @@ class TestLeakyReLU
8484

8585
static void test_get_by_init()
8686
{
87-
auto a = nam::activations::ActivationLeakyReLU();
87+
auto a = nam::activations::ActivationLeakyReLU(0.01);
8888
_test_class(&a);
8989
}
9090

tools/test/test_fast_lut.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#include <cassert>
2+
#include <string>
3+
#include <vector>
4+
#include <cmath>
5+
6+
#include "NAM/activations.h"
7+
8+
namespace test_lut {
9+
10+
float sigmoid(float x) {
11+
return 1.0f / (1.0f + std::exp(-x));
12+
}
13+
14+
class TestFastLUT
15+
{
16+
public:
17+
static void test_sigmoid()
18+
{
19+
// create a lut for sigmoid from -8.0 to 8.0 with 1024 samples
20+
nam::activations::FastLUTActivation lut_sigmoid(-8.0f, 8.0f, 1024, [](float x) {
21+
return 1.0f / (1.0f + expf(-x));
22+
});
23+
24+
float input = 1.25f;
25+
assert(abs(sigmoid(input) - lut_sigmoid.lookup(input)) < 1e-3);
26+
}
27+
static void test_tanh()
28+
{
29+
// create a lut for sigmoid from -8.0 to 8.0 with 1024 samples
30+
nam::activations::FastLUTActivation lut_tanh(-8.0f, 8.0f, 1024, [](float x) {
31+
return std::tanh(x);
32+
});
33+
34+
float input = 1.25f;
35+
assert(abs(std::tanh(input) - lut_tanh.lookup(input)) < 1e-3);
36+
}
37+
};
38+
}
39+

0 commit comments

Comments
 (0)