Skip to content

Commit c867740

Browse files
authored
Merge pull request BerriAI#20481 from Harshit28j/litellm_aws_rotation_fix
Fix authorization issues, same alias; verified working
2 parents af3acdd + ac4bd34 commit c867740

File tree

4 files changed

+270
-38
lines changed

4 files changed

+270
-38
lines changed

litellm/proxy/hooks/key_management_event_hooks.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,15 +150,26 @@ async def async_key_rotated_hook(
150150
existing_key_row.key_alias
151151
or f"virtual-key-{existing_key_row.token}"
152152
)
153+
new_secret_name = (
154+
response.key_alias
155+
or data.key_alias
156+
or f"virtual-key-{response.token_id}"
157+
)
158+
verbose_proxy_logger.info(
159+
"Updating secret in secret manager: secret_name=%s",
160+
new_secret_name,
161+
)
153162
team_id = getattr(existing_key_row, "team_id", None)
154163
await KeyManagementEventHooks._rotate_virtual_key_in_secret_manager(
155164
current_secret_name=initial_secret_name,
156-
new_secret_name=response.key_alias
157-
or data.key_alias
158-
or f"virtual-key-{response.token_id}",
165+
new_secret_name=new_secret_name,
159166
new_secret_value=response.key,
160167
team_id=team_id,
161168
)
169+
verbose_proxy_logger.info(
170+
"Secret updated in secret manager: secret_name=%s",
171+
new_secret_name,
172+
)
162173
except Exception as e:
163174
verbose_proxy_logger.warning(
164175
f"Failed to rotate virtual key in secret manager: {e}"

litellm/proxy/management_endpoints/key_management_endpoints.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2770,6 +2770,7 @@ async def can_modify_verification_token(
27702770
27712771
Rules:
27722772
- Proxy admin can modify any key
2773+
- Internal jobs service account can modify any key (for auto-rotation)
27732774
- For team keys: only team admin or key owner can modify
27742775
- For personal keys: only key owner can modify
27752776
@@ -2782,13 +2783,19 @@ async def can_modify_verification_token(
27822783
Returns:
27832784
True if user can modify the key, False otherwise
27842785
"""
2786+
from litellm.constants import LITELLM_INTERNAL_JOBS_SERVICE_ACCOUNT_NAME
2787+
27852788
is_team_key = _is_team_key(data=key_info)
27862789

27872790
# 1. Proxy admin can modify any key
27882791
if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value:
27892792
return True
27902793

2791-
# 2. For team keys: only team admin or key owner can modify
2794+
# 2. Internal jobs service account can modify any key (for auto-rotation)
2795+
if user_api_key_dict.api_key == LITELLM_INTERNAL_JOBS_SERVICE_ACCOUNT_NAME:
2796+
return True
2797+
2798+
# 3. For team keys: only team admin or key owner can modify
27922799
if is_team_key and key_info.team_id is not None:
27932800
# Get team object to check if user is team admin
27942801
team_table = await get_team_object(
@@ -2818,7 +2825,7 @@ async def can_modify_verification_token(
28182825
# Not team admin and doesn't own the key
28192826
return False
28202827

2821-
# 3. For personal keys: only key owner can modify
2828+
# 4. For personal keys: only key owner can modify
28222829
if key_info.user_id is not None and key_info.user_id == user_api_key_dict.user_id:
28232830
return True
28242831

@@ -3179,7 +3186,7 @@ def get_new_token(data: Optional[RegenerateKeyRequest]) -> str:
31793186
dependencies=[Depends(user_api_key_auth)],
31803187
)
31813188
@management_endpoint_wrapper
3182-
async def regenerate_key_fn(
3189+
async def regenerate_key_fn( # noqa: PLR0915
31833190
key: Optional[str] = None,
31843191
data: Optional[RegenerateKeyRequest] = None,
31853192
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
@@ -3330,6 +3337,10 @@ async def regenerate_key_fn(
33303337
detail={"error": "You are not authorized to regenerate this key"},
33313338
)
33323339

3340+
verbose_proxy_logger.info(
3341+
"Key regeneration requested: key_alias=%s",
3342+
getattr(_key_in_db, "key_alias", None),
3343+
)
33333344
verbose_proxy_logger.debug("key_in_db: %s", _key_in_db)
33343345

33353346
new_token = get_new_token(data=data)
@@ -3380,6 +3391,10 @@ async def regenerate_key_fn(
33803391
**updated_token_dict,
33813392
)
33823393

3394+
verbose_proxy_logger.info(
3395+
"Key regeneration completed: key_alias=%s",
3396+
getattr(_key_in_db, "key_alias", None),
3397+
)
33833398
asyncio.create_task(
33843399
KeyManagementEventHooks.async_key_rotated_hook(
33853400
data=data,

litellm/secret_managers/aws_secret_manager_v2.py

Lines changed: 129 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
44
Handles Async Operations for:
55
- Read Secret
6-
- Write Secret
6+
- Write Secret (CreateSecret)
7+
- Update Secret (PutSecretValue) - for in-place rotation when alias is preserved
78
- Delete Secret
89
910
Relevant issue: https://github.com/BerriAI/litellm/issues/1883
@@ -42,11 +43,11 @@ def __init__(
4243
aws_profile_name: Optional[str] = None,
4344
aws_web_identity_token: Optional[str] = None,
4445
aws_sts_endpoint: Optional[str] = None,
45-
**kwargs
46+
**kwargs,
4647
):
4748
BaseSecretManager.__init__(self, **kwargs)
4849
BaseAWSLLM.__init__(self, **kwargs)
49-
50+
5051
# Store AWS authentication settings
5152
self.aws_region_name = aws_region_name
5253
self.aws_role_name = aws_role_name
@@ -61,7 +62,7 @@ def validate_environment(cls):
6162
# AWS_REGION_NAME is only strictly required if not using a profile or role
6263
# When using IAM roles, the region can come from multiple sources
6364
if (
64-
"AWS_REGION_NAME" not in os.environ
65+
"AWS_REGION_NAME" not in os.environ
6566
and "AWS_REGION" not in os.environ
6667
and "AWS_DEFAULT_REGION" not in os.environ
6768
):
@@ -83,22 +84,36 @@ def load_aws_secret_manager(
8384
return
8485
try:
8586
cls.validate_environment()
86-
87+
8788
# Extract AWS settings from key_management_settings if provided
8889
aws_kwargs = {}
8990
if key_management_settings is not None:
9091
aws_kwargs = {
91-
"aws_region_name": getattr(key_management_settings, "aws_region_name", None),
92-
"aws_role_name": getattr(key_management_settings, "aws_role_name", None),
93-
"aws_session_name": getattr(key_management_settings, "aws_session_name", None),
94-
"aws_external_id": getattr(key_management_settings, "aws_external_id", None),
95-
"aws_profile_name": getattr(key_management_settings, "aws_profile_name", None),
96-
"aws_web_identity_token": getattr(key_management_settings, "aws_web_identity_token", None),
97-
"aws_sts_endpoint": getattr(key_management_settings, "aws_sts_endpoint", None),
92+
"aws_region_name": getattr(
93+
key_management_settings, "aws_region_name", None
94+
),
95+
"aws_role_name": getattr(
96+
key_management_settings, "aws_role_name", None
97+
),
98+
"aws_session_name": getattr(
99+
key_management_settings, "aws_session_name", None
100+
),
101+
"aws_external_id": getattr(
102+
key_management_settings, "aws_external_id", None
103+
),
104+
"aws_profile_name": getattr(
105+
key_management_settings, "aws_profile_name", None
106+
),
107+
"aws_web_identity_token": getattr(
108+
key_management_settings, "aws_web_identity_token", None
109+
),
110+
"aws_sts_endpoint": getattr(
111+
key_management_settings, "aws_sts_endpoint", None
112+
),
98113
}
99114
# Remove None values
100115
aws_kwargs = {k: v for k, v in aws_kwargs.items() if v is not None}
101-
116+
102117
litellm.secret_manager_client = cls(**aws_kwargs)
103118
litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
104119

@@ -246,13 +261,13 @@ async def async_read_secret_from_primary_secret(
246261
return primary_secret_kv_pairs.get(secret_name)
247262

248263
async def async_write_secret(
249-
self,
250-
secret_name: str,
251-
secret_value: str,
252-
description: Optional[str] = None,
253-
optional_params: Optional[dict] = None,
254-
timeout: Optional[Union[float, httpx.Timeout]] = None,
255-
tags: Optional[Union[dict, list]] = None
264+
self,
265+
secret_name: str,
266+
secret_value: str,
267+
description: Optional[str] = None,
268+
optional_params: Optional[dict] = None,
269+
timeout: Optional[Union[float, httpx.Timeout]] = None,
270+
tags: Optional[Union[dict, list]] = None,
256271
) -> dict:
257272
"""
258273
Async function to write a secret to AWS Secrets Manager
@@ -312,6 +327,94 @@ async def async_write_secret(
312327
except httpx.TimeoutException:
313328
raise ValueError("Timeout error occurred")
314329

330+
async def async_put_secret_value(
331+
self,
332+
secret_name: str,
333+
secret_value: str,
334+
optional_params: Optional[dict] = None,
335+
timeout: Optional[Union[float, httpx.Timeout]] = None,
336+
) -> dict:
337+
"""
338+
Async function to update an existing secret's value in AWS Secrets Manager.
339+
340+
Uses PutSecretValue to update in place. Use this when rotating a secret
341+
that keeps the same name (current_secret_name == new_secret_name).
342+
343+
Args:
344+
secret_name: Name of the existing secret to update
345+
secret_value: New value to store
346+
optional_params: Additional AWS parameters
347+
timeout: Request timeout
348+
349+
Returns:
350+
dict: Response from AWS Secrets Manager containing update details
351+
"""
352+
from litellm._uuid import uuid
353+
354+
data: Dict[str, Any] = {
355+
"SecretId": secret_name,
356+
"SecretString": secret_value,
357+
"ClientRequestToken": str(uuid.uuid4()),
358+
}
359+
360+
endpoint_url, headers, body = self._prepare_request(
361+
action="PutSecretValue",
362+
secret_name=secret_name,
363+
secret_value=secret_value,
364+
optional_params=optional_params,
365+
request_data=data,
366+
)
367+
368+
async_client = get_async_httpx_client(
369+
llm_provider=httpxSpecialProvider.SecretManager,
370+
params={"timeout": timeout},
371+
)
372+
373+
try:
374+
response = await async_client.post(
375+
url=endpoint_url, headers=headers, data=body.decode("utf-8")
376+
)
377+
response.raise_for_status()
378+
return response.json()
379+
except httpx.HTTPStatusError as err:
380+
raise ValueError(f"HTTP error occurred: {err.response.text}")
381+
except httpx.TimeoutException:
382+
raise ValueError("Timeout error occurred")
383+
384+
async def async_rotate_secret(
385+
self,
386+
current_secret_name: str,
387+
new_secret_name: str,
388+
new_secret_value: str,
389+
optional_params: Optional[dict] = None,
390+
timeout: Optional[Union[float, httpx.Timeout]] = None,
391+
) -> dict:
392+
"""
393+
Rotate a secret. When current_secret_name == new_secret_name (in-place
394+
update), uses PutSecretValue instead of create+delete to avoid
395+
ResourceExistsException.
396+
"""
397+
if current_secret_name == new_secret_name:
398+
# Same alias: update in place via PutSecretValue
399+
verbose_logger.info(
400+
"Secret rotated in-place (PutSecretValue): secret_name=%s",
401+
current_secret_name,
402+
)
403+
return await self.async_put_secret_value(
404+
secret_name=current_secret_name,
405+
secret_value=new_secret_value,
406+
optional_params=optional_params,
407+
timeout=timeout,
408+
)
409+
# Different names: create new, delete old (base class logic)
410+
return await super().async_rotate_secret(
411+
current_secret_name=current_secret_name,
412+
new_secret_name=new_secret_name,
413+
new_secret_value=new_secret_value,
414+
optional_params=optional_params,
415+
timeout=timeout,
416+
)
417+
315418
async def async_delete_secret(
316419
self,
317420
secret_name: str,
@@ -375,7 +478,7 @@ def _prepare_request(
375478
except ImportError:
376479
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
377480
optional_params = optional_params or {}
378-
481+
379482
# Build optional_params from instance settings if not provided
380483
# This allows the IAM role settings to be used for Secret Manager calls
381484
if not optional_params.get("aws_role_name") and self.aws_role_name:
@@ -388,11 +491,14 @@ def _prepare_request(
388491
optional_params["aws_external_id"] = self.aws_external_id
389492
if not optional_params.get("aws_profile_name") and self.aws_profile_name:
390493
optional_params["aws_profile_name"] = self.aws_profile_name
391-
if not optional_params.get("aws_web_identity_token") and self.aws_web_identity_token:
494+
if (
495+
not optional_params.get("aws_web_identity_token")
496+
and self.aws_web_identity_token
497+
):
392498
optional_params["aws_web_identity_token"] = self.aws_web_identity_token
393499
if not optional_params.get("aws_sts_endpoint") and self.aws_sts_endpoint:
394500
optional_params["aws_sts_endpoint"] = self.aws_sts_endpoint
395-
501+
396502
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
397503
optional_params
398504
)
@@ -431,12 +537,3 @@ def _prepare_request(
431537
prepped = request.prepare()
432538

433539
return endpoint_url, prepped.headers, body
434-
435-
436-
# if __name__ == "__main__":
437-
# print("loading aws secret manager v2")
438-
# aws_secret_manager_v2 = AWSSecretsManagerV2()
439-
# import asyncio
440-
# print("writing secret to aws secret manager v2")
441-
# asyncio.run(aws_secret_manager_v2.async_write_secret(secret_name="test_secret_3", secret_value="test_value_2"))
442-
# print("reading secret from aws secret manager v2")

0 commit comments

Comments
 (0)