Skip to content

Commit bca1468

Browse files
committed
fix: fix comment issues and add getting role_trn from VEFAAS_IAM_CRIDENTIAL_PATH
1 parent 8f674d1 commit bca1468

File tree

4 files changed

+215
-83
lines changed

4 files changed

+215
-83
lines changed

veadk/configs/auth_configs.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,7 @@ class VeIdentityConfig(BaseSettings):
3939
If not provided, the endpoint will be auto-generated based on the region.
4040
"""
4141

42-
role_trn: str = ""
43-
"""The TRN of the role, in the format: trn:iam::${AccountId}:role/${RoleName}
44-
45-
When accessing with AK/SK (without session_token), this role will be automatically assumed to obtain temporary credentials.
46-
For example: trn:iam::2000012345:role/MyWorkloadRole
47-
"""
48-
49-
role_session_name: str = "veadk_assume_role_session"
42+
role_session_name: str = "veadk_default_assume_role_session"
5043
"""Role session name, used to distinguish different sessions in audit logs.
5144
"""
5245

veadk/integrations/ve_identity/identity_client.py

Lines changed: 136 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import json
2020
import os
21+
from pathlib import Path
2122
import uuid
2223
from functools import wraps
2324
from typing import Any, Dict, List, Literal, Optional
@@ -27,6 +28,7 @@
2728
import volcenginesdkcore
2829
import volcenginesdksts
2930

31+
from veadk.consts import VEFAAS_IAM_CRIDENTIAL_PATH
3032
from veadk.integrations.ve_identity.models import (
3133
AssumeRoleCredential,
3234
DCRRegistrationRequest,
@@ -35,7 +37,7 @@
3537
WorkloadToken,
3638
)
3739
from veadk.auth.veauth.utils import get_credential_from_vefaas_iam
38-
from veadk.configs.auth_configs import VeIdentityConfig
40+
from veadk.config import settings
3941

4042
from veadk.utils.logger import get_logger
4143

@@ -80,19 +82,22 @@ def _refresh_creds(self: IdentityClient):
8082
except Exception as e:
8183
logger.warning(f"Failed to retrieve credentials from VeFaaS IAM: {e}")
8284

83-
# If there is no session_token and role_trn is configured, execute AssumeRole
84-
if not session_token and self._identity_config.role_trn and ak and sk:
85-
try:
86-
logger.info(
87-
f"No session token found, attempting AssumeRole with role: {self._identity_config.role_trn}"
88-
)
89-
sts_credentials = self._assume_role(ak, sk)
90-
ak = sts_credentials.access_key_id
91-
sk = sts_credentials.secret_access_key
92-
session_token = sts_credentials.session_token
93-
logger.info("Successfully assumed role and obtained STS credentials")
94-
except Exception as e:
95-
logger.warning(f"Failed to assume role: {e}")
85+
if not session_token and ak and sk:
86+
role_trn = self._get_iam_role_trn_from_vefaas_iam()
87+
if not role_trn:
88+
role_trn = os.getenv("RUNTIME_IAM_ROLE_TRN", "")
89+
# If there is no session_token and role_trn is configured, execute AssumeRole
90+
if role_trn:
91+
try:
92+
logger.info(
93+
f"No session token found, attempting AssumeRole with role: {role_trn}"
94+
)
95+
sts_credentials = self._assume_role(ak, sk, role_trn)
96+
ak = sts_credentials.access_key_id
97+
sk = sts_credentials.secret_access_key
98+
session_token = sts_credentials.session_token
99+
except Exception as e:
100+
logger.warning(f"Failed to assume role: {e}")
96101

97102
# Update configuration with the credentials
98103
self._api_client.api_client.configuration.ak = ak
@@ -132,7 +137,6 @@ def __init__(
132137
secret_key: Optional[str] = None,
133138
session_token: Optional[str] = None,
134139
region: str = "cn-beijing",
135-
identity_config: Optional[VeIdentityConfig] = None,
136140
):
137141
"""Initialize the identity client.
138142
@@ -146,7 +150,6 @@ def __init__(
146150
KeyError: If required environment variables are not set.
147151
"""
148152
self.region = region
149-
self._identity_config = identity_config or VeIdentityConfig()
150153

151154
# Store initial credentials for fallback
152155
self._initial_access_key = access_key or os.getenv("VOLCENGINE_ACCESS_KEY", "")
@@ -166,19 +169,75 @@ def __init__(
166169
volcenginesdkcore.ApiClient(configuration)
167170
)
168171

169-
def _assume_role(self, access_key: str, secret_key: str) -> AssumeRoleCredential:
172+
# STS credential cache
173+
self._cached_sts_credential: Optional[AssumeRoleCredential] = None
174+
self._sts_credential_expires_at: Optional[int] = None
175+
176+
def _get_iam_role_trn_from_vefaas_iam(self) -> Optional[str]:
177+
logger.info(
178+
f"Try to get IAM Role TRN from VeFaaS IAM file (path={VEFAAS_IAM_CRIDENTIAL_PATH})."
179+
)
180+
181+
path = Path(VEFAAS_IAM_CRIDENTIAL_PATH)
182+
183+
if not path.exists():
184+
logger.error(
185+
f"Get IAM Role TRN from IAM file failed, and VeFaaS IAM file (path={VEFAAS_IAM_CRIDENTIAL_PATH}) not exists. Please check your configuration."
186+
)
187+
return None
188+
189+
with open(VEFAAS_IAM_CRIDENTIAL_PATH, "r") as f:
190+
cred_dict = json.load(f)
191+
role_trn = cred_dict["role_trn"]
192+
193+
logger.info("Get IAM Role TRN from IAM file successfully.")
194+
195+
return role_trn
196+
197+
def _is_sts_credential_expired(self) -> bool:
198+
"""Check if cached STS credential is expired or will expire soon.
199+
200+
Returns:
201+
True if credential is expired or will expire within 5 minutes, False otherwise.
202+
"""
203+
if self._sts_credential_expires_at is None:
204+
return True
205+
206+
import time
207+
208+
current_time = int(time.time())
209+
# Refresh 5 minutes in advance to avoid expiration during use.
210+
buffer_seconds = 300
211+
return current_time >= (self._sts_credential_expires_at - buffer_seconds)
212+
213+
def _assume_role(
214+
self, access_key: str, secret_key: str, role_trn: str
215+
) -> AssumeRoleCredential:
170216
"""Execute AssumeRole to get STS temporary credentials.
171217
172218
Args:
173219
access_key: VolcEngine access key
174220
secret_key: VolcEngine secret key
221+
role_trn: The role TRN to assume
175222
176223
Returns:
177224
AssumeRoleCredential containing temporary credentials
178225
179226
Raises:
180227
Exception: If AssumeRole fails
181228
"""
229+
# Check if the cached credentials are still valid
230+
if (
231+
self._cached_sts_credential is not None
232+
and not self._is_sts_credential_expired()
233+
):
234+
logger.info("Using cached STS credentials")
235+
return self._cached_sts_credential
236+
237+
logger.info(
238+
"Cached STS credentials expired or not found, requesting new credentials..."
239+
)
240+
182241
# Create STS client configuration
183242
sts_config = volcenginesdkcore.Configuration()
184243
sts_config.region = self.region
@@ -190,13 +249,13 @@ def _assume_role(self, access_key: str, secret_key: str) -> AssumeRoleCredential
190249

191250
# Construct an AssumeRole request
192251
assume_role_request = volcenginesdksts.AssumeRoleRequest(
193-
role_trn=self._identity_config.role_trn,
194-
role_session_name=self._identity_config.role_session_name,
252+
role_trn=role_trn,
253+
role_session_name=settings.veidentity.role_session_name,
195254
)
196255

197256
logger.info(
198-
f"Executing AssumeRole for role: {self._identity_config.role_trn}, "
199-
f"session: {self._identity_config.role_session_name}"
257+
f"Executing AssumeRole for role: {role_trn}, "
258+
f"session: {settings.veidentity.role_session_name}"
200259
)
201260

202261
response: volcenginesdksts.AssumeRoleResponse = sts_client.assume_role(
@@ -206,16 +265,42 @@ def _assume_role(self, access_key: str, secret_key: str) -> AssumeRoleCredential
206265
if not response.credentials:
207266
raise Exception("AssumeRole returned no credentials")
208267

209-
access_key = response["access_key_id"]
210-
secret_key = response["secret_access_key"]
211-
session_token = response["session_token"]
268+
credentials = response.credentials
269+
270+
# Parse expiration time
271+
from datetime import datetime
272+
import calendar
273+
274+
try:
275+
# ExpiredTime format: "2021-04-12T11:57:09+08:00"
276+
dt = datetime.strptime(
277+
credentials.expired_time.replace("+08:00", ""), "%Y-%m-%dT%H:%M:%S"
278+
)
279+
expires_at_timestamp = calendar.timegm(dt.timetuple())
280+
except Exception as e:
281+
logger.warning(f"Failed to parse STS credential expiration time: {e}")
282+
# Expires in 1 hour by default
283+
import time
212284

213-
return AssumeRoleCredential(
214-
access_key_id=access_key,
215-
secret_access_key=secret_key,
216-
session_token=session_token,
285+
expires_at_timestamp = int(time.time()) + 3600
286+
287+
sts_credential = AssumeRoleCredential(
288+
access_key_id=credentials.access_key_id,
289+
secret_access_key=credentials.secret_access_key,
290+
session_token=credentials.session_token,
291+
)
292+
293+
# Cached credentials and expiration time
294+
self._cached_sts_credential = sts_credential
295+
self._sts_credential_expires_at = expires_at_timestamp
296+
297+
logger.info(
298+
f"Successfully obtained and cached STS credentials, "
299+
f"expires at {datetime.fromtimestamp(expires_at_timestamp).isoformat()}"
217300
)
218301

302+
return sts_credential
303+
219304
@refresh_credentials
220305
def create_oauth2_credential_provider(
221306
self, request_params: Dict[str, Any]
@@ -255,7 +340,7 @@ def create_api_key_credential_provider(
255340
@refresh_credentials
256341
def get_workload_access_token(
257342
self,
258-
workload_name: str,
343+
workload_name: Optional[str] = None,
259344
user_token: Optional[str] = None,
260345
user_id: Optional[str] = None,
261346
) -> WorkloadToken:
@@ -606,35 +691,47 @@ async def create_oauth2_credential_provider_with_dcr(
606691

607692
@refresh_credentials
608693
def check_permission(
609-
self, principal_id, operation, resource_id, namespace="default"
694+
self,
695+
principal: Dict[str, str],
696+
operation: Dict[str, str],
697+
resource: Dict[str, str],
698+
namespace: str = "default",
610699
) -> bool:
611700
"""Check if the principal has permission to perform the operation on the resource.
612701
613702
Args:
614-
principal_id: The ID of the principal (user or service).
615-
operation: The operation to check permission for.
616-
resource_id: The ID of the resource.
617-
namespace: The namespace of the resource. Defaults to "default".
703+
principal: Principal information, e.g., {"Type": "User", "Id": "user123"}
704+
operation: Operation to check, e.g., {"Type": "Action", "Id": "invoke"}
705+
resource: Resource information, e.g., {"Type": "Agent", "Id": "agent456"}
706+
namespace: Namespace of the resource. Defaults to "default".
618707
619708
Returns:
620709
True if the principal has permission, False otherwise.
710+
711+
Raises:
712+
ValueError: If input parameters are invalid
713+
RuntimeError: If the permission check API call fails
621714
"""
622715
logger.info(
623-
f"Checking permission for principal {principal_id} on resource {resource_id} for operation {operation}..."
716+
f"Checking permission for principal {principal['Id']} on resource {resource['Id']} for operation {operation['Id']}..."
624717
)
625718

626719
request = volcenginesdkid.CheckPermissionRequest(
627-
principal_id=principal_id,
720+
namespace_name=namespace,
628721
operation=operation,
629-
resource_id=resource_id,
630-
namespace=namespace,
722+
principal=principal,
723+
resource=resource,
631724
)
632725

633726
response: volcenginesdkid.CheckPermissionResponse = (
634727
self._api_client.check_permission(request)
635728
)
636729

730+
if not hasattr(response, "allowed"):
731+
logger.error("Permission check failed")
732+
return False
733+
637734
logger.info(
638-
f"Permission check result for principal {principal_id} on resource {resource_id}: {response.allowed}"
735+
f"Permission check result for principal {principal['Id']} on resource {resource['Id']}: {response.allowed}"
639736
)
640737
return response.allowed

veadk/integrations/ve_identity/token_manager.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,6 @@ async def get_workload_token(
132132
f"Cached workload token expired for agent '{tool_context.agent_name}', refreshing..."
133133
)
134134

135-
# Default to agent_name if workload_name not specified
136-
if not workload_name:
137-
workload_name = tool_context.agent_name
138-
139135
# Determine user_id based on authentication mode
140136
user_id = None if user_token else tool_context._invocation_context.user_id
141137

0 commit comments

Comments
 (0)