3131from redisvl .query import CountQuery , FilterQuery , TextQuery # type: ignore
3232from redisvl .query .filter import Tag # type: ignore
3333
34+ # TODO : Remove all the display references logic once streamlit integrates pull request Fix st.chat_input collapse after submit #12081
35+
3436
3537class LimitedRedisChatMessageHistory (RedisChatMessageHistory ):
3638
@@ -207,8 +209,20 @@ def display_chat_messages(self):
207209 st .session_state ["messages" ] = []
208210 # all messages from the history, see ROLES
209211 messages = history .messages
212+ num_msgs = len (messages )
213+ exist_references = bool (
214+ st .session_state .get ("visited_docs" )
215+ or st .session_state .get ("visited_links" )
216+ )
217+
218+ def display_references ():
219+ """Display references if they exist."""
220+ if st .session_state .get ("visited_links" , None ):
221+ self .display_visited_links ()
222+ elif st .session_state .get ("visited_docs" , None ):
223+ self .display_visited_docs ()
210224
211- for m in messages :
225+ for idx , m in enumerate ( messages ) :
212226 role = m .type
213227 if role == ROLES [1 ]: # "human"
214228 st .session_state ["messages" ].append (m )
@@ -223,6 +237,20 @@ def display_chat_messages(self):
223237 else :
224238 st .session_state ["messages" ].append (m )
225239 with st .chat_message (role , avatar = ASSISTANT_AVATAR ):
240+
241+ if exist_references :
242+ if idx == num_msgs - 1 :
243+ st .markdown (m .content )
244+ display_references ()
245+ # since this is the last message; break the loop
246+ break
247+ # the only way two ai msg are consecutive is if the last message is a summary
248+ if idx == num_msgs - 2 :
249+ if messages [idx + 1 ].type == ROLES [0 ]:
250+ st .markdown (m .content )
251+ display_references ()
252+ continue
253+
226254 st .write (m .content )
227255
228256 else :
@@ -240,6 +268,7 @@ def handle_user_input(self):
240268 placeholder = session_state ["_" ]("Message" ),
241269 key = f"chat_input_{ st .session_state .input_key_counter } " ,
242270 ):
271+
243272 if not session_state .feedback_saved :
244273 self .log_feedback ()
245274 st .session_state .feedback_saved = False
@@ -258,7 +287,6 @@ def handle_user_input(self):
258287 self .generate_response (prompt )
259288
260289 st .session_state .input_key_counter += 1
261-
262290 st .rerun () # Rerun to update the chat messages and input field
263291
264292 def get_agent (self ):
@@ -469,18 +497,22 @@ def _get_stream():
469497
470498 self .store_response (response , prompt , graph )
471499 if graph ._visited_docs ():
472- self .display_visited_docs ()
500+ st .session_state ["visited_docs" ] = (
501+ graph ._visited_docs .format_references ()
502+ )
503+ graph ._visited_docs .clear ()
504+ # self.display_visited_docs()
473505
474506 if graph ._visited_links :
475- self .display_visited_links ()
476-
477- # self.store_response(response, prompt)
507+ st .session_state ["visited_links" ] = graph ._visited_links
508+ # self.display_visited_links()
478509
479510 def display_visited_docs (self ):
480511 """Display the documents visited for the current user query."""
481512
482- graph = self .get_agent ()
483- references = graph ._visited_docs .format_references ()
513+ # graph = self.get_agent()
514+ # references = graph._visited_docs.format_references()
515+ references = st .session_state ["visited_docs" ]
484516 reference_examination_regulations = "https://www.uni-osnabrueck.de/studium/im-studium/zugangs-zulassungs-und-pruefungsordnungen/"
485517 message = session_state ["_" ](
486518 "The information provided draws on the documents below that can be found in the [University Website]({}). We encourage you to visit the site to explore these resources for additional details and insights!"
@@ -498,14 +530,13 @@ def display_visited_docs(self):
498530 # TODO: Remove page numbers, these are wrong. Temporary
499531 # st.markdown(f"- **{key}**, **{page_label}**: {page_list}")
500532 st .markdown (f"- **{ key } **" )
501- graph . _visited_docs . clear ()
533+ st . session_state [ "visited_docs" ] = None
502534
503535 def display_visited_links (self ):
504536 """Display the links visited for the current user query."""
505537
506- graph = self .get_agent ()
507538 with st .expander (session_state ["_" ]("Sources" )):
508- for link in graph . _visited_links :
539+ for link in st . session_state [ "visited_links" ] :
509540 st .markdown (
510541 f"""
511542
@@ -516,6 +547,7 @@ def display_visited_links(self):
516547 """ ,
517548 unsafe_allow_html = True ,
518549 )
550+ st .session_state ["visited_links" ] = None
519551
520552 def store_response (
521553 self ,
@@ -801,7 +833,8 @@ def show_delete_button(self):
801833
802834 def run (self ):
803835 """Main method to run the application logic."""
804- st .title ("ask.UOS" )
836+ with st .container (key = "page-header-container" ):
837+ st .title ("ask.UOS" )
805838 initialize_session_sate ()
806839 RemoveEmptyElementContainer ()
807840 # Get or create user ID using our method
0 commit comments