Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 53f9d34

Browse files
committed
initial test
1 parent 77bac00 commit 53f9d34

File tree

2 files changed

+50
-41
lines changed

2 files changed

+50
-41
lines changed

torchchat/generate.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66
import argparse
7+
import base64
78
import itertools
89
import logging
910
import os
@@ -12,6 +13,7 @@
1213

1314
from abc import ABC, abstractmethod
1415
from dataclasses import dataclass
16+
from io import BytesIO
1517
from os import PathLike
1618
from pathlib import Path
1719
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
@@ -733,6 +735,9 @@ def _callback(self, x, *, buffer, done_generating):
733735
buffer.clear()
734736
# print(, end='', flush=True)
735737

738+
def print_m(self, message):
739+
print(message.role, [t["type"] if t["type"] != "text" else t for t in message.content ])
740+
736741
def _gen_model_input(
737742
self,
738743
prompt: Union[str | List[Any]],
@@ -775,6 +780,7 @@ def _gen_model_input(
775780
assert (
776781
image_prompts is None or len(image_prompts) == 1
777782
), "At most one image is supported at the moment"
783+
778784
if image_prompts and isinstance(image_prompts[0], str):
779785
images = [Image.open(image_prompts[0])]
780786
else:
@@ -783,24 +789,45 @@ def _gen_model_input(
783789
assert (
784790
max_new_tokens is not None
785791
), "max_new_tokens must be specified for Flamingo models"
786-
assert isinstance(
787-
prompt, str
788-
), "(Currently) prompt must be a str for Flamingo models"
789792

790-
is_multimodal = images is not None
791-
content = [{"type": "text", "content": prompt}]
793+
image_found = False
794+
messages = []
795+
for message in prompt:
796+
if isinstance(message["content"], str):
797+
messages.append(Message(**message))
798+
799+
elif isinstance(message["content"], list):
800+
images = None
801+
for content_dict in message["content"]:
802+
if content_dict["type"] == "text":
803+
prompt_arg = content_dict["text"]
804+
elif content_dict["type"] == "image_url":
805+
assert (
806+
images is None
807+
), "At most one image is supported at the moment"
808+
809+
base64_decoded = base64.b64decode(
810+
content_dict["image_url"].split(";base64,")[1]
811+
)
812+
images = [Image.open(BytesIO(base64_decoded))]
813+
image_found = True
792814

793-
if is_multimodal:
794-
content = [{"type": "image", "content": images[0]}] + content
815+
is_multimodal = images is not None
816+
content = [{"type": "text", "content": prompt_arg}]
795817

796-
messages = [
797-
Message(
798-
role="user",
799-
content=content,
800-
eot=True,
801-
),
802-
Message(role="assistant", content=""),
803-
]
818+
if is_multimodal:
819+
content = [{"type": "image", "content": images[0]}] + content
820+
821+
messages.append(
822+
Message(
823+
role="user",
824+
content=content,
825+
)
826+
)
827+
828+
print("MESSAGE CONTENTS:")
829+
messages.append(Message(role="assistant", content=""))
830+
[self.print_m(m) for m in messages]
804831

805832
transform = llama3_2_vision_transform(str(self.tokenizer_args.tokenizer_path))
806833

torchchat/usages/openai_api.py

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -314,38 +314,20 @@ def _gen_model_inputs_from_openai_completion_request(
314314
if not isinstance(self.model, FlamingoModel):
315315
prompt = [
316316
{"role": message["role"], "content": message["content"]}
317-
for message in completion_request.messages
317+
for message in messages
318318
]
319319
return self._gen_model_input(
320320
prompt=prompt, max_new_tokens=completion_request.max_tokens
321321
)
322322

323323
# Llama 3.2 11B
324-
prompt = None
325-
images = None
326-
327-
for message in messages:
328-
torchtune_contents = []
329-
if isinstance(message["content"], list):
330-
for content_dict in message["content"]:
331-
if content_dict["type"] == "text":
332-
assert (
333-
prompt is None
334-
), "At most one text prompt is supported for each request"
335-
prompt = content_dict["text"]
336-
elif content_dict["type"] == "image_url":
337-
assert (
338-
images is None
339-
), "At most one image is supported at the moment"
340-
341-
base64_decoded = base64.b64decode(
342-
content_dict["image_url"].split(";base64,")[1]
343-
)
344-
images = [Image.open(BytesIO(base64_decoded))]
345-
346-
assert prompt is not None, "Text prompt must be specified in the request"
347-
348-
return self._gen_model_input(prompt, images, completion_request.max_tokens)
324+
325+
prompt = [
326+
{"role": message["role"], "content": message["content"]}
327+
for message in messages
328+
]
329+
330+
return self._gen_model_input(prompt=prompt, max_new_tokens=completion_request.max_tokens)
349331

350332
def chunked_completion(self, completion_request: CompletionRequest):
351333
"""Handle a chat completion request and yield a chunked response.

0 commit comments

Comments
 (0)