|
18 | 18 | from dataclasses import dataclass |
19 | 19 | from datetime import timedelta |
20 | 20 | from inspect import Signature |
21 | | -from typing import Any, AsyncContextManager, Callable, Awaitable, Dict, Generic, List, Literal, Optional, TypeVar |
| 21 | +from typing import ( |
| 22 | + Any, |
| 23 | + AsyncContextManager, |
| 24 | + Callable, |
| 25 | + Awaitable, |
| 26 | + Dict, |
| 27 | + Generic, |
| 28 | + List, |
| 29 | + Literal, |
| 30 | + Optional, |
| 31 | + TypeVar, |
| 32 | +) |
22 | 33 |
|
23 | 34 | from restate.retry_policy import InvocationRetryPolicy |
24 | 35 |
|
25 | 36 | from restate.context import HandlerType |
26 | 37 | from restate.exceptions import TerminalError |
27 | 38 | from restate.serde import DefaultSerde, PydanticJsonSerde, MsgspecJsonSerde, Serde, is_pydantic, Msgspec |
| 39 | +from restate.types import extract_core_type |
28 | 40 |
|
29 | 41 | I = TypeVar("I") |
30 | 42 | O = TypeVar("O") |
@@ -78,48 +90,129 @@ class HandlerIO(Generic[I, O]): |
78 | 90 | output_type: Optional[TypeHint[O]] = None |
79 | 91 |
|
80 | 92 |
|
81 | | -def update_handler_io_with_type_hints(handler_io: HandlerIO[I, O], signature: Signature): |
| 93 | +def _json_schema_wrap_as_optional(schema: Dict[str, Any]) -> Dict[str, Any]: |
82 | 94 | """ |
83 | | - Augment handler_io with additional information about the input and output types. |
| 95 | + modify the given JSON schema with its type wrapped as optional (nullable). |
| 96 | + """ |
| 97 | + t = schema.get("type") |
| 98 | + |
| 99 | + if t is None: |
| 100 | + # If type is unspecified, leave it open by only adding "null" |
| 101 | + schema["type"] = ["null"] |
| 102 | + return schema |
| 103 | + |
| 104 | + if isinstance(t, list): |
| 105 | + if "null" not in t: |
| 106 | + t.append("null") |
| 107 | + else: |
| 108 | + if t != "null": |
| 109 | + schema["type"] = [t, "null"] |
| 110 | + |
| 111 | + return schema |
| 112 | + |
| 113 | + |
| 114 | +def _make_json_schema_generator( |
| 115 | + original: Callable[[], Dict[str, Any]], type: Literal["optional", "simple"] |
| 116 | +) -> Callable[[], Dict[str, Any]]: |
| 117 | + """ |
| 118 | + Create a JSON schema generator that handles optional types. |
| 119 | +
|
| 120 | + If the type is optional, the generated schema will include "null" in the type. |
| 121 | + """ |
| 122 | + if type == "simple": |
| 123 | + return original |
| 124 | + |
| 125 | + def generator() -> Dict[str, Any]: |
| 126 | + schema = original() |
| 127 | + if type == "optional": |
| 128 | + return _json_schema_wrap_as_optional(schema) |
| 129 | + |
| 130 | + assert False, "unreachable" |
| 131 | + |
| 132 | + return generator |
| 133 | + |
| 134 | + |
| 135 | +def update_handler_io_with_input_type_hints(handler_io: HandlerIO[I, O], signature: Signature): |
| 136 | + """ |
| 137 | + Augment handler_io with additional information about the input type. |
84 | 138 |
|
85 | 139 | This function has a special check for msgspec Structs and Pydantic models when these are provided. |
86 | 140 | This method will inspect the signature of an handler and will look for |
87 | | - the input and the return types of a function, and will: |
| 141 | + the input type of a function, and will: |
88 | 142 | * capture any msgspec Structs or Pydantic models (to be used later at discovery) |
89 | 143 | * replace the default json serializer (is unchanged by a user) with the appropriate serde |
90 | 144 | """ |
91 | 145 | params = list(signature.parameters.values()) |
92 | 146 | if len(params) == 1: |
93 | 147 | # if there is only one parameter, it is the context. |
94 | 148 | handler_io.input_type = TypeHint(is_void=True) |
95 | | - else: |
96 | | - annotation = params[-1].annotation |
97 | | - handler_io.input_type = TypeHint(annotation=annotation) |
98 | | - if Msgspec.is_struct(annotation): |
99 | | - handler_io.input_type.generate_json_schema = lambda: Msgspec.json_schema(annotation) |
100 | | - if isinstance(handler_io.input_serde, DefaultSerde): |
101 | | - handler_io.input_serde = MsgspecJsonSerde(annotation) |
102 | | - elif is_pydantic(annotation): |
103 | | - handler_io.input_type.generate_json_schema = lambda: annotation.model_json_schema(mode="serialization") |
104 | | - if isinstance(handler_io.input_serde, DefaultSerde): |
105 | | - handler_io.input_serde = PydanticJsonSerde(annotation) |
| 149 | + return |
| 150 | + |
| 151 | + annotation = params[-1].annotation |
| 152 | + core_kind, core_type = extract_core_type(annotation) |
| 153 | + handler_io.input_type = TypeHint(annotation=core_type) |
| 154 | + if Msgspec.is_struct(core_type): |
| 155 | + handler_io.input_type.generate_json_schema = _make_json_schema_generator( |
| 156 | + lambda: Msgspec.json_schema(core_type), core_kind |
| 157 | + ) |
| 158 | + if isinstance(handler_io.input_serde, DefaultSerde): |
| 159 | + handler_io.input_serde = MsgspecJsonSerde(core_type) |
| 160 | + return |
| 161 | + |
| 162 | + if is_pydantic(core_type): |
| 163 | + handler_io.input_type.generate_json_schema = _make_json_schema_generator( |
| 164 | + lambda: core_type.model_json_schema(mode="serialization"), core_kind |
| 165 | + ) |
| 166 | + if isinstance(handler_io.input_serde, DefaultSerde): |
| 167 | + handler_io.input_serde = PydanticJsonSerde(core_type) |
| 168 | + |
| 169 | + |
| 170 | +def update_handler_io_with_return_type_hints(handler_io: HandlerIO[I, O], signature: Signature): |
| 171 | + """ |
| 172 | + Augment handler_io with additional information about the output type. |
106 | 173 |
|
| 174 | + This function has a special check for msgspec Structs and Pydantic models when these are provided. |
| 175 | + This method will inspect the signature of an handler and will look for |
| 176 | + the return type of a function, and will: |
| 177 | + * capture any msgspec Structs or Pydantic models (to be used later at discovery) |
| 178 | + * replace the default json serializer (is unchanged by a user) with the appropriate serde |
| 179 | + """ |
107 | 180 | return_annotation = signature.return_annotation |
108 | 181 | if return_annotation is None or return_annotation is Signature.empty: |
109 | 182 | # if there is no return annotation, we assume it is void |
110 | 183 | handler_io.output_type = TypeHint(is_void=True) |
111 | | - else: |
112 | | - handler_io.output_type = TypeHint(annotation=return_annotation) |
113 | | - if Msgspec.is_struct(return_annotation): |
114 | | - handler_io.output_type.generate_json_schema = lambda: Msgspec.json_schema(return_annotation) |
115 | | - if isinstance(handler_io.output_serde, DefaultSerde): |
116 | | - handler_io.output_serde = MsgspecJsonSerde(return_annotation) |
117 | | - elif is_pydantic(return_annotation): |
118 | | - handler_io.output_type.generate_json_schema = lambda: return_annotation.model_json_schema( |
119 | | - mode="serialization" |
120 | | - ) |
121 | | - if isinstance(handler_io.output_serde, DefaultSerde): |
122 | | - handler_io.output_serde = PydanticJsonSerde(return_annotation) |
| 184 | + return |
| 185 | + |
| 186 | + core_kind, return_core_type = extract_core_type(return_annotation) |
| 187 | + handler_io.output_type = TypeHint(annotation=return_core_type) |
| 188 | + if Msgspec.is_struct(return_core_type): |
| 189 | + handler_io.output_type.generate_json_schema = _make_json_schema_generator( |
| 190 | + lambda: Msgspec.json_schema(return_core_type), core_kind |
| 191 | + ) |
| 192 | + if isinstance(handler_io.output_serde, DefaultSerde): |
| 193 | + handler_io.output_serde = MsgspecJsonSerde(return_core_type) |
| 194 | + return |
| 195 | + |
| 196 | + if is_pydantic(return_core_type): |
| 197 | + handler_io.output_type.generate_json_schema = _make_json_schema_generator( |
| 198 | + lambda: return_core_type.model_json_schema(mode="serialization"), core_kind |
| 199 | + ) |
| 200 | + if isinstance(handler_io.output_serde, DefaultSerde): |
| 201 | + handler_io.output_serde = PydanticJsonSerde(return_core_type) |
| 202 | + |
| 203 | + |
| 204 | +def update_handler_io_with_type_hints(handler_io: HandlerIO[I, O], signature: Signature): |
| 205 | + """ |
| 206 | + Augment handler_io with additional information about the input and output types. |
| 207 | +
|
| 208 | + This function has a special check for msgspec Structs and Pydantic models when these are provided. |
| 209 | + This method will inspect the signature of an handler and will look for |
| 210 | + the input and the return types of a function, and will: |
| 211 | + * capture any msgspec Structs or Pydantic models (to be used later at discovery) |
| 212 | + * replace the default json serializer (is unchanged by a user) with the appropriate serde |
| 213 | + """ |
| 214 | + update_handler_io_with_input_type_hints(handler_io, signature) |
| 215 | + update_handler_io_with_return_type_hints(handler_io, signature) |
123 | 216 |
|
124 | 217 |
|
125 | 218 | # pylint: disable=R0902 |
|
0 commit comments