1919from smithy_aws_core .identity import AWSCredentialsIdentity
2020
2121
22+ @dataclass (init = False )
23+ class Config :
24+ """Configuration for EC2Metadata."""
25+
26+ _HOST_MAPPING = {"IPv4" : "169.254.169.254" , "IPv6" : "[fd00:ec2::254]" }
27+ _MIN_TTL = 5
28+ _MAX_TTL = 21600
29+
30+ retry_strategy : RetryStrategy
31+ endpoint_uri : URI
32+ endpoint_mode : Literal ["IPv4" , "IPv6" ]
33+ port : int
34+ token_ttl : int
35+
36+ def __init__ (
37+ self ,
38+ * ,
39+ retry_strategy : RetryStrategy | None = None ,
40+ endpoint_uri : URI | None = None ,
41+ endpoint_mode : Literal ["IPv4" , "IPv6" ] = "IPv4" ,
42+ port : int = 80 ,
43+ token_ttl : int = _MAX_TTL ,
44+ ec2_instance_profile_name : str | None = None ,
45+ ):
46+ # TODO: Implement retries.
47+ self .retry_strategy = retry_strategy or SimpleRetryStrategy (max_attempts = 3 )
48+ self .endpoint_mode = endpoint_mode
49+ self .endpoint_uri = self ._resolve_endpoint (endpoint_uri , endpoint_mode )
50+ self .port = port
51+ self .token_ttl = self ._validate_token_ttl (token_ttl )
52+ self .ec2_instance_profile_name = ec2_instance_profile_name
53+
54+ def _validate_token_ttl (self , ttl : int ) -> int :
55+ """Validates the token TTL value."""
56+ if not self ._MIN_TTL <= ttl <= self ._MAX_TTL :
57+ raise ValueError (
58+ f"Token TTL must be between { self ._MIN_TTL } and { self ._MAX_TTL } seconds."
59+ )
60+ return ttl
61+
62+ def _resolve_endpoint (
63+ self , endpoint_uri : URI | None , endpoint_mode : Literal ["IPv4" , "IPv6" ]
64+ ) -> URI :
65+ if endpoint_uri is not None :
66+ return endpoint_uri
67+
68+ return URI (
69+ scheme = "http" ,
70+ host = self ._HOST_MAPPING .get (endpoint_mode , self ._HOST_MAPPING ["IPv4" ]),
71+ )
72+
73+
2274class Token :
2375 """Represents an IMDSv2 session token with a value and method for checking
2476 expiration."""
@@ -43,27 +95,15 @@ class TokenCache:
4395 In addition, it knows how to refresh itself.
4496 """
4597
46- _MIN_TTL = 5
47- _MAX_TTL = 21600
4898 _TOKEN_PATH = "/latest/api/token"
4999
50- def __init__ (
51- self , http_client : HTTPClient , base_uri : URI , token_ttl : int = _MAX_TTL
52- ):
100+ def __init__ (self , http_client : HTTPClient , config : Config ):
53101 self ._http_client = http_client
54- self ._base_uri = base_uri
55- self ._token_ttl = self . _validate_token_ttl ( token_ttl )
102+ self ._config = config
103+ self ._base_uri = config . endpoint_uri
56104 self ._refresh_lock = asyncio .Lock ()
57105 self ._token = None
58106
59- def _validate_token_ttl (self , ttl : int ) -> int :
60- """Validates the token TTL value."""
61- if not self ._MIN_TTL <= ttl <= self ._MAX_TTL :
62- raise ValueError (
63- f"Token TTL must be between { self ._MIN_TTL } and { self ._MAX_TTL } seconds."
64- )
65- return ttl
66-
67107 def _should_refresh (self ) -> bool :
68108 """Determines if the token should be refreshed."""
69109 return self ._token is None or self ._token .is_expired ()
@@ -78,7 +118,7 @@ async def _refresh(self) -> None:
78118 # TODO: Add user-agent
79119 Field (
80120 name = "x-aws-ec2-metadata-token-ttl-seconds" ,
81- values = [str (self ._token_ttl )],
121+ values = [str (self ._config . token_ttl )],
82122 ),
83123 ]
84124 )
@@ -93,7 +133,7 @@ async def _refresh(self) -> None:
93133 )
94134 response = await self ._http_client .send (request )
95135 token_value = await response .consume_body_async ()
96- self ._token = Token (token_value , self ._token_ttl )
136+ self ._token = Token (token_value , self ._config . token_ttl )
97137
98138 async def get_token (self ) -> Token :
99139 """Get the current token, refreshing it if expired."""
@@ -103,55 +143,12 @@ async def get_token(self) -> Token:
103143 return self ._token
104144
105145
106- @dataclass (init = False )
107- class Config :
108- """Configuration for EC2Metadata."""
109-
110- _HOST_MAPPING = {"IPv4" : "169.254.169.254" , "IPv6" : "[fd00:ec2::254]" }
111-
112- retry_strategy : RetryStrategy
113- endpoint_uri : URI
114- endpoint_mode : Literal ["IPv4" , "IPv6" ]
115- port : int
116- token_ttl : int
117-
118- def __init__ (
119- self ,
120- * ,
121- retry_strategy : RetryStrategy | None = None ,
122- endpoint_uri : URI | None = None ,
123- endpoint_mode : Literal ["IPv4" , "IPv6" ] = "IPv4" ,
124- port : int = 80 ,
125- token_ttl : int = 21600 ,
126- ec2_instance_profile_name : str | None = None ,
127- ):
128- self .retry_strategy = retry_strategy or SimpleRetryStrategy (max_attempts = 3 )
129- self .endpoint_mode = endpoint_mode
130- self .endpoint_uri = self ._resolve_endpoint (endpoint_uri , endpoint_mode )
131- self .port = port
132- self .token_ttl = token_ttl
133- self .ec2_instance_profile_name = ec2_instance_profile_name
134-
135- def _resolve_endpoint (
136- self , endpoint_uri : URI | None , endpoint_mode : Literal ["IPv4" , "IPv6" ]
137- ) -> URI :
138- if endpoint_uri is not None :
139- return endpoint_uri
140-
141- return URI (
142- scheme = "http" ,
143- host = self ._HOST_MAPPING .get (endpoint_mode , self ._HOST_MAPPING ["IPv4" ]),
144- )
145-
146-
147146class EC2Metadata :
148147 def __init__ (self , http_client : HTTPClient , config : Config | None = None ):
149148 self ._http_client = http_client
150149 self ._config = config or Config ()
151150 self ._token_cache = TokenCache (
152- http_client = self ._http_client ,
153- base_uri = self ._config .endpoint_uri ,
154- token_ttl = self ._config .token_ttl ,
151+ http_client = self ._http_client , config = self ._config
155152 )
156153
157154 async def get (self , * , path : str ) -> str :
0 commit comments