Skip to content

Commit c349744

Browse files
committed
test: implement file upload interfaces for flask test client
1 parent c4ae3e4 commit c349744

File tree

1 file changed

+111
-1
lines changed

1 file changed

+111
-1
lines changed

tests/unit/restapi/lib/client.py

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,16 @@
2121
import structlog
2222
from flask.testing import FlaskClient
2323
from structlog.stdlib import BoundLogger
24+
from werkzeug.datastructures import FileStorage
2425
from werkzeug.test import TestResponse
2526

2627
from dioptra.client.base import (
2728
DioptraClientError,
29+
DioptraFile,
2830
DioptraRequestProtocol,
2931
DioptraResponseProtocol,
3032
DioptraSession,
33+
IllegalArgumentError,
3134
StatusCodeError,
3235
)
3336
from dioptra.restapi.routes import V1_ROOT
@@ -141,6 +144,69 @@ def is_2xx(status_code: int) -> bool:
141144
return status_code >= HTTPStatus.OK and status_code < HTTPStatus.MULTIPLE_CHOICES
142145

143146

147+
def format_file_for_request(file_: DioptraFile) -> FileStorage:
148+
"""Format the DioptraFile object into a FlaskClient-compatible data structure.
149+
150+
Returns:
151+
The file encoded as a Werkzeug FileStorage object.
152+
"""
153+
if file_.content_type is None:
154+
return FileStorage(stream=file_.stream, filename=file_.filename)
155+
156+
return FileStorage(
157+
stream=file_.stream, filename=file_.filename, content_type=file_.content_type
158+
)
159+
160+
161+
def prepare_data_and_files(
162+
data: dict[str, Any] | None,
163+
files: dict[str, DioptraFile | list[DioptraFile]] | None,
164+
) -> dict[str, Any]:
165+
"""Prepare the data and files for the request.
166+
167+
Args:
168+
data: A dictionary to send in the body of the request as part of a multipart
169+
form.
170+
files: Dictionary of "name": DioptraFile or lists of DioptraFile pairs to be
171+
uploaded.
172+
173+
Returns:
174+
A dictionary containing the prepared data and files dictionary.
175+
"""
176+
merged: dict[str, Any] = {}
177+
178+
if data is not None:
179+
merged = merged | data
180+
181+
if files is not None:
182+
for key, value in files.items():
183+
if isinstance(value, DioptraFile):
184+
merged[key] = format_file_for_request(value)
185+
186+
else:
187+
formatted_files: list[FileStorage] = []
188+
189+
try:
190+
for dioptra_file in value:
191+
if not isinstance(dioptra_file, DioptraFile):
192+
raise IllegalArgumentError(
193+
"Illegal type for files (reason: a list can only "
194+
f"contain the DioptraFile type): {type(dioptra_file)}."
195+
)
196+
197+
formatted_files.append(format_file_for_request(dioptra_file))
198+
199+
except TypeError as err:
200+
raise IllegalArgumentError(
201+
"Illegal type for files (reason: must be a DioptraFile or a "
202+
f"list of DioptraFile): {type(value)}."
203+
) from err
204+
205+
merged[key] = formatted_files
206+
207+
return merged
208+
209+
144210
class DioptraFlaskClientSession(DioptraSession[DioptraResponseProtocol]):
145211
"""
146212
The interface for communicating with the Dioptra API using the FlaskClient.
@@ -173,6 +239,8 @@ def make_request(
173239
url: str,
174240
params: dict[str, Any] | None = None,
175241
json_: dict[str, Any] | None = None,
242+
data: dict[str, Any] | None = None,
243+
files: dict[str, DioptraFile | list[DioptraFile]] | None = None,
176244
) -> DioptraResponseProtocol:
177245
"""Make a request to the API.
178246
@@ -183,6 +251,10 @@ def make_request(
183251
params: The query parameters to include in the request. Optional, defaults
184252
to None.
185253
json_: The JSON data to include in the request. Optional, defaults to None.
254+
data: A dictionary to send in the body of the request as part of a
255+
multipart form. Optional, defaults to None.
256+
files: Dictionary of "name": DioptraFile or lists of DioptraFile pairs to be
257+
uploaded. Optional, defaults to None.
186258
187259
Returns:
188260
The response from the API.
@@ -212,12 +284,42 @@ def make_request(
212284
method = methods_registry[method_name]
213285
method_kwargs: dict[str, Any] = {"follow_redirects": True}
214286

287+
if method_name != "post":
288+
if data:
289+
raise IllegalArgumentError(
290+
"Illegal value for data (reason: data is only supported for POST "
291+
f"requests): {data}."
292+
)
293+
294+
if files:
295+
raise IllegalArgumentError(
296+
"Illegal value for files (reason: files is only supported for POST "
297+
f"requests): {files}."
298+
)
299+
215300
if json_:
301+
if data:
302+
raise IllegalArgumentError(
303+
"Illegal value for json_ (reason: json_ is not supported if data "
304+
f"is not None): {json_}."
305+
)
306+
307+
if files:
308+
raise IllegalArgumentError(
309+
"Illegal value for json_ (reason: json_ is not supported if files "
310+
f"is not None): {json_}."
311+
)
312+
216313
method_kwargs["json"] = json_
217314

218315
if params:
219316
method_kwargs["query_string"] = params
220317

318+
if data or files:
319+
merged_data = prepare_data_and_files(data=data, files=files)
320+
method_kwargs["data"] = merged_data
321+
method_kwargs["content_type"] = "multipart/form-data"
322+
221323
return method(url, **method_kwargs)
222324

223325
def download(
@@ -329,6 +431,8 @@ def post(
329431
*parts,
330432
params: dict[str, Any] | None = None,
331433
json_: dict[str, Any] | None = None,
434+
data: dict[str, Any] | None = None,
435+
files: dict[str, DioptraFile | list[DioptraFile]] | None = None,
332436
) -> DioptraResponseProtocol:
333437
"""Make a POST request to the API.
334438
@@ -341,11 +445,17 @@ def post(
341445
params: The query parameters to include in the request. Optional, defaults
342446
to None.
343447
json_: The JSON data to include in the request. Optional, defaults to None.
448+
data: A dictionary to send in the body of the request as part of a
449+
multipart form. Optional, defaults to None.
450+
files: Dictionary of "name": DioptraFile or lists of DioptraFile pairs to be
451+
uploaded. Optional, defaults to None.
344452
345453
Returns:
346454
A DioptraTestResponse object.
347455
"""
348-
return self._post(endpoint, *parts, params=params, json_=json_)
456+
return self._post(
457+
endpoint, *parts, params=params, json_=json_, data=data, files=files
458+
)
349459

350460
def delete(
351461
self,

0 commit comments

Comments
 (0)