Skip to content

Commit a8aebc8

Browse files
committed
feat: Use AWS version 2 to get the AWS credentials for AWS MSK
1 parent d74b563 commit a8aebc8

File tree

2 files changed

+88
-35
lines changed

2 files changed

+88
-35
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ require (
1010
github.com/aws/aws-sdk-go-v2 v1.26.1
1111
github.com/aws/aws-sdk-go-v2/config v1.27.11
1212
github.com/aws/aws-sdk-go-v2/service/s3 v1.53.1
13+
github.com/aws/aws-sdk-go-v2/service/sts v1.28.6
1314
github.com/getsentry/sentry-go v0.12.0
1415
github.com/go-kit/kit v0.9.0
1516
github.com/google/uuid v1.6.0
@@ -61,7 +62,6 @@ require (
6162
github.com/aws/aws-sdk-go-v2/service/sqs v1.24.4 // indirect
6263
github.com/aws/aws-sdk-go-v2/service/sso v1.20.5 // indirect
6364
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.4 // indirect
64-
github.com/aws/aws-sdk-go-v2/service/sts v1.28.6 // indirect
6565
github.com/aws/smithy-go v1.20.2 // indirect
6666
github.com/cespare/xxhash/v2 v2.2.0 // indirect
6767
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect

pkg/pubsub/kafka/config.go

Lines changed: 87 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ import (
1010
"github.com/aws/aws-sdk-go/aws/session"
1111
"github.com/aws/aws-sdk-go/service/sts"
1212
"github.com/twmb/franz-go/pkg/kgo"
13+
14+
"github.com/aws/aws-sdk-go-v2/aws"
15+
stsv2 "github.com/aws/aws-sdk-go-v2/service/sts"
16+
1317
awssasl "github.com/twmb/franz-go/pkg/sasl/aws"
1418
"github.com/twmb/franz-go/pkg/sasl/plain"
1519

@@ -24,10 +28,14 @@ type Config struct {
2428
// Kafka configuration provided by go-sdk
2529
KafkaConfig pubsub.Kafka
2630
// AWS session reference, it will be used in case AWS MSK IAM authentication mechanism is used
31+
//
32+
// Deprecated: Use AwsConfig instead
2733
AwsSession *session.Session
2834
// MsgHandler is a function that will be called when a message is received
2935
MsgHandler MsgHandler
30-
Logger sdklogger.Logger
36+
// AWS configuration reference, it will be used in case AWS MSK IAM authentication mechanism is used
37+
AwsConfig *aws.Config
38+
Logger sdklogger.Logger
3139
}
3240

3341
const tlsConnectionTimeout = 10 * time.Second
@@ -43,7 +51,7 @@ func newConfig(c Config, opts ...kgo.Opt) ([]kgo.Opt, error) {
4351
case pubsub.Plain:
4452
options = append(options, getPlainSaslOption(c.KafkaConfig.SASL))
4553
case pubsub.AWSMskIam:
46-
options = append(options, getAwsMskIamSaslOption(c.KafkaConfig.SASL.AWSMskIam, c.AwsSession))
54+
options = append(options, getAwsMskIamSaslOption(c.KafkaConfig.SASL.AWSMskIam, c.AwsSession, c.AwsConfig))
4755
}
4856
}
4957

@@ -101,11 +109,11 @@ func getPlainSaslOption(saslConf pubsub.SASL) kgo.Opt {
101109
}.AsMechanism())
102110
}
103111

104-
func getAwsMskIamSaslOption(iamConf pubsub.SASLAwsMskIam, s *session.Session) kgo.Opt {
112+
func getAwsMskIamSaslOption(iamConf pubsub.SASLAwsMskIam, s *session.Session, awsCfg *aws.Config) kgo.Opt {
105113
var opt kgo.Opt
106114

107-
// no AWS session provided
108-
if s == nil {
115+
// no AWS session and AWS config provided
116+
if s == nil && awsCfg == nil {
109117
opt = kgo.SASL(awssasl.Auth{
110118
AccessKey: iamConf.AccessKey,
111119
SecretKey: iamConf.SecretKey,
@@ -115,40 +123,85 @@ func getAwsMskIamSaslOption(iamConf pubsub.SASLAwsMskIam, s *session.Session) kg
115123
} else {
116124
opt = kgo.SASL(
117125
awssasl.ManagedStreamingIAM(func(ctx context.Context) (awssasl.Auth, error) {
118-
// If assumable role is not provided, we try to get credentials from the provided AWS session
119-
if iamConf.AssumableRole == "" {
120-
val, err := s.Config.Credentials.Get()
121-
if err != nil {
122-
return awssasl.Auth{}, err
123-
}
124-
125-
return awssasl.Auth{
126-
AccessKey: val.AccessKeyID,
127-
SecretKey: val.SecretAccessKey,
128-
SessionToken: val.SessionToken,
129-
UserAgent: iamConf.UserAgent,
130-
}, nil
126+
if s != nil {
127+
return getAwsSaslAuthFromSession(iamConf, s)
131128
}
132129

133-
svc := sts.New(s)
134-
135-
res, stsErr := svc.AssumeRole(&sts.AssumeRoleInput{
136-
RoleArn: &iamConf.AssumableRole,
137-
RoleSessionName: &iamConf.SessionName,
138-
})
139-
if stsErr != nil {
140-
return awssasl.Auth{}, stsErr
141-
}
142-
143-
return awssasl.Auth{
144-
AccessKey: *res.Credentials.AccessKeyId,
145-
SecretKey: *res.Credentials.SecretAccessKey,
146-
SessionToken: *res.Credentials.SessionToken,
147-
UserAgent: iamConf.UserAgent,
148-
}, nil
130+
return getAwsSaslAuthFromConfig(ctx, iamConf, awsCfg)
149131
}),
150132
)
151133
}
152134

153135
return opt
154136
}
137+
138+
func getAwsSaslAuthFromSession(iamConf pubsub.SASLAwsMskIam, s *session.Session) (awssasl.Auth, error) {
139+
// If assumable role is not provided, we try to get credentials from the provided AWS session
140+
if iamConf.AssumableRole == "" {
141+
val, err := s.Config.Credentials.Get()
142+
if err != nil {
143+
return awssasl.Auth{}, err
144+
}
145+
146+
return awssasl.Auth{
147+
AccessKey: val.AccessKeyID,
148+
SecretKey: val.SecretAccessKey,
149+
SessionToken: val.SessionToken,
150+
UserAgent: iamConf.UserAgent,
151+
}, nil
152+
}
153+
154+
svc := sts.New(s)
155+
156+
res, stsErr := svc.AssumeRole(&sts.AssumeRoleInput{
157+
RoleArn: &iamConf.AssumableRole,
158+
RoleSessionName: &iamConf.SessionName,
159+
})
160+
if stsErr != nil {
161+
return awssasl.Auth{}, stsErr
162+
}
163+
164+
return awssasl.Auth{
165+
AccessKey: *res.Credentials.AccessKeyId,
166+
SecretKey: *res.Credentials.SecretAccessKey,
167+
SessionToken: *res.Credentials.SessionToken,
168+
UserAgent: iamConf.UserAgent,
169+
}, nil
170+
}
171+
172+
func getAwsSaslAuthFromConfig(
173+
ctx context.Context,
174+
iamConf pubsub.SASLAwsMskIam,
175+
awsCfg *aws.Config) (awssasl.Auth, error) {
176+
// If assumable role is not provided, we try to get credentials from the provided AWS config
177+
if iamConf.AssumableRole == "" {
178+
val, err := awsCfg.Credentials.Retrieve(ctx)
179+
if err != nil {
180+
return awssasl.Auth{}, err
181+
}
182+
183+
return awssasl.Auth{
184+
AccessKey: val.AccessKeyID,
185+
SecretKey: val.SecretAccessKey,
186+
SessionToken: val.SessionToken,
187+
UserAgent: iamConf.UserAgent,
188+
}, nil
189+
}
190+
191+
client := stsv2.NewFromConfig(*awsCfg)
192+
193+
res, stsErr := client.AssumeRole(ctx, &stsv2.AssumeRoleInput{
194+
RoleArn: &iamConf.AssumableRole,
195+
RoleSessionName: &iamConf.SessionName,
196+
})
197+
if stsErr != nil {
198+
return awssasl.Auth{}, stsErr
199+
}
200+
201+
return awssasl.Auth{
202+
AccessKey: *res.Credentials.AccessKeyId,
203+
SecretKey: *res.Credentials.SecretAccessKey,
204+
SessionToken: *res.Credentials.SessionToken,
205+
UserAgent: iamConf.UserAgent,
206+
}, nil
207+
}

0 commit comments

Comments
 (0)