Skip to content

Commit 6befcf8

Browse files
authored
Fix lint errors (#38)
1 parent e4b3766 commit 6befcf8

File tree

4 files changed

+53
-36
lines changed

4 files changed

+53
-36
lines changed

RELEASE.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Release type: patch
2+
3+
Makes a series of minor changes to fix lint errors between MyPy and Ruff.

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ ignore = [
105105
# maybe we can enable this in future
106106
# we'd want to have consistent docstrings in future
107107
"D",
108+
"ANN001", # missing annotation for function argument self.
108109
"ANN101", # missing annotation for self?
109110
# definitely enable these, maybe not in tests
110111
"ANN102",
@@ -169,6 +170,8 @@ ignore = [
169170
"B905",
170171
"ISC001",
171172

173+
"E501", # Line too long (we enforce black on pre-commit anyway).
174+
172175
# same?
173176
"S105",
174177
"S106",

src/strawberry_sqlalchemy_mapper/loader.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from collections import defaultdict
2-
from typing import Any, Dict, List, Mapping, Tuple
2+
from typing import Any, Dict, List, Mapping, Tuple, Union
33

44
from sqlalchemy import select, tuple_
5-
from sqlalchemy.orm import RelationshipProperty
5+
from sqlalchemy.engine.base import Connection
6+
from sqlalchemy.orm import RelationshipProperty, Session
67
from strawberry.dataloader import DataLoader
78

89

@@ -13,7 +14,7 @@ class StrawberrySQLAlchemyLoader:
1314

1415
_loaders: Dict[RelationshipProperty, DataLoader]
1516

16-
def __init__(self, bind) -> None:
17+
def __init__(self, bind: Union[Session, Connection]) -> None:
1718
self._loaders = {}
1819
self.bind = bind
1920

@@ -29,7 +30,7 @@ def loader_for(self, relationship: RelationshipProperty) -> DataLoader:
2930
async def load_fn(keys: List[Tuple]) -> List[Any]:
3031
query = select(related_model).filter(
3132
tuple_(
32-
*[remote for _, remote in relationship.local_remote_pairs]
33+
*[remote for _, remote in relationship.local_remote_pairs or []]
3334
).in_(keys)
3435
)
3536
if relationship.order_by:
@@ -40,7 +41,8 @@ def group_by_remote_key(row: Any) -> Tuple:
4041
return tuple(
4142
[
4243
getattr(row, remote.key)
43-
for _, remote in relationship.local_remote_pairs
44+
for _, remote in relationship.local_remote_pairs or []
45+
if remote.key
4446
]
4547
)
4648

src/strawberry_sqlalchemy_mapper/mapper.py

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import ast
12
import asyncio
23
import collections.abc
34
import dataclasses
@@ -7,6 +8,7 @@
78
from decimal import Decimal
89
from itertools import chain
910
from typing import (
11+
TYPE_CHECKING,
1012
Any,
1113
Awaitable,
1214
Callable,
@@ -62,6 +64,7 @@
6264
from sqlalchemy.orm.state import InstanceState
6365
from sqlalchemy.sql.type_api import TypeEngine
6466
from strawberry.annotation import StrawberryAnnotation
67+
from strawberry.field import StrawberryField
6568
from strawberry.types import Info
6669

6770
from strawberry_sqlalchemy_mapper.exc import (
@@ -72,9 +75,12 @@
7275
UnsupportedDescriptorType,
7376
)
7477

78+
if TYPE_CHECKING:
79+
from sqlalchemy.orm.sql.elements import ColumnElement
80+
7581
BaseModelType = TypeVar("BaseModelType")
7682

77-
SkipTypeSentinelT = NewType("SkipType", object)
83+
SkipTypeSentinelT = NewType("SkipTypeSentinelT", object)
7884
SkipTypeSentinel = cast(SkipTypeSentinelT, sentinel.create("SkipTypeSentinel"))
7985

8086

@@ -204,7 +210,7 @@ def _is_model_polymorphic(self, model: Type[BaseModelType]) -> bool:
204210
"""
205211
Whether a model is part of a polymorphic hierarchy
206212
"""
207-
return inspect(model).polymorphic_on is not None
213+
return inspect(model).polymorphic_on is not None # type: ignore[union-attr]
208214

209215
def _edge_type_for(self, type_name: str) -> Type[Any]:
210216
"""
@@ -249,7 +255,7 @@ def _get_polymorphic_base_model(
249255
"""
250256
Given a model, return the base of its inheritance tree (which may be itself).
251257
"""
252-
return inspect(model).base_mapper.entity
258+
return inspect(model).base_mapper.entity # type: ignore[union-attr]
253259

254260
def _convert_column_to_strawberry_type(
255261
self, column: Column
@@ -280,7 +286,7 @@ def _convert_column_to_strawberry_type(
280286
if type_annotation is SkipTypeSentinel:
281287
return type_annotation
282288
if column.nullable:
283-
type_annotation = Optional[type_annotation] # type: ignore
289+
type_annotation = Optional[type_annotation]
284290
assert type_annotation is not None
285291
return type_annotation
286292

@@ -291,7 +297,7 @@ def _convert_relationship_to_strawberry_type(
291297
Given a SQLAlchemy relationship, return the type annotation for the field in the
292298
corresponding strawberry type.
293299
"""
294-
relationship_model: Type[BaseModelType] = relationship.entity.entity
300+
relationship_model: Type[BaseModelType] = relationship.entity.entity # type: ignore[assignment]
295301
type_name = self.model_to_type_or_interface_name(relationship_model)
296302
if self.model_is_interface(relationship_model):
297303
self._related_interface_models.add(relationship_model)
@@ -315,8 +321,8 @@ def _get_relationship_is_optional(self, relationship: RelationshipProperty) -> b
315321
else:
316322
assert relationship.direction == MANYTOONE
317323
# this model is the one with the FK
318-
for local_col, _ in relationship.local_remote_pairs:
319-
local_col: Column
324+
local_col: ColumnElement
325+
for local_col, _ in relationship.local_remote_pairs or []:
320326
if local_col.nullable:
321327
return True
322328
return False
@@ -339,7 +345,7 @@ def _get_association_proxy_annotation(
339345
are of the form (relationship, relationship).
340346
"""
341347
is_multiple = mapper.relationships[descriptor.target_collection].uselist
342-
in_between_mapper: Mapper = mapper.relationships[
348+
in_between_mapper: Mapper = mapper.relationships[ # type: ignore[assignment]
343349
descriptor.target_collection
344350
].entity
345351
if descriptor.value_attr in in_between_mapper.relationships:
@@ -352,17 +358,13 @@ def _get_association_proxy_annotation(
352358
raise UnsupportedAssociationProxyTarget(key)
353359
if strawberry_type is SkipTypeSentinel:
354360
return strawberry_type
355-
if is_multiple and not self._is_connection_type(
356-
cast(Union[Type[Any], ForwardRef], strawberry_type)
357-
):
361+
if is_multiple and not self._is_connection_type(strawberry_type):
358362
if isinstance(strawberry_type, ForwardRef):
359363
strawberry_type = self._connection_type_for(
360364
strawberry_type.__forward_arg__
361365
)
362366
else:
363-
strawberry_type = self._connection_type_for(
364-
cast(Type[Any], strawberry_type).__name__
365-
)
367+
strawberry_type = self._connection_type_for(strawberry_type.__name__)
366368
return strawberry_type
367369

368370
def make_connection_wrapper_resolver(
@@ -405,7 +407,8 @@ async def resolve(self, info: Info):
405407
relationship_key = tuple(
406408
[
407409
getattr(self, local.key)
408-
for local, _ in relationship.local_remote_pairs
410+
for local, _ in relationship.local_remote_pairs or []
411+
if local.key
409412
]
410413
)
411414
if any(item is None for item in relationship_key):
@@ -437,7 +440,7 @@ def connection_resolver_for(
437440
if relationship.uselist:
438441
return self.make_connection_wrapper_resolver(
439442
relationship_resolver,
440-
self.model_to_type_or_interface_name(relationship.entity.entity),
443+
self.model_to_type_or_interface_name(relationship.entity.entity), # type: ignore[arg-type]
441444
)
442445
else:
443446
return relationship_resolver
@@ -456,14 +459,14 @@ def association_proxy_resolver_for(
456459
"""
457460
in_between_relationship = mapper.relationships[descriptor.target_collection]
458461
in_between_resolver = self.relationship_resolver_for(in_between_relationship)
459-
in_between_mapper: Mapper = mapper.relationships[
462+
in_between_mapper: Mapper = mapper.relationships[ # type: ignore[assignment]
460463
descriptor.target_collection
461464
].entity
462465
assert descriptor.value_attr in in_between_mapper.relationships
463466
end_relationship = in_between_mapper.relationships[descriptor.value_attr]
464467
end_relationship_resolver = self.relationship_resolver_for(end_relationship)
465468
end_type_name = self.model_to_type_or_interface_name(
466-
end_relationship.entity.entity
469+
end_relationship.entity.entity # type: ignore[arg-type]
467470
)
468471
connection_type = self._connection_type_for(end_type_name)
469472
edge_type = self._edge_type_for(end_type_name)
@@ -517,8 +520,8 @@ def _handle_columns(
517520
"""
518521
Add annotations for the columns of the given mapper.
519522
"""
523+
column: Column
520524
for key, column in mapper.columns.items():
521-
column: Column
522525
if (
523526
key in excluded_keys
524527
or key in type_.__annotations__
@@ -563,7 +566,7 @@ class Employee:
563566
def convert(type_: Any) -> Any:
564567
old_annotations = getattr(type_, "__annotations__", {})
565568
type_.__annotations__ = {}
566-
mapper: Mapper = inspect(model)
569+
mapper: Mapper = cast(Mapper, inspect(model))
567570
generated_field_keys = []
568571

569572
excluded_keys = getattr(type_, "__exclude__", [])
@@ -582,8 +585,8 @@ def convert(type_: Any) -> Any:
582585
generated_field_keys.append(key)
583586

584587
self._handle_columns(mapper, type_, excluded_keys, generated_field_keys)
588+
relationship: RelationshipProperty
585589
for key, relationship in mapper.relationships.items():
586-
relationship: RelationshipProperty
587590
if (
588591
key in excluded_keys
589592
or key in type_.__annotations__
@@ -599,8 +602,11 @@ def convert(type_: Any) -> Any:
599602
strawberry_type,
600603
generated_field_keys,
601604
)
602-
field = strawberry.field(
603-
resolver=self.connection_resolver_for(relationship)
605+
field = cast(
606+
StrawberryField,
607+
strawberry.field(
608+
resolver=self.connection_resolver_for(relationship)
609+
),
604610
)
605611
assert not field.init
606612
setattr(
@@ -618,8 +624,8 @@ def convert(type_: Any) -> Any:
618624
if key in mapper.columns or key in mapper.relationships:
619625
continue
620626
if key in model.__annotations__:
621-
annotation = eval(model.__annotations__[key])
622-
for (
627+
annotation = ast.literal_eval(model.__annotations__[key])
628+
for ( # type: ignore[assignment]
623629
sqlalchemy_type,
624630
strawberry_type,
625631
) in self.sqlalchemy_type_to_strawberry_type_map.items():
@@ -629,18 +635,21 @@ def convert(type_: Any) -> Any:
629635
)
630636
break
631637
elif isinstance(descriptor, AssociationProxy):
632-
strawberry_type = self._get_association_proxy_annotation(
638+
strawberry_type = self._get_association_proxy_annotation( # type: ignore[assignment]
633639
mapper, key, descriptor
634640
)
635641
if strawberry_type is SkipTypeSentinel:
636642
continue
637643
self._add_annotation(
638644
type_, key, strawberry_type, generated_field_keys
639645
)
640-
field = strawberry.field(
641-
resolver=self.association_proxy_resolver_for(
642-
mapper, descriptor, strawberry_type
643-
)
646+
field = cast(
647+
StrawberryField,
648+
strawberry.field(
649+
resolver=self.association_proxy_resolver_for(
650+
mapper, descriptor, strawberry_type # type: ignore[arg-type]
651+
)
652+
),
644653
)
645654
assert not field.init
646655
setattr(type_, key, field)
@@ -656,7 +665,7 @@ def convert(type_: Any) -> Any:
656665
if "typing" in annotation:
657666
# Try to evaluate from existing typing imports
658667
annotation = annotation[7:]
659-
annotation = eval(annotation)
668+
annotation = ast.literal_eval(annotation)
660669
except NameError:
661670
raise UnsupportedDescriptorType(key)
662671
self._add_annotation(

0 commit comments

Comments
 (0)