Skip to content

Commit 03be703

Browse files
authored
Merge pull request #98 from amcclosky/add-on-commit-option
Add optional `@hook` `on_commit` argument to for executing hooks `on_commit`
2 parents 756be52 + a69e64b commit 03be703

File tree

8 files changed

+153
-18
lines changed

8 files changed

+153
-18
lines changed

django_lifecycle/decorators.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from functools import wraps
2-
from typing import List
2+
from typing import List, Optional
33

44
from django_lifecycle import NotSet
55

@@ -10,7 +10,7 @@ class DjangoLifeCycleException(Exception):
1010
pass
1111

1212

13-
def _validate_hook_params(hook, when, when_any, has_changed):
13+
def _validate_hook_params(hook, when, when_any, has_changed, on_commit):
1414
if hook not in VALID_HOOKS:
1515
raise DjangoLifeCycleException(
1616
"%s is not a valid hook; must be one of %s" % (hook, VALID_HOOKS)
@@ -46,6 +46,13 @@ def _validate_hook_params(hook, when, when_any, has_changed):
4646
raise DjangoLifeCycleException(
4747
"Can pass either 'when' or 'when_any' but not both"
4848
)
49+
50+
if on_commit is not None:
51+
if not hook.startswith("after_"):
52+
raise DjangoLifeCycleException("'on_commit' hook param is only valid with AFTER_* hooks")
53+
54+
if not isinstance(on_commit, bool):
55+
raise DjangoLifeCycleException("'on_commit' hook param must be a boolean")
4956

5057

5158
def hook(
@@ -58,8 +65,9 @@ def hook(
5865
is_not=NotSet,
5966
was_not=NotSet,
6067
changes_to=NotSet,
68+
on_commit: Optional[bool] = None
6169
):
62-
_validate_hook_params(hook, when, when_any, has_changed)
70+
_validate_hook_params(hook, when, when_any, has_changed, on_commit)
6371

6472
def decorator(hooked_method):
6573
if not hasattr(hooked_method, "_hooked"):
@@ -83,6 +91,7 @@ def func(*args, **kwargs):
8391
"was": was,
8492
"was_not": was_not,
8593
"changes_to": changes_to,
94+
"on_commit": on_commit
8695
}
8796
)
8897

django_lifecycle/mixins.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from functools import reduce, lru_cache
1+
from functools import partial, reduce, lru_cache
22
from inspect import isfunction
33
from typing import Any, List
44

@@ -200,6 +200,8 @@ def _run_hooked_methods(self, hook: str, **kwargs) -> List[str]:
200200
for callback_specs in method._hooked:
201201
if callback_specs["hook"] != hook:
202202
continue
203+
204+
on_commit = callback_specs.get("on_commit", False)
203205

204206
when_field = callback_specs.get("when")
205207
when_any_field = callback_specs.get("when_any")
@@ -225,10 +227,27 @@ def _run_hooked_methods(self, hook: str, **kwargs) -> List[str]:
225227
]
226228
):
227229
continue
230+
231+
# Save method name before potentially wrapping with `on_commit`
232+
method_name = method.__name__
233+
234+
# Apply `on_commit` after saving the method as `fired` to preserve
235+
# the non-anonymous name
236+
if on_commit:
237+
# Append `_on_commit` to the existing method name to allow for firing
238+
# the same hook within the atomic transaction and on_commit
239+
method_name = method_name + "_on_commit"
240+
241+
# Use partial to create a function closure that binds `self`
242+
# to ensure its available to execute later.
243+
_on_commit_func = partial(method, self)
244+
_on_commit_func.__name__ = method_name
245+
transaction.on_commit(_on_commit_func)
246+
else:
247+
method(self)
228248

229249
# Only call the method once per hook
230-
fired.append(method.__name__)
231-
method(self)
250+
fired.append(method_name)
232251
break
233252

234253
return fired

docs/examples.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,14 @@ Or you want to email a user when their account is deleted. You could add the dec
2828
)
2929
```
3030

31+
Or if you want to enqueue a background job that depends on state being committed to your database
32+
33+
```python
34+
@hook(AFTER_CREATE, on_commit=True)
35+
def do_after_create_jobs(self):
36+
enqueue_job(send_item_shipped_notication, self.item_id)
37+
```
38+
3139
Read on to see how to only fire the hooked method if certain conditions about the model's current and previous state are met.
3240

3341
## Transitions between specific values

docs/hooks_and_conditions.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ You can hook into one or more lifecycle moments by adding the `@hook` decorator
1717
was: Any = '*',
1818
was_not: Any = None,
1919
changes_to: Any = None,
20+
on_commit: Optional[bool] = None
2021
):
2122
```
2223
## Lifecycle Moments
@@ -52,3 +53,4 @@ If you do not use any conditional parameters, the hook will fire every time the
5253
| was | Any | Only fire the hooked method if the value of the `when` field was equal to this value when first initialized; defaults to `*`. |
5354
| was_not | Any | Only fire the hooked method if the value of the `when` field was NOT equal to this value when first initialized. |
5455
| changes_to | Any | Only fire the hooked method if the value of the `when` field was NOT equal to this value when first initialized but is currently equal to this value. |
56+
| on_commit | bool | When `True` only fire the hooked method after the current database transaction has been commited or not at all. (Only applies to `AFTER_*` hooks) |

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
asgiref==3.4.1
22
Click==7.0
33
Django==3.2.8
4+
django-capture-on-commit-callbacks==1.10.0
45
djangorestframework==3.11.2
56
ghp-import==2.0.2
67
importlib-metadata==4.8.1

tests/testapp/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def lowercase_email(self):
5151
def timestamp_joined_at(self):
5252
self.joined_at = timezone.now()
5353

54-
@hook("after_create")
54+
@hook("after_create", on_commit=True)
5555
def do_after_create_jobs(self):
5656
# queue background job to process thumbnail image...
5757
mail.send_mail(
@@ -75,7 +75,7 @@ def ensure_trial_not_active(self):
7575
def ensure_last_name_is_not_changed_to_flanders(self):
7676
raise CannotRename("Oh, not Flanders. Anybody but Flanders.")
7777

78-
@hook("after_update", when="organization.name", has_changed=True)
78+
@hook("after_update", when="organization.name", has_changed=True, on_commit=True)
7979
def notify_org_name_change(self):
8080
mail.send_mail(
8181
"The name of your organization has changed!",

tests/testapp/tests/test_mixin.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,3 +366,85 @@ def test_comparison_state_should_reset_after_save(self):
366366
self.assertTrue(account.has_changed("first_name"))
367367
account.save()
368368
self.assertFalse(account.has_changed("first_name"))
369+
370+
def test_run_hooked_methods_for_on_commit(self):
371+
instance = UserAccount(first_name="Bob")
372+
373+
instance._potentially_hooked_methods = MagicMock(
374+
return_value = [
375+
MagicMock(
376+
__name__="method_that_fires_on_commit",
377+
_hooked=[
378+
{
379+
"hook": "after_create",
380+
"when": None,
381+
"when_any": None,
382+
"has_changed": None,
383+
"is_now": "*",
384+
"is_not": NotSet,
385+
"was": "*",
386+
"was_not": NotSet,
387+
"changes_to": NotSet,
388+
"on_commit": True
389+
}
390+
],
391+
),
392+
MagicMock(
393+
__name__="method_that_fires_in_transaction",
394+
_hooked=[
395+
{
396+
"hook": "after_create",
397+
"when": None,
398+
"when_any": None,
399+
"has_changed": None,
400+
"is_now": "*",
401+
"is_not": NotSet,
402+
"was": "*",
403+
"was_not": NotSet,
404+
"changes_to": NotSet,
405+
"on_commit": False
406+
}
407+
],
408+
),
409+
MagicMock(
410+
__name__="method_that_fires_in_default",
411+
_hooked=[
412+
{
413+
"hook": "after_create",
414+
"when": None,
415+
"when_any": None,
416+
"has_changed": None,
417+
"is_now": "*",
418+
"is_not": NotSet,
419+
"was": "*",
420+
"was_not": NotSet,
421+
"changes_to": NotSet,
422+
"on_commit": None
423+
}
424+
],
425+
),
426+
MagicMock(
427+
__name__="after_save_method_that_fires_on_commit",
428+
_hooked=[
429+
{
430+
"hook": "after_save",
431+
"when": None,
432+
"when_any": None,
433+
"has_changed": None,
434+
"is_now": "*",
435+
"is_not": NotSet,
436+
"was": "*",
437+
"was_not": NotSet,
438+
"changes_to": NotSet,
439+
"on_commit": True
440+
}
441+
],
442+
),
443+
]
444+
)
445+
446+
fired_methods = instance._run_hooked_methods("after_create")
447+
self.assertEqual(fired_methods, ["method_that_fires_on_commit_on_commit", "method_that_fires_in_transaction", "method_that_fires_in_default"])
448+
449+
fired_methods = instance._run_hooked_methods("after_save")
450+
self.assertEqual(fired_methods, ["after_save_method_that_fires_on_commit_on_commit"])

tests/testapp/tests/test_user_account.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from django.core import mail
44
from django.test import TestCase
55

6+
from django_capture_on_commit_callbacks import capture_on_commit_callbacks
7+
68
from tests.testapp.models import CannotDeleteActiveTrial, Organization, UserAccount
79

810

@@ -22,7 +24,10 @@ def test_update_joined_at_before_create(self):
2224
self.assertTrue(isinstance(account.joined_at, datetime.datetime))
2325

2426
def test_send_welcome_email_after_create(self):
25-
UserAccount.objects.create(**self.stub_data)
27+
with capture_on_commit_callbacks(execute=True) as callbacks:
28+
UserAccount.objects.create(**self.stub_data)
29+
30+
self.assertEquals(len(callbacks), 1, msg=f"{callbacks}")
2631
self.assertEqual(len(mail.outbox), 1)
2732
self.assertEqual(mail.outbox[0].subject, "Welcome!")
2833

@@ -73,12 +78,16 @@ def test_notify_org_name_change(self):
7378
org = Organization.objects.create(name="Hogwarts")
7479
UserAccount.objects.create(**self.stub_data, organization=org)
7580
mail.outbox = []
81+
7682
account = UserAccount.objects.get()
7783

78-
org.name = "Coursera Wizardry"
79-
org.save()
84+
with capture_on_commit_callbacks(execute=True) as callbacks:
85+
org.name = "Coursera Wizardry"
86+
org.save()
8087

81-
account.save()
88+
account.save()
89+
90+
self.assertEquals(len(callbacks), 1)
8291
self.assertEqual(len(mail.outbox), 1)
8392
self.assertEqual(
8493
mail.outbox[0].subject, "The name of your organization has changed!"
@@ -95,18 +104,23 @@ def test_no_notify_sent_if_org_name_has_not_changed(self):
95104
def test_additional_notify_sent_for_specific_org_name_change(self):
96105
org = Organization.objects.create(name="Hogwarts")
97106
UserAccount.objects.create(**self.stub_data, organization=org)
107+
98108
mail.outbox = []
99-
account = UserAccount.objects.get()
100109

101-
org.name = "Hogwarts Online"
102-
org.save()
110+
with capture_on_commit_callbacks(execute=True) as callbacks:
111+
account = UserAccount.objects.get()
103112

104-
account.save()
113+
org.name = "Hogwarts Online"
114+
org.save()
115+
116+
account.save()
117+
118+
self.assertEquals(len(callbacks), 1, msg="Only one hook should be an on_commit callback")
105119
self.assertEqual(len(mail.outbox), 2)
106120
self.assertEqual(
107-
mail.outbox[0].subject, "The name of your organization has changed!"
121+
mail.outbox[1].subject, "The name of your organization has changed!"
108122
)
109-
self.assertEqual(mail.outbox[1].subject, "You were moved to our online school!")
123+
self.assertEqual(mail.outbox[0].subject, "You were moved to our online school!")
110124

111125
def test_email_user_about_name_change(self):
112126
account = UserAccount.objects.create(**self.stub_data)

0 commit comments

Comments
 (0)