Skip to content

Commit 20a04fc

Browse files
jfsantosJoão Felipe Santosclaude
authored
Refactor factory into separate config parser and unified create_dsp() construction path (#227)
* Extract config structs and unified create_dsp() construction path Add typed config structs (LinearConfig, LSTMConfig, ConvNetConfig, WaveNetConfig) and parse_config_json() functions per architecture. Introduce ModelConfig variant, ModelMetadata, and create_dsp() for unified model construction independent of JSON parsing. Refactor get_dsp() to use the new unified path. Register Linear factory (was previously missing). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Replace ModelConfig variant with extensible abstract base class ModelConfig is now an abstract class with a virtual create() method instead of a std::variant over built-in config types. Each architecture's config struct inherits from ModelConfig and implements create(). A single ConfigParserRegistry maps architecture names to parser functions that return unique_ptr<ModelConfig>. Both built-in architectures (via ConfigParserHelper) and external ones (via factory::Helper, which now wraps into FactoryConfig) register in the same registry. This can be refactored such that we only have ConfigParserHelper, but I wanted to keep the API backwards-compatible. * Restoring missing comments --------- Co-authored-by: João Felipe Santos <santosjf@pm.me> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent a0c93c0 commit 20a04fc

11 files changed

Lines changed: 422 additions & 183 deletions

File tree

NAM/convnet.cpp

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -322,25 +322,40 @@ void nam::convnet::ConvNet::_rewind_buffers_()
322322
this->Buffer::_rewind_buffers_();
323323
}
324324

325-
// Factory
326-
std::unique_ptr<nam::DSP> nam::convnet::Factory(const nlohmann::json& config, std::vector<float>& weights,
327-
const double expectedSampleRate)
325+
// Config parser
326+
nam::convnet::ConvNetConfig nam::convnet::parse_config_json(const nlohmann::json& config)
328327
{
329-
const int channels = config["channels"];
330-
const std::vector<int> dilations = config["dilations"];
331-
const bool batchnorm = config["batchnorm"];
328+
ConvNetConfig c;
329+
c.channels = config["channels"];
330+
c.dilations = config["dilations"].get<std::vector<int>>();
331+
c.batchnorm = config["batchnorm"];
332332
// Parse JSON into typed ActivationConfig at model loading boundary
333-
const activations::ActivationConfig activation_config =
334-
activations::ActivationConfig::from_json(config["activation"]);
335-
const int groups = config.value("groups", 1); // defaults to 1
333+
c.activation = activations::ActivationConfig::from_json(config["activation"]);
334+
c.groups = config.value("groups", 1); // defaults to 1
336335
// Default to 1 channel in/out for backward compatibility
337-
const int in_channels = config.value("in_channels", 1);
338-
const int out_channels = config.value("out_channels", 1);
339-
return std::make_unique<nam::convnet::ConvNet>(
340-
in_channels, out_channels, channels, dilations, batchnorm, activation_config, weights, expectedSampleRate, groups);
336+
c.in_channels = config.value("in_channels", 1);
337+
c.out_channels = config.value("out_channels", 1);
338+
return c;
339+
}
340+
341+
// ConvNetConfig::create()
342+
std::unique_ptr<nam::DSP> nam::convnet::ConvNetConfig::create(std::vector<float> weights, double sampleRate)
343+
{
344+
return std::make_unique<nam::convnet::ConvNet>(in_channels, out_channels, channels, dilations, batchnorm, activation,
345+
weights, sampleRate, groups);
346+
}
347+
348+
// Config parser for ConfigParserRegistry
349+
std::unique_ptr<nam::ModelConfig> nam::convnet::create_config(const nlohmann::json& config, double sampleRate)
350+
{
351+
(void)sampleRate;
352+
auto c = std::make_unique<ConvNetConfig>();
353+
auto parsed = parse_config_json(config);
354+
*c = parsed;
355+
return c;
341356
}
342357

343358
namespace
344359
{
345-
static nam::factory::Helper _register_ConvNet("ConvNet", nam::convnet::Factory);
360+
static nam::ConfigParserHelper _register_ConvNet("ConvNet", nam::convnet::create_config);
346361
}

NAM/convnet.h

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -165,13 +165,27 @@ class ConvNet : public Buffer
165165
int PrewarmSamples() override { return mPrewarmSamples; };
166166
};
167167

168-
/// \brief Factory function to instantiate ConvNet from JSON
168+
/// \brief Configuration for a ConvNet model
169+
struct ConvNetConfig : public ModelConfig
170+
{
171+
int channels;
172+
std::vector<int> dilations;
173+
bool batchnorm;
174+
activations::ActivationConfig activation;
175+
int groups;
176+
int in_channels;
177+
int out_channels;
178+
179+
std::unique_ptr<DSP> create(std::vector<float> weights, double sampleRate) override;
180+
};
181+
182+
/// \brief Parse ConvNet configuration from JSON
169183
/// \param config JSON configuration object
170-
/// \param weights Model weights vector
171-
/// \param expectedSampleRate Expected sample rate in Hz (-1.0 if unknown)
172-
/// \return Unique pointer to a DSP object (ConvNet instance)
173-
std::unique_ptr<DSP> Factory(const nlohmann::json& config, std::vector<float>& weights,
174-
const double expectedSampleRate);
184+
/// \return ConvNetConfig
185+
ConvNetConfig parse_config_json(const nlohmann::json& config);
186+
187+
/// \brief Config parser for ConfigParserRegistry
188+
std::unique_ptr<ModelConfig> create_config(const nlohmann::json& config, double sampleRate);
175189

176190
}; // namespace convnet
177191
}; // namespace nam

NAM/dsp.cpp

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -301,16 +301,38 @@ void nam::Linear::process(NAM_SAMPLE** input, NAM_SAMPLE** output, const int num
301301
nam::Buffer::_advance_input_buffer_(num_frames);
302302
}
303303

304-
// Factory
305-
std::unique_ptr<nam::DSP> nam::linear::Factory(const nlohmann::json& config, std::vector<float>& weights,
306-
const double expectedSampleRate)
304+
// Config parser
305+
nam::linear::LinearConfig nam::linear::parse_config_json(const nlohmann::json& config)
307306
{
308-
const int receptive_field = config["receptive_field"];
309-
const bool bias = config["bias"];
307+
LinearConfig c;
308+
c.receptive_field = config["receptive_field"];
309+
c.bias = config["bias"];
310310
// Default to 1 channel in/out for backward compatibility
311-
const int in_channels = config.value("in_channels", 1);
312-
const int out_channels = config.value("out_channels", 1);
313-
return std::make_unique<nam::Linear>(in_channels, out_channels, receptive_field, bias, weights, expectedSampleRate);
311+
c.in_channels = config.value("in_channels", 1);
312+
c.out_channels = config.value("out_channels", 1);
313+
return c;
314+
}
315+
316+
// LinearConfig::create()
317+
std::unique_ptr<nam::DSP> nam::linear::LinearConfig::create(std::vector<float> weights, double sampleRate)
318+
{
319+
return std::make_unique<nam::Linear>(in_channels, out_channels, receptive_field, bias, weights, sampleRate);
320+
}
321+
322+
// Config parser for ConfigParserRegistry
323+
std::unique_ptr<nam::ModelConfig> nam::linear::create_config(const nlohmann::json& config, double sampleRate)
324+
{
325+
(void)sampleRate;
326+
auto c = std::make_unique<LinearConfig>();
327+
auto parsed = parse_config_json(config);
328+
*c = parsed;
329+
return c;
330+
}
331+
332+
// Register the config parser
333+
namespace
334+
{
335+
static nam::ConfigParserHelper _register_Linear("Linear", nam::linear::create_config);
314336
}
315337

316338
// NN modules =================================================================

NAM/dsp.h

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include "activations.h"
1313
#include "json.hpp"
14+
#include "model_config.h"
1415

1516
#ifdef NAM_SAMPLE_FLOAT
1617
#define NAM_SAMPLE float
@@ -258,13 +259,28 @@ class Linear : public Buffer
258259

259260
namespace linear
260261
{
261-
/// \brief Factory function to instantiate Linear model from JSON
262+
263+
/// \brief Configuration for a Linear model
264+
struct LinearConfig : public ModelConfig
265+
{
266+
int receptive_field;
267+
bool bias;
268+
int in_channels;
269+
int out_channels;
270+
271+
std::unique_ptr<DSP> create(std::vector<float> weights, double sampleRate) override;
272+
};
273+
274+
/// \brief Parse Linear configuration from JSON
275+
/// \param config JSON configuration object
276+
/// \return LinearConfig
277+
LinearConfig parse_config_json(const nlohmann::json& config);
278+
279+
/// \brief Config parser for ConfigParserRegistry
262280
/// \param config JSON configuration object
263-
/// \param weights Model weights vector
264-
/// \param expectedSampleRate Expected sample rate in Hz (-1.0 if unknown)
265-
/// \return Unique pointer to a DSP object (Linear instance)
266-
std::unique_ptr<DSP> Factory(const nlohmann::json& config, std::vector<float>& weights,
267-
const double expectedSampleRate);
281+
/// \param sampleRate Expected sample rate in Hz
282+
/// \return unique_ptr<ModelConfig> wrapping a LinearConfig
283+
std::unique_ptr<ModelConfig> create_config(const nlohmann::json& config, double sampleRate);
268284
} // namespace linear
269285

270286
// NN modules =================================================================

NAM/get_dsp.cpp

Lines changed: 53 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,12 @@
22
#include <iostream>
33
#include <sstream>
44
#include <stdexcept>
5-
#include <unordered_set>
65

76
#include "dsp.h"
87
#include "registry.h"
98
#include "json.hpp"
10-
#include "lstm.h"
11-
#include "convnet.h"
12-
#include "wavenet.h"
139
#include "get_dsp.h"
10+
#include "model_config.h"
1411

1512
namespace nam
1613
{
@@ -146,62 +143,69 @@ std::unique_ptr<DSP> get_dsp(const nlohmann::json& config, dspData& returnedConf
146143
return get_dsp(conf);
147144
}
148145

149-
struct OptionalValue
146+
// =============================================================================
147+
// Unified construction path
148+
// =============================================================================
149+
150+
std::unique_ptr<ModelConfig> parse_model_config_json(const std::string& architecture, const nlohmann::json& config,
151+
double sample_rate)
152+
{
153+
return ConfigParserRegistry::instance().parse(architecture, config, sample_rate);
154+
}
155+
156+
namespace
150157
{
151-
bool have = false;
152-
double value = 0.0;
153-
};
158+
159+
void apply_metadata(DSP& dsp, const ModelMetadata& metadata)
160+
{
161+
if (metadata.loudness.has_value())
162+
dsp.SetLoudness(metadata.loudness.value());
163+
if (metadata.input_level.has_value())
164+
dsp.SetInputLevel(metadata.input_level.value());
165+
if (metadata.output_level.has_value())
166+
dsp.SetOutputLevel(metadata.output_level.value());
167+
}
168+
169+
} // anonymous namespace
170+
171+
std::unique_ptr<DSP> create_dsp(std::unique_ptr<ModelConfig> config, std::vector<float> weights,
172+
const ModelMetadata& metadata)
173+
{
174+
auto out = config->create(std::move(weights), metadata.sample_rate);
175+
apply_metadata(*out, metadata);
176+
// "pre-warm" the model to settle initial conditions
177+
// Can this be removed now that it's part of Reset()?
178+
out->prewarm();
179+
return out;
180+
}
181+
182+
// =============================================================================
183+
// get_dsp(dspData&) — now uses unified path
184+
// =============================================================================
154185

155186
std::unique_ptr<DSP> get_dsp(dspData& conf)
156187
{
157188
verify_config_version(conf.version);
158189

159-
auto& architecture = conf.architecture;
160-
nlohmann::json& config = conf.config;
161-
std::vector<float>& weights = conf.weights;
162-
OptionalValue loudness, inputLevel, outputLevel;
163-
164-
auto AssignOptional = [&conf](const std::string key, OptionalValue& v) {
165-
if (conf.metadata.find(key) != conf.metadata.end())
166-
{
167-
if (!conf.metadata[key].is_null())
168-
{
169-
v.value = conf.metadata[key];
170-
v.have = true;
171-
}
172-
}
173-
};
190+
// Extract metadata from JSON
191+
ModelMetadata metadata;
192+
metadata.version = conf.version;
193+
metadata.sample_rate = conf.expected_sample_rate;
174194

175195
if (!conf.metadata.is_null())
176196
{
177-
AssignOptional("loudness", loudness);
178-
AssignOptional("input_level_dbu", inputLevel);
179-
AssignOptional("output_level_dbu", outputLevel);
180-
}
181-
const double expectedSampleRate = conf.expected_sample_rate;
182-
183-
// Initialize using registry-based factory
184-
std::unique_ptr<DSP> out =
185-
nam::factory::FactoryRegistry::instance().create(architecture, config, weights, expectedSampleRate);
186-
187-
if (loudness.have)
188-
{
189-
out->SetLoudness(loudness.value);
190-
}
191-
if (inputLevel.have)
192-
{
193-
out->SetInputLevel(inputLevel.value);
194-
}
195-
if (outputLevel.have)
196-
{
197-
out->SetOutputLevel(outputLevel.value);
197+
auto extract = [&conf](const std::string& key) -> std::optional<double> {
198+
if (conf.metadata.find(key) != conf.metadata.end() && !conf.metadata[key].is_null())
199+
return conf.metadata[key].get<double>();
200+
return std::nullopt;
201+
};
202+
metadata.loudness = extract("loudness");
203+
metadata.input_level = extract("input_level_dbu");
204+
metadata.output_level = extract("output_level_dbu");
198205
}
199206

200-
// "pre-warm" the model to settle initial conditions
201-
// Can this be removed now that it's part of Reset()?
202-
out->prewarm();
203-
204-
return out;
207+
auto model_config = ConfigParserRegistry::instance().parse(conf.architecture, conf.config, conf.expected_sample_rate);
208+
return create_dsp(std::move(model_config), std::move(conf.weights), metadata);
205209
}
206210

207211
double get_sample_rate_from_nam_file(const nlohmann::json& j)

NAM/lstm.cpp

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -163,22 +163,38 @@ void nam::lstm::LSTM::_process_sample()
163163
this->_output.noalias() += this->_head_bias;
164164
}
165165

166-
// Factory to instantiate from nlohmann json
167-
std::unique_ptr<nam::DSP> nam::lstm::Factory(const nlohmann::json& config, std::vector<float>& weights,
168-
const double expectedSampleRate)
166+
// Config parser
167+
nam::lstm::LSTMConfig nam::lstm::parse_config_json(const nlohmann::json& config)
169168
{
170-
const int num_layers = config["num_layers"];
171-
const int input_size = config["input_size"];
172-
const int hidden_size = config["hidden_size"];
169+
LSTMConfig c;
170+
c.num_layers = config["num_layers"];
171+
c.input_size = config["input_size"];
172+
c.hidden_size = config["hidden_size"];
173173
// Default to 1 channel in/out for backward compatibility
174-
const int in_channels = config.value("in_channels", 1);
175-
const int out_channels = config.value("out_channels", 1);
176-
return std::make_unique<nam::lstm::LSTM>(
177-
in_channels, out_channels, num_layers, input_size, hidden_size, weights, expectedSampleRate);
174+
c.in_channels = config.value("in_channels", 1);
175+
c.out_channels = config.value("out_channels", 1);
176+
return c;
178177
}
179178

180-
// Register the factory
179+
// LSTMConfig::create()
180+
std::unique_ptr<nam::DSP> nam::lstm::LSTMConfig::create(std::vector<float> weights, double sampleRate)
181+
{
182+
return std::make_unique<nam::lstm::LSTM>(in_channels, out_channels, num_layers, input_size, hidden_size, weights,
183+
sampleRate);
184+
}
185+
186+
// Config parser for ConfigParserRegistry
187+
std::unique_ptr<nam::ModelConfig> nam::lstm::create_config(const nlohmann::json& config, double sampleRate)
188+
{
189+
(void)sampleRate;
190+
auto c = std::make_unique<LSTMConfig>();
191+
auto parsed = parse_config_json(config);
192+
*c = parsed;
193+
return c;
194+
}
195+
196+
// Register the config parser
181197
namespace
182198
{
183-
static nam::factory::Helper _register_LSTM("LSTM", nam::lstm::Factory);
199+
static nam::ConfigParserHelper _register_LSTM("LSTM", nam::lstm::create_config);
184200
}

NAM/lstm.h

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,25 @@ class LSTM : public DSP
9595
Eigen::VectorXf _output;
9696
};
9797

98-
/// \brief Factory function to instantiate LSTM from JSON
98+
/// \brief Configuration for an LSTM model
99+
struct LSTMConfig : public ModelConfig
100+
{
101+
int num_layers;
102+
int input_size;
103+
int hidden_size;
104+
int in_channels;
105+
int out_channels;
106+
107+
std::unique_ptr<DSP> create(std::vector<float> weights, double sampleRate) override;
108+
};
109+
110+
/// \brief Parse LSTM configuration from JSON
99111
/// \param config JSON configuration object
100-
/// \param weights Model weights vector
101-
/// \param expectedSampleRate Expected sample rate in Hz (-1.0 if unknown)
102-
/// \return Unique pointer to a DSP object (LSTM instance)
103-
std::unique_ptr<DSP> Factory(const nlohmann::json& config, std::vector<float>& weights,
104-
const double expectedSampleRate);
112+
/// \return LSTMConfig
113+
LSTMConfig parse_config_json(const nlohmann::json& config);
114+
115+
/// \brief Config parser for ConfigParserRegistry
116+
std::unique_ptr<ModelConfig> create_config(const nlohmann::json& config, double sampleRate);
105117

106118
}; // namespace lstm
107119
}; // namespace nam

0 commit comments

Comments
 (0)