11from __future__ import annotations
2- from copy import deepcopy
3- from functools import partial , reduce , lru_cache
2+ from functools import partial , lru_cache
43from inspect import isfunction
54from typing import Any , List
65from typing import Iterable
76
8- from django .core .exceptions import FieldDoesNotExist , ObjectDoesNotExist
97from django .db import transaction
108from django .db .models .fields .related_descriptors import (
119 ForwardManyToOneDescriptor ,
2826 AFTER_SAVE ,
2927 AFTER_DELETE ,
3028)
31-
29+ from .model_state import ModelState
30+ from .utils import get_value
31+ from .utils import sanitize_field_name
3232
3333DJANGO_RELATED_FIELD_DESCRIPTOR_CLASSES = (
3434 ForwardManyToOneDescriptor ,
@@ -80,91 +80,32 @@ def instantiate_hooked_method(
8080class LifecycleModelMixin (object ):
8181 def __init__ (self , * args , ** kwargs ):
8282 super ().__init__ (* args , ** kwargs )
83- self ._initial_state = self ._snapshot_state ()
84-
85- def _snapshot_state (self ):
86- state = deepcopy (self .__dict__ )
87-
88- for watched_related_field in self ._watched_fk_model_fields ():
89- state [watched_related_field ] = self ._current_value (watched_related_field )
90-
91- if "_state" in state :
92- del state ["_state" ]
93-
94- if "_potentially_hooked_methods" in state :
95- del state ["_potentially_hooked_methods" ]
83+ self ._initial_state = ModelState .from_instance (self )
9684
97- if "_initial_state" in state :
98- del state ["_initial_state" ]
99-
100- if "_watched_fk_model_fields" in state :
101- del state ["_watched_fk_model_fields" ]
102-
103- return state
85+ def _snapshot_state (self ) -> dict :
86+ return ModelState .from_instance (self ).initial_state
10487
10588 @property
10689 def _diff_with_initial (self ) -> dict :
107- initial = self ._initial_state
108- current = self ._snapshot_state ()
109- diffs = []
110-
111- for k , v in initial .items ():
112- if k in current and v != current [k ]:
113- diffs .append ((k , (v , current [k ])))
114-
115- return dict (diffs )
90+ return self ._initial_state .get_diff (self )
11691
11792 def _sanitize_field_name (self , field_name : str ) -> str :
118- try :
119- field = self ._meta .get_field (field_name )
120-
121- try :
122- internal_type = field .get_internal_type ()
123- except AttributeError :
124- return field
125- if internal_type == "ForeignKey" or internal_type == "OneToOneField" :
126- if not field_name .endswith ("_id" ):
127- return field_name + "_id"
128- except FieldDoesNotExist :
129- pass
130-
131- return field_name
93+ return sanitize_field_name (self , field_name )
13294
13395 def _current_value (self , field_name : str ) -> Any :
134- if "." in field_name :
135-
136- def getitem (obj , field_name : str ):
137- try :
138- return getattr (obj , field_name )
139- except (AttributeError , ObjectDoesNotExist ):
140- return None
141-
142- return reduce (getitem , field_name .split ("." ), self )
143- else :
144- return getattr (self , self ._sanitize_field_name (field_name ))
96+ return get_value (self , field_name )
14597
14698 def initial_value (self , field_name : str ) -> Any :
14799 """
148100 Get initial value of field when model value instantiated.
149101 """
150- field_name = self ._sanitize_field_name (field_name )
151-
152- if field_name in self ._initial_state :
153- return self ._initial_state [field_name ]
154-
155- return None
102+ return self ._initial_state .get_value (self , field_name )
156103
157104 def has_changed (self , field_name : str ) -> bool :
158105 """
159106 Check if a field has changed since the model value instantiated.
160107 """
161- changed = self ._diff_with_initial .keys ()
162- field_name = self ._sanitize_field_name (field_name )
163-
164- if field_name in changed :
165- return True
166-
167- return False
108+ return self ._initial_state .has_changed (self , field_name )
168109
169110 def _clear_watched_fk_model_cache (self ):
170111 """ """
@@ -175,7 +116,7 @@ def _clear_watched_fk_model_cache(self):
175116 field .delete_cached_value (self )
176117
177118 def _reset_initial_state (self ):
178- self ._initial_state = self . _snapshot_state ( )
119+ self ._initial_state = ModelState . from_instance ( self )
179120
180121 @transaction .atomic
181122 def save (self , * args , ** kwargs ):
0 commit comments