99from pydantic .version import VERSION as PYDANTIC_VERSION
1010
1111from arrest .converters import compile_path
12- from arrest .defaults import HEADER_DEFAULTS , TIMEOUT_DEFAULT
13- from arrest .exceptions import ArrestError , ArrestHTTPException , HandlerNotFound
12+ from arrest .defaults import TIMEOUT_DEFAULT
13+ from arrest .exceptions import ArrestHTTPException , HandlerNotFound
1414from arrest .handler import HandlerKey , ResourceHandler
1515from arrest .http import Methods
1616from arrest .logging import logger
1717from arrest .params import Param , Params , ParamTypes
18- from arrest .utils import extract_model_field , join_url , jsonify
18+ from arrest .utils import extract_model_field , join_url , jsonify , validate_request_model
1919
2020
2121class Resource :
@@ -43,7 +43,7 @@ def __init__(
4343 name : Optional [str ] = None ,
4444 * ,
4545 route : Optional [str ],
46- headers : Optional [dict ] = HEADER_DEFAULTS ,
46+ headers : Optional [dict ] = None ,
4747 timeout : Optional [int ] = TIMEOUT_DEFAULT ,
4848 response_model : Optional [Type [BaseModel ]] = None ,
4949 handlers : Union [
@@ -133,11 +133,9 @@ async def request(
133133
134134 params : dict = {}
135135
136- if not (
137- match := self .get_matching_handler (method = method , path = path , ** kwargs )
138- ):
136+ if not (match := self .get_matching_handler (method = method , path = path , ** kwargs )):
139137 logger .warning ("no matching handler found for request" )
140- raise HandlerNotFound ("no matching handler found for request" )
138+ raise HandlerNotFound (message = "no matching handler found for request" )
141139
142140 handler , url = match
143141
@@ -161,9 +159,9 @@ async def request(
161159 callback_response = await handler .callback (response )
162160 else :
163161 callback_response = handler .callback (response )
164- except Exception as exc :
162+ except Exception :
165163 logger .warning ("something went wrong during callback" , exc_info = True )
166- raise ArrestError ( str ( exc )) from exc
164+ raise
167165 return callback_response
168166
169167 return response
@@ -344,27 +342,24 @@ def extract_request_params(
344342 """
345343
346344 header_params = headers or {}
347- header_params |= self .headers
345+ if self .headers :
346+ header_params |= self .headers
348347 query_params = query or {}
349348 body_params = {}
350349
351350 if request_type :
352351 # perform type validation on `request_data`
353- request_data = request_type . model_validate ( request_data )
352+ request_data = validate_request_model ( type_ = request_type , obj = request_data )
354353
355354 if isinstance (request_data , BaseModel ):
356355 # extract pydantic fields into `Query`, `Body` and `Header`
357356 model_fields : dict = (
358- request_data .__fields__
359- if PYDANTIC_VERSION .startswith ("2." )
360- else request_data .model_fields
357+ request_data .__fields__ if PYDANTIC_VERSION .startswith ("2." ) else request_data .model_fields
361358 )
362359
363360 for field , field_info in model_fields .items ():
364361 field_info = cast (Param , field_info )
365- if not hasattr (field_info , "_param_type" ) and isinstance (
366- field_info , FieldInfo
367- ):
362+ if not hasattr (field_info , "_param_type" ) and isinstance (field_info , FieldInfo ):
368363 body_params |= extract_model_field (request_data , field )
369364 elif field_info ._param_type == ParamTypes .query :
370365 query_params |= extract_model_field (request_data , field )
@@ -416,9 +411,7 @@ async def __make_request(
416411 params .body ,
417412 )
418413 try :
419- async with httpx .AsyncClient (
420- timeout = self .timeout , headers = headers
421- ) as client :
414+ async with httpx .AsyncClient (timeout = self .timeout , headers = headers ) as client :
422415 match method :
423416 case Methods .GET :
424417 response = await client .get (url = url , params = query_params )
@@ -452,19 +445,15 @@ async def __make_request(
452445 response = await client .options (url = url , params = query_params )
453446
454447 status_code = response .status_code
455- logger .debug (
456- f"{ method !s} { url } returned with status code { status_code !s} "
457- )
448+ logger .debug (f"{ method !s} { url } returned with status code { status_code !s} " )
458449 response .raise_for_status ()
459450 response_body = response .json ()
460451
461452 # parse response to pydantic model
462453 parsed_response = response_body
463454 if response_type :
464455 if isinstance (response_body , list ):
465- parsed_response = [
466- response_type (** item ) for item in response_body
467- ]
456+ parsed_response = [response_type (** item ) for item in response_body ]
468457 elif isinstance (response_body , dict ):
469458 parsed_response = response_type (** response_body )
470459 else :
@@ -477,9 +466,7 @@ async def __make_request(
477466 # exception handling
478467 except httpx .HTTPStatusError as exc :
479468 err_response_body = exc .response .json ()
480- raise ArrestHTTPException (
481- status_code = exc .response .status_code , data = err_response_body
482- ) from exc
469+ raise ArrestHTTPException (status_code = exc .response .status_code , data = err_response_body ) from exc
483470
484471 except httpx .TimeoutException :
485472 raise ArrestHTTPException (
@@ -502,9 +489,7 @@ def get_matching_handler(
502489 url = join_url (self .base_url , self .route , parsed_path )
503490 return handler , url
504491
505- def _bind_handler (
506- self , base_url : str | None = None , * , handler : ResourceHandler
507- ) -> None :
492+ def _bind_handler (self , base_url : str | None = None , * , handler : ResourceHandler ) -> None :
508493 """
509494 compose a fully-qualified url by joining base service url, resource url
510495 and handler url,
@@ -514,18 +499,14 @@ def _bind_handler(
514499 """
515500
516501 base_url = base_url or self .base_url
517- handler .path_regex , handler .path_format , handler .param_types = compile_path (
518- handler .route
519- )
502+ handler .path_regex , handler .path_format , handler .param_types = compile_path (handler .route )
520503
521504 self .routes [HandlerKey (* (handler .method , handler .path_format ))] = handler
522505
523506 def initialize_handlers (
524507 self ,
525508 base_url : str | None = None ,
526- handlers : list [ResourceHandler ]
527- | list [Mapping [str , Any ]]
528- | list [tuple [Any , ...]] = None ,
509+ handlers : list [ResourceHandler ] | list [Mapping [str , Any ]] | list [tuple [Any , ...]] = None ,
529510 ) -> None :
530511 """
531512 specifically used to inject `base_url` from a Service class to
@@ -541,18 +522,12 @@ def initialize_handlers(
541522 for _handler in handlers :
542523 try :
543524 if isinstance (_handler , dict ):
544- self ._bind_handler (
545- base_url = base_url , handler = ResourceHandler (** _handler )
546- )
525+ self ._bind_handler (base_url = base_url , handler = ResourceHandler (** _handler ))
547526 elif isinstance (_handler , tuple ):
548527 if len (_handler ) < 2 :
549- raise ValueError (
550- "Too few arguments to unpack. Expected atleast 2"
551- )
528+ raise ValueError ("Too few arguments to unpack. Expected atleast 2" )
552529 if len (_handler ) > 5 :
553- raise ValueError (
554- f"Too many arguments to unpack. Expected 5, got { len (_handler )} "
555- )
530+ raise ValueError (f"Too many arguments to unpack. Expected 5, got { len (_handler )} " )
556531
557532 method , route , rest = (
558533 _handler [0 ],
0 commit comments