|
4 | 4 | import pytest |
5 | 5 | from smithy_aws_core.credentials_resolvers import ( |
6 | 6 | CredentialsResolverChain, |
| 7 | + IMDSCredentialsResolver, |
7 | 8 | StaticCredentialsResolver, |
8 | 9 | ) |
9 | 10 | from smithy_aws_core.credentials_resolvers.environment import ( |
@@ -55,20 +56,103 @@ async def test_env_credentials_resolver_partial(monkeypatch: pytest.MonkeyPatch) |
55 | 56 | await resolver_chain.get_identity(identity_properties=IdentityProperties()) |
56 | 57 |
|
57 | 58 |
|
58 | | -async def test_env_credentials_resolver_success(monkeypatch: pytest.MonkeyPatch): |
| 59 | +async def test_default_sources_env_credentials_resolver_success( |
| 60 | + monkeypatch: pytest.MonkeyPatch, |
| 61 | +): |
59 | 62 | monkeypatch.setenv("AWS_ACCESS_KEY_ID", "akid") |
60 | 63 | monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "secret") |
61 | | - resolver_chain = CredentialsResolverChain( |
62 | | - sources=[EnvironmentCredentialsSource()], config=Config() |
| 64 | + resolver_chain = CredentialsResolverChain(config=Config()) |
| 65 | + |
| 66 | + credentials = await resolver_chain.get_identity( |
| 67 | + identity_properties=IdentityProperties() |
| 68 | + ) |
| 69 | + assert credentials.access_key_id == "akid" |
| 70 | + assert credentials.secret_access_key == "secret" |
| 71 | + |
| 72 | + |
| 73 | +async def test_default_sources_imds_resolver_success(monkeypatch: pytest.MonkeyPatch): |
| 74 | + monkeypatch.delenv("AWS_ACCESS_KEY_ID", raising=False) |
| 75 | + monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False) |
| 76 | + |
| 77 | + async def mock_imds_get_identity( |
| 78 | + self: IMDSCredentialsResolver, *, identity_properties: IdentityProperties |
| 79 | + ) -> AWSCredentialsIdentity: |
| 80 | + return AWSCredentialsIdentity( |
| 81 | + access_key_id="akid", |
| 82 | + secret_access_key="secret", |
| 83 | + ) |
| 84 | + |
| 85 | + monkeypatch.setattr( |
| 86 | + "smithy_aws_core.credentials_resolvers.IMDSCredentialsResolver.get_identity", |
| 87 | + mock_imds_get_identity, |
63 | 88 | ) |
64 | 89 |
|
| 90 | + resolver_chain = CredentialsResolverChain(config=Config()) |
| 91 | + |
65 | 92 | credentials = await resolver_chain.get_identity( |
66 | 93 | identity_properties=IdentityProperties() |
67 | 94 | ) |
68 | 95 | assert credentials.access_key_id == "akid" |
69 | 96 | assert credentials.secret_access_key == "secret" |
70 | 97 |
|
71 | 98 |
|
| 99 | +async def test_multiple_sources_one_valid(): |
| 100 | + class FailingSource(CredentialsSource): |
| 101 | + def is_available(self, config: AwsCredentialsConfig) -> bool: |
| 102 | + return False |
| 103 | + |
| 104 | + def build_resolver( |
| 105 | + self, config: AwsCredentialsConfig |
| 106 | + ) -> AWSCredentialsResolver: |
| 107 | + raise RuntimeError("Should not be called") |
| 108 | + |
| 109 | + static_credentials = AWSCredentialsIdentity( |
| 110 | + access_key_id="valid_akid", secret_access_key="valid_secret" |
| 111 | + ) |
| 112 | + static_resolver = StaticCredentialsResolver(credentials=static_credentials) |
| 113 | + |
| 114 | + class ValidSource(CredentialsSource): |
| 115 | + def is_available(self, config: AwsCredentialsConfig) -> bool: |
| 116 | + return True |
| 117 | + |
| 118 | + def build_resolver( |
| 119 | + self, config: AwsCredentialsConfig |
| 120 | + ) -> AWSCredentialsResolver: |
| 121 | + return static_resolver |
| 122 | + |
| 123 | + resolver_chain = CredentialsResolverChain( |
| 124 | + sources=[FailingSource(), ValidSource()], config=Config() |
| 125 | + ) |
| 126 | + |
| 127 | + credentials = await resolver_chain.get_identity( |
| 128 | + identity_properties=IdentityProperties() |
| 129 | + ) |
| 130 | + assert credentials.access_key_id == "valid_akid" |
| 131 | + assert credentials.secret_access_key == "valid_secret" |
| 132 | + |
| 133 | + |
| 134 | +async def test_cached_resolver_used(monkeypatch: pytest.MonkeyPatch): |
| 135 | + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "cached_akid") |
| 136 | + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "cached_secret") |
| 137 | + resolver_chain = CredentialsResolverChain( |
| 138 | + sources=[EnvironmentCredentialsSource()], config=Config() |
| 139 | + ) |
| 140 | + |
| 141 | + credentials1 = await resolver_chain.get_identity( |
| 142 | + identity_properties=IdentityProperties() |
| 143 | + ) |
| 144 | + credentials2 = await resolver_chain.get_identity( |
| 145 | + identity_properties=IdentityProperties() |
| 146 | + ) |
| 147 | + |
| 148 | + assert credentials1.access_key_id == credentials2.access_key_id == "cached_akid" |
| 149 | + assert ( |
| 150 | + credentials1.secret_access_key |
| 151 | + == credentials2.secret_access_key |
| 152 | + == "cached_secret" |
| 153 | + ) |
| 154 | + |
| 155 | + |
72 | 156 | async def test_custom_sources_with_static_credentials(): |
73 | 157 | static_credentials = AWSCredentialsIdentity( |
74 | 158 | access_key_id="static_akid", |
|
0 commit comments