Skip to content

Commit 74c41b7

Browse files
Ckk3tylernisonoff
andauthored
add missing page_info / cursors for Association Proxy logic (#241)
* add missing page_info / cursors for relay * add RELEASE.md * add missing cursor * run formatter and update release * added tests to proxy_association, still missing a test case * Add a better messages in the exception, add new test with keysetconnection * add _resolve_connection_edges to a staticmethod and run pre-commit * update Release md * Run pre-commit in exc.py * run pre-commit in conftest.py * run pre-commit in test_association_proxy.py * run pre-commit in test_mapper * Fix tests on windows error * add a new test to null relationship * refactor test_association_proxy for better understanding * update release.md to add tylernisonoff credits --------- Co-authored-by: Tyler Nisonoff <[email protected]>
1 parent e19bdb8 commit 74c41b7

File tree

6 files changed

+972
-53
lines changed

6 files changed

+972
-53
lines changed

RELEASE.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
Release type: patch
2+
3+
Ensure association proxy resolvers return valid relay connections, including `page_info` and edge `cursor` details, even for empty results.
4+
5+
Thanks to https://github.com/tylernisonoff for the original PR.

src/strawberry_sqlalchemy_mapper/exc.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,37 +2,38 @@ class UnsupportedColumnType(Exception):
22
def __init__(self, key, type):
33
super().__init__(
44
f"Unsupported column type: `{type}` on column: `{key}`. "
5-
+ "Possible fix: exclude this column"
5+
"Possible fix: exclude this column"
66
)
77

88

99
class UnsupportedAssociationProxyTarget(Exception):
1010
def __init__(self, key):
1111
super().__init__(
1212
f"Association proxy `{key}` is expected to be of form "
13-
+ "association_proxy(relationship_name, other relationship name)"
13+
"association_proxy(relationship_name, other relationship name). "
14+
"Ensure it matches the expected form or add this association proxy to __exclude__."
1415
)
1516

1617

1718
class HybridPropertyNotAnnotated(Exception):
1819
def __init__(self, key):
1920
super().__init__(
2021
f"Descriptor `{key}` is a hybrid property, but does not have an "
21-
+ "annotated return type"
22+
"annotated return type"
2223
)
2324

2425

2526
class UnsupportedDescriptorType(Exception):
2627
def __init__(self, key):
2728
super().__init__(
2829
f"Descriptor `{key}` is expected to be a column, relationship, "
29-
+ "or association proxy."
30+
"or association proxy."
3031
)
3132

3233

3334
class InterfaceModelNotPolymorphic(Exception):
3435
def __init__(self, model):
3536
super().__init__(
3637
f"Model `{model}` is not polymorphic or is not the base model of its "
37-
+ "inheritance chain, and thus cannot be used as an interface."
38+
"inheritance chain, and thus cannot be used as an interface."
3839
)

src/strawberry_sqlalchemy_mapper/mapper.py

Lines changed: 46 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,12 @@
6969
Mapper,
7070
RelationshipProperty,
7171
)
72-
from sqlalchemy.orm.state import InstanceState
7372
from sqlalchemy.sql.type_api import TypeEngine
7473
from strawberry import relay
7574
from strawberry.annotation import StrawberryAnnotation
7675
from strawberry.scalars import JSON as StrawberryJSON
7776
from strawberry.types import Info
7877
from strawberry.types.base import WithStrawberryObjectDefinition, get_object_definition
79-
from strawberry.types.field import StrawberryField
8078
from strawberry.types.lazy_type import LazyType
8179
from strawberry.types.private import is_private
8280

@@ -97,13 +95,14 @@
9795
from strawberry_sqlalchemy_mapper.scalars import BigInt
9896

9997
if TYPE_CHECKING:
98+
from sqlalchemy.orm.state import InstanceState
10099
from sqlalchemy.sql.expression import ColumnElement
100+
from strawberry.types.field import StrawberryField
101101

102102
BaseModelType = TypeVar("BaseModelType")
103103

104104
SkipTypeSentinelT = NewType("SkipTypeSentinelT", object)
105-
SkipTypeSentinel = cast(SkipTypeSentinelT, sentinel.create("SkipTypeSentinel"))
106-
105+
SkipTypeSentinel = cast("SkipTypeSentinelT", sentinel.create("SkipTypeSentinel"))
107106

108107
#: Set on generated types to the original type handed to the decorator
109108
_ORIGINAL_TYPE_KEY = "_original_type"
@@ -151,13 +150,11 @@ class StrawberrySQLAlchemyType(Generic[BaseModelType]):
151150

152151
@overload
153152
@classmethod
154-
def from_type(cls, type_: type, *, strict: Literal[True]) -> Self:
155-
...
153+
def from_type(cls, type_: type, *, strict: Literal[True]) -> Self: ...
156154

157155
@overload
158156
@classmethod
159-
def from_type(cls, type_: type, *, strict: bool = False) -> Optional[Self]:
160-
...
157+
def from_type(cls, type_: type, *, strict: bool = False) -> Optional[Self]: ...
161158

162159
@classmethod
163160
def from_type(
@@ -468,28 +465,34 @@ def make_connection_wrapper_resolver(
468465
edge_type = self._edge_type_for(type_name)
469466

470467
async def wrapper(self, info: Info):
471-
# TODO: Add pagination support to dataloader resolvers
472-
edges = [
473-
edge_type.resolve_edge(
474-
related_object,
475-
cursor=i,
476-
)
477-
for i, related_object in enumerate(await resolver(self, info))
478-
]
479-
return connection_type(
480-
edges=edges,
481-
page_info=relay.PageInfo(
482-
has_next_page=False,
483-
has_previous_page=False,
484-
start_cursor=edges[0].cursor if edges else None,
485-
end_cursor=edges[-1].cursor if edges else None,
486-
),
468+
return StrawberrySQLAlchemyMapper._resolve_connection_edges(
469+
await resolver(self, info), edge_type, connection_type
487470
)
488471

489472
setattr(wrapper, _IS_GENERATED_RESOLVER_KEY, True)
490473

491474
return wrapper
492475

476+
@staticmethod
477+
def _resolve_connection_edges(related_objects, edge_type, connection_type):
478+
# TODO: Add pagination support to dataloader resolvers
479+
edges = [
480+
edge_type.resolve_edge(
481+
related_object,
482+
cursor=i,
483+
)
484+
for i, related_object in enumerate(related_objects)
485+
]
486+
return connection_type(
487+
edges=edges,
488+
page_info=relay.PageInfo(
489+
has_next_page=False,
490+
has_previous_page=False,
491+
start_cursor=edges[0].cursor if edges else None,
492+
end_cursor=edges[-1].cursor if edges else None,
493+
),
494+
)
495+
493496
def relationship_resolver_for(
494497
self, relationship: RelationshipProperty
495498
) -> Callable[..., Awaitable[Any]]:
@@ -499,7 +502,7 @@ def relationship_resolver_for(
499502
"""
500503

501504
async def resolve(self, info: Info):
502-
instance_state = cast(InstanceState, inspect(self))
505+
instance_state = cast("InstanceState", inspect(self))
503506
if relationship.key not in instance_state.unloaded:
504507
related_objects = getattr(self, relationship.key)
505508
else:
@@ -575,7 +578,15 @@ async def resolve(self, info: Info):
575578
in_between_objects = await in_between_resolver(self, info)
576579
if in_between_objects is None:
577580
if is_multiple:
578-
return connection_type(edges=[])
581+
return connection_type(
582+
edges=[],
583+
page_info=relay.PageInfo(
584+
has_next_page=False,
585+
has_previous_page=False,
586+
start_cursor=None,
587+
end_cursor=None,
588+
),
589+
)
579590
else:
580591
return None
581592
if descriptor.value_attr in in_between_mapper.relationships:
@@ -595,7 +606,9 @@ async def resolve(self, info: Info):
595606
outputs = await end_relationship_resolver(in_between_objects, info)
596607
if not isinstance(outputs, collections.abc.Iterable):
597608
return outputs
598-
return connection_type(edges=[edge_type(node=obj) for obj in outputs])
609+
return StrawberrySQLAlchemyMapper._resolve_connection_edges(
610+
outputs, edge_type, connection_type
611+
)
599612
else:
600613
assert descriptor.value_attr in in_between_mapper.columns
601614
if isinstance(in_between_objects, collections.abc.Iterable):
@@ -668,7 +681,7 @@ def convert(type_: Any) -> Any:
668681
type_.__annotations__ = {
669682
k: v for k, v in old_annotations.items() if is_private(v)
670683
}
671-
mapper: Mapper = cast(Mapper, inspect(model))
684+
mapper: Mapper = cast("Mapper", inspect(model))
672685
generated_field_keys = []
673686

674687
excluded_keys = getattr(type_, "__exclude__", [])
@@ -707,7 +720,7 @@ def convert(type_: Any) -> Any:
707720
generated_field_keys,
708721
)
709722
sqlalchemy_field = cast(
710-
StrawberryField,
723+
"StrawberryField",
711724
field(
712725
resolver=self.connection_resolver_for(
713726
relationship,
@@ -751,7 +764,7 @@ def convert(type_: Any) -> Any:
751764
type_, key, strawberry_type, generated_field_keys
752765
)
753766
sqlalchemy_field = cast(
754-
StrawberryField,
767+
"StrawberryField",
755768
field(
756769
resolver=self.association_proxy_resolver_for(
757770
mapper,
@@ -789,7 +802,8 @@ def convert(type_: Any) -> Any:
789802
# ignore inherited `is_type_of`
790803
if "is_type_of" not in type_.__dict__:
791804
type_.is_type_of = (
792-
lambda obj, info: type(obj) == model or type(obj) == type_
805+
lambda obj, info: type(obj) == model # noqa: E721
806+
or type(obj) == type_ # noqa: E721
793807
)
794808

795809
# Default querying methods for relay
@@ -822,7 +836,7 @@ def convert(type_: Any) -> Any:
822836
setattr(
823837
type_,
824838
attr,
825-
types.MethodType(cast(classmethod, meth).__func__, type_),
839+
types.MethodType(cast("classmethod", meth).__func__, type_),
826840
)
827841

828842
# need to make fields that are already in the type

tests/conftest.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,16 @@
2020
from sqlalchemy.ext import asyncio
2121
from sqlalchemy.ext.asyncio import create_async_engine
2222
from sqlalchemy.ext.asyncio.engine import AsyncEngine
23+
from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper
2324
from testing.postgresql import Postgresql, PostgresqlFactory
2425

2526
SQLA_VERSION = version.parse(sqlalchemy.__version__)
2627
SQLA2 = SQLA_VERSION >= version.parse("2.0")
2728

2829

2930
logging.basicConfig()
30-
logging.getLogger("sqlalchemy.engine").setLevel(logging.INFO)
31+
log = logging.getLogger("sqlalchemy.engine")
32+
log.setLevel(logging.INFO)
3133

3234

3335
def _pick_unused_port():
@@ -55,7 +57,7 @@ def postgresql(postgresql_factory) -> Postgresql:
5557
# Our windows test pipeline doesn't play nice with postgres because
5658
# Github Actions doesn't support containers on windows.
5759
# It would probably be nicer if we chcked if postgres is installed
58-
logging.info("Skipping postgresql tests on Windows OS")
60+
log.info("Skipping postgresql tests on Windows OS")
5961
SUPPORTED_DBS = []
6062
else:
6163
SUPPORTED_DBS = ["postgresql"] # TODO: Add sqlite and mysql.
@@ -111,3 +113,8 @@ def async_sessionmaker(async_engine):
111113
@pytest.fixture
112114
def base():
113115
return orm.declarative_base()
116+
117+
118+
@pytest.fixture
119+
def mapper():
120+
return StrawberrySQLAlchemyMapper()

0 commit comments

Comments
 (0)