Skip to content

Commit 8031a06

Browse files
authored
[FEAUTRE] bottlenecks in WaveNet layers (#185)
* [REFINE] Update WaveNet Layer and LayerArray constructors to include bottleneck parameter - Modified the constructors of _Layer and _LayerArray to accept a new bottleneck parameter, enhancing the flexibility of the layer configurations. - Updated relevant method calls and test cases to reflect the new parameter, ensuring consistency across the codebase. - Adjusted JSON configuration handling to support the bottleneck parameter, maintaining backward compatibility with existing configurations. * [FEATURE] Introduce bottleneck parameter in Layer and LayerArray tests - Added a bottleneck parameter to the constructors of _Layer and _LayerArray in various test cases, enhancing the flexibility of layer configurations. - Implemented new test cases for layers with bottleneck configurations, including both gated and non-gated scenarios. - Updated existing tests to utilize the bottleneck parameter, ensuring comprehensive coverage and consistency across the codebase. * [REFINE] Temporarily disable bottleneck layer tests and update weight initialization comments - Commented out tests for bottleneck and gated bottleneck layers in run_tests.cpp while investigating a resize error. - Updated weight initialization logic in test_layer.cpp to clarify the layout for Conv1D and 1x1 convolutions, ensuring consistency with the new bottleneck parameter. - Adjusted comments for better clarity on weight patterns and dimensions in the test cases. * [REFINE] Update WaveNet Layer to utilize bottleneck parameter - Adjusted the WaveNet Layer's SetMaxBufferSize and Process methods to correctly use the bottleneck parameter for resizing internal buffers. - Updated the handling of activation functions to ensure they operate on the correct number of channels based on the bottleneck. - Modified test cases to reflect changes in the Layer constructor and ensure proper functionality with the bottleneck configuration. - Enhanced comments for clarity regarding the internal channel structure and weight initialization in tests. * [REFINE] Update headInput resizing in WaveNet layer test - Modified the headInput matrix resizing in test_layer.cpp to utilize the bottleneck parameter instead of channels, ensuring alignment with recent changes in the WaveNet layer configuration. - This adjustment enhances the accuracy of the test cases by reflecting the updated architecture that incorporates the bottleneck parameter. * Remove unused variable * Add test for Layer::Process() with bottleneck configuration - Introduced a new test case, test_layer_bottleneck_process_realtime_safe(), to validate that the Layer::Process() method operates correctly when the bottleneck parameter differs from the number of channels. - Ensured that the test checks for memory allocation during processing, maintaining real-time safety. - Updated run_tests.cpp to include this new test, enhancing coverage for bottleneck scenarios in the WaveNet layer.
1 parent e3e5154 commit 8031a06

File tree

8 files changed

+340
-68
lines changed

8 files changed

+340
-68
lines changed

NAM/wavenet.cpp

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@ void nam::wavenet::_Layer::SetMaxBufferSize(const int maxBufferSize)
1313
{
1414
_conv.SetMaxBufferSize(maxBufferSize);
1515
_input_mixin.SetMaxBufferSize(maxBufferSize);
16-
_z.resize(this->_conv.get_out_channels(), maxBufferSize);
16+
const long z_channels = this->_conv.get_out_channels(); // This is 2*bottleneck when gated, bottleneck when not
17+
_z.resize(z_channels, maxBufferSize);
1718
_1x1.SetMaxBufferSize(maxBufferSize);
1819
// Pre-allocate output buffers
1920
const long channels = this->get_channels();
2021
this->_output_next_layer.resize(channels, maxBufferSize);
21-
this->_output_head.resize(channels, maxBufferSize);
22+
// _output_head stores the activated portion: bottleneck rows (the actual bottleneck value, not doubled)
23+
this->_output_head.resize(this->_bottleneck, maxBufferSize);
2224
}
2325

2426
void nam::wavenet::_Layer::set_weights_(std::vector<float>::iterator& weights)
@@ -30,7 +32,7 @@ void nam::wavenet::_Layer::set_weights_(std::vector<float>::iterator& weights)
3032

3133
void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::MatrixXf& condition, const int num_frames)
3234
{
33-
const long channels = this->get_channels();
35+
const long bottleneck = this->_bottleneck; // Use the actual bottleneck value, not the doubled output channels
3436

3537
// Step 1: input convolutions
3638
this->_conv.Process(input, num_frames);
@@ -50,19 +52,20 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma
5052
// do this column-wise:
5153
for (int i = 0; i < num_frames; i++)
5254
{
53-
this->_activation->apply(this->_z.block(0, i, channels, 1));
55+
this->_activation->apply(this->_z.block(0, i, bottleneck, 1));
5456
// TODO Need to support other activation functions here instead of hardcoded sigmoid
55-
activations::Activation::get_activation("Sigmoid")->apply(this->_z.block(channels, i, channels, 1));
57+
activations::Activation::get_activation("Sigmoid")->apply(this->_z.block(bottleneck, i, bottleneck, 1));
5658
}
57-
this->_z.block(0, 0, channels, num_frames).array() *= this->_z.block(channels, 0, channels, num_frames).array();
58-
_1x1.process_(_z.topRows(channels), num_frames); // Might not be RT safe
59+
this->_z.block(0, 0, bottleneck, num_frames).array() *=
60+
this->_z.block(bottleneck, 0, bottleneck, num_frames).array();
61+
_1x1.process_(_z.topRows(bottleneck), num_frames); // Might not be RT safe
5962
}
6063

6164
// Store output to head (skip connection: activated conv output)
6265
if (!this->_gated)
6366
this->_output_head.leftCols(num_frames).noalias() = this->_z.leftCols(num_frames);
6467
else
65-
this->_output_head.leftCols(num_frames).noalias() = this->_z.topRows(channels).leftCols(num_frames);
68+
this->_output_head.leftCols(num_frames).noalias() = this->_z.topRows(bottleneck).leftCols(num_frames);
6669
// Store output to next layer (residual connection: input + _1x1 output)
6770
this->_output_next_layer.leftCols(num_frames).noalias() =
6871
input.leftCols(num_frames) + _1x1.GetOutput().leftCols(num_frames);
@@ -72,15 +75,17 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma
7275
// LayerArray =================================================================
7376

7477
nam::wavenet::_LayerArray::_LayerArray(const int input_size, const int condition_size, const int head_size,
75-
const int channels, const int kernel_size, const std::vector<int>& dilations,
76-
const std::string activation, const bool gated, const bool head_bias,
77-
const int groups_input, const int groups_1x1)
78+
const int channels, const int bottleneck, const int kernel_size,
79+
const std::vector<int>& dilations, const std::string activation,
80+
const bool gated, const bool head_bias, const int groups_input,
81+
const int groups_1x1)
7882
: _rechannel(input_size, channels, false)
79-
, _head_rechannel(channels, head_size, head_bias)
83+
, _head_rechannel(bottleneck, head_size, head_bias)
84+
, _bottleneck(bottleneck)
8085
{
8186
for (size_t i = 0; i < dilations.size(); i++)
82-
this->_layers.push_back(
83-
_Layer(condition_size, channels, kernel_size, dilations[i], activation, gated, groups_input, groups_1x1));
87+
this->_layers.push_back(_Layer(
88+
condition_size, channels, bottleneck, kernel_size, dilations[i], activation, gated, groups_input, groups_1x1));
8489
}
8590

8691
void nam::wavenet::_LayerArray::SetMaxBufferSize(const int maxBufferSize)
@@ -94,7 +99,7 @@ void nam::wavenet::_LayerArray::SetMaxBufferSize(const int maxBufferSize)
9499
// Pre-allocate output buffers
95100
const long channels = this->_get_channels();
96101
this->_layer_outputs.resize(channels, maxBufferSize);
97-
this->_head_inputs.resize(channels, maxBufferSize);
102+
this->_head_inputs.resize(this->_bottleneck, maxBufferSize);
98103
}
99104

100105

@@ -199,9 +204,9 @@ nam::wavenet::WaveNet::WaveNet(const std::vector<nam::wavenet::LayerArrayParams>
199204
{
200205
this->_layer_arrays.push_back(nam::wavenet::_LayerArray(
201206
layer_array_params[i].input_size, layer_array_params[i].condition_size, layer_array_params[i].head_size,
202-
layer_array_params[i].channels, layer_array_params[i].kernel_size, layer_array_params[i].dilations,
203-
layer_array_params[i].activation, layer_array_params[i].gated, layer_array_params[i].head_bias,
204-
layer_array_params[i].groups_input, layer_array_params[i].groups_1x1));
207+
layer_array_params[i].channels, layer_array_params[i].bottleneck, layer_array_params[i].kernel_size,
208+
layer_array_params[i].dilations, layer_array_params[i].activation, layer_array_params[i].gated,
209+
layer_array_params[i].head_bias, layer_array_params[i].groups_input, layer_array_params[i].groups_1x1));
205210
if (i > 0)
206211
if (layer_array_params[i].channels != layer_array_params[i - 1].head_size)
207212
{
@@ -300,8 +305,10 @@ std::unique_ptr<nam::DSP> nam::wavenet::Factory(const nlohmann::json& config, st
300305
nlohmann::json layer_config = config["layers"][i];
301306
const int groups = layer_config.value("groups", 1); // defaults to 1
302307
const int groups_1x1 = layer_config.value("groups_1x1", 1); // defaults to 1
308+
const int channels = layer_config["channels"];
309+
const int bottleneck = layer_config.value("bottleneck", channels); // defaults to channels if not present
303310
layer_array_params.push_back(nam::wavenet::LayerArrayParams(
304-
layer_config["input_size"], layer_config["condition_size"], layer_config["head_size"], layer_config["channels"],
311+
layer_config["input_size"], layer_config["condition_size"], layer_config["head_size"], channels, bottleneck,
305312
layer_config["kernel_size"], layer_config["dilations"], layer_config["activation"], layer_config["gated"],
306313
layer_config["head_bias"], groups, groups_1x1));
307314
}

NAM/wavenet.h

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@ namespace wavenet
1616
class _Layer
1717
{
1818
public:
19-
_Layer(const int condition_size, const int channels, const int kernel_size, const int dilation,
19+
_Layer(const int condition_size, const int channels, const int bottleneck, const int kernel_size, const int dilation,
2020
const std::string activation, const bool gated, const int groups_input, const int groups_1x1)
21-
: _conv(channels, gated ? 2 * channels : channels, kernel_size, true, dilation, groups_input)
22-
, _input_mixin(condition_size, gated ? 2 * channels : channels, false)
23-
, _1x1(channels, channels, true, groups_1x1)
21+
: _conv(channels, gated ? 2 * bottleneck : bottleneck, kernel_size, true, dilation, groups_input)
22+
, _input_mixin(condition_size, gated ? 2 * bottleneck : bottleneck, false)
23+
, _1x1(bottleneck, channels, true, groups_1x1)
2424
, _activation(activations::Activation::get_activation(activation)) // needs to support activations with parameters
25-
, _gated(gated) {};
25+
, _gated(gated)
26+
, _bottleneck(bottleneck) {};
2627
// Resize all arrays to be able to process `maxBufferSize` frames.
2728
void SetMaxBufferSize(const int maxBufferSize);
2829
// Set the parameters of this module
@@ -71,18 +72,21 @@ class _Layer
7172

7273
activations::Activation* _activation;
7374
const bool _gated;
75+
const int _bottleneck; // Internal channel count (not doubled when gated)
7476
};
7577

7678
class LayerArrayParams
7779
{
7880
public:
7981
LayerArrayParams(const int input_size_, const int condition_size_, const int head_size_, const int channels_,
80-
const int kernel_size_, const std::vector<int>&& dilations_, const std::string activation_,
81-
const bool gated_, const bool head_bias_, const int groups_input, const int groups_1x1_)
82+
const int bottleneck_, const int kernel_size_, const std::vector<int>&& dilations_,
83+
const std::string activation_, const bool gated_, const bool head_bias_, const int groups_input,
84+
const int groups_1x1_)
8285
: input_size(input_size_)
8386
, condition_size(condition_size_)
8487
, head_size(head_size_)
8588
, channels(channels_)
89+
, bottleneck(bottleneck_)
8690
, kernel_size(kernel_size_)
8791
, dilations(std::move(dilations_))
8892
, activation(activation_)
@@ -97,6 +101,7 @@ class LayerArrayParams
97101
const int condition_size;
98102
const int head_size;
99103
const int channels;
104+
const int bottleneck;
100105
const int kernel_size;
101106
std::vector<int> dilations;
102107
const std::string activation;
@@ -111,8 +116,9 @@ class _LayerArray
111116
{
112117
public:
113118
_LayerArray(const int input_size, const int condition_size, const int head_size, const int channels,
114-
const int kernel_size, const std::vector<int>& dilations, const std::string activation, const bool gated,
115-
const bool head_bias, const int groups_input, const int groups_1x1);
119+
const int bottleneck, const int kernel_size, const std::vector<int>& dilations,
120+
const std::string activation, const bool gated, const bool head_bias, const int groups_input,
121+
const int groups_1x1);
116122

117123
void SetMaxBufferSize(const int maxBufferSize);
118124

@@ -150,12 +156,15 @@ class _LayerArray
150156
std::vector<_Layer> _layers;
151157
// Output from last layer (for next layer array)
152158
Eigen::MatrixXf _layer_outputs;
153-
// Accumulated head inputs from all layers
159+
// Accumulated head inputs from all layers (bottleneck channels)
154160
Eigen::MatrixXf _head_inputs;
155161

156-
// Rechannel for the head
162+
// Rechannel for the head (bottleneck -> head_size)
157163
Conv1x1 _head_rechannel;
158164

165+
// Bottleneck size (internal channel count)
166+
const int _bottleneck;
167+
159168
long _get_channels() const;
160169
// Common processing logic after head inputs are set
161170
void ProcessInner(const Eigen::MatrixXf& layer_inputs, const Eigen::MatrixXf& condition, const int num_frames);

tools/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ set_target_properties(run_tests PROPERTIES COMPILE_OPTIONS "-O0")
2020
# Release/RelWithDebInfo/MinSizeRel build types automatically define NDEBUG
2121
# We use a compile option to undefine it, which works on GCC, Clang, and MSVC
2222
target_compile_options(run_tests PRIVATE
23-
$<$<OR:$<CONFIG:Release>,$<CONFIG:RelWithDebInfo>,$<CONFIG:MinSizeRel>>:-U_NDEBUG>
23+
$<$<OR:$<CONFIG:Release>,$<CONFIG:RelWithDebInfo>,$<CONFIG:MinSizeRel>>:-UNDEBUG>
2424
)
2525

2626
source_group(NAM ${CMAKE_CURRENT_SOURCE_DIR} FILES ${NAM_SOURCES})

tools/run_tests.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ int main()
104104
test_wavenet::test_layer::test_non_gated_layer();
105105
test_wavenet::test_layer::test_layer_activations();
106106
test_wavenet::test_layer::test_layer_multichannel();
107+
test_wavenet::test_layer::test_layer_bottleneck();
108+
test_wavenet::test_layer::test_layer_bottleneck_gated();
107109
test_wavenet::test_layer_array::test_layer_array_basic();
108110
test_wavenet::test_layer_array::test_layer_array_receptive_field();
109111
test_wavenet::test_layer_array::test_layer_array_with_head_input();
@@ -118,6 +120,7 @@ int main()
118120
test_wavenet::test_conv1d_grouped_process_realtime_safe();
119121
test_wavenet::test_conv1d_grouped_dilated_process_realtime_safe();
120122
test_wavenet::test_layer_process_realtime_safe();
123+
test_wavenet::test_layer_bottleneck_process_realtime_safe();
121124
test_wavenet::test_layer_grouped_process_realtime_safe();
122125
test_wavenet::test_layer_array_process_realtime_safe();
123126
test_wavenet::test_process_realtime_safe();

tools/test/test_wavenet/test_full.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ void test_wavenet_model()
1919
const int condition_size = 1;
2020
const int head_size = 1;
2121
const int channels = 1;
22+
const int bottleneck = channels;
2223
const int kernel_size = 1;
2324
std::vector<int> dilations{1};
2425
const std::string activation = "ReLU";
@@ -29,7 +30,7 @@ void test_wavenet_model()
2930
const int groups = 1;
3031
const int groups_1x1 = 1;
3132

32-
nam::wavenet::LayerArrayParams params(input_size, condition_size, head_size, channels, kernel_size,
33+
nam::wavenet::LayerArrayParams params(input_size, condition_size, head_size, channels, bottleneck, kernel_size,
3334
std::move(dilations), activation, gated, head_bias, groups, groups_1x1);
3435
std::vector<nam::wavenet::LayerArrayParams> layer_array_params;
3536
layer_array_params.push_back(std::move(params));
@@ -85,15 +86,16 @@ void test_wavenet_multiple_arrays()
8586
std::vector<nam::wavenet::LayerArrayParams> layer_array_params;
8687
// First array
8788
std::vector<int> dilations1{1};
89+
const int bottleneck = channels;
8890
const int groups_1x1 = 1;
8991
layer_array_params.push_back(nam::wavenet::LayerArrayParams(input_size, condition_size, head_size, channels,
90-
kernel_size, std::move(dilations1), activation, gated,
91-
head_bias, groups, groups_1x1));
92+
bottleneck, kernel_size, std::move(dilations1), activation,
93+
gated, head_bias, groups, groups_1x1));
9294
// Second array (head_size of first must match channels of second)
9395
std::vector<int> dilations2{1};
9496
layer_array_params.push_back(nam::wavenet::LayerArrayParams(head_size, condition_size, head_size, channels,
95-
kernel_size, std::move(dilations2), activation, gated,
96-
head_bias, groups, groups_1x1));
97+
bottleneck, kernel_size, std::move(dilations2), activation,
98+
gated, head_bias, groups, groups_1x1));
9799

98100
std::vector<float> weights;
99101
// Array 0: rechannel, layer, head_rechannel
@@ -127,6 +129,7 @@ void test_wavenet_zero_input()
127129
const int condition_size = 1;
128130
const int head_size = 1;
129131
const int channels = 1;
132+
const int bottleneck = channels;
130133
const int kernel_size = 1;
131134
std::vector<int> dilations{1};
132135
const std::string activation = "ReLU";
@@ -137,7 +140,7 @@ void test_wavenet_zero_input()
137140
const int groups = 1;
138141
const int groups_1x1 = 1;
139142

140-
nam::wavenet::LayerArrayParams params(input_size, condition_size, head_size, channels, kernel_size,
143+
nam::wavenet::LayerArrayParams params(input_size, condition_size, head_size, channels, bottleneck, kernel_size,
141144
std::move(dilations), activation, gated, head_bias, groups, groups_1x1);
142145
std::vector<nam::wavenet::LayerArrayParams> layer_array_params;
143146
layer_array_params.push_back(std::move(params));
@@ -168,6 +171,7 @@ void test_wavenet_different_buffer_sizes()
168171
const int condition_size = 1;
169172
const int head_size = 1;
170173
const int channels = 1;
174+
const int bottleneck = channels;
171175
const int kernel_size = 1;
172176
std::vector<int> dilations{1};
173177
const std::string activation = "ReLU";
@@ -178,7 +182,7 @@ void test_wavenet_different_buffer_sizes()
178182
const int groups = 1;
179183
const int groups_1x1 = 1;
180184

181-
nam::wavenet::LayerArrayParams params(input_size, condition_size, head_size, channels, kernel_size,
185+
nam::wavenet::LayerArrayParams params(input_size, condition_size, head_size, channels, bottleneck, kernel_size,
182186
std::move(dilations), activation, gated, head_bias, groups, groups_1x1);
183187
std::vector<nam::wavenet::LayerArrayParams> layer_array_params;
184188
layer_array_params.push_back(std::move(params));
@@ -210,6 +214,7 @@ void test_wavenet_prewarm()
210214
const int condition_size = 1;
211215
const int head_size = 1;
212216
const int channels = 1;
217+
const int bottleneck = channels;
213218
const int kernel_size = 3;
214219
std::vector<int> dilations{1, 2, 4};
215220
const std::string activation = "ReLU";
@@ -220,7 +225,7 @@ void test_wavenet_prewarm()
220225
const int groups = 1;
221226
const int groups_1x1 = 1;
222227

223-
nam::wavenet::LayerArrayParams params(input_size, condition_size, head_size, channels, kernel_size,
228+
nam::wavenet::LayerArrayParams params(input_size, condition_size, head_size, channels, bottleneck, kernel_size,
224229
std::move(dilations), activation, gated, head_bias, groups, groups_1x1);
225230
std::vector<nam::wavenet::LayerArrayParams> layer_array_params;
226231
layer_array_params.push_back(std::move(params));

0 commit comments

Comments
 (0)