Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2025 Renata Hodovan, Akos Kiss.
// Copyright (c) 2025-2026 Renata Hodovan, Akos Kiss.
//
// Licensed under the BSD 3-Clause License
// <LICENSE.rst or https://opensource.org/licenses/BSD-3-Clause>.
Expand Down Expand Up @@ -35,8 +35,8 @@ class DefaultModel : public Model {
return grammarinator::util::random_weighted_choice(weights);
}

bool quantify(const Rule* node, int idx, int cnt, int start, int stop) override {
return grammarinator::util::random_real(0.0, 1.0) > 0.5;
bool quantify(const Rule* node, int idx, int cnt, int start, int stop, double prob = 0.5) override {
return grammarinator::util::random_real(0.0, 1.0) < prob;
}

std::string charset(const Rule* node, int idx, const std::vector<std::string>& chars) override {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2025 Renata Hodovan, Akos Kiss.
// Copyright (c) 2025-2026 Renata Hodovan, Akos Kiss.
//
// Licensed under the BSD 3-Clause License
// <LICENSE.rst or https://opensource.org/licenses/BSD-3-Clause>.
Expand Down Expand Up @@ -30,7 +30,7 @@ class Model {
virtual ~Model() = default;

virtual int choice(const Rule* node, int idx, const std::vector<double>& weights) = 0;
virtual bool quantify(const Rule* node, int idx, int cnt, int start, int stop) = 0;
virtual bool quantify(const Rule* node, int idx, int cnt, int start, int stop, double prob = 0.5) = 0;
virtual std::string charset(const Rule* node, int idx, const std::vector<std::string>& chars) = 0;
};

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2025 Renata Hodovan, Akos Kiss.
// Copyright (c) 2025-2026 Renata Hodovan, Akos Kiss.
//
// Licensed under the BSD 3-Clause License
// <LICENSE.rst or https://opensource.org/licenses/BSD-3-Clause>.
Expand All @@ -23,15 +23,19 @@ namespace runtime {
*/
class WeightedModel : public Model {
public:
using WeightMapKey = std::tuple<std::string, size_t, size_t>;
using WeightMap = std::map<WeightMapKey, double>;
using AltMapKey = std::tuple<std::string, size_t, size_t>;
using AltMap = std::map<AltMapKey, double>;

using QuantMapKey = std::tuple<std::string, size_t>;
using QuantMap = std::map<QuantMapKey, double>;

private:
Model* model;
const WeightMap& weights;
const AltMap& weights;
const QuantMap& probs;

public:
explicit WeightedModel(Model* model, const WeightMap& weights = {}) noexcept : Model(), model(model), weights(weights) {}
explicit WeightedModel(Model* model, const AltMap& weights = {}, const QuantMap& probs = {}) noexcept : Model(), model(model), weights(weights), probs(probs) {}
WeightedModel(const WeightedModel& other) = delete;
WeightedModel& operator=(const WeightedModel& other) = delete;
WeightedModel(WeightedModel&& other) = delete;
Expand All @@ -41,14 +45,15 @@ class WeightedModel : public Model {
int choice(const Rule* node, int idx, const std::vector<double>& cweights) override {
std::vector<double> multiplied_weights(cweights.size());
for (size_t i = 0; i < cweights.size(); ++i) {
auto it = weights.find(WeightMapKey(node->name, idx, i));
auto it = weights.find(AltMapKey(node->name, idx, i));
multiplied_weights[i] = cweights[i] * (it != weights.end() ? it->second : 1.0);
}
return model->choice(node, idx, multiplied_weights);
}

bool quantify(const Rule* node, int idx, int cnt, int start, int stop) override {
return model->quantify(node, idx, cnt, start, stop);
bool quantify(const Rule* node, int idx, int cnt, int start, int stop, double prob = 0.5) override {
auto it = probs.find(QuantMapKey(node->name, idx));
return model->quantify(node, idx, cnt, start, stop, it != probs.end() ? it->second : prob);
}

std::string charset(const Rule* node, int idx, const std::vector<std::string>& chars) override {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2025 Renata Hodovan, Akos Kiss.
// Copyright (c) 2025-2026 Renata Hodovan, Akos Kiss.
//
// Licensed under the BSD 3-Clause License
// <LICENSE.rst or https://opensource.org/licenses/BSD-3-Clause>.
Expand All @@ -23,16 +23,17 @@ namespace tool {
template<class GeneratorClass, class ModelClass = runtime::DefaultModel, class... ListenerClasses>
class DefaultGeneratorFactory : public GeneratorFactory<GeneratorClass> {
private:
runtime::WeightedModel::WeightMap weights;
runtime::WeightedModel::AltMap weights;
runtime::WeightedModel::QuantMap probs;

public:
explicit DefaultGeneratorFactory(const runtime::WeightedModel::WeightMap& weights = {})
: weights(weights) {}
explicit DefaultGeneratorFactory(const runtime::WeightedModel::AltMap& weights = {}, const runtime::WeightedModel::QuantMap& probs = {})
: weights(weights), probs(probs) {}

GeneratorClass operator()(const runtime::RuleSize& limit = runtime::RuleSize::max()) {
runtime::Model* model = new ModelClass();
if (!weights.empty()) {
model = new runtime::WeightedModel(model, weights);
if (!weights.empty() || !probs.empty()) {
model = new runtime::WeightedModel(model, weights, probs);
}
std::vector<runtime::Listener*> listeners = {(new ListenerClasses())...};
return GeneratorClass(model, listeners, limit);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2025 Renata Hodovan, Akos Kiss.
// Copyright (c) 2025-2026 Renata Hodovan, Akos Kiss.
//
// Licensed under the BSD 3-Clause License
// <LICENSE.rst or https://opensource.org/licenses/BSD-3-Clause>.
Expand All @@ -23,7 +23,7 @@ class JsonWeightLoader {
JsonWeightLoader(JsonWeightLoader&& other) = delete;
JsonWeightLoader& operator=(JsonWeightLoader&& other) = delete;

void load(const std::string& fn, runtime::WeightedModel::WeightMap& weights) {
void load(const std::string& fn, runtime::WeightedModel::AltMap& weights, runtime::WeightedModel::QuantMap& probs) {
std::ifstream wf(fn);
if (!wf) {
GRAMMARINATOR_LOG_FATAL("Failed to open the weights JSON file for reading: {}", fn);
Expand All @@ -36,11 +36,20 @@ class JsonWeightLoader {
return;
}

for (auto& [rule, alts] : data.items()) {
for (auto& [alternation_idx, alternatives] : alts.items()) {
if (data.contains("alts") && data["alts"].is_object()) {
for (auto& [rule, alts] : data["alts"].items()) {
for (auto& [alternation_idx, alternatives] : alts.items()) {
for (auto& [alternative_idx, w] : alternatives.items()) {
weights[{rule, static_cast<size_t>(std::stoul(alternation_idx)), static_cast<size_t>(std::stoul(alternative_idx))}] = w.get<double>();
}
}
}
}
if (data.contains("quants") && data["quants"].is_object()) {
for (auto& [rule, quants] : data["quants"].items()) {
for (auto& [quantifier_idx, quant] : quants.items()) {
probs[{rule, static_cast<size_t>(std::stoul(quantifier_idx))}] = quant.get<double>();
}
}
}
}
Expand Down
13 changes: 7 additions & 6 deletions grammarinator-cxx/libgrlf/grlf.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2025 Renata Hodovan, Akos Kiss.
// Copyright (c) 2025-2026 Renata Hodovan, Akos Kiss.
//
// Licensed under the BSD 3-Clause License
// <LICENSE.rst or https://opensource.org/licenses/BSD-3-Clause>.
Expand Down Expand Up @@ -33,7 +33,8 @@ struct {
int max_tokens = 0;
int max_depth = 0;
int memo_size = 0;
runtime::WeightedModel::WeightMap weights;
runtime::WeightedModel::AltMap weights;
runtime::WeightedModel::QuantMap probs;
} settings;

void initialize_int_arg(const std::string& arg, const std::string& name, int& dest) {
Expand Down Expand Up @@ -81,11 +82,11 @@ void initialize_double_arg(const std::string& arg, const std::string& name, doub
}
}

void initialize_weights_arg(const std::string& arg, const std::string& name, runtime::WeightedModel::WeightMap& weights) {
void initialize_weights_arg(const std::string& arg, const std::string& name, runtime::WeightedModel::AltMap& weights, runtime::WeightedModel::QuantMap& probs) {
std::string prefix = "-" + name + "=";
if (arg.rfind(prefix, 0) == 0) {
std::string weights_path = arg.substr(prefix.length());
JsonWeightLoader().load(weights_path, weights);
JsonWeightLoader().load(weights_path, weights, probs);
}
}

Expand Down Expand Up @@ -116,7 +117,7 @@ grammarinator::tool::LibFuzzerTool<grammarinator::tool::DefaultGeneratorFactory<
libfuzzer_tool() {
static const GRAMMARINATOR_TREECODEC treeCodec;
static grammarinator::tool::LibFuzzerTool<grammarinator::tool::DefaultGeneratorFactory<GRAMMARINATOR_GENERATOR, GRAMMARINATOR_MODEL, GRAMMARINATOR_LISTENER>>
tool(grammarinator::tool::DefaultGeneratorFactory<GRAMMARINATOR_GENERATOR, GRAMMARINATOR_MODEL, GRAMMARINATOR_LISTENER>(settings.weights),
tool(grammarinator::tool::DefaultGeneratorFactory<GRAMMARINATOR_GENERATOR, GRAMMARINATOR_MODEL, GRAMMARINATOR_LISTENER>(settings.weights, settings.probs),
GRAMMARINATOR_GENERATOR::_default_rule,
grammarinator::runtime::RuleSize(settings.max_depth > 0 ? settings.max_depth : grammarinator::runtime::RuleSize::max().depth,
settings.max_tokens > 0 ? settings.max_tokens : grammarinator::runtime::RuleSize::max().tokens),
Expand Down Expand Up @@ -150,7 +151,7 @@ int GrammarinatorInitialize(int* argc, char*** argv) {
initialize_int_arg((*argv)[i], "max_tokens", settings.max_tokens);
initialize_int_arg((*argv)[i], "max_depth", settings.max_depth);
initialize_int_arg((*argv)[i], "memo_size", settings.memo_size);
initialize_weights_arg((*argv)[i], "weights", settings.weights);
initialize_weights_arg((*argv)[i], "weights", settings.weights, settings.probs);
}
}
return 0;
Expand Down
11 changes: 6 additions & 5 deletions grammarinator-cxx/tools/generate.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2025 Renata Hodovan, Akos Kiss.
// Copyright (c) 2025-2026 Renata Hodovan, Akos Kiss.
//
// Licensed under the BSD 3-Clause License
// <LICENSE.rst or https://opensource.org/licenses/BSD-3-Clause>.
Expand Down Expand Up @@ -71,7 +71,7 @@ int main(int argc, char **argv) {
cxxopts::value<int>()->default_value(std::to_string(RuleSize::max().tokens)),
"NUM")
("weights",
"JSON file defining custom weights for alternatives",
"JSON file defining custom weights for alternatives and quantifiers",
cxxopts::value<std::string>(),
"FILE")
("p,population",
Expand Down Expand Up @@ -158,9 +158,10 @@ int main(int argc, char **argv) {
TreeCodec* tree_codec = std::get<1>(tree_format_it->second)();

// Parse optional custom weights from JSON
runtime::WeightedModel::WeightMap weights;
runtime::WeightedModel::AltMap weights;
runtime::WeightedModel::QuantMap probs;
if (args.count("weights")) {
JsonWeightLoader().load(args["weights"].as<std::string>(), weights);
JsonWeightLoader().load(args["weights"].as<std::string>(), weights, probs);
}

auto allowlist = args.count("allowlist")
Expand All @@ -179,7 +180,7 @@ int main(int argc, char **argv) {

FilePopulation *population = args.count("population") ? new FilePopulation(args["population"].as<std::string>(), tree_extension, *tree_codec) : nullptr;
int seed = args.count("random-seed") ? args["random-seed"].as<int>() : std::random_device()();
GeneratorTool generator(DefaultGeneratorFactory<GRAMMARINATOR_GENERATOR, GRAMMARINATOR_MODEL, GRAMMARINATOR_LISTENER>(weights), // generator_factory
GeneratorTool generator(DefaultGeneratorFactory<GRAMMARINATOR_GENERATOR, GRAMMARINATOR_MODEL, GRAMMARINATOR_LISTENER>(weights, probs), // generator_factory
args.count("stdout") ? "" : args["out"].as<std::string>(), // out_format
args.count("rule") ? args["rule"].as<std::string>() : "", // rule
RuleSize(args["max-depth"].as<int>(), args["max-tokens"].as<int>()), // limit
Expand Down
18 changes: 14 additions & 4 deletions grammarinator/generate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2017-2025 Renata Hodovan, Akos Kiss.
# Copyright (c) 2017-2026 Renata Hodovan, Akos Kiss.
#
# Licensed under the BSD 3-Clause License
# <LICENSE.rst or https://opensource.org/licenses/BSD-3-Clause>.
Expand Down Expand Up @@ -43,12 +43,21 @@ def process_args(args):
raise ValueError('Custom weights should point to an existing JSON file.')

with open(args.weights, 'r') as f:
weights = {}
for rule, alts in json.load(f).items():
weights, probs = {}, {}
data = json.load(f)

for rule, alts in data.get('alts', {}).items():
for alternation_idx, alternatives in alts.items():
for alternative_idx, w in alternatives.items():
weights[(rule, int(alternation_idx), int(alternative_idx))] = w

for rule, quants in data.get('quants', {}).items():
for quantifier_idx, quant in quants.items():
probs[(rule, int(quantifier_idx))] = quant
args.weights = weights
args.probs = probs
else:
args.probs = None

if args.population:
args.population = abspath(args.population)
Expand All @@ -58,6 +67,7 @@ def generator_tool_helper(args, lock=None):
return GeneratorTool(generator_factory=DefaultGeneratorFactory(args.generator,
model_class=args.model,
weights=args.weights,
probs=args.probs,
listener_classes=args.listener),
rule=args.rule, out_format=args.out, lock=lock,
limit=RuleSize(depth=args.max_depth, tokens=args.max_tokens),
Expand Down Expand Up @@ -99,7 +109,7 @@ def execute():
parser.add_argument('--max-tokens', default=RuleSize.max.tokens, type=int, metavar='NUM',
help='maximum token number during generation (default: %(default)f).')
parser.add_argument('-w', '--weights', metavar='FILE',
help='JSON file defining custom weights for alternatives.')
help='JSON file defining custom weights for alternatives and quantifiers.')

# Evolutionary settings.
parser.add_argument('--population', metavar='DIR',
Expand Down
6 changes: 3 additions & 3 deletions grammarinator/runtime/default_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2025 Renata Hodovan, Akos Kiss.
# Copyright (c) 2020-2026 Renata Hodovan, Akos Kiss.
#
# Licensed under the BSD 3-Clause License
# <LICENSE.rst or https://opensource.org/licenses/BSD-3-Clause>.
Expand All @@ -25,7 +25,7 @@ def choice(self, node: Rule, idx: int, weights: list[float]) -> int:
# assert sum(weights) > 0, 'Sum of weights is zero.'
return random.choices(range(len(weights)), weights=weights)[0]

def quantify(self, node: Rule, idx: int, cnt: int, start: int, stop: int | float) -> bool:
def quantify(self, node: Rule, idx: int, cnt: int, start: int, stop: int | float, prob: float = 0.5) -> bool:
"""
After generating the minimum expected items (``start``) and before
reaching the maximum expected items (``stop``), quantify decides about
Expand All @@ -34,7 +34,7 @@ def quantify(self, node: Rule, idx: int, cnt: int, start: int, stop: int | float
Parameters ``node``, ``idx``, ``cnt``, ``start``, and ``stop`` are
unused.
"""
return bool(random.getrandbits(1))
return random.random() < prob

def charset(self, node: Rule, idx: int, chars: tuple[int, ...]) -> str:
"""
Expand Down
6 changes: 3 additions & 3 deletions grammarinator/runtime/dispatching_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2025 Renata Hodovan, Akos Kiss.
# Copyright (c) 2020-2026 Renata Hodovan, Akos Kiss.
#
# Licensed under the BSD 3-Clause License
# <LICENSE.rst or https://opensource.org/licenses/BSD-3-Clause>.
Expand Down Expand Up @@ -28,14 +28,14 @@ def choice(self, node: Rule, idx: int, weights: list[float]) -> int:
name = 'choice_' + node.name
return (getattr(self, name) if hasattr(self, name) else super().choice)(node, idx, weights)

def quantify(self, node: Rule, idx: int, cnt: int, start: int, stop: int | float) -> bool:
def quantify(self, node: Rule, idx: int, cnt: int, start: int, stop: int | float, prob: float = 0.5) -> bool:
"""
Trampoline to the ``quantify_{node.name}`` method of the subclassed
model, if it exists. Otherwise, it calls the default implementation
(:meth:`DefaultModel.quantify`).
"""
name = 'quantify_' + node.name
return (getattr(self, name) if hasattr(self, name) else super().quantify)(node, idx, cnt, start, stop)
return (getattr(self, name) if hasattr(self, name) else super().quantify)(node, idx, cnt, start, stop, prob)

def charset(self, node: Rule, idx: int, chars: tuple[int, ...]) -> str:
"""
Expand Down
6 changes: 4 additions & 2 deletions grammarinator/runtime/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023-2025 Renata Hodovan, Akos Kiss.
# Copyright (c) 2023-2026 Renata Hodovan, Akos Kiss.
#
# Licensed under the BSD 3-Clause License
# <LICENSE.rst or https://opensource.org/licenses/BSD-3-Clause>.
Expand Down Expand Up @@ -30,7 +30,7 @@ def choice(self, node: Rule, idx: int, weights: list[float]) -> int:
"""
raise NotImplementedError()

def quantify(self, node: Rule, idx: int, cnt: int, start: int, stop: int | float) -> bool:
def quantify(self, node: Rule, idx: int, cnt: int, start: int, stop: int | float, prob: float = 0.5) -> bool:
"""
Guide the loop of subtree quantification. This has to make a binary
decision to tell whether to enable the next iteration or stop the loop.
Expand All @@ -44,6 +44,8 @@ def quantify(self, node: Rule, idx: int, cnt: int, start: int, stop: int | float
between ``start`` (inclusive) and ``stop`` (exclusive).
:param start: Lower bound of the quantification range.
:param stop: Upper bound of the quantification range.
:param prob: Predefined probability of enabling the next iteration
(between 0 and 1).
:return: Boolean value enabling the next iteration or stopping it.
"""
raise NotImplementedError()
Expand Down
Loading