1818
1919import json
2020import os
21+ from pathlib import Path
2122import uuid
2223from functools import wraps
2324from typing import Any , Dict , List , Literal , Optional
2425
2526import aiohttp
2627import volcenginesdkid
2728import volcenginesdkcore
29+ import volcenginesdksts
2830
31+ from veadk .consts import VEFAAS_IAM_CRIDENTIAL_PATH
2932from veadk .integrations .ve_identity .models import (
33+ AssumeRoleCredential ,
3034 DCRRegistrationRequest ,
3135 DCRRegistrationResponse ,
3236 OAuth2TokenResponse ,
3337 WorkloadToken ,
3438)
3539from veadk .auth .veauth .utils import get_credential_from_vefaas_iam
40+ from veadk .config import settings
3641
3742from veadk .utils .logger import get_logger
3843
@@ -77,6 +82,23 @@ def _refresh_creds(self: IdentityClient):
7782 except Exception as e :
7883 logger .warning (f"Failed to retrieve credentials from VeFaaS IAM: { e } " )
7984
85+ if not session_token and ak and sk :
86+ role_trn = self ._get_iam_role_trn_from_vefaas_iam ()
87+ if not role_trn :
88+ role_trn = os .getenv ("RUNTIME_IAM_ROLE_TRN" , "" )
89+ # If there is no session_token and role_trn is configured, execute AssumeRole
90+ if role_trn :
91+ try :
92+ logger .info (
93+ f"No session token found, attempting AssumeRole with role: { role_trn } "
94+ )
95+ sts_credentials = self ._assume_role (ak , sk , role_trn )
96+ ak = sts_credentials .access_key_id
97+ sk = sts_credentials .secret_access_key
98+ session_token = sts_credentials .session_token
99+ except Exception as e :
100+ logger .warning (f"Failed to assume role: { e } " )
101+
80102 # Update configuration with the credentials
81103 self ._api_client .api_client .configuration .ak = ak
82104 self ._api_client .api_client .configuration .sk = sk
@@ -128,6 +150,7 @@ def __init__(
128150 KeyError: If required environment variables are not set.
129151 """
130152 self .region = region
153+
131154 # Store initial credentials for fallback
132155 self ._initial_access_key = access_key or os .getenv ("VOLCENGINE_ACCESS_KEY" , "" )
133156 self ._initial_secret_key = secret_key or os .getenv ("VOLCENGINE_SECRET_KEY" , "" )
@@ -146,6 +169,138 @@ def __init__(
146169 volcenginesdkcore .ApiClient (configuration )
147170 )
148171
172+ # STS credential cache
173+ self ._cached_sts_credential : Optional [AssumeRoleCredential ] = None
174+ self ._sts_credential_expires_at : Optional [int ] = None
175+
176+ def _get_iam_role_trn_from_vefaas_iam (self ) -> Optional [str ]:
177+ logger .info (
178+ f"Try to get IAM Role TRN from VeFaaS IAM file (path={ VEFAAS_IAM_CRIDENTIAL_PATH } )."
179+ )
180+
181+ path = Path (VEFAAS_IAM_CRIDENTIAL_PATH )
182+
183+ if not path .exists ():
184+ logger .error (
185+ f"Get IAM Role TRN from IAM file failed, and VeFaaS IAM file (path={ VEFAAS_IAM_CRIDENTIAL_PATH } ) not exists. Please check your configuration."
186+ )
187+ return None
188+
189+ with open (VEFAAS_IAM_CRIDENTIAL_PATH , "r" ) as f :
190+ cred_dict = json .load (f )
191+ role_trn = cred_dict ["role_trn" ]
192+
193+ logger .info ("Get IAM Role TRN from IAM file successfully." )
194+
195+ return role_trn
196+
197+ def _is_sts_credential_expired (self ) -> bool :
198+ """Check if cached STS credential is expired or will expire soon.
199+
200+ Returns:
201+ True if credential is expired or will expire within 5 minutes, False otherwise.
202+ """
203+ if self ._sts_credential_expires_at is None :
204+ return True
205+
206+ import time
207+
208+ current_time = int (time .time ())
209+ # Refresh 5 minutes in advance to avoid expiration during use.
210+ buffer_seconds = 300
211+ return current_time >= (self ._sts_credential_expires_at - buffer_seconds )
212+
213+ def _assume_role (
214+ self , access_key : str , secret_key : str , role_trn : str
215+ ) -> AssumeRoleCredential :
216+ """Execute AssumeRole to get STS temporary credentials.
217+
218+ Args:
219+ access_key: VolcEngine access key
220+ secret_key: VolcEngine secret key
221+ role_trn: The role TRN to assume
222+
223+ Returns:
224+ AssumeRoleCredential containing temporary credentials
225+
226+ Raises:
227+ Exception: If AssumeRole fails
228+ """
229+ # Check if the cached credentials are still valid
230+ if (
231+ self ._cached_sts_credential is not None
232+ and not self ._is_sts_credential_expired ()
233+ ):
234+ logger .info ("Using cached STS credentials" )
235+ return self ._cached_sts_credential
236+
237+ logger .info (
238+ "Cached STS credentials expired or not found, requesting new credentials..."
239+ )
240+
241+ # Create STS client configuration
242+ sts_config = volcenginesdkcore .Configuration ()
243+ sts_config .region = self .region
244+ sts_config .ak = access_key
245+ sts_config .sk = secret_key
246+
247+ # Create an STS API client
248+ sts_client = volcenginesdksts .STSApi (volcenginesdkcore .ApiClient (sts_config ))
249+
250+ # Construct an AssumeRole request
251+ assume_role_request = volcenginesdksts .AssumeRoleRequest (
252+ role_trn = role_trn ,
253+ role_session_name = settings .veidentity .role_session_name ,
254+ )
255+
256+ logger .info (
257+ f"Executing AssumeRole for role: { role_trn } , "
258+ f"session: { settings .veidentity .role_session_name } "
259+ )
260+
261+ response : volcenginesdksts .AssumeRoleResponse = sts_client .assume_role (
262+ assume_role_request
263+ )
264+
265+ if not response .credentials :
266+ raise Exception ("AssumeRole returned no credentials" )
267+
268+ credentials = response .credentials
269+
270+ # Parse expiration time
271+ from datetime import datetime
272+ import calendar
273+
274+ try :
275+ # ExpiredTime format: "2021-04-12T11:57:09+08:00"
276+ dt = datetime .strptime (
277+ credentials .expired_time .replace ("+08:00" , "" ), "%Y-%m-%dT%H:%M:%S"
278+ )
279+ expires_at_timestamp = calendar .timegm (dt .timetuple ())
280+ except Exception as e :
281+ logger .warning (f"Failed to parse STS credential expiration time: { e } " )
282+ # Expires in 1 hour by default
283+ import time
284+
285+ expires_at_timestamp = int (time .time ()) + 3600
286+
287+ sts_credential = AssumeRoleCredential (
288+ access_key_id = credentials .access_key_id ,
289+ secret_access_key = credentials .secret_access_key ,
290+ session_token = credentials .session_token ,
291+ )
292+
293+ # Cached credentials and expiration time
294+ self ._cached_sts_credential = sts_credential
295+ self ._sts_credential_expires_at = expires_at_timestamp
296+
297+ logger .info (
298+ f"Successfully obtained and cached STS credentials, "
299+ f"expires at { datetime .fromtimestamp (expires_at_timestamp ).isoformat ()} "
300+ )
301+
302+ return sts_credential
303+
149304 @refresh_credentials
150305 def create_oauth2_credential_provider (
151306 self , request_params : Dict [str , Any ]
@@ -185,7 +340,7 @@ def create_api_key_credential_provider(
185340 @refresh_credentials
186341 def get_workload_access_token (
187342 self ,
188- workload_name : str ,
343+ workload_name : Optional [ str ] = None ,
189344 user_token : Optional [str ] = None ,
190345 user_id : Optional [str ] = None ,
191346 ) -> WorkloadToken :
@@ -533,3 +688,50 @@ async def create_oauth2_credential_provider_with_dcr(
533688
534689 # Create the credential provider with updated config
535690 return self .create_oauth2_credential_provider (request_params )
691+
692+ @refresh_credentials
693+ def check_permission (
694+ self ,
695+ principal : Dict [str , str ],
696+ operation : Dict [str , str ],
697+ resource : Dict [str , str ],
698+ namespace : str = "default" ,
699+ ) -> bool :
700+ """Check if the principal has permission to perform the operation on the resource.
701+
702+ Args:
703+ principal: Principal information, e.g., {"Type": "User", "Id": "user123"}
704+ operation: Operation to check, e.g., {"Type": "Action", "Id": "invoke"}
705+ resource: Resource information, e.g., {"Type": "Agent", "Id": "agent456"}
706+ namespace: Namespace of the resource. Defaults to "default".
707+
708+ Returns:
709+ True if the principal has permission, False otherwise.
710+
711+ Raises:
712+ ValueError: If input parameters are invalid
713+ RuntimeError: If the permission check API call fails
714+ """
715+ logger .info (
716+ f"Checking permission for principal { principal ['Id' ]} on resource { resource ['Id' ]} for operation { operation ['Id' ]} ..."
717+ )
718+
719+ request = volcenginesdkid .CheckPermissionRequest (
720+ namespace_name = namespace ,
721+ operation = operation ,
722+ principal = principal ,
723+ resource = resource ,
724+ )
725+
726+ response : volcenginesdkid .CheckPermissionResponse = (
727+ self ._api_client .check_permission (request )
728+ )
729+
730+ if not hasattr (response , "allowed" ):
731+ logger .error ("Permission check failed" )
732+ return False
733+
734+ logger .info (
735+ f"Permission check result for principal { principal ['Id' ]} on resource { resource ['Id' ]} : { response .allowed } "
736+ )
737+ return response .allowed
0 commit comments