diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e39d61e..10a0ad6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,12 +5,6 @@ repos: - id: black exclude: ^tests/\w+/snapshots/ - - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.0.289 - hooks: - - id: ruff - exclude: ^tests/\w+/snapshots/ - - repo: https://github.com/patrick91/pre-commit-alex rev: aa5da9e54b92ab7284feddeaf52edf14b1690de3 hooks: diff --git a/src/strawberry_sqlalchemy_mapper/loader.py b/src/strawberry_sqlalchemy_mapper/loader.py index 40047e0..f45100e 100644 --- a/src/strawberry_sqlalchemy_mapper/loader.py +++ b/src/strawberry_sqlalchemy_mapper/loader.py @@ -48,10 +48,10 @@ def __init__( async def _scalars_all(self, *args, **kwargs): if self._async_bind_factory: async with self._async_bind_factory() as bind: - return (await bind.scalars(*args, **kwargs)).all() + return (await bind.scalars(*args, **kwargs)).unique().all() else: assert self._bind is not None - return self._bind.scalars(*args, **kwargs).all() + return self._bind.scalars(*args, **kwargs).unique().all() def loader_for(self, relationship: RelationshipProperty) -> DataLoader: """ diff --git a/src/strawberry_sqlalchemy_mapper/mapper.py b/src/strawberry_sqlalchemy_mapper/mapper.py index 11b098e..4ac8f95 100644 --- a/src/strawberry_sqlalchemy_mapper/mapper.py +++ b/src/strawberry_sqlalchemy_mapper/mapper.py @@ -1,4 +1,3 @@ -import ast import asyncio import collections.abc import dataclasses @@ -160,7 +159,11 @@ def __init__( extra_sqlalchemy_type_to_strawberry_type_map: Optional[ Mapping[Type[TypeEngine], Type[Any]] ] = None, + edge_type: Optional[Type] = None, + connection_type: Optional[Type] = None, + use_relay: bool=False, ) -> None: + self.use_relay = use_relay if model_to_type_name is None: model_to_type_name = self._default_model_to_type_name self.model_to_type_name = model_to_type_name @@ -181,6 +184,9 @@ def __init__( self._related_type_models = set() self._related_interface_models = set() + self.edge_type = edge_type + self.connection_type = connection_type + @staticmethod def _default_model_to_type_name(model: Type[BaseModelType]) -> str: return model.__name__ @@ -220,6 +226,8 @@ def _edge_type_for(self, type_name: str) -> Type[Any]: Get or create a corresponding Edge model for the given type (to support future pagination) """ + if self.edge_type is not None: + return self.edge_type edge_name = f"{type_name}Edge" if edge_name not in self.edge_types: self.edge_types[edge_name] = edge_type = strawberry.type( @@ -238,6 +246,10 @@ def _connection_type_for(self, type_name: str) -> Type[Any]: Get or create a corresponding Connection model for the given type (to support future pagination) """ + if not self.use_relay: + return List[ForwardRef(type_name)] + if self.connection_type is not None: + return self.connection_type[ForwardRef(type_name)] connection_name = f"{type_name}Connection" if connection_name not in self.connection_types: self.connection_types[connection_name] = connection_type = strawberry.type( @@ -269,6 +281,8 @@ def _convert_column_to_strawberry_type( """ if isinstance(column.type, Enum): type_annotation = column.type.python_type + if not hasattr(column.type, "_enum_definition"): + type_annotation = strawberry.enum(type_annotation) elif isinstance(column.type, ARRAY): item_type = self._convert_column_to_strawberry_type( Column(column.type.item_type, nullable=False) @@ -439,6 +453,8 @@ def connection_resolver_for( Return an async field resolver for the given relationship that returns a Connection instead of an array of objects. """ + if not self.use_relay: + return self.relationship_resolver_for(relationship) relationship_resolver = self.relationship_resolver_for(relationship) if relationship.uselist: return self.make_connection_wrapper_resolver( @@ -545,6 +561,7 @@ def type( model: Type[BaseModelType], make_interface=False, use_federation=False, + **kwargs, ) -> Callable[[Type[object]], Any]: """ Decorate a type with this to register it as a strawberry type @@ -627,12 +644,12 @@ def convert(type_: Any) -> Any: if key in mapper.columns or key in mapper.relationships: continue if key in model.__annotations__: - annotation = ast.literal_eval(model.__annotations__[key]) + annotation = eval(model.__annotations__[key]) for ( # type: ignore[assignment] sqlalchemy_type, strawberry_type, ) in self.sqlalchemy_type_to_strawberry_type_map.items(): - if isinstance(annotation, sqlalchemy_type): + if annotation == sqlalchemy_type: self._add_annotation( type_, key, strawberry_type, generated_field_keys ) @@ -668,7 +685,7 @@ def convert(type_: Any) -> Any: if "typing" in annotation: # Try to evaluate from existing typing imports annotation = annotation[7:] - annotation = ast.literal_eval(annotation) + annotation = eval(annotation) except NameError: raise UnsupportedDescriptorType(key) self._add_annotation( @@ -693,15 +710,19 @@ def convert(type_: Any) -> Any: # (because they may not have default values) type_.__annotations__.update(old_annotations) + type_name = type_.__name__ + if "name" in kwargs: + type_name = kwargs["name"] + if make_interface: - mapped_type = strawberry.interface(type_) - self.mapped_interfaces[type_.__name__] = mapped_type + mapped_type = strawberry.interface(type_, **kwargs) + self.mapped_interfaces[type_name] = mapped_type elif use_federation: - mapped_type = strawberry.federation.type(type_) - self.mapped_types[type_.__name__] = mapped_type + mapped_type = strawberry.federation.type(type_, **kwargs) + self.mapped_types[type_name] = mapped_type else: - mapped_type = strawberry.type(type_) - self.mapped_types[type_.__name__] = mapped_type + mapped_type = strawberry.type(type_, **kwargs) + self.mapped_types[type_name] = mapped_type setattr(mapped_type, _GENERATED_FIELD_KEYS_KEY, generated_field_keys) setattr(mapped_type, _ORIGINAL_TYPE_KEY, type_) return mapped_type