Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 32241ff

Browse files
authored
[Distributed] Fix tiktokenizer decoding (#1257)
1 parent dc3d35e commit 32241ff

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

dist_run.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
442442
# New token generated each iteration
443443
# need a row dimension for each prompt in the batch
444444
new_token = torch.zeros(batch_size, 1, device=device, dtype=torch.int64)
445-
logger.info(f"{color.green}{new_token.shape=}, {new_token=}{color.reset}")
446445
# Store the generated tokens
447446
res = []
448447

@@ -519,7 +518,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
519518

520519
# Decode the output
521520
if pp_rank == last_pp_rank:
522-
# logger.info(f"{color.red}Decoding...{output.shape=}{color.reset}")
523521
new_token = _batch_decode_next_tokens(output, prompt_lengths, step)
524522
res.append(new_token)
525523
if not args.disable_in_flight_decode:
@@ -541,7 +539,13 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
541539
# token ids. Thus cat'ing along dim 1.
542540
res = torch.cat(res, dim=1)
543541
res_list = res.tolist()
544-
responses = tokenizer.decode(res_list)
542+
if isinstance(tokenizer, TiktokenTokenizer):
543+
# For TiktokenTokenizer, we need to decode prompt by prompt.
544+
# TODO: is there a better way to do this?
545+
responses = [tokenizer.decode(sequence) for sequence in res_list]
546+
else: # SentencePieceProcessor
547+
# For SentencePieceProcessor, we can decode the entire 2D list at once.
548+
responses = tokenizer.decode(res_list)
545549
# Show prompts and responses
546550
for prompt_text, response_text in zip(prompt, responses):
547551
logger.info(f"Prompt: {color.green}{prompt_text} {color.reset}")

torchchat/distributed/safetensor_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,19 @@ def get_hf_weight_map_and_path(
8888
raise FileNotFoundError(
8989
f"Weight index file for {model_id} does not exist in HF cache."
9090
)
91+
logger.info(
92+
f"Loading weight map from: {index_file}"
93+
)
9194
weight_map = read_weights_from_json(index_file)
9295
if weight_map is None:
9396
raise ValueError(f"Weight map not found in config file {index_file}")
9497
weight_map, new_to_old_keymap = remap_weight_keys(weight_map)
9598
weight_path = os.path.dirname(index_file)
9699
if not os.path.exists(weight_path):
97100
raise FileNotFoundError(f"Weight path {weight_path} does not exist")
101+
logger.info(
102+
f"Loading weights from: {weight_path}"
103+
)
98104
return weight_map, weight_path, new_to_old_keymap
99105

100106

0 commit comments

Comments
 (0)