1+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+ # SPDX-License-Identifier: Apache-2.0
3+
4+ # pyright: reportPrivateUsage=false
5+ import pytest
6+ import time
7+ from smithy_core .retries import SimpleRetryStrategy
8+ from smithy_core import URI
9+ from smithy_aws_core .credentials_resolvers .imds import Config , Token , TokenCache
10+ from unittest .mock import MagicMock , AsyncMock
11+
12+
13+ def test_config_defaults ():
14+ config = Config ()
15+ assert isinstance (config .retry_strategy , SimpleRetryStrategy )
16+ assert config .endpoint_uri == URI (scheme = "http" , host = Config ._HOST_MAPPING ["IPv4" ])
17+ assert config .endpoint_mode == "IPv4"
18+ assert config .port == 80
19+ assert config .token_ttl == 21600
20+
21+
22+ def test_endpoint_resolution ():
23+ config_ipv4 = Config (endpoint_mode = "IPv4" )
24+ config_ipv6 = Config (endpoint_mode = "IPv6" )
25+ assert config_ipv4 .endpoint_uri .host == Config ._HOST_MAPPING ["IPv4" ]
26+ assert config_ipv6 .endpoint_uri .host == Config ._HOST_MAPPING ["IPv6" ]
27+
28+
29+ def test_config_uses_custom_endpoint ():
30+ # The custom endpoint should take precedence over IPv4 endpoint resolution.
31+ config = Config (
32+ endpoint_uri = URI (scheme = "http" , host = "test.host" ), endpoint_mode = "IPv4"
33+ )
34+ assert config .endpoint_uri == URI (scheme = "http" , host = "test.host" )
35+
36+ # The custom endpoint takes precedence over IPv6 endpoint resolution.
37+ config = Config (
38+ endpoint_uri = URI (scheme = "http" , host = "test.host" ), endpoint_mode = "IPv6"
39+ )
40+ assert config .endpoint_uri == URI (scheme = "http" , host = "test.host" )
41+
42+
43+ def test_config_ttl_validation ():
44+ # TTL values < _MIN_TTL should throw a ValueError
45+ with pytest .raises (ValueError ):
46+ Config (token_ttl = Config ._MIN_TTL - 1 )
47+ # TTL values > _MAX_TTL should throw a ValueError
48+ with pytest .raises (ValueError ):
49+ Config (token_ttl = Config ._MAX_TTL + 1 )
50+
51+
52+ def test_token_creation ():
53+ token = Token (value = "test-token" , ttl = 100 )
54+ assert token ._value == "test-token"
55+ assert token ._ttl == 100
56+ assert not token .is_expired ()
57+
58+
59+ def test_token_expiration ():
60+ token = Token (value = "test-token" , ttl = 1 )
61+ assert not token .is_expired ()
62+ time .sleep (1.1 )
63+ assert token .is_expired ()
64+
65+
66+ async def test_token_cache_should_refresh ():
67+ http_client = AsyncMock ()
68+ config = MagicMock ()
69+ # A new token cache needs a refresh
70+ token_cache = TokenCache (http_client , config )
71+ assert token_cache ._should_refresh ()
72+ # A token cache with an unexpired token doesn't need a refresh
73+ token_cache ._token = MagicMock ()
74+ token_cache ._token .is_expired .return_value = False
75+ assert not token_cache ._should_refresh ()
76+ # A token cache with an expired token needs a refresh
77+ token_cache ._token .is_expired .return_value = True
78+ assert token_cache ._should_refresh ()
79+
80+
81+ async def test_token_cache_refresh ():
82+ # Test that TokenCache correctly refreshes the token when needed
83+ http_client = AsyncMock ()
84+ config = MagicMock ()
85+ config .token_ttl = 100
86+ config .endpoint_uri .scheme = "http"
87+ config .endpoint_uri .host = "169.254.169.254"
88+ response_mock = AsyncMock ()
89+ response_mock .consume_body_async .return_value = b"new-token-value"
90+ http_client .send .return_value = response_mock
91+ token_cache = TokenCache (http_client , config )
92+ assert token_cache ._should_refresh ()
93+ await token_cache ._refresh ()
94+ assert token_cache ._token is not None
95+ assert token_cache ._token .value == "new-token-value"
96+ assert token_cache ._token ._ttl == 100
97+
98+
99+ async def test_token_cache_get_token ():
100+ # Test that TokenCache correctly returns an existing token or refreshes if expired
101+ http_client = AsyncMock ()
102+ config = MagicMock ()
103+ token_cache = TokenCache (http_client , config )
104+ token_cache ._refresh = AsyncMock ()
105+ token_cache ._token = MagicMock ()
106+ token_cache ._token .is_expired .return_value = False
107+ token = await token_cache .get_token ()
108+ assert token == token_cache ._token
109+ token_cache ._refresh .assert_not_awaited ()
110+ token_cache ._token .is_expired .return_value = True
111+ await token_cache .get_token ()
112+ token_cache ._refresh .assert_awaited ()
113+
114+ # TODO: Add tests for EC2Metadata and IMDSCredentialsResolver
0 commit comments