Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit a290249

Browse files
authored
Merge branch 'main' into lessw2020/input_pos
2 parents 31fb1cf + d0993b3 commit a290249

File tree

6 files changed

+212
-111
lines changed

6 files changed

+212
-111
lines changed

torchchat/cli/convert_hf_checkpoint.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import glob
67
import json
78
import os
89
import re
@@ -41,7 +42,12 @@ def convert_hf_checkpoint(
4142
print(f"Model config {config.__dict__}")
4243

4344
# Load the json file containing weight mapping
44-
model_map_json = model_dir / "pytorch_model.bin.index.json"
45+
model_map_json_matches = [Path(m) for m in glob.glob(str(model_dir / "*.index.json"))]
46+
assert len(model_map_json_matches) <= 1, "Found multiple weight mapping files"
47+
if len(model_map_json_matches):
48+
model_map_json = model_map_json_matches[0]
49+
else:
50+
model_map_json = model_dir / "pytorch_model.bin.index.json"
4551

4652
# If there is no weight mapping, check for a consolidated model and
4753
# tokenizer we can move. Llama 2 and Mistral have weight mappings, while
@@ -96,9 +102,33 @@ def permute(w, n_heads):
96102

97103
merged_result = {}
98104
for file in sorted(bin_files):
99-
state_dict = torch.load(
105+
106+
# The state_dict can be loaded from either a torch zip file or
107+
# safetensors. We take our best guess from the name and try all
108+
# possibilities
109+
load_pt_mmap = lambda: torch.load(
100110
str(file), map_location="cpu", mmap=True, weights_only=True
101111
)
112+
load_pt_no_mmap = lambda: torch.load(
113+
str(file), map_location="cpu", mmap=False, weights_only=True
114+
)
115+
def load_safetensors():
116+
import safetensors.torch
117+
with open(file, "rb") as handle:
118+
return safetensors.torch.load(handle.read())
119+
if "safetensors" in str(file):
120+
loaders = [load_safetensors, load_pt_mmap, load_pt_no_mmap]
121+
else:
122+
loaders = [load_pt_mmap, load_pt_no_mmap, load_safetensors]
123+
124+
state_dict = None
125+
for loader in loaders:
126+
try:
127+
state_dict = loader()
128+
break
129+
except Exception:
130+
continue
131+
assert state_dict is not None, f"Unable to load tensors from {file}"
102132
merged_result.update(state_dict)
103133
final_result = {}
104134
for key, value in merged_result.items():

torchchat/cli/download.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,44 @@
2222
def _download_hf_snapshot(
2323
model_config: ModelConfig, artifact_dir: Path, hf_token: Optional[str]
2424
):
25-
from huggingface_hub import snapshot_download
25+
from huggingface_hub import model_info, snapshot_download
2626
from requests.exceptions import HTTPError
2727

2828
# Download and store the HF model artifacts.
2929
print(f"Downloading {model_config.name} from HuggingFace...", file=sys.stderr)
3030
try:
31+
# Fetch the info about the model's repo
32+
model_info = model_info(model_config.distribution_path, token=hf_token)
33+
model_fnames = [f.rfilename for f in model_info.siblings]
34+
35+
# Check the model config for preference between safetensors and pth
36+
has_pth = any(f.endswith(".pth") for f in model_fnames)
37+
has_safetensors = any(f.endswith(".safetensors") for f in model_fnames)
38+
39+
# If told to prefer safetensors, ignore pth files
40+
if model_config.prefer_safetensors:
41+
if not has_safetensors:
42+
print(
43+
f"Model {model_config.name} does not have safetensors files, but prefer_safetensors is set to True. Using pth files instead.",
44+
file=sys.stderr,
45+
)
46+
exit(1)
47+
ignore_patterns = "*.pth"
48+
49+
# If the model has both, prefer pth files over safetensors
50+
elif has_pth and has_safetensors:
51+
ignore_patterns = "*safetensors*"
52+
53+
# Otherwise, download everything
54+
else:
55+
ignore_patterns = None
56+
3157
snapshot_download(
3258
model_config.distribution_path,
3359
local_dir=artifact_dir,
3460
local_dir_use_symlinks=False,
3561
token=hf_token,
36-
ignore_patterns="*safetensors*",
62+
ignore_patterns=ignore_patterns,
3763
)
3864
except HTTPError as e:
3965
if e.response.status_code == 401: # Missing HuggingFace CLI login.

torchchat/generate.py

Lines changed: 96 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66
import argparse
7+
import base64
78
import itertools
89
import logging
910
import os
@@ -12,6 +13,7 @@
1213

1314
from abc import ABC, abstractmethod
1415
from dataclasses import dataclass
16+
from io import BytesIO
1517
from os import PathLike
1618
from pathlib import Path
1719
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
@@ -600,9 +602,8 @@ def generate(
600602

601603
if len(prompt.shape) > 1:
602604
prompt = prompt.squeeze(0)
603-
T = prompt.size(0)
604-
max_new_tokens = min(max_new_tokens, max_seq_length - start_pos - T)
605-
T_new = T + max_new_tokens
605+
prompt_length = prompt.size(0)
606+
max_new_tokens = min(max_new_tokens, max_seq_length - start_pos - prompt_length)
606607
# set up caches only if first inference
607608
if start_pos == 0:
608609
model = model.to(device=device)
@@ -616,7 +617,7 @@ def generate(
616617
batch_size=1,
617618
dtype=self.dtype,
618619
encoder_max_seq_len=6404,
619-
decoder_max_seq_len=T_new,
620+
decoder_max_seq_len=max_seq_length,
620621
)
621622
else:
622623
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
@@ -629,7 +630,7 @@ def generate(
629630
model.reset_caches()
630631

631632
input_pos = torch.arange(
632-
start_pos, T + start_pos, device=device, dtype=torch.int
633+
start_pos, prompt_length + start_pos, device=device, dtype=torch.int
633634
)
634635

635636
prefill_t0 = time.perf_counter()
@@ -655,7 +656,9 @@ def generate(
655656
# max_new_tokens <= 2 means we are effectively not calling decode_n_tokens().
656657
callback(next_token.clone().view(-1), done_generating=max_new_tokens <= 2)
657658

658-
input_pos = torch.tensor([start_pos + T], device=device, dtype=torch.int)
659+
input_pos = torch.tensor(
660+
[start_pos + prompt_length], device=device, dtype=torch.int
661+
)
659662
accept_counts = [0] * (
660663
speculate_k + 1
661664
) # creates array of [0, 0, 0, ...] that is speculate_k + 1 long
@@ -678,7 +681,7 @@ def generate(
678681
)
679682

680683
accept_counts[len(next_tokens) - 1] += 1
681-
num_added = min(T_new - input_pos - 1, len(next_tokens))
684+
num_added = min(max_new_tokens - input_pos - 1, len(next_tokens))
682685
for token in next_tokens[:num_added,]:
683686
callback(token)
684687
yield token, None
@@ -741,6 +744,7 @@ def _gen_model_input(
741744
prompt: Union[str | List[Any]],
742745
image_prompts: Optional[List[str | Image.Image]] = None,
743746
max_new_tokens: Optional[int] = None,
747+
max_seq_len: Optional[int] = 2048,
744748
) -> Tuple[torch.Tensor, Optional[Dict[str, Any]]]:
745749
"""
746750
Convert prompt and image prompts into consumable model input args.
@@ -757,7 +761,7 @@ def _gen_model_input(
757761
Tuple[torch.Tensor, Optional[Dict[str, Any]]]: Encoded prompt and batch config for multimodal models.
758762
"""
759763

760-
# Not Llama 3.2 11B
764+
# Text-Only model
761765
if self.model.config.model_type != ModelType.Flamingo:
762766
# Single String prompt
763767
if isinstance(prompt, str):
@@ -778,32 +782,69 @@ def _gen_model_input(
778782
assert (
779783
image_prompts is None or len(image_prompts) == 1
780784
), "At most one image is supported at the moment"
785+
781786
if image_prompts and isinstance(image_prompts[0], str):
782787
images = [Image.open(image_prompts[0])]
783788
else:
784-
images = image_prompts
789+
images = None
785790

786791
assert (
787792
max_new_tokens is not None
788793
), "max_new_tokens must be specified for Flamingo models"
789-
assert isinstance(
790-
prompt, str
791-
), "(Currently) prompt must be a str for Flamingo models"
792794

793-
is_multimodal = images is not None
794-
content = [{"type": "text", "content": prompt}]
795+
image_found = False
796+
messages = []
797+
for message in prompt:
798+
if isinstance(message["content"], str):
799+
if not image_found and image_prompts:
800+
messages.append(
801+
Message(
802+
role=message["role"],
803+
content=[
804+
{"type": "image", "content": images[0]},
805+
{"type": "text", "content": message["content"]},
806+
],
807+
)
808+
)
809+
image_found = True
810+
else:
811+
messages.append(Message(**message))
812+
813+
elif isinstance(message["content"], list):
814+
images = None
815+
for content_dict in message["content"]:
816+
if content_dict["type"] == "text":
817+
prompt_arg = content_dict["text"]
818+
elif content_dict["type"] == "image_url":
819+
assert (
820+
images is None
821+
), "At most one image is supported at the moment"
822+
823+
base64_decoded = base64.b64decode(
824+
content_dict["image_url"].split(";base64,")[1]
825+
)
826+
images = [Image.open(BytesIO(base64_decoded))]
827+
image_found = True
828+
829+
is_multimodal = images is not None
830+
content = [{"type": "text", "content": prompt_arg}]
831+
832+
if is_multimodal:
833+
content = [{"type": "image", "content": images[0]}] + content
795834

796-
if is_multimodal:
797-
content = [{"type": "image", "content": images[0]}] + content
835+
messages.append(
836+
Message(
837+
role=message["role"],
838+
content=content,
839+
)
840+
)
798841

799-
messages = [
842+
messages.append(
800843
Message(
801-
role="user",
802-
content=content,
803-
eot=True,
804-
),
805-
Message(role="assistant", content=""),
806-
]
844+
role="assistant",
845+
content="",
846+
)
847+
)
807848

808849
transform = llama3_2_vision_transform(str(self.tokenizer_args.tokenizer_path))
809850

@@ -812,7 +853,7 @@ def _gen_model_input(
812853
with device, set_default_dtype(self.dtype):
813854
data = transform({"messages": messages}, inference=True)
814855

815-
if is_multimodal:
856+
if image_found:
816857
batch = padded_collate_tiled_images_and_mask(
817858
[data], pad_direction="left", pad_max_images=1
818859
)
@@ -822,17 +863,27 @@ def _gen_model_input(
822863
batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(
823864
self.dtype
824865
)
866+
825867
else:
826868
encoded = torch.tensor(data["tokens"], device=device).view(-1)
827869
seq_len = encoded.size(0)
828870
batch = {}
829871

830872
total_response_length = seq_len + max_new_tokens
831-
batch["causal_mask"] = torch.tril(
832-
torch.ones(
833-
size=(total_response_length, total_response_length),
834-
dtype=torch.bool,
835-
)
873+
batch["causal_mask"] = torch.nn.functional.pad(
874+
torch.tril(
875+
torch.ones(
876+
size=(total_response_length, total_response_length),
877+
dtype=torch.bool,
878+
)
879+
),
880+
(
881+
0,
882+
max_seq_len - total_response_length,
883+
0,
884+
max_seq_len - total_response_length,
885+
),
886+
value=0,
836887
)
837888

838889
logging.debug(encoded)
@@ -845,12 +896,6 @@ def chat(
845896
if generator_args.chat_mode:
846897
print("Starting Interactive Chat")
847898

848-
encoded, batch = self._gen_model_input(
849-
generator_args.prompt,
850-
generator_args.image_prompts,
851-
generator_args.max_new_tokens,
852-
)
853-
854899
model_size = sum(
855900
[
856901
p.numel() * p.dtype.itemsize
@@ -896,6 +941,12 @@ def chat(
896941
max_seq_length = (
897942
text_transformer_args.max_seq_length if text_transformer_args else 2048
898943
)
944+
encoded, batch = self._gen_model_input(
945+
[{"role": "user", "content": generator_args.prompt}],
946+
generator_args.image_prompts,
947+
generator_args.max_new_tokens,
948+
max_seq_length,
949+
)
899950

900951
if generator_args.chat_mode:
901952
print(
@@ -907,16 +958,16 @@ def chat(
907958
if get_system_prompt == "y" or get_system_prompt == "Y":
908959
self.system_prompt = input("What is your system prompt? \n")
909960

910-
elif not generator_args.is_torchtune_model:
911-
max_seq_length = min(
912-
encoded.size(0) + generator_args.max_new_tokens,
913-
(
914-
text_transformer_args.block_size
915-
if text_transformer_args is not None
916-
else 2048
917-
),
918-
max_seq_length,
919-
)
961+
# elif not generator_args.is_torchtune_model:
962+
# max_seq_length = min(
963+
# encoded.size(0) + generator_args.max_new_tokens,
964+
# (
965+
# text_transformer_args.block_size
966+
# if text_transformer_args is not None
967+
# else 2048
968+
# ),
969+
# max_seq_length,
970+
# )
920971

921972
max_seq_length = (
922973
max_seq_length + self.speculative_builder_args.speculate_k + 1

torchchat/model_config/model_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class ModelConfig:
4646
checkpoint_file: str = field(default="model.pth")
4747
tokenizer_file: str = field(default="tokenizer.model")
4848
transformer_params_key: str = field(default=None)
49+
prefer_safetensors: bool = field(default=False)
4950

5051

5152
# Keys are stored in lowercase.

0 commit comments

Comments
 (0)