44#include < cmath> // expf
55#include < unordered_map>
66#include < Eigen/Dense>
7+ #include < functional>
78
89namespace nam
910{
@@ -25,6 +26,17 @@ inline float hard_tanh(float x)
2526 return t > 1 ? 1 : t;
2627}
2728
29+ inline float leaky_hardtanh (float x, float min_val, float max_val, float min_slope, float max_slope)
30+ {
31+ if (x < min_val) {
32+ return (x - min_val) * min_slope + min_val;
33+ } else if (x > max_val) {
34+ return (x - max_val) * max_slope + max_val;
35+ } else {
36+ return x;
37+ }
38+ }
39+
2840inline float fast_tanh (const float x)
2941{
3042 const float ax = fabsf (x);
@@ -38,14 +50,32 @@ inline float fast_sigmoid(const float x)
3850{
3951 return 0 .5f * (fast_tanh (x * 0 .5f ) + 1 .0f );
4052}
41-
42- // Assumes PyTorch default of 0.01 for negative slope. This may change to be
43- // configurable in the future.
44- inline float leaky_relu (float x)
53+
54+ inline float leaky_relu (float x, float negative_slope)
4555{
46- const float negative_slope = 0.01 ;
4756 return x > 0 .0f ? x : negative_slope * x;
4857}
58+ inline float leaky_relu (float x)
59+ {
60+ return leaky_relu (x, 0.01 );
61+ }
62+
63+
64+ inline float swish (float x)
65+ {
66+ return x * sigmoid (x);
67+ }
68+
69+ inline float hardswish (float x)
70+ {
71+ if (x <= -3.0 ) {
72+ return 0 ;
73+ } else if (x >= 3.0 ) {
74+ return x;
75+ } else {
76+ return x * (x + 3.0 )/6.0 ;
77+ }
78+ }
4979
5080class Activation
5181{
@@ -64,6 +94,8 @@ class Activation
6494 static void enable_fast_tanh ();
6595 static void disable_fast_tanh ();
6696 static bool using_fast_tanh;
97+ static void enable_lut (std::string function_name, float min, float max, std::size_t n_points);
98+ static void disable_lut (std::string function_name);
6799
68100protected:
69101 static std::unordered_map<std::string, Activation*> _activations;
@@ -93,6 +125,30 @@ class ActivationHardTanh : public Activation
93125 }
94126};
95127
128+ class ActivationLeakyHardTanh : public Activation
129+ {
130+ public:
131+ ActivationLeakyHardTanh () = default ;
132+ ActivationLeakyHardTanh (float min_val_, float max_val_, float min_slope_, float max_slope_) {
133+ min_val = min_val_;
134+ max_val = max_val_;
135+ min_slope = min_slope_;
136+ max_slope = max_slope_;
137+ }
138+ void apply (float * data, long size) override
139+ {
140+ for (long pos = 0 ; pos < size; pos++)
141+ {
142+ data[pos] = leaky_hardtanh (data[pos], min_val, max_val, min_slope, max_slope);
143+ }
144+ }
145+ private:
146+ float min_val = -1.0 ;
147+ float max_val = 1.0 ;
148+ float min_slope = 0.01 ;
149+ float max_slope = 0.01 ;
150+ };
151+
96152class ActivationFastTanh : public Activation
97153{
98154public:
@@ -120,13 +176,19 @@ class ActivationReLU : public Activation
120176class ActivationLeakyReLU : public Activation
121177{
122178public:
179+ ActivationLeakyReLU () = default ;
180+ ActivationLeakyReLU (float ns) {
181+ negative_slope = ns;
182+ }
123183 void apply (float * data, long size) override
124184 {
125185 for (long pos = 0 ; pos < size; pos++)
126186 {
127- data[pos] = leaky_relu (data[pos]);
187+ data[pos] = leaky_relu (data[pos], negative_slope );
128188 }
129189 }
190+ private:
191+ float negative_slope = 0.01 ;
130192};
131193
132194class ActivationSigmoid : public Activation
@@ -140,5 +202,75 @@ class ActivationSigmoid : public Activation
140202 }
141203 }
142204};
205+
206+ class ActivationSwish : public Activation
207+ {
208+ public:
209+ void apply (float * data, long size) override
210+ {
211+ for (long pos = 0 ; pos < size; pos++)
212+ {
213+ data[pos] = swish (data[pos]);
214+ }
215+ }
216+ };
217+
218+ class ActivationHardSwish : public Activation
219+ {
220+ public:
221+ void apply (float * data, long size) override
222+ {
223+ for (long pos = 0 ; pos < size; pos++)
224+ {
225+ data[pos] = hardswish (data[pos]);
226+ }
227+ }
228+ };
229+
230+ class FastLUTActivation : public Activation
231+ {
232+ public:
233+ FastLUTActivation (float min_x, float max_x, std::size_t size, std::function<float (float )> f)
234+ : min_x_(min_x), max_x_(max_x), size_(size) {
235+
236+ step_ = (max_x - min_x) / (size - 1 );
237+ inv_step_ = 1 .0f / step_;
238+ table_.reserve (size);
239+
240+ for (std::size_t i = 0 ; i < size; ++i) {
241+ table_.push_back (f (min_x + i * step_));
242+ }
243+ }
244+
245+ // Fast lookup with linear interpolation
246+ inline float lookup (float x) const {
247+ // Clamp input to range
248+ x = std::clamp (x, min_x_, max_x_);
249+
250+ // Calculate float index
251+ float f_idx = (x - min_x_) * inv_step_;
252+ std::size_t i = static_cast <std::size_t >(f_idx);
253+
254+ // Handle edge case at max_x_
255+ if (i >= size_ - 1 ) return table_.back ();
256+
257+ // Linear interpolation: y = y0 + (y1 - y0) * fractional_part
258+ float frac = f_idx - static_cast <float >(i);
259+ return table_[i] + (table_[i + 1 ] - table_[i]) * frac;
260+ }
261+
262+ // Vector application (Batch processing)
263+ void apply (std::vector<float >& data) const {
264+ for (float & val : data) {
265+ val = lookup (val);
266+ }
267+ }
268+
269+ private:
270+ float min_x_, max_x_, step_, inv_step_;
271+ size_t size_;
272+ std::vector<float > table_;
273+ };
274+
143275}; // namespace activations
144276}; // namespace nam
0 commit comments