Skip to content

Commit d7eb874

Browse files
authored
Support handler/run_typed Optional return types (#158)
This commit adds support for handler *return types* that are `Union[T, None].` With this commit, the correct serializer will be detected (and the correct json schema) for the following situations ```python async def greeter(ctx: Context, name: str) -> MyResponse | None: return MyResponse(..) ``` With this commit the correct serializer will be used (i.e. the serializer for MyResponse) instead of a generic JSON one. This also applies for `ctx.run_typed` where the provided function had the return type of `T | None`. There is no need to pass a `type_hint`in this situation. Previously: ``` async def my_effect() -> PydanticMessage | None res = ctx.run_typed(... , type_hint=PydanticMessage) ```
1 parent 4827509 commit d7eb874

File tree

5 files changed

+201
-30
lines changed

5 files changed

+201
-30
lines changed

python/restate/handler.py

Lines changed: 120 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,25 @@
1818
from dataclasses import dataclass
1919
from datetime import timedelta
2020
from inspect import Signature
21-
from typing import Any, AsyncContextManager, Callable, Awaitable, Dict, Generic, List, Literal, Optional, TypeVar
21+
from typing import (
22+
Any,
23+
AsyncContextManager,
24+
Callable,
25+
Awaitable,
26+
Dict,
27+
Generic,
28+
List,
29+
Literal,
30+
Optional,
31+
TypeVar,
32+
)
2233

2334
from restate.retry_policy import InvocationRetryPolicy
2435

2536
from restate.context import HandlerType
2637
from restate.exceptions import TerminalError
2738
from restate.serde import DefaultSerde, PydanticJsonSerde, MsgspecJsonSerde, Serde, is_pydantic, Msgspec
39+
from restate.types import extract_core_type
2840

2941
I = TypeVar("I")
3042
O = TypeVar("O")
@@ -78,48 +90,129 @@ class HandlerIO(Generic[I, O]):
7890
output_type: Optional[TypeHint[O]] = None
7991

8092

81-
def update_handler_io_with_type_hints(handler_io: HandlerIO[I, O], signature: Signature):
93+
def _json_schema_wrap_as_optional(schema: Dict[str, Any]) -> Dict[str, Any]:
8294
"""
83-
Augment handler_io with additional information about the input and output types.
95+
modify the given JSON schema with its type wrapped as optional (nullable).
96+
"""
97+
t = schema.get("type")
98+
99+
if t is None:
100+
# If type is unspecified, leave it open by only adding "null"
101+
schema["type"] = ["null"]
102+
return schema
103+
104+
if isinstance(t, list):
105+
if "null" not in t:
106+
t.append("null")
107+
else:
108+
if t != "null":
109+
schema["type"] = [t, "null"]
110+
111+
return schema
112+
113+
114+
def _make_json_schema_generator(
115+
original: Callable[[], Dict[str, Any]], type: Literal["optional", "simple"]
116+
) -> Callable[[], Dict[str, Any]]:
117+
"""
118+
Create a JSON schema generator that handles optional types.
119+
120+
If the type is optional, the generated schema will include "null" in the type.
121+
"""
122+
if type == "simple":
123+
return original
124+
125+
def generator() -> Dict[str, Any]:
126+
schema = original()
127+
if type == "optional":
128+
return _json_schema_wrap_as_optional(schema)
129+
130+
assert False, "unreachable"
131+
132+
return generator
133+
134+
135+
def update_handler_io_with_input_type_hints(handler_io: HandlerIO[I, O], signature: Signature):
136+
"""
137+
Augment handler_io with additional information about the input type.
84138
85139
This function has a special check for msgspec Structs and Pydantic models when these are provided.
86140
This method will inspect the signature of an handler and will look for
87-
the input and the return types of a function, and will:
141+
the input type of a function, and will:
88142
* capture any msgspec Structs or Pydantic models (to be used later at discovery)
89143
* replace the default json serializer (is unchanged by a user) with the appropriate serde
90144
"""
91145
params = list(signature.parameters.values())
92146
if len(params) == 1:
93147
# if there is only one parameter, it is the context.
94148
handler_io.input_type = TypeHint(is_void=True)
95-
else:
96-
annotation = params[-1].annotation
97-
handler_io.input_type = TypeHint(annotation=annotation)
98-
if Msgspec.is_struct(annotation):
99-
handler_io.input_type.generate_json_schema = lambda: Msgspec.json_schema(annotation)
100-
if isinstance(handler_io.input_serde, DefaultSerde):
101-
handler_io.input_serde = MsgspecJsonSerde(annotation)
102-
elif is_pydantic(annotation):
103-
handler_io.input_type.generate_json_schema = lambda: annotation.model_json_schema(mode="serialization")
104-
if isinstance(handler_io.input_serde, DefaultSerde):
105-
handler_io.input_serde = PydanticJsonSerde(annotation)
149+
return
150+
151+
annotation = params[-1].annotation
152+
core_kind, core_type = extract_core_type(annotation)
153+
handler_io.input_type = TypeHint(annotation=core_type)
154+
if Msgspec.is_struct(core_type):
155+
handler_io.input_type.generate_json_schema = _make_json_schema_generator(
156+
lambda: Msgspec.json_schema(core_type), core_kind
157+
)
158+
if isinstance(handler_io.input_serde, DefaultSerde):
159+
handler_io.input_serde = MsgspecJsonSerde(core_type)
160+
return
161+
162+
if is_pydantic(core_type):
163+
handler_io.input_type.generate_json_schema = _make_json_schema_generator(
164+
lambda: core_type.model_json_schema(mode="serialization"), core_kind
165+
)
166+
if isinstance(handler_io.input_serde, DefaultSerde):
167+
handler_io.input_serde = PydanticJsonSerde(core_type)
168+
169+
170+
def update_handler_io_with_return_type_hints(handler_io: HandlerIO[I, O], signature: Signature):
171+
"""
172+
Augment handler_io with additional information about the output type.
106173
174+
This function has a special check for msgspec Structs and Pydantic models when these are provided.
175+
This method will inspect the signature of an handler and will look for
176+
the return type of a function, and will:
177+
* capture any msgspec Structs or Pydantic models (to be used later at discovery)
178+
* replace the default json serializer (is unchanged by a user) with the appropriate serde
179+
"""
107180
return_annotation = signature.return_annotation
108181
if return_annotation is None or return_annotation is Signature.empty:
109182
# if there is no return annotation, we assume it is void
110183
handler_io.output_type = TypeHint(is_void=True)
111-
else:
112-
handler_io.output_type = TypeHint(annotation=return_annotation)
113-
if Msgspec.is_struct(return_annotation):
114-
handler_io.output_type.generate_json_schema = lambda: Msgspec.json_schema(return_annotation)
115-
if isinstance(handler_io.output_serde, DefaultSerde):
116-
handler_io.output_serde = MsgspecJsonSerde(return_annotation)
117-
elif is_pydantic(return_annotation):
118-
handler_io.output_type.generate_json_schema = lambda: return_annotation.model_json_schema(
119-
mode="serialization"
120-
)
121-
if isinstance(handler_io.output_serde, DefaultSerde):
122-
handler_io.output_serde = PydanticJsonSerde(return_annotation)
184+
return
185+
186+
core_kind, return_core_type = extract_core_type(return_annotation)
187+
handler_io.output_type = TypeHint(annotation=return_core_type)
188+
if Msgspec.is_struct(return_core_type):
189+
handler_io.output_type.generate_json_schema = _make_json_schema_generator(
190+
lambda: Msgspec.json_schema(return_core_type), core_kind
191+
)
192+
if isinstance(handler_io.output_serde, DefaultSerde):
193+
handler_io.output_serde = MsgspecJsonSerde(return_core_type)
194+
return
195+
196+
if is_pydantic(return_core_type):
197+
handler_io.output_type.generate_json_schema = _make_json_schema_generator(
198+
lambda: return_core_type.model_json_schema(mode="serialization"), core_kind
199+
)
200+
if isinstance(handler_io.output_serde, DefaultSerde):
201+
handler_io.output_serde = PydanticJsonSerde(return_core_type)
202+
203+
204+
def update_handler_io_with_type_hints(handler_io: HandlerIO[I, O], signature: Signature):
205+
"""
206+
Augment handler_io with additional information about the input and output types.
207+
208+
This function has a special check for msgspec Structs and Pydantic models when these are provided.
209+
This method will inspect the signature of an handler and will look for
210+
the input and the return types of a function, and will:
211+
* capture any msgspec Structs or Pydantic models (to be used later at discovery)
212+
* replace the default json serializer (is unchanged by a user) with the appropriate serde
213+
"""
214+
update_handler_io_with_input_type_hints(handler_io, signature)
215+
update_handler_io_with_return_type_hints(handler_io, signature)
123216

124217

125218
# pylint: disable=R0902

python/restate/server_context.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from restate.handler import Handler, handler_from_callable, invoke_handler
5050
from restate.serde import BytesSerde, DefaultSerde, Serde
5151
from restate.server_types import ReceiveChannel, Send
52+
from restate.types import extract_core_type
5253
from restate.vm import Failure, Invocation, NotReady, VMWrapper, RunRetryConfig, Suspended # pylint: disable=line-too-long
5354
from restate.vm import (
5455
DoProgressAnyCompleted,
@@ -697,6 +698,10 @@ def run_typed(
697698
if options.type_hint is None:
698699
signature = inspect.signature(action, eval_str=True)
699700
options.type_hint = signature.return_annotation
701+
core_type_kind, core_type = extract_core_type(options.type_hint)
702+
if core_type_kind == "simple" or core_type_kind == "optional":
703+
# use core type as it is more specific. E.g. Optional[T] -> T
704+
options.type_hint = core_type
700705
options.serde = typing.cast(DefaultSerde, options.serde).with_maybe_type(options.type_hint)
701706
handle = self.vm.sys_run(name)
702707
update_restate_context_is_replaying(self.vm)

python/restate/types.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
"""
1717

1818
from dataclasses import dataclass
19+
from types import UnionType
20+
from typing import Any, Tuple, Literal, Union, get_args, get_origin
1921

2022
from restate.client_types import RestateClient
2123

@@ -32,3 +34,22 @@ class HarnessEnvironment:
3234

3335
client: RestateClient
3436
"""The Restate client connected to the ingress URL"""
37+
38+
39+
def extract_core_type(annotation: Any) -> Tuple[Literal["optional", "simple"], Any]:
40+
"""
41+
Extract the core type from a type annotation.
42+
43+
Currently only supports Optional[T] types.
44+
"""
45+
if annotation is None:
46+
return "simple", annotation
47+
48+
origin = get_origin(annotation)
49+
args = get_args(annotation)
50+
51+
if (origin is UnionType or Union) and len(args) == 2 and type(None) in args:
52+
non_none_type = args[0] if args[1] is type(None) else args[1]
53+
return "optional", non_none_type
54+
55+
return "simple", annotation

tests/serde.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,30 @@
44
def test_bytes_serde():
55
s = BytesSerde()
66
assert bytes(range(20)) == s.serialize(bytes(range(20)))
7+
8+
9+
def extract_core_type_optional():
10+
from restate.types import extract_core_type
11+
12+
from typing import Optional, Union
13+
14+
kind, tpe = extract_core_type(Optional[int])
15+
assert kind == "optional"
16+
assert tpe is int
17+
18+
kind, tpe = extract_core_type(Union[int, None])
19+
assert kind == "optional"
20+
assert tpe is int
21+
22+
kind, tpe = extract_core_type(str | None)
23+
assert kind == "optional"
24+
assert tpe is str
25+
26+
kind, tpe = extract_core_type(None | str)
27+
assert kind == "optional"
28+
assert tpe is str
29+
30+
for t in [int, str, bytes, dict, list, None]:
31+
kind, tpe = extract_core_type(t)
32+
assert kind == "simple"
33+
assert tpe is t

tests/servercontext.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,38 @@ async def test_promise_default_serde():
8585
async def run(ctx: WorkflowContext) -> str:
8686
promise = ctx.promise("test.promise", type_hint=str)
8787

88-
assert isinstance(promise.serde, DefaultSerde), \
89-
f"Expected DefaultSerde but got {type(promise.serde).__name__}"
88+
assert isinstance(promise.serde, DefaultSerde), f"Expected DefaultSerde but got {type(promise.serde).__name__}"
9089

9190
await promise.resolve("success")
9291
return await promise.value()
9392

94-
9593
async with simple_harness(workflow) as client:
9694
result = await client.workflow_call(run, key="test-key", arg=None)
9795
assert result == "success"
96+
97+
98+
async def test_handler_with_union_none():
99+
greeter = Service("greeter")
100+
101+
@greeter.handler()
102+
async def greet(ctx: Context, name: str) -> str | None:
103+
return "hi"
104+
105+
async with simple_harness(greeter) as client:
106+
res = await client.service_call(greet, arg="bob")
107+
assert res == "hi"
108+
109+
110+
async def test_handler_with_ctx_none():
111+
greeter = Service("greeter")
112+
113+
async def maybe_something() -> str | None:
114+
return "hi"
115+
116+
@greeter.handler()
117+
async def greet(ctx: Context, name: str) -> str | None:
118+
return await ctx.run_typed("foo", maybe_something)
119+
120+
async with simple_harness(greeter) as client:
121+
res = await client.service_call(greet, arg="bob")
122+
assert res == "hi"

0 commit comments

Comments
 (0)