Skip to content

Commit 878e7b2

Browse files
committed
Add tests for Config, Token, and TokenCache
1 parent 28cc2d5 commit 878e7b2

File tree

1 file changed

+114
-0
lines changed
  • packages/smithy-aws-core/tests/unit/credentials_resolvers

1 file changed

+114
-0
lines changed
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# pyright: reportPrivateUsage=false
5+
import pytest
6+
import time
7+
from smithy_core.retries import SimpleRetryStrategy
8+
from smithy_core import URI
9+
from smithy_aws_core.credentials_resolvers.imds import Config, Token, TokenCache
10+
from unittest.mock import MagicMock, AsyncMock
11+
12+
13+
def test_config_defaults():
14+
config = Config()
15+
assert isinstance(config.retry_strategy, SimpleRetryStrategy)
16+
assert config.endpoint_uri == URI(scheme="http", host=Config._HOST_MAPPING["IPv4"])
17+
assert config.endpoint_mode == "IPv4"
18+
assert config.port == 80
19+
assert config.token_ttl == 21600
20+
21+
22+
def test_endpoint_resolution():
23+
config_ipv4 = Config(endpoint_mode="IPv4")
24+
config_ipv6 = Config(endpoint_mode="IPv6")
25+
assert config_ipv4.endpoint_uri.host == Config._HOST_MAPPING["IPv4"]
26+
assert config_ipv6.endpoint_uri.host == Config._HOST_MAPPING["IPv6"]
27+
28+
29+
def test_config_uses_custom_endpoint():
30+
# The custom endpoint should take precedence over IPv4 endpoint resolution.
31+
config = Config(
32+
endpoint_uri=URI(scheme="http", host="test.host"), endpoint_mode="IPv4"
33+
)
34+
assert config.endpoint_uri == URI(scheme="http", host="test.host")
35+
36+
# The custom endpoint takes precedence over IPv6 endpoint resolution.
37+
config = Config(
38+
endpoint_uri=URI(scheme="http", host="test.host"), endpoint_mode="IPv6"
39+
)
40+
assert config.endpoint_uri == URI(scheme="http", host="test.host")
41+
42+
43+
def test_config_ttl_validation():
44+
# TTL values < _MIN_TTL should throw a ValueError
45+
with pytest.raises(ValueError):
46+
Config(token_ttl=Config._MIN_TTL - 1)
47+
# TTL values > _MAX_TTL should throw a ValueError
48+
with pytest.raises(ValueError):
49+
Config(token_ttl=Config._MAX_TTL + 1)
50+
51+
52+
def test_token_creation():
53+
token = Token(value="test-token", ttl=100)
54+
assert token._value == "test-token"
55+
assert token._ttl == 100
56+
assert not token.is_expired()
57+
58+
59+
def test_token_expiration():
60+
token = Token(value="test-token", ttl=1)
61+
assert not token.is_expired()
62+
time.sleep(1.1)
63+
assert token.is_expired()
64+
65+
66+
async def test_token_cache_should_refresh():
67+
http_client = AsyncMock()
68+
config = MagicMock()
69+
# A new token cache needs a refresh
70+
token_cache = TokenCache(http_client, config)
71+
assert token_cache._should_refresh()
72+
# A token cache with an unexpired token doesn't need a refresh
73+
token_cache._token = MagicMock()
74+
token_cache._token.is_expired.return_value = False
75+
assert not token_cache._should_refresh()
76+
# A token cache with an expired token needs a refresh
77+
token_cache._token.is_expired.return_value = True
78+
assert token_cache._should_refresh()
79+
80+
81+
async def test_token_cache_refresh():
82+
# Test that TokenCache correctly refreshes the token when needed
83+
http_client = AsyncMock()
84+
config = MagicMock()
85+
config.token_ttl = 100
86+
config.endpoint_uri.scheme = "http"
87+
config.endpoint_uri.host = "169.254.169.254"
88+
response_mock = AsyncMock()
89+
response_mock.consume_body_async.return_value = b"new-token-value"
90+
http_client.send.return_value = response_mock
91+
token_cache = TokenCache(http_client, config)
92+
assert token_cache._should_refresh()
93+
await token_cache._refresh()
94+
assert token_cache._token is not None
95+
assert token_cache._token.value == "new-token-value"
96+
assert token_cache._token._ttl == 100
97+
98+
99+
async def test_token_cache_get_token():
100+
# Test that TokenCache correctly returns an existing token or refreshes if expired
101+
http_client = AsyncMock()
102+
config = MagicMock()
103+
token_cache = TokenCache(http_client, config)
104+
token_cache._refresh = AsyncMock()
105+
token_cache._token = MagicMock()
106+
token_cache._token.is_expired.return_value = False
107+
token = await token_cache.get_token()
108+
assert token == token_cache._token
109+
token_cache._refresh.assert_not_awaited()
110+
token_cache._token.is_expired.return_value = True
111+
await token_cache.get_token()
112+
token_cache._refresh.assert_awaited()
113+
114+
# TODO: Add tests for EC2Metadata and IMDSCredentialsResolver

0 commit comments

Comments
 (0)