Skip to content

Commit 98e57ca

Browse files
authored
chat: fix case where template accepts type content only (ggml-org#19419)
* chat: fix case where template accepts type content only * rm stray log * reuse render_message_to_json
1 parent 262364e commit 98e57ca

File tree

5 files changed

+55
-9
lines changed

5 files changed

+55
-9
lines changed

common/chat.cpp

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -380,15 +380,46 @@ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messa
380380
return msgs;
381381
}
382382

383-
json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text) {
383+
static json render_message_to_json(const std::vector<common_chat_msg> & msgs, const jinja::caps & c) {
384+
if (!c.supports_string_content && !c.supports_typed_content) {
385+
LOG_WRN("%s: Neither string content nor typed content is supported by the template. This is unexpected and may lead to issues.\n", __func__);
386+
}
387+
388+
bool only_string_accepted = c.supports_string_content && !c.supports_typed_content;
389+
bool only_typed_accepted = !c.supports_string_content && c.supports_typed_content;
390+
384391
json messages = json::array();
385392
for (const auto & msg : msgs) {
386-
json jmsg = msg.to_json_oaicompat(concat_typed_text);
387-
messages.push_back(jmsg);
393+
if (only_string_accepted) {
394+
json jmsg = msg.to_json_oaicompat(/* concat_typed_text= */ true);
395+
messages.push_back(jmsg);
396+
} else if (only_typed_accepted) {
397+
json jmsg = msg.to_json_oaicompat(/* concat_typed_text= */ false);
398+
if (jmsg.at("content").is_string()) {
399+
jmsg["content"] = json::array({
400+
json{
401+
{"type", "text"},
402+
{"text", jmsg.at("content").get<std::string>()},
403+
}
404+
});
405+
}
406+
messages.push_back(jmsg);
407+
} else {
408+
json jmsg = msg.to_json_oaicompat(/* concat_typed_text= */ false);
409+
messages.push_back(jmsg);
410+
}
388411
}
389412
return messages;
390413
}
391414

415+
// DEPRECATED: only used in tests
416+
json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text) {
417+
jinja::caps c;
418+
c.supports_string_content = true;
419+
c.supports_typed_content = !concat_typed_text;
420+
return render_message_to_json(msgs, c);
421+
}
422+
392423
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & tools) {
393424
std::vector<common_chat_tool> result;
394425

@@ -3020,7 +3051,7 @@ static common_chat_params common_chat_templates_apply_jinja(
30203051
: *tmpls->template_default;
30213052
const auto & src = tmpl.source();
30223053
const auto & caps = tmpl.original_caps();
3023-
params.messages = common_chat_msgs_to_json_oaicompat(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content);
3054+
params.messages = render_message_to_json(inputs.messages, tmpl.original_caps());
30243055
params.add_generation_prompt = inputs.add_generation_prompt;
30253056
params.tool_choice = inputs.tool_choice;
30263057
params.reasoning_format = inputs.reasoning_format;

common/chat.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,8 @@ bool common_chat_templates_support_enable_thinking(const common_chat_templates *
240240

241241
// Parses a JSON array of messages in OpenAI's chat completion API format.
242242
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const nlohmann::ordered_json & messages);
243+
244+
// DEPRECATED: only used in tests
243245
nlohmann::ordered_json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
244246

245247
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const nlohmann::ordered_json & tools);

common/jinja/caps.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ static void caps_print_stats(value & v, const std::string & path) {
6363

6464
std::map<std::string, bool> caps::to_map() const {
6565
return {
66-
{"requires_typed_content", requires_typed_content},
66+
{"supports_string_content", supports_string_content},
67+
{"supports_typed_content", supports_typed_content},
6768
{"supports_tools", supports_tools},
6869
{"supports_tool_calls", supports_tool_calls},
6970
{"supports_parallel_tool_calls", supports_parallel_tool_calls},
@@ -89,7 +90,7 @@ caps caps_get(jinja::program & prog) {
8990
return v->stats.ops.find(op_name) != v->stats.ops.end();
9091
};
9192

92-
// case: typed content requirement
93+
// case: typed content support
9394
caps_try_execute(
9495
prog,
9596
[&]() {
@@ -105,12 +106,16 @@ caps caps_get(jinja::program & prog) {
105106
// tools
106107
return json{nullptr};
107108
},
108-
[&](bool, value & messages, value &) {
109+
[&](bool success, value & messages, value &) {
109110
auto & content = messages->at(0)->at("content");
110111
caps_print_stats(content, "messages[0].content");
111112
if (has_op(content, "selectattr") || has_op(content, "array_access")) {
112113
// accessed as an array
113-
result.requires_typed_content = true;
114+
result.supports_typed_content = true;
115+
}
116+
if (!success) {
117+
// failed to execute with content as string
118+
result.supports_string_content = false;
114119
}
115120
}
116121
);

common/jinja/caps.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ struct caps {
1414
bool supports_parallel_tool_calls = true;
1515
bool supports_preserve_reasoning = false; // support assistant message with reasoning_content
1616

17-
bool requires_typed_content = false; // default: use string content
17+
// one of the 2 content capabilities must be true
18+
bool supports_string_content = true;
19+
bool supports_typed_content = false;
1820

1921
// for reporting on server
2022
std::map<std::string, bool> to_map() const;

common/jinja/runtime.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,12 @@ value for_statement::execute_impl(context & ctx) {
446446

447447
value iterable_val = iter_expr->execute(scope);
448448

449+
// mark the variable being iterated as used for stats
450+
if (ctx.is_get_stats) {
451+
iterable_val->stats.used = true;
452+
iterable_val->stats.ops.insert("array_access");
453+
}
454+
449455
if (iterable_val->is_undefined()) {
450456
JJ_DEBUG("%s", "For loop iterable is undefined, skipping loop");
451457
iterable_val = mk_val<value_array>();

0 commit comments

Comments
 (0)