Skip to content

Commit fbdfa54

Browse files
csechristCameron SechristCkk3
authored
allow for directives to be added to the types (#204)
* feat: allow for directives to be added to the types * Add pre-commit updates that work with python 3.8 * add tests and run black * add release.md * fix? * added a Release.md with any \n * add release minor --------- Co-authored-by: Cameron Sechrist <[email protected]> Co-authored-by: Ckk3 <[email protected]>
1 parent 9f98569 commit fbdfa54

File tree

4 files changed

+107
-10
lines changed

4 files changed

+107
-10
lines changed

.pre-commit-config.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
repos:
22
- repo: https://github.com/psf/black
3-
rev: 24.4.2
3+
rev: 24.4.0
44
hooks:
55
- id: black
66
exclude: ^tests/\w+/snapshots/
77

88
- repo: https://github.com/astral-sh/ruff-pre-commit
9-
rev: v0.4.5
9+
rev: v0.11.7
1010
hooks:
1111
- id: ruff
1212
exclude: ^tests/\w+/snapshots/
@@ -24,7 +24,7 @@ repos:
2424
files: '^docs/.*\.mdx?$'
2525

2626
- repo: https://github.com/pre-commit/pre-commit-hooks
27-
rev: v4.6.0
27+
rev: v5.0.0
2828
hooks:
2929
- id: trailing-whitespace
3030
- id: check-merge-conflict
@@ -33,7 +33,7 @@ repos:
3333
- id: check-toml
3434

3535
- repo: https://github.com/adamchainz/blacken-docs
36-
rev: 1.16.0
36+
rev: 1.18.0
3737
hooks:
3838
- id: blacken-docs
3939
args: [--skip-errors]

RELEASE.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
Release type: minor
2+
3+
**Added support for GraphQL directives** in the SQLAlchemy type mapper, enabling better integration with GraphQL federation.
4+
5+
**Example usage:**
6+
```python
7+
@mapper.type(Employee, directives=["@deprecated(reason: 'Use newEmployee instead')"])
8+
class Employee:
9+
pass
10+
```

src/strawberry_sqlalchemy_mapper/mapper.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
NewType,
2626
Optional,
2727
Protocol,
28+
Sequence,
2829
Set,
2930
Type,
3031
TypeVar,
@@ -150,11 +151,13 @@ class StrawberrySQLAlchemyType(Generic[BaseModelType]):
150151

151152
@overload
152153
@classmethod
153-
def from_type(cls, type_: type, *, strict: Literal[True]) -> Self: ...
154+
def from_type(cls, type_: type, *, strict: Literal[True]) -> Self:
155+
...
154156

155157
@overload
156158
@classmethod
157-
def from_type(cls, type_: type, *, strict: bool = False) -> Optional[Self]: ...
159+
def from_type(cls, type_: type, *, strict: bool = False) -> Optional[Self]:
160+
...
158161

159162
@classmethod
160163
def from_type(
@@ -387,7 +390,7 @@ def _convert_relationship_to_strawberry_type(
387390
if relationship.uselist:
388391
# Use list if excluding relay pagination
389392
if use_list:
390-
return List[ForwardRef(type_name)] # type: ignore
393+
return List[ForwardRef(type_name)] # type: ignore
391394

392395
return self._connection_type_for(type_name)
393396
else:
@@ -638,6 +641,7 @@ def type(
638641
model: Type[BaseModelType],
639642
make_interface=False,
640643
use_federation=False,
644+
directives: Union[Sequence[object], None] = (),
641645
) -> Callable[[Type[object]], Any]:
642646
"""
643647
Decorate a type with this to register it as a strawberry type
@@ -832,10 +836,12 @@ def convert(type_: Any) -> Any:
832836
mapped_type = strawberry.interface(type_)
833837
self.mapped_interfaces[type_.__name__] = mapped_type
834838
elif use_federation:
835-
mapped_type = strawberry.federation.type(type_)
839+
mapped_type = strawberry.federation.type(
840+
type_, directives=directives if directives else ()
841+
)
836842
self.mapped_types[type_.__name__] = mapped_type
837843
else:
838-
mapped_type = strawberry.type(type_)
844+
mapped_type = strawberry.type(type_, directives=directives)
839845
self.mapped_types[type_.__name__] = mapped_type
840846

841847
setattr(

tests/test_mapper.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,8 @@ class Department:
326326
@strawberry.type
327327
class Query:
328328
@strawberry.field
329-
def departments(self) -> Department: ...
329+
def departments(self) -> Department:
330+
...
330331

331332
mapper.finalize()
332333
schema = strawberry.Schema(query=Query)
@@ -379,3 +380,83 @@ def departments(self) -> Department: ...
379380
}
380381
'''
381382
assert str(schema) == textwrap.dedent(expected).strip()
383+
384+
385+
@pytest.mark.parametrize(
386+
"directives",
387+
[
388+
(["@deprecated(reason: 'Use newEmployee instead')"]),
389+
(
390+
[
391+
"@deprecated(reason: 'Use newEmployee instead')",
392+
"@customDirective(value: 'example')",
393+
]
394+
),
395+
],
396+
)
397+
def test_type_with_directives(mapper, employee_table, directives):
398+
Employee = employee_table
399+
400+
@mapper.type(Employee, directives=directives)
401+
class Employee:
402+
pass
403+
404+
mapper.finalize()
405+
additional_types = list(mapper.mapped_types.values())
406+
assert len(additional_types) == 1
407+
mapped_employee_type = additional_types[0]
408+
assert mapped_employee_type.__name__ == "Employee"
409+
assert len(mapped_employee_type.__strawberry_definition__.fields) == 2
410+
assert mapped_employee_type.__strawberry_definition__.directives == directives
411+
412+
413+
@pytest.mark.parametrize(
414+
"directives",
415+
[
416+
(["@deprecated(reason: 'Use newEmployee instead')"]),
417+
(
418+
[
419+
"@deprecated(reason: 'Use newEmployee instead')",
420+
"@customDirective(value: 'example')",
421+
]
422+
),
423+
],
424+
)
425+
def test_type_with_directives_and_federation(mapper, employee_table, directives):
426+
Employee = employee_table
427+
428+
@mapper.type(Employee, directives=directives, use_federation=True)
429+
class Employee:
430+
pass
431+
432+
mapper.finalize()
433+
additional_types = list(mapper.mapped_types.values())
434+
assert len(additional_types) == 1
435+
mapped_employee_type = additional_types[0]
436+
assert mapped_employee_type.__name__ == "Employee"
437+
assert len(mapped_employee_type.__strawberry_definition__.fields) == 2
438+
assert mapped_employee_type.__strawberry_definition__.directives == directives
439+
440+
441+
@pytest.mark.parametrize(
442+
"use_federation_value, expected_directives",
443+
[(True, []), (False, ())],
444+
)
445+
def test_type_with_default_directives(
446+
mapper, employee_table, use_federation_value, expected_directives
447+
):
448+
Employee = employee_table
449+
450+
@mapper.type(Employee, use_federation=use_federation_value)
451+
class Employee:
452+
pass
453+
454+
mapper.finalize()
455+
additional_types = list(mapper.mapped_types.values())
456+
assert len(additional_types) == 1
457+
mapped_employee_type = additional_types[0]
458+
assert mapped_employee_type.__name__ == "Employee"
459+
assert len(mapped_employee_type.__strawberry_definition__.fields) == 2
460+
assert (
461+
mapped_employee_type.__strawberry_definition__.directives == expected_directives
462+
)

0 commit comments

Comments
 (0)