Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 7844931

Browse files
author
vmpuri
committed
UI and API implementation for base64 encoded image input
1 parent 03c9819 commit 7844931

File tree

3 files changed

+239
-36
lines changed

3 files changed

+239
-36
lines changed

torchchat/usages/browser.py

Lines changed: 107 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,123 @@
1+
import base64
2+
import logging
13
import time
4+
from pathlib import Path
5+
26
import streamlit as st
7+
8+
logger = logging.getLogger(__name__)
9+
logger.setLevel(logging.DEBUG)
10+
311
from openai import OpenAI
412

513
st.title("torchchat")
614

715
start_state = [
816
{
917
"role": "system",
10-
"content": "You're an assistant. Answer questions directly, be brief, and have fun.",
18+
"content": "You're a helpful assistant - have fun.",
1119
},
1220
{"role": "assistant", "content": "How can I help you?"},
1321
]
1422

23+
st.session_state.uploader_key = 0
24+
25+
26+
def reset_per_message_state():
27+
# Catch all function for anything that should be reset between each message.
28+
_update_uploader_key()
29+
30+
31+
def _update_uploader_key():
32+
# Increment the uploader key to reset the file uploader after each message.
33+
st.session_state.uploader_key = int(time.time())
34+
print("Uplaoder key", st.session_state.uploader_key)
35+
36+
1537
with st.sidebar:
38+
# API Configuration
39+
api_base_url = st.text_input(
40+
label="API Base URL",
41+
value="http://127.0.0.1:5000/v1",
42+
help="The base URL for the OpenAI API to connect to",
43+
)
44+
45+
st.divider()
46+
temperature = st.slider(
47+
"Temperature", min_value=0.0, max_value=1.0, value=1.0, step=0.01
48+
)
49+
1650
response_max_tokens = st.slider(
1751
"Max Response Tokens", min_value=10, max_value=1000, value=250, step=10
1852
)
1953
if st.button("Reset Chat", type="primary"):
2054
st.session_state["messages"] = start_state
2155

56+
image_prompts = st.file_uploader(
57+
"Image Prompts",
58+
type=["jpeg"],
59+
accept_multiple_files=True,
60+
key=st.session_state.uploader_key,
61+
)
62+
63+
for image in image_prompts:
64+
st.image(image)
65+
66+
67+
client = OpenAI(
68+
base_url=api_base_url,
69+
api_key="813", # The OpenAI API requires an API key, but since we don't consume it, this can be any non-empty string.
70+
)
71+
2272
if "messages" not in st.session_state:
2373
st.session_state["messages"] = start_state
2474

2575

2676
for msg in st.session_state.messages:
27-
st.chat_message(msg["role"]).write(msg["content"])
77+
with st.chat_message(msg["role"]):
78+
if type(msg["content"]) is list:
79+
for content in msg["content"]:
80+
if content["type"] == "image_url":
81+
extension = (
82+
content["image_url"].split(";base64")[0].split("image/")[1]
83+
)
84+
base64_repr = content["image_url"].split("base64,")[1]
85+
st.image(base64.b64decode(base64_repr))
86+
else:
87+
st.write(content["text"])
88+
elif type(msg["content"]) is dict:
89+
if msg["content"]["type"] == "image_url":
90+
st.image(msg["content"]["image_url"])
91+
else:
92+
st.write(msg["content"]["text"])
93+
elif type(msg["content"]) is str:
94+
st.write(msg["content"])
95+
else:
96+
st.write(f"no clue breh {type(msg['content'])} {msg['content']}")
2897

29-
if prompt := st.chat_input():
30-
client = OpenAI(
31-
base_url="http://127.0.0.1:5000/v1",
32-
api_key="813", # The OpenAI API requires an API key, but since we don't consume it, this can be any non-empty string.
33-
)
3498

35-
st.session_state.messages.append({"role": "user", "content": prompt})
36-
st.chat_message("user").write(prompt)
99+
if prompt := st.chat_input(on_submit=reset_per_message_state):
100+
user_message = {"role": "user", "content": [{"type": "text", "text": prompt}]}
101+
102+
if image_prompts:
103+
for image_prompt in image_prompts:
104+
extension = Path(image_prompt.name).suffix.strip(".")
105+
image_bytes = image_prompt.getvalue()
106+
base64_encoded = base64.b64encode(image_bytes).decode("utf-8")
107+
user_message["content"].append(
108+
{
109+
"type": "image_url",
110+
"image_url": f"data:image/{extension};base64,{base64_encoded}",
111+
}
112+
)
113+
st.session_state.messages.append(user_message)
114+
115+
with st.chat_message("user"):
116+
st.write(prompt)
117+
for img in image_prompts:
118+
st.image(img)
119+
120+
image_prompts = None
37121

38122
with st.chat_message("assistant"), st.status(
39123
"Generating... ", expanded=True
@@ -53,15 +137,20 @@ def get_streamed_completion(completion_generator):
53137
state="complete",
54138
)
55139

56-
response = st.write_stream(
57-
get_streamed_completion(
58-
client.chat.completions.create(
59-
model="llama3",
60-
messages=st.session_state.messages,
61-
max_tokens=response_max_tokens,
62-
stream=True,
140+
try:
141+
response = st.write_stream(
142+
get_streamed_completion(
143+
client.chat.completions.create(
144+
model="llama3",
145+
messages=st.session_state.messages,
146+
max_tokens=response_max_tokens,
147+
temperature=temperature,
148+
stream=True,
149+
)
63150
)
64-
)
65-
)[0]
151+
)[0]
152+
except Exception as e:
153+
response = st.error(f"Error: {e}")
154+
print(e)
66155

67156
st.session_state.messages.append({"role": "assistant", "content": response})

torchchat/usages/openai_api.py

Lines changed: 131 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,23 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import base64
78
import os
89
import time
910
import uuid
1011

1112
from abc import ABC
1213
from dataclasses import dataclass
14+
from io import BytesIO
1315
from pwd import getpwuid
1416
from typing import Any, Dict, List, Optional, Union
1517

1618
import torch
1719

20+
from _torchchat_test_script import flamingo_transform, padded_collate
21+
from PIL import Image
22+
from torchtune.data import Message
23+
1824
from torchchat.cli.download import is_model_downloaded, load_model_configs
1925
from torchchat.generate import Generator, GeneratorArgs
2026

@@ -31,6 +37,46 @@
3137
# Message classes and associated objects - see the types of Messages under "Create Chat Completion >>> Request body >>> messages"
3238

3339

40+
@dataclass
41+
class _ContentPart(ABC):
42+
"""A single part of a message content field.
43+
44+
See the "Assistants >>> Messages >>> Create Message >>> Request body >>> content >>> Show possible types" section of the OpenAI API docs for more details.
45+
"""
46+
47+
type: str
48+
49+
50+
@dataclass
51+
class ImageFile:
52+
file_id: str
53+
detail: Optional[str]
54+
55+
56+
@dataclass
57+
class ImageFileContentPart(_ContentPart):
58+
type: str = "image_file"
59+
image_file: Optional[ImageFile] = None
60+
61+
62+
@dataclass
63+
class ImageUrl:
64+
url: str
65+
detail: Optional[str]
66+
67+
68+
@dataclass
69+
class ImageUrlContentPart(_ContentPart):
70+
type: str = "image_url"
71+
image_url: Optional[ImageUrl] = None
72+
73+
74+
@dataclass
75+
class TextContentPart(_ContentPart):
76+
text: str = ""
77+
type: str = "text"
78+
79+
3480
@dataclass
3581
class _AbstractMessage(ABC):
3682
"""Base class with common parameters for message types.
@@ -42,7 +88,7 @@ class _AbstractMessage(ABC):
4288
"""
4389

4490
role: str
45-
content: Optional[str] = None
91+
content: Optional[Union[List[_ContentPart], str]] = None
4692

4793

4894
@dataclass
@@ -185,7 +231,7 @@ class ChunkDelta:
185231

186232
tool_calls: Optional[List[ToolCall]]
187233
role: Optional[str]
188-
content: Optional[str]
234+
content: Optional[Union[List[_ContentPart], str]] = None
189235

190236

191237
@dataclass
@@ -232,18 +278,55 @@ def __init__(self, *args, **kwargs):
232278
"""
233279

234280
super().__init__(*args, **kwargs)
235-
self.max_seq_length = (
236-
self.model.config.transformer_args["text"].max_seq_length
237-
+ self.speculative_builder_args.speculate_k
238-
+ 1
239-
if self.draft_model is not None
240-
else self.model.config.transformer_args["text"].max_seq_length
241-
)
281+
self.max_seq_length = 128
282+
if self.model.config.transformer_args.get("text", None):
283+
self.max_seq_len = (
284+
self.model.config.transformer_args["text"].max_seq_length
285+
+ self.speculative_builder_args.speculate_k
286+
+ 1
287+
if self.draft_model is not None
288+
else self.model.config.transformer_args["text"].max_seq_length
289+
)
242290
# The System fingerprint is a unique identifier for the model and its configuration.
243291
self.system_fingerprint = (
244292
f"{self.builder_args.device}_{self.builder_args.precision}"
245293
)
246294

295+
def _openai_messages_to_torchtune(self, messages: List[_AbstractMessage]):
296+
"""Convert a list of OpenAI API messages to a list of TorchTune messages.
297+
298+
Args:
299+
messages: A list of OpenAI API messages.
300+
301+
Returns:
302+
A list of Torchtune Messages.
303+
"""
304+
torchtune_messages = []
305+
for message in messages:
306+
torchtune_contents = []
307+
if isinstance(message["content"], list):
308+
for content in message["content"]:
309+
if isinstance(content, dict):
310+
if content["type"] == "image_url":
311+
torchtune_contents.append({"type": "image"})
312+
elif content["type"] == "image_file":
313+
torchtune_contents.append({"type": "image"})
314+
elif content["type"] == "text":
315+
torchtune_contents.append(
316+
{"type": "text", "content": content["text"]}
317+
)
318+
elif isinstance(content, str):
319+
torchtune_contents.append({"type": "text", "text": content})
320+
else:
321+
torchtune_contents.append(
322+
{"type": "text", "content": message["content"]}
323+
)
324+
torchtune_messages.append(
325+
Message(role=message["role"], content=torchtune_contents, eot=True)
326+
)
327+
torchtune_messages.append(Message(role="assistant", content=""))
328+
return torchtune_messages
329+
247330
def chunked_completion(self, completion_request: CompletionRequest):
248331
"""Handle a chat completion request and yield a chunked response.
249332
@@ -271,15 +354,42 @@ def chunked_completion(self, completion_request: CompletionRequest):
271354
id = str(uuid.uuid4())
272355

273356
idx = 0
274-
tokens = self.chat_formatter.encode_dialog_prompt(
275-
dialog=[
276-
{"role": message["role"], "content": message["content"]}
277-
for message in completion_request.messages
278-
]
279-
)
357+
images = []
280358

281-
encoded = torch.tensor(tokens, dtype=torch.int, device=self.builder_args.device)
282-
print(self.tokenizer.decode(tokens))
359+
device_sync(device=self.builder_args.device)
360+
for message in completion_request.messages:
361+
contents = message["content"]
362+
if isinstance(contents, list):
363+
for content in message["content"]:
364+
if content["type"] == "image_url":
365+
base64_decoded = base64.b64decode(
366+
content["image_url"].split(";base64,")[1]
367+
)
368+
images.append(Image.open(BytesIO(base64_decoded)))
369+
print("images:", len(images), flush=True)
370+
if len(images) > 0:
371+
transform = flamingo_transform(str(self.tokenizer_args.tokenizer_path))
372+
torchtune_messages = self._openai_messages_to_torchtune(
373+
completion_request.messages
374+
)
375+
data = transform(
376+
{"images": images, "messages": torchtune_messages}, inference=True
377+
)
378+
batch = padded_collate([data], self.builder_args.device)
379+
batch.pop("mask")
380+
encoded = batch["tokens"]
381+
else:
382+
tokens = self.chat_formatter.encode_dialog_prompt(
383+
dialog=[
384+
{"role": message["role"], "content": message["content"]}
385+
for message in completion_request.messages
386+
]
387+
)
388+
print("tokens:", self.tokenizer.decode(tokens), flush=True)
389+
encoded = torch.tensor(
390+
tokens, dtype=torch.int, device=self.builder_args.device
391+
)
392+
batch = None
283393

284394
start_pos = 0
285395

@@ -293,7 +403,7 @@ def chunked_completion(self, completion_request: CompletionRequest):
293403
encoded_prompt=encoded,
294404
temperature=float(completion_request.temperature),
295405
chat_mode=False,
296-
sequential_prefill=True,
406+
sequential_prefill=False,
297407
)
298408

299409
def callback(x, *, done_generating=False):
@@ -313,6 +423,7 @@ def callback(x, *, done_generating=False):
313423
draft_model=self.draft_model,
314424
speculate_k=generator_args.speculate_k,
315425
chat_mode=generator_args.chat_mode,
426+
batch=batch,
316427
callback=callback,
317428
temperature=generator_args.temperature,
318429
top_k=generator_args.top_k,
@@ -323,10 +434,12 @@ def callback(x, *, done_generating=False):
323434
):
324435
if y is None:
325436
continue
437+
326438
elif y.item() == self.tokenizer.eos_id:
327439
# Stop generation if the EOS token is generated.
328440
break
329441

442+
y = y.view(-1)
330443
# Decode the torch.Tensor token to a string and append to the buffer. Separate the sequences with a period token.
331444
content = "".join(
332445
self.tokenizer.decode([self.tokenizer.encode(".")[0]] + y.tolist())[1:]

0 commit comments

Comments
 (0)