Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit eee8abc

Browse files
committed
initial changes
1 parent c454026 commit eee8abc

File tree

1 file changed

+59
-5
lines changed

1 file changed

+59
-5
lines changed

torchchat/usages/openai_api.py

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919

2020
from PIL import Image
2121

22-
from torchtune.data import Message, padded_collate
22+
from torchtune.data import Message, padded_collate_tiled_images_and_mask
23+
2324
from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform
2425

2526
from torchchat.cli.download import is_model_downloaded, load_model_configs
@@ -295,6 +296,42 @@ def __init__(self, *args, **kwargs):
295296
f"{self.builder_args.device}_{self.builder_args.precision}"
296297
)
297298

299+
def _openai_messages_to_torchtune_messages(
300+
self, messages: List[_AbstractMessage]
301+
) -> List[Message]:
302+
"""Convert a list of OpenAI API messages to a list of TorchTune messages.
303+
304+
Args:
305+
messages: A list of OpenAI API messages.
306+
307+
Returns:
308+
A list of Torchtune Messages.
309+
"""
310+
torchtune_messages = []
311+
for message in messages:
312+
torchtune_contents = []
313+
if isinstance(message["content"], list):
314+
for content_dict in message["content"]:
315+
converted_content = []
316+
if content_dict["type"] == "text":
317+
converted_content.append(
318+
{"type": "text", "content": content_dict["text"]}
319+
)
320+
elif content_dict["type"] == "image_url":
321+
base64_decoded = base64.b64decode(
322+
content_dict["image_url"].split(";base64,")[1]
323+
)
324+
converted_content.append(
325+
{
326+
"type": "image",
327+
"content": Image.open(BytesIO(base64_decoded)),
328+
}
329+
)
330+
torchtune_messages.append(
331+
Message(role=message["role"], content=converted_content, eot=False)
332+
)
333+
return torchtune_messages
334+
298335
def _openai_messages_to_torchtune(
299336
self, messages: List[_AbstractMessage]
300337
) -> List[Message]:
@@ -376,15 +413,32 @@ def chunked_completion(self, completion_request: CompletionRequest):
376413
transform = llama3_2_vision_transform(
377414
str(self.tokenizer_args.tokenizer_path)
378415
)
379-
torchtune_messages = self._openai_messages_to_torchtune(
416+
torchtune_messages = self._openai_messages_to_torchtune_messages(
380417
completion_request.messages
381418
)
382419
data = transform(
383420
{"images": images, "messages": torchtune_messages}, inference=True
384421
)
385-
batch = padded_collate([data], self.builder_args.device)
386-
batch.pop("mask")
387-
encoded = batch["tokens"]
422+
seq_len = len(data["tokens"])
423+
total_response_length = seq_len + completion_request.max_tokens
424+
causal_mask = torch.tril(
425+
torch.ones(
426+
size=(total_response_length, total_response_length),
427+
dtype=torch.bool,
428+
)
429+
)
430+
input_pos = torch.arange(total_response_length)
431+
432+
with torch.no_grad():
433+
with torch.device(self.builder_args.device):
434+
batch = padded_collate_tiled_images_and_mask([data], pad_direction="left", pad_max_images=1)
435+
batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(self.builder_args.precision)
436+
batch["causal_mask"] = causal_mask
437+
batch["input_pos"] = input_pos[None, :seq_len]
438+
batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]
439+
440+
#batch = padded_collate([data], self.builder_args.device)
441+
encoded = batch["tokens"].view(-1)
388442
else:
389443
tokens = self.chat_formatter.encode_dialog_prompt(
390444
dialog=[

0 commit comments

Comments
 (0)