|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import copy |
3 | 4 | import dataclasses |
4 | 5 | import sys |
5 | | -from typing import Any |
| 6 | +from typing import Annotated, Any, get_args, get_origin |
6 | 7 |
|
7 | 8 | from strawberry.annotation import StrawberryAnnotation |
8 | 9 | from strawberry.exceptions import ( |
9 | 10 | FieldWithResolverAndDefaultFactoryError, |
10 | 11 | FieldWithResolverAndDefaultValueError, |
| 12 | + MultipleStrawberryFieldsError, |
11 | 13 | PrivateStrawberryFieldError, |
12 | 14 | ) |
13 | 15 | from strawberry.types.base import has_object_definition |
|
16 | 18 | from strawberry.types.unset import UNSET |
17 | 19 |
|
18 | 20 |
|
| 21 | +def _get_field_from_annotated( |
| 22 | + field: dataclasses.Field, |
| 23 | + origin: type, |
| 24 | + module_namespace: dict[str, Any], |
| 25 | + cls: type, |
| 26 | +) -> StrawberryField | None: |
| 27 | + """Extract a StrawberryField from an Annotated type annotation. |
| 28 | +
|
| 29 | + Returns a configured StrawberryField if the annotation contains one, |
| 30 | + or None if no StrawberryField is found in the Annotated args. |
| 31 | + Raises MultipleStrawberryFieldsError if more than one is found. |
| 32 | + """ |
| 33 | + field_type = field.type |
| 34 | + |
| 35 | + if get_origin(field_type) is not Annotated: |
| 36 | + return None |
| 37 | + |
| 38 | + first, *rest = get_args(field_type) |
| 39 | + |
| 40 | + strawberry_fields = [arg for arg in rest if isinstance(arg, StrawberryField)] |
| 41 | + |
| 42 | + if len(strawberry_fields) > 1: |
| 43 | + raise MultipleStrawberryFieldsError(field_name=field.name, cls=cls) |
| 44 | + |
| 45 | + if not strawberry_fields: |
| 46 | + return None |
| 47 | + |
| 48 | + result = copy.copy(strawberry_fields[0]) |
| 49 | + result.python_name = field.name |
| 50 | + result.type_annotation = StrawberryAnnotation( |
| 51 | + annotation=first, |
| 52 | + namespace=module_namespace, |
| 53 | + ) |
| 54 | + result.origin = origin |
| 55 | + |
| 56 | + # Transfer default from dataclass field if not set in strawberry.field() |
| 57 | + if result.default is dataclasses.MISSING: |
| 58 | + result.default = field.default |
| 59 | + result.default_value = field.default |
| 60 | + |
| 61 | + return result |
| 62 | + |
| 63 | + |
19 | 64 | def _get_fields( |
20 | 65 | cls: type[Any], original_type_annotations: dict[str, type[Any]] |
21 | 66 | ) -> list[StrawberryField]: |
@@ -82,6 +127,19 @@ class if one is not set by either using an explicit strawberry.field(name=...) o |
82 | 127 | # then we can proceed with finding the fields for the current class |
83 | 128 | for field in dataclasses.fields(cls): # type: ignore |
84 | 129 | if isinstance(field, StrawberryField): |
| 130 | + # Check for conflict: strawberry.field in both Annotated and assignment |
| 131 | + annotation = ( |
| 132 | + field.type_annotation.annotation |
| 133 | + if isinstance(field.type_annotation, StrawberryAnnotation) |
| 134 | + else field.type |
| 135 | + ) |
| 136 | + if get_origin(annotation) is Annotated: |
| 137 | + annotated_args = get_args(annotation) |
| 138 | + if any(isinstance(arg, StrawberryField) for arg in annotated_args[1:]): |
| 139 | + raise MultipleStrawberryFieldsError( |
| 140 | + field_name=field.python_name or field.name, cls=cls |
| 141 | + ) |
| 142 | + |
85 | 143 | # Check that the field type is not Private |
86 | 144 | if is_private(field.type): |
87 | 145 | raise PrivateStrawberryFieldError(field.python_name, cls) |
@@ -140,18 +198,24 @@ class if one is not set by either using an explicit strawberry.field(name=...) o |
140 | 198 | origin = origins.get(field.name, cls) |
141 | 199 | module = sys.modules[origin.__module__] |
142 | 200 |
|
143 | | - # Create a StrawberryField, for fields of Types #1 and #2a |
144 | | - field = StrawberryField( # noqa: PLW2901 |
145 | | - python_name=field.name, |
146 | | - graphql_name=None, |
147 | | - type_annotation=StrawberryAnnotation( |
148 | | - annotation=field.type, |
149 | | - namespace=module.__dict__, |
150 | | - ), |
151 | | - origin=origin, |
152 | | - default=getattr(cls, field.name, dataclasses.MISSING), |
| 201 | + annotated_field = _get_field_from_annotated( |
| 202 | + field, origin, module.__dict__, cls |
153 | 203 | ) |
154 | 204 |
|
| 205 | + if annotated_field is not None: |
| 206 | + field = annotated_field # noqa: PLW2901 |
| 207 | + else: |
| 208 | + field = StrawberryField( # noqa: PLW2901 |
| 209 | + python_name=field.name, |
| 210 | + graphql_name=None, |
| 211 | + type_annotation=StrawberryAnnotation( |
| 212 | + annotation=field.type, |
| 213 | + namespace=module.__dict__, |
| 214 | + ), |
| 215 | + origin=origin, |
| 216 | + default=getattr(cls, field.name, dataclasses.MISSING), |
| 217 | + ) |
| 218 | + |
155 | 219 | field_name = field.python_name |
156 | 220 |
|
157 | 221 | assert_message = "Field must have a name by the time the schema is generated" |
|
0 commit comments