|
23 | 23 | get_current_event_loop, |
24 | 24 | send_to_queue, |
25 | 25 | ) |
| 26 | +from mellea.helpers.event_loop_helper import _run_async_in_thread |
26 | 27 | from mellea.helpers.fancy_logger import FancyLogger |
27 | 28 | from mellea.stdlib.base import ( |
28 | 29 | CBlock, |
@@ -404,28 +405,26 @@ def _generate_from_raw( |
404 | 405 | # See https://github.com/ollama/ollama/blob/main/docs/faq.md#how-does-ollama-handle-concurrent-requests. |
405 | 406 | prompts = [self.formatter.print(action) for action in actions] |
406 | 407 |
|
407 | | - async def get_response(coroutines): |
| 408 | + async def get_response(): |
| 409 | + # Run async so that we can make use of Ollama's concurrency. |
| 410 | + coroutines: list[Coroutine[Any, Any, ollama.GenerateResponse]] = [] |
| 411 | + for prompt in prompts: |
| 412 | + co = self._async_client.generate( |
| 413 | + model=self._get_ollama_model_id(), |
| 414 | + prompt=prompt, |
| 415 | + raw=True, |
| 416 | + think=model_opts.get(ModelOption.THINKING, None), |
| 417 | + format=format.model_json_schema() if format is not None else None, |
| 418 | + options=self._make_backend_specific_and_remove(model_opts), |
| 419 | + ) |
| 420 | + coroutines.append(co) |
| 421 | + |
408 | 422 | responses = await asyncio.gather(*coroutines, return_exceptions=True) |
409 | 423 | return responses |
410 | 424 |
|
411 | | - async_client = ollama.AsyncClient(self._base_url) |
412 | | - # Run async so that we can make use of Ollama's concurrency. |
413 | | - coroutines = [] |
414 | | - for prompt in prompts: |
415 | | - co = async_client.generate( |
416 | | - model=self._get_ollama_model_id(), |
417 | | - prompt=prompt, |
418 | | - raw=True, |
419 | | - think=model_opts.get(ModelOption.THINKING, None), |
420 | | - format=format.model_json_schema() if format is not None else None, |
421 | | - options=self._make_backend_specific_and_remove(model_opts), |
422 | | - ) |
423 | | - coroutines.append(co) |
424 | | - |
425 | | - # Revisit this once we start using async elsewhere. Only one asyncio event |
426 | | - # loop can be running in a given thread. |
427 | | - responses: list[ollama.GenerateResponse | BaseException] = asyncio.run( |
428 | | - get_response(coroutines) |
| 425 | + # Run in the same event_loop like other Mellea async code called from a sync function. |
| 426 | + responses: list[ollama.GenerateResponse | BaseException] = _run_async_in_thread( |
| 427 | + get_response() |
429 | 428 | ) |
430 | 429 |
|
431 | 430 | results = [] |
|
0 commit comments