Skip to content

Commit 4be00b9

Browse files
committed
Add basic IMDS credentials resolver
1 parent 40df093 commit 4be00b9

File tree

2 files changed

+236
-1
lines changed

2 files changed

+236
-1
lines changed

packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,10 @@
22
# SPDX-License-Identifier: Apache-2.0
33
from .environment import EnvironmentCredentialsResolver
44
from .static import StaticCredentialsResolver
5+
from .imds import IMDSCredentialsResolver
56

6-
__all__ = ("EnvironmentCredentialsResolver", "StaticCredentialsResolver")
7+
__all__ = (
8+
"EnvironmentCredentialsResolver",
9+
"StaticCredentialsResolver",
10+
"IMDSCredentialsResolver",
11+
)
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
import json
4+
import threading
5+
from dataclasses import dataclass
6+
from datetime import datetime, timedelta
7+
from typing import Literal
8+
9+
from smithy_core import URI
10+
from smithy_core.aio.interfaces.identity import IdentityResolver
11+
from smithy_core.exceptions import SmithyIdentityException
12+
from smithy_core.interfaces.identity import IdentityProperties
13+
from smithy_core.interfaces.retries import RetryStrategy
14+
from smithy_core.retries import SimpleRetryStrategy
15+
from smithy_http import Field, Fields
16+
from smithy_http.aio import HTTPRequest
17+
from smithy_http.aio.interfaces import HTTPClient
18+
19+
from smithy_aws_core.identity import AWSCredentialsIdentity
20+
21+
22+
class Token:
23+
"""Represents an IMDSv2 session token with a value and method for checking
24+
expiration."""
25+
26+
def __init__(self, value: bytes, ttl: int):
27+
self._value = value
28+
self._ttl = ttl
29+
self._created_time = datetime.now()
30+
31+
def is_expired(self) -> bool:
32+
"""Check if the token has expired."""
33+
return datetime.now() - self._created_time >= timedelta(seconds=self._ttl)
34+
35+
@property
36+
def value(self) -> bytes:
37+
return self._value
38+
39+
40+
class TokenCache:
41+
"""Holds the token needed to fetch instance metadata. In addition, it knows how to
42+
refresh itself.
43+
44+
:param HTTPClient http_client: The client used for making http requests.
45+
:param int token_ttl: The time in seconds before a token expires.
46+
"""
47+
48+
_MIN_TTL = 5
49+
_MAX_TTL = 21600
50+
_TOKEN_PATH = "/latest/api/token"
51+
52+
def __init__(
53+
self, http_client: HTTPClient, base_uri: URI, token_ttl: int = _MAX_TTL
54+
):
55+
self._http_client = http_client
56+
self._base_uri = base_uri
57+
self._token_ttl = self._validate_token_ttl(token_ttl)
58+
self._refresh_lock = threading.Lock()
59+
self._token = None
60+
61+
def _validate_token_ttl(self, ttl: int) -> int:
62+
"""Validates the token TTL value."""
63+
if not self._MIN_TTL <= ttl <= self._MAX_TTL:
64+
raise ValueError(
65+
f"Token TTL must be between {self._MIN_TTL} and {self._MAX_TTL} seconds."
66+
)
67+
return ttl
68+
69+
def _should_refresh(self) -> bool:
70+
"""Determines if the token should be refreshed."""
71+
return self._token is None or self._token.is_expired()
72+
73+
async def _refresh(self) -> None:
74+
"""Refreshes the token if needed, with thread safety."""
75+
with self._refresh_lock:
76+
if not self._should_refresh():
77+
return
78+
headers = Fields(
79+
[
80+
# TODO: Add user-agent
81+
Field(
82+
name="x-aws-ec2-metadata-token-ttl-seconds",
83+
values=[str(self._token_ttl)],
84+
),
85+
]
86+
)
87+
request = HTTPRequest(
88+
method="PUT",
89+
destination=URI(
90+
scheme=self._base_uri.scheme,
91+
host=self._base_uri.host,
92+
path=self._TOKEN_PATH,
93+
),
94+
fields=headers,
95+
)
96+
response = await self._http_client.send(request)
97+
token_value = await response.consume_body_async()
98+
self._token = Token(token_value, self._token_ttl)
99+
100+
async def get_token(self) -> Token:
101+
"""Get the current token, refreshing it if expired."""
102+
if self._should_refresh():
103+
await self._refresh()
104+
assert self._token is not None
105+
return self._token
106+
107+
108+
@dataclass(init=False)
109+
class Config:
110+
"""Configuration for EC2Metadata."""
111+
112+
retry_strategy: RetryStrategy
113+
endpoint_uri: URI
114+
endpoint_mode: Literal["IPv4", "IPv6"]
115+
port: int
116+
token_ttl: int
117+
118+
def __init__(
119+
self,
120+
*,
121+
retry_strategy: RetryStrategy | None = None,
122+
endpoint_uri: URI | None = None,
123+
endpoint_mode: Literal["IPv4", "IPv6"] = "IPv4",
124+
port: int = 80,
125+
token_ttl: int = 21600,
126+
ec2_instance_profile_name: str | None = None,
127+
):
128+
self.retry_strategy = retry_strategy or SimpleRetryStrategy(max_attempts=3)
129+
self.endpoint_mode = endpoint_mode
130+
self.endpoint_uri = self._resolve_endpoint(endpoint_uri, endpoint_mode)
131+
self.port = port
132+
self.token_ttl = token_ttl
133+
self.ec2_instance_profile_name = ec2_instance_profile_name
134+
135+
def _resolve_endpoint(
136+
self, endpoint_uri: URI | None, endpoint_mode: Literal["IPv4", "IPv6"]
137+
) -> URI:
138+
if endpoint_uri is not None:
139+
return endpoint_uri
140+
141+
host_mapping = {"IPv4": "169.254.169.254", "IPv6": "[fd00:ec2::254]"}
142+
143+
return URI(
144+
scheme="http", host=host_mapping.get(endpoint_mode, host_mapping["IPv4"])
145+
)
146+
147+
148+
class EC2Metadata:
149+
def __init__(self, http_client: HTTPClient, config: Config | None = None):
150+
self._http_client = http_client
151+
self._config = config or Config()
152+
self._token_cache = TokenCache(
153+
http_client=self._http_client,
154+
base_uri=self._config.endpoint_uri,
155+
token_ttl=self._config.token_ttl,
156+
)
157+
158+
async def get(self, *, path: str) -> str:
159+
token = await self._token_cache.get_token()
160+
headers = Fields(
161+
[
162+
# TODO: Add user-agent
163+
Field(
164+
name="x-aws-ec2-metadata-token",
165+
values=[token.value.decode("utf-8")],
166+
)
167+
]
168+
)
169+
request = HTTPRequest(
170+
method="GET",
171+
destination=URI(
172+
scheme=self._config.endpoint_uri.scheme,
173+
host=self._config.endpoint_uri.host,
174+
port=self._config.port,
175+
path=path,
176+
),
177+
fields=headers,
178+
)
179+
response = await self._http_client.send(request=request)
180+
body = await response.consume_body_async()
181+
return body.decode("utf-8")
182+
183+
184+
class IMDSCredentialsResolver(
185+
IdentityResolver[AWSCredentialsIdentity, IdentityProperties]
186+
):
187+
"""Resolves AWS Credentials from an EC2 Instance Metadata Service (IMDS) client."""
188+
189+
# TODO: Handle fallback to legacy path when a 404 is received.
190+
_METADATA_PATH_BASE = "/latest/meta-data/iam/security-credentials-extended/"
191+
192+
def __init__(self, http_client: HTTPClient, config: Config | None = None):
193+
self._http_client = http_client
194+
self._ec2_metadata_client = EC2Metadata(http_client=http_client, config=config)
195+
self._config = config or Config()
196+
self._credentials = None
197+
self._profile_name = self._config.ec2_instance_profile_name
198+
199+
async def get_identity(
200+
self, *, identity_properties: IdentityProperties
201+
) -> AWSCredentialsIdentity:
202+
if self._credentials is not None:
203+
return self._credentials
204+
205+
profile = self._profile_name
206+
if profile is None:
207+
profile = await self._ec2_metadata_client.get(path=self._METADATA_PATH_BASE)
208+
209+
creds_str = await self._ec2_metadata_client.get(
210+
path=f"{self._METADATA_PATH_BASE}/{profile}"
211+
)
212+
creds = json.loads(creds_str)
213+
214+
access_key_id = creds.get("AccessKeyId")
215+
secret_access_key = creds.get("SecretAccessKey")
216+
session_token = creds.get("Token")
217+
account_id = creds.get("AccountId")
218+
219+
if access_key_id is None or secret_access_key is None:
220+
raise SmithyIdentityException(
221+
"AccessKeyId and SecretAccessKey are required"
222+
)
223+
224+
self._credentials = AWSCredentialsIdentity(
225+
access_key_id=access_key_id,
226+
secret_access_key=secret_access_key,
227+
session_token=session_token,
228+
account_id=account_id,
229+
)
230+
return self._credentials

0 commit comments

Comments
 (0)