@@ -34,9 +34,9 @@ const (
3434
3535// KeyFlow handles auth with SA key
3636type KeyFlow struct {
37- client * http.Client
37+ rt http.RoundTripper
38+ authClient * http.Client
3839 config * KeyFlowConfig
39- doer func (req * http.Request ) (resp * http.Response , err error )
4040 key * ServiceAccountKeyResponse
4141 privateKey * rsa.PrivateKey
4242 privateKeyPEM []byte
@@ -116,15 +116,25 @@ func (c *KeyFlow) GetToken() TokenResponseBody {
116116 return * c .token
117117}
118118
119- func (c * KeyFlow ) Init (cfg * KeyFlowConfig ) error {
119+ func (c * KeyFlow ) Init (cfg * KeyFlowConfig , rt http. RoundTripper ) error {
120120 // No concurrency at this point, so no mutex check needed
121121 c .token = & TokenResponseBody {}
122122 c .config = cfg
123123
124124 if c .config .TokenUrl == "" {
125125 c .config .TokenUrl = tokenAPI
126126 }
127- c .configureHTTPClient ()
127+
128+ if rt == nil {
129+ rt = http .DefaultTransport
130+ }
131+
132+ c .rt = rt
133+ c .authClient = & http.Client {
134+ Transport : rt ,
135+ Timeout : DefaultClientTimeout ,
136+ }
137+
128138 err := c .validate ()
129139 if err != nil {
130140 return err
@@ -163,7 +173,7 @@ func (c *KeyFlow) SetToken(accessToken, refreshToken string) error {
163173
164174// Roundtrip performs the request
165175func (c * KeyFlow ) RoundTrip (req * http.Request ) (* http.Response , error ) {
166- if c .client == nil {
176+ if c .rt == nil {
167177 return nil , fmt .Errorf ("please run Init()" )
168178 }
169179
@@ -172,13 +182,13 @@ func (c *KeyFlow) RoundTrip(req *http.Request) (*http.Response, error) {
172182 return nil , err
173183 }
174184 req .Header .Set ("Authorization" , fmt .Sprintf ("Bearer %s" , accessToken ))
175- return c .doer (req )
185+ return c .rt . RoundTrip (req )
176186}
177187
178188// GetAccessToken returns a short-lived access token and saves the access and refresh tokens in the token field
179189func (c * KeyFlow ) GetAccessToken () (string , error ) {
180- if c .client == nil {
181- return "" , fmt .Errorf ("nil http client , please run Init()" )
190+ if c .rt == nil {
191+ return "" , fmt .Errorf ("nil http round tripper , please run Init()" )
182192 }
183193
184194 c .tokenMutex .RLock ()
@@ -203,14 +213,6 @@ func (c *KeyFlow) GetAccessToken() (string, error) {
203213 return accessToken , nil
204214}
205215
206- // configureHTTPClient configures the HTTP client
207- func (c * KeyFlow ) configureHTTPClient () {
208- client := & http.Client {}
209- client .Timeout = DefaultClientTimeout
210- c .client = client
211- c .doer = c .client .Do
212- }
213-
214216// validate the client is configured well
215217func (c * KeyFlow ) validate () error {
216218 if c .config .ServiceAccountKey == nil {
@@ -279,10 +281,6 @@ func (c *KeyFlow) createAccessToken() (err error) {
279281// createAccessTokenWithRefreshToken creates an access token using
280282// an existing pre-validated refresh token
281283func (c * KeyFlow ) createAccessTokenWithRefreshToken () (err error ) {
282- if c .client == nil {
283- return fmt .Errorf ("nil http client, please run Init()" )
284- }
285-
286284 c .tokenMutex .RLock ()
287285 refreshToken := c .token .RefreshToken
288286 c .tokenMutex .RUnlock ()
@@ -334,7 +332,8 @@ func (c *KeyFlow) requestToken(grant, assertion string) (*http.Response, error)
334332 return nil , err
335333 }
336334 req .Header .Add ("Content-Type" , "application/x-www-form-urlencoded" )
337- return c .doer (req )
335+
336+ return c .authClient .Do (req )
338337}
339338
340339// parseTokenResponse parses the response from the server
0 commit comments