5
5
from typing import Any , Optional , ClassVar , Union , Dict
6
6
7
7
import aiohttp
8
- from aiohttp import ClientResponse , FormData , TCPConnector
8
+ from aiohttp import ClientResponse , FormData , TCPConnector , multipart , hdrs , payload
9
9
10
10
from . import logging
11
11
from .errors import HttpErrorDict , ServerError
20
20
HTTP_OK_STATUS = [200 , 202 , 204 ]
21
21
22
22
23
+ class _FormData (FormData ):
24
+ def _gen_form_data (self ) -> multipart .MultipartWriter :
25
+ """Encode a list of fields using the multipart/form-data MIME format"""
26
+ if self ._is_processed :
27
+ return self ._writer # rewrite this part of FormData object to enable retry of request
28
+ for dispparams , headers , value in self ._fields :
29
+ try :
30
+ if hdrs .CONTENT_TYPE in headers :
31
+ part = payload .get_payload (
32
+ value ,
33
+ content_type = headers [hdrs .CONTENT_TYPE ],
34
+ headers = headers ,
35
+ encoding = self ._charset ,
36
+ )
37
+ else :
38
+ part = payload .get_payload (
39
+ value , headers = headers , encoding = self ._charset
40
+ )
41
+ except Exception as exc :
42
+ print (value )
43
+ raise TypeError (
44
+ "Can not serialize value type: %r\n "
45
+ "headers: %r\n value: %r" % (type (value ), headers , value )
46
+ ) from exc
47
+
48
+ if dispparams :
49
+ part .set_content_disposition (
50
+ "form-data" , quote_fields = self ._quote_fields , ** dispparams
51
+ )
52
+ assert part .headers is not None
53
+ part .headers .popall (hdrs .CONTENT_LENGTH , None )
54
+
55
+ self ._writer .append_payload (part )
56
+
57
+ self ._is_processed = True
58
+ return self ._writer
59
+
60
+
23
61
async def _handle_response (response : ClientResponse ) -> Union [Dict [str , Any ], str ]:
24
62
url = response .request_info .url
25
63
try :
@@ -41,8 +79,8 @@ async def _handle_response(response: ClientResponse) -> Union[Dict[str, Any], st
41
79
# type of data should be dict or str or None, so there should be a condition to check and prevent bug
42
80
message = data ["message" ] if isinstance (data , dict ) else str (data )
43
81
if not error_dict_get :
44
- raise ServerError (message )
45
- raise error_dict_get (msg = message )
82
+ raise ServerError (message ) from None # adding from None to prevent chain exception being raised
83
+ raise error_dict_get (msg = message ) from None
46
84
47
85
48
86
class Route :
@@ -91,8 +129,13 @@ def __init__(
91
129
self ._global_over : Optional [asyncio .Event ] = None
92
130
self ._headers : Optional [dict ] = None
93
131
132
+ def __del__ (self ):
133
+ if self ._session and not self ._session .closed :
134
+ _loop = asyncio .get_event_loop ()
135
+ _loop .create_task (self ._session .close ())
136
+
94
137
async def close (self ) -> None :
95
- if self ._session :
138
+ if self ._session and not self . _session . closed :
96
139
await self ._session .close ()
97
140
98
141
async def check_session (self ):
@@ -104,7 +147,7 @@ async def check_session(self):
104
147
105
148
if not self ._session or self ._session .closed :
106
149
self ._session = aiohttp .ClientSession (
107
- headers = self ._headers , connector = TCPConnector (limit = 500 , ssl = SSLContext ())
150
+ headers = self ._headers , connector = TCPConnector (limit = 500 , ssl = SSLContext (), force_close = True )
108
151
)
109
152
110
153
async def request (self , route : Route , retry_time : int = 0 , ** kwargs : Any ):
@@ -115,7 +158,7 @@ async def request(self, route: Route, retry_time: int = 0, **kwargs: Any):
115
158
json_ = kwargs ["json" ]
116
159
json__get = json_ .get ("file_image" )
117
160
if json__get and isinstance (json__get , bytes ):
118
- kwargs ["data" ] = FormData ()
161
+ kwargs ["data" ] = _FormData ()
119
162
for k , v in kwargs .pop ("json" ).items ():
120
163
if v :
121
164
if isinstance (v , dict ):
@@ -142,10 +185,9 @@ async def request(self, route: Route, retry_time: int = 0, **kwargs: Any):
142
185
_log .debug (response )
143
186
return await _handle_response (response )
144
187
except asyncio .TimeoutError :
145
- _log .debug ("session timeout retry" )
146
- self ._session = aiohttp .ClientSession (
147
- headers = self ._headers , connector = TCPConnector (limit = 500 , ssl = SSLContext ())
148
- )
188
+ _log .warning (f"请求超时,请求连接: { route .url } " )
189
+ except ConnectionResetError :
190
+ _log .debug ("session connection broken retry" )
149
191
await self .request (route , retry_time + 1 , ** kwargs )
150
192
151
193
async def login (self , token : Token ) -> robot .Robot :
0 commit comments