Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 64 additions & 34 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import base64
import itertools
import logging
import os
Expand All @@ -12,6 +13,7 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from io import BytesIO
from os import PathLike
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -600,9 +602,7 @@ def generate(

if len(prompt.shape) > 1:
prompt = prompt.squeeze(0)
T = prompt.size(0)
max_new_tokens = min(max_new_tokens, max_seq_length - start_pos - T)
T_new = T + max_new_tokens
prompt_length = prompt.size(0)
# set up caches only if first inference
if start_pos == 0:
model = model.to(device=device)
Expand All @@ -616,7 +616,7 @@ def generate(
batch_size=1,
dtype=self.dtype,
encoder_max_seq_len=6404,
decoder_max_seq_len=T_new,
decoder_max_seq_len=max_seq_length,
)
else:
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
Expand All @@ -629,7 +629,7 @@ def generate(
model.reset_caches()

input_pos = torch.arange(
start_pos, T + start_pos, device=device, dtype=torch.int
start_pos, prompt_length + start_pos, device=device, dtype=torch.int
)

prefill_t0 = time.perf_counter()
Expand All @@ -655,7 +655,7 @@ def generate(
# max_new_tokens <= 2 means we are effectively not calling decode_n_tokens().
callback(next_token.clone().view(-1), done_generating=max_new_tokens <= 2)

input_pos = torch.tensor([start_pos + T], device=device, dtype=torch.int)
input_pos = torch.tensor([start_pos + prompt_length], device=device, dtype=torch.int)
accept_counts = [0] * (
speculate_k + 1
) # creates array of [0, 0, 0, ...] that is speculate_k + 1 long
Expand All @@ -678,7 +678,7 @@ def generate(
)

accept_counts[len(next_tokens) - 1] += 1
num_added = min(T_new - input_pos - 1, len(next_tokens))
num_added = min(max_new_tokens - input_pos - 1, len(next_tokens))
for token in next_tokens[:num_added,]:
callback(token)
yield token, None
Expand Down Expand Up @@ -734,13 +734,14 @@ def _callback(self, x, *, buffer, done_generating):
if len(buffer) == 4 or done_generating:
print("".join(buffer), end="", flush=True)
buffer.clear()
# print(, end='', flush=True)
print(, end='', flush=True)

def _gen_model_input(
self,
prompt: Union[str | List[Any]],
image_prompts: Optional[List[str | Image.Image]] = None,
max_new_tokens: Optional[int] = None,
max_seq_len: Optional[int] = 2048,
) -> Tuple[torch.Tensor, Optional[Dict[str, Any]]]:
"""
Convert prompt and image prompts into consumable model input args.
Expand All @@ -757,7 +758,7 @@ def _gen_model_input(
Tuple[torch.Tensor, Optional[Dict[str, Any]]]: Encoded prompt and batch config for multimodal models.
"""

# Not Llama 3.2 11B
# Text-Only model
if self.model.config.model_type != ModelType.Flamingo:
# Single String prompt
if isinstance(prompt, str):
Expand All @@ -778,6 +779,7 @@ def _gen_model_input(
assert (
image_prompts is None or len(image_prompts) == 1
), "At most one image is supported at the moment"

if image_prompts and isinstance(image_prompts[0], str):
images = [Image.open(image_prompts[0])]
else:
Expand All @@ -786,24 +788,41 @@ def _gen_model_input(
assert (
max_new_tokens is not None
), "max_new_tokens must be specified for Flamingo models"
assert isinstance(
prompt, str
), "(Currently) prompt must be a str for Flamingo models"

is_multimodal = images is not None
content = [{"type": "text", "content": prompt}]

if is_multimodal:
content = [{"type": "image", "content": images[0]}] + content

messages = [
Message(
role="user",
content=content,
eot=True,
),
Message(role="assistant", content=""),
]
image_found = False
messages = []
for message in prompt:
if isinstance(message["content"], str):
messages.append(Message(**message))

elif isinstance(message["content"], list):
images = None
for content_dict in message["content"]:
if content_dict["type"] == "text":
prompt_arg = content_dict["text"]
elif content_dict["type"] == "image_url":
assert (
images is None
), "At most one image is supported at the moment"

base64_decoded = base64.b64decode(
content_dict["image_url"].split(";base64,")[1]
)
images = [Image.open(BytesIO(base64_decoded))]
image_found = True

is_multimodal = images is not None
content = [{"type": "text", "content": prompt_arg}]
[]
if is_multimodal:
content = [{"type": "image", "content": images[0]}] + content

messages.append(
Message(
role="user",
content=content,
)
)

transform = llama3_2_vision_transform(str(self.tokenizer_args.tokenizer_path))

Expand All @@ -812,27 +831,37 @@ def _gen_model_input(
with device, set_default_dtype(self.dtype):
data = transform({"messages": messages}, inference=True)

if is_multimodal:
if image_found:
batch = padded_collate_tiled_images_and_mask(
[data], pad_direction="left", pad_max_images=1
)
encoded = batch.pop("tokens").to(device).view(-1)
seq_len = encoded.size(0)
seq_len = encoded.size(0)
batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]
batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(
self.dtype
)

else:
encoded = torch.tensor(data["tokens"], device=device).view(-1)
seq_len = encoded.size(0)
batch = {}

total_response_length = seq_len + max_new_tokens
batch["causal_mask"] = torch.tril(
torch.ones(
size=(total_response_length, total_response_length),
dtype=torch.bool,
)
total_response_length = max_seq_len + max_new_tokens
batch["causal_mask"] = torch.nn.functional.pad(
torch.tril(
torch.ones(
size=(total_response_length, total_response_length),
dtype=torch.bool,
)
),
(
0,
max_seq_len - total_response_length,
0,
max_seq_len - total_response_length,
),
value=0,
)

logging.debug(encoded)
Expand All @@ -849,6 +878,7 @@ def chat(
generator_args.prompt,
generator_args.image_prompts,
generator_args.max_new_tokens,
generator_args.max_seq_length,
)

model_size = sum(
Expand Down
36 changes: 10 additions & 26 deletions torchchat/usages/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,38 +316,22 @@ def _gen_model_inputs_from_openai_completion_request(
if not isinstance(self.model, FlamingoModel):
prompt = [
{"role": message["role"], "content": message["content"]}
for message in completion_request.messages
for message in messages
]
return self._gen_model_input(
prompt=prompt, max_new_tokens=completion_request.max_tokens
)

# Llama 3.2 11B
prompt = None
images = None

for message in messages:
torchtune_contents = []
if isinstance(message["content"], list):
for content_dict in message["content"]:
if content_dict["type"] == "text":
assert (
prompt is None
), "At most one text prompt is supported for each request"
prompt = content_dict["text"]
elif content_dict["type"] == "image_url":
assert (
images is None
), "At most one image is supported at the moment"

base64_decoded = base64.b64decode(
content_dict["image_url"].split(";base64,")[1]
)
images = [Image.open(BytesIO(base64_decoded))]

assert prompt is not None, "Text prompt must be specified in the request"

return self._gen_model_input(prompt, images, completion_request.max_tokens)

prompt = [
{"role": message["role"], "content": message["content"]}
for message in messages
]

return self._gen_model_input(
prompt=prompt, max_new_tokens=completion_request.max_tokens
)

def chunked_completion(self, completion_request: CompletionRequest):
"""Handle a chat completion request and yield a chunked response.
Expand Down
Loading