Skip to content

Commit 14d2192

Browse files
committed
Gemma Chat Template Support for LoRA Finetuning
- Add auto-detection for Gemma format (<start_of_turn>model\n...<end_of_turn>) - Falls back to ChatML format for other models - Uses models default chat-template i.e. no need for jinja chat-template This enables instruction finetuning on any model.
1 parent 1c32e62 commit 14d2192

File tree

1 file changed

+65
-28
lines changed

1 file changed

+65
-28
lines changed

common/common.cpp

Lines changed: 65 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1636,11 +1636,21 @@ ggml_opt_dataset_t common_opt_sft_dataset_init(
16361636
} else {
16371637
chat_template_source.assign(std::istreambuf_iterator<char>(tmpl_file), std::istreambuf_iterator<char>());
16381638
tmpl_file.close();
1639-
try {
1640-
chat_templates = common_chat_templates_init(llama_get_model(ctx), chat_template_source);
1641-
} catch (const std::exception & e) {
1642-
LOG_ERR("Warning: Failed to parse chat template '%s': %s\n", chat_template_path.c_str(), e.what());
1643-
}
1639+
}
1640+
}
1641+
1642+
try {
1643+
chat_templates = common_chat_templates_init(llama_get_model(ctx), chat_template_source);
1644+
if (chat_template_source.empty()) {
1645+
LOG_INF("Using model's built-in chat template\n");
1646+
} else {
1647+
LOG_INF("Using custom chat template from: %s\n", chat_template_path.c_str());
1648+
}
1649+
} catch (const std::exception & e) {
1650+
if (!chat_template_path.empty()) {
1651+
LOG_ERR("Warning: Failed to parse chat template '%s': %s\n", chat_template_path.c_str(), e.what());
1652+
} else {
1653+
LOG_ERR("Warning: Failed to initialize chat template: %s\n", e.what());
16441654
}
16451655
}
16461656

@@ -1756,33 +1766,60 @@ ggml_opt_dataset_t common_opt_sft_dataset_init(
17561766
std::vector<Span> assistant_spans;
17571767

17581768
{
1759-
size_t from = 0;
1760-
while (true) {
1761-
size_t open = render.find(START_AST, from);
1762-
if (open == std::string::npos) break;
1763-
1764-
// Include the role token ("assistant") and everything through the closing tag/newlines
1765-
size_t lo = open + START_TAG.size();
1766-
if (lo > render.size()) {
1767-
lo = render.size();
1768-
}
1769+
bool is_gemma = render.find("<start_of_turn>model\n") != std::string::npos;
1770+
1771+
if (is_gemma) {
1772+
const std::string GEMMA_START = "<start_of_turn>model\n";
1773+
const std::string GEMMA_END = "<end_of_turn>";
1774+
1775+
size_t from = 0;
1776+
while (true) {
1777+
size_t open = render.find(GEMMA_START, from);
1778+
if (open == std::string::npos) break;
1779+
size_t lo = open;
1780+
size_t close = render.find(GEMMA_END, lo);
1781+
if (close == std::string::npos) {
1782+
assistant_spans.push_back({lo, render.size()});
1783+
break;
1784+
}
17691785

1770-
size_t close = render.find(END_TAG, open + START_AST.size());
1771-
if (close == std::string::npos) {
1772-
assistant_spans.push_back({lo, render.size()});
1773-
break;
1774-
}
1786+
size_t hi = close + GEMMA_END.size();
1787+
if (hi < render.size() && render[hi] == '\n') {
1788+
hi++;
1789+
}
1790+
assistant_spans.push_back({lo, std::min(hi, render.size())});
17751791

1776-
size_t hi = close + END_TAG.size();
1777-
if (hi <= lo) {
1778-
lo = open;
1779-
hi = close + END_TAG.size();
1792+
from = hi;
17801793
}
1794+
} else {
1795+
size_t from = 0;
1796+
while (true) {
1797+
size_t open = render.find(START_AST, from);
1798+
if (open == std::string::npos) break;
1799+
1800+
// Include the role token ("assistant") and everything through the closing tag/newlines
1801+
size_t lo = open + START_TAG.size();
1802+
if (lo > render.size()) {
1803+
lo = render.size();
1804+
}
17811805

1782-
assistant_spans.push_back({lo, std::min(hi, render.size())});
1806+
size_t close = render.find(END_TAG, open + START_AST.size());
1807+
if (close == std::string::npos) {
1808+
assistant_spans.push_back({lo, render.size()});
1809+
break;
1810+
}
17831811

1784-
size_t next_from = hi;
1785-
from = next_from;
1812+
size_t hi = close + END_TAG.size();
1813+
if (hi <= lo) {
1814+
lo = open;
1815+
hi = close + END_TAG.size();
1816+
}
1817+
1818+
assistant_spans.push_back({lo, std::min(hi, render.size())});
1819+
1820+
size_t next_from = hi;
1821+
from = next_from;
1822+
}
17861823
}
17871824
}
17881825

@@ -1814,7 +1851,7 @@ ggml_opt_dataset_t common_opt_sft_dataset_init(
18141851
LOG_WRN("Warning: Conversation %zu has zero assistant tokens after masking\n", i);
18151852
continue;
18161853
}
1817-
1854+
18181855
all_tokenized_data.push_back(tokens_full);
18191856
all_assistant_masks.push_back(assistant_mask);
18201857
}

0 commit comments

Comments
 (0)