Skip to content

Commit 56087ef

Browse files
authored
feat: add agent authorization in agent callback (#303)
* feat: add agent authorization in agent callback * fix callback addition * fix: change workload id and using assumeRole * fix: fix docstrings * fix: fix comment issues and add getting role_trn from VEFAAS_IAM_CRIDENTIAL_PATH
1 parent e96cf0b commit 56087ef

File tree

6 files changed

+348
-5
lines changed

6 files changed

+348
-5
lines changed

veadk/agent.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from veadk.tracing.base_tracer import BaseTracer
4949
from veadk.utils.logger import get_logger
5050
from veadk.utils.patches import patch_asyncio, patch_tracer
51+
from veadk.tools.builtin_tools.agent_authorization import check_agent_authorization
5152
from veadk.version import VERSION
5253

5354
patch_tracer()
@@ -123,6 +124,8 @@ class Agent(LlmAgent):
123124
)
124125
"""
125126

127+
enable_authz: bool = False
128+
126129
def model_post_init(self, __context: Any) -> None:
127130
super().model_post_init(None) # for sub_agents init
128131

@@ -184,6 +187,18 @@ def model_post_init(self, __context: Any) -> None:
184187
load_memory.custom_metadata["backend"] = self.long_term_memory.backend
185188
self.tools.append(load_memory)
186189

190+
if self.enable_authz:
191+
if self.before_agent_callback:
192+
if isinstance(self.before_agent_callback, list):
193+
self.before_agent_callback.append(check_agent_authorization)
194+
else:
195+
self.before_agent_callback = [
196+
self.before_agent_callback,
197+
check_agent_authorization,
198+
]
199+
else:
200+
self.before_agent_callback = check_agent_authorization
201+
187202
logger.info(f"VeADK version: {VERSION}")
188203

189204
logger.info(f"{self.__class__.__name__} `{self.name}` init done.")

veadk/configs/auth_configs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ class VeIdentityConfig(BaseSettings):
3939
If not provided, the endpoint will be auto-generated based on the region.
4040
"""
4141

42+
role_session_name: str = "veadk_default_assume_role_session"
43+
"""Role session name, used to distinguish different sessions in audit logs.
44+
"""
45+
4246
def get_endpoint(self) -> str:
4347
"""Get the endpoint URL for Identity service.
4448

veadk/integrations/ve_identity/identity_client.py

Lines changed: 203 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,26 @@
1818

1919
import json
2020
import os
21+
from pathlib import Path
2122
import uuid
2223
from functools import wraps
2324
from typing import Any, Dict, List, Literal, Optional
2425

2526
import aiohttp
2627
import volcenginesdkid
2728
import volcenginesdkcore
29+
import volcenginesdksts
2830

31+
from veadk.consts import VEFAAS_IAM_CRIDENTIAL_PATH
2932
from veadk.integrations.ve_identity.models import (
33+
AssumeRoleCredential,
3034
DCRRegistrationRequest,
3135
DCRRegistrationResponse,
3236
OAuth2TokenResponse,
3337
WorkloadToken,
3438
)
3539
from veadk.auth.veauth.utils import get_credential_from_vefaas_iam
40+
from veadk.config import settings
3641

3742
from veadk.utils.logger import get_logger
3843

@@ -77,6 +82,23 @@ def _refresh_creds(self: IdentityClient):
7782
except Exception as e:
7883
logger.warning(f"Failed to retrieve credentials from VeFaaS IAM: {e}")
7984

85+
if not session_token and ak and sk:
86+
role_trn = self._get_iam_role_trn_from_vefaas_iam()
87+
if not role_trn:
88+
role_trn = os.getenv("RUNTIME_IAM_ROLE_TRN", "")
89+
# If there is no session_token and role_trn is configured, execute AssumeRole
90+
if role_trn:
91+
try:
92+
logger.info(
93+
f"No session token found, attempting AssumeRole with role: {role_trn}"
94+
)
95+
sts_credentials = self._assume_role(ak, sk, role_trn)
96+
ak = sts_credentials.access_key_id
97+
sk = sts_credentials.secret_access_key
98+
session_token = sts_credentials.session_token
99+
except Exception as e:
100+
logger.warning(f"Failed to assume role: {e}")
101+
80102
# Update configuration with the credentials
81103
self._api_client.api_client.configuration.ak = ak
82104
self._api_client.api_client.configuration.sk = sk
@@ -128,6 +150,7 @@ def __init__(
128150
KeyError: If required environment variables are not set.
129151
"""
130152
self.region = region
153+
131154
# Store initial credentials for fallback
132155
self._initial_access_key = access_key or os.getenv("VOLCENGINE_ACCESS_KEY", "")
133156
self._initial_secret_key = secret_key or os.getenv("VOLCENGINE_SECRET_KEY", "")
@@ -146,6 +169,138 @@ def __init__(
146169
volcenginesdkcore.ApiClient(configuration)
147170
)
148171

172+
# STS credential cache
173+
self._cached_sts_credential: Optional[AssumeRoleCredential] = None
174+
self._sts_credential_expires_at: Optional[int] = None
175+
176+
def _get_iam_role_trn_from_vefaas_iam(self) -> Optional[str]:
177+
logger.info(
178+
f"Try to get IAM Role TRN from VeFaaS IAM file (path={VEFAAS_IAM_CRIDENTIAL_PATH})."
179+
)
180+
181+
path = Path(VEFAAS_IAM_CRIDENTIAL_PATH)
182+
183+
if not path.exists():
184+
logger.error(
185+
f"Get IAM Role TRN from IAM file failed, and VeFaaS IAM file (path={VEFAAS_IAM_CRIDENTIAL_PATH}) not exists. Please check your configuration."
186+
)
187+
return None
188+
189+
with open(VEFAAS_IAM_CRIDENTIAL_PATH, "r") as f:
190+
cred_dict = json.load(f)
191+
role_trn = cred_dict["role_trn"]
192+
193+
logger.info("Get IAM Role TRN from IAM file successfully.")
194+
195+
return role_trn
196+
197+
def _is_sts_credential_expired(self) -> bool:
198+
"""Check if cached STS credential is expired or will expire soon.
199+
200+
Returns:
201+
True if credential is expired or will expire within 5 minutes, False otherwise.
202+
"""
203+
if self._sts_credential_expires_at is None:
204+
return True
205+
206+
import time
207+
208+
current_time = int(time.time())
209+
# Refresh 5 minutes in advance to avoid expiration during use.
210+
buffer_seconds = 300
211+
return current_time >= (self._sts_credential_expires_at - buffer_seconds)
212+
213+
def _assume_role(
214+
self, access_key: str, secret_key: str, role_trn: str
215+
) -> AssumeRoleCredential:
216+
"""Execute AssumeRole to get STS temporary credentials.
217+
218+
Args:
219+
access_key: VolcEngine access key
220+
secret_key: VolcEngine secret key
221+
role_trn: The role TRN to assume
222+
223+
Returns:
224+
AssumeRoleCredential containing temporary credentials
225+
226+
Raises:
227+
Exception: If AssumeRole fails
228+
"""
229+
# Check if the cached credentials are still valid
230+
if (
231+
self._cached_sts_credential is not None
232+
and not self._is_sts_credential_expired()
233+
):
234+
logger.info("Using cached STS credentials")
235+
return self._cached_sts_credential
236+
237+
logger.info(
238+
"Cached STS credentials expired or not found, requesting new credentials..."
239+
)
240+
241+
# Create STS client configuration
242+
sts_config = volcenginesdkcore.Configuration()
243+
sts_config.region = self.region
244+
sts_config.ak = access_key
245+
sts_config.sk = secret_key
246+
247+
# Create an STS API client
248+
sts_client = volcenginesdksts.STSApi(volcenginesdkcore.ApiClient(sts_config))
249+
250+
# Construct an AssumeRole request
251+
assume_role_request = volcenginesdksts.AssumeRoleRequest(
252+
role_trn=role_trn,
253+
role_session_name=settings.veidentity.role_session_name,
254+
)
255+
256+
logger.info(
257+
f"Executing AssumeRole for role: {role_trn}, "
258+
f"session: {settings.veidentity.role_session_name}"
259+
)
260+
261+
response: volcenginesdksts.AssumeRoleResponse = sts_client.assume_role(
262+
assume_role_request
263+
)
264+
265+
if not response.credentials:
266+
raise Exception("AssumeRole returned no credentials")
267+
268+
credentials = response.credentials
269+
270+
# Parse expiration time
271+
from datetime import datetime
272+
import calendar
273+
274+
try:
275+
# ExpiredTime format: "2021-04-12T11:57:09+08:00"
276+
dt = datetime.strptime(
277+
credentials.expired_time.replace("+08:00", ""), "%Y-%m-%dT%H:%M:%S"
278+
)
279+
expires_at_timestamp = calendar.timegm(dt.timetuple())
280+
except Exception as e:
281+
logger.warning(f"Failed to parse STS credential expiration time: {e}")
282+
# Expires in 1 hour by default
283+
import time
284+
285+
expires_at_timestamp = int(time.time()) + 3600
286+
287+
sts_credential = AssumeRoleCredential(
288+
access_key_id=credentials.access_key_id,
289+
secret_access_key=credentials.secret_access_key,
290+
session_token=credentials.session_token,
291+
)
292+
293+
# Cached credentials and expiration time
294+
self._cached_sts_credential = sts_credential
295+
self._sts_credential_expires_at = expires_at_timestamp
296+
297+
logger.info(
298+
f"Successfully obtained and cached STS credentials, "
299+
f"expires at {datetime.fromtimestamp(expires_at_timestamp).isoformat()}"
300+
)
301+
302+
return sts_credential
303+
149304
@refresh_credentials
150305
def create_oauth2_credential_provider(
151306
self, request_params: Dict[str, Any]
@@ -185,7 +340,7 @@ def create_api_key_credential_provider(
185340
@refresh_credentials
186341
def get_workload_access_token(
187342
self,
188-
workload_name: str,
343+
workload_name: Optional[str] = None,
189344
user_token: Optional[str] = None,
190345
user_id: Optional[str] = None,
191346
) -> WorkloadToken:
@@ -533,3 +688,50 @@ async def create_oauth2_credential_provider_with_dcr(
533688

534689
# Create the credential provider with updated config
535690
return self.create_oauth2_credential_provider(request_params)
691+
692+
@refresh_credentials
693+
def check_permission(
694+
self,
695+
principal: Dict[str, str],
696+
operation: Dict[str, str],
697+
resource: Dict[str, str],
698+
namespace: str = "default",
699+
) -> bool:
700+
"""Check if the principal has permission to perform the operation on the resource.
701+
702+
Args:
703+
principal: Principal information, e.g., {"Type": "User", "Id": "user123"}
704+
operation: Operation to check, e.g., {"Type": "Action", "Id": "invoke"}
705+
resource: Resource information, e.g., {"Type": "Agent", "Id": "agent456"}
706+
namespace: Namespace of the resource. Defaults to "default".
707+
708+
Returns:
709+
True if the principal has permission, False otherwise.
710+
711+
Raises:
712+
ValueError: If input parameters are invalid
713+
RuntimeError: If the permission check API call fails
714+
"""
715+
logger.info(
716+
f"Checking permission for principal {principal['Id']} on resource {resource['Id']} for operation {operation['Id']}..."
717+
)
718+
719+
request = volcenginesdkid.CheckPermissionRequest(
720+
namespace_name=namespace,
721+
operation=operation,
722+
principal=principal,
723+
resource=resource,
724+
)
725+
726+
response: volcenginesdkid.CheckPermissionResponse = (
727+
self._api_client.check_permission(request)
728+
)
729+
730+
if not hasattr(response, "allowed"):
731+
logger.error("Permission check failed")
732+
return False
733+
734+
logger.info(
735+
f"Permission check result for principal {principal['Id']} on resource {resource['Id']}: {response.allowed}"
736+
)
737+
return response.allowed

veadk/integrations/ve_identity/models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,3 +220,9 @@ def validate_expires_at_positive(cls, v: int) -> int:
220220
if v <= 0:
221221
raise ValueError("expires_at must be a positive Unix timestamp")
222222
return v
223+
224+
225+
class AssumeRoleCredential(BaseModel):
226+
access_key_id: str
227+
secret_access_key: str
228+
session_token: str

veadk/integrations/ve_identity/token_manager.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,6 @@ async def get_workload_token(
132132
f"Cached workload token expired for agent '{tool_context.agent_name}', refreshing..."
133133
)
134134

135-
# Default to agent_name if workload_name not specified
136-
if not workload_name:
137-
workload_name = tool_context.agent_name
138-
139135
# Determine user_id based on authentication mode
140136
user_id = None if user_token else tool_context._invocation_context.user_id
141137

0 commit comments

Comments
 (0)