Skip to content

Commit 20dfc2f

Browse files
committed
refactor reference display logic and session state updates for visited documents and links
1 parent aaf46d1 commit 20dfc2f

File tree

2 files changed

+47
-12
lines changed

2 files changed

+47
-12
lines changed

pages/ask_uos_chat.py

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
from redisvl.query import CountQuery, FilterQuery, TextQuery # type: ignore
3232
from redisvl.query.filter import Tag # type: ignore
3333

34+
# TODO : Remove all the display references logic once streamlit integrates pull request Fix st.chat_input collapse after submit #12081
35+
3436

3537
class LimitedRedisChatMessageHistory(RedisChatMessageHistory):
3638

@@ -207,8 +209,20 @@ def display_chat_messages(self):
207209
st.session_state["messages"] = []
208210
# all messages from the history, see ROLES
209211
messages = history.messages
212+
num_msgs = len(messages)
213+
exist_references = bool(
214+
st.session_state.get("visited_docs")
215+
or st.session_state.get("visited_links")
216+
)
217+
218+
def display_references():
219+
"""Display references if they exist."""
220+
if st.session_state.get("visited_links", None):
221+
self.display_visited_links()
222+
elif st.session_state.get("visited_docs", None):
223+
self.display_visited_docs()
210224

211-
for m in messages:
225+
for idx, m in enumerate(messages):
212226
role = m.type
213227
if role == ROLES[1]: # "human"
214228
st.session_state["messages"].append(m)
@@ -223,6 +237,20 @@ def display_chat_messages(self):
223237
else:
224238
st.session_state["messages"].append(m)
225239
with st.chat_message(role, avatar=ASSISTANT_AVATAR):
240+
241+
if exist_references:
242+
if idx == num_msgs - 1:
243+
st.markdown(m.content)
244+
display_references()
245+
# since this is the last message; break the loop
246+
break
247+
# the only way two ai msg are consecutive is if the last message is a summary
248+
if idx == num_msgs - 2:
249+
if messages[idx + 1].type == ROLES[0]:
250+
st.markdown(m.content)
251+
display_references()
252+
continue
253+
226254
st.write(m.content)
227255

228256
else:
@@ -240,6 +268,7 @@ def handle_user_input(self):
240268
placeholder=session_state["_"]("Message"),
241269
key=f"chat_input_{st.session_state.input_key_counter}",
242270
):
271+
243272
if not session_state.feedback_saved:
244273
self.log_feedback()
245274
st.session_state.feedback_saved = False
@@ -258,7 +287,6 @@ def handle_user_input(self):
258287
self.generate_response(prompt)
259288

260289
st.session_state.input_key_counter += 1
261-
262290
st.rerun() # Rerun to update the chat messages and input field
263291

264292
def get_agent(self):
@@ -469,18 +497,22 @@ def _get_stream():
469497

470498
self.store_response(response, prompt, graph)
471499
if graph._visited_docs():
472-
self.display_visited_docs()
500+
st.session_state["visited_docs"] = (
501+
graph._visited_docs.format_references()
502+
)
503+
graph._visited_docs.clear()
504+
# self.display_visited_docs()
473505

474506
if graph._visited_links:
475-
self.display_visited_links()
476-
477-
# self.store_response(response, prompt)
507+
st.session_state["visited_links"] = graph._visited_links
508+
# self.display_visited_links()
478509

479510
def display_visited_docs(self):
480511
"""Display the documents visited for the current user query."""
481512

482-
graph = self.get_agent()
483-
references = graph._visited_docs.format_references()
513+
# graph = self.get_agent()
514+
# references = graph._visited_docs.format_references()
515+
references = st.session_state["visited_docs"]
484516
reference_examination_regulations = "https://www.uni-osnabrueck.de/studium/im-studium/zugangs-zulassungs-und-pruefungsordnungen/"
485517
message = session_state["_"](
486518
"The information provided draws on the documents below that can be found in the [University Website]({}). We encourage you to visit the site to explore these resources for additional details and insights!"
@@ -498,14 +530,13 @@ def display_visited_docs(self):
498530
# TODO: Remove page numbers, these are wrong. Temporary
499531
# st.markdown(f"- **{key}**, **{page_label}**: {page_list}")
500532
st.markdown(f"- **{key}**")
501-
graph._visited_docs.clear()
533+
st.session_state["visited_docs"] = None
502534

503535
def display_visited_links(self):
504536
"""Display the links visited for the current user query."""
505537

506-
graph = self.get_agent()
507538
with st.expander(session_state["_"]("Sources")):
508-
for link in graph._visited_links:
539+
for link in st.session_state["visited_links"]:
509540
st.markdown(
510541
f"""
511542
@@ -516,6 +547,7 @@ def display_visited_links(self):
516547
""",
517548
unsafe_allow_html=True,
518549
)
550+
st.session_state["visited_links"] = None
519551

520552
def store_response(
521553
self,
@@ -801,7 +833,8 @@ def show_delete_button(self):
801833

802834
def run(self):
803835
"""Main method to run the application logic."""
804-
st.title("ask.UOS")
836+
with st.container(key="page-header-container"):
837+
st.title("ask.UOS")
805838
initialize_session_sate()
806839
RemoveEmptyElementContainer()
807840
# Get or create user ID using our method

pages/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ def initialize_session_sate() -> None:
2525
"agent": None,
2626
"ask_uos_user_id": None,
2727
"input_key_counter": 0,
28+
"visited_docs": None,
29+
"visited_links": None,
2830
}
2931

3032
for key, value in defaults.items():

0 commit comments

Comments
 (0)