Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 26a99fc

Browse files
author
vmpuri
committed
Clear image input after submitting a chat
1 parent ad84f51 commit 26a99fc

File tree

2 files changed

+46
-40
lines changed

2 files changed

+46
-40
lines changed

torchchat/generate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -815,13 +815,13 @@ def _gen_model_input(
815815

816816
is_multimodal = images is not None
817817
content = [{"type": "text", "content": prompt_arg}]
818-
[]
818+
819819
if is_multimodal:
820820
content = [{"type": "image", "content": images[0]}] + content
821821

822822
messages.append(
823823
Message(
824-
role="user",
824+
role=message["role"],
825825
content=content,
826826
)
827827
)

torchchat/usages/browser.py

Lines changed: 44 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -12,35 +12,45 @@
1212

1313
st.title("torchchat")
1414

15+
1516
start_state = [
1617
{
1718
"role": "system",
18-
"content": "You're a helpful assistant - have fun.",
19+
"content": "You're a helpful assistant - be brief and have fun.",
1920
},
2021
{"role": "assistant", "content": "How can I help you?"},
2122
]
2223

23-
st.session_state.uploader_key = 0
24-
25-
26-
def reset_per_message_state():
27-
# Catch all function for anything that should be reset between each message.
28-
_update_uploader_key()
2924

25+
def reset_chat():
26+
st.session_state["messages"] = start_state
27+
st.session_state["conversation_images"] = []
3028

31-
def _update_uploader_key():
32-
# Increment the uploader key to reset the file uploader after each message.
33-
st.session_state.uploader_key = int(time.time())
29+
if "messages" not in st.session_state:
30+
st.session_state.messages = start_state
31+
if "conversation_images" not in st.session_state:
32+
st.session_state.conversation_images = []
3433

34+
def _upload_image_prompts(file_uploads):
35+
for file in file_uploads:
36+
st.session_state.conversation_images.append(file)
3537

3638
with st.sidebar:
39+
if st.button("Reset Chat", type="primary"):
40+
reset_chat()
41+
3742
# API Configuration
3843
api_base_url = st.text_input(
3944
label="API Base URL",
4045
value="http://127.0.0.1:5000/v1",
4146
help="The base URL for the OpenAI API to connect to",
4247
)
4348

49+
client = OpenAI(
50+
base_url=api_base_url,
51+
api_key="813", # The OpenAI API requires an API key, but since we don't consume it, this can be any non-empty string.
52+
)
53+
4454
st.divider()
4555
temperature = st.slider(
4656
"Temperature", min_value=0.0, max_value=1.0, value=1.0, step=0.01
@@ -49,28 +59,6 @@ def _update_uploader_key():
4959
response_max_tokens = st.slider(
5060
"Max Response Tokens", min_value=10, max_value=1000, value=250, step=10
5161
)
52-
if st.button("Reset Chat", type="primary"):
53-
st.session_state["messages"] = start_state
54-
55-
image_prompts = st.file_uploader(
56-
"Image Prompts",
57-
type=["jpeg"],
58-
accept_multiple_files=True,
59-
key=st.session_state.uploader_key,
60-
)
61-
62-
for image in image_prompts:
63-
st.image(image)
64-
65-
66-
client = OpenAI(
67-
base_url=api_base_url,
68-
api_key="813", # The OpenAI API requires an API key, but since we don't consume it, this can be any non-empty string.
69-
)
70-
71-
if "messages" not in st.session_state:
72-
st.session_state["messages"] = start_state
73-
7462

7563
for msg in st.session_state.messages:
7664
with st.chat_message(msg["role"]):
@@ -86,6 +74,7 @@ def _update_uploader_key():
8674
st.write(content["text"])
8775
elif type(msg["content"]) is dict:
8876
if msg["content"]["type"] == "image_url":
77+
pass
8978
st.image(msg["content"]["image_url"])
9079
else:
9180
st.write(msg["content"]["text"])
@@ -98,8 +87,8 @@ def _update_uploader_key():
9887
if prompt := st.chat_input():
9988
user_message = {"role": "user", "content": [{"type": "text", "text": prompt}]}
10089

101-
if image_prompts:
102-
for image_prompt in image_prompts:
90+
if len(st.session_state.conversation_images) > 0:
91+
for image_prompt in st.session_state.conversation_images:
10392
extension = Path(image_prompt.name).suffix.strip(".")
10493
image_bytes = image_prompt.getvalue()
10594
base64_encoded = base64.b64encode(image_bytes).decode("utf-8")
@@ -113,11 +102,10 @@ def _update_uploader_key():
113102

114103
with st.chat_message("user"):
115104
st.write(prompt)
116-
for img in image_prompts:
105+
for img in st.session_state.conversation_images:
117106
st.image(img)
118-
119-
image_prompts = None
120-
reset_per_message_state()
107+
st.session_state.conversation_images = []
108+
121109

122110
with st.chat_message("assistant"), st.status(
123111
"Generating... ", expanded=True
@@ -154,3 +142,21 @@ def get_streamed_completion(completion_generator):
154142
print(e)
155143

156144
st.session_state.messages.append({"role": "assistant", "content": response})
145+
146+
# Note: This section needs to be at the end of the file to ensure that the session state is updated before the sidebar is rendered.
147+
with st.sidebar:
148+
st.divider()
149+
150+
with st.form("image_uploader", clear_on_submit=True):
151+
file_uploads = st.file_uploader(
152+
"Upload Image Prompts",
153+
type=["jpeg"],
154+
accept_multiple_files=True,
155+
)
156+
submitted = st.form_submit_button("Attach images to chat message")
157+
if submitted:
158+
_upload_image_prompts(file_uploads)
159+
160+
st.markdown("Image Prompts")
161+
for image in st.session_state.conversation_images:
162+
st.image(image)

0 commit comments

Comments
 (0)