11import inspect
22from collections import defaultdict
33from logging import getLogger
4- from typing import Any , Awaitable , Callable , Dict , Optional , Union
4+ from typing import Any , Awaitable , Callable , Dict , Optional , TypeVar , get_type_hints
55
6+ import pydantic
67from aiohttp import web
7- from pydantic import schema_of
8- from pydantic .utils import deep_update
8+ from deepmerge import always_merger
99from taskiq_dependencies import DependencyGraph
1010
1111from aiohttp_deps .initializer import InjectableFuncHandler , InjectableViewHandler
1212from aiohttp_deps .utils import Form , Header , Json , Path , Query
1313
14- REF_TEMPLATE = "#/components/schemas/{model}"
14+ _T = TypeVar ("_T" ) # noqa: WPS111
15+
1516SCHEMA_KEY = "openapi_schema"
1617SWAGGER_HTML_TEMPALTE = """
1718<html lang="en">
@@ -67,19 +68,14 @@ def _is_optional(annotation: Optional[inspect.Parameter]) -> bool:
6768 if annotation is None or annotation .annotation == annotation .empty :
6869 return True
6970
70- origin = getattr (annotation .annotation , "__origin__" , None )
71- if origin is None :
72- return False
71+ def dummy (_var : annotation .annotation ) -> None : # type: ignore
72+ """Dummy function to use for type resolution."""
7373
74- if origin == Union :
75- args = getattr (annotation .annotation , "__args__" , ())
76- for arg in args :
77- if arg is type (None ): # noqa: E721, WPS516
78- return True
79- return False
74+ var = get_type_hints (dummy ).get ("_var" )
75+ return var == Optional [var ]
8076
8177
82- def _add_route_def ( # noqa: C901
78+ def _add_route_def ( # noqa: C901, WPS210
8379 openapi_schema : Dict [str , Any ],
8480 route : web .ResourceRoute ,
8581 method : str ,
@@ -94,6 +90,19 @@ def _add_route_def( # noqa: C901
9490 if route .resource is None : # pragma: no cover
9591 return
9692
93+ params : Dict [tuple [str , str ], Any ] = {}
94+
95+ def _insert_in_params (data : Dict [str , Any ]) -> None :
96+ element = params .get ((data ["name" ], data ["in" ]))
97+ if element is None :
98+ params [(data ["name" ], data ["in" ])] = data
99+ return
100+ element ["required" ] = element .get ("required" ) or data .get ("required" )
101+ element ["allowEmptyValue" ] = bool (element .get ("allowEmptyValue" )) and bool (
102+ data .get ("allowEmptyValue" ),
103+ )
104+ params [(data ["name" ], data ["in" ])] = element
105+
97106 for dependency in graph .ordered_deps :
98107 if isinstance (dependency .dependency , (Json , Form )):
99108 content_type = "application/json"
@@ -103,10 +112,9 @@ def _add_route_def( # noqa: C901
103112 dependency .signature
104113 and dependency .signature .annotation != inspect .Parameter .empty
105114 ):
106- input_schema = schema_of (
115+ input_schema = pydantic . TypeAdapter (
107116 dependency .signature .annotation ,
108- ref_template = REF_TEMPLATE ,
109- )
117+ ).json_schema ()
110118 openapi_schema ["components" ]["schemas" ].update (
111119 input_schema .pop ("definitions" , {}),
112120 )
@@ -118,7 +126,7 @@ def _add_route_def( # noqa: C901
118126 "content" : {content_type : {}},
119127 }
120128 elif isinstance (dependency .dependency , Query ):
121- route_info [ "parameters" ]. append (
129+ _insert_in_params (
122130 {
123131 "name" : dependency .dependency .alias or dependency .param_name ,
124132 "in" : "query" ,
@@ -127,16 +135,17 @@ def _add_route_def( # noqa: C901
127135 },
128136 )
129137 elif isinstance (dependency .dependency , Header ):
130- route_info ["parameters" ].append (
138+ name = dependency .dependency .alias or dependency .param_name
139+ _insert_in_params (
131140 {
132- "name" : dependency . dependency . alias or dependency . param_name ,
141+ "name" : name . capitalize () ,
133142 "in" : "header" ,
134143 "description" : dependency .dependency .description ,
135144 "required" : not _is_optional (dependency .signature ),
136145 },
137146 )
138147 elif isinstance (dependency .dependency , Path ):
139- route_info [ "parameters" ]. append (
148+ _insert_in_params (
140149 {
141150 "name" : dependency .dependency .alias or dependency .param_name ,
142151 "in" : "path" ,
@@ -146,8 +155,9 @@ def _add_route_def( # noqa: C901
146155 },
147156 )
148157
158+ route_info ["parameters" ] = list (params .values ())
149159 openapi_schema ["paths" ][route .resource .canonical ].update (
150- {method .lower (): deep_update (route_info , extra_openapi )},
160+ {method .lower (): always_merger . merge (route_info , extra_openapi )},
151161 )
152162
153163
@@ -264,7 +274,7 @@ async def event_handler(app: web.Application) -> None:
264274 return event_handler
265275
266276
267- def extra_openapi (additional_schema : Dict [str , Any ]) -> Callable [..., Any ]:
277+ def extra_openapi (additional_schema : Dict [str , Any ]) -> Callable [[ _T ], _T ]:
268278 """
269279 Add extra openapi schema.
270280
@@ -275,8 +285,46 @@ def extra_openapi(additional_schema: Dict[str, Any]) -> Callable[..., Any]:
275285 :return: same function with new attributes.
276286 """
277287
278- def decorator (func : Any ) -> Any :
279- func .__extra_openapi__ = additional_schema
288+ def decorator (func : _T ) -> _T :
289+ func .__extra_openapi__ = additional_schema # type: ignore
290+ return func
291+
292+ return decorator
293+
294+
295+ def openapi_response (
296+ status : int ,
297+ model : Any ,
298+ * ,
299+ content_type : str = "application/json" ,
300+ description : Optional [str ] = None ,
301+ ) -> Callable [[_T ], _T ]:
302+ """
303+ Add response schema to the endpoint.
304+
305+ This function takes a status and model,
306+ which is going to represent the response.
307+
308+ :param status: Status of a response.
309+ :param model: Response model.
310+ :param content_type: Content-type of a response.
311+ :param description: Response's description.
312+
313+ :returns: decorator that modifies your function.
314+ """
315+
316+ def decorator (func : _T ) -> _T :
317+ openapi = getattr (func , "__extra_openapi__" , {})
318+ adapter : "pydantic.TypeAdapter[Any]" = pydantic .TypeAdapter (model )
319+ responses = openapi .get ("responses" , {})
320+ status_response = responses .get (status , {})
321+ if not status_response :
322+ status_response ["description" ] = description
323+ status_response ["content" ] = status_response .get ("content" , {})
324+ status_response ["content" ][content_type ] = {"schema" : adapter .json_schema ()}
325+ responses [status ] = status_response
326+ openapi ["responses" ] = responses
327+ func .__extra_openapi__ = openapi # type: ignore
280328 return func
281329
282330 return decorator
0 commit comments