Skip to content

Commit 55fe801

Browse files
committed
refactor(core): fix
1 parent 4748498 commit 55fe801

File tree

12 files changed

+92
-188
lines changed

12 files changed

+92
-188
lines changed

volcenginesdkcore/api_client.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,12 @@ def __init__(self, configuration=None, header_name=None, header_value=None,
7171

7272
self.interceptor_chain = InterceptorChain()
7373

74-
self.interceptor_chain.append_common_request_interceptor(BuildRequestInterceptor())
75-
self.interceptor_chain.append_common_request_interceptor(RuntimeOptionsInterceptor())
76-
self.interceptor_chain.append_common_request_interceptor(ResolveEndpointInterceptor())
77-
self.interceptor_chain.append_common_request_interceptor(SignRequestInterceptor())
74+
self.interceptor_chain.append_request_interceptor(BuildRequestInterceptor())
75+
self.interceptor_chain.append_request_interceptor(RuntimeOptionsInterceptor())
76+
self.interceptor_chain.append_request_interceptor(ResolveEndpointInterceptor())
77+
self.interceptor_chain.append_request_interceptor(SignRequestInterceptor())
7878

79-
self.interceptor_chain.append_common_response_interceptor(DeserializedResponseInterceptor())
79+
self.interceptor_chain.append_response_interceptor(DeserializedResponseInterceptor())
8080

8181
def __del__(self):
8282
if self._pool is not None:

volcenginesdkcore/auth/providers/sts_provider.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
from datetime import datetime
44

55
import dateutil.parser
6-
from urllib3 import Timeout
76

87
from volcenginesdkcore import UniversalApi, UniversalInfo, ApiClient, Configuration
98
from volcenginesdkcore.auth.providers.provider import Provider, CredentialValue
109

10+
import threading
1111

1212
class AssumeRoleCredentials:
1313
def __init__(self, ak, sk, session_token, current_time, expired_time):
@@ -20,7 +20,7 @@ def __init__(self, ak, sk, session_token, current_time, expired_time):
2020

2121
class StsCredentialProvider(Provider):
2222
def __init__(self, ak, sk, role_name, account_id, duration_seconds=3600, scheme='https',
23-
host='sts.volcengineapi.com', region='cn-north-1', timeout=30):
23+
host='sts.volcengineapi.com', region='cn-north-1', timeout=30, expired_buffer_seconds=60):
2424
self.ak = ak
2525
self.sk = sk
2626
self.role_name = role_name
@@ -34,17 +34,24 @@ def __init__(self, ak, sk, role_name, account_id, duration_seconds=3600, scheme=
3434
self.scheme = scheme
3535

3636
self.expired_time = None
37+
if expired_buffer_seconds > 600:
38+
raise ValueError('expired_buffer_seconds must be less than or equal to 600')
39+
self.expired_buffer_seconds = expired_buffer_seconds
3740

3841
self.credentials = None
3942

43+
self._lock = threading.Lock()
44+
4045
def retrieve(self):
4146
return self.credentials
4247

4348
def is_expired(self):
44-
return self.credentials is None or (self.expired_time and self.expired_time < time.time())
49+
return (self.credentials is None or
50+
(self.expired_time and self.expired_time < time.time() + self.expired_buffer_seconds))
4551

4652
def refresh(self):
47-
self._assume_role()
53+
with self._lock:
54+
self._assume_role()
4855

4956
def _assume_role(self):
5057
params = {
@@ -58,7 +65,7 @@ def _assume_role(self):
5865
configuration.host = self.host
5966
configuration.region = self.region
6067
configuration.schema = self.scheme
61-
configuration.timeout = Timeout(self.timeout)
68+
configuration.read_timeout = self.timeout
6269
c = UniversalApi(ApiClient(configuration))
6370
info = UniversalInfo(method='GET', service='sts', version='2018-01-01', action='AssumeRole',
6471
content_type='text/plain')

volcenginesdkcore/configuration.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,9 @@ def __init__(self):
8888
self.assert_hostname = None
8989

9090
self.num_pools = 4
91-
self.timeout = Timeout(connect=3.0, read=30.0, total=30.0)
92-
self.retries = Retry(3)
91+
92+
self.connect_timeout = 60.0
93+
self.read_timeout = 60.0
9394

9495
# urllib3 connection pool's maximum number of connections saved
9596
# per pool. urllib3 uses 1 connection as default value, but this is

volcenginesdkcore/endpoint/endpoint_provider.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,11 @@
44

55

66
class ResolvedEndpoint:
7-
def __init__(self, host, scheme="https"):
7+
def __init__(self, host):
88
self.host = host
9-
self.scheme = scheme
109

11-
@property
12-
def full_url(self):
13-
return self.scheme + '://' + self.host
10+
def url_for(self, scheme='https'):
11+
return scheme + '://' + self.host
1412

1513

1614
class EndpointProvider(object):

volcenginesdkcore/endpoint/providers/default_provider.py

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
fallback_endpoint = 'open.volcengineapi.com'
55

66

7-
class DefaultEndpointConfig:
7+
class ServiceEndpointInfo:
88
def __init__(self, service, is_global, global_endpoint,
99
region_endpoint_map, fallback_endpoint=fallback_endpoint):
1010
self.service = service
@@ -13,50 +13,48 @@ def __init__(self, service, is_global, global_endpoint,
1313
self.region_endpoint_map = region_endpoint_map
1414
self.fallback_endpoint = fallback_endpoint
1515

16-
17-
def get_default_endpoint(service, region):
18-
if service in default_endpoint:
19-
e = default_endpoint[service]
20-
if e.is_global:
21-
return e.global_endpoint
22-
if region in e.region_endpoint_map:
23-
return e.region_endpoint_map[region]
24-
return e.fallback_endpoint
25-
return fallback_endpoint
26-
27-
28-
default_endpoint = {
29-
'ecs': DefaultEndpointConfig(
30-
service='ecs',
31-
is_global=False,
32-
global_endpoint='',
33-
region_endpoint_map={
34-
'cn-beijing-autodriving': 'ecs' + '.' + 'cn-beijing-autodriving' + '.volcengineapi.com'
35-
}
36-
),
37-
}
16+
def get_endpoint_for(self, region):
17+
if self.is_global:
18+
return self.global_endpoint
19+
if region in self.region_endpoint_map:
20+
return self.region_endpoint_map[region]
21+
return self.fallback_endpoint
3822

3923

4024
class DefaultEndpointProvider(EndpointProvider):
25+
default_endpoint = {
26+
'ecs': ServiceEndpointInfo(
27+
service='ecs',
28+
is_global=False,
29+
global_endpoint='',
30+
region_endpoint_map={
31+
'cn-beijing-autodriving': 'ecs' + '.' + 'cn-beijing-autodriving' + '.volcengineapi.com'
32+
}
33+
),
34+
}
4135

4236
def __init__(self, custom_endpoints=None):
43-
self.scheme = 'https'
4437
self.custom_endpoints = custom_endpoints or {}
4538

39+
def get_default_endpoint(self, service, region):
40+
if service in self.default_endpoint:
41+
e = self.default_endpoint[service]
42+
return e.get_endpoint_for(region)
43+
return fallback_endpoint
44+
4645
def endpoint_for(self, service, region):
47-
# 检查自定义终端节点配置
4846
if service in self.custom_endpoints:
49-
url = self.custom_endpoints[service]
47+
conf = self.custom_endpoints[service]
48+
host = conf.get_endpoint_for(region)
5049
else:
51-
url = get_default_endpoint(service=service, region=region)
50+
host = self.get_default_endpoint(service=service, region=region)
5251

53-
return ResolvedEndpoint(url, self.scheme)
52+
return ResolvedEndpoint(host)
5453

5554

5655
class HostEndpointProvider(EndpointProvider):
5756
def __init__(self, host):
5857
self.host = host
59-
self.scheme = 'https'
6058

6159
def endpoint_for(self, service, region):
62-
return ResolvedEndpoint(self.host, self.scheme)
60+
return ResolvedEndpoint(self.host)
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
from volcenginesdkcore.interceptor.interceptors.context import *
2-
from .chain import *
32
from .interceptors import *

volcenginesdkcore/interceptor/chain.py

Lines changed: 3 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,15 @@
11
# coding=utf-8
2-
3-
4-
def check_common_request_interceptor(interceptor):
5-
if not interceptor.is_common():
6-
raise Exception("interceptor is not common")
7-
if not interceptor.is_request():
8-
raise Exception("interceptor is not for request")
9-
if not interceptor.name():
10-
raise Exception("interceptor name is not defined")
11-
12-
13-
def check_common_response_interceptor(interceptor):
14-
if not interceptor.is_common():
15-
raise Exception("interceptor is not common")
16-
if not interceptor.is_response():
17-
raise Exception("interceptor is not for response")
18-
if not interceptor.name():
19-
raise Exception("interceptor name is not defined")
2+
from volcenginesdkcore.interceptor import RequestInterceptor, ResponseInterceptor
203

214

225
def check_request_interceptor(interceptor):
23-
if interceptor.is_common():
24-
raise Exception("interceptor is common")
25-
if not interceptor.is_request():
6+
if not issubclass(interceptor.__class__, RequestInterceptor):
267
raise Exception("interceptor is not for request")
27-
if not interceptor.name():
28-
raise Exception("interceptor name is not defined")
298

309

3110
def check_response_interceptor(interceptor):
32-
if interceptor.is_common():
33-
raise Exception("interceptor is common")
34-
if not interceptor.is_response():
11+
if not issubclass(interceptor.__class__, ResponseInterceptor):
3512
raise Exception("interceptor is not for response")
36-
if not interceptor.name():
37-
raise Exception("interceptor name is not defined")
3813

3914

4015
def insert_interceptor(chain, interceptor, after_name=''):
@@ -54,21 +29,9 @@ class InterceptorChain:
5429
"""拦截器链"""
5530

5631
def __init__(self):
57-
self.common_request_interceptors = []
5832
self.request_interceptors = []
59-
self.common_response_interceptors = []
6033
self.response_interceptors = []
6134

62-
def append_common_request_interceptor(self, interceptor):
63-
check_common_request_interceptor(interceptor)
64-
self.common_request_interceptors.append(interceptor)
65-
return self
66-
67-
def insert_common_request_interceptor(self, interceptor, after_name=''):
68-
check_common_request_interceptor(interceptor)
69-
self.common_request_interceptors = insert_interceptor(self.common_request_interceptors, interceptor, after_name)
70-
return self
71-
7235
def append_request_interceptor(self, interceptor):
7336
check_request_interceptor(interceptor)
7437
self.request_interceptors.append(interceptor)
@@ -79,16 +42,6 @@ def insert_request_interceptor(self, interceptor, after_name=''):
7942
self.request_interceptors = insert_interceptor(self.request_interceptors, interceptor, after_name)
8043
return self
8144

82-
def append_common_response_interceptor(self, interceptor):
83-
check_common_response_interceptor(interceptor)
84-
self.common_response_interceptors.append(interceptor)
85-
return self
86-
87-
def insert_common_response_interceptor(self, interceptor, after_name=''):
88-
check_common_response_interceptor(interceptor)
89-
self.common_response_interceptors = insert_interceptor(self.common_response_interceptors, interceptor,
90-
after_name)
91-
9245
def append_response_interceptor(self, interceptor):
9346
check_response_interceptor(interceptor)
9447
self.response_interceptors.append(interceptor)
@@ -99,18 +52,12 @@ def insert_response_interceptor(self, interceptor, after_name=''):
9952
self.response_interceptors = insert_interceptor(self.response_interceptors, interceptor, after_name)
10053

10154
def execute_request(self, context):
102-
for interceptor in self.common_request_interceptors:
103-
context = interceptor.intercept(context)
104-
10555
for interceptor in self.request_interceptors:
10656
context = interceptor.intercept(context)
10757

10858
return context
10959

11060
def execute_response(self, context):
111-
for interceptor in self.common_response_interceptors:
112-
context = interceptor.intercept(context)
113-
11461
for interceptor in self.response_interceptors:
11562
context = interceptor.intercept(context)
11663

0 commit comments

Comments
 (0)