Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,6 @@ system_info.txt
# build artifacts
checkpoints/
exportedModels/

# test script
_torchchat_test_script.py
11 changes: 6 additions & 5 deletions install/install_requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,6 @@ PYTORCH_NIGHTLY_VERSION=dev20240901
# Nightly version for torchvision
VISION_NIGHTLY_VERSION=dev20240901

# Nightly version for torchtune
TUNE_NIGHTLY_VERSION=dev20240916


# Uninstall triton, as nightly will depend on pytorch-triton, which is one and the same
(
set -x
Expand All @@ -76,7 +72,6 @@ fi
REQUIREMENTS_TO_INSTALL=(
torch=="2.5.0.${PYTORCH_NIGHTLY_VERSION}"
torchvision=="0.20.0.${VISION_NIGHTLY_VERSION}"
torchtune=="0.3.0.${TUNE_NIGHTLY_VERSION}"
)

# Install the requirements. --extra-index-url tells pip to look for package
Expand All @@ -92,6 +87,12 @@ REQUIREMENTS_TO_INSTALL=(
$PIP_EXECUTABLE install torchao=="0.5.0"
)

# Rely on the latest tochtune for flamingo support
(
set -x
$PIP_EXECUTABLE install git+https://github.com/pytorch/torchtune.git@18efc81dda1c537bb7c25058ff059b4623ccff58
)
Comment on lines +92 to +94
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works on Mac now right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah and mac tests passed


if [[ -x "$(command -v nvidia-smi)" ]]; then
(
set -x
Expand Down
54 changes: 39 additions & 15 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch._inductor.config

try:
from _torchchat_test_script import flamingo_transform, padded_collate
from _torchchat_test_script import flamingo_transform
except ImportError:
pass

Expand All @@ -38,8 +38,9 @@
from torchchat.utils.device_info import get_device_info

# torchtune model definition dependencies
from torchtune.data import Message
from torchtune.generation._generation import sample as tune_sample
from torchtune.data import Message, padded_collate_tiled_images_and_mask

from torchtune.generation import sample as tune_sample
from torchtune.models.llama3 import llama3_tokenizer
from torchtune.training import set_default_dtype

Expand Down Expand Up @@ -357,15 +358,25 @@ def prefill(

if batch is not None:
# TODO: Verify sequential prefill works with multimodal models
logits = model(**batch)[:, -1]
return tune_sample(logits, 0, 500)
tokens = batch["tokens"]
if 'encoder_input' in batch:
encoder_input = batch['encoder_input']
else:
encoder_input = None

seq_len = tokens.size(1)
mask = batch["causal_mask"][None, :seq_len]
encoder_mask = batch["encoder_mask"]
input_pos = input_pos.view(1, -1)
logits = model(tokens=tokens, mask=mask, encoder_input=encoder_input, input_pos=input_pos, encoder_mask=encoder_mask)[:, -1]
return tune_sample(logits, temperature=0, top_k=500)
elif sequential_prefill:
for i in range(width):
x_sliced, ip_sliced = x[:, i].view(-1, 1), input_pos[i].view(-1)
# logging.debug(f"<sliced> x: {x_sliced}, input_pos: {ip_sliced}")
logits = model(x_sliced, ip_sliced) # (x[:, i], input_pos[i])
elif self.model.config.model_type == ModelType.Flamingo:
logits = model(x)
assert False, "Flamingo requires batch"
else:
# input_pos: [B, S]
logits = model(x, input_pos)
Expand All @@ -387,10 +398,10 @@ def decode_one_token(
assert input_pos.shape[-1] == 1
x = x.view(1, -1)
if model.config.model_type == ModelType.Flamingo:
if batch is not None:
logits = model(x, encoder_mask=batch["encoder_mask"][:, -1:])
else:
logits = model(x)
assert batch is not None, "Flamingo requires batch"
mask = batch["causal_mask"][None, input_pos.item(), None, :]
encoder_mask = batch["encoder_mask"][:, -1:]
logits = model(x, encoder_mask=encoder_mask, mask=mask, input_pos=input_pos)[:, -1:]
else:
logits = model(x, input_pos)
# print(f"x: {x},\n input_pos: {input_pos}\n")
Expand Down Expand Up @@ -593,7 +604,8 @@ def generate(
self.is_torchtune_model
or self.model.config.model_type == ModelType.Flamingo
):
model.setup_caches(max_batch_size=1, dtype=self.dtype)
# 6404 is one-gpu affordable max_seq_length for single image input
model.setup_caches(batch_size=1, dtype=self.dtype, encoder_max_seq_len=6404, decoder_max_seq_len=T_new)
else:
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
if is_speculative and draft_model is not model:
Expand Down Expand Up @@ -742,10 +754,22 @@ def chat(
]

transform = flamingo_transform(str(self.tokenizer_args.tokenizer_path))
data = transform({"messages": messages}, inference=True)
batch = padded_collate([data], self.builder_args.device)
batch.pop("mask")
encoded = batch["tokens"]

with torch.device(device=self.builder_args.device), set_default_dtype(self.dtype):
data = transform({"messages": messages}, inference=True)
batch = padded_collate_tiled_images_and_mask([data], pad_direction="left", pad_max_images=1)
# set_default_dtype can not handle the dtype of the image tensor inside the batch; need to manually cast it
batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(self.dtype)
seq_len = len(data["tokens"])
total_response_length = seq_len + generator_args.max_new_tokens
batch["causal_mask"] = torch.tril(
torch.ones(
size=(total_response_length, total_response_length),
dtype=torch.bool,
)
)
batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]
encoded = batch["tokens"].view(-1)

else:
encoded = self.encode_tokens(
Expand Down
33 changes: 25 additions & 8 deletions torchchat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,13 @@ def build_model(self) -> nn.Module:
modules[name] = module_class(TransformerArgs.from_params(config_args))
else:
modules[name] = module_class(**config_args)

# Temporary add extra params to the DeepFusionModel.
# TODO: Remove it once we can make fusion model configurable in model_param.
if recipe.fusion_class == DeepFusionModel:
modules["encoder_trainable"] = False
modules["decoder_trainable"] = False
modules["fusion_trainable"] = False

return recipe.fusion_class(**modules)

Expand Down Expand Up @@ -535,18 +542,28 @@ def reset_caches(self):
class FlamingoModel(Model):
def forward(
self,
tokens: Tensor,
encoder_input: Optional[Dict[str, Tensor]] = None,
encoder_mask: Optional[Tensor] = None,
tokens: torch.Tensor,
*,
mask: Optional[torch.Tensor] = None,
encoder_input: Optional[Dict] = None,
encoder_mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
) -> Tensor:
if encoder_input is None:
return self.model(tokens, encoder_mask=encoder_mask)
return self.model(
tokens, encoder_input=encoder_input, encoder_mask=encoder_mask
tokens,
mask=mask,
encoder_input=encoder_input,
encoder_mask=encoder_mask,
input_pos=input_pos,
)

def setup_caches(self, max_batch_size, dtype):
self.model.setup_caches(max_batch_size, dtype=dtype)
def setup_caches(self, batch_size, dtype, encoder_max_seq_len, decoder_max_seq_len):
self.model.setup_caches(
batch_size=batch_size,
dtype=dtype,
encoder_max_seq_len=encoder_max_seq_len,
decoder_max_seq_len=decoder_max_seq_len,
)

def reset_caches(self):
self.model.reset_caches()
Expand Down
Loading