Skip to content

Commit d547322

Browse files
committed
Extend model's weight control
Until now, only the probabilities of alternation alternatives could be defined externally through a JSON file. With this change, the scope of model weight control is broadened to include quantifiers as well. The JSON lookup is now split into two sections: `alts` for alternation alternatives and `quants` for quantifiers. The file is still provided through the existing `--weights` command-line option. This commit also updates the C++ implementation with this new feature, along with adding the missing CLI support and parsing logic to match the Python version.
1 parent 14947b3 commit d547322

File tree

20 files changed

+124
-78
lines changed

20 files changed

+124
-78
lines changed

grammarinator-cxx/libgrammarinator/include/grammarinator/runtime/DefaultModel.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2025 Renata Hodovan, Akos Kiss.
1+
// Copyright (c) 2025-2026 Renata Hodovan, Akos Kiss.
22
//
33
// Licensed under the BSD 3-Clause License
44
// <LICENSE.rst or https://opensource.org/licenses/BSD-3-Clause>.
@@ -35,8 +35,8 @@ class DefaultModel : public Model {
3535
return grammarinator::util::random_weighted_choice(weights);
3636
}
3737

38-
bool quantify(const Rule* node, int idx, int cnt, int start, int stop) override {
39-
return grammarinator::util::random_real(0.0, 1.0) > 0.5;
38+
bool quantify(const Rule* node, int idx, int cnt, int start, int stop, double prob = 0.5) override {
39+
return grammarinator::util::random_real(0.0, 1.0) < prob;
4040
}
4141

4242
std::string charset(const Rule* node, int idx, const std::vector<std::string>& chars) override {

grammarinator-cxx/libgrammarinator/include/grammarinator/runtime/Model.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2025 Renata Hodovan, Akos Kiss.
1+
// Copyright (c) 2025-2026 Renata Hodovan, Akos Kiss.
22
//
33
// Licensed under the BSD 3-Clause License
44
// <LICENSE.rst or https://opensource.org/licenses/BSD-3-Clause>.
@@ -30,7 +30,7 @@ class Model {
3030
virtual ~Model() = default;
3131

3232
virtual int choice(const Rule* node, int idx, const std::vector<double>& weights) = 0;
33-
virtual bool quantify(const Rule* node, int idx, int cnt, int start, int stop) = 0;
33+
virtual bool quantify(const Rule* node, int idx, int cnt, int start, int stop, double prob = 0.5) = 0;
3434
virtual std::string charset(const Rule* node, int idx, const std::vector<std::string>& chars) = 0;
3535
};
3636

grammarinator-cxx/libgrammarinator/include/grammarinator/runtime/WeightedModel.hpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2025 Renata Hodovan, Akos Kiss.
1+
// Copyright (c) 2025-2026 Renata Hodovan, Akos Kiss.
22
//
33
// Licensed under the BSD 3-Clause License
44
// <LICENSE.rst or https://opensource.org/licenses/BSD-3-Clause>.
@@ -23,15 +23,19 @@ namespace runtime {
2323
*/
2424
class WeightedModel : public Model {
2525
public:
26-
using WeightMapKey = std::tuple<std::string, size_t, size_t>;
27-
using WeightMap = std::map<WeightMapKey, double>;
26+
using AltMapKey = std::tuple<std::string, size_t, size_t>;
27+
using AltMap = std::map<AltMapKey, double>;
28+
29+
using QuantMapKey = std::tuple<std::string, size_t>;
30+
using QuantMap = std::map<QuantMapKey, double>;
2831

2932
private:
3033
Model* model;
31-
const WeightMap& weights;
34+
const AltMap& weights;
35+
const QuantMap& probs;
3236

3337
public:
34-
explicit WeightedModel(Model* model, const WeightMap& weights = {}) noexcept : Model(), model(model), weights(weights) {}
38+
explicit WeightedModel(Model* model, const AltMap& weights = {}, const QuantMap& probs = {}) noexcept : Model(), model(model), weights(weights), probs(probs) {}
3539
WeightedModel(const WeightedModel& other) = delete;
3640
WeightedModel& operator=(const WeightedModel& other) = delete;
3741
WeightedModel(WeightedModel&& other) = delete;
@@ -41,14 +45,15 @@ class WeightedModel : public Model {
4145
int choice(const Rule* node, int idx, const std::vector<double>& cweights) override {
4246
std::vector<double> multiplied_weights(cweights.size());
4347
for (size_t i = 0; i < cweights.size(); ++i) {
44-
auto it = weights.find(WeightMapKey(node->name, idx, i));
48+
auto it = weights.find(AltMapKey(node->name, idx, i));
4549
multiplied_weights[i] = cweights[i] * (it != weights.end() ? it->second : 1.0);
4650
}
4751
return model->choice(node, idx, multiplied_weights);
4852
}
4953

50-
bool quantify(const Rule* node, int idx, int cnt, int start, int stop) override {
51-
return model->quantify(node, idx, cnt, start, stop);
54+
bool quantify(const Rule* node, int idx, int cnt, int start, int stop, double prob = 0.5) override {
55+
auto it = probs.find(QuantMapKey(node->name, idx));
56+
return model->quantify(node, idx, cnt, start, stop, it != probs.end() ? it->second : prob);
5257
}
5358

5459
std::string charset(const Rule* node, int idx, const std::vector<std::string>& chars) override {

grammarinator-cxx/libgrammarinator/include/grammarinator/tool/DefaultGeneratorFactory.hpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2025 Renata Hodovan, Akos Kiss.
1+
// Copyright (c) 2025-2026 Renata Hodovan, Akos Kiss.
22
//
33
// Licensed under the BSD 3-Clause License
44
// <LICENSE.rst or https://opensource.org/licenses/BSD-3-Clause>.
@@ -23,16 +23,17 @@ namespace tool {
2323
template<class GeneratorClass, class ModelClass = runtime::DefaultModel, class... ListenerClasses>
2424
class DefaultGeneratorFactory : public GeneratorFactory<GeneratorClass> {
2525
private:
26-
runtime::WeightedModel::WeightMap weights;
26+
runtime::WeightedModel::AltMap weights;
27+
runtime::WeightedModel::QuantMap probs;
2728

2829
public:
29-
explicit DefaultGeneratorFactory(const runtime::WeightedModel::WeightMap& weights = {})
30-
: weights(weights) {}
30+
explicit DefaultGeneratorFactory(const runtime::WeightedModel::AltMap& weights = {}, const runtime::WeightedModel::QuantMap& probs = {})
31+
: weights(weights), probs(probs) {}
3132

3233
GeneratorClass operator()(const runtime::RuleSize& limit = runtime::RuleSize::max()) {
3334
runtime::Model* model = new ModelClass();
34-
if (!weights.empty()) {
35-
model = new runtime::WeightedModel(model, weights);
35+
if (!weights.empty() || !probs.empty()) {
36+
model = new runtime::WeightedModel(model, weights, probs);
3637
}
3738
std::vector<runtime::Listener*> listeners = {(new ListenerClasses())...};
3839
return GeneratorClass(model, listeners, limit);

grammarinator-cxx/libgrammarinator/include/grammarinator/tool/JsonWeightLoader.hpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2025 Renata Hodovan, Akos Kiss.
1+
// Copyright (c) 2025-2026 Renata Hodovan, Akos Kiss.
22
//
33
// Licensed under the BSD 3-Clause License
44
// <LICENSE.rst or https://opensource.org/licenses/BSD-3-Clause>.
@@ -23,7 +23,7 @@ class JsonWeightLoader {
2323
JsonWeightLoader(JsonWeightLoader&& other) = delete;
2424
JsonWeightLoader& operator=(JsonWeightLoader&& other) = delete;
2525

26-
void load(const std::string& fn, runtime::WeightedModel::WeightMap& weights) {
26+
void load(const std::string& fn, runtime::WeightedModel::AltMap& weights, runtime::WeightedModel::QuantMap& probs) {
2727
std::ifstream wf(fn);
2828
if (!wf) {
2929
GRAMMARINATOR_LOG_FATAL("Failed to open the weights JSON file for reading: {}", fn);
@@ -36,11 +36,20 @@ class JsonWeightLoader {
3636
return;
3737
}
3838

39-
for (auto& [rule, alts] : data.items()) {
40-
for (auto& [alternation_idx, alternatives] : alts.items()) {
39+
if (data.contains("alts") && data["alts"].is_object()) {
40+
for (auto& [rule, alts] : data["alts"].items()) {
41+
for (auto& [alternation_idx, alternatives] : alts.items()) {
4142
for (auto& [alternative_idx, w] : alternatives.items()) {
4243
weights[{rule, static_cast<size_t>(std::stoul(alternation_idx)), static_cast<size_t>(std::stoul(alternative_idx))}] = w.get<double>();
4344
}
45+
}
46+
}
47+
}
48+
if (data.contains("quants") && data["quants"].is_object()) {
49+
for (auto& [rule, quants] : data["quants"].items()) {
50+
for (auto& [quantifier_idx, quant] : quants.items()) {
51+
probs[{rule, static_cast<size_t>(std::stoul(quantifier_idx))}] = quant.get<double>();
52+
}
4453
}
4554
}
4655
}

grammarinator-cxx/libgrlf/grlf.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2025 Renata Hodovan, Akos Kiss.
1+
// Copyright (c) 2025-2026 Renata Hodovan, Akos Kiss.
22
//
33
// Licensed under the BSD 3-Clause License
44
// <LICENSE.rst or https://opensource.org/licenses/BSD-3-Clause>.
@@ -33,7 +33,8 @@ struct {
3333
int max_tokens = 0;
3434
int max_depth = 0;
3535
int memo_size = 0;
36-
runtime::WeightedModel::WeightMap weights;
36+
runtime::WeightedModel::AltMap weights;
37+
runtime::WeightedModel::QuantMap probs;
3738
} settings;
3839

3940
void initialize_int_arg(const std::string& arg, const std::string& name, int& dest) {
@@ -81,11 +82,11 @@ void initialize_double_arg(const std::string& arg, const std::string& name, doub
8182
}
8283
}
8384

84-
void initialize_weights_arg(const std::string& arg, const std::string& name, runtime::WeightedModel::WeightMap& weights) {
85+
void initialize_weights_arg(const std::string& arg, const std::string& name, runtime::WeightedModel::AltMap& weights, runtime::WeightedModel::QuantMap& probs) {
8586
std::string prefix = "-" + name + "=";
8687
if (arg.rfind(prefix, 0) == 0) {
8788
std::string weights_path = arg.substr(prefix.length());
88-
JsonWeightLoader().load(weights_path, weights);
89+
JsonWeightLoader().load(weights_path, weights, probs);
8990
}
9091
}
9192

@@ -116,7 +117,7 @@ grammarinator::tool::LibFuzzerTool<grammarinator::tool::DefaultGeneratorFactory<
116117
libfuzzer_tool() {
117118
static const GRAMMARINATOR_TREECODEC treeCodec;
118119
static grammarinator::tool::LibFuzzerTool<grammarinator::tool::DefaultGeneratorFactory<GRAMMARINATOR_GENERATOR, GRAMMARINATOR_MODEL, GRAMMARINATOR_LISTENER>>
119-
tool(grammarinator::tool::DefaultGeneratorFactory<GRAMMARINATOR_GENERATOR, GRAMMARINATOR_MODEL, GRAMMARINATOR_LISTENER>(settings.weights),
120+
tool(grammarinator::tool::DefaultGeneratorFactory<GRAMMARINATOR_GENERATOR, GRAMMARINATOR_MODEL, GRAMMARINATOR_LISTENER>(settings.weights, settings.probs),
120121
GRAMMARINATOR_GENERATOR::_default_rule,
121122
grammarinator::runtime::RuleSize(settings.max_depth > 0 ? settings.max_depth : grammarinator::runtime::RuleSize::max().depth,
122123
settings.max_tokens > 0 ? settings.max_tokens : grammarinator::runtime::RuleSize::max().tokens),
@@ -150,7 +151,7 @@ int GrammarinatorInitialize(int* argc, char*** argv) {
150151
initialize_int_arg((*argv)[i], "max_tokens", settings.max_tokens);
151152
initialize_int_arg((*argv)[i], "max_depth", settings.max_depth);
152153
initialize_int_arg((*argv)[i], "memo_size", settings.memo_size);
153-
initialize_weights_arg((*argv)[i], "weights", settings.weights);
154+
initialize_weights_arg((*argv)[i], "weights", settings.weights, settings.probs);
154155
}
155156
}
156157
return 0;

grammarinator-cxx/tools/generate.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2025 Renata Hodovan, Akos Kiss.
1+
// Copyright (c) 2025-2026 Renata Hodovan, Akos Kiss.
22
//
33
// Licensed under the BSD 3-Clause License
44
// <LICENSE.rst or https://opensource.org/licenses/BSD-3-Clause>.
@@ -71,7 +71,7 @@ int main(int argc, char **argv) {
7171
cxxopts::value<int>()->default_value(std::to_string(RuleSize::max().tokens)),
7272
"NUM")
7373
("weights",
74-
"JSON file defining custom weights for alternatives",
74+
"JSON file defining custom weights for alternatives and quantifiers",
7575
cxxopts::value<std::string>(),
7676
"FILE")
7777
("p,population",
@@ -158,9 +158,10 @@ int main(int argc, char **argv) {
158158
TreeCodec* tree_codec = std::get<1>(tree_format_it->second)();
159159

160160
// Parse optional custom weights from JSON
161-
runtime::WeightedModel::WeightMap weights;
161+
runtime::WeightedModel::AltMap weights;
162+
runtime::WeightedModel::QuantMap probs;
162163
if (args.count("weights")) {
163-
JsonWeightLoader().load(args["weights"].as<std::string>(), weights);
164+
JsonWeightLoader().load(args["weights"].as<std::string>(), weights, probs);
164165
}
165166

166167
auto allowlist = args.count("allowlist")
@@ -179,7 +180,7 @@ int main(int argc, char **argv) {
179180

180181
FilePopulation *population = args.count("population") ? new FilePopulation(args["population"].as<std::string>(), tree_extension, *tree_codec) : nullptr;
181182
int seed = args.count("random-seed") ? args["random-seed"].as<int>() : std::random_device()();
182-
GeneratorTool generator(DefaultGeneratorFactory<GRAMMARINATOR_GENERATOR, GRAMMARINATOR_MODEL, GRAMMARINATOR_LISTENER>(weights), // generator_factory
183+
GeneratorTool generator(DefaultGeneratorFactory<GRAMMARINATOR_GENERATOR, GRAMMARINATOR_MODEL, GRAMMARINATOR_LISTENER>(weights, probs), // generator_factory
183184
args.count("stdout") ? "" : args["out"].as<std::string>(), // out_format
184185
args.count("rule") ? args["rule"].as<std::string>() : "", // rule
185186
RuleSize(args["max-depth"].as<int>(), args["max-tokens"].as<int>()), // limit

grammarinator/generate.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2017-2025 Renata Hodovan, Akos Kiss.
1+
# Copyright (c) 2017-2026 Renata Hodovan, Akos Kiss.
22
#
33
# Licensed under the BSD 3-Clause License
44
# <LICENSE.rst or https://opensource.org/licenses/BSD-3-Clause>.
@@ -43,12 +43,21 @@ def process_args(args):
4343
raise ValueError('Custom weights should point to an existing JSON file.')
4444

4545
with open(args.weights, 'r') as f:
46-
weights = {}
47-
for rule, alts in json.load(f).items():
46+
weights, probs = {}, {}
47+
data = json.load(f)
48+
49+
for rule, alts in data.get('alts', {}).items():
4850
for alternation_idx, alternatives in alts.items():
4951
for alternative_idx, w in alternatives.items():
5052
weights[(rule, int(alternation_idx), int(alternative_idx))] = w
53+
54+
for rule, quants in data.get('quants', {}).items():
55+
for quantifier_idx, quant in quants.items():
56+
probs[(rule, int(quantifier_idx))] = quant
5157
args.weights = weights
58+
args.probs = probs
59+
else:
60+
args.probs = None
5261

5362
if args.population:
5463
args.population = abspath(args.population)
@@ -58,6 +67,7 @@ def generator_tool_helper(args, lock=None):
5867
return GeneratorTool(generator_factory=DefaultGeneratorFactory(args.generator,
5968
model_class=args.model,
6069
weights=args.weights,
70+
probs=args.probs,
6171
listener_classes=args.listener),
6272
rule=args.rule, out_format=args.out, lock=lock,
6373
limit=RuleSize(depth=args.max_depth, tokens=args.max_tokens),
@@ -99,7 +109,7 @@ def execute():
99109
parser.add_argument('--max-tokens', default=RuleSize.max.tokens, type=int, metavar='NUM',
100110
help='maximum token number during generation (default: %(default)f).')
101111
parser.add_argument('-w', '--weights', metavar='FILE',
102-
help='JSON file defining custom weights for alternatives.')
112+
help='JSON file defining custom weights for alternatives and quantifiers.')
103113

104114
# Evolutionary settings.
105115
parser.add_argument('--population', metavar='DIR',

grammarinator/runtime/default_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2020-2025 Renata Hodovan, Akos Kiss.
1+
# Copyright (c) 2020-2026 Renata Hodovan, Akos Kiss.
22
#
33
# Licensed under the BSD 3-Clause License
44
# <LICENSE.rst or https://opensource.org/licenses/BSD-3-Clause>.
@@ -25,7 +25,7 @@ def choice(self, node: Rule, idx: int, weights: list[float]) -> int:
2525
# assert sum(weights) > 0, 'Sum of weights is zero.'
2626
return random.choices(range(len(weights)), weights=weights)[0]
2727

28-
def quantify(self, node: Rule, idx: int, cnt: int, start: int, stop: int | float) -> bool:
28+
def quantify(self, node: Rule, idx: int, cnt: int, start: int, stop: int | float, prob: float = 0.5) -> bool:
2929
"""
3030
After generating the minimum expected items (``start``) and before
3131
reaching the maximum expected items (``stop``), quantify decides about
@@ -34,7 +34,7 @@ def quantify(self, node: Rule, idx: int, cnt: int, start: int, stop: int | float
3434
Parameters ``node``, ``idx``, ``cnt``, ``start``, and ``stop`` are
3535
unused.
3636
"""
37-
return bool(random.getrandbits(1))
37+
return random.random() < prob
3838

3939
def charset(self, node: Rule, idx: int, chars: tuple[int, ...]) -> str:
4040
"""

grammarinator/runtime/dispatching_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2020-2025 Renata Hodovan, Akos Kiss.
1+
# Copyright (c) 2020-2026 Renata Hodovan, Akos Kiss.
22
#
33
# Licensed under the BSD 3-Clause License
44
# <LICENSE.rst or https://opensource.org/licenses/BSD-3-Clause>.
@@ -28,14 +28,14 @@ def choice(self, node: Rule, idx: int, weights: list[float]) -> int:
2828
name = 'choice_' + node.name
2929
return (getattr(self, name) if hasattr(self, name) else super().choice)(node, idx, weights)
3030

31-
def quantify(self, node: Rule, idx: int, cnt: int, start: int, stop: int | float) -> bool:
31+
def quantify(self, node: Rule, idx: int, cnt: int, start: int, stop: int | float, prob: float = 0.5) -> bool:
3232
"""
3333
Trampoline to the ``quantify_{node.name}`` method of the subclassed
3434
model, if it exists. Otherwise, it calls the default implementation
3535
(:meth:`DefaultModel.quantify`).
3636
"""
3737
name = 'quantify_' + node.name
38-
return (getattr(self, name) if hasattr(self, name) else super().quantify)(node, idx, cnt, start, stop)
38+
return (getattr(self, name) if hasattr(self, name) else super().quantify)(node, idx, cnt, start, stop, prob)
3939

4040
def charset(self, node: Rule, idx: int, chars: tuple[int, ...]) -> str:
4141
"""

0 commit comments

Comments
 (0)