Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 021fd32

Browse files
authored
Bump torchtune pin to a 9-24 commit; Update Flamingo Definition (#1195)
* update flamingo model for tune * 1/n flamingo e2e ppl * flamingo e2e enable * bump up tune version * remove hacky cache size, add comment for magic number * dytpe set for input * manually cast dtype * extra config for deep fusion module
1 parent f0a03a7 commit 021fd32

File tree

4 files changed

+73
-28
lines changed

4 files changed

+73
-28
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,6 @@ system_info.txt
3030
# build artifacts
3131
checkpoints/
3232
exportedModels/
33+
34+
# test script
35+
_torchchat_test_script.py

install/install_requirements.sh

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,6 @@ PYTORCH_NIGHTLY_VERSION=dev20240901
5252
# Nightly version for torchvision
5353
VISION_NIGHTLY_VERSION=dev20240901
5454

55-
# Nightly version for torchtune
56-
TUNE_NIGHTLY_VERSION=dev20240916
57-
58-
5955
# Uninstall triton, as nightly will depend on pytorch-triton, which is one and the same
6056
(
6157
set -x
@@ -76,7 +72,6 @@ fi
7672
REQUIREMENTS_TO_INSTALL=(
7773
torch=="2.5.0.${PYTORCH_NIGHTLY_VERSION}"
7874
torchvision=="0.20.0.${VISION_NIGHTLY_VERSION}"
79-
torchtune=="0.3.0.${TUNE_NIGHTLY_VERSION}"
8075
)
8176

8277
# Install the requirements. --extra-index-url tells pip to look for package
@@ -92,6 +87,12 @@ REQUIREMENTS_TO_INSTALL=(
9287
$PIP_EXECUTABLE install torchao=="0.5.0"
9388
)
9489

90+
# Rely on the latest tochtune for flamingo support
91+
(
92+
set -x
93+
$PIP_EXECUTABLE install git+https://github.com/pytorch/torchtune.git@18efc81dda1c537bb7c25058ff059b4623ccff58
94+
)
95+
9596
if [[ -x "$(command -v nvidia-smi)" ]]; then
9697
(
9798
set -x

torchchat/generate.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import torch._inductor.config
2222

2323
try:
24-
from _torchchat_test_script import flamingo_transform, padded_collate
24+
from _torchchat_test_script import flamingo_transform
2525
except ImportError:
2626
pass
2727

@@ -38,8 +38,9 @@
3838
from torchchat.utils.device_info import get_device_info
3939

4040
# torchtune model definition dependencies
41-
from torchtune.data import Message
42-
from torchtune.generation._generation import sample as tune_sample
41+
from torchtune.data import Message, padded_collate_tiled_images_and_mask
42+
43+
from torchtune.generation import sample as tune_sample
4344
from torchtune.models.llama3 import llama3_tokenizer
4445
from torchtune.training import set_default_dtype
4546

@@ -357,15 +358,25 @@ def prefill(
357358

358359
if batch is not None:
359360
# TODO: Verify sequential prefill works with multimodal models
360-
logits = model(**batch)[:, -1]
361-
return tune_sample(logits, 0, 500)
361+
tokens = batch["tokens"]
362+
if 'encoder_input' in batch:
363+
encoder_input = batch['encoder_input']
364+
else:
365+
encoder_input = None
366+
367+
seq_len = tokens.size(1)
368+
mask = batch["causal_mask"][None, :seq_len]
369+
encoder_mask = batch["encoder_mask"]
370+
input_pos = input_pos.view(1, -1)
371+
logits = model(tokens=tokens, mask=mask, encoder_input=encoder_input, input_pos=input_pos, encoder_mask=encoder_mask)[:, -1]
372+
return tune_sample(logits, temperature=0, top_k=500)
362373
elif sequential_prefill:
363374
for i in range(width):
364375
x_sliced, ip_sliced = x[:, i].view(-1, 1), input_pos[i].view(-1)
365376
# logging.debug(f"<sliced> x: {x_sliced}, input_pos: {ip_sliced}")
366377
logits = model(x_sliced, ip_sliced) # (x[:, i], input_pos[i])
367378
elif self.model.config.model_type == ModelType.Flamingo:
368-
logits = model(x)
379+
assert False, "Flamingo requires batch"
369380
else:
370381
# input_pos: [B, S]
371382
logits = model(x, input_pos)
@@ -387,10 +398,10 @@ def decode_one_token(
387398
assert input_pos.shape[-1] == 1
388399
x = x.view(1, -1)
389400
if model.config.model_type == ModelType.Flamingo:
390-
if batch is not None:
391-
logits = model(x, encoder_mask=batch["encoder_mask"][:, -1:])
392-
else:
393-
logits = model(x)
401+
assert batch is not None, "Flamingo requires batch"
402+
mask = batch["causal_mask"][None, input_pos.item(), None, :]
403+
encoder_mask = batch["encoder_mask"][:, -1:]
404+
logits = model(x, encoder_mask=encoder_mask, mask=mask, input_pos=input_pos)[:, -1:]
394405
else:
395406
logits = model(x, input_pos)
396407
# print(f"x: {x},\n input_pos: {input_pos}\n")
@@ -593,7 +604,8 @@ def generate(
593604
self.is_torchtune_model
594605
or self.model.config.model_type == ModelType.Flamingo
595606
):
596-
model.setup_caches(max_batch_size=1, dtype=self.dtype)
607+
# 6404 is one-gpu affordable max_seq_length for single image input
608+
model.setup_caches(batch_size=1, dtype=self.dtype, encoder_max_seq_len=6404, decoder_max_seq_len=T_new)
597609
else:
598610
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
599611
if is_speculative and draft_model is not model:
@@ -742,10 +754,22 @@ def chat(
742754
]
743755

744756
transform = flamingo_transform(str(self.tokenizer_args.tokenizer_path))
745-
data = transform({"messages": messages}, inference=True)
746-
batch = padded_collate([data], self.builder_args.device)
747-
batch.pop("mask")
748-
encoded = batch["tokens"]
757+
758+
with torch.device(device=self.builder_args.device), set_default_dtype(self.dtype):
759+
data = transform({"messages": messages}, inference=True)
760+
batch = padded_collate_tiled_images_and_mask([data], pad_direction="left", pad_max_images=1)
761+
# set_default_dtype can not handle the dtype of the image tensor inside the batch; need to manually cast it
762+
batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(self.dtype)
763+
seq_len = len(data["tokens"])
764+
total_response_length = seq_len + generator_args.max_new_tokens
765+
batch["causal_mask"] = torch.tril(
766+
torch.ones(
767+
size=(total_response_length, total_response_length),
768+
dtype=torch.bool,
769+
)
770+
)
771+
batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]
772+
encoded = batch["tokens"].view(-1)
749773

750774
else:
751775
encoded = self.encode_tokens(

torchchat/model.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,13 @@ def build_model(self) -> nn.Module:
458458
modules[name] = module_class(TransformerArgs.from_params(config_args))
459459
else:
460460
modules[name] = module_class(**config_args)
461+
462+
# Temporary add extra params to the DeepFusionModel.
463+
# TODO: Remove it once we can make fusion model configurable in model_param.
464+
if recipe.fusion_class == DeepFusionModel:
465+
modules["encoder_trainable"] = False
466+
modules["decoder_trainable"] = False
467+
modules["fusion_trainable"] = False
461468

462469
return recipe.fusion_class(**modules)
463470

@@ -535,18 +542,28 @@ def reset_caches(self):
535542
class FlamingoModel(Model):
536543
def forward(
537544
self,
538-
tokens: Tensor,
539-
encoder_input: Optional[Dict[str, Tensor]] = None,
540-
encoder_mask: Optional[Tensor] = None,
545+
tokens: torch.Tensor,
546+
*,
547+
mask: Optional[torch.Tensor] = None,
548+
encoder_input: Optional[Dict] = None,
549+
encoder_mask: Optional[torch.Tensor] = None,
550+
input_pos: Optional[torch.Tensor] = None,
541551
) -> Tensor:
542-
if encoder_input is None:
543-
return self.model(tokens, encoder_mask=encoder_mask)
544552
return self.model(
545-
tokens, encoder_input=encoder_input, encoder_mask=encoder_mask
553+
tokens,
554+
mask=mask,
555+
encoder_input=encoder_input,
556+
encoder_mask=encoder_mask,
557+
input_pos=input_pos,
546558
)
547559

548-
def setup_caches(self, max_batch_size, dtype):
549-
self.model.setup_caches(max_batch_size, dtype=dtype)
560+
def setup_caches(self, batch_size, dtype, encoder_max_seq_len, decoder_max_seq_len):
561+
self.model.setup_caches(
562+
batch_size=batch_size,
563+
dtype=dtype,
564+
encoder_max_seq_len=encoder_max_seq_len,
565+
decoder_max_seq_len=decoder_max_seq_len,
566+
)
550567

551568
def reset_caches(self):
552569
self.model.reset_caches()

0 commit comments

Comments
 (0)