Skip to content

Commit 086d9ce

Browse files
committed
Improve tests
1 parent d14e85e commit 086d9ce

File tree

1 file changed

+87
-3
lines changed

1 file changed

+87
-3
lines changed

packages/smithy-aws-core/tests/unit/credentials_resolvers/test_credentials_resolver_chain.py

Lines changed: 87 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55
from smithy_aws_core.credentials_resolvers import (
66
CredentialsResolverChain,
7+
IMDSCredentialsResolver,
78
StaticCredentialsResolver,
89
)
910
from smithy_aws_core.credentials_resolvers.environment import (
@@ -55,20 +56,103 @@ async def test_env_credentials_resolver_partial(monkeypatch: pytest.MonkeyPatch)
5556
await resolver_chain.get_identity(identity_properties=IdentityProperties())
5657

5758

58-
async def test_env_credentials_resolver_success(monkeypatch: pytest.MonkeyPatch):
59+
async def test_default_sources_env_credentials_resolver_success(
60+
monkeypatch: pytest.MonkeyPatch,
61+
):
5962
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "akid")
6063
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "secret")
61-
resolver_chain = CredentialsResolverChain(
62-
sources=[EnvironmentCredentialsSource()], config=Config()
64+
resolver_chain = CredentialsResolverChain(config=Config())
65+
66+
credentials = await resolver_chain.get_identity(
67+
identity_properties=IdentityProperties()
68+
)
69+
assert credentials.access_key_id == "akid"
70+
assert credentials.secret_access_key == "secret"
71+
72+
73+
async def test_default_sources_imds_resolver_success(monkeypatch: pytest.MonkeyPatch):
74+
monkeypatch.delenv("AWS_ACCESS_KEY_ID", raising=False)
75+
monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False)
76+
77+
async def mock_imds_get_identity(
78+
self: IMDSCredentialsResolver, *, identity_properties: IdentityProperties
79+
) -> AWSCredentialsIdentity:
80+
return AWSCredentialsIdentity(
81+
access_key_id="akid",
82+
secret_access_key="secret",
83+
)
84+
85+
monkeypatch.setattr(
86+
"smithy_aws_core.credentials_resolvers.IMDSCredentialsResolver.get_identity",
87+
mock_imds_get_identity,
6388
)
6489

90+
resolver_chain = CredentialsResolverChain(config=Config())
91+
6592
credentials = await resolver_chain.get_identity(
6693
identity_properties=IdentityProperties()
6794
)
6895
assert credentials.access_key_id == "akid"
6996
assert credentials.secret_access_key == "secret"
7097

7198

99+
async def test_multiple_sources_one_valid():
100+
class FailingSource(CredentialsSource):
101+
def is_available(self, config: AwsCredentialsConfig) -> bool:
102+
return False
103+
104+
def build_resolver(
105+
self, config: AwsCredentialsConfig
106+
) -> AWSCredentialsResolver:
107+
raise RuntimeError("Should not be called")
108+
109+
static_credentials = AWSCredentialsIdentity(
110+
access_key_id="valid_akid", secret_access_key="valid_secret"
111+
)
112+
static_resolver = StaticCredentialsResolver(credentials=static_credentials)
113+
114+
class ValidSource(CredentialsSource):
115+
def is_available(self, config: AwsCredentialsConfig) -> bool:
116+
return True
117+
118+
def build_resolver(
119+
self, config: AwsCredentialsConfig
120+
) -> AWSCredentialsResolver:
121+
return static_resolver
122+
123+
resolver_chain = CredentialsResolverChain(
124+
sources=[FailingSource(), ValidSource()], config=Config()
125+
)
126+
127+
credentials = await resolver_chain.get_identity(
128+
identity_properties=IdentityProperties()
129+
)
130+
assert credentials.access_key_id == "valid_akid"
131+
assert credentials.secret_access_key == "valid_secret"
132+
133+
134+
async def test_cached_resolver_used(monkeypatch: pytest.MonkeyPatch):
135+
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "cached_akid")
136+
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "cached_secret")
137+
resolver_chain = CredentialsResolverChain(
138+
sources=[EnvironmentCredentialsSource()], config=Config()
139+
)
140+
141+
credentials1 = await resolver_chain.get_identity(
142+
identity_properties=IdentityProperties()
143+
)
144+
credentials2 = await resolver_chain.get_identity(
145+
identity_properties=IdentityProperties()
146+
)
147+
148+
assert credentials1.access_key_id == credentials2.access_key_id == "cached_akid"
149+
assert (
150+
credentials1.secret_access_key
151+
== credentials2.secret_access_key
152+
== "cached_secret"
153+
)
154+
155+
72156
async def test_custom_sources_with_static_credentials():
73157
static_credentials = AWSCredentialsIdentity(
74158
access_key_id="static_akid",

0 commit comments

Comments
 (0)