Skip to content

Commit bbd50cc

Browse files
authored
[BUGFIX]Fix tokenizer bug when getting action masks with enable_thinking arguments (agentscope-ai#316)
Co-authored-by: 问昊 <[email protected]>
1 parent 332f37a commit bbd50cc

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

trinity/common/models/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def tokenize_and_mask_messages_hf(
1616
messages: List[dict],
1717
tools: Optional[List[dict]] = None,
1818
chat_template: Optional[str] = None,
19+
enable_thinking: Optional[bool] = None,
1920
) -> Tuple[torch.Tensor, torch.Tensor, int]:
2021
"""Calculate the assistant token mask with `chat_template`.
2122
@@ -35,6 +36,7 @@ def tokenize_and_mask_messages_hf(
3536
tools=tools,
3637
chat_template=chat_template,
3738
add_generation_prompt=False,
39+
enable_thinking=enable_thinking,
3840
padding=False,
3941
truncation=True,
4042
return_tensors="pt",
@@ -52,6 +54,7 @@ def tokenize_and_mask_messages_default(
5254
messages: List[dict],
5355
tools: Optional[List[dict]] = None,
5456
chat_template: Optional[str] = None,
57+
enable_thinking: Optional[bool] = None,
5558
) -> Tuple[torch.Tensor, torch.Tensor, int]:
5659
"""Calculate the assistant token mask.
5760
@@ -78,6 +81,7 @@ def tokenize_and_mask_messages_default(
7881
tools=tools,
7982
chat_template=chat_template,
8083
add_generation_prompt=False,
84+
enable_thinking=enable_thinking,
8185
padding=False,
8286
truncation=True,
8387
return_tensors="pt",
@@ -91,6 +95,7 @@ def tokenize_and_mask_messages_default(
9195
tools=tools,
9296
chat_template=chat_template,
9397
add_generation_prompt=True,
98+
enable_thinking=enable_thinking,
9499
padding=False,
95100
truncation=True,
96101
return_tensors="pt",
@@ -102,6 +107,7 @@ def tokenize_and_mask_messages_default(
102107
tools=tools,
103108
chat_template=chat_template,
104109
add_generation_prompt=False,
110+
enable_thinking=enable_thinking,
105111
padding=False,
106112
truncation=True,
107113
return_tensors="pt",

trinity/common/models/vllm_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ async def convert_messages_to_experience(
354354
messages=messages,
355355
tools=tools,
356356
chat_template=self.chat_template,
357+
enable_thinking=self.enable_thinking,
357358
) # (seq_length, ), (seq_length, )
358359
logprobs = await self.logprobs(token_ids=token_ids.tolist()) # (seq_length - 1,)
359360
return Experience(

0 commit comments

Comments
 (0)