From b070abb158ea735d1730fd4f66967fef8284361a Mon Sep 17 00:00:00 2001 From: Leonardo Santiago Date: Mon, 6 Oct 2025 15:33:58 -0300 Subject: [PATCH 1/5] fix(storage): do not mutate httpx client instance instead, use `yarl.URL` to parse and manipulate the URL in a stateless way, by returning a new url when paths are appended --- src/storage/pyproject.toml | 1 + src/storage/src/storage3/_async/bucket.py | 25 +++++++---- src/storage/src/storage3/_async/client.py | 29 ++----------- src/storage/src/storage3/_sync/bucket.py | 25 +++++++---- src/storage/src/storage3/_sync/client.py | 31 ++------------ src/storage/tests/_async/test_bucket.py | 51 ++++++++++++++--------- src/storage/tests/_sync/test_bucket.py | 49 +++++++++++++--------- src/storage/tests/test_client.py | 8 +--- uv.lock | 14 ++++--- 9 files changed, 112 insertions(+), 121 deletions(-) diff --git a/src/storage/pyproject.toml b/src/storage/pyproject.toml index 03550807..efe78b81 100644 --- a/src/storage/pyproject.toml +++ b/src/storage/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "httpx[http2] >=0.26,<0.29", "deprecation >=2.1.0", "pydantic >=2.11.7", + "yarl>=1.20.1", ] [project.urls] diff --git a/src/storage/src/storage3/_async/bucket.py b/src/storage/src/storage3/_async/bucket.py index b748b0e7..0a0ca12b 100644 --- a/src/storage/src/storage3/_async/bucket.py +++ b/src/storage/src/storage3/_async/bucket.py @@ -3,6 +3,8 @@ from typing import Any, Optional from httpx import AsyncClient, HTTPStatusError, Response +from httpx._types import HeaderTypes +from yarl import URL from ..exceptions import StorageApiError from ..types import CreateOrUpdateBucketOptions, RequestMethod @@ -14,17 +16,22 @@ class AsyncStorageBucketAPI: """This class abstracts access to the endpoint to the Get, List, Empty, and Delete operations on a bucket""" - def __init__(self, session: AsyncClient) -> None: + def __init__(self, session: AsyncClient, url: str, headers: HeaderTypes) -> None: self._client = session + self._base_url = URL(url) + self._headers = headers async def _request( self, method: RequestMethod, - url: str, + path: list[str], json: Optional[dict[Any, Any]] = None, ) -> Response: try: - response = await self._client.request(method, url, json=json) + url_path = self._base_url.joinpath(*path) + response = await self._client.request( + method, str(url_path), json=json, headers=self._headers + ) response.raise_for_status() except HTTPStatusError as exc: resp = exc.response.json() @@ -35,7 +42,7 @@ async def _request( async def list_buckets(self) -> list[AsyncBucket]: """Retrieves the details of all storage buckets within an existing product.""" # if the request doesn't error, it is assured to return a list - res = await self._request("GET", "/bucket") + res = await self._request("GET", ["bucket"]) return [AsyncBucket(**bucket) for bucket in res.json()] async def get_bucket(self, id: str) -> AsyncBucket: @@ -46,7 +53,7 @@ async def get_bucket(self, id: str) -> AsyncBucket: id The unique identifier of the bucket you would like to retrieve. """ - res = await self._request("GET", f"/bucket/{id}") + res = await self._request("GET", ["bucket", id]) json = res.json() return AsyncBucket(**json) @@ -73,7 +80,7 @@ async def create_bucket( json.update(**options) res = await self._request( "POST", - "/bucket", + ["bucket"], json=json, ) return res.json() @@ -92,7 +99,7 @@ async def update_bucket( `allowed_mime_types`. """ json = {"id": id, "name": id, **options} - res = await self._request("PUT", f"/bucket/{id}", json=json) + res = await self._request("PUT", ["bucket", id], json=json) return res.json() async def empty_bucket(self, id: str) -> dict[str, str]: @@ -103,7 +110,7 @@ async def empty_bucket(self, id: str) -> dict[str, str]: id The unique identifier of the bucket you would like to empty. """ - res = await self._request("POST", f"/bucket/{id}/empty", json={}) + res = await self._request("POST", ["bucket", id, "empty"], json={}) return res.json() async def delete_bucket(self, id: str) -> dict[str, str]: @@ -115,5 +122,5 @@ async def delete_bucket(self, id: str) -> dict[str, str]: id The unique identifier of the bucket you would like to delete. """ - res = await self._request("DELETE", f"/bucket/{id}", json={}) + res = await self._request("DELETE", ["bucket", id], json={}) return res.json() diff --git a/src/storage/src/storage3/_async/client.py b/src/storage/src/storage3/_async/client.py index b2c5b4e2..0e12b267 100644 --- a/src/storage/src/storage3/_async/client.py +++ b/src/storage/src/storage3/_async/client.py @@ -55,39 +55,16 @@ def __init__( self.verify = bool(verify) if verify is not None else True self.timeout = int(abs(timeout)) if timeout is not None else DEFAULT_TIMEOUT - self.session = self._create_session( + self.session = http_client or AsyncClient( base_url=url, headers=headers, timeout=self.timeout, - verify=self.verify, proxy=proxy, - http_client=http_client, - ) - super().__init__(self.session) - - def _create_session( - self, - base_url: str, - headers: dict[str, str], - timeout: int, - verify: bool = True, - proxy: Optional[str] = None, - http_client: Optional[AsyncClient] = None, - ) -> AsyncClient: - if http_client is not None: - http_client.base_url = base_url - http_client.headers.update({**headers}) - return http_client - - return AsyncClient( - base_url=base_url, - headers=headers, - timeout=timeout, - proxy=proxy, - verify=verify, + verify=self.verify, follow_redirects=True, http2=True, ) + super().__init__(self.session, url, headers) async def __aenter__(self) -> AsyncStorageClient: return self diff --git a/src/storage/src/storage3/_sync/bucket.py b/src/storage/src/storage3/_sync/bucket.py index b81e90b8..87a8d026 100644 --- a/src/storage/src/storage3/_sync/bucket.py +++ b/src/storage/src/storage3/_sync/bucket.py @@ -3,6 +3,8 @@ from typing import Any, Optional from httpx import Client, HTTPStatusError, Response +from httpx._types import HeaderTypes +from yarl import URL from ..exceptions import StorageApiError from ..types import CreateOrUpdateBucketOptions, RequestMethod @@ -14,17 +16,22 @@ class SyncStorageBucketAPI: """This class abstracts access to the endpoint to the Get, List, Empty, and Delete operations on a bucket""" - def __init__(self, session: Client) -> None: + def __init__(self, session: Client, url: str, headers: HeaderTypes) -> None: self._client = session + self._base_url = URL(url) + self._headers = headers def _request( self, method: RequestMethod, - url: str, + path: list[str], json: Optional[dict[Any, Any]] = None, ) -> Response: try: - response = self._client.request(method, url, json=json) + url_path = self._base_url.joinpath(*path) + response = self._client.request( + method, str(url_path), json=json, headers=self._headers + ) response.raise_for_status() except HTTPStatusError as exc: resp = exc.response.json() @@ -35,7 +42,7 @@ def _request( def list_buckets(self) -> list[SyncBucket]: """Retrieves the details of all storage buckets within an existing product.""" # if the request doesn't error, it is assured to return a list - res = self._request("GET", "/bucket") + res = self._request("GET", ["bucket"]) return [SyncBucket(**bucket) for bucket in res.json()] def get_bucket(self, id: str) -> SyncBucket: @@ -46,7 +53,7 @@ def get_bucket(self, id: str) -> SyncBucket: id The unique identifier of the bucket you would like to retrieve. """ - res = self._request("GET", f"/bucket/{id}") + res = self._request("GET", ["bucket", id]) json = res.json() return SyncBucket(**json) @@ -73,7 +80,7 @@ def create_bucket( json.update(**options) res = self._request( "POST", - "/bucket", + ["bucket"], json=json, ) return res.json() @@ -92,7 +99,7 @@ def update_bucket( `allowed_mime_types`. """ json = {"id": id, "name": id, **options} - res = self._request("PUT", f"/bucket/{id}", json=json) + res = self._request("PUT", ["bucket", id], json=json) return res.json() def empty_bucket(self, id: str) -> dict[str, str]: @@ -103,7 +110,7 @@ def empty_bucket(self, id: str) -> dict[str, str]: id The unique identifier of the bucket you would like to empty. """ - res = self._request("POST", f"/bucket/{id}/empty", json={}) + res = self._request("POST", ["bucket", id, "empty"], json={}) return res.json() def delete_bucket(self, id: str) -> dict[str, str]: @@ -115,5 +122,5 @@ def delete_bucket(self, id: str) -> dict[str, str]: id The unique identifier of the bucket you would like to delete. """ - res = self._request("DELETE", f"/bucket/{id}", json={}) + res = self._request("DELETE", ["bucket", id], json={}) return res.json() diff --git a/src/storage/src/storage3/_sync/client.py b/src/storage/src/storage3/_sync/client.py index 15876922..a5cd766b 100644 --- a/src/storage/src/storage3/_sync/client.py +++ b/src/storage/src/storage3/_sync/client.py @@ -16,7 +16,7 @@ ] -class SyncStorageClient(SyncStorageBucketAPI): +class SyncStorageClient(SyncStorageBucketAPI): # """Manage storage buckets and files.""" def __init__( @@ -55,39 +55,16 @@ def __init__( self.verify = bool(verify) if verify is not None else True self.timeout = int(abs(timeout)) if timeout is not None else DEFAULT_TIMEOUT - self.session = self._create_session( + self.session = http_client or Client( base_url=url, headers=headers, timeout=self.timeout, - verify=self.verify, proxy=proxy, - http_client=http_client, - ) - super().__init__(self.session) - - def _create_session( - self, - base_url: str, - headers: dict[str, str], - timeout: int, - verify: bool = True, - proxy: Optional[str] = None, - http_client: Optional[Client] = None, - ) -> Client: - if http_client is not None: - http_client.base_url = base_url - http_client.headers.update({**headers}) - return http_client - - return Client( - base_url=base_url, - headers=headers, - timeout=timeout, - proxy=proxy, - verify=verify, + verify=self.verify, follow_redirects=True, http2=True, ) + super().__init__(self.session, url, headers) def __enter__(self) -> SyncStorageClient: return self diff --git a/src/storage/tests/_async/test_bucket.py b/src/storage/tests/_async/test_bucket.py index fe8fa0b6..8a106c15 100644 --- a/src/storage/tests/_async/test_bucket.py +++ b/src/storage/tests/_async/test_bucket.py @@ -1,12 +1,14 @@ from unittest.mock import AsyncMock, Mock import pytest -from httpx import HTTPStatusError, Response +from httpx import AsyncClient, HTTPStatusError, Response from storage3 import AsyncBucket, AsyncStorageBucketAPI from storage3.exceptions import StorageApiError from storage3.types import CreateOrUpdateBucketOptions +from ..test_client import valid_url + @pytest.fixture def mock_client(): @@ -14,8 +16,15 @@ def mock_client(): @pytest.fixture -def storage_api(mock_client): - return AsyncStorageBucketAPI(mock_client) +def headers() -> dict[str, str]: + return {} + + +@pytest.fixture +def storage_api( + mock_client: AsyncClient, headers: dict[str, str] +) -> AsyncStorageBucketAPI: + return AsyncStorageBucketAPI(mock_client, "", headers) @pytest.fixture @@ -58,7 +67,7 @@ async def test_list_buckets(storage_api, mock_client, mock_response): assert buckets[0].id == "bucket1" assert buckets[1].id == "bucket2" - mock_client.request.assert_called_once_with("GET", "/bucket", json=None) + mock_client.request.assert_called_once_with("GET", "bucket", json=None, headers={}) async def test_get_bucket(storage_api, mock_client, mock_response): @@ -84,7 +93,7 @@ async def test_get_bucket(storage_api, mock_client, mock_response): assert bucket.owner == "test-owner" mock_client.request.assert_called_once_with( - "GET", f"/bucket/{bucket_id}", json=None + "GET", f"bucket/{bucket_id}", json=None, headers={} ) @@ -103,7 +112,7 @@ async def test_create_bucket(storage_api, mock_client, mock_response): assert result == {"message": "Bucket created successfully"} mock_client.request.assert_called_once_with( "POST", - "/bucket", + "bucket", json={ "id": bucket_id, "name": bucket_name, @@ -111,6 +120,7 @@ async def test_create_bucket(storage_api, mock_client, mock_response): "file_size_limit": 1000000, "allowed_mime_types": ["image/*"], }, + headers={}, ) @@ -123,7 +133,7 @@ async def test_create_bucket_minimal(storage_api, mock_client, mock_response): assert result == {"message": "Bucket created successfully"} mock_client.request.assert_called_once_with( - "POST", "/bucket", json={"id": bucket_id, "name": bucket_id} + "POST", "bucket", json={"id": bucket_id, "name": bucket_id}, headers={} ) @@ -139,13 +149,14 @@ async def test_update_bucket(storage_api, mock_client, mock_response): assert result == {"message": "Bucket updated successfully"} mock_client.request.assert_called_once_with( "PUT", - f"/bucket/{bucket_id}", + f"bucket/{bucket_id}", json={ "id": bucket_id, "name": bucket_id, "public": False, "file_size_limit": 2000000, }, + headers={}, ) @@ -158,7 +169,7 @@ async def test_empty_bucket(storage_api, mock_client, mock_response): assert result == {"message": "Bucket emptied successfully"} mock_client.request.assert_called_once_with( - "POST", f"/bucket/{bucket_id}/empty", json={} + "POST", f"bucket/{bucket_id}/empty", json={}, headers={} ) @@ -171,7 +182,7 @@ async def test_delete_bucket(storage_api, mock_client, mock_response): assert result == {"message": "Bucket deleted successfully"} mock_client.request.assert_called_once_with( - "DELETE", f"/bucket/{bucket_id}", json={} + "DELETE", f"bucket/{bucket_id}", json={}, headers={} ) @@ -187,23 +198,25 @@ async def test_request_error_handling(storage_api, mock_client): mock_client.request.side_effect = exc with pytest.raises(StorageApiError) as exc_info: - await storage_api._request("GET", "/test") + await storage_api._request("GET", ["test"]) assert exc_info.value.message == "Test error message" @pytest.mark.parametrize( - "method,url,json_data", + "method,path,json_data", [ - ("GET", "/test", None), - ("POST", "/test", {"key": "value"}), - ("PUT", "/test", {"id": "123"}), - ("DELETE", "/test", {}), + ("GET", "test", None), + ("POST", "test", {"key": "value"}), + ("PUT", "test", {"id": "123"}), + ("DELETE", "test", {}), ], ) async def test_request_methods( - storage_api, mock_client, mock_response, method, url, json_data + storage_api, mock_client, mock_response, method, path, json_data ): mock_client.request.return_value = mock_response - await storage_api._request(method, url, json_data) - mock_client.request.assert_called_once_with(method, url, json=json_data) + await storage_api._request(method, [path], json_data) + mock_client.request.assert_called_once_with( + method, path, json=json_data, headers={} + ) diff --git a/src/storage/tests/_sync/test_bucket.py b/src/storage/tests/_sync/test_bucket.py index a732704d..2401b23e 100644 --- a/src/storage/tests/_sync/test_bucket.py +++ b/src/storage/tests/_sync/test_bucket.py @@ -1,12 +1,14 @@ from unittest.mock import Mock import pytest -from httpx import HTTPStatusError, Response +from httpx import Client, HTTPStatusError, Response from storage3 import SyncBucket, SyncStorageBucketAPI from storage3.exceptions import StorageApiError from storage3.types import CreateOrUpdateBucketOptions +from ..test_client import valid_url + @pytest.fixture def mock_client(): @@ -14,8 +16,13 @@ def mock_client(): @pytest.fixture -def storage_api(mock_client): - return SyncStorageBucketAPI(mock_client) +def headers() -> dict[str, str]: + return {} + + +@pytest.fixture +def storage_api(mock_client: Client, headers: dict[str, str]) -> SyncStorageBucketAPI: + return SyncStorageBucketAPI(mock_client, "", headers) @pytest.fixture @@ -58,7 +65,7 @@ def test_list_buckets(storage_api, mock_client, mock_response): assert buckets[0].id == "bucket1" assert buckets[1].id == "bucket2" - mock_client.request.assert_called_once_with("GET", "/bucket", json=None) + mock_client.request.assert_called_once_with("GET", "bucket", json=None, headers={}) def test_get_bucket(storage_api, mock_client, mock_response): @@ -84,7 +91,7 @@ def test_get_bucket(storage_api, mock_client, mock_response): assert bucket.owner == "test-owner" mock_client.request.assert_called_once_with( - "GET", f"/bucket/{bucket_id}", json=None + "GET", f"bucket/{bucket_id}", json=None, headers={} ) @@ -103,7 +110,7 @@ def test_create_bucket(storage_api, mock_client, mock_response): assert result == {"message": "Bucket created successfully"} mock_client.request.assert_called_once_with( "POST", - "/bucket", + "bucket", json={ "id": bucket_id, "name": bucket_name, @@ -111,6 +118,7 @@ def test_create_bucket(storage_api, mock_client, mock_response): "file_size_limit": 1000000, "allowed_mime_types": ["image/*"], }, + headers={}, ) @@ -123,7 +131,7 @@ def test_create_bucket_minimal(storage_api, mock_client, mock_response): assert result == {"message": "Bucket created successfully"} mock_client.request.assert_called_once_with( - "POST", "/bucket", json={"id": bucket_id, "name": bucket_id} + "POST", "bucket", json={"id": bucket_id, "name": bucket_id}, headers={} ) @@ -139,13 +147,14 @@ def test_update_bucket(storage_api, mock_client, mock_response): assert result == {"message": "Bucket updated successfully"} mock_client.request.assert_called_once_with( "PUT", - f"/bucket/{bucket_id}", + f"bucket/{bucket_id}", json={ "id": bucket_id, "name": bucket_id, "public": False, "file_size_limit": 2000000, }, + headers={}, ) @@ -158,7 +167,7 @@ def test_empty_bucket(storage_api, mock_client, mock_response): assert result == {"message": "Bucket emptied successfully"} mock_client.request.assert_called_once_with( - "POST", f"/bucket/{bucket_id}/empty", json={} + "POST", f"bucket/{bucket_id}/empty", json={}, headers={} ) @@ -171,7 +180,7 @@ def test_delete_bucket(storage_api, mock_client, mock_response): assert result == {"message": "Bucket deleted successfully"} mock_client.request.assert_called_once_with( - "DELETE", f"/bucket/{bucket_id}", json={} + "DELETE", f"bucket/{bucket_id}", json={}, headers={} ) @@ -187,23 +196,25 @@ def test_request_error_handling(storage_api, mock_client): mock_client.request.side_effect = exc with pytest.raises(StorageApiError) as exc_info: - storage_api._request("GET", "/test") + storage_api._request("GET", ["test"]) assert exc_info.value.message == "Test error message" @pytest.mark.parametrize( - "method,url,json_data", + "method,path,json_data", [ - ("GET", "/test", None), - ("POST", "/test", {"key": "value"}), - ("PUT", "/test", {"id": "123"}), - ("DELETE", "/test", {}), + ("GET", "test", None), + ("POST", "test", {"key": "value"}), + ("PUT", "test", {"id": "123"}), + ("DELETE", "test", {}), ], ) def test_request_methods( - storage_api, mock_client, mock_response, method, url, json_data + storage_api, mock_client, mock_response, method, path, json_data ): mock_client.request.return_value = mock_response - storage_api._request(method, url, json_data) - mock_client.request.assert_called_once_with(method, url, json=json_data) + storage_api._request(method, [path], json_data) + mock_client.request.assert_called_once_with( + method, path, json=json_data, headers={} + ) diff --git a/src/storage/tests/test_client.py b/src/storage/tests/test_client.py index d5434d46..5fe321a8 100644 --- a/src/storage/tests/test_client.py +++ b/src/storage/tests/test_client.py @@ -45,9 +45,7 @@ def test_async_storage_client(valid_url, valid_headers): ) assert isinstance(client, AsyncStorageClient) - assert all( - client._client.headers[key] == value for key, value in valid_headers.items() - ) + assert all(client._headers[key] == value for key, value in valid_headers.items()) assert client._client.headers.get("x-user-agent") == "my-app/0.0.1" assert client._client.timeout == Timeout(5.0) @@ -60,9 +58,7 @@ def test_sync_storage_client(valid_url, valid_headers): ) assert isinstance(client, SyncStorageClient) - assert all( - client._client.headers[key] == value for key, value in valid_headers.items() - ) + assert all(client._headers[key] == value for key, value in valid_headers.items()) assert client._client.headers.get("x-user-agent") == "my-app/0.0.1" assert client._client.timeout == Timeout(5.0) diff --git a/uv.lock b/uv.lock index 40cc09cb..fef9ef94 100644 --- a/uv.lock +++ b/uv.lock @@ -1533,7 +1533,7 @@ wheels = [ [[package]] name = "postgrest" -version = "2.21.0" +version = "2.21.1" source = { editable = "src/postgrest" } dependencies = [ { name = "deprecation" }, @@ -2018,7 +2018,7 @@ wheels = [ [[package]] name = "realtime" -version = "2.20.0" +version = "2.21.1" source = { editable = "src/realtime" } dependencies = [ { name = "pydantic" }, @@ -2594,12 +2594,13 @@ wheels = [ [[package]] name = "storage3" -version = "2.21.0" +version = "2.21.1" source = { editable = "src/storage" } dependencies = [ { name = "deprecation" }, { name = "httpx", extra = ["http2"] }, { name = "pydantic" }, + { name = "yarl" }, ] [package.dev-dependencies] @@ -2645,6 +2646,7 @@ requires-dist = [ { name = "deprecation", specifier = ">=2.1.0" }, { name = "httpx", extras = ["http2"], specifier = ">=0.26,<0.29" }, { name = "pydantic", specifier = ">=2.11.7" }, + { name = "yarl", specifier = ">=1.20.1" }, ] [package.metadata.requires-dev] @@ -2692,7 +2694,7 @@ wheels = [ [[package]] name = "supabase" -version = "2.21.0" +version = "2.21.1" source = { editable = "src/supabase" } dependencies = [ { name = "httpx" }, @@ -2755,7 +2757,7 @@ tests = [ [[package]] name = "supabase-auth" -version = "2.21.0" +version = "2.21.1" source = { editable = "src/auth" } dependencies = [ { name = "httpx", extra = ["http2"] }, @@ -2824,7 +2826,7 @@ tests = [ [[package]] name = "supabase-functions" -version = "2.21.0" +version = "2.21.1" source = { editable = "src/functions" } dependencies = [ { name = "httpx", extra = ["http2"] }, From 44839e8e438b4658939edd9168a8e3f52344b8fb Mon Sep 17 00:00:00 2001 From: Leonardo Santiago Date: Tue, 7 Oct 2025 13:17:28 -0300 Subject: [PATCH 2/5] fix: do not mutate httpx client in postgrest --- src/postgrest/Makefile | 2 +- src/postgrest/pyproject.toml | 1 + src/postgrest/src/postgrest/_async/client.py | 75 +++-- .../src/postgrest/_async/request_builder.py | 265 +++++++----------- src/postgrest/src/postgrest/_sync/client.py | 75 +++-- .../src/postgrest/_sync/request_builder.py | 265 +++++++----------- src/postgrest/src/postgrest/base_client.py | 38 +-- .../src/postgrest/base_request_builder.py | 92 ++++-- src/postgrest/src/postgrest/types.py | 2 + src/postgrest/src/postgrest/utils.py | 5 +- src/postgrest/tests/_async/client.py | 2 +- src/postgrest/tests/_async/test_client.py | 40 +-- .../_async/test_filter_request_builder.py | 112 +++++--- .../_async/test_query_request_builder.py | 21 +- .../tests/_async/test_request_builder.py | 182 ++++++------ src/postgrest/tests/_sync/client.py | 2 +- src/postgrest/tests/_sync/test_client.py | 40 +-- .../_sync/test_filter_request_builder.py | 112 +++++--- .../tests/_sync/test_query_request_builder.py | 21 +- .../tests/_sync/test_request_builder.py | 182 ++++++------ src/supabase/tests/_async/test_client.py | 43 +++ uv.lock | 2 + 22 files changed, 809 insertions(+), 770 deletions(-) diff --git a/src/postgrest/Makefile b/src/postgrest/Makefile index f0d84a6e..2a9ae0a4 100644 --- a/src/postgrest/Makefile +++ b/src/postgrest/Makefile @@ -48,7 +48,7 @@ unasync: build-sync: unasync sed -i 's/@pytest.mark.asyncio//g' tests/_sync/test_client.py - sed -i 's/_async/_sync/g' tests/_sync/test_client.py + sed -i 's/_async/_sync/g' tests/_sync/test_client.py tests/_sync/test_query_request_builder.py tests/_sync/test_filter_request_builder.py sed -i 's/Async/Sync/g' src/postgrest/_sync/request_builder.py tests/_sync/test_client.py sed -i 's/_client\.SyncClient/_client\.Client/g' tests/_sync/test_client.py sed -i 's/SyncHTTPTransport/HTTPTransport/g' tests/_sync/**.py diff --git a/src/postgrest/pyproject.toml b/src/postgrest/pyproject.toml index 6a7b7efb..58358869 100644 --- a/src/postgrest/pyproject.toml +++ b/src/postgrest/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "deprecation >=2.1.0", "pydantic >=1.9,<3.0", "strenum >=0.4.9; python_version < \"3.11\"", + "yarl>=1.20.1", ] [project.urls] diff --git a/src/postgrest/src/postgrest/_async/client.py b/src/postgrest/src/postgrest/_async/client.py index 896e7494..6e9c8061 100644 --- a/src/postgrest/src/postgrest/_async/client.py +++ b/src/postgrest/src/postgrest/_async/client.py @@ -5,6 +5,8 @@ from deprecation import deprecated from httpx import AsyncClient, Headers, QueryParams, Timeout +from httpx._types import HeaderTypes +from yarl import URL from ..base_client import BasePostgrestClient from ..constants import ( @@ -13,7 +15,11 @@ ) from ..types import CountMethod from ..version import __version__ -from .request_builder import AsyncRequestBuilder, AsyncRPCFilterRequestBuilder +from .request_builder import ( + AsyncRequestBuilder, + AsyncRPCFilterRequestBuilder, + RequestConfig, +) class AsyncPostgrestClient(BasePostgrestClient): @@ -59,52 +65,32 @@ def __init__( else DEFAULT_POSTGREST_CLIENT_TIMEOUT ) ) - BasePostgrestClient.__init__( self, - base_url, + URL(base_url), schema=schema, headers=headers, timeout=self.timeout, verify=self.verify, proxy=proxy, - http_client=http_client, ) - self.session: AsyncClient = self.session - - def create_session( - self, - base_url: str, - headers: Dict[str, str], - timeout: Union[int, float, Timeout], - verify: bool = True, - proxy: Optional[str] = None, - ) -> AsyncClient: - http_client = None - if isinstance(self.http_client, AsyncClient): - http_client = self.http_client - if http_client is not None: - http_client.base_url = base_url - http_client.headers.update({**headers}) - return http_client - - return AsyncClient( + self.session = http_client or AsyncClient( base_url=base_url, - headers=headers, + headers=self.headers, timeout=timeout, - verify=verify, + verify=self.verify, proxy=proxy, follow_redirects=True, http2=True, ) - def schema(self, schema: str): + def schema(self, schema: str) -> AsyncPostgrestClient: """Switch to another schema.""" return AsyncPostgrestClient( - base_url=self.base_url, + base_url=str(self.base_url), schema=schema, - headers=self.headers, + headers=dict(self.headers), timeout=self.timeout, verify=self.verify, proxy=self.proxy, @@ -128,7 +114,9 @@ def from_(self, table: str) -> AsyncRequestBuilder: Returns: :class:`AsyncRequestBuilder` """ - return AsyncRequestBuilder(self.session, f"/{table}") + return AsyncRequestBuilder( + self.session, self.base_url.joinpath(table), self.headers, self.basic_auth + ) def table(self, table: str) -> AsyncRequestBuilder: """Alias to :meth:`from_`.""" @@ -142,7 +130,7 @@ def from_table(self, table: str) -> AsyncRequestBuilder: def rpc( self, func: str, - params: dict, + params: dict[str, str], count: Optional[CountMethod] = None, head: bool = False, get: bool = False, @@ -171,17 +159,20 @@ def rpc( method = "HEAD" if head else "GET" if get else "POST" headers = Headers({"Prefer": f"count={count}"}) if count else Headers() - - if method in ("HEAD", "GET"): - return AsyncRPCFilterRequestBuilder( - self.session, - f"/rpc/{func}", - method, - headers, - QueryParams(params), - json={}, - ) + headers.update(self.headers) # the params here are params to be sent to the RPC and not the queryparams! - return AsyncRPCFilterRequestBuilder( - self.session, f"/rpc/{func}", method, headers, QueryParams(), json=params + json, http_params = ( + ({}, QueryParams(params)) + if method in ("HEAD", "GET") + else (params, QueryParams()) + ) + request = RequestConfig( + self.session, + self.base_url.joinpath("rpc", func), + method, + headers, + http_params, + self.basic_auth, + json, ) + return AsyncRPCFilterRequestBuilder(request) diff --git a/src/postgrest/src/postgrest/_async/request_builder.py b/src/postgrest/src/postgrest/_async/request_builder.py index 106845fa..a575edd7 100644 --- a/src/postgrest/src/postgrest/_async/request_builder.py +++ b/src/postgrest/src/postgrest/_async/request_builder.py @@ -2,9 +2,10 @@ from typing import Any, Generic, Optional, TypeVar, Union -from httpx import AsyncClient, Headers, QueryParams +from httpx import AsyncClient, BasicAuth, Headers, QueryParams, Response from pydantic import ValidationError from typing_extensions import override +from yarl import URL from ..base_request_builder import ( APIResponse, @@ -12,6 +13,7 @@ BaseRPCRequestBuilder, BaseSelectRequestBuilder, CountMethod, + RequestConfig, SingleAPIResponse, pre_delete, pre_insert, @@ -23,23 +25,12 @@ from ..types import JSON, ReturnMethod from ..utils import model_validate_json +ReqConfig = RequestConfig[AsyncClient] + class AsyncQueryRequestBuilder: - def __init__( - self, - session: AsyncClient, - path: str, - http_method: str, - headers: Headers, - params: QueryParams, - json: JSON, - ) -> None: - self.session = session - self.path = path - self.http_method = http_method - self.headers = headers - self.params = params - self.json = None if http_method in {"GET", "HEAD"} else json + def __init__(self, request: ReqConfig): + self.request = request async def execute(self) -> APIResponse | str: """Execute the query. @@ -53,23 +44,19 @@ async def execute(self) -> APIResponse | str: Raises: :class:`APIError` If the API raised an error. """ - r = await self.session.request( - self.http_method, - self.path, - json=self.json, - params=self.params, - headers=self.headers, - ) + r = await self.request.send() try: if r.is_success: - if self.http_method != "HEAD": + if self.request.http_method != "HEAD": body = r.text - if self.headers.get("Accept") == "text/csv": + if self.request.headers.get("Accept") == "text/csv": return body - if self.headers.get( + if self.request.headers.get( "Accept" - ) and "application/vnd.pgrst.plan" in self.headers.get("Accept"): - if "+json" not in self.headers.get("Accept"): + ) and "application/vnd.pgrst.plan" in self.request.headers.get( + "Accept" + ): + if "+json" not in self.request.headers.get("Accept"): return body return APIResponse.from_http_request_response(r) else: @@ -80,21 +67,8 @@ async def execute(self) -> APIResponse | str: class AsyncSingleRequestBuilder: - def __init__( - self, - session: AsyncClient, - path: str, - http_method: str, - headers: Headers, - params: QueryParams, - json: JSON, - ) -> None: - self.session = session - self.path = path - self.http_method = http_method - self.headers = headers - self.params = params - self.json = json + def __init__(self, request: ReqConfig): + self.request = request async def execute(self) -> SingleAPIResponse: """Execute the query. @@ -108,13 +82,7 @@ async def execute(self) -> SingleAPIResponse: Raises: :class:`APIError` If the API raised an error. """ - r = await self.session.request( - self.http_method, - self.path, - json=self.json, - params=self.params, - headers=self.headers, - ) + r = await self.request.send() try: if ( 200 <= r.status_code <= 299 @@ -128,34 +96,13 @@ async def execute(self) -> SingleAPIResponse: class AsyncMaybeSingleRequestBuilder: - def __init__( - self, - session: AsyncClient, - path: str, - http_method: str, - headers: Headers, - params: QueryParams, - json: JSON, - ) -> None: - self.session = session - self.path = path - self.http_method = http_method - self.headers = headers - self.params = params - self.json = json + def __init__(self, request: ReqConfig): + self.request = request async def execute(self) -> Optional[SingleAPIResponse]: r = None try: - request = AsyncSingleRequestBuilder( - self.session, - self.path, - self.http_method, - self.headers, - self.params, - self.json, - ) - r = await request.execute() + r = await AsyncSingleRequestBuilder(self.request).execute() except APIError as e: if e.details and "The result contains 0 rows" in e.details: return None @@ -171,52 +118,26 @@ async def execute(self) -> Optional[SingleAPIResponse]: return r -class AsyncFilterRequestBuilder(BaseFilterRequestBuilder, AsyncQueryRequestBuilder): - def __init__( - self, - session: AsyncClient, - path: str, - http_method: str, - headers: Headers, - params: QueryParams, - json: JSON, - ) -> None: - BaseFilterRequestBuilder.__init__(self, headers, params) - AsyncQueryRequestBuilder.__init__( - self, session, path, http_method, headers, params, json - ) +class AsyncFilterRequestBuilder( + BaseFilterRequestBuilder[AsyncClient], AsyncQueryRequestBuilder +): + def __init__(self, request: ReqConfig) -> None: + BaseFilterRequestBuilder.__init__(self, request) + AsyncQueryRequestBuilder.__init__(self, request) class AsyncRPCFilterRequestBuilder(BaseRPCRequestBuilder, AsyncSingleRequestBuilder): - def __init__( - self, - session: AsyncClient, - path: str, - http_method: str, - headers: Headers, - params: QueryParams, - json: JSON, - ) -> None: - BaseFilterRequestBuilder.__init__(self, headers, params) - AsyncSingleRequestBuilder.__init__( - self, session, path, http_method, headers, params, json - ) + def __init__(self, request: ReqConfig) -> None: + BaseFilterRequestBuilder.__init__(self, request) + AsyncSingleRequestBuilder.__init__(self, request) -class AsyncSelectRequestBuilder(AsyncQueryRequestBuilder, BaseSelectRequestBuilder): - def __init__( - self, - session: AsyncClient, - path: str, - http_method: str, - headers: Headers, - params: QueryParams, - json: JSON, - ) -> None: - BaseSelectRequestBuilder.__init__(self, headers, params) - AsyncQueryRequestBuilder.__init__( - self, session, path, http_method, headers, params, json - ) +class AsyncSelectRequestBuilder( + AsyncQueryRequestBuilder, BaseSelectRequestBuilder[AsyncClient] +): + def __init__(self, request: ReqConfig) -> None: + BaseSelectRequestBuilder.__init__(self, request) + AsyncQueryRequestBuilder.__init__(self, request) def single(self) -> AsyncSingleRequestBuilder: """Specify that the query will only return a single row in response. @@ -224,27 +145,13 @@ def single(self) -> AsyncSingleRequestBuilder: .. caution:: The API will raise an error if the query returned more than one row. """ - self.headers["Accept"] = "application/vnd.pgrst.object+json" - return AsyncSingleRequestBuilder( - headers=self.headers, - http_method=self.http_method, - json=self.json, - params=self.params, - path=self.path, - session=self.session, - ) + self.request.headers["Accept"] = "application/vnd.pgrst.object+json" + return AsyncSingleRequestBuilder(self.request) def maybe_single(self) -> AsyncMaybeSingleRequestBuilder: """Retrieves at most one row from the result. Result must be at most one row (e.g. using `eq` on a UNIQUE column), otherwise this will result in an error.""" - self.headers["Accept"] = "application/vnd.pgrst.object+json" - return AsyncMaybeSingleRequestBuilder( - headers=self.headers, - http_method=self.http_method, - json=self.json, - params=self.params, - path=self.path, - session=self.session, - ) + self.request.headers["Accept"] = "application/vnd.pgrst.object+json" + return AsyncMaybeSingleRequestBuilder(self.request) def text_search( self, column: str, query: str, options: dict[str, Any] = {} @@ -258,34 +165,26 @@ def text_search( elif type_ == "web_search": type_part = "w" config_part = f"({options.get('config')})" if options.get("config") else "" - self.params = self.params.add(column, f"{type_part}fts{config_part}.{query}") - - return AsyncQueryRequestBuilder( - headers=self.headers, - http_method=self.http_method, - json=self.json, - params=self.params, - path=self.path, - session=self.session, + self.request.params = self.request.params.add( + column, f"{type_part}fts{config_part}.{query}" ) + return AsyncQueryRequestBuilder(self.request) + def csv(self) -> AsyncSingleRequestBuilder: """Specify that the query must retrieve data as a single CSV string.""" - self.headers["Accept"] = "text/csv" - return AsyncSingleRequestBuilder( - session=self.session, - path=self.path, - http_method=self.http_method, - headers=self.headers, - params=self.params, - json=self.json, - ) + self.request.headers["Accept"] = "text/csv" + return AsyncSingleRequestBuilder(self.request) -class AsyncRequestBuilder: - def __init__(self, session: AsyncClient, path: str) -> None: +class AsyncRequestBuilder: # + def __init__( + self, session: AsyncClient, path: URL, headers: Headers, auth: BasicAuth | None + ) -> None: self.session = session self.path = path + self.headers = headers + self.auth = auth def select( self, @@ -302,9 +201,17 @@ def select( :class:`AsyncSelectRequestBuilder` """ method, params, headers, json = pre_select(*columns, count=count, head=head) - return AsyncSelectRequestBuilder( - self.session, self.path, method, headers, params, json + headers.update(self.headers) + request = RequestConfig( + session=self.session, + path=self.path, + auth=self.auth, + params=params, + http_method=method, + headers=headers, + json=json, ) + return AsyncSelectRequestBuilder(request) def insert( self, @@ -335,9 +242,17 @@ def insert( upsert=upsert, default_to_null=default_to_null, ) - return AsyncQueryRequestBuilder( - self.session, self.path, method, headers, params, json + headers.update(self.headers) + request = RequestConfig( + session=self.session, + path=self.path, + auth=self.auth, + params=params, + http_method=method, + headers=headers, + json=json, ) + return AsyncQueryRequestBuilder(request) def upsert( self, @@ -372,9 +287,17 @@ def upsert( on_conflict=on_conflict, default_to_null=default_to_null, ) - return AsyncQueryRequestBuilder( - self.session, self.path, method, headers, params, json + headers.update(self.headers) + request = RequestConfig( + session=self.session, + path=self.path, + auth=self.auth, + params=params, + http_method=method, + headers=headers, + json=json, ) + return AsyncQueryRequestBuilder(request) def update( self, @@ -397,9 +320,17 @@ def update( count=count, returning=returning, ) - return AsyncFilterRequestBuilder( - self.session, self.path, method, headers, params, json + headers.update(self.headers) + request = RequestConfig( + session=self.session, + path=self.path, + auth=self.auth, + params=params, + http_method=method, + headers=headers, + json=json, ) + return AsyncFilterRequestBuilder(request) def delete( self, @@ -419,6 +350,14 @@ def delete( count=count, returning=returning, ) - return AsyncFilterRequestBuilder( - self.session, self.path, method, headers, params, json + headers.update(self.headers) + request = RequestConfig( + session=self.session, + path=self.path, + auth=self.auth, + params=params, + http_method=method, + headers=headers, + json=json, ) + return AsyncFilterRequestBuilder(request) diff --git a/src/postgrest/src/postgrest/_sync/client.py b/src/postgrest/src/postgrest/_sync/client.py index ac578960..9df8ae8f 100644 --- a/src/postgrest/src/postgrest/_sync/client.py +++ b/src/postgrest/src/postgrest/_sync/client.py @@ -5,6 +5,8 @@ from deprecation import deprecated from httpx import Client, Headers, QueryParams, Timeout +from httpx._types import HeaderTypes +from yarl import URL from ..base_client import BasePostgrestClient from ..constants import ( @@ -13,7 +15,11 @@ ) from ..types import CountMethod from ..version import __version__ -from .request_builder import SyncRequestBuilder, SyncRPCFilterRequestBuilder +from .request_builder import ( + RequestConfig, + SyncRequestBuilder, + SyncRPCFilterRequestBuilder, +) class SyncPostgrestClient(BasePostgrestClient): @@ -59,52 +65,32 @@ def __init__( else DEFAULT_POSTGREST_CLIENT_TIMEOUT ) ) - BasePostgrestClient.__init__( self, - base_url, + URL(base_url), schema=schema, headers=headers, timeout=self.timeout, verify=self.verify, proxy=proxy, - http_client=http_client, ) - self.session: Client = self.session - - def create_session( - self, - base_url: str, - headers: Dict[str, str], - timeout: Union[int, float, Timeout], - verify: bool = True, - proxy: Optional[str] = None, - ) -> Client: - http_client = None - if isinstance(self.http_client, Client): - http_client = self.http_client - if http_client is not None: - http_client.base_url = base_url - http_client.headers.update({**headers}) - return http_client - - return Client( + self.session = http_client or Client( base_url=base_url, - headers=headers, + headers=self.headers, timeout=timeout, - verify=verify, + verify=self.verify, proxy=proxy, follow_redirects=True, http2=True, ) - def schema(self, schema: str): + def schema(self, schema: str) -> SyncPostgrestClient: """Switch to another schema.""" return SyncPostgrestClient( - base_url=self.base_url, + base_url=str(self.base_url), schema=schema, - headers=self.headers, + headers=dict(self.headers), timeout=self.timeout, verify=self.verify, proxy=self.proxy, @@ -128,7 +114,9 @@ def from_(self, table: str) -> SyncRequestBuilder: Returns: :class:`AsyncRequestBuilder` """ - return SyncRequestBuilder(self.session, f"/{table}") + return SyncRequestBuilder( + self.session, self.base_url.joinpath(table), self.headers, self.basic_auth + ) def table(self, table: str) -> SyncRequestBuilder: """Alias to :meth:`from_`.""" @@ -142,7 +130,7 @@ def from_table(self, table: str) -> SyncRequestBuilder: def rpc( self, func: str, - params: dict, + params: dict[str, str], count: Optional[CountMethod] = None, head: bool = False, get: bool = False, @@ -171,17 +159,20 @@ def rpc( method = "HEAD" if head else "GET" if get else "POST" headers = Headers({"Prefer": f"count={count}"}) if count else Headers() - - if method in ("HEAD", "GET"): - return SyncRPCFilterRequestBuilder( - self.session, - f"/rpc/{func}", - method, - headers, - QueryParams(params), - json={}, - ) + headers.update(self.headers) # the params here are params to be sent to the RPC and not the queryparams! - return SyncRPCFilterRequestBuilder( - self.session, f"/rpc/{func}", method, headers, QueryParams(), json=params + json, http_params = ( + ({}, QueryParams(params)) + if method in ("HEAD", "GET") + else (params, QueryParams()) + ) + request = RequestConfig( + self.session, + self.base_url.joinpath("rpc", func), + method, + headers, + http_params, + self.basic_auth, + json, ) + return SyncRPCFilterRequestBuilder(request) diff --git a/src/postgrest/src/postgrest/_sync/request_builder.py b/src/postgrest/src/postgrest/_sync/request_builder.py index aacc922d..e7e2bd9e 100644 --- a/src/postgrest/src/postgrest/_sync/request_builder.py +++ b/src/postgrest/src/postgrest/_sync/request_builder.py @@ -2,9 +2,10 @@ from typing import Any, Generic, Optional, TypeVar, Union -from httpx import Client, Headers, QueryParams +from httpx import BasicAuth, Client, Headers, QueryParams, Response from pydantic import ValidationError from typing_extensions import override +from yarl import URL from ..base_request_builder import ( APIResponse, @@ -12,6 +13,7 @@ BaseRPCRequestBuilder, BaseSelectRequestBuilder, CountMethod, + RequestConfig, SingleAPIResponse, pre_delete, pre_insert, @@ -23,23 +25,12 @@ from ..types import JSON, ReturnMethod from ..utils import model_validate_json +ReqConfig = RequestConfig[Client] + class SyncQueryRequestBuilder: - def __init__( - self, - session: Client, - path: str, - http_method: str, - headers: Headers, - params: QueryParams, - json: JSON, - ) -> None: - self.session = session - self.path = path - self.http_method = http_method - self.headers = headers - self.params = params - self.json = None if http_method in {"GET", "HEAD"} else json + def __init__(self, request: ReqConfig): + self.request = request def execute(self) -> APIResponse | str: """Execute the query. @@ -53,23 +44,19 @@ def execute(self) -> APIResponse | str: Raises: :class:`APIError` If the API raised an error. """ - r = self.session.request( - self.http_method, - self.path, - json=self.json, - params=self.params, - headers=self.headers, - ) + r = self.request.send() try: if r.is_success: - if self.http_method != "HEAD": + if self.request.http_method != "HEAD": body = r.text - if self.headers.get("Accept") == "text/csv": + if self.request.headers.get("Accept") == "text/csv": return body - if self.headers.get( + if self.request.headers.get( "Accept" - ) and "application/vnd.pgrst.plan" in self.headers.get("Accept"): - if "+json" not in self.headers.get("Accept"): + ) and "application/vnd.pgrst.plan" in self.request.headers.get( + "Accept" + ): + if "+json" not in self.request.headers.get("Accept"): return body return APIResponse.from_http_request_response(r) else: @@ -80,21 +67,8 @@ def execute(self) -> APIResponse | str: class SyncSingleRequestBuilder: - def __init__( - self, - session: Client, - path: str, - http_method: str, - headers: Headers, - params: QueryParams, - json: JSON, - ) -> None: - self.session = session - self.path = path - self.http_method = http_method - self.headers = headers - self.params = params - self.json = json + def __init__(self, request: ReqConfig): + self.request = request def execute(self) -> SingleAPIResponse: """Execute the query. @@ -108,13 +82,7 @@ def execute(self) -> SingleAPIResponse: Raises: :class:`APIError` If the API raised an error. """ - r = self.session.request( - self.http_method, - self.path, - json=self.json, - params=self.params, - headers=self.headers, - ) + r = self.request.send() try: if ( 200 <= r.status_code <= 299 @@ -128,34 +96,13 @@ def execute(self) -> SingleAPIResponse: class SyncMaybeSingleRequestBuilder: - def __init__( - self, - session: Client, - path: str, - http_method: str, - headers: Headers, - params: QueryParams, - json: JSON, - ) -> None: - self.session = session - self.path = path - self.http_method = http_method - self.headers = headers - self.params = params - self.json = json + def __init__(self, request: ReqConfig): + self.request = request def execute(self) -> Optional[SingleAPIResponse]: r = None try: - request = SyncSingleRequestBuilder( - self.session, - self.path, - self.http_method, - self.headers, - self.params, - self.json, - ) - r = request.execute() + r = SyncSingleRequestBuilder(self.request).execute() except APIError as e: if e.details and "The result contains 0 rows" in e.details: return None @@ -171,52 +118,26 @@ def execute(self) -> Optional[SingleAPIResponse]: return r -class SyncFilterRequestBuilder(BaseFilterRequestBuilder, SyncQueryRequestBuilder): - def __init__( - self, - session: Client, - path: str, - http_method: str, - headers: Headers, - params: QueryParams, - json: JSON, - ) -> None: - BaseFilterRequestBuilder.__init__(self, headers, params) - SyncQueryRequestBuilder.__init__( - self, session, path, http_method, headers, params, json - ) +class SyncFilterRequestBuilder( + BaseFilterRequestBuilder[Client], SyncQueryRequestBuilder +): + def __init__(self, request: ReqConfig) -> None: + BaseFilterRequestBuilder.__init__(self, request) + SyncQueryRequestBuilder.__init__(self, request) class SyncRPCFilterRequestBuilder(BaseRPCRequestBuilder, SyncSingleRequestBuilder): - def __init__( - self, - session: Client, - path: str, - http_method: str, - headers: Headers, - params: QueryParams, - json: JSON, - ) -> None: - BaseFilterRequestBuilder.__init__(self, headers, params) - SyncSingleRequestBuilder.__init__( - self, session, path, http_method, headers, params, json - ) + def __init__(self, request: ReqConfig) -> None: + BaseFilterRequestBuilder.__init__(self, request) + SyncSingleRequestBuilder.__init__(self, request) -class SyncSelectRequestBuilder(SyncQueryRequestBuilder, BaseSelectRequestBuilder): - def __init__( - self, - session: Client, - path: str, - http_method: str, - headers: Headers, - params: QueryParams, - json: JSON, - ) -> None: - BaseSelectRequestBuilder.__init__(self, headers, params) - SyncQueryRequestBuilder.__init__( - self, session, path, http_method, headers, params, json - ) +class SyncSelectRequestBuilder( + SyncQueryRequestBuilder, BaseSelectRequestBuilder[Client] +): + def __init__(self, request: ReqConfig) -> None: + BaseSelectRequestBuilder.__init__(self, request) + SyncQueryRequestBuilder.__init__(self, request) def single(self) -> SyncSingleRequestBuilder: """Specify that the query will only return a single row in response. @@ -224,27 +145,13 @@ def single(self) -> SyncSingleRequestBuilder: .. caution:: The API will raise an error if the query returned more than one row. """ - self.headers["Accept"] = "application/vnd.pgrst.object+json" - return SyncSingleRequestBuilder( - headers=self.headers, - http_method=self.http_method, - json=self.json, - params=self.params, - path=self.path, - session=self.session, - ) + self.request.headers["Accept"] = "application/vnd.pgrst.object+json" + return SyncSingleRequestBuilder(self.request) def maybe_single(self) -> SyncMaybeSingleRequestBuilder: """Retrieves at most one row from the result. Result must be at most one row (e.g. using `eq` on a UNIQUE column), otherwise this will result in an error.""" - self.headers["Accept"] = "application/vnd.pgrst.object+json" - return SyncMaybeSingleRequestBuilder( - headers=self.headers, - http_method=self.http_method, - json=self.json, - params=self.params, - path=self.path, - session=self.session, - ) + self.request.headers["Accept"] = "application/vnd.pgrst.object+json" + return SyncMaybeSingleRequestBuilder(self.request) def text_search( self, column: str, query: str, options: dict[str, Any] = {} @@ -258,34 +165,26 @@ def text_search( elif type_ == "web_search": type_part = "w" config_part = f"({options.get('config')})" if options.get("config") else "" - self.params = self.params.add(column, f"{type_part}fts{config_part}.{query}") - - return SyncQueryRequestBuilder( - headers=self.headers, - http_method=self.http_method, - json=self.json, - params=self.params, - path=self.path, - session=self.session, + self.request.params = self.request.params.add( + column, f"{type_part}fts{config_part}.{query}" ) + return SyncQueryRequestBuilder(self.request) + def csv(self) -> SyncSingleRequestBuilder: """Specify that the query must retrieve data as a single CSV string.""" - self.headers["Accept"] = "text/csv" - return SyncSingleRequestBuilder( - session=self.session, - path=self.path, - http_method=self.http_method, - headers=self.headers, - params=self.params, - json=self.json, - ) + self.request.headers["Accept"] = "text/csv" + return SyncSingleRequestBuilder(self.request) -class SyncRequestBuilder: - def __init__(self, session: Client, path: str) -> None: +class SyncRequestBuilder: # + def __init__( + self, session: Client, path: URL, headers: Headers, auth: BasicAuth | None + ) -> None: self.session = session self.path = path + self.headers = headers + self.auth = auth def select( self, @@ -302,9 +201,17 @@ def select( :class:`SyncSelectRequestBuilder` """ method, params, headers, json = pre_select(*columns, count=count, head=head) - return SyncSelectRequestBuilder( - self.session, self.path, method, headers, params, json + headers.update(self.headers) + request = RequestConfig( + session=self.session, + path=self.path, + auth=self.auth, + params=params, + http_method=method, + headers=headers, + json=json, ) + return SyncSelectRequestBuilder(request) def insert( self, @@ -335,9 +242,17 @@ def insert( upsert=upsert, default_to_null=default_to_null, ) - return SyncQueryRequestBuilder( - self.session, self.path, method, headers, params, json + headers.update(self.headers) + request = RequestConfig( + session=self.session, + path=self.path, + auth=self.auth, + params=params, + http_method=method, + headers=headers, + json=json, ) + return SyncQueryRequestBuilder(request) def upsert( self, @@ -372,9 +287,17 @@ def upsert( on_conflict=on_conflict, default_to_null=default_to_null, ) - return SyncQueryRequestBuilder( - self.session, self.path, method, headers, params, json + headers.update(self.headers) + request = RequestConfig( + session=self.session, + path=self.path, + auth=self.auth, + params=params, + http_method=method, + headers=headers, + json=json, ) + return SyncQueryRequestBuilder(request) def update( self, @@ -397,9 +320,17 @@ def update( count=count, returning=returning, ) - return SyncFilterRequestBuilder( - self.session, self.path, method, headers, params, json + headers.update(self.headers) + request = RequestConfig( + session=self.session, + path=self.path, + auth=self.auth, + params=params, + http_method=method, + headers=headers, + json=json, ) + return SyncFilterRequestBuilder(request) def delete( self, @@ -419,6 +350,14 @@ def delete( count=count, returning=returning, ) - return SyncFilterRequestBuilder( - self.session, self.path, method, headers, params, json + headers.update(self.headers) + request = RequestConfig( + session=self.session, + path=self.path, + auth=self.auth, + params=params, + http_method=method, + headers=headers, + json=json, ) + return SyncFilterRequestBuilder(request) diff --git a/src/postgrest/src/postgrest/base_client.py b/src/postgrest/src/postgrest/base_client.py index f4a819d8..0f27f1a0 100644 --- a/src/postgrest/src/postgrest/base_client.py +++ b/src/postgrest/src/postgrest/base_client.py @@ -3,7 +3,8 @@ from abc import ABC, abstractmethod from typing import Dict, Optional, Union -from httpx import AsyncClient, BasicAuth, Client, Timeout +from httpx import AsyncClient, BasicAuth, Client, Headers, Timeout +from yarl import URL from .utils import is_http_url @@ -13,46 +14,25 @@ class BasePostgrestClient(ABC): def __init__( self, - base_url: str, + base_url: URL, *, schema: str, headers: Dict[str, str], timeout: Union[int, float, Timeout], verify: bool = True, proxy: Optional[str] = None, - http_client: Union[Client, AsyncClient, None] = None, ) -> None: if not is_http_url(base_url): ValueError("base_url must be a valid HTTP URL string") self.base_url = base_url - self.headers = { - **headers, - "Accept-Profile": schema, - "Content-Profile": schema, - } + self.headers = Headers(headers) + self.headers["Accept-Profile"] = schema + self.headers["Content-Profile"] = schema self.timeout = timeout self.verify = verify self.proxy = proxy - self.http_client = http_client - self.session = self.create_session( - self.base_url, - self.headers, - self.timeout, - self.verify, - self.proxy, - ) - - @abstractmethod - def create_session( - self, - base_url: str, - headers: Dict[str, str], - timeout: Union[int, float, Timeout], - verify: bool = True, - proxy: Optional[str] = None, - ) -> Union[Client, AsyncClient]: - raise NotImplementedError() + self.basic_auth: BasicAuth | None = None def auth( self, @@ -71,9 +51,9 @@ def auth( Bearer token is preferred if both ones are provided. """ if token: - self.session.headers["Authorization"] = f"Bearer {token}" + self.headers["Authorization"] = f"Bearer {token}" elif username: - self.session.auth = BasicAuth(username, password) + self.basic_auth = BasicAuth(username, password) else: raise ValueError( "Neither bearer token or basic authentication scheme is provided" diff --git a/src/postgrest/src/postgrest/base_request_builder.py b/src/postgrest/src/postgrest/base_request_builder.py index 9d79e232..f92a6158 100644 --- a/src/postgrest/src/postgrest/base_request_builder.py +++ b/src/postgrest/src/postgrest/base_request_builder.py @@ -5,6 +5,7 @@ from re import search from typing import ( Any, + Awaitable, Dict, Generic, Iterable, @@ -16,11 +17,13 @@ Type, TypeVar, Union, + overload, ) -from httpx import AsyncClient, Client, Headers, QueryParams +from httpx import AsyncClient, BasicAuth, Client, Headers, QueryParams from httpx import Response as RequestResponse from pydantic import BaseModel, ValidationError +from yarl import URL try: from typing import Self # type: ignore @@ -47,6 +50,44 @@ class QueryArgs(NamedTuple): json: JSON +C = TypeVar("C", Client, AsyncClient) + + +class RequestConfig(Generic[C]): + def __init__( + self, + session: C, + path: URL, + http_method: str, + headers: Headers, + params: QueryParams, + auth: BasicAuth | None, + json: JSON, + ) -> None: + self.session: C = session + self.path = path + self.http_method = http_method + self.headers = headers + self.params = params + self.json = None if http_method in {"GET", "HEAD"} else json + self.auth = auth + + @overload + def send(self: RequestConfig[Client]) -> RequestResponse: ... + @overload + def send(self: RequestConfig[AsyncClient]) -> Awaitable[RequestResponse]: ... + + def send(self: RequestConfig[C]): + return self.session.request( + self.http_method, + str(self.path), + json=self.json, + params=self.params, + headers=self.headers, + auth=self.auth, + ) + + def _unique_columns(json: List[Dict[str, JSON]]): unique_keys = {key for row in json for key in row.keys()} columns = ",".join([f'"{k}"' for k in unique_keys]) @@ -227,14 +268,9 @@ def from_http_request_response( return SingleAPIResponse(data=data, count=count) -class BaseFilterRequestBuilder: - def __init__( - self, - headers: Headers, - params: QueryParams, - ) -> None: - self.headers = headers - self.params = params +class BaseFilterRequestBuilder(Generic[C]): + def __init__(self, request: RequestConfig[C]) -> None: + self.request: RequestConfig[C] = request self.negate_next = False @property @@ -255,7 +291,7 @@ def filter(self: Self, column: str, operator: str, criteria: str) -> Self: self.negate_next = False operator = f"{Filters.NOT}.{operator}" key, val = sanitize_param(column), f"{operator}.{criteria}" - self.params = self.params.add(key, val) + self.request.params = self.request.params.add(key, val) return self def eq(self: Self, column: str, value: Any) -> Self: @@ -389,7 +425,7 @@ def or_(self: Self, filters: str, reference_table: Optional[str] = None) -> Self reference_table: Set this to filter on referenced tables instead of the parent table """ key = f"{sanitize_param(reference_table)}.or" if reference_table else "or" - self.params = self.params.add(key, f"({filters})") + self.request.params = self.request.params.add(key, f"({filters})") return self def fts(self: Self, column: str, query: Any) -> Self: @@ -507,7 +543,7 @@ def max_affected(self: Self, value: int) -> Self: Args: value: The maximum number of rows that can be affected """ - prefer_header = self.headers.get("Prefer", "") + prefer_header = self.request.headers.get("Prefer", "") if prefer_header: if "handling=strict" not in prefer_header: prefer_header += ",handling=strict" @@ -516,11 +552,11 @@ def max_affected(self: Self, value: int) -> Self: prefer_header += f",max-affected={value}" - self.headers["Prefer"] = prefer_header + self.request.headers["Prefer"] = prefer_header return self -class BaseSelectRequestBuilder(BaseFilterRequestBuilder): +class BaseSelectRequestBuilder(BaseFilterRequestBuilder[C]): def explain( self: Self, analyze: bool = False, @@ -536,7 +572,7 @@ def explain( if key not in ["self", "format"] and value ] options_str = "|".join(options) - self.headers["Accept"] = ( + self.request.headers["Accept"] = ( f"application/vnd.pgrst.plan+{format}; options={options_str}" ) return self @@ -560,9 +596,9 @@ def order( Allow ordering results for foreign tables with the foreign_table parameter. """ key = f"{foreign_table}.order" if foreign_table else "order" - existing_order = self.params.get(key) + existing_order = self.request.params.get(key) - self.params = self.params.set( + self.request.params = self.request.params.set( key, f"{existing_order + ',' if existing_order else ''}" + f"{column}.{'desc' if desc else 'asc'}" @@ -583,7 +619,7 @@ def limit(self: Self, size: int, *, foreign_table: Optional[str] = None) -> Self .. versionchanged:: 0.10.3 Allow limiting results returned for foreign tables with the foreign_table parameter. """ - self.params = self.params.add( + self.request.params = self.request.params.add( f"{foreign_table}.limit" if foreign_table else "limit", size, ) @@ -594,7 +630,7 @@ def offset(self: Self, size: int) -> Self: Args: size: The number of the row to start at """ - self.params = self.params.add( + self.request.params = self.request.params.add( "offset", size, ) @@ -603,10 +639,10 @@ def offset(self: Self, size: int) -> Self: def range( self: Self, start: int, end: int, foreign_table: Optional[str] = None ) -> Self: - self.params = self.params.add( + self.request.params = self.request.params.add( f"{foreign_table}.offset" if foreign_table else "offset", start ) - self.params = self.params.add( + self.request.params = self.request.params.add( f"{foreign_table}.limit" if foreign_table else "limit", end - start + 1, ) @@ -626,11 +662,11 @@ def select( :class:`BaseSelectRequestBuilder` """ method, params, headers, json = pre_select(*columns, count=None) - self.params = self.params.add("select", params.get("select")) - if self.headers.get("Prefer"): - self.headers["Prefer"] += ",return=representation" + self.request.params = self.request.params.add("select", params.get("select")) + if self.request.headers.get("Prefer"): + self.request.headers["Prefer"] += ",return=representation" else: - self.headers["Prefer"] = "return=representation" + self.request.headers["Prefer"] = "return=representation" return self @@ -640,15 +676,15 @@ def single(self) -> Self: .. caution:: The API will raise an error if the query returned more than one row. """ - self.headers["Accept"] = "application/vnd.pgrst.object+json" + self.request.headers["Accept"] = "application/vnd.pgrst.object+json" return self def maybe_single(self) -> Self: """Retrieves at most one row from the result. Result must be at most one row (e.g. using `eq` on a UNIQUE column), otherwise this will result in an error.""" - self.headers["Accept"] = "application/vnd.pgrst.object+json" + self.request.headers["Accept"] = "application/vnd.pgrst.object+json" return self def csv(self) -> Self: """Specify that the query must retrieve data as a single CSV string.""" - self.headers["Accept"] = "text/csv" + self.request.headers["Accept"] = "text/csv" return self diff --git a/src/postgrest/src/postgrest/types.py b/src/postgrest/src/postgrest/types.py index 17b0804f..748f87e4 100644 --- a/src/postgrest/src/postgrest/types.py +++ b/src/postgrest/src/postgrest/types.py @@ -4,8 +4,10 @@ from collections.abc import Mapping, Sequence from typing import Union +from httpx import AsyncClient, BasicAuth, Client, Headers, QueryParams from pydantic import TypeAdapter from typing_extensions import TypeAliasType +from yarl import URL if sys.version_info >= (3, 11): from enum import StrEnum diff --git a/src/postgrest/src/postgrest/utils.py b/src/postgrest/src/postgrest/utils.py index aa8967d2..0fa2cbec 100644 --- a/src/postgrest/src/postgrest/utils.py +++ b/src/postgrest/src/postgrest/utils.py @@ -7,6 +7,7 @@ from httpx import AsyncClient # noqa: F401 from httpx import Client as BaseClient # noqa: F401 from pydantic import BaseModel +from yarl import URL from .version import __version__ @@ -40,8 +41,8 @@ def sanitize_pattern_param(pattern: str) -> str: return sanitize_param(pattern.replace("%", "*")) -def is_http_url(url: str) -> bool: - return urlparse(url).scheme in {"https", "http"} +def is_http_url(url: URL) -> bool: + return url.scheme in {"https", "http"} TBaseModel = TypeVar("TBaseModel", bound=BaseModel) diff --git a/src/postgrest/tests/_async/client.py b/src/postgrest/tests/_async/client.py index 25cdeb0e..fd585bbb 100644 --- a/src/postgrest/tests/_async/client.py +++ b/src/postgrest/tests/_async/client.py @@ -11,7 +11,7 @@ def rest_client(): ) -def rest_client_httpx(): +def rest_client_httpx() -> AsyncPostgrestClient: transport = AsyncHTTPTransport( retries=4, limits=Limits( diff --git a/src/postgrest/tests/_async/test_client.py b/src/postgrest/tests/_async/test_client.py index 24fc43cc..292a0401 100644 --- a/src/postgrest/tests/_async/test_client.py +++ b/src/postgrest/tests/_async/test_client.py @@ -35,6 +35,7 @@ def test_simple(self, postgrest_client: AsyncPostgrestClient): "Content-Profile": "public", } ) + print(session.headers) assert session.headers.items() >= headers.items() @pytest.mark.asyncio @@ -57,7 +58,7 @@ async def test_custom_headers(self): class TestHttpxClientConstructor: @pytest.mark.asyncio - async def test_custom_httpx_client(self): + async def test_custom_httpx_client(self) -> None: transport = AsyncHTTPTransport( retries=10, limits=Limits( @@ -71,40 +72,37 @@ async def test_custom_httpx_client(self): async with AsyncPostgrestClient( "https://example.com", http_client=http_client, timeout=20.0 ) as client: - session = client.session - - assert session.base_url == "https://example.com" - assert session.timeout == Timeout( + assert str(client.base_url) == "https://example.com" + assert client.session.timeout == Timeout( timeout=5.0 ) # Should be the default 5 since we use custom httpx client - assert session.headers.get("x-user-agent") == "my-app/0.0.1" - assert isinstance(session, AsyncClient) + assert client.session.headers.get("x-user-agent") == "my-app/0.0.1" + assert isinstance(client.session, AsyncClient) class TestAuth: def test_auth_token(self, postgrest_client: AsyncPostgrestClient): postgrest_client.auth("s3cr3t") - session = postgrest_client.session - - assert session.headers["Authorization"] == "Bearer s3cr3t" + assert postgrest_client.headers["Authorization"] == "Bearer s3cr3t" def test_auth_basic(self, postgrest_client: AsyncPostgrestClient): postgrest_client.auth(None, username="admin", password="s3cr3t") - session = postgrest_client.session - assert isinstance(session.auth, BasicAuth) - assert session.auth._auth_header == BasicAuth("admin", "s3cr3t")._auth_header + assert isinstance(postgrest_client.basic_auth, BasicAuth) + assert ( + postgrest_client.basic_auth._auth_header + == BasicAuth("admin", "s3cr3t")._auth_header + ) def test_schema(postgrest_client: AsyncPostgrestClient): client = postgrest_client.schema("private") - session = client.session subheaders = { "accept-profile": "private", "content-profile": "private", } - assert subheaders.items() < dict(session.headers).items() + assert subheaders.items() < client.headers.items() @pytest.mark.asyncio @@ -154,8 +152,10 @@ async def test_response_maybe_single(postgrest_client: AsyncPostgrestClient): client = ( postgrest_client.from_("test").select("a", "b").eq("c", "d").maybe_single() ) - assert "Accept" in client.headers - assert client.headers.get("Accept") == "application/vnd.pgrst.object+json" + assert "Accept" in client.request.headers + assert ( + client.request.headers.get("Accept") == "application/vnd.pgrst.object+json" + ) with pytest.raises(APIError) as exc_info: await client.execute() assert isinstance(exc_info, pytest.ExceptionInfo) @@ -178,8 +178,10 @@ async def test_response_client_invalid_response_but_valid_json( ), ): client = postgrest_client.from_("test").select("a", "b").eq("c", "d").single() - assert "Accept" in client.headers - assert client.headers.get("Accept") == "application/vnd.pgrst.object+json" + assert "Accept" in client.request.headers + assert ( + client.request.headers.get("Accept") == "application/vnd.pgrst.object+json" + ) with pytest.raises(APIError) as exc_info: await client.execute() assert isinstance(exc_info, pytest.ExceptionInfo) diff --git a/src/postgrest/tests/_async/test_filter_request_builder.py b/src/postgrest/tests/_async/test_filter_request_builder.py index 1174e1ea..e4a63eb9 100644 --- a/src/postgrest/tests/_async/test_filter_request_builder.py +++ b/src/postgrest/tests/_async/test_filter_request_builder.py @@ -1,25 +1,30 @@ +from typing import AsyncIterable + import pytest from httpx import AsyncClient, Headers, QueryParams +from yarl import URL from postgrest import AsyncFilterRequestBuilder +from postgrest._async.request_builder import RequestConfig @pytest.fixture -async def filter_request_builder(): +async def filter_request_builder() -> AsyncIterable[AsyncFilterRequestBuilder]: async with AsyncClient() as client: - yield AsyncFilterRequestBuilder( - client, "/example_table", "GET", Headers(), QueryParams(), {} + request = RequestConfig( + client, URL("/example_table"), "GET", Headers(), QueryParams(), None, {} ) + yield AsyncFilterRequestBuilder(request) def test_constructor(filter_request_builder: AsyncFilterRequestBuilder): builder = filter_request_builder - assert builder.path == "/example_table" - assert len(builder.headers) == 0 - assert len(builder.params) == 0 - assert builder.http_method == "GET" - assert builder.json is None + assert str(builder.request.path) == "/example_table" + assert len(builder.request.headers) == 0 + assert len(builder.request.params) == 0 + assert builder.request.http_method == "GET" + assert builder.request.json is None assert not builder.negate_next @@ -32,7 +37,7 @@ def test_not_(filter_request_builder): def test_filter(filter_request_builder): builder = filter_request_builder.filter(":col.name", "eq", "val") - assert builder.params['":col.name"'] == "eq.val" + assert builder.request.params['":col.name"'] == "eq.val" @pytest.mark.parametrize( @@ -47,76 +52,76 @@ def test_filter_special_characters( ): builder = filter_request_builder.filter(col_name, "eq", "val") - assert str(builder.params) == f"{expected_query_prefix}=eq.val" + assert str(builder.request.params) == f"{expected_query_prefix}=eq.val" def test_multivalued_param(filter_request_builder): builder = filter_request_builder.lte("x", "a").gte("x", "b") - assert str(builder.params) == "x=lte.a&x=gte.b" + assert str(builder.request.params) == "x=lte.a&x=gte.b" def test_match(filter_request_builder): builder = filter_request_builder.match({"id": "1", "done": "false"}) - assert str(builder.params) == "id=eq.1&done=eq.false" + assert str(builder.request.params) == "id=eq.1&done=eq.false" def test_equals(filter_request_builder): builder = filter_request_builder.eq("x", "a") - assert str(builder.params) == "x=eq.a" + assert str(builder.request.params) == "x=eq.a" def test_not_equal(filter_request_builder): builder = filter_request_builder.neq("x", "a") - assert str(builder.params) == "x=neq.a" + assert str(builder.request.params) == "x=neq.a" def test_greater_than(filter_request_builder): builder = filter_request_builder.gt("x", "a") - assert str(builder.params) == "x=gt.a" + assert str(builder.request.params) == "x=gt.a" def test_greater_than_or_equals_to(filter_request_builder): builder = filter_request_builder.gte("x", "a") - assert str(builder.params) == "x=gte.a" + assert str(builder.request.params) == "x=gte.a" def test_contains(filter_request_builder): builder = filter_request_builder.contains("x", "a") - assert str(builder.params) == "x=cs.a" + assert str(builder.request.params) == "x=cs.a" def test_contains_dictionary(filter_request_builder): builder = filter_request_builder.contains("x", {"a": "b"}) # {"a":"b"} - assert str(builder.params) == "x=cs.%7B%22a%22%3A+%22b%22%7D" + assert str(builder.request.params) == "x=cs.%7B%22a%22%3A+%22b%22%7D" def test_contains_any_item(filter_request_builder): builder = filter_request_builder.contains("x", ["a", "b"]) # {a,b} - assert str(builder.params) == "x=cs.%7Ba%2Cb%7D" + assert str(builder.request.params) == "x=cs.%7Ba%2Cb%7D" def test_contains_in_list(filter_request_builder): builder = filter_request_builder.contains("x", '[{"a": "b"}]') # [{"a":+"b"}] (the + represents the space) - assert str(builder.params) == "x=cs.%5B%7B%22a%22%3A+%22b%22%7D%5D" + assert str(builder.request.params) == "x=cs.%5B%7B%22a%22%3A+%22b%22%7D%5D" def test_contained_by_mixed_items(filter_request_builder): builder = filter_request_builder.contained_by("x", ["a", '["b", "c"]']) # {a,["b",+"c"]} - assert str(builder.params) == "x=cd.%7Ba%2C%5B%22b%22%2C+%22c%22%5D%7D" + assert str(builder.request.params) == "x=cd.%7Ba%2C%5B%22b%22%2C+%22c%22%5D%7D" def test_range_greater_than(filter_request_builder): @@ -125,7 +130,10 @@ def test_range_greater_than(filter_request_builder): ) # {a,["b",+"c"]} - assert str(builder.params) == "x=sr.%282000-01-02+08%3A30%2C2000-01-02+09%3A30%29" + assert ( + str(builder.request.params) + == "x=sr.%282000-01-02+08%3A30%2C2000-01-02+09%3A30%29" + ) def test_range_greater_than_or_equal_to(filter_request_builder): @@ -134,7 +142,10 @@ def test_range_greater_than_or_equal_to(filter_request_builder): ) # {a,["b",+"c"]} - assert str(builder.params) == "x=nxl.%282000-01-02+08%3A30%2C2000-01-02+09%3A30%29" + assert ( + str(builder.request.params) + == "x=nxl.%282000-01-02+08%3A30%2C2000-01-02+09%3A30%29" + ) def test_range_less_than(filter_request_builder): @@ -143,7 +154,10 @@ def test_range_less_than(filter_request_builder): ) # {a,["b",+"c"]} - assert str(builder.params) == "x=sl.%282000-01-02+08%3A30%2C2000-01-02+09%3A30%29" + assert ( + str(builder.request.params) + == "x=sl.%282000-01-02+08%3A30%2C2000-01-02+09%3A30%29" + ) def test_range_less_than_or_equal_to(filter_request_builder): @@ -152,7 +166,10 @@ def test_range_less_than_or_equal_to(filter_request_builder): ) # {a,["b",+"c"]} - assert str(builder.params) == "x=nxr.%282000-01-02+08%3A30%2C2000-01-02+09%3A30%29" + assert ( + str(builder.request.params) + == "x=nxr.%282000-01-02+08%3A30%2C2000-01-02+09%3A30%29" + ) def test_range_adjacent(filter_request_builder): @@ -161,14 +178,17 @@ def test_range_adjacent(filter_request_builder): ) # {a,["b",+"c"]} - assert str(builder.params) == "x=adj.%282000-01-02+08%3A30%2C2000-01-02+09%3A30%29" + assert ( + str(builder.request.params) + == "x=adj.%282000-01-02+08%3A30%2C2000-01-02+09%3A30%29" + ) def test_overlaps(filter_request_builder): builder = filter_request_builder.overlaps("x", ["is:closed", "severity:high"]) # {a,["b",+"c"]} - assert str(builder.params) == "x=ov.%7Bis%3Aclosed%2Cseverity%3Ahigh%7D" + assert str(builder.request.params) == "x=ov.%7Bis%3Aclosed%2Cseverity%3Ahigh%7D" def test_overlaps_with_timestamp_range(filter_request_builder): @@ -177,68 +197,71 @@ def test_overlaps_with_timestamp_range(filter_request_builder): ) # {a,["b",+"c"]} - assert str(builder.params) == "x=ov.%5B2000-01-01+12%3A45%2C+2000-01-01+13%3A15%29" + assert ( + str(builder.request.params) + == "x=ov.%5B2000-01-01+12%3A45%2C+2000-01-01+13%3A15%29" + ) def test_like(filter_request_builder): builder = filter_request_builder.like("x", "%a%") - assert str(builder.params) == "x=like.%25a%25" + assert str(builder.request.params) == "x=like.%25a%25" def test_ilike(filter_request_builder): builder = filter_request_builder.ilike("x", "%a%") - assert str(builder.params) == "x=ilike.%25a%25" + assert str(builder.request.params) == "x=ilike.%25a%25" def test_like_all_of(filter_request_builder): builder = filter_request_builder.like_all_of("x", "A*,*b") - assert str(builder.params) == "x=like%28all%29.%7BA%2A%2C%2Ab%7D" + assert str(builder.request.params) == "x=like%28all%29.%7BA%2A%2C%2Ab%7D" def test_like_any_of(filter_request_builder): builder = filter_request_builder.like_any_of("x", "a*,*b") - assert str(builder.params) == "x=like%28any%29.%7Ba%2A%2C%2Ab%7D" + assert str(builder.request.params) == "x=like%28any%29.%7Ba%2A%2C%2Ab%7D" def test_ilike_all_of(filter_request_builder): builder = filter_request_builder.ilike_all_of("x", "A*,*b") - assert str(builder.params) == "x=ilike%28all%29.%7BA%2A%2C%2Ab%7D" + assert str(builder.request.params) == "x=ilike%28all%29.%7BA%2A%2C%2Ab%7D" def test_ilike_any_of(filter_request_builder): builder = filter_request_builder.ilike_any_of("x", "A*,*b") - assert str(builder.params) == "x=ilike%28any%29.%7BA%2A%2C%2Ab%7D" + assert str(builder.request.params) == "x=ilike%28any%29.%7BA%2A%2C%2Ab%7D" def test_is_(filter_request_builder): builder = filter_request_builder.is_("x", "a") - assert str(builder.params) == "x=is.a" + assert str(builder.request.params) == "x=is.a" def test_in_(filter_request_builder): builder = filter_request_builder.in_("x", ["a", "b"]) - assert str(builder.params) == "x=in.%28a%2Cb%29" + assert str(builder.request.params) == "x=in.%28a%2Cb%29" def test_or_(filter_request_builder): builder = filter_request_builder.or_("x.eq.1") - assert str(builder.params) == "or=%28x.eq.1%29" + assert str(builder.request.params) == "or=%28x.eq.1%29" def test_or_in_contain(filter_request_builder): builder = filter_request_builder.or_("id.in.(5,6,7), arraycol.cs.{'a','b'}") assert ( - str(builder.params) + str(builder.request.params) == "or=%28id.in.%285%2C6%2C7%29%2C+arraycol.cs.%7B%27a%27%2C%27b%27%7D%29" ) @@ -246,26 +269,29 @@ def test_or_in_contain(filter_request_builder): def test_max_affected(filter_request_builder): builder = filter_request_builder.max_affected(5) - assert builder.headers["prefer"] == "handling=strict,max-affected=5" + assert builder.request.headers["prefer"] == "handling=strict,max-affected=5" def test_max_affected_with_existing_prefer_header(filter_request_builder): # Set an existing prefer header - filter_request_builder.headers["prefer"] = "return=representation" + filter_request_builder.request.headers["prefer"] = "return=representation" builder = filter_request_builder.max_affected(10) assert ( - builder.headers["prefer"] + builder.request.headers["prefer"] == "return=representation,handling=strict,max-affected=10" ) def test_max_affected_with_existing_handling_strict(filter_request_builder): # Set an existing prefer header with handling=strict - filter_request_builder.headers["prefer"] = "handling=strict,return=minimal" + filter_request_builder.request.headers["prefer"] = "handling=strict,return=minimal" builder = filter_request_builder.max_affected(3) - assert builder.headers["prefer"] == "handling=strict,return=minimal,max-affected=3" + assert ( + builder.request.headers["prefer"] + == "handling=strict,return=minimal,max-affected=3" + ) def test_max_affected_returns_self(filter_request_builder): diff --git a/src/postgrest/tests/_async/test_query_request_builder.py b/src/postgrest/tests/_async/test_query_request_builder.py index ccbf5c5b..78edc2da 100644 --- a/src/postgrest/tests/_async/test_query_request_builder.py +++ b/src/postgrest/tests/_async/test_query_request_builder.py @@ -1,22 +1,27 @@ +from typing import AsyncIterable + import pytest from httpx import AsyncClient, Headers, QueryParams +from yarl import URL from postgrest import AsyncQueryRequestBuilder +from postgrest._async.request_builder import RequestConfig @pytest.fixture -async def query_request_builder(): +async def query_request_builder() -> AsyncIterable[AsyncQueryRequestBuilder]: async with AsyncClient() as client: - yield AsyncQueryRequestBuilder( - client, "/example_table", "GET", Headers(), QueryParams(), {} + request = RequestConfig( + client, URL("/example_table"), "GET", Headers(), QueryParams(), None, {} ) + yield AsyncQueryRequestBuilder(request) def test_constructor(query_request_builder: AsyncQueryRequestBuilder): builder = query_request_builder - assert builder.path == "/example_table" - assert len(builder.headers) == 0 - assert len(builder.params) == 0 - assert builder.http_method == "GET" - assert builder.json is None + assert str(builder.request.path) == "/example_table" + assert len(builder.request.headers) == 0 + assert len(builder.request.params) == 0 + assert builder.request.http_method == "GET" + assert builder.request.json is None diff --git a/src/postgrest/tests/_async/test_request_builder.py b/src/postgrest/tests/_async/test_request_builder.py index c8e6a6f5..356be18f 100644 --- a/src/postgrest/tests/_async/test_request_builder.py +++ b/src/postgrest/tests/_async/test_request_builder.py @@ -1,52 +1,54 @@ -from typing import Any, Dict, List +from typing import Any, AsyncIterable, Dict, List import pytest -from httpx import AsyncClient, Request, Response +from httpx import AsyncClient, Headers, QueryParams, Request, Response +from yarl import URL from postgrest import AsyncRequestBuilder, AsyncSingleRequestBuilder +from postgrest._async.request_builder import RequestConfig from postgrest.base_request_builder import APIResponse, SingleAPIResponse from postgrest.types import JSON, CountMethod @pytest.fixture -async def request_builder(): +async def request_builder() -> AsyncIterable[AsyncRequestBuilder]: async with AsyncClient() as client: - yield AsyncRequestBuilder(client, "/example_table") + yield AsyncRequestBuilder(client, URL("/example_table"), Headers(), None) def test_constructor(request_builder): - assert request_builder.path == "/example_table" + assert str(request_builder.path) == "/example_table" class TestSelect: def test_select(self, request_builder: AsyncRequestBuilder): builder = request_builder.select("col1", "col2") - assert builder.params["select"] == "col1,col2" - assert builder.headers.get("prefer") is None - assert builder.http_method == "GET" - assert builder.json is None + assert builder.request.params["select"] == "col1,col2" + assert builder.request.headers.get("prefer") is None + assert builder.request.http_method == "GET" + assert builder.request.json is None def test_select_with_count(self, request_builder: AsyncRequestBuilder): builder = request_builder.select(count=CountMethod.exact) - assert builder.params["select"] == "*" - assert builder.headers["prefer"] == "count=exact" - assert builder.http_method == "GET" - assert builder.json is None + assert builder.request.params["select"] == "*" + assert builder.request.headers["prefer"] == "count=exact" + assert builder.request.http_method == "GET" + assert builder.request.json is None def test_select_with_head(self, request_builder: AsyncRequestBuilder): builder = request_builder.select("col1", "col2", head=True) - assert builder.params.get("select") == "col1,col2" - assert builder.headers.get("prefer") is None - assert builder.http_method == "HEAD" - assert builder.json is None + assert builder.request.params.get("select") == "col1,col2" + assert builder.request.headers.get("prefer") is None + assert builder.request.http_method == "HEAD" + assert builder.request.json is None def test_select_as_csv(self, request_builder: AsyncRequestBuilder): builder = request_builder.select("*").csv() - assert builder.headers["Accept"] == "text/csv" + assert builder.request.headers["Accept"] == "text/csv" assert isinstance(builder, AsyncSingleRequestBuilder) @@ -54,77 +56,85 @@ class TestInsert: def test_insert(self, request_builder: AsyncRequestBuilder): builder = request_builder.insert({"key1": "val1"}) - assert builder.headers.get_list("prefer", True) == ["return=representation"] - assert builder.http_method == "POST" - assert builder.json == {"key1": "val1"} + assert builder.request.headers.get_list("prefer", True) == [ + "return=representation" + ] + assert builder.request.http_method == "POST" + assert builder.request.json == {"key1": "val1"} def test_insert_with_count(self, request_builder: AsyncRequestBuilder): builder = request_builder.insert({"key1": "val1"}, count=CountMethod.exact) - assert builder.headers.get_list("prefer", True) == [ + assert builder.request.headers.get_list("prefer", True) == [ "return=representation", "count=exact", ] - assert builder.http_method == "POST" - assert builder.json == {"key1": "val1"} + assert builder.request.http_method == "POST" + assert builder.request.json == {"key1": "val1"} def test_insert_with_upsert(self, request_builder: AsyncRequestBuilder): builder = request_builder.insert({"key1": "val1"}, upsert=True) - assert builder.headers.get_list("prefer", True) == [ + assert builder.request.headers.get_list("prefer", True) == [ "return=representation", "resolution=merge-duplicates", ] - assert builder.http_method == "POST" - assert builder.json == {"key1": "val1"} + assert builder.request.http_method == "POST" + assert builder.request.json == {"key1": "val1"} def test_upsert_with_default_single(self, request_builder: AsyncRequestBuilder): builder = request_builder.upsert([{"key1": "val1"}], default_to_null=False) - assert builder.headers.get_list("prefer", True) == [ + assert builder.request.headers.get_list("prefer", True) == [ "return=representation", "resolution=merge-duplicates", "missing=default", ] - assert builder.http_method == "POST" - assert builder.json == [{"key1": "val1"}] - assert builder.params.get("columns") == '"key1"' + assert builder.request.http_method == "POST" + assert builder.request.json == [{"key1": "val1"}] + assert builder.request.params.get("columns") == '"key1"' def test_bulk_insert_using_default(self, request_builder: AsyncRequestBuilder): builder = request_builder.insert( [{"key1": "val1", "key2": "val2"}, {"key3": "val3"}], default_to_null=False ) - assert builder.headers.get_list("prefer", True) == [ + assert builder.request.headers.get_list("prefer", True) == [ "return=representation", "missing=default", ] - assert builder.http_method == "POST" - assert builder.json == [{"key1": "val1", "key2": "val2"}, {"key3": "val3"}] - assert set(builder.params["columns"].split(",")) == set( + assert builder.request.http_method == "POST" + assert builder.request.json == [ + {"key1": "val1", "key2": "val2"}, + {"key3": "val3"}, + ] + assert set(builder.request.params["columns"].split(",")) == set( '"key1","key2","key3"'.split(",") ) def test_upsert(self, request_builder: AsyncRequestBuilder): builder = request_builder.upsert({"key1": "val1"}) - assert builder.headers.get_list("prefer", True) == [ + assert builder.request.headers.get_list("prefer", True) == [ "return=representation", "resolution=merge-duplicates", ] - assert builder.http_method == "POST" - assert builder.json == {"key1": "val1"} + assert builder.request.http_method == "POST" + assert builder.request.json == {"key1": "val1"} def test_bulk_upsert_with_default(self, request_builder: AsyncRequestBuilder): builder = request_builder.upsert( [{"key1": "val1", "key2": "val2"}, {"key3": "val3"}], default_to_null=False ) - assert builder.headers.get_list("prefer", True) == [ + assert builder.request.headers.get_list("prefer", True) == [ "return=representation", "resolution=merge-duplicates", "missing=default", ] - assert builder.http_method == "POST" - assert builder.json == [{"key1": "val1", "key2": "val2"}, {"key3": "val3"}] - assert set(builder.params["columns"].split(",")) == set( + assert builder.request.http_method == "POST" + assert builder.request.json == [ + {"key1": "val1", "key2": "val2"}, + {"key3": "val3"}, + ] + assert set(builder.request.params["columns"].split(",")) == set( '"key1","key2","key3"'.split(",") ) @@ -133,56 +143,60 @@ class TestUpdate: def test_update(self, request_builder: AsyncRequestBuilder): builder = request_builder.update({"key1": "val1"}) - assert builder.headers.get_list("prefer", True) == ["return=representation"] - assert builder.http_method == "PATCH" - assert builder.json == {"key1": "val1"} + assert builder.request.headers.get_list("prefer", True) == [ + "return=representation" + ] + assert builder.request.http_method == "PATCH" + assert builder.request.json == {"key1": "val1"} def test_update_with_count(self, request_builder: AsyncRequestBuilder): builder = request_builder.update({"key1": "val1"}, count=CountMethod.exact) - assert builder.headers.get_list("prefer", True) == [ + assert builder.request.headers.get_list("prefer", True) == [ "return=representation", "count=exact", ] - assert builder.http_method == "PATCH" - assert builder.json == {"key1": "val1"} + assert builder.request.http_method == "PATCH" + assert builder.request.json == {"key1": "val1"} def test_update_with_max_affected(self, request_builder: AsyncRequestBuilder): builder = request_builder.update({"key1": "val1"}).max_affected(5) - assert "handling=strict" in builder.headers["prefer"] - assert "max-affected=5" in builder.headers["prefer"] - assert "return=representation" in builder.headers["prefer"] - assert builder.http_method == "PATCH" - assert builder.json == {"key1": "val1"} + assert "handling=strict" in builder.request.headers["prefer"] + assert "max-affected=5" in builder.request.headers["prefer"] + assert "return=representation" in builder.request.headers["prefer"] + assert builder.request.http_method == "PATCH" + assert builder.request.json == {"key1": "val1"} class TestDelete: def test_delete(self, request_builder: AsyncRequestBuilder): builder = request_builder.delete() - assert builder.headers.get_list("prefer", True) == ["return=representation"] - assert builder.http_method == "DELETE" - assert builder.json == {} + assert builder.request.headers.get_list("prefer", True) == [ + "return=representation" + ] + assert builder.request.http_method == "DELETE" + assert builder.request.json == {} def test_delete_with_count(self, request_builder: AsyncRequestBuilder): builder = request_builder.delete(count=CountMethod.exact) - assert builder.headers.get_list("prefer", True) == [ + assert builder.request.headers.get_list("prefer", True) == [ "return=representation", "count=exact", ] - assert builder.http_method == "DELETE" - assert builder.json == {} + assert builder.request.http_method == "DELETE" + assert builder.request.json == {} def test_delete_with_max_affected(self, request_builder: AsyncRequestBuilder): builder = request_builder.delete().max_affected(10) - assert "handling=strict" in builder.headers["prefer"] - assert "max-affected=10" in builder.headers["prefer"] - assert "return=representation" in builder.headers["prefer"] - assert builder.http_method == "DELETE" - assert builder.json == {} + assert "handling=strict" in builder.request.headers["prefer"] + assert "max-affected=10" in builder.request.headers["prefer"] + assert "return=representation" in builder.request.headers["prefer"] + assert builder.request.http_method == "DELETE" + assert builder.request.json == {} class TestTextSearch: @@ -196,31 +210,35 @@ def test_text_search(self, request_builder: AsyncRequestBuilder): }, ) assert "catchphrase=plfts%28english%29.%27fat%27+%26+%27cat%27" in str( - builder.params + builder.request.params ) class TestExplain: def test_explain_plain(self, request_builder: AsyncRequestBuilder): builder = request_builder.select("*").explain() - assert builder.params["select"] == "*" - assert "application/vnd.pgrst.plan" in str(builder.headers.get("accept")) + assert builder.request.params["select"] == "*" + assert "application/vnd.pgrst.plan" in str( + builder.request.headers.get("accept") + ) def test_explain_options(self, request_builder: AsyncRequestBuilder): builder = request_builder.select("*").explain( format="json", analyze=True, verbose=True, buffers=True, wal=True ) - assert builder.params["select"] == "*" - assert "application/vnd.pgrst.plan+json;" in str(builder.headers.get("accept")) + assert builder.request.params["select"] == "*" + assert "application/vnd.pgrst.plan+json;" in str( + builder.request.headers.get("accept") + ) assert "options=analyze|verbose|buffers|wal" in str( - builder.headers.get("accept") + builder.request.headers.get("accept") ) class TestOrder: def test_order(self, request_builder: AsyncRequestBuilder): builder = request_builder.select().order("country_name", desc=True) - assert str(builder.params) == "select=%2A&order=country_name.desc" + assert str(builder.request.params) == "select=%2A&order=country_name.desc" def test_multiple_orders(self, request_builder: AsyncRequestBuilder): builder = ( @@ -228,7 +246,10 @@ def test_multiple_orders(self, request_builder: AsyncRequestBuilder): .order("country_name", desc=True) .order("iso", desc=True) ) - assert str(builder.params) == "select=%2A&order=country_name.desc%2Ciso.desc" + assert ( + str(builder.request.params) + == "select=%2A&order=country_name.desc%2Ciso.desc" + ) def test_multiple_orders_on_foreign_table( self, request_builder: AsyncRequestBuilder @@ -239,22 +260,25 @@ def test_multiple_orders_on_foreign_table( .order("city_name", desc=True, foreign_table=foreign_table) .order("id", desc=True, foreign_table=foreign_table) ) - assert str(builder.params) == "select=%2A&cities.order=city_name.desc%2Cid.desc" + assert ( + str(builder.request.params) + == "select=%2A&cities.order=city_name.desc%2Cid.desc" + ) class TestRange: def test_range_on_own_table(self, request_builder: AsyncRequestBuilder): builder = request_builder.select("*").range(0, 1) - assert builder.params["select"] == "*" - assert builder.params["limit"] == "2" - assert builder.params["offset"] == "0" + assert builder.request.params["select"] == "*" + assert builder.request.params["limit"] == "2" + assert builder.request.params["offset"] == "0" def test_range_on_foreign_table(self, request_builder: AsyncRequestBuilder): foreign_table = "cities" builder = request_builder.select("*").range(1, 2, foreign_table) - assert builder.params["select"] == "*" - assert builder.params[f"{foreign_table}.limit"] == "2" - assert builder.params[f"{foreign_table}.offset"] == "1" + assert builder.request.params["select"] == "*" + assert builder.request.params[f"{foreign_table}.limit"] == "2" + assert builder.request.params[f"{foreign_table}.offset"] == "1" @pytest.fixture diff --git a/src/postgrest/tests/_sync/client.py b/src/postgrest/tests/_sync/client.py index a4b2e132..832095aa 100644 --- a/src/postgrest/tests/_sync/client.py +++ b/src/postgrest/tests/_sync/client.py @@ -11,7 +11,7 @@ def rest_client(): ) -def rest_client_httpx(): +def rest_client_httpx() -> SyncPostgrestClient: transport = HTTPTransport( retries=4, limits=Limits( diff --git a/src/postgrest/tests/_sync/test_client.py b/src/postgrest/tests/_sync/test_client.py index ba07fd89..c2fdb771 100644 --- a/src/postgrest/tests/_sync/test_client.py +++ b/src/postgrest/tests/_sync/test_client.py @@ -35,6 +35,7 @@ def test_simple(self, postgrest_client: SyncPostgrestClient): "Content-Profile": "public", } ) + print(session.headers) assert session.headers.items() >= headers.items() def test_custom_headers(self): @@ -55,7 +56,7 @@ def test_custom_headers(self): class TestHttpxClientConstructor: - def test_custom_httpx_client(self): + def test_custom_httpx_client(self) -> None: transport = HTTPTransport( retries=10, limits=Limits( @@ -69,40 +70,37 @@ def test_custom_httpx_client(self): with SyncPostgrestClient( "https://example.com", http_client=http_client, timeout=20.0 ) as client: - session = client.session - - assert session.base_url == "https://example.com" - assert session.timeout == Timeout( + assert str(client.base_url) == "https://example.com" + assert client.session.timeout == Timeout( timeout=5.0 ) # Should be the default 5 since we use custom httpx client - assert session.headers.get("x-user-agent") == "my-app/0.0.1" - assert isinstance(session, Client) + assert client.session.headers.get("x-user-agent") == "my-app/0.0.1" + assert isinstance(client.session, Client) class TestAuth: def test_auth_token(self, postgrest_client: SyncPostgrestClient): postgrest_client.auth("s3cr3t") - session = postgrest_client.session - - assert session.headers["Authorization"] == "Bearer s3cr3t" + assert postgrest_client.headers["Authorization"] == "Bearer s3cr3t" def test_auth_basic(self, postgrest_client: SyncPostgrestClient): postgrest_client.auth(None, username="admin", password="s3cr3t") - session = postgrest_client.session - assert isinstance(session.auth, BasicAuth) - assert session.auth._auth_header == BasicAuth("admin", "s3cr3t")._auth_header + assert isinstance(postgrest_client.basic_auth, BasicAuth) + assert ( + postgrest_client.basic_auth._auth_header + == BasicAuth("admin", "s3cr3t")._auth_header + ) def test_schema(postgrest_client: SyncPostgrestClient): client = postgrest_client.schema("private") - session = client.session subheaders = { "accept-profile": "private", "content-profile": "private", } - assert subheaders.items() < dict(session.headers).items() + assert subheaders.items() < client.headers.items() def test_params_purged_after_execute(postgrest_client: SyncPostgrestClient): @@ -149,8 +147,10 @@ def test_response_maybe_single(postgrest_client: SyncPostgrestClient): client = ( postgrest_client.from_("test").select("a", "b").eq("c", "d").maybe_single() ) - assert "Accept" in client.headers - assert client.headers.get("Accept") == "application/vnd.pgrst.object+json" + assert "Accept" in client.request.headers + assert ( + client.request.headers.get("Accept") == "application/vnd.pgrst.object+json" + ) with pytest.raises(APIError) as exc_info: client.execute() assert isinstance(exc_info, pytest.ExceptionInfo) @@ -174,8 +174,10 @@ def test_response_client_invalid_response_but_valid_json( ), ): client = postgrest_client.from_("test").select("a", "b").eq("c", "d").single() - assert "Accept" in client.headers - assert client.headers.get("Accept") == "application/vnd.pgrst.object+json" + assert "Accept" in client.request.headers + assert ( + client.request.headers.get("Accept") == "application/vnd.pgrst.object+json" + ) with pytest.raises(APIError) as exc_info: client.execute() assert isinstance(exc_info, pytest.ExceptionInfo) diff --git a/src/postgrest/tests/_sync/test_filter_request_builder.py b/src/postgrest/tests/_sync/test_filter_request_builder.py index 2eae647f..c5c5c1d8 100644 --- a/src/postgrest/tests/_sync/test_filter_request_builder.py +++ b/src/postgrest/tests/_sync/test_filter_request_builder.py @@ -1,25 +1,30 @@ +from typing import Iterable + import pytest from httpx import Client, Headers, QueryParams +from yarl import URL from postgrest import SyncFilterRequestBuilder +from postgrest._sync.request_builder import RequestConfig @pytest.fixture -def filter_request_builder(): +def filter_request_builder() -> Iterable[SyncFilterRequestBuilder]: with Client() as client: - yield SyncFilterRequestBuilder( - client, "/example_table", "GET", Headers(), QueryParams(), {} + request = RequestConfig( + client, URL("/example_table"), "GET", Headers(), QueryParams(), None, {} ) + yield SyncFilterRequestBuilder(request) def test_constructor(filter_request_builder: SyncFilterRequestBuilder): builder = filter_request_builder - assert builder.path == "/example_table" - assert len(builder.headers) == 0 - assert len(builder.params) == 0 - assert builder.http_method == "GET" - assert builder.json is None + assert str(builder.request.path) == "/example_table" + assert len(builder.request.headers) == 0 + assert len(builder.request.params) == 0 + assert builder.request.http_method == "GET" + assert builder.request.json is None assert not builder.negate_next @@ -32,7 +37,7 @@ def test_not_(filter_request_builder): def test_filter(filter_request_builder): builder = filter_request_builder.filter(":col.name", "eq", "val") - assert builder.params['":col.name"'] == "eq.val" + assert builder.request.params['":col.name"'] == "eq.val" @pytest.mark.parametrize( @@ -47,76 +52,76 @@ def test_filter_special_characters( ): builder = filter_request_builder.filter(col_name, "eq", "val") - assert str(builder.params) == f"{expected_query_prefix}=eq.val" + assert str(builder.request.params) == f"{expected_query_prefix}=eq.val" def test_multivalued_param(filter_request_builder): builder = filter_request_builder.lte("x", "a").gte("x", "b") - assert str(builder.params) == "x=lte.a&x=gte.b" + assert str(builder.request.params) == "x=lte.a&x=gte.b" def test_match(filter_request_builder): builder = filter_request_builder.match({"id": "1", "done": "false"}) - assert str(builder.params) == "id=eq.1&done=eq.false" + assert str(builder.request.params) == "id=eq.1&done=eq.false" def test_equals(filter_request_builder): builder = filter_request_builder.eq("x", "a") - assert str(builder.params) == "x=eq.a" + assert str(builder.request.params) == "x=eq.a" def test_not_equal(filter_request_builder): builder = filter_request_builder.neq("x", "a") - assert str(builder.params) == "x=neq.a" + assert str(builder.request.params) == "x=neq.a" def test_greater_than(filter_request_builder): builder = filter_request_builder.gt("x", "a") - assert str(builder.params) == "x=gt.a" + assert str(builder.request.params) == "x=gt.a" def test_greater_than_or_equals_to(filter_request_builder): builder = filter_request_builder.gte("x", "a") - assert str(builder.params) == "x=gte.a" + assert str(builder.request.params) == "x=gte.a" def test_contains(filter_request_builder): builder = filter_request_builder.contains("x", "a") - assert str(builder.params) == "x=cs.a" + assert str(builder.request.params) == "x=cs.a" def test_contains_dictionary(filter_request_builder): builder = filter_request_builder.contains("x", {"a": "b"}) # {"a":"b"} - assert str(builder.params) == "x=cs.%7B%22a%22%3A+%22b%22%7D" + assert str(builder.request.params) == "x=cs.%7B%22a%22%3A+%22b%22%7D" def test_contains_any_item(filter_request_builder): builder = filter_request_builder.contains("x", ["a", "b"]) # {a,b} - assert str(builder.params) == "x=cs.%7Ba%2Cb%7D" + assert str(builder.request.params) == "x=cs.%7Ba%2Cb%7D" def test_contains_in_list(filter_request_builder): builder = filter_request_builder.contains("x", '[{"a": "b"}]') # [{"a":+"b"}] (the + represents the space) - assert str(builder.params) == "x=cs.%5B%7B%22a%22%3A+%22b%22%7D%5D" + assert str(builder.request.params) == "x=cs.%5B%7B%22a%22%3A+%22b%22%7D%5D" def test_contained_by_mixed_items(filter_request_builder): builder = filter_request_builder.contained_by("x", ["a", '["b", "c"]']) # {a,["b",+"c"]} - assert str(builder.params) == "x=cd.%7Ba%2C%5B%22b%22%2C+%22c%22%5D%7D" + assert str(builder.request.params) == "x=cd.%7Ba%2C%5B%22b%22%2C+%22c%22%5D%7D" def test_range_greater_than(filter_request_builder): @@ -125,7 +130,10 @@ def test_range_greater_than(filter_request_builder): ) # {a,["b",+"c"]} - assert str(builder.params) == "x=sr.%282000-01-02+08%3A30%2C2000-01-02+09%3A30%29" + assert ( + str(builder.request.params) + == "x=sr.%282000-01-02+08%3A30%2C2000-01-02+09%3A30%29" + ) def test_range_greater_than_or_equal_to(filter_request_builder): @@ -134,7 +142,10 @@ def test_range_greater_than_or_equal_to(filter_request_builder): ) # {a,["b",+"c"]} - assert str(builder.params) == "x=nxl.%282000-01-02+08%3A30%2C2000-01-02+09%3A30%29" + assert ( + str(builder.request.params) + == "x=nxl.%282000-01-02+08%3A30%2C2000-01-02+09%3A30%29" + ) def test_range_less_than(filter_request_builder): @@ -143,7 +154,10 @@ def test_range_less_than(filter_request_builder): ) # {a,["b",+"c"]} - assert str(builder.params) == "x=sl.%282000-01-02+08%3A30%2C2000-01-02+09%3A30%29" + assert ( + str(builder.request.params) + == "x=sl.%282000-01-02+08%3A30%2C2000-01-02+09%3A30%29" + ) def test_range_less_than_or_equal_to(filter_request_builder): @@ -152,7 +166,10 @@ def test_range_less_than_or_equal_to(filter_request_builder): ) # {a,["b",+"c"]} - assert str(builder.params) == "x=nxr.%282000-01-02+08%3A30%2C2000-01-02+09%3A30%29" + assert ( + str(builder.request.params) + == "x=nxr.%282000-01-02+08%3A30%2C2000-01-02+09%3A30%29" + ) def test_range_adjacent(filter_request_builder): @@ -161,14 +178,17 @@ def test_range_adjacent(filter_request_builder): ) # {a,["b",+"c"]} - assert str(builder.params) == "x=adj.%282000-01-02+08%3A30%2C2000-01-02+09%3A30%29" + assert ( + str(builder.request.params) + == "x=adj.%282000-01-02+08%3A30%2C2000-01-02+09%3A30%29" + ) def test_overlaps(filter_request_builder): builder = filter_request_builder.overlaps("x", ["is:closed", "severity:high"]) # {a,["b",+"c"]} - assert str(builder.params) == "x=ov.%7Bis%3Aclosed%2Cseverity%3Ahigh%7D" + assert str(builder.request.params) == "x=ov.%7Bis%3Aclosed%2Cseverity%3Ahigh%7D" def test_overlaps_with_timestamp_range(filter_request_builder): @@ -177,68 +197,71 @@ def test_overlaps_with_timestamp_range(filter_request_builder): ) # {a,["b",+"c"]} - assert str(builder.params) == "x=ov.%5B2000-01-01+12%3A45%2C+2000-01-01+13%3A15%29" + assert ( + str(builder.request.params) + == "x=ov.%5B2000-01-01+12%3A45%2C+2000-01-01+13%3A15%29" + ) def test_like(filter_request_builder): builder = filter_request_builder.like("x", "%a%") - assert str(builder.params) == "x=like.%25a%25" + assert str(builder.request.params) == "x=like.%25a%25" def test_ilike(filter_request_builder): builder = filter_request_builder.ilike("x", "%a%") - assert str(builder.params) == "x=ilike.%25a%25" + assert str(builder.request.params) == "x=ilike.%25a%25" def test_like_all_of(filter_request_builder): builder = filter_request_builder.like_all_of("x", "A*,*b") - assert str(builder.params) == "x=like%28all%29.%7BA%2A%2C%2Ab%7D" + assert str(builder.request.params) == "x=like%28all%29.%7BA%2A%2C%2Ab%7D" def test_like_any_of(filter_request_builder): builder = filter_request_builder.like_any_of("x", "a*,*b") - assert str(builder.params) == "x=like%28any%29.%7Ba%2A%2C%2Ab%7D" + assert str(builder.request.params) == "x=like%28any%29.%7Ba%2A%2C%2Ab%7D" def test_ilike_all_of(filter_request_builder): builder = filter_request_builder.ilike_all_of("x", "A*,*b") - assert str(builder.params) == "x=ilike%28all%29.%7BA%2A%2C%2Ab%7D" + assert str(builder.request.params) == "x=ilike%28all%29.%7BA%2A%2C%2Ab%7D" def test_ilike_any_of(filter_request_builder): builder = filter_request_builder.ilike_any_of("x", "A*,*b") - assert str(builder.params) == "x=ilike%28any%29.%7BA%2A%2C%2Ab%7D" + assert str(builder.request.params) == "x=ilike%28any%29.%7BA%2A%2C%2Ab%7D" def test_is_(filter_request_builder): builder = filter_request_builder.is_("x", "a") - assert str(builder.params) == "x=is.a" + assert str(builder.request.params) == "x=is.a" def test_in_(filter_request_builder): builder = filter_request_builder.in_("x", ["a", "b"]) - assert str(builder.params) == "x=in.%28a%2Cb%29" + assert str(builder.request.params) == "x=in.%28a%2Cb%29" def test_or_(filter_request_builder): builder = filter_request_builder.or_("x.eq.1") - assert str(builder.params) == "or=%28x.eq.1%29" + assert str(builder.request.params) == "or=%28x.eq.1%29" def test_or_in_contain(filter_request_builder): builder = filter_request_builder.or_("id.in.(5,6,7), arraycol.cs.{'a','b'}") assert ( - str(builder.params) + str(builder.request.params) == "or=%28id.in.%285%2C6%2C7%29%2C+arraycol.cs.%7B%27a%27%2C%27b%27%7D%29" ) @@ -246,26 +269,29 @@ def test_or_in_contain(filter_request_builder): def test_max_affected(filter_request_builder): builder = filter_request_builder.max_affected(5) - assert builder.headers["prefer"] == "handling=strict,max-affected=5" + assert builder.request.headers["prefer"] == "handling=strict,max-affected=5" def test_max_affected_with_existing_prefer_header(filter_request_builder): # Set an existing prefer header - filter_request_builder.headers["prefer"] = "return=representation" + filter_request_builder.request.headers["prefer"] = "return=representation" builder = filter_request_builder.max_affected(10) assert ( - builder.headers["prefer"] + builder.request.headers["prefer"] == "return=representation,handling=strict,max-affected=10" ) def test_max_affected_with_existing_handling_strict(filter_request_builder): # Set an existing prefer header with handling=strict - filter_request_builder.headers["prefer"] = "handling=strict,return=minimal" + filter_request_builder.request.headers["prefer"] = "handling=strict,return=minimal" builder = filter_request_builder.max_affected(3) - assert builder.headers["prefer"] == "handling=strict,return=minimal,max-affected=3" + assert ( + builder.request.headers["prefer"] + == "handling=strict,return=minimal,max-affected=3" + ) def test_max_affected_returns_self(filter_request_builder): diff --git a/src/postgrest/tests/_sync/test_query_request_builder.py b/src/postgrest/tests/_sync/test_query_request_builder.py index 81b3a3ef..58fbf2fc 100644 --- a/src/postgrest/tests/_sync/test_query_request_builder.py +++ b/src/postgrest/tests/_sync/test_query_request_builder.py @@ -1,22 +1,27 @@ +from typing import Iterable + import pytest from httpx import Client, Headers, QueryParams +from yarl import URL from postgrest import SyncQueryRequestBuilder +from postgrest._sync.request_builder import RequestConfig @pytest.fixture -def query_request_builder(): +def query_request_builder() -> Iterable[SyncQueryRequestBuilder]: with Client() as client: - yield SyncQueryRequestBuilder( - client, "/example_table", "GET", Headers(), QueryParams(), {} + request = RequestConfig( + client, URL("/example_table"), "GET", Headers(), QueryParams(), None, {} ) + yield SyncQueryRequestBuilder(request) def test_constructor(query_request_builder: SyncQueryRequestBuilder): builder = query_request_builder - assert builder.path == "/example_table" - assert len(builder.headers) == 0 - assert len(builder.params) == 0 - assert builder.http_method == "GET" - assert builder.json is None + assert str(builder.request.path) == "/example_table" + assert len(builder.request.headers) == 0 + assert len(builder.request.params) == 0 + assert builder.request.http_method == "GET" + assert builder.request.json is None diff --git a/src/postgrest/tests/_sync/test_request_builder.py b/src/postgrest/tests/_sync/test_request_builder.py index 89b807dd..5217de24 100644 --- a/src/postgrest/tests/_sync/test_request_builder.py +++ b/src/postgrest/tests/_sync/test_request_builder.py @@ -1,52 +1,54 @@ -from typing import Any, Dict, List +from typing import Any, Dict, Iterable, List import pytest -from httpx import Client, Request, Response +from httpx import Client, Headers, QueryParams, Request, Response +from yarl import URL from postgrest import SyncRequestBuilder, SyncSingleRequestBuilder +from postgrest._async.request_builder import RequestConfig from postgrest.base_request_builder import APIResponse, SingleAPIResponse from postgrest.types import JSON, CountMethod @pytest.fixture -def request_builder(): +def request_builder() -> Iterable[SyncRequestBuilder]: with Client() as client: - yield SyncRequestBuilder(client, "/example_table") + yield SyncRequestBuilder(client, URL("/example_table"), Headers(), None) def test_constructor(request_builder): - assert request_builder.path == "/example_table" + assert str(request_builder.path) == "/example_table" class TestSelect: def test_select(self, request_builder: SyncRequestBuilder): builder = request_builder.select("col1", "col2") - assert builder.params["select"] == "col1,col2" - assert builder.headers.get("prefer") is None - assert builder.http_method == "GET" - assert builder.json is None + assert builder.request.params["select"] == "col1,col2" + assert builder.request.headers.get("prefer") is None + assert builder.request.http_method == "GET" + assert builder.request.json is None def test_select_with_count(self, request_builder: SyncRequestBuilder): builder = request_builder.select(count=CountMethod.exact) - assert builder.params["select"] == "*" - assert builder.headers["prefer"] == "count=exact" - assert builder.http_method == "GET" - assert builder.json is None + assert builder.request.params["select"] == "*" + assert builder.request.headers["prefer"] == "count=exact" + assert builder.request.http_method == "GET" + assert builder.request.json is None def test_select_with_head(self, request_builder: SyncRequestBuilder): builder = request_builder.select("col1", "col2", head=True) - assert builder.params.get("select") == "col1,col2" - assert builder.headers.get("prefer") is None - assert builder.http_method == "HEAD" - assert builder.json is None + assert builder.request.params.get("select") == "col1,col2" + assert builder.request.headers.get("prefer") is None + assert builder.request.http_method == "HEAD" + assert builder.request.json is None def test_select_as_csv(self, request_builder: SyncRequestBuilder): builder = request_builder.select("*").csv() - assert builder.headers["Accept"] == "text/csv" + assert builder.request.headers["Accept"] == "text/csv" assert isinstance(builder, SyncSingleRequestBuilder) @@ -54,77 +56,85 @@ class TestInsert: def test_insert(self, request_builder: SyncRequestBuilder): builder = request_builder.insert({"key1": "val1"}) - assert builder.headers.get_list("prefer", True) == ["return=representation"] - assert builder.http_method == "POST" - assert builder.json == {"key1": "val1"} + assert builder.request.headers.get_list("prefer", True) == [ + "return=representation" + ] + assert builder.request.http_method == "POST" + assert builder.request.json == {"key1": "val1"} def test_insert_with_count(self, request_builder: SyncRequestBuilder): builder = request_builder.insert({"key1": "val1"}, count=CountMethod.exact) - assert builder.headers.get_list("prefer", True) == [ + assert builder.request.headers.get_list("prefer", True) == [ "return=representation", "count=exact", ] - assert builder.http_method == "POST" - assert builder.json == {"key1": "val1"} + assert builder.request.http_method == "POST" + assert builder.request.json == {"key1": "val1"} def test_insert_with_upsert(self, request_builder: SyncRequestBuilder): builder = request_builder.insert({"key1": "val1"}, upsert=True) - assert builder.headers.get_list("prefer", True) == [ + assert builder.request.headers.get_list("prefer", True) == [ "return=representation", "resolution=merge-duplicates", ] - assert builder.http_method == "POST" - assert builder.json == {"key1": "val1"} + assert builder.request.http_method == "POST" + assert builder.request.json == {"key1": "val1"} def test_upsert_with_default_single(self, request_builder: SyncRequestBuilder): builder = request_builder.upsert([{"key1": "val1"}], default_to_null=False) - assert builder.headers.get_list("prefer", True) == [ + assert builder.request.headers.get_list("prefer", True) == [ "return=representation", "resolution=merge-duplicates", "missing=default", ] - assert builder.http_method == "POST" - assert builder.json == [{"key1": "val1"}] - assert builder.params.get("columns") == '"key1"' + assert builder.request.http_method == "POST" + assert builder.request.json == [{"key1": "val1"}] + assert builder.request.params.get("columns") == '"key1"' def test_bulk_insert_using_default(self, request_builder: SyncRequestBuilder): builder = request_builder.insert( [{"key1": "val1", "key2": "val2"}, {"key3": "val3"}], default_to_null=False ) - assert builder.headers.get_list("prefer", True) == [ + assert builder.request.headers.get_list("prefer", True) == [ "return=representation", "missing=default", ] - assert builder.http_method == "POST" - assert builder.json == [{"key1": "val1", "key2": "val2"}, {"key3": "val3"}] - assert set(builder.params["columns"].split(",")) == set( + assert builder.request.http_method == "POST" + assert builder.request.json == [ + {"key1": "val1", "key2": "val2"}, + {"key3": "val3"}, + ] + assert set(builder.request.params["columns"].split(",")) == set( '"key1","key2","key3"'.split(",") ) def test_upsert(self, request_builder: SyncRequestBuilder): builder = request_builder.upsert({"key1": "val1"}) - assert builder.headers.get_list("prefer", True) == [ + assert builder.request.headers.get_list("prefer", True) == [ "return=representation", "resolution=merge-duplicates", ] - assert builder.http_method == "POST" - assert builder.json == {"key1": "val1"} + assert builder.request.http_method == "POST" + assert builder.request.json == {"key1": "val1"} def test_bulk_upsert_with_default(self, request_builder: SyncRequestBuilder): builder = request_builder.upsert( [{"key1": "val1", "key2": "val2"}, {"key3": "val3"}], default_to_null=False ) - assert builder.headers.get_list("prefer", True) == [ + assert builder.request.headers.get_list("prefer", True) == [ "return=representation", "resolution=merge-duplicates", "missing=default", ] - assert builder.http_method == "POST" - assert builder.json == [{"key1": "val1", "key2": "val2"}, {"key3": "val3"}] - assert set(builder.params["columns"].split(",")) == set( + assert builder.request.http_method == "POST" + assert builder.request.json == [ + {"key1": "val1", "key2": "val2"}, + {"key3": "val3"}, + ] + assert set(builder.request.params["columns"].split(",")) == set( '"key1","key2","key3"'.split(",") ) @@ -133,56 +143,60 @@ class TestUpdate: def test_update(self, request_builder: SyncRequestBuilder): builder = request_builder.update({"key1": "val1"}) - assert builder.headers.get_list("prefer", True) == ["return=representation"] - assert builder.http_method == "PATCH" - assert builder.json == {"key1": "val1"} + assert builder.request.headers.get_list("prefer", True) == [ + "return=representation" + ] + assert builder.request.http_method == "PATCH" + assert builder.request.json == {"key1": "val1"} def test_update_with_count(self, request_builder: SyncRequestBuilder): builder = request_builder.update({"key1": "val1"}, count=CountMethod.exact) - assert builder.headers.get_list("prefer", True) == [ + assert builder.request.headers.get_list("prefer", True) == [ "return=representation", "count=exact", ] - assert builder.http_method == "PATCH" - assert builder.json == {"key1": "val1"} + assert builder.request.http_method == "PATCH" + assert builder.request.json == {"key1": "val1"} def test_update_with_max_affected(self, request_builder: SyncRequestBuilder): builder = request_builder.update({"key1": "val1"}).max_affected(5) - assert "handling=strict" in builder.headers["prefer"] - assert "max-affected=5" in builder.headers["prefer"] - assert "return=representation" in builder.headers["prefer"] - assert builder.http_method == "PATCH" - assert builder.json == {"key1": "val1"} + assert "handling=strict" in builder.request.headers["prefer"] + assert "max-affected=5" in builder.request.headers["prefer"] + assert "return=representation" in builder.request.headers["prefer"] + assert builder.request.http_method == "PATCH" + assert builder.request.json == {"key1": "val1"} class TestDelete: def test_delete(self, request_builder: SyncRequestBuilder): builder = request_builder.delete() - assert builder.headers.get_list("prefer", True) == ["return=representation"] - assert builder.http_method == "DELETE" - assert builder.json == {} + assert builder.request.headers.get_list("prefer", True) == [ + "return=representation" + ] + assert builder.request.http_method == "DELETE" + assert builder.request.json == {} def test_delete_with_count(self, request_builder: SyncRequestBuilder): builder = request_builder.delete(count=CountMethod.exact) - assert builder.headers.get_list("prefer", True) == [ + assert builder.request.headers.get_list("prefer", True) == [ "return=representation", "count=exact", ] - assert builder.http_method == "DELETE" - assert builder.json == {} + assert builder.request.http_method == "DELETE" + assert builder.request.json == {} def test_delete_with_max_affected(self, request_builder: SyncRequestBuilder): builder = request_builder.delete().max_affected(10) - assert "handling=strict" in builder.headers["prefer"] - assert "max-affected=10" in builder.headers["prefer"] - assert "return=representation" in builder.headers["prefer"] - assert builder.http_method == "DELETE" - assert builder.json == {} + assert "handling=strict" in builder.request.headers["prefer"] + assert "max-affected=10" in builder.request.headers["prefer"] + assert "return=representation" in builder.request.headers["prefer"] + assert builder.request.http_method == "DELETE" + assert builder.request.json == {} class TestTextSearch: @@ -196,31 +210,35 @@ def test_text_search(self, request_builder: SyncRequestBuilder): }, ) assert "catchphrase=plfts%28english%29.%27fat%27+%26+%27cat%27" in str( - builder.params + builder.request.params ) class TestExplain: def test_explain_plain(self, request_builder: SyncRequestBuilder): builder = request_builder.select("*").explain() - assert builder.params["select"] == "*" - assert "application/vnd.pgrst.plan" in str(builder.headers.get("accept")) + assert builder.request.params["select"] == "*" + assert "application/vnd.pgrst.plan" in str( + builder.request.headers.get("accept") + ) def test_explain_options(self, request_builder: SyncRequestBuilder): builder = request_builder.select("*").explain( format="json", analyze=True, verbose=True, buffers=True, wal=True ) - assert builder.params["select"] == "*" - assert "application/vnd.pgrst.plan+json;" in str(builder.headers.get("accept")) + assert builder.request.params["select"] == "*" + assert "application/vnd.pgrst.plan+json;" in str( + builder.request.headers.get("accept") + ) assert "options=analyze|verbose|buffers|wal" in str( - builder.headers.get("accept") + builder.request.headers.get("accept") ) class TestOrder: def test_order(self, request_builder: SyncRequestBuilder): builder = request_builder.select().order("country_name", desc=True) - assert str(builder.params) == "select=%2A&order=country_name.desc" + assert str(builder.request.params) == "select=%2A&order=country_name.desc" def test_multiple_orders(self, request_builder: SyncRequestBuilder): builder = ( @@ -228,7 +246,10 @@ def test_multiple_orders(self, request_builder: SyncRequestBuilder): .order("country_name", desc=True) .order("iso", desc=True) ) - assert str(builder.params) == "select=%2A&order=country_name.desc%2Ciso.desc" + assert ( + str(builder.request.params) + == "select=%2A&order=country_name.desc%2Ciso.desc" + ) def test_multiple_orders_on_foreign_table( self, request_builder: SyncRequestBuilder @@ -239,22 +260,25 @@ def test_multiple_orders_on_foreign_table( .order("city_name", desc=True, foreign_table=foreign_table) .order("id", desc=True, foreign_table=foreign_table) ) - assert str(builder.params) == "select=%2A&cities.order=city_name.desc%2Cid.desc" + assert ( + str(builder.request.params) + == "select=%2A&cities.order=city_name.desc%2Cid.desc" + ) class TestRange: def test_range_on_own_table(self, request_builder: SyncRequestBuilder): builder = request_builder.select("*").range(0, 1) - assert builder.params["select"] == "*" - assert builder.params["limit"] == "2" - assert builder.params["offset"] == "0" + assert builder.request.params["select"] == "*" + assert builder.request.params["limit"] == "2" + assert builder.request.params["offset"] == "0" def test_range_on_foreign_table(self, request_builder: SyncRequestBuilder): foreign_table = "cities" builder = request_builder.select("*").range(1, 2, foreign_table) - assert builder.params["select"] == "*" - assert builder.params[f"{foreign_table}.limit"] == "2" - assert builder.params[f"{foreign_table}.offset"] == "1" + assert builder.request.params["select"] == "*" + assert builder.request.params[f"{foreign_table}.limit"] == "2" + assert builder.request.params[f"{foreign_table}.offset"] == "1" @pytest.fixture diff --git a/src/supabase/tests/_async/test_client.py b/src/supabase/tests/_async/test_client.py index fc70f382..6bd28dff 100644 --- a/src/supabase/tests/_async/test_client.py +++ b/src/supabase/tests/_async/test_client.py @@ -239,3 +239,46 @@ async def test_custom_headers_immutable(): assert client1.options.headers.get("x-app-name") == "grapes" assert client1.options.headers.get("x-version") == "1.0" assert client2.options.headers.get("x-app-name") == "apple" + + +async def test_httpx_client_base_url_isolation(): + """Test that shared httpx_client doesn't cause base_url mutation between services. + This test reproduces the issue where accessing PostgREST after Storage causes + Storage requests to hit the wrong endpoint (404 errors). + See: https://github.com/supabase/supabase-py/issues/1244 + """ + url = os.environ.get("SUPABASE_TEST_URL") + key = os.environ.get("SUPABASE_TEST_KEY") + + # Create client with shared httpx instance + timeout = Timeout(10.0, read=60.0) + httpx_client = AsyncHttpxClient(timeout=timeout) + options = AsyncClientOptions(httpx_client=httpx_client) + client = await create_async_client(url, key, options) + + # Access storage and capture its base_url + storage = client.storage + storage_base_url = str(storage._base_url).rstrip("/") + assert storage_base_url.endswith("/storage/v1"), ( + f"Expected storage base_url to end with '/storage/v1', got {storage_base_url}" + ) + + # Access postgrest (this should NOT mutate storage's base_url) + postgrest = client.postgrest + postgrest_base_url = str(postgrest.base_url).rstrip("/") + assert postgrest_base_url.endswith("/rest/v1"), ( + f"Expected postgrest base_url to end with '/rest/v1', got {postgrest_base_url}" + ) + + # Verify storage still has the correct base_url + storage_base_url_after = str(storage._base_url).rstrip("/") + assert storage_base_url_after.endswith("/storage/v1"), ( + f"Storage base_url was mutated! Expected '/storage/v1', got {storage_base_url_after}" + ) + + assert str(storage._base_url).rstrip("/").endswith("/storage/v1"), ( + "Storage base_url was mutated after accessing functions" + ) + assert str(postgrest.base_url).rstrip("/").endswith("/rest/v1"), ( + "PostgREST base_url was mutated after accessing functions" + ) diff --git a/uv.lock b/uv.lock index fef9ef94..5e4ba1dd 100644 --- a/uv.lock +++ b/uv.lock @@ -1540,6 +1540,7 @@ dependencies = [ { name = "httpx", extra = ["http2"] }, { name = "pydantic" }, { name = "strenum", marker = "python_full_version < '3.11'" }, + { name = "yarl" }, ] [package.dev-dependencies] @@ -1580,6 +1581,7 @@ requires-dist = [ { name = "httpx", extras = ["http2"], specifier = ">=0.26,<0.29" }, { name = "pydantic", specifier = ">=1.9,<3.0" }, { name = "strenum", marker = "python_full_version < '3.11'", specifier = ">=0.4.9" }, + { name = "yarl", specifier = ">=1.20.1" }, ] [package.metadata.requires-dev] From 05bb433b289b78a6b26fa92a50d0f67b9ca39e6c Mon Sep 17 00:00:00 2001 From: Leonardo Santiago Date: Tue, 7 Oct 2025 13:37:25 -0300 Subject: [PATCH 3/5] chore(postgrest): remove print --- src/postgrest/tests/_async/test_client.py | 1 - src/postgrest/tests/_sync/test_client.py | 1 - 2 files changed, 2 deletions(-) diff --git a/src/postgrest/tests/_async/test_client.py b/src/postgrest/tests/_async/test_client.py index 292a0401..18a95627 100644 --- a/src/postgrest/tests/_async/test_client.py +++ b/src/postgrest/tests/_async/test_client.py @@ -35,7 +35,6 @@ def test_simple(self, postgrest_client: AsyncPostgrestClient): "Content-Profile": "public", } ) - print(session.headers) assert session.headers.items() >= headers.items() @pytest.mark.asyncio diff --git a/src/postgrest/tests/_sync/test_client.py b/src/postgrest/tests/_sync/test_client.py index c2fdb771..7f7c7610 100644 --- a/src/postgrest/tests/_sync/test_client.py +++ b/src/postgrest/tests/_sync/test_client.py @@ -35,7 +35,6 @@ def test_simple(self, postgrest_client: SyncPostgrestClient): "Content-Profile": "public", } ) - print(session.headers) assert session.headers.items() >= headers.items() def test_custom_headers(self): From 304ccbb33bed9377f63b4c19823f35d1d3ff3b8e Mon Sep 17 00:00:00 2001 From: Leonardo Santiago Date: Tue, 7 Oct 2025 13:53:14 -0300 Subject: [PATCH 4/5] fix: remove unused imports --- src/postgrest/src/postgrest/_async/client.py | 1 - src/postgrest/src/postgrest/_sync/client.py | 1 - 2 files changed, 2 deletions(-) diff --git a/src/postgrest/src/postgrest/_async/client.py b/src/postgrest/src/postgrest/_async/client.py index 6e9c8061..00438141 100644 --- a/src/postgrest/src/postgrest/_async/client.py +++ b/src/postgrest/src/postgrest/_async/client.py @@ -5,7 +5,6 @@ from deprecation import deprecated from httpx import AsyncClient, Headers, QueryParams, Timeout -from httpx._types import HeaderTypes from yarl import URL from ..base_client import BasePostgrestClient diff --git a/src/postgrest/src/postgrest/_sync/client.py b/src/postgrest/src/postgrest/_sync/client.py index 9df8ae8f..9c94f6d1 100644 --- a/src/postgrest/src/postgrest/_sync/client.py +++ b/src/postgrest/src/postgrest/_sync/client.py @@ -5,7 +5,6 @@ from deprecation import deprecated from httpx import Client, Headers, QueryParams, Timeout -from httpx._types import HeaderTypes from yarl import URL from ..base_client import BasePostgrestClient From ee2247db6a74b4b6661389c33274f5ca9935ab36 Mon Sep 17 00:00:00 2001 From: Leonardo Santiago Date: Tue, 7 Oct 2025 15:13:41 -0300 Subject: [PATCH 5/5] fix: do not mutate functions URL --- src/functions/pyproject.toml | 1 + .../_async/functions_client.py | 37 ++++++++--------- .../_sync/functions_client.py | 41 +++++++++---------- .../tests/_async/test_function_client.py | 2 +- .../tests/_sync/test_function_client.py | 2 +- src/functions/tests/test_client.py | 4 +- uv.lock | 2 + 7 files changed, 43 insertions(+), 46 deletions(-) diff --git a/src/functions/pyproject.toml b/src/functions/pyproject.toml index 75ebb249..7497bba1 100644 --- a/src/functions/pyproject.toml +++ b/src/functions/pyproject.toml @@ -15,6 +15,7 @@ requires-python = ">=3.9" dependencies = [ "httpx[http2] >=0.26,<0.29", "strenum >=0.4.15", + "yarl>=1.20.1", ] diff --git a/src/functions/src/supabase_functions/_async/functions_client.py b/src/functions/src/supabase_functions/_async/functions_client.py index ea2c502c..f35cce58 100644 --- a/src/functions/src/supabase_functions/_async/functions_client.py +++ b/src/functions/src/supabase_functions/_async/functions_client.py @@ -2,6 +2,7 @@ from warnings import warn from httpx import AsyncClient, HTTPError, Response, QueryParams +from yarl import URL from ..errors import FunctionsHttpError, FunctionsRelayError from ..utils import ( @@ -24,7 +25,7 @@ def __init__( ): if not is_http_url(url): raise ValueError("url must be a valid HTTP URL string") - self.url = url + self.url = URL(url) self.headers = { "User-Agent": f"supabase-py/functions-py v{__version__}", **headers, @@ -51,37 +52,32 @@ def __init__( self.verify = bool(verify) if verify is not None else True self.timeout = int(abs(timeout)) if timeout is not None else 60 - - if http_client is not None: - http_client.base_url = self.url - http_client.headers.update({**self.headers}) - self._client = http_client - else: - self._client = AsyncClient( - base_url=self.url, - headers=self.headers, - verify=self.verify, - timeout=self.timeout, - proxy=proxy, - follow_redirects=True, - http2=True, - ) + self._client = http_client or AsyncClient( + verify=self.verify, + timeout=self.timeout, + proxy=proxy, + follow_redirects=True, + http2=True, + ) async def _request( self, method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"], - url: str, + path: list[str], headers: Optional[Dict[str, str]] = None, json: Optional[Dict[Any, Any]] = None, params: Optional[QueryParams] = None, ) -> Response: + url = self.url.joinpath(*path) + headers = headers or dict() + headers.update(self.headers) response = ( await self._client.request( - method, url, data=json, headers=headers, params=params + method, str(url), data=json, headers=headers, params=params ) if isinstance(json, str) else await self._client.request( - method, url, json=json, headers=headers, params=params + method, str(url), json=json, headers=headers, params=params ) ) try: @@ -129,7 +125,6 @@ async def invoke( params = QueryParams() body = None response_type = "text/plain" - url = f"{self.url}/{function_name}" if invoke_options is not None: headers.update(invoke_options.get("headers", {})) @@ -153,7 +148,7 @@ async def invoke( headers["Content-Type"] = "application/json" response = await self._request( - "POST", url, headers=headers, json=body, params=params + "POST", [function_name], headers=headers, json=body, params=params ) is_relay_error = response.headers.get("x-relay-header") diff --git a/src/functions/src/supabase_functions/_sync/functions_client.py b/src/functions/src/supabase_functions/_sync/functions_client.py index 8063a7d3..a2c12455 100644 --- a/src/functions/src/supabase_functions/_sync/functions_client.py +++ b/src/functions/src/supabase_functions/_sync/functions_client.py @@ -2,6 +2,7 @@ from warnings import warn from httpx import Client, HTTPError, Response, QueryParams +from yarl import URL from ..errors import FunctionsHttpError, FunctionsRelayError from ..utils import ( @@ -24,7 +25,7 @@ def __init__( ): if not is_http_url(url): raise ValueError("url must be a valid HTTP URL string") - self.url = url + self.url = URL(url) self.headers = { "User-Agent": f"supabase-py/functions-py v{__version__}", **headers, @@ -51,35 +52,32 @@ def __init__( self.verify = bool(verify) if verify is not None else True self.timeout = int(abs(timeout)) if timeout is not None else 60 - - if http_client is not None: - http_client.base_url = self.url - http_client.headers.update({**self.headers}) - self._client = http_client - else: - self._client = Client( - base_url=self.url, - headers=self.headers, - verify=self.verify, - timeout=self.timeout, - proxy=proxy, - follow_redirects=True, - http2=True, - ) + self._client = http_client or Client( + verify=self.verify, + timeout=self.timeout, + proxy=proxy, + follow_redirects=True, + http2=True, + ) def _request( self, method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"], - url: str, + path: list[str], headers: Optional[Dict[str, str]] = None, json: Optional[Dict[Any, Any]] = None, params: Optional[QueryParams] = None, ) -> Response: + url = self.url.joinpath(*path) + headers = headers or dict() + headers.update(self.headers) response = ( - self._client.request(method, url, data=json, headers=headers, params=params) + self._client.request( + method, str(url), data=json, headers=headers, params=params + ) if isinstance(json, str) else self._client.request( - method, url, json=json, headers=headers, params=params + method, str(url), json=json, headers=headers, params=params ) ) try: @@ -127,7 +125,6 @@ def invoke( params = QueryParams() body = None response_type = "text/plain" - url = f"{self.url}/{function_name}" if invoke_options is not None: headers.update(invoke_options.get("headers", {})) @@ -150,7 +147,9 @@ def invoke( elif isinstance(body, dict): headers["Content-Type"] = "application/json" - response = self._request("POST", url, headers=headers, json=body, params=params) + response = self._request( + "POST", [function_name], headers=headers, json=body, params=params + ) is_relay_error = response.headers.get("x-relay-header") if is_relay_error and is_relay_error == "true": diff --git a/src/functions/tests/_async/test_function_client.py b/src/functions/tests/_async/test_function_client.py index d1f47b97..b359b4b6 100644 --- a/src/functions/tests/_async/test_function_client.py +++ b/src/functions/tests/_async/test_function_client.py @@ -31,7 +31,7 @@ async def test_init_with_valid_params(valid_url, default_headers): client = AsyncFunctionsClient( url=valid_url, headers=default_headers, timeout=10, verify=True ) - assert client.url == valid_url + assert str(client.url) == valid_url assert "User-Agent" in client.headers assert client.headers["User-Agent"] == f"supabase-py/functions-py v{__version__}" assert client._client.timeout == Timeout(10) diff --git a/src/functions/tests/_sync/test_function_client.py b/src/functions/tests/_sync/test_function_client.py index 6489e91a..1ed4cc98 100644 --- a/src/functions/tests/_sync/test_function_client.py +++ b/src/functions/tests/_sync/test_function_client.py @@ -31,7 +31,7 @@ def test_init_with_valid_params(valid_url, default_headers): client = SyncFunctionsClient( url=valid_url, headers=default_headers, timeout=10, verify=True ) - assert client.url == valid_url + assert str(client.url) == valid_url assert "User-Agent" in client.headers assert client.headers["User-Agent"] == f"supabase-py/functions-py v{__version__}" assert client._client.timeout == Timeout(10) diff --git a/src/functions/tests/test_client.py b/src/functions/tests/test_client.py index 60929098..804f3708 100644 --- a/src/functions/tests/test_client.py +++ b/src/functions/tests/test_client.py @@ -22,7 +22,7 @@ def test_create_async_client(valid_url, valid_headers): ) assert isinstance(client, AsyncFunctionsClient) - assert client.url == valid_url + assert str(client.url) == valid_url assert all(client.headers[key] == value for key, value in valid_headers.items()) @@ -33,7 +33,7 @@ def test_create_sync_client(valid_url, valid_headers): ) assert isinstance(client, SyncFunctionsClient) - assert client.url == valid_url + assert str(client.url) == valid_url assert all(client.headers[key] == value for key, value in valid_headers.items()) diff --git a/uv.lock b/uv.lock index 5e4ba1dd..af8fc7fd 100644 --- a/uv.lock +++ b/uv.lock @@ -2833,6 +2833,7 @@ source = { editable = "src/functions" } dependencies = [ { name = "httpx", extra = ["http2"] }, { name = "strenum" }, + { name = "yarl" }, ] [package.dev-dependencies] @@ -2865,6 +2866,7 @@ tests = [ requires-dist = [ { name = "httpx", extras = ["http2"], specifier = ">=0.26,<0.29" }, { name = "strenum", specifier = ">=0.4.15" }, + { name = "yarl", specifier = ">=1.20.1" }, ] [package.metadata.requires-dev]