Skip to content

Commit 3829267

Browse files
Add type validation for foreign key and one to one model consistency (#1792)
1 parent a5bb80f commit 3829267

File tree

4 files changed

+98
-2
lines changed

4 files changed

+98
-2
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Changelog
1414
Added
1515
^^^^^
1616
- Implement savepoints for transactions (#1816)
17+
- Added type validation for foreign key fields to ensure type safety. Now raises `ValidationError` when assigning foreign key values with incorrect model types (#1792)
1718

1819
Fixed
1920
^^^^^
@@ -1498,4 +1499,4 @@ Docs/examples:
14981499
14991500
await Tournament.filter(
15001501
events__name__in=['1', '3']
1501-
).order_by('-events__participants__name').distinct()
1502+
).order_by('-events__participants__name').distinct()

tests/fields/test_fk.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,20 @@
11
from tests import testmodels
22
from tortoise.contrib import test
3-
from tortoise.exceptions import IntegrityError, NoValuesFetched, OperationalError
3+
from tortoise.exceptions import (
4+
IntegrityError,
5+
NoValuesFetched,
6+
OperationalError,
7+
ValidationError,
8+
)
49
from tortoise.queryset import QuerySet
510

611

712
class TestForeignKeyField(test.TestCase):
13+
def assertRaisesWrongTypeException(self, relation_name: str):
14+
return self.assertRaisesRegex(
15+
ValidationError, f"Invalid type for relationship field '{relation_name}'"
16+
)
17+
818
async def test_empty(self):
919
with self.assertRaises(IntegrityError):
1020
await testmodels.MinRelation.create()
@@ -151,6 +161,11 @@ async def test_minimal__instantiated_create(self):
151161
tour = await testmodels.Tournament.create(name="Team1")
152162
await testmodels.MinRelation.create(tournament=tour)
153163

164+
async def test_minimal__instantiated_create_wrong_type(self):
165+
author = await testmodels.Author.create(name="Author1")
166+
with self.assertRaisesWrongTypeException("tournament"):
167+
await testmodels.MinRelation.create(tournament=author)
168+
154169
async def test_minimal__instantiated_iterate(self):
155170
tour = await testmodels.Tournament.create(name="Team1")
156171
async for _ in tour.minrelations:
@@ -229,3 +244,57 @@ async def test_event__offset(self):
229244
event2 = await testmodels.Event.create(name="Event2", tournament=tour)
230245
event3 = await testmodels.Event.create(name="Event3", tournament=tour)
231246
self.assertEqual(await tour.events.offset(1).order_by("name"), [event2, event3])
247+
248+
async def test_fk_correct_type_assignment(self):
249+
tour1 = await testmodels.Tournament.create(name="Team1")
250+
tour2 = await testmodels.Tournament.create(name="Team2")
251+
event = await testmodels.Event(name="Event1", tournament=tour1)
252+
253+
event.tournament = tour2
254+
await event.save()
255+
self.assertEqual(event.tournament_id, tour2.id)
256+
257+
async def test_fk_wrong_type_assignment(self):
258+
tour = await testmodels.Tournament.create(name="Team1")
259+
author = await testmodels.Author.create(name="Author")
260+
rel = await testmodels.MinRelation.create(tournament=tour)
261+
262+
with self.assertRaisesWrongTypeException("tournament"):
263+
rel.tournament = author
264+
265+
async def test_fk_none_assignment(self):
266+
manager = await testmodels.Employee.create(name="Manager")
267+
employee = await testmodels.Employee.create(name="Employee", manager=manager)
268+
269+
employee.manager = None
270+
await employee.save()
271+
self.assertIsNone(employee.manager)
272+
273+
async def test_fk_update_wrong_type(self):
274+
tour = await testmodels.Tournament.create(name="Team1")
275+
rel = await testmodels.MinRelation.create(tournament=tour)
276+
author = await testmodels.Author.create(name="Author1")
277+
278+
with self.assertRaisesWrongTypeException("tournament"):
279+
await testmodels.MinRelation.filter(id=rel.id).update(tournament=author)
280+
281+
async def test_fk_bulk_create_wrong_type(self):
282+
author = await testmodels.Author.create(name="Author")
283+
with self.assertRaisesWrongTypeException("tournament"):
284+
await testmodels.MinRelation.bulk_create(
285+
[testmodels.MinRelation(tournament=author) for _ in range(10)]
286+
)
287+
288+
async def test_fk_bulk_update_wrong_type(self):
289+
tour = await testmodels.Tournament.create(name="Team1")
290+
await testmodels.MinRelation.bulk_create(
291+
[testmodels.MinRelation(tournament=tour) for _ in range(1, 10)]
292+
)
293+
author = await testmodels.Author.create(name="Author")
294+
295+
with self.assertRaisesWrongTypeException("tournament"):
296+
relations = await testmodels.MinRelation.all()
297+
await testmodels.MinRelation.bulk_update(
298+
[testmodels.MinRelation(id=rel.id, tournament=author) for rel in relations],
299+
fields=["tournament"],
300+
)

tortoise/models.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,13 @@
3030
from tortoise.exceptions import (
3131
ConfigurationError,
3232
DoesNotExist,
33+
FieldError,
3334
IncompleteInstanceError,
3435
IntegrityError,
3536
ObjectDoesNotExistError,
3637
OperationalError,
3738
ParamsError,
39+
ValidationError,
3840
)
3941
from tortoise.expressions import Expression
4042
from tortoise.fields.base import Field
@@ -685,6 +687,8 @@ def __setattr__(self, key, value) -> None:
685687
# set field value override async default function
686688
if hasattr(self, "_await_when_save"):
687689
self._await_when_save.pop(key, None)
690+
if key in self._meta.fk_fields or key in self._meta.o2o_fields:
691+
self._validate_relation_type(key, value)
688692
super().__setattr__(key, value)
689693

690694
def _set_kwargs(self, kwargs: dict) -> Set[str]:
@@ -806,6 +810,27 @@ def _set_pk_val(self, value: Any) -> None:
806810
Can be used as a field name when doing filtering e.g. ``.filter(pk=...)`` etc...
807811
"""
808812

813+
@classmethod
814+
def _validate_relation_type(cls, field_key: str, value: Optional["Model"]) -> None:
815+
if value is None:
816+
return
817+
818+
field = cls._meta.fields_map[field_key]
819+
if not isinstance(field, (OneToOneFieldInstance, ForeignKeyFieldInstance)):
820+
raise FieldError(
821+
f"Field '{field_key}' must be a OneToOne or ForeignKey relation, "
822+
f"got {type(field).__name__}"
823+
)
824+
825+
expected_model = field.related_model
826+
received_model = type(value)
827+
if received_model is not expected_model:
828+
raise ValidationError(
829+
f"Invalid type for relationship field '{field_key}'. "
830+
f"Expected model type '{expected_model.__name__}', but got '{received_model.__name__}'. "
831+
"Make sure you're using the correct model class for this relationship."
832+
)
833+
809834
@classmethod
810835
async def _getbypk(cls: Type[MODEL], key: Any) -> MODEL:
811836
try:

tortoise/queryset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1204,6 +1204,7 @@ def _make_query(self) -> None:
12041204
if field_object.pk:
12051205
raise IntegrityError(f"Field {key} is PK and can not be updated")
12061206
if isinstance(field_object, (ForeignKeyFieldInstance, OneToOneFieldInstance)):
1207+
self.model._validate_relation_type(key, value)
12071208
fk_field: str = field_object.source_field # type: ignore
12081209
db_field = self.model._meta.fields_map[fk_field].source_field
12091210
value = executor.column_map[fk_field](

0 commit comments

Comments
 (0)