Skip to content

Commit 3a5745b

Browse files
GLGDLYSaucePlum
andauthored
fix & feat: 修复重试本地图片报错&避免unclosed client session&增加私信图片attachments (#116)
* fix: 重写FormData的_gen_form_data方法,以避免重试时由于已process过一次而导致的Form data has been processed already错误 * feat: 为BotHttp增加__del__方法避免出现unclosed client session的问题 * feat: ...补充 * fix: 修复RuntimeWarning问题 * fix: 改善already running的问题 * fix: 进一步改善unclosed client session的问题 * fix: http重新设置client session时调用close * fix: 优化Cannot write to closing transport * feat: 私信加入attachments字段 * fix: http去除超时重试,避免出现多次请求(特指因为重试而连续发送相同消息)的问题 fix: http的_handle_response中raise加入from None避免出现重试下的连锁错误(During handling of the above exception, another exception occurred) Co-authored-by: 小念同学 <[email protected]>
1 parent 2cc718f commit 3a5745b

File tree

4 files changed

+70
-12
lines changed

4 files changed

+70
-12
lines changed

botpy/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def ws_dispatch(self, event: str, *args: Any, **kwargs: Any) -> None:
246246
247247
解析client类的on_event事件,进行对应的事件回调
248248
"""
249-
_log.info("[botpy] 调度事件: %s", event)
249+
_log.debug("[botpy] 调度事件: %s", event)
250250
method = "on_" + event
251251

252252
try:
@@ -265,7 +265,7 @@ def _schedule_event(
265265
) -> asyncio.Task:
266266
wrapped = self._run_event(coro, event_name, *args, **kwargs)
267267
# Schedules the task
268-
return self.loop.create_task(wrapped, name=f"botpy: {event_name}")
268+
return self.loop.create_task(wrapped, name=f"[botpy] {event_name}")
269269

270270
async def _run_event(
271271
self,

botpy/http.py

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import Any, Optional, ClassVar, Union, Dict
66

77
import aiohttp
8-
from aiohttp import ClientResponse, FormData, TCPConnector
8+
from aiohttp import ClientResponse, FormData, TCPConnector, multipart, hdrs, payload
99

1010
from . import logging
1111
from .errors import HttpErrorDict, ServerError
@@ -20,6 +20,44 @@
2020
HTTP_OK_STATUS = [200, 202, 204]
2121

2222

23+
class _FormData(FormData):
24+
def _gen_form_data(self) -> multipart.MultipartWriter:
25+
"""Encode a list of fields using the multipart/form-data MIME format"""
26+
if self._is_processed:
27+
return self._writer # rewrite this part of FormData object to enable retry of request
28+
for dispparams, headers, value in self._fields:
29+
try:
30+
if hdrs.CONTENT_TYPE in headers:
31+
part = payload.get_payload(
32+
value,
33+
content_type=headers[hdrs.CONTENT_TYPE],
34+
headers=headers,
35+
encoding=self._charset,
36+
)
37+
else:
38+
part = payload.get_payload(
39+
value, headers=headers, encoding=self._charset
40+
)
41+
except Exception as exc:
42+
print(value)
43+
raise TypeError(
44+
"Can not serialize value type: %r\n "
45+
"headers: %r\n value: %r" % (type(value), headers, value)
46+
) from exc
47+
48+
if dispparams:
49+
part.set_content_disposition(
50+
"form-data", quote_fields=self._quote_fields, **dispparams
51+
)
52+
assert part.headers is not None
53+
part.headers.popall(hdrs.CONTENT_LENGTH, None)
54+
55+
self._writer.append_payload(part)
56+
57+
self._is_processed = True
58+
return self._writer
59+
60+
2361
async def _handle_response(response: ClientResponse) -> Union[Dict[str, Any], str]:
2462
url = response.request_info.url
2563
try:
@@ -41,8 +79,8 @@ async def _handle_response(response: ClientResponse) -> Union[Dict[str, Any], st
4179
# type of data should be dict or str or None, so there should be a condition to check and prevent bug
4280
message = data["message"] if isinstance(data, dict) else str(data)
4381
if not error_dict_get:
44-
raise ServerError(message)
45-
raise error_dict_get(msg=message)
82+
raise ServerError(message) from None # adding from None to prevent chain exception being raised
83+
raise error_dict_get(msg=message) from None
4684

4785

4886
class Route:
@@ -91,8 +129,13 @@ def __init__(
91129
self._global_over: Optional[asyncio.Event] = None
92130
self._headers: Optional[dict] = None
93131

132+
def __del__(self):
133+
if self._session and not self._session.closed:
134+
_loop = asyncio.get_event_loop()
135+
_loop.create_task(self._session.close())
136+
94137
async def close(self) -> None:
95-
if self._session:
138+
if self._session and not self._session.closed:
96139
await self._session.close()
97140

98141
async def check_session(self):
@@ -104,7 +147,7 @@ async def check_session(self):
104147

105148
if not self._session or self._session.closed:
106149
self._session = aiohttp.ClientSession(
107-
headers=self._headers, connector=TCPConnector(limit=500, ssl=SSLContext())
150+
headers=self._headers, connector=TCPConnector(limit=500, ssl=SSLContext(), force_close=True)
108151
)
109152

110153
async def request(self, route: Route, retry_time: int = 0, **kwargs: Any):
@@ -115,7 +158,7 @@ async def request(self, route: Route, retry_time: int = 0, **kwargs: Any):
115158
json_ = kwargs["json"]
116159
json__get = json_.get("file_image")
117160
if json__get and isinstance(json__get, bytes):
118-
kwargs["data"] = FormData()
161+
kwargs["data"] = _FormData()
119162
for k, v in kwargs.pop("json").items():
120163
if v:
121164
if isinstance(v, dict):
@@ -142,10 +185,9 @@ async def request(self, route: Route, retry_time: int = 0, **kwargs: Any):
142185
_log.debug(response)
143186
return await _handle_response(response)
144187
except asyncio.TimeoutError:
145-
_log.debug("session timeout retry")
146-
self._session = aiohttp.ClientSession(
147-
headers=self._headers, connector=TCPConnector(limit=500, ssl=SSLContext())
148-
)
188+
_log.warning(f"请求超时,请求连接: {route.url}")
189+
except ConnectionResetError:
190+
_log.debug("session connection broken retry")
149191
await self.request(route, retry_time + 1, **kwargs)
150192

151193
async def login(self, token: Token) -> robot.Robot:

botpy/message.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ class DirectMessage:
9595
"guild_id",
9696
"member",
9797
"message_reference",
98+
"attachments",
9899
"seq",
99100
"seq_in_channel",
100101
"src_guild_id",
@@ -113,6 +114,7 @@ def __init__(self, api: BotAPI, event_id, data: gateway.DirectMessagePayload):
113114
self.guild_id = data.get("guild_id", None)
114115
self.member = self._Member(data.get("member", {}))
115116
self.message_reference = self._MessageRef(data.get("message_reference", {}))
117+
self.attachments = [self._Attachments(items) for items in data.get("attachments", {})]
116118
self.seq = data.get("seq", None) # 全局消息序号
117119
self.seq_in_channel = data.get("seq_in_channel", None) # 子频道消息序号
118120
self.src_guild_id = data.get("src_guild_id", None)
@@ -145,6 +147,19 @@ def __init__(self, data):
145147
def __repr__(self):
146148
return str(self.__dict__)
147149

150+
class _Attachments:
151+
def __init__(self, data):
152+
self.content_type = data.get("content_type", None)
153+
self.filename = data.get("filename", None)
154+
self.height = data.get("height", None)
155+
self.width = data.get("width", None)
156+
self.id = data.get("id", None)
157+
self.size = data.get("size", None)
158+
self.url = data.get("url", None)
159+
160+
def __repr__(self):
161+
return str(self.__dict__)
162+
148163
async def reply(self, **kwargs):
149164
return await self._api.post_dms(guild_id=self.guild_id, msg_id=self.id, **kwargs)
150165

botpy/types/gateway.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class DirectMessagePayload(TypedDict):
6464
id: str
6565
member: Member
6666
message_reference: MessageRefPayload
67+
attachments: List[MessageAttachPayload]
6768
seq: int
6869
seq_in_channel: str
6970
src_guild_id: str

0 commit comments

Comments
 (0)