Skip to content

Implement Standalone Convert; Enum Handling; User Edge/Connection Classes; Column Alias Handling #25

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,6 @@ repos:
- id: black
exclude: ^tests/\w+/snapshots/

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.289
hooks:
- id: ruff
exclude: ^tests/\w+/snapshots/

- repo: https://github.com/patrick91/pre-commit-alex
rev: aa5da9e54b92ab7284feddeaf52edf14b1690de3
hooks:
Expand Down
4 changes: 2 additions & 2 deletions src/strawberry_sqlalchemy_mapper/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ def __init__(
async def _scalars_all(self, *args, **kwargs):
if self._async_bind_factory:
async with self._async_bind_factory() as bind:
return (await bind.scalars(*args, **kwargs)).all()
return (await bind.scalars(*args, **kwargs)).unique().all()
else:
assert self._bind is not None
return self._bind.scalars(*args, **kwargs).all()
return self._bind.scalars(*args, **kwargs).unique().all()

def loader_for(self, relationship: RelationshipProperty) -> DataLoader:
"""
Expand Down
41 changes: 31 additions & 10 deletions src/strawberry_sqlalchemy_mapper/mapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import ast
import asyncio
import collections.abc
import dataclasses
Expand Down Expand Up @@ -160,7 +159,11 @@ def __init__(
extra_sqlalchemy_type_to_strawberry_type_map: Optional[
Mapping[Type[TypeEngine], Type[Any]]
] = None,
edge_type: Optional[Type] = None,
connection_type: Optional[Type] = None,
use_relay: bool=False,
) -> None:
self.use_relay = use_relay
if model_to_type_name is None:
model_to_type_name = self._default_model_to_type_name
self.model_to_type_name = model_to_type_name
Expand All @@ -181,6 +184,9 @@ def __init__(
self._related_type_models = set()
self._related_interface_models = set()

self.edge_type = edge_type
self.connection_type = connection_type

@staticmethod
def _default_model_to_type_name(model: Type[BaseModelType]) -> str:
return model.__name__
Expand Down Expand Up @@ -220,6 +226,8 @@ def _edge_type_for(self, type_name: str) -> Type[Any]:
Get or create a corresponding Edge model for the given type
(to support future pagination)
"""
if self.edge_type is not None:
return self.edge_type
edge_name = f"{type_name}Edge"
if edge_name not in self.edge_types:
self.edge_types[edge_name] = edge_type = strawberry.type(
Expand All @@ -238,6 +246,10 @@ def _connection_type_for(self, type_name: str) -> Type[Any]:
Get or create a corresponding Connection model for the given type
(to support future pagination)
"""
if not self.use_relay:
return List[ForwardRef(type_name)]
if self.connection_type is not None:
return self.connection_type[ForwardRef(type_name)]
connection_name = f"{type_name}Connection"
if connection_name not in self.connection_types:
self.connection_types[connection_name] = connection_type = strawberry.type(
Expand Down Expand Up @@ -269,6 +281,8 @@ def _convert_column_to_strawberry_type(
"""
if isinstance(column.type, Enum):
type_annotation = column.type.python_type
if not hasattr(column.type, "_enum_definition"):
type_annotation = strawberry.enum(type_annotation)
elif isinstance(column.type, ARRAY):
item_type = self._convert_column_to_strawberry_type(
Column(column.type.item_type, nullable=False)
Expand Down Expand Up @@ -439,6 +453,8 @@ def connection_resolver_for(
Return an async field resolver for the given relationship that
returns a Connection instead of an array of objects.
"""
if not self.use_relay:
return self.relationship_resolver_for(relationship)
relationship_resolver = self.relationship_resolver_for(relationship)
if relationship.uselist:
return self.make_connection_wrapper_resolver(
Expand Down Expand Up @@ -545,6 +561,7 @@ def type(
model: Type[BaseModelType],
make_interface=False,
use_federation=False,
**kwargs,
) -> Callable[[Type[object]], Any]:
"""
Decorate a type with this to register it as a strawberry type
Expand Down Expand Up @@ -627,12 +644,12 @@ def convert(type_: Any) -> Any:
if key in mapper.columns or key in mapper.relationships:
continue
if key in model.__annotations__:
annotation = ast.literal_eval(model.__annotations__[key])
annotation = eval(model.__annotations__[key])
for ( # type: ignore[assignment]
sqlalchemy_type,
strawberry_type,
) in self.sqlalchemy_type_to_strawberry_type_map.items():
if isinstance(annotation, sqlalchemy_type):
if annotation == sqlalchemy_type:
self._add_annotation(
type_, key, strawberry_type, generated_field_keys
)
Expand Down Expand Up @@ -668,7 +685,7 @@ def convert(type_: Any) -> Any:
if "typing" in annotation:
# Try to evaluate from existing typing imports
annotation = annotation[7:]
annotation = ast.literal_eval(annotation)
annotation = eval(annotation)
except NameError:
raise UnsupportedDescriptorType(key)
self._add_annotation(
Expand All @@ -693,15 +710,19 @@ def convert(type_: Any) -> Any:
# (because they may not have default values)
type_.__annotations__.update(old_annotations)

type_name = type_.__name__
if "name" in kwargs:
type_name = kwargs["name"]

if make_interface:
mapped_type = strawberry.interface(type_)
self.mapped_interfaces[type_.__name__] = mapped_type
mapped_type = strawberry.interface(type_, **kwargs)
self.mapped_interfaces[type_name] = mapped_type
elif use_federation:
mapped_type = strawberry.federation.type(type_)
self.mapped_types[type_.__name__] = mapped_type
mapped_type = strawberry.federation.type(type_, **kwargs)
self.mapped_types[type_name] = mapped_type
else:
mapped_type = strawberry.type(type_)
self.mapped_types[type_.__name__] = mapped_type
mapped_type = strawberry.type(type_, **kwargs)
self.mapped_types[type_name] = mapped_type
setattr(mapped_type, _GENERATED_FIELD_KEYS_KEY, generated_field_keys)
setattr(mapped_type, _ORIGINAL_TYPE_KEY, type_)
return mapped_type
Expand Down