Skip to content

Commit c0a1e13

Browse files
jfsantosJoão Felipe Santos
andauthored
Added gating activation classes (#180)
* Added gating activation classes * Throwing exceptions instead of resizing, removed default activation * Updated to generalize to 2*channels in, channels out. Added tests to compare to wavenet implementation * Formatting * Fixed issue with blending activation, addressed comments for gating activation. Removed all runtime checks and replaced with asserts. * Formatting * Removed extra argument that wasn't being used * Fixed small nitpicks * Removed default activations, moved ActivationIdentity to activations.h --------- Co-authored-by: João Felipe Santos <[email protected]>
1 parent 0516183 commit c0a1e13

File tree

7 files changed

+820
-1
lines changed

7 files changed

+820
-1
lines changed

NAM/activations.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,15 @@ class Activation
111111
static std::unordered_map<std::string, Activation*> _activations;
112112
};
113113

114+
// identity function activation
115+
class ActivationIdentity : public nam::activations::Activation
116+
{
117+
public:
118+
ActivationIdentity() = default;
119+
~ActivationIdentity() = default;
120+
// Inherit the default apply methods which do nothing
121+
};
122+
114123
class ActivationTanh : public Activation
115124
{
116125
public:

NAM/gating_activations.h

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
#pragma once
2+
3+
#include <string>
4+
#include <cmath> // expf
5+
#include <unordered_map>
6+
#include <Eigen/Dense>
7+
#include <functional>
8+
#include <stdexcept>
9+
#include "activations.h"
10+
11+
namespace nam
12+
{
13+
namespace gating_activations
14+
{
15+
16+
// Default linear activation (identity function)
17+
class IdentityActivation : public nam::activations::Activation
18+
{
19+
public:
20+
IdentityActivation() = default;
21+
~IdentityActivation() = default;
22+
// Inherit the default apply methods which do nothing (linear/identity)
23+
};
24+
25+
class GatingActivation
26+
{
27+
public:
28+
/**
29+
* Constructor for GatingActivation
30+
* @param input_act Activation function for input channels
31+
* @param gating_act Activation function for gating channels
32+
* @param input_channels Number of input channels (default: 1)
33+
* @param gating_channels Number of gating channels (default: 1)
34+
*/
35+
GatingActivation(activations::Activation* input_act, activations::Activation* gating_act, int input_channels = 1)
36+
: input_activation(input_act)
37+
, gating_activation(gating_act)
38+
, num_channels(input_channels)
39+
{
40+
assert(num_channels > 0);
41+
}
42+
43+
~GatingActivation() = default;
44+
45+
/**
46+
* Apply gating activation to input matrix
47+
* @param input Input matrix with shape (input_channels + gating_channels) x num_samples
48+
* @param output Output matrix with shape input_channels x num_samples
49+
*/
50+
void apply(Eigen::MatrixXf& input, Eigen::MatrixXf& output)
51+
{
52+
// Validate input dimensions (assert for real-time performance)
53+
const int total_channels = 2 * num_channels;
54+
assert(input.rows() == total_channels);
55+
assert(output.rows() == num_channels);
56+
assert(output.cols() == input.cols());
57+
58+
// Process column-by-column to ensure memory contiguity (important for column-major matrices)
59+
const int num_samples = input.cols();
60+
for (int i = 0; i < num_samples; i++)
61+
{
62+
// Apply activation to input channels
63+
Eigen::MatrixXf input_block = input.block(0, i, num_channels, 1);
64+
input_activation->apply(input_block);
65+
66+
// Apply activation to gating channels
67+
Eigen::MatrixXf gating_block = input.block(num_channels, i, num_channels, 1);
68+
gating_activation->apply(gating_block);
69+
70+
// Element-wise multiplication and store result
71+
// For wavenet compatibility, we assume one-to-one mapping
72+
output.block(0, i, num_channels, 1) = input_block.array() * gating_block.array();
73+
}
74+
}
75+
76+
/**
77+
* Get the total number of input channels required
78+
*/
79+
int get_input_channels() const { return 2 * num_channels; }
80+
81+
/**
82+
* Get the number of output channels
83+
*/
84+
int get_output_channels() const { return num_channels; }
85+
86+
private:
87+
activations::Activation* input_activation;
88+
activations::Activation* gating_activation;
89+
int num_channels;
90+
};
91+
92+
class BlendingActivation
93+
{
94+
public:
95+
/**
96+
* Constructor for BlendingActivation
97+
* @param input_act Activation function for input channels
98+
* @param blend_act Activation function for blending channels
99+
* @param input_channels Number of input channels
100+
*/
101+
BlendingActivation(activations::Activation* input_act, activations::Activation* blend_act, int input_channels = 1)
102+
: input_activation(input_act)
103+
, blending_activation(blend_act)
104+
, num_channels(input_channels)
105+
{
106+
if (num_channels <= 0)
107+
{
108+
throw std::invalid_argument("BlendingActivation: number of input channels must be positive");
109+
}
110+
// Initialize input buffer with correct size
111+
// Note: current code copies column-by-column so we only need (num_channels, 1)
112+
input_buffer.resize(num_channels, 1);
113+
}
114+
115+
~BlendingActivation() = default;
116+
117+
/**
118+
* Apply blending activation to input matrix
119+
* @param input Input matrix with shape (input_channels + blend_channels) x num_samples
120+
* @param output Output matrix with shape input_channels x num_samples
121+
*/
122+
void apply(Eigen::MatrixXf& input, Eigen::MatrixXf& output)
123+
{
124+
// Validate input dimensions (assert for real-time performance)
125+
const int total_channels = num_channels * 2; // 2*channels in, channels out
126+
assert(input.rows() == total_channels);
127+
assert(output.rows() == num_channels);
128+
assert(output.cols() == input.cols());
129+
130+
// Process column-by-column to ensure memory contiguity
131+
const int num_samples = input.cols();
132+
for (int i = 0; i < num_samples; i++)
133+
{
134+
// Store pre-activation input values in buffer
135+
input_buffer = input.block(0, i, num_channels, 1);
136+
137+
// Apply activation to input channels
138+
Eigen::MatrixXf input_block = input.block(0, i, num_channels, 1);
139+
input_activation->apply(input_block);
140+
141+
// Apply activation to blend channels to compute alpha
142+
Eigen::MatrixXf blend_block = input.block(num_channels, i, num_channels, 1);
143+
blending_activation->apply(blend_block);
144+
145+
// Weighted blending: alpha * activated_input + (1 - alpha) * pre_activation_input
146+
output.block(0, i, num_channels, 1) =
147+
blend_block.array() * input_block.array() + (1.0f - blend_block.array()) * input_buffer.array();
148+
}
149+
}
150+
151+
/**
152+
* Get the total number of input channels required
153+
*/
154+
int get_input_channels() const { return 2 * num_channels; }
155+
156+
/**
157+
* Get the number of output channels
158+
*/
159+
int get_output_channels() const { return num_channels; }
160+
161+
private:
162+
activations::Activation* input_activation;
163+
activations::Activation* blending_activation;
164+
int num_channels;
165+
Eigen::MatrixXf input_buffer;
166+
};
167+
168+
169+
}; // namespace gating_activations
170+
}; // namespace nam

tools/run_tests.cpp

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
#include "test/test_get_dsp.cpp"
88
#include "test/test_wavenet.cpp"
99
#include "test/test_fast_lut.cpp"
10+
#include "test/test_gating_activations.cpp"
11+
#include "test/test_wavenet_gating_compatibility.cpp"
12+
#include "test/test_blending_detailed.cpp"
13+
#include "test/test_input_buffer_verification.cpp"
1014

1115
int main()
1216
{
@@ -24,7 +28,7 @@ int main()
2428
test_activations::TestPReLU::test_core_function();
2529
test_activations::TestPReLU::test_per_channel_behavior();
2630
// This is enforced by an assert so it doesn't need to be tested
27-
//test_activations::TestPReLU::test_wrong_number_of_channels();
31+
// test_activations::TestPReLU::test_wrong_number_of_channels();
2832

2933
test_dsp::test_construct();
3034
test_dsp::test_get_input_level();
@@ -44,6 +48,31 @@ int main()
4448
test_lut::TestFastLUT::test_sigmoid();
4549
test_lut::TestFastLUT::test_tanh();
4650

51+
// Gating activations tests
52+
test_gating_activations::TestGatingActivation::test_basic_functionality();
53+
test_gating_activations::TestGatingActivation::test_with_custom_activations();
54+
// test_gating_activations::TestGatingActivation::test_error_handling();
55+
56+
// Wavenet gating compatibility tests
57+
test_wavenet_gating_compatibility::TestWavenetGatingCompatibility::test_wavenet_style_gating();
58+
test_wavenet_gating_compatibility::TestWavenetGatingCompatibility::test_column_by_column_processing();
59+
test_wavenet_gating_compatibility::TestWavenetGatingCompatibility::test_memory_contiguity();
60+
test_wavenet_gating_compatibility::TestWavenetGatingCompatibility::test_multiple_channels();
61+
62+
test_gating_activations::TestBlendingActivation::test_basic_functionality();
63+
test_gating_activations::TestBlendingActivation::test_blending_behavior();
64+
test_gating_activations::TestBlendingActivation::test_with_custom_activations();
65+
// test_gating_activations::TestBlendingActivation::test_error_handling();
66+
test_gating_activations::TestBlendingActivation::test_edge_cases();
67+
68+
// Detailed blending tests
69+
test_blending_detailed::TestBlendingDetailed::test_blending_with_different_activations();
70+
test_blending_detailed::TestBlendingDetailed::test_input_buffer_usage();
71+
72+
// Input buffer verification tests
73+
test_input_buffer_verification::TestInputBufferVerification::test_buffer_stores_pre_activation_values();
74+
test_input_buffer_verification::TestInputBufferVerification::test_buffer_with_different_activations();
75+
4776
std::cout << "Success!" << std::endl;
4877
return 0;
4978
}
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
// Detailed test for BlendingActivation behavior
2+
3+
#include <cassert>
4+
#include <string>
5+
#include <vector>
6+
#include <cmath>
7+
#include <iostream>
8+
9+
#include "NAM/gating_activations.h"
10+
#include "NAM/activations.h"
11+
12+
namespace test_blending_detailed
13+
{
14+
15+
class TestBlendingDetailed
16+
{
17+
public:
18+
static void test_blending_with_different_activations()
19+
{
20+
// Test case: 2 input channels, so we need 4 total input channels (2*channels in)
21+
Eigen::MatrixXf input(4, 2); // 4 rows (2 input + 2 blending), 2 samples
22+
input << 1.0f, 2.0f, // Input channel 1
23+
3.0f, 4.0f, // Input channel 2
24+
0.5f, 0.8f, // Blending channel 1
25+
0.3f, 0.6f; // Blending channel 2
26+
27+
Eigen::MatrixXf output(2, 2); // 2 output channels, 2 samples
28+
29+
// Test with default (linear) activations
30+
nam::activations::ActivationIdentity identity_act;
31+
nam::activations::ActivationIdentity identity_blend_act;
32+
nam::gating_activations::BlendingActivation blending_act(&identity_act, &identity_blend_act, 2);
33+
blending_act.apply(input, output);
34+
35+
std::cout << "Blending with linear activations:" << std::endl;
36+
std::cout << "Input:" << std::endl << input << std::endl;
37+
std::cout << "Output:" << std::endl << output << std::endl;
38+
39+
// With linear activations:
40+
// alpha = blend_input (since linear activation does nothing)
41+
// output = alpha * input + (1 - alpha) * input = input
42+
// So output should equal the input channels after activation (which is the same as input)
43+
assert(fabs(output(0, 0) - 1.0f) < 1e-6);
44+
assert(fabs(output(1, 0) - 3.0f) < 1e-6);
45+
assert(fabs(output(0, 1) - 2.0f) < 1e-6);
46+
assert(fabs(output(1, 1) - 4.0f) < 1e-6);
47+
48+
// Test with sigmoid blending activation
49+
nam::activations::Activation* sigmoid_act = nam::activations::Activation::get_activation("Sigmoid");
50+
nam::gating_activations::BlendingActivation blending_act_sigmoid(&identity_act, sigmoid_act, 2);
51+
52+
Eigen::MatrixXf output_sigmoid(2, 2);
53+
blending_act_sigmoid.apply(input, output_sigmoid);
54+
55+
std::cout << "Blending with sigmoid blending activation:" << std::endl;
56+
std::cout << "Output:" << std::endl << output_sigmoid << std::endl;
57+
58+
// With sigmoid blending, alpha values should be between 0 and 1
59+
// For blend input 0.5, sigmoid(0.5) ≈ 0.622
60+
// For blend input 0.8, sigmoid(0.8) ≈ 0.690
61+
// For blend input 0.3, sigmoid(0.3) ≈ 0.574
62+
// For blend input 0.6, sigmoid(0.6) ≈ 0.646
63+
64+
float alpha0_0 = 1.0f / (1.0f + expf(-0.5f)); // sigmoid(0.5)
65+
float alpha1_0 = 1.0f / (1.0f + expf(-0.8f)); // sigmoid(0.8)
66+
float alpha0_1 = 1.0f / (1.0f + expf(-0.3f)); // sigmoid(0.3)
67+
float alpha1_1 = 1.0f / (1.0f + expf(-0.6f)); // sigmoid(0.6)
68+
69+
// Expected output: alpha * activated_input + (1 - alpha) * pre_activation_input
70+
// Since input activation is linear, activated_input = pre_activation_input = input
71+
// So output = alpha * input + (1 - alpha) * input = input
72+
// This should be the same as with linear activations
73+
assert(fabs(output_sigmoid(0, 0) - 1.0f) < 1e-6);
74+
assert(fabs(output_sigmoid(1, 0) - 3.0f) < 1e-6);
75+
assert(fabs(output_sigmoid(0, 1) - 2.0f) < 1e-6);
76+
assert(fabs(output_sigmoid(1, 1) - 4.0f) < 1e-6);
77+
78+
std::cout << "Blending detailed test passed" << std::endl;
79+
}
80+
81+
static void test_input_buffer_usage()
82+
{
83+
// Test that the input buffer is correctly storing pre-activation values
84+
Eigen::MatrixXf input(2, 1);
85+
input << 2.0f, 0.5f;
86+
87+
Eigen::MatrixXf output(1, 1);
88+
89+
// Test with ReLU activation on input (which will change values < 0 to 0)
90+
nam::activations::ActivationReLU relu_act;
91+
nam::activations::ActivationIdentity identity_act;
92+
nam::gating_activations::BlendingActivation blending_act(&relu_act, &identity_act, 1);
93+
94+
blending_act.apply(input, output);
95+
96+
// With input=2.0, ReLU(2.0)=2.0, blend=0.5
97+
// output = 0.5 * 2.0 + (1 - 0.5) * 2.0 = 0.5 * 2.0 + 0.5 * 2.0 = 2.0
98+
assert(fabs(output(0, 0) - 2.0f) < 1e-6);
99+
100+
// Test with negative input value
101+
Eigen::MatrixXf input2(2, 1);
102+
input2 << -1.0f, 0.5f;
103+
104+
Eigen::MatrixXf output2(1, 1);
105+
blending_act.apply(input2, output2);
106+
107+
// With input=-1.0, ReLU(-1.0)=0.0, blend=0.5
108+
// output = 0.5 * 0.0 + (1 - 0.5) * (-1.0) = 0.0 + 0.5 * (-1.0) = -0.5
109+
assert(fabs(output2(0, 0) - (-0.5f)) < 1e-6);
110+
111+
std::cout << "Input buffer usage test passed" << std::endl;
112+
}
113+
};
114+
115+
}; // namespace test_blending_detailed

0 commit comments

Comments
 (0)