diff --git a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/__init__.py b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/__init__.py index 3aead11b3..bf7d08c0a 100644 --- a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/__init__.py +++ b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/__init__.py @@ -2,5 +2,10 @@ # SPDX-License-Identifier: Apache-2.0 from .environment import EnvironmentCredentialsResolver from .static import StaticCredentialsResolver +from .imds import IMDSCredentialsResolver -__all__ = ("EnvironmentCredentialsResolver", "StaticCredentialsResolver") +__all__ = ( + "EnvironmentCredentialsResolver", + "StaticCredentialsResolver", + "IMDSCredentialsResolver", +) diff --git a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py new file mode 100644 index 000000000..1ba05ab4c --- /dev/null +++ b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py @@ -0,0 +1,234 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import json +import asyncio +import smithy_aws_core +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from typing import Literal + +from smithy_core import URI +from smithy_core.aio.interfaces.identity import IdentityResolver +from smithy_core.exceptions import SmithyIdentityException +from smithy_core.interfaces.identity import IdentityProperties +from smithy_core.interfaces.retries import RetryStrategy +from smithy_core.retries import SimpleRetryStrategy +from smithy_http import Field, Fields +from smithy_http.aio import HTTPRequest +from smithy_http.aio.interfaces import HTTPClient + +from smithy_aws_core.identity import AWSCredentialsIdentity + +_USER_AGENT_FIELD = Field( + name="User-Agent", + values=[f"aws-sdk-python-imds-client/{smithy_aws_core.__version__}"], +) + + +@dataclass(init=False) +class Config: + """Configuration for EC2Metadata.""" + + _HOST_MAPPING = {"IPv4": "169.254.169.254", "IPv6": "[fd00:ec2::254]"} + _MIN_TTL = 5 + _MAX_TTL = 21600 + + retry_strategy: RetryStrategy + endpoint_uri: URI + endpoint_mode: Literal["IPv4", "IPv6"] + token_ttl: int + + def __init__( + self, + *, + retry_strategy: RetryStrategy | None = None, + endpoint_uri: URI | None = None, + endpoint_mode: Literal["IPv4", "IPv6"] = "IPv4", + token_ttl: int = _MAX_TTL, + ec2_instance_profile_name: str | None = None, + ): + # TODO: Implement retries. + self.retry_strategy = retry_strategy or SimpleRetryStrategy(max_attempts=3) + self.endpoint_mode = endpoint_mode + self.endpoint_uri = self._resolve_endpoint(endpoint_uri, endpoint_mode) + self.token_ttl = self._validate_token_ttl(token_ttl) + self.ec2_instance_profile_name = ec2_instance_profile_name + + def _validate_token_ttl(self, ttl: int) -> int: + if not self._MIN_TTL <= ttl <= self._MAX_TTL: + raise ValueError( + f"Token TTL must be between {self._MIN_TTL} and {self._MAX_TTL} seconds." + ) + return ttl + + def _resolve_endpoint( + self, endpoint_uri: URI | None, endpoint_mode: Literal["IPv4", "IPv6"] + ) -> URI: + if endpoint_uri is not None: + return endpoint_uri + + return URI( + scheme="http", + host=self._HOST_MAPPING.get(endpoint_mode, self._HOST_MAPPING["IPv4"]), + port=80, + ) + + +class Token: + """Represents an IMDSv2 session token with a value and method for checking + expiration.""" + + def __init__(self, value: str, ttl: int): + self._value = value + self._ttl = ttl + self._created_time = datetime.now() + + def is_expired(self) -> bool: + return datetime.now() - self._created_time >= timedelta(seconds=self._ttl) + + @property + def value(self) -> str: + return self._value + + +class TokenCache: + """Holds the token needed to fetch instance metadata. + + In addition, it knows how to refresh itself. + """ + + _TOKEN_PATH = "/latest/api/token" + + def __init__(self, http_client: HTTPClient, config: Config): + self._http_client = http_client + self._config = config + self._base_uri = config.endpoint_uri + self._refresh_lock = asyncio.Lock() + self._token = None + + def _should_refresh(self) -> bool: + return self._token is None or self._token.is_expired() + + async def _refresh(self) -> None: + async with self._refresh_lock: + if not self._should_refresh(): + return + headers = Fields( + [ + _USER_AGENT_FIELD, + Field( + name="x-aws-ec2-metadata-token-ttl-seconds", + values=[str(self._config.token_ttl)], + ), + ] + ) + request = HTTPRequest( + method="PUT", + destination=URI( + scheme=self._base_uri.scheme, + host=self._base_uri.host, + port=self._base_uri.port, + path=self._TOKEN_PATH, + ), + fields=headers, + ) + response = await self._http_client.send(request) + token_value = await response.consume_body_async() + self._token = Token(token_value.decode("utf-8"), self._config.token_ttl) + + async def get_token(self) -> Token: + if self._should_refresh(): + await self._refresh() + assert self._token is not None + return self._token + + +class EC2Metadata: + def __init__(self, http_client: HTTPClient, config: Config | None = None): + self._http_client = http_client + self._config = config or Config() + self._token_cache = TokenCache( + http_client=self._http_client, config=self._config + ) + + async def get(self, *, path: str) -> str: + token = await self._token_cache.get_token() + headers = Fields( + [ + _USER_AGENT_FIELD, + Field( + name="x-aws-ec2-metadata-token", + values=[token.value], + ), + ] + ) + request = HTTPRequest( + method="GET", + destination=URI( + scheme=self._config.endpoint_uri.scheme, + host=self._config.endpoint_uri.host, + port=self._config.endpoint_uri.port, + path=path, + ), + fields=headers, + ) + response = await self._http_client.send(request=request) + body = await response.consume_body_async() + return body.decode("utf-8") + + +class IMDSCredentialsResolver( + IdentityResolver[AWSCredentialsIdentity, IdentityProperties] +): + """Resolves AWS Credentials from an EC2 Instance Metadata Service (IMDS) client.""" + + _METADATA_PATH_BASE = "/latest/meta-data/iam/security-credentials" + + def __init__(self, http_client: HTTPClient, config: Config | None = None): + # TODO: Respect IMDS specific config values from aws shared config file and environment. + self._http_client = http_client + self._ec2_metadata_client = EC2Metadata(http_client=http_client, config=config) + self._config = config or Config() + self._credentials = None + self._profile_name = self._config.ec2_instance_profile_name + + async def get_identity( + self, *, identity_properties: IdentityProperties + ) -> AWSCredentialsIdentity: + if ( + self._credentials is not None + and self._credentials.expiration + and datetime.now(timezone.utc) < self._credentials.expiration + ): + return self._credentials + + profile = self._profile_name + if profile is None: + profile = await self._ec2_metadata_client.get(path=self._METADATA_PATH_BASE) + + creds_str = await self._ec2_metadata_client.get( + path=f"{self._METADATA_PATH_BASE}/{profile}" + ) + creds = json.loads(creds_str) + + access_key_id = creds.get("AccessKeyId") + secret_access_key = creds.get("SecretAccessKey") + session_token = creds.get("Token") + account_id = creds.get("AccountId") + expiration = creds.get("Expiration") + if expiration is not None: + expiration = datetime.fromisoformat(expiration).replace(tzinfo=timezone.utc) + + if access_key_id is None or secret_access_key is None: + raise SmithyIdentityException( + "AccessKeyId and SecretAccessKey are required" + ) + + self._credentials = AWSCredentialsIdentity( + access_key_id=access_key_id, + secret_access_key=secret_access_key, + session_token=session_token, + expiration=expiration, + account_id=account_id, + ) + return self._credentials diff --git a/packages/smithy-aws-core/tests/unit/credentials_resolvers/test_imds.py b/packages/smithy-aws-core/tests/unit/credentials_resolvers/test_imds.py new file mode 100644 index 000000000..ebee43f17 --- /dev/null +++ b/packages/smithy-aws-core/tests/unit/credentials_resolvers/test_imds.py @@ -0,0 +1,180 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +# pyright: reportPrivateUsage=false +import json +import pytest +import time +from datetime import datetime, timezone +from smithy_core.retries import SimpleRetryStrategy +from smithy_core import URI +from smithy_http.aio import HTTPRequest +from smithy_aws_core.credentials_resolvers.imds import ( + Config, + Token, + TokenCache, + EC2Metadata, + IMDSCredentialsResolver, +) +from unittest.mock import MagicMock, AsyncMock + + +def test_config_defaults(): + config = Config() + assert isinstance(config.retry_strategy, SimpleRetryStrategy) + assert config.endpoint_uri == URI( + scheme="http", host=Config._HOST_MAPPING["IPv4"], port=80 + ) + assert config.endpoint_mode == "IPv4" + assert config.token_ttl == 21600 + + +def test_endpoint_resolution(): + config_ipv4 = Config(endpoint_mode="IPv4") + config_ipv6 = Config(endpoint_mode="IPv6") + assert config_ipv4.endpoint_uri.host == Config._HOST_MAPPING["IPv4"] + assert config_ipv6.endpoint_uri.host == Config._HOST_MAPPING["IPv6"] + + +def test_config_uses_custom_endpoint(): + # The custom endpoint should take precedence over IPv4 endpoint resolution. + config = Config( + endpoint_uri=URI(scheme="https", host="test.host", port=123), + endpoint_mode="IPv4", + ) + assert config.endpoint_uri == URI(scheme="https", host="test.host", port=123) + + # The custom endpoint takes precedence over IPv6 endpoint resolution. + config = Config( + endpoint_uri=URI(scheme="https", host="test.host", port=123), + endpoint_mode="IPv6", + ) + assert config.endpoint_uri == URI(scheme="https", host="test.host", port=123) + + +def test_config_ttl_validation(): + # TTL values < _MIN_TTL should throw a ValueError + with pytest.raises(ValueError): + Config(token_ttl=Config._MIN_TTL - 1) + # TTL values > _MAX_TTL should throw a ValueError + with pytest.raises(ValueError): + Config(token_ttl=Config._MAX_TTL + 1) + + +def test_token_creation(): + token = Token(value="test-token", ttl=100) + assert token._value == "test-token" + assert token._ttl == 100 + assert not token.is_expired() + + +def test_token_expiration(): + token = Token(value="test-token", ttl=1) + assert not token.is_expired() + time.sleep(1.1) + assert token.is_expired() + + +async def test_token_cache_should_refresh(): + http_client = AsyncMock() + config = MagicMock() + # A new token cache needs a refresh + token_cache = TokenCache(http_client, config) + assert token_cache._should_refresh() + # A token cache with an unexpired token doesn't need a refresh + token_cache._token = MagicMock() + token_cache._token.is_expired.return_value = False + assert not token_cache._should_refresh() + # A token cache with an expired token needs a refresh + token_cache._token.is_expired.return_value = True + assert token_cache._should_refresh() + + +async def test_token_cache_refresh(): + # Test that TokenCache correctly refreshes the token when needed + http_client = AsyncMock() + config = MagicMock() + config.token_ttl = 100 + config.endpoint_uri.scheme = "http" + config.endpoint_uri.host = "169.254.169.254" + response_mock = AsyncMock() + response_mock.consume_body_async.return_value = b"new-token-value" + http_client.send.return_value = response_mock + token_cache = TokenCache(http_client, config) + assert token_cache._should_refresh() + await token_cache._refresh() + assert token_cache._token is not None + assert token_cache._token.value == "new-token-value" + assert token_cache._token._ttl == 100 + + +async def test_token_cache_get_token(): + # Test that TokenCache correctly returns an existing token or refreshes if expired + http_client = AsyncMock() + config = MagicMock() + token_cache = TokenCache(http_client, config) + token_cache._refresh = AsyncMock() + token_cache._token = MagicMock() + token_cache._token.is_expired.return_value = False + token = await token_cache.get_token() + assert token == token_cache._token + token_cache._refresh.assert_not_awaited() + token_cache._token.is_expired.return_value = True + await token_cache.get_token() + token_cache._refresh.assert_awaited() + + +async def test_ec2_metadata_get(): + # Test EC2Metadata.get() method to retrieve metadata from IMDS + http_client = AsyncMock() + config = Config() + response = AsyncMock() + response.consume_body_async.return_value = b"metadata-response" + http_client.send.return_value = response + + ec2_metadata = EC2Metadata(http_client, config) + ec2_metadata._token_cache.get_token = AsyncMock( + return_value=Token("mocked-token", config.token_ttl) + ) + + result = await ec2_metadata.get(path="/test-path") + assert result == "metadata-response" + + request = http_client.send.call_args.kwargs["request"] + assert isinstance(request, HTTPRequest) + assert request.destination.path == "/test-path" + assert request.method == "GET" + assert request.fields["x-aws-ec2-metadata-token"].values == ["mocked-token"] + + +async def test_imds_credentials_resolver(): + # Test IMDSCredentialsResolver retrieving credentials + http_client = AsyncMock() + config = Config() + ec2_metadata = AsyncMock() + resolver = IMDSCredentialsResolver(http_client, config) + resolver._ec2_metadata_client = ec2_metadata + + # Mock EC2Metadata client get responses + ec2_metadata.get.side_effect = [ + "test-profile", + json.dumps( + { + "AccessKeyId": "test-access-key", + "SecretAccessKey": "test-secret-key", + "Token": "test-session-token", + "AccountId": "test-account", + "Expiration": "2025-03-13T07:28:47Z", + } + ), + ] + + credentials = await resolver.get_identity(identity_properties=MagicMock()) + assert credentials.access_key_id == "test-access-key" + assert credentials.secret_access_key == "test-secret-key" + assert credentials.session_token == "test-session-token" + assert credentials.account_id == "test-account" + assert credentials.expiration == datetime( + 2025, 3, 13, 7, 28, 47, tzinfo=timezone.utc + ) + ec2_metadata.get.assert_awaited() diff --git a/packages/smithy-http/src/smithy_http/aio/crt.py b/packages/smithy-http/src/smithy_http/aio/crt.py index 0e17eba49..6183a232b 100644 --- a/packages/smithy-http/src/smithy_http/aio/crt.py +++ b/packages/smithy-http/src/smithy_http/aio/crt.py @@ -273,13 +273,15 @@ async def _get_connection( ) -> "crt_http.HttpClientConnection": # TODO: Use CRT connection pooling instead of this basic kind connection_key = (url.scheme, url.host, url.port) - if connection_key in self._connections: - return self._connections[connection_key] - else: - connection = await self._create_connection(url) - self._connections[connection_key] = connection + connection = self._connections.get(connection_key) + + if connection and connection.is_open(): return connection + connection = await self._create_connection(url) + self._connections[connection_key] = connection + return connection + def _build_new_connection( self, url: core_interfaces.URI ) -> ConcurrentFuture["crt_http.HttpClientConnection"]: