|
| 1 | +import asyncio |
1 | 2 | import json |
2 | 3 | import os |
3 | 4 | import uuid |
4 | 5 | from dataclasses import dataclass |
5 | 6 | from datetime import timedelta |
6 | | -from typing import Any, Optional, Union, no_type_check |
| 7 | +from typing import Any, AsyncIterator, Optional, Union, no_type_check |
7 | 8 |
|
8 | 9 | import nexusrpc |
9 | 10 | import pytest |
|
39 | 40 | HandoffOutputItem, |
40 | 41 | ToolCallItem, |
41 | 42 | ToolCallOutputItem, |
| 43 | + TResponseStreamEvent, |
42 | 44 | ) |
43 | 45 | from openai import APIStatusError, AsyncOpenAI, BaseModel |
44 | 46 | from openai.types.responses import ( |
@@ -1884,6 +1886,86 @@ async def test_chat_completions_model(client: Client): |
1884 | 1886 | await workflow_handle.result() |
1885 | 1887 |
|
1886 | 1888 |
|
| 1889 | +class WaitModel(Model): |
| 1890 | + async def get_response( |
| 1891 | + self, |
| 1892 | + system_instructions: Union[str, None], |
| 1893 | + input: Union[str, list[TResponseInputItem]], |
| 1894 | + model_settings: ModelSettings, |
| 1895 | + tools: list[Tool], |
| 1896 | + output_schema: Union[AgentOutputSchemaBase, None], |
| 1897 | + handoffs: list[Handoff], |
| 1898 | + tracing: ModelTracing, |
| 1899 | + *, |
| 1900 | + previous_response_id: Union[str, None], |
| 1901 | + prompt: Union[ResponsePromptParam, None] = None, |
| 1902 | + ) -> ModelResponse: |
| 1903 | + activity.logger.info("Waiting") |
| 1904 | + await asyncio.sleep(1.0) |
| 1905 | + activity.logger.info("Returning") |
| 1906 | + return ModelResponse( |
| 1907 | + output=[ |
| 1908 | + ResponseOutputMessage( |
| 1909 | + id="", |
| 1910 | + content=[ |
| 1911 | + ResponseOutputText( |
| 1912 | + text="test", annotations=[], type="output_text" |
| 1913 | + ) |
| 1914 | + ], |
| 1915 | + role="assistant", |
| 1916 | + status="completed", |
| 1917 | + type="message", |
| 1918 | + ) |
| 1919 | + ], |
| 1920 | + usage=Usage(), |
| 1921 | + response_id=None, |
| 1922 | + ) |
| 1923 | + |
| 1924 | + def stream_response( |
| 1925 | + self, |
| 1926 | + system_instructions: Optional[str], |
| 1927 | + input: Union[str, list[TResponseInputItem]], |
| 1928 | + model_settings: ModelSettings, |
| 1929 | + tools: list[Tool], |
| 1930 | + output_schema: Optional[AgentOutputSchemaBase], |
| 1931 | + handoffs: list[Handoff], |
| 1932 | + tracing: ModelTracing, |
| 1933 | + *, |
| 1934 | + previous_response_id: Optional[str], |
| 1935 | + prompt: Optional[ResponsePromptParam], |
| 1936 | + ) -> AsyncIterator[TResponseStreamEvent]: |
| 1937 | + raise NotImplementedError() |
| 1938 | + |
| 1939 | + |
| 1940 | +async def test_heartbeat(client: Client, env: WorkflowEnvironment): |
| 1941 | + if env.supports_time_skipping: |
| 1942 | + pytest.skip("Relies on real timing, skip.") |
| 1943 | + |
| 1944 | + new_config = client.config() |
| 1945 | + new_config["plugins"] = [ |
| 1946 | + openai_agents.OpenAIAgentsPlugin( |
| 1947 | + model_params=ModelActivityParameters( |
| 1948 | + heartbeat_timeout=timedelta(seconds=0.5), |
| 1949 | + ), |
| 1950 | + model_provider=TestModelProvider(WaitModel()), |
| 1951 | + ) |
| 1952 | + ] |
| 1953 | + client = Client(**new_config) |
| 1954 | + |
| 1955 | + async with new_worker( |
| 1956 | + client, |
| 1957 | + HelloWorldAgent, |
| 1958 | + ) as worker: |
| 1959 | + workflow_handle = await client.start_workflow( |
| 1960 | + HelloWorldAgent.run, |
| 1961 | + "Tell me about recursion in programming.", |
| 1962 | + id=f"workflow-tool-{uuid.uuid4()}", |
| 1963 | + task_queue=worker.task_queue, |
| 1964 | + execution_timeout=timedelta(seconds=5.0), |
| 1965 | + ) |
| 1966 | + await workflow_handle.result() |
| 1967 | + |
| 1968 | + |
1887 | 1969 | def test_summary_extraction(): |
1888 | 1970 | input: list[TResponseInputItem] = [ |
1889 | 1971 | EasyInputMessageParam( |
|
0 commit comments