Skip to content
This repository was archived by the owner on Sep 8, 2025. It is now read-only.

Commit 1e2526e

Browse files
feat: add max_retries to async client
1 parent 6a150c0 commit 1e2526e

File tree

1 file changed

+44
-22
lines changed

1 file changed

+44
-22
lines changed

postgrest/_async/request_builder.py

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from json import JSONDecodeError
44
from typing import Any, Generic, Optional, TypeVar, Union
55

6-
from httpx import Headers, QueryParams
6+
from httpx import Headers, NetworkError, QueryParams, ReadError, TimeoutException
77
from pydantic import ValidationError
88

99
from ..base_request_builder import (
@@ -35,13 +35,16 @@ def __init__(
3535
headers: Headers,
3636
params: QueryParams,
3737
json: dict,
38+
max_retries: int = 0,
3839
) -> None:
3940
self.session = session
4041
self.path = path
4142
self.http_method = http_method
4243
self.headers = headers
4344
self.params = params
44-
self.json = None if http_method in {"GET", "HEAD"} else json
45+
self.json = json
46+
self.max_retries = max_retries
47+
self.attempt = 1
4548

4649
async def execute(self) -> APIResponse[_ReturnT]:
4750
"""Execute the query.
@@ -55,14 +58,14 @@ async def execute(self) -> APIResponse[_ReturnT]:
5558
Raises:
5659
:class:`APIError` If the API raised an error.
5760
"""
58-
r = await self.session.request(
59-
self.http_method,
60-
self.path,
61-
json=self.json,
62-
params=self.params,
63-
headers=self.headers,
64-
)
6561
try:
62+
r = await self.session.request(
63+
self.http_method,
64+
self.path,
65+
json=self.json,
66+
params=self.params,
67+
headers=self.headers,
68+
)
6669
if r.is_success:
6770
if self.http_method != "HEAD":
6871
body = r.text
@@ -76,6 +79,10 @@ async def execute(self) -> APIResponse[_ReturnT]:
7679
return APIResponse[_ReturnT].from_http_request_response(r)
7780
else:
7881
raise APIError(r.json())
82+
except (TimeoutException, NetworkError, ReadError) as e:
83+
if self.attempt < self.max_retries:
84+
self.attempt += 1
85+
await self.execute()
7986
except ValidationError as e:
8087
raise APIError(r.json()) from e
8188
except JSONDecodeError:
@@ -91,13 +98,16 @@ def __init__(
9198
headers: Headers,
9299
params: QueryParams,
93100
json: dict,
101+
max_retries: int = 0,
94102
) -> None:
95103
self.session = session
96104
self.path = path
97105
self.http_method = http_method
98106
self.headers = headers
99107
self.params = params
100108
self.json = json
109+
self.max_retries = max_retries
110+
self.attempt = 1
101111

102112
async def execute(self) -> SingleAPIResponse[_ReturnT]:
103113
"""Execute the query.
@@ -111,20 +121,24 @@ async def execute(self) -> SingleAPIResponse[_ReturnT]:
111121
Raises:
112122
:class:`APIError` If the API raised an error.
113123
"""
114-
r = await self.session.request(
115-
self.http_method,
116-
self.path,
117-
json=self.json,
118-
params=self.params,
119-
headers=self.headers,
120-
)
121124
try:
125+
r = await self.session.request(
126+
self.http_method,
127+
self.path,
128+
json=self.json,
129+
params=self.params,
130+
headers=self.headers,
131+
)
122132
if (
123133
200 <= r.status_code <= 299
124134
): # Response.ok from JS (https://developer.mozilla.org/en-US/docs/Web/API/Response/ok)
125135
return SingleAPIResponse[_ReturnT].from_http_request_response(r)
126136
else:
127137
raise APIError(r.json())
138+
except (TimeoutException, NetworkError, ReadError) as e:
139+
if self.attempt < self.max_retries:
140+
self.attempt += 1
141+
await self.execute()
128142
except ValidationError as e:
129143
raise APIError(r.json()) from e
130144
except JSONDecodeError:
@@ -182,12 +196,13 @@ def __init__(
182196
headers: Headers,
183197
params: QueryParams,
184198
json: dict,
199+
max_retries: int = 0,
185200
) -> None:
186201
get_origin_and_cast(BaseFilterRequestBuilder[_ReturnT]).__init__(
187202
self, session, headers, params
188203
)
189204
get_origin_and_cast(AsyncSingleRequestBuilder[_ReturnT]).__init__(
190-
self, session, path, http_method, headers, params, json
205+
self, session, path, http_method, headers, params, json, max_retries
191206
)
192207

193208

@@ -201,12 +216,14 @@ def __init__(
201216
headers: Headers,
202217
params: QueryParams,
203218
json: dict,
219+
max_retries: int = 0,
204220
) -> None:
221+
self.max_retries = max_retries
205222
get_origin_and_cast(BaseSelectRequestBuilder[_ReturnT]).__init__(
206223
self, session, headers, params
207224
)
208225
get_origin_and_cast(AsyncQueryRequestBuilder[_ReturnT]).__init__(
209-
self, session, path, http_method, headers, params, json
226+
self, session, path, http_method, headers, params, json, self.max_retries
210227
)
211228

212229
def single(self) -> AsyncSingleRequestBuilder[_ReturnT]:
@@ -223,6 +240,7 @@ def single(self) -> AsyncSingleRequestBuilder[_ReturnT]:
223240
params=self.params,
224241
path=self.path,
225242
session=self.session, # type: ignore
243+
max_retries=self.max_retries,
226244
)
227245

228246
def maybe_single(self) -> AsyncMaybeSingleRequestBuilder[_ReturnT]:
@@ -235,6 +253,7 @@ def maybe_single(self) -> AsyncMaybeSingleRequestBuilder[_ReturnT]:
235253
params=self.params,
236254
path=self.path,
237255
session=self.session, # type: ignore
256+
max_retries=self.max_retries,
238257
)
239258

240259
def text_search(
@@ -258,6 +277,7 @@ def text_search(
258277
params=self.params,
259278
path=self.path,
260279
session=self.session, # type: ignore
280+
max_retries=self.max_retries,
261281
)
262282

263283
def csv(self) -> AsyncSingleRequestBuilder[str]:
@@ -270,13 +290,15 @@ def csv(self) -> AsyncSingleRequestBuilder[str]:
270290
headers=self.headers,
271291
params=self.params,
272292
json=self.json,
293+
max_retries=self.max_retries,
273294
)
274295

275296

276297
class AsyncRequestBuilder(Generic[_ReturnT]):
277-
def __init__(self, session: AsyncClient, path: str) -> None:
298+
def __init__(self, session: AsyncClient, path: str, max_retries: int = 0) -> None:
278299
self.session = session
279300
self.path = path
301+
self.max_retries = max_retries
280302

281303
def select(
282304
self,
@@ -294,7 +316,7 @@ def select(
294316
"""
295317
method, params, headers, json = pre_select(*columns, count=count, head=head)
296318
return AsyncSelectRequestBuilder[_ReturnT](
297-
self.session, self.path, method, headers, params, json
319+
self.session, self.path, method, headers, params, json, self.max_retries
298320
)
299321

300322
def insert(
@@ -327,7 +349,7 @@ def insert(
327349
default_to_null=default_to_null,
328350
)
329351
return AsyncQueryRequestBuilder[_ReturnT](
330-
self.session, self.path, method, headers, params, json
352+
self.session, self.path, method, headers, params, json, self.max_retries
331353
)
332354

333355
def upsert(
@@ -364,7 +386,7 @@ def upsert(
364386
default_to_null=default_to_null,
365387
)
366388
return AsyncQueryRequestBuilder[_ReturnT](
367-
self.session, self.path, method, headers, params, json
389+
self.session, self.path, method, headers, params, json, self.max_retries
368390
)
369391

370392
def update(

0 commit comments

Comments
 (0)