|
7 | 7 |
|
8 | 8 | import logging
|
9 | 9 | import time
|
10 |
| -from typing import Any, Protocol, Union |
| 10 | +from typing import Any, Callable, Union |
11 | 11 |
|
12 | 12 | from opentelemetry import trace as trace_api
|
13 | 13 |
|
|
21 | 21 | logger = logging.getLogger(__name__)
|
22 | 22 |
|
23 | 23 |
|
24 |
| -class FunctionNodeCallable(Protocol): |
25 |
| - """Protocol defining the required signature for functions used in FunctionNode. |
26 |
| -
|
27 |
| - Functions must accept: |
28 |
| - - task: The input task (string or ContentBlock list) |
29 |
| - - invocation_state: Additional state/context from the calling environment |
30 |
| - - **kwargs: Additional keyword arguments for future extensibility |
31 |
| -
|
32 |
| - Functions must return: |
33 |
| - - A string result that will be converted to a Message |
34 |
| - """ |
35 |
| - |
36 |
| - def __call__( |
37 |
| - self, task: Union[str, list[ContentBlock]], invocation_state: dict[str, Any] | None = None, **kwargs: Any |
38 |
| - ) -> str: |
39 |
| - """Execute the node with the given task.""" |
40 |
| - ... |
| 24 | +FunctionNodeCallable = Callable[[Union[str, list[ContentBlock]], dict[str, Any] | None], str] |
41 | 25 |
|
42 | 26 |
|
43 | 27 | class FunctionNode(MultiAgentBase):
|
@@ -82,16 +66,19 @@ async def invoke_async(
|
82 | 66 | logger.debug("task=<%s> | starting function node execution", task)
|
83 | 67 | logger.debug("function_name=<%s> | executing function", self.name)
|
84 | 68 |
|
85 |
| - start_time = time.time() |
86 | 69 | span = self.tracer.start_multiagent_span(task, "function_node")
|
87 | 70 | with trace_api.use_span(span, end_on_exit=True):
|
88 | 71 | try:
|
| 72 | + start_time = time.time() |
89 | 73 | # Execute the wrapped function with proper parameters
|
90 | 74 | function_result = self.func(task, invocation_state, **kwargs)
|
91 |
| - logger.debug("function_result=<%s> | function executed successfully", function_result) |
92 |
| - |
93 | 75 | # Calculate execution time
|
94 | 76 | execution_time = int((time.time() - start_time) * 1000) # Convert to milliseconds
|
| 77 | + logger.debug( |
| 78 | + "function_result=<%s>, execution_time=<%dms> | function executed successfully", |
| 79 | + function_result, |
| 80 | + execution_time, |
| 81 | + ) |
95 | 82 |
|
96 | 83 | # Convert function result to Message
|
97 | 84 | message = Message(role="assistant", content=[ContentBlock(text=str(function_result))])
|
|
0 commit comments