@@ -114,7 +114,7 @@ async def add_endpoint(
114
114
return apimodelsv1 .ProviderEndpoint .from_db_model (dbendpoint )
115
115
116
116
async def update_endpoint (
117
- self , endpoint : apimodelsv1 .AddProviderEndpointRequest
117
+ self , endpoint : apimodelsv1 .ProviderEndpoint
118
118
) -> apimodelsv1 .ProviderEndpoint :
119
119
"""Update an endpoint."""
120
120
@@ -134,12 +134,40 @@ async def update_endpoint(
134
134
if founddbe is None :
135
135
raise ProviderNotFoundError ("Provider not found" )
136
136
137
- models = []
138
- if endpoint .auth_type == apimodelsv1 .ProviderAuthType .api_key and not endpoint .api_key :
137
+ dbendpoint = await self ._db_writer .update_provider_endpoint (endpoint .to_db_model ())
138
+
139
+ return apimodelsv1 .ProviderEndpoint .from_db_model (dbendpoint )
140
+
141
+ async def configure_auth_material (
142
+ self , provider_id : UUID , config : apimodelsv1 .ConfigureAuthMaterial
143
+ ):
144
+ """Add an API key."""
145
+ if config .auth_type == apimodelsv1 .ProviderAuthType .api_key and not config .api_key :
139
146
raise ValueError ("API key must be provided for API auth type" )
140
- if endpoint .auth_type != apimodelsv1 .ProviderAuthType .passthrough :
147
+ elif config .auth_type != apimodelsv1 .ProviderAuthType .api_key and config .api_key :
148
+ raise ValueError ("API key provided for non-API auth type" )
149
+
150
+ dbendpoint = await self ._db_reader .get_provider_endpoint_by_id (str (provider_id ))
151
+ if dbendpoint is None :
152
+ raise ProviderNotFoundError ("Provider not found" )
153
+
154
+ await self ._db_writer .push_provider_auth_material (
155
+ dbmodels .ProviderAuthMaterial (
156
+ provider_endpoint_id = dbendpoint .id ,
157
+ auth_type = config .auth_type ,
158
+ auth_blob = config .api_key if config .api_key else "" ,
159
+ )
160
+ )
161
+
162
+ endpoint = apimodelsv1 .ProviderEndpoint .from_db_model (dbendpoint )
163
+ endpoint .auth_type = config .auth_type
164
+ provider_registry = get_provider_registry ()
165
+ prov = endpoint .get_from_registry (provider_registry )
166
+
167
+ models = []
168
+ if config .auth_type != apimodelsv1 .ProviderAuthType .passthrough :
141
169
try :
142
- models = prov .models (endpoint = endpoint .endpoint , api_key = endpoint .api_key )
170
+ models = prov .models (endpoint = endpoint .endpoint , api_key = config .api_key )
143
171
except Exception as err :
144
172
raise ValueError ("Unable to get models from provider: {}" .format (str (err )))
145
173
@@ -154,56 +182,21 @@ async def update_endpoint(
154
182
for model in models_set - models_in_db_set :
155
183
await self ._db_writer .add_provider_model (
156
184
dbmodels .ProviderModel (
157
- provider_endpoint_id = founddbe .id ,
185
+ provider_endpoint_id = dbendpoint .id ,
158
186
name = model ,
159
187
)
160
188
)
161
189
162
190
# Remove the models that are in the DB but not in the provider
163
191
for model in models_in_db_set - models_set :
164
192
await self ._db_writer .delete_provider_model (
165
- founddbe .id ,
193
+ dbendpoint .id ,
166
194
model ,
167
195
)
168
196
169
- dbendpoint = await self ._db_writer .update_provider_endpoint (endpoint .to_db_model ())
170
-
171
- # If an API key was provided or we've changed the auth type, we update the auth material
172
- if endpoint .auth_type != founddbe .auth_type or endpoint .api_key :
173
- await self ._db_writer .push_provider_auth_material (
174
- dbmodels .ProviderAuthMaterial (
175
- provider_endpoint_id = dbendpoint .id ,
176
- auth_type = endpoint .auth_type ,
177
- auth_blob = endpoint .api_key if endpoint .api_key else "" ,
178
- )
179
- )
180
-
181
197
# a model might have been deleted, let's repopulate the cache
182
198
await self ._ws_crud .repopulate_mux_cache ()
183
199
184
- return apimodelsv1 .ProviderEndpoint .from_db_model (dbendpoint )
185
-
186
- async def configure_auth_material (
187
- self , provider_id : UUID , config : apimodelsv1 .ConfigureAuthMaterial
188
- ):
189
- """Add an API key."""
190
- if config .auth_type == apimodelsv1 .ProviderAuthType .api_key and not config .api_key :
191
- raise ValueError ("API key must be provided for API auth type" )
192
- elif config .auth_type != apimodelsv1 .ProviderAuthType .api_key and config .api_key :
193
- raise ValueError ("API key provided for non-API auth type" )
194
-
195
- dbendpoint = await self ._db_reader .get_provider_endpoint_by_id (str (provider_id ))
196
- if dbendpoint is None :
197
- raise ProviderNotFoundError ("Provider not found" )
198
-
199
- await self ._db_writer .push_provider_auth_material (
200
- dbmodels .ProviderAuthMaterial (
201
- provider_endpoint_id = dbendpoint .id ,
202
- auth_type = config .auth_type ,
203
- auth_blob = config .api_key if config .api_key else "" ,
204
- )
205
- )
206
-
207
200
async def delete_endpoint (self , provider_id : UUID ):
208
201
"""Delete an endpoint."""
209
202
0 commit comments