Skip to content

Commit 8374083

Browse files
jakeloroccotuliocoppola
authored andcommitted
feat: ollama generate_from_raw uses existing event loop (generative-computing#204)
* feat: ollama generate_from_raw uses existing event loop * fix: add blocking prevention mech * fix: test issues with cache
1 parent 1a7ebd6 commit 8374083

File tree

5 files changed

+26
-23
lines changed

5 files changed

+26
-23
lines changed

mellea/backends/ollama.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
get_current_event_loop,
2424
send_to_queue,
2525
)
26+
from mellea.helpers.event_loop_helper import _run_async_in_thread
2627
from mellea.helpers.fancy_logger import FancyLogger
2728
from mellea.stdlib.base import (
2829
CBlock,
@@ -404,28 +405,26 @@ def _generate_from_raw(
404405
# See https://github.com/ollama/ollama/blob/main/docs/faq.md#how-does-ollama-handle-concurrent-requests.
405406
prompts = [self.formatter.print(action) for action in actions]
406407

407-
async def get_response(coroutines):
408+
async def get_response():
409+
# Run async so that we can make use of Ollama's concurrency.
410+
coroutines: list[Coroutine[Any, Any, ollama.GenerateResponse]] = []
411+
for prompt in prompts:
412+
co = self._async_client.generate(
413+
model=self._get_ollama_model_id(),
414+
prompt=prompt,
415+
raw=True,
416+
think=model_opts.get(ModelOption.THINKING, None),
417+
format=format.model_json_schema() if format is not None else None,
418+
options=self._make_backend_specific_and_remove(model_opts),
419+
)
420+
coroutines.append(co)
421+
408422
responses = await asyncio.gather(*coroutines, return_exceptions=True)
409423
return responses
410424

411-
async_client = ollama.AsyncClient(self._base_url)
412-
# Run async so that we can make use of Ollama's concurrency.
413-
coroutines = []
414-
for prompt in prompts:
415-
co = async_client.generate(
416-
model=self._get_ollama_model_id(),
417-
prompt=prompt,
418-
raw=True,
419-
think=model_opts.get(ModelOption.THINKING, None),
420-
format=format.model_json_schema() if format is not None else None,
421-
options=self._make_backend_specific_and_remove(model_opts),
422-
)
423-
coroutines.append(co)
424-
425-
# Revisit this once we start using async elsewhere. Only one asyncio event
426-
# loop can be running in a given thread.
427-
responses: list[ollama.GenerateResponse | BaseException] = asyncio.run(
428-
get_response(coroutines)
425+
# Run in the same event_loop like other Mellea async code called from a sync function.
426+
responses: list[ollama.GenerateResponse | BaseException] = _run_async_in_thread(
427+
get_response()
429428
)
430429

431430
results = []

mellea/helpers/event_loop_helper.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from collections.abc import Coroutine
66
from typing import Any, TypeVar
77

8+
from mellea.helpers.async_helpers import get_current_event_loop
9+
810
R = TypeVar("R")
911

1012

@@ -52,6 +54,9 @@ async def finalize_tasks():
5254

5355
def __call__(self, co: Coroutine[Any, Any, R]) -> R:
5456
"""Runs the coroutine in the event loop."""
57+
if self._event_loop == get_current_event_loop():
58+
# If this gets called from the same event loop, launch in a separate thread to prevent blocking.
59+
return _EventLoopHandler()(co)
5560
return asyncio.run_coroutine_threadsafe(co, self._event_loop).result()
5661

5762

test/backends/test_ollama.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,7 @@ async def get_client_async():
193193

194194
fourth_client = asyncio.run(get_client_async())
195195
assert fourth_client in backend._client_cache.cache.values()
196-
assert second_client not in backend._client_cache.cache.values()
197-
196+
assert len(backend._client_cache.cache.values()) == 2
198197

199198
if __name__ == "__main__":
200199
pytest.main([__file__])

test/backends/test_openai_ollama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ async def get_client_async():
206206

207207
fourth_client = asyncio.run(get_client_async())
208208
assert fourth_client in backend._client_cache.cache.values()
209-
assert second_client not in backend._client_cache.cache.values()
209+
assert len(backend._client_cache.cache.values()) == 2
210210

211211
if __name__ == "__main__":
212212
import pytest

test/backends/test_watsonx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ async def get_client_async():
167167

168168
fourth_client = asyncio.run(get_client_async())
169169
assert fourth_client in backend._client_cache.cache.values()
170-
assert second_client not in backend._client_cache.cache.values()
170+
assert len(backend._client_cache.cache.values()) == 2
171171

172172
if __name__ == "__main__":
173173
import pytest

0 commit comments

Comments
 (0)