|
22 | 22 | convert_tools_to_json, |
23 | 23 | ) |
24 | 24 | from mellea.backends.types import ModelOption |
25 | | -from mellea.helpers.async_helpers import send_to_queue |
| 25 | +from mellea.helpers.async_helpers import ( |
| 26 | + ClientCache, |
| 27 | + get_current_event_loop, |
| 28 | + send_to_queue, |
| 29 | +) |
26 | 30 | from mellea.helpers.fancy_logger import FancyLogger |
27 | 31 | from mellea.helpers.openai_compatible_helpers import ( |
28 | 32 | chat_completion_delta_merge, |
@@ -93,15 +97,12 @@ def __init__( |
93 | 97 | self._project_id = os.environ.get("WATSONX_PROJECT_ID") |
94 | 98 |
|
95 | 99 | self._creds = Credentials(url=base_url, api_key=api_key) |
96 | | - _client = APIClient(credentials=self._creds) |
97 | | - self._model_inference = ModelInference( |
98 | | - model_id=self._get_watsonx_model_id(), |
99 | | - api_client=_client, |
100 | | - credentials=self._creds, |
101 | | - project_id=self._project_id, |
102 | | - params=self.model_options, |
103 | | - **kwargs, |
104 | | - ) |
| 100 | + self._kwargs = kwargs |
| 101 | + |
| 102 | + self._client_cache = ClientCache(2) |
| 103 | + |
| 104 | + # Call once to set up the model inference and prepopulate the cache. |
| 105 | + _ = self._model |
105 | 106 |
|
106 | 107 | # A mapping of common options for this backend mapped to their Mellea ModelOptions equivalent. |
107 | 108 | # These are usually values that must be extracted before hand or that are common among backend providers. |
@@ -134,16 +135,22 @@ def __init__( |
134 | 135 |
|
135 | 136 | @property |
136 | 137 | def _model(self) -> ModelInference: |
137 | | - """Watsonx's client gets tied to a specific event loop. Reset it here.""" |
138 | | - _client = APIClient(credentials=self._creds) |
139 | | - self._model_inference = ModelInference( |
140 | | - model_id=self._get_watsonx_model_id(), |
141 | | - api_client=_client, |
142 | | - credentials=self._creds, |
143 | | - project_id=self._project_id, |
144 | | - params=self.model_options, |
145 | | - ) |
146 | | - return self._model_inference |
| 138 | + """Watsonx's client gets tied to a specific event loop. Reset it if needed here.""" |
| 139 | + key = id(get_current_event_loop()) |
| 140 | + |
| 141 | + _model_inference = self._client_cache.get(key) |
| 142 | + if _model_inference is None: |
| 143 | + _client = APIClient(credentials=self._creds) |
| 144 | + _model_inference = ModelInference( |
| 145 | + model_id=self._get_watsonx_model_id(), |
| 146 | + api_client=_client, |
| 147 | + credentials=self._creds, |
| 148 | + project_id=self._project_id, |
| 149 | + params=self.model_options, |
| 150 | + **self._kwargs, |
| 151 | + ) |
| 152 | + self._client_cache.put(key, _model_inference) |
| 153 | + return _model_inference |
147 | 154 |
|
148 | 155 | def _get_watsonx_model_id(self) -> str: |
149 | 156 | """Gets the watsonx model id from the model_id that was provided in the constructor. Raises AssertionError if the ModelIdentifier does not provide a watsonx_name.""" |
|
0 commit comments