Skip to content

Commit 73dbf14

Browse files
committed
Refactor RequestGenerator to use threading and update test suite
Refactored RequestGenerator class: - Replaced asyncio.Queue with Queue from the queue module for thread safety. - Utilized threading for background queue population to ensure non-blocking request generation. - Removed the start method as threading automatically starts the background task in async mode. - Ensured that the _populate_queue method runs in a background thread to keep the queue populated. - Implemented clean shutdown with the stop method joining the background thread. Updated unit tests: - Added test_request_generator_sync_constructor and test_request_generator_async_constructor to verify constructor behavior. - Added tests for __repr__ and __iter__ methods. - Added tests to ensure create_item raises NotImplementedError if not overridden. - Added tests to verify __iter__ calls create_item the expected number of times. Separated test files: - Created tests/unit/request/test_base.py for unit tests. - Created tests/integration/request/test_base.py for integration tests. Unit tests: - Verified the construction of the class with different input parameters. - Mocked AutoTokenizer for testing tokenizer initialization with both a class implementation and a string alias. - Ensured that the __iter__ method works correctly in both sync and async modes. - Verified that create_item is called the expected number of times. Integration tests: - Tested tokenizer construction with both a Hugging Face tokenizer and a string alias, ensuring the correct tokenizer is created.
1 parent 2287517 commit 73dbf14

File tree

4 files changed

+138
-50
lines changed

4 files changed

+138
-50
lines changed

src/guidellm/request/base.py

Lines changed: 31 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
import asyncio
1+
import threading
2+
import time
23
from abc import ABC, abstractmethod
4+
from queue import Empty, Full, Queue
35
from typing import Iterator, Optional, Union
46

57
from loguru import logger
@@ -31,20 +33,29 @@ def __init__(
3133
):
3234
self._async_queue_size = async_queue_size
3335
self._mode = mode
34-
self._queue = asyncio.Queue(maxsize=async_queue_size)
35-
self._stop_event = asyncio.Event()
36+
self._queue = Queue(maxsize=async_queue_size)
37+
self._stop_event = threading.Event()
3638

3739
if tokenizer is not None:
3840
self._tokenizer = (
3941
AutoTokenizer.from_pretrained(tokenizer)
4042
if isinstance(tokenizer, str)
4143
else tokenizer
4244
)
43-
logger.info(f"Tokenizer initialized: {self._tokenizer}")
45+
logger.info("Tokenizer initialized: {}", self._tokenizer)
4446
else:
4547
self._tokenizer = None
4648
logger.debug("No tokenizer provided")
4749

50+
if self._mode == "async":
51+
self._thread = threading.Thread(target=self._populate_queue)
52+
self._thread.daemon = True
53+
self._thread.start()
54+
logger.info(
55+
"RequestGenerator started in async mode with queue size: {}",
56+
self._async_queue_size,
57+
)
58+
4859
def __repr__(self) -> str:
4960
"""
5061
Return a string representation of the RequestGenerator.
@@ -72,7 +83,7 @@ def __iter__(self) -> Iterator[TextGenerationRequest]:
7283
item = self._queue.get_nowait()
7384
self._queue.task_done()
7485
yield item
75-
except asyncio.QueueEmpty:
86+
except Empty:
7687
continue
7788
else:
7889
while not self._stop_event.is_set():
@@ -118,46 +129,31 @@ def create_item(self) -> TextGenerationRequest:
118129
"""
119130
raise NotImplementedError()
120131

121-
def start(self):
122-
"""
123-
Start the background task that populates the queue.
124-
"""
125-
if self.mode == "async":
126-
try:
127-
loop = asyncio.get_running_loop()
128-
logger.info("Using existing event loop")
129-
except RuntimeError:
130-
raise RuntimeError("No running event loop found for async mode")
131-
132-
loop.call_soon_threadsafe(
133-
lambda: asyncio.create_task(self._populate_queue())
134-
)
135-
logger.info(
136-
f"RequestGenerator started in async mode with queue size: "
137-
f"{self._async_queue_size}"
138-
)
139-
else:
140-
logger.info("RequestGenerator started in sync mode")
141-
142132
def stop(self):
143133
"""
144134
Stop the background task that populates the queue.
145135
"""
146136
logger.info("Stopping RequestGenerator...")
147137
self._stop_event.set()
138+
if self._mode == "async":
139+
self._thread.join()
148140
logger.info("RequestGenerator stopped")
149141

150-
async def _populate_queue(self):
142+
def _populate_queue(self):
151143
"""
152144
Populate the request queue in the background.
153145
"""
154146
while not self._stop_event.is_set():
155-
if self._queue.qsize() < self._async_queue_size:
156-
item = self.create_item()
157-
await self._queue.put(item)
158-
logger.debug(
159-
f"Item added to queue. Current queue size: {self._queue.qsize()}"
160-
)
161-
else:
162-
await asyncio.sleep(0.1)
147+
try:
148+
if self._queue.qsize() < self._async_queue_size:
149+
item = self.create_item()
150+
self._queue.put(item, timeout=0.1)
151+
logger.debug(
152+
"Item added to queue. Current queue size: {}",
153+
self._queue.qsize(),
154+
)
155+
else:
156+
time.sleep(0.1)
157+
except Full:
158+
continue
163159
logger.info("RequestGenerator stopped populating queue")
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import pytest
2+
from transformers import AutoTokenizer, PreTrainedTokenizerBase
3+
from guidellm.core.request import TextGenerationRequest
4+
from guidellm.request.base import RequestGenerator
5+
6+
7+
class TestRequestGenerator(RequestGenerator):
8+
def create_item(self) -> TextGenerationRequest:
9+
return TextGenerationRequest(prompt="Test prompt")
10+
11+
12+
@pytest.mark.smoke
13+
def test_request_generator_with_hf_tokenizer():
14+
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
15+
generator = TestRequestGenerator(tokenizer=tokenizer)
16+
assert generator.tokenizer == tokenizer
17+
18+
19+
@pytest.mark.smoke
20+
def test_request_generator_with_string_tokenizer():
21+
generator = TestRequestGenerator(tokenizer="bert-base-uncased")
22+
assert isinstance(generator.tokenizer, PreTrainedTokenizerBase)
23+
assert generator.tokenizer.name_or_path == "bert-base-uncased"

tests/unit/request/test_base.py

Lines changed: 84 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from unittest.mock import Mock, patch
2+
13
import pytest
24

35
from guidellm.core.request import TextGenerationRequest
@@ -10,15 +12,28 @@ def create_item(self) -> TextGenerationRequest:
1012

1113

1214
@pytest.mark.smoke
13-
def test_request_generator_sync():
15+
def test_request_generator_sync_constructor():
1416
generator = TestRequestGenerator(mode="sync")
1517
assert generator.mode == "sync"
18+
assert generator.async_queue_size == 50 # Default value
1619
assert generator.tokenizer is None
1720

21+
22+
@pytest.mark.smoke
23+
def test_request_generator_async_constructor():
24+
generator = TestRequestGenerator(mode="async", async_queue_size=10)
25+
assert generator.mode == "async"
26+
assert generator.async_queue_size == 10
27+
assert generator.tokenizer is None
28+
generator.stop()
29+
30+
31+
@pytest.mark.smoke
32+
def test_request_generator_sync_iter():
33+
generator = TestRequestGenerator(mode="sync")
1834
items = []
1935
for item in generator:
2036
items.append(item)
21-
2237
if len(items) == 5:
2338
break
2439

@@ -27,28 +42,30 @@ def test_request_generator_sync():
2742

2843

2944
@pytest.mark.smoke
30-
@pytest.mark.asyncio
31-
def test_request_generator_async():
32-
generator = TestRequestGenerator(mode="async", async_queue_size=10)
33-
assert generator.mode == "async"
34-
assert generator.async_queue_size == 10
35-
assert generator.tokenizer is None
36-
37-
generator.start()
38-
45+
def test_request_generator_async_iter():
46+
generator = TestRequestGenerator(mode="async")
3947
items = []
4048
for item in generator:
4149
items.append(item)
42-
4350
if len(items) == 5:
4451
break
4552

4653
generator.stop()
47-
assert generator._stop_event.is_set()
48-
4954
assert len(items) == 5
5055
assert items[0].prompt == "Test prompt"
51-
assert items[-1].prompt == "Test prompt"
56+
57+
58+
@pytest.mark.regression
59+
def test_request_generator_with_mock_tokenizer():
60+
mock_tokenizer = Mock()
61+
generator = TestRequestGenerator(tokenizer=mock_tokenizer)
62+
assert generator.tokenizer == mock_tokenizer
63+
64+
with patch("guidellm.request.base.AutoTokenizer") as MockAutoTokenizer:
65+
MockAutoTokenizer.from_pretrained.return_value = mock_tokenizer
66+
generator = TestRequestGenerator(tokenizer="mock-tokenizer")
67+
assert generator.tokenizer == mock_tokenizer
68+
MockAutoTokenizer.from_pretrained.assert_called_with("mock-tokenizer")
5269

5370

5471
@pytest.mark.regression
@@ -57,3 +74,55 @@ def test_request_generator_repr():
5774
assert repr(generator) == (
5875
"RequestGenerator(mode=sync, async_queue_size=100, tokenizer=None)"
5976
)
77+
78+
79+
@pytest.mark.regression
80+
def test_request_generator_create_item_not_implemented():
81+
with pytest.raises(TypeError):
82+
class IncompleteRequestGenerator(RequestGenerator):
83+
pass
84+
85+
IncompleteRequestGenerator()
86+
87+
class IncompleteCreateItemGenerator(RequestGenerator):
88+
def create_item(self):
89+
super().create_item()
90+
91+
generator = IncompleteCreateItemGenerator()
92+
with pytest.raises(NotImplementedError):
93+
generator.create_item()
94+
95+
96+
@pytest.mark.regression
97+
def test_request_generator_iter_calls_create_item():
98+
generator = TestRequestGenerator(mode="sync")
99+
generator.create_item = Mock(
100+
return_value=TextGenerationRequest(prompt="Mock prompt")
101+
)
102+
103+
items = []
104+
for item in generator:
105+
items.append(item)
106+
if len(items) == 5:
107+
break
108+
109+
assert len(items) == 5
110+
generator.create_item.assert_called()
111+
112+
113+
@pytest.mark.regression
114+
def test_request_generator_async_iter_calls_create_item():
115+
generator = TestRequestGenerator(mode="sync")
116+
generator.create_item = Mock(
117+
return_value=TextGenerationRequest(prompt="Mock prompt")
118+
)
119+
120+
items = []
121+
for item in generator:
122+
items.append(item)
123+
if len(items) == 5:
124+
break
125+
126+
generator.stop()
127+
assert len(items) == 5
128+
generator.create_item.assert_called()

0 commit comments

Comments
 (0)