Skip to content

Commit b0fe9fd

Browse files
committed
add new policy, add logic to use policy
1 parent 2510f19 commit b0fe9fd

File tree

7 files changed

+170
-32
lines changed

7 files changed

+170
-32
lines changed

sdk/cosmos/azure-cosmos/azure/cosmos/_base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def GetHeaders( # pylint: disable=too-many-statements,too-many-branches
116116
path: str,
117117
resource_id: Optional[str],
118118
resource_type: str,
119+
operation_type: str,
119120
options: Mapping[str, Any],
120121
partition_key_range_id: Optional[str] = None,
121122
) -> Dict[str, Any]:
@@ -323,6 +324,11 @@ def GetHeaders( # pylint: disable=too-many-statements,too-many-branches
323324
if resource_type != 'dbs' and options.get("containerRID"):
324325
headers[http_constants.HttpHeaders.IntendedCollectionRID] = options["containerRID"]
325326

327+
if resource_type == "":
328+
resource_type = "databaseaccount"
329+
headers[http_constants.HttpHeaders.ThinClientProxyResourceType] = resource_type
330+
headers[http_constants.HttpHeaders.ThinClientProxyOperationType] = operation_type
331+
326332
return headers
327333

328334

sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2038,7 +2038,8 @@ def PatchItem(
20382038
if options is None:
20392039
options = {}
20402040

2041-
headers = base.GetHeaders(self, self.default_headers, "patch", path, document_id, resource_type, options)
2041+
headers = base.GetHeaders(self, self.default_headers, "patch", path, document_id, resource_type,
2042+
documents._OperationType.Patch, options)
20422043
# Patch will use WriteEndpoint since it uses PUT operation
20432044
request_params = RequestObject(resource_type, documents._OperationType.Patch)
20442045
request_data = {}
@@ -2126,7 +2127,8 @@ def _Batch(
21262127
) -> Tuple[List[Dict[str, Any]], CaseInsensitiveDict]:
21272128
initial_headers = self.default_headers.copy()
21282129
base._populate_batch_headers(initial_headers)
2129-
headers = base.GetHeaders(self, initial_headers, "post", path, collection_id, "docs", options)
2130+
headers = base.GetHeaders(self, initial_headers, "post", path, collection_id, "docs",
2131+
documents._OperationType.Batch, options)
21302132
request_params = RequestObject("docs", documents._OperationType.Batch)
21312133
return cast(
21322134
Tuple[List[Dict[str, Any]], CaseInsensitiveDict],
@@ -2185,7 +2187,8 @@ def DeleteAllItemsByPartitionKey(
21852187
# Specified url to perform background operation to delete all items by partition key
21862188
path = '{}{}/{}'.format(path, "operations", "partitionkeydelete")
21872189
collection_id = base.GetResourceIdOrFullNameFromLink(collection_link)
2188-
headers = base.GetHeaders(self, self.default_headers, "post", path, collection_id, "partitionkey", options)
2190+
headers = base.GetHeaders(self, self.default_headers, "post", path, collection_id,
2191+
"partitionkey", documents._OperationType.Delete, options)
21892192
request_params = RequestObject("partitionkey", documents._OperationType.Delete)
21902193
_, last_response_headers = self.__Post(
21912194
path=path,
@@ -2353,7 +2356,8 @@ def ExecuteStoredProcedure(
23532356

23542357
path = base.GetPathFromLink(sproc_link)
23552358
sproc_id = base.GetResourceIdOrFullNameFromLink(sproc_link)
2356-
headers = base.GetHeaders(self, initial_headers, "post", path, sproc_id, "sprocs", options)
2359+
headers = base.GetHeaders(self, initial_headers, "post", path, sproc_id, "sprocs",
2360+
documents._OperationType.ExecuteJavaScript, options)
23572361

23582362
# ExecuteStoredProcedure will use WriteEndpoint since it uses POST operation
23592363
request_params = RequestObject("sprocs", documents._OperationType.ExecuteJavaScript)
@@ -2550,7 +2554,8 @@ def GetDatabaseAccount(
25502554
if url_connection is None:
25512555
url_connection = self.url_connection
25522556

2553-
headers = base.GetHeaders(self, self.default_headers, "get", "", "", "", {})
2557+
headers = base.GetHeaders(self, self.default_headers, "get", "", "", "",
2558+
documents._OperationType.Read,{})
25542559
request_params = RequestObject("databaseaccount", documents._OperationType.Read, url_connection)
25552560
result, last_response_headers = self.__Get("", request_params, headers, **kwargs)
25562561
self.last_response_headers = last_response_headers
@@ -2615,7 +2620,8 @@ def Create(
26152620
options = {}
26162621

26172622
initial_headers = initial_headers or self.default_headers
2618-
headers = base.GetHeaders(self, initial_headers, "post", path, id, typ, options)
2623+
headers = base.GetHeaders(self, initial_headers, "post", path, id, typ, documents._OperationType.Create,
2624+
options)
26192625
# Create will use WriteEndpoint since it uses POST operation
26202626

26212627
request_params = RequestObject(typ, documents._OperationType.Create)
@@ -2659,7 +2665,8 @@ def Upsert(
26592665
options = {}
26602666

26612667
initial_headers = initial_headers or self.default_headers
2662-
headers = base.GetHeaders(self, initial_headers, "post", path, id, typ, options)
2668+
headers = base.GetHeaders(self, initial_headers, "post", path, id, typ, documents._OperationType.Upsert,
2669+
options)
26632670
headers[http_constants.HttpHeaders.IsUpsert] = True
26642671

26652672
# Upsert will use WriteEndpoint since it uses POST operation
@@ -2703,7 +2710,8 @@ def Replace(
27032710
options = {}
27042711

27052712
initial_headers = initial_headers or self.default_headers
2706-
headers = base.GetHeaders(self, initial_headers, "put", path, id, typ, options)
2713+
headers = base.GetHeaders(self, initial_headers, "put", path, id, typ, documents._OperationType.Replace,
2714+
options)
27072715
# Replace will use WriteEndpoint since it uses PUT operation
27082716
request_params = RequestObject(typ, documents._OperationType.Replace)
27092717
result, last_response_headers = self.__Put(path, request_params, resource, headers, **kwargs)
@@ -2744,7 +2752,7 @@ def Read(
27442752
options = {}
27452753

27462754
initial_headers = initial_headers or self.default_headers
2747-
headers = base.GetHeaders(self, initial_headers, "get", path, id, typ, options)
2755+
headers = base.GetHeaders(self, initial_headers, "get", path, id, typ, documents._OperationType.Read, options)
27482756
# Read will use ReadEndpoint since it uses GET operation
27492757
request_params = RequestObject(typ, documents._OperationType.Read)
27502758
result, last_response_headers = self.__Get(path, request_params, headers, **kwargs)
@@ -2782,7 +2790,8 @@ def DeleteResource(
27822790
options = {}
27832791

27842792
initial_headers = initial_headers or self.default_headers
2785-
headers = base.GetHeaders(self, initial_headers, "delete", path, id, typ, options)
2793+
headers = base.GetHeaders(self, initial_headers, "delete", path, id, typ, documents._OperationType.Delete,
2794+
options)
27862795
# Delete will use WriteEndpoint since it uses DELETE operation
27872796
request_params = RequestObject(typ, documents._OperationType.Delete)
27882797
result, last_response_headers = self.__Delete(path, request_params, headers, **kwargs)
@@ -3027,6 +3036,7 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]:
30273036
path,
30283037
resource_id,
30293038
resource_type,
3039+
request_params.operation_type,
30303040
options,
30313041
partition_key_range_id
30323042
)
@@ -3064,6 +3074,7 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]:
30643074
path,
30653075
resource_id,
30663076
resource_type,
3077+
documents._OperationType.SqlQuery,
30673078
options,
30683079
partition_key_range_id
30693080
)

sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@
2222
"""Internal methods for executing functions in the Azure Cosmos database service.
2323
"""
2424
import json
25+
from requests.exceptions import ReadTimeout, ConnectTimeout
2526
import time
2627
from typing import Optional
2728

28-
from azure.core.exceptions import AzureError, ClientAuthenticationError, ServiceRequestError
29+
from azure.core.exceptions import AzureError, ClientAuthenticationError, ServiceRequestError, ServiceResponseError
2930
from azure.core.pipeline import PipelineRequest
3031
from azure.core.pipeline.policies import RetryPolicy
3132
from azure.core.pipeline.transport._base import HttpRequest
@@ -38,6 +39,7 @@
3839
from . import _gone_retry_policy
3940
from . import _timeout_failover_retry_policy
4041
from . import _container_recreate_retry_policy
42+
from . import _service_response_retry_policy
4143
from .http_constants import HttpHeaders, StatusCodes, SubStatusCodes
4244

4345

@@ -78,6 +80,9 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs):
7880
timeout_failover_retry_policy = _timeout_failover_retry_policy._TimeoutFailoverRetryPolicy(
7981
client.connection_policy, global_endpoint_manager, *args
8082
)
83+
service_response_retry_policy = _service_response_retry_policy.ServiceResponseRetryPolicy(
84+
client.connection_policy, global_endpoint_manager, *args,
85+
)
8186
# HttpRequest we would need to modify for Container Recreate Retry Policy
8287
request: Optional[HttpRequest] = None
8388
if args and len(args) > 3:
@@ -188,6 +193,15 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs):
188193
if kwargs['timeout'] <= 0:
189194
raise exceptions.CosmosClientTimeoutError()
190195

196+
except ServiceResponseError as e:
197+
if _has_retryable_headers(request.http_request.headers):
198+
# we resolve the request endpoint to the next preferred region
199+
# once we are out of preferred regions we stop retrying
200+
retry_policy = service_response_retry_policy
201+
if not retry_policy.ShouldRetry():
202+
if args and args[0].should_clear_session_token_on_session_read_failure and client.session:
203+
client.session.clear_session_token(client.last_response_headers)
204+
raise
191205

192206
def ExecuteFunction(function, *args, **kwargs):
193207
"""Stub method so that it can be used for mocking purposes as well.
@@ -198,6 +212,12 @@ def ExecuteFunction(function, *args, **kwargs):
198212
"""
199213
return function(*args, **kwargs)
200214

215+
def _has_retryable_headers(request_headers):
216+
if (request_headers.get(HttpHeaders.ThinClientProxyResourceType) in ["docs"]
217+
and request_headers.get(HttpHeaders.ThinClientProxyOperationType) in ["Read", "Query", "QueryPlan",
218+
"ReadFeed", "SqlQuery"]):
219+
return True
220+
return False
201221

202222
def _configure_timeout(request: PipelineRequest, absolute: Optional[int], per_request: int) -> None:
203223
if absolute is not None:
@@ -261,6 +281,9 @@ def send(self, request):
261281
timeout_error.history = retry_settings['history']
262282
raise
263283
except ServiceRequestError as err:
284+
if _has_retryable_headers(request.http_request.headers):
285+
# raise exception immediately to be dealt with in client retry policies
286+
raise err
264287
# the request ran into a socket timeout or failed to establish a new connection
265288
# since request wasn't sent, we retry up to however many connection retries are configured (default 3)
266289
if retry_settings['connect'] > 0:
@@ -269,13 +292,16 @@ def send(self, request):
269292
self.sleep(retry_settings, request.context.transport)
270293
continue
271294
raise err
272-
except AzureError as err:
295+
except ServiceResponseError as err:
273296
retry_error = err
274-
if self._is_method_retryable(retry_settings, request.http_request):
275-
retry_active = self.increment(retry_settings, response=request, error=err)
276-
if retry_active:
277-
self.sleep(retry_settings, request.context.transport)
278-
continue
297+
if err.exc_type in [ReadTimeout, ConnectTimeout]:
298+
if _has_retryable_headers(request.http_request.headers):
299+
# raise exception immediately to be dealt with in client retry policies
300+
raise err
301+
retry_active = self.increment(retry_settings, response=request, error=err)
302+
if retry_active:
303+
self.sleep(retry_settings, request.context.transport)
304+
continue
279305
raise err
280306
finally:
281307
end_time = time.time()
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# The MIT License (MIT)
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
4+
"""Internal class for service response read errors implementation in the Azure
5+
Cosmos database service.
6+
"""
7+
8+
class ServiceResponseRetryPolicy(object):
9+
10+
def __init__(self, connection_policy, global_endpoint_manager, *args):
11+
self.args = args
12+
self.global_endpoint_manager = global_endpoint_manager
13+
self.total_retries = len(self.global_endpoint_manager.location_cache.read_endpoints)
14+
self.failover_retry_count = 0
15+
self.connection_policy = connection_policy
16+
self.request = args[0] if args else None
17+
if self.request:
18+
self.location_endpoint = self.global_endpoint_manager.resolve_service_endpoint(self.request)
19+
20+
def ShouldRetry(self):
21+
"""Returns true if the request should retry based on preferred regions and retries already done.
22+
23+
"""
24+
if not self.connection_policy.EnableEndpointDiscovery:
25+
return False
26+
if self.args[0].operation_type != 'Read' and self.args[0].resource_type != 'docs':
27+
return False
28+
29+
self.failover_retry_count += 1
30+
if self.failover_retry_count > self.total_retries:
31+
return False
32+
33+
if self.request:
34+
# clear previous location-based routing directive
35+
self.request.clear_route_to_location()
36+
37+
# set location-based routing directive based on retry count
38+
# ensuring usePreferredLocations is set to True for retry
39+
self.request.route_to_location_with_preferred_location_flag(self.failover_retry_count, True)
40+
41+
# Resolve the endpoint for the request and pin the resolution to the resolved endpoint
42+
# This enables marking the endpoint unavailability on endpoint failover/unreachability
43+
self.location_endpoint = self.global_endpoint_manager.resolve_service_endpoint(self.request)
44+
self.request.route_to_location(self.location_endpoint)
45+
return True

0 commit comments

Comments
 (0)