Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit fd9f704

Browse files
committed
Build tokenizer from TokenizerArgs
1 parent 3f6fa2d commit fd9f704

File tree

3 files changed

+20
-35
lines changed

3 files changed

+20
-35
lines changed

torchchat/distributed/dist_run.py

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -102,32 +102,11 @@ def decode(
102102

103103

104104
def _build_chat_tokenizer(
105-
model_name: str,
106-
model_base_name: Optional[str] = None,
105+
tokenizer_args: TokenizerArgs,
107106
) -> SentencePieceProcessor | TiktokenTokenizer:
108107
"""Builds a tokenizer for the given model name"""
109-
110-
# Try to infer the model base name from the model name:
111-
# e.g. "llama2-7b-chat" -> "llama2"
112-
if model_base_name is None:
113-
model_base_name = model_name.split("-")[0]
114-
logger.info(
115-
f"Using model base name '{model_base_name}' to build tokenizer. "
116-
"If not found, please specify it using the `model_base_name` argument."
117-
)
118-
119-
# Create base args for tokenizer
120-
default_model_dir = Path(
121-
os.getenv("TORCHCHAT_MODELDIR", "~/.torchchat/model-cache")
122-
).expanduser()
123-
124-
tokenconfig = {
125-
"model_directory": default_model_dir,
126-
"model": model_base_name,
127-
"tokenizer_path": None,
128-
}
129-
args = dict_to_args(tokenconfig)
130-
tokenizer_args = TokenizerArgs.from_args(args)
108+
109+
tokenizer_args = TokenizerArgs.from_args(tokenizer_args)
131110
tokenizer = tokenizer_args.t
132111
assert tokenizer is not None, f"Failed to get tokenizer using {tokenconfig=}"
133112
logger.info(
@@ -313,9 +292,14 @@ def _cleanup():
313292
]
314293

315294

316-
def main(args, pipe):
295+
def main(
296+
builder_args,
297+
tokenizer_args,
298+
pipe,
299+
):
317300
model_name = "llama3" # args.model_name
318-
pp_degree = args.pp
301+
# print(f"{builder_args.checkpoint_path=}")
302+
pp_degree = builder_args.pp
319303

320304
rank, world_size = _init_distributed()
321305
logger.info(f"Worker started: {rank=}, {world_size=}")
@@ -332,7 +316,7 @@ def main(args, pipe):
332316
config = TransformerArgs.from_params(model_config.transformer_args["text"])
333317
logger.info(f"Transformer Config: {config}")
334318

335-
tokenizer = _build_chat_tokenizer(model_name)
319+
tokenizer = _build_chat_tokenizer(tokenizer_args)
336320

337321
set_precision(model_dtype)
338322
logger.info(f"Using cache precision {model_dtype}")
@@ -385,7 +369,7 @@ def main(args, pipe):
385369
# Load weights
386370
logger.info(f"Loading weights for {pp_rank=} on {device=}")
387371
with CUDATrackTime() as timer:
388-
_load_model_weights(model, distribution, device, config, args.chpt_from)
372+
_load_model_weights(model, distribution, device, config, builder_args.chpt_from)
389373

390374
logger.info(
391375
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"

torchchat/distributed/generate.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ def _setup_env(world_size: int, rank: int, target: callable, *args, **kwargs):
3232
return target(*args, **kwargs)
3333

3434

35-
def _launch_distributed_inference(builder_args: BuilderArgs) -> None:
35+
def _launch_distributed_inference(
36+
builder_args: BuilderArgs, tokenizer_args: TokenizerArgs
37+
) -> tuple[List]:
3638
# create programmatic elastic launch
3739
print("Launching distributed inference ...")
3840

@@ -49,7 +51,7 @@ def _launch_distributed_inference(builder_args: BuilderArgs) -> None:
4951
pipes.append(server_pipe)
5052
proc = mp.Process(
5153
target=partial(_setup_env, num_processes_per_node, rank, main),
52-
args=(builder_args, client_pipe),
54+
args=(builder_args, tokenizer_args, client_pipe),
5355
)
5456
proc.start()
5557

@@ -189,7 +191,9 @@ def __init__(
189191

190192
self.check_args()
191193

192-
self.procs, self.pipes = _launch_distributed_inference(builder_args)
194+
self.procs, self.pipes = _launch_distributed_inference(
195+
builder_args, tokenizer_args
196+
)
193197

194198
self.loop = asyncio.new_event_loop()
195199
asyncio.set_event_loop(self.loop)

torchchat/generate.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030
BuilderArgs,
3131
TokenizerArgs,
3232
)
33-
from torchchat.model import Model, ModelType
3433
from torchchat.distributed.generate import DistributedGenerator
34+
from torchchat.model import Model, ModelType
3535
from torchchat.utils.build_utils import device_sync, set_precision
3636
from torchchat.utils.device_info import get_device_info
3737

@@ -1228,7 +1228,6 @@ def main(args):
12281228
)
12291229
if torch.cuda.is_available():
12301230
torch.cuda.reset_peak_memory_stats()
1231-
12321231

12331232
for _ in gen.chat(generator_args):
12341233
pass
@@ -1248,5 +1247,3 @@ def main(args):
12481247

12491248
print(f"Model output: {response}")
12501249
dist_gen.shutdown()
1251-
1252-

0 commit comments

Comments
 (0)