1+ import  base64 
2+ import  logging 
13import  time 
4+ from  pathlib  import  Path 
5+ 
26import  streamlit  as  st 
7+ 
8+ logger  =  logging .getLogger (__name__ )
9+ logger .setLevel (logging .DEBUG )
10+ 
311from  openai  import  OpenAI 
412
513st .title ("torchchat" )
614
715start_state  =  [
816    {
917        "role" : "system" ,
10-         "content" : "You're an  assistant. Answer questions directly, be brief, and  have fun." ,
18+         "content" : "You're a helpful  assistant -  have fun." ,
1119    },
1220    {"role" : "assistant" , "content" : "How can I help you?" },
1321]
1422
23+ st .session_state .uploader_key  =  0 
24+ 
25+ 
26+ def  reset_per_message_state ():
27+     # Catch all function for anything that should be reset between each message. 
28+     _update_uploader_key ()
29+ 
30+ 
31+ def  _update_uploader_key ():
32+     # Increment the uploader key to reset the file uploader after each message. 
33+     st .session_state .uploader_key  =  int (time .time ())
34+ 
35+ 
1536with  st .sidebar :
37+     # API Configuration 
38+     api_base_url  =  st .text_input (
39+         label = "API Base URL" ,
40+         value = "http://127.0.0.1:5000/v1" ,
41+         help = "The base URL for the OpenAI API to connect to" ,
42+     )
43+ 
44+     st .divider ()
45+     temperature  =  st .slider (
46+         "Temperature" , min_value = 0.0 , max_value = 1.0 , value = 1.0 , step = 0.01 
47+     )
48+ 
1649    response_max_tokens  =  st .slider (
1750        "Max Response Tokens" , min_value = 10 , max_value = 1000 , value = 250 , step = 10 
1851    )
1952    if  st .button ("Reset Chat" , type = "primary" ):
2053        st .session_state ["messages" ] =  start_state 
2154
55+     image_prompts  =  st .file_uploader (
56+         "Image Prompts" ,
57+         type = ["jpeg" ],
58+         accept_multiple_files = True ,
59+         key = st .session_state .uploader_key ,
60+     )
61+ 
62+     for  image  in  image_prompts :
63+         st .image (image )
64+ 
65+ 
66+ client  =  OpenAI (
67+     base_url = api_base_url ,
68+     api_key = "813" ,  # The OpenAI API requires an API key, but since we don't consume it, this can be any non-empty string. 
69+ )
70+ 
2271if  "messages"  not  in   st .session_state :
2372    st .session_state ["messages" ] =  start_state 
2473
2574
2675for  msg  in  st .session_state .messages :
27-     st .chat_message (msg ["role" ]).write (msg ["content" ])
76+     with  st .chat_message (msg ["role" ]):
77+         if  type (msg ["content" ]) is  list :
78+             for  content  in  msg ["content" ]:
79+                 if  content ["type" ] ==  "image_url" :
80+                     extension  =  (
81+                         content ["image_url" ].split (";base64" )[0 ].split ("image/" )[1 ]
82+                     )
83+                     base64_repr  =  content ["image_url" ].split ("base64," )[1 ]
84+                     st .image (base64 .b64decode (base64_repr ))
85+                 else :
86+                     st .write (content ["text" ])
87+         elif  type (msg ["content" ]) is  dict :
88+             if  msg ["content" ]["type" ] ==  "image_url" :
89+                 st .image (msg ["content" ]["image_url" ])
90+             else :
91+                 st .write (msg ["content" ]["text" ])
92+         elif  type (msg ["content" ]) is  str :
93+             st .write (msg ["content" ])
94+         else :
95+             st .write (f"Unhandled content type: { type (msg ['content' ])}  " )
96+ 
2897
2998if  prompt  :=  st .chat_input ():
30-     client  =  OpenAI (
31-         base_url = "http://127.0.0.1:5000/v1" ,
32-         api_key = "813" ,  # The OpenAI API requires an API key, but since we don't consume it, this can be any non-empty string. 
33-     )
99+     user_message  =  {"role" : "user" , "content" : [{"type" : "text" , "text" : prompt }]}
100+ 
101+     if  image_prompts :
102+         for  image_prompt  in  image_prompts :
103+             extension  =  Path (image_prompt .name ).suffix .strip ("." )
104+             image_bytes  =  image_prompt .getvalue ()
105+             base64_encoded  =  base64 .b64encode (image_bytes ).decode ("utf-8" )
106+             user_message ["content" ].append (
107+                 {
108+                     "type" : "image_url" ,
109+                     "image_url" : f"data:image/{ extension }  ;base64,{ base64_encoded }  " ,
110+                 }
111+             )
112+     st .session_state .messages .append (user_message )
113+ 
114+     with  st .chat_message ("user" ):
115+         st .write (prompt )
116+         for  img  in  image_prompts :
117+             st .image (img )
34118
35-     st . session_state . messages . append ({ "role" :  "user" ,  "content" :  prompt }) 
36-     st . chat_message ( "user" ). write ( prompt )
119+     image_prompts   =   None 
120+     reset_per_message_state ( )
37121
38122    with  st .chat_message ("assistant" ), st .status (
39123        "Generating... " , expanded = True 
@@ -53,15 +137,20 @@ def get_streamed_completion(completion_generator):
53137                state = "complete" ,
54138            )
55139
56-         response  =  st .write_stream (
57-             get_streamed_completion (
58-                 client .chat .completions .create (
59-                     model = "llama3" ,
60-                     messages = st .session_state .messages ,
61-                     max_tokens = response_max_tokens ,
62-                     stream = True ,
140+         try :
141+             response  =  st .write_stream (
142+                 get_streamed_completion (
143+                     client .chat .completions .create (
144+                         model = "llama3" ,
145+                         messages = st .session_state .messages ,
146+                         max_tokens = response_max_tokens ,
147+                         temperature = temperature ,
148+                         stream = True ,
149+                     )
63150                )
64-             )
65-         )[0 ]
151+             )[0 ]
152+         except  Exception  as  e :
153+             response  =  st .error (f"Error: { e }  " )
154+             print (e )
66155
67156    st .session_state .messages .append ({"role" : "assistant" , "content" : response })
0 commit comments