Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 974a250

Browse files
author
vmpuri
committed
UI and API implementation for base64 encoded image input
1 parent 03c9819 commit 974a250

File tree

3 files changed

+241
-41
lines changed

3 files changed

+241
-41
lines changed

torchchat/usages/browser.py

Lines changed: 99 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,110 @@
1-
import time
1+
import base64
22
import streamlit as st
3+
import time
4+
from pathlib import Path
5+
import logging
6+
7+
logger = logging.getLogger(__name__)
8+
logger.setLevel(logging.DEBUG)
9+
310
from openai import OpenAI
411

512
st.title("torchchat")
613

714
start_state = [
815
{
916
"role": "system",
10-
"content": "You're an assistant. Answer questions directly, be brief, and have fun.",
17+
"content": "You're a helpful assistant - have fun.",
1118
},
1219
{"role": "assistant", "content": "How can I help you?"},
1320
]
1421

22+
if "uploader_key" not in st.session_state:
23+
st.session_state["uploader_key"] = 0
24+
25+
def reset_per_message_state():
26+
# Catch all function for anything that should be reset between each message.
27+
_update_uploader_key()
28+
29+
def _update_uploader_key():
30+
# Increment the uploader key to reset the file uploader after each message.
31+
st.session_state.uploader_key += 1
32+
print("Uplaoder key", st.session_state.uploader_key)
33+
1534
with st.sidebar:
35+
# API Configuration
36+
<<<<<<< Updated upstream
37+
api_base_url = st.text_input(label="API Base URL", value="http://127.0.0.1:5000/v1", help="The base URL for the OpenAI API to connect to")
38+
=======
39+
api_base_url = st.text_input(label="API Base URL", value="http://puri.devvm2162.cco0:8085/v1", help="The base URL for the OpenAI API to connect to")
40+
>>>>>>> Stashed changes
41+
42+
st.divider()
43+
temperature = st.slider("Temperature", min_value=0.0, max_value=1.0, value=1.0, step=0.01)
44+
45+
1646
response_max_tokens = st.slider(
1747
"Max Response Tokens", min_value=10, max_value=1000, value=250, step=10
1848
)
1949
if st.button("Reset Chat", type="primary"):
2050
st.session_state["messages"] = start_state
2151

52+
image_prompts = st.file_uploader("Image Prompts", type=["jpeg"], accept_multiple_files=True, key=f"uploader_{st.session_state.uploader_key}")
53+
54+
for image in image_prompts:
55+
st.image(image)
56+
57+
58+
client = OpenAI(
59+
base_url=api_base_url,
60+
api_key="813", # The OpenAI API requires an API key, but since we don't consume it, this can be any non-empty string.
61+
)
62+
63+
64+
2265
if "messages" not in st.session_state:
2366
st.session_state["messages"] = start_state
2467

2568

2669
for msg in st.session_state.messages:
27-
st.chat_message(msg["role"]).write(msg["content"])
70+
with st.chat_message(msg["role"]):
71+
if type(msg["content"]) is list:
72+
for content in msg["content"]:
73+
if content["type"] == "image_url":
74+
extension = content["image_url"].split(";base64")[0].split("image/")[1]
75+
base64_repr = content["image_url"].split("base64,")[1]
76+
st.image(base64.b64decode(base64_repr))
77+
else:
78+
st.write(content["text"])
79+
elif type(msg["content"]) is dict:
80+
if msg["content"]["type"] == "image_url":
81+
st.image(msg["content"]["image_url"])
82+
else:
83+
st.write(msg["content"]["text"])
84+
elif type(msg["content"]) is str:
85+
st.write(msg["content"])
86+
else:
87+
st.write(f"no clue breh {type(msg['content'])} {msg['content']}")
2888

29-
if prompt := st.chat_input():
30-
client = OpenAI(
31-
base_url="http://127.0.0.1:5000/v1",
32-
api_key="813", # The OpenAI API requires an API key, but since we don't consume it, this can be any non-empty string.
33-
)
3489

35-
st.session_state.messages.append({"role": "user", "content": prompt})
36-
st.chat_message("user").write(prompt)
90+
if prompt := st.chat_input():
91+
user_message = {"role": "user", "content": [{"type":"text", "text": prompt}]}
92+
93+
if image_prompts:
94+
for image_prompt in image_prompts:
95+
extension = Path(image_prompt.name).suffix.strip(".")
96+
image_bytes = image_prompt.getvalue()
97+
base64_encoded = base64.b64encode(image_bytes).decode("utf-8")
98+
user_message["content"].append({"type":"image_url", "image_url": f"data:image/{extension};base64,{base64_encoded}"})
99+
_update_uploader_key()
100+
st.session_state.messages.append(user_message)
101+
102+
with st.chat_message("user"):
103+
st.write(prompt)
104+
for img in image_prompts:
105+
st.image(img)
106+
107+
image_prompts = None
37108

38109
with st.chat_message("assistant"), st.status(
39110
"Generating... ", expanded=True
@@ -53,15 +124,24 @@ def get_streamed_completion(completion_generator):
53124
state="complete",
54125
)
55126

56-
response = st.write_stream(
57-
get_streamed_completion(
58-
client.chat.completions.create(
59-
model="llama3",
60-
messages=st.session_state.messages,
61-
max_tokens=response_max_tokens,
62-
stream=True,
127+
try:
128+
response = st.write_stream(
129+
get_streamed_completion(
130+
client.chat.completions.create(
131+
model="llama3",
132+
messages=st.session_state.messages,
133+
max_tokens=response_max_tokens,
134+
temperature=temperature,
135+
stream=True,
136+
)
63137
)
64-
)
65-
)[0]
138+
)[0]
139+
except Exception as e:
140+
response = st.error(f"Error: {e}")
141+
print(e)
142+
143+
144+
66145

146+
67147
st.session_state.messages.append({"role": "assistant", "content": response})

0 commit comments

Comments
 (0)