Skip to content

Commit 2e64aeb

Browse files
authored
fix: Use lazy type to specialize Edge/Connection (#106)
Use a lazy type to properly resolve and specialize the `Edge`/`Connection` types that are created from the model. The reason we are using a subclass of `LazyType` is because the original one expects a path to be passed to be used to import it later. We don't have one here, but we do have a mapper containing all the created types that can be retrieved later. Fix #97
1 parent 4fa950f commit 2e64aeb

File tree

3 files changed

+94
-2
lines changed

3 files changed

+94
-2
lines changed

RELEASE.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Release type: patch
2+
3+
Fix a regression from 0.4.0 which was [raising an issue](https://github.com/strawberry-graphql/strawberry-sqlalchemy/issues/97)
4+
when trying to create a connection from the model's relationships.

src/strawberry_sqlalchemy_mapper/mapper.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
from strawberry import relay
7474
from strawberry.annotation import StrawberryAnnotation
7575
from strawberry.field import StrawberryField
76+
from strawberry.lazy_type import LazyType
7677
from strawberry.private import is_private
7778
from strawberry.scalars import JSON as StrawberryJSON
7879
from strawberry.type import WithStrawberryObjectDefinition, get_object_definition
@@ -112,6 +113,19 @@
112113
_IS_GENERATED_RESOLVER_KEY = "_is_generated_resolver"
113114

114115

116+
@dataclasses.dataclass(frozen=True)
117+
class StrawberrySQLAlchemyLazy(LazyType):
118+
# We don't actually want `Optional` here but we can't use dataclasse's
119+
# kw_only untul python 3.10, and we do have fields with default values
120+
# in strawberry's LazyType
121+
mapper: Optional["StrawberrySQLAlchemyMapper[Any]"] = None
122+
module: str = "__not_used__"
123+
124+
def resolve_type(self) -> Type[Any]:
125+
assert self.mapper is not None
126+
return self.mapper.mapped_types[self.type_name]
127+
128+
115129
class WithStrawberrySQLAlchemyObjectDefinition(
116130
WithStrawberryObjectDefinition,
117131
Protocol,
@@ -282,11 +296,12 @@ def _edge_type_for(self, type_name: str) -> Type[Any]:
282296
"""
283297
edge_name = f"{type_name}Edge"
284298
if edge_name not in self.edge_types:
299+
lazy_type = StrawberrySQLAlchemyLazy(type_name=type_name, mapper=self)
285300
self.edge_types[edge_name] = edge_type = strawberry.type(
286301
dataclasses.make_dataclass(
287302
edge_name,
288303
[],
289-
bases=(relay.Edge[type_name],), # type: ignore[valid-type]
304+
bases=(relay.Edge[lazy_type],), # type: ignore[valid-type]
290305
)
291306
)
292307
setattr(edge_type, _GENERATED_FIELD_KEYS_KEY, ["node"])
@@ -300,13 +315,14 @@ def _connection_type_for(self, type_name: str) -> Type[Any]:
300315
connection_name = f"{type_name}Connection"
301316
if connection_name not in self.connection_types:
302317
edge_type = self._edge_type_for(type_name)
318+
lazy_type = StrawberrySQLAlchemyLazy(type_name=type_name, mapper=self)
303319
self.connection_types[connection_name] = connection_type = strawberry.type(
304320
dataclasses.make_dataclass(
305321
connection_name,
306322
[
307323
("edges", List[edge_type]), # type: ignore[valid-type]
308324
],
309-
bases=(relay.ListConnection[type_name],), # type: ignore[valid-type]
325+
bases=(relay.ListConnection[lazy_type],), # type: ignore[valid-type]
310326
)
311327
)
312328
setattr(connection_type, _GENERATED_FIELD_KEYS_KEY, ["edges"])

tests/test_mapper.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import enum
2+
import textwrap
23
from typing import List, Optional
34

45
import pytest
6+
import strawberry
57
from sqlalchemy import JSON, Column, Enum, ForeignKey, Integer, String
68
from sqlalchemy.dialects.postgresql.array import ARRAY
79
from sqlalchemy.orm import relationship
@@ -279,3 +281,73 @@ class Employee:
279281
assert type(name.type) == StrawberryOptional
280282
id = next(iter(filter(lambda f: f.name == "department", employee_type_fields)))
281283
assert type(id.type) == StrawberryOptional
284+
285+
286+
def test_relationships_schema(employee_and_department_tables, mapper):
287+
EmployeeModel, DepartmentModel = employee_and_department_tables
288+
289+
@mapper.type(EmployeeModel)
290+
class Employee:
291+
__exclude__ = ["password_hash"]
292+
293+
@mapper.type(DepartmentModel)
294+
class Department:
295+
pass
296+
297+
@strawberry.type
298+
class Query:
299+
@strawberry.field
300+
def departments(self) -> Department:
301+
...
302+
303+
mapper.finalize()
304+
schema = strawberry.Schema(query=Query)
305+
306+
expected = '''
307+
type Department {
308+
id: Int!
309+
name: String!
310+
employees: EmployeeConnection!
311+
}
312+
313+
type Employee {
314+
id: Int!
315+
name: String!
316+
departmentId: Int
317+
department: Department
318+
}
319+
320+
type EmployeeConnection {
321+
"""Pagination data for this connection"""
322+
pageInfo: PageInfo!
323+
edges: [EmployeeEdge!]!
324+
}
325+
326+
type EmployeeEdge {
327+
"""A cursor for use in pagination"""
328+
cursor: String!
329+
330+
"""The item at the end of the edge"""
331+
node: Employee!
332+
}
333+
334+
"""Information to aid in pagination."""
335+
type PageInfo {
336+
"""When paginating forwards, are there more items?"""
337+
hasNextPage: Boolean!
338+
339+
"""When paginating backwards, are there more items?"""
340+
hasPreviousPage: Boolean!
341+
342+
"""When paginating backwards, the cursor to continue."""
343+
startCursor: String
344+
345+
"""When paginating forwards, the cursor to continue."""
346+
endCursor: String
347+
}
348+
349+
type Query {
350+
departments: Department!
351+
}
352+
'''
353+
assert str(schema) == textwrap.dedent(expected).strip()

0 commit comments

Comments
 (0)