Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 62 additions & 7 deletions torchchat/usages/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

from PIL import Image

from torchtune.data import Message, padded_collate
from torchtune.data import Message, padded_collate_tiled_images_and_mask

from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform

from torchchat.cli.download import is_model_downloaded, load_model_configs
Expand Down Expand Up @@ -288,13 +289,50 @@ def __init__(self, *args, **kwargs):
else self.model.text_transformer_args.max_seq_length
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of the things on our list is having a unification configuration system for both tune-backend model and chat-backend models to get rid of the try .. except here.

)
except:
# can not find max_seq_length in model config, use default value
self.max_seq_length = 128
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha now i can see where 128 comes from

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It happens to be the same as the head_dim, very tricky to trace it LOL

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks to @iseeyuan for spotting it.

self.max_seq_length = 2048
print(f"can not find max_seq_length in model config, use default value: {self.max_seq_length}")
# The System fingerprint is a unique identifier for the model and its configuration.
self.system_fingerprint = (
f"{self.builder_args.device}_{self.builder_args.precision}"
)

def _openai_messages_to_torchtune_messages(
self, messages: List[_AbstractMessage]
) -> List[Message]:
"""Convert a list of OpenAI API messages to a list of TorchTune messages.

Args:
messages: A list of OpenAI API messages.

Returns:
A list of Torchtune Messages.
"""
torchtune_messages = []
for message in messages:
torchtune_contents = []
if isinstance(message["content"], list):
for content_dict in message["content"]:
converted_content = []
if content_dict["type"] == "text":
converted_content.append(
{"type": "text", "content": content_dict["text"]}
)
elif content_dict["type"] == "image_url":
base64_decoded = base64.b64decode(
content_dict["image_url"].split(";base64,")[1]
)
image = Image.open(BytesIO(base64_decoded))
converted_content.append(
{
"type": "image",
"content": image,
}
)
torchtune_messages.append(
Message(role=message["role"], content=converted_content, eot=False)
)
return torchtune_messages

def _openai_messages_to_torchtune(
self, messages: List[_AbstractMessage]
) -> List[Message]:
Expand Down Expand Up @@ -376,15 +414,32 @@ def chunked_completion(self, completion_request: CompletionRequest):
transform = llama3_2_vision_transform(
str(self.tokenizer_args.tokenizer_path)
)
torchtune_messages = self._openai_messages_to_torchtune(
torchtune_messages = self._openai_messages_to_torchtune_messages(
completion_request.messages
)
data = transform(
{"images": images, "messages": torchtune_messages}, inference=True
)
batch = padded_collate([data], self.builder_args.device)
batch.pop("mask")
encoded = batch["tokens"]
seq_len = len(data["tokens"])
total_response_length = seq_len + completion_request.max_tokens
causal_mask = torch.tril(
torch.ones(
size=(total_response_length, total_response_length),
dtype=torch.bool,
)
)
input_pos = torch.arange(total_response_length)

with torch.no_grad():
with torch.device(self.builder_args.device):
batch = padded_collate_tiled_images_and_mask([data], pad_direction="left", pad_max_images=1)
batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(self.builder_args.precision)
batch["causal_mask"] = causal_mask
batch["input_pos"] = input_pos[None, :seq_len]
batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]

#batch = padded_collate([data], self.builder_args.device)
encoded = batch["tokens"].view(-1)
else:
tokens = self.chat_formatter.encode_dialog_prompt(
dialog=[
Expand Down
Loading