Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,6 @@ system_info.txt
# build artifacts
checkpoints/
exportedModels/

# test script
_torchchat_test_script.py
11 changes: 6 additions & 5 deletions install/install_requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,6 @@ PYTORCH_NIGHTLY_VERSION=dev20240901
# Nightly version for torchvision
VISION_NIGHTLY_VERSION=dev20240901

# Nightly version for torchtune
TUNE_NIGHTLY_VERSION=dev20240916


# Uninstall triton, as nightly will depend on pytorch-triton, which is one and the same
(
set -x
Expand All @@ -76,7 +72,6 @@ fi
REQUIREMENTS_TO_INSTALL=(
torch=="2.5.0.${PYTORCH_NIGHTLY_VERSION}"
torchvision=="0.20.0.${VISION_NIGHTLY_VERSION}"
torchtune=="0.3.0.${TUNE_NIGHTLY_VERSION}"
)

# Install the requirements. --extra-index-url tells pip to look for package
Expand All @@ -92,6 +87,12 @@ REQUIREMENTS_TO_INSTALL=(
$PIP_EXECUTABLE install torchao=="0.5.0"
)

# Rely on the latest tochtune for flamingo support
(
set -x
$PIP_EXECUTABLE install git+https://github.com/pytorch/torchtune.git@18efc81dda1c537bb7c25058ff059b4623ccff58
)
Comment on lines +92 to +94
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works on Mac now right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah and mac tests passed


if [[ -x "$(command -v nvidia-smi)" ]]; then
(
set -x
Expand Down
50 changes: 35 additions & 15 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch._inductor.config

try:
from _torchchat_test_script import flamingo_transform, padded_collate
from _torchchat_test_script import flamingo_transform
except ImportError:
pass

Expand All @@ -38,8 +38,9 @@
from torchchat.utils.device_info import get_device_info

# torchtune model definition dependencies
from torchtune.data import Message
from torchtune.generation._generation import sample as tune_sample
from torchtune.data import Message, padded_collate_tiled_images_and_mask

from torchtune.generation import sample as tune_sample
from torchtune.models.llama3 import llama3_tokenizer
from torchtune.training import set_default_dtype

Expand Down Expand Up @@ -357,15 +358,25 @@ def prefill(

if batch is not None:
# TODO: Verify sequential prefill works with multimodal models
logits = model(**batch)[:, -1]
return tune_sample(logits, 0, 500)
tokens = batch["tokens"]
if 'encoder_input' in batch:
encoder_input = batch['encoder_input']
else:
encoder_input = None

seq_len = tokens.size(1)
mask = batch["causal_mask"][None, :seq_len]
encoder_mask = batch["encoder_mask"]
input_pos = input_pos.view(1, -1)
logits = model(tokens=tokens, mask=mask, encoder_input=encoder_input, input_pos=input_pos, encoder_mask=encoder_mask)[:, -1]
return tune_sample(logits, temperature=0, top_k=500)
elif sequential_prefill:
for i in range(width):
x_sliced, ip_sliced = x[:, i].view(-1, 1), input_pos[i].view(-1)
# logging.debug(f"<sliced> x: {x_sliced}, input_pos: {ip_sliced}")
logits = model(x_sliced, ip_sliced) # (x[:, i], input_pos[i])
elif self.model.config.model_type == ModelType.Flamingo:
logits = model(x)
assert False, "Flamingo requires batch"
else:
# input_pos: [B, S]
logits = model(x, input_pos)
Expand All @@ -387,10 +398,10 @@ def decode_one_token(
assert input_pos.shape[-1] == 1
x = x.view(1, -1)
if model.config.model_type == ModelType.Flamingo:
if batch is not None:
logits = model(x, encoder_mask=batch["encoder_mask"][:, -1:])
else:
logits = model(x)
assert batch is not None, "Flamingo requires batch"
mask = batch["causal_mask"][None, input_pos.item(), None, :]
encoder_mask = batch["encoder_mask"][:, -1:]
logits = model(x, encoder_mask=encoder_mask, mask=mask, input_pos=input_pos)[:, -1:]
else:
logits = model(x, input_pos)
# print(f"x: {x},\n input_pos: {input_pos}\n")
Expand Down Expand Up @@ -593,7 +604,7 @@ def generate(
self.is_torchtune_model
or self.model.config.model_type == ModelType.Flamingo
):
model.setup_caches(max_batch_size=1, dtype=self.dtype)
model.setup_caches(batch_size=1, dtype=self.dtype, encoder_max_seq_len=6404, decoder_max_seq_len=max_seq_length-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Magic number 6404?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is the size that can hold single image input while one-gpu affordable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will leave a comment

else:
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
if is_speculative and draft_model is not model:
Expand Down Expand Up @@ -742,10 +753,19 @@ def chat(
]

transform = flamingo_transform(str(self.tokenizer_args.tokenizer_path))
data = transform({"messages": messages}, inference=True)
batch = padded_collate([data], self.builder_args.device)
batch.pop("mask")
encoded = batch["tokens"]

with torch.device(device=self.builder_args.device):
data = transform({"messages": messages}, inference=True)
batch = padded_collate_tiled_images_and_mask([data], pad_direction="left", pad_max_images=1)
seq_len = len(data["tokens"])
batch["causal_mask"] = torch.tril(
torch.ones(
size=(generator_args.max_new_tokens, generator_args.max_new_tokens),
dtype=torch.bool,
)
)
batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]
encoded = batch["tokens"]

else:
encoded = self.encode_tokens(
Expand Down
26 changes: 18 additions & 8 deletions torchchat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,18 +535,28 @@ def reset_caches(self):
class FlamingoModel(Model):
def forward(
self,
tokens: Tensor,
encoder_input: Optional[Dict[str, Tensor]] = None,
encoder_mask: Optional[Tensor] = None,
tokens: torch.Tensor,
*,
mask: Optional[torch.Tensor] = None,
encoder_input: Optional[Dict] = None,
encoder_mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
) -> Tensor:
if encoder_input is None:
return self.model(tokens, encoder_mask=encoder_mask)
return self.model(
tokens, encoder_input=encoder_input, encoder_mask=encoder_mask
tokens,
mask=mask,
encoder_input=encoder_input,
encoder_mask=encoder_mask,
input_pos=input_pos,
)

def setup_caches(self, max_batch_size, dtype):
self.model.setup_caches(max_batch_size, dtype=dtype)
def setup_caches(self, batch_size, dtype, encoder_max_seq_len, decoder_max_seq_len):
self.model.setup_caches(
batch_size=batch_size,
dtype=dtype,
encoder_max_seq_len=encoder_max_seq_len,
decoder_max_seq_len=decoder_max_seq_len,
)

def reset_caches(self):
self.model.reset_caches()
Expand Down
Loading