Skip to content

Commit 84565a7

Browse files
committed
Recursively rebuild models in openai.types to ensure they are available for encoding
1 parent 808a5f4 commit 84565a7

File tree

2 files changed

+39
-1
lines changed

2 files changed

+39
-1
lines changed

temporalio/contrib/openai_agents/__init__.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@
88
Use with caution in production environments.
99
"""
1010

11+
import importlib
12+
import inspect
13+
import pkgutil
14+
15+
from pydantic import BaseModel
16+
1117
from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity
1218
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
1319
from temporalio.contrib.openai_agents._trace_interceptor import (
@@ -30,3 +36,23 @@
3036
"TestModel",
3137
"TestModelProvider",
3238
]
39+
40+
41+
def _reload_models(module_name: str) -> None:
42+
"""Recursively walk through modules and rebuild BaseModel classes."""
43+
module = importlib.import_module(module_name)
44+
45+
# Process classes in the current module
46+
for _, obj in inspect.getmembers(module, inspect.isclass):
47+
if issubclass(obj, BaseModel) and obj is not BaseModel:
48+
obj.model_rebuild()
49+
50+
# Recursively process submodules
51+
if hasattr(module, "__path__"):
52+
for _, submodule_name, _ in pkgutil.iter_modules(module.__path__):
53+
full_submodule_name = f"{module_name}.{submodule_name}"
54+
_reload_models(full_submodule_name)
55+
56+
57+
# Recursively call model_rebuild() on all BaseModel classes in openai.types
58+
_reload_models("openai.types")

tests/contrib/openai_agents/test_openai.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
)
4646
from openai.types.responses.response_function_web_search import ActionSearch
4747
from openai.types.responses.response_prompt_param import ResponsePromptParam
48-
from pydantic import ConfigDict, Field
48+
from pydantic import ConfigDict, Field, TypeAdapter
4949

5050
from temporalio import activity, workflow
5151
from temporalio.client import Client, WorkflowFailureError, WorkflowHandle
@@ -1736,3 +1736,15 @@ async def test_workflow_method_tools(client: Client):
17361736
execution_timeout=timedelta(seconds=10),
17371737
)
17381738
await workflow_handle.result()
1739+
1740+
1741+
async def test_response_serialization():
1742+
import json
1743+
1744+
from openai.types.responses.response_output_item import ImageGenerationCall
1745+
1746+
data = json.loads(
1747+
b'{"id": "msg_68757ec43348819d86709f0fcb70316301a1194a3e05b38c","type": "image_generation_call","status": "completed"}'
1748+
)
1749+
call = TypeAdapter(ImageGenerationCall).validate_python(data)
1750+
encoded = await pydantic_data_converter.encode([call])

0 commit comments

Comments
 (0)