Skip to content

Commit b325e70

Browse files
committed
Refactor Sources into a class/Protocol
1 parent 4351de1 commit b325e70

File tree

6 files changed

+121
-28
lines changed

6 files changed

+121
-28
lines changed

codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsAuthIntegration.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ public List<RuntimeClientPlugin> getClientPlugins(GenerationContext context) {
5959
.nullable(true)
6060
.initialize(writer -> {
6161
writer.addImport("smithy_aws_core.credentials_resolvers", "CredentialsResolverChain");
62-
writer.write("self.aws_credentials_identity_resolver = aws_credentials_identity_resolver or CredentialsResolverChain()");
62+
writer.write("self.aws_credentials_identity_resolver = aws_credentials_identity_resolver or CredentialsResolverChain(config=self)");
6363
})
6464
.build())
6565
.addConfigProperty(REGION)

packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/chain.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,24 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
3-
import os
4-
from collections.abc import Callable, Sequence
3+
from collections.abc import Sequence
54

65
from smithy_core.aio.interfaces.identity import IdentityResolver
76
from smithy_core.exceptions import SmithyIdentityException
87
from smithy_core.interfaces.identity import IdentityProperties
98

10-
from smithy_aws_core.credentials_resolvers import EnvironmentCredentialsResolver
9+
from smithy_aws_core.credentials_resolvers.environment import (
10+
EnvironmentCredentialsSource,
11+
)
12+
from smithy_aws_core.credentials_resolvers.imds import IMDSCredentialsSource
13+
from smithy_aws_core.credentials_resolvers.interfaces import (
14+
AwsCredentialsConfig,
15+
CredentialsSource,
16+
)
1117
from smithy_aws_core.identity import AWSCredentialsIdentity, AWSCredentialsResolver
1218

13-
14-
def _env_creds_available() -> bool:
15-
return "AWS_ACCESS_KEY_ID" in os.environ and "AWS_SECRET_ACCESS_KEY" in os.environ
16-
17-
18-
def _build_env_creds() -> AWSCredentialsResolver:
19-
return EnvironmentCredentialsResolver()
20-
21-
22-
type CredentialSource = tuple[Callable[[], bool], Callable[[], AWSCredentialsResolver]]
23-
_DEFAULT_SOURCES: Sequence[CredentialSource] = (
24-
(_env_creds_available, _build_env_creds),
19+
_DEFAULT_SOURCES: Sequence[CredentialsSource] = (
20+
EnvironmentCredentialsSource(),
21+
IMDSCredentialsSource(),
2522
)
2623

2724

@@ -30,8 +27,14 @@ class CredentialsResolverChain(
3027
):
3128
"""Resolves AWS Credentials from system environment variables."""
3229

33-
def __init__(self, *, sources: Sequence[CredentialSource] = _DEFAULT_SOURCES):
34-
self._sources: Sequence[CredentialSource] = sources
30+
def __init__(
31+
self,
32+
*,
33+
config: AwsCredentialsConfig,
34+
sources: Sequence[CredentialsSource] = _DEFAULT_SOURCES,
35+
):
36+
self._config = config
37+
self._sources: Sequence[CredentialsSource] = sources
3538
self._credentials_resolver: AWSCredentialsResolver | None = None
3639

3740
async def get_identity(
@@ -43,8 +46,8 @@ async def get_identity(
4346
)
4447

4548
for source in self._sources:
46-
if source[0]():
47-
self._credentials_resolver = source[1]()
49+
if source.is_available(config=self._config):
50+
self._credentials_resolver = source.build_resolver(config=self._config)
4851
return await self._credentials_resolver.get_identity(
4952
identity_properties=identity_properties
5053
)

packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/environment.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@
66
from smithy_core.exceptions import SmithyIdentityException
77
from smithy_core.interfaces.identity import IdentityProperties
88

9-
from ..identity import AWSCredentialsIdentity
9+
from smithy_aws_core.credentials_resolvers.interfaces import (
10+
AwsCredentialsConfig,
11+
CredentialsSource,
12+
)
13+
14+
from ..identity import AWSCredentialsIdentity, AWSCredentialsResolver
1015

1116

1217
class EnvironmentCredentialsResolver(
@@ -41,3 +46,13 @@ async def get_identity(
4146
)
4247

4348
return self._credentials
49+
50+
51+
class EnvironmentCredentialsSource(CredentialsSource):
52+
def is_available(self, config: AwsCredentialsConfig) -> bool:
53+
return (
54+
"AWS_ACCESS_KEY_ID" in os.environ and "AWS_SECRET_ACCESS_KEY" in os.environ
55+
)
56+
57+
def build_resolver(self, config: AwsCredentialsConfig) -> AWSCredentialsResolver:
58+
return EnvironmentCredentialsResolver()

packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,13 @@
1717
from smithy_http.aio import HTTPRequest
1818
from smithy_http.aio.interfaces import HTTPClient
1919

20+
from smithy_aws_core.credentials_resolvers.interfaces import (
21+
AwsCredentialsConfig,
22+
CredentialsSource,
23+
)
24+
2025
from .. import __version__
21-
from ..identity import AWSCredentialsIdentity
26+
from ..identity import AWSCredentialsIdentity, AWSCredentialsResolver
2227

2328
_USER_AGENT_FIELD = Field(
2429
name="User-Agent",
@@ -235,3 +240,13 @@ async def get_identity(
235240
account_id=account_id,
236241
)
237242
return self._credentials
243+
244+
245+
class IMDSCredentialsSource(CredentialsSource):
246+
def is_available(self, config: AwsCredentialsConfig) -> bool:
247+
# IMDS credentials should always be the last in the chain
248+
# We cannot check if they available without actually making a call
249+
return True
250+
251+
def build_resolver(self, config: AwsCredentialsConfig) -> AWSCredentialsResolver:
252+
return IMDSCredentialsResolver(http_client=config.http_client)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
from typing import Protocol
4+
5+
from smithy_http.aio.interfaces import HTTPClient
6+
7+
from smithy_aws_core.identity import AWSCredentialsResolver
8+
9+
10+
class AwsCredentialsConfig(Protocol):
11+
"""Configuration required for resolving credentials."""
12+
13+
http_client: HTTPClient
14+
"""A static endpoint to use for the request."""
15+
16+
17+
class CredentialsSource(Protocol):
18+
def is_available(self, config: AwsCredentialsConfig) -> bool:
19+
"""Returns True if credentials are available from this source."""
20+
...
21+
22+
def build_resolver(self, config: AwsCredentialsConfig) -> AWSCredentialsResolver:
23+
"""Builds a credentials resolver for the given configuration."""
24+
...

packages/smithy-aws-core/tests/unit/credentials_resolvers/test_credentials_resolver_chain.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,44 @@
1+
from dataclasses import dataclass
2+
from unittest.mock import Mock
3+
14
import pytest
25
from smithy_aws_core.credentials_resolvers import (
36
CredentialsResolverChain,
47
StaticCredentialsResolver,
58
)
6-
from smithy_aws_core.identity import AWSCredentialsIdentity
9+
from smithy_aws_core.credentials_resolvers.environment import (
10+
EnvironmentCredentialsSource,
11+
)
12+
from smithy_aws_core.credentials_resolvers.interfaces import (
13+
AwsCredentialsConfig,
14+
CredentialsSource,
15+
)
16+
from smithy_aws_core.identity import AWSCredentialsIdentity, AWSCredentialsResolver
717
from smithy_core.exceptions import SmithyIdentityException
818
from smithy_core.interfaces.identity import IdentityProperties
19+
from smithy_http.aio.interfaces import HTTPClient
20+
21+
22+
@dataclass
23+
class Config:
24+
http_client: HTTPClient
25+
26+
def __init__(self):
27+
self.http_client = Mock(spec=HTTPClient) # type: ignore
928

1029

1130
async def test_no_sources_resolve():
12-
resolver_chain = CredentialsResolverChain(sources=[])
31+
resolver_chain = CredentialsResolverChain(sources=[], config=Config())
1332
with pytest.raises(SmithyIdentityException):
1433
await resolver_chain.get_identity(identity_properties=IdentityProperties())
1534

1635

1736
async def test_env_credentials_resolver_not_set(monkeypatch: pytest.MonkeyPatch):
1837
monkeypatch.delenv("AWS_ACCESS_KEY_ID", raising=False)
1938
monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False)
20-
resolver_chain = CredentialsResolverChain()
39+
resolver_chain = CredentialsResolverChain(
40+
sources=[EnvironmentCredentialsSource()], config=Config()
41+
)
2142

2243
with pytest.raises(SmithyIdentityException):
2344
await resolver_chain.get_identity(identity_properties=IdentityProperties())
@@ -26,7 +47,9 @@ async def test_env_credentials_resolver_not_set(monkeypatch: pytest.MonkeyPatch)
2647
async def test_env_credentials_resolver_partial(monkeypatch: pytest.MonkeyPatch):
2748
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "akid")
2849
monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False)
29-
resolver_chain = CredentialsResolverChain()
50+
resolver_chain = CredentialsResolverChain(
51+
sources=[EnvironmentCredentialsSource()], config=Config()
52+
)
3053

3154
with pytest.raises(SmithyIdentityException):
3255
await resolver_chain.get_identity(identity_properties=IdentityProperties())
@@ -35,7 +58,9 @@ async def test_env_credentials_resolver_partial(monkeypatch: pytest.MonkeyPatch)
3558
async def test_env_credentials_resolver_success(monkeypatch: pytest.MonkeyPatch):
3659
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "akid")
3760
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "secret")
38-
resolver_chain = CredentialsResolverChain()
61+
resolver_chain = CredentialsResolverChain(
62+
sources=[EnvironmentCredentialsSource()], config=Config()
63+
)
3964

4065
credentials = await resolver_chain.get_identity(
4166
identity_properties=IdentityProperties()
@@ -50,8 +75,19 @@ async def test_custom_sources_with_static_credentials():
5075
secret_access_key="static_secret",
5176
)
5277
static_resolver = StaticCredentialsResolver(credentials=static_credentials)
78+
79+
class TestStaticSource(CredentialsSource):
80+
def is_available(self, config: AwsCredentialsConfig) -> bool:
81+
return True
82+
83+
def build_resolver(
84+
self, config: AwsCredentialsConfig
85+
) -> AWSCredentialsResolver:
86+
return static_resolver
87+
5388
resolver_chain = CredentialsResolverChain(
54-
sources=[(lambda: False, lambda: None), (lambda: True, lambda: static_resolver)] # type: ignore
89+
sources=[TestStaticSource()],
90+
config=Config(), # type: ignore
5591
)
5692

5793
credentials = await resolver_chain.get_identity(

0 commit comments

Comments
 (0)