@@ -14,16 +14,16 @@ namespace gating_activations
1414{
1515
1616// Default linear activation (identity function)
17- class LinearActivation : public nam ::activations::Activation
17+ class IdentityActivation : public nam ::activations::Activation
1818{
1919public:
20- LinearActivation () = default ;
21- ~LinearActivation () = default ;
20+ IdentityActivation () = default ;
21+ ~IdentityActivation () = default ;
2222 // Inherit the default apply methods which do nothing (linear/identity)
2323};
2424
2525// Static instance for default activation
26- static LinearActivation default_activation;
26+ static IdentityActivation default_activation;
2727
2828class GatingActivation
2929{
@@ -40,12 +40,8 @@ class GatingActivation
4040 : input_activation(input_act ? input_act : &default_activation)
4141 , gating_activation(gating_act ? gating_act : activations::Activation::get_activation(" Sigmoid" ))
4242 , num_input_channels(input_channels)
43- , num_gating_channels(gating_channels)
4443 {
45- if (num_input_channels <= 0 || num_gating_channels <= 0 )
46- {
47- throw std::invalid_argument (" GatingActivation: number of channels must be positive" );
48- }
44+ assert (num_input_channels > 0 );
4945 }
5046
5147 ~GatingActivation () = default ;
@@ -57,20 +53,11 @@ class GatingActivation
5753 */
5854 void apply (Eigen::MatrixXf& input, Eigen::MatrixXf& output)
5955 {
60- // Validate input dimensions
61- const int total_channels = num_input_channels + num_gating_channels;
62- if (input.rows () != total_channels)
63- {
64- throw std::invalid_argument (" GatingActivation: input matrix must have " + std::to_string (total_channels)
65- + " rows" );
66- }
67-
68- // Validate output dimensions
69- if (output.rows () != num_input_channels || output.cols () != input.cols ())
70- {
71- throw std::invalid_argument (" GatingActivation: output matrix must have " + std::to_string (num_input_channels)
72- + " rows and " + std::to_string (input.cols ()) + " columns" );
73- }
56+ // Validate input dimensions (assert for real-time performance)
57+ const int total_channels = 2 * num_input_channels;
58+ assert (input.rows () == total_channels);
59+ assert (output.rows () == num_input_channels);
60+ assert (output.cols () == input.cols ());
7461
7562 // Process column-by-column to ensure memory contiguity (important for column-major matrices)
7663 const int num_samples = input.cols ();
@@ -81,20 +68,19 @@ class GatingActivation
8168 input_activation->apply (input_block);
8269
8370 // Apply activation to gating channels
84- Eigen::MatrixXf gating_block = input.block (num_input_channels, i, num_gating_channels , 1 );
71+ Eigen::MatrixXf gating_block = input.block (num_input_channels, i, num_input_channels , 1 );
8572 gating_activation->apply (gating_block);
8673
8774 // Element-wise multiplication and store result
8875 // For wavenet compatibility, we assume one-to-one mapping
89- assert (num_input_channels == num_gating_channels);
9076 output.block (0 , i, num_input_channels, 1 ) = input_block.array () * gating_block.array ();
9177 }
9278 }
9379
9480 /* *
9581 * Get the total number of input channels required
9682 */
97- int get_total_input_channels () const { return num_input_channels + num_gating_channels ; }
83+ int get_total_input_channels () const { return 2 * num_input_channels ; }
9884
9985 /* *
10086 * Get the number of output channels
@@ -105,7 +91,6 @@ class GatingActivation
10591 activations::Activation* input_activation;
10692 activations::Activation* gating_activation;
10793 int num_input_channels;
108- int num_gating_channels;
10994};
11095
11196class BlendingActivation
@@ -115,27 +100,21 @@ class BlendingActivation
115100 * Constructor for BlendingActivation
116101 * @param input_act Activation function for input channels
117102 * @param blend_act Activation function for blending channels
118- * @param alpha_val Blending factor (0.0 to 1.0)
119103 * @param input_channels Number of input channels
120- * @param blend_channels Number of blending channels
121104 */
122105 BlendingActivation (activations::Activation* input_act = nullptr , activations::Activation* blend_act = nullptr ,
123- float alpha_val = 0 . 5f , int input_channels = 1 , int blend_channels = 1 )
106+ int input_channels = 1 )
124107 : input_activation(input_act ? input_act : &default_activation)
125108 , blending_activation(blend_act ? blend_act : &default_activation)
126- , alpha(alpha_val)
127109 , num_input_channels(input_channels)
128- , num_blend_channels(blend_channels)
129110 {
130- // Validate alpha is in valid range
131- if (alpha < 0 .0f || alpha > 1 .0f )
132- {
133- throw std::invalid_argument (" BlendingActivation: alpha must be between 0.0 and 1.0" );
134- }
135- if (num_input_channels <= 0 || num_blend_channels <= 0 )
111+ if (num_input_channels <= 0 )
136112 {
137- throw std::invalid_argument (" BlendingActivation: number of channels must be positive" );
113+ throw std::invalid_argument (" BlendingActivation: number of input channels must be positive" );
138114 }
115+ // Initialize input buffer with correct size
116+ // Note: current code copies column-by-column so we only need (num_input_channels, 1)
117+ input_buffer.resize (num_input_channels, 1 );
139118 }
140119
141120 ~BlendingActivation () = default ;
@@ -147,44 +126,37 @@ class BlendingActivation
147126 */
148127 void apply (Eigen::MatrixXf& input, Eigen::MatrixXf& output)
149128 {
150- // Validate input dimensions
151- const int total_channels = num_input_channels + num_blend_channels;
152- if (input.rows () != total_channels)
153- {
154- throw std::invalid_argument (" BlendingActivation: input matrix must have " + std::to_string (total_channels)
155- + " rows" );
156- }
157-
158- // Validate output dimensions
159- if (output.rows () != num_input_channels || output.cols () != input.cols ())
160- {
161- throw std::invalid_argument (" BlendingActivation: output matrix must have " + std::to_string (num_input_channels)
162- + " rows and " + std::to_string (input.cols ()) + " columns" );
163- }
129+ // Validate input dimensions (assert for real-time performance)
130+ const int total_channels = num_input_channels * 2 ; // 2*channels in, channels out
131+ assert (input.rows () == total_channels);
132+ assert (output.rows () == num_input_channels);
133+ assert (output.cols () == input.cols ());
164134
165135 // Process column-by-column to ensure memory contiguity
166136 const int num_samples = input.cols ();
167137 for (int i = 0 ; i < num_samples; i++)
168138 {
139+ // Store pre-activation input values in buffer
140+ input_buffer = input.block (0 , i, num_input_channels, 1 );
141+
169142 // Apply activation to input channels
170143 Eigen::MatrixXf input_block = input.block (0 , i, num_input_channels, 1 );
171144 input_activation->apply (input_block);
172145
173- // Apply activation to blend channels
174- Eigen::MatrixXf blend_block = input.block (num_input_channels, i, num_blend_channels , 1 );
146+ // Apply activation to blend channels to compute alpha
147+ Eigen::MatrixXf blend_block = input.block (num_input_channels, i, num_input_channels , 1 );
175148 blending_activation->apply (blend_block);
176149
177- // Weighted blending
178- // For wavenet compatibility, we assume one-to-one mapping
179- assert (num_input_channels == num_blend_channels);
180- output.block (0 , i, num_input_channels, 1 ) = alpha * input_block + (1 .0f - alpha) * blend_block;
150+ // Weighted blending: alpha * activated_input + (1 - alpha) * pre_activation_input
151+ output.block (0 , i, num_input_channels, 1 ) =
152+ blend_block.array () * input_block.array () + (1 .0f - blend_block.array ()) * input_buffer.array ();
181153 }
182154 }
183155
184156 /* *
185157 * Get the total number of input channels required
186158 */
187- int get_total_input_channels () const { return num_input_channels + num_blend_channels ; }
159+ int get_total_input_channels () const { return 2 * num_input_channels ; }
188160
189161 /* *
190162 * Get the number of output channels
@@ -194,9 +166,8 @@ class BlendingActivation
194166private:
195167 activations::Activation* input_activation;
196168 activations::Activation* blending_activation;
197- float alpha;
198169 int num_input_channels;
199- int num_blend_channels ;
170+ Eigen::MatrixXf input_buffer ;
200171};
201172
202173
0 commit comments