Skip to content

Commit c9899a4

Browse files
author
João Felipe Santos
committed
Fixed issue with blending activation, addressed comments for gating activation. Removed all runtime checks and replaced with asserts.
1 parent 2caf327 commit c9899a4

File tree

3 files changed

+103
-146
lines changed

3 files changed

+103
-146
lines changed

NAM/gating_activations.h

Lines changed: 33 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,16 @@ namespace gating_activations
1414
{
1515

1616
// Default linear activation (identity function)
17-
class LinearActivation : public nam::activations::Activation
17+
class IdentityActivation : public nam::activations::Activation
1818
{
1919
public:
20-
LinearActivation() = default;
21-
~LinearActivation() = default;
20+
IdentityActivation() = default;
21+
~IdentityActivation() = default;
2222
// Inherit the default apply methods which do nothing (linear/identity)
2323
};
2424

2525
// Static instance for default activation
26-
static LinearActivation default_activation;
26+
static IdentityActivation default_activation;
2727

2828
class GatingActivation
2929
{
@@ -40,12 +40,8 @@ class GatingActivation
4040
: input_activation(input_act ? input_act : &default_activation)
4141
, gating_activation(gating_act ? gating_act : activations::Activation::get_activation("Sigmoid"))
4242
, num_input_channels(input_channels)
43-
, num_gating_channels(gating_channels)
4443
{
45-
if (num_input_channels <= 0 || num_gating_channels <= 0)
46-
{
47-
throw std::invalid_argument("GatingActivation: number of channels must be positive");
48-
}
44+
assert(num_input_channels > 0);
4945
}
5046

5147
~GatingActivation() = default;
@@ -57,20 +53,11 @@ class GatingActivation
5753
*/
5854
void apply(Eigen::MatrixXf& input, Eigen::MatrixXf& output)
5955
{
60-
// Validate input dimensions
61-
const int total_channels = num_input_channels + num_gating_channels;
62-
if (input.rows() != total_channels)
63-
{
64-
throw std::invalid_argument("GatingActivation: input matrix must have " + std::to_string(total_channels)
65-
+ " rows");
66-
}
67-
68-
// Validate output dimensions
69-
if (output.rows() != num_input_channels || output.cols() != input.cols())
70-
{
71-
throw std::invalid_argument("GatingActivation: output matrix must have " + std::to_string(num_input_channels)
72-
+ " rows and " + std::to_string(input.cols()) + " columns");
73-
}
56+
// Validate input dimensions (assert for real-time performance)
57+
const int total_channels = 2 * num_input_channels;
58+
assert(input.rows() == total_channels);
59+
assert(output.rows() == num_input_channels);
60+
assert(output.cols() == input.cols());
7461

7562
// Process column-by-column to ensure memory contiguity (important for column-major matrices)
7663
const int num_samples = input.cols();
@@ -81,20 +68,19 @@ class GatingActivation
8168
input_activation->apply(input_block);
8269

8370
// Apply activation to gating channels
84-
Eigen::MatrixXf gating_block = input.block(num_input_channels, i, num_gating_channels, 1);
71+
Eigen::MatrixXf gating_block = input.block(num_input_channels, i, num_input_channels, 1);
8572
gating_activation->apply(gating_block);
8673

8774
// Element-wise multiplication and store result
8875
// For wavenet compatibility, we assume one-to-one mapping
89-
assert(num_input_channels == num_gating_channels);
9076
output.block(0, i, num_input_channels, 1) = input_block.array() * gating_block.array();
9177
}
9278
}
9379

9480
/**
9581
* Get the total number of input channels required
9682
*/
97-
int get_total_input_channels() const { return num_input_channels + num_gating_channels; }
83+
int get_total_input_channels() const { return 2 * num_input_channels; }
9884

9985
/**
10086
* Get the number of output channels
@@ -105,7 +91,6 @@ class GatingActivation
10591
activations::Activation* input_activation;
10692
activations::Activation* gating_activation;
10793
int num_input_channels;
108-
int num_gating_channels;
10994
};
11095

11196
class BlendingActivation
@@ -115,27 +100,21 @@ class BlendingActivation
115100
* Constructor for BlendingActivation
116101
* @param input_act Activation function for input channels
117102
* @param blend_act Activation function for blending channels
118-
* @param alpha_val Blending factor (0.0 to 1.0)
119103
* @param input_channels Number of input channels
120-
* @param blend_channels Number of blending channels
121104
*/
122105
BlendingActivation(activations::Activation* input_act = nullptr, activations::Activation* blend_act = nullptr,
123-
float alpha_val = 0.5f, int input_channels = 1, int blend_channels = 1)
106+
int input_channels = 1)
124107
: input_activation(input_act ? input_act : &default_activation)
125108
, blending_activation(blend_act ? blend_act : &default_activation)
126-
, alpha(alpha_val)
127109
, num_input_channels(input_channels)
128-
, num_blend_channels(blend_channels)
129110
{
130-
// Validate alpha is in valid range
131-
if (alpha < 0.0f || alpha > 1.0f)
132-
{
133-
throw std::invalid_argument("BlendingActivation: alpha must be between 0.0 and 1.0");
134-
}
135-
if (num_input_channels <= 0 || num_blend_channels <= 0)
111+
if (num_input_channels <= 0)
136112
{
137-
throw std::invalid_argument("BlendingActivation: number of channels must be positive");
113+
throw std::invalid_argument("BlendingActivation: number of input channels must be positive");
138114
}
115+
// Initialize input buffer with correct size
116+
// Note: current code copies column-by-column so we only need (num_input_channels, 1)
117+
input_buffer.resize(num_input_channels, 1);
139118
}
140119

141120
~BlendingActivation() = default;
@@ -147,44 +126,37 @@ class BlendingActivation
147126
*/
148127
void apply(Eigen::MatrixXf& input, Eigen::MatrixXf& output)
149128
{
150-
// Validate input dimensions
151-
const int total_channels = num_input_channels + num_blend_channels;
152-
if (input.rows() != total_channels)
153-
{
154-
throw std::invalid_argument("BlendingActivation: input matrix must have " + std::to_string(total_channels)
155-
+ " rows");
156-
}
157-
158-
// Validate output dimensions
159-
if (output.rows() != num_input_channels || output.cols() != input.cols())
160-
{
161-
throw std::invalid_argument("BlendingActivation: output matrix must have " + std::to_string(num_input_channels)
162-
+ " rows and " + std::to_string(input.cols()) + " columns");
163-
}
129+
// Validate input dimensions (assert for real-time performance)
130+
const int total_channels = num_input_channels * 2; // 2*channels in, channels out
131+
assert(input.rows() == total_channels);
132+
assert(output.rows() == num_input_channels);
133+
assert(output.cols() == input.cols());
164134

165135
// Process column-by-column to ensure memory contiguity
166136
const int num_samples = input.cols();
167137
for (int i = 0; i < num_samples; i++)
168138
{
139+
// Store pre-activation input values in buffer
140+
input_buffer = input.block(0, i, num_input_channels, 1);
141+
169142
// Apply activation to input channels
170143
Eigen::MatrixXf input_block = input.block(0, i, num_input_channels, 1);
171144
input_activation->apply(input_block);
172145

173-
// Apply activation to blend channels
174-
Eigen::MatrixXf blend_block = input.block(num_input_channels, i, num_blend_channels, 1);
146+
// Apply activation to blend channels to compute alpha
147+
Eigen::MatrixXf blend_block = input.block(num_input_channels, i, num_input_channels, 1);
175148
blending_activation->apply(blend_block);
176149

177-
// Weighted blending
178-
// For wavenet compatibility, we assume one-to-one mapping
179-
assert(num_input_channels == num_blend_channels);
180-
output.block(0, i, num_input_channels, 1) = alpha * input_block + (1.0f - alpha) * blend_block;
150+
// Weighted blending: alpha * activated_input + (1 - alpha) * pre_activation_input
151+
output.block(0, i, num_input_channels, 1) =
152+
blend_block.array() * input_block.array() + (1.0f - blend_block.array()) * input_buffer.array();
181153
}
182154
}
183155

184156
/**
185157
* Get the total number of input channels required
186158
*/
187-
int get_total_input_channels() const { return num_input_channels + num_blend_channels; }
159+
int get_total_input_channels() const { return 2 * num_input_channels; }
188160

189161
/**
190162
* Get the number of output channels
@@ -194,9 +166,8 @@ class BlendingActivation
194166
private:
195167
activations::Activation* input_activation;
196168
activations::Activation* blending_activation;
197-
float alpha;
198169
int num_input_channels;
199-
int num_blend_channels;
170+
Eigen::MatrixXf input_buffer;
200171
};
201172

202173

tools/run_tests.cpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
#include "test/test_fast_lut.cpp"
1010
#include "test/test_gating_activations.cpp"
1111
#include "test/test_wavenet_gating_compatibility.cpp"
12+
#include "test/test_blending_detailed.cpp"
13+
#include "test/test_input_buffer_verification.cpp"
1214

1315
int main()
1416
{
@@ -46,18 +48,26 @@ int main()
4648
test_gating_activations::TestGatingActivation::test_with_custom_activations();
4749
test_gating_activations::TestGatingActivation::test_error_handling();
4850

49-
test_gating_activations::TestBlendingActivation::test_basic_functionality();
50-
test_gating_activations::TestBlendingActivation::test_different_alpha_values();
51-
test_gating_activations::TestBlendingActivation::test_with_custom_activations();
52-
test_gating_activations::TestBlendingActivation::test_error_handling();
53-
test_gating_activations::TestBlendingActivation::test_edge_cases();
54-
5551
// Wavenet gating compatibility tests
5652
test_wavenet_gating_compatibility::TestWavenetGatingCompatibility::test_wavenet_style_gating();
5753
test_wavenet_gating_compatibility::TestWavenetGatingCompatibility::test_column_by_column_processing();
5854
test_wavenet_gating_compatibility::TestWavenetGatingCompatibility::test_memory_contiguity();
5955
test_wavenet_gating_compatibility::TestWavenetGatingCompatibility::test_multiple_channels();
6056

57+
test_gating_activations::TestBlendingActivation::test_basic_functionality();
58+
test_gating_activations::TestBlendingActivation::test_blending_behavior();
59+
test_gating_activations::TestBlendingActivation::test_with_custom_activations();
60+
test_gating_activations::TestBlendingActivation::test_error_handling();
61+
test_gating_activations::TestBlendingActivation::test_edge_cases();
62+
63+
// Detailed blending tests
64+
test_blending_detailed::TestBlendingDetailed::test_blending_with_different_activations();
65+
test_blending_detailed::TestBlendingDetailed::test_input_buffer_usage();
66+
67+
// Input buffer verification tests
68+
test_input_buffer_verification::TestInputBufferVerification::test_buffer_stores_pre_activation_values();
69+
test_input_buffer_verification::TestInputBufferVerification::test_buffer_with_different_activations();
70+
6171
std::cout << "Success!" << std::endl;
6272
return 0;
6373
}

0 commit comments

Comments
 (0)