Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions src/wandbot/api/routers/chat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from starlette import status

from wandbot.chat.schemas import ChatRequest, ChatResponse
Expand Down Expand Up @@ -36,14 +37,22 @@ async def query(request: APIQueryRequest) -> APIQueryResponse:

try:
chat_instance = chat_components["chat"]
result = await chat_instance.__acall__(
ChatRequest(
question=request.question,
chat_history=request.chat_history,
language=request.language,
application=request.application,
),
chat_req = ChatRequest(
question=request.question,
chat_history=request.chat_history,
language=request.language,
application=request.application,
stream=request.stream,
)

if chat_req.stream:
async def event_gen():
async for token in chat_instance.astream(chat_req):
yield f"data: {token}\n\n"

return StreamingResponse(event_gen(), media_type="text/event-stream")

result = await chat_instance.__acall__(chat_req)
return APIQueryResponse(**result.model_dump())
except Exception as e:
logger.error(f"Error processing chat query: {e}")
Expand Down
33 changes: 33 additions & 0 deletions src/wandbot/chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,36 @@ async def __acall__(self, chat_request: ChatRequest) -> ChatResponse:
@weave.op
def __call__(self, chat_request: ChatRequest) -> ChatResponse:
return run_sync(self.__acall__(chat_request))

async def astream(self, chat_request: ChatRequest):
"""Stream the chat response tokens asynchronously."""
original_language = chat_request.language

working_request = chat_request

if original_language == "ja":
translated_question = translate_ja_to_en(
chat_request.question, self.chat_config.ja_translation_model_name
)
working_request = ChatRequest(
question=translated_question,
chat_history=chat_request.chat_history,
application=chat_request.application,
language="en",
)

async for token in self.rag_pipeline.astream(
working_request.question, working_request.chat_history or []
):
yield token

result = self.rag_pipeline.stream_result
result_dict = result.model_dump()

if original_language == "ja":
result_dict["answer"] = translate_en_to_ja(
result_dict["answer"], self.chat_config.ja_translation_model_name
)

result_dict.update({"application": chat_request.application})
self.last_stream_response = ChatResponse(**result_dict)
Comment on lines +237 to +268
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Enhance error handling and consistency with the main __acall__ method.

The streaming implementation lacks the comprehensive error handling, timing, and API status tracking present in __acall__. This could result in inconsistent behavior and poor error visibility.

Consider these improvements:

  1. Add proper exception handling with ErrorInfo tracking
  2. Include timing information using Timer
  3. Add API status tracking similar to __acall__
  4. Handle translation errors gracefully with fallback responses
 async def astream(self, chat_request: ChatRequest):
     """Stream the chat response tokens asynchronously."""
     original_language = chat_request.language
+    api_call_statuses = {}
     
     working_request = chat_request
+    
+    with Timer() as timer:
+        try:
+            # Handle Japanese translation with error handling
+            if original_language == "ja":
+                try:
+                    translated_question = translate_ja_to_en(
+                        chat_request.question, self.chat_config.ja_translation_model_name
+                    )
+                    working_request = ChatRequest(
+                        question=translated_question,
+                        chat_history=chat_request.chat_history,
+                        application=chat_request.application,
+                        language="en",
+                    )
+                except Exception as e:
+                    # Handle translation error similar to __acall__
+                    api_call_statuses["chat_success"] = False
+                    api_call_statuses["chat_error_info"] = ErrorInfo(
+                        has_error=True, error_message=str(e), 
+                        error_type=type(e).__name__, component="translation"
+                    ).model_dump()
+                    yield f"Translation error: {str(e)}"
+                    return
+                    
+            async for token in self.rag_pipeline.astream(
+                working_request.question, working_request.chat_history or []
+            ):
+                yield token
+                
+        except Exception as e:
+            # Handle streaming errors
+            api_call_statuses["chat_success"] = False
+            api_call_statuses["chat_error_info"] = ErrorInfo(
+                has_error=True, error_message=str(e),
+                error_type=type(e).__name__, component="chat"
+            ).model_dump()
+            yield f"Streaming error: {str(e)}"
+            return
+            
+        # Store final result with complete metadata
+        result = self.rag_pipeline.stream_result
+        result_dict = result.model_dump()
+        
+        # Handle response translation
+        if original_language == "ja":
+            try:
+                result_dict["answer"] = translate_en_to_ja(
+                    result_dict["answer"], self.chat_config.ja_translation_model_name
+                )
+            except Exception as e:
+                result_dict["answer"] = f"Translation error: {str(e)}\nOriginal answer: {result_dict['answer']}"
+                
+        # Update with complete metadata
+        api_call_statuses["chat_success"] = True
+        result_dict.update({
+            "application": chat_request.application,
+            "api_call_statuses": api_call_statuses,
+            "time_taken": timer.elapsed,
+            "start_time": timer.start,
+            "end_time": timer.stop,
+        })
+        
+        self.last_stream_response = ChatResponse(**result_dict)

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In src/wandbot/chat/chat.py lines 237 to 268, the astream method lacks the
robust error handling, timing, and API status tracking found in the __acall__
method, leading to inconsistent behavior and poor error visibility. To fix this,
wrap the streaming logic in try-except blocks to catch exceptions and record
them using ErrorInfo, use a Timer to measure execution time, and update API
status accordingly. Also, handle translation errors gracefully by providing
fallback responses instead of failing outright, ensuring consistent and reliable
streaming behavior.

44 changes: 43 additions & 1 deletion src/wandbot/chat/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,46 @@ async def __acall__(
def __call__(
self, question: str, chat_history: List[Tuple[str, str]] | None = None
) -> RAGPipelineOutput:
return run_sync(self.__acall__(question, chat_history))
return run_sync(self.__acall__(question, chat_history))

async def astream(
self, question: str, chat_history: List[Tuple[str, str]] | None = None
) -> None:
Comment on lines +146 to +148
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix the return type annotation for the async generator.

The method yields tokens but is annotated as returning None. This should be an async generator type.

-    async def astream(
-        self, question: str, chat_history: List[Tuple[str, str]] | None = None
-    ) -> None:
+    async def astream(
+        self, question: str, chat_history: List[Tuple[str, str]] | None = None
+    ) -> AsyncGenerator[str, None]:

You'll also need to add the import:

-from typing import Dict, List, Tuple
+from typing import AsyncGenerator, Dict, List, Tuple
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
async def astream(
self, question: str, chat_history: List[Tuple[str, str]] | None = None
) -> None:
-from typing import Dict, List, Tuple
+from typing import AsyncGenerator, Dict, List, Tuple
async def astream(
self, question: str, chat_history: List[Tuple[str, str]] | None = None
) -> AsyncGenerator[str, None]:
🤖 Prompt for AI Agents
In src/wandbot/chat/rag.py around lines 146 to 148, the async method astream is
currently annotated as returning None, but it actually yields tokens as an async
generator. Change the return type annotation to an appropriate async generator
type that reflects the yielded token type. Also, add the necessary import for
the async generator type from the typing module to support this annotation.

"""Stream tokens from the response synthesizer."""
if chat_history is None:
chat_history = []

enhanced_query = await self.query_enhancer({"query": question, "chat_history": chat_history})
retrieval_result = await self.retrieval_engine.__acall__(enhanced_query)

async for token in self.response_synthesizer.stream(retrieval_result):
yield token

response = self.response_synthesizer.stream_output

self.stream_result = RAGPipelineOutput(
question=enhanced_query["standalone_query"],
answer=response["response"],
sources="\n".join(
[doc.metadata["source"] for doc in retrieval_result.documents]
),
source_documents=response["context_str"],
system_prompt=response["response_prompt"],
model=response["response_model"],
total_tokens=0,
prompt_tokens=0,
completion_tokens=0,
time_taken=0,
start_time=datetime.datetime.now(),
end_time=datetime.datetime.now(),
api_call_statuses={
"web_search_success": retrieval_result.retrieval_info["api_statuses"]["web_search_api"].success,
"reranker_api_error_info": retrieval_result.retrieval_info["api_statuses"]["reranker_api"].error_info,
"reranker_api_success": retrieval_result.retrieval_info["api_statuses"]["reranker_api"].success,
"query_enhancer_llm_api_error_info": enhanced_query.get("api_statuses", {}).get("query_enhancer_llm_api", {}).error_info if enhanced_query.get("api_statuses") else None,
"query_enhancer_llm_api_success": enhanced_query.get("api_statuses", {}).get("query_enhancer_llm_api", {}).success if enhanced_query.get("api_statuses") else False,
"embedding_api_error_info": retrieval_result.retrieval_info["api_statuses"]["embedding_api"].error_info,
"embedding_api_success": retrieval_result.retrieval_info["api_statuses"]["embedding_api"].success,
},
response_synthesis_llm_messages=response.get("response_synthesis_llm_messages"),
)
Comment on lines +161 to +186
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Eliminate code duplication and improve error handling.

The RAGPipelineOutput construction logic is duplicated between __acall__ and astream methods. Consider extracting this into a helper method and adding proper error handling.

+    def _build_pipeline_output(self, enhanced_query, response, retrieval_result) -> RAGPipelineOutput:
+        """Helper method to build RAGPipelineOutput consistently."""
+        return RAGPipelineOutput(
+            question=enhanced_query["standalone_query"],
+            answer=response["response"],
+            sources="\n".join(
+                [doc.metadata["source"] for doc in retrieval_result.documents]
+            ),
+            source_documents=response["context_str"],
+            system_prompt=response["response_prompt"],
+            model=response["response_model"],
+            total_tokens=0,
+            prompt_tokens=0,
+            completion_tokens=0,
+            time_taken=0,
+            start_time=datetime.datetime.now(),
+            end_time=datetime.datetime.now(),
+            api_call_statuses={
+                "web_search_success": retrieval_result.retrieval_info["api_statuses"]["web_search_api"].success,
+                "reranker_api_error_info": retrieval_result.retrieval_info["api_statuses"]["reranker_api"].error_info,
+                "reranker_api_success": retrieval_result.retrieval_info["api_statuses"]["reranker_api"].success,
+                "query_enhancer_llm_api_error_info": enhanced_query.get("api_statuses", {}).get("query_enhancer_llm_api", {}).error_info if enhanced_query.get("api_statuses") else None,
+                "query_enhancer_llm_api_success": enhanced_query.get("api_statuses", {}).get("query_enhancer_llm_api", {}).success if enhanced_query.get("api_statuses") else False,
+                "embedding_api_error_info": retrieval_result.retrieval_info["api_statuses"]["embedding_api"].error_info,
+                "embedding_api_success": retrieval_result.retrieval_info["api_statuses"]["embedding_api"].success,
+            },
+            response_synthesis_llm_messages=response.get("response_synthesis_llm_messages"),
+        )

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In src/wandbot/chat/rag.py around lines 161 to 186, the construction of the
RAGPipelineOutput object is duplicated in both __acall__ and astream methods. To
fix this, extract the RAGPipelineOutput creation logic into a separate helper
method that takes the necessary inputs and returns the constructed object. Then
replace the duplicated code in both methods with calls to this helper.
Additionally, add proper error handling within this helper to catch and manage
any exceptions during the construction process.

1 change: 1 addition & 0 deletions src/wandbot/chat/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class ChatRequest(BaseModel):
chat_history: List[QuestionAnswer] | None = None
application: str | None = None
language: str = "en"
stream: bool = False


class ChatResponse(BaseModel):
Expand Down
78 changes: 74 additions & 4 deletions src/wandbot/models/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,16 @@ def __init__(self,
self.timeout = timeout
self.semaphore = asyncio.Semaphore(n_parallel_api_calls)

async def create(self,
messages: List[Dict[str, Any]],
async def create(self,
messages: List[Dict[str, Any]],
**kwargs) -> tuple[Union[str, BaseModel], APIStatus]:
raise NotImplementedError("Subclasses must implement create method")

async def stream(self, messages: List[Dict[str, Any]]):
result, _ = await self.create(messages=messages)
if result:
yield result

class AsyncOpenAILLMModel(BaseLLMModel):
JSON_MODELS = [
"gpt-4-", # All gpt-4- models
Expand Down Expand Up @@ -180,6 +185,26 @@ async def create(self,
error_info=error_info
)

async def stream(self, messages: List[Dict[str, Any]]):
api_params = {
"model": self.model_name,
"temperature": self.temperature,
"messages": messages,
"stream": True,
}
if api_params["temperature"] == 0:
api_params["temperature"] = 0.1

if self.model_name.startswith("o"):
api_params.pop("temperature", None)

async with self.semaphore:
response = await self.client.chat.completions.create(**api_params)
async for chunk in response:
delta = chunk.choices[0].delta.content
if delta:
yield delta

Comment on lines +188 to +207
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add response_model support and error handling to OpenAI streaming.

The streaming implementation doesn't handle the response_model case and lacks error handling compared to the create method.

     async def stream(self, messages: List[Dict[str, Any]]):
+        if self.response_model:
+            # For structured outputs, fall back to non-streaming
+            result, _ = await self.create(messages=messages)
+            if result:
+                yield result
+            return
+            
+        try:
             api_params = {
                 "model": self.model_name,
                 "temperature": self.temperature,
                 "messages": messages,
                 "stream": True,
             }
             if api_params["temperature"] == 0:
                 api_params["temperature"] = 0.1
     
             if self.model_name.startswith("o"):
                 api_params.pop("temperature", None)
     
             async with self.semaphore:
                 response = await self.client.chat.completions.create(**api_params)
                 async for chunk in response:
                     delta = chunk.choices[0].delta.content
                     if delta:
                         yield delta
+        except Exception as e:
+            logger.error(f"OpenAI streaming error: {str(e)}")
+            # Fall back to non-streaming
+            result, _ = await self.create(messages=messages)
+            if result:
+                yield result
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
async def stream(self, messages: List[Dict[str, Any]]):
api_params = {
"model": self.model_name,
"temperature": self.temperature,
"messages": messages,
"stream": True,
}
if api_params["temperature"] == 0:
api_params["temperature"] = 0.1
if self.model_name.startswith("o"):
api_params.pop("temperature", None)
async with self.semaphore:
response = await self.client.chat.completions.create(**api_params)
async for chunk in response:
delta = chunk.choices[0].delta.content
if delta:
yield delta
async def stream(self, messages: List[Dict[str, Any]]):
# If a response_model is configured, fall back to non-streaming
if self.response_model:
result, _ = await self.create(messages=messages)
if result:
yield result
return
try:
api_params = {
"model": self.model_name,
"temperature": self.temperature,
"messages": messages,
"stream": True,
}
if api_params["temperature"] == 0:
api_params["temperature"] = 0.1
if self.model_name.startswith("o"):
api_params.pop("temperature", None)
async with self.semaphore:
response = await self.client.chat.completions.create(**api_params)
async for chunk in response:
delta = chunk.choices[0].delta.content
if delta:
yield delta
except Exception as e:
logger.error(f"OpenAI streaming error: {e}")
# Fall back to non-streaming on error
result, _ = await self.create(messages=messages)
if result:
yield result
🤖 Prompt for AI Agents
In src/wandbot/models/llm.py around lines 188 to 207, the stream method lacks
support for the response_model parameter and does not include error handling
like the create method. Update the stream method to accept and pass the
response_model argument when calling the OpenAI API, and add appropriate
try-except blocks to catch and handle exceptions during the streaming process,
ensuring errors are logged or managed gracefully.

class AsyncAnthropicLLMModel(BaseLLMModel):
def __init__(self, **kwargs):
super().__init__(**kwargs)
Expand Down Expand Up @@ -244,6 +269,42 @@ async def create(self,
error_info=error_info
)

async def stream(self, messages: List[Dict[str, Any]], max_tokens: int = 4000):
processed_messages = []
for msg in messages:
if msg.get("role") == "developer":
processed_messages.append({"role": "system", "content": msg.get("content")})
else:
processed_messages.append(msg)

system_msg, chat_messages = extract_system_and_messages(processed_messages)
api_params = {
"model": self.model_name,
"temperature": self.temperature,
"messages": chat_messages,
"max_tokens": max_tokens,
"stream": True,
}

if api_params["temperature"] == 0:
api_params["temperature"] = 0.1

if system_msg:
api_params["system"] = system_msg

if self.response_model:
api_params["messages"] += add_json_response_model_to_messages(self.response_model)

async with self.semaphore:
response = await self.client.messages.create(**api_params)
async for chunk in response:
try:
delta = chunk.delta.text # type: ignore
except AttributeError:
delta = ""
if delta:
yield delta

Comment on lines +272 to +307
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Eliminate code duplication and improve error handling.

The message preprocessing logic is duplicated from the create method. Consider extracting this into a helper method and adding proper error handling.

+    def _preprocess_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+        """Helper method to preprocess messages for Anthropic API."""
+        processed_messages = []
+        for msg in messages:
+            if msg.get("role") == "developer":
+                processed_messages.append({"role": "system", "content": msg.get("content")})
+                logger.debug("Converted 'developer' role to 'system' for Anthropic call.")
+            else:
+                processed_messages.append(msg)
+        return processed_messages
+
     async def stream(self, messages: List[Dict[str, Any]], max_tokens: int = 4000):
-        processed_messages = []
-        for msg in messages:
-            if msg.get("role") == "developer":
-                processed_messages.append({"role": "system", "content": msg.get("content")})
-            else:
-                processed_messages.append(msg)
+        try:
+            processed_messages = self._preprocess_messages(messages)
+            # ... rest of the implementation
+        except Exception as e:
+            logger.error(f"Anthropic streaming error: {str(e)}")
+            # Fall back to non-streaming
+            result, _ = await self.create(messages=messages, max_tokens=max_tokens)
+            if result:
+                yield result

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In src/wandbot/models/llm.py between lines 272 and 307, the message
preprocessing logic duplicates code from the create method and lacks robust
error handling. Refactor by extracting the message preprocessing into a separate
helper method that both stream and create methods can call. Add try-except
blocks around the preprocessing steps to catch and handle potential errors
gracefully, ensuring the stream method yields results only when preprocessing
succeeds.

class AsyncGoogleGenAILLMModel(BaseLLMModel):
def __init__(self, **kwargs):
super().__init__(**kwargs)
Expand Down Expand Up @@ -322,6 +383,11 @@ async def create(self,
error_info=error_info
)

async def stream(self, messages: List[Dict[str, Any]], max_tokens: int = 4000):
result, _ = await self.create(messages=messages, max_tokens=max_tokens)
if result:
yield result


class LLMModel:
PROVIDER_MAP = {
Expand All @@ -344,8 +410,8 @@ def __init__(self, provider: str, **kwargs):
def model_name(self) -> str:
return self.model.model_name

async def create(self,
messages: List[Dict[str, Any]],
async def create(self,
messages: List[Dict[str, Any]],
**kwargs) -> tuple[Union[str, BaseModel], APIStatus]:
try:
async with self.model.semaphore: # Use the specific model's semaphore
Expand All @@ -369,3 +435,7 @@ async def create(self,
success=False,
error_info=error_info
)

async def stream(self, messages: List[Dict[str, Any]], **kwargs):
async for token in self.model.stream(messages=messages, **kwargs):
yield token
30 changes: 30 additions & 0 deletions src/wandbot/rag/response_synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,36 @@ async def __call__(self, inputs: RetrievalResult) -> Dict[str, Any]:
}
}

async def stream(self, inputs: RetrievalResult):
"""Stream response tokens while capturing the final result."""
formatted_input = self._format_input(inputs)
messages = self.get_messages(formatted_input)

result = ""
used_model = self.model
try:
async for token in self.model.stream(messages=messages):
result += token
yield token
except Exception as e:
logger.warning(f"Primary Response Synthesizer model failed, trying fallback: {str(e)}")
used_model = self.fallback_model
async for token in self.fallback_model.stream(messages=messages):
result += token
yield token

self.stream_output = {
"query_str": formatted_input["query_str"],
"context_str": formatted_input["context_str"],
"response": result,
"response_model": used_model.model_name,
"response_synthesis_llm_messages": messages,
"response_prompt": RESPONSE_SYNTHESIS_SYSTEM_PROMPT,
"api_statuses": {
"response_synthesis_llm_api": None
},
}
Comment on lines +230 to +258
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve error handling and API status tracking in the streaming method.

The streaming implementation has several issues that could affect reliability and debugging:

  1. API status is set to None instead of proper tracking
  2. Fallback error handling doesn't reset the result accumulator
  3. No exception handling for the fallback model failure
  4. Missing comprehensive error information
 async def stream(self, inputs: RetrievalResult):
     """Stream response tokens while capturing the final result."""
     formatted_input = self._format_input(inputs)
     messages = self.get_messages(formatted_input)
     
     result = ""
     used_model = self.model
+    llm_api_status = None
+    
     try:
         async for token in self.model.stream(messages=messages):
             result += token
             yield token
+        llm_api_status = APIStatus(success=True, error_info=None)
     except Exception as e:
         logger.warning(f"Primary Response Synthesizer model failed, trying fallback: {str(e)}")
         used_model = self.fallback_model
-        async for token in self.fallback_model.stream(messages=messages):
-            result += token
-            yield token
+        result = ""  # Reset accumulator for fallback
+        try:
+            async for token in self.fallback_model.stream(messages=messages):
+                result += token
+                yield token
+            llm_api_status = APIStatus(success=True, error_info=None)
+        except Exception as fallback_e:
+            logger.error(f"Both primary and fallback Response Synthesizer models failed: {str(fallback_e)}")
+            llm_api_status = APIStatus(success=False, error_info=str(fallback_e))
+            raise Exception(f"Response synthesis streaming failed: {str(fallback_e)}")
     
     self.stream_output = {
         "query_str": formatted_input["query_str"],
         "context_str": formatted_input["context_str"],
         "response": result,
         "response_model": used_model.model_name,
         "response_synthesis_llm_messages": messages,
         "response_prompt": RESPONSE_SYNTHESIS_SYSTEM_PROMPT,
         "api_statuses": {
-            "response_synthesis_llm_api": None
+            "response_synthesis_llm_api": llm_api_status
         },
     }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
async def stream(self, inputs: RetrievalResult):
"""Stream response tokens while capturing the final result."""
formatted_input = self._format_input(inputs)
messages = self.get_messages(formatted_input)
result = ""
used_model = self.model
try:
async for token in self.model.stream(messages=messages):
result += token
yield token
except Exception as e:
logger.warning(f"Primary Response Synthesizer model failed, trying fallback: {str(e)}")
used_model = self.fallback_model
async for token in self.fallback_model.stream(messages=messages):
result += token
yield token
self.stream_output = {
"query_str": formatted_input["query_str"],
"context_str": formatted_input["context_str"],
"response": result,
"response_model": used_model.model_name,
"response_synthesis_llm_messages": messages,
"response_prompt": RESPONSE_SYNTHESIS_SYSTEM_PROMPT,
"api_statuses": {
"response_synthesis_llm_api": None
},
}
async def stream(self, inputs: RetrievalResult):
"""Stream response tokens while capturing the final result."""
formatted_input = self._format_input(inputs)
messages = self.get_messages(formatted_input)
result = ""
used_model = self.model
llm_api_status = None
try:
async for token in self.model.stream(messages=messages):
result += token
yield token
llm_api_status = APIStatus(success=True, error_info=None)
except Exception as e:
logger.warning(f"Primary Response Synthesizer model failed, trying fallback: {str(e)}")
used_model = self.fallback_model
result = "" # Reset accumulator for fallback
try:
async for token in self.fallback_model.stream(messages=messages):
result += token
yield token
llm_api_status = APIStatus(success=True, error_info=None)
except Exception as fallback_e:
logger.error(f"Both primary and fallback Response Synthesizer models failed: {str(fallback_e)}")
llm_api_status = APIStatus(success=False, error_info=str(fallback_e))
raise Exception(f"Response synthesis streaming failed: {str(fallback_e)}")
self.stream_output = {
"query_str": formatted_input["query_str"],
"context_str": formatted_input["context_str"],
"response": result,
"response_model": used_model.model_name,
"response_synthesis_llm_messages": messages,
"response_prompt": RESPONSE_SYNTHESIS_SYSTEM_PROMPT,
"api_statuses": {
"response_synthesis_llm_api": llm_api_status
},
}
🤖 Prompt for AI Agents
In src/wandbot/rag/response_synthesis.py lines 230 to 258, improve the stream
method by properly tracking API call statuses instead of setting them to None,
reset the result accumulator before using the fallback model to avoid mixing
outputs, add exception handling around the fallback model streaming to catch and
log any errors, and include detailed error information in the logs to aid
debugging and reliability.


def _format_input(self, inputs: RetrievalResult) -> Dict[str, str]:
"""Format the input data for the prompt template."""
return {
Expand Down
38 changes: 38 additions & 0 deletions tests/test_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest

from wandbot.rag.response_synthesis import ResponseSynthesizer
from wandbot.schema.document import Document
from wandbot.schema.retrieval import RetrievalResult


async def fake_stream(*args, **kwargs):
for tok in ["hello ", "world"]:
yield tok


@pytest.mark.asyncio
async def test_streaming_response():
synth = ResponseSynthesizer(
primary_provider="openai",
primary_model_name="dummy",
primary_temperature=0,
fallback_provider="openai",
fallback_model_name="dummy",
fallback_temperature=0,
max_retries=1,
)

synth.model.stream = fake_stream # type: ignore
synth.get_messages = lambda x: []

retrieval = RetrievalResult(
documents=[Document(page_content="doc", metadata={"source": "s"})],
retrieval_info={"query": "q", "language": "en", "intents": [], "sub_queries": []},
)

tokens = []
async for token in synth.stream(retrieval):
tokens.append(token)

assert "".join(tokens) == "hello world"
assert synth.stream_output["response"] == "hello world"