1+ import base64
2+ import logging
13import time
4+ from pathlib import Path
5+
26import streamlit as st
7+
8+ logger = logging .getLogger (__name__ )
9+ logger .setLevel (logging .DEBUG )
10+
311from openai import OpenAI
412
513st .title ("torchchat" )
614
715start_state = [
816 {
917 "role" : "system" ,
10- "content" : "You're an assistant. Answer questions directly, be brief, and have fun." ,
18+ "content" : "You're a helpful assistant - have fun." ,
1119 },
1220 {"role" : "assistant" , "content" : "How can I help you?" },
1321]
1422
23+ st .session_state .uploader_key = 0
24+
25+
26+ def reset_per_message_state ():
27+ # Catch all function for anything that should be reset between each message.
28+ _update_uploader_key ()
29+
30+
31+ def _update_uploader_key ():
32+ # Increment the uploader key to reset the file uploader after each message.
33+ st .session_state .uploader_key = int (time .time ())
34+
35+
1536with st .sidebar :
37+ # API Configuration
38+ api_base_url = st .text_input (
39+ label = "API Base URL" ,
40+ value = "http://127.0.0.1:5000/v1" ,
41+ help = "The base URL for the OpenAI API to connect to" ,
42+ )
43+
44+ st .divider ()
45+ temperature = st .slider (
46+ "Temperature" , min_value = 0.0 , max_value = 1.0 , value = 1.0 , step = 0.01
47+ )
48+
1649 response_max_tokens = st .slider (
1750 "Max Response Tokens" , min_value = 10 , max_value = 1000 , value = 250 , step = 10
1851 )
1952 if st .button ("Reset Chat" , type = "primary" ):
2053 st .session_state ["messages" ] = start_state
2154
55+ image_prompts = st .file_uploader (
56+ "Image Prompts" ,
57+ type = ["jpeg" ],
58+ accept_multiple_files = True ,
59+ key = st .session_state .uploader_key ,
60+ )
61+
62+ for image in image_prompts :
63+ st .image (image )
64+
65+
66+ client = OpenAI (
67+ base_url = api_base_url ,
68+ api_key = "813" , # The OpenAI API requires an API key, but since we don't consume it, this can be any non-empty string.
69+ )
70+
2271if "messages" not in st .session_state :
2372 st .session_state ["messages" ] = start_state
2473
2574
2675for msg in st .session_state .messages :
27- st .chat_message (msg ["role" ]).write (msg ["content" ])
76+ with st .chat_message (msg ["role" ]):
77+ if type (msg ["content" ]) is list :
78+ for content in msg ["content" ]:
79+ if content ["type" ] == "image_url" :
80+ extension = (
81+ content ["image_url" ].split (";base64" )[0 ].split ("image/" )[1 ]
82+ )
83+ base64_repr = content ["image_url" ].split ("base64," )[1 ]
84+ st .image (base64 .b64decode (base64_repr ))
85+ else :
86+ st .write (content ["text" ])
87+ elif type (msg ["content" ]) is dict :
88+ if msg ["content" ]["type" ] == "image_url" :
89+ st .image (msg ["content" ]["image_url" ])
90+ else :
91+ st .write (msg ["content" ]["text" ])
92+ elif type (msg ["content" ]) is str :
93+ st .write (msg ["content" ])
94+ else :
95+ st .write (f"Unhandled content type: { type (msg ['content' ])} " )
96+
2897
2998if prompt := st .chat_input ():
30- client = OpenAI (
31- base_url = "http://127.0.0.1:5000/v1" ,
32- api_key = "813" , # The OpenAI API requires an API key, but since we don't consume it, this can be any non-empty string.
33- )
99+ user_message = {"role" : "user" , "content" : [{"type" : "text" , "text" : prompt }]}
100+
101+ if image_prompts :
102+ for image_prompt in image_prompts :
103+ extension = Path (image_prompt .name ).suffix .strip ("." )
104+ image_bytes = image_prompt .getvalue ()
105+ base64_encoded = base64 .b64encode (image_bytes ).decode ("utf-8" )
106+ user_message ["content" ].append (
107+ {
108+ "type" : "image_url" ,
109+ "image_url" : f"data:image/{ extension } ;base64,{ base64_encoded } " ,
110+ }
111+ )
112+ st .session_state .messages .append (user_message )
113+
114+ with st .chat_message ("user" ):
115+ st .write (prompt )
116+ for img in image_prompts :
117+ st .image (img )
34118
35- st . session_state . messages . append ({ "role" : "user" , "content" : prompt })
36- st . chat_message ( "user" ). write ( prompt )
119+ image_prompts = None
120+ reset_per_message_state ( )
37121
38122 with st .chat_message ("assistant" ), st .status (
39123 "Generating... " , expanded = True
@@ -53,15 +137,20 @@ def get_streamed_completion(completion_generator):
53137 state = "complete" ,
54138 )
55139
56- response = st .write_stream (
57- get_streamed_completion (
58- client .chat .completions .create (
59- model = "llama3" ,
60- messages = st .session_state .messages ,
61- max_tokens = response_max_tokens ,
62- stream = True ,
140+ try :
141+ response = st .write_stream (
142+ get_streamed_completion (
143+ client .chat .completions .create (
144+ model = "llama3" ,
145+ messages = st .session_state .messages ,
146+ max_tokens = response_max_tokens ,
147+ temperature = temperature ,
148+ stream = True ,
149+ )
63150 )
64- )
65- )[0 ]
151+ )[0 ]
152+ except Exception as e :
153+ response = st .error (f"Error: { e } " )
154+ print (e )
66155
67156 st .session_state .messages .append ({"role" : "assistant" , "content" : response })
0 commit comments