diff --git a/src/wandbot/api/routers/chat.py b/src/wandbot/api/routers/chat.py index 30218f4913..b64da786db 100644 --- a/src/wandbot/api/routers/chat.py +++ b/src/wandbot/api/routers/chat.py @@ -1,4 +1,5 @@ from fastapi import APIRouter, HTTPException +from fastapi.responses import StreamingResponse from starlette import status from wandbot.chat.schemas import ChatRequest, ChatResponse @@ -36,14 +37,22 @@ async def query(request: APIQueryRequest) -> APIQueryResponse: try: chat_instance = chat_components["chat"] - result = await chat_instance.__acall__( - ChatRequest( - question=request.question, - chat_history=request.chat_history, - language=request.language, - application=request.application, - ), + chat_req = ChatRequest( + question=request.question, + chat_history=request.chat_history, + language=request.language, + application=request.application, + stream=request.stream, ) + + if chat_req.stream: + async def event_gen(): + async for token in chat_instance.astream(chat_req): + yield f"data: {token}\n\n" + + return StreamingResponse(event_gen(), media_type="text/event-stream") + + result = await chat_instance.__acall__(chat_req) return APIQueryResponse(**result.model_dump()) except Exception as e: logger.error(f"Error processing chat query: {e}") diff --git a/src/wandbot/chat/chat.py b/src/wandbot/chat/chat.py index d7bc545092..1f3d59abcc 100644 --- a/src/wandbot/chat/chat.py +++ b/src/wandbot/chat/chat.py @@ -233,3 +233,36 @@ async def __acall__(self, chat_request: ChatRequest) -> ChatResponse: @weave.op def __call__(self, chat_request: ChatRequest) -> ChatResponse: return run_sync(self.__acall__(chat_request)) + + async def astream(self, chat_request: ChatRequest): + """Stream the chat response tokens asynchronously.""" + original_language = chat_request.language + + working_request = chat_request + + if original_language == "ja": + translated_question = translate_ja_to_en( + chat_request.question, self.chat_config.ja_translation_model_name + ) + working_request = ChatRequest( + question=translated_question, + chat_history=chat_request.chat_history, + application=chat_request.application, + language="en", + ) + + async for token in self.rag_pipeline.astream( + working_request.question, working_request.chat_history or [] + ): + yield token + + result = self.rag_pipeline.stream_result + result_dict = result.model_dump() + + if original_language == "ja": + result_dict["answer"] = translate_en_to_ja( + result_dict["answer"], self.chat_config.ja_translation_model_name + ) + + result_dict.update({"application": chat_request.application}) + self.last_stream_response = ChatResponse(**result_dict) diff --git a/src/wandbot/chat/rag.py b/src/wandbot/chat/rag.py index 6fb3e2202f..94e621ddd0 100644 --- a/src/wandbot/chat/rag.py +++ b/src/wandbot/chat/rag.py @@ -141,4 +141,46 @@ async def __acall__( def __call__( self, question: str, chat_history: List[Tuple[str, str]] | None = None ) -> RAGPipelineOutput: - return run_sync(self.__acall__(question, chat_history)) \ No newline at end of file + return run_sync(self.__acall__(question, chat_history)) + + async def astream( + self, question: str, chat_history: List[Tuple[str, str]] | None = None + ) -> None: + """Stream tokens from the response synthesizer.""" + if chat_history is None: + chat_history = [] + + enhanced_query = await self.query_enhancer({"query": question, "chat_history": chat_history}) + retrieval_result = await self.retrieval_engine.__acall__(enhanced_query) + + async for token in self.response_synthesizer.stream(retrieval_result): + yield token + + response = self.response_synthesizer.stream_output + + self.stream_result = RAGPipelineOutput( + question=enhanced_query["standalone_query"], + answer=response["response"], + sources="\n".join( + [doc.metadata["source"] for doc in retrieval_result.documents] + ), + source_documents=response["context_str"], + system_prompt=response["response_prompt"], + model=response["response_model"], + total_tokens=0, + prompt_tokens=0, + completion_tokens=0, + time_taken=0, + start_time=datetime.datetime.now(), + end_time=datetime.datetime.now(), + api_call_statuses={ + "web_search_success": retrieval_result.retrieval_info["api_statuses"]["web_search_api"].success, + "reranker_api_error_info": retrieval_result.retrieval_info["api_statuses"]["reranker_api"].error_info, + "reranker_api_success": retrieval_result.retrieval_info["api_statuses"]["reranker_api"].success, + "query_enhancer_llm_api_error_info": enhanced_query.get("api_statuses", {}).get("query_enhancer_llm_api", {}).error_info if enhanced_query.get("api_statuses") else None, + "query_enhancer_llm_api_success": enhanced_query.get("api_statuses", {}).get("query_enhancer_llm_api", {}).success if enhanced_query.get("api_statuses") else False, + "embedding_api_error_info": retrieval_result.retrieval_info["api_statuses"]["embedding_api"].error_info, + "embedding_api_success": retrieval_result.retrieval_info["api_statuses"]["embedding_api"].success, + }, + response_synthesis_llm_messages=response.get("response_synthesis_llm_messages"), + ) diff --git a/src/wandbot/chat/schemas.py b/src/wandbot/chat/schemas.py index 67607acf8f..0518ab1a13 100644 --- a/src/wandbot/chat/schemas.py +++ b/src/wandbot/chat/schemas.py @@ -46,6 +46,7 @@ class ChatRequest(BaseModel): chat_history: List[QuestionAnswer] | None = None application: str | None = None language: str = "en" + stream: bool = False class ChatResponse(BaseModel): diff --git a/src/wandbot/models/llm.py b/src/wandbot/models/llm.py index 36b870d079..d444c1d069 100644 --- a/src/wandbot/models/llm.py +++ b/src/wandbot/models/llm.py @@ -107,11 +107,16 @@ def __init__(self, self.timeout = timeout self.semaphore = asyncio.Semaphore(n_parallel_api_calls) - async def create(self, - messages: List[Dict[str, Any]], + async def create(self, + messages: List[Dict[str, Any]], **kwargs) -> tuple[Union[str, BaseModel], APIStatus]: raise NotImplementedError("Subclasses must implement create method") + async def stream(self, messages: List[Dict[str, Any]]): + result, _ = await self.create(messages=messages) + if result: + yield result + class AsyncOpenAILLMModel(BaseLLMModel): JSON_MODELS = [ "gpt-4-", # All gpt-4- models @@ -180,6 +185,26 @@ async def create(self, error_info=error_info ) + async def stream(self, messages: List[Dict[str, Any]]): + api_params = { + "model": self.model_name, + "temperature": self.temperature, + "messages": messages, + "stream": True, + } + if api_params["temperature"] == 0: + api_params["temperature"] = 0.1 + + if self.model_name.startswith("o"): + api_params.pop("temperature", None) + + async with self.semaphore: + response = await self.client.chat.completions.create(**api_params) + async for chunk in response: + delta = chunk.choices[0].delta.content + if delta: + yield delta + class AsyncAnthropicLLMModel(BaseLLMModel): def __init__(self, **kwargs): super().__init__(**kwargs) @@ -244,6 +269,42 @@ async def create(self, error_info=error_info ) + async def stream(self, messages: List[Dict[str, Any]], max_tokens: int = 4000): + processed_messages = [] + for msg in messages: + if msg.get("role") == "developer": + processed_messages.append({"role": "system", "content": msg.get("content")}) + else: + processed_messages.append(msg) + + system_msg, chat_messages = extract_system_and_messages(processed_messages) + api_params = { + "model": self.model_name, + "temperature": self.temperature, + "messages": chat_messages, + "max_tokens": max_tokens, + "stream": True, + } + + if api_params["temperature"] == 0: + api_params["temperature"] = 0.1 + + if system_msg: + api_params["system"] = system_msg + + if self.response_model: + api_params["messages"] += add_json_response_model_to_messages(self.response_model) + + async with self.semaphore: + response = await self.client.messages.create(**api_params) + async for chunk in response: + try: + delta = chunk.delta.text # type: ignore + except AttributeError: + delta = "" + if delta: + yield delta + class AsyncGoogleGenAILLMModel(BaseLLMModel): def __init__(self, **kwargs): super().__init__(**kwargs) @@ -322,6 +383,11 @@ async def create(self, error_info=error_info ) + async def stream(self, messages: List[Dict[str, Any]], max_tokens: int = 4000): + result, _ = await self.create(messages=messages, max_tokens=max_tokens) + if result: + yield result + class LLMModel: PROVIDER_MAP = { @@ -344,8 +410,8 @@ def __init__(self, provider: str, **kwargs): def model_name(self) -> str: return self.model.model_name - async def create(self, - messages: List[Dict[str, Any]], + async def create(self, + messages: List[Dict[str, Any]], **kwargs) -> tuple[Union[str, BaseModel], APIStatus]: try: async with self.model.semaphore: # Use the specific model's semaphore @@ -369,3 +435,7 @@ async def create(self, success=False, error_info=error_info ) + + async def stream(self, messages: List[Dict[str, Any]], **kwargs): + async for token in self.model.stream(messages=messages, **kwargs): + yield token diff --git a/src/wandbot/rag/response_synthesis.py b/src/wandbot/rag/response_synthesis.py index 85454373e7..b251c57db1 100644 --- a/src/wandbot/rag/response_synthesis.py +++ b/src/wandbot/rag/response_synthesis.py @@ -227,6 +227,36 @@ async def __call__(self, inputs: RetrievalResult) -> Dict[str, Any]: } } + async def stream(self, inputs: RetrievalResult): + """Stream response tokens while capturing the final result.""" + formatted_input = self._format_input(inputs) + messages = self.get_messages(formatted_input) + + result = "" + used_model = self.model + try: + async for token in self.model.stream(messages=messages): + result += token + yield token + except Exception as e: + logger.warning(f"Primary Response Synthesizer model failed, trying fallback: {str(e)}") + used_model = self.fallback_model + async for token in self.fallback_model.stream(messages=messages): + result += token + yield token + + self.stream_output = { + "query_str": formatted_input["query_str"], + "context_str": formatted_input["context_str"], + "response": result, + "response_model": used_model.model_name, + "response_synthesis_llm_messages": messages, + "response_prompt": RESPONSE_SYNTHESIS_SYSTEM_PROMPT, + "api_statuses": { + "response_synthesis_llm_api": None + }, + } + def _format_input(self, inputs: RetrievalResult) -> Dict[str, str]: """Format the input data for the prompt template.""" return { diff --git a/tests/test_stream.py b/tests/test_stream.py new file mode 100644 index 0000000000..8fcba87e4f --- /dev/null +++ b/tests/test_stream.py @@ -0,0 +1,38 @@ +import pytest + +from wandbot.rag.response_synthesis import ResponseSynthesizer +from wandbot.schema.document import Document +from wandbot.schema.retrieval import RetrievalResult + + +async def fake_stream(*args, **kwargs): + for tok in ["hello ", "world"]: + yield tok + + +@pytest.mark.asyncio +async def test_streaming_response(): + synth = ResponseSynthesizer( + primary_provider="openai", + primary_model_name="dummy", + primary_temperature=0, + fallback_provider="openai", + fallback_model_name="dummy", + fallback_temperature=0, + max_retries=1, + ) + + synth.model.stream = fake_stream # type: ignore + synth.get_messages = lambda x: [] + + retrieval = RetrievalResult( + documents=[Document(page_content="doc", metadata={"source": "s"})], + retrieval_info={"query": "q", "language": "en", "intents": [], "sub_queries": []}, + ) + + tokens = [] + async for token in synth.stream(retrieval): + tokens.append(token) + + assert "".join(tokens) == "hello world" + assert synth.stream_output["response"] == "hello world"