@@ -10,6 +10,10 @@ import (
10
10
"github.com/aws/aws-sdk-go/aws/session"
11
11
"github.com/aws/aws-sdk-go/service/sts"
12
12
"github.com/twmb/franz-go/pkg/kgo"
13
+
14
+ "github.com/aws/aws-sdk-go-v2/aws"
15
+ stsv2 "github.com/aws/aws-sdk-go-v2/service/sts"
16
+
13
17
awssasl "github.com/twmb/franz-go/pkg/sasl/aws"
14
18
"github.com/twmb/franz-go/pkg/sasl/plain"
15
19
@@ -24,10 +28,14 @@ type Config struct {
24
28
// Kafka configuration provided by go-sdk
25
29
KafkaConfig pubsub.Kafka
26
30
// AWS session reference, it will be used in case AWS MSK IAM authentication mechanism is used
31
+ //
32
+ // Deprecated: Use AwsConfig instead
27
33
AwsSession * session.Session
28
34
// MsgHandler is a function that will be called when a message is received
29
35
MsgHandler MsgHandler
30
- Logger sdklogger.Logger
36
+ // AWS configuration reference, it will be used in case AWS MSK IAM authentication mechanism is used
37
+ AwsConfig * aws.Config
38
+ Logger sdklogger.Logger
31
39
}
32
40
33
41
const tlsConnectionTimeout = 10 * time .Second
@@ -43,7 +51,7 @@ func newConfig(c Config, opts ...kgo.Opt) ([]kgo.Opt, error) {
43
51
case pubsub .Plain :
44
52
options = append (options , getPlainSaslOption (c .KafkaConfig .SASL ))
45
53
case pubsub .AWSMskIam :
46
- options = append (options , getAwsMskIamSaslOption (c .KafkaConfig .SASL .AWSMskIam , c .AwsSession ))
54
+ options = append (options , getAwsMskIamSaslOption (c .KafkaConfig .SASL .AWSMskIam , c .AwsSession , c . AwsConfig ))
47
55
}
48
56
}
49
57
@@ -101,11 +109,11 @@ func getPlainSaslOption(saslConf pubsub.SASL) kgo.Opt {
101
109
}.AsMechanism ())
102
110
}
103
111
104
- func getAwsMskIamSaslOption (iamConf pubsub.SASLAwsMskIam , s * session.Session ) kgo.Opt {
112
+ func getAwsMskIamSaslOption (iamConf pubsub.SASLAwsMskIam , s * session.Session , awsCfg * aws. Config ) kgo.Opt {
105
113
var opt kgo.Opt
106
114
107
- // no AWS session provided
108
- if s == nil {
115
+ // no AWS session and AWS config provided
116
+ if s == nil && awsCfg == nil {
109
117
opt = kgo .SASL (awssasl.Auth {
110
118
AccessKey : iamConf .AccessKey ,
111
119
SecretKey : iamConf .SecretKey ,
@@ -115,40 +123,85 @@ func getAwsMskIamSaslOption(iamConf pubsub.SASLAwsMskIam, s *session.Session) kg
115
123
} else {
116
124
opt = kgo .SASL (
117
125
awssasl .ManagedStreamingIAM (func (ctx context.Context ) (awssasl.Auth , error ) {
118
- // If assumable role is not provided, we try to get credentials from the provided AWS session
119
- if iamConf .AssumableRole == "" {
120
- val , err := s .Config .Credentials .Get ()
121
- if err != nil {
122
- return awssasl.Auth {}, err
123
- }
124
-
125
- return awssasl.Auth {
126
- AccessKey : val .AccessKeyID ,
127
- SecretKey : val .SecretAccessKey ,
128
- SessionToken : val .SessionToken ,
129
- UserAgent : iamConf .UserAgent ,
130
- }, nil
126
+ if s != nil {
127
+ return getAwsSaslAuthFromSession (iamConf , s )
131
128
}
132
129
133
- svc := sts .New (s )
134
-
135
- res , stsErr := svc .AssumeRole (& sts.AssumeRoleInput {
136
- RoleArn : & iamConf .AssumableRole ,
137
- RoleSessionName : & iamConf .SessionName ,
138
- })
139
- if stsErr != nil {
140
- return awssasl.Auth {}, stsErr
141
- }
142
-
143
- return awssasl.Auth {
144
- AccessKey : * res .Credentials .AccessKeyId ,
145
- SecretKey : * res .Credentials .SecretAccessKey ,
146
- SessionToken : * res .Credentials .SessionToken ,
147
- UserAgent : iamConf .UserAgent ,
148
- }, nil
130
+ return getAwsSaslAuthFromConfig (ctx , iamConf , awsCfg )
149
131
}),
150
132
)
151
133
}
152
134
153
135
return opt
154
136
}
137
+
138
+ func getAwsSaslAuthFromSession (iamConf pubsub.SASLAwsMskIam , s * session.Session ) (awssasl.Auth , error ) {
139
+ // If assumable role is not provided, we try to get credentials from the provided AWS session
140
+ if iamConf .AssumableRole == "" {
141
+ val , err := s .Config .Credentials .Get ()
142
+ if err != nil {
143
+ return awssasl.Auth {}, err
144
+ }
145
+
146
+ return awssasl.Auth {
147
+ AccessKey : val .AccessKeyID ,
148
+ SecretKey : val .SecretAccessKey ,
149
+ SessionToken : val .SessionToken ,
150
+ UserAgent : iamConf .UserAgent ,
151
+ }, nil
152
+ }
153
+
154
+ svc := sts .New (s )
155
+
156
+ res , stsErr := svc .AssumeRole (& sts.AssumeRoleInput {
157
+ RoleArn : & iamConf .AssumableRole ,
158
+ RoleSessionName : & iamConf .SessionName ,
159
+ })
160
+ if stsErr != nil {
161
+ return awssasl.Auth {}, stsErr
162
+ }
163
+
164
+ return awssasl.Auth {
165
+ AccessKey : * res .Credentials .AccessKeyId ,
166
+ SecretKey : * res .Credentials .SecretAccessKey ,
167
+ SessionToken : * res .Credentials .SessionToken ,
168
+ UserAgent : iamConf .UserAgent ,
169
+ }, nil
170
+ }
171
+
172
+ func getAwsSaslAuthFromConfig (
173
+ ctx context.Context ,
174
+ iamConf pubsub.SASLAwsMskIam ,
175
+ awsCfg * aws.Config ) (awssasl.Auth , error ) {
176
+ // If assumable role is not provided, we try to get credentials from the provided AWS config
177
+ if iamConf .AssumableRole == "" {
178
+ val , err := awsCfg .Credentials .Retrieve (ctx )
179
+ if err != nil {
180
+ return awssasl.Auth {}, err
181
+ }
182
+
183
+ return awssasl.Auth {
184
+ AccessKey : val .AccessKeyID ,
185
+ SecretKey : val .SecretAccessKey ,
186
+ SessionToken : val .SessionToken ,
187
+ UserAgent : iamConf .UserAgent ,
188
+ }, nil
189
+ }
190
+
191
+ client := stsv2 .NewFromConfig (* awsCfg )
192
+
193
+ res , stsErr := client .AssumeRole (ctx , & stsv2.AssumeRoleInput {
194
+ RoleArn : & iamConf .AssumableRole ,
195
+ RoleSessionName : & iamConf .SessionName ,
196
+ })
197
+ if stsErr != nil {
198
+ return awssasl.Auth {}, stsErr
199
+ }
200
+
201
+ return awssasl.Auth {
202
+ AccessKey : * res .Credentials .AccessKeyId ,
203
+ SecretKey : * res .Credentials .SecretAccessKey ,
204
+ SessionToken : * res .Credentials .SessionToken ,
205
+ UserAgent : iamConf .UserAgent ,
206
+ }, nil
207
+ }
0 commit comments