diff --git a/src/postgrest/Makefile b/src/postgrest/Makefile index 228ed2c6..f0d84a6e 100644 --- a/src/postgrest/Makefile +++ b/src/postgrest/Makefile @@ -1,26 +1,47 @@ -tests: pytest +help:: + @echo "Available commands" + @echo " help -- (default) print this message" + +tests: mypy pytest +help:: + @echo " tests -- run all tests for postgrest package" pytest: start-infra uv run --package postgrest pytest --cov=./ --cov-report=xml -vv +help:: + @echo " pytest -- run pytest on postgrest package" + +mypy: + uv run --package supabase_functions mypy src/postgrest tests +help:: + @echo " mypy -- run mypy on postgrest package" start-infra: cd infra &&\ docker compose down &&\ docker compose up -d sleep 2 +help:: + @echo " stop-infra -- start containers for tests" + +stop-infra: + cd infra &&\ + docker compose down --remove-orphans +help:: + @echo " stop-infra -- stop containers for tests" clean-infra: cd infra &&\ docker compose down --remove-orphans &&\ docker system prune -a --volumes -f - -stop-infra: - cd infra &&\ - docker compose down --remove-orphans +help:: + @echo " clean-infra -- delete all stored information about the containers" clean: rm -rf htmlcov .pytest_cache .mypy_cache .ruff_cache rm -f .coverage coverage.xml +help:: + @echo " clean -- clean intermediary files generated by tests" unasync: uv run --package postgrest run-unasync.py @@ -33,6 +54,10 @@ build-sync: unasync sed -i 's/SyncHTTPTransport/HTTPTransport/g' tests/_sync/**.py sed -i 's/SyncClient/Client/g' src/postgrest/_sync/**.py tests/_sync/**.py sed -i 's/self\.session\.aclose/self\.session\.close/g' src/postgrest/_sync/client.py +help:: + @echo " build-sync -- generate _sync from _async implementation" build: uv build --package postgrest +help:: + @echo " build -- invoke uv build on storage3 package" diff --git a/src/postgrest/pyproject.toml b/src/postgrest/pyproject.toml index 6291cc44..a67a4cd2 100644 --- a/src/postgrest/pyproject.toml +++ b/src/postgrest/pyproject.toml @@ -43,6 +43,9 @@ test = [ lints = [ "pre-commit >=4.2.0", "ruff >=0.12.1", + "python-lsp-server (>=1.12.2,<2.0.0)", + "pylsp-mypy (>=0.7.0,<0.8.0)", + "python-lsp-ruff (>=2.2.2,<3.0.0)", ] docs = [ "sphinx >=7.1.2", @@ -84,3 +87,15 @@ filterwarnings = [ [build-system] requires = ["uv_build>=0.8.3,<0.9.0"] build-backend = "uv_build" + +[tool.mypy] +python_version = "3.9" +check_untyped_defs = true +allow_redefinition = true +follow_untyped_imports = true # for deprecation module that does not have stubs + +no_warn_no_return = true +warn_return_any = true +warn_unused_configs = true +warn_redundant_casts = true +warn_unused_ignores = true diff --git a/src/postgrest/run-unasync.py b/src/postgrest/run-unasync.py index 034cfdc4..a0b811dd 100644 --- a/src/postgrest/run-unasync.py +++ b/src/postgrest/run-unasync.py @@ -2,7 +2,7 @@ import unasync -paths = Path("src/supabase").glob("**/*.py") +paths = Path("src/postgrest").glob("**/*.py") tests = Path("tests").glob("**/*.py") rules = (unasync._DEFAULT_RULE,) diff --git a/src/postgrest/src/postgrest/_async/client.py b/src/postgrest/src/postgrest/_async/client.py index 63fb95e5..896e7494 100644 --- a/src/postgrest/src/postgrest/_async/client.py +++ b/src/postgrest/src/postgrest/_async/client.py @@ -15,8 +15,6 @@ from ..version import __version__ from .request_builder import AsyncRequestBuilder, AsyncRPCFilterRequestBuilder -_TableT = Dict[str, Any] - class AsyncPostgrestClient(BasePostgrestClient): """PostgREST client.""" @@ -72,7 +70,7 @@ def __init__( proxy=proxy, http_client=http_client, ) - self.session = cast(AsyncClient, self.session) + self.session: AsyncClient = self.session def create_session( self, @@ -122,7 +120,7 @@ async def aclose(self) -> None: """Close the underlying HTTP connections.""" await self.session.aclose() - def from_(self, table: str) -> AsyncRequestBuilder[_TableT]: + def from_(self, table: str) -> AsyncRequestBuilder: """Perform a table operation. Args: @@ -130,9 +128,9 @@ def from_(self, table: str) -> AsyncRequestBuilder[_TableT]: Returns: :class:`AsyncRequestBuilder` """ - return AsyncRequestBuilder[_TableT](self.session, f"/{table}") + return AsyncRequestBuilder(self.session, f"/{table}") - def table(self, table: str) -> AsyncRequestBuilder[_TableT]: + def table(self, table: str) -> AsyncRequestBuilder: """Alias to :meth:`from_`.""" return self.from_(table) @@ -148,7 +146,7 @@ def rpc( count: Optional[CountMethod] = None, head: bool = False, get: bool = False, - ) -> AsyncRPCFilterRequestBuilder[Any]: + ) -> AsyncRPCFilterRequestBuilder: """Perform a stored procedure call. Args: @@ -175,7 +173,7 @@ def rpc( headers = Headers({"Prefer": f"count={count}"}) if count else Headers() if method in ("HEAD", "GET"): - return AsyncRPCFilterRequestBuilder[Any]( + return AsyncRPCFilterRequestBuilder( self.session, f"/rpc/{func}", method, @@ -184,6 +182,6 @@ def rpc( json={}, ) # the params here are params to be sent to the RPC and not the queryparams! - return AsyncRPCFilterRequestBuilder[Any]( + return AsyncRPCFilterRequestBuilder( self.session, f"/rpc/{func}", method, headers, QueryParams(), json=params ) diff --git a/src/postgrest/src/postgrest/_async/request_builder.py b/src/postgrest/src/postgrest/_async/request_builder.py index 3c2f5949..106845fa 100644 --- a/src/postgrest/src/postgrest/_async/request_builder.py +++ b/src/postgrest/src/postgrest/_async/request_builder.py @@ -4,6 +4,7 @@ from httpx import AsyncClient, Headers, QueryParams from pydantic import ValidationError +from typing_extensions import override from ..base_request_builder import ( APIResponse, @@ -19,13 +20,11 @@ pre_upsert, ) from ..exceptions import APIError, APIErrorFromJSON, generate_default_error_message -from ..types import ReturnMethod -from ..utils import get_origin_and_cast, model_validate_json +from ..types import JSON, ReturnMethod +from ..utils import model_validate_json -_ReturnT = TypeVar("_ReturnT") - -class AsyncQueryRequestBuilder(Generic[_ReturnT]): +class AsyncQueryRequestBuilder: def __init__( self, session: AsyncClient, @@ -33,7 +32,7 @@ def __init__( http_method: str, headers: Headers, params: QueryParams, - json: dict, + json: JSON, ) -> None: self.session = session self.path = path @@ -42,7 +41,7 @@ def __init__( self.params = params self.json = None if http_method in {"GET", "HEAD"} else json - async def execute(self) -> APIResponse[_ReturnT]: + async def execute(self) -> APIResponse | str: """Execute the query. .. tip:: @@ -72,7 +71,7 @@ async def execute(self) -> APIResponse[_ReturnT]: ) and "application/vnd.pgrst.plan" in self.headers.get("Accept"): if "+json" not in self.headers.get("Accept"): return body - return APIResponse[_ReturnT].from_http_request_response(r) + return APIResponse.from_http_request_response(r) else: json_obj = model_validate_json(APIErrorFromJSON, r.content) raise APIError(dict(json_obj)) @@ -80,7 +79,7 @@ async def execute(self) -> APIResponse[_ReturnT]: raise APIError(generate_default_error_message(r)) -class AsyncSingleRequestBuilder(Generic[_ReturnT]): +class AsyncSingleRequestBuilder: def __init__( self, session: AsyncClient, @@ -88,7 +87,7 @@ def __init__( http_method: str, headers: Headers, params: QueryParams, - json: dict, + json: JSON, ) -> None: self.session = session self.path = path @@ -97,17 +96,17 @@ def __init__( self.params = params self.json = json - async def execute(self) -> SingleAPIResponse[_ReturnT]: + async def execute(self) -> SingleAPIResponse: """Execute the query. - .. tip:: - This is the last method called, after the query is built. + .. tip:: + This is the last method called, after the query is built. - Returns: - :class:`SingleAPIResponse` - - Raises: - :class:`APIError` If the API raised an error. + Returns: + :class:`SingleAPIResponse` + na + Raises: + :class:`APIError` If the API raised an error. """ r = await self.session.request( self.http_method, @@ -120,7 +119,7 @@ async def execute(self) -> SingleAPIResponse[_ReturnT]: if ( 200 <= r.status_code <= 299 ): # Response.ok from JS (https://developer.mozilla.org/en-US/docs/Web/API/Response/ok) - return SingleAPIResponse[_ReturnT].from_http_request_response(r) + return SingleAPIResponse.from_http_request_response(r) else: json_obj = model_validate_json(APIErrorFromJSON, r.content) raise APIError(dict(json_obj)) @@ -128,11 +127,35 @@ async def execute(self) -> SingleAPIResponse[_ReturnT]: raise APIError(generate_default_error_message(r)) -class AsyncMaybeSingleRequestBuilder(AsyncSingleRequestBuilder[_ReturnT]): - async def execute(self) -> Optional[SingleAPIResponse[_ReturnT]]: +class AsyncMaybeSingleRequestBuilder: + def __init__( + self, + session: AsyncClient, + path: str, + http_method: str, + headers: Headers, + params: QueryParams, + json: JSON, + ) -> None: + self.session = session + self.path = path + self.http_method = http_method + self.headers = headers + self.params = params + self.json = json + + async def execute(self) -> Optional[SingleAPIResponse]: r = None try: - r = await AsyncSingleRequestBuilder[_ReturnT].execute(self) + request = AsyncSingleRequestBuilder( + self.session, + self.path, + self.http_method, + self.headers, + self.params, + self.json, + ) + r = await request.execute() except APIError as e: if e.details and "The result contains 0 rows" in e.details: return None @@ -148,10 +171,7 @@ async def execute(self) -> Optional[SingleAPIResponse[_ReturnT]]: return r -# ignoring type checking as a workaround for https://github.com/python/mypy/issues/9319 -class AsyncFilterRequestBuilder( - BaseFilterRequestBuilder[_ReturnT], AsyncQueryRequestBuilder[_ReturnT] -): # type: ignore +class AsyncFilterRequestBuilder(BaseFilterRequestBuilder, AsyncQueryRequestBuilder): def __init__( self, session: AsyncClient, @@ -159,20 +179,15 @@ def __init__( http_method: str, headers: Headers, params: QueryParams, - json: dict, + json: JSON, ) -> None: - get_origin_and_cast(BaseFilterRequestBuilder[_ReturnT]).__init__( - self, session, headers, params - ) - get_origin_and_cast(AsyncQueryRequestBuilder[_ReturnT]).__init__( + BaseFilterRequestBuilder.__init__(self, headers, params) + AsyncQueryRequestBuilder.__init__( self, session, path, http_method, headers, params, json ) -# this exists for type-safety. see https://gist.github.com/anand2312/93d3abf401335fd3310d9e30112303bf -class AsyncRPCFilterRequestBuilder( - BaseRPCRequestBuilder[_ReturnT], AsyncSingleRequestBuilder[_ReturnT] -): +class AsyncRPCFilterRequestBuilder(BaseRPCRequestBuilder, AsyncSingleRequestBuilder): def __init__( self, session: AsyncClient, @@ -180,20 +195,15 @@ def __init__( http_method: str, headers: Headers, params: QueryParams, - json: dict, + json: JSON, ) -> None: - get_origin_and_cast(BaseFilterRequestBuilder[_ReturnT]).__init__( - self, session, headers, params - ) - get_origin_and_cast(AsyncSingleRequestBuilder[_ReturnT]).__init__( + BaseFilterRequestBuilder.__init__(self, headers, params) + AsyncSingleRequestBuilder.__init__( self, session, path, http_method, headers, params, json ) -# ignoring type checking as a workaround for https://github.com/python/mypy/issues/9319 -class AsyncSelectRequestBuilder( - BaseSelectRequestBuilder[_ReturnT], AsyncQueryRequestBuilder[_ReturnT] -): # type: ignore +class AsyncSelectRequestBuilder(AsyncQueryRequestBuilder, BaseSelectRequestBuilder): def __init__( self, session: AsyncClient, @@ -201,46 +211,44 @@ def __init__( http_method: str, headers: Headers, params: QueryParams, - json: dict, + json: JSON, ) -> None: - get_origin_and_cast(BaseSelectRequestBuilder[_ReturnT]).__init__( - self, session, headers, params - ) - get_origin_and_cast(AsyncQueryRequestBuilder[_ReturnT]).__init__( + BaseSelectRequestBuilder.__init__(self, headers, params) + AsyncQueryRequestBuilder.__init__( self, session, path, http_method, headers, params, json ) - def single(self) -> AsyncSingleRequestBuilder[_ReturnT]: + def single(self) -> AsyncSingleRequestBuilder: """Specify that the query will only return a single row in response. .. caution:: The API will raise an error if the query returned more than one row. """ self.headers["Accept"] = "application/vnd.pgrst.object+json" - return AsyncSingleRequestBuilder[_ReturnT]( + return AsyncSingleRequestBuilder( headers=self.headers, http_method=self.http_method, json=self.json, params=self.params, path=self.path, - session=self.session, # type: ignore + session=self.session, ) - def maybe_single(self) -> AsyncMaybeSingleRequestBuilder[_ReturnT]: + def maybe_single(self) -> AsyncMaybeSingleRequestBuilder: """Retrieves at most one row from the result. Result must be at most one row (e.g. using `eq` on a UNIQUE column), otherwise this will result in an error.""" self.headers["Accept"] = "application/vnd.pgrst.object+json" - return AsyncMaybeSingleRequestBuilder[_ReturnT]( + return AsyncMaybeSingleRequestBuilder( headers=self.headers, http_method=self.http_method, json=self.json, params=self.params, path=self.path, - session=self.session, # type: ignore + session=self.session, ) def text_search( self, column: str, query: str, options: dict[str, Any] = {} - ) -> AsyncFilterRequestBuilder[_ReturnT]: + ) -> AsyncQueryRequestBuilder: type_ = options.get("type") type_part = "" if type_ == "plain": @@ -252,20 +260,20 @@ def text_search( config_part = f"({options.get('config')})" if options.get("config") else "" self.params = self.params.add(column, f"{type_part}fts{config_part}.{query}") - return AsyncQueryRequestBuilder[_ReturnT]( + return AsyncQueryRequestBuilder( headers=self.headers, http_method=self.http_method, json=self.json, params=self.params, path=self.path, - session=self.session, # type: ignore + session=self.session, ) - def csv(self) -> AsyncSingleRequestBuilder[str]: + def csv(self) -> AsyncSingleRequestBuilder: """Specify that the query must retrieve data as a single CSV string.""" self.headers["Accept"] = "text/csv" - return AsyncSingleRequestBuilder[str]( - session=self.session, # type: ignore + return AsyncSingleRequestBuilder( + session=self.session, path=self.path, http_method=self.http_method, headers=self.headers, @@ -274,7 +282,7 @@ def csv(self) -> AsyncSingleRequestBuilder[str]: ) -class AsyncRequestBuilder(Generic[_ReturnT]): +class AsyncRequestBuilder: def __init__(self, session: AsyncClient, path: str) -> None: self.session = session self.path = path @@ -284,7 +292,7 @@ def select( *columns: str, count: Optional[CountMethod] = None, head: Optional[bool] = None, - ) -> AsyncSelectRequestBuilder[_ReturnT]: + ) -> AsyncSelectRequestBuilder: """Run a SELECT query. Args: @@ -294,19 +302,19 @@ def select( :class:`AsyncSelectRequestBuilder` """ method, params, headers, json = pre_select(*columns, count=count, head=head) - return AsyncSelectRequestBuilder[_ReturnT]( + return AsyncSelectRequestBuilder( self.session, self.path, method, headers, params, json ) def insert( self, - json: Union[dict, list], + json: JSON, *, count: Optional[CountMethod] = None, returning: ReturnMethod = ReturnMethod.representation, upsert: bool = False, default_to_null: bool = True, - ) -> AsyncQueryRequestBuilder[_ReturnT]: + ) -> AsyncQueryRequestBuilder: """Run an INSERT query. Args: @@ -327,20 +335,20 @@ def insert( upsert=upsert, default_to_null=default_to_null, ) - return AsyncQueryRequestBuilder[_ReturnT]( + return AsyncQueryRequestBuilder( self.session, self.path, method, headers, params, json ) def upsert( self, - json: Union[dict, list], + json: JSON, *, count: Optional[CountMethod] = None, returning: ReturnMethod = ReturnMethod.representation, ignore_duplicates: bool = False, on_conflict: str = "", default_to_null: bool = True, - ) -> AsyncQueryRequestBuilder[_ReturnT]: + ) -> AsyncQueryRequestBuilder: """Run an upsert (INSERT ... ON CONFLICT DO UPDATE) query. Args: @@ -364,17 +372,17 @@ def upsert( on_conflict=on_conflict, default_to_null=default_to_null, ) - return AsyncQueryRequestBuilder[_ReturnT]( + return AsyncQueryRequestBuilder( self.session, self.path, method, headers, params, json ) def update( self, - json: dict, + json: JSON, *, count: Optional[CountMethod] = None, returning: ReturnMethod = ReturnMethod.representation, - ) -> AsyncFilterRequestBuilder[_ReturnT]: + ) -> AsyncFilterRequestBuilder: """Run an UPDATE query. Args: @@ -389,7 +397,7 @@ def update( count=count, returning=returning, ) - return AsyncFilterRequestBuilder[_ReturnT]( + return AsyncFilterRequestBuilder( self.session, self.path, method, headers, params, json ) @@ -398,7 +406,7 @@ def delete( *, count: Optional[CountMethod] = None, returning: ReturnMethod = ReturnMethod.representation, - ) -> AsyncFilterRequestBuilder[_ReturnT]: + ) -> AsyncFilterRequestBuilder: """Run a DELETE query. Args: @@ -411,6 +419,6 @@ def delete( count=count, returning=returning, ) - return AsyncFilterRequestBuilder[_ReturnT]( + return AsyncFilterRequestBuilder( self.session, self.path, method, headers, params, json ) diff --git a/src/postgrest/src/postgrest/_sync/client.py b/src/postgrest/src/postgrest/_sync/client.py index 908db515..ac578960 100644 --- a/src/postgrest/src/postgrest/_sync/client.py +++ b/src/postgrest/src/postgrest/_sync/client.py @@ -15,8 +15,6 @@ from ..version import __version__ from .request_builder import SyncRequestBuilder, SyncRPCFilterRequestBuilder -_TableT = Dict[str, Any] - class SyncPostgrestClient(BasePostgrestClient): """PostgREST client.""" @@ -72,7 +70,7 @@ def __init__( proxy=proxy, http_client=http_client, ) - self.session = cast(Client, self.session) + self.session: Client = self.session def create_session( self, @@ -122,7 +120,7 @@ def aclose(self) -> None: """Close the underlying HTTP connections.""" self.session.close() - def from_(self, table: str) -> SyncRequestBuilder[_TableT]: + def from_(self, table: str) -> SyncRequestBuilder: """Perform a table operation. Args: @@ -130,9 +128,9 @@ def from_(self, table: str) -> SyncRequestBuilder[_TableT]: Returns: :class:`AsyncRequestBuilder` """ - return SyncRequestBuilder[_TableT](self.session, f"/{table}") + return SyncRequestBuilder(self.session, f"/{table}") - def table(self, table: str) -> SyncRequestBuilder[_TableT]: + def table(self, table: str) -> SyncRequestBuilder: """Alias to :meth:`from_`.""" return self.from_(table) @@ -148,7 +146,7 @@ def rpc( count: Optional[CountMethod] = None, head: bool = False, get: bool = False, - ) -> SyncRPCFilterRequestBuilder[Any]: + ) -> SyncRPCFilterRequestBuilder: """Perform a stored procedure call. Args: @@ -175,7 +173,7 @@ def rpc( headers = Headers({"Prefer": f"count={count}"}) if count else Headers() if method in ("HEAD", "GET"): - return SyncRPCFilterRequestBuilder[Any]( + return SyncRPCFilterRequestBuilder( self.session, f"/rpc/{func}", method, @@ -184,6 +182,6 @@ def rpc( json={}, ) # the params here are params to be sent to the RPC and not the queryparams! - return SyncRPCFilterRequestBuilder[Any]( + return SyncRPCFilterRequestBuilder( self.session, f"/rpc/{func}", method, headers, QueryParams(), json=params ) diff --git a/src/postgrest/src/postgrest/_sync/request_builder.py b/src/postgrest/src/postgrest/_sync/request_builder.py index efb0f7a7..aacc922d 100644 --- a/src/postgrest/src/postgrest/_sync/request_builder.py +++ b/src/postgrest/src/postgrest/_sync/request_builder.py @@ -4,6 +4,7 @@ from httpx import Client, Headers, QueryParams from pydantic import ValidationError +from typing_extensions import override from ..base_request_builder import ( APIResponse, @@ -19,13 +20,11 @@ pre_upsert, ) from ..exceptions import APIError, APIErrorFromJSON, generate_default_error_message -from ..types import ReturnMethod -from ..utils import get_origin_and_cast, model_validate_json +from ..types import JSON, ReturnMethod +from ..utils import model_validate_json -_ReturnT = TypeVar("_ReturnT") - -class SyncQueryRequestBuilder(Generic[_ReturnT]): +class SyncQueryRequestBuilder: def __init__( self, session: Client, @@ -33,7 +32,7 @@ def __init__( http_method: str, headers: Headers, params: QueryParams, - json: dict, + json: JSON, ) -> None: self.session = session self.path = path @@ -42,7 +41,7 @@ def __init__( self.params = params self.json = None if http_method in {"GET", "HEAD"} else json - def execute(self) -> APIResponse[_ReturnT]: + def execute(self) -> APIResponse | str: """Execute the query. .. tip:: @@ -72,7 +71,7 @@ def execute(self) -> APIResponse[_ReturnT]: ) and "application/vnd.pgrst.plan" in self.headers.get("Accept"): if "+json" not in self.headers.get("Accept"): return body - return APIResponse[_ReturnT].from_http_request_response(r) + return APIResponse.from_http_request_response(r) else: json_obj = model_validate_json(APIErrorFromJSON, r.content) raise APIError(dict(json_obj)) @@ -80,7 +79,7 @@ def execute(self) -> APIResponse[_ReturnT]: raise APIError(generate_default_error_message(r)) -class SyncSingleRequestBuilder(Generic[_ReturnT]): +class SyncSingleRequestBuilder: def __init__( self, session: Client, @@ -88,7 +87,7 @@ def __init__( http_method: str, headers: Headers, params: QueryParams, - json: dict, + json: JSON, ) -> None: self.session = session self.path = path @@ -97,17 +96,17 @@ def __init__( self.params = params self.json = json - def execute(self) -> SingleAPIResponse[_ReturnT]: + def execute(self) -> SingleAPIResponse: """Execute the query. - .. tip:: - This is the last method called, after the query is built. + .. tip:: + This is the last method called, after the query is built. - Returns: - :class:`SingleAPIResponse` - - Raises: - :class:`APIError` If the API raised an error. + Returns: + :class:`SingleAPIResponse` + na + Raises: + :class:`APIError` If the API raised an error. """ r = self.session.request( self.http_method, @@ -120,7 +119,7 @@ def execute(self) -> SingleAPIResponse[_ReturnT]: if ( 200 <= r.status_code <= 299 ): # Response.ok from JS (https://developer.mozilla.org/en-US/docs/Web/API/Response/ok) - return SingleAPIResponse[_ReturnT].from_http_request_response(r) + return SingleAPIResponse.from_http_request_response(r) else: json_obj = model_validate_json(APIErrorFromJSON, r.content) raise APIError(dict(json_obj)) @@ -128,11 +127,35 @@ def execute(self) -> SingleAPIResponse[_ReturnT]: raise APIError(generate_default_error_message(r)) -class SyncMaybeSingleRequestBuilder(SyncSingleRequestBuilder[_ReturnT]): - def execute(self) -> Optional[SingleAPIResponse[_ReturnT]]: +class SyncMaybeSingleRequestBuilder: + def __init__( + self, + session: Client, + path: str, + http_method: str, + headers: Headers, + params: QueryParams, + json: JSON, + ) -> None: + self.session = session + self.path = path + self.http_method = http_method + self.headers = headers + self.params = params + self.json = json + + def execute(self) -> Optional[SingleAPIResponse]: r = None try: - r = SyncSingleRequestBuilder[_ReturnT].execute(self) + request = SyncSingleRequestBuilder( + self.session, + self.path, + self.http_method, + self.headers, + self.params, + self.json, + ) + r = request.execute() except APIError as e: if e.details and "The result contains 0 rows" in e.details: return None @@ -148,10 +171,7 @@ def execute(self) -> Optional[SingleAPIResponse[_ReturnT]]: return r -# ignoring type checking as a workaround for https://github.com/python/mypy/issues/9319 -class SyncFilterRequestBuilder( - BaseFilterRequestBuilder[_ReturnT], SyncQueryRequestBuilder[_ReturnT] -): # type: ignore +class SyncFilterRequestBuilder(BaseFilterRequestBuilder, SyncQueryRequestBuilder): def __init__( self, session: Client, @@ -159,20 +179,15 @@ def __init__( http_method: str, headers: Headers, params: QueryParams, - json: dict, + json: JSON, ) -> None: - get_origin_and_cast(BaseFilterRequestBuilder[_ReturnT]).__init__( - self, session, headers, params - ) - get_origin_and_cast(SyncQueryRequestBuilder[_ReturnT]).__init__( + BaseFilterRequestBuilder.__init__(self, headers, params) + SyncQueryRequestBuilder.__init__( self, session, path, http_method, headers, params, json ) -# this exists for type-safety. see https://gist.github.com/anand2312/93d3abf401335fd3310d9e30112303bf -class SyncRPCFilterRequestBuilder( - BaseRPCRequestBuilder[_ReturnT], SyncSingleRequestBuilder[_ReturnT] -): +class SyncRPCFilterRequestBuilder(BaseRPCRequestBuilder, SyncSingleRequestBuilder): def __init__( self, session: Client, @@ -180,20 +195,15 @@ def __init__( http_method: str, headers: Headers, params: QueryParams, - json: dict, + json: JSON, ) -> None: - get_origin_and_cast(BaseFilterRequestBuilder[_ReturnT]).__init__( - self, session, headers, params - ) - get_origin_and_cast(SyncSingleRequestBuilder[_ReturnT]).__init__( + BaseFilterRequestBuilder.__init__(self, headers, params) + SyncSingleRequestBuilder.__init__( self, session, path, http_method, headers, params, json ) -# ignoring type checking as a workaround for https://github.com/python/mypy/issues/9319 -class SyncSelectRequestBuilder( - BaseSelectRequestBuilder[_ReturnT], SyncQueryRequestBuilder[_ReturnT] -): # type: ignore +class SyncSelectRequestBuilder(SyncQueryRequestBuilder, BaseSelectRequestBuilder): def __init__( self, session: Client, @@ -201,46 +211,44 @@ def __init__( http_method: str, headers: Headers, params: QueryParams, - json: dict, + json: JSON, ) -> None: - get_origin_and_cast(BaseSelectRequestBuilder[_ReturnT]).__init__( - self, session, headers, params - ) - get_origin_and_cast(SyncQueryRequestBuilder[_ReturnT]).__init__( + BaseSelectRequestBuilder.__init__(self, headers, params) + SyncQueryRequestBuilder.__init__( self, session, path, http_method, headers, params, json ) - def single(self) -> SyncSingleRequestBuilder[_ReturnT]: + def single(self) -> SyncSingleRequestBuilder: """Specify that the query will only return a single row in response. .. caution:: The API will raise an error if the query returned more than one row. """ self.headers["Accept"] = "application/vnd.pgrst.object+json" - return SyncSingleRequestBuilder[_ReturnT]( + return SyncSingleRequestBuilder( headers=self.headers, http_method=self.http_method, json=self.json, params=self.params, path=self.path, - session=self.session, # type: ignore + session=self.session, ) - def maybe_single(self) -> SyncMaybeSingleRequestBuilder[_ReturnT]: + def maybe_single(self) -> SyncMaybeSingleRequestBuilder: """Retrieves at most one row from the result. Result must be at most one row (e.g. using `eq` on a UNIQUE column), otherwise this will result in an error.""" self.headers["Accept"] = "application/vnd.pgrst.object+json" - return SyncMaybeSingleRequestBuilder[_ReturnT]( + return SyncMaybeSingleRequestBuilder( headers=self.headers, http_method=self.http_method, json=self.json, params=self.params, path=self.path, - session=self.session, # type: ignore + session=self.session, ) def text_search( self, column: str, query: str, options: dict[str, Any] = {} - ) -> SyncFilterRequestBuilder[_ReturnT]: + ) -> SyncQueryRequestBuilder: type_ = options.get("type") type_part = "" if type_ == "plain": @@ -252,20 +260,20 @@ def text_search( config_part = f"({options.get('config')})" if options.get("config") else "" self.params = self.params.add(column, f"{type_part}fts{config_part}.{query}") - return SyncQueryRequestBuilder[_ReturnT]( + return SyncQueryRequestBuilder( headers=self.headers, http_method=self.http_method, json=self.json, params=self.params, path=self.path, - session=self.session, # type: ignore + session=self.session, ) - def csv(self) -> SyncSingleRequestBuilder[str]: + def csv(self) -> SyncSingleRequestBuilder: """Specify that the query must retrieve data as a single CSV string.""" self.headers["Accept"] = "text/csv" - return SyncSingleRequestBuilder[str]( - session=self.session, # type: ignore + return SyncSingleRequestBuilder( + session=self.session, path=self.path, http_method=self.http_method, headers=self.headers, @@ -274,7 +282,7 @@ def csv(self) -> SyncSingleRequestBuilder[str]: ) -class SyncRequestBuilder(Generic[_ReturnT]): +class SyncRequestBuilder: def __init__(self, session: Client, path: str) -> None: self.session = session self.path = path @@ -284,7 +292,7 @@ def select( *columns: str, count: Optional[CountMethod] = None, head: Optional[bool] = None, - ) -> SyncSelectRequestBuilder[_ReturnT]: + ) -> SyncSelectRequestBuilder: """Run a SELECT query. Args: @@ -294,19 +302,19 @@ def select( :class:`SyncSelectRequestBuilder` """ method, params, headers, json = pre_select(*columns, count=count, head=head) - return SyncSelectRequestBuilder[_ReturnT]( + return SyncSelectRequestBuilder( self.session, self.path, method, headers, params, json ) def insert( self, - json: Union[dict, list], + json: JSON, *, count: Optional[CountMethod] = None, returning: ReturnMethod = ReturnMethod.representation, upsert: bool = False, default_to_null: bool = True, - ) -> SyncQueryRequestBuilder[_ReturnT]: + ) -> SyncQueryRequestBuilder: """Run an INSERT query. Args: @@ -327,20 +335,20 @@ def insert( upsert=upsert, default_to_null=default_to_null, ) - return SyncQueryRequestBuilder[_ReturnT]( + return SyncQueryRequestBuilder( self.session, self.path, method, headers, params, json ) def upsert( self, - json: Union[dict, list], + json: JSON, *, count: Optional[CountMethod] = None, returning: ReturnMethod = ReturnMethod.representation, ignore_duplicates: bool = False, on_conflict: str = "", default_to_null: bool = True, - ) -> SyncQueryRequestBuilder[_ReturnT]: + ) -> SyncQueryRequestBuilder: """Run an upsert (INSERT ... ON CONFLICT DO UPDATE) query. Args: @@ -364,17 +372,17 @@ def upsert( on_conflict=on_conflict, default_to_null=default_to_null, ) - return SyncQueryRequestBuilder[_ReturnT]( + return SyncQueryRequestBuilder( self.session, self.path, method, headers, params, json ) def update( self, - json: dict, + json: JSON, *, count: Optional[CountMethod] = None, returning: ReturnMethod = ReturnMethod.representation, - ) -> SyncFilterRequestBuilder[_ReturnT]: + ) -> SyncFilterRequestBuilder: """Run an UPDATE query. Args: @@ -389,7 +397,7 @@ def update( count=count, returning=returning, ) - return SyncFilterRequestBuilder[_ReturnT]( + return SyncFilterRequestBuilder( self.session, self.path, method, headers, params, json ) @@ -398,7 +406,7 @@ def delete( *, count: Optional[CountMethod] = None, returning: ReturnMethod = ReturnMethod.representation, - ) -> SyncFilterRequestBuilder[_ReturnT]: + ) -> SyncFilterRequestBuilder: """Run a DELETE query. Args: @@ -411,6 +419,6 @@ def delete( count=count, returning=returning, ) - return SyncFilterRequestBuilder[_ReturnT]( + return SyncFilterRequestBuilder( self.session, self.path, method, headers, params, json ) diff --git a/src/postgrest/src/postgrest/base_request_builder.py b/src/postgrest/src/postgrest/base_request_builder.py index 3e7ff0aa..9d79e232 100644 --- a/src/postgrest/src/postgrest/base_request_builder.py +++ b/src/postgrest/src/postgrest/base_request_builder.py @@ -20,10 +20,10 @@ from httpx import AsyncClient, Client, Headers, QueryParams from httpx import Response as RequestResponse -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError try: - from typing import Self + from typing import Self # type: ignore except ImportError: from typing_extensions import Self @@ -32,10 +32,11 @@ from pydantic import field_validator except ImportError: # < 2.0.0 - from pydantic import validator as field_validator + from pydantic import validator as field_validator # type: ignore -from .types import CountMethod, Filters, RequestMethod, ReturnMethod -from .utils import get_origin_and_cast, sanitize_param +from .base_client import BasePostgrestClient +from .types import JSON, CountMethod, Filters, JSONAdapter, RequestMethod, ReturnMethod +from .utils import sanitize_param class QueryArgs(NamedTuple): @@ -43,10 +44,10 @@ class QueryArgs(NamedTuple): method: RequestMethod params: QueryParams headers: Headers - json: Dict[Any, Any] + json: JSON -def _unique_columns(json: List[Dict]): +def _unique_columns(json: List[Dict[str, JSON]]): unique_keys = {key for row in json for key in row.keys()} columns = ",".join([f'"{k}"' for k in unique_keys]) return columns @@ -75,7 +76,7 @@ def pre_select( head: Optional[bool] = None, ) -> QueryArgs: method = RequestMethod.HEAD if head else RequestMethod.GET - cleaned_columns = _cleaned_columns(columns or "*") + cleaned_columns = _cleaned_columns(columns or ("*",)) params = QueryParams({"select": cleaned_columns}) headers = Headers({"Prefer": f"count={count}"}) if count else Headers() @@ -83,7 +84,7 @@ def pre_select( def pre_insert( - json: Union[dict, list], + json: JSON, *, count: Optional[CountMethod], returning: ReturnMethod, @@ -106,7 +107,7 @@ def pre_insert( def pre_upsert( - json: Union[dict, list], + json: JSON, *, count: Optional[CountMethod], returning: ReturnMethod, @@ -132,7 +133,7 @@ def pre_upsert( def pre_update( - json: dict, + json: JSON, *, count: Optional[CountMethod], returning: ReturnMethod, @@ -156,15 +157,8 @@ def pre_delete( return QueryArgs(RequestMethod.DELETE, QueryParams(), headers, {}) -_ReturnT = TypeVar("_ReturnT") - - -# the APIResponse.data is marked as _ReturnT instead of list[_ReturnT] -# as it is also returned in the case of rpc() calls; and rpc calls do not -# necessarily return lists. -# https://github.com/supabase-community/postgrest-py/issues/200 -class APIResponse(BaseModel, Generic[_ReturnT]): - data: List[_ReturnT] +class APIResponse(BaseModel): + data: List[JSON] """The data returned by the query.""" count: Optional[int] = None """The number of rows returned.""" @@ -188,78 +182,57 @@ def _is_count_in_prefer_header(prefer_header: str) -> bool: pattern = f"count=({'|'.join([cm.value for cm in CountMethod])})" return bool(search(pattern, prefer_header)) - @classmethod + @staticmethod def _get_count_from_http_request_response( - cls: Type[Self], request_response: RequestResponse, ) -> Optional[int]: prefer_header: Optional[str] = request_response.request.headers.get("prefer") if not prefer_header: return None - is_count_in_prefer_header = cls._is_count_in_prefer_header(prefer_header) + is_count_in_prefer_header = APIResponse._is_count_in_prefer_header( + prefer_header + ) content_range_header: Optional[str] = request_response.headers.get( "content-range" ) - return ( - cls._get_count_from_content_range_header(content_range_header) - if (is_count_in_prefer_header and content_range_header) - else None - ) + if is_count_in_prefer_header and content_range_header: + return APIResponse._get_count_from_content_range_header( + content_range_header + ) + return None - @classmethod - def from_http_request_response( - cls: Type[Self], request_response: RequestResponse - ) -> Self: - count = cls._get_count_from_http_request_response(request_response) + @staticmethod + def from_http_request_response(request_response: RequestResponse) -> APIResponse: + count = APIResponse._get_count_from_http_request_response(request_response) try: - data = request_response.json() - except JSONDecodeError: + data = JSONAdapter.validate_json(request_response.content) + except ValidationError: data = request_response.text if len(request_response.text) > 0 else [] - # the type-ignore here is as pydantic needs us to pass the type parameter - # here explicitly, but pylance already knows that cls is correctly parametrized - return cls[_ReturnT](data=data, count=count) # type: ignore + return APIResponse(data=data, count=count) - @classmethod - def from_dict(cls: Type[Self], dict: Dict[str, Any]) -> Self: - keys = dict.keys() - assert len(keys) == 3 and "data" in keys and "count" in keys and "error" in keys - return cls[_ReturnT]( # type: ignore - data=dict.get("data"), count=dict.get("count"), error=dict.get("error") - ) - -class SingleAPIResponse(APIResponse[_ReturnT], Generic[_ReturnT]): - data: _ReturnT # type: ignore +class SingleAPIResponse(APIResponse): + data: JSON # type: ignore """The data returned by the query.""" - @classmethod + @staticmethod def from_http_request_response( - cls: Type[Self], request_response: RequestResponse - ) -> Self: - count = cls._get_count_from_http_request_response(request_response) + request_response: RequestResponse, + ) -> SingleAPIResponse: + count = APIResponse._get_count_from_http_request_response(request_response) try: data = request_response.json() except JSONDecodeError: data = request_response.text if len(request_response.text) > 0 else [] - return cls[_ReturnT](data=data, count=count) # type: ignore + return SingleAPIResponse(data=data, count=count) - @classmethod - def from_dict(cls: Type[Self], dict: Dict[str, Any]) -> Self: - keys = dict.keys() - assert len(keys) == 3 and "data" in keys and "count" in keys and "error" in keys - return cls[_ReturnT]( # type: ignore - data=dict.get("data"), count=dict.get("count"), error=dict.get("error") - ) - -class BaseFilterRequestBuilder(Generic[_ReturnT]): +class BaseFilterRequestBuilder: def __init__( self, - session: Union[AsyncClient, Client], headers: Headers, params: QueryParams, ) -> None: - self.session = session self.headers = headers self.params = params self.negate_next = False @@ -547,20 +520,7 @@ def max_affected(self: Self, value: int) -> Self: return self -class BaseSelectRequestBuilder(BaseFilterRequestBuilder[_ReturnT]): - def __init__( - self, - session: Union[AsyncClient, Client], - headers: Headers, - params: QueryParams, - ) -> None: - # Generic[T] is an instance of typing._GenericAlias, so doing Generic[T].__init__ - # tries to call _GenericAlias.__init__ - which is the wrong method - # The __origin__ attribute of the _GenericAlias is the actual class - get_origin_and_cast(BaseFilterRequestBuilder[_ReturnT]).__init__( - self, session, headers, params - ) - +class BaseSelectRequestBuilder(BaseFilterRequestBuilder): def explain( self: Self, analyze: bool = False, @@ -653,20 +613,7 @@ def range( return self -class BaseRPCRequestBuilder(BaseSelectRequestBuilder[_ReturnT]): - def __init__( - self, - session: Union[AsyncClient, Client], - headers: Headers, - params: QueryParams, - ) -> None: - # Generic[T] is an instance of typing._GenericAlias, so doing Generic[T].__init__ - # tries to call _GenericAlias.__init__ - which is the wrong method - # The __origin__ attribute of the _GenericAlias is the actual class - get_origin_and_cast(BaseSelectRequestBuilder[_ReturnT]).__init__( - self, session, headers, params - ) - +class BaseRPCRequestBuilder(BaseSelectRequestBuilder): def select( self, *columns: str, diff --git a/src/postgrest/src/postgrest/types.py b/src/postgrest/src/postgrest/types.py index fa6f94ce..17b0804f 100644 --- a/src/postgrest/src/postgrest/types.py +++ b/src/postgrest/src/postgrest/types.py @@ -1,12 +1,23 @@ from __future__ import annotations import sys +from collections.abc import Mapping, Sequence +from typing import Union + +from pydantic import TypeAdapter +from typing_extensions import TypeAliasType if sys.version_info >= (3, 11): from enum import StrEnum else: from strenum import StrEnum +# https://docs.pydantic.dev/2.11/concepts/types/#named-recursive-types +JSON = TypeAliasType( + "JSON", "Union[None, bool, str, int, float, Sequence[JSON], Mapping[str, JSON]]" +) +JSONAdapter: TypeAdapter = TypeAdapter(JSON) + class CountMethod(StrEnum): exact = "exact" diff --git a/src/postgrest/src/postgrest/utils.py b/src/postgrest/src/postgrest/utils.py index 7d2da4f0..aa8967d2 100644 --- a/src/postgrest/src/postgrest/utils.py +++ b/src/postgrest/src/postgrest/utils.py @@ -40,20 +40,6 @@ def sanitize_pattern_param(pattern: str) -> str: return sanitize_param(pattern.replace("%", "*")) -_T = TypeVar("_T") - - -def get_origin_and_cast(typ: type[type[_T]]) -> type[_T]: - # Base[T] is an instance of typing._GenericAlias, so doing Base[T].__init__ - # tries to call _GenericAlias.__init__ - which is the wrong method - # get_origin(Base[T]) returns Base - # This function casts Base back to Base[T] to maintain type-safety - # while still allowing us to access the methods of `Base` at runtime - # See: definitions of request builders that use multiple-inheritance - # like AsyncFilterRequestBuilder - return cast(Type[_T], get_origin(typ)) - - def is_http_url(url: str) -> bool: return urlparse(url).scheme in {"https", "http"} diff --git a/src/postgrest/tests/_async/test_request_builder.py b/src/postgrest/tests/_async/test_request_builder.py index e3986182..c8e6a6f5 100644 --- a/src/postgrest/tests/_async/test_request_builder.py +++ b/src/postgrest/tests/_async/test_request_builder.py @@ -5,7 +5,7 @@ from postgrest import AsyncRequestBuilder, AsyncSingleRequestBuilder from postgrest.base_request_builder import APIResponse, SingleAPIResponse -from postgrest.types import CountMethod +from postgrest.types import JSON, CountMethod @pytest.fixture @@ -407,19 +407,15 @@ def request_response_with_csv_data(csv_api_response: str) -> Response: class TestApiResponse: - def test_response_raises_when_api_error( - self, api_response_with_error: Dict[str, Any] - ): + def test_response_raises_when_api_error(self, api_response_with_error: List[JSON]): with pytest.raises(ValueError): APIResponse(data=api_response_with_error) - def test_parses_valid_response_only_data(self, api_response: List[Dict[str, Any]]): + def test_parses_valid_response_only_data(self, api_response: List[JSON]): result = APIResponse(data=api_response) assert result.data == api_response - def test_parses_valid_response_data_and_count( - self, api_response: List[Dict[str, Any]] - ): + def test_parses_valid_response_data_and_count(self, api_response: List[JSON]): count = len(api_response) result = APIResponse(data=api_response, count=count) assert result.data == api_response diff --git a/src/postgrest/tests/_sync/test_request_builder.py b/src/postgrest/tests/_sync/test_request_builder.py index 865703cf..89b807dd 100644 --- a/src/postgrest/tests/_sync/test_request_builder.py +++ b/src/postgrest/tests/_sync/test_request_builder.py @@ -5,7 +5,7 @@ from postgrest import SyncRequestBuilder, SyncSingleRequestBuilder from postgrest.base_request_builder import APIResponse, SingleAPIResponse -from postgrest.types import CountMethod +from postgrest.types import JSON, CountMethod @pytest.fixture @@ -407,19 +407,15 @@ def request_response_with_csv_data(csv_api_response: str) -> Response: class TestApiResponse: - def test_response_raises_when_api_error( - self, api_response_with_error: Dict[str, Any] - ): + def test_response_raises_when_api_error(self, api_response_with_error: List[JSON]): with pytest.raises(ValueError): APIResponse(data=api_response_with_error) - def test_parses_valid_response_only_data(self, api_response: List[Dict[str, Any]]): + def test_parses_valid_response_only_data(self, api_response: List[JSON]): result = APIResponse(data=api_response) assert result.data == api_response - def test_parses_valid_response_data_and_count( - self, api_response: List[Dict[str, Any]] - ): + def test_parses_valid_response_data_and_count(self, api_response: List[JSON]): count = len(api_response) result = APIResponse(data=api_response, count=count) assert result.data == api_response diff --git a/uv.lock b/uv.lock index 0f437183..8e0656d1 100644 --- a/uv.lock +++ b/uv.lock @@ -1623,10 +1623,13 @@ dependencies = [ [package.dev-dependencies] dev = [ { name = "pre-commit" }, + { name = "pylsp-mypy" }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-cov" }, { name = "pytest-depends" }, + { name = "python-lsp-ruff" }, + { name = "python-lsp-server" }, { name = "ruff" }, { name = "unasync" }, ] @@ -1638,6 +1641,9 @@ docs = [ ] lints = [ { name = "pre-commit" }, + { name = "pylsp-mypy" }, + { name = "python-lsp-ruff" }, + { name = "python-lsp-server" }, { name = "ruff" }, ] test = [ @@ -1659,10 +1665,13 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ { name = "pre-commit", specifier = ">=4.2.0" }, + { name = "pylsp-mypy", specifier = ">=0.7.0,<0.8.0" }, { name = "pytest", specifier = ">=8.4.1" }, { name = "pytest-asyncio", specifier = ">=1.0.0" }, { name = "pytest-cov", specifier = ">=6.2.1" }, { name = "pytest-depends", specifier = ">=1.0.1" }, + { name = "python-lsp-ruff", specifier = ">=2.2.2,<3.0.0" }, + { name = "python-lsp-server", specifier = ">=1.12.2,<2.0.0" }, { name = "ruff", specifier = ">=0.12.1" }, { name = "unasync", specifier = ">=0.6.0" }, ] @@ -1672,6 +1681,9 @@ docs = [ ] lints = [ { name = "pre-commit", specifier = ">=4.2.0" }, + { name = "pylsp-mypy", specifier = ">=0.7.0,<0.8.0" }, + { name = "python-lsp-ruff", specifier = ">=2.2.2,<3.0.0" }, + { name = "python-lsp-server", specifier = ">=1.12.2,<2.0.0" }, { name = "ruff", specifier = ">=0.12.1" }, ] test = [