1- import time
1+ import base64
22import streamlit as st
3+ import time
4+ from pathlib import Path
5+ import logging
6+
7+ logger = logging .getLogger (__name__ )
8+ logger .setLevel (logging .DEBUG )
9+
310from openai import OpenAI
411
512st .title ("torchchat" )
613
714start_state = [
815 {
916 "role" : "system" ,
10- "content" : "You're an assistant. Answer questions directly, be brief, and have fun." ,
17+ "content" : "You're a helpful assistant - have fun." ,
1118 },
1219 {"role" : "assistant" , "content" : "How can I help you?" },
1320]
1421
22+ if "uploader_key" not in st .session_state :
23+ st .session_state ["uploader_key" ] = 0
24+
25+ def reset_per_message_state ():
26+ # Catch all function for anything that should be reset between each message.
27+ _update_uploader_key ()
28+
29+ def _update_uploader_key ():
30+ # Increment the uploader key to reset the file uploader after each message.
31+ st .session_state .uploader_key += 1
32+ print ("Uplaoder key" , st .session_state .uploader_key )
33+
1534with st .sidebar :
35+ # API Configuration
36+ < << << << Updated upstream
37+ api_base_url = st .text_input (label = "API Base URL" , value = "http://127.0.0.1:5000/v1" , help = "The base URL for the OpenAI API to connect to" )
38+ == == == =
39+ api_base_url = st .text_input (label = "API Base URL" , value = "http://puri.devvm2162.cco0:8085/v1" , help = "The base URL for the OpenAI API to connect to" )
40+ >> >> >> > Stashed changes
41+
42+ st .divider ()
43+ temperature = st .slider ("Temperature" , min_value = 0.0 , max_value = 1.0 , value = 1.0 , step = 0.01 )
44+
45+
1646 response_max_tokens = st .slider (
1747 "Max Response Tokens" , min_value = 10 , max_value = 1000 , value = 250 , step = 10
1848 )
1949 if st .button ("Reset Chat" , type = "primary" ):
2050 st .session_state ["messages" ] = start_state
2151
52+ image_prompts = st .file_uploader ("Image Prompts" , type = ["jpeg" ], accept_multiple_files = True , key = f"uploader_{ st .session_state .uploader_key } " )
53+
54+ for image in image_prompts :
55+ st .image (image )
56+
57+
58+ client = OpenAI (
59+ base_url = api_base_url ,
60+ api_key = "813" , # The OpenAI API requires an API key, but since we don't consume it, this can be any non-empty string.
61+ )
62+
63+
64+
2265if "messages" not in st .session_state :
2366 st .session_state ["messages" ] = start_state
2467
2568
2669for msg in st .session_state .messages :
27- st .chat_message (msg ["role" ]).write (msg ["content" ])
70+ with st .chat_message (msg ["role" ]):
71+ if type (msg ["content" ]) is list :
72+ for content in msg ["content" ]:
73+ if content ["type" ] == "image_url" :
74+ extension = content ["image_url" ].split (";base64" )[0 ].split ("image/" )[1 ]
75+ base64_repr = content ["image_url" ].split ("base64," )[1 ]
76+ st .image (base64 .b64decode (base64_repr ))
77+ else :
78+ st .write (content ["text" ])
79+ elif type (msg ["content" ]) is dict :
80+ if msg ["content" ]["type" ] == "image_url" :
81+ st .image (msg ["content" ]["image_url" ])
82+ else :
83+ st .write (msg ["content" ]["text" ])
84+ elif type (msg ["content" ]) is str :
85+ st .write (msg ["content" ])
86+ else :
87+ st .write (f"no clue breh { type (msg ['content' ])} { msg ['content' ]} " )
2888
29- if prompt := st .chat_input ():
30- client = OpenAI (
31- base_url = "http://127.0.0.1:5000/v1" ,
32- api_key = "813" , # The OpenAI API requires an API key, but since we don't consume it, this can be any non-empty string.
33- )
3489
35- st .session_state .messages .append ({"role" : "user" , "content" : prompt })
36- st .chat_message ("user" ).write (prompt )
90+ if prompt := st .chat_input ():
91+ user_message = {"role" : "user" , "content" : [{"type" :"text" , "text" : prompt }]}
92+
93+ if image_prompts :
94+ for image_prompt in image_prompts :
95+ extension = Path (image_prompt .name ).suffix .strip ("." )
96+ image_bytes = image_prompt .getvalue ()
97+ base64_encoded = base64 .b64encode (image_bytes ).decode ("utf-8" )
98+ user_message ["content" ].append ({"type" :"image_url" , "image_url" : f"data:image/{ extension } ;base64,{ base64_encoded } " })
99+ _update_uploader_key ()
100+ st .session_state .messages .append (user_message )
101+
102+ with st .chat_message ("user" ):
103+ st .write (prompt )
104+ for img in image_prompts :
105+ st .image (img )
106+
107+ image_prompts = None
37108
38109 with st .chat_message ("assistant" ), st .status (
39110 "Generating... " , expanded = True
@@ -53,15 +124,24 @@ def get_streamed_completion(completion_generator):
53124 state = "complete" ,
54125 )
55126
56- response = st .write_stream (
57- get_streamed_completion (
58- client .chat .completions .create (
59- model = "llama3" ,
60- messages = st .session_state .messages ,
61- max_tokens = response_max_tokens ,
62- stream = True ,
127+ try :
128+ response = st .write_stream (
129+ get_streamed_completion (
130+ client .chat .completions .create (
131+ model = "llama3" ,
132+ messages = st .session_state .messages ,
133+ max_tokens = response_max_tokens ,
134+ temperature = temperature ,
135+ stream = True ,
136+ )
63137 )
64- )
65- )[0 ]
138+ )[0 ]
139+ except Exception as e :
140+ response = st .error (f"Error: { e } " )
141+ print (e )
142+
143+
144+
66145
146+
67147 st .session_state .messages .append ({"role" : "assistant" , "content" : response })
0 commit comments