Skip to content

Commit 44ffb3e

Browse files
authored
Add test for Unbatchify util (#8076)
1 parent 6ee8cdc commit 44ffb3e

File tree

1 file changed

+75
-0
lines changed

1 file changed

+75
-0
lines changed

tests/utils/test_unbatchify.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import time
2+
from unittest.mock import MagicMock
3+
from concurrent.futures import Future
4+
5+
from dspy.utils.unbatchify import Unbatchify
6+
7+
8+
def simple_batch_processor(batch):
9+
"""A simple batch function that adds 1 to each item."""
10+
return [item + 1 for item in batch]
11+
12+
13+
def submit(self, input_item: any) -> Future:
14+
"""Submits an item for processing and returns a Future."""
15+
future = Future()
16+
self.input_queue.put((input_item, future))
17+
return future
18+
19+
20+
Unbatchify.submit = submit
21+
22+
23+
def test_unbatchify_batch_size_trigger():
24+
"""Test that the batch processes exactly when max_batch_size is reached."""
25+
batch_fn_mock = MagicMock(wraps=simple_batch_processor)
26+
unbatcher = Unbatchify(batch_fn=batch_fn_mock, max_batch_size=2, max_wait_time=5.0)
27+
28+
futures = []
29+
futures.append(unbatcher.submit(10))
30+
time.sleep(0.02)
31+
assert batch_fn_mock.call_count == 0
32+
33+
futures.append(unbatcher.submit(20))
34+
35+
results_1_2 = [f.result() for f in futures]
36+
assert batch_fn_mock.call_count == 1
37+
batch_fn_mock.assert_called_once_with([10, 20])
38+
assert results_1_2 == [11, 21]
39+
40+
futures_3_4 = []
41+
futures_3_4.append(unbatcher.submit(30))
42+
futures_3_4.append(unbatcher.submit(40))
43+
44+
results_3_4 = [f.result() for f in futures_3_4]
45+
time.sleep(0.01)
46+
assert batch_fn_mock.call_count == 2
47+
assert batch_fn_mock.call_args_list[1].args[0] == [30, 40]
48+
assert results_3_4 == [31, 41]
49+
50+
unbatcher.close()
51+
52+
53+
54+
def test_unbatchify_timeout_trigger():
55+
"""Test that the batch processes after max_wait_time."""
56+
batch_fn_mock = MagicMock(wraps=simple_batch_processor)
57+
wait_time = 0.15
58+
unbatcher = Unbatchify(batch_fn=batch_fn_mock, max_batch_size=5, max_wait_time=wait_time)
59+
60+
futures = []
61+
futures.append(unbatcher.submit(100))
62+
futures.append(unbatcher.submit(200))
63+
64+
time.sleep(wait_time / 2)
65+
assert batch_fn_mock.call_count == 0
66+
67+
results = [f.result() for f in futures]
68+
69+
assert batch_fn_mock.call_count == 1
70+
batch_fn_mock.assert_called_once_with([100, 200])
71+
assert results == [101, 201]
72+
73+
unbatcher.close()
74+
75+

0 commit comments

Comments
 (0)