Skip to content

Commit 9929170

Browse files
nickstenningmattt
andauthored
Don't persist or send cookies to Replicate API (#102)
Replicate's API does not use cookies and even if we return cookies the client should not save and replay them. Signed-off-by: Nick Stenning <[email protected]> Signed-off-by: Mattt Zmuda <[email protected]> Co-authored-by: Mattt Zmuda <[email protected]>
1 parent 83458a3 commit 9929170

File tree

1 file changed

+21
-2
lines changed

1 file changed

+21
-2
lines changed

replicate/client.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import requests
77
from requests.adapters import HTTPAdapter, Retry
8+
from requests.cookies import RequestsCookieJar
89

910
from replicate.__about__ import __version__
1011
from replicate.exceptions import ModelError, ReplicateError
@@ -25,7 +26,7 @@ def __init__(self, api_token: Optional[str] = None) -> None:
2526
self.poll_interval = float(os.environ.get("REPLICATE_POLL_INTERVAL", "0.5"))
2627

2728
# TODO: make thread safe
28-
self.read_session = requests.Session()
29+
self.read_session = _create_session()
2930
read_retries = Retry(
3031
total=5,
3132
backoff_factor=2,
@@ -50,7 +51,7 @@ def __init__(self, api_token: Optional[str] = None) -> None:
5051
self.read_session.mount("http://", HTTPAdapter(max_retries=read_retries))
5152
self.read_session.mount("https://", HTTPAdapter(max_retries=read_retries))
5253

53-
self.write_session = requests.Session()
54+
self.write_session = _create_session()
5455
write_retries = Retry(
5556
total=5,
5657
backoff_factor=2,
@@ -138,3 +139,21 @@ def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]:
138139
if prediction.status == "failed":
139140
raise ModelError(prediction.error)
140141
return prediction.output
142+
143+
144+
class _NonpersistentCookieJar(RequestsCookieJar):
145+
"""
146+
A cookie jar that doesn't persist cookies between requests.
147+
"""
148+
149+
def set(self, name, value, **kwargs) -> None:
150+
return
151+
152+
def set_cookie(self, cookie, *args, **kwargs) -> None:
153+
return
154+
155+
156+
def _create_session() -> requests.Session:
157+
s = requests.Session()
158+
s.cookies = _NonpersistentCookieJar()
159+
return s

0 commit comments

Comments
 (0)