Skip to content

Commit f096de7

Browse files
committed
refactor: Extract conditions
1 parent f0d4b0a commit f096de7

File tree

4 files changed

+223
-117
lines changed

4 files changed

+223
-117
lines changed

django_lifecycle/conditions.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from typing import Any
5+
from typing import List
6+
from typing import Optional
7+
8+
9+
from django_lifecycle import NotSet
10+
11+
12+
@dataclass
13+
class WhenFieldWas:
14+
field_name: str
15+
was: Any = "*"
16+
17+
def __call__(self, instance: Any, update_fields=None) -> bool:
18+
return self.was in (instance.initial_value(self.field_name), "*")
19+
20+
21+
@dataclass
22+
class WhenFieldIsNow:
23+
field_name: str
24+
is_now: Any = "*"
25+
26+
def __call__(self, instance: Any, update_fields=None) -> bool:
27+
return self.is_now in (instance._current_value(self.field_name), "*")
28+
29+
30+
@dataclass
31+
class WhenFieldHasChanged:
32+
field_name: str
33+
has_changed: bool | None = None
34+
35+
def __call__(self, instance: Any, update_fields=None) -> bool:
36+
is_partial_fields_update = update_fields is not None
37+
is_synced = (
38+
is_partial_fields_update is False or self.field_name in update_fields
39+
)
40+
if not is_synced:
41+
return False
42+
43+
return self.has_changed is None or self.has_changed == instance.has_changed(
44+
self.field_name
45+
)
46+
47+
48+
@dataclass
49+
class WhenFieldIsNot:
50+
field_name: str
51+
is_not: Any = NotSet
52+
53+
def __call__(self, instance: Any, update_fields=None) -> bool:
54+
return (
55+
self.is_not is NotSet
56+
or instance._current_value(self.field_name) != self.is_not
57+
)
58+
59+
60+
@dataclass
61+
class WhenFieldWasNot:
62+
field_name: str
63+
was_not: Any = NotSet
64+
65+
def __call__(self, instance: Any, update_fields=None) -> bool:
66+
return (
67+
self.was_not is NotSet
68+
or instance.initial_value(self.field_name) != self.was_not
69+
)
70+
71+
72+
@dataclass
73+
class WhenFieldChangesTo:
74+
field_name: str
75+
changes_to: Any = NotSet
76+
77+
def __call__(self, instance: Any, update_fields=None) -> bool:
78+
is_partial_fields_update = update_fields is not None
79+
is_synced = (
80+
is_partial_fields_update is False or self.field_name in update_fields
81+
)
82+
if not is_synced:
83+
return False
84+
85+
return any(
86+
[
87+
self.changes_to is NotSet,
88+
(
89+
instance.initial_value(self.field_name) != self.changes_to
90+
and instance._current_value(self.field_name) == self.changes_to
91+
),
92+
]
93+
)
94+
95+
96+
@dataclass
97+
class WhenAny:
98+
when_any: Optional[List[str]] = None
99+
was: Any = "*"
100+
is_now: Any = "*"
101+
has_changed: Optional[bool] = None
102+
is_not: Any = NotSet
103+
was_not: Any = NotSet
104+
changes_to: Any = NotSet
105+
106+
def __call__(self, instance: Any, update_fields=None) -> bool:
107+
conditions = (
108+
WhenCondition(
109+
when=field,
110+
was=self.was,
111+
is_now=self.is_now,
112+
has_changed=self.has_changed,
113+
is_not=self.is_not,
114+
was_not=self.was_not,
115+
changes_to=self.changes_to,
116+
)
117+
for field in self.when_any
118+
)
119+
return any(
120+
condition(instance, update_fields=update_fields) for condition in conditions
121+
)
122+
123+
124+
@dataclass
125+
class WhenCondition:
126+
when: Optional[str] = None
127+
was: Any = "*"
128+
is_now: Any = "*"
129+
has_changed: Optional[bool] = None
130+
is_not: Any = NotSet
131+
was_not: Any = NotSet
132+
changes_to: Any = NotSet
133+
134+
def __call__(self, instance: Any, update_fields=None) -> bool:
135+
has_changed_condition = WhenFieldHasChanged(
136+
self.when, has_changed=self.has_changed
137+
)
138+
if not has_changed_condition(instance, update_fields=update_fields):
139+
return False
140+
141+
changes_to_condition = WhenFieldChangesTo(self.when, changes_to=self.changes_to)
142+
if not changes_to_condition(instance, self.when):
143+
return False
144+
145+
is_now_condition = WhenFieldIsNow(self.when, is_now=self.is_now)
146+
if not is_now_condition(instance, self.when):
147+
return False
148+
149+
was_condition = WhenFieldWas(self.when, was=self.was)
150+
if not was_condition(instance, self.when):
151+
return False
152+
153+
was_not_condition = WhenFieldWasNot(self.when, was_not=self.was_not)
154+
if not was_not_condition(instance, self.when):
155+
return False
156+
157+
is_not_condition = WhenFieldIsNot(self.when, is_not=self.is_not)
158+
if not is_not_condition(instance, self.when):
159+
return False
160+
161+
return True

django_lifecycle/decorators.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
from __future__ import annotations
2+
13
from dataclasses import dataclass
24
from functools import wraps
35
from typing import Any
6+
from typing import Callable
47
from typing import List, Optional
58

69
from django_lifecycle import NotSet
10+
from .conditions import WhenCondition
711
from .dataclass_validation import Validations
812
from .hooks import VALID_HOOKS
913
from .priority import DEFAULT_PRIORITY
@@ -27,6 +31,36 @@ class HookConfig(Validations):
2731
on_commit: bool = False
2832
priority: int = DEFAULT_PRIORITY
2933

34+
@property
35+
def conditions(self) -> list[Callable]:
36+
if self.when:
37+
return [
38+
WhenCondition(
39+
when=self.when,
40+
was=self.was,
41+
is_now=self.is_now,
42+
has_changed=self.has_changed,
43+
is_not=self.is_not,
44+
was_not=self.was_not,
45+
changes_to=self.changes_to,
46+
)
47+
]
48+
elif self.when_any:
49+
return [
50+
WhenCondition(
51+
when=field,
52+
was=self.was,
53+
is_now=self.is_now,
54+
has_changed=self.has_changed,
55+
is_not=self.is_not,
56+
was_not=self.was_not,
57+
changes_to=self.changes_to,
58+
)
59+
for field in self.when_any
60+
]
61+
else:
62+
return [lambda *_, **__: True]
63+
3064
def validate_hook(self, value, **kwargs):
3165
if value not in VALID_HOOKS:
3266
raise DjangoLifeCycleException(

django_lifecycle/mixins.py

Lines changed: 14 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from __future__ import annotations
12
from functools import partial, reduce, lru_cache
23
from inspect import isfunction
34
from typing import Any, List
5+
from typing import Iterable
46

57
from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist
68
from django.db import transaction
@@ -13,7 +15,6 @@
1315
)
1416
from django.utils.functional import cached_property
1517

16-
from . import NotSet
1718
from .abstract import AbstractHookedMethod
1819
from .decorators import HookConfig
1920
from .hooks import (
@@ -250,7 +251,9 @@ def _watched_fk_model_fields(cls) -> List[str]:
250251
def _watched_fk_models(cls) -> List[str]:
251252
return [_.split(".")[0] for _ in cls._watched_fk_model_fields()]
252253

253-
def _get_hooked_methods(self, hook: str, **kwargs) -> List[AbstractHookedMethod]:
254+
def _get_hooked_methods(
255+
self, hook: str, update_fields: Iterable[str] | None = None, **kwargs
256+
) -> List[AbstractHookedMethod]:
254257
"""
255258
Iterate through decorated methods to find those that should be
256259
triggered by the current hook. If conditions exist, check them before
@@ -266,44 +269,15 @@ def _get_hooked_methods(self, hook: str, **kwargs) -> List[AbstractHookedMethod]
266269
if callback_specs.hook != hook:
267270
continue
268271

269-
when_field = callback_specs.when
270-
when_any_field = callback_specs.when_any
271-
update_fields = kwargs.get("update_fields", None)
272-
is_partial_fields_update = update_fields is not None
273-
274-
if when_field:
275-
if not self._check_callback_conditions(
276-
when_field,
277-
callback_specs,
278-
is_synced=(
279-
is_partial_fields_update is False
280-
or when_field in update_fields
281-
),
282-
):
283-
continue
284-
elif when_any_field:
285-
any_condition_matched = False
286-
287-
for field_name in when_any_field:
288-
if self._check_callback_conditions(
289-
field_name,
290-
callback_specs,
291-
is_synced=(
292-
is_partial_fields_update is False
293-
or field_name in update_fields
294-
),
295-
):
296-
any_condition_matched = True
297-
break
298-
299-
if not any_condition_matched:
300-
continue
301-
302-
hooked_method = instantiate_hooked_method(method, callback_specs)
303-
hooked_methods.append(hooked_method)
304-
305-
# Only store the method once per hook
306-
break
272+
if any(
273+
condition(self, update_fields=update_fields)
274+
for condition in callback_specs.conditions
275+
):
276+
hooked_method = instantiate_hooked_method(method, callback_specs)
277+
hooked_methods.append(hooked_method)
278+
279+
# Only store the method once per hook
280+
break
307281

308282
return sorted(hooked_methods)
309283

@@ -317,69 +291,6 @@ def _run_hooked_methods(self, hook: str, **kwargs) -> List[str]:
317291

318292
return fired
319293

320-
def _check_callback_conditions(
321-
self, field_name: str, specs: dict, is_synced: bool
322-
) -> bool:
323-
if not self._check_has_changed(field_name, specs, is_synced):
324-
return False
325-
326-
if not self._check_changes_to_condition(field_name, specs, is_synced):
327-
return False
328-
329-
if not self._check_is_now_condition(field_name, specs):
330-
return False
331-
332-
if not self._check_was_condition(field_name, specs):
333-
return False
334-
335-
if not self._check_was_not_condition(field_name, specs):
336-
return False
337-
338-
if not self._check_is_not_condition(field_name, specs):
339-
return False
340-
341-
return True
342-
343-
def _check_has_changed(
344-
self, field_name: str, specs: HookConfig, is_synced: bool
345-
) -> bool:
346-
if not is_synced:
347-
return False
348-
349-
has_changed = specs.has_changed
350-
return has_changed is None or has_changed == self.has_changed(field_name)
351-
352-
def _check_is_now_condition(self, field_name: str, specs: HookConfig) -> bool:
353-
return specs.is_now in (self._current_value(field_name), "*")
354-
355-
def _check_is_not_condition(self, field_name: str, specs: HookConfig) -> bool:
356-
is_not = specs.is_not
357-
return is_not is NotSet or self._current_value(field_name) != is_not
358-
359-
def _check_was_condition(self, field_name: str, specs: HookConfig) -> bool:
360-
return specs.was in (self.initial_value(field_name), "*")
361-
362-
def _check_was_not_condition(self, field_name: str, specs: HookConfig) -> bool:
363-
was_not = specs.was_not
364-
return was_not is NotSet or self.initial_value(field_name) != was_not
365-
366-
def _check_changes_to_condition(
367-
self, field_name: str, specs: HookConfig, is_synced: bool
368-
) -> bool:
369-
if not is_synced:
370-
return False
371-
372-
changes_to = specs.changes_to
373-
return any(
374-
[
375-
changes_to is NotSet,
376-
(
377-
self.initial_value(field_name) != changes_to
378-
and self._current_value(field_name) == changes_to
379-
),
380-
]
381-
)
382-
383294
@classmethod
384295
def _get_model_property_names(cls) -> List[str]:
385296
"""

0 commit comments

Comments
 (0)