@@ -4,9 +4,11 @@ import (
44 "context"
55 "crypto/hmac"
66 "crypto/sha1"
7+ "encoding/base64"
78 "encoding/json"
89 "fmt"
910 "hash"
11+ "io"
1012 "io/ioutil"
1113 math_rand "math/rand"
1214 "net"
@@ -26,11 +28,13 @@ const (
2628)
2729
2830var (
29- defaultCVMAuthExpire = int64 (600 )
31+ defaultTmpAuthExpire = int64 (600 )
3032 defaultCVMSchema = "http"
3133 defaultCVMMetaHost = "metadata.tencentyun.com"
3234 defaultCVMCredURI = "latest/meta-data/cam/security-credentials"
3335 internalHost = regexp .MustCompile (`^.*cos-internal\.[a-z-1]+\.tencentcos\.cn$` )
36+ defaultStsHost = "sts.tencentcloudapi.com"
37+ defaultStsSchema = "https"
3438)
3539
3640var DNSScatterDialContext = DNSScatterDialContextFunc
@@ -424,7 +428,7 @@ func (t *CVMCredentialTransport) GetRoles() ([]string, error) {
424428func (t * CVMCredentialTransport ) UpdateCredential (now int64 ) (string , string , string , error ) {
425429 t .rwLocker .Lock ()
426430 defer t .rwLocker .Unlock ()
427- if t .expiredTime > now + defaultCVMAuthExpire {
431+ if t .expiredTime > now + defaultTmpAuthExpire {
428432 return t .secretID , t .secretKey , t .sessionToken , nil
429433 }
430434 roleName := t .RoleName
@@ -460,8 +464,8 @@ func (t *CVMCredentialTransport) UpdateCredential(now int64) (string, string, st
460464func (t * CVMCredentialTransport ) GetCredential () (string , string , string , error ) {
461465 now := time .Now ().Unix ()
462466 t .rwLocker .RLock ()
463- // 提前 defaultCVMAuthExpire 获取重新获取临时密钥
464- if t .expiredTime <= now + defaultCVMAuthExpire {
467+ // 提前 defaultTmpAuthExpire 获取重新获取临时密钥
468+ if t .expiredTime <= now + defaultTmpAuthExpire {
465469 expiredTime := t .expiredTime
466470 t .rwLocker .RUnlock ()
467471 secretID , secretKey , secretToken , err := t .UpdateCredential (now )
@@ -545,3 +549,208 @@ func (c *Credential) GetSecretId() string {
545549func (c * Credential ) GetToken () string {
546550 return c .SessionToken
547551}
552+
553+ // 通过sts访问
554+ type Credentials struct {
555+ TmpSecretID string `json:"TmpSecretId,omitempty"`
556+ TmpSecretKey string `json:"TmpSecretKey,omitempty"`
557+ SessionToken string `json:"Token,omitempty"`
558+ }
559+ type CredentialError struct {
560+ Code string `json:"Code,omitempty"`
561+ Message string `json:"Message,omitempty"`
562+ RequestId string `json:"RequestId,omitempty"`
563+ }
564+
565+ func (e * CredentialError ) Error () string {
566+ return fmt .Sprintf ("Code: %v, Message: %v, RequestId: %v" , e .Code , e .Message , e .RequestId )
567+ }
568+
569+ type CredentialResult struct {
570+ Credentials * Credentials `json:"Credentials,omitempty"`
571+ ExpiredTime int64 `json:"ExpiredTime,omitempty"`
572+ RequestId string `json:"RequestId,omitempty"`
573+ Error * CredentialError `json:"Error,omitempty"`
574+ }
575+
576+ type CredentialCompleteResult struct {
577+ Response * CredentialResult `json:"Response"`
578+ }
579+
580+ type CredentialPolicyStatement struct {
581+ Action []string `json:"action,omitempty"`
582+ Effect string `json:"effect,omitempty"`
583+ Resource []string `json:"resource,omitempty"`
584+ Condition map [string ]map [string ]interface {} `json:"condition,omitempty"`
585+ }
586+
587+ type CredentialPolicy struct {
588+ Version string `json:"version,omitempty"`
589+ Statement []CredentialPolicyStatement `json:"statement,omitempty"`
590+ }
591+
592+ type StsCredentialTransport struct {
593+ Transport http.RoundTripper
594+ SecretID string
595+ SecretKey string
596+ Policy * CredentialPolicy
597+ Host string
598+ Region string
599+ expiredTime int64
600+ credential Credentials
601+ rwLocker sync.RWMutex
602+ }
603+
604+ func (t * StsCredentialTransport ) UpdateCredential (now int64 ) (string , string , string , error ) {
605+ t .rwLocker .Lock ()
606+ defer t .rwLocker .Unlock ()
607+ if t .expiredTime > now + defaultTmpAuthExpire {
608+ return t .credential .TmpSecretID , t .credential .TmpSecretKey , t .credential .SessionToken , nil
609+ }
610+ region := t .Region
611+ if region == "" {
612+ region = "ap-guangzhou"
613+ }
614+ policy , err := getPolicy (t .Policy )
615+ if err != nil {
616+ return t .credential .TmpSecretID , t .credential .TmpSecretKey , t .credential .SessionToken , err
617+ }
618+ params := map [string ]interface {}{
619+ "SecretId" : t .SecretID ,
620+ "Policy" : url .QueryEscape (policy ),
621+ "DurationSeconds" : 1800 ,
622+ "Region" : region ,
623+ "Timestamp" : time .Now ().Unix (),
624+ "Nonce" : math_rand .Int (),
625+ "Name" : "cos-sts-sdk" ,
626+ "Action" : "GetFederationToken" ,
627+ "Version" : "2018-08-13" ,
628+ }
629+ resp , err := t .sendRequest (params )
630+ if err != nil {
631+ return t .credential .TmpSecretID , t .credential .TmpSecretKey , t .credential .SessionToken , err
632+ }
633+ defer resp .Body .Close ()
634+ if resp .StatusCode > 299 {
635+ return t .credential .TmpSecretID , t .credential .TmpSecretKey , t .credential .SessionToken , fmt .Errorf ("sts StatusCode error: %v" , resp .StatusCode )
636+ }
637+ result := & CredentialCompleteResult {}
638+ err = json .NewDecoder (resp .Body ).Decode (result )
639+ if err == io .EOF {
640+ err = nil // ignore EOF errors caused by empty response body
641+ }
642+ if err != nil {
643+ return t .credential .TmpSecretID , t .credential .TmpSecretKey , t .credential .SessionToken , err
644+ }
645+ if result .Response != nil && result .Response .Error != nil {
646+ result .Response .Error .RequestId = result .Response .RequestId
647+ return t .credential .TmpSecretID , t .credential .TmpSecretKey , t .credential .SessionToken , result .Response .Error
648+ }
649+ if result .Response != nil && result .Response .Credentials != nil {
650+ t .credential .TmpSecretID , t .credential .TmpSecretKey , t .credential .SessionToken , t .expiredTime = result .Response .Credentials .TmpSecretID , result .Response .Credentials .TmpSecretKey , result .Response .Credentials .SessionToken , result .Response .ExpiredTime
651+ return t .credential .TmpSecretID , t .credential .TmpSecretKey , t .credential .SessionToken , nil
652+ }
653+ return t .credential .TmpSecretID , t .credential .TmpSecretKey , t .credential .SessionToken , fmt .Errorf ("GetCredential failed, result: %v" , result .Response )
654+ }
655+
656+ func (t * StsCredentialTransport ) GetCredential () (string , string , string , error ) {
657+ now := time .Now ().Unix ()
658+ t .rwLocker .RLock ()
659+ // 提前 defaultTmpAuthExpire 获取重新获取临时密钥
660+ if t .expiredTime <= now + defaultTmpAuthExpire {
661+ expiredTime := t .expiredTime
662+ t .rwLocker .RUnlock ()
663+ secretID , secretKey , secretToken , err := t .UpdateCredential (now )
664+ // 获取临时密钥失败但密钥未过期
665+ if err != nil && now < expiredTime {
666+ err = nil
667+ }
668+ return secretID , secretKey , secretToken , err
669+ }
670+ defer t .rwLocker .RUnlock ()
671+ return t .credential .TmpSecretID , t .credential .TmpSecretKey , t .credential .SessionToken , nil
672+ }
673+
674+ func (t * StsCredentialTransport ) RoundTrip (req * http.Request ) (* http.Response , error ) {
675+ ak , sk , token , err := t .GetCredential ()
676+ if err != nil {
677+ return nil , err
678+ }
679+ req = cloneRequest (req )
680+ // 增加 Authorization header
681+ authTime := NewAuthTime (defaultAuthExpire )
682+ AddAuthorizationHeader (ak , sk , token , req , authTime )
683+
684+ resp , err := t .transport ().RoundTrip (req )
685+ return resp , err
686+ }
687+
688+ func (t * StsCredentialTransport ) transport () http.RoundTripper {
689+ if t .Transport != nil {
690+ return t .Transport
691+ }
692+ return http .DefaultTransport
693+ }
694+
695+ func (t * StsCredentialTransport ) sendRequest (params map [string ]interface {}) (* http.Response , error ) {
696+ paramValues := url.Values {}
697+ for k , v := range params {
698+ paramValues .Add (fmt .Sprintf ("%v" , k ), fmt .Sprintf ("%v" , v ))
699+ }
700+ sign := t .signed ("POST" , params )
701+ paramValues .Add ("Signature" , sign )
702+
703+ host := defaultStsHost
704+ if t .Host != "" {
705+ host = t .Host
706+ }
707+ resp , err := http .DefaultClient .PostForm (defaultStsSchema + "://" + host , paramValues )
708+ return resp , err
709+ }
710+
711+ func (t * StsCredentialTransport ) signed (method string , params map [string ]interface {}) string {
712+ host := defaultStsHost
713+ if t .Host != "" {
714+ host = t .Host
715+ }
716+ source := method + host + "/?" + makeFlat (params )
717+
718+ hmacObj := hmac .New (sha1 .New , []byte (t .SecretKey ))
719+ hmacObj .Write ([]byte (source ))
720+
721+ sign := base64 .StdEncoding .EncodeToString (hmacObj .Sum (nil ))
722+
723+ return sign
724+ }
725+
726+ func getPolicy (policy * CredentialPolicy ) (string , error ) {
727+ if policy == nil {
728+ return "" , nil
729+ }
730+ res := policy
731+ if policy .Version == "" {
732+ res = & CredentialPolicy {
733+ Version : "2.0" ,
734+ Statement : policy .Statement ,
735+ }
736+ }
737+ bs , err := json .Marshal (res )
738+ if err != nil {
739+ return "" , err
740+ }
741+ return string (bs ), nil
742+ }
743+
744+ func makeFlat (params map [string ]interface {}) string {
745+ keys := make ([]string , 0 , len (params ))
746+ for k , _ := range params {
747+ keys = append (keys , k )
748+ }
749+ sort .Strings (keys )
750+
751+ var plainParms string
752+ for _ , k := range keys {
753+ plainParms += fmt .Sprintf ("&%v=%v" , k , params [k ])
754+ }
755+ return plainParms [1 :]
756+ }
0 commit comments