Skip to content

Commit 2097b84

Browse files
fix: handle lazy filters and ordering in strawberry_django.connection (#773)
Co-authored-by: Thiago Bellini Ribeiro <[email protected]>
1 parent 2b78b4c commit 2097b84

File tree

6 files changed

+116
-36
lines changed

6 files changed

+116
-36
lines changed

strawberry_django/fields/field.py

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Any,
1919
TypeVar,
2020
Union,
21+
_AnnotatedAlias, # type: ignore
2122
cast,
2223
overload,
2324
)
@@ -574,9 +575,9 @@ def field(
574575
graphql_type: Any | None = None,
575576
extensions: Sequence[FieldExtension] = (),
576577
pagination: bool | UnsetType = UNSET,
577-
filters: type | UnsetType | None = UNSET,
578-
order: type | UnsetType | None = UNSET,
579-
ordering: type | UnsetType | None = UNSET,
578+
filters: _AnnotatedAlias | type | UnsetType | None = UNSET,
579+
order: _AnnotatedAlias | type | UnsetType | None = UNSET,
580+
ordering: _AnnotatedAlias | type | UnsetType | None = UNSET,
580581
only: TypeOrSequence[str] | None = None,
581582
select_related: TypeOrSequence[str] | None = None,
582583
prefetch_related: TypeOrSequence[PrefetchType] | None = None,
@@ -603,9 +604,9 @@ def field(
603604
graphql_type: Any | None = None,
604605
extensions: Sequence[FieldExtension] = (),
605606
pagination: bool | UnsetType = UNSET,
606-
filters: type | UnsetType | None = UNSET,
607-
order: type | UnsetType | None = UNSET,
608-
ordering: type | UnsetType | None = UNSET,
607+
filters: _AnnotatedAlias | type | UnsetType | None = UNSET,
608+
order: _AnnotatedAlias | type | UnsetType | None = UNSET,
609+
ordering: _AnnotatedAlias | type | UnsetType | None = UNSET,
609610
only: TypeOrSequence[str] | None = None,
610611
select_related: TypeOrSequence[str] | None = None,
611612
prefetch_related: TypeOrSequence[PrefetchType] | None = None,
@@ -632,9 +633,9 @@ def field(
632633
graphql_type: Any | None = None,
633634
extensions: Sequence[FieldExtension] = (),
634635
pagination: bool | UnsetType = UNSET,
635-
filters: type | UnsetType | None = UNSET,
636-
order: type | UnsetType | None = UNSET,
637-
ordering: type | UnsetType | None = UNSET,
636+
filters: _AnnotatedAlias | type | UnsetType | None = UNSET,
637+
order: _AnnotatedAlias | type | UnsetType | None = UNSET,
638+
ordering: _AnnotatedAlias | type | UnsetType | None = UNSET,
638639
only: TypeOrSequence[str] | None = None,
639640
select_related: TypeOrSequence[str] | None = None,
640641
prefetch_related: TypeOrSequence[PrefetchType] | None = None,
@@ -660,9 +661,9 @@ def field(
660661
graphql_type: Any | None = None,
661662
extensions: Sequence[FieldExtension] = (),
662663
pagination: bool | UnsetType = UNSET,
663-
filters: type | UnsetType | None = UNSET,
664-
order: type | UnsetType | None = UNSET,
665-
ordering: type | UnsetType | None = UNSET,
664+
filters: _AnnotatedAlias | type | UnsetType | None = UNSET,
665+
order: _AnnotatedAlias | type | UnsetType | None = UNSET,
666+
ordering: _AnnotatedAlias | type | UnsetType | None = UNSET,
666667
only: TypeOrSequence[str] | None = None,
667668
select_related: TypeOrSequence[str] | None = None,
668669
prefetch_related: TypeOrSequence[PrefetchType] | None = None,
@@ -808,9 +809,9 @@ def connection(
808809
directives: Sequence[object] | None = (),
809810
extensions: Sequence[FieldExtension] = (),
810811
max_results: int | None = None,
811-
filters: type | None = UNSET,
812-
order: type | None = UNSET,
813-
ordering: type | None = UNSET,
812+
filters: _AnnotatedAlias | type | UnsetType | None = UNSET,
813+
order: _AnnotatedAlias | type | UnsetType | None = UNSET,
814+
ordering: _AnnotatedAlias | type | UnsetType | None = UNSET,
814815
only: TypeOrSequence[str] | None = None,
815816
select_related: TypeOrSequence[str] | None = None,
816817
prefetch_related: TypeOrSequence[PrefetchType] | None = None,
@@ -838,9 +839,9 @@ def connection(
838839
directives: Sequence[object] | None = (),
839840
extensions: Sequence[FieldExtension] = (),
840841
max_results: int | None = None,
841-
filters: type | None = UNSET,
842-
order: type | None = UNSET,
843-
ordering: type | None = UNSET,
842+
filters: _AnnotatedAlias | type | UnsetType | None = UNSET,
843+
order: _AnnotatedAlias | type | UnsetType | None = UNSET,
844+
ordering: _AnnotatedAlias | type | UnsetType | None = UNSET,
844845
only: TypeOrSequence[str] | None = None,
845846
select_related: TypeOrSequence[str] | None = None,
846847
prefetch_related: TypeOrSequence[PrefetchType] | None = None,
@@ -866,9 +867,9 @@ def connection(
866867
directives: Sequence[object] | None = (),
867868
extensions: Sequence[FieldExtension] = (),
868869
max_results: int | None = None,
869-
filters: type | None = UNSET,
870-
order: type | None = UNSET,
871-
ordering: type | None = UNSET,
870+
filters: _AnnotatedAlias | type | UnsetType | None = UNSET,
871+
order: _AnnotatedAlias | type | UnsetType | None = UNSET,
872+
ordering: _AnnotatedAlias | type | UnsetType | None = UNSET,
872873
only: TypeOrSequence[str] | None = None,
873874
select_related: TypeOrSequence[str] | None = None,
874875
prefetch_related: TypeOrSequence[PrefetchType] | None = None,
@@ -998,9 +999,9 @@ def offset_paginated(
998999
metadata: Mapping[Any, Any] | None = None,
9991000
directives: Sequence[object] | None = (),
10001001
extensions: Sequence[FieldExtension] = (),
1001-
filters: type | None = UNSET,
1002-
order: type | None = UNSET,
1003-
ordering: type | None = UNSET,
1002+
filters: _AnnotatedAlias | type | UnsetType | None = UNSET,
1003+
order: _AnnotatedAlias | type | UnsetType | None = UNSET,
1004+
ordering: _AnnotatedAlias | type | UnsetType | None = UNSET,
10041005
only: TypeOrSequence[str] | None = None,
10051006
select_related: TypeOrSequence[str] | None = None,
10061007
prefetch_related: TypeOrSequence[PrefetchType] | None = None,
@@ -1027,9 +1028,9 @@ def offset_paginated(
10271028
metadata: Mapping[Any, Any] | None = None,
10281029
directives: Sequence[object] | None = (),
10291030
extensions: Sequence[FieldExtension] = (),
1030-
filters: type | None = UNSET,
1031-
order: type | None = UNSET,
1032-
ordering: type | None = UNSET,
1031+
filters: _AnnotatedAlias | type | UnsetType | None = UNSET,
1032+
order: _AnnotatedAlias | type | UnsetType | None = UNSET,
1033+
ordering: _AnnotatedAlias | type | UnsetType | None = UNSET,
10331034
only: TypeOrSequence[str] | None = None,
10341035
select_related: TypeOrSequence[str] | None = None,
10351036
prefetch_related: TypeOrSequence[PrefetchType] | None = None,
@@ -1054,9 +1055,9 @@ def offset_paginated(
10541055
metadata: Mapping[Any, Any] | None = None,
10551056
directives: Sequence[object] | None = (),
10561057
extensions: Sequence[FieldExtension] = (),
1057-
filters: type | None = UNSET,
1058-
order: type | None = UNSET,
1059-
ordering: type | None = UNSET,
1058+
filters: _AnnotatedAlias | type | UnsetType | None = UNSET,
1059+
order: _AnnotatedAlias | type | UnsetType | None = UNSET,
1060+
ordering: _AnnotatedAlias | type | UnsetType | None = UNSET,
10601061
only: TypeOrSequence[str] | None = None,
10611062
select_related: TypeOrSequence[str] | None = None,
10621063
prefetch_related: TypeOrSequence[PrefetchType] | None = None,

strawberry_django/filters.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
from types import FunctionType
99
from typing import (
1010
TYPE_CHECKING,
11+
Annotated,
1112
Any,
1213
Generic,
1314
TypeVar,
1415
cast,
16+
get_origin,
1517
)
1618

1719
import strawberry
@@ -32,6 +34,7 @@
3234
)
3335
from strawberry_django.utils.typing import (
3436
WithStrawberryDjangoObjectDefinition,
37+
get_type_from_lazy_annotation,
3538
has_django_definition,
3639
)
3740

@@ -287,6 +290,9 @@ def apply(
287290

288291
class StrawberryDjangoFieldFilters(StrawberryDjangoFieldBase):
289292
def __init__(self, filters: type | UnsetType | None = UNSET, **kwargs):
293+
if filters and get_origin(filters) is Annotated:
294+
filters = get_type_from_lazy_annotation(filters) or filters
295+
290296
if filters and not has_object_definition(filters):
291297
raise TypeError("filters needs to be a strawberry type")
292298

strawberry_django/ordering.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import enum
55
from typing import (
66
TYPE_CHECKING,
7+
Annotated,
78
Any,
89
Optional,
910
TypeVar,
@@ -19,15 +20,19 @@
1920
from strawberry.types.field import StrawberryField, field
2021
from strawberry.types.unset import UnsetType
2122
from strawberry.utils.str_converters import to_camel_case
22-
from typing_extensions import Self, dataclass_transform, deprecated
23+
from typing_extensions import Self, dataclass_transform, deprecated, get_origin
2324

2425
from strawberry_django.fields.base import StrawberryDjangoFieldBase
2526
from strawberry_django.fields.filter_order import (
2627
WITH_NONE_META,
2728
FilterOrderField,
2829
FilterOrderFieldResolver,
2930
)
30-
from strawberry_django.utils.typing import is_auto, unwrap_type
31+
from strawberry_django.utils.typing import (
32+
get_type_from_lazy_annotation,
33+
is_auto,
34+
unwrap_type,
35+
)
3136

3237
from .arguments import argument
3338

@@ -282,6 +287,12 @@ def __init__(
282287
ordering: type | UnsetType | None = UNSET,
283288
**kwargs,
284289
):
290+
if order and get_origin(order) is Annotated:
291+
order = get_type_from_lazy_annotation(order) or order
292+
293+
if ordering and get_origin(ordering) is Annotated:
294+
ordering = get_type_from_lazy_annotation(ordering) or ordering
295+
285296
if order and not has_object_definition(order):
286297
raise TypeError("order needs to be a strawberry type")
287298
if ordering and not has_object_definition(ordering):

strawberry_django/utils/typing.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
ClassVar,
1010
TypeVar,
1111
Union,
12+
_AnnotatedAlias, # type: ignore
1213
cast,
14+
get_args,
1315
overload,
1416
)
1517

@@ -22,7 +24,7 @@
2224
StrawberryType,
2325
WithStrawberryObjectDefinition,
2426
)
25-
from strawberry.types.lazy_type import LazyType
27+
from strawberry.types.lazy_type import LazyType, StrawberryLazyReference
2628
from strawberry.utils.typing import is_classvar
2729
from typing_extensions import Protocol
2830

@@ -112,11 +114,11 @@ def get_annotations(cls) -> dict[str, StrawberryAnnotation]:
112114

113115

114116
@overload
115-
def unwrap_type(type_: StrawberryContainer) -> StrawberryType | type: ...
117+
def unwrap_type(type_: StrawberryContainer) -> type: ...
116118

117119

118120
@overload
119-
def unwrap_type(type_: LazyType) -> StrawberryType | type: ...
121+
def unwrap_type(type_: LazyType) -> type: ...
120122

121123

122124
@overload
@@ -137,3 +139,12 @@ def unwrap_type(type_):
137139
break
138140

139141
return type_
142+
143+
144+
def get_type_from_lazy_annotation(type_: _AnnotatedAlias) -> type | None:
145+
first, *rest = get_args(type_)
146+
for arg in rest:
147+
if isinstance(arg, StrawberryLazyReference):
148+
return unwrap_type(arg.resolve_forward_ref(first))
149+
150+
return None

tests/filters/test_filters.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import textwrap
22
from enum import Enum
3-
from typing import Generic, Optional, TypeVar, cast
3+
from typing import Annotated, Generic, Optional, TypeVar, cast
44

55
import pytest
66
import strawberry
@@ -96,6 +96,11 @@ class Query:
9696
fruits: list[Fruit] = strawberry_django.field()
9797
field_filter: list[Fruit] = strawberry_django.field(filters=FieldFilter)
9898
type_filter: list[Fruit] = strawberry_django.field(filters=TypeFilter)
99+
type_lazy_filter: list[Fruit] = strawberry_django.field(
100+
filters=Annotated[
101+
"TypeFilter", strawberry.lazy("tests.filters.test_filters")
102+
]
103+
)
99104
enum_filter: list[Fruit] = strawberry_django.field(filters=EnumFilter)
100105
enum_lookup_filter: list[Fruit] = strawberry_django.field(
101106
filters=EnumLookupFilter
@@ -238,6 +243,14 @@ def test_type_filter_method(query, fruits):
238243
]
239244

240245

246+
def test_type_lazy_filter_method(query, fruits):
247+
result = query('{ fruits: typeLazyFilter(filters: { name: "anana" }) { id name } }')
248+
assert not result.errors
249+
assert result.data["fruits"] == [
250+
{"id": "3", "name": "banana"},
251+
]
252+
253+
241254
def test_resolver_filter(fruits):
242255
@strawberry.type
243256
class Query:

tests/test_ordering.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,16 @@ class Query:
101101
fruits_with_order_connection: DjangoListConnection[FruitWithOrderNode] = (
102102
strawberry_django.connection()
103103
)
104+
fruits_with_lazy_order_connection: DjangoListConnection[FruitWithOrderNode] = (
105+
strawberry_django.connection(
106+
order=Annotated["FruitOrder", strawberry.lazy("tests.test_ordering")]
107+
)
108+
)
109+
fruits_with_lazy_ordering_connection: DjangoListConnection[FruitWithOrderNode] = (
110+
strawberry_django.connection(
111+
ordering=Annotated["FruitOrder", strawberry.lazy("tests.test_ordering")]
112+
)
113+
)
104114
fruits_with_order_paginated: OffsetPaginated[FruitWithOrder] = (
105115
strawberry_django.offset_paginated()
106116
)
@@ -201,6 +211,34 @@ def test_type_ordering_connection(query, fruits):
201211
}
202212

203213

214+
def test_type_lazy_ordering_connection(query, fruits):
215+
result = query(
216+
"{ fruitsWithLazyOrderingConnection(ordering: [{ name: ASC }]) { edges { node { name } } } }"
217+
)
218+
assert not result.errors
219+
assert result.data["fruitsWithLazyOrderingConnection"] == {
220+
"edges": [
221+
{"node": {"name": "banana"}},
222+
{"node": {"name": "raspberry"}},
223+
{"node": {"name": "strawberry"}},
224+
]
225+
}
226+
227+
228+
def test_type_lazy_order_connection(query, fruits):
229+
result = query(
230+
"{ fruitsWithLazyOrderConnection(ordering: [{ name: ASC }]) { edges { node { name } } } }"
231+
)
232+
assert not result.errors
233+
assert result.data["fruitsWithLazyOrderConnection"] == {
234+
"edges": [
235+
{"node": {"name": "banana"}},
236+
{"node": {"name": "raspberry"}},
237+
{"node": {"name": "strawberry"}},
238+
]
239+
}
240+
241+
204242
def test_type_ordering_paginated(query, fruits):
205243
result = query(
206244
"{ fruitsWithOrderPaginated(ordering: [{ name: ASC }]) { results { id name } } }"

0 commit comments

Comments
 (0)