Skip to content

Commit 613533d

Browse files
committed
Merge branch 'bulk-enqueue'
2 parents 2c67eb1 + a5d0f5d commit 613533d

File tree

5 files changed

+279
-10
lines changed

5 files changed

+279
-10
lines changed

django_lightweight_queue/backends/base.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABCMeta, abstractmethod
2-
from typing import Tuple, TypeVar, Optional
2+
from typing import Tuple, TypeVar, Optional, Collection
33

44
from ..job import Job
55
from ..types import QueueName, WorkerNumber
@@ -18,6 +18,16 @@ def startup(self, queue: QueueName) -> None:
1818
def enqueue(self, job: Job, queue: QueueName) -> None:
1919
raise NotImplementedError()
2020

21+
def bulk_enqueue(self, jobs: Collection[Job], queue: QueueName) -> None:
22+
"""
23+
Enqueue a number of tasks in one pass.
24+
25+
Backends are strongly encouraged to override this with a more efficient
26+
implemenation if they can.
27+
"""
28+
for job in jobs:
29+
self.enqueue(job, queue)
30+
2131
@abstractmethod
2232
def dequeue(self, queue: QueueName, worker_num: WorkerNumber, timeout: int) -> Optional[Job]:
2333
raise NotImplementedError()

django_lightweight_queue/backends/redis.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional
1+
from typing import Optional, Collection
22

33
import redis
44

@@ -21,7 +21,13 @@ def __init__(self) -> None:
2121
)
2222

2323
def enqueue(self, job: Job, queue: QueueName) -> None:
24-
self.client.lpush(self._key(queue), job.to_json().encode('utf-8'))
24+
return self.bulk_enqueue([job], queue)
25+
26+
def bulk_enqueue(self, jobs: Collection[Job], queue: QueueName) -> None:
27+
self.client.lpush(
28+
self._key(queue),
29+
*(job.to_json().encode('utf-8') for job in jobs),
30+
)
2531

2632
def dequeue(self, queue: QueueName, worker_num: WorkerNumber, timeout: int) -> Optional[Job]:
2733
raw = self.client.brpop(self._key(queue), timeout)

django_lightweight_queue/backends/reliable_redis.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, List, Tuple, TypeVar, Optional
1+
from typing import Dict, List, Tuple, TypeVar, Optional, Collection
22

33
import redis
44

@@ -88,7 +88,13 @@ def move_processing_jobs_to_main(pipe: redis.client.Pipeline) -> None:
8888
)
8989

9090
def enqueue(self, job: Job, queue: QueueName) -> None:
91-
self.client.lpush(self._key(queue), job.to_json().encode('utf-8'))
91+
return self.bulk_enqueue([job], queue)
92+
93+
def bulk_enqueue(self, jobs: Collection[Job], queue: QueueName) -> None:
94+
self.client.lpush(
95+
self._key(queue),
96+
*(job.to_json().encode('utf-8') for job in jobs),
97+
)
9298

9399
def dequeue(self, queue: QueueName, worker_number: WorkerNumber, timeout: int) -> Optional[Job]:
94100
main_queue_key = self._key(queue)

django_lightweight_queue/task.py

Lines changed: 96 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,16 @@
1-
from typing import Any, Generic, TypeVar, Callable, Optional
1+
from types import TracebackType
2+
from typing import (
3+
Any,
4+
cast,
5+
Dict,
6+
List,
7+
Type,
8+
Tuple,
9+
Generic,
10+
TypeVar,
11+
Callable,
12+
Optional,
13+
)
214

315
from . import app_settings
416
from .job import Job
@@ -83,6 +95,48 @@ def __call__(self, fn: TCallable) -> 'TaskWrapper[TCallable]':
8395
return TaskWrapper(fn, self.queue, self.timeout, self.sigkill_on_stop, self.atomic)
8496

8597

98+
class BulkEnqueueHelper(Generic[TCallable]):
99+
def __init__(
100+
self,
101+
task_wrapper: 'TaskWrapper[TCallable]',
102+
batch_size: int,
103+
queue_override: Optional[QueueName],
104+
) -> None:
105+
self._to_create: List[Job] = []
106+
self._task_wrapper = task_wrapper
107+
self.batch_size = batch_size
108+
self.queue_override = queue_override
109+
110+
def __enter__(self) -> TCallable:
111+
return cast(TCallable, self._create)
112+
113+
def __exit__(
114+
self,
115+
exc_type: Optional[Type[BaseException]],
116+
exc_val: Optional[BaseException],
117+
exc_tb: Optional[TracebackType],
118+
) -> None:
119+
self.flush()
120+
121+
def _create(self, *args: Any, **kwargs: Any) -> None:
122+
self._to_create.append(
123+
self._task_wrapper._build_job(args, kwargs),
124+
)
125+
if len(self._to_create) >= self.batch_size:
126+
self.flush()
127+
128+
def flush(self) -> None:
129+
if not self._to_create:
130+
return
131+
132+
self._task_wrapper._enqueue_job_instances(
133+
self._to_create,
134+
queue_override=self.queue_override,
135+
)
136+
137+
self._to_create = []
138+
139+
86140
class TaskWrapper(Generic[TCallable]):
87141
def __init__(
88142
self,
@@ -103,18 +157,55 @@ def __init__(
103157
def __repr__(self) -> str:
104158
return "<TaskWrapper: {}>".format(self.path)
105159

106-
def __call__(self, *args: Any, **kwargs: Any) -> None:
160+
def _build_job(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Job:
107161
# Allow us to override the default values dynamically
108162
timeout = kwargs.pop('django_lightweight_queue_timeout', self.timeout)
109163
sigkill_on_stop = kwargs.pop(
110164
'django_lightweight_queue_sigkill_on_stop',
111165
self.sigkill_on_stop,
112166
)
113167

114-
# Allow queue overrides, but you must ensure that this queue will exist
115-
queue = kwargs.pop('django_lightweight_queue_queue', self.queue)
116-
117168
job = Job(self.path, args, kwargs, timeout, sigkill_on_stop)
118169
job.validate()
119170

171+
return job
172+
173+
def _enqueue_job_instances(
174+
self,
175+
new_jobs: List[Job],
176+
queue_override: Optional[QueueName],
177+
) -> None:
178+
queue = queue_override if queue_override is not None else self.queue
179+
get_backend(queue).bulk_enqueue(new_jobs, queue)
180+
181+
def __call__(self, *args: Any, **kwargs: Any) -> None:
182+
job = self._build_job(args, kwargs)
183+
184+
# Allow queue overrides, but you must ensure that this queue will exist
185+
queue = kwargs.pop('django_lightweight_queue_queue', self.queue)
186+
120187
get_backend(queue).enqueue(job, queue)
188+
189+
def bulk_enqueue(
190+
self,
191+
batch_size: int = 1000,
192+
queue_override: Optional[QueueName] = None,
193+
) -> BulkEnqueueHelper[TCallable]:
194+
"""
195+
Enqueue jobs in bulk.
196+
197+
Use like:
198+
199+
with my_task.bulk_enqueue() as enqueue:
200+
enqueue(the_ids=[42, 43])
201+
enqueue(the_ids=[45, 46])
202+
203+
This is equivalent to:
204+
205+
my_task(the_ids=[42, 43])
206+
my_task(the_ids=[45, 46])
207+
208+
The target queue for the whole batch may be overridden, however the
209+
caller must ensure that the queue actually exists (i.e: has workers).
210+
"""
211+
return BulkEnqueueHelper(self, batch_size, queue_override)

tests/test_task.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
import unittest
2+
import contextlib
3+
from typing import Any, Mapping, Iterator
4+
from unittest import mock
5+
6+
import fakeredis
7+
from django_lightweight_queue import task
8+
from django_lightweight_queue.types import QueueName, WorkerNumber
9+
from django_lightweight_queue.utils import get_path, get_backend
10+
from django_lightweight_queue.backends.redis import RedisBackend
11+
12+
from . import settings
13+
14+
QUEUE = QueueName('dummy-queue')
15+
16+
17+
@task(str(QUEUE))
18+
def dummy_task(num: int) -> None:
19+
pass
20+
21+
22+
class TaskTests(unittest.TestCase):
23+
longMessage = True
24+
prefix = settings.LIGHTWEIGHT_QUEUE_REDIS_PREFIX
25+
26+
@contextlib.contextmanager
27+
def mock_workers(self, workers: Mapping[str, int]) -> Iterator[None]:
28+
with unittest.mock.patch(
29+
'django_lightweight_queue.utils._accepting_implied_queues',
30+
new=False,
31+
), unittest.mock.patch.dict(
32+
'django_lightweight_queue.app_settings.WORKERS',
33+
workers,
34+
):
35+
yield
36+
37+
def setUp(self) -> None:
38+
super().setUp()
39+
40+
get_backend.cache_clear()
41+
42+
with mock.patch('redis.StrictRedis', fakeredis.FakeStrictRedis):
43+
self.backend = RedisBackend()
44+
45+
# Mock get_backend. Unfortunately due to the nameing of the 'task'
46+
# decorator class being the same as its containing module and it being
47+
# exposed as the symbol at django_lightweight_queue.task, we cannot mock
48+
# this in the normal way. Instead we mock get_path (which get_backend
49+
# calls) and intercept the our dummy value.
50+
def mocked_get_path(path: str) -> Any:
51+
if path == 'test-backend':
52+
return lambda: self.backend
53+
return get_path(path)
54+
55+
patch = mock.patch(
56+
'django_lightweight_queue.app_settings.BACKEND',
57+
new='test-backend',
58+
)
59+
patch.start()
60+
self.addCleanup(patch.stop)
61+
patch = mock.patch(
62+
'django_lightweight_queue.utils.get_path',
63+
side_effect=mocked_get_path,
64+
)
65+
patch.start()
66+
self.addCleanup(patch.stop)
67+
68+
def tearDown(self) -> None:
69+
super().tearDown()
70+
get_backend.cache_clear()
71+
72+
def test_enqueues_job(self) -> None:
73+
self.assertEqual(0, self.backend.length(QUEUE))
74+
75+
dummy_task(42)
76+
77+
job = self.backend.dequeue(QUEUE, WorkerNumber(0), 5)
78+
# Plain assert to placate mypy
79+
assert job is not None, "Failed to get a job after enqueuing one"
80+
81+
self.assertEqual(
82+
{
83+
'path': 'tests.test_task.dummy_task',
84+
'args': [42],
85+
'kwargs': {},
86+
'timeout': None,
87+
'sigkill_on_stop': False,
88+
'created_time': mock.ANY,
89+
},
90+
job.as_dict(),
91+
)
92+
93+
def test_bulk_enqueues_jobs(self) -> None:
94+
self.assertEqual(0, self.backend.length(QUEUE))
95+
96+
with dummy_task.bulk_enqueue() as enqueue:
97+
enqueue(13)
98+
enqueue(num=42)
99+
100+
job = self.backend.dequeue(QUEUE, WorkerNumber(0), 5)
101+
# Plain assert to placate mypy
102+
assert job is not None, "Failed to get a job after enqueuing one"
103+
104+
self.assertEqual(
105+
{
106+
'path': 'tests.test_task.dummy_task',
107+
'args': [13],
108+
'kwargs': {},
109+
'timeout': None,
110+
'sigkill_on_stop': False,
111+
'created_time': mock.ANY,
112+
},
113+
job.as_dict(),
114+
"First job",
115+
)
116+
117+
job = self.backend.dequeue(QUEUE, WorkerNumber(0), 5)
118+
# Plain assert to placate mypy
119+
assert job is not None, "Failed to get a job after enqueuing one"
120+
121+
self.assertEqual(
122+
{
123+
'path': 'tests.test_task.dummy_task',
124+
'args': [],
125+
'kwargs': {'num': 42},
126+
'timeout': None,
127+
'sigkill_on_stop': False,
128+
'created_time': mock.ANY,
129+
},
130+
job.as_dict(),
131+
"Second job",
132+
)
133+
134+
def test_bulk_enqueues_jobs_batch_size_boundary(self) -> None:
135+
self.assertEqual(0, self.backend.length(QUEUE), "Should initially be empty")
136+
137+
with dummy_task.bulk_enqueue(batch_size=3) as enqueue:
138+
enqueue(1)
139+
enqueue(2)
140+
enqueue(3)
141+
enqueue(4)
142+
143+
jobs = [
144+
self.backend.dequeue(QUEUE, WorkerNumber(0), 5)
145+
for _ in range(4)
146+
]
147+
148+
self.assertEqual(0, self.backend.length(QUEUE), "Should be empty after dequeuing all jobs")
149+
150+
args = [x.args for x in jobs if x is not None]
151+
152+
self.assertEqual(
153+
[[1], [2], [3], [4]],
154+
args,
155+
"Wrong jobs bulk enqueued",
156+
)

0 commit comments

Comments
 (0)