Skip to content

Commit 7eb03ed

Browse files
committed
Merge remote-tracking branch 'origin/main' into openai/block_sqlite_session
2 parents f36fd8f + 79f2900 commit 7eb03ed

File tree

3 files changed

+77
-28
lines changed

3 files changed

+77
-28
lines changed

temporalio/contrib/openai_agents/_temporal_model_stub.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
logger = logging.getLogger(__name__)
1010

11-
from typing import Any, AsyncIterator, Sequence, Union, cast
11+
from typing import Any, AsyncIterator, Union, cast
1212

1313
from agents import (
1414
AgentOutputSchema,
@@ -54,7 +54,7 @@ def __init__(
5454
async def get_response(
5555
self,
5656
system_instructions: Optional[str],
57-
input: Union[str, list[TResponseInputItem], dict[str, str]],
57+
input: Union[str, list[TResponseInputItem]],
5858
model_settings: ModelSettings,
5959
tools: list[Tool],
6060
output_schema: Optional[AgentOutputSchemaBase],
@@ -64,28 +64,6 @@ async def get_response(
6464
previous_response_id: Optional[str],
6565
prompt: Optional[ResponsePromptParam],
6666
) -> ModelResponse:
67-
def get_summary(
68-
input: Union[str, list[TResponseInputItem], dict[str, str]],
69-
) -> str:
70-
### Activity summary shown in the UI
71-
try:
72-
max_size = 100
73-
if isinstance(input, str):
74-
return input[:max_size]
75-
elif isinstance(input, list):
76-
seq_input = cast(Sequence[Any], input)
77-
last_item = seq_input[-1]
78-
if isinstance(last_item, dict):
79-
return last_item.get("content", "")[:max_size]
80-
elif hasattr(last_item, "content"):
81-
return str(getattr(last_item, "content"))[:max_size]
82-
return str(last_item)[:max_size]
83-
elif isinstance(input, dict):
84-
return input.get("content", "")[:max_size]
85-
except Exception as e:
86-
logger.error(f"Error getting summary: {e}")
87-
return ""
88-
8967
def make_tool_info(tool: Tool) -> ToolInput:
9068
if isinstance(tool, (FileSearchTool, WebSearchTool)):
9169
return tool
@@ -150,7 +128,7 @@ def make_tool_info(tool: Tool) -> ToolInput:
150128
return await workflow.execute_activity_method(
151129
ModelActivity.invoke_model_activity,
152130
activity_input,
153-
summary=self.model_params.summary_override or get_summary(input),
131+
summary=self.model_params.summary_override or _extract_summary(input),
154132
task_queue=self.model_params.task_queue,
155133
schedule_to_close_timeout=self.model_params.schedule_to_close_timeout,
156134
schedule_to_start_timeout=self.model_params.schedule_to_start_timeout,
@@ -176,3 +154,34 @@ def stream_response(
176154
prompt: ResponsePromptParam | None,
177155
) -> AsyncIterator[TResponseStreamEvent]:
178156
raise NotImplementedError("Temporal model doesn't support streams yet")
157+
158+
159+
def _extract_summary(input: Union[str, list[TResponseInputItem]]) -> str:
160+
### Activity summary shown in the UI
161+
try:
162+
max_size = 100
163+
if isinstance(input, str):
164+
return input[:max_size]
165+
elif isinstance(input, list):
166+
# Find all message inputs, which are reasonably summarizable
167+
messages: list[TResponseInputItem] = [
168+
item for item in input if item.get("type", "message") == "message"
169+
]
170+
if not messages:
171+
return ""
172+
173+
content: Any = messages[-1].get("content", "")
174+
175+
# In the case of multiple contents, take the last one
176+
if isinstance(content, list):
177+
if not content:
178+
return ""
179+
content = content[-1]
180+
181+
# Take the text field from the content if present
182+
if isinstance(content, dict) and content.get("text") is not None:
183+
content = content.get("text")
184+
return str(content)[:max_size]
185+
except Exception as e:
186+
logger.error(f"Error getting summary: {e}")
187+
return ""

temporalio/contrib/openai_agents/workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ async def run_activity(ctx: RunContextWrapper[Any], input: str) -> Any:
134134
cancellation_type=cancellation_type,
135135
activity_id=activity_id,
136136
versioning_intent=versioning_intent,
137-
summary=summary,
137+
summary=summary or schema.description,
138138
priority=priority,
139139
)
140140
try:

tests/contrib/openai_agents/test_openai.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,16 @@
4545
)
4646
from openai import APIStatusError, AsyncOpenAI, BaseModel
4747
from openai.types.responses import (
48+
EasyInputMessageParam,
4849
ResponseFunctionToolCall,
50+
ResponseFunctionToolCallParam,
4951
ResponseFunctionWebSearch,
52+
ResponseInputTextParam,
5053
ResponseOutputMessage,
5154
ResponseOutputText,
5255
)
5356
from openai.types.responses.response_function_web_search import ActionSearch
57+
from openai.types.responses.response_input_item_param import Message
5458
from openai.types.responses.response_prompt_param import ResponsePromptParam
5559
from pydantic import ConfigDict, Field, TypeAdapter
5660

@@ -64,6 +68,7 @@
6468
TestModel,
6569
TestModelProvider,
6670
)
71+
from temporalio.contrib.openai_agents._temporal_model_stub import _extract_summary
6772
from temporalio.contrib.pydantic import pydantic_data_converter
6873
from temporalio.exceptions import ApplicationError, CancelledError
6974
from temporalio.testing import WorkflowEnvironment
@@ -681,7 +686,8 @@ async def test_research_workflow(client: Client, use_local_model: bool):
681686
new_config["plugins"] = [
682687
openai_agents.OpenAIAgentsPlugin(
683688
model_params=ModelActivityParameters(
684-
start_to_close_timeout=timedelta(seconds=30)
689+
start_to_close_timeout=timedelta(seconds=120),
690+
schedule_to_close_timeout=timedelta(seconds=120),
685691
),
686692
model_provider=TestModelProvider(TestResearchModel())
687693
if use_local_model
@@ -1688,7 +1694,7 @@ class WorkflowToolModel(StaticTestModel):
16881694
id="",
16891695
content=[
16901696
ResponseOutputText(
1691-
text="",
1697+
text="Workflow tool was used",
16921698
annotations=[],
16931699
type="output_text",
16941700
)
@@ -1941,6 +1947,40 @@ async def test_heartbeat(client: Client, env: WorkflowEnvironment):
19411947
await workflow_handle.result()
19421948

19431949

1950+
def test_summary_extraction():
1951+
input: list[TResponseInputItem] = [
1952+
EasyInputMessageParam(
1953+
content="First message",
1954+
role="user",
1955+
)
1956+
]
1957+
1958+
assert _extract_summary(input) == "First message"
1959+
1960+
input.append(
1961+
Message(
1962+
content=[
1963+
ResponseInputTextParam(
1964+
text="Second message",
1965+
type="input_text",
1966+
)
1967+
],
1968+
role="user",
1969+
)
1970+
)
1971+
assert _extract_summary(input) == "Second message"
1972+
1973+
input.append(
1974+
ResponseFunctionToolCallParam(
1975+
arguments="",
1976+
call_id="",
1977+
name="",
1978+
type="function_call",
1979+
)
1980+
)
1981+
assert _extract_summary(input) == "Second message"
1982+
1983+
19441984
@workflow.defn
19451985
class SessionWorkflow:
19461986
@workflow.run

0 commit comments

Comments
 (0)