Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 141fea0

Browse files
committed
rebase and solve comments
2 parents 01bb624 + 319ac86 commit 141fea0

File tree

9 files changed

+588
-52
lines changed

9 files changed

+588
-52
lines changed

_torchchat_test_script.py

Lines changed: 293 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,293 @@
1+
2+
import torch
3+
import sys
4+
import os
5+
6+
from torchtune import training
7+
from torchtune.models.flamingo import flamingo_decoder, flamingo_vision_encoder, FlamingoTransform
8+
from torchtune.modules.model_fusion import DeepFusionModel
9+
10+
from torchchat.model import Model
11+
12+
import re
13+
14+
from typing import Dict
15+
from torchtune.generation._generation import sample
16+
from torchtune.training import set_default_dtype
17+
import numpy as np
18+
import PIL
19+
20+
from torchtune.data import Message
21+
22+
def flamingo_transform(tokenizer_path):
23+
return FlamingoTransform(
24+
tokenizer_path,
25+
tile_size=448,
26+
patch_size=14,
27+
max_num_tiles=4,
28+
max_seq_len=8192,
29+
encoder_max_seq_len=4100,
30+
image_mean=(0.48145466, 0.4578275, 0.40821073),
31+
image_std=(0.26862954, 0.26130258, 0.27577711),
32+
prompt_template=None,
33+
)
34+
35+
def padded_collate(batch, device='cuda', dtype=torch.bfloat16, padding_idx=0):
36+
# Placeholder Collator until https://github.com/pytorch/torchtune/pull/1156 lands
37+
assert len(batch) == 1, "Test collate function only supports bs = 1"
38+
sample = batch[0]
39+
sample["tokens"] = torch.Tensor(sample["tokens"])[None, ...].to(device).long()
40+
sample["mask"] = torch.Tensor(sample["mask"])[None, ...].to(device).bool()
41+
sample["encoder_input"]["images"] = torch.stack(sample["encoder_input"]["images"])[None, ...].to(device)
42+
sample["encoder_input"]["aspect_ratio"] = torch.stack(sample["encoder_input"]["aspect_ratio"])[None, ...].to(device)
43+
assert len(sample["encoder_mask"]), "Test collate function only supports 1 image per sequence"
44+
# Pad encoder mask to max_num_tiles sequence length (4100)
45+
s_x, s_y = sample["encoder_mask"][0].shape
46+
mask_padding = torch.zeros((s_x, 4100 - s_y), dtype=torch.bool)
47+
encoder_mask = torch.cat([sample["encoder_mask"][0], mask_padding], dim=1)
48+
sample["encoder_mask"] = encoder_mask[None, ...].to(device)
49+
return sample
50+
51+
52+
53+
_FROM_META = {
54+
"text_model.tok_embeddings.weight": "decoder.tok_embeddings.weight",
55+
"text_model.learnable_embedding.weight": "decoder.tok_embeddings.fusion_embedding.weight",
56+
"text_model.norm.weight": "decoder.norm.scale",
57+
"text_model.output.weight": "decoder.output.weight",
58+
59+
"text_model.layers.{}.attention_norm.weight": "decoder.layers.{}.sa_norm.scale",
60+
"text_model.layers.{}.attention.wq.weight": "decoder.layers.{}.attn.q_proj.weight",
61+
"text_model.layers.{}.attention.wk.weight": "decoder.layers.{}.attn.k_proj.weight",
62+
"text_model.layers.{}.attention.wv.weight": "decoder.layers.{}.attn.v_proj.weight",
63+
"text_model.layers.{}.attention.wo.weight": "decoder.layers.{}.attn.output_proj.weight",
64+
"text_model.layers.{}.ffn_norm.weight": "decoder.layers.{}.mlp_norm.scale",
65+
"text_model.layers.{}.feed_forward.w1.weight": "decoder.layers.{}.mlp.w1.weight",
66+
"text_model.layers.{}.feed_forward.w3.weight": "decoder.layers.{}.mlp.w3.weight",
67+
"text_model.layers.{}.feed_forward.w2.weight": "decoder.layers.{}.mlp.w2.weight",
68+
69+
"text_model.cross_attention_layers.{}.gate_attn": "decoder.layers.{}.fusion_layer.ca_scale.scale",
70+
"text_model.cross_attention_layers.{}.gate_ffwd": "decoder.layers.{}.fusion_layer.mlp_scale.scale",
71+
"text_model.cross_attention_layers.{}.attention_norm.weight": "decoder.layers.{}.fusion_layer.ca_norm.scale",
72+
"text_model.cross_attention_layers.{}.ffn_norm.weight": "decoder.layers.{}.fusion_layer.mlp_norm.scale",
73+
"text_model.cross_attention_layers.{}.attention.wq.weight": "decoder.layers.{}.fusion_layer.attn.q_proj.weight",
74+
"text_model.cross_attention_layers.{}.attention.wk.weight": "decoder.layers.{}.fusion_layer.attn.k_proj.weight",
75+
"text_model.cross_attention_layers.{}.attention.wv.weight": "decoder.layers.{}.fusion_layer.attn.v_proj.weight",
76+
"text_model.cross_attention_layers.{}.attention.wo.weight": "decoder.layers.{}.fusion_layer.attn.output_proj.weight",
77+
"text_model.cross_attention_layers.{}.attention.inner_attention.q_norm.weight": "decoder.layers.{}.fusion_layer.attn.q_norm.scale",
78+
"text_model.cross_attention_layers.{}.attention.inner_attention.k_norm.weight": "decoder.layers.{}.fusion_layer.attn.k_norm.scale",
79+
"text_model.cross_attention_layers.{}.feed_forward.w1.weight": "decoder.layers.{}.fusion_layer.mlp.w1.weight",
80+
"text_model.cross_attention_layers.{}.feed_forward.w3.weight": "decoder.layers.{}.fusion_layer.mlp.w3.weight",
81+
"text_model.cross_attention_layers.{}.feed_forward.w2.weight": "decoder.layers.{}.fusion_layer.mlp.w2.weight",
82+
83+
"vision_model.vision_encoder.positional_embedding": "encoder.clip.token_pos_embedding.local_token_positional_embedding",
84+
"vision_model.vision_encoder.gated_positional_embedding": "encoder.clip.token_pos_embedding.global_token_positional_embedding",
85+
"vision_model.vision_encoder.gated_positional_embedding_gate": "encoder.clip.token_pos_embedding.gate",
86+
"vision_model.vision_encoder.ln_pre.weight": "encoder.clip.ln_pre.weight",
87+
"vision_model.vision_encoder.ln_pre.bias": "encoder.clip.ln_pre.bias",
88+
"vision_model.vision_encoder.ln_post.weight": "encoder.clip.ln_post.weight",
89+
"vision_model.vision_encoder.ln_post.bias": "encoder.clip.ln_post.bias",
90+
"vision_model.vision_encoder.pre_tile_pos_embed.embedding": "encoder.clip.pre_tile_pos_embed.embedding",
91+
"vision_model.vision_encoder.pre_tile_pos_embed.gate": "encoder.clip.pre_tile_pos_embed.gate",
92+
"vision_model.vision_encoder.post_tile_pos_embed.embedding": "encoder.clip.post_tile_pos_embed.embedding",
93+
"vision_model.vision_encoder.post_tile_pos_embed.gate": "encoder.clip.post_tile_pos_embed.gate",
94+
"vision_model.vision_encoder.class_embedding" : "encoder.clip.cls_token_embedding.weight",
95+
"vision_model.vision_encoder.conv1._linear.weight" : "encoder.clip.conv.weight",
96+
97+
"vision_model.vision_encoder.transformer.resblocks.{}.attn.wq.weight": "encoder.clip.layers.{}.attn.q_proj.weight",
98+
"vision_model.vision_encoder.transformer.resblocks.{}.attn.wk.weight": "encoder.clip.layers.{}.attn.k_proj.weight",
99+
"vision_model.vision_encoder.transformer.resblocks.{}.attn.wv.weight": "encoder.clip.layers.{}.attn.v_proj.weight",
100+
"vision_model.vision_encoder.transformer.resblocks.{}.attn.wo.weight": "encoder.clip.layers.{}.attn.output_proj.weight",
101+
"vision_model.vision_encoder.transformer.resblocks.{}.mlp.c_fc.weight": "encoder.clip.layers.{}.mlp.w1.weight",
102+
"vision_model.vision_encoder.transformer.resblocks.{}.mlp.c_fc.bias": "encoder.clip.layers.{}.mlp.w1.bias",
103+
"vision_model.vision_encoder.transformer.resblocks.{}.mlp.c_proj.weight": "encoder.clip.layers.{}.mlp.w2.weight",
104+
"vision_model.vision_encoder.transformer.resblocks.{}.mlp.c_proj.bias": "encoder.clip.layers.{}.mlp.w2.bias",
105+
"vision_model.vision_encoder.transformer.resblocks.{}.ln_1.weight": "encoder.clip.layers.{}.sa_norm.weight",
106+
"vision_model.vision_encoder.transformer.resblocks.{}.ln_1.bias": "encoder.clip.layers.{}.sa_norm.bias",
107+
"vision_model.vision_encoder.transformer.resblocks.{}.ln_2.weight": "encoder.clip.layers.{}.mlp_norm.weight",
108+
"vision_model.vision_encoder.transformer.resblocks.{}.ln_2.bias": "encoder.clip.layers.{}.mlp_norm.bias",
109+
110+
"vision_model.vision_projection.weight" : "encoder.projection.output.weight",
111+
"vision_model.vision_projection.bias" : "encoder.projection.output.bias",
112+
113+
"vision_model.vision_encoder.global_transformer.resblocks.{}.attn.wq.weight": "encoder.projection.layers.{}.attn.q_proj.weight",
114+
"vision_model.vision_encoder.global_transformer.resblocks.{}.attn.wk.weight": "encoder.projection.layers.{}.attn.k_proj.weight",
115+
"vision_model.vision_encoder.global_transformer.resblocks.{}.attn.wv.weight": "encoder.projection.layers.{}.attn.v_proj.weight",
116+
"vision_model.vision_encoder.global_transformer.resblocks.{}.attn.wo.weight": "encoder.projection.layers.{}.attn.output_proj.weight",
117+
"vision_model.vision_encoder.global_transformer.resblocks.{}.mlp.c_fc.weight": "encoder.projection.layers.{}.mlp.w1.weight",
118+
"vision_model.vision_encoder.global_transformer.resblocks.{}.mlp.c_fc.bias": "encoder.projection.layers.{}.mlp.w1.bias",
119+
"vision_model.vision_encoder.global_transformer.resblocks.{}.mlp.c_proj.weight": "encoder.projection.layers.{}.mlp.w2.weight",
120+
"vision_model.vision_encoder.global_transformer.resblocks.{}.mlp.c_proj.bias": "encoder.projection.layers.{}.mlp.w2.bias",
121+
"vision_model.vision_encoder.global_transformer.resblocks.{}.ln_1.weight": "encoder.projection.layers.{}.sa_norm.weight",
122+
"vision_model.vision_encoder.global_transformer.resblocks.{}.ln_1.bias": "encoder.projection.layers.{}.sa_norm.bias",
123+
"vision_model.vision_encoder.global_transformer.resblocks.{}.ln_2.weight": "encoder.projection.layers.{}.mlp_norm.weight",
124+
"vision_model.vision_encoder.global_transformer.resblocks.{}.ln_2.bias": "encoder.projection.layers.{}.mlp_norm.bias",
125+
"vision_model.vision_encoder.global_transformer.resblocks.{}.gate_attn": "encoder.projection.layers.{}.sa_scale.scale",
126+
"vision_model.vision_encoder.global_transformer.resblocks.{}.gate_ffn": "encoder.projection.layers.{}.mlp_scale.scale",
127+
}
128+
129+
130+
def get_mapped_key(key: str, mapping_dict: Dict[str, str]) -> str:
131+
try:
132+
if any(k.isdigit() for k in key.split(".")):
133+
# Replace layer number with "{}" to create key for lookup
134+
abstract_key = re.sub(r"(\.\d+)", ".{}", key)
135+
layer_num = re.search(r"\d+", key).group(0)
136+
new_key = mapping_dict[abstract_key]
137+
new_key = new_key.format(layer_num)
138+
else:
139+
new_key = mapping_dict[key]
140+
except KeyError as e:
141+
raise Exception(
142+
f'Error converting the state dict. Found unexpected key: "{key}". '
143+
"Please make sure you're loading a checkpoint with the right format. "
144+
) from e
145+
146+
return new_key
147+
148+
149+
def flamingo_meta_to_tune(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
150+
"""
151+
Convertor from Meta state dict to torchtune state dict. This handles:
152+
- Updateing the cross attention layer numbers
153+
"""
154+
converted_state_dict = {}
155+
156+
for key, value in state_dict.items():
157+
if key == "text_model.rope.freqs":
158+
continue
159+
new_key = get_mapped_key(key, _FROM_META)
160+
if "cross_attention_layers" in key:
161+
layer = int(key.split(".")[2])
162+
# TODO: grab num_layers and generalize this
163+
new_layer = (layer + 1) * 4 - 1
164+
key_lst = new_key.split(".")
165+
key_lst[2] = str(new_layer)
166+
new_key = ".".join(key_lst)
167+
if "gate_ffwd" in key or "gate_attn" in key:
168+
value = value[:1]
169+
elif "conv1" in key:
170+
# TODO: get patch size and generalize
171+
value = value.reshape(-1, 3, 14, 14)
172+
converted_state_dict[new_key] = value
173+
return converted_state_dict
174+
175+
176+
177+
if __name__ == "__main__":
178+
llava3_2_dir = str(sys.argv[1])
179+
param_path = os.path.join(llava3_2_dir, "flamingo.json")
180+
tokenizer_path = os.path.join(llava3_2_dir, "tokenizer.model")
181+
checkpoint_path = os.path.join(llava3_2_dir, "consolidated.pth")
182+
image_path = os.path.join(llava3_2_dir, "dog.jpg")
183+
184+
if len(sys.argv) > 2:
185+
device = torch.device(str(sys.argv[2]))
186+
elif torch.cuda.is_available():
187+
device = torch.device('cuda:0')
188+
else:
189+
device = torch.device("cpu")
190+
191+
print(f"Using device: {device}")
192+
print(f"Loading model from {param_path}")
193+
194+
dtype = torch.bfloat16
195+
with set_default_dtype(dtype), device:
196+
model = Model.from_params(param_path)
197+
198+
transform = flamingo_transform(tokenizer_path)
199+
200+
print(f"Loading checkpoint from {checkpoint_path}")
201+
state_dict = torch.load(checkpoint_path)
202+
print("Converting state dict into flamingo format")
203+
state_dict = flamingo_meta_to_tune(state_dict)
204+
print("Loading state dict into model")
205+
model.model.load_state_dict(state_dict)
206+
207+
model = torch.compile(model)
208+
images = [PIL.Image.open(image_path)]
209+
210+
dialog = [
211+
Message(
212+
role="user",
213+
content=[
214+
{"type": "image"},
215+
{"type": "text", "content": "What's in this image?"},
216+
],
217+
eot=True,
218+
),
219+
Message(role="assistant", content="")
220+
]
221+
222+
data = transform({"images": images, "messages": dialog}, inference=True)
223+
224+
model.eval()
225+
with device:
226+
model.setup_caches(1, dtype=torch.bfloat16)
227+
228+
229+
max_generated_tokens = 100
230+
temperature = .6
231+
top_k = 500
232+
233+
print("Generating...")
234+
235+
generated_tokens = []
236+
model.reset_caches()
237+
with torch.no_grad():
238+
batch = padded_collate([data], device, dtype)
239+
batch.pop("mask")
240+
241+
logits = model(**batch)[:, -1]
242+
tok = sample(logits, temperature, top_k)
243+
generated_tokens.append(tok.item())
244+
245+
cache_mask = batch["encoder_mask"][:, -1:]
246+
for _ in range(max_generated_tokens):
247+
if tok.item() in transform.stop_tokens:
248+
break
249+
logits = model(tok, encoder_mask=cache_mask)[:, -1]
250+
tok = sample(logits, temperature, top_k)
251+
generated_tokens.append(tok.item())
252+
253+
print(transform.decode(generated_tokens))
254+
255+
256+
257+
""":md
258+
## Chat Pseudo Code
259+
260+
This approach guarantees that there's only one image cached at a time so that there's no need for cross attention masking.
261+
This works because Llama3v is trained such that each token is only allowed to attend to the previous image and the rest are
262+
masked during training/finetuning. Since consecutive images are treated as one image for Llama3v, you can control the maximum
263+
encoder sequence length by setting max_consecuitve here, as well as by settin max_num_tiles and max_resolution for the image input.
264+
265+
```python
266+
model.eval()
267+
model.setup_caches(1, torch.bf16)
268+
269+
with torch.no_grad():
270+
# Prefill system prompt
271+
toks, _ = transform(parse_prompt(system_prompt))
272+
model(toks)
273+
while True:
274+
# Prefill user prompt split over images
275+
user_prompt = input(">>> ")
276+
toks, imgs = transform(parse_prompt(user_prompt))
277+
for i, tok in enumerate(split(toks, image_token, max_consecutive=1)):
278+
img = None
279+
if imgs is not None:
280+
img = imgs[i]
281+
reset_attn_cache(model)
282+
logits = model(tok, img)
283+
284+
# Decode assitant response
285+
tok = sample_tok(logits) # only ouptput single token logits when model.cache_enabled=True
286+
while tok != EOS:
287+
logits = model(tok)
288+
tok = sample_tok(logits)
289+
sys.stdout.buffer.write(transform.decode(tok))
290+
```
291+
"""
292+
293+
""":py"""

distributed/parallelize_llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def apply_tp(
6262
# after we apply TP to the model. Because we don't want to change model code
6363
# when applying TP. We need to have change to ensure KVCache has the correct
6464
# size as k and v.
65-
model.model.config.n_local_heads = model.model.config.n_local_heads // tp_mesh.size()
65+
model.get_text_transformer_args.n_local_heads = model.get_text_transformer_args.n_local_heads // tp_mesh.size()
6666

6767
# Apply tensor parallelism to every transformer block
6868
for transformer_block in model.layers:

torchchat/cli/builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@ def _initialize_model(
563563
model.setup_caches(
564564
max_batch_size=1,
565565
max_seq_length=max_seq_length
566-
or model.model.config.max_seq_length,
566+
or model.get_text_transformer_args.max_seq_length,
567567
)
568568

569569
model.to(dtype=builder_args.precision)

torchchat/export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def export_for_server(
5454
torch.tensor([0, 1, 2, 3, 4], dtype=torch.int, device=device),
5555
)
5656

57-
seq = Dim("seq", min=1, max=model.model.config.max_seq_length)
57+
seq = Dim("seq", min=1, max=model.get_text_transformer_args.max_seq_length)
5858
# Specify that the first dimension of each input is that batch size
5959
dynamic_shapes = {"tokens": {1: seq}, "input_pos": {0: seq}}
6060
else:

torchchat/generate.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,8 @@ def prefill(
364364
x_sliced, ip_sliced = x[:, i].view(-1, 1), input_pos[i].view(-1)
365365
# logging.debug(f"<sliced> x: {x_sliced}, input_pos: {ip_sliced}")
366366
logits = model(x_sliced, ip_sliced) # (x[:, i], input_pos[i])
367+
elif self.model.config.model_type == ModelType.Flamingo:
368+
logits = model(x)
367369
else:
368370
# input_pos: [B, S]
369371
logits = model(x, input_pos)
@@ -383,11 +385,14 @@ def decode_one_token(
383385
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
384386
# input_pos: [B, 1]
385387
assert input_pos.shape[-1] == 1
386-
if model.config.model_type == ModelType.Flamingo and batch is not None:
387-
x = x.view(1, -1)
388-
logits = model(x, encoder_mask=batch["encoder_mask"][:, -1:])
388+
x = x.view(1, -1)
389+
if model.config.model_type == ModelType.Flamingo:
390+
if batch is not None:
391+
logits = model(x, encoder_mask=batch["encoder_mask"][:, -1:])
392+
else:
393+
logits = model(x)
389394
else:
390-
logits = model(x.view(1, -1), input_pos)
395+
logits = model(x, input_pos)
391396
# print(f"x: {x},\n input_pos: {input_pos}\n")
392397
return self.sample(logits, need_probs=need_probs, **sampling_kwargs)
393398

0 commit comments

Comments
 (0)