Skip to content

Commit 37db356

Browse files
committed
Put rebuild back
1 parent 703fc32 commit 37db356

File tree

3 files changed

+28
-2
lines changed

3 files changed

+28
-2
lines changed

temporalio/contrib/openai_agents/__init__.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@
77
This module is experimental and may change in future versions.
88
Use with caution in production environments.
99
"""
10+
11+
import importlib
12+
import inspect
13+
import pkgutil
14+
15+
from pydantic import BaseModel
16+
1017
from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity
1118
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
1219
from temporalio.contrib.openai_agents._trace_interceptor import (
@@ -30,3 +37,22 @@
3037
"TestModelProvider",
3138
]
3239

40+
41+
def _reload_models(module_name: str) -> None:
42+
"""Recursively walk through modules and rebuild BaseModel classes."""
43+
module = importlib.import_module(module_name)
44+
45+
# Process classes in the current module
46+
for _, obj in inspect.getmembers(module, inspect.isclass):
47+
if issubclass(obj, BaseModel) and obj is not BaseModel:
48+
obj.model_rebuild()
49+
50+
# Recursively process submodules
51+
if hasattr(module, "__path__"):
52+
for _, submodule_name, _ in pkgutil.iter_modules(module.__path__):
53+
full_submodule_name = f"{module_name}.{submodule_name}"
54+
_reload_models(full_submodule_name)
55+
56+
57+
# Recursively call model_rebuild() on all BaseModel classes in openai.types
58+
_reload_models("openai.types")

temporalio/contrib/openai_agents/_invoke_model_activity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import enum
77
import json
8-
from pydantic.dataclasses import dataclass
98
from typing import Any, Optional, Union, cast
109

1110
from agents import (
@@ -24,6 +23,7 @@
2423
WebSearchTool,
2524
)
2625
from agents.models.multi_provider import MultiProvider
26+
from pydantic.dataclasses import dataclass
2727
from typing_extensions import Required, TypedDict
2828

2929
from temporalio import activity

tests/contrib/openai_agents/test_openai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1754,4 +1754,4 @@ async def test_response_serialization():
17541754
usage=Usage(),
17551755
response_id="",
17561756
)
1757-
encoded = await pydantic_data_converter.encode([model_response])
1757+
encoded = await pydantic_data_converter.encode([model_response])

0 commit comments

Comments
 (0)