Skip to content

Commit 7f7c3dc

Browse files
committed
feat(a2a): added ability to resubscribe the task stream
1 parent a81e118 commit 7f7c3dc

File tree

3 files changed

+105
-42
lines changed

3 files changed

+105
-42
lines changed

AgentCrew/modules/a2a/common/client/client.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
TaskPushNotificationConfig,
2020
SendStreamingMessageResponse,
2121
SetTaskPushNotificationConfigResponse,
22+
TaskResubscriptionRequest,
23+
TaskIdParams,
2224
)
2325

2426
if TYPE_CHECKING:
@@ -33,7 +35,6 @@
3335
SendMessageRequest,
3436
SendStreamingMessageRequest,
3537
SetTaskPushNotificationConfigRequest,
36-
TaskIdParams,
3738
TaskQueryParams,
3839
)
3940

@@ -131,3 +132,27 @@ async def get_task_push_notification_config(
131132
)
132133
result = await self._send_request(A2ARequest(root=request))
133134
return GetTaskPushNotificationConfigResponse.model_validate(result)
135+
136+
async def resubscribe_to_task(
137+
self, task_id: str
138+
) -> AsyncGenerator[SendStreamingMessageResponse]:
139+
request = TaskResubscriptionRequest(
140+
id=str(uuid4()), params=TaskIdParams(id=task_id)
141+
)
142+
request_headers = {"Content-Type": "application/json", **self.headers}
143+
144+
async with httpx.AsyncClient(timeout=None) as client:
145+
async with aconnect_sse(
146+
client,
147+
"POST",
148+
self.url,
149+
json=request.model_dump(),
150+
headers=request_headers,
151+
) as event_source:
152+
try:
153+
async for sse in event_source.aiter_sse():
154+
yield SendStreamingMessageResponse.model_validate(
155+
json.loads(sse.data)
156+
)
157+
except json.JSONDecodeError as e:
158+
raise httpx.DecodingError(str(e)) from e

AgentCrew/modules/a2a/task_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ async def on_send_message_streaming(
257257
while True:
258258
event = await queue.get()
259259
if event is None: # End of stream
260+
await self.store.delete_task(task_id)
260261
break
261262
yield SendStreamingMessageResponse(
262263
root=SendStreamingMessageSuccessResponse(
@@ -265,7 +266,6 @@ async def on_send_message_streaming(
265266
)
266267

267268
finally:
268-
await self.store.delete_task(task_id)
269269
self.streaming_tasks.pop(task_id, None)
270270

271271
def _create_ask_tool_message(
@@ -748,6 +748,7 @@ async def on_resubscribe_to_task(
748748
while True:
749749
event = await queue.get()
750750
if event is None:
751+
await self.store.delete_task(task_id)
751752
break
752753

753754
yield SendStreamingMessageResponse(

AgentCrew/modules/agents/remote_agent.py

Lines changed: 77 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from __future__ import annotations
22

3+
import asyncio
34
from typing import Dict, TYPE_CHECKING
45
from uuid import uuid4
56

7+
import httpx
8+
from loguru import logger
69
from pydantic import ValidationError
710
from AgentCrew.modules.a2a.adapters import (
811
convert_agent_message_to_a2a,
@@ -130,51 +133,85 @@ async def process_messages(
130133
a2a_payload = MessageSendParams(
131134
metadata={"id": str(uuid4())},
132135
message=a2a_message,
133-
# acceptedOutputModes can be set here if needed, e.g., based on agent_card.defaultOutputModes
134-
# For now, relying on server defaults or agent's capability.
135136
)
136137

137138
full_response_text = ""
138-
139-
async for stream_response in self.client.send_message_streaming(a2a_payload):
140-
if isinstance(stream_response.root, JSONRPCErrorResponse):
141-
raise Exception(
142-
f"Remote agent stream error: {stream_response.root.error.code} - {stream_response.root.error.message}"
143-
)
144-
145-
if stream_response.root.result:
146-
event = stream_response.root.result
147-
current_content_chunk_text = ""
148-
current_thinking_chunk_text = ""
149-
150-
if isinstance(event, TaskArtifactUpdateEvent):
151-
self.current_task_id = event.task_id
152-
for part in event.artifact.parts:
153-
if isinstance(part.root, TextPart):
154-
current_content_chunk_text += part.root.text
155-
if current_content_chunk_text:
156-
full_response_text += current_content_chunk_text
157-
yield (
158-
full_response_text,
159-
current_content_chunk_text,
160-
None,
139+
max_retries = 3
140+
retry_count = 0
141+
is_resubscribe = False
142+
143+
while retry_count <= max_retries:
144+
try:
145+
if is_resubscribe and self.current_task_id:
146+
logger.info(
147+
f"Resubscribing to task {self.current_task_id} "
148+
f"(attempt {retry_count})"
149+
)
150+
stream = self.client.resubscribe_to_task(self.current_task_id)
151+
else:
152+
stream = self.client.send_message_streaming(a2a_payload)
153+
154+
async for stream_response in stream:
155+
if isinstance(stream_response.root, JSONRPCErrorResponse):
156+
raise Exception(
157+
f"Remote agent stream error: "
158+
f"{stream_response.root.error.code} - "
159+
f"{stream_response.root.error.message}"
161160
)
162161

163-
elif isinstance(event, TaskStatusUpdateEvent):
164-
self.current_task_id = event.task_id
165-
if event.status.message and event.status.message.parts:
166-
for part in event.status.message.parts:
167-
if isinstance(part.root, TextPart):
168-
current_content_chunk_text += part.root.text
169-
if current_thinking_chunk_text:
170-
yield (
171-
full_response_text,
172-
None,
173-
(current_thinking_chunk_text, None),
174-
)
175-
176-
# After the loop, the generator stops. If full_response_text is empty,
177-
# it signifies no textual content was streamed as artifacts.
162+
if stream_response.root.result:
163+
event = stream_response.root.result
164+
current_content_chunk_text = ""
165+
current_thinking_chunk_text = ""
166+
167+
if isinstance(event, TaskArtifactUpdateEvent):
168+
self.current_task_id = event.task_id
169+
for part in event.artifact.parts:
170+
if isinstance(part.root, TextPart):
171+
current_content_chunk_text += part.root.text
172+
if current_content_chunk_text:
173+
full_response_text += current_content_chunk_text
174+
yield (
175+
full_response_text,
176+
current_content_chunk_text,
177+
None,
178+
)
179+
180+
elif isinstance(event, TaskStatusUpdateEvent):
181+
self.current_task_id = event.task_id
182+
if event.status.message and event.status.message.parts:
183+
for part in event.status.message.parts:
184+
if isinstance(part.root, TextPart):
185+
current_content_chunk_text += part.root.text
186+
if current_thinking_chunk_text:
187+
yield (
188+
full_response_text,
189+
None,
190+
(current_thinking_chunk_text, None),
191+
)
192+
193+
break
194+
195+
except (
196+
httpx.ReadError,
197+
httpx.RemoteProtocolError,
198+
httpx.ReadTimeout,
199+
ConnectionError,
200+
httpx.ConnectError,
201+
) as e:
202+
retry_count += 1
203+
if retry_count > max_retries:
204+
logger.error(
205+
f"Failed to reconnect after {max_retries} attempts: {e}"
206+
)
207+
raise
208+
wait_time = min(2**retry_count, 30)
209+
logger.warning(
210+
f"Stream connection lost: {e}. "
211+
f"Retrying in {wait_time}s (attempt {retry_count}/{max_retries})"
212+
)
213+
await asyncio.sleep(wait_time)
214+
is_resubscribe = True
178215

179216
def get_process_result(self) -> Tuple:
180217
"""

0 commit comments

Comments
 (0)