Skip to content

Commit c94dad7

Browse files
author
Martin Guibert
committed
fix: inject account id to enumerator instead of repo
1 parent e0104c8 commit c94dad7

File tree

7 files changed

+35
-54
lines changed

7 files changed

+35
-54
lines changed

enumeration/remote/aws/init.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ func Init(version string, alerter alerter.AlerterInterface, providerLibrary *ter
3535
repositoryCache := cache.New(100)
3636

3737
s3Repository := repository.NewS3Repository(client.NewAWSClientFactory(provider.session), repositoryCache)
38-
s3ControlRepository := repository.NewS3ControlRepository(client.NewAWSClientFactory(provider.session), provider.accountId, repositoryCache)
38+
s3ControlRepository := repository.NewS3ControlRepository(client.NewAWSClientFactory(provider.session), repositoryCache)
3939
ec2repository := repository.NewEC2Repository(provider.session, repositoryCache)
4040
elbv2Repository := repository.NewELBV2Repository(provider.session, repositoryCache)
4141
route53repository := repository.NewRoute53Repository(provider.session, repositoryCache)
@@ -72,7 +72,7 @@ func Init(version string, alerter alerter.AlerterInterface, providerLibrary *ter
7272
remoteLibrary.AddEnumerator(NewS3BucketAnalyticEnumerator(s3Repository, factory, provider.Config, alerter))
7373
remoteLibrary.AddDetailsFetcher(aws.AwsS3BucketAnalyticsConfigurationResourceType, common.NewGenericDetailsFetcher(aws.AwsS3BucketAnalyticsConfigurationResourceType, provider, deserializer))
7474
remoteLibrary.AddEnumerator(NewS3BucketPublicAccessBlockEnumerator(s3Repository, factory, provider.Config, alerter))
75-
remoteLibrary.AddEnumerator(NewS3AccountPublicAccessBlockEnumerator(s3ControlRepository, factory, provider.Config, alerter))
75+
remoteLibrary.AddEnumerator(NewS3AccountPublicAccessBlockEnumerator(s3ControlRepository, factory, provider.accountId, alerter))
7676

7777
remoteLibrary.AddEnumerator(NewEC2EbsVolumeEnumerator(ec2repository, factory))
7878
remoteLibrary.AddDetailsFetcher(aws.AwsEbsVolumeResourceType, common.NewGenericDetailsFetcher(aws.AwsEbsVolumeResourceType, provider, deserializer))

enumeration/remote/aws/repository/mock_S3ControlRepository.go

Lines changed: 7 additions & 21 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

enumeration/remote/aws/repository/s3control_repository.go

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,34 +8,28 @@ import (
88
)
99

1010
type S3ControlRepository interface {
11-
DescribeAccountPublicAccessBlock() (*s3control.PublicAccessBlockConfiguration, error)
12-
GetAccountID() string
11+
DescribeAccountPublicAccessBlock(accountID string) (*s3control.PublicAccessBlockConfiguration, error)
1312
}
1413

1514
type s3ControlRepository struct {
1615
clientFactory client.AwsClientFactoryInterface
17-
accountId string
1816
cache cache.Cache
1917
}
2018

21-
func NewS3ControlRepository(factory client.AwsClientFactoryInterface, accountId string, c cache.Cache) *s3ControlRepository {
19+
func NewS3ControlRepository(factory client.AwsClientFactoryInterface, c cache.Cache) *s3ControlRepository {
2220
return &s3ControlRepository{
2321
clientFactory: factory,
24-
accountId: accountId,
2522
cache: c,
2623
}
2724
}
28-
func (s *s3ControlRepository) GetAccountID() string {
29-
return s.accountId
30-
}
3125

32-
func (s *s3ControlRepository) DescribeAccountPublicAccessBlock() (*s3control.PublicAccessBlockConfiguration, error) {
26+
func (s *s3ControlRepository) DescribeAccountPublicAccessBlock(accountID string) (*s3control.PublicAccessBlockConfiguration, error) {
3327
cacheKey := "S3DescribeAccountPublicAccessBlock"
3428
if v := s.cache.Get(cacheKey); v != nil {
3529
return v.(*s3control.PublicAccessBlockConfiguration), nil
3630
}
3731
out, err := s.clientFactory.GetS3ControlClient(nil).GetPublicAccessBlock(&s3control.GetPublicAccessBlockInput{
38-
AccountId: aws.String(s.accountId),
32+
AccountId: aws.String(accountID),
3933
})
4034

4135
if err != nil {

enumeration/remote/aws/repository/s3control_repository_test.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
)
1818

1919
func Test_s3ControlRepository_DescribeAccountPublicAccessBlock(t *testing.T) {
20+
accountID := "123456"
2021

2122
tests := []struct {
2223
name string
@@ -65,14 +66,14 @@ func Test_s3ControlRepository_DescribeAccountPublicAccessBlock(t *testing.T) {
6566
tt.mocks(mockedClient)
6667
factory := client.MockAwsClientFactoryInterface{}
6768
factory.On("GetS3ControlClient", (*aws.Config)(nil)).Return(mockedClient).Once()
68-
r := NewS3ControlRepository(&factory, "", store)
69-
got, err := r.DescribeAccountPublicAccessBlock()
69+
r := NewS3ControlRepository(&factory, store)
70+
got, err := r.DescribeAccountPublicAccessBlock(accountID)
7071
factory.AssertExpectations(t)
7172
assert.Equal(t, tt.wantErr, err)
7273

7374
if err == nil {
7475
// Check that results were cached
75-
cachedData, err := r.DescribeAccountPublicAccessBlock()
76+
cachedData, err := r.DescribeAccountPublicAccessBlock(accountID)
7677
assert.NoError(t, err)
7778
assert.Equal(t, got, cachedData)
7879
assert.IsType(t, &s3control.PublicAccessBlockConfiguration{}, store.Get("S3DescribeAccountPublicAccessBlock"))

enumeration/remote/aws/s3_account_public_access_block_enumerator.go

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,23 @@ import (
55
"github.com/snyk/driftctl/enumeration/alerter"
66
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
77
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
8-
tf "github.com/snyk/driftctl/enumeration/remote/terraform"
98
"github.com/snyk/driftctl/enumeration/resource"
109
"github.com/snyk/driftctl/enumeration/resource/aws"
1110
)
1211

1312
type S3AccountPublicAccessBlockEnumerator struct {
14-
repository repository.S3ControlRepository
15-
factory resource.ResourceFactory
16-
providerConfig tf.TerraformProviderConfig
17-
alerter alerter.AlerterInterface
13+
repository repository.S3ControlRepository
14+
factory resource.ResourceFactory
15+
accountID string
16+
alerter alerter.AlerterInterface
1817
}
1918

20-
func NewS3AccountPublicAccessBlockEnumerator(repo repository.S3ControlRepository, factory resource.ResourceFactory, providerConfig tf.TerraformProviderConfig, alerter alerter.AlerterInterface) *S3AccountPublicAccessBlockEnumerator {
19+
func NewS3AccountPublicAccessBlockEnumerator(repo repository.S3ControlRepository, factory resource.ResourceFactory, accountId string, alerter alerter.AlerterInterface) *S3AccountPublicAccessBlockEnumerator {
2120
return &S3AccountPublicAccessBlockEnumerator{
22-
repository: repo,
23-
factory: factory,
24-
providerConfig: providerConfig,
25-
alerter: alerter,
21+
repository: repo,
22+
factory: factory,
23+
accountID: accountId,
24+
alerter: alerter,
2625
}
2726
}
2827

@@ -31,7 +30,7 @@ func (e *S3AccountPublicAccessBlockEnumerator) SupportedType() resource.Resource
3130
}
3231

3332
func (e *S3AccountPublicAccessBlockEnumerator) Enumerate() ([]*resource.Resource, error) {
34-
accountPublicAccessBlock, err := e.repository.DescribeAccountPublicAccessBlock()
33+
accountPublicAccessBlock, err := e.repository.DescribeAccountPublicAccessBlock(e.accountID)
3534
if err != nil {
3635
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
3736
}
@@ -42,7 +41,7 @@ func (e *S3AccountPublicAccessBlockEnumerator) Enumerate() ([]*resource.Resource
4241
results,
4342
e.factory.CreateAbstractResource(
4443
string(e.SupportedType()),
45-
e.repository.GetAccountID(),
44+
e.accountID,
4645
map[string]interface{}{
4746
"block_public_acls": awssdk.BoolValue(accountPublicAccessBlock.BlockPublicAcls),
4847
"block_public_policy": awssdk.BoolValue(accountPublicAccessBlock.BlockPublicPolicy),

enumeration/remote/aws_s3_scanner_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,6 +1071,7 @@ func TestS3BucketAnalytic(t *testing.T) {
10711071
func TestS3AccountPublicAccessBlock(t *testing.T) {
10721072
dummyError := errors.New("this is an error")
10731073

1074+
accountID := "123456"
10741075
tests := []struct {
10751076
test string
10761077
mocks func(*repository.MockS3ControlRepository, *mocks.AlerterInterface)
@@ -1080,8 +1081,7 @@ func TestS3AccountPublicAccessBlock(t *testing.T) {
10801081
{
10811082
test: "existing access block",
10821083
mocks: func(repository *repository.MockS3ControlRepository, alerter *mocks.AlerterInterface) {
1083-
repository.On("GetAccountID").Return("123456")
1084-
repository.On("DescribeAccountPublicAccessBlock").Return(&s3control.PublicAccessBlockConfiguration{
1084+
repository.On("DescribeAccountPublicAccessBlock", accountID).Return(&s3control.PublicAccessBlockConfiguration{
10851085
BlockPublicAcls: awssdk.Bool(false),
10861086
BlockPublicPolicy: awssdk.Bool(true),
10871087
IgnorePublicAcls: awssdk.Bool(false),
@@ -1090,7 +1090,7 @@ func TestS3AccountPublicAccessBlock(t *testing.T) {
10901090
},
10911091
assertExpected: func(t *testing.T, got []*resource.Resource) {
10921092
assert.Len(t, got, 1)
1093-
assert.Equal(t, got[0].ResourceId(), "123456")
1093+
assert.Equal(t, got[0].ResourceId(), accountID)
10941094
assert.Equal(t, got[0].ResourceType(), resourceaws.AwsS3AccountPublicAccessBlock)
10951095
assert.Equal(t, got[0].Attributes(), &resource.Attributes{
10961096
"block_public_acls": false,
@@ -1103,7 +1103,7 @@ func TestS3AccountPublicAccessBlock(t *testing.T) {
11031103
{
11041104
test: "cannot list access block",
11051105
mocks: func(repository *repository.MockS3ControlRepository, alerter *mocks.AlerterInterface) {
1106-
repository.On("DescribeAccountPublicAccessBlock").Return(nil, dummyError)
1106+
repository.On("DescribeAccountPublicAccessBlock", accountID).Return(nil, dummyError)
11071107
},
11081108
wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsS3AccountPublicAccessBlock),
11091109
},
@@ -1125,7 +1125,7 @@ func TestS3AccountPublicAccessBlock(t *testing.T) {
11251125

11261126
remoteLibrary.AddEnumerator(aws.NewS3AccountPublicAccessBlockEnumerator(
11271127
repo, factory,
1128-
tf.TerraformProviderConfig{DefaultAlias: "us-east-1"},
1128+
accountID,
11291129
alerter,
11301130
))
11311131

enumeration/resource/resource_types.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ var supportedTypes = map[string]ResourceTypeMeta{
9999
"aws_s3_bucket_metric": {},
100100
"aws_s3_bucket_notification": {},
101101
"aws_s3_bucket_policy": {},
102-
"aws_s3_bucket_public_access_block": {}, "aws_security_group": {children: []ResourceType{
102+
"aws_s3_bucket_public_access_block": {},
103+
"aws_security_group": {children: []ResourceType{
103104
"aws_security_group_rule",
104105
}},
105106
"aws_s3_account_public_access_block": {},

0 commit comments

Comments
 (0)