Skip to content

Commit 8e3e7fe

Browse files
committed
feat(core): support p0 capabilities
1 parent fad6778 commit 8e3e7fe

23 files changed

+1142
-451
lines changed

volcenginesdkcore/api_client.py

Lines changed: 39 additions & 445 deletions
Large diffs are not rendered by default.

volcenginesdkcore/auth/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .credential import Credential
2+
from .providers import *
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# coding=utf-8
2+
class Credential(object):
3+
"""凭证管理类"""
4+
5+
def __init__(self, provider):
6+
self.provider = provider
7+
8+
def get(self):
9+
if self.provider.is_expired():
10+
# 当凭证过期时自动刷新
11+
self.provider.refresh()
12+
return self.provider.retrieve()
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .static_provider import StaticCredentialProvider
2+
from .sts_provider import StsCredentialProvider
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# coding=utf-8
2+
import abc
3+
4+
5+
class CredentialValue:
6+
def __init__(self, ak, sk, session_token=None, provider_name=None):
7+
self.ak = ak
8+
self.sk = sk
9+
self.session_token = session_token
10+
self.provider_name = provider_name
11+
12+
13+
class Provider(object):
14+
15+
@abc.abstractmethod
16+
def retrieve(self):
17+
"""获取凭证"""
18+
raise NotImplementedError()
19+
20+
@abc.abstractmethod
21+
def is_expired(self):
22+
"""判断凭证是否过期"""
23+
raise NotImplementedError()
24+
25+
@abc.abstractmethod
26+
def refresh(self):
27+
"""刷新凭证"""
28+
raise NotImplementedError()
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# coding=utf-8
2+
from volcenginesdkcore.auth.providers.provider import Provider, CredentialValue
3+
4+
5+
class StaticCredentialProvider(Provider):
6+
"""静态凭证提供者"""
7+
8+
def _refresh(self):
9+
return
10+
11+
def __init__(self, access_key_id, secret_access_key, session_token=None):
12+
self.credentials = CredentialValue(
13+
access_key_id,
14+
secret_access_key,
15+
session_token,
16+
"StaticCredentialProvider"
17+
)
18+
19+
def retrieve(self):
20+
return self.credentials
21+
22+
def is_expired(self):
23+
return False
24+
25+
def refresh(self):
26+
return
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import time
2+
import uuid
3+
from datetime import datetime
4+
5+
import dateutil.parser
6+
from urllib3 import Timeout
7+
8+
from volcenginesdkcore import UniversalApi, UniversalInfo, ApiClient, Configuration
9+
from volcenginesdkcore.auth.providers.provider import Provider, CredentialValue
10+
11+
12+
class AssumeRoleCredentials:
13+
def __init__(self, ak, sk, session_token, current_time, expired_time):
14+
self.ak = ak
15+
self.sk = sk
16+
self.session_token = session_token
17+
self.current_time = current_time
18+
self.expired_time = expired_time
19+
20+
21+
class StsCredentialProvider(Provider):
22+
def __init__(self, ak, sk, role_name, account_id, duration_seconds=3600, scheme='https',
23+
host='sts.volcengineapi.com', region='cn-north-1', timeout=30):
24+
self.ak = ak
25+
self.sk = sk
26+
self.role_name = role_name
27+
self.account_id = account_id
28+
29+
self.timeout = timeout
30+
self.duration_seconds = duration_seconds
31+
32+
self.host = host
33+
self.region = region
34+
self.scheme = scheme
35+
36+
self.expired_time = None
37+
38+
self.credentials = None
39+
40+
def retrieve(self):
41+
return self.credentials
42+
43+
def is_expired(self):
44+
return self.credentials is None or (self.expired_time and self.expired_time < time.time())
45+
46+
def refresh(self):
47+
self._assume_role()
48+
49+
def _assume_role(self):
50+
params = {
51+
'DurationSeconds': self.duration_seconds,
52+
'RoleSessionName': uuid.uuid4().hex,
53+
'RoleTrn': 'trn:iam::' + self.account_id + ':role/' + self.role_name,
54+
}
55+
configuration = Configuration()
56+
configuration.ak = self.ak
57+
configuration.sk = self.sk
58+
configuration.host = self.host
59+
configuration.region = self.region
60+
configuration.schema = self.scheme
61+
configuration.timeout = Timeout(self.timeout)
62+
c = UniversalApi(ApiClient(configuration))
63+
info = UniversalInfo(method='GET', service='sts', version='2018-01-01', action='AssumeRole',
64+
content_type='text/plain')
65+
66+
resp, status_code, resp_header = c.do_call_with_http_info(info=info, body=params)
67+
resp_cred = resp['Credentials']
68+
69+
# Parse the ISO string
70+
dt = dateutil.parser.parse(resp_cred['ExpiredTime'])
71+
72+
# Convert to timestamp (seconds since epoch)
73+
self.expired_time = (dt - datetime(1970, 1, 1, tzinfo=dateutil.tz.tzutc())).total_seconds()
74+
75+
self.credentials = CredentialValue(ak=resp_cred['AccessKeyId'],
76+
sk=resp_cred['SecretAccessKey'],
77+
session_token=resp_cred['SessionToken'],
78+
provider_name='StsCredentialProvider')

volcenginesdkcore/configuration.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99

1010
import six
1111
from six.moves import http_client as httplib
12+
from urllib3 import Timeout, Retry
13+
14+
from volcenginesdkcore.endpoint.providers import DefaultEndpointProvider
1215

1316

1417
class TypeWithDefault(type):
@@ -36,7 +39,7 @@ def __init__(self):
3639
"""Constructor"""
3740

3841
# Default Base url
39-
self.host = "open.volcengineapi.com"
42+
self.host = None
4043
# Schema Support http or https
4144
self.schema = "http"
4245
# Temp file folder for downloading files
@@ -53,6 +56,7 @@ def __init__(self):
5356
# 自定义适配
5457
self.ak = ""
5558
self.sk = ""
59+
self.session_token = ""
5660
self.region = ""
5761

5862
# Logging Settings
@@ -83,6 +87,10 @@ def __init__(self):
8387
# Set this to True/False to enable/disable SSL hostname verification.
8488
self.assert_hostname = None
8589

90+
self.num_pools = 4
91+
self.timeout = Timeout(connect=3.0, read=30.0, total=30.0)
92+
self.retries = Retry(3)
93+
8694
# urllib3 connection pool's maximum number of connections saved
8795
# per pool. urllib3 uses 1 connection as default value, but this is
8896
# not the best value when you are making a lot of possibly parallel
@@ -98,6 +106,8 @@ def __init__(self):
98106
# Disable client side validation
99107
self.client_side_validation = True
100108

109+
self.endpoint_provider = DefaultEndpointProvider()
110+
101111
@property
102112
def logger_file(self):
103113
"""The logger file.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .endpoint_provider import *
2+
from .providers import *
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# coding=utf-8
2+
3+
import abc
4+
5+
6+
class ResolvedEndpoint:
7+
def __init__(self, host, scheme="https"):
8+
self.host = host
9+
self.scheme = scheme
10+
11+
@property
12+
def full_url(self):
13+
return self.scheme + '://' + self.host
14+
15+
16+
class EndpointProvider(object):
17+
"""接口类:确定服务请求的终端节点"""
18+
19+
@abc.abstractmethod
20+
def endpoint_for(self, service, region):
21+
"""返回指定服务和区域的终端节点"""
22+
raise NotImplementedError()

0 commit comments

Comments
 (0)