Skip to content

Commit 5fb014d

Browse files
jakeloroccotuliocoppola
authored andcommitted
fix: async overhaul; create global event loop; add client cache (generative-computing#186)
* fix: async overhaul; create global event loop; add client cache * fix: watsonx test cicd check * feat: add client cache to openai and simplify setup * fix: add additional test for client cache
1 parent 691b308 commit 5fb014d

File tree

13 files changed

+403
-73
lines changed

13 files changed

+403
-73
lines changed

mellea/backends/litellm.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
convert_tools_to_json,
2323
)
2424
from mellea.backends.types import ModelOption
25-
from mellea.helpers.async_helpers import send_to_queue
25+
from mellea.helpers.async_helpers import get_current_event_loop, send_to_queue
2626
from mellea.helpers.fancy_logger import FancyLogger
2727
from mellea.helpers.openai_compatible_helpers import (
2828
chat_completion_delta_merge,
@@ -105,6 +105,8 @@ def __init__(
105105
ModelOption.STREAM: "stream",
106106
}
107107

108+
self._past_event_loops: set[int] = set()
109+
108110
def generate_from_context(
109111
self,
110112
action: Component | CBlock,
@@ -270,6 +272,11 @@ def _generate_from_chat_context_standard(
270272

271273
model_specific_options = self._make_backend_specific_and_remove(model_opts)
272274

275+
if self._has_potential_event_loop_errors():
276+
FancyLogger().get_logger().warning(
277+
"There is a known bug with litellm. This generation call may fail. If it does, you should ensure that you are either running only synchronous Mellea functions or running async Mellea functions from one asyncio.run() call."
278+
)
279+
273280
chat_response: Coroutine[
274281
Any, Any, litellm.ModelResponse | litellm.ModelResponseStream # type: ignore
275282
] = litellm.acompletion(
@@ -488,3 +495,24 @@ def _extract_model_tool_requests(
488495
if len(model_tool_calls) > 0:
489496
return model_tool_calls
490497
return None
498+
499+
def _has_potential_event_loop_errors(self) -> bool:
500+
"""In some cases litellm doesn't create a new async client. There doesn't appear to be any way for us to force that behavior. As a result, log a warning for known cases.
501+
502+
This whole function can be removed once the bug is fixed: https://github.com/BerriAI/litellm/issues/15294.
503+
"""
504+
# Async clients are tied to event loops.
505+
key = id(get_current_event_loop())
506+
507+
has_potential_issue = False
508+
if (
509+
len(self._past_event_loops) > 0
510+
and key not in self._past_event_loops
511+
and "watsonx/" in str(self.model_id)
512+
):
513+
has_potential_issue = True
514+
515+
# Add this loop to the known set.
516+
self._past_event_loops.add(key)
517+
518+
return has_potential_issue

mellea/backends/ollama.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@
1818
add_tools_from_model_options,
1919
)
2020
from mellea.backends.types import ModelOption
21-
from mellea.helpers.async_helpers import send_to_queue
21+
from mellea.helpers.async_helpers import (
22+
ClientCache,
23+
get_current_event_loop,
24+
send_to_queue,
25+
)
2226
from mellea.helpers.fancy_logger import FancyLogger
2327
from mellea.stdlib.base import (
2428
CBlock,
@@ -69,6 +73,11 @@ def __init__(
6973
self._base_url = base_url
7074
self._client = ollama.Client(base_url)
7175

76+
self._client_cache = ClientCache(2)
77+
78+
# Call once to set up an async client and prepopulate the cache.
79+
_ = self._async_client
80+
7281
if not self._check_ollama_server():
7382
err = f"could not create OllamaModelBackend: ollama server not running at {base_url}"
7483
FancyLogger.get_logger().error(err)
@@ -181,6 +190,17 @@ def _pull_ollama_model(self) -> bool:
181190
except ollama.ResponseError:
182191
return False
183192

193+
@property
194+
def _async_client(self) -> ollama.AsyncClient:
195+
"""Ollama's client gets tied to a specific event loop. Reset it if needed here."""
196+
key = id(get_current_event_loop())
197+
198+
_async_client = self._client_cache.get(key)
199+
if _async_client is None:
200+
_async_client = ollama.AsyncClient(self._base_url)
201+
self._client_cache.put(key, _async_client)
202+
return _async_client
203+
184204
def _simplify_and_merge(
185205
self, model_options: dict[str, Any] | None
186206
) -> dict[str, Any]:
@@ -318,13 +338,10 @@ def generate_from_chat_context(
318338
add_tools_from_context_actions(tools, [action])
319339
FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}")
320340

321-
# Ollama ties its async client to an event loop so we have to create it here.
322-
async_client = ollama.AsyncClient(self._base_url)
323-
324341
# Generate a chat response from ollama, using the chat messages. Can be either type since stream is passed as a model option.
325342
chat_response: Coroutine[
326343
Any, Any, AsyncIterator[ollama.ChatResponse] | ollama.ChatResponse
327-
] = async_client.chat(
344+
] = self._async_client.chat(
328345
model=self._get_ollama_model_id(),
329346
messages=conversation,
330347
tools=list(tools.values()),

mellea/backends/openai.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,11 @@
2929
convert_tools_to_json,
3030
)
3131
from mellea.backends.types import ModelOption
32-
from mellea.helpers.async_helpers import send_to_queue
32+
from mellea.helpers.async_helpers import (
33+
ClientCache,
34+
get_current_event_loop,
35+
send_to_queue,
36+
)
3337
from mellea.helpers.fancy_logger import FancyLogger
3438
from mellea.helpers.openai_compatible_helpers import (
3539
chat_completion_delta_merge,
@@ -164,18 +168,35 @@ def __init__(
164168
else:
165169
self._api_key = api_key
166170

167-
openai_client_kwargs = self.filter_openai_client_kwargs(**kwargs)
171+
self._openai_client_kwargs = self.filter_openai_client_kwargs(**kwargs)
168172

169173
self._client = openai.OpenAI( # type: ignore
170-
api_key=self._api_key, base_url=self._base_url, **openai_client_kwargs
171-
)
172-
self._async_client = openai.AsyncOpenAI(
173-
api_key=self._api_key, base_url=self._base_url, **openai_client_kwargs
174+
api_key=self._api_key, base_url=self._base_url, **self._openai_client_kwargs
174175
)
175176

177+
self._client_cache = ClientCache(2)
178+
179+
# Call once to create an async_client and populate the cache.
180+
_ = self._async_client
181+
176182
# ALoras that have been loaded for this model.
177183
self._aloras: dict[str, OpenAIAlora] = {}
178184

185+
@property
186+
def _async_client(self) -> openai.AsyncOpenAI:
187+
"""OpenAI's client usually handles changing event loops but explicitly handle it here for edge cases."""
188+
key = id(get_current_event_loop())
189+
190+
_async_client = self._client_cache.get(key)
191+
if _async_client is None:
192+
_async_client = openai.AsyncOpenAI(
193+
api_key=self._api_key,
194+
base_url=self._base_url,
195+
**self._openai_client_kwargs,
196+
)
197+
self._client_cache.put(key, _async_client)
198+
return _async_client
199+
179200
@staticmethod
180201
def filter_openai_client_kwargs(**kwargs) -> dict:
181202
"""Filter kwargs to only include valid OpenAI client parameters."""

mellea/backends/watsonx.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@
2222
convert_tools_to_json,
2323
)
2424
from mellea.backends.types import ModelOption
25-
from mellea.helpers.async_helpers import send_to_queue
25+
from mellea.helpers.async_helpers import (
26+
ClientCache,
27+
get_current_event_loop,
28+
send_to_queue,
29+
)
2630
from mellea.helpers.fancy_logger import FancyLogger
2731
from mellea.helpers.openai_compatible_helpers import (
2832
chat_completion_delta_merge,
@@ -93,15 +97,12 @@ def __init__(
9397
self._project_id = os.environ.get("WATSONX_PROJECT_ID")
9498

9599
self._creds = Credentials(url=base_url, api_key=api_key)
96-
_client = APIClient(credentials=self._creds)
97-
self._model_inference = ModelInference(
98-
model_id=self._get_watsonx_model_id(),
99-
api_client=_client,
100-
credentials=self._creds,
101-
project_id=self._project_id,
102-
params=self.model_options,
103-
**kwargs,
104-
)
100+
self._kwargs = kwargs
101+
102+
self._client_cache = ClientCache(2)
103+
104+
# Call once to set up the model inference and prepopulate the cache.
105+
_ = self._model
105106

106107
# A mapping of common options for this backend mapped to their Mellea ModelOptions equivalent.
107108
# These are usually values that must be extracted before hand or that are common among backend providers.
@@ -134,16 +135,22 @@ def __init__(
134135

135136
@property
136137
def _model(self) -> ModelInference:
137-
"""Watsonx's client gets tied to a specific event loop. Reset it here."""
138-
_client = APIClient(credentials=self._creds)
139-
self._model_inference = ModelInference(
140-
model_id=self._get_watsonx_model_id(),
141-
api_client=_client,
142-
credentials=self._creds,
143-
project_id=self._project_id,
144-
params=self.model_options,
145-
)
146-
return self._model_inference
138+
"""Watsonx's client gets tied to a specific event loop. Reset it if needed here."""
139+
key = id(get_current_event_loop())
140+
141+
_model_inference = self._client_cache.get(key)
142+
if _model_inference is None:
143+
_client = APIClient(credentials=self._creds)
144+
_model_inference = ModelInference(
145+
model_id=self._get_watsonx_model_id(),
146+
api_client=_client,
147+
credentials=self._creds,
148+
project_id=self._project_id,
149+
params=self.model_options,
150+
**self._kwargs,
151+
)
152+
self._client_cache.put(key, _model_inference)
153+
return _model_inference
147154

148155
def _get_watsonx_model_id(self) -> str:
149156
"""Gets the watsonx model id from the model_id that was provided in the constructor. Raises AssertionError if the ModelIdentifier does not provide a watsonx_name."""

mellea/helpers/async_helpers.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""Async helper functions."""
22

33
import asyncio
4+
from collections import OrderedDict
45
from collections.abc import AsyncIterator, Coroutine
5-
from typing import Any
6+
from typing import Any, TypeVar
67

78
from mellea.stdlib.base import ModelOutputThunk
89

@@ -46,3 +47,53 @@ async def wait_for_all_mots(mots: list[ModelOutputThunk]):
4647
coroutines.append(mot.avalue())
4748

4849
await asyncio.gather(*coroutines)
50+
51+
52+
def get_current_event_loop() -> None | asyncio.AbstractEventLoop:
53+
"""Get the current event loop without having to catch exceptions."""
54+
loop = None
55+
try:
56+
loop = asyncio.get_running_loop()
57+
except RuntimeError:
58+
pass
59+
return loop
60+
61+
62+
class ClientCache:
63+
"""A simple [LRU](https://en.wikipedia.org/wiki/Cache_replacement_policies#Least_Recently_Used_(LRU)) cache.
64+
65+
Used to keep track of clients for backends where the client is tied to a specific event loop.
66+
"""
67+
68+
def __init__(self, capacity: int):
69+
"""Initializes the LRU cache with a certain capacity.
70+
71+
The `ClientCache` either contains a value or it doesn't.
72+
"""
73+
self.capacity = capacity
74+
self.cache: OrderedDict = OrderedDict()
75+
76+
def current_size(self):
77+
"""Just return the size of the key set. This isn't necessarily safe."""
78+
return len(self.cache.keys())
79+
80+
def get(self, key: int) -> Any | None:
81+
"""Gets a value from the cache."""
82+
if key not in self.cache:
83+
return None
84+
else:
85+
# Move the accessed item to the end (most recent)
86+
value = self.cache.pop(key)
87+
self.cache[key] = value
88+
return value
89+
90+
def put(self, key: int, value: Any):
91+
"""Put a value into the cache."""
92+
if key in self.cache:
93+
# If the key exists, move it to the end (most recent)
94+
self.cache.pop(key)
95+
elif len(self.cache) >= self.capacity:
96+
# If the cache is full, remove the least recently used item
97+
self.cache.popitem(last=False)
98+
# Add the new key-value pair to the end (most recent)
99+
self.cache[key] = value
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""Helper for event loop management. Allows consistently running async generate requests in sync code."""
2+
3+
import asyncio
4+
import threading
5+
from collections.abc import Coroutine
6+
from typing import Any, TypeVar
7+
8+
R = TypeVar("R")
9+
10+
11+
class _EventLoopHandler:
12+
"""A class that handles the event loop for Mellea code. Do not directly instantiate this. Use `_run_async_in_thread`."""
13+
14+
def __init__(self):
15+
"""Instantiates an EventLoopHandler. Used to ensure consistency when calling async code from sync code in Mellea.
16+
17+
Do not instantiate this class. Rely on the exported `_run_async_in_thread` function.
18+
"""
19+
self._event_loop = asyncio.new_event_loop()
20+
self._thread: threading.Thread = threading.Thread(
21+
target=self._event_loop.run_forever, daemon=True
22+
)
23+
self._thread.start()
24+
25+
def __del__(self):
26+
"""Delete the event loop handler."""
27+
self._close_event_loop()
28+
29+
def _close_event_loop(self) -> None:
30+
"""Called when deleting the event loop handler. Cleans up the event loop and thread."""
31+
if self._event_loop:
32+
try:
33+
tasks = asyncio.all_tasks(self._event_loop)
34+
for task in tasks:
35+
task.cancel()
36+
37+
async def finalize_tasks():
38+
# TODO: We can log errors here if needed.
39+
await asyncio.gather(*tasks, return_exceptions=True)
40+
41+
out = asyncio.run_coroutine_threadsafe(
42+
finalize_tasks(), self._event_loop
43+
)
44+
45+
# Timeout if needed.
46+
out.result(5)
47+
except Exception:
48+
pass
49+
50+
# Finally stop the event loop for this session.
51+
self._event_loop.stop()
52+
53+
def __call__(self, co: Coroutine[Any, Any, R]) -> R:
54+
"""Runs the coroutine in the event loop."""
55+
return asyncio.run_coroutine_threadsafe(co, self._event_loop).result()
56+
57+
58+
# Instantiate this class once. It will not be re-instantiated.
59+
__event_loop_handler = _EventLoopHandler()
60+
61+
62+
def _run_async_in_thread(co: Coroutine[Any, Any, R]) -> R:
63+
"""Call to run async code from synchronous code in Mellea.
64+
65+
In Mellea, we utilize async code underneath sync code to speed up
66+
inference requests. This puts us in a difficult situation since most
67+
api providers and sdks use async clients that get bound to a specific event
68+
loop to make requests. These clients are typically long-lasting and sometimes
69+
cannot be easily reinstantiated on demand to avoid these issues.
70+
By declaring a single event loop for these async requests,
71+
Mellea avoids these client issues.
72+
73+
Note: This implementation requires that sessions/backends be run only through
74+
the top-level / session sync or async interfaces, not both. You will need to
75+
reinstantiate your backend if switching between the two.
76+
77+
Args:
78+
co: coroutine to run
79+
80+
Returns:
81+
output of the coroutine
82+
"""
83+
return __event_loop_handler(co)
84+
85+
86+
__all__ = ["_run_async_in_thread"]

0 commit comments

Comments
 (0)