Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 78debce

Browse files
committed
remove global var for tokenizer type + patch tokenizer to allow list of sequences
1 parent 6fe1646 commit 78debce

File tree

1 file changed

+25
-29
lines changed

1 file changed

+25
-29
lines changed

dist_run.py

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212
import os
1313
from enum import auto, Enum
1414
from pathlib import Path
15-
from types import SimpleNamespace
15+
from types import SimpleNamespace, MethodType
1616
from typing import Any, Dict, List, Optional, Tuple
1717

1818
import torch
1919
import torch.distributed as dist
2020
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
21-
from torchchat.cli.builder import _initialize_tokenizer, TokenizerArgs
21+
from torchchat.cli.builder import TokenizerArgs
2222

2323
# TODO - these are not distributed specific, consider moving to new package
2424
from torchchat.distributed.checkpoint_utils import (
@@ -50,7 +50,6 @@
5050

5151

5252
logger = SingletonLogger.get_logger()
53-
_tokenizer_type = None # global variable to store the tokenizer type
5453

5554
# Using model name to identify the model to load, for example "llama2-7b-chat".
5655
# You can change it to other values listed below.
@@ -61,11 +60,6 @@
6160
}
6261

6362

64-
class TokenizerType(Enum):
65-
Tiktoken = auto()
66-
SentencePiece = auto()
67-
68-
6963
def _init_distributed():
7064
dist.init_process_group("nccl")
7165
rank = dist.get_rank()
@@ -82,14 +76,29 @@ def _create_device_mesh(mesh_dimensions):
8276
def dict_to_args(dictionary: Dict[str, Any]) -> SimpleNamespace:
8377
return SimpleNamespace(**dictionary)
8478

79+
def _patch_tokenizer(tokenizer):
80+
"""Patch the tokenizer to support decoding of token ids."""
81+
if isinstance(tokenizer, TiktokenTokenizer):
82+
# Patch tiktokenizer to allow a list of sequences.
83+
#TODO: Upstream to tokenizer modules
84+
old_decode = tokenizer.decode
85+
86+
def decode(self, token_ids: List[int|List[int]], *args, **kwargs) -> str | List[str]:
87+
if len(token_ids)<1:
88+
return ""
89+
if isinstance(token_ids[0], list):
90+
return [old_decode(t, *args, **kwargs) for t in token_ids]
91+
else:
92+
return old_decode(token_ids, *args, **kwargs)
93+
94+
tokenizer.decode = MethodType(decode, tokenizer)
95+
return tokenizer
8596

8697
def _build_chat_tokenizer(
8798
model_name: str,
8899
model_base_name: Optional[str] = None,
89100
) -> SentencePieceProcessor | TiktokenTokenizer:
90-
"""Builds a tokenizer for the given model name, and sets the global tokenizer type variable"""
91-
92-
global _tokenizer_type
101+
"""Builds a tokenizer for the given model name"""
93102

94103
# Try to infer the model base name from the model name:
95104
# e.g. "llama2-7b-chat" -> "llama2"
@@ -112,20 +121,14 @@ def _build_chat_tokenizer(
112121
}
113122
args = dict_to_args(tokenconfig)
114123
tokenizer_args = TokenizerArgs.from_args(args)
115-
tokenizer = _initialize_tokenizer(tokenizer_args)
124+
tokenizer = tokenizer_args.t
116125
assert tokenizer is not None, f"Failed to get tokenizer using {tokenconfig=}"
117126
logger.info(
118127
f"using tokenizer = {tokenizer.__class__.__module__}.{tokenizer.__class__.__name__}"
119128
)
120-
# set global variable _tokenizer_type
121-
if isinstance(tokenizer, TiktokenTokenizer):
122-
_tokenizer_type = TokenizerType.Tiktoken
123-
elif isinstance(tokenizer, SentencePieceProcessor):
124-
_tokenizer_type = TokenizerType.SentencePiece
125-
else:
126-
raise ValueError(f"Unknown tokenizer type: {tokenizer.__class__}")
127129

128-
logger.info(f"tokenizer type = {_tokenizer_type}")
130+
tokenizer = _patch_tokenizer(tokenizer)
131+
129132
return tokenizer
130133

131134

@@ -568,15 +571,8 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
568571
# token ids. Thus cat'ing along dim 1.
569572
res = torch.cat(res, dim=1)
570573
res_list = res.tolist()
571-
if _tokenizer_type == TokenizerType.Tiktoken:
572-
# For TiktokenTokenizer, we need to decode prompt by prompt.
573-
# TODO: is there a better way to do this?
574-
responses = [tokenizer.decode(sequence) for sequence in res_list]
575-
elif _tokenizer_type == TokenizerType.SentencePiece: # SentencePieceProcessor
576-
# For SentencePieceProcessor, we can decode the entire 2D list at once.
577-
responses = tokenizer.decode(res_list)
578-
else:
579-
raise ValueError(f"Unknown tokenizer type {_tokenizer_type}")
574+
575+
responses = tokenizer.decode(res_list)
580576

581577
# Show prompts and responses
582578
for prompt_text, response_text in zip(prompt, responses):

0 commit comments

Comments
 (0)