@@ -154,10 +154,10 @@ async def update_endpoint(
154154 authm = await self ._db_reader .get_auth_material_by_provider_id (str (endpoint .id ))
155155
156156 models = await self ._find_models_for_provider (
157- endpoint , authm .auth_type , authm .auth_blob , prov
157+ endpoint . endpoint , authm .auth_type , authm .auth_blob , prov
158158 )
159159
160- await self ._update_models_for_provider (dbendpoint , endpoint , prov , models )
160+ await self ._update_models_for_provider (dbendpoint , models )
161161
162162 # a model might have been deleted, let's repopulate the cache
163163 await self ._ws_crud .repopulate_mux_cache ()
@@ -191,7 +191,7 @@ async def configure_auth_material(
191191 prov = endpoint .get_from_registry (provider_registry )
192192
193193 models = await self ._find_models_for_provider (
194- endpoint , config .auth_type , config .api_key , prov
194+ endpoint . endpoint , config .auth_type , config .api_key , prov
195195 )
196196
197197 await self ._db_writer .push_provider_auth_material (
@@ -202,35 +202,34 @@ async def configure_auth_material(
202202 )
203203 )
204204
205- await self ._update_models_for_provider (dbendpoint , endpoint , models )
205+ await self ._update_models_for_provider (dbendpoint , models )
206206
207207 # a model might have been deleted, let's repopulate the cache
208208 await self ._ws_crud .repopulate_mux_cache ()
209209
210210 async def _find_models_for_provider (
211211 self ,
212- endpoint : apimodelsv1 . ProviderEndpoint ,
212+ endpoint : str ,
213213 auth_type : apimodelsv1 .ProviderAuthType ,
214214 api_key : str ,
215215 prov : BaseProvider ,
216216 ) -> List [str ]:
217217 if auth_type != apimodelsv1 .ProviderAuthType .passthrough :
218218 try :
219- return prov .models (endpoint = endpoint . endpoint , api_key = api_key )
219+ return prov .models (endpoint = endpoint , api_key = api_key )
220220 except Exception as err :
221221 raise ProviderModelsNotFoundError (f"Unable to get models from provider: { err } " )
222222 return []
223223
224224 async def _update_models_for_provider (
225225 self ,
226226 dbendpoint : dbmodels .ProviderEndpoint ,
227- endpoint : apimodelsv1 .ProviderEndpoint ,
228227 found_models : List [str ],
229228 ) -> None :
230229 models_set = set (found_models )
231230
232231 # Get the models from the provider
233- models_in_db = await self ._db_reader .get_provider_models_by_provider_id (str (endpoint .id ))
232+ models_in_db = await self ._db_reader .get_provider_models_by_provider_id (str (dbendpoint .id ))
234233
235234 models_in_db_set = set (model .name for model in models_in_db )
236235
@@ -318,7 +317,7 @@ async def initialize_provider_endpoints(preg: ProviderRegistry):
318317 dbprovend = await db_reader .get_provider_endpoint_by_name (provend .name )
319318 if dbprovend is not None :
320319 logger .debug (
321- "Provider already in DB. Not re-adding. " ,
320+ "Provider already in DB. skipping " ,
322321 provider = provend .name ,
323322 endpoint = provend .endpoint ,
324323 )
@@ -334,6 +333,21 @@ async def initialize_provider_endpoints(preg: ProviderRegistry):
334333 continue
335334 await try_initialize_provider_endpoints (provend , pimpl , db_writer )
336335
336+ provcrud = ProviderCrud ()
337+
338+ endpoints = await provcrud .list_endpoints ()
339+ for endpoint in endpoints :
340+ dbprovend = await db_reader .get_provider_endpoint_by_name (endpoint .name )
341+ pimpl = endpoint .get_from_registry (preg )
342+ if pimpl is None :
343+ logger .warning (
344+ "Provider not found in registry" ,
345+ provider = endpoint .name ,
346+ endpoint = endpoint .endpoint ,
347+ )
348+ continue
349+ await try_update_to_provider (provcrud , pimpl , dbprovend )
350+
337351
338352async def try_initialize_provider_endpoints (
339353 provend : apimodelsv1 .ProviderEndpoint ,
@@ -376,6 +390,30 @@ async def try_initialize_provider_endpoints(
376390 await asyncio .gather (* tasks )
377391
378392
393+ async def try_update_to_provider (
394+ provcrud : ProviderCrud , prov : BaseProvider , dbprovend : dbmodels .ProviderEndpoint
395+ ):
396+
397+ authm = await provcrud ._db_reader .get_auth_material_by_provider_id (str (dbprovend .id ))
398+
399+ try :
400+ models = await provcrud ._find_models_for_provider (
401+ dbprovend .endpoint , authm .auth_type , authm .auth_blob , prov
402+ )
403+ except Exception as err :
404+ logger .error (
405+ "Unable to get models from provider. Skipping" ,
406+ provider = dbprovend .name ,
407+ err = str (err ),
408+ )
409+ return
410+
411+ await provcrud ._update_models_for_provider (dbprovend , models )
412+
413+ # a model might have been deleted, let's repopulate the cache
414+ await provcrud ._ws_crud .repopulate_mux_cache ()
415+
416+
379417def __provider_endpoint_from_cfg (
380418 provider_name : str , provider_url : str
381419) -> Optional [apimodelsv1 .ProviderEndpoint ]:
0 commit comments