Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 32d969e

Browse files
committed
2/n llava e2e init
1 parent f52007e commit 32d969e

File tree

7 files changed

+378
-133
lines changed

7 files changed

+378
-133
lines changed

torchchat/cli/convert_hf_checkpoint.py

Lines changed: 87 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,13 @@
77
import os
88
import re
99
import sys
10+
import glob
1011
from pathlib import Path
11-
from typing import Optional
12+
from typing import Any, Dict, Optional
1213

1314
import torch
15+
import safetensors.torch
16+
import shutil
1417

1518
# support running without installing as a package
1619
wd = Path(__file__).parent.parent
@@ -24,34 +27,34 @@ def _translate_state_dict_for_vision_model(hf_state_dict) -> Dict[str, Any]:
2427
translated_state_dict = {}
2528
hf_weight_prefix = "vision_model."
2629
name_mapping = {
27-
f"{hf_weight_prefix}embeddings.class_embedding": "model.encoder.cls_token_embedding.weight",
28-
f"{hf_weight_prefix}embeddings.position_embedding.weight": "model.encoder.token_pos_embedding.positional_embedding",
29-
f"{hf_weight_prefix}embeddings.patch_embedding.weight": "model.encoder.conv.weight",
30-
f"{hf_weight_prefix}pre_layrnorm.weight": "model.encoder.ln_pre.weight",
31-
f"{hf_weight_prefix}pre_layrnorm.bias": "model.encoder.ln_pre.bias",
32-
f"{hf_weight_prefix}post_layernorm.weight": "model.encoder.ln_post.weight",
33-
f"{hf_weight_prefix}post_layernorm.bias": "model.encoder.ln_post.bias",
30+
f"{hf_weight_prefix}embeddings.class_embedding": "encoder.cls_token_embedding.weight",
31+
f"{hf_weight_prefix}embeddings.position_embedding.weight": "encoder.token_pos_embedding.positional_embedding",
32+
f"{hf_weight_prefix}embeddings.patch_embedding.weight": "encoder.conv.weight",
33+
f"{hf_weight_prefix}pre_layrnorm.weight": "encoder.ln_pre.weight",
34+
f"{hf_weight_prefix}pre_layrnorm.bias": "encoder.ln_pre.bias",
35+
f"{hf_weight_prefix}post_layernorm.weight": "encoder.ln_post.weight",
36+
f"{hf_weight_prefix}post_layernorm.bias": "encoder.ln_post.bias",
3437
}
3538
patterns = [
3639
(
3740
rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.self_attn\.(k|q|v)_proj\.(weight|bias)",
38-
lambda match: f"model.encoder.layers.{match.group(1)}.attn.{match.group(2)}_proj.{match.group(3)}",
41+
lambda match: f"encoder.layers.{match.group(1)}.attn.{match.group(2)}_proj.{match.group(3)}",
3942
),
4043
(
4144
rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.(weight|bias)",
42-
lambda match: f"model.encoder.layers.{match.group(1)}.attn.output_proj.{match.group(2)}",
45+
lambda match: f"encoder.layers.{match.group(1)}.attn.output_proj.{match.group(2)}",
4346
),
4447
(
4548
rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.mlp\.fc(1|2)\.(weight|bias)",
46-
lambda match: f"model.encoder.layers.{match.group(1)}.mlp.w{match.group(2)}.{match.group(3)}",
49+
lambda match: f"encoder.layers.{match.group(1)}.mlp.w{match.group(2)}.{match.group(3)}",
4750
),
4851
(
4952
rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.layer_norm1\.(weight|bias)",
50-
lambda match: f"model.encoder.layers.{match.group(1)}.sa_norm.{match.group(2)}",
53+
lambda match: f"encoder.layers.{match.group(1)}.sa_norm.{match.group(2)}",
5154
),
5255
(
5356
rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.layer_norm2\.(weight|bias)",
54-
lambda match: f"model.encoder.layers.{match.group(1)}.mlp_norm.{match.group(2)}",
57+
lambda match: f"encoder.layers.{match.group(1)}.mlp_norm.{match.group(2)}",
5558
),
5659
]
5760
for pattern, replacement in patterns:
@@ -82,18 +85,18 @@ def _translate_state_dict_for_vision_model(hf_state_dict) -> Dict[str, Any]:
8285

8386
def _translate_state_dict_for_text_model(hf_state_dict) -> Dict[str, Any]:
8487
key_map = {
85-
r"model.layers.([0-9]+).self_attn.q_proj.": r"model.decoder.layers.\1.attention.wq.",
86-
r"model.layers.([0-9]+).self_attn.k_proj.": r"model.decoder.layers.\1.attention.wk.",
87-
r"model.layers.([0-9]+).self_attn.v_proj.": r"model.decoder.layers.\1.attention.wv.",
88-
r"model.layers.([0-9]+).self_attn.o_proj.": r"model.decoder.layers.\1.attention.wo.",
89-
r"model.layers.([0-9]+).input_layernorm.": r"model.decoder.layers.\1.attention_norm.",
90-
r"model.layers.([0-9]+).mlp.gate_proj.": r"model.decoder.layers.\1.feed_forward.w1.",
91-
r"model.layers.([0-9]+).mlp.down_proj.": r"model.decoder.layers.\1.feed_forward.w2.",
92-
r"model.layers.([0-9]+).mlp.up_proj.": r"model.decoder.layers.\1.feed_forward.w3.",
93-
r"model.layers.([0-9]+).post_attention_layernorm.": r"model.decoder.layers.\1.ffn_norm.",
94-
r"model.norm.": r"model.decoder.norm.",
88+
r"model.layers.([0-9]+).self_attn.q_proj.": r"decoder.layers.\1.attention.wq.",
89+
r"model.layers.([0-9]+).self_attn.k_proj.": r"decoder.layers.\1.attention.wk.",
90+
r"model.layers.([0-9]+).self_attn.v_proj.": r"decoder.layers.\1.attention.wv.",
91+
r"model.layers.([0-9]+).self_attn.o_proj.": r"decoder.layers.\1.attention.wo.",
92+
r"model.layers.([0-9]+).input_layernorm.": r"decoder.layers.\1.attention_norm.",
93+
r"model.layers.([0-9]+).mlp.gate_proj.": r"decoder.layers.\1.feed_forward.w1.",
94+
r"model.layers.([0-9]+).mlp.down_proj.": r"decoder.layers.\1.feed_forward.w2.",
95+
r"model.layers.([0-9]+).mlp.up_proj.": r"decoder.layers.\1.feed_forward.w3.",
96+
r"model.layers.([0-9]+).post_attention_layernorm.": r"decoder.layers.\1.ffn_norm.",
97+
r"model.norm.": r"decoder.norm.",
9598
# r"model.embed_tokens.": r"tok_embeddings.", # load separately
96-
r"lm_head.": r"model.decoder.output.",
99+
r"lm_head.": r"decoder.output.",
97100
}
98101
new_state_dict = {}
99102
def get_new_key(old_key: str) -> str:
@@ -109,7 +112,7 @@ def get_new_key(old_key: str) -> str:
109112
def _translate_state_dict_for_mm_projector_model(hf_state_dict) -> Dict[str, Any]:
110113
new_state_dict = {}
111114
for old_key in hf_state_dict.keys():
112-
new_key = "model.mm_projector." + old_key
115+
new_key = "mm_projector." + old_key
113116
new_state_dict[new_key] = hf_state_dict[old_key]
114117
return new_state_dict
115118

@@ -127,13 +130,65 @@ def split_checkpoint(llava_ckpt):
127130
return language_model_ckpt, multi_modal_ckpt, vision_tower_ckpt
128131
language_model_ckpt, multi_modal_ckpt, vision_tower_ckpt = split_checkpoint(llava_ckpt)
129132
remapped_state_dict = {
130-
"model.tok_embeddings.weight": language_model_ckpt.pop("model.embed_tokens.weight"),
133+
"tok_embeddings.weight": language_model_ckpt.pop("model.embed_tokens.weight"),
131134
}
132135
remapped_state_dict.update(_translate_state_dict_for_text_model(language_model_ckpt))
133136
remapped_state_dict.update(_translate_state_dict_for_vision_model(vision_tower_ckpt))
134137
remapped_state_dict.update(_translate_state_dict_for_mm_projector_model(multi_modal_ckpt))
135138
return remapped_state_dict
136139

140+
141+
@torch.inference_mode
142+
def convert_llava_checkpoint(
143+
*,
144+
model_dir: Optional[Path] = None,
145+
) -> None:
146+
147+
"""
148+
Process safetensor files from a specific directory structure and save the remapped model.
149+
150+
Args:
151+
model_dir (str): Base directory containing the model subdirectories.
152+
"""
153+
154+
def _get_llava_files_with_pattern(pattern):
155+
pattern = os.path.join(model_dir, f"models--llava-hf--llava-1.5-7b-hf/snapshots/*/{pattern}")
156+
return glob.glob(pattern)
157+
158+
# get all safetensor files in the model directory
159+
safetensor_files = _get_llava_files_with_pattern("*.safetensors")
160+
161+
if not safetensor_files:
162+
raise ValueError("No safetensor files found.")
163+
164+
merged_weights = {}
165+
166+
# Merge safetensor files into a whole
167+
for file in safetensor_files:
168+
# Load weights from the current file
169+
part_weights = safetensors.torch.load_file(file)
170+
171+
# Iterate over each weight in the current file
172+
for key, value in part_weights.items():
173+
if key in merged_weights:
174+
# If the key already exists, concatenate tensors
175+
merged_weights[key] = torch.cat((merged_weights[key], value), dim=0)
176+
else:
177+
# If the key does not exist, add it to the dictionary
178+
merged_weights[key] = value
179+
180+
# Remap the checkpoint and save it as pth
181+
remapped_weights = remap_llava_checkpoint(merged_weights)
182+
model_path = model_dir / "model.pth"
183+
torch.save(remapped_weights, model_path)
184+
185+
# copy tokenizer
186+
tokenizer_files = _get_llava_files_with_pattern("tokenizer.model")
187+
assert len(tokenizer_files) == 1, "Should get only one tokenizer file, but got {}".format(tokenizer_files)
188+
189+
tokenizer_path = model_dir / "tokenizer.model"
190+
shutil.copy(tokenizer_files[0], tokenizer_path)
191+
137192

138193
@torch.inference_mode()
139194
def convert_text_only_hf_checkpoint(
@@ -245,18 +300,18 @@ def permute(w, n_heads):
245300

246301

247302
@torch.inference_mode()
248-
def convert_text_only_hf_checkpoint(
303+
def convert_hf_checkpoint(
249304
*,
250305
model_dir: Optional[Path] = None,
251306
model_name: Optional[str] = None,
252307
remove_bin_files: bool = False,
253308
):
254-
if model_name == "llava-1.5":
255-
print("Converting LLaVA 1.5 checkpoint.")
256-
print(os.listdir(model_dir))
257-
exit(0)
309+
print(model_name)
310+
print("***********************")
311+
if "llava" in model_name:
312+
convert_llava_checkpoint(model_dir=model_dir)
258313
else:
259-
convert_text_only_hf_checkpoint(model_dir, model_name, remove_bin_files)
314+
convert_text_only_hf_checkpoint(model_dir=model_dir, model_name=model_name, remove_bin_files=remove_bin_files)
260315

261316

262317
if __name__ == "__main__":

torchchat/cli/download.py

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -28,34 +28,14 @@ def _download_hf_snapshot(
2828
# Download and store the HF model artifacts.
2929
print(f"Downloading {model_config.name} from HuggingFace...", file=sys.stderr)
3030
try:
31-
32-
import huggingface_hub
33-
# 定义模型名称和版本
34-
model_name = "llava-hf/llava-1.5-7b-hf"
35-
# 下载模型checkpoint
36-
repo_id = model_name
37-
revision = "main" # 默认分支
38-
# 强制重新下载
39-
snapshot_dir = huggingface_hub.snapshot_download(
40-
repo_id=repo_id,
41-
revision=revision,
31+
snapshot_download(
32+
model_config.distribution_path,
4233
cache_dir=artifact_dir,
43-
force_download=True,
34+
local_dir_use_symlinks=False,
35+
token=hf_token,
36+
ignore_patterns=None if "llava" in model_config.name else "*safetensors*",
4437
)
45-
print(f"模型下载完成,保存在 {snapshot_dir} 目录下")
46-
47-
48-
# snapshot_download(
49-
# model_config.distribution_path,
50-
# cache_dir=artifact_dir,
51-
# local_dir_use_symlinks=False,
52-
# token=hf_token,
53-
# ignore_patterns="*safetensors*",
54-
# )
55-
print("*****************")
56-
print(os.listdir(artifact_dir))
57-
shutil.copytree(artifact_dir, "/home/gasoonjia/download/hahaha")
58-
exit(0)
38+
5939
except HTTPError as e:
6040
if e.response.status_code == 401: # Missing HuggingFace CLI login.
6141
print(
@@ -99,8 +79,8 @@ def download_and_convert(
9979
# location once the download and conversion is complete. This
10080
# allows recovery in the event that the download or conversion
10181
# fails unexpectedly.
102-
# temp_dir = models_dir / "downloads" / model_config.name
103-
temp_dir = Path("/home/gasoonjia") / "downloads" / model_config.name
82+
temp_dir = models_dir / "downloads" / model_config.name
83+
# temp_dir = Path("/home/gasoonjia") / "downloads" / model_config.name
10484

10585
if os.path.isdir(temp_dir):
10686
shutil.rmtree(temp_dir)

torchchat/generate.py

Lines changed: 51 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from torchchat.model import Model, ModelType
3737
from torchchat.utils.build_utils import device_sync, set_precision
3838
from torchchat.utils.device_info import get_device_info
39+
from torchchat.utils.preprocessors import llava_image_preprocess
3940

4041
# torchtune model definition dependencies
4142
from torchtune.data import Message
@@ -622,6 +623,13 @@ def generate(
622623
sequential_prefill=sequential_prefill,
623624
**sampling_kwargs,
624625
)
626+
627+
# For llava, we need to extract next pos id from prefill result
628+
if self.model.config.model_type == ModelType.Llava:
629+
next_token, context_len = next_token
630+
else:
631+
next_token, context_len = next_token, T
632+
625633
if is_speculative:
626634
self.prefill(
627635
draft_model,
@@ -636,7 +644,7 @@ def generate(
636644
# max_new_tokens <= 2 means we are effectively not calling decode_n_tokens().
637645
callback(next_token.clone().view(-1), done_generating=max_new_tokens <= 2)
638646

639-
input_pos = torch.tensor([start_pos + T], device=device, dtype=torch.int)
647+
input_pos = torch.tensor([start_pos + context_len], device=device, dtype=torch.int)
640648
accept_counts = [0] * (
641649
speculate_k + 1
642650
) # creates array of [0, 0, 0, ...] that is speculate_k + 1 long
@@ -729,31 +737,54 @@ def chat(
729737
print("Builder Args:")
730738
print(self.builder_args)
731739

732-
exit(0)
733-
734740
if generator_args.image_prompts is not None:
735741
print("Image prompts", generator_args.image_prompts)
736-
737742
# Support for just the first image prompt for now
738743
images = [Image.open(generator_args.image_prompts[0])]
739-
messages = [
740-
Message(
741-
role="user",
742-
content=[
743-
{"type": "image", "content": images[0]},
744-
{"type": "text", "content": generator_args.prompt},
745-
],
746-
eot=True,
747-
),
748-
Message(role="assistant", content=""),
749-
]
750744

751-
transform = flamingo_transform(str(self.tokenizer_args.tokenizer_path))
752-
data = transform({"messages": messages}, inference=True)
753-
batch = padded_collate([data], self.builder_args.device)
754-
batch.pop("mask")
755-
encoded = batch["tokens"]
745+
assert len(images) == 1, "Only one image prompt is supported for now"
746+
747+
#TODO: updated encoded variable for multi-modality models to include image tokens.
748+
if self.model.config.model_type == ModelType.Flamingo:
749+
messages = [
750+
Message(
751+
role="user",
752+
content=[
753+
{"type": "image", "content": images[0]},
754+
{"type": "text", "content": generator_args.prompt},
755+
],
756+
eot=True,
757+
),
758+
Message(role="assistant", content=""),
759+
]
756760

761+
transform = flamingo_transform(str(self.tokenizer_args.tokenizer_path))
762+
data = transform({"messages": messages}, inference=True)
763+
batch = padded_collate([data], self.builder_args.device)
764+
batch.pop("mask")
765+
encoded = batch["tokens"]
766+
elif self.model.config.model_type == ModelType.Llava:
767+
#TODO: double check the tokenizer.
768+
def find_subtensor(tensor, target):
769+
target_len = len(target)
770+
for i in range(len(tensor) - target_len + 1):
771+
if torch.all(tensor[i:i+target_len] == target):
772+
return i
773+
return -1
774+
775+
input_ids = self.encode_tokens(generator_args.prompt, bos=True, device=self.builder_args.device)
776+
image_token_indices = self.encode_tokens("<image>", device=self.builder_args.device)[1:]
777+
index = find_subtensor(input_ids, image_token_indices)
778+
779+
batch = {
780+
"tokens": input_ids[:index].unsqueeze(0),
781+
"encoder_input": llava_image_preprocess(images[0], device=self.builder_args.device),
782+
"post_tokens": input_ids[index + len(image_token_indices) :].unsqueeze(0),
783+
}
784+
print("BATTTTTTTCHCHHHHHHHHH")
785+
print(batch)
786+
encoded = torch.cat([batch["tokens"].view(1, -1), batch["post_tokens"].view(1, -1)], dim=-1).view(-1)
787+
757788
else:
758789
encoded = self.encode_tokens(
759790
generator_args.prompt, bos=True, device=self.builder_args.device

0 commit comments

Comments
 (0)