1- from typing import Dict , List , Optional , Union , Any
2- import numpy as np
3- from pydantic import BaseModel , Field
4- from openai import AsyncOpenAI , ChatCompletion
5- import os
61import logging
2+ import os
73from enum import Enum
8- import asyncio
9- from redis . asyncio import ConnectionPool , Redis
4+ from typing import Any
5+
106import anthropic
7+ import numpy as np
8+ from openai import AsyncOpenAI
9+ from pydantic import BaseModel , Field
10+
1111
1212# Setup logging
1313logger = logging .getLogger (__name__ )
@@ -23,16 +23,16 @@ class MemoryMessage(BaseModel):
2323class MemoryMessagesAndContext (BaseModel ):
2424 """Request payload for adding messages to memory"""
2525
26- messages : List [MemoryMessage ]
27- context : Optional [ str ] = None
26+ messages : list [MemoryMessage ]
27+ context : str | None = None
2828
2929
3030class MemoryResponse (BaseModel ):
3131 """Response containing messages and context"""
3232
33- messages : List [MemoryMessage ]
34- context : Optional [ str ] = None
35- tokens : Optional [ int ] = None
33+ messages : list [MemoryMessage ]
34+ context : str | None = None
35+ tokens : int | None = None
3636
3737
3838class SearchPayload (BaseModel ):
@@ -59,27 +59,27 @@ class RedisearchResult(BaseModel):
5959 role : str
6060 content : str
6161 dist : float
62-
63-
62+
63+
6464class SearchResults (BaseModel ):
6565 """Results from a redisearch query"""
6666
67- docs : List [RedisearchResult ]
67+ docs : list [RedisearchResult ]
6868 total : int
6969
7070
7171class NamespaceQuery (BaseModel ):
7272 """Query parameters for namespace"""
7373
74- namespace : Optional [ str ] = None
74+ namespace : str | None = None
7575
7676
7777class GetSessionsQuery (BaseModel ):
7878 """Query parameters for getting sessions"""
7979
8080 page : int = Field (default = 1 )
8181 size : int = Field (default = 20 )
82- namespace : Optional [ str ] = None
82+ namespace : str | None = None
8383
8484
8585class ModelProvider (str , Enum ):
@@ -260,7 +260,7 @@ def get_model_config(model_name: str) -> ModelConfig:
260260class ChatResponse :
261261 """Unified wrapper for chat responses from different providers"""
262262
263- def __init__ (self , choices : List [Any ], usage : Dict [str , int ]):
263+ def __init__ (self , choices : list [Any ], usage : dict [str , int ]):
264264 self .choices = choices or []
265265 self .usage = usage or {"total_tokens" : 0 }
266266
@@ -319,7 +319,7 @@ async def create_chat_completion(self, model: str, prompt: str) -> ChatResponse:
319319 logger .error (f"Error creating chat completion with Anthropic: { e } " )
320320 raise
321321
322- async def create_embedding (self , query_vec : List [str ]) -> np .ndarray :
322+ async def create_embedding (self , query_vec : list [str ]) -> np .ndarray :
323323 """
324324 Create embeddings for the given texts
325325 Note: Anthropic doesn't offer an embedding API, so we'll use OpenAI's
@@ -345,22 +345,27 @@ def __init__(self, api_key: str | None = None, base_url: str | None = None):
345345
346346 if openai_api_base :
347347 self .completion_client = AsyncOpenAI (
348- api_key = openai_api_key , base_url = openai_api_base
348+ api_key = openai_api_key ,
349+ base_url = openai_api_base ,
349350 )
350351 self .embedding_client = AsyncOpenAI (
351- api_key = openai_api_key , base_url = openai_api_base
352+ api_key = openai_api_key ,
353+ base_url = openai_api_base ,
352354 )
353355 else :
354356 self .completion_client = AsyncOpenAI (api_key = openai_api_key )
355357 self .embedding_client = AsyncOpenAI (api_key = openai_api_key )
356358
357359 async def create_chat_completion (
358- self , model : str , progressive_prompt : str
360+ self ,
361+ model : str ,
362+ progressive_prompt : str ,
359363 ) -> ChatResponse :
360364 """Create a chat completion using the OpenAI API"""
361365 try :
362366 response = await self .completion_client .chat .completions .create (
363- model = model , messages = [{"role" : "user" , "content" : progressive_prompt }]
367+ model = model ,
368+ messages = [{"role" : "user" , "content" : progressive_prompt }],
364369 )
365370
366371 # Convert to unified format
@@ -380,7 +385,7 @@ async def create_chat_completion(
380385 logger .error (f"Error creating chat completion with OpenAI: { e } " )
381386 raise
382387
383- async def create_embedding (self , query_vec : List [str ]) -> np .ndarray :
388+ async def create_embedding (self , query_vec : list [str ]) -> np .ndarray :
384389 """Create embeddings for the given texts"""
385390 try :
386391 embeddings = []
@@ -391,7 +396,8 @@ async def create_embedding(self, query_vec: List[str]) -> np.ndarray:
391396 for i in range (0 , len (query_vec ), batch_size ):
392397 batch = query_vec [i : i + batch_size ]
393398 response = await self .embedding_client .embeddings .create (
394- model = embedding_model , input = batch
399+ model = embedding_model ,
400+ input = batch ,
395401 )
396402 batch_embeddings = [item .embedding for item in response .data ]
397403 embeddings .extend (batch_embeddings )
@@ -408,13 +414,13 @@ class ModelClientFactory:
408414 @staticmethod
409415 async def get_client (
410416 model_name : str ,
411- ) -> Union [ OpenAIClientWrapper , AnthropicClientWrapper ] :
417+ ) -> OpenAIClientWrapper | AnthropicClientWrapper :
412418 """Get the appropriate client for a model"""
413419 model_config = get_model_config (model_name )
414420
415421 if model_config .provider == ModelProvider .OPENAI :
416422 return OpenAIClientWrapper (api_key = os .environ .get ("OPENAI_API_KEY" ))
417- elif model_config .provider == ModelProvider .ANTHROPIC :
423+ if model_config .provider == ModelProvider .ANTHROPIC :
418424 return AnthropicClientWrapper (api_key = os .environ .get ("ANTHROPIC_API_KEY" ))
419- else :
420- raise ValueError (f"Unsupported model provider: { model_config .provider } " )
425+
426+ raise ValueError (f"Unsupported model provider: { model_config .provider } " )
0 commit comments