Skip to content

Commit 241a0a9

Browse files
committed
Unit test summary, explicitly handle some edge cases
1 parent d6afb78 commit 241a0a9

File tree

2 files changed

+72
-31
lines changed

2 files changed

+72
-31
lines changed

temporalio/contrib/openai_agents/_temporal_model_stub.py

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -64,35 +64,6 @@ async def get_response(
6464
previous_response_id: Optional[str],
6565
prompt: Optional[ResponsePromptParam],
6666
) -> ModelResponse:
67-
def get_summary(
68-
input: Union[str, list[TResponseInputItem]],
69-
) -> str:
70-
### Activity summary shown in the UI
71-
try:
72-
max_size = 100
73-
if isinstance(input, str):
74-
return input[:max_size]
75-
elif isinstance(input, list):
76-
content: Any = [
77-
item
78-
for item in input
79-
if (item.get("type") or "message") == "message"
80-
][-1]
81-
if isinstance(content, dict):
82-
content = content.get("content", "")
83-
elif hasattr(content, "content"):
84-
content = getattr(content, "content")
85-
86-
if isinstance(content, list):
87-
content = content[-1]
88-
89-
if isinstance(content, dict) and content.get("text") is not None:
90-
content = content.get("text")
91-
return str(content)[:max_size]
92-
except Exception as e:
93-
logger.error(f"Error getting summary: {e}")
94-
return ""
95-
9667
def make_tool_info(tool: Tool) -> ToolInput:
9768
if isinstance(tool, (FileSearchTool, WebSearchTool)):
9869
return tool
@@ -157,7 +128,7 @@ def make_tool_info(tool: Tool) -> ToolInput:
157128
return await workflow.execute_activity_method(
158129
ModelActivity.invoke_model_activity,
159130
activity_input,
160-
summary=self.model_params.summary_override or get_summary(input),
131+
summary=self.model_params.summary_override or _extract_summary(input),
161132
task_queue=self.model_params.task_queue,
162133
schedule_to_close_timeout=self.model_params.schedule_to_close_timeout,
163134
schedule_to_start_timeout=self.model_params.schedule_to_start_timeout,
@@ -183,3 +154,34 @@ def stream_response(
183154
prompt: ResponsePromptParam | None,
184155
) -> AsyncIterator[TResponseStreamEvent]:
185156
raise NotImplementedError("Temporal model doesn't support streams yet")
157+
158+
159+
def _extract_summary(input: Union[str, list[TResponseInputItem]]) -> str:
160+
### Activity summary shown in the UI
161+
try:
162+
max_size = 100
163+
if isinstance(input, str):
164+
return input[:max_size]
165+
elif isinstance(input, list):
166+
# Find all message inputs, which are reasonably summarizable
167+
messages: list[TResponseInputItem] = [
168+
item for item in input if (item.get("type") or "message") == "message"
169+
]
170+
if not messages:
171+
return ""
172+
173+
content: Any = messages[-1].get("content", "")
174+
175+
# In the case of multiple contents, take the last one
176+
if isinstance(content, list):
177+
if not content:
178+
return ""
179+
content = content[-1]
180+
181+
# Take the text field from the content if present
182+
if isinstance(content, dict) and content.get("text") is not None:
183+
content = content.get("text")
184+
return str(content)[:max_size]
185+
except Exception as e:
186+
logger.error(f"Error getting summary: {e}")
187+
return ""

tests/contrib/openai_agents/test_openai.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,16 @@
4242
)
4343
from openai import APIStatusError, AsyncOpenAI, BaseModel
4444
from openai.types.responses import (
45+
EasyInputMessageParam,
4546
ResponseFunctionToolCall,
47+
ResponseFunctionToolCallParam,
4648
ResponseFunctionWebSearch,
49+
ResponseInputTextParam,
4750
ResponseOutputMessage,
4851
ResponseOutputText,
4952
)
5053
from openai.types.responses.response_function_web_search import ActionSearch
54+
from openai.types.responses.response_input_item_param import Message
5155
from openai.types.responses.response_prompt_param import ResponsePromptParam
5256
from pydantic import ConfigDict, Field, TypeAdapter
5357

@@ -61,6 +65,7 @@
6165
TestModel,
6266
TestModelProvider,
6367
)
68+
from temporalio.contrib.openai_agents._temporal_model_stub import _extract_summary
6469
from temporalio.contrib.pydantic import pydantic_data_converter
6570
from temporalio.exceptions import ApplicationError, CancelledError
6671
from temporalio.testing import WorkflowEnvironment
@@ -1706,7 +1711,7 @@ class WorkflowToolModel(StaticTestModel):
17061711
id="",
17071712
content=[
17081713
ResponseOutputText(
1709-
text="",
1714+
text="Workflow tool was used",
17101715
annotations=[],
17111716
type="output_text",
17121717
)
@@ -1877,3 +1882,37 @@ async def test_chat_completions_model(client: Client):
18771882
execution_timeout=timedelta(seconds=10),
18781883
)
18791884
await workflow_handle.result()
1885+
1886+
1887+
def test_summary_extraction():
1888+
input: list[TResponseInputItem] = [
1889+
EasyInputMessageParam(
1890+
content="First message",
1891+
role="user",
1892+
)
1893+
]
1894+
1895+
assert _extract_summary(input) == "First message"
1896+
1897+
input.append(
1898+
Message(
1899+
content=[
1900+
ResponseInputTextParam(
1901+
text="Second message",
1902+
type="input_text",
1903+
)
1904+
],
1905+
role="user",
1906+
)
1907+
)
1908+
assert _extract_summary(input) == "Second message"
1909+
1910+
input.append(
1911+
ResponseFunctionToolCallParam(
1912+
arguments="",
1913+
call_id="",
1914+
name="",
1915+
type="function_call",
1916+
)
1917+
)
1918+
assert _extract_summary(input) == "Second message"

0 commit comments

Comments
 (0)