Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 96 additions & 45 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import base64
import itertools
import logging
import os
Expand All @@ -12,6 +13,7 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from io import BytesIO
from os import PathLike
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -600,9 +602,8 @@ def generate(

if len(prompt.shape) > 1:
prompt = prompt.squeeze(0)
T = prompt.size(0)
max_new_tokens = min(max_new_tokens, max_seq_length - start_pos - T)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this line necessary to remove?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like it would be problem with long prompt_lengths

T_new = T + max_new_tokens
prompt_length = prompt.size(0)
max_new_tokens = min(max_new_tokens, max_seq_length - start_pos - prompt_length)
# set up caches only if first inference
if start_pos == 0:
model = model.to(device=device)
Expand All @@ -616,7 +617,7 @@ def generate(
batch_size=1,
dtype=self.dtype,
encoder_max_seq_len=6404,
decoder_max_seq_len=T_new,
decoder_max_seq_len=max_seq_length,
)
else:
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
Expand All @@ -629,7 +630,7 @@ def generate(
model.reset_caches()

input_pos = torch.arange(
start_pos, T + start_pos, device=device, dtype=torch.int
start_pos, prompt_length + start_pos, device=device, dtype=torch.int
)

prefill_t0 = time.perf_counter()
Expand All @@ -655,7 +656,9 @@ def generate(
# max_new_tokens <= 2 means we are effectively not calling decode_n_tokens().
callback(next_token.clone().view(-1), done_generating=max_new_tokens <= 2)

input_pos = torch.tensor([start_pos + T], device=device, dtype=torch.int)
input_pos = torch.tensor(
[start_pos + prompt_length], device=device, dtype=torch.int
)
accept_counts = [0] * (
speculate_k + 1
) # creates array of [0, 0, 0, ...] that is speculate_k + 1 long
Expand All @@ -678,7 +681,7 @@ def generate(
)

accept_counts[len(next_tokens) - 1] += 1
num_added = min(T_new - input_pos - 1, len(next_tokens))
num_added = min(max_new_tokens - input_pos - 1, len(next_tokens))
for token in next_tokens[:num_added,]:
callback(token)
yield token, None
Expand Down Expand Up @@ -741,6 +744,7 @@ def _gen_model_input(
prompt: Union[str | List[Any]],
image_prompts: Optional[List[str | Image.Image]] = None,
max_new_tokens: Optional[int] = None,
max_seq_len: Optional[int] = 2048,
) -> Tuple[torch.Tensor, Optional[Dict[str, Any]]]:
"""
Convert prompt and image prompts into consumable model input args.
Expand All @@ -757,7 +761,7 @@ def _gen_model_input(
Tuple[torch.Tensor, Optional[Dict[str, Any]]]: Encoded prompt and batch config for multimodal models.
"""

# Not Llama 3.2 11B
# Text-Only model
if self.model.config.model_type != ModelType.Flamingo:
# Single String prompt
if isinstance(prompt, str):
Expand All @@ -778,32 +782,69 @@ def _gen_model_input(
assert (
image_prompts is None or len(image_prompts) == 1
), "At most one image is supported at the moment"

if image_prompts and isinstance(image_prompts[0], str):
images = [Image.open(image_prompts[0])]
else:
images = image_prompts
images = None

assert (
max_new_tokens is not None
), "max_new_tokens must be specified for Flamingo models"
assert isinstance(
prompt, str
), "(Currently) prompt must be a str for Flamingo models"

is_multimodal = images is not None
content = [{"type": "text", "content": prompt}]
image_found = False
messages = []
for message in prompt:
Copy link
Contributor

@Jack-Khuu Jack-Khuu Oct 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be torchchat.py generate Llama3.2-11B right?

Since it sends prompt: str and uses the image_prompt field

You might need to "create" a container prompt with those 2

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or chat calls a curried version of this function that creates the format before calling this function

if isinstance(message["content"], str):
if not image_found and image_prompts:
messages.append(
Message(
role=message["role"],
content=[
{"type": "image", "content": images[0]},
{"type": "text", "content": message["content"]},
],
)
)
image_found = True
else:
messages.append(Message(**message))

elif isinstance(message["content"], list):
images = None
for content_dict in message["content"]:
if content_dict["type"] == "text":
prompt_arg = content_dict["text"]
elif content_dict["type"] == "image_url":
assert (
images is None
), "At most one image is supported at the moment"

base64_decoded = base64.b64decode(
content_dict["image_url"].split(";base64,")[1]
)
images = [Image.open(BytesIO(base64_decoded))]
image_found = True

is_multimodal = images is not None
content = [{"type": "text", "content": prompt_arg}]

if is_multimodal:
content = [{"type": "image", "content": images[0]}] + content

if is_multimodal:
content = [{"type": "image", "content": images[0]}] + content
messages.append(
Message(
role=message["role"],
content=content,
)
)

messages = [
messages.append(
Message(
role="user",
content=content,
eot=True,
),
Message(role="assistant", content=""),
]
role="assistant",
content="",
)
)

transform = llama3_2_vision_transform(str(self.tokenizer_args.tokenizer_path))

Expand All @@ -812,7 +853,7 @@ def _gen_model_input(
with device, set_default_dtype(self.dtype):
data = transform({"messages": messages}, inference=True)

if is_multimodal:
if image_found:
batch = padded_collate_tiled_images_and_mask(
[data], pad_direction="left", pad_max_images=1
)
Expand All @@ -822,17 +863,27 @@ def _gen_model_input(
batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(
self.dtype
)

else:
encoded = torch.tensor(data["tokens"], device=device).view(-1)
seq_len = encoded.size(0)
batch = {}

total_response_length = seq_len + max_new_tokens
batch["causal_mask"] = torch.tril(
torch.ones(
size=(total_response_length, total_response_length),
dtype=torch.bool,
)
batch["causal_mask"] = torch.nn.functional.pad(
torch.tril(
torch.ones(
size=(total_response_length, total_response_length),
dtype=torch.bool,
)
),
(
0,
max_seq_len - total_response_length,
0,
max_seq_len - total_response_length,
),
value=0,
)

logging.debug(encoded)
Expand All @@ -845,12 +896,6 @@ def chat(
if generator_args.chat_mode:
print("Starting Interactive Chat")

encoded, batch = self._gen_model_input(
generator_args.prompt,
generator_args.image_prompts,
generator_args.max_new_tokens,
)

model_size = sum(
[
p.numel() * p.dtype.itemsize
Expand Down Expand Up @@ -896,6 +941,12 @@ def chat(
max_seq_length = (
text_transformer_args.max_seq_length if text_transformer_args else 2048
)
encoded, batch = self._gen_model_input(
[{"role": "user", "content": generator_args.prompt}],
generator_args.image_prompts,
generator_args.max_new_tokens,
max_seq_length,
)

if generator_args.chat_mode:
print(
Expand All @@ -907,16 +958,16 @@ def chat(
if get_system_prompt == "y" or get_system_prompt == "Y":
self.system_prompt = input("What is your system prompt? \n")

elif not generator_args.is_torchtune_model:
max_seq_length = min(
encoded.size(0) + generator_args.max_new_tokens,
(
text_transformer_args.block_size
if text_transformer_args is not None
else 2048
),
max_seq_length,
)
# elif not generator_args.is_torchtune_model:
# max_seq_length = min(
# encoded.size(0) + generator_args.max_new_tokens,
# (
# text_transformer_args.block_size
# if text_transformer_args is not None
# else 2048
# ),
# max_seq_length,
# )

max_seq_length = (
max_seq_length + self.speculative_builder_args.speculate_k + 1
Expand Down
Loading
Loading