Skip to content

Commit 8987a3c

Browse files
sdatkinsonJoão Felipe Santos
andauthored
[FEATURE] Integrate gating & blending activations into WaveNet (#193)
* Replaced manual gating/blending with Gating/BlendingActivation classes. * Deleted debugging stuff and replaced throw with assert. * Remove backward compatibility constructors from _Layer, LayerArrayParams, and _LayerArray - Removed boolean gated parameter constructors from _Layer class - Removed boolean gated parameter constructors from LayerArrayParams struct - Removed boolean gated parameter constructors from _LayerArray class - Updated all test files to use GatingMode enum instead of bool gated - Backward compatibility logic retained in Factory function for JSON parsing - All tests pass after updating to use GatingMode enum * Formatting * Enhance _Layer constructor to validate gating and blending activations - Added validation checks in the _Layer constructor to ensure that gating and blending activations are not provided when the gating mode is set to NONE. - Updated the constructor signature to remove default values for gating and blending activations, requiring explicit values during instantiation. - Adjusted related test cases to reflect the new constructor requirements, ensuring comprehensive coverage of the updated functionality. - Improved error handling by throwing std::invalid_argument exceptions for invalid activation configurations. * Refactor _Layer constructor to use secondary_activation parameter - Updated the _Layer constructor to replace gating and blending activation parameters with a single secondary_activation parameter. - Added validation to ensure secondary_activation is provided for GATED and BLENDED modes. - Adjusted related test cases to reflect the new constructor signature and ensure proper functionality. - Removed default values for gating and blending activations, enforcing explicit activation configuration during instantiation. * Note on gating/blending activations --------- Co-authored-by: João Felipe Santos <santosjf@pm.me>
1 parent 57dfd6b commit 8987a3c

File tree

11 files changed

+520
-127
lines changed

11 files changed

+520
-127
lines changed

NAM/gating_activations.h

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,13 @@ class GatingActivation
3737
, gating_activation(gating_act)
3838
, num_channels(input_channels)
3939
{
40-
assert(num_channels > 0);
40+
if (num_channels <= 0)
41+
{
42+
throw std::invalid_argument("GatingActivation: number of input channels must be positive");
43+
}
44+
// Initialize input buffer with correct size
45+
// Note: current code copies column-by-column so we only need (num_channels, 1)
46+
input_buffer.resize(num_channels, 1);
4147
}
4248

4349
~GatingActivation() = default;
@@ -47,7 +53,8 @@ class GatingActivation
4753
* @param input Input matrix with shape (input_channels + gating_channels) x num_samples
4854
* @param output Output matrix with shape input_channels x num_samples
4955
*/
50-
void apply(Eigen::MatrixXf& input, Eigen::MatrixXf& output)
56+
template <typename InputDerived, typename OutputDerived>
57+
void apply(const Eigen::MatrixBase<InputDerived>& input, Eigen::MatrixBase<OutputDerived>& output)
5158
{
5259
// Validate input dimensions (assert for real-time performance)
5360
const int total_channels = 2 * num_channels;
@@ -59,6 +66,9 @@ class GatingActivation
5966
const int num_samples = input.cols();
6067
for (int i = 0; i < num_samples; i++)
6168
{
69+
// Store pre-activation input values in buffer to avoid overwriting issues
70+
input_buffer = input.block(0, i, num_channels, 1);
71+
6272
// Apply activation to input channels
6373
Eigen::MatrixXf input_block = input.block(0, i, num_channels, 1);
6474
input_activation->apply(input_block);
@@ -87,6 +97,7 @@ class GatingActivation
8797
activations::Activation* input_activation;
8898
activations::Activation* gating_activation;
8999
int num_channels;
100+
Eigen::MatrixXf input_buffer;
90101
};
91102

92103
class BlendingActivation
@@ -103,10 +114,8 @@ class BlendingActivation
103114
, blending_activation(blend_act)
104115
, num_channels(input_channels)
105116
{
106-
if (num_channels <= 0)
107-
{
108-
throw std::invalid_argument("BlendingActivation: number of input channels must be positive");
109-
}
117+
assert(num_channels > 0);
118+
110119
// Initialize input buffer with correct size
111120
// Note: current code copies column-by-column so we only need (num_channels, 1)
112121
input_buffer.resize(num_channels, 1);
@@ -119,7 +128,8 @@ class BlendingActivation
119128
* @param input Input matrix with shape (input_channels + blend_channels) x num_samples
120129
* @param output Output matrix with shape input_channels x num_samples
121130
*/
122-
void apply(Eigen::MatrixXf& input, Eigen::MatrixXf& output)
131+
template <typename InputDerived, typename OutputDerived>
132+
void apply(const Eigen::MatrixBase<InputDerived>& input, Eigen::MatrixBase<OutputDerived>& output)
123133
{
124134
// Validate input dimensions (assert for real-time performance)
125135
const int total_channels = num_channels * 2; // 2*channels in, channels out

NAM/wavenet.cpp

Lines changed: 78 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -58,29 +58,38 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma
5858
this->_conv.GetOutput().leftCols(num_frames) + _input_mixin.GetOutput().leftCols(num_frames);
5959

6060
// Step 2 & 3: activation and 1x1
61-
if (!this->_gated)
61+
//
62+
// A note about the gating/blending activations:
63+
// They take 2x dimension as input.
64+
// The top channels are for the "primary" activation and will be in-place modified for the final result.
65+
// The bottom channels are for the "secondary" activation and should not be used post-activation.
66+
if (this->_gating_mode == GatingMode::NONE)
6267
{
6368
this->_activation->apply(this->_z.leftCols(num_frames));
6469
_1x1.process_(_z, num_frames);
6570
}
66-
else
71+
else if (this->_gating_mode == GatingMode::GATED)
6772
{
68-
// CAREFUL: .topRows() and .bottomRows() won't be memory-contiguous for a column-major matrix (Issue 125). Need to
69-
// do this column-wise:
70-
for (int i = 0; i < num_frames; i++)
71-
{
72-
this->_activation->apply(this->_z.block(0, i, bottleneck, 1));
73-
// TODO Need to support other activation functions here instead of hardcoded sigmoid
74-
activations::Activation::get_activation("Sigmoid")->apply(this->_z.block(bottleneck, i, bottleneck, 1));
75-
}
76-
this->_z.block(0, 0, bottleneck, num_frames).array() *=
77-
this->_z.block(bottleneck, 0, bottleneck, num_frames).array();
78-
_1x1.process_(_z.topRows(bottleneck), num_frames); // Might not be RT safe
73+
// Use the GatingActivation class
74+
// Extract the blocks first to avoid temporary reference issues
75+
auto input_block = this->_z.leftCols(num_frames);
76+
auto output_block = this->_z.topRows(bottleneck).leftCols(num_frames);
77+
this->_gating_activation->apply(input_block, output_block);
78+
_1x1.process_(this->_z.topRows(bottleneck), num_frames);
79+
}
80+
else if (this->_gating_mode == GatingMode::BLENDED)
81+
{
82+
// Use the BlendingActivation class
83+
// Extract the blocks first to avoid temporary reference issues
84+
auto input_block = this->_z.leftCols(num_frames);
85+
auto output_block = this->_z.topRows(bottleneck).leftCols(num_frames);
86+
this->_blending_activation->apply(input_block, output_block);
87+
_1x1.process_(this->_z.topRows(bottleneck), num_frames);
7988
}
8089

8190
if (this->_head1x1)
8291
{
83-
if (!this->_gated)
92+
if (this->_gating_mode == GatingMode::NONE)
8493
this->_head1x1->process_(this->_z.leftCols(num_frames), num_frames);
8594
else
8695
this->_head1x1->process(this->_z.topRows(bottleneck).leftCols(num_frames), num_frames);
@@ -89,7 +98,7 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma
8998
else
9099
{
91100
// Store output to head (skip connection: activated conv output)
92-
if (!this->_gated)
101+
if (this->_gating_mode == GatingMode::NONE)
93102
this->_output_head.leftCols(num_frames).noalias() = this->_z.leftCols(num_frames);
94103
else
95104
this->_output_head.leftCols(num_frames).noalias() = this->_z.topRows(bottleneck).leftCols(num_frames);
@@ -105,15 +114,16 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma
105114
nam::wavenet::_LayerArray::_LayerArray(const int input_size, const int condition_size, const int head_size,
106115
const int channels, const int bottleneck, const int kernel_size,
107116
const std::vector<int>& dilations, const std::string activation,
108-
const bool gated, const bool head_bias, const int groups_input,
109-
const int groups_1x1, const Head1x1Params& head1x1_params)
117+
const GatingMode gating_mode, const bool head_bias, const int groups_input,
118+
const int groups_1x1, const Head1x1Params& head1x1_params,
119+
const std::string& secondary_activation)
110120
: _rechannel(input_size, channels, false)
111121
, _head_rechannel(bottleneck, head_size, head_bias)
112122
, _bottleneck(bottleneck)
113123
{
114124
for (size_t i = 0; i < dilations.size(); i++)
115-
this->_layers.push_back(_Layer(condition_size, channels, bottleneck, kernel_size, dilations[i], activation, gated,
116-
groups_input, groups_1x1, head1x1_params));
125+
this->_layers.push_back(_Layer(condition_size, channels, bottleneck, kernel_size, dilations[i], activation,
126+
gating_mode, groups_input, groups_1x1, head1x1_params, secondary_activation));
117127
}
118128

119129
void nam::wavenet::_LayerArray::SetMaxBufferSize(const int maxBufferSize)
@@ -263,9 +273,9 @@ nam::wavenet::WaveNet::WaveNet(const int in_channels,
263273
this->_layer_arrays.push_back(nam::wavenet::_LayerArray(
264274
layer_array_params[i].input_size, layer_array_params[i].condition_size, layer_array_params[i].head_size,
265275
layer_array_params[i].channels, layer_array_params[i].bottleneck, layer_array_params[i].kernel_size,
266-
layer_array_params[i].dilations, layer_array_params[i].activation, layer_array_params[i].gated,
276+
layer_array_params[i].dilations, layer_array_params[i].activation, layer_array_params[i].gating_mode,
267277
layer_array_params[i].head_bias, layer_array_params[i].groups_input, layer_array_params[i].groups_1x1,
268-
layer_array_params[i].head1x1_params));
278+
layer_array_params[i].head1x1_params, layer_array_params[i].secondary_activation));
269279
if (i > 0)
270280
if (layer_array_params[i].channels != layer_array_params[i - 1].head_size)
271281
{
@@ -468,7 +478,50 @@ std::unique_ptr<nam::DSP> nam::wavenet::Factory(const nlohmann::json& config, st
468478
const int kernel_size = layer_config["kernel_size"];
469479
const auto dilations = layer_config["dilations"];
470480
const std::string activation = layer_config["activation"].get<std::string>();
471-
const bool gated = layer_config["gated"];
481+
// Parse gating mode - support both old "gated" boolean and new "gating_mode" string
482+
GatingMode gating_mode = GatingMode::NONE;
483+
std::string secondary_activation;
484+
485+
if (layer_config.find("gating_mode") != layer_config.end())
486+
{
487+
std::string gating_mode_str = layer_config["gating_mode"].get<std::string>();
488+
if (gating_mode_str == "gated")
489+
{
490+
gating_mode = GatingMode::GATED;
491+
secondary_activation = layer_config["secondary_activation"].get<std::string>();
492+
}
493+
else if (gating_mode_str == "blended")
494+
{
495+
gating_mode = GatingMode::BLENDED;
496+
secondary_activation = layer_config["secondary_activation"].get<std::string>();
497+
}
498+
else if (gating_mode_str == "none")
499+
{
500+
gating_mode = GatingMode::NONE;
501+
secondary_activation.clear();
502+
}
503+
else
504+
throw std::runtime_error("Invalid gating_mode: " + gating_mode_str);
505+
}
506+
else if (layer_config.find("gated") != layer_config.end())
507+
{
508+
// Backward compatibility: convert old "gated" boolean to new enum
509+
bool gated = layer_config["gated"];
510+
gating_mode = gated ? GatingMode::GATED : GatingMode::NONE;
511+
if (gated)
512+
{
513+
secondary_activation = "Sigmoid";
514+
}
515+
else
516+
{
517+
secondary_activation.clear();
518+
}
519+
}
520+
else
521+
{
522+
throw std::invalid_argument("No information on gating mode found for layer array " + std::to_string(i));
523+
}
524+
472525
const bool head_bias = layer_config["head_bias"];
473526

474527
// Parse head1x1 parameters
@@ -477,9 +530,9 @@ std::unique_ptr<nam::DSP> nam::wavenet::Factory(const nlohmann::json& config, st
477530
int head1x1_groups = layer_config.value("head1x1_groups", 1);
478531
nam::wavenet::Head1x1Params head1x1_params(head1x1_active, head1x1_out_channels, head1x1_groups);
479532

480-
layer_array_params.push_back(nam::wavenet::LayerArrayParams(input_size, condition_size, head_size, channels,
481-
bottleneck, kernel_size, dilations, activation, gated,
482-
head_bias, groups, groups_1x1, head1x1_params));
533+
layer_array_params.push_back(nam::wavenet::LayerArrayParams(
534+
input_size, condition_size, head_size, channels, bottleneck, kernel_size, dilations, activation, gating_mode,
535+
head_bias, groups, groups_1x1, head1x1_params, secondary_activation));
483536
}
484537
const bool with_head = !config["head"].is_null();
485538
const float head_scale = config["head_scale"];

NAM/wavenet.h

Lines changed: 58 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,33 @@
33
#include <string>
44
#include <vector>
55
#include <memory>
6+
#include <stdexcept>
67

78
#include "json.hpp"
89
#include <Eigen/Dense>
910

1011
#include "dsp.h"
1112
#include "conv1d.h"
13+
#include "gating_activations.h"
1214

1315
namespace nam
1416
{
1517
namespace wavenet
1618
{
19+
20+
// Gating mode for WaveNet layers
21+
enum class GatingMode
22+
{
23+
NONE, // No gating or blending
24+
GATED, // Traditional gating (element-wise multiplication)
25+
BLENDED // Blending (weighted average)
26+
};
27+
28+
// Helper function for backward compatibility with boolean gated parameter
29+
inline GatingMode gating_mode_from_bool(bool gated)
30+
{
31+
return gated ? GatingMode::GATED : GatingMode::NONE;
32+
}
1733
// Parameters for head1x1 configuration
1834
struct Head1x1Params
1935
{
@@ -32,20 +48,42 @@ struct Head1x1Params
3248
class _Layer
3349
{
3450
public:
51+
// New constructor with GatingMode enum and configurable activations
3552
_Layer(const int condition_size, const int channels, const int bottleneck, const int kernel_size, const int dilation,
36-
const std::string activation, const bool gated, const int groups_input, const int groups_1x1,
37-
const Head1x1Params& head1x1_params)
38-
: _conv(channels, gated ? 2 * bottleneck : bottleneck, kernel_size, true, dilation)
39-
, _input_mixin(condition_size, gated ? 2 * bottleneck : bottleneck, false)
53+
const std::string activation, const GatingMode gating_mode, const int groups_input, const int groups_1x1,
54+
const Head1x1Params& head1x1_params, const std::string& secondary_activation)
55+
: _conv(channels, (gating_mode != GatingMode::NONE) ? 2 * bottleneck : bottleneck, kernel_size, true, dilation)
56+
, _input_mixin(condition_size, (gating_mode != GatingMode::NONE) ? 2 * bottleneck : bottleneck, false)
4057
, _1x1(bottleneck, channels, groups_1x1)
4158
, _activation(activations::Activation::get_activation(activation)) // needs to support activations with parameters
42-
, _gated(gated)
59+
, _gating_mode(gating_mode)
4360
, _bottleneck(bottleneck)
4461
{
4562
if (head1x1_params.active)
4663
{
4764
_head1x1 = std::make_unique<Conv1x1>(bottleneck, head1x1_params.out_channels, true, head1x1_params.groups);
4865
}
66+
67+
// Validate & initialize gating/blending activation
68+
if (gating_mode == GatingMode::GATED)
69+
{
70+
if (secondary_activation.empty())
71+
throw std::invalid_argument("secondary_activation must be provided for gated mode");
72+
_gating_activation = std::make_unique<gating_activations::GatingActivation>(
73+
_activation, activations::Activation::get_activation(secondary_activation), bottleneck);
74+
}
75+
else if (gating_mode == GatingMode::BLENDED)
76+
{
77+
if (secondary_activation.empty())
78+
throw std::invalid_argument("secondary_activation must be provided for blended mode");
79+
_blending_activation = std::make_unique<gating_activations::BlendingActivation>(
80+
_activation, activations::Activation::get_activation(secondary_activation), bottleneck);
81+
}
82+
else
83+
{
84+
if (!secondary_activation.empty())
85+
throw std::invalid_argument("secondary_activation provided for none mode");
86+
}
4987
};
5088

5189
// Resize all arrays to be able to process `maxBufferSize` frames.
@@ -97,17 +135,22 @@ class _Layer
97135
Eigen::MatrixXf _output_head;
98136

99137
activations::Activation* _activation;
100-
const bool _gated;
138+
const GatingMode _gating_mode;
101139
const int _bottleneck; // Internal channel count (not doubled when gated)
140+
141+
// Gating/blending activation objects
142+
std::unique_ptr<gating_activations::GatingActivation> _gating_activation;
143+
std::unique_ptr<gating_activations::BlendingActivation> _blending_activation;
102144
};
103145

104146
class LayerArrayParams
105147
{
106148
public:
107149
LayerArrayParams(const int input_size_, const int condition_size_, const int head_size_, const int channels_,
108150
const int bottleneck_, const int kernel_size_, const std::vector<int>&& dilations_,
109-
const std::string activation_, const bool gated_, const bool head_bias_, const int groups_input,
110-
const int groups_1x1_, const Head1x1Params& head1x1_params_)
151+
const std::string activation_, const GatingMode gating_mode_, const bool head_bias_,
152+
const int groups_input, const int groups_1x1_, const Head1x1Params& head1x1_params_,
153+
const std::string& secondary_activation_)
111154
: input_size(input_size_)
112155
, condition_size(condition_size_)
113156
, head_size(head_size_)
@@ -116,11 +159,12 @@ class LayerArrayParams
116159
, kernel_size(kernel_size_)
117160
, dilations(std::move(dilations_))
118161
, activation(activation_)
119-
, gated(gated_)
162+
, gating_mode(gating_mode_)
120163
, head_bias(head_bias_)
121164
, groups_input(groups_input)
122165
, groups_1x1(groups_1x1_)
123166
, head1x1_params(head1x1_params_)
167+
, secondary_activation(secondary_activation_)
124168
{
125169
}
126170

@@ -132,21 +176,23 @@ class LayerArrayParams
132176
const int kernel_size;
133177
std::vector<int> dilations;
134178
const std::string activation;
135-
const bool gated;
179+
const GatingMode gating_mode;
136180
const bool head_bias;
137181
const int groups_input;
138182
const int groups_1x1;
139183
const Head1x1Params head1x1_params;
184+
const std::string secondary_activation;
140185
};
141186

142187
// An array of layers with the same channels, kernel sizes, activations.
143188
class _LayerArray
144189
{
145190
public:
191+
// New constructor with GatingMode enum and configurable activations
146192
_LayerArray(const int input_size, const int condition_size, const int head_size, const int channels,
147193
const int bottleneck, const int kernel_size, const std::vector<int>& dilations,
148-
const std::string activation, const bool gated, const bool head_bias, const int groups_input,
149-
const int groups_1x1, const Head1x1Params& head1x1_params);
194+
const std::string activation, const GatingMode gating_mode, const bool head_bias, const int groups_input,
195+
const int groups_1x1, const Head1x1Params& head1x1_params, const std::string& secondary_activation);
150196

151197
void SetMaxBufferSize(const int maxBufferSize);
152198

tools/run_tests.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "test/test_blending_detailed.cpp"
2222
#include "test/test_input_buffer_verification.cpp"
2323
#include "test/test_lstm.cpp"
24+
#include "test/test_wavenet_configurable_gating.cpp"
2425

2526
int main()
2627
{
@@ -177,6 +178,9 @@ int main()
177178
test_input_buffer_verification::TestInputBufferVerification::test_buffer_stores_pre_activation_values();
178179
test_input_buffer_verification::TestInputBufferVerification::test_buffer_with_different_activations();
179180

181+
// Configurable gating/blending tests
182+
run_configurable_gating_tests();
183+
180184
test_get_dsp::test_gets_input_level();
181185
test_get_dsp::test_gets_output_level();
182186
test_get_dsp::test_null_input_level();

0 commit comments

Comments
 (0)