|
| 1 | +import time |
| 2 | +import uuid |
| 3 | +from datetime import datetime |
| 4 | + |
| 5 | +import dateutil.parser |
| 6 | +from urllib3 import Timeout |
| 7 | + |
| 8 | +from volcenginesdkcore import UniversalApi, UniversalInfo, ApiClient, Configuration |
| 9 | +from volcenginesdkcore.auth.providers.provider import Provider, CredentialValue |
| 10 | + |
| 11 | + |
| 12 | +class AssumeRoleCredentials: |
| 13 | + def __init__(self, ak, sk, session_token, current_time, expired_time): |
| 14 | + self.ak = ak |
| 15 | + self.sk = sk |
| 16 | + self.session_token = session_token |
| 17 | + self.current_time = current_time |
| 18 | + self.expired_time = expired_time |
| 19 | + |
| 20 | + |
| 21 | +class StsCredentialProvider(Provider): |
| 22 | + def __init__(self, ak, sk, role_name, account_id, duration_seconds=3600, scheme='https', |
| 23 | + host='sts.volcengineapi.com', region='cn-north-1', timeout=30): |
| 24 | + self.ak = ak |
| 25 | + self.sk = sk |
| 26 | + self.role_name = role_name |
| 27 | + self.account_id = account_id |
| 28 | + |
| 29 | + self.timeout = timeout |
| 30 | + self.duration_seconds = duration_seconds |
| 31 | + |
| 32 | + self.host = host |
| 33 | + self.region = region |
| 34 | + self.scheme = scheme |
| 35 | + |
| 36 | + self.expired_time = None |
| 37 | + |
| 38 | + self.credentials = None |
| 39 | + |
| 40 | + def retrieve(self): |
| 41 | + return self.credentials |
| 42 | + |
| 43 | + def is_expired(self): |
| 44 | + return self.credentials is None or (self.expired_time and self.expired_time < time.time()) |
| 45 | + |
| 46 | + def refresh(self): |
| 47 | + self._assume_role() |
| 48 | + |
| 49 | + def _assume_role(self): |
| 50 | + params = { |
| 51 | + 'DurationSeconds': self.duration_seconds, |
| 52 | + 'RoleSessionName': uuid.uuid4().hex, |
| 53 | + 'RoleTrn': 'trn:iam::' + self.account_id + ':role/' + self.role_name, |
| 54 | + } |
| 55 | + configuration = Configuration() |
| 56 | + configuration.ak = self.ak |
| 57 | + configuration.sk = self.sk |
| 58 | + configuration.host = self.host |
| 59 | + configuration.region = self.region |
| 60 | + configuration.schema = self.scheme |
| 61 | + configuration.timeout = Timeout(self.timeout) |
| 62 | + c = UniversalApi(ApiClient(configuration)) |
| 63 | + info = UniversalInfo(method='GET', service='sts', version='2018-01-01', action='AssumeRole', |
| 64 | + content_type='text/plain') |
| 65 | + |
| 66 | + resp, status_code, resp_header = c.do_call_with_http_info(info=info, body=params) |
| 67 | + resp_cred = resp['Credentials'] |
| 68 | + |
| 69 | + # Parse the ISO string |
| 70 | + dt = dateutil.parser.parse(resp_cred['ExpiredTime']) |
| 71 | + |
| 72 | + # Convert to timestamp (seconds since epoch) |
| 73 | + self.expired_time = (dt - datetime(1970, 1, 1, tzinfo=dateutil.tz.tzutc())).total_seconds() |
| 74 | + |
| 75 | + self.credentials = CredentialValue(ak=resp_cred['AccessKeyId'], |
| 76 | + sk=resp_cred['SecretAccessKey'], |
| 77 | + session_token=resp_cred['SessionToken'], |
| 78 | + provider_name='StsCredentialProvider') |
0 commit comments