Skip to content

Commit 3c8bdae

Browse files
o-santisilentworks
andauthored
fix: validate JSON input for APIError (#597)
Co-authored-by: Andrew Smith <[email protected]>
1 parent 576a5b8 commit 3c8bdae

File tree

6 files changed

+87
-23
lines changed

6 files changed

+87
-23
lines changed

Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ remove_pytest_asyncio_from_sync:
3737
sed -i 's/@pytest.mark.asyncio//g' tests/_sync/test_client.py
3838
sed -i 's/_async/_sync/g' tests/_sync/test_client.py
3939
sed -i 's/Async/Sync/g' tests/_sync/test_client.py
40+
sed -i 's/_client\.SyncClient/_client\.Client/g' tests/_sync/test_client.py
4041

4142
sleep:
4243
sleep 2

postgrest/_async/request_builder.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
from json import JSONDecodeError
43
from typing import Any, Generic, Optional, TypeVar, Union
54

65
from httpx import Headers, QueryParams
@@ -19,7 +18,7 @@
1918
pre_update,
2019
pre_upsert,
2120
)
22-
from ..exceptions import APIError, generate_default_error_message
21+
from ..exceptions import APIError, APIErrorFromJSON, generate_default_error_message
2322
from ..types import ReturnMethod
2423
from ..utils import AsyncClient, get_origin_and_cast
2524

@@ -75,10 +74,9 @@ async def execute(self) -> APIResponse[_ReturnT]:
7574
return body
7675
return APIResponse[_ReturnT].from_http_request_response(r)
7776
else:
78-
raise APIError(r.json())
77+
json_obj = APIErrorFromJSON.model_validate_json(r.content)
78+
raise APIError(dict(json_obj))
7979
except ValidationError as e:
80-
raise APIError(r.json()) from e
81-
except JSONDecodeError:
8280
raise APIError(generate_default_error_message(r))
8381

8482

@@ -124,10 +122,9 @@ async def execute(self) -> SingleAPIResponse[_ReturnT]:
124122
): # Response.ok from JS (https://developer.mozilla.org/en-US/docs/Web/API/Response/ok)
125123
return SingleAPIResponse[_ReturnT].from_http_request_response(r)
126124
else:
127-
raise APIError(r.json())
125+
json_obj = APIErrorFromJSON.model_validate_json(r.content)
126+
raise APIError(dict(json_obj))
128127
except ValidationError as e:
129-
raise APIError(r.json()) from e
130-
except JSONDecodeError:
131128
raise APIError(generate_default_error_message(r))
132129

133130

postgrest/_sync/request_builder.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
from json import JSONDecodeError
43
from typing import Any, Generic, Optional, TypeVar, Union
54

65
from httpx import Headers, QueryParams
@@ -19,7 +18,7 @@
1918
pre_update,
2019
pre_upsert,
2120
)
22-
from ..exceptions import APIError, generate_default_error_message
21+
from ..exceptions import APIError, APIErrorFromJSON, generate_default_error_message
2322
from ..types import ReturnMethod
2423
from ..utils import SyncClient, get_origin_and_cast
2524

@@ -75,10 +74,9 @@ def execute(self) -> APIResponse[_ReturnT]:
7574
return body
7675
return APIResponse[_ReturnT].from_http_request_response(r)
7776
else:
78-
raise APIError(r.json())
77+
json_obj = APIErrorFromJSON.model_validate_json(r.content)
78+
raise APIError(dict(json_obj))
7979
except ValidationError as e:
80-
raise APIError(r.json()) from e
81-
except JSONDecodeError:
8280
raise APIError(generate_default_error_message(r))
8381

8482

@@ -124,10 +122,9 @@ def execute(self) -> SingleAPIResponse[_ReturnT]:
124122
): # Response.ok from JS (https://developer.mozilla.org/en-US/docs/Web/API/Response/ok)
125123
return SingleAPIResponse[_ReturnT].from_http_request_response(r)
126124
else:
127-
raise APIError(r.json())
125+
json_obj = APIErrorFromJSON.model_validate_json(r.content)
126+
raise APIError(dict(json_obj))
128127
except ValidationError as e:
129-
raise APIError(r.json()) from e
130-
except JSONDecodeError:
131128
raise APIError(generate_default_error_message(r))
132129

133130

@@ -290,7 +287,7 @@ def select(
290287
*columns: The names of the columns to fetch.
291288
count: The method to use to get the count of rows returned.
292289
Returns:
293-
:class:`SyncSelectRequestBuilder`
290+
:class:`AsyncSelectRequestBuilder`
294291
"""
295292
method, params, headers, json = pre_select(*columns, count=count, head=head)
296293
return SyncSelectRequestBuilder[_ReturnT](
@@ -317,7 +314,7 @@ def insert(
317314
Otherwise, use the default value for the column.
318315
Only applies for bulk inserts.
319316
Returns:
320-
:class:`SyncQueryRequestBuilder`
317+
:class:`AsyncQueryRequestBuilder`
321318
"""
322319
method, params, headers, json = pre_insert(
323320
json,
@@ -353,7 +350,7 @@ def upsert(
353350
not when merging with existing rows under `ignoreDuplicates: false`.
354351
This also only applies when doing bulk upserts.
355352
Returns:
356-
:class:`SyncQueryRequestBuilder`
353+
:class:`AsyncQueryRequestBuilder`
357354
"""
358355
method, params, headers, json = pre_upsert(
359356
json,
@@ -381,7 +378,7 @@ def update(
381378
count: The method to use to get the count of rows returned.
382379
returning: Either 'minimal' or 'representation'
383380
Returns:
384-
:class:`SyncFilterRequestBuilder`
381+
:class:`AsyncFilterRequestBuilder`
385382
"""
386383
method, params, headers, json = pre_update(
387384
json,
@@ -404,7 +401,7 @@ def delete(
404401
count: The method to use to get the count of rows returned.
405402
returning: Either 'minimal' or 'representation'
406403
Returns:
407-
:class:`SyncFilterRequestBuilder`
404+
:class:`AsyncFilterRequestBuilder`
408405
"""
409406
method, params, headers, json = pre_delete(
410407
count=count,

postgrest/exceptions.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,23 @@
11
from typing import Dict, Optional
22

3+
from pydantic import BaseModel
4+
5+
6+
class APIErrorFromJSON(BaseModel):
7+
"""
8+
A pydantic object to validate an error info object
9+
from a json string.
10+
"""
11+
12+
message: Optional[str]
13+
"""The error message."""
14+
code: Optional[str]
15+
"""The error code."""
16+
hint: Optional[str]
17+
"""The error hint."""
18+
details: Optional[str]
19+
"""The error details."""
20+
321

422
class APIError(Exception):
523
"""

tests/_async/test_client.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from unittest.mock import patch
22

33
import pytest
4-
from httpx import BasicAuth, Headers
4+
from httpx import BasicAuth, Headers, Request, Response
55

66
from postgrest import AsyncPostgrestClient
77
from postgrest.exceptions import APIError
@@ -127,3 +127,28 @@ async def test_response_maybe_single(postgrest_client: AsyncPostgrestClient):
127127
exc_response = exc_info.value.json()
128128
assert isinstance(exc_response.get("message"), str)
129129
assert "code" in exc_response and int(exc_response["code"]) == 204
130+
131+
132+
# https://github.com/supabase/postgrest-py/issues/595
133+
@pytest.mark.asyncio
134+
async def test_response_client_invalid_response_but_valid_json(
135+
postgrest_client: AsyncPostgrestClient,
136+
):
137+
with patch(
138+
"httpx._client.AsyncClient.request",
139+
return_value=Response(
140+
status_code=502,
141+
text='"gateway error: Error: Network connection lost."', # quotes makes this text a valid non-dict JSON object
142+
request=Request(method="GET", url="http://example.com"),
143+
),
144+
):
145+
client = postgrest_client.from_("test").select("a", "b").eq("c", "d").single()
146+
assert "Accept" in client.headers
147+
assert client.headers.get("Accept") == "application/vnd.pgrst.object+json"
148+
with pytest.raises(APIError) as exc_info:
149+
await client.execute()
150+
assert isinstance(exc_info, pytest.ExceptionInfo)
151+
exc_response = exc_info.value.json()
152+
assert isinstance(exc_response.get("message"), str)
153+
assert exc_response.get("message") == "JSON could not be generated"
154+
assert "code" in exc_response and int(exc_response["code"]) == 502

tests/_sync/test_client.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from unittest.mock import patch
22

33
import pytest
4-
from httpx import BasicAuth, Headers
4+
from httpx import BasicAuth, Headers, Request, Response
55

66
from postgrest import SyncPostgrestClient
77
from postgrest.exceptions import APIError
@@ -123,3 +123,29 @@ def test_response_maybe_single(postgrest_client: SyncPostgrestClient):
123123
exc_response = exc_info.value.json()
124124
assert isinstance(exc_response.get("message"), str)
125125
assert "code" in exc_response and int(exc_response["code"]) == 204
126+
127+
128+
# https://github.com/supabase/postgrest-py/issues/595
129+
130+
131+
def test_response_client_invalid_response_but_valid_json(
132+
postgrest_client: SyncPostgrestClient,
133+
):
134+
with patch(
135+
"httpx._client.Client.request",
136+
return_value=Response(
137+
status_code=502,
138+
text='"gateway error: Error: Network connection lost."', # quotes makes this text a valid non-dict JSON object
139+
request=Request(method="GET", url="http://example.com"),
140+
),
141+
):
142+
client = postgrest_client.from_("test").select("a", "b").eq("c", "d").single()
143+
assert "Accept" in client.headers
144+
assert client.headers.get("Accept") == "application/vnd.pgrst.object+json"
145+
with pytest.raises(APIError) as exc_info:
146+
client.execute()
147+
assert isinstance(exc_info, pytest.ExceptionInfo)
148+
exc_response = exc_info.value.json()
149+
assert isinstance(exc_response.get("message"), str)
150+
assert exc_response.get("message") == "JSON could not be generated"
151+
assert "code" in exc_response and int(exc_response["code"]) == 502

0 commit comments

Comments
 (0)