@@ -1636,11 +1636,21 @@ ggml_opt_dataset_t common_opt_sft_dataset_init(
16361636 } else {
16371637 chat_template_source.assign (std::istreambuf_iterator<char >(tmpl_file), std::istreambuf_iterator<char >());
16381638 tmpl_file.close ();
1639- try {
1640- chat_templates = common_chat_templates_init (llama_get_model (ctx), chat_template_source);
1641- } catch (const std::exception & e) {
1642- LOG_ERR (" Warning: Failed to parse chat template '%s': %s\n " , chat_template_path.c_str (), e.what ());
1643- }
1639+ }
1640+ }
1641+
1642+ try {
1643+ chat_templates = common_chat_templates_init (llama_get_model (ctx), chat_template_source);
1644+ if (chat_template_source.empty ()) {
1645+ LOG_INF (" Using model's built-in chat template\n " );
1646+ } else {
1647+ LOG_INF (" Using custom chat template from: %s\n " , chat_template_path.c_str ());
1648+ }
1649+ } catch (const std::exception & e) {
1650+ if (!chat_template_path.empty ()) {
1651+ LOG_ERR (" Warning: Failed to parse chat template '%s': %s\n " , chat_template_path.c_str (), e.what ());
1652+ } else {
1653+ LOG_ERR (" Warning: Failed to initialize chat template: %s\n " , e.what ());
16441654 }
16451655 }
16461656
@@ -1756,33 +1766,60 @@ ggml_opt_dataset_t common_opt_sft_dataset_init(
17561766 std::vector<Span> assistant_spans;
17571767
17581768 {
1759- size_t from = 0 ;
1760- while (true ) {
1761- size_t open = render.find (START_AST, from);
1762- if (open == std::string::npos) break ;
1763-
1764- // Include the role token ("assistant") and everything through the closing tag/newlines
1765- size_t lo = open + START_TAG.size ();
1766- if (lo > render.size ()) {
1767- lo = render.size ();
1768- }
1769+ bool is_gemma = render.find (" <start_of_turn>model\n " ) != std::string::npos;
1770+
1771+ if (is_gemma) {
1772+ const std::string GEMMA_START = " <start_of_turn>model\n " ;
1773+ const std::string GEMMA_END = " <end_of_turn>" ;
1774+
1775+ size_t from = 0 ;
1776+ while (true ) {
1777+ size_t open = render.find (GEMMA_START, from);
1778+ if (open == std::string::npos) break ;
1779+ size_t lo = open;
1780+ size_t close = render.find (GEMMA_END, lo);
1781+ if (close == std::string::npos) {
1782+ assistant_spans.push_back ({lo, render.size ()});
1783+ break ;
1784+ }
17691785
1770- size_t close = render. find (END_TAG, open + START_AST .size () );
1771- if (close == std::string::npos ) {
1772- assistant_spans. push_back ({lo, render. size ()}) ;
1773- break ;
1774- }
1786+ size_t hi = close + GEMMA_END .size ();
1787+ if (hi < render. size () && render[hi] == ' \n ' ) {
1788+ hi++ ;
1789+ }
1790+ assistant_spans. push_back ({lo, std::min (hi, render. size ())});
17751791
1776- size_t hi = close + END_TAG.size ();
1777- if (hi <= lo) {
1778- lo = open;
1779- hi = close + END_TAG.size ();
1792+ from = hi;
17801793 }
1794+ } else {
1795+ size_t from = 0 ;
1796+ while (true ) {
1797+ size_t open = render.find (START_AST, from);
1798+ if (open == std::string::npos) break ;
1799+
1800+ // Include the role token ("assistant") and everything through the closing tag/newlines
1801+ size_t lo = open + START_TAG.size ();
1802+ if (lo > render.size ()) {
1803+ lo = render.size ();
1804+ }
17811805
1782- assistant_spans.push_back ({lo, std::min (hi, render.size ())});
1806+ size_t close = render.find (END_TAG, open + START_AST.size ());
1807+ if (close == std::string::npos) {
1808+ assistant_spans.push_back ({lo, render.size ()});
1809+ break ;
1810+ }
17831811
1784- size_t next_from = hi;
1785- from = next_from;
1812+ size_t hi = close + END_TAG.size ();
1813+ if (hi <= lo) {
1814+ lo = open;
1815+ hi = close + END_TAG.size ();
1816+ }
1817+
1818+ assistant_spans.push_back ({lo, std::min (hi, render.size ())});
1819+
1820+ size_t next_from = hi;
1821+ from = next_from;
1822+ }
17861823 }
17871824 }
17881825
@@ -1814,7 +1851,7 @@ ggml_opt_dataset_t common_opt_sft_dataset_init(
18141851 LOG_WRN (" Warning: Conversation %zu has zero assistant tokens after masking\n " , i);
18151852 continue ;
18161853 }
1817-
1854+
18181855 all_tokenized_data.push_back (tokens_full);
18191856 all_assistant_masks.push_back (assistant_mask);
18201857 }
0 commit comments