Skip to content

Commit 5562692

Browse files
authored
fix: fix generics resolution when using Annotated unions (#3950)
* fix: fix generics resolution when using Annotated unions Fix #3289 * fix mypy
1 parent 741b0bc commit 5562692

File tree

4 files changed

+151
-27
lines changed

4 files changed

+151
-27
lines changed

RELEASE.md

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
Release type: patch
2+
3+
This release fixes the resolution of `Generics` when specializing using a union
4+
defined with `Annotated`, like in the example below:
5+
6+
```python
7+
from typing import Annotated, Generic, TypeVar, Union
8+
import strawberry
9+
10+
T = TypeVar("T")
11+
12+
13+
@strawberry.type
14+
class User:
15+
name: str
16+
age: int
17+
18+
19+
@strawberry.type
20+
class ProUser:
21+
name: str
22+
age: float
23+
24+
25+
@strawberry.type
26+
class GenType(Generic[T]):
27+
data: T
28+
29+
30+
GeneralUser = Annotated[Union[User, ProUser], strawberry.union("GeneralUser")]
31+
32+
33+
@strawberry.type
34+
class Response(GenType[GeneralUser]): ...
35+
36+
37+
@strawberry.type
38+
class Query:
39+
@strawberry.field
40+
def user(self) -> Response: ...
41+
42+
43+
schema = strawberry.Schema(query=Query)
44+
```
45+
46+
Before this would raise a `TypeError`, now it works as expected.

strawberry/annotation.py

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,16 +125,44 @@ def _get_type_with_args(
125125

126126
return evaled_type, []
127127

128-
def resolve(self) -> Union[StrawberryType, type]:
128+
def resolve(
129+
self,
130+
*,
131+
type_definition: Optional[StrawberryObjectDefinition] = None,
132+
) -> Union[StrawberryType, type]:
129133
"""Return resolved (transformed) annotation."""
130-
if self.__resolve_cache__ is None:
131-
self.__resolve_cache__ = self._resolve()
134+
if (resolved := self.__resolve_cache__) is None:
135+
resolved = self._resolve()
136+
self.__resolve_cache__ = resolved
137+
138+
# If this is a generic field, try to resolve it using its origin's
139+
# specialized type_var_map
140+
if self._is_type_generic(resolved) and type_definition is not None:
141+
from strawberry.types.base import StrawberryType
142+
143+
specialized_type_var_map = type_definition.specialized_type_var_map
144+
if specialized_type_var_map and isinstance(resolved, StrawberryType):
145+
resolved = resolved.copy_with(specialized_type_var_map)
146+
147+
# If the field is still generic, try to resolve it from the type_definition
148+
# that is asking for it.
149+
if (
150+
self._is_type_generic(resolved)
151+
and type_definition.type_var_map
152+
and isinstance(resolved, StrawberryType)
153+
):
154+
resolved = resolved.copy_with(type_definition.type_var_map)
155+
156+
# Resolve the type again to resolve any `Annotated` types
157+
resolved = self._resolve_evaled_type(resolved)
132158

133-
return self.__resolve_cache__
159+
return resolved
134160

135161
def _resolve(self) -> Union[StrawberryType, type]:
136162
evaled_type = cast("Any", self.evaluate())
163+
return self._resolve_evaled_type(evaled_type)
137164

165+
def _resolve_evaled_type(self, evaled_type: Any) -> Union[StrawberryType, type]:
138166
if is_private(evaled_type):
139167
return evaled_type
140168

@@ -145,7 +173,7 @@ def _resolve(self) -> Union[StrawberryType, type]:
145173
if self._is_lazy_type(evaled_type):
146174
return evaled_type
147175
if self._is_streamable(evaled_type, args):
148-
return self.create_list(list[evaled_type]) # type: ignore[valid-type]
176+
return self.create_list(list[evaled_type])
149177
if self._is_list(evaled_type):
150178
return self.create_list(evaled_type)
151179
if self._is_maybe(evaled_type):
@@ -292,6 +320,20 @@ def _is_enum(cls, annotation: Any) -> bool:
292320
return False
293321
return issubclass(annotation, Enum)
294322

323+
@classmethod
324+
def _is_type_generic(cls, type_: Union[StrawberryType, type]) -> bool:
325+
"""Returns True if `resolver_type` is generic else False."""
326+
from strawberry.types.base import StrawberryType
327+
328+
if isinstance(type_, StrawberryType):
329+
return type_.is_graphql_generic
330+
331+
# solves the Generic subclass case
332+
if has_object_definition(type_):
333+
return type_.__strawberry_definition__.is_graphql_generic
334+
335+
return False
336+
295337
@classmethod
296338
def _is_graphql_generic(cls, annotation: Any) -> bool:
297339
if hasattr(annotation, "__origin__"):

strawberry/types/field.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
Optional,
1414
TypeVar,
1515
Union,
16-
cast,
1716
overload,
1817
)
1918

@@ -348,34 +347,14 @@ def resolve_type(
348347
with contextlib.suppress(NameError):
349348
# Prioritise the field type over the resolver return type
350349
if self.type_annotation is not None:
351-
resolved = self.type_annotation.resolve()
350+
resolved = self.type_annotation.resolve(type_definition=type_definition)
352351
elif self.base_resolver is not None and self.base_resolver.type is not None:
353352
# Handle unannotated functions (such as lambdas)
354353
# Generics will raise MissingTypesForGenericError later
355354
# on if we let it be returned. So use `type_annotation` instead
356355
# which is the same behaviour as having no type information.
357356
resolved = self.base_resolver.type
358357

359-
# If this is a generic field, try to resolve it using its origin's
360-
# specialized type_var_map
361-
# TODO: should we check arguments here too?
362-
if _is_generic(resolved): # type: ignore
363-
specialized_type_var_map = (
364-
type_definition and type_definition.specialized_type_var_map
365-
)
366-
if specialized_type_var_map and isinstance(resolved, StrawberryType):
367-
resolved = resolved.copy_with(specialized_type_var_map)
368-
369-
# If the field is still generic, try to resolve it from the type_definition
370-
# that is asking for it.
371-
if (
372-
_is_generic(cast("Union[StrawberryType, type]", resolved))
373-
and type_definition is not None
374-
and type_definition.type_var_map
375-
and isinstance(resolved, StrawberryType)
376-
):
377-
resolved = resolved.copy_with(type_definition.type_var_map)
378-
379358
return resolved
380359

381360
def copy_with(

tests/schema/test_union.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1194,3 +1194,60 @@ class Query:
11941194
"""
11951195
).strip()
11961196
)
1197+
1198+
1199+
def test_union_used_inside_generic():
1200+
T = TypeVar("T")
1201+
1202+
@strawberry.type
1203+
class User:
1204+
name: str
1205+
age: int
1206+
1207+
@strawberry.type
1208+
class ProUser:
1209+
name: str
1210+
age: float
1211+
1212+
@strawberry.type
1213+
class GenType(Generic[T]):
1214+
data: T
1215+
1216+
GeneralUser = Annotated[Union[User, ProUser], strawberry.union("GeneralUser")]
1217+
1218+
@strawberry.type
1219+
class Response(GenType[GeneralUser]): ...
1220+
1221+
@strawberry.type
1222+
class Query:
1223+
@strawberry.field
1224+
def user(self) -> Response: ...
1225+
1226+
schema = strawberry.Schema(query=Query)
1227+
1228+
assert (
1229+
str(schema)
1230+
== textwrap.dedent(
1231+
"""
1232+
union GeneralUser = User | ProUser
1233+
1234+
type ProUser {
1235+
name: String!
1236+
age: Float!
1237+
}
1238+
1239+
type Query {
1240+
user: Response!
1241+
}
1242+
1243+
type Response {
1244+
data: GeneralUser!
1245+
}
1246+
1247+
type User {
1248+
name: String!
1249+
age: Int!
1250+
}
1251+
"""
1252+
).strip()
1253+
)

0 commit comments

Comments
 (0)