2121import structlog
2222from flask .testing import FlaskClient
2323from structlog .stdlib import BoundLogger
24+ from werkzeug .datastructures import FileStorage
2425from werkzeug .test import TestResponse
2526
2627from dioptra .client .base import (
2728 DioptraClientError ,
29+ DioptraFile ,
2830 DioptraRequestProtocol ,
2931 DioptraResponseProtocol ,
3032 DioptraSession ,
33+ IllegalArgumentError ,
3134 StatusCodeError ,
3235)
3336from dioptra .restapi .routes import V1_ROOT
@@ -141,6 +144,69 @@ def is_2xx(status_code: int) -> bool:
141144 return status_code >= HTTPStatus .OK and status_code < HTTPStatus .MULTIPLE_CHOICES
142145
143146
147+ def format_file_for_request (file_ : DioptraFile ) -> FileStorage :
148+ """Format the DioptraFile object into a FlaskClient-compatible data structure.
149+
150+ Returns:
151+ The file encoded as a Werkzeug FileStorage object.
152+ """
153+ if file_ .content_type is None :
154+ return FileStorage (stream = file_ .stream , filename = file_ .filename )
155+
156+ return FileStorage (
157+ stream = file_ .stream , filename = file_ .filename , content_type = file_ .content_type
158+ )
159+
160+
161+ def prepare_data_and_files (
162+ data : dict [str , Any ] | None ,
163+ files : dict [str , DioptraFile | list [DioptraFile ]] | None ,
164+ ) -> dict [str , Any ]:
165+ """Prepare the data and files for the request.
166+
167+ Args:
168+ data: A dictionary to send in the body of the request as part of a multipart
169+ form.
170+ files: Dictionary of "name": DioptraFile or lists of DioptraFile pairs to be
171+ uploaded.
172+
173+ Returns:
174+ A dictionary containing the prepared data and files dictionary.
175+ """
176+ merged : dict [str , Any ] = {}
177+
178+ if data is not None :
179+ merged = merged | data
180+
181+ if files is not None :
182+ for key , value in files .items ():
183+ if isinstance (value , DioptraFile ):
184+ merged [key ] = format_file_for_request (value )
185+
186+ else :
187+ formatted_files : list [FileStorage ] = []
188+
189+ try :
190+ for dioptra_file in value :
191+ if not isinstance (dioptra_file , DioptraFile ):
192+ raise IllegalArgumentError (
193+ "Illegal type for files (reason: a list can only "
194+ f"contain the DioptraFile type): { type (dioptra_file )} ."
195+ )
196+
197+ formatted_files .append (format_file_for_request (dioptra_file ))
198+
199+ except TypeError as err :
200+ raise IllegalArgumentError (
201+ "Illegal type for files (reason: must be a DioptraFile or a "
202+ f"list of DioptraFile): { type (value )} ."
203+ ) from err
204+
205+ merged [key ] = formatted_files
206+
207+ return merged
208+
209+
144210class DioptraFlaskClientSession (DioptraSession [DioptraResponseProtocol ]):
145211 """
146212 The interface for communicating with the Dioptra API using the FlaskClient.
@@ -173,6 +239,8 @@ def make_request(
173239 url : str ,
174240 params : dict [str , Any ] | None = None ,
175241 json_ : dict [str , Any ] | None = None ,
242+ data : dict [str , Any ] | None = None ,
243+ files : dict [str , DioptraFile | list [DioptraFile ]] | None = None ,
176244 ) -> DioptraResponseProtocol :
177245 """Make a request to the API.
178246
@@ -183,6 +251,10 @@ def make_request(
183251 params: The query parameters to include in the request. Optional, defaults
184252 to None.
185253 json_: The JSON data to include in the request. Optional, defaults to None.
254+ data: A dictionary to send in the body of the request as part of a
255+ multipart form. Optional, defaults to None.
256+ files: Dictionary of "name": DioptraFile or lists of DioptraFile pairs to be
257+ uploaded. Optional, defaults to None.
186258
187259 Returns:
188260 The response from the API.
@@ -212,12 +284,42 @@ def make_request(
212284 method = methods_registry [method_name ]
213285 method_kwargs : dict [str , Any ] = {"follow_redirects" : True }
214286
287+ if method_name != "post" :
288+ if data :
289+ raise IllegalArgumentError (
290+ "Illegal value for data (reason: data is only supported for POST "
291+ f"requests): { data } ."
292+ )
293+
294+ if files :
295+ raise IllegalArgumentError (
296+ "Illegal value for files (reason: files is only supported for POST "
297+ f"requests): { files } ."
298+ )
299+
215300 if json_ :
301+ if data :
302+ raise IllegalArgumentError (
303+ "Illegal value for json_ (reason: json_ is not supported if data "
304+ f"is not None): { json_ } ."
305+ )
306+
307+ if files :
308+ raise IllegalArgumentError (
309+ "Illegal value for json_ (reason: json_ is not supported if files "
310+ f"is not None): { json_ } ."
311+ )
312+
216313 method_kwargs ["json" ] = json_
217314
218315 if params :
219316 method_kwargs ["query_string" ] = params
220317
318+ if data or files :
319+ merged_data = prepare_data_and_files (data = data , files = files )
320+ method_kwargs ["data" ] = merged_data
321+ method_kwargs ["content_type" ] = "multipart/form-data"
322+
221323 return method (url , ** method_kwargs )
222324
223325 def download (
@@ -329,6 +431,8 @@ def post(
329431 * parts ,
330432 params : dict [str , Any ] | None = None ,
331433 json_ : dict [str , Any ] | None = None ,
434+ data : dict [str , Any ] | None = None ,
435+ files : dict [str , DioptraFile | list [DioptraFile ]] | None = None ,
332436 ) -> DioptraResponseProtocol :
333437 """Make a POST request to the API.
334438
@@ -341,11 +445,17 @@ def post(
341445 params: The query parameters to include in the request. Optional, defaults
342446 to None.
343447 json_: The JSON data to include in the request. Optional, defaults to None.
448+ data: A dictionary to send in the body of the request as part of a
449+ multipart form. Optional, defaults to None.
450+ files: Dictionary of "name": DioptraFile or lists of DioptraFile pairs to be
451+ uploaded. Optional, defaults to None.
344452
345453 Returns:
346454 A DioptraTestResponse object.
347455 """
348- return self ._post (endpoint , * parts , params = params , json_ = json_ )
456+ return self ._post (
457+ endpoint , * parts , params = params , json_ = json_ , data = data , files = files
458+ )
349459
350460 def delete (
351461 self ,
0 commit comments