Skip to content

Commit 1bae71c

Browse files
authored
Retries via LiteLLM RetryPolicy (#1866)
* Retr Signed-off-by: dbczumar <[email protected]> * works Signed-off-by: dbczumar <[email protected]> * retry Signed-off-by: dbczumar <[email protected]> * Retry Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * Rename Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * Update test_lm.py * Update test_lm.py * Update test_lm.py * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * Make tests more robust Signed-off-by: dbczumar <[email protected]> * Update pyproject.toml --------- Signed-off-by: dbczumar <[email protected]>
1 parent 7c2f604 commit 1bae71c

File tree

7 files changed

+138
-18
lines changed

7 files changed

+138
-18
lines changed

dspy/clients/lm.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import pydantic
1212
import ujson
1313
from cachetools import LRUCache, cached
14+
from litellm import RetryPolicy
1415

1516
from dspy.adapters.base import Adapter
1617
from dspy.clients.openai import OpenAIProvider
@@ -36,7 +37,7 @@ def __init__(
3637
max_tokens: int = 1000,
3738
cache: bool = True,
3839
callbacks: Optional[List[BaseCallback]] = None,
39-
num_retries: int = 3,
40+
num_retries: int = 8,
4041
provider=None,
4142
finetuning_model: Optional[str] = None,
4243
launch_kwargs: Optional[dict[str, Any]] = None,
@@ -102,14 +103,13 @@ def __call__(self, prompt=None, messages=None, **kwargs):
102103
outputs = [
103104
{
104105
"text": c.message.content if hasattr(c, "message") else c["text"],
105-
"logprobs": c.logprobs if hasattr(c, "logprobs") else c["logprobs"]
106+
"logprobs": c.logprobs if hasattr(c, "logprobs") else c["logprobs"],
106107
}
107108
for c in response["choices"]
108109
]
109110
else:
110111
outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]]
111112

112-
113113
# Logging, with removed api key & where `cost` is None on cache hit.
114114
kwargs = {k: v for k, v in kwargs.items() if not k.startswith("api_")}
115115
entry = dict(prompt=prompt, messages=messages, kwargs=kwargs, response=response)
@@ -310,8 +310,12 @@ def cached_litellm_completion(request: Dict[str, Any], num_retries: int):
310310

311311
def litellm_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}):
312312
return litellm.completion(
313-
num_retries=num_retries,
314313
cache=cache,
314+
retry_policy=_get_litellm_retry_policy(num_retries),
315+
# In LiteLLM version 1.55.3 (the first version that supports retry_policy as an argument
316+
# to completion()), the default value of max_retries is non-zero for certain providers, and
317+
# max_retries is stacked on top of the retry_policy. To avoid this, we set max_retries=0
318+
max_retries=0,
315319
**request,
316320
)
317321

@@ -344,6 +348,32 @@ def litellm_text_completion(request: Dict[str, Any], num_retries: int, cache={"n
344348
api_key=api_key,
345349
api_base=api_base,
346350
prompt=prompt,
347-
num_retries=num_retries,
351+
retry_policy=_get_litellm_retry_policy(num_retries),
352+
# In LiteLLM version 1.55.3 (the first version that supports retry_policy as an argument
353+
# to completion()), the default value of max_retries is non-zero for certain providers, and
354+
# max_retries is stacked on top of the retry_policy. To avoid this, we set max_retries=0
355+
max_retries=0,
348356
**request,
349357
)
358+
359+
360+
def _get_litellm_retry_policy(num_retries: int) -> RetryPolicy:
361+
"""
362+
Get a LiteLLM retry policy for retrying requests when transient API errors occur.
363+
Args:
364+
num_retries: The number of times to retry a request if it fails transiently due to
365+
network error, rate limiting, etc. Requests are retried with exponential
366+
backoff.
367+
Returns:
368+
A LiteLLM RetryPolicy instance.
369+
"""
370+
return RetryPolicy(
371+
TimeoutErrorRetries=num_retries,
372+
RateLimitErrorRetries=num_retries,
373+
InternalServerErrorRetries=num_retries,
374+
ContentPolicyViolationErrorRetries=num_retries,
375+
# We don't retry on errors that are unlikely to be transient
376+
# (e.g. bad request, invalid auth credentials)
377+
BadRequestErrorRetries=0,
378+
AuthenticationErrorRetries=0,
379+
)

poetry.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ dependencies = [
3838
"pydantic~=2.0",
3939
"jinja2",
4040
"magicattr~=0.1.6",
41-
"litellm",
41+
"litellm==1.55.3",
4242
"diskcache",
4343
"json-repair",
4444
"tenacity>=8.2.3",
@@ -132,7 +132,7 @@ pgvector = { version = "^0.2.5", optional = true }
132132
llama-index = { version = "^0.10.30", optional = true }
133133
jinja2 = "^3.1.3"
134134
magicattr = "^0.1.6"
135-
litellm = { version = "==1.53.7", extras = ["proxy"] }
135+
litellm = { version = "==1.55.3", extras = ["proxy"] }
136136
diskcache = "^5.6.0"
137137
json-repair = "^0.30.0"
138138
tenacity = ">=8.2.3"

requirements.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@ anyio
22
asyncer==0.0.8
33
backoff
44
cachetools
5+
cloudpickle
56
datasets
67
diskcache
78
httpx
9+
jinja2
810
joblib~=1.3
911
json-repair
10-
litellm[proxy]==1.53.7
12+
litellm[proxy]==1.55.3
1113
magicattr~=0.1.6
1214
openai
1315
optuna
@@ -18,5 +20,3 @@ requests
1820
tenacity>=8.2.3
1921
tqdm
2022
ujson
21-
cloudpickle
22-
jinja2

tests/clients/test_lm.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66

77
import dspy
8-
from tests.test_utils.server import litellm_test_server
8+
from tests.test_utils.server import litellm_test_server, read_litellm_test_server_request_logs
99

1010

1111
def test_chat_lms_can_be_queried(litellm_test_server):
@@ -84,3 +84,75 @@ class ResponseFormat(pydantic.BaseModel):
8484
response_format=ResponseFormat,
8585
)
8686
lm("Query")
87+
88+
89+
@pytest.mark.parametrize(
90+
("error_code", "expected_exception", "expected_num_retries"),
91+
[
92+
("429", litellm.RateLimitError, 2),
93+
("504", litellm.Timeout, 3),
94+
# Don't retry on user errors
95+
("400", litellm.BadRequestError, 0),
96+
("401", litellm.AuthenticationError, 0),
97+
# TODO: LiteLLM retry logic isn't implemented properly for internal server errors
98+
# and content policy violations, both of which may be transient and should be retried
99+
# ("content-policy-violation, litellm.BadRequestError, 1),
100+
# ("500", litellm.InternalServerError, 0, 1),
101+
],
102+
)
103+
def test_lm_chat_calls_are_retried_for_expected_failures(
104+
litellm_test_server,
105+
error_code,
106+
expected_exception,
107+
expected_num_retries,
108+
):
109+
api_base, server_log_file_path = litellm_test_server
110+
111+
openai_lm = dspy.LM(
112+
model="openai/dspy-test-model",
113+
api_base=api_base,
114+
api_key="fakekey",
115+
num_retries=expected_num_retries,
116+
model_type="chat",
117+
)
118+
with pytest.raises(expected_exception):
119+
openai_lm(error_code)
120+
121+
request_logs = read_litellm_test_server_request_logs(server_log_file_path)
122+
assert len(request_logs) == expected_num_retries + 1 # 1 initial request + 1 retries
123+
124+
125+
@pytest.mark.parametrize(
126+
("error_code", "expected_exception", "expected_num_retries"),
127+
[
128+
("429", litellm.RateLimitError, 2),
129+
("504", litellm.Timeout, 3),
130+
# Don't retry on user errors
131+
("400", litellm.BadRequestError, 0),
132+
("401", litellm.AuthenticationError, 0),
133+
# TODO: LiteLLM retry logic isn't implemented properly for internal server errors
134+
# and content policy violations, both of which may be transient and should be retried
135+
# ("content-policy-violation, litellm.BadRequestError, 2),
136+
# ("500", litellm.InternalServerError, 0, 2),
137+
],
138+
)
139+
def test_lm_text_calls_are_retried_for_expected_failures(
140+
litellm_test_server,
141+
error_code,
142+
expected_exception,
143+
expected_num_retries,
144+
):
145+
api_base, server_log_file_path = litellm_test_server
146+
147+
openai_lm = dspy.LM(
148+
model="openai/dspy-test-model",
149+
api_base=api_base,
150+
api_key="fakekey",
151+
num_retries=expected_num_retries,
152+
model_type="text",
153+
)
154+
with pytest.raises(expected_exception):
155+
openai_lm(error_code)
156+
157+
request_logs = read_litellm_test_server_request_logs(server_log_file_path)
158+
assert len(request_logs) == expected_num_retries + 1 # 1 initial request + 1 retries

tests/test_utils/server/litellm_server.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,38 @@
1010
class DSPyTestModel(CustomLLM):
1111
def completion(self, *args, **kwargs) -> litellm.ModelResponse:
1212
_append_request_to_log_file(kwargs)
13-
return _get_mock_llm_response()
13+
return _get_mock_llm_response(kwargs)
1414

1515
async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse:
1616
_append_request_to_log_file(kwargs)
17-
return _get_mock_llm_response()
17+
return _get_mock_llm_response(kwargs)
1818

1919

20-
def _get_mock_llm_response():
20+
def _get_mock_llm_response(request_kwargs):
21+
_throw_exception_based_on_content_if_applicable(request_kwargs)
2122
return litellm.completion(
2223
model="gpt-3.5-turbo",
2324
messages=[{"role": "user", "content": "Hello world"}],
2425
mock_response="Hi!",
2526
)
2627

2728

29+
def _throw_exception_based_on_content_if_applicable(request_kwargs):
30+
"""
31+
Throws an exception, for testing purposes, based on the content of the request message.
32+
"""
33+
model = request_kwargs["model"]
34+
content = request_kwargs["messages"][0]["content"]
35+
if "429" in content:
36+
raise litellm.RateLimitError(message="Rate limit exceeded", llm_provider=None, model=model)
37+
elif "504" in content:
38+
raise litellm.Timeout("Request timed out!", llm_provider=None, model=model)
39+
elif "400" in content:
40+
raise litellm.BadRequestError(message="Bad request", llm_provider=None, model=model)
41+
elif "401" in content:
42+
raise litellm.AuthenticationError(message="Authentication error", llm_provider=None, model=model)
43+
44+
2845
def _append_request_to_log_file(completion_kwargs):
2946
log_file_path = os.environ.get(LITELLM_TEST_SERVER_LOG_FILE_PATH_ENV_VAR)
3047
if log_file_path is None:

tests/test_utils/server/litellm_server_config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ model_list:
77
model: "dspy-test-provider/dspy-test-model"
88

99
litellm_settings:
10+
num_retries: 0
1011
custom_provider_map:
1112
- {
1213
"provider": "dspy-test-provider",

0 commit comments

Comments
 (0)