Skip to content

Commit 6060b35

Browse files
committed
use 3.11 create_task when possible to avoid an extra context copy
1 parent 6ca781c commit 6060b35

File tree

3 files changed

+37
-11
lines changed

3 files changed

+37
-11
lines changed

temporalio/contrib/openai_agents/_trace_interceptor.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import asyncio
66
import contextvars
77
import random
8+
import sys
89
import uuid
910
from contextlib import contextmanager
1011
from typing import Any, Mapping, Optional, Protocol, Type
@@ -424,9 +425,14 @@ async def start_child_workflow(
424425
data={"workflow": input.workflow},
425426
input=input,
426427
)
427-
handle = await ctx.run(
428-
asyncio.create_task, self.next.start_child_workflow(input)
429-
)
428+
if sys.version_info >= (3, 11):
429+
handle = await asyncio.create_task(
430+
self.next.start_child_workflow(input), context=ctx
431+
)
432+
else:
433+
handle = await ctx.run(
434+
asyncio.create_task, self.next.start_child_workflow(input)
435+
)
430436
if span:
431437
handle.add_done_callback(lambda _: ctx.run(span.finish)) # type: ignore
432438
return handle

temporalio/contrib/opentelemetry.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
contextmanager,
1010
)
1111
from dataclasses import dataclass
12+
import sys
1213
from types import TracebackType
1314
from typing import (
1415
Any,
@@ -394,7 +395,12 @@ async def execute_workflow(
394395
:py:meth:`temporalio.worker.WorkflowInboundInterceptor.execute_workflow`.
395396
"""
396397
with self._top_level_workflow_context(success_is_complete=True) as ctx:
397-
return await ctx.run(asyncio.create_task, self._execute_workflow(input))
398+
if sys.version_info >= (3, 11):
399+
return await asyncio.create_task(
400+
self._execute_workflow(input), context=ctx
401+
)
402+
else:
403+
return await ctx.run(asyncio.create_task, self._execute_workflow(input))
398404

399405
async def _execute_workflow(
400406
self, input: temporalio.worker.ExecuteWorkflowInput
@@ -414,7 +420,12 @@ async def handle_signal(self, input: temporalio.worker.HandleSignalInput) -> Non
414420
# Create a span in the current context for the signal and link any
415421
# header given
416422
with self._top_level_workflow_context(success_is_complete=False) as ctx:
417-
return await ctx.run(asyncio.create_task, self._handle_signal(input))
423+
if sys.version_info >= (3, 11):
424+
return await asyncio.create_task(
425+
self._handle_signal(input), context=ctx
426+
)
427+
else:
428+
return await ctx.run(asyncio.create_task, self._handle_signal(input))
418429

419430
async def _handle_signal(self, input: temporalio.worker.HandleSignalInput) -> None:
420431
"""Implementation of
@@ -511,9 +522,14 @@ async def handle_update_handler(
511522
:py:meth:`temporalio.worker.WorkflowInboundInterceptor.handle_update_handler`.
512523
"""
513524
with self._top_level_workflow_context(success_is_complete=False) as ctx:
514-
return await ctx.run(
515-
asyncio.create_task, self._handle_update_handler(input)
516-
)
525+
if sys.version_info >= (3, 11):
526+
return await asyncio.create_task(
527+
self._handle_update_handler(input), context=ctx
528+
)
529+
else:
530+
return await ctx.run(
531+
asyncio.create_task, self._handle_update_handler(input)
532+
)
517533

518534
async def _handle_update_handler(
519535
self, input: temporalio.worker.HandleUpdateInput

tests/nexus/test_handler.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import uuid
2121
from collections.abc import Mapping
2222
from concurrent.futures.thread import ThreadPoolExecutor
23-
from dataclasses import dataclass
23+
from dataclasses import dataclass, field
2424
from types import MappingProxyType
2525
from typing import Any, Callable, Optional, Union
2626

@@ -313,7 +313,9 @@ async def non_serializable_output(
313313
class SuccessfulResponse:
314314
status_code: int
315315
body_json: Optional[Union[dict[str, Any], Callable[[dict[str, Any]], bool]]] = None
316-
headers: Mapping[str, str] = SUCCESSFUL_RESPONSE_HEADERS
316+
headers: Mapping[str, str] = field(
317+
default_factory=lambda: SUCCESSFUL_RESPONSE_HEADERS
318+
)
317319

318320

319321
@dataclass
@@ -325,7 +327,9 @@ class UnsuccessfulResponse:
325327
# Expected value of inverse of non_retryable attribute of exception.
326328
retryable_exception: bool = True
327329
body_json: Optional[Callable[[dict[str, Any]], bool]] = None
328-
headers: Mapping[str, str] = UNSUCCESSFUL_RESPONSE_HEADERS
330+
headers: Mapping[str, str] = field(
331+
default_factory=lambda: UNSUCCESSFUL_RESPONSE_HEADERS
332+
)
329333

330334

331335
class _TestCase:

0 commit comments

Comments
 (0)