Skip to content

Commit 3549fa7

Browse files
committed
PR Feedback
1 parent 9d3890e commit 3549fa7

File tree

4 files changed

+34
-11
lines changed

4 files changed

+34
-11
lines changed

temporalio/contrib/openai_agents/_invoke_model_activity.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,9 @@ class ModelActivity:
155155

156156
def __init__(self, model_provider: Optional[ModelProvider] = None):
157157
"""Initialize the activity with a model provider."""
158-
self._model_provider = model_provider
159-
if model_provider is None and not workflow.in_workflow():
158+
if model_provider:
159+
self._model_provider = model_provider
160+
else:
160161
self._model_provider = OpenAIProvider(
161162
openai_client=AsyncOpenAI(max_retries=0)
162163
)
@@ -165,10 +166,6 @@ def __init__(self, model_provider: Optional[ModelProvider] = None):
165166
@_auto_heartbeater
166167
async def invoke_model_activity(self, input: ActivityModelInput) -> ModelResponse:
167168
"""Activity that invokes a model with the given input."""
168-
if not self._model_provider:
169-
self._model_provider = OpenAIProvider(
170-
openai_client=AsyncOpenAI(max_retries=0)
171-
)
172169
model = self._model_provider.get_model(input.get("model_name"))
173170

174171
async def empty_on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> str:

temporalio/plugin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ def __init__(
8686
Applied to the Worker and Replayer.
8787
workflow_failure_exception_types: Exception types for workflow failures to append,
8888
or callable to customize existing ones. Applied to the Worker and Replayer.
89-
run_context: Optional async context manager producer to wrap worker/replayer execution.
90-
Applied to the Worker and Replayer.
89+
run_context: A place to run custom code to wrap around the Worker (or Replayer) execution.
90+
Specifically, it's an async context manager producer. Applied to the Worker and Replayer.
9191
9292
Returns:
9393
A configured Plugin instance.

tests/contrib/openai_agents/test_openai_replay.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from pathlib import Path
22

33
import pytest
4+
from agents import OpenAIProvider
5+
from openai import AsyncOpenAI
46

57
from temporalio.client import WorkflowHistory
68
from temporalio.contrib.openai_agents import ModelActivityParameters, OpenAIAgentsPlugin
@@ -42,5 +44,11 @@ async def test_replay(file_name: str) -> None:
4244
InputGuardrailWorkflow,
4345
OutputGuardrailWorkflow,
4446
],
45-
plugins=[OpenAIAgentsPlugin()],
47+
plugins=[
48+
OpenAIAgentsPlugin(
49+
model_provider=OpenAIProvider(
50+
openai_client=AsyncOpenAI(max_retries=0, api_key="PLACEHOLDER")
51+
)
52+
)
53+
],
4654
).replay_workflow(WorkflowHistory.from_json("fake", history_json))

tests/test_plugins.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ async def test_replay(client: Client) -> None:
275275
await replayer.replay_workflow(await handle.fetch_history())
276276

277277

278-
async def test_static_plugins(client: Client) -> None:
278+
async def test_simple_plugins(client: Client) -> None:
279279
plugin = SimplePlugin(
280280
"MyPlugin",
281281
data_converter=pydantic_data_converter,
@@ -311,7 +311,7 @@ async def test_static_plugins(client: Client) -> None:
311311
assert replayer.config().get("workflows") == [HelloWorkflow, HelloWorkflow2]
312312

313313

314-
async def test_static_plugins_callables(client: Client) -> None:
314+
async def test_simple_plugins_callables(client: Client) -> None:
315315
def converter(old: Optional[DataConverter]):
316316
if old != temporalio.converter.default():
317317
raise ValueError("Can't override non-default converter")
@@ -344,3 +344,21 @@ def converter(old: Optional[DataConverter]):
344344
plugins=[plugin],
345345
)
346346
assert worker.config().get("workflows") == []
347+
348+
349+
class MediumPlugin(SimplePlugin):
350+
def __init__(self):
351+
super().__init__("MediumPlugin", data_converter=pydantic_data_converter)
352+
353+
def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
354+
config = super().configure_worker(config)
355+
config["task_queue"] = "override"
356+
return config
357+
358+
359+
async def test_medium_plugin(client: Client) -> None:
360+
plugin = MediumPlugin()
361+
worker = Worker(
362+
client, task_queue="queue", plugins=[plugin], workflows=[HelloWorkflow]
363+
)
364+
assert worker.config().get("task_queue") == "override"

0 commit comments

Comments
 (0)