-
Notifications
You must be signed in to change notification settings - Fork 56
Implement streaming endpoint #99
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -141,4 +141,46 @@ async def __acall__( | |||||||||||||||||||
| def __call__( | ||||||||||||||||||||
| self, question: str, chat_history: List[Tuple[str, str]] | None = None | ||||||||||||||||||||
| ) -> RAGPipelineOutput: | ||||||||||||||||||||
| return run_sync(self.__acall__(question, chat_history)) | ||||||||||||||||||||
| return run_sync(self.__acall__(question, chat_history)) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| async def astream( | ||||||||||||||||||||
| self, question: str, chat_history: List[Tuple[str, str]] | None = None | ||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||
|
Comment on lines
+146
to
+148
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix the return type annotation for the async generator. The method yields tokens but is annotated as returning - async def astream(
- self, question: str, chat_history: List[Tuple[str, str]] | None = None
- ) -> None:
+ async def astream(
+ self, question: str, chat_history: List[Tuple[str, str]] | None = None
+ ) -> AsyncGenerator[str, None]:You'll also need to add the import: -from typing import Dict, List, Tuple
+from typing import AsyncGenerator, Dict, List, Tuple📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||
| """Stream tokens from the response synthesizer.""" | ||||||||||||||||||||
| if chat_history is None: | ||||||||||||||||||||
| chat_history = [] | ||||||||||||||||||||
|
|
||||||||||||||||||||
| enhanced_query = await self.query_enhancer({"query": question, "chat_history": chat_history}) | ||||||||||||||||||||
| retrieval_result = await self.retrieval_engine.__acall__(enhanced_query) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| async for token in self.response_synthesizer.stream(retrieval_result): | ||||||||||||||||||||
| yield token | ||||||||||||||||||||
|
|
||||||||||||||||||||
| response = self.response_synthesizer.stream_output | ||||||||||||||||||||
|
|
||||||||||||||||||||
| self.stream_result = RAGPipelineOutput( | ||||||||||||||||||||
| question=enhanced_query["standalone_query"], | ||||||||||||||||||||
| answer=response["response"], | ||||||||||||||||||||
| sources="\n".join( | ||||||||||||||||||||
| [doc.metadata["source"] for doc in retrieval_result.documents] | ||||||||||||||||||||
| ), | ||||||||||||||||||||
| source_documents=response["context_str"], | ||||||||||||||||||||
| system_prompt=response["response_prompt"], | ||||||||||||||||||||
| model=response["response_model"], | ||||||||||||||||||||
| total_tokens=0, | ||||||||||||||||||||
| prompt_tokens=0, | ||||||||||||||||||||
| completion_tokens=0, | ||||||||||||||||||||
| time_taken=0, | ||||||||||||||||||||
| start_time=datetime.datetime.now(), | ||||||||||||||||||||
| end_time=datetime.datetime.now(), | ||||||||||||||||||||
| api_call_statuses={ | ||||||||||||||||||||
| "web_search_success": retrieval_result.retrieval_info["api_statuses"]["web_search_api"].success, | ||||||||||||||||||||
| "reranker_api_error_info": retrieval_result.retrieval_info["api_statuses"]["reranker_api"].error_info, | ||||||||||||||||||||
| "reranker_api_success": retrieval_result.retrieval_info["api_statuses"]["reranker_api"].success, | ||||||||||||||||||||
| "query_enhancer_llm_api_error_info": enhanced_query.get("api_statuses", {}).get("query_enhancer_llm_api", {}).error_info if enhanced_query.get("api_statuses") else None, | ||||||||||||||||||||
| "query_enhancer_llm_api_success": enhanced_query.get("api_statuses", {}).get("query_enhancer_llm_api", {}).success if enhanced_query.get("api_statuses") else False, | ||||||||||||||||||||
| "embedding_api_error_info": retrieval_result.retrieval_info["api_statuses"]["embedding_api"].error_info, | ||||||||||||||||||||
| "embedding_api_success": retrieval_result.retrieval_info["api_statuses"]["embedding_api"].success, | ||||||||||||||||||||
| }, | ||||||||||||||||||||
| response_synthesis_llm_messages=response.get("response_synthesis_llm_messages"), | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
Comment on lines
+161
to
+186
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Eliminate code duplication and improve error handling. The + def _build_pipeline_output(self, enhanced_query, response, retrieval_result) -> RAGPipelineOutput:
+ """Helper method to build RAGPipelineOutput consistently."""
+ return RAGPipelineOutput(
+ question=enhanced_query["standalone_query"],
+ answer=response["response"],
+ sources="\n".join(
+ [doc.metadata["source"] for doc in retrieval_result.documents]
+ ),
+ source_documents=response["context_str"],
+ system_prompt=response["response_prompt"],
+ model=response["response_model"],
+ total_tokens=0,
+ prompt_tokens=0,
+ completion_tokens=0,
+ time_taken=0,
+ start_time=datetime.datetime.now(),
+ end_time=datetime.datetime.now(),
+ api_call_statuses={
+ "web_search_success": retrieval_result.retrieval_info["api_statuses"]["web_search_api"].success,
+ "reranker_api_error_info": retrieval_result.retrieval_info["api_statuses"]["reranker_api"].error_info,
+ "reranker_api_success": retrieval_result.retrieval_info["api_statuses"]["reranker_api"].success,
+ "query_enhancer_llm_api_error_info": enhanced_query.get("api_statuses", {}).get("query_enhancer_llm_api", {}).error_info if enhanced_query.get("api_statuses") else None,
+ "query_enhancer_llm_api_success": enhanced_query.get("api_statuses", {}).get("query_enhancer_llm_api", {}).success if enhanced_query.get("api_statuses") else False,
+ "embedding_api_error_info": retrieval_result.retrieval_info["api_statuses"]["embedding_api"].error_info,
+ "embedding_api_success": retrieval_result.retrieval_info["api_statuses"]["embedding_api"].success,
+ },
+ response_synthesis_llm_messages=response.get("response_synthesis_llm_messages"),
+ )
🤖 Prompt for AI Agents |
||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -107,11 +107,16 @@ def __init__(self, | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.timeout = timeout | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.semaphore = asyncio.Semaphore(n_parallel_api_calls) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def create(self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| messages: List[Dict[str, Any]], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def create(self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| messages: List[Dict[str, Any]], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| **kwargs) -> tuple[Union[str, BaseModel], APIStatus]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise NotImplementedError("Subclasses must implement create method") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def stream(self, messages: List[Dict[str, Any]]): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| result, _ = await self.create(messages=messages) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if result: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| yield result | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class AsyncOpenAILLMModel(BaseLLMModel): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| JSON_MODELS = [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "gpt-4-", # All gpt-4- models | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -180,6 +185,26 @@ async def create(self, | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| error_info=error_info | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def stream(self, messages: List[Dict[str, Any]]): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| api_params = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "model": self.model_name, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "temperature": self.temperature, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "messages": messages, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "stream": True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if api_params["temperature"] == 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| api_params["temperature"] = 0.1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.model_name.startswith("o"): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| api_params.pop("temperature", None) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async with self.semaphore: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| response = await self.client.chat.completions.create(**api_params) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async for chunk in response: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| delta = chunk.choices[0].delta.content | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if delta: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| yield delta | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+188
to
+207
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add response_model support and error handling to OpenAI streaming. The streaming implementation doesn't handle the async def stream(self, messages: List[Dict[str, Any]]):
+ if self.response_model:
+ # For structured outputs, fall back to non-streaming
+ result, _ = await self.create(messages=messages)
+ if result:
+ yield result
+ return
+
+ try:
api_params = {
"model": self.model_name,
"temperature": self.temperature,
"messages": messages,
"stream": True,
}
if api_params["temperature"] == 0:
api_params["temperature"] = 0.1
if self.model_name.startswith("o"):
api_params.pop("temperature", None)
async with self.semaphore:
response = await self.client.chat.completions.create(**api_params)
async for chunk in response:
delta = chunk.choices[0].delta.content
if delta:
yield delta
+ except Exception as e:
+ logger.error(f"OpenAI streaming error: {str(e)}")
+ # Fall back to non-streaming
+ result, _ = await self.create(messages=messages)
+ if result:
+ yield result📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class AsyncAnthropicLLMModel(BaseLLMModel): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__(self, **kwargs): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| super().__init__(**kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -244,6 +269,42 @@ async def create(self, | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| error_info=error_info | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def stream(self, messages: List[Dict[str, Any]], max_tokens: int = 4000): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| processed_messages = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for msg in messages: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if msg.get("role") == "developer": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| processed_messages.append({"role": "system", "content": msg.get("content")}) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| processed_messages.append(msg) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| system_msg, chat_messages = extract_system_and_messages(processed_messages) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| api_params = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "model": self.model_name, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "temperature": self.temperature, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "messages": chat_messages, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "max_tokens": max_tokens, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "stream": True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if api_params["temperature"] == 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| api_params["temperature"] = 0.1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if system_msg: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| api_params["system"] = system_msg | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.response_model: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| api_params["messages"] += add_json_response_model_to_messages(self.response_model) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async with self.semaphore: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| response = await self.client.messages.create(**api_params) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async for chunk in response: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| delta = chunk.delta.text # type: ignore | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| except AttributeError: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| delta = "" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if delta: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| yield delta | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+272
to
+307
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Eliminate code duplication and improve error handling. The message preprocessing logic is duplicated from the + def _preprocess_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """Helper method to preprocess messages for Anthropic API."""
+ processed_messages = []
+ for msg in messages:
+ if msg.get("role") == "developer":
+ processed_messages.append({"role": "system", "content": msg.get("content")})
+ logger.debug("Converted 'developer' role to 'system' for Anthropic call.")
+ else:
+ processed_messages.append(msg)
+ return processed_messages
+
async def stream(self, messages: List[Dict[str, Any]], max_tokens: int = 4000):
- processed_messages = []
- for msg in messages:
- if msg.get("role") == "developer":
- processed_messages.append({"role": "system", "content": msg.get("content")})
- else:
- processed_messages.append(msg)
+ try:
+ processed_messages = self._preprocess_messages(messages)
+ # ... rest of the implementation
+ except Exception as e:
+ logger.error(f"Anthropic streaming error: {str(e)}")
+ # Fall back to non-streaming
+ result, _ = await self.create(messages=messages, max_tokens=max_tokens)
+ if result:
+ yield result
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class AsyncGoogleGenAILLMModel(BaseLLMModel): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__(self, **kwargs): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| super().__init__(**kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -322,6 +383,11 @@ async def create(self, | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| error_info=error_info | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def stream(self, messages: List[Dict[str, Any]], max_tokens: int = 4000): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| result, _ = await self.create(messages=messages, max_tokens=max_tokens) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if result: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| yield result | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class LLMModel: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| PROVIDER_MAP = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -344,8 +410,8 @@ def __init__(self, provider: str, **kwargs): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def model_name(self) -> str: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return self.model.model_name | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def create(self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| messages: List[Dict[str, Any]], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def create(self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| messages: List[Dict[str, Any]], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| **kwargs) -> tuple[Union[str, BaseModel], APIStatus]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async with self.model.semaphore: # Use the specific model's semaphore | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -369,3 +435,7 @@ async def create(self, | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| success=False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| error_info=error_info | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def stream(self, messages: List[Dict[str, Any]], **kwargs): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async for token in self.model.stream(messages=messages, **kwargs): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| yield token | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -227,6 +227,36 @@ async def __call__(self, inputs: RetrievalResult) -> Dict[str, Any]: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def stream(self, inputs: RetrievalResult): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Stream response tokens while capturing the final result.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| formatted_input = self._format_input(inputs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| messages = self.get_messages(formatted_input) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| result = "" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| used_model = self.model | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async for token in self.model.stream(messages=messages): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| result += token | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| yield token | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.warning(f"Primary Response Synthesizer model failed, trying fallback: {str(e)}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| used_model = self.fallback_model | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async for token in self.fallback_model.stream(messages=messages): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| result += token | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| yield token | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.stream_output = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "query_str": formatted_input["query_str"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "context_str": formatted_input["context_str"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "response": result, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "response_model": used_model.model_name, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "response_synthesis_llm_messages": messages, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "response_prompt": RESPONSE_SYNTHESIS_SYSTEM_PROMPT, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "api_statuses": { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "response_synthesis_llm_api": None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+230
to
+258
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Improve error handling and API status tracking in the streaming method. The streaming implementation has several issues that could affect reliability and debugging:
async def stream(self, inputs: RetrievalResult):
"""Stream response tokens while capturing the final result."""
formatted_input = self._format_input(inputs)
messages = self.get_messages(formatted_input)
result = ""
used_model = self.model
+ llm_api_status = None
+
try:
async for token in self.model.stream(messages=messages):
result += token
yield token
+ llm_api_status = APIStatus(success=True, error_info=None)
except Exception as e:
logger.warning(f"Primary Response Synthesizer model failed, trying fallback: {str(e)}")
used_model = self.fallback_model
- async for token in self.fallback_model.stream(messages=messages):
- result += token
- yield token
+ result = "" # Reset accumulator for fallback
+ try:
+ async for token in self.fallback_model.stream(messages=messages):
+ result += token
+ yield token
+ llm_api_status = APIStatus(success=True, error_info=None)
+ except Exception as fallback_e:
+ logger.error(f"Both primary and fallback Response Synthesizer models failed: {str(fallback_e)}")
+ llm_api_status = APIStatus(success=False, error_info=str(fallback_e))
+ raise Exception(f"Response synthesis streaming failed: {str(fallback_e)}")
self.stream_output = {
"query_str": formatted_input["query_str"],
"context_str": formatted_input["context_str"],
"response": result,
"response_model": used_model.model_name,
"response_synthesis_llm_messages": messages,
"response_prompt": RESPONSE_SYNTHESIS_SYSTEM_PROMPT,
"api_statuses": {
- "response_synthesis_llm_api": None
+ "response_synthesis_llm_api": llm_api_status
},
}📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _format_input(self, inputs: RetrievalResult) -> Dict[str, str]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Format the input data for the prompt template.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| import pytest | ||
|
|
||
| from wandbot.rag.response_synthesis import ResponseSynthesizer | ||
| from wandbot.schema.document import Document | ||
| from wandbot.schema.retrieval import RetrievalResult | ||
|
|
||
|
|
||
| async def fake_stream(*args, **kwargs): | ||
| for tok in ["hello ", "world"]: | ||
| yield tok | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_streaming_response(): | ||
| synth = ResponseSynthesizer( | ||
| primary_provider="openai", | ||
| primary_model_name="dummy", | ||
| primary_temperature=0, | ||
| fallback_provider="openai", | ||
| fallback_model_name="dummy", | ||
| fallback_temperature=0, | ||
| max_retries=1, | ||
| ) | ||
|
|
||
| synth.model.stream = fake_stream # type: ignore | ||
| synth.get_messages = lambda x: [] | ||
|
|
||
| retrieval = RetrievalResult( | ||
| documents=[Document(page_content="doc", metadata={"source": "s"})], | ||
| retrieval_info={"query": "q", "language": "en", "intents": [], "sub_queries": []}, | ||
| ) | ||
|
|
||
| tokens = [] | ||
| async for token in synth.stream(retrieval): | ||
| tokens.append(token) | ||
|
|
||
| assert "".join(tokens) == "hello world" | ||
| assert synth.stream_output["response"] == "hello world" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Enhance error handling and consistency with the main
__acall__method.The streaming implementation lacks the comprehensive error handling, timing, and API status tracking present in
__acall__. This could result in inconsistent behavior and poor error visibility.Consider these improvements:
ErrorInfotrackingTimer__acall__async def astream(self, chat_request: ChatRequest): """Stream the chat response tokens asynchronously.""" original_language = chat_request.language + api_call_statuses = {} working_request = chat_request + + with Timer() as timer: + try: + # Handle Japanese translation with error handling + if original_language == "ja": + try: + translated_question = translate_ja_to_en( + chat_request.question, self.chat_config.ja_translation_model_name + ) + working_request = ChatRequest( + question=translated_question, + chat_history=chat_request.chat_history, + application=chat_request.application, + language="en", + ) + except Exception as e: + # Handle translation error similar to __acall__ + api_call_statuses["chat_success"] = False + api_call_statuses["chat_error_info"] = ErrorInfo( + has_error=True, error_message=str(e), + error_type=type(e).__name__, component="translation" + ).model_dump() + yield f"Translation error: {str(e)}" + return + + async for token in self.rag_pipeline.astream( + working_request.question, working_request.chat_history or [] + ): + yield token + + except Exception as e: + # Handle streaming errors + api_call_statuses["chat_success"] = False + api_call_statuses["chat_error_info"] = ErrorInfo( + has_error=True, error_message=str(e), + error_type=type(e).__name__, component="chat" + ).model_dump() + yield f"Streaming error: {str(e)}" + return + + # Store final result with complete metadata + result = self.rag_pipeline.stream_result + result_dict = result.model_dump() + + # Handle response translation + if original_language == "ja": + try: + result_dict["answer"] = translate_en_to_ja( + result_dict["answer"], self.chat_config.ja_translation_model_name + ) + except Exception as e: + result_dict["answer"] = f"Translation error: {str(e)}\nOriginal answer: {result_dict['answer']}" + + # Update with complete metadata + api_call_statuses["chat_success"] = True + result_dict.update({ + "application": chat_request.application, + "api_call_statuses": api_call_statuses, + "time_taken": timer.elapsed, + "start_time": timer.start, + "end_time": timer.stop, + }) + + self.last_stream_response = ChatResponse(**result_dict)🤖 Prompt for AI Agents