@@ -54,7 +54,7 @@ def __init__(
5454 max_tokens: The maximum number of tokens to generate per response.
5555 cache: Whether to cache the model responses for reuse to improve performance
5656 and reduce costs.
57- cache_in_memory: To enable additional caching with LRU in memory.
57+ cache_in_memory (deprecated) : To enable additional caching with LRU in memory.
5858 callbacks: A list of callback functions to run before and after each request.
5959 num_retries: The number of times to retry a request if it fails transiently due to
6060 network error, rate limiting, etc. Requests are retried with exponential
@@ -92,44 +92,69 @@ def __init__(
9292 else :
9393 self .kwargs = dict (temperature = temperature , max_tokens = max_tokens , ** kwargs )
9494
95+ def _get_cached_completion_fn (self , completion_fn , cache , enable_memory_cache ):
96+ ignored_args_for_cache_key = ["api_key" , "api_base" , "base_url" ]
97+ if cache and enable_memory_cache :
98+ completion_fn = request_cache (
99+ cache_arg_name = "request" ,
100+ ignored_args_for_cache_key = ignored_args_for_cache_key ,
101+ )(completion_fn )
102+ elif cache :
103+ completion_fn = request_cache (
104+ cache_arg_name = "request" ,
105+ ignored_args_for_cache_key = ignored_args_for_cache_key ,
106+ enable_memory_cache = False ,
107+ )(completion_fn )
108+ else :
109+ completion_fn = completion_fn
110+
111+ if not cache or litellm .cache is None :
112+ litellm_cache_args = {"no-cache" : True , "no-store" : True }
113+ else :
114+ litellm_cache_args = {"no-cache" : False , "no-store" : False }
115+
116+ return completion_fn , litellm_cache_args
117+
95118 def forward (self , prompt = None , messages = None , ** kwargs ):
96119 # Build the request.
97120 cache = kwargs .pop ("cache" , self .cache )
98- # disable cache will also disable in memory cache
99- cache_in_memory = cache and kwargs . pop ( "cache_in_memory" , self . cache_in_memory )
121+ enable_memory_cache = kwargs . pop ( "cache_in_memory" , self . cache_in_memory )
122+
100123 messages = messages or [{"role" : "user" , "content" : prompt }]
101124 kwargs = {** self .kwargs , ** kwargs }
102125
103- # Make the request and handle LRU & disk caching.
104- if cache_in_memory :
105- completion = cached_litellm_completion if self .model_type == "chat" else cached_litellm_text_completion
106-
107- results = completion (
108- request = dict (model = self .model , messages = messages , ** kwargs ),
109- num_retries = self .num_retries ,
110- )
111- else :
112- completion = litellm_completion if self .model_type == "chat" else litellm_text_completion
126+ completion = litellm_completion if self .model_type == "chat" else litellm_text_completion
127+ completion , litellm_cache_args = self ._get_cached_completion_fn (completion , cache , enable_memory_cache )
113128
114- results = completion (
115- request = dict (model = self .model , messages = messages , ** kwargs ),
116- num_retries = self .num_retries ,
117- # only leverage LiteLLM cache in this case
118- cache = {"no-cache" : not cache , "no-store" : not cache },
119- )
129+ results = completion (
130+ request = dict (model = self .model , messages = messages , ** kwargs ),
131+ num_retries = self .num_retries ,
132+ cache = litellm_cache_args ,
133+ )
120134
121135 if not getattr (results , "cache_hit" , False ) and dspy .settings .usage_tracker and hasattr (results , "usage" ):
122136 settings .usage_tracker .add_usage (self .model , dict (results .usage ))
123137 return results
124138
125139 async def aforward (self , prompt = None , messages = None , ** kwargs ):
126- completion = alitellm_completion if self .model_type == "chat" else alitellm_text_completion
140+ # Build the request.
141+ cache = kwargs .pop ("cache" , self .cache )
142+ enable_memory_cache = kwargs .pop ("cache_in_memory" , self .cache_in_memory )
127143
128144 messages = messages or [{"role" : "user" , "content" : prompt }]
145+ kwargs = {** self .kwargs , ** kwargs }
146+
147+ completion = alitellm_completion if self .model_type == "chat" else alitellm_text_completion
148+ completion , litellm_cache_args = self ._get_cached_completion_fn (completion , cache , enable_memory_cache )
149+
129150 results = await completion (
130151 request = dict (model = self .model , messages = messages , ** kwargs ),
131152 num_retries = self .num_retries ,
153+ cache = litellm_cache_args ,
132154 )
155+
156+ if not getattr (results , "cache_hit" , False ) and dspy .settings .usage_tracker and hasattr (results , "usage" ):
157+ settings .usage_tracker .add_usage (self .model , dict (results .usage ))
133158 return results
134159
135160 def launch (self , launch_kwargs : Optional [Dict [str , Any ]] = None ):
@@ -206,22 +231,6 @@ def dump_state(self):
206231 return {key : getattr (self , key ) for key in state_keys } | self .kwargs
207232
208233
209- @request_cache (cache_arg_name = "request" , ignored_args_for_cache_key = ["api_key" , "api_base" , "base_url" ])
210- def cached_litellm_completion (request : Dict [str , Any ], num_retries : int ):
211- import litellm
212-
213- if litellm .cache :
214- litellm_cache_args = {"no-cache" : False , "no-store" : False }
215- else :
216- litellm_cache_args = {"no-cache" : True , "no-store" : True }
217-
218- return litellm_completion (
219- request ,
220- cache = litellm_cache_args ,
221- num_retries = num_retries ,
222- )
223-
224-
225234def litellm_completion (request : Dict [str , Any ], num_retries : int , cache = {"no-cache" : True , "no-store" : True }):
226235 retry_kwargs = dict (
227236 retry_policy = _get_litellm_retry_policy (num_retries ),
@@ -267,22 +276,6 @@ async def stream_completion():
267276 return stream_completion ()
268277
269278
270- @request_cache (cache_arg_name = "request" , ignored_args_for_cache_key = ["api_key" , "api_base" , "base_url" ])
271- def cached_litellm_text_completion (request : Dict [str , Any ], num_retries : int ):
272- import litellm
273-
274- if litellm .cache :
275- litellm_cache_args = {"no-cache" : False , "no-store" : False }
276- else :
277- litellm_cache_args = {"no-cache" : True , "no-store" : True }
278-
279- return litellm_text_completion (
280- request ,
281- num_retries = num_retries ,
282- cache = litellm_cache_args ,
283- )
284-
285-
286279def litellm_text_completion (request : Dict [str , Any ], num_retries : int , cache = {"no-cache" : True , "no-store" : True }):
287280 # Extract the provider and model from the model string.
288281 # TODO: Not all the models are in the format of "provider/model"
0 commit comments