Skip to content

Implement streaming endpoint#99

Open
morganmcg1 wants to merge 1 commit intomainfrom
codex/add-chat-completion-streaming-functionality
Open

Implement streaming endpoint#99
morganmcg1 wants to merge 1 commit intomainfrom
codex/add-chat-completion-streaming-functionality

Conversation

@morganmcg1
Copy link
Member

@morganmcg1 morganmcg1 commented Jul 7, 2025

Summary

  • support streaming in ChatRequest
  • add StreamingResponse endpoint
  • connect new astream pipeline to LLM streaming
  • provide LLM streaming implementations
  • test streaming via mock

Testing

  • pytest -k test_streaming_response tests/test_stream.py (fails: ModuleNotFoundError: No module named 'openai')

https://chatgpt.com/codex/tasks/task_e_686b7f40cdf4832ba2d6ca8a1b3a7570

Summary by CodeRabbit

  • New Features

    • Added support for streaming chat responses, allowing users to receive answers incrementally as they are generated.
    • Introduced a new option to enable or disable streaming in chat requests.
  • Bug Fixes

    • Improved handling of language translation during streaming for Japanese queries.
  • Tests

    • Added tests to verify the streaming response functionality in chat interactions.

@coderabbitai
Copy link

coderabbitai bot commented Jul 7, 2025

Walkthrough

The changes introduce asynchronous streaming capabilities to the chat and RAG pipeline layers, allowing chat responses to be delivered incrementally as tokens. New stream methods are implemented across LLM model classes and the response synthesizer, and API endpoints now conditionally support streaming output based on a new stream flag in the request schema. Associated tests are added for streaming behavior.

Changes

File(s) Change Summary
src/wandbot/api/routers/chat.py Modified the query endpoint to support streaming responses via a new stream request attribute.
src/wandbot/chat/chat.py Added astream method to Chat for asynchronous token streaming with translation handling.
src/wandbot/chat/rag.py Added astream method to RAGPipeline for streaming tokens and constructing pipeline output.
src/wandbot/chat/schemas.py Added stream: bool = False field to ChatRequest schema.
src/wandbot/models/llm.py Added asynchronous stream methods to LLM model base and subclasses for token-level streaming.
src/wandbot/rag/response_synthesis.py Added stream method to ResponseSynthesizer for streaming response tokens with fallback logic.
tests/test_stream.py Added asynchronous test for streaming response functionality in ResponseSynthesizer.

Sequence Diagram(s)

sequenceDiagram
    participant Client
    participant API (query endpoint)
    participant Chat
    participant RAGPipeline
    participant ResponseSynthesizer
    participant LLMModel

    Client->>API (query endpoint): Send chat request (with stream flag)
    API (query endpoint)->>Chat: Create ChatRequest (with stream)
    alt stream == true
        Chat->>RAGPipeline: astream(chat_request)
        RAGPipeline->>ResponseSynthesizer: stream(retrieval_result)
        ResponseSynthesizer->>LLMModel: stream(messages)
        loop For each token
            LLMModel-->>ResponseSynthesizer: yield token
            ResponseSynthesizer-->>RAGPipeline: yield token
            RAGPipeline-->>Chat: yield token
            Chat-->>API (query endpoint): yield token
            API (query endpoint)-->>Client: yield token (SSE)
        end
    else stream == false
        Chat->>RAGPipeline: __acall__(chat_request)
        RAGPipeline->>ResponseSynthesizer: __call__(retrieval_result)
        ResponseSynthesizer->>LLMModel: create(messages)
        LLMModel-->>ResponseSynthesizer: full response
        ResponseSynthesizer-->>RAGPipeline: full response
        RAGPipeline-->>Chat: full response
        Chat-->>API (query endpoint): full response
        API (query endpoint)-->>Client: full response
    end
Loading

Poem

In the warren where tokens stream,
Rabbits code a brand new dream—
Responses flow, not all at once,
But hop by hop, with every nonce.
Now chats arrive in gentle parts,
Streaming answers, winning hearts!
🐇✨

✨ Finishing Touches
  • 📝 Generate Docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🔭 Outside diff range comments (1)
src/wandbot/models/llm.py (1)

439-442: Add error handling to the wrapper stream method.

The wrapper method should handle errors consistently with the create method.

     async def stream(self, messages: List[Dict[str, Any]], **kwargs):
-        async for token in self.model.stream(messages=messages, **kwargs):
-            yield token
+        try:
+            async for token in self.model.stream(messages=messages, **kwargs):
+                yield token
+        except Exception as e:
+            logger.error(f"LLMModel streaming error: {str(e)}")
+            # Fall back to non-streaming
+            result, _ = await self.create(messages=messages, **kwargs)
+            if result:
+                yield result
🧹 Nitpick comments (2)
src/wandbot/api/routers/chat.py (1)

48-54: Consider adding error handling and content-type headers for better streaming experience.

The streaming implementation is functional but could benefit from enhanced error handling and proper headers.

 if chat_req.stream:
     async def event_gen():
-        async for token in chat_instance.astream(chat_req):
-            yield f"data: {token}\n\n"
+        try:
+            async for token in chat_instance.astream(chat_req):
+                yield f"data: {token}\n\n"
+        except Exception as e:
+            logger.error(f"Error during streaming: {e}")
+            yield f"data: [ERROR] {str(e)}\n\n"
+        finally:
+            yield "data: [DONE]\n\n"
     
-    return StreamingResponse(event_gen(), media_type="text/event-stream")
+    return StreamingResponse(
+        event_gen(), 
+        media_type="text/event-stream",
+        headers={
+            "Cache-Control": "no-cache",
+            "Connection": "keep-alive",
+            "Access-Control-Allow-Origin": "*",
+        }
+    )
src/wandbot/models/llm.py (1)

386-390: Consider implementing true streaming for Google GenAI.

The current implementation doesn't actually stream tokens incrementally - it just yields the full result at once. This defeats the purpose of streaming.

Consider investigating if Google GenAI supports streaming and implement it properly, or document why streaming isn't supported:

     async def stream(self, messages: List[Dict[str, Any]], max_tokens: int = 4000):
+        # TODO: Implement actual streaming when Google GenAI supports it
+        # For now, fall back to full response yielding
         result, _ = await self.create(messages=messages, max_tokens=max_tokens)
         if result:
             yield result
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 62692d8 and 249800c.

📒 Files selected for processing (7)
  • src/wandbot/api/routers/chat.py (2 hunks)
  • src/wandbot/chat/chat.py (1 hunks)
  • src/wandbot/chat/rag.py (1 hunks)
  • src/wandbot/chat/schemas.py (1 hunks)
  • src/wandbot/models/llm.py (6 hunks)
  • src/wandbot/rag/response_synthesis.py (1 hunks)
  • tests/test_stream.py (1 hunks)
🔇 Additional comments (6)
src/wandbot/chat/schemas.py (1)

49-49: LGTM! Clean schema extension for streaming support.

The addition of the stream field with a sensible default value maintains backward compatibility while enabling the new streaming functionality.

src/wandbot/api/routers/chat.py (2)

2-2: LGTM! Appropriate import for streaming functionality.


40-46: LGTM! Clean ChatRequest construction with streaming support.

The updated ChatRequest construction properly includes the new stream field from the request payload.

tests/test_stream.py (1)

1-39: Streaming test coverage is solid; OpenAI dependency is already specified

The openai module is declared in both requirements.txt and pyproject.toml, so no dependency changes are needed.

Consider extending the test suite to cover:

  • Error handling scenarios in the streaming flow
  • Integration with the full Chat client rather than a bare ResponseSynthesizer
  • Language translation or other transformations during streaming
  • End-to-end API endpoint streaming behavior
src/wandbot/chat/rag.py (1)

159-159: stream_output attribute is correctly defined

After inspecting src/wandbot/rag/response_synthesis.py, the ResponseSynthesizer class includes:

  • self.stream_output = { "query_str":…, "context_str":…, "response": result }

Furthermore, tests/test_stream.py validates its usage (assert synth.stream_output["response"] == "hello world"). No changes needed.

src/wandbot/models/llm.py (1)

115-119: LGTM! Good fallback implementation.

The base class stream method provides a reasonable fallback by calling the existing create method and yielding the result.

Comment on lines +237 to +268
async def astream(self, chat_request: ChatRequest):
"""Stream the chat response tokens asynchronously."""
original_language = chat_request.language

working_request = chat_request

if original_language == "ja":
translated_question = translate_ja_to_en(
chat_request.question, self.chat_config.ja_translation_model_name
)
working_request = ChatRequest(
question=translated_question,
chat_history=chat_request.chat_history,
application=chat_request.application,
language="en",
)

async for token in self.rag_pipeline.astream(
working_request.question, working_request.chat_history or []
):
yield token

result = self.rag_pipeline.stream_result
result_dict = result.model_dump()

if original_language == "ja":
result_dict["answer"] = translate_en_to_ja(
result_dict["answer"], self.chat_config.ja_translation_model_name
)

result_dict.update({"application": chat_request.application})
self.last_stream_response = ChatResponse(**result_dict)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Enhance error handling and consistency with the main __acall__ method.

The streaming implementation lacks the comprehensive error handling, timing, and API status tracking present in __acall__. This could result in inconsistent behavior and poor error visibility.

Consider these improvements:

  1. Add proper exception handling with ErrorInfo tracking
  2. Include timing information using Timer
  3. Add API status tracking similar to __acall__
  4. Handle translation errors gracefully with fallback responses
 async def astream(self, chat_request: ChatRequest):
     """Stream the chat response tokens asynchronously."""
     original_language = chat_request.language
+    api_call_statuses = {}
     
     working_request = chat_request
+    
+    with Timer() as timer:
+        try:
+            # Handle Japanese translation with error handling
+            if original_language == "ja":
+                try:
+                    translated_question = translate_ja_to_en(
+                        chat_request.question, self.chat_config.ja_translation_model_name
+                    )
+                    working_request = ChatRequest(
+                        question=translated_question,
+                        chat_history=chat_request.chat_history,
+                        application=chat_request.application,
+                        language="en",
+                    )
+                except Exception as e:
+                    # Handle translation error similar to __acall__
+                    api_call_statuses["chat_success"] = False
+                    api_call_statuses["chat_error_info"] = ErrorInfo(
+                        has_error=True, error_message=str(e), 
+                        error_type=type(e).__name__, component="translation"
+                    ).model_dump()
+                    yield f"Translation error: {str(e)}"
+                    return
+                    
+            async for token in self.rag_pipeline.astream(
+                working_request.question, working_request.chat_history or []
+            ):
+                yield token
+                
+        except Exception as e:
+            # Handle streaming errors
+            api_call_statuses["chat_success"] = False
+            api_call_statuses["chat_error_info"] = ErrorInfo(
+                has_error=True, error_message=str(e),
+                error_type=type(e).__name__, component="chat"
+            ).model_dump()
+            yield f"Streaming error: {str(e)}"
+            return
+            
+        # Store final result with complete metadata
+        result = self.rag_pipeline.stream_result
+        result_dict = result.model_dump()
+        
+        # Handle response translation
+        if original_language == "ja":
+            try:
+                result_dict["answer"] = translate_en_to_ja(
+                    result_dict["answer"], self.chat_config.ja_translation_model_name
+                )
+            except Exception as e:
+                result_dict["answer"] = f"Translation error: {str(e)}\nOriginal answer: {result_dict['answer']}"
+                
+        # Update with complete metadata
+        api_call_statuses["chat_success"] = True
+        result_dict.update({
+            "application": chat_request.application,
+            "api_call_statuses": api_call_statuses,
+            "time_taken": timer.elapsed,
+            "start_time": timer.start,
+            "end_time": timer.stop,
+        })
+        
+        self.last_stream_response = ChatResponse(**result_dict)

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In src/wandbot/chat/chat.py lines 237 to 268, the astream method lacks the
robust error handling, timing, and API status tracking found in the __acall__
method, leading to inconsistent behavior and poor error visibility. To fix this,
wrap the streaming logic in try-except blocks to catch exceptions and record
them using ErrorInfo, use a Timer to measure execution time, and update API
status accordingly. Also, handle translation errors gracefully by providing
fallback responses instead of failing outright, ensuring consistent and reliable
streaming behavior.

Comment on lines +230 to +258
async def stream(self, inputs: RetrievalResult):
"""Stream response tokens while capturing the final result."""
formatted_input = self._format_input(inputs)
messages = self.get_messages(formatted_input)

result = ""
used_model = self.model
try:
async for token in self.model.stream(messages=messages):
result += token
yield token
except Exception as e:
logger.warning(f"Primary Response Synthesizer model failed, trying fallback: {str(e)}")
used_model = self.fallback_model
async for token in self.fallback_model.stream(messages=messages):
result += token
yield token

self.stream_output = {
"query_str": formatted_input["query_str"],
"context_str": formatted_input["context_str"],
"response": result,
"response_model": used_model.model_name,
"response_synthesis_llm_messages": messages,
"response_prompt": RESPONSE_SYNTHESIS_SYSTEM_PROMPT,
"api_statuses": {
"response_synthesis_llm_api": None
},
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve error handling and API status tracking in the streaming method.

The streaming implementation has several issues that could affect reliability and debugging:

  1. API status is set to None instead of proper tracking
  2. Fallback error handling doesn't reset the result accumulator
  3. No exception handling for the fallback model failure
  4. Missing comprehensive error information
 async def stream(self, inputs: RetrievalResult):
     """Stream response tokens while capturing the final result."""
     formatted_input = self._format_input(inputs)
     messages = self.get_messages(formatted_input)
     
     result = ""
     used_model = self.model
+    llm_api_status = None
+    
     try:
         async for token in self.model.stream(messages=messages):
             result += token
             yield token
+        llm_api_status = APIStatus(success=True, error_info=None)
     except Exception as e:
         logger.warning(f"Primary Response Synthesizer model failed, trying fallback: {str(e)}")
         used_model = self.fallback_model
-        async for token in self.fallback_model.stream(messages=messages):
-            result += token
-            yield token
+        result = ""  # Reset accumulator for fallback
+        try:
+            async for token in self.fallback_model.stream(messages=messages):
+                result += token
+                yield token
+            llm_api_status = APIStatus(success=True, error_info=None)
+        except Exception as fallback_e:
+            logger.error(f"Both primary and fallback Response Synthesizer models failed: {str(fallback_e)}")
+            llm_api_status = APIStatus(success=False, error_info=str(fallback_e))
+            raise Exception(f"Response synthesis streaming failed: {str(fallback_e)}")
     
     self.stream_output = {
         "query_str": formatted_input["query_str"],
         "context_str": formatted_input["context_str"],
         "response": result,
         "response_model": used_model.model_name,
         "response_synthesis_llm_messages": messages,
         "response_prompt": RESPONSE_SYNTHESIS_SYSTEM_PROMPT,
         "api_statuses": {
-            "response_synthesis_llm_api": None
+            "response_synthesis_llm_api": llm_api_status
         },
     }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
async def stream(self, inputs: RetrievalResult):
"""Stream response tokens while capturing the final result."""
formatted_input = self._format_input(inputs)
messages = self.get_messages(formatted_input)
result = ""
used_model = self.model
try:
async for token in self.model.stream(messages=messages):
result += token
yield token
except Exception as e:
logger.warning(f"Primary Response Synthesizer model failed, trying fallback: {str(e)}")
used_model = self.fallback_model
async for token in self.fallback_model.stream(messages=messages):
result += token
yield token
self.stream_output = {
"query_str": formatted_input["query_str"],
"context_str": formatted_input["context_str"],
"response": result,
"response_model": used_model.model_name,
"response_synthesis_llm_messages": messages,
"response_prompt": RESPONSE_SYNTHESIS_SYSTEM_PROMPT,
"api_statuses": {
"response_synthesis_llm_api": None
},
}
async def stream(self, inputs: RetrievalResult):
"""Stream response tokens while capturing the final result."""
formatted_input = self._format_input(inputs)
messages = self.get_messages(formatted_input)
result = ""
used_model = self.model
llm_api_status = None
try:
async for token in self.model.stream(messages=messages):
result += token
yield token
llm_api_status = APIStatus(success=True, error_info=None)
except Exception as e:
logger.warning(f"Primary Response Synthesizer model failed, trying fallback: {str(e)}")
used_model = self.fallback_model
result = "" # Reset accumulator for fallback
try:
async for token in self.fallback_model.stream(messages=messages):
result += token
yield token
llm_api_status = APIStatus(success=True, error_info=None)
except Exception as fallback_e:
logger.error(f"Both primary and fallback Response Synthesizer models failed: {str(fallback_e)}")
llm_api_status = APIStatus(success=False, error_info=str(fallback_e))
raise Exception(f"Response synthesis streaming failed: {str(fallback_e)}")
self.stream_output = {
"query_str": formatted_input["query_str"],
"context_str": formatted_input["context_str"],
"response": result,
"response_model": used_model.model_name,
"response_synthesis_llm_messages": messages,
"response_prompt": RESPONSE_SYNTHESIS_SYSTEM_PROMPT,
"api_statuses": {
"response_synthesis_llm_api": llm_api_status
},
}
🤖 Prompt for AI Agents
In src/wandbot/rag/response_synthesis.py lines 230 to 258, improve the stream
method by properly tracking API call statuses instead of setting them to None,
reset the result accumulator before using the fallback model to avoid mixing
outputs, add exception handling around the fallback model streaming to catch and
log any errors, and include detailed error information in the logs to aid
debugging and reliability.

Comment on lines +161 to +186
self.stream_result = RAGPipelineOutput(
question=enhanced_query["standalone_query"],
answer=response["response"],
sources="\n".join(
[doc.metadata["source"] for doc in retrieval_result.documents]
),
source_documents=response["context_str"],
system_prompt=response["response_prompt"],
model=response["response_model"],
total_tokens=0,
prompt_tokens=0,
completion_tokens=0,
time_taken=0,
start_time=datetime.datetime.now(),
end_time=datetime.datetime.now(),
api_call_statuses={
"web_search_success": retrieval_result.retrieval_info["api_statuses"]["web_search_api"].success,
"reranker_api_error_info": retrieval_result.retrieval_info["api_statuses"]["reranker_api"].error_info,
"reranker_api_success": retrieval_result.retrieval_info["api_statuses"]["reranker_api"].success,
"query_enhancer_llm_api_error_info": enhanced_query.get("api_statuses", {}).get("query_enhancer_llm_api", {}).error_info if enhanced_query.get("api_statuses") else None,
"query_enhancer_llm_api_success": enhanced_query.get("api_statuses", {}).get("query_enhancer_llm_api", {}).success if enhanced_query.get("api_statuses") else False,
"embedding_api_error_info": retrieval_result.retrieval_info["api_statuses"]["embedding_api"].error_info,
"embedding_api_success": retrieval_result.retrieval_info["api_statuses"]["embedding_api"].success,
},
response_synthesis_llm_messages=response.get("response_synthesis_llm_messages"),
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Eliminate code duplication and improve error handling.

The RAGPipelineOutput construction logic is duplicated between __acall__ and astream methods. Consider extracting this into a helper method and adding proper error handling.

+    def _build_pipeline_output(self, enhanced_query, response, retrieval_result) -> RAGPipelineOutput:
+        """Helper method to build RAGPipelineOutput consistently."""
+        return RAGPipelineOutput(
+            question=enhanced_query["standalone_query"],
+            answer=response["response"],
+            sources="\n".join(
+                [doc.metadata["source"] for doc in retrieval_result.documents]
+            ),
+            source_documents=response["context_str"],
+            system_prompt=response["response_prompt"],
+            model=response["response_model"],
+            total_tokens=0,
+            prompt_tokens=0,
+            completion_tokens=0,
+            time_taken=0,
+            start_time=datetime.datetime.now(),
+            end_time=datetime.datetime.now(),
+            api_call_statuses={
+                "web_search_success": retrieval_result.retrieval_info["api_statuses"]["web_search_api"].success,
+                "reranker_api_error_info": retrieval_result.retrieval_info["api_statuses"]["reranker_api"].error_info,
+                "reranker_api_success": retrieval_result.retrieval_info["api_statuses"]["reranker_api"].success,
+                "query_enhancer_llm_api_error_info": enhanced_query.get("api_statuses", {}).get("query_enhancer_llm_api", {}).error_info if enhanced_query.get("api_statuses") else None,
+                "query_enhancer_llm_api_success": enhanced_query.get("api_statuses", {}).get("query_enhancer_llm_api", {}).success if enhanced_query.get("api_statuses") else False,
+                "embedding_api_error_info": retrieval_result.retrieval_info["api_statuses"]["embedding_api"].error_info,
+                "embedding_api_success": retrieval_result.retrieval_info["api_statuses"]["embedding_api"].success,
+            },
+            response_synthesis_llm_messages=response.get("response_synthesis_llm_messages"),
+        )

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In src/wandbot/chat/rag.py around lines 161 to 186, the construction of the
RAGPipelineOutput object is duplicated in both __acall__ and astream methods. To
fix this, extract the RAGPipelineOutput creation logic into a separate helper
method that takes the necessary inputs and returns the constructed object. Then
replace the duplicated code in both methods with calls to this helper.
Additionally, add proper error handling within this helper to catch and manage
any exceptions during the construction process.

Comment on lines +146 to +148
async def astream(
self, question: str, chat_history: List[Tuple[str, str]] | None = None
) -> None:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix the return type annotation for the async generator.

The method yields tokens but is annotated as returning None. This should be an async generator type.

-    async def astream(
-        self, question: str, chat_history: List[Tuple[str, str]] | None = None
-    ) -> None:
+    async def astream(
+        self, question: str, chat_history: List[Tuple[str, str]] | None = None
+    ) -> AsyncGenerator[str, None]:

You'll also need to add the import:

-from typing import Dict, List, Tuple
+from typing import AsyncGenerator, Dict, List, Tuple
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
async def astream(
self, question: str, chat_history: List[Tuple[str, str]] | None = None
) -> None:
-from typing import Dict, List, Tuple
+from typing import AsyncGenerator, Dict, List, Tuple
async def astream(
self, question: str, chat_history: List[Tuple[str, str]] | None = None
) -> AsyncGenerator[str, None]:
🤖 Prompt for AI Agents
In src/wandbot/chat/rag.py around lines 146 to 148, the async method astream is
currently annotated as returning None, but it actually yields tokens as an async
generator. Change the return type annotation to an appropriate async generator
type that reflects the yielded token type. Also, add the necessary import for
the async generator type from the typing module to support this annotation.

Comment on lines +188 to +207
async def stream(self, messages: List[Dict[str, Any]]):
api_params = {
"model": self.model_name,
"temperature": self.temperature,
"messages": messages,
"stream": True,
}
if api_params["temperature"] == 0:
api_params["temperature"] = 0.1

if self.model_name.startswith("o"):
api_params.pop("temperature", None)

async with self.semaphore:
response = await self.client.chat.completions.create(**api_params)
async for chunk in response:
delta = chunk.choices[0].delta.content
if delta:
yield delta

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add response_model support and error handling to OpenAI streaming.

The streaming implementation doesn't handle the response_model case and lacks error handling compared to the create method.

     async def stream(self, messages: List[Dict[str, Any]]):
+        if self.response_model:
+            # For structured outputs, fall back to non-streaming
+            result, _ = await self.create(messages=messages)
+            if result:
+                yield result
+            return
+            
+        try:
             api_params = {
                 "model": self.model_name,
                 "temperature": self.temperature,
                 "messages": messages,
                 "stream": True,
             }
             if api_params["temperature"] == 0:
                 api_params["temperature"] = 0.1
     
             if self.model_name.startswith("o"):
                 api_params.pop("temperature", None)
     
             async with self.semaphore:
                 response = await self.client.chat.completions.create(**api_params)
                 async for chunk in response:
                     delta = chunk.choices[0].delta.content
                     if delta:
                         yield delta
+        except Exception as e:
+            logger.error(f"OpenAI streaming error: {str(e)}")
+            # Fall back to non-streaming
+            result, _ = await self.create(messages=messages)
+            if result:
+                yield result
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
async def stream(self, messages: List[Dict[str, Any]]):
api_params = {
"model": self.model_name,
"temperature": self.temperature,
"messages": messages,
"stream": True,
}
if api_params["temperature"] == 0:
api_params["temperature"] = 0.1
if self.model_name.startswith("o"):
api_params.pop("temperature", None)
async with self.semaphore:
response = await self.client.chat.completions.create(**api_params)
async for chunk in response:
delta = chunk.choices[0].delta.content
if delta:
yield delta
async def stream(self, messages: List[Dict[str, Any]]):
# If a response_model is configured, fall back to non-streaming
if self.response_model:
result, _ = await self.create(messages=messages)
if result:
yield result
return
try:
api_params = {
"model": self.model_name,
"temperature": self.temperature,
"messages": messages,
"stream": True,
}
if api_params["temperature"] == 0:
api_params["temperature"] = 0.1
if self.model_name.startswith("o"):
api_params.pop("temperature", None)
async with self.semaphore:
response = await self.client.chat.completions.create(**api_params)
async for chunk in response:
delta = chunk.choices[0].delta.content
if delta:
yield delta
except Exception as e:
logger.error(f"OpenAI streaming error: {e}")
# Fall back to non-streaming on error
result, _ = await self.create(messages=messages)
if result:
yield result
🤖 Prompt for AI Agents
In src/wandbot/models/llm.py around lines 188 to 207, the stream method lacks
support for the response_model parameter and does not include error handling
like the create method. Update the stream method to accept and pass the
response_model argument when calling the OpenAI API, and add appropriate
try-except blocks to catch and handle exceptions during the streaming process,
ensuring errors are logged or managed gracefully.

Comment on lines +272 to +307
async def stream(self, messages: List[Dict[str, Any]], max_tokens: int = 4000):
processed_messages = []
for msg in messages:
if msg.get("role") == "developer":
processed_messages.append({"role": "system", "content": msg.get("content")})
else:
processed_messages.append(msg)

system_msg, chat_messages = extract_system_and_messages(processed_messages)
api_params = {
"model": self.model_name,
"temperature": self.temperature,
"messages": chat_messages,
"max_tokens": max_tokens,
"stream": True,
}

if api_params["temperature"] == 0:
api_params["temperature"] = 0.1

if system_msg:
api_params["system"] = system_msg

if self.response_model:
api_params["messages"] += add_json_response_model_to_messages(self.response_model)

async with self.semaphore:
response = await self.client.messages.create(**api_params)
async for chunk in response:
try:
delta = chunk.delta.text # type: ignore
except AttributeError:
delta = ""
if delta:
yield delta

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Eliminate code duplication and improve error handling.

The message preprocessing logic is duplicated from the create method. Consider extracting this into a helper method and adding proper error handling.

+    def _preprocess_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+        """Helper method to preprocess messages for Anthropic API."""
+        processed_messages = []
+        for msg in messages:
+            if msg.get("role") == "developer":
+                processed_messages.append({"role": "system", "content": msg.get("content")})
+                logger.debug("Converted 'developer' role to 'system' for Anthropic call.")
+            else:
+                processed_messages.append(msg)
+        return processed_messages
+
     async def stream(self, messages: List[Dict[str, Any]], max_tokens: int = 4000):
-        processed_messages = []
-        for msg in messages:
-            if msg.get("role") == "developer":
-                processed_messages.append({"role": "system", "content": msg.get("content")})
-            else:
-                processed_messages.append(msg)
+        try:
+            processed_messages = self._preprocess_messages(messages)
+            # ... rest of the implementation
+        except Exception as e:
+            logger.error(f"Anthropic streaming error: {str(e)}")
+            # Fall back to non-streaming
+            result, _ = await self.create(messages=messages, max_tokens=max_tokens)
+            if result:
+                yield result

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In src/wandbot/models/llm.py between lines 272 and 307, the message
preprocessing logic duplicates code from the create method and lacks robust
error handling. Refactor by extracting the message preprocessing into a separate
helper method that both stream and create methods can call. Add try-except
blocks around the preprocessing steps to catch and handle potential errors
gracefully, ensuring the stream method yields results only when preprocessing
succeeds.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant