@@ -16,6 +16,7 @@ def tokenize_and_mask_messages_hf(
1616 messages : List [dict ],
1717 tools : Optional [List [dict ]] = None ,
1818 chat_template : Optional [str ] = None ,
19+ enable_thinking : Optional [bool ] = None ,
1920) -> Tuple [torch .Tensor , torch .Tensor , int ]:
2021 """Calculate the assistant token mask with `chat_template`.
2122
@@ -35,6 +36,7 @@ def tokenize_and_mask_messages_hf(
3536 tools = tools ,
3637 chat_template = chat_template ,
3738 add_generation_prompt = False ,
39+ enable_thinking = enable_thinking ,
3840 padding = False ,
3941 truncation = True ,
4042 return_tensors = "pt" ,
@@ -52,6 +54,7 @@ def tokenize_and_mask_messages_default(
5254 messages : List [dict ],
5355 tools : Optional [List [dict ]] = None ,
5456 chat_template : Optional [str ] = None ,
57+ enable_thinking : Optional [bool ] = None ,
5558) -> Tuple [torch .Tensor , torch .Tensor , int ]:
5659 """Calculate the assistant token mask.
5760
@@ -78,6 +81,7 @@ def tokenize_and_mask_messages_default(
7881 tools = tools ,
7982 chat_template = chat_template ,
8083 add_generation_prompt = False ,
84+ enable_thinking = enable_thinking ,
8185 padding = False ,
8286 truncation = True ,
8387 return_tensors = "pt" ,
@@ -91,6 +95,7 @@ def tokenize_and_mask_messages_default(
9195 tools = tools ,
9296 chat_template = chat_template ,
9397 add_generation_prompt = True ,
98+ enable_thinking = enable_thinking ,
9499 padding = False ,
95100 truncation = True ,
96101 return_tensors = "pt" ,
@@ -102,6 +107,7 @@ def tokenize_and_mask_messages_default(
102107 tools = tools ,
103108 chat_template = chat_template ,
104109 add_generation_prompt = False ,
110+ enable_thinking = enable_thinking ,
105111 padding = False ,
106112 truncation = True ,
107113 return_tensors = "pt" ,
0 commit comments