|
1 | | -from typing import Any, cast |
| 1 | +from typing import Any, List, cast |
2 | 2 |
|
3 | 3 | from agentex import AsyncAgentex |
4 | 4 | from agentex.lib.core.tracing.tracer import AsyncTracer |
@@ -78,7 +78,7 @@ async def message_send( |
78 | 78 | task_name: str | None = None, |
79 | 79 | trace_id: str | None = None, |
80 | 80 | parent_span_id: str | None = None, |
81 | | - ) -> TaskMessage: |
| 81 | + ) -> List[TaskMessage]: |
82 | 82 | trace = self._tracer.trace(trace_id=trace_id) |
83 | 83 | async with trace.span( |
84 | 84 | parent_id=parent_span_id, |
@@ -115,10 +115,17 @@ async def message_send( |
115 | 115 | else: |
116 | 116 | raise ValueError("Either agent_name or agent_id must be provided") |
117 | 117 |
|
118 | | - task_message = TaskMessage.model_validate(json_rpc_response.result) |
| 118 | + task_messages: List[TaskMessage] = [] |
| 119 | + if isinstance(json_rpc_response.result, list): |
| 120 | + for message in json_rpc_response.result: |
| 121 | + task_message = TaskMessage.model_validate(message) |
| 122 | + task_messages.append(task_message) |
| 123 | + else: |
| 124 | + task_messages = [TaskMessage.model_validate(json_rpc_response.result)] |
| 125 | + |
119 | 126 | if span: |
120 | | - span.output = task_message.model_dump() |
121 | | - return task_message |
| 127 | + span.output = [task_message.model_dump() for task_message in task_messages] |
| 128 | + return task_messages |
122 | 129 |
|
123 | 130 | async def event_send( |
124 | 131 | self, |
|
0 commit comments