Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/postgrest/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ unasync:

build-sync: unasync
sed -i 's/@pytest.mark.asyncio//g' tests/_sync/test_client.py
sed -i 's/_async/_sync/g' tests/_sync/test_client.py
sed -i 's/_async/_sync/g' tests/_sync/test_client.py tests/_sync/test_query_request_builder.py tests/_sync/test_filter_request_builder.py
sed -i 's/Async/Sync/g' src/postgrest/_sync/request_builder.py tests/_sync/test_client.py
sed -i 's/_client\.SyncClient/_client\.Client/g' tests/_sync/test_client.py
sed -i 's/SyncHTTPTransport/HTTPTransport/g' tests/_sync/**.py
Expand Down
1 change: 1 addition & 0 deletions src/postgrest/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies = [
"deprecation >=2.1.0",
"pydantic >=1.9,<3.0",
"strenum >=0.4.9; python_version < \"3.11\"",
"yarl>=1.20.1",
]

[project.urls]
Expand Down
75 changes: 33 additions & 42 deletions src/postgrest/src/postgrest/_async/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from deprecation import deprecated
from httpx import AsyncClient, Headers, QueryParams, Timeout
from httpx._types import HeaderTypes
from yarl import URL

from ..base_client import BasePostgrestClient
from ..constants import (
Expand All @@ -13,7 +15,11 @@
)
from ..types import CountMethod
from ..version import __version__
from .request_builder import AsyncRequestBuilder, AsyncRPCFilterRequestBuilder
from .request_builder import (
AsyncRequestBuilder,
AsyncRPCFilterRequestBuilder,
RequestConfig,
)


class AsyncPostgrestClient(BasePostgrestClient):
Expand Down Expand Up @@ -59,52 +65,32 @@ def __init__(
else DEFAULT_POSTGREST_CLIENT_TIMEOUT
)
)

BasePostgrestClient.__init__(
self,
base_url,
URL(base_url),
schema=schema,
headers=headers,
timeout=self.timeout,
verify=self.verify,
proxy=proxy,
http_client=http_client,
)
self.session: AsyncClient = self.session

def create_session(
self,
base_url: str,
headers: Dict[str, str],
timeout: Union[int, float, Timeout],
verify: bool = True,
proxy: Optional[str] = None,
) -> AsyncClient:
http_client = None
if isinstance(self.http_client, AsyncClient):
http_client = self.http_client

if http_client is not None:
http_client.base_url = base_url
http_client.headers.update({**headers})
return http_client

return AsyncClient(
self.session = http_client or AsyncClient(
base_url=base_url,
headers=headers,
headers=self.headers,
timeout=timeout,
verify=verify,
verify=self.verify,
proxy=proxy,
follow_redirects=True,
http2=True,
)

def schema(self, schema: str):
def schema(self, schema: str) -> AsyncPostgrestClient:
"""Switch to another schema."""
return AsyncPostgrestClient(
base_url=self.base_url,
base_url=str(self.base_url),
schema=schema,
headers=self.headers,
headers=dict(self.headers),
timeout=self.timeout,
verify=self.verify,
proxy=self.proxy,
Expand All @@ -128,7 +114,9 @@ def from_(self, table: str) -> AsyncRequestBuilder:
Returns:
:class:`AsyncRequestBuilder`
"""
return AsyncRequestBuilder(self.session, f"/{table}")
return AsyncRequestBuilder(
self.session, self.base_url.joinpath(table), self.headers, self.basic_auth
)

def table(self, table: str) -> AsyncRequestBuilder:
"""Alias to :meth:`from_`."""
Expand All @@ -142,7 +130,7 @@ def from_table(self, table: str) -> AsyncRequestBuilder:
def rpc(
self,
func: str,
params: dict,
params: dict[str, str],
count: Optional[CountMethod] = None,
head: bool = False,
get: bool = False,
Expand Down Expand Up @@ -171,17 +159,20 @@ def rpc(
method = "HEAD" if head else "GET" if get else "POST"

headers = Headers({"Prefer": f"count={count}"}) if count else Headers()

if method in ("HEAD", "GET"):
return AsyncRPCFilterRequestBuilder(
self.session,
f"/rpc/{func}",
method,
headers,
QueryParams(params),
json={},
)
headers.update(self.headers)
# the params here are params to be sent to the RPC and not the queryparams!
return AsyncRPCFilterRequestBuilder(
self.session, f"/rpc/{func}", method, headers, QueryParams(), json=params
json, http_params = (
({}, QueryParams(params))
if method in ("HEAD", "GET")
else (params, QueryParams())
)
request = RequestConfig(
self.session,
self.base_url.joinpath("rpc", func),
method,
headers,
http_params,
self.basic_auth,
json,
)
return AsyncRPCFilterRequestBuilder(request)
Loading