Skip to content

Commit b6255e1

Browse files
committed
Move ttl validation to Config. Make config a parameter for TokenCache.
1 parent 43c8250 commit b6255e1

File tree

1 file changed

+58
-61
lines changed
  • packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers

1 file changed

+58
-61
lines changed

packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py

Lines changed: 58 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,58 @@
1919
from smithy_aws_core.identity import AWSCredentialsIdentity
2020

2121

22+
@dataclass(init=False)
23+
class Config:
24+
"""Configuration for EC2Metadata."""
25+
26+
_HOST_MAPPING = {"IPv4": "169.254.169.254", "IPv6": "[fd00:ec2::254]"}
27+
_MIN_TTL = 5
28+
_MAX_TTL = 21600
29+
30+
retry_strategy: RetryStrategy
31+
endpoint_uri: URI
32+
endpoint_mode: Literal["IPv4", "IPv6"]
33+
port: int
34+
token_ttl: int
35+
36+
def __init__(
37+
self,
38+
*,
39+
retry_strategy: RetryStrategy | None = None,
40+
endpoint_uri: URI | None = None,
41+
endpoint_mode: Literal["IPv4", "IPv6"] = "IPv4",
42+
port: int = 80,
43+
token_ttl: int = _MAX_TTL,
44+
ec2_instance_profile_name: str | None = None,
45+
):
46+
# TODO: Implement retries.
47+
self.retry_strategy = retry_strategy or SimpleRetryStrategy(max_attempts=3)
48+
self.endpoint_mode = endpoint_mode
49+
self.endpoint_uri = self._resolve_endpoint(endpoint_uri, endpoint_mode)
50+
self.port = port
51+
self.token_ttl = self._validate_token_ttl(token_ttl)
52+
self.ec2_instance_profile_name = ec2_instance_profile_name
53+
54+
def _validate_token_ttl(self, ttl: int) -> int:
55+
"""Validates the token TTL value."""
56+
if not self._MIN_TTL <= ttl <= self._MAX_TTL:
57+
raise ValueError(
58+
f"Token TTL must be between {self._MIN_TTL} and {self._MAX_TTL} seconds."
59+
)
60+
return ttl
61+
62+
def _resolve_endpoint(
63+
self, endpoint_uri: URI | None, endpoint_mode: Literal["IPv4", "IPv6"]
64+
) -> URI:
65+
if endpoint_uri is not None:
66+
return endpoint_uri
67+
68+
return URI(
69+
scheme="http",
70+
host=self._HOST_MAPPING.get(endpoint_mode, self._HOST_MAPPING["IPv4"]),
71+
)
72+
73+
2274
class Token:
2375
"""Represents an IMDSv2 session token with a value and method for checking
2476
expiration."""
@@ -43,27 +95,15 @@ class TokenCache:
4395
In addition, it knows how to refresh itself.
4496
"""
4597

46-
_MIN_TTL = 5
47-
_MAX_TTL = 21600
4898
_TOKEN_PATH = "/latest/api/token"
4999

50-
def __init__(
51-
self, http_client: HTTPClient, base_uri: URI, token_ttl: int = _MAX_TTL
52-
):
100+
def __init__(self, http_client: HTTPClient, config: Config):
53101
self._http_client = http_client
54-
self._base_uri = base_uri
55-
self._token_ttl = self._validate_token_ttl(token_ttl)
102+
self._config = config
103+
self._base_uri = config.endpoint_uri
56104
self._refresh_lock = asyncio.Lock()
57105
self._token = None
58106

59-
def _validate_token_ttl(self, ttl: int) -> int:
60-
"""Validates the token TTL value."""
61-
if not self._MIN_TTL <= ttl <= self._MAX_TTL:
62-
raise ValueError(
63-
f"Token TTL must be between {self._MIN_TTL} and {self._MAX_TTL} seconds."
64-
)
65-
return ttl
66-
67107
def _should_refresh(self) -> bool:
68108
"""Determines if the token should be refreshed."""
69109
return self._token is None or self._token.is_expired()
@@ -78,7 +118,7 @@ async def _refresh(self) -> None:
78118
# TODO: Add user-agent
79119
Field(
80120
name="x-aws-ec2-metadata-token-ttl-seconds",
81-
values=[str(self._token_ttl)],
121+
values=[str(self._config.token_ttl)],
82122
),
83123
]
84124
)
@@ -93,7 +133,7 @@ async def _refresh(self) -> None:
93133
)
94134
response = await self._http_client.send(request)
95135
token_value = await response.consume_body_async()
96-
self._token = Token(token_value, self._token_ttl)
136+
self._token = Token(token_value, self._config.token_ttl)
97137

98138
async def get_token(self) -> Token:
99139
"""Get the current token, refreshing it if expired."""
@@ -103,55 +143,12 @@ async def get_token(self) -> Token:
103143
return self._token
104144

105145

106-
@dataclass(init=False)
107-
class Config:
108-
"""Configuration for EC2Metadata."""
109-
110-
_HOST_MAPPING = {"IPv4": "169.254.169.254", "IPv6": "[fd00:ec2::254]"}
111-
112-
retry_strategy: RetryStrategy
113-
endpoint_uri: URI
114-
endpoint_mode: Literal["IPv4", "IPv6"]
115-
port: int
116-
token_ttl: int
117-
118-
def __init__(
119-
self,
120-
*,
121-
retry_strategy: RetryStrategy | None = None,
122-
endpoint_uri: URI | None = None,
123-
endpoint_mode: Literal["IPv4", "IPv6"] = "IPv4",
124-
port: int = 80,
125-
token_ttl: int = 21600,
126-
ec2_instance_profile_name: str | None = None,
127-
):
128-
self.retry_strategy = retry_strategy or SimpleRetryStrategy(max_attempts=3)
129-
self.endpoint_mode = endpoint_mode
130-
self.endpoint_uri = self._resolve_endpoint(endpoint_uri, endpoint_mode)
131-
self.port = port
132-
self.token_ttl = token_ttl
133-
self.ec2_instance_profile_name = ec2_instance_profile_name
134-
135-
def _resolve_endpoint(
136-
self, endpoint_uri: URI | None, endpoint_mode: Literal["IPv4", "IPv6"]
137-
) -> URI:
138-
if endpoint_uri is not None:
139-
return endpoint_uri
140-
141-
return URI(
142-
scheme="http",
143-
host=self._HOST_MAPPING.get(endpoint_mode, self._HOST_MAPPING["IPv4"]),
144-
)
145-
146-
147146
class EC2Metadata:
148147
def __init__(self, http_client: HTTPClient, config: Config | None = None):
149148
self._http_client = http_client
150149
self._config = config or Config()
151150
self._token_cache = TokenCache(
152-
http_client=self._http_client,
153-
base_uri=self._config.endpoint_uri,
154-
token_ttl=self._config.token_ttl,
151+
http_client=self._http_client, config=self._config
155152
)
156153

157154
async def get(self, *, path: str) -> str:

0 commit comments

Comments
 (0)