Skip to content

Commit 673c671

Browse files
committed
Add context manager support
1 parent eb8e2fd commit 673c671

File tree

8 files changed

+140
-4
lines changed

8 files changed

+140
-4
lines changed

python/restate/context_managers.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
#
2+
# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH
3+
#
4+
# This file is part of the Restate SDK for Python,
5+
# which is released under the MIT license.
6+
#
7+
# You can find a copy of the license in file LICENSE in the root
8+
# directory of this repository or package, or at
9+
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
10+
#
11+
"""
12+
contextvar utility for async context managers.
13+
"""
14+
15+
import contextvars
16+
from contextlib import asynccontextmanager
17+
from typing import (
18+
Any,
19+
AsyncContextManager,
20+
AsyncGenerator,
21+
Callable,
22+
Generic,
23+
ParamSpec,
24+
TypeVar,
25+
)
26+
27+
P = ParamSpec("P")
28+
T = TypeVar("T")
29+
30+
31+
class contextvar(Generic[P, T]):
32+
"""
33+
A type-safe decorator for asynccontextmanager functions that captures the yielded value in a ContextVar.
34+
This is useful when integrating with frameworks that only support None yielded values from context managers.
35+
36+
Example usage:
37+
```python
38+
@contextvar
39+
@asynccontextmanager
40+
async def my_resource() -> AsyncIterator[str]:
41+
yield "hi"
42+
43+
async def usage_example():
44+
async with my_resource():
45+
print(my_resource.value) # prints "hi"
46+
```
47+
48+
49+
"""
50+
51+
def __init__(self, func: Callable[P, AsyncContextManager[T]]):
52+
self.func = func
53+
self._value_var: contextvars.ContextVar[T | None] = contextvars.ContextVar("value")
54+
55+
@property
56+
def value(self) -> T:
57+
"""Return the value yielded by the wrapped context manager."""
58+
val = self._value_var.get()
59+
if val is None:
60+
raise LookupError("No value set in contextvar")
61+
return val
62+
63+
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> AsyncContextManager[None]:
64+
@asynccontextmanager
65+
async def wrapper() -> AsyncGenerator[None, Any]:
66+
async with self.func(*args, **kwargs) as value:
67+
token = self._value_var.set(value)
68+
try:
69+
yield # we make it yield None, as the value is accessible via .value()
70+
finally:
71+
self._value_var.reset(token)
72+
73+
return wrapper()

python/restate/extensions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@
1111
"""This module contains internal extensions apis"""
1212

1313
from .server_context import current_context
14+
from .context_managers import contextvar
1415

15-
__all__ = ["current_context"]
16+
__all__ = ["current_context", "contextvar"]

python/restate/handler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from dataclasses import dataclass
1919
from datetime import timedelta
2020
from inspect import Signature
21-
from typing import Any, Callable, Awaitable, Dict, Generic, Literal, Optional, TypeVar
21+
from typing import Any, AsyncContextManager, Callable, Awaitable, Dict, Generic, List, Literal, Optional, TypeVar
2222

2323
from restate.retry_policy import InvocationRetryPolicy
2424

@@ -150,6 +150,7 @@ class Handler(Generic[I, O]):
150150
enable_lazy_state: Optional[bool] = None
151151
ingress_private: Optional[bool] = None
152152
invocation_retry_policy: Optional[InvocationRetryPolicy] = None
153+
context_managers: Optional[List[Callable[[], AsyncContextManager[None]]]] = None
153154

154155

155156
# disable too many arguments warning
@@ -172,6 +173,7 @@ def make_handler(
172173
enable_lazy_state: Optional[bool] = None,
173174
ingress_private: Optional[bool] = None,
174175
invocation_retry_policy: Optional[InvocationRetryPolicy] = None,
176+
context_managers: Optional[List[Callable[[], AsyncContextManager[None]]]] = None,
175177
) -> Handler[I, O]:
176178
"""
177179
Factory function to create a handler.
@@ -225,6 +227,7 @@ def make_handler(
225227
enable_lazy_state=enable_lazy_state,
226228
ingress_private=ingress_private,
227229
invocation_retry_policy=invocation_retry_policy,
230+
context_managers=context_managers,
228231
)
229232

230233
vars(wrapped)[RESTATE_UNIQUE_HANDLER_SYMBOL] = handler

python/restate/object.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def __init__(
8787
enable_lazy_state: typing.Optional[bool] = None,
8888
ingress_private: typing.Optional[bool] = None,
8989
invocation_retry_policy: typing.Optional[InvocationRetryPolicy] = None,
90+
context_managers: typing.Optional[typing.List[typing.Callable[[], typing.AsyncContextManager[None]]]] = None,
9091
):
9192
self.service_tag = ServiceTag("object", name, description, metadata)
9293
self.handlers = {}
@@ -97,6 +98,7 @@ def __init__(
9798
self.enable_lazy_state = enable_lazy_state
9899
self.ingress_private = ingress_private
99100
self.invocation_retry_policy = invocation_retry_policy
101+
self.context_managers = context_managers
100102

101103
@property
102104
def name(self):
@@ -122,6 +124,7 @@ def handler(
122124
enable_lazy_state: typing.Optional[bool] = None,
123125
ingress_private: typing.Optional[bool] = None,
124126
invocation_retry_policy: typing.Optional[InvocationRetryPolicy] = None,
127+
context_managers: typing.Optional[typing.List[typing.Callable[[], typing.AsyncContextManager[None]]]] = None,
125128
) -> typing.Callable[[T], T]:
126129
"""
127130
Decorator for defining a handler function.
@@ -184,6 +187,11 @@ def wrapped(*args, **kwargs):
184187
return fn(*args, **kwargs)
185188

186189
signature = inspect.signature(fn, eval_str=True)
190+
combined_context_managers = (
191+
(self.context_managers or []) + (context_managers or [])
192+
if self.context_managers or context_managers
193+
else None
194+
)
187195
handler = make_handler(
188196
self.service_tag,
189197
handler_io,
@@ -201,6 +209,7 @@ def wrapped(*args, **kwargs):
201209
enable_lazy_state,
202210
ingress_private,
203211
invocation_retry_policy,
212+
combined_context_managers,
204213
)
205214
self.handlers[handler.name] = handler
206215
return wrapped

python/restate/server_context.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"""This module contains the restate context implementation based on the server"""
1818

1919
import asyncio
20+
from contextlib import AsyncExitStack
2021
import contextvars
2122
import copy
2223
from random import Random
@@ -342,7 +343,10 @@ async def enter(self):
342343
token = _restate_context_var.set(self)
343344
try:
344345
in_buffer = self.invocation.input_buffer
345-
out_buffer = await invoke_handler(handler=self.handler, ctx=self, in_buffer=in_buffer)
346+
async with AsyncExitStack() as stack:
347+
for manager in self.handler.context_managers or []:
348+
await stack.enter_async_context(manager())
349+
out_buffer = await invoke_handler(handler=self.handler, ctx=self, in_buffer=in_buffer)
346350
restate_context_is_replaying.set(False)
347351
self.vm.sys_write_output_success(bytes(out_buffer))
348352
self.vm.sys_end()

python/restate/service.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def __init__(
8181
idempotency_retention: typing.Optional[timedelta] = None,
8282
ingress_private: typing.Optional[bool] = None,
8383
invocation_retry_policy: typing.Optional[InvocationRetryPolicy] = None,
84+
context_managers: typing.Optional[typing.List[typing.Callable[[], typing.AsyncContextManager[None]]]] = None,
8485
) -> None:
8586
self.service_tag = ServiceTag("service", name, description, metadata)
8687
self.handlers: typing.Dict[str, Handler] = {}
@@ -90,6 +91,7 @@ def __init__(
9091
self.idempotency_retention = idempotency_retention
9192
self.ingress_private = ingress_private
9293
self.invocation_retry_policy = invocation_retry_policy
94+
self.context_managers = context_managers
9395

9496
@property
9597
def name(self):
@@ -112,6 +114,7 @@ def handler(
112114
idempotency_retention: typing.Optional[timedelta] = None,
113115
ingress_private: typing.Optional[bool] = None,
114116
invocation_retry_policy: typing.Optional[InvocationRetryPolicy] = None,
117+
context_managers: typing.Optional[typing.List[typing.Callable[[], typing.AsyncContextManager[None]]]] = None,
115118
) -> typing.Callable[[T], T]:
116119
"""
117120
Decorator for defining a handler function.
@@ -170,6 +173,14 @@ def wrapped(*args, **kwargs):
170173
return fn(*args, **kwargs)
171174

172175
signature = inspect.signature(fn, eval_str=True)
176+
177+
# combine context managers or leave None if both are None
178+
combined_context_managers = (
179+
(self.context_managers or []) + (context_managers or [])
180+
if self.context_managers or context_managers
181+
else None
182+
)
183+
173184
handler = make_handler(
174185
self.service_tag,
175186
handler_io,
@@ -187,6 +198,7 @@ def wrapped(*args, **kwargs):
187198
None,
188199
ingress_private,
189200
invocation_retry_policy,
201+
combined_context_managers,
190202
)
191203
self.handlers[handler.name] = handler
192204
return wrapped

python/restate/workflow.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def __init__(
9292
enable_lazy_state: typing.Optional[bool] = None,
9393
ingress_private: typing.Optional[bool] = None,
9494
invocation_retry_policy: typing.Optional[InvocationRetryPolicy] = None,
95+
context_managers: typing.Optional[typing.List[typing.Callable[[], typing.AsyncContextManager[None]]]] = None,
9596
):
9697
self.service_tag = ServiceTag("workflow", name, description, metadata)
9798
self.handlers = {}
@@ -102,6 +103,7 @@ def __init__(
102103
self.enable_lazy_state = enable_lazy_state
103104
self.ingress_private = ingress_private
104105
self.invocation_retry_policy = invocation_retry_policy
106+
self.context_managers = context_managers
105107

106108
@property
107109
def name(self):
@@ -125,6 +127,7 @@ def main(
125127
enable_lazy_state: typing.Optional[bool] = None,
126128
ingress_private: typing.Optional[bool] = None,
127129
invocation_retry_policy: typing.Optional[InvocationRetryPolicy] = None,
130+
context_managers: typing.Optional[typing.List[typing.Callable[[], typing.AsyncContextManager[None]]]] = None,
128131
) -> typing.Callable[[T], T]:
129132
"""
130133
Mark this handler as a workflow entry point.
@@ -182,6 +185,7 @@ def main(
182185
enable_lazy_state=enable_lazy_state,
183186
ingress_private=ingress_private,
184187
invocation_retry_policy=invocation_retry_policy,
188+
context_managers=context_managers,
185189
)
186190

187191
def handler(
@@ -199,6 +203,7 @@ def handler(
199203
enable_lazy_state: typing.Optional[bool] = None,
200204
ingress_private: typing.Optional[bool] = None,
201205
invocation_retry_policy: typing.Optional[InvocationRetryPolicy] = None,
206+
context_managers: typing.Optional[typing.List[typing.Callable[[], typing.AsyncContextManager[None]]]] = None,
202207
) -> typing.Callable[[T], T]:
203208
"""
204209
Decorator for defining a handler function.
@@ -256,6 +261,7 @@ def handler(
256261
enable_lazy_state,
257262
ingress_private,
258263
invocation_retry_policy,
264+
context_managers,
259265
)
260266

261267
# pylint: disable=R0914
@@ -276,6 +282,7 @@ def _add_handler(
276282
enable_lazy_state: typing.Optional[bool] = None,
277283
ingress_private: typing.Optional[bool] = None,
278284
invocation_retry_policy: typing.Optional["InvocationRetryPolicy"] = None,
285+
context_managers: typing.Optional[typing.List[typing.Callable[[], typing.AsyncContextManager[None]]]] = None,
279286
) -> typing.Callable[[T], T]:
280287
"""
281288
Decorator for defining a handler function.
@@ -342,6 +349,11 @@ def wrapped(*args, **kwargs):
342349

343350
signature = inspect.signature(fn, eval_str=True)
344351
description = inspect.getdoc(fn)
352+
combined_context_managers = (
353+
(self.context_managers or []) + (context_managers or [])
354+
if self.context_managers or context_managers
355+
else None
356+
)
345357
handler = make_handler(
346358
service_tag=self.service_tag,
347359
handler_io=handler_io,
@@ -359,6 +371,7 @@ def wrapped(*args, **kwargs):
359371
enable_lazy_state=enable_lazy_state,
360372
ingress_private=ingress_private,
361373
invocation_retry_policy=invocation_retry_policy,
374+
context_managers=combined_context_managers,
362375
)
363376
self.handlers[handler.name] = handler
364377
return wrapped

tests/ext.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ def anyio_backend():
3535

3636

3737
def magic_function():
38-
from restate.extensions import current_context
3938

4039
ctx = current_context()
4140
assert ctx is not None
@@ -48,6 +47,23 @@ async def greet(ctx: Context, name: str) -> str:
4847
return f"Hello {id}!"
4948

5049

50+
# -- context manager
51+
52+
from contextlib import asynccontextmanager
53+
from restate.extensions import contextvar
54+
55+
56+
@contextvar
57+
@asynccontextmanager
58+
async def my_resource_manager():
59+
yield "hello"
60+
61+
62+
@greeter.handler(context_managers=[my_resource_manager])
63+
async def greet_with_cm(ctx: Context, name: str) -> str:
64+
return my_resource_manager.value
65+
66+
5167
@pytest.fixture(scope="session")
5268
async def restate_test_harness():
5369
async with restate.create_test_harness(
@@ -62,3 +78,8 @@ async def restate_test_harness():
6278
async def test_greeter(restate_test_harness: HarnessEnvironment):
6379
greeting = await restate_test_harness.client.service_call(greet, arg="bob")
6480
assert greeting.startswith("Hello ")
81+
82+
83+
async def test_greeter_with_cm(restate_test_harness: HarnessEnvironment):
84+
greeting = await restate_test_harness.client.service_call(greet_with_cm, arg="bob")
85+
assert greeting == "hello"

0 commit comments

Comments
 (0)