Skip to content

Commit 598e862

Browse files
committed
Refine decorator to preserve sync/async semantics
1 parent 2235325 commit 598e862

File tree

5 files changed

+201
-78
lines changed

5 files changed

+201
-78
lines changed

src/datastar_py/django.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from __future__ import annotations
22

33
from collections.abc import Awaitable, Mapping
4-
from functools import wraps
5-
from inspect import isawaitable
4+
from functools import partial, wraps
5+
from inspect import isasyncgenfunction, iscoroutinefunction
66
from typing import Any, Callable, ParamSpec
77

88
from django.http import HttpRequest
@@ -46,30 +46,44 @@ def __init__(
4646

4747
def datastar_response(
4848
func: Callable[P, Awaitable[DatastarEvents] | DatastarEvents],
49-
) -> Callable[P, Awaitable[DatastarResponse]]:
49+
) -> Callable[P, Awaitable[DatastarResponse] | DatastarResponse]:
5050
"""A decorator which wraps a function result in DatastarResponse.
5151
5252
Can be used on a sync or async function or generator function.
53+
Preserves the sync/async nature of the decorated function.
5354
"""
54-
55-
@wraps(func)
56-
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> DatastarResponse:
57-
r = func(*args, **kwargs)
58-
59-
if hasattr(r, "__aiter__"):
60-
raise NotImplementedError(
61-
"Async generators/iterables are not yet supported by the Django adapter; "
62-
"use a sync generator or return a single value/awaitable instead."
63-
)
64-
65-
if hasattr(r, "__iter__") and not isinstance(r, (str, bytes)):
66-
return DatastarResponse(r)
67-
68-
if isawaitable(r):
69-
return DatastarResponse(await r)
70-
return DatastarResponse(r)
71-
72-
return wrapper
55+
# Unwrap partials to inspect the actual underlying function
56+
actual_func = func
57+
while isinstance(actual_func, partial):
58+
actual_func = actual_func.func
59+
60+
# Async generators not supported by Django
61+
if isasyncgenfunction(actual_func):
62+
raise NotImplementedError(
63+
"Async generators are not yet supported by the Django adapter; "
64+
"use a sync generator or return a single value/awaitable instead."
65+
)
66+
67+
# Coroutine (async def + return)
68+
if iscoroutinefunction(actual_func):
69+
70+
@wraps(actual_func)
71+
async def async_coro_wrapper(*args: P.args, **kwargs: P.kwargs) -> DatastarResponse:
72+
result = await func(*args, **kwargs)
73+
return DatastarResponse(result)
74+
75+
async_coro_wrapper.__annotations__["return"] = DatastarResponse
76+
return async_coro_wrapper
77+
78+
# Sync Function (def) - includes sync generators
79+
else:
80+
81+
@wraps(actual_func)
82+
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> DatastarResponse:
83+
return DatastarResponse(func(*args, **kwargs))
84+
85+
sync_wrapper.__annotations__["return"] = DatastarResponse
86+
return sync_wrapper
7387

7488

7589
def read_signals(request: HttpRequest) -> dict[str, Any] | None:

src/datastar_py/litestar.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from __future__ import annotations
22

33
from collections.abc import Awaitable, Mapping
4-
from functools import wraps
5-
from inspect import isawaitable
4+
from functools import partial, wraps
5+
from inspect import isasyncgenfunction, iscoroutinefunction
66
from typing import (
77
TYPE_CHECKING,
88
Any,
@@ -65,32 +65,47 @@ def __init__(
6565

6666
def datastar_response(
6767
func: Callable[P, Awaitable[DatastarEvents] | DatastarEvents],
68-
) -> Callable[P, DatastarResponse]:
68+
) -> Callable[P, Awaitable[DatastarResponse] | DatastarResponse]:
6969
"""A decorator which wraps a function result in DatastarResponse.
7070
7171
Can be used on a sync or async function or generator function.
72+
Preserves the sync/async nature of the decorated function.
7273
"""
74+
# Unwrap partials to inspect the actual underlying function
75+
actual_func = func
76+
while isinstance(actual_func, partial):
77+
actual_func = actual_func.func
7378

74-
@wraps(func)
75-
def wrapper(*args: P.args, **kwargs: P.kwargs) -> DatastarResponse:
76-
r = func(*args, **kwargs)
79+
# Case A: Async Generator (async def + yield)
80+
if isasyncgenfunction(actual_func):
7781

78-
if hasattr(r, "__aiter__"):
79-
return DatastarResponse(r)
82+
@wraps(actual_func)
83+
async def async_gen_wrapper(*args: P.args, **kwargs: P.kwargs) -> DatastarResponse:
84+
return DatastarResponse(func(*args, **kwargs))
8085

81-
if hasattr(r, "__iter__") and not isinstance(r, (str, bytes)):
82-
return DatastarResponse(r)
86+
async_gen_wrapper.__annotations__["return"] = DatastarResponse
87+
return async_gen_wrapper
8388

84-
if isawaitable(r):
85-
async def await_and_yield():
86-
yield await r
89+
# Case B: Standard Coroutine (async def + return)
90+
elif iscoroutinefunction(actual_func):
8791

88-
return DatastarResponse(await_and_yield())
92+
@wraps(actual_func)
93+
async def async_coro_wrapper(*args: P.args, **kwargs: P.kwargs) -> DatastarResponse:
94+
result = await func(*args, **kwargs)
95+
return DatastarResponse(result)
8996

90-
return DatastarResponse(r)
97+
async_coro_wrapper.__annotations__["return"] = DatastarResponse
98+
return async_coro_wrapper
9199

92-
wrapper.__annotations__["return"] = DatastarResponse
93-
return wrapper
100+
# Case C: Sync Function (def) - includes sync generators
101+
else:
102+
103+
@wraps(actual_func)
104+
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> DatastarResponse:
105+
return DatastarResponse(func(*args, **kwargs))
106+
107+
sync_wrapper.__annotations__["return"] = DatastarResponse
108+
return sync_wrapper
94109

95110

96111
async def read_signals(request: Request) -> dict[str, Any] | None:

src/datastar_py/quart.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from __future__ import annotations
22

33
from collections.abc import Awaitable, Mapping
4-
from functools import wraps
5-
from inspect import isasyncgen, isasyncgenfunction, isgenerator
4+
from functools import partial, wraps
5+
from inspect import isasyncgen, isasyncgenfunction, iscoroutinefunction, isgenerator
66
from typing import Any, Callable, ParamSpec
77

8-
from quart import Response, copy_current_request_context, request, stream_with_context
8+
from quart import Response, request, stream_with_context
99

1010
from . import _read_signals
1111
from .sse import SSE_HEADERS, DatastarEvents, ServerSentEventGenerator
@@ -43,20 +43,47 @@ def __init__(
4343

4444
def datastar_response(
4545
func: Callable[P, Awaitable[DatastarEvents] | DatastarEvents],
46-
) -> Callable[P, Awaitable[DatastarResponse]]:
46+
) -> Callable[P, Awaitable[DatastarResponse] | DatastarResponse]:
4747
"""A decorator which wraps a function result in DatastarResponse.
4848
4949
Can be used on a sync or async function or generator function.
50+
Preserves the sync/async nature of the decorated function.
5051
"""
52+
# Unwrap partials to inspect the actual underlying function
53+
actual_func = func
54+
while isinstance(actual_func, partial):
55+
actual_func = actual_func.func
5156

52-
@wraps(func)
53-
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> DatastarResponse:
54-
if isasyncgenfunction(func):
57+
# Case A: Async Generator (async def + yield)
58+
if isasyncgenfunction(actual_func):
59+
60+
@wraps(actual_func)
61+
async def async_gen_wrapper(*args: P.args, **kwargs: P.kwargs) -> DatastarResponse:
5562
return DatastarResponse(stream_with_context(func)(*args, **kwargs))
56-
return DatastarResponse(await copy_current_request_context(func)(*args, **kwargs))
5763

58-
wrapper.__annotations__["return"] = DatastarResponse
59-
return wrapper
64+
async_gen_wrapper.__annotations__["return"] = DatastarResponse
65+
return async_gen_wrapper
66+
67+
# Case B: Standard Coroutine (async def + return)
68+
elif iscoroutinefunction(actual_func):
69+
70+
@wraps(actual_func)
71+
async def async_coro_wrapper(*args: P.args, **kwargs: P.kwargs) -> DatastarResponse:
72+
result = await func(*args, **kwargs)
73+
return DatastarResponse(result)
74+
75+
async_coro_wrapper.__annotations__["return"] = DatastarResponse
76+
return async_coro_wrapper
77+
78+
# Case C: Sync Function (def) - includes sync generators
79+
else:
80+
81+
@wraps(actual_func)
82+
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> DatastarResponse:
83+
return DatastarResponse(func(*args, **kwargs))
84+
85+
sync_wrapper.__annotations__["return"] = DatastarResponse
86+
return sync_wrapper
6087

6188

6289
async def read_signals() -> dict[str, Any] | None:

src/datastar_py/starlette.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from __future__ import annotations
22

33
from collections.abc import Awaitable, Mapping
4-
from functools import wraps
5-
from inspect import isawaitable
4+
from functools import partial, wraps
5+
from inspect import isasyncgenfunction, iscoroutinefunction
66
from typing import (
77
TYPE_CHECKING,
88
Any,
@@ -55,37 +55,47 @@ def __init__(
5555

5656
def datastar_response(
5757
func: Callable[P, Awaitable[DatastarEvents] | DatastarEvents],
58-
) -> Callable[P, DatastarResponse]:
58+
) -> Callable[P, Awaitable[DatastarResponse] | DatastarResponse]:
5959
"""A decorator which wraps a function result in DatastarResponse.
6060
6161
Can be used on a sync or async function or generator function.
62+
Preserves the sync/async nature of the decorated function.
6263
"""
64+
# Unwrap partials to inspect the actual underlying function
65+
actual_func = func
66+
while isinstance(actual_func, partial):
67+
actual_func = actual_func.func
6368

64-
@wraps(func)
65-
def wrapper(*args: P.args, **kwargs: P.kwargs) -> DatastarResponse:
66-
r = func(*args, **kwargs)
69+
# Case A: Async Generator (async def + yield)
70+
if isasyncgenfunction(actual_func):
6771

68-
# Check for async generator/iterator first (most specific case)
69-
if hasattr(r, "__aiter__"):
70-
return DatastarResponse(r)
72+
@wraps(actual_func)
73+
async def async_gen_wrapper(*args: P.args, **kwargs: P.kwargs) -> DatastarResponse:
74+
return DatastarResponse(func(*args, **kwargs))
7175

72-
# Check for sync generator/iterator (before Awaitable to avoid false positives)
73-
if hasattr(r, "__iter__") and not isinstance(r, (str, bytes)):
74-
return DatastarResponse(r)
76+
async_gen_wrapper.__annotations__["return"] = DatastarResponse
77+
return async_gen_wrapper
7578

76-
# Check for coroutines/tasks (but NOT async generators, already handled above)
77-
if isawaitable(r):
78-
# Wrap awaitable in an async generator that yields the result
79-
async def await_and_yield():
80-
yield await r
79+
# Case B: Standard Coroutine (async def + return)
80+
elif iscoroutinefunction(actual_func):
8181

82-
return DatastarResponse(await_and_yield())
82+
@wraps(actual_func)
83+
async def async_coro_wrapper(*args: P.args, **kwargs: P.kwargs) -> DatastarResponse:
84+
result = await func(*args, **kwargs)
85+
return DatastarResponse(result)
8386

84-
# Default case: single value or unknown type
85-
return DatastarResponse(r)
87+
async_coro_wrapper.__annotations__["return"] = DatastarResponse
88+
return async_coro_wrapper
8689

87-
wrapper.__annotations__["return"] = DatastarResponse
88-
return wrapper
90+
# Case C: Sync Function (def) - includes sync generators
91+
else:
92+
93+
@wraps(actual_func)
94+
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> DatastarResponse:
95+
return DatastarResponse(func(*args, **kwargs))
96+
97+
sync_wrapper.__annotations__["return"] = DatastarResponse
98+
return sync_wrapper
8999

90100

91101
async def read_signals(request: Request) -> dict[str, Any] | None:

tests/test_datastar_decorator_runtime.py

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ def anyio_backend() -> str:
3535
"async_generator",
3636
],
3737
)
38-
def test_decorator_returns_response_objects(module_path: str, variant: str) -> None:
39-
"""Decorated handlers should stay sync-callable and return DatastarResponse immediately."""
38+
def test_decorator_preserves_sync_async_semantics(module_path: str, variant: str) -> None:
39+
"""Decorated handlers should preserve sync/async nature of the original function."""
4040

4141
mod = importlib.import_module(module_path)
4242
datastar_response = mod.datastar_response
@@ -59,12 +59,18 @@ async def handler() -> Any:
5959
async def handler() -> Any:
6060
yield SSE.patch_signals({"ok": True})
6161

62-
result = handler()
63-
if inspect.iscoroutine(result):
64-
result.close() # avoid "coroutine was never awaited" warnings
62+
is_async_variant = variant.startswith("async_")
6563

66-
assert not inspect.iscoroutinefunction(handler), "Decorator should preserve sync callable semantics"
67-
assert isinstance(result, DatastarResponse)
64+
# Verify the wrapper preserves sync/async nature
65+
if is_async_variant:
66+
assert inspect.iscoroutinefunction(handler), "Async handlers should remain async"
67+
# Call and close coroutine to avoid warnings (we can't await in sync test)
68+
coro = handler()
69+
coro.close()
70+
else:
71+
assert not inspect.iscoroutinefunction(handler), "Sync handlers should remain sync"
72+
result = handler()
73+
assert isinstance(result, DatastarResponse), "Sync handlers should return DatastarResponse directly"
6874

6975

7076
async def _fetch(
@@ -125,3 +131,54 @@ async def ping(request) -> PlainTextResponse: # noqa: ANN001
125131
finally:
126132
server.should_exit = True
127133
thread.join(timeout=2)
134+
135+
136+
def test_async_generator_iterates_on_event_loop() -> None:
137+
"""Async generators should iterate on the event loop, not spawn a thread.
138+
139+
This addresses the concern that a sync wrapper might cause async handlers
140+
to run in the threadpool. The wrapper being sync only affects where the
141+
generator object is created (trivial); iteration happens based on iterator
142+
type - Starlette's StreamingResponse detects __aiter__ and iterates async.
143+
144+
This test uses Starlette, but the same principle applies to Litestar which
145+
also uses a sync wrapper. Litestar's Stream response similarly detects
146+
async iterators and iterates them on the event loop.
147+
"""
148+
from starlette.testclient import TestClient
149+
150+
from datastar_py.starlette import datastar_response
151+
152+
execution_threads: dict[str, str] = {}
153+
154+
@datastar_response
155+
async def async_gen_handler(request) -> Any: # noqa: ANN001
156+
execution_threads["async_gen"] = threading.current_thread().name
157+
yield SSE.patch_signals({"async": True})
158+
159+
@datastar_response
160+
def sync_gen_handler(request) -> Any: # noqa: ANN001
161+
execution_threads["sync_gen"] = threading.current_thread().name
162+
yield SSE.patch_signals({"sync": True})
163+
164+
app = Starlette(routes=[
165+
Route("/async", async_gen_handler),
166+
Route("/sync", sync_gen_handler),
167+
])
168+
169+
with TestClient(app) as client:
170+
client.get("/async")
171+
client.get("/sync")
172+
173+
# Async generator runs on the asyncio portal thread (event loop context)
174+
# Sync generator runs in a separate threadpool worker
175+
# The key assertion: they run in DIFFERENT thread contexts
176+
assert execution_threads["async_gen"] != execution_threads["sync_gen"], (
177+
f"Async and sync generators should run in different thread contexts. "
178+
f"Async ran on: {execution_threads['async_gen']}, Sync ran on: {execution_threads['sync_gen']}"
179+
)
180+
181+
# Async generator should be on the event loop thread (asyncio-portal-* or MainThread)
182+
assert "asyncio" in execution_threads["async_gen"] or execution_threads["async_gen"] == "MainThread", (
183+
f"Async generator should run on event loop, but ran on {execution_threads['async_gen']}"
184+
)

0 commit comments

Comments
 (0)