Skip to content

Commit cc04a88

Browse files
author
Martin Guibert
committed
feat: add support for aws_s3_account_public_access_block
1 parent 3e16c30 commit cc04a88

22 files changed

+4495
-9
lines changed

enumeration/remote/aws/client/mock_AwsClientFactoryInterface.go

Lines changed: 40 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

enumeration/remote/aws/client/s3_client_factory.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@ import (
55
"github.com/aws/aws-sdk-go/aws/client"
66
"github.com/aws/aws-sdk-go/service/s3"
77
"github.com/aws/aws-sdk-go/service/s3/s3iface"
8+
"github.com/aws/aws-sdk-go/service/s3control"
9+
"github.com/aws/aws-sdk-go/service/s3control/s3controliface"
810
)
911

1012
type AwsClientFactoryInterface interface {
1113
GetS3Client(configs ...*aws.Config) s3iface.S3API
14+
GetS3ControlClient(configs ...*aws.Config) s3controliface.S3ControlAPI
1215
}
1316

1417
type AwsClientFactory struct {
@@ -22,3 +25,7 @@ func NewAWSClientFactory(config client.ConfigProvider) *AwsClientFactory {
2225
func (s AwsClientFactory) GetS3Client(configs ...*aws.Config) s3iface.S3API {
2326
return s3.New(s.config, configs...)
2427
}
28+
29+
func (s AwsClientFactory) GetS3ControlClient(configs ...*aws.Config) s3controliface.S3ControlAPI {
30+
return s3control.New(s.config, configs...)
31+
}

enumeration/remote/aws/init.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package aws
33
import (
44
"github.com/snyk/driftctl/enumeration"
55
"github.com/snyk/driftctl/enumeration/alerter"
6-
"github.com/snyk/driftctl/enumeration/remote/aws/client"
6+
client "github.com/snyk/driftctl/enumeration/remote/aws/client"
77
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
88
"github.com/snyk/driftctl/enumeration/remote/cache"
99
"github.com/snyk/driftctl/enumeration/remote/common"
@@ -35,6 +35,7 @@ func Init(version string, alerter alerter.AlerterInterface, providerLibrary *ter
3535
repositoryCache := cache.New(100)
3636

3737
s3Repository := repository.NewS3Repository(client.NewAWSClientFactory(provider.session), repositoryCache)
38+
s3ControlRepository := repository.NewS3ControlRepository(client.NewAWSClientFactory(provider.session), provider.accountId, repositoryCache)
3839
ec2repository := repository.NewEC2Repository(provider.session, repositoryCache)
3940
elbv2Repository := repository.NewELBV2Repository(provider.session, repositoryCache)
4041
route53repository := repository.NewRoute53Repository(provider.session, repositoryCache)
@@ -71,6 +72,7 @@ func Init(version string, alerter alerter.AlerterInterface, providerLibrary *ter
7172
remoteLibrary.AddEnumerator(NewS3BucketAnalyticEnumerator(s3Repository, factory, provider.Config, alerter))
7273
remoteLibrary.AddDetailsFetcher(aws.AwsS3BucketAnalyticsConfigurationResourceType, common.NewGenericDetailsFetcher(aws.AwsS3BucketAnalyticsConfigurationResourceType, provider, deserializer))
7374
remoteLibrary.AddEnumerator(NewS3BucketPublicAccessBlockEnumerator(s3Repository, factory, provider.Config, alerter))
75+
remoteLibrary.AddEnumerator(NewS3AccountPublicAccessBlockEnumerator(s3ControlRepository, factory, provider.Config, alerter))
7476

7577
remoteLibrary.AddEnumerator(NewEC2EbsVolumeEnumerator(ec2repository, factory))
7678
remoteLibrary.AddDetailsFetcher(aws.AwsEbsVolumeResourceType, common.NewGenericDetailsFetcher(aws.AwsEbsVolumeResourceType, provider, deserializer))

enumeration/remote/aws/provider.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package aws
22

33
import (
4+
"github.com/aws/aws-sdk-go/aws"
45
"github.com/aws/aws-sdk-go/aws/credentials"
56
"github.com/aws/aws-sdk-go/aws/session"
67
"github.com/aws/aws-sdk-go/service/sts"
@@ -42,9 +43,10 @@ type awsConfig struct {
4243

4344
type AWSTerraformProvider struct {
4445
*terraform.TerraformProvider
45-
session *session.Session
46-
name string
47-
version string
46+
session *session.Session
47+
name string
48+
version string
49+
accountId string
4850
}
4951

5052
func NewAWSTerraformProvider(version string, progress enumeration.ProgressCounter, configDir string) (*AWSTerraformProvider, error) {
@@ -115,11 +117,13 @@ func (p *AWSTerraformProvider) CheckCredentialsExist() error {
115117
// This call is to make sure that the credentials are valid
116118
// A more complex logic exist in terraform provider, but it's probably not worth to implement it
117119
// https://github.com/hashicorp/terraform-provider-aws/blob/e3959651092864925045a6044961a73137095798/aws/auth_helpers.go#L111
118-
_, err = sts.New(p.session).GetCallerIdentity(&sts.GetCallerIdentityInput{})
120+
identity, err := sts.New(p.session).GetCallerIdentity(&sts.GetCallerIdentityInput{})
119121
if err != nil {
120122
logrus.Debug(err)
121123
return errors.New("Could not authenticate successfully on AWS with the provided credentials.\n" +
122124
"Please refer to the AWS documentation: https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-configure.html\n")
123125
}
126+
127+
p.accountId = aws.StringValue(identity.Account)
124128
return nil
125129
}

enumeration/remote/aws/repository/mock_S3ControlRepository.go

Lines changed: 65 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

enumeration/remote/aws/repository/s3_repository.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package repository
22

33
import (
44
"fmt"
5+
56
"github.com/snyk/driftctl/enumeration/remote/aws/client"
67
"github.com/snyk/driftctl/enumeration/remote/cache"
78

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package repository
2+
3+
import (
4+
"github.com/aws/aws-sdk-go/aws"
5+
"github.com/aws/aws-sdk-go/service/s3control"
6+
"github.com/snyk/driftctl/enumeration/remote/aws/client"
7+
"github.com/snyk/driftctl/enumeration/remote/cache"
8+
)
9+
10+
type S3ControlRepository interface {
11+
DescribeAccountPublicAccessBlock() (*s3control.PublicAccessBlockConfiguration, error)
12+
GetAccountID() string
13+
}
14+
15+
type s3ControlRepository struct {
16+
clientFactory client.AwsClientFactoryInterface
17+
accountId string
18+
cache cache.Cache
19+
}
20+
21+
func NewS3ControlRepository(factory client.AwsClientFactoryInterface, accountId string, c cache.Cache) *s3ControlRepository {
22+
return &s3ControlRepository{
23+
clientFactory: factory,
24+
accountId: accountId,
25+
cache: c,
26+
}
27+
}
28+
func (s *s3ControlRepository) GetAccountID() string {
29+
return s.accountId
30+
}
31+
32+
func (s *s3ControlRepository) DescribeAccountPublicAccessBlock() (*s3control.PublicAccessBlockConfiguration, error) {
33+
cacheKey := "S3DescribeAccountPublicAccessBlock"
34+
if v := s.cache.Get(cacheKey); v != nil {
35+
return v.(*s3control.PublicAccessBlockConfiguration), nil
36+
}
37+
out, err := s.clientFactory.GetS3ControlClient(nil).GetPublicAccessBlock(&s3control.GetPublicAccessBlockInput{
38+
AccountId: aws.String(s.accountId),
39+
})
40+
41+
if err != nil {
42+
return nil, err
43+
}
44+
45+
result := out.PublicAccessBlockConfiguration
46+
47+
s.cache.Put(cacheKey, result)
48+
return result, nil
49+
}
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
package repository
2+
3+
import (
4+
"strings"
5+
"testing"
6+
7+
"github.com/aws/aws-sdk-go/service/s3control"
8+
"github.com/snyk/driftctl/enumeration/remote/aws/client"
9+
"github.com/snyk/driftctl/enumeration/remote/cache"
10+
"github.com/stretchr/testify/mock"
11+
12+
"github.com/aws/aws-sdk-go/aws"
13+
"github.com/aws/aws-sdk-go/aws/awserr"
14+
"github.com/r3labs/diff/v2"
15+
awstest "github.com/snyk/driftctl/test/aws"
16+
"github.com/stretchr/testify/assert"
17+
)
18+
19+
func Test_s3ControlRepository_DescribeAccountPublicAccessBlock(t *testing.T) {
20+
21+
tests := []struct {
22+
name string
23+
mocks func(client *awstest.MockFakeS3Control)
24+
want *s3control.PublicAccessBlockConfiguration
25+
wantErr error
26+
}{
27+
{
28+
name: "describe account public accessblock",
29+
mocks: func(client *awstest.MockFakeS3Control) {
30+
client.On("GetPublicAccessBlock", mock.Anything).Return(
31+
&s3control.GetPublicAccessBlockOutput{
32+
PublicAccessBlockConfiguration: &s3control.PublicAccessBlockConfiguration{
33+
BlockPublicAcls: aws.Bool(false),
34+
BlockPublicPolicy: aws.Bool(true),
35+
IgnorePublicAcls: aws.Bool(false),
36+
RestrictPublicBuckets: aws.Bool(true),
37+
},
38+
},
39+
nil,
40+
).Once()
41+
},
42+
want: &s3control.PublicAccessBlockConfiguration{
43+
BlockPublicAcls: aws.Bool(false),
44+
BlockPublicPolicy: aws.Bool(true),
45+
IgnorePublicAcls: aws.Bool(false),
46+
RestrictPublicBuckets: aws.Bool(true),
47+
},
48+
},
49+
{
50+
name: "Error detting account public accessblock",
51+
mocks: func(client *awstest.MockFakeS3Control) {
52+
client.On("GetPublicAccessBlock", mock.Anything).Return(
53+
nil,
54+
awserr.NewRequestFailure(nil, 403, ""),
55+
).Once()
56+
},
57+
want: nil,
58+
wantErr: awserr.NewRequestFailure(nil, 403, ""),
59+
},
60+
}
61+
for _, tt := range tests {
62+
t.Run(tt.name, func(t *testing.T) {
63+
store := cache.New(1)
64+
mockedClient := &awstest.MockFakeS3Control{}
65+
tt.mocks(mockedClient)
66+
factory := client.MockAwsClientFactoryInterface{}
67+
factory.On("GetS3ControlClient", (*aws.Config)(nil)).Return(mockedClient).Once()
68+
r := NewS3ControlRepository(&factory, "", store)
69+
got, err := r.DescribeAccountPublicAccessBlock()
70+
factory.AssertExpectations(t)
71+
assert.Equal(t, tt.wantErr, err)
72+
73+
if err == nil {
74+
// Check that results were cached
75+
cachedData, err := r.DescribeAccountPublicAccessBlock()
76+
assert.NoError(t, err)
77+
assert.Equal(t, got, cachedData)
78+
assert.IsType(t, &s3control.PublicAccessBlockConfiguration{}, store.Get("S3DescribeAccountPublicAccessBlock"))
79+
}
80+
81+
changelog, err := diff.Diff(got, tt.want)
82+
assert.Nil(t, err)
83+
if len(changelog) > 0 {
84+
for _, change := range changelog {
85+
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
86+
}
87+
t.Fail()
88+
}
89+
})
90+
}
91+
}

0 commit comments

Comments
 (0)