@@ -22,24 +22,41 @@ def __init__(self, api_token=None) -> None:
2222 self .poll_interval = float (os .environ .get ("REPLICATE_POLL_INTERVAL" , "0.5" ))
2323
2424 # TODO: make thread safe
25- self .session = requests .Session ()
26-
27- # Gracefully retry requests
28- # This is primarily for when iterating through predict(), where if an exception is thrown, the client
29- # has no way of restarting the iterator.
30- # We might just want to enable retry logic for iterators, but for now this is a blunt instrument to
31- # make this reliable.
32- retries = Retry (
25+ self .read_session = requests .Session ()
26+ read_retries = Retry (
3327 total = 5 ,
3428 backoff_factor = 2 ,
35- # TODO: Only retry on GET so we don't unintionally mutute data
36- method_whitelist = ["GET" , "POST" , "PUT" ],
29+ # Only retry 500s on GET so we don't unintionally mutute data
30+ method_whitelist = ["GET" ],
3731 # https://support.cloudflare.com/hc/en-us/articles/115003011431-Troubleshooting-Cloudflare-5XX-errors
38- status_forcelist = [429 , 500 , 502 , 503 , 504 , 520 , 521 , 522 , 523 , 524 , 526 , 527 ],
32+ status_forcelist = [
33+ 429 ,
34+ 500 ,
35+ 502 ,
36+ 503 ,
37+ 504 ,
38+ 520 ,
39+ 521 ,
40+ 522 ,
41+ 523 ,
42+ 524 ,
43+ 526 ,
44+ 527 ,
45+ ],
3946 )
47+ self .read_session .mount ("http://" , HTTPAdapter (max_retries = read_retries ))
48+ self .read_session .mount ("https://" , HTTPAdapter (max_retries = read_retries ))
4049
41- self .session .mount ("http://" , HTTPAdapter (max_retries = retries ))
42- self .session .mount ("https://" , HTTPAdapter (max_retries = retries ))
50+ self .write_session = requests .Session ()
51+ write_retries = Retry (
52+ total = 5 ,
53+ backoff_factor = 2 ,
54+ method_whitelist = ["POST" , "PUT" ],
55+ # Only retry POST/PUT requests on rate limits, so we don't unintionally mutute data
56+ status_forcelist = [429 ],
57+ )
58+ self .write_session .mount ("http://" , HTTPAdapter (max_retries = write_retries ))
59+ self .write_session .mount ("https://" , HTTPAdapter (max_retries = write_retries ))
4360
4461 def _request (self , method : str , path : str , ** kwargs ):
4562 # from requests.Session
@@ -49,7 +66,10 @@ def _request(self, method: str, path: str, **kwargs):
4966 kwargs .setdefault ("allow_redirects" , False )
5067 kwargs .setdefault ("headers" , {})
5168 kwargs ["headers" ].update (self ._headers ())
52- resp = self .session .request (method , self .base_url + path , ** kwargs )
69+ session = self .read_session
70+ if method in ["POST" , "PUT" , "DELETE" , "PATCH" ]:
71+ session = self .write_session
72+ resp = session .request (method , self .base_url + path , ** kwargs )
5373 if 400 <= resp .status_code < 600 :
5474 try :
5575 raise ReplicateError (resp .json ()["detail" ])
0 commit comments