@@ -349,62 +349,89 @@ func (provider *oauth2TokenExchange) getRequestParams() (string, error) {
349349}
350350
351351func (provider * oauth2TokenExchange ) processTokenExchangeResponse (result * http.Response , now time.Time ) error {
352- var (
353- data []byte
354- err error
355- )
356- if result .Body != nil {
357- data , err = io .ReadAll (result .Body )
358- if err != nil {
359- return xerrors .WithStackTrace (err )
360- }
361- } else {
362- data = make ([]byte , 0 )
352+ data , err := readResponseBody (result )
353+ if err != nil {
354+ return err
363355 }
364356
365357 if result .StatusCode != http .StatusOK {
366- description := result .Status
358+ return provider .handleErrorResponse (result .Status , data )
359+ }
367360
368- //nolint:tagliatelle
369- type errorResponse struct {
370- ErrorName string `json:"error"`
371- ErrorDescription string `json:"error_description"`
372- ErrorURI string `json:"error_uri"`
373- }
374- var parsedErrorResponse errorResponse
375- if err := json .Unmarshal (data , & parsedErrorResponse ); err != nil {
376- description += ", could not parse response: " + err .Error ()
361+ parsedResponse , err := parseTokenResponse (data )
362+ if err != nil {
363+ return err
364+ }
377365
378- return xerrors .WithStackTrace (fmt .Errorf ("%w: %s" , errCouldNotExchangeToken , description ))
379- }
366+ if err := validateTokenResponse (parsedResponse , provider ); err != nil {
367+ return err
368+ }
380369
381- if parsedErrorResponse .ErrorName != "" {
382- description += ", error: " + parsedErrorResponse .ErrorName
383- }
370+ provider .updateToken (parsedResponse , now )
384371
385- if parsedErrorResponse .ErrorDescription != "" {
386- description += fmt .Sprintf (", description: %q" , parsedErrorResponse .ErrorDescription )
387- }
372+ return nil
373+ }
388374
389- if parsedErrorResponse .ErrorURI != "" {
390- description += ", error_uri: " + parsedErrorResponse .ErrorURI
375+ func readResponseBody (result * http.Response ) ([]byte , error ) {
376+ if result .Body != nil {
377+ data , err := io .ReadAll (result .Body )
378+ if err != nil {
379+ return nil , xerrors .WithStackTrace (err )
391380 }
392381
393- return xerrors . WithStackTrace ( fmt . Errorf ( "%w: %s" , errCouldNotExchangeToken , description ))
382+ return data , nil
394383 }
395384
385+ return make ([]byte , 0 ), nil
386+ }
387+
388+ func (provider * oauth2TokenExchange ) handleErrorResponse (status string , data []byte ) error {
389+ description := status
390+
396391 //nolint:tagliatelle
397- type response struct {
398- AccessToken string `json:"access_token"`
399- TokenType string `json:"token_type"`
400- ExpiresIn int64 `json:"expires_in"`
401- Scope string `json:"scope"`
392+ type errorResponse struct {
393+ ErrorName string `json:"error"`
394+ ErrorDescription string `json:"error_description"`
395+ ErrorURI string `json:"error_uri"`
396+ }
397+ var parsedErrorResponse errorResponse
398+ if err := json .Unmarshal (data , & parsedErrorResponse ); err != nil {
399+ description += ", could not parse response: " + err .Error ()
400+
401+ return xerrors .WithStackTrace (fmt .Errorf ("%w: %s" , errCouldNotExchangeToken , description ))
402402 }
403- var parsedResponse response
403+
404+ if parsedErrorResponse .ErrorName != "" {
405+ description += ", error: " + parsedErrorResponse .ErrorName
406+ }
407+ if parsedErrorResponse .ErrorDescription != "" {
408+ description += fmt .Sprintf (", description: %q" , parsedErrorResponse .ErrorDescription )
409+ }
410+ if parsedErrorResponse .ErrorURI != "" {
411+ description += ", error_uri: " + parsedErrorResponse .ErrorURI
412+ }
413+
414+ return xerrors .WithStackTrace (fmt .Errorf ("%w: %s" , errCouldNotExchangeToken , description ))
415+ }
416+
417+ //nolint:tagliatelle
418+ type tokenResponse struct {
419+ AccessToken string `json:"access_token"`
420+ TokenType string `json:"token_type"`
421+ ExpiresIn int64 `json:"expires_in"`
422+ Scope string `json:"scope"`
423+ }
424+
425+ func parseTokenResponse (data []byte ) (* tokenResponse , error ) {
426+ var parsedResponse tokenResponse
404427 if err := json .Unmarshal (data , & parsedResponse ); err != nil {
405- return xerrors .WithStackTrace (fmt .Errorf ("%w: %w" , errCouldNotParseResponse , err ))
428+ return nil , xerrors .WithStackTrace (fmt .Errorf ("%w: %w" , errCouldNotParseResponse , err ))
406429 }
407430
431+ return & parsedResponse , nil
432+ }
433+
434+ func validateTokenResponse (parsedResponse * tokenResponse , provider * oauth2TokenExchange ) error {
408435 if ! strings .EqualFold (parsedResponse .TokenType , "bearer" ) {
409436 return xerrors .WithStackTrace (
410437 fmt .Errorf ("%w: %q" , errUnsupportedTokenType , parsedResponse .TokenType ))
@@ -423,18 +450,17 @@ func (provider *oauth2TokenExchange) processTokenExchangeResponse(result *http.R
423450 }
424451 }
425452
453+ return nil
454+ }
455+
456+ func (provider * oauth2TokenExchange ) updateToken (parsedResponse * tokenResponse , now time.Time ) {
426457 provider .receivedToken = "Bearer " + parsedResponse .AccessToken
427458
428- // Expire time
429- expireDelta := time .Duration (parsedResponse .ExpiresIn )
430- expireDelta *= time .Second
459+ expireDelta := time .Duration (parsedResponse .ExpiresIn ) * time .Second
431460 provider .receivedTokenExpireTime = now .Add (expireDelta )
432461
433- updateDelta := time .Duration (parsedResponse .ExpiresIn / updateTimeDivider )
434- updateDelta *= time .Second
462+ updateDelta := time .Duration (parsedResponse .ExpiresIn / updateTimeDivider ) * time .Second
435463 provider .updateTokenTime = now .Add (updateDelta )
436-
437- return nil
438464}
439465
440466func (provider * oauth2TokenExchange ) exchangeToken (ctx context.Context , now time.Time ) error {
0 commit comments