Skip to content

Commit e161f53

Browse files
committed
chore(utils): refactor endpointrequest to make more sense
rename classes with `endpoint` in their name. given that this is a client, and not a server, these might not make sense the initial idea was that these were supposed to hit a 'server endpoint', but I think that the name is confusing after trying to use it elsewhere i've also split the one request class into multiple, because it makes more sense in the context of the other packages
1 parent 6d28ba8 commit e161f53

File tree

3 files changed

+121
-67
lines changed

3 files changed

+121
-67
lines changed

src/functions/src/supabase_functions/client.py

Lines changed: 46 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
from typing import Dict, Generic, Literal, Optional, Union, overload
22
from warnings import warn
33

4-
from httpx import AsyncClient, Client, Headers, Response
4+
from httpx import AsyncClient, Client, Headers, QueryParams, Response
55
from supabase_utils.http import (
66
AsyncExecutor,
7-
EndpointRequest,
7+
BytesRequest,
8+
EmptyRequest,
89
Executor,
910
HTTPRequestMethod,
10-
ServerEndpoint,
11+
JSONRequest,
12+
ResponseHandler,
1113
SyncExecutor,
12-
http_endpoint,
14+
TextRequest,
15+
ToHttpxRequest,
16+
http_request,
1317
)
1418
from supabase_utils.types import JSON
1519
from yarl import URL
@@ -79,42 +83,61 @@ def _invoke_options_to_request(
7983
region: Optional[FunctionRegion],
8084
headers: Optional[Dict[str, str]],
8185
method: Optional[HTTPRequestMethod],
82-
) -> EndpointRequest:
86+
) -> ToHttpxRequest:
8387
if not is_valid_str_arg(function_name):
8488
raise ValueError("function_name must a valid string value.")
8589

86-
request = EndpointRequest(
87-
method="POST",
88-
path=[function_name],
89-
headers=Headers(self.headers),
90-
)
90+
method = method or "POST"
91+
path = [function_name]
92+
new_headers = Headers(self.headers)
93+
query_params = QueryParams()
9194

92-
request.headers.update(headers or dict())
95+
if headers:
96+
new_headers.update(headers)
9397
if region and region != FunctionRegion.Any:
94-
request.headers["x-region"] = region.value
98+
new_headers["x-region"] = region.value
9599
# Add region as query parameter
96-
request.query_param("forceFunctionRegion", region.value)
97-
98-
if method:
99-
request.method = method
100+
query_params = query_params.set("forceFunctionRegion", region.value)
100101

101102
if isinstance(body, str):
102-
request.plain_text(body)
103+
return TextRequest(
104+
text=body,
105+
method=method,
106+
path=path,
107+
headers=new_headers,
108+
query_params=query_params,
109+
)
103110
elif isinstance(body, dict):
104-
request.json(body)
111+
return JSONRequest(
112+
body=body,
113+
method=method,
114+
path=path,
115+
headers=new_headers,
116+
query_params=query_params,
117+
exclude_none=False,
118+
)
105119
elif isinstance(body, bytes):
106-
request.bytes(body)
107-
return request
120+
return BytesRequest(
121+
body=body,
122+
method=method,
123+
path=path,
124+
headers=new_headers,
125+
query_params=query_params,
126+
)
127+
else:
128+
return EmptyRequest(
129+
method=method, path=path, headers=new_headers, query_params=query_params
130+
)
108131

109-
@http_endpoint
132+
@http_request
110133
def invoke(
111134
self,
112135
function_name: str,
113136
body: Union[bytes, str, Dict[str, JSON], None] = None,
114137
region: Optional[FunctionRegion] = None,
115138
headers: Optional[Dict[str, str]] = None,
116139
method: Optional[HTTPRequestMethod] = None,
117-
) -> ServerEndpoint[Response, Union[FunctionsHttpError, FunctionsRelayError]]:
140+
) -> ResponseHandler[Response, Union[FunctionsHttpError, FunctionsRelayError]]:
118141
"""Invokes a function
119142
120143
Parameters
@@ -128,7 +151,7 @@ def invoke(
128151
request = self._invoke_options_to_request(
129152
function_name, body, region, headers, method
130153
)
131-
return ServerEndpoint(
154+
return ResponseHandler(
132155
request=request,
133156
on_success=lambda response: response,
134157
on_failure=on_error_response,

src/utils/pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ documentation = "https://github.com/supabase/supabase-py/tree/main/src/utils"
2828
changelog = "https://github.com/supabase/supabase-py/tree/main/CHANGELOG.md"
2929

3030
[tool.mypy]
31-
python_version = "3.9"
3231
check_untyped_defs = true
3332
allow_redefinition = true
3433
follow_untyped_imports = true # for deprecation module that does not have stubs

src/utils/src/supabase_utils/http.py

Lines changed: 75 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@
1919
Headers,
2020
HTTPStatusError,
2121
QueryParams,
22-
Request,
2322
Response,
2423
)
24+
from httpx import (
25+
Request as HttpxRequest,
26+
)
2527
from pydantic import BaseModel, TypeAdapter
2628
from typing_extensions import Concatenate, ParamSpec
2729
from yarl import URL
@@ -32,51 +34,77 @@
3234

3335

3436
@dataclass
35-
class EndpointRequest:
36-
method: HTTPRequestMethod
37+
class EmptyRequest:
3738
path: List[str]
38-
body: Optional[bytes] = None
39-
headers: Headers = field(default_factory=Headers)
40-
query_params: QueryParams = field(default_factory=QueryParams)
39+
method: HTTPRequestMethod
40+
headers: Headers = field(default_factory=Headers, kw_only=True)
41+
query_params: QueryParams = field(default_factory=QueryParams, kw_only=True)
42+
43+
def to_request(self, base_url: URL) -> HttpxRequest:
44+
return HttpxRequest(
45+
method=self.method,
46+
url=str(base_url.joinpath(*self.path)),
47+
headers=self.headers,
48+
params=self.query_params,
49+
)
50+
51+
52+
@dataclass
53+
class BytesRequest(EmptyRequest):
54+
body: bytes
4155

42-
def bytes(self, bs: bytes) -> "EndpointRequest":
56+
def to_request(self, base_url: URL) -> HttpxRequest:
4357
self.headers["Content-Type"] = "application/octet-stream"
44-
self.body = bs
45-
return self
58+
return HttpxRequest(
59+
method=self.method,
60+
url=str(base_url.joinpath(*self.path)),
61+
headers=self.headers,
62+
params=self.query_params,
63+
content=self.body,
64+
)
4665

47-
def plain_text(self, text: str) -> "EndpointRequest":
48-
self.headers["Content-Type"] = "text/plain; charset=utf-8"
49-
self.body = text.encode("utf-8")
50-
return self
5166

52-
def model(self, model: BaseModel) -> "EndpointRequest":
67+
@dataclass
68+
class JSONRequest(EmptyRequest):
69+
body: Union[JSON, BaseModel]
70+
exclude_none: bool = True
71+
72+
def to_request(self, base_url: URL) -> HttpxRequest:
73+
if isinstance(self.body, BaseModel):
74+
content = self.body.__pydantic_serializer__.to_json(
75+
self.body, exclude_none=self.exclude_none
76+
)
77+
else:
78+
content = JSONParser.dump_json(self.body)
5379
self.headers["Content-Type"] = "application/json"
54-
self.body = model.__pydantic_serializer__.to_json(model)
55-
return self
80+
return HttpxRequest(
81+
method=self.method,
82+
url=str(base_url.joinpath(*self.path)),
83+
headers=self.headers,
84+
params=self.query_params,
85+
content=content,
86+
)
5687

57-
def json(self, json: JSON) -> "EndpointRequest":
58-
self.headers["Content-Type"] = "application/json"
59-
self.body = JSONParser.dump_json(json)
60-
return self
6188

62-
def query_param(self, key: str, value: str) -> "EndpointRequest":
63-
self.query_params = self.query_params.set(key, value)
64-
return self
89+
@dataclass
90+
class TextRequest(EmptyRequest):
91+
text: str
6592

66-
def to_request(self, base_url: URL) -> Request:
67-
return Request(
93+
def to_request(self, base_url: URL) -> HttpxRequest:
94+
self.headers["Content-Type"] = "text/plain; charset=utf-8"
95+
return HttpxRequest(
6896
method=self.method,
6997
url=str(base_url.joinpath(*self.path)),
7098
headers=self.headers,
7199
params=self.query_params,
72-
content=self.body,
100+
content=self.text.encode("utf-8"),
73101
)
74102

75103

76104
T = TypeVar("T", covariant=True)
77105

78106

79-
class FromHTTPResponse(Protocol[T]):
107+
class FromHttpxResponse(Protocol[T]):
80108
def __call__(self, response: Response) -> T: ...
81109

82110

@@ -86,7 +114,7 @@ def __call__(self, response: Response) -> T: ...
86114
Model = TypeVar("Model", bound=BaseModel)
87115

88116

89-
def validate_model(model: type[Model]) -> FromHTTPResponse[Model]:
117+
def validate_model(model: type[Model]) -> FromHttpxResponse[Model]:
90118
def from_response(response: Response) -> Model:
91119
return model.model_validate_json(response.content)
92120

@@ -96,49 +124,53 @@ def from_response(response: Response) -> Model:
96124
Inner = TypeVar("Inner")
97125

98126

99-
def validate_adapter(adapter: TypeAdapter[Inner]) -> FromHTTPResponse[Inner]:
127+
def validate_adapter(adapter: TypeAdapter[Inner]) -> FromHttpxResponse[Inner]:
100128
def from_response(response: Response) -> Inner:
101129
return adapter.validate_json(response.content)
102130

103131
return from_response
104132

105133

134+
class ToHttpxRequest(Protocol):
135+
def to_request(self, base_url: URL) -> HttpxRequest: ...
136+
137+
106138
@dataclass
107-
class ServerEndpoint(Generic[Success, Failure]):
108-
request: EndpointRequest
109-
on_success: FromHTTPResponse[Success]
110-
on_failure: FromHTTPResponse[Failure]
139+
class ResponseHandler(Generic[Success, Failure]):
140+
request: ToHttpxRequest
141+
on_success: FromHttpxResponse[Success]
142+
on_failure: FromHttpxResponse[Failure]
111143

112144

113145
class SyncExecutor:
114146
def __init__(self, session: Client) -> None:
115147
self.session = session
116148

117149
def communicate(
118-
self, base_url: URL, endpoint: ServerEndpoint[Success, Failure]
150+
self, base_url: URL, handler: ResponseHandler[Success, Failure]
119151
) -> Success:
120-
response = self.session.send(endpoint.request.to_request(base_url))
152+
response = self.session.send(handler.request.to_request(base_url))
121153
try:
122154
response.raise_for_status()
123-
return endpoint.on_success(response)
155+
return handler.on_success(response)
124156
except HTTPStatusError:
125-
raise endpoint.on_failure(response) from None
157+
raise handler.on_failure(response) from None
126158

127159

128160
class AsyncExecutor:
129161
def __init__(self, session: AsyncClient) -> None:
130162
self.session = session
131163

132164
async def communicate(
133-
self, base_url: URL, endpoint: ServerEndpoint[Success, Failure]
165+
self, base_url: URL, handler: ResponseHandler[Success, Failure]
134166
) -> Success:
135-
request = endpoint.request.to_request(base_url)
167+
request = handler.request.to_request(base_url)
136168
response = await self.session.send(request)
137169
try:
138170
response.raise_for_status()
139-
return endpoint.on_success(response)
171+
return handler.on_success(response)
140172
except HTTPStatusError:
141-
raise endpoint.on_failure(response) from None
173+
raise handler.on_failure(response) from None
142174

143175

144176
Params = ParamSpec("Params")
@@ -151,8 +183,8 @@ class HasExecutor(Protocol[Executor]):
151183

152184

153185
@dataclass
154-
class http_endpoint(Generic[Params, Success, Failure]):
155-
method: Callable[Concatenate[Any, Params], ServerEndpoint[Success, Failure]]
186+
class http_request(Generic[Params, Success, Failure]):
187+
method: Callable[Concatenate[Any, Params], ResponseHandler[Success, Failure]]
156188

157189
@overload
158190
def __get__(

0 commit comments

Comments
 (0)