Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 1e78f59

Browse files
authored
Merge branch 'main' into kv_cache
2 parents 07c21c1 + 16b3d64 commit 1e78f59

File tree

3 files changed

+252
-39
lines changed

3 files changed

+252
-39
lines changed

torchchat/generate.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,8 @@ def prefill(
364364
x_sliced, ip_sliced = x[:, i].view(-1, 1), input_pos[i].view(-1)
365365
# logging.debug(f"<sliced> x: {x_sliced}, input_pos: {ip_sliced}")
366366
logits = model(x_sliced, ip_sliced) # (x[:, i], input_pos[i])
367+
elif self.model.config.model_type == ModelType.Flamingo:
368+
logits = model(x)
367369
else:
368370
# input_pos: [B, S]
369371
logits = model(x, input_pos)
@@ -383,11 +385,14 @@ def decode_one_token(
383385
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
384386
# input_pos: [B, 1]
385387
assert input_pos.shape[-1] == 1
386-
if model.config.model_type == ModelType.Flamingo and batch is not None:
387-
x = x.view(1, -1)
388-
logits = model(x, encoder_mask=batch["encoder_mask"][:, -1:])
388+
x = x.view(1, -1)
389+
if model.config.model_type == ModelType.Flamingo:
390+
if batch is not None:
391+
logits = model(x, encoder_mask=batch["encoder_mask"][:, -1:])
392+
else:
393+
logits = model(x)
389394
else:
390-
logits = model(x.view(1, -1), input_pos)
395+
logits = model(x, input_pos)
391396
# print(f"x: {x},\n input_pos: {input_pos}\n")
392397
return self.sample(logits, need_probs=need_probs, **sampling_kwargs)
393398

torchchat/usages/browser.py

Lines changed: 106 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,123 @@
1+
import base64
2+
import logging
13
import time
4+
from pathlib import Path
5+
26
import streamlit as st
7+
8+
logger = logging.getLogger(__name__)
9+
logger.setLevel(logging.DEBUG)
10+
311
from openai import OpenAI
412

513
st.title("torchchat")
614

715
start_state = [
816
{
917
"role": "system",
10-
"content": "You're an assistant. Answer questions directly, be brief, and have fun.",
18+
"content": "You're a helpful assistant - have fun.",
1119
},
1220
{"role": "assistant", "content": "How can I help you?"},
1321
]
1422

23+
st.session_state.uploader_key = 0
24+
25+
26+
def reset_per_message_state():
27+
# Catch all function for anything that should be reset between each message.
28+
_update_uploader_key()
29+
30+
31+
def _update_uploader_key():
32+
# Increment the uploader key to reset the file uploader after each message.
33+
st.session_state.uploader_key = int(time.time())
34+
35+
1536
with st.sidebar:
37+
# API Configuration
38+
api_base_url = st.text_input(
39+
label="API Base URL",
40+
value="http://127.0.0.1:5000/v1",
41+
help="The base URL for the OpenAI API to connect to",
42+
)
43+
44+
st.divider()
45+
temperature = st.slider(
46+
"Temperature", min_value=0.0, max_value=1.0, value=1.0, step=0.01
47+
)
48+
1649
response_max_tokens = st.slider(
1750
"Max Response Tokens", min_value=10, max_value=1000, value=250, step=10
1851
)
1952
if st.button("Reset Chat", type="primary"):
2053
st.session_state["messages"] = start_state
2154

55+
image_prompts = st.file_uploader(
56+
"Image Prompts",
57+
type=["jpeg"],
58+
accept_multiple_files=True,
59+
key=st.session_state.uploader_key,
60+
)
61+
62+
for image in image_prompts:
63+
st.image(image)
64+
65+
66+
client = OpenAI(
67+
base_url=api_base_url,
68+
api_key="813", # The OpenAI API requires an API key, but since we don't consume it, this can be any non-empty string.
69+
)
70+
2271
if "messages" not in st.session_state:
2372
st.session_state["messages"] = start_state
2473

2574

2675
for msg in st.session_state.messages:
27-
st.chat_message(msg["role"]).write(msg["content"])
76+
with st.chat_message(msg["role"]):
77+
if type(msg["content"]) is list:
78+
for content in msg["content"]:
79+
if content["type"] == "image_url":
80+
extension = (
81+
content["image_url"].split(";base64")[0].split("image/")[1]
82+
)
83+
base64_repr = content["image_url"].split("base64,")[1]
84+
st.image(base64.b64decode(base64_repr))
85+
else:
86+
st.write(content["text"])
87+
elif type(msg["content"]) is dict:
88+
if msg["content"]["type"] == "image_url":
89+
st.image(msg["content"]["image_url"])
90+
else:
91+
st.write(msg["content"]["text"])
92+
elif type(msg["content"]) is str:
93+
st.write(msg["content"])
94+
else:
95+
st.write(f"Unhandled content type: {type(msg['content'])}")
96+
2897

2998
if prompt := st.chat_input():
30-
client = OpenAI(
31-
base_url="http://127.0.0.1:5000/v1",
32-
api_key="813", # The OpenAI API requires an API key, but since we don't consume it, this can be any non-empty string.
33-
)
99+
user_message = {"role": "user", "content": [{"type": "text", "text": prompt}]}
100+
101+
if image_prompts:
102+
for image_prompt in image_prompts:
103+
extension = Path(image_prompt.name).suffix.strip(".")
104+
image_bytes = image_prompt.getvalue()
105+
base64_encoded = base64.b64encode(image_bytes).decode("utf-8")
106+
user_message["content"].append(
107+
{
108+
"type": "image_url",
109+
"image_url": f"data:image/{extension};base64,{base64_encoded}",
110+
}
111+
)
112+
st.session_state.messages.append(user_message)
113+
114+
with st.chat_message("user"):
115+
st.write(prompt)
116+
for img in image_prompts:
117+
st.image(img)
34118

35-
st.session_state.messages.append({"role": "user", "content": prompt})
36-
st.chat_message("user").write(prompt)
119+
image_prompts = None
120+
reset_per_message_state()
37121

38122
with st.chat_message("assistant"), st.status(
39123
"Generating... ", expanded=True
@@ -53,15 +137,20 @@ def get_streamed_completion(completion_generator):
53137
state="complete",
54138
)
55139

56-
response = st.write_stream(
57-
get_streamed_completion(
58-
client.chat.completions.create(
59-
model="llama3",
60-
messages=st.session_state.messages,
61-
max_tokens=response_max_tokens,
62-
stream=True,
140+
try:
141+
response = st.write_stream(
142+
get_streamed_completion(
143+
client.chat.completions.create(
144+
model="llama3",
145+
messages=st.session_state.messages,
146+
max_tokens=response_max_tokens,
147+
temperature=temperature,
148+
stream=True,
149+
)
63150
)
64-
)
65-
)[0]
151+
)[0]
152+
except Exception as e:
153+
response = st.error(f"Error: {e}")
154+
print(e)
66155

67156
st.session_state.messages.append({"role": "assistant", "content": response})

0 commit comments

Comments
 (0)