Skip to content

Commit f59030b

Browse files
authored
feat: support Annotated pattern for enum registration (#4293)
* feat: support Annotated pattern for enum registration * fix: address review feedback on enum docs and tests * refactor: replace EnumDefinition with functools.partial * refactor: replace functools.partial with EnumAnnotation class
1 parent ccca075 commit f59030b

File tree

6 files changed

+210
-14
lines changed

6 files changed

+210
-14
lines changed

RELEASE.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
Release type: minor
2+
3+
Enums can now be registered via `Annotated`. The preferred way is still using
4+
`@strawberry.enum` as a decorator, but when you need to expose an existing enum
5+
under a different name or alias, `Annotated` works as a proper type alias in all
6+
type checkers:
7+
8+
```python
9+
from typing import Annotated
10+
from enum import Enum
11+
import strawberry
12+
13+
14+
class IceCreamFlavour(Enum):
15+
VANILLA = "vanilla"
16+
STRAWBERRY = "strawberry"
17+
CHOCOLATE = "chocolate"
18+
19+
20+
MyIceCreamFlavour = Annotated[
21+
IceCreamFlavour, strawberry.enum(description="Ice cream flavours")
22+
]
23+
24+
25+
@strawberry.type
26+
class Query:
27+
@strawberry.field
28+
def flavour(self) -> MyIceCreamFlavour:
29+
return IceCreamFlavour.VANILLA
30+
```

docs/editors/mypy.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,33 @@ plugins = ["pydantic.mypy", "strawberry.ext.mypy_plugin"]
2929

3030
The strawberry plugin synthesises `__init__`, `to_pydantic()` and
3131
`from_pydantic()` on pydantic-decorated classes so that mypy can see them.
32+
33+
## Enums
34+
35+
The preferred way to register an enum is with the decorator:
36+
37+
```python
38+
from enum import Enum
39+
import strawberry
40+
41+
42+
@strawberry.enum
43+
class IceCreamFlavour(Enum):
44+
VANILLA = "vanilla"
45+
STRAWBERRY = "strawberry"
46+
```
47+
48+
If you need to expose an existing enum under a different name or alias, use
49+
`Annotated` instead of assigning `strawberry.enum(IceCreamFlavour)` to a
50+
variable — mypy treats the latter as a value, not a type, so it cannot be used
51+
in annotations.
52+
53+
```python
54+
from typing import Annotated
55+
import strawberry
56+
57+
MyIceCreamFlavour = Annotated[IceCreamFlavour, strawberry.enum(description="...")]
58+
```
59+
60+
`MyIceCreamFlavour` is a proper type alias that mypy and Pyright accept in
61+
annotations.

strawberry/annotation.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
get_object_definition,
3030
has_object_definition,
3131
)
32-
from strawberry.types.enum import StrawberryEnumDefinition
32+
from strawberry.types.enum import EnumAnnotation, StrawberryEnumDefinition
3333
from strawberry.types.enum import enum as strawberry_enum
3434
from strawberry.types.lazy_type import LazyType
3535
from strawberry.types.maybe import _annotation_is_maybe
@@ -189,7 +189,7 @@ def _resolve_evaled_type(self, evaled_type: Any) -> StrawberryType | type:
189189
# Everything remaining should be a raw annotation that needs to be turned into
190190
# a StrawberryType
191191
if self._is_enum(evaled_type):
192-
return self.create_enum(evaled_type)
192+
return self.create_enum(evaled_type, args)
193193
if self._is_optional(evaled_type, args):
194194
return self.create_optional(evaled_type)
195195
if self._is_union(evaled_type, args):
@@ -215,10 +215,20 @@ def create_concrete_type(self, evaled_type: type) -> type:
215215
return evaled_type.__strawberry_definition__.resolve_generic(evaled_type)
216216
raise ValueError(f"Not supported {evaled_type}")
217217

218-
def create_enum(self, evaled_type: Any) -> StrawberryEnumDefinition:
218+
def create_enum(
219+
self, evaled_type: Any, args: list[Any] | None = None
220+
) -> StrawberryEnumDefinition:
221+
enum_annotation: EnumAnnotation | None = None
222+
if args:
223+
enum_annotation = next(
224+
(a for a in args if isinstance(a, EnumAnnotation)), None
225+
)
226+
219227
try:
220228
return evaled_type.__strawberry_definition__
221229
except AttributeError:
230+
if enum_annotation is not None:
231+
return enum_annotation(evaled_type).__strawberry_definition__
222232
return strawberry_enum(evaled_type).__strawberry_definition__
223233

224234
def create_list(self, evaled_type: Any) -> StrawberryList:

strawberry/types/enum.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,23 @@ class MyEnum(Enum):
105105
GraphqlEnumNameFrom = Literal["key", "value"]
106106

107107

108+
@dataclasses.dataclass
109+
class EnumAnnotation:
110+
name: str | None = None
111+
description: str | None = None
112+
directives: Iterable[object] = ()
113+
graphql_name_from: GraphqlEnumNameFrom = "key"
114+
115+
def __call__(self, cls: EnumType) -> EnumType:
116+
return _process_enum(
117+
cls,
118+
self.name,
119+
self.description,
120+
directives=self.directives,
121+
graphql_name_from=self.graphql_name_from,
122+
)
123+
124+
108125
def _process_enum(
109126
cls: EnumType,
110127
name: str | None = None,
@@ -238,20 +255,17 @@ class MyEnum(Enum):
238255
If name is passed, the name of the GraphQL type will be
239256
the value passed of name instead of the Enum class name.
240257
"""
241-
242-
def wrap(cls: EnumType) -> EnumType:
243-
return _process_enum(
244-
cls,
245-
name,
246-
description,
247-
directives=directives,
248-
graphql_name_from=graphql_name_from,
249-
)
258+
wrapper = EnumAnnotation(
259+
name=name,
260+
description=description,
261+
directives=directives,
262+
graphql_name_from=graphql_name_from,
263+
)
250264

251265
if not cls:
252-
return wrap
266+
return wrapper
253267

254-
return wrap(cls)
268+
return wrapper(cls)
255269

256270

257271
WithStrawberryEnumDefinition = WithStrawberryDefinition["StrawberryEnumDefinition"]
@@ -265,6 +279,7 @@ def has_enum_definition(obj: Any) -> TypeGuard[type[WithStrawberryEnumDefinition
265279

266280

267281
__all__ = [
282+
"EnumAnnotation",
268283
"EnumValue",
269284
"EnumValueDefinition",
270285
"StrawberryEnumDefinition",

tests/enums/test_enum.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from enum import Enum, IntEnum
2+
from typing import Annotated
23

34
import pytest
45

@@ -236,3 +237,60 @@ class Query:
236237

237238
with pytest.raises(TypeError, match=r"Expected name to be a string."):
238239
strawberry.Schema(Query)
240+
241+
242+
def test_annotated_enum_with_description():
243+
class IceCreamFlavour(Enum):
244+
VANILLA = "vanilla"
245+
STRAWBERRY = "strawberry"
246+
CHOCOLATE = "chocolate"
247+
248+
MyIceCreamFlavour = Annotated[
249+
IceCreamFlavour, strawberry.enum(description="Ice cream flavours")
250+
]
251+
252+
@strawberry.type
253+
class Query:
254+
@strawberry.field
255+
def flavour(self) -> MyIceCreamFlavour:
256+
return IceCreamFlavour.VANILLA
257+
258+
schema = strawberry.Schema(Query)
259+
schema_str = str(schema)
260+
assert '"""Ice cream flavours"""' in schema_str
261+
assert "enum IceCreamFlavour" in schema_str
262+
263+
264+
def test_annotated_enum_with_name():
265+
class Flavour(Enum):
266+
VANILLA = "vanilla"
267+
STRAWBERRY = "strawberry"
268+
269+
MyIceCreamFlavour = Annotated[Flavour, strawberry.enum(name="IceCreamFlavour")]
270+
271+
@strawberry.type
272+
class Query:
273+
@strawberry.field
274+
def flavour(self) -> MyIceCreamFlavour:
275+
return Flavour.VANILLA
276+
277+
schema = strawberry.Schema(Query)
278+
assert "IceCreamFlavour" in str(schema)
279+
280+
281+
def test_annotated_enum_on_already_decorated():
282+
@strawberry.enum
283+
class Flavour(Enum):
284+
VANILLA = "vanilla"
285+
STRAWBERRY = "strawberry"
286+
287+
MyFlavour = Annotated[Flavour, strawberry.enum(description="Flavours")]
288+
289+
@strawberry.type
290+
class Query:
291+
@strawberry.field
292+
def flavour(self) -> MyFlavour:
293+
return Flavour.VANILLA
294+
295+
schema = strawberry.Schema(Query)
296+
assert "enum Flavour" in str(schema)

tests/typecheckers/test_enum.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,59 @@ def test_enum_with_manual_decorator_and_name():
279279
)
280280

281281

282+
CODE_WITH_ANNOTATED = """
283+
from enum import Enum
284+
from typing import Annotated
285+
286+
import strawberry
287+
288+
class IceCreamFlavour(Enum):
289+
VANILLA = "vanilla"
290+
STRAWBERRY = "strawberry"
291+
CHOCOLATE = "chocolate"
292+
293+
MyIceCreamFlavour = Annotated[IceCreamFlavour, strawberry.enum(description="Flavours")]
294+
295+
x: MyIceCreamFlavour = IceCreamFlavour.VANILLA
296+
reveal_type(x)
297+
"""
298+
299+
300+
def test_enum_with_annotated():
301+
results = typecheck(CODE_WITH_ANNOTATED)
302+
303+
assert results.pyright == snapshot(
304+
[
305+
Result(
306+
type="information",
307+
message='Type of "x" is "Literal[IceCreamFlavour.VANILLA]"',
308+
line=15,
309+
column=13,
310+
),
311+
]
312+
)
313+
assert results.mypy == snapshot(
314+
[
315+
Result(
316+
type="note",
317+
message='Revealed type is "mypy_test.IceCreamFlavour"',
318+
line=15,
319+
column=13,
320+
),
321+
]
322+
)
323+
assert results.ty == snapshot(
324+
[
325+
Result(
326+
type="information",
327+
message="Revealed type: `Literal[IceCreamFlavour.VANILLA]`",
328+
line=15,
329+
column=13,
330+
),
331+
]
332+
)
333+
334+
282335
CODE_WITH_DEPRECATION_REASON = """
283336
from enum import Enum
284337

0 commit comments

Comments
 (0)