Skip to content

Commit 13208e2

Browse files
author
João Felipe Santos
committed
Updated how activations are initialized from config. Passing an object instead of a string is now supported.
1 parent 8987a3c commit 13208e2

File tree

11 files changed

+220
-35
lines changed

11 files changed

+220
-35
lines changed

NAM/activations.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,59 @@ nam::activations::Activation* nam::activations::Activation::get_activation(const
3131
return _activations[name];
3232
}
3333

34+
nam::activations::Activation* nam::activations::Activation::get_activation(const nlohmann::json& activation_config)
35+
{
36+
// If it's a string, use the existing string-based lookup
37+
if (activation_config.is_string())
38+
{
39+
std::string name = activation_config.get<std::string>();
40+
return get_activation(name);
41+
}
42+
43+
// If it's an object, parse the activation type and parameters
44+
if (activation_config.is_object())
45+
{
46+
std::string type = activation_config["type"].get<std::string>();
47+
48+
// Handle different activation types with parameters
49+
if (type == "PReLU")
50+
{
51+
if (activation_config.find("negative_slope") != activation_config.end())
52+
{
53+
float negative_slope = activation_config["negative_slope"].get<float>();
54+
return new ActivationPReLU(negative_slope);
55+
}
56+
else if (activation_config.find("negative_slopes") != activation_config.end())
57+
{
58+
std::vector<float> negative_slopes = activation_config["negative_slopes"].get<std::vector<float>>();
59+
return new ActivationPReLU(negative_slopes);
60+
}
61+
// If no parameters provided, use default
62+
return new ActivationPReLU(0.01);
63+
}
64+
else if (type == "LeakyReLU")
65+
{
66+
float negative_slope = activation_config.value("negative_slope", 0.01f);
67+
return new ActivationLeakyReLU(negative_slope);
68+
}
69+
else if (type == "LeakyHardTanh")
70+
{
71+
float min_val = activation_config.value("min_val", -1.0f);
72+
float max_val = activation_config.value("max_val", 1.0f);
73+
float min_slope = activation_config.value("min_slope", 0.01f);
74+
float max_slope = activation_config.value("max_slope", 0.01f);
75+
return new ActivationLeakyHardTanh(min_val, max_val, min_slope, max_slope);
76+
}
77+
else
78+
{
79+
// For other activation types without parameters, use the default string-based lookup
80+
return get_activation(type);
81+
}
82+
}
83+
84+
return nullptr;
85+
}
86+
3487
void nam::activations::Activation::enable_fast_tanh()
3588
{
3689
nam::activations::Activation::using_fast_tanh = true;

NAM/activations.h

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <unordered_map>
77
#include <Eigen/Dense>
88
#include <functional>
9+
#include "json.hpp"
910

1011
namespace nam
1112
{
@@ -102,6 +103,7 @@ class Activation
102103
virtual void apply(float* data, long size) {}
103104

104105
static Activation* get_activation(const std::string name);
106+
static Activation* get_activation(const nlohmann::json& activation_config);
105107
static void enable_fast_tanh();
106108
static void disable_fast_tanh();
107109
static bool using_fast_tanh;
@@ -226,20 +228,31 @@ class ActivationPReLU : public Activation
226228
void apply(Eigen::MatrixXf& matrix) override
227229
{
228230
// Matrix is organized as (channels, time_steps)
229-
int n_channels = negative_slopes.size();
230-
int actual_channels = matrix.rows();
231-
232-
// NOTE: check not done during runtime on release builds
233-
// model loader should make sure dimensions match
234-
assert(actual_channels == n_channels);
235-
231+
unsigned long actual_channels = static_cast<unsigned long>(matrix.rows());
232+
233+
// Prepare the slopes for the current matrix size
234+
std::vector<float> slopes_for_channels = negative_slopes;
235+
236+
if (slopes_for_channels.size() == 1 && actual_channels > 1)
237+
{
238+
// Broadcast single slope to all channels
239+
float slope = slopes_for_channels[0];
240+
slopes_for_channels.clear();
241+
slopes_for_channels.resize(actual_channels, slope);
242+
}
243+
else if (slopes_for_channels.size() != actual_channels)
244+
{
245+
// This should not happen in normal usage, but handle gracefully
246+
slopes_for_channels.resize(actual_channels, 0.01f); // Default slope
247+
}
248+
236249
// Apply each negative slope to its corresponding channel
237-
for (int channel = 0; channel < std::min(n_channels, actual_channels); channel++)
250+
for (unsigned long channel = 0; channel < actual_channels; channel++)
238251
{
239252
// Apply the negative slope to all time steps in this channel
240-
for (int time_step = 0; time_step < matrix.rows(); time_step++)
253+
for (int time_step = 0; time_step < matrix.cols(); time_step++)
241254
{
242-
matrix(channel, time_step) = leaky_relu(matrix(channel, time_step), negative_slopes[channel]);
255+
matrix(channel, time_step) = leaky_relu(matrix(channel, time_step), slopes_for_channels[channel]);
243256
}
244257
}
245258
}

NAM/convnet.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,15 @@ void nam::convnet::BatchNorm::process_(Eigen::MatrixXf& x, const long i_start, c
4848
}
4949

5050
void nam::convnet::ConvNetBlock::set_weights_(const int in_channels, const int out_channels, const int _dilation,
51-
const bool batchnorm, const std::string activation, const int groups,
51+
const bool batchnorm, const nlohmann::json activation_config, const int groups,
5252
std::vector<float>::iterator& weights)
5353
{
5454
this->_batchnorm = batchnorm;
5555
// HACK 2 kernel
5656
this->conv.set_size_and_weights_(in_channels, out_channels, 2, _dilation, !batchnorm, groups, weights);
5757
if (this->_batchnorm)
5858
this->batchnorm = BatchNorm(out_channels, weights);
59-
this->activation = activations::Activation::get_activation(activation);
59+
this->activation = activations::Activation::get_activation(activation_config);
6060
}
6161

6262
void nam::convnet::ConvNetBlock::SetMaxBufferSize(const int maxBufferSize)
@@ -173,7 +173,7 @@ void nam::convnet::_Head::process_(const Eigen::MatrixXf& input, Eigen::MatrixXf
173173
}
174174

175175
nam::convnet::ConvNet::ConvNet(const int in_channels, const int out_channels, const int channels,
176-
const std::vector<int>& dilations, const bool batchnorm, const std::string activation,
176+
const std::vector<int>& dilations, const bool batchnorm, const nlohmann::json activation_config,
177177
std::vector<float>& weights, const double expected_sample_rate, const int groups)
178178
: Buffer(in_channels, out_channels, *std::max_element(dilations.begin(), dilations.end()), expected_sample_rate)
179179
{
@@ -183,7 +183,7 @@ nam::convnet::ConvNet::ConvNet(const int in_channels, const int out_channels, co
183183
// First block takes in_channels input, subsequent blocks take channels input
184184
for (size_t i = 0; i < dilations.size(); i++)
185185
this->_blocks[i].set_weights_(
186-
i == 0 ? in_channels : channels, channels, dilations[i], batchnorm, activation, groups, it);
186+
i == 0 ? in_channels : channels, channels, dilations[i], batchnorm, activation_config, groups, it);
187187
// Only need _block_vals for the head (one entry)
188188
// Conv1D layers manage their own buffers now
189189
this->_block_vals.resize(1);
@@ -327,13 +327,13 @@ std::unique_ptr<nam::DSP> nam::convnet::Factory(const nlohmann::json& config, st
327327
const int channels = config["channels"];
328328
const std::vector<int> dilations = config["dilations"];
329329
const bool batchnorm = config["batchnorm"];
330-
const std::string activation = config["activation"];
330+
const nlohmann::json activation_config = config["activation"];
331331
const int groups = config.value("groups", 1); // defaults to 1
332332
// Default to 1 channel in/out for backward compatibility
333333
const int in_channels = config.value("in_channels", 1);
334334
const int out_channels = config.value("out_channels", 1);
335335
return std::make_unique<nam::convnet::ConvNet>(
336-
in_channels, out_channels, channels, dilations, batchnorm, activation, weights, expectedSampleRate, groups);
336+
in_channels, out_channels, channels, dilations, batchnorm, activation_config, weights, expectedSampleRate, groups);
337337
}
338338

339339
namespace

NAM/convnet.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class ConvNetBlock
4444
public:
4545
ConvNetBlock() {};
4646
void set_weights_(const int in_channels, const int out_channels, const int _dilation, const bool batchnorm,
47-
const std::string activation, const int groups, std::vector<float>::iterator& weights);
47+
const nlohmann::json activation_config, const int groups, std::vector<float>::iterator& weights);
4848
void SetMaxBufferSize(const int maxBufferSize);
4949
// Process input matrix directly (new API, similar to WaveNet)
5050
void Process(const Eigen::MatrixXf& input, const int num_frames);
@@ -78,7 +78,7 @@ class ConvNet : public Buffer
7878
{
7979
public:
8080
ConvNet(const int in_channels, const int out_channels, const int channels, const std::vector<int>& dilations,
81-
const bool batchnorm, const std::string activation, std::vector<float>& weights,
81+
const bool batchnorm, const nlohmann::json activation_config, std::vector<float>& weights,
8282
const double expected_sample_rate = -1.0, const int groups = 1);
8383
~ConvNet() = default;
8484

NAM/wavenet.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma
113113

114114
nam::wavenet::_LayerArray::_LayerArray(const int input_size, const int condition_size, const int head_size,
115115
const int channels, const int bottleneck, const int kernel_size,
116-
const std::vector<int>& dilations, const std::string activation,
116+
const std::vector<int>& dilations, const nlohmann::json activation_config,
117117
const GatingMode gating_mode, const bool head_bias, const int groups_input,
118118
const int groups_1x1, const Head1x1Params& head1x1_params,
119119
const std::string& secondary_activation)
@@ -122,7 +122,7 @@ nam::wavenet::_LayerArray::_LayerArray(const int input_size, const int condition
122122
, _bottleneck(bottleneck)
123123
{
124124
for (size_t i = 0; i < dilations.size(); i++)
125-
this->_layers.push_back(_Layer(condition_size, channels, bottleneck, kernel_size, dilations[i], activation,
125+
this->_layers.push_back(_Layer(condition_size, channels, bottleneck, kernel_size, dilations[i], activation_config,
126126
gating_mode, groups_input, groups_1x1, head1x1_params, secondary_activation));
127127
}
128128

@@ -273,7 +273,7 @@ nam::wavenet::WaveNet::WaveNet(const int in_channels,
273273
this->_layer_arrays.push_back(nam::wavenet::_LayerArray(
274274
layer_array_params[i].input_size, layer_array_params[i].condition_size, layer_array_params[i].head_size,
275275
layer_array_params[i].channels, layer_array_params[i].bottleneck, layer_array_params[i].kernel_size,
276-
layer_array_params[i].dilations, layer_array_params[i].activation, layer_array_params[i].gating_mode,
276+
layer_array_params[i].dilations, layer_array_params[i].activation_config, layer_array_params[i].gating_mode,
277277
layer_array_params[i].head_bias, layer_array_params[i].groups_input, layer_array_params[i].groups_1x1,
278278
layer_array_params[i].head1x1_params, layer_array_params[i].secondary_activation));
279279
if (i > 0)
@@ -477,7 +477,7 @@ std::unique_ptr<nam::DSP> nam::wavenet::Factory(const nlohmann::json& config, st
477477
const int head_size = layer_config["head_size"];
478478
const int kernel_size = layer_config["kernel_size"];
479479
const auto dilations = layer_config["dilations"];
480-
const std::string activation = layer_config["activation"].get<std::string>();
480+
const nlohmann::json activation_config = layer_config["activation"];
481481
// Parse gating mode - support both old "gated" boolean and new "gating_mode" string
482482
GatingMode gating_mode = GatingMode::NONE;
483483
std::string secondary_activation;
@@ -531,7 +531,7 @@ std::unique_ptr<nam::DSP> nam::wavenet::Factory(const nlohmann::json& config, st
531531
nam::wavenet::Head1x1Params head1x1_params(head1x1_active, head1x1_out_channels, head1x1_groups);
532532

533533
layer_array_params.push_back(nam::wavenet::LayerArrayParams(
534-
input_size, condition_size, head_size, channels, bottleneck, kernel_size, dilations, activation, gating_mode,
534+
input_size, condition_size, head_size, channels, bottleneck, kernel_size, dilations, activation_config, gating_mode,
535535
head_bias, groups, groups_1x1, head1x1_params, secondary_activation));
536536
}
537537
const bool with_head = !config["head"].is_null();

NAM/wavenet.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,12 @@ class _Layer
5050
public:
5151
// New constructor with GatingMode enum and configurable activations
5252
_Layer(const int condition_size, const int channels, const int bottleneck, const int kernel_size, const int dilation,
53-
const std::string activation, const GatingMode gating_mode, const int groups_input, const int groups_1x1,
53+
const nlohmann::json activation_config, const GatingMode gating_mode, const int groups_input, const int groups_1x1,
5454
const Head1x1Params& head1x1_params, const std::string& secondary_activation)
5555
: _conv(channels, (gating_mode != GatingMode::NONE) ? 2 * bottleneck : bottleneck, kernel_size, true, dilation)
5656
, _input_mixin(condition_size, (gating_mode != GatingMode::NONE) ? 2 * bottleneck : bottleneck, false)
5757
, _1x1(bottleneck, channels, groups_1x1)
58-
, _activation(activations::Activation::get_activation(activation)) // needs to support activations with parameters
58+
, _activation(activations::Activation::get_activation(activation_config)) // now supports activations with parameters
5959
, _gating_mode(gating_mode)
6060
, _bottleneck(bottleneck)
6161
{
@@ -148,7 +148,7 @@ class LayerArrayParams
148148
public:
149149
LayerArrayParams(const int input_size_, const int condition_size_, const int head_size_, const int channels_,
150150
const int bottleneck_, const int kernel_size_, const std::vector<int>&& dilations_,
151-
const std::string activation_, const GatingMode gating_mode_, const bool head_bias_,
151+
const nlohmann::json activation_, const GatingMode gating_mode_, const bool head_bias_,
152152
const int groups_input, const int groups_1x1_, const Head1x1Params& head1x1_params_,
153153
const std::string& secondary_activation_)
154154
: input_size(input_size_)
@@ -158,7 +158,7 @@ class LayerArrayParams
158158
, bottleneck(bottleneck_)
159159
, kernel_size(kernel_size_)
160160
, dilations(std::move(dilations_))
161-
, activation(activation_)
161+
, activation_config(activation_)
162162
, gating_mode(gating_mode_)
163163
, head_bias(head_bias_)
164164
, groups_input(groups_input)
@@ -175,7 +175,7 @@ class LayerArrayParams
175175
const int bottleneck;
176176
const int kernel_size;
177177
std::vector<int> dilations;
178-
const std::string activation;
178+
const nlohmann::json activation_config;
179179
const GatingMode gating_mode;
180180
const bool head_bias;
181181
const int groups_input;
@@ -191,7 +191,7 @@ class _LayerArray
191191
// New constructor with GatingMode enum and configurable activations
192192
_LayerArray(const int input_size, const int condition_size, const int head_size, const int channels,
193193
const int bottleneck, const int kernel_size, const std::vector<int>& dilations,
194-
const std::string activation, const GatingMode gating_mode, const bool head_bias, const int groups_input,
194+
const nlohmann::json activation_config, const GatingMode gating_mode, const bool head_bias, const int groups_input,
195195
const int groups_1x1, const Head1x1Params& head1x1_params, const std::string& secondary_activation);
196196

197197
void SetMaxBufferSize(const int maxBufferSize);

build/.gitignore

Lines changed: 0 additions & 4 deletions
This file was deleted.

tools/run_tests.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,15 @@ int main()
4444
// This is enforced by an assert so it doesn't need to be tested
4545
// test_activations::TestPReLU::test_wrong_number_of_channels();
4646

47+
// JSON activation parsing tests
48+
test_activations::TestJSONActivationParsing::test_string_activation();
49+
test_activations::TestJSONActivationParsing::test_json_prelu_single_slope();
50+
test_activations::TestJSONActivationParsing::test_json_prelu_multi_slope();
51+
test_activations::TestJSONActivationParsing::test_json_leaky_relu();
52+
test_activations::TestJSONActivationParsing::test_json_leaky_hardtanh();
53+
test_activations::TestJSONActivationParsing::test_json_unknown_activation();
54+
test_activations::TestJSONActivationParsing::test_functional_verification();
55+
4756
test_dsp::test_construct();
4857
test_dsp::test_get_input_level();
4958
test_dsp::test_get_output_level();

0 commit comments

Comments
 (0)