|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import typing |
4 | | -from typing import Any, List, Dict |
| 4 | +from typing import Any |
5 | 5 | import base64 |
6 | 6 | import traceback |
7 | 7 |
|
|
26 | 26 | from ..utils import constants |
27 | 27 |
|
28 | 28 |
|
29 | | -def _deserialize_message_chain(data: List[Dict[str, Any]]) -> platform_message.MessageChain: |
30 | | - """Deserialize message chain with proper handling of Forward messages. |
31 | | -
|
32 | | - The default MessageChain.model_validate doesn't properly deserialize nested |
33 | | - MessageChain in ForwardMessageNode, causing data loss. This function handles |
34 | | - that case explicitly. |
35 | | - """ |
36 | | - components = [] |
37 | | - component_types = platform_message.MessageChain._get_component_types() |
38 | | - |
39 | | - for item in data: |
40 | | - if not isinstance(item, dict) or 'type' not in item: |
41 | | - components.append(platform_message.Unknown(text=f'Invalid component: {item}')) |
42 | | - continue |
43 | | - |
44 | | - comp_type = item['type'] |
45 | | - if comp_type not in component_types: |
46 | | - components.append(platform_message.Unknown(text=f'Unknown type: {comp_type}')) |
47 | | - continue |
48 | | - |
49 | | - comp_class = component_types[comp_type] |
50 | | - |
51 | | - # Special handling for Forward messages |
52 | | - if comp_type == 'Forward': |
53 | | - node_list = [] |
54 | | - for node_data in item.get('node_list', []): |
55 | | - # Recursively deserialize message_chain in each node |
56 | | - mc_data = node_data.get('message_chain', []) |
57 | | - if isinstance(mc_data, list): |
58 | | - mc = _deserialize_message_chain(mc_data) |
59 | | - elif isinstance(mc_data, platform_message.MessageChain): |
60 | | - mc = mc_data |
61 | | - else: |
62 | | - mc = platform_message.MessageChain([]) |
63 | | - |
64 | | - node = platform_message.ForwardMessageNode( |
65 | | - sender_id=node_data.get('sender_id', ''), |
66 | | - sender_name=node_data.get('sender_name', ''), |
67 | | - message_chain=mc, |
68 | | - message_id=node_data.get('message_id', 0), |
69 | | - ) |
70 | | - node_list.append(node) |
71 | | - |
72 | | - display_data = item.get('display', {}) |
73 | | - display = platform_message.ForwardMessageDiaplay( |
74 | | - title=display_data.get('title', 'Chat history'), |
75 | | - brief=display_data.get('brief', '[Chat history]'), |
76 | | - source=display_data.get('source', 'Chat history'), |
77 | | - preview=display_data.get('preview', []), |
78 | | - summary=display_data.get('summary', 'View forwarded messages'), |
79 | | - ) |
80 | | - |
81 | | - forward = platform_message.Forward( |
82 | | - display=display, |
83 | | - node_list=node_list, |
84 | | - ) |
85 | | - components.append(forward) |
86 | | - else: |
87 | | - # For other component types, use default validation |
88 | | - # but handle Quote's nested MessageChain |
89 | | - if comp_type == 'Quote' and 'origin' in item: |
90 | | - origin_data = item['origin'] |
91 | | - if isinstance(origin_data, list): |
92 | | - item['origin'] = _deserialize_message_chain(origin_data) |
93 | | - |
94 | | - components.append(comp_class.model_validate(item)) |
95 | | - |
96 | | - return platform_message.MessageChain(root=components) |
97 | | - |
98 | | - |
99 | 29 | class RuntimeConnectionHandler(handler.Handler): |
100 | 30 | """Runtime connection handler""" |
101 | 31 |
|
@@ -350,7 +280,7 @@ async def send_message(data: dict[str, Any]) -> handler.ActionResponse: |
350 | 280 | message_chain = data['message_chain'] |
351 | 281 |
|
352 | 282 | # Use custom deserializer that properly handles Forward messages |
353 | | - message_chain_obj = _deserialize_message_chain(message_chain) |
| 283 | + message_chain_obj = platform_message.MessageChain.model_validate(message_chain) |
354 | 284 |
|
355 | 285 | bot = await self.ap.platform_mgr.get_bot_by_uuid(bot_uuid) |
356 | 286 | if bot is None: |
|
0 commit comments