Skip to content

Commit 61619c1

Browse files
authored
Remove LiteLLM caching from LM (#8742)
* avoid attribute error in cache get * Update cache.py
1 parent da482fd commit 61619c1

File tree

6 files changed

+10
-206
lines changed

6 files changed

+10
-206
lines changed

dspy/clients/__init__.py

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from pathlib import Path
44

55
import litellm
6-
from litellm.caching.caching import Cache as LitellmCache
76

87
from dspy.clients.base_lm import BaseLM, inspect_history
98
from dspy.clients.cache import Cache
@@ -15,23 +14,12 @@
1514

1615
DISK_CACHE_DIR = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache")
1716
DISK_CACHE_LIMIT = int(os.environ.get("DSPY_CACHE_LIMIT", 3e10)) # 30 GB default
18-
19-
20-
def _litellm_track_cache_hit_callback(kwargs, completion_response, start_time, end_time):
21-
# Access the cache_hit information
22-
completion_response.cache_hit = kwargs.get("cache_hit", False)
23-
24-
25-
litellm.success_callback = [_litellm_track_cache_hit_callback]
26-
27-
2817
def configure_cache(
2918
enable_disk_cache: bool | None = True,
3019
enable_memory_cache: bool | None = True,
3120
disk_cache_dir: str | None = DISK_CACHE_DIR,
3221
disk_size_limit_bytes: int | None = DISK_CACHE_LIMIT,
3322
memory_max_entries: int | None = 1000000,
34-
enable_litellm_cache: bool = False,
3523
):
3624
"""Configure the cache for DSPy.
3725
@@ -41,27 +29,7 @@ def configure_cache(
4129
disk_cache_dir: The directory to store the on-disk cache.
4230
disk_size_limit_bytes: The size limit of the on-disk cache.
4331
memory_max_entries: The maximum number of entries in the in-memory cache.
44-
enable_litellm_cache: Whether to enable LiteLLM cache.
4532
"""
46-
if enable_disk_cache and enable_litellm_cache:
47-
raise ValueError(
48-
"Cannot enable both LiteLLM and DSPy on-disk cache, please set at most one of `enable_disk_cache` or "
49-
"`enable_litellm_cache` to True."
50-
)
51-
52-
if enable_litellm_cache:
53-
try:
54-
litellm.cache = LitellmCache(disk_cache_dir=DISK_CACHE_DIR, type="disk")
55-
56-
if litellm.cache.cache.disk_cache.size_limit != DISK_CACHE_LIMIT:
57-
litellm.cache.cache.disk_cache.reset("size_limit", DISK_CACHE_LIMIT)
58-
except Exception as e:
59-
# It's possible that users don't have the write permissions to the cache directory.
60-
# In that case, we'll just disable the cache.
61-
logger.warning("Failed to initialize LiteLLM cache: %s", e)
62-
litellm.cache = None
63-
else:
64-
litellm.cache = None
6533

6634
import dspy
6735

@@ -75,7 +43,7 @@ def configure_cache(
7543

7644

7745
litellm.telemetry = False
78-
litellm.cache = None # By default we disable litellm cache and use DSPy on-disk cache.
46+
litellm.cache = None # By default we disable LiteLLM cache and use DSPy on-disk cache.
7947

8048
DSPY_CACHE = Cache(
8149
enable_disk_cache=True,

dspy/clients/cache.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def get(self, request: dict[str, Any], ignored_args_for_cache_key: list[str] | N
118118
if hasattr(response, "usage"):
119119
# Clear the usage data when cache is hit, because no LM call is made
120120
response.usage = {}
121+
response.cache_hit = True
121122
return response
122123

123124
def put(

dspy/clients/lm.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ def __init__(
3333
temperature: float = 0.0,
3434
max_tokens: int = 4000,
3535
cache: bool = True,
36-
cache_in_memory: bool = True,
3736
callbacks: list[BaseCallback] | None = None,
3837
num_retries: int = 3,
3938
provider: Provider | None = None,
@@ -53,7 +52,6 @@ def __init__(
5352
max_tokens: The maximum number of tokens to generate per response.
5453
cache: Whether to cache the model responses for reuse to improve performance
5554
and reduce costs.
56-
cache_in_memory (deprecated): To enable additional caching with LRU in memory.
5755
callbacks: A list of callback functions to run before and after each request.
5856
num_retries: The number of times to retry a request if it fails transiently due to
5957
network error, rate limiting, etc. Requests are retried with exponential
@@ -66,7 +64,6 @@ def __init__(
6664
self.model = model
6765
self.model_type = model_type
6866
self.cache = cache
69-
self.cache_in_memory = cache_in_memory
7067
self.provider = provider or self.infer_provider()
7168
self.callbacks = callbacks or []
7269
self.history = []
@@ -91,33 +88,21 @@ def __init__(
9188
else:
9289
self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs)
9390

94-
def _get_cached_completion_fn(self, completion_fn, cache, enable_memory_cache):
91+
def _get_cached_completion_fn(self, completion_fn, cache):
9592
ignored_args_for_cache_key = ["api_key", "api_base", "base_url"]
96-
if cache and enable_memory_cache:
93+
if cache:
9794
completion_fn = request_cache(
9895
cache_arg_name="request",
9996
ignored_args_for_cache_key=ignored_args_for_cache_key,
10097
)(completion_fn)
101-
elif cache:
102-
completion_fn = request_cache(
103-
cache_arg_name="request",
104-
ignored_args_for_cache_key=ignored_args_for_cache_key,
105-
enable_memory_cache=False,
106-
)(completion_fn)
107-
else:
108-
completion_fn = completion_fn
10998

110-
if not cache or litellm.cache is None:
111-
litellm_cache_args = {"no-cache": True, "no-store": True}
112-
else:
113-
litellm_cache_args = {"no-cache": False, "no-store": False}
99+
litellm_cache_args = {"no-cache": True, "no-store": True}
114100

115101
return completion_fn, litellm_cache_args
116102

117103
def forward(self, prompt=None, messages=None, **kwargs):
118104
# Build the request.
119105
cache = kwargs.pop("cache", self.cache)
120-
enable_memory_cache = kwargs.pop("cache_in_memory", self.cache_in_memory)
121106

122107
messages = messages or [{"role": "user", "content": prompt}]
123108
kwargs = {**self.kwargs, **kwargs}
@@ -128,7 +113,7 @@ def forward(self, prompt=None, messages=None, **kwargs):
128113
completion = litellm_text_completion
129114
elif self.model_type == "responses":
130115
completion = litellm_responses_completion
131-
completion, litellm_cache_args = self._get_cached_completion_fn(completion, cache, enable_memory_cache)
116+
completion, litellm_cache_args = self._get_cached_completion_fn(completion, cache)
132117

133118
results = completion(
134119
request=dict(model=self.model, messages=messages, **kwargs),
@@ -145,7 +130,6 @@ def forward(self, prompt=None, messages=None, **kwargs):
145130
async def aforward(self, prompt=None, messages=None, **kwargs):
146131
# Build the request.
147132
cache = kwargs.pop("cache", self.cache)
148-
enable_memory_cache = kwargs.pop("cache_in_memory", self.cache_in_memory)
149133

150134
messages = messages or [{"role": "user", "content": prompt}]
151135
kwargs = {**self.kwargs, **kwargs}
@@ -156,7 +140,7 @@ async def aforward(self, prompt=None, messages=None, **kwargs):
156140
completion = alitellm_text_completion
157141
elif self.model_type == "responses":
158142
completion = alitellm_responses_completion
159-
completion, litellm_cache_args = self._get_cached_completion_fn(completion, cache, enable_memory_cache)
143+
completion, litellm_cache_args = self._get_cached_completion_fn(completion, cache)
160144

161145
results = await completion(
162146
request=dict(model=self.model, messages=messages, **kwargs),
@@ -246,7 +230,6 @@ def dump_state(self):
246230
"model",
247231
"model_type",
248232
"cache",
249-
"cache_in_memory",
250233
"num_retries",
251234
"finetuning_model",
252235
"launch_kwargs",

tests/clients/test_litellm_cache.py

Lines changed: 0 additions & 107 deletions
This file was deleted.

tests/clients/test_lm.py

Lines changed: 3 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -62,48 +62,13 @@ def test_chat_lms_can_be_queried(litellm_test_server):
6262
assert azure_openai_lm("azure openai query") == expected_response
6363

6464

65-
@pytest.mark.parametrize(
66-
("cache", "cache_in_memory"),
67-
[
68-
(True, True),
69-
(True, False),
70-
(False, True),
71-
(False, False),
72-
],
73-
)
74-
def test_litellm_cache(litellm_test_server, cache, cache_in_memory):
75-
api_base, _ = litellm_test_server
76-
expected_response = ["Hi!"]
77-
78-
original_cache = dspy.cache
79-
dspy.clients.configure_cache(
80-
enable_disk_cache=False,
81-
enable_memory_cache=cache_in_memory,
82-
enable_litellm_cache=cache,
83-
)
84-
85-
openai_lm = dspy.LM(
86-
model="openai/dspy-test-model",
87-
api_base=api_base,
88-
api_key="fakekey",
89-
model_type="chat",
90-
cache=cache,
91-
cache_in_memory=cache_in_memory,
92-
)
93-
assert openai_lm("openai query") == expected_response
94-
95-
# Reset the cache configuration
96-
dspy.cache = original_cache
97-
98-
9965
def test_dspy_cache(litellm_test_server, tmp_path):
10066
api_base, _ = litellm_test_server
10167

10268
original_cache = dspy.cache
10369
dspy.clients.configure_cache(
10470
enable_disk_cache=True,
10571
enable_memory_cache=True,
106-
enable_litellm_cache=False,
10772
disk_cache_dir=tmp_path / ".disk_cache",
10873
)
10974
cache = dspy.cache
@@ -288,7 +253,6 @@ def test_dump_state():
288253
"max_tokens": 100,
289254
"num_retries": 10,
290255
"cache": True,
291-
"cache_in_memory": True,
292256
"finetuning_model": None,
293257
"launch_kwargs": {"temperature": 1},
294258
"train_kwargs": {"temperature": 5},
@@ -377,7 +341,6 @@ async def test_async_lm_call_with_cache(tmp_path):
377341
dspy.clients.configure_cache(
378342
enable_disk_cache=True,
379343
enable_memory_cache=True,
380-
enable_litellm_cache=False,
381344
disk_cache_dir=tmp_path / ".disk_cache",
382345
)
383346
cache = dspy.cache
@@ -400,11 +363,10 @@ async def test_async_lm_call_with_cache(tmp_path):
400363
# Second call should hit the cache, so no new call to LiteLLM is made.
401364
assert mock_alitellm_completion.call_count == 1
402365

403-
# Test that explicitly disabling memory cache works
404-
await lm.acall("New query", cache_in_memory=False)
366+
# A new query should result in a new LiteLLM call and a new cache entry.
367+
await lm.acall("New query")
405368

406-
# There should be a new call to LiteLLM on new query, but the memory cache shouldn't be written to.
407-
assert len(cache.memory_cache) == 1
369+
assert len(cache.memory_cache) == 2
408370
assert mock_alitellm_completion.call_count == 2
409371

410372
dspy.cache = original_cache
@@ -470,7 +432,6 @@ def test_responses_api(litellm_test_server):
470432
api_key="fakekey",
471433
model_type="responses",
472434
cache=False,
473-
cache_in_memory=False,
474435
)
475436
assert lm("openai query") == [expected_text]
476437

@@ -501,7 +462,6 @@ def test_responses_api_tool_calls(litellm_test_server):
501462
api_key="fakekey",
502463
model_type="responses",
503464
cache=False,
504-
cache_in_memory=False,
505465
)
506466
assert lm("openai query") == expected_response
507467

tests/predict/test_predict.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ def test_lm_after_dump_and_load_state():
5656
"max_tokens": 100,
5757
"num_retries": 10,
5858
"cache": True,
59-
"cache_in_memory": True,
6059
"finetuning_model": None,
6160
"launch_kwargs": {},
6261
"train_kwargs": {},

0 commit comments

Comments
 (0)