Skip to content

Commit 5452d73

Browse files
authored
jinja: correct stats for tojson and string filters (ggml-org#19785)
1 parent ed48378 commit 5452d73

File tree

4 files changed

+101
-7
lines changed

4 files changed

+101
-7
lines changed

common/jinja/runtime.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ value identifier::execute_impl(context & ctx) {
8585
auto builtins = global_builtins();
8686
if (!it->is_undefined()) {
8787
if (ctx.is_get_stats) {
88-
it->stats.used = true;
88+
value_t::stats_t::mark_used(it);
8989
}
9090
JJ_DEBUG("Identifier '%s' found, type = %s", val.c_str(), it->type().c_str());
9191
return it;
@@ -277,7 +277,7 @@ value binary_expression::execute_impl(context & ctx) {
277277
static value try_builtin_func(context & ctx, const std::string & name, value & input, bool undef_on_missing = false) {
278278
JJ_DEBUG("Trying built-in function '%s' for type %s", name.c_str(), input->type().c_str());
279279
if (ctx.is_get_stats) {
280-
input->stats.used = true;
280+
value_t::stats_t::mark_used(input);
281281
input->stats.ops.insert(name);
282282
}
283283
auto builtins = input->get_builtins();
@@ -448,7 +448,7 @@ value for_statement::execute_impl(context & ctx) {
448448

449449
// mark the variable being iterated as used for stats
450450
if (ctx.is_get_stats) {
451-
iterable_val->stats.used = true;
451+
value_t::stats_t::mark_used(iterable_val);
452452
iterable_val->stats.ops.insert("array_access");
453453
}
454454

@@ -470,7 +470,7 @@ value for_statement::execute_impl(context & ctx) {
470470
items.push_back(std::move(tuple));
471471
}
472472
if (ctx.is_get_stats) {
473-
iterable_val->stats.used = true;
473+
value_t::stats_t::mark_used(iterable_val);
474474
iterable_val->stats.ops.insert("object_access");
475475
}
476476
} else {
@@ -480,7 +480,7 @@ value for_statement::execute_impl(context & ctx) {
480480
items.push_back(item);
481481
}
482482
if (ctx.is_get_stats) {
483-
iterable_val->stats.used = true;
483+
value_t::stats_t::mark_used(iterable_val);
484484
iterable_val->stats.ops.insert("array_access");
485485
}
486486
}
@@ -817,8 +817,9 @@ value member_expression::execute_impl(context & ctx) {
817817
}
818818

819819
if (ctx.is_get_stats && val && object && property) {
820-
val->stats.used = true;
821-
object->stats.used = true;
820+
value_t::stats_t::mark_used(val);
821+
value_t::stats_t::mark_used(object);
822+
value_t::stats_t::mark_used(property);
822823
if (is_val<value_int>(property)) {
823824
object->stats.ops.insert("array_access");
824825
} else if (is_val<value_string>(property)) {

common/jinja/value.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,11 @@ static value tojson(const func_args & args) {
161161
value val_separators = args.get_kwarg_or_pos("separators", 3);
162162
value val_sort = args.get_kwarg_or_pos("sort_keys", 4);
163163
int indent = -1;
164+
if (args.ctx.is_get_stats) {
165+
// mark as used (recursively) for stats
166+
auto val_input = args.get_pos(0);
167+
value_t::stats_t::mark_used(const_cast<value&>(val_input), true);
168+
}
164169
if (is_val<value_int>(val_indent)) {
165170
indent = static_cast<int>(val_indent->as_int());
166171
}
@@ -891,6 +896,11 @@ const func_builtins & value_array_t::get_builtins() const {
891896
}},
892897
{"string", [](const func_args & args) -> value {
893898
args.ensure_vals<value_array>();
899+
if (args.ctx.is_get_stats) {
900+
// mark as used (recursively) for stats
901+
auto val_input = args.get_pos(0);
902+
value_t::stats_t::mark_used(const_cast<value&>(val_input), true);
903+
}
894904
return mk_val<value_string>(args.get_pos(0)->as_string());
895905
}},
896906
{"tojson", tojson},
@@ -1046,6 +1056,11 @@ const func_builtins & value_object_t::get_builtins() const {
10461056
{"tojson", tojson},
10471057
{"string", [](const func_args & args) -> value {
10481058
args.ensure_vals<value_object>();
1059+
if (args.ctx.is_get_stats) {
1060+
// mark as used (recursively) for stats
1061+
auto val_input = args.get_pos(0);
1062+
value_t::stats_t::mark_used(const_cast<value&>(val_input), true);
1063+
}
10491064
return mk_val<value_string>(args.get_pos(0)->as_string());
10501065
}},
10511066
{"length", [](const func_args & args) -> value {
@@ -1358,4 +1373,21 @@ std::string value_to_string_repr(const value & val) {
13581373
}
13591374
}
13601375

1376+
// stats utility
1377+
void value_t::stats_t::mark_used(value & val, bool deep) {
1378+
val->stats.used = true;
1379+
if (deep) {
1380+
if (is_val<value_array>(val)) {
1381+
for (auto & item : val->val_arr) {
1382+
mark_used(item, deep);
1383+
}
1384+
} else if (is_val<value_object>(val)) {
1385+
for (auto & pair : val->val_obj) {
1386+
mark_used(pair.first, deep);
1387+
mark_used(pair.second, deep);
1388+
}
1389+
}
1390+
}
1391+
}
1392+
13611393
} // namespace jinja

common/jinja/value.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ struct value_t {
118118
bool used = false;
119119
// ops can be builtin calls or operators: "array_access", "object_access"
120120
std::set<std::string> ops;
121+
// utility to recursively mark value and its children as used
122+
static void mark_used(value & val, bool deep = false);
121123
} stats;
122124

123125
value_t() = default;

tests/test-jinja.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ static void test_string_methods(testing & t);
3232
static void test_array_methods(testing & t);
3333
static void test_object_methods(testing & t);
3434
static void test_hasher(testing & t);
35+
static void test_stats(testing & t);
3536
static void test_fuzzing(testing & t);
3637

3738
static bool g_python_mode = false;
@@ -70,6 +71,7 @@ int main(int argc, char *argv[]) {
7071
t.test("object methods", test_object_methods);
7172
if (!g_python_mode) {
7273
t.test("hasher", test_hasher);
74+
t.test("stats", test_stats);
7375
t.test("fuzzing", test_fuzzing);
7476
}
7577

@@ -1795,6 +1797,63 @@ static void test_hasher(testing & t) {
17951797
});
17961798
}
17971799

1800+
static void test_stats(testing & t) {
1801+
static auto get_stats = [](const std::string & tmpl, const json & vars) -> jinja::value {
1802+
jinja::lexer lexer;
1803+
auto lexer_res = lexer.tokenize(tmpl);
1804+
1805+
jinja::program prog = jinja::parse_from_tokens(lexer_res);
1806+
1807+
jinja::context ctx(tmpl);
1808+
jinja::global_from_json(ctx, json{{ "val", vars }}, true);
1809+
ctx.is_get_stats = true;
1810+
1811+
jinja::runtime runtime(ctx);
1812+
runtime.execute(prog);
1813+
1814+
return ctx.get_val("val");
1815+
};
1816+
1817+
t.test("stats", [](testing & t) {
1818+
jinja::value val = get_stats(
1819+
"{{val.num}} "
1820+
"{{val.str}} "
1821+
"{{val.arr[0]}} "
1822+
"{{val.obj.key1}} "
1823+
"{{val.nested | tojson}}",
1824+
// Note: the json below will be wrapped inside "val" in the context
1825+
json{
1826+
{"num", 1},
1827+
{"str", "abc"},
1828+
{"arr", json::array({1, 2, 3})},
1829+
{"obj", json::object({{"key1", 1}, {"key2", 2}, {"key3", 3}})},
1830+
{"nested", json::object({
1831+
{"inner_key1", json::array({1, 2})},
1832+
{"inner_key2", json::object({{"a", "x"}, {"b", "y"}})}
1833+
})},
1834+
{"mixed", json::object({
1835+
{"used", 1},
1836+
{"unused", 2},
1837+
})},
1838+
}
1839+
);
1840+
1841+
t.assert_true("num is used", val->at("num")->stats.used);
1842+
t.assert_true("str is used", val->at("str")->stats.used);
1843+
1844+
t.assert_true("arr is used", val->at("arr")->stats.used);
1845+
t.assert_true("arr[0] is used", val->at("arr")->at(0)->stats.used);
1846+
t.assert_true("arr[1] is not used", !val->at("arr")->at(1)->stats.used);
1847+
1848+
t.assert_true("obj is used", val->at("obj")->stats.used);
1849+
t.assert_true("obj.key1 is used", val->at("obj")->at("key1")->stats.used);
1850+
t.assert_true("obj.key2 is not used", !val->at("obj")->at("key2")->stats.used);
1851+
1852+
t.assert_true("inner_key1[0] is used", val->at("nested")->at("inner_key1")->at(0)->stats.used);
1853+
t.assert_true("inner_key2.a is used", val->at("nested")->at("inner_key2")->at("a")->stats.used);
1854+
});
1855+
}
1856+
17981857
static void test_template_cpp(testing & t, const std::string & name, const std::string & tmpl, const json & vars, const std::string & expect) {
17991858
t.test(name, [&tmpl, &vars, &expect](testing & t) {
18001859
jinja::lexer lexer;

0 commit comments

Comments
 (0)