diff --git a/chat/components/chat.py b/chat/components/chat.py index f3367cb..bf99bce 100644 --- a/chat/components/chat.py +++ b/chat/components/chat.py @@ -52,10 +52,7 @@ def message(qa: QA) -> rx.Component: def chat() -> rx.Component: """List all the messages in a single conversation.""" return rx.auto_scroll( - rx.foreach( - State.chats[State.current_chat], - message, - ), + rx.foreach(State.selected_chat, message), flex="1", padding="8px", ) diff --git a/chat/components/navbar.py b/chat/components/navbar.py index af7e439..71be5bf 100644 --- a/chat/components/navbar.py +++ b/chat/components/navbar.py @@ -19,7 +19,7 @@ def sidebar_chat(chat: str) -> rx.Component: rx.button( rx.icon( tag="trash", - on_click=State.delete_chat, + on_click=State.delete_chat(chat), stroke_width=1, ), width="20%", @@ -27,7 +27,8 @@ def sidebar_chat(chat: str) -> rx.Component: color_scheme="red", ), width="100%", - ) + ), + key=chat, ) diff --git a/chat/state.py b/chat/state.py index c6f09aa..26e9ffd 100644 --- a/chat/state.py +++ b/chat/state.py @@ -16,16 +16,13 @@ class QA(TypedDict): answer: str -DEFAULT_CHATS = { - "Intros": [], -} - - class State(rx.State): """The app state.""" # A dict from the chat name to the list of questions and answers. - chats: dict[str, list[QA]] = DEFAULT_CHATS + _chats: dict[str, list[QA]] = { + "Intros": [], + } # The current chat name. current_chat = "Intros" @@ -41,15 +38,31 @@ def create_chat(self): """Create a new chat.""" # Add the new chat to the list of chats. self.current_chat = self.new_chat_name - self.chats[self.new_chat_name] = [] + self._chats[self.new_chat_name] = [] + + @rx.var + def selected_chat(self) -> list[QA]: + """Get the list of questions and answers for the current chat. + + Returns: + The list of questions and answers. + """ + return ( + self._chats[self.current_chat] if self.current_chat in self._chats else [] + ) @rx.event - def delete_chat(self): + def delete_chat(self, chat_name: str): """Delete the current chat.""" - del self.chats[self.current_chat] - if len(self.chats) == 0: - self.chats = DEFAULT_CHATS - self.current_chat = list(self.chats.keys())[0] + if chat_name not in self._chats: + return + del self._chats[chat_name] + if len(self._chats) == 0: + self._chats = { + "Intros": [], + } + if self.current_chat not in self._chats: + self.current_chat = list(self._chats.keys())[0] @rx.event def set_chat(self, chat_name: str): @@ -76,7 +89,7 @@ def chat_titles(self) -> list[str]: Returns: The list of chat names. """ - return list(self.chats.keys()) + return list(self._chats.keys()) @rx.event async def process_question(self, form_data: dict[str, Any]): @@ -100,7 +113,7 @@ async def openai_process_question(self, question: str): # Add the question to the list of questions. qa = QA(question=question, answer="") - self.chats[self.current_chat].append(qa) + self._chats[self.current_chat].append(qa) # Clear the input and start the processing. self.processing = True @@ -113,7 +126,7 @@ async def openai_process_question(self, question: str): "content": "You are a friendly chatbot named Reflex. Respond in markdown.", } ] - for qa in self.chats[self.current_chat]: + for qa in self._chats[self.current_chat]: messages.append({"role": "user", "content": qa["question"]}) messages.append({"role": "assistant", "content": qa["answer"]}) @@ -133,13 +146,13 @@ async def openai_process_question(self, question: str): answer_text = item.choices[0].delta.content # Ensure answer_text is not None before concatenation if answer_text is not None: - self.chats[self.current_chat][-1]["answer"] += answer_text + self._chats[self.current_chat][-1]["answer"] += answer_text else: # Handle the case where answer_text is None, perhaps log it or assign a default value # For example, assigning an empty string if answer_text is None answer_text = "" - self.chats[self.current_chat][-1]["answer"] += answer_text - self.chats = self.chats + self._chats[self.current_chat][-1]["answer"] += answer_text + self._chats = self._chats yield # Toggle the processing flag.