Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit f30e07f

Browse files
Fix logic for updating provider details and provider auth type (#914)
We were trying to update the provider models when updating the provider details. Provider details do not have relevant information to update the models. Change the logic to update the provider models when updating the provider authentication details
1 parent 1bba21b commit f30e07f

File tree

2 files changed

+36
-43
lines changed

2 files changed

+36
-43
lines changed

src/codegate/api/v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ async def configure_auth_material(
161161
)
162162
async def update_provider_endpoint(
163163
provider_id: UUID,
164-
request: v1_models.AddProviderEndpointRequest,
164+
request: v1_models.ProviderEndpoint,
165165
) -> v1_models.ProviderEndpoint:
166166
"""Update a provider endpoint by ID."""
167167
try:

src/codegate/providers/crud/crud.py

Lines changed: 35 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ async def add_endpoint(
114114
return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint)
115115

116116
async def update_endpoint(
117-
self, endpoint: apimodelsv1.AddProviderEndpointRequest
117+
self, endpoint: apimodelsv1.ProviderEndpoint
118118
) -> apimodelsv1.ProviderEndpoint:
119119
"""Update an endpoint."""
120120

@@ -134,12 +134,40 @@ async def update_endpoint(
134134
if founddbe is None:
135135
raise ProviderNotFoundError("Provider not found")
136136

137-
models = []
138-
if endpoint.auth_type == apimodelsv1.ProviderAuthType.api_key and not endpoint.api_key:
137+
dbendpoint = await self._db_writer.update_provider_endpoint(endpoint.to_db_model())
138+
139+
return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint)
140+
141+
async def configure_auth_material(
142+
self, provider_id: UUID, config: apimodelsv1.ConfigureAuthMaterial
143+
):
144+
"""Add an API key."""
145+
if config.auth_type == apimodelsv1.ProviderAuthType.api_key and not config.api_key:
139146
raise ValueError("API key must be provided for API auth type")
140-
if endpoint.auth_type != apimodelsv1.ProviderAuthType.passthrough:
147+
elif config.auth_type != apimodelsv1.ProviderAuthType.api_key and config.api_key:
148+
raise ValueError("API key provided for non-API auth type")
149+
150+
dbendpoint = await self._db_reader.get_provider_endpoint_by_id(str(provider_id))
151+
if dbendpoint is None:
152+
raise ProviderNotFoundError("Provider not found")
153+
154+
await self._db_writer.push_provider_auth_material(
155+
dbmodels.ProviderAuthMaterial(
156+
provider_endpoint_id=dbendpoint.id,
157+
auth_type=config.auth_type,
158+
auth_blob=config.api_key if config.api_key else "",
159+
)
160+
)
161+
162+
endpoint = apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint)
163+
endpoint.auth_type = config.auth_type
164+
provider_registry = get_provider_registry()
165+
prov = endpoint.get_from_registry(provider_registry)
166+
167+
models = []
168+
if config.auth_type != apimodelsv1.ProviderAuthType.passthrough:
141169
try:
142-
models = prov.models(endpoint=endpoint.endpoint, api_key=endpoint.api_key)
170+
models = prov.models(endpoint=endpoint.endpoint, api_key=config.api_key)
143171
except Exception as err:
144172
raise ValueError("Unable to get models from provider: {}".format(str(err)))
145173

@@ -154,56 +182,21 @@ async def update_endpoint(
154182
for model in models_set - models_in_db_set:
155183
await self._db_writer.add_provider_model(
156184
dbmodels.ProviderModel(
157-
provider_endpoint_id=founddbe.id,
185+
provider_endpoint_id=dbendpoint.id,
158186
name=model,
159187
)
160188
)
161189

162190
# Remove the models that are in the DB but not in the provider
163191
for model in models_in_db_set - models_set:
164192
await self._db_writer.delete_provider_model(
165-
founddbe.id,
193+
dbendpoint.id,
166194
model,
167195
)
168196

169-
dbendpoint = await self._db_writer.update_provider_endpoint(endpoint.to_db_model())
170-
171-
# If an API key was provided or we've changed the auth type, we update the auth material
172-
if endpoint.auth_type != founddbe.auth_type or endpoint.api_key:
173-
await self._db_writer.push_provider_auth_material(
174-
dbmodels.ProviderAuthMaterial(
175-
provider_endpoint_id=dbendpoint.id,
176-
auth_type=endpoint.auth_type,
177-
auth_blob=endpoint.api_key if endpoint.api_key else "",
178-
)
179-
)
180-
181197
# a model might have been deleted, let's repopulate the cache
182198
await self._ws_crud.repopulate_mux_cache()
183199

184-
return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint)
185-
186-
async def configure_auth_material(
187-
self, provider_id: UUID, config: apimodelsv1.ConfigureAuthMaterial
188-
):
189-
"""Add an API key."""
190-
if config.auth_type == apimodelsv1.ProviderAuthType.api_key and not config.api_key:
191-
raise ValueError("API key must be provided for API auth type")
192-
elif config.auth_type != apimodelsv1.ProviderAuthType.api_key and config.api_key:
193-
raise ValueError("API key provided for non-API auth type")
194-
195-
dbendpoint = await self._db_reader.get_provider_endpoint_by_id(str(provider_id))
196-
if dbendpoint is None:
197-
raise ProviderNotFoundError("Provider not found")
198-
199-
await self._db_writer.push_provider_auth_material(
200-
dbmodels.ProviderAuthMaterial(
201-
provider_endpoint_id=dbendpoint.id,
202-
auth_type=config.auth_type,
203-
auth_blob=config.api_key if config.api_key else "",
204-
)
205-
)
206-
207200
async def delete_endpoint(self, provider_id: UUID):
208201
"""Delete an endpoint."""
209202

0 commit comments

Comments
 (0)