|
23 | 23 | from veadk.integrations.ve_identity.identity_client import IdentityClient |
24 | 24 | from veadk.integrations.ve_identity.token_manager import get_workload_token |
25 | 25 | from veadk.utils.logger import get_logger |
| 26 | +from veadk.utils.jwt import extract_delegation_chain_from_jwt |
26 | 27 |
|
27 | 28 | logger = get_logger(__name__) |
28 | 29 |
|
|
31 | 32 | identity_client = IdentityClient(region=region) |
32 | 33 |
|
33 | 34 |
|
34 | | -def _strip_bearer_prefix(token: str) -> str: |
35 | | - """Remove 'Bearer ' prefix from token if present. |
36 | | - Args: |
37 | | - token: Token string that may contain "Bearer " prefix |
38 | | - Returns: |
39 | | - Token without "Bearer " prefix |
40 | | - """ |
41 | | - return token[7:] if token.lower().startswith("bearer ") else token |
42 | | - |
43 | | - |
44 | | -def _extract_role_id_from_jwt(token: str) -> Optional[str]: |
45 | | - """Extract role_id (sub field) from JWT token. |
46 | | - Args: |
47 | | - token: JWT token string (with or without "Bearer " prefix) |
48 | | - Returns: |
49 | | - Role ID from sub field, or None if parsing fails |
50 | | - """ |
51 | | - try: |
52 | | - # Remove "Bearer " prefix if present |
53 | | - token = _strip_bearer_prefix(token) |
54 | | - |
55 | | - # JWT token has 3 parts separated by dots: header.payload.signature |
56 | | - parts = token.split(".") |
57 | | - if len(parts) != 3: |
58 | | - logger.error("Invalid JWT format: expected 3 parts") |
59 | | - return None |
60 | | - |
61 | | - # Decode payload (second part) |
62 | | - payload_part = parts[1] |
63 | | - |
64 | | - # Add padding for base64url decoding (JWT doesn't use padding) |
65 | | - missing_padding = len(payload_part) % 4 |
66 | | - if missing_padding: |
67 | | - payload_part += "=" * (4 - missing_padding) |
68 | | - |
69 | | - # Decode base64 and parse JSON |
70 | | - decoded_bytes = base64.urlsafe_b64decode(payload_part) |
71 | | - payload = json.loads(decoded_bytes.decode("utf-8")) |
72 | | - |
73 | | - # Extract sub field as role_id |
74 | | - return payload.get("act").get("sub") |
75 | | - |
76 | | - except (ValueError, json.JSONDecodeError) as e: |
77 | | - logger.error(f"Failed to parse JWT token: {e}") |
78 | | - return None |
79 | | - except Exception as e: |
80 | | - logger.error(f"Unexpected error parsing JWT: {e}") |
81 | | - return None |
82 | | - |
83 | | - |
84 | 35 | async def check_agent_authorization( |
85 | 36 | callback_context: CallbackContext, |
86 | 37 | ) -> Optional[types.Content]: |
87 | | - """Check if the agent is authorized to run using VeIdentity.""" |
88 | | - user_id = callback_context._invocation_context.user_id |
89 | | - |
| 38 | + """Check if the agent is authorized to run using Agent Identity.""" |
90 | 39 | try: |
91 | 40 | workload_token = await get_workload_token( |
92 | 41 | tool_context=callback_context, identity_client=identity_client |
93 | 42 | ) |
94 | 43 |
|
95 | | - # Parse role_id from workload_token |
96 | | - role_id = _extract_role_id_from_jwt(workload_token) |
| 44 | + # Parse user_id and actors from workload_token |
| 45 | + user_id, actors = extract_delegation_chain_from_jwt(workload_token) |
97 | 46 |
|
98 | | - principal = {"Type": "User", "Id": user_id} |
99 | | - operation = {"Type": "Action", "Id": "invoke"} |
100 | | - resource = {"Type": "Agent", "Id": role_id} |
| 47 | + if not user_id: |
| 48 | + logger.warning("Failed to extract user_id from JWT token") |
| 49 | + return types.Content( |
| 50 | + parts=[types.Part(text="Failed to verify agent authorization.")], |
| 51 | + role="model", |
| 52 | + ) |
| 53 | + |
| 54 | + if len(actors) == 0: |
| 55 | + logger.warning("Failed to extract actors from JWT token") |
| 56 | + return types.Content( |
| 57 | + parts=[types.Part(text="Failed to verify agent authorization.")], |
| 58 | + role="model", |
| 59 | + ) |
| 60 | + |
| 61 | + # The first actor in the chain is the agent itself |
| 62 | + role_id = actors[0] |
| 63 | + |
| 64 | + principal = {"Type": "user", "Id": user_id} |
| 65 | + operation = {"Type": "action", "Id": "invoke"} |
| 66 | + resource = {"Type": "agent", "Id": role_id} |
| 67 | + original_callers = [{"Type": "agent", "Id": actor} for actor in actors[1:]] |
101 | 68 |
|
102 | 69 | allowed = identity_client.check_permission( |
103 | | - principal=principal, operation=operation, resource=resource |
| 70 | + principal=principal, |
| 71 | + operation=operation, |
| 72 | + resource=resource, |
| 73 | + original_callers=original_callers, |
104 | 74 | ) |
105 | 75 |
|
106 | 76 | if allowed: |
107 | | - logger.info("Agent is authorized to run.") |
| 77 | + logger.info(f"Agent {role_id} is authorized to run by user {user_id}.") |
108 | 78 | return None |
109 | 79 | else: |
110 | | - logger.warning("Agent is not authorized to run.") |
| 80 | + logger.warning( |
| 81 | + f"Agent {role_id} is not authorized to run by user {user_id}." |
| 82 | + ) |
111 | 83 | return types.Content( |
112 | | - parts=[types.Part(text="Agent is not authorized to run.")], role="model" |
| 84 | + parts=[ |
| 85 | + types.Part( |
| 86 | + text=f"Agent {role_id} is not authorized to run by user {user_id}." |
| 87 | + ) |
| 88 | + ], |
| 89 | + role="model", |
113 | 90 | ) |
114 | 91 |
|
115 | 92 | except Exception as e: |
|
0 commit comments