Skip to content

Commit 5b0e1e4

Browse files
authored
Merge pull request #160 from rsinger86/refactor/extract-model-state-methods
Extract model state methods
2 parents 92de9d3 + 6d96c8e commit 5b0e1e4

File tree

3 files changed

+114
-72
lines changed

3 files changed

+114
-72
lines changed

django_lifecycle/mixins.py

Lines changed: 13 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
from __future__ import annotations
2-
from copy import deepcopy
3-
from functools import partial, reduce, lru_cache
2+
from functools import partial, lru_cache
43
from inspect import isfunction
54
from typing import Any, List
65
from typing import Iterable
76

8-
from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist
97
from django.db import transaction
108
from django.db.models.fields.related_descriptors import (
119
ForwardManyToOneDescriptor,
@@ -28,7 +26,9 @@
2826
AFTER_SAVE,
2927
AFTER_DELETE,
3028
)
31-
29+
from .model_state import ModelState
30+
from .utils import get_value
31+
from .utils import sanitize_field_name
3232

3333
DJANGO_RELATED_FIELD_DESCRIPTOR_CLASSES = (
3434
ForwardManyToOneDescriptor,
@@ -80,91 +80,32 @@ def instantiate_hooked_method(
8080
class LifecycleModelMixin(object):
8181
def __init__(self, *args, **kwargs):
8282
super().__init__(*args, **kwargs)
83-
self._initial_state = self._snapshot_state()
84-
85-
def _snapshot_state(self):
86-
state = deepcopy(self.__dict__)
87-
88-
for watched_related_field in self._watched_fk_model_fields():
89-
state[watched_related_field] = self._current_value(watched_related_field)
90-
91-
if "_state" in state:
92-
del state["_state"]
93-
94-
if "_potentially_hooked_methods" in state:
95-
del state["_potentially_hooked_methods"]
83+
self._initial_state = ModelState.from_instance(self)
9684

97-
if "_initial_state" in state:
98-
del state["_initial_state"]
99-
100-
if "_watched_fk_model_fields" in state:
101-
del state["_watched_fk_model_fields"]
102-
103-
return state
85+
def _snapshot_state(self) -> dict:
86+
return ModelState.from_instance(self).initial_state
10487

10588
@property
10689
def _diff_with_initial(self) -> dict:
107-
initial = self._initial_state
108-
current = self._snapshot_state()
109-
diffs = []
110-
111-
for k, v in initial.items():
112-
if k in current and v != current[k]:
113-
diffs.append((k, (v, current[k])))
114-
115-
return dict(diffs)
90+
return self._initial_state.get_diff(self)
11691

11792
def _sanitize_field_name(self, field_name: str) -> str:
118-
try:
119-
field = self._meta.get_field(field_name)
120-
121-
try:
122-
internal_type = field.get_internal_type()
123-
except AttributeError:
124-
return field
125-
if internal_type == "ForeignKey" or internal_type == "OneToOneField":
126-
if not field_name.endswith("_id"):
127-
return field_name + "_id"
128-
except FieldDoesNotExist:
129-
pass
130-
131-
return field_name
93+
return sanitize_field_name(self, field_name)
13294

13395
def _current_value(self, field_name: str) -> Any:
134-
if "." in field_name:
135-
136-
def getitem(obj, field_name: str):
137-
try:
138-
return getattr(obj, field_name)
139-
except (AttributeError, ObjectDoesNotExist):
140-
return None
141-
142-
return reduce(getitem, field_name.split("."), self)
143-
else:
144-
return getattr(self, self._sanitize_field_name(field_name))
96+
return get_value(self, field_name)
14597

14698
def initial_value(self, field_name: str) -> Any:
14799
"""
148100
Get initial value of field when model value instantiated.
149101
"""
150-
field_name = self._sanitize_field_name(field_name)
151-
152-
if field_name in self._initial_state:
153-
return self._initial_state[field_name]
154-
155-
return None
102+
return self._initial_state.get_value(self, field_name)
156103

157104
def has_changed(self, field_name: str) -> bool:
158105
"""
159106
Check if a field has changed since the model value instantiated.
160107
"""
161-
changed = self._diff_with_initial.keys()
162-
field_name = self._sanitize_field_name(field_name)
163-
164-
if field_name in changed:
165-
return True
166-
167-
return False
108+
return self._initial_state.has_changed(self, field_name)
168109

169110
def _clear_watched_fk_model_cache(self):
170111
""" """
@@ -175,7 +116,7 @@ def _clear_watched_fk_model_cache(self):
175116
field.delete_cached_value(self)
176117

177118
def _reset_initial_state(self):
178-
self._initial_state = self._snapshot_state()
119+
self._initial_state = ModelState.from_instance(self)
179120

180121
@transaction.atomic
181122
def save(self, *args, **kwargs):

django_lifecycle/model_state.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from __future__ import annotations
2+
3+
from copy import deepcopy
4+
from typing import Any
5+
from typing import Dict
6+
from typing import TYPE_CHECKING
7+
8+
from django_lifecycle.utils import get_value
9+
from django_lifecycle.utils import sanitize_field_name
10+
11+
if TYPE_CHECKING:
12+
from django_lifecycle import LifecycleModelMixin
13+
14+
15+
class ModelState:
16+
def __init__(self, initial_state: Dict[str, Any]):
17+
self.initial_state = initial_state
18+
19+
@classmethod
20+
def from_instance(cls, instance: "LifecycleModelMixin") -> ModelState:
21+
state = deepcopy(instance.__dict__)
22+
23+
for watched_related_field in instance._watched_fk_model_fields():
24+
state[watched_related_field] = get_value(instance, watched_related_field)
25+
26+
fields_to_remove = (
27+
"_state",
28+
"_potentially_hooked_methods",
29+
"_initial_state",
30+
"_watched_fk_model_fields",
31+
)
32+
for field in fields_to_remove:
33+
state.pop(field, None)
34+
35+
return ModelState(state)
36+
37+
def get_diff(self, instance: "LifecycleModelMixin") -> dict:
38+
current = ModelState.from_instance(instance).initial_state
39+
diffs = {}
40+
41+
for key, initial_value in self.initial_state.items():
42+
try:
43+
current_value = current[key]
44+
except KeyError:
45+
continue
46+
47+
if initial_value != current_value:
48+
diffs[key] = (key, current_value)
49+
50+
return diffs
51+
52+
def get_value(self, instance: "LifecycleModelMixin", field_name: str) -> Any:
53+
"""
54+
Get initial value of field when model was instantiated.
55+
"""
56+
field_name = sanitize_field_name(instance, field_name)
57+
return self.initial_state.get(field_name)
58+
59+
def has_changed(self, instance: "LifecycleModelMixin", field_name: str) -> bool:
60+
"""
61+
Check if a field has changed since the model was instantiated.
62+
"""
63+
field_name = sanitize_field_name(instance, field_name)
64+
return field_name in self.get_diff(instance)

django_lifecycle/utils.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from functools import reduce
2+
from typing import Any
3+
4+
from django.core.exceptions import FieldDoesNotExist
5+
from django.core.exceptions import ObjectDoesNotExist
6+
from django.db import models
7+
8+
9+
def sanitize_field_name(instance: models.Model, field_name: str) -> str:
10+
try:
11+
field = instance._meta.get_field(field_name)
12+
13+
try:
14+
internal_type = field.get_internal_type()
15+
except AttributeError:
16+
return field
17+
if internal_type == "ForeignKey" or internal_type == "OneToOneField":
18+
if not field_name.endswith("_id"):
19+
return field_name + "_id"
20+
except FieldDoesNotExist:
21+
pass
22+
23+
return field_name
24+
25+
26+
def get_value(instance, sanitized_field_name: str) -> Any:
27+
if "." in sanitized_field_name:
28+
29+
def getitem(obj, field_name: str):
30+
try:
31+
return getattr(obj, field_name)
32+
except (AttributeError, ObjectDoesNotExist):
33+
return None
34+
35+
return reduce(getitem, sanitized_field_name.split("."), instance)
36+
else:
37+
return getattr(instance, sanitize_field_name(instance, sanitized_field_name))

0 commit comments

Comments
 (0)