Skip to content

Commit 9d4d2e8

Browse files
committed
Prep scheduler for multiturn
Revert "Revert loop logic changes" This reverts commit bcc2f8c. Revert "Strip out multiturn features" This reverts commit a524469.
1 parent 89b501f commit 9d4d2e8

File tree

6 files changed

+122
-29
lines changed

6 files changed

+122
-29
lines changed

src/guidellm/request/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,17 @@
55
RequestLoaderDescription,
66
)
77
from .request import GenerationRequest
8+
from .session import GenerativeRequestSession, RequestSession
89
from .types import RequestT, ResponseT
910

1011
__all__ = [
1112
"GenerationRequest",
1213
"GenerativeRequestLoader",
1314
"GenerativeRequestLoaderDescription",
15+
"GenerativeRequestSession",
1416
"RequestLoader",
1517
"RequestLoaderDescription",
18+
"RequestSession",
1619
"RequestT",
1720
"ResponseT",
1821
]

src/guidellm/request/loader.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from guidellm.dataset import ColumnInputTypes, load_dataset
1616
from guidellm.objects import StandardBaseModel
1717
from guidellm.request.request import GenerationRequest
18+
from guidellm.request.session import GenerativeRequestSession
1819

1920
__all__ = [
2021
"GenerativeRequestLoader",
@@ -105,14 +106,14 @@ def __init__(
105106
self.preserve_iter_state = iter_type == "infinite" # ensure no caching requests
106107
self._preserved_iter = None
107108

108-
def __iter__(self) -> Iterator[GenerationRequest]:
109+
def __iter__(self) -> Iterator[GenerativeRequestSession]:
109110
scope_create_count = 0
110111

111112
while (dataset_iter := self._get_dataset_iter(scope_create_count)) is not None:
112113
scope_create_count += 1
113114

114115
for item in dataset_iter:
115-
yield self._create_request(item)
116+
yield GenerativeRequestSession(self._create_request(item))
116117

117118
self._preserved_iter = None
118119

src/guidellm/request/session.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Generic, TypeVar
3+
4+
from guidellm.backend.response import ResponseSummary
5+
from guidellm.request.request import GenerationRequest
6+
7+
__all__ = ["GenerativeRequestSession", "RequestSession"]
8+
9+
RequestT = TypeVar("RequestT")
10+
ResponseT = TypeVar("ResponseT")
11+
12+
13+
class RequestSession(ABC, Generic[RequestT, ResponseT]):
14+
"""
15+
A series of requests that build upon each other to
16+
form a conversion between the user and the model.
17+
"""
18+
19+
@abstractmethod
20+
def __len__(self) -> int: ...
21+
22+
@abstractmethod
23+
def get_next_request(self) -> RequestT: ...
24+
25+
@abstractmethod
26+
def get_next_delay(self) -> float: ...
27+
28+
@abstractmethod
29+
def push_response(self, response: ResponseT) -> None: ...
30+
31+
@property
32+
@abstractmethod
33+
def complete(self) -> bool: ...
34+
35+
36+
class GenerativeRequestSession(RequestSession[GenerationRequest, ResponseSummary]):
37+
def __init__(self, request: GenerationRequest) -> None:
38+
self.request = request
39+
self._complete = False
40+
41+
def __len__(self) -> int:
42+
return 1
43+
44+
def get_next_request(self) -> GenerationRequest:
45+
return self.request
46+
47+
def get_next_delay(self) -> float:
48+
return 0.0
49+
50+
def push_response(self, response: ResponseSummary) -> None: # noqa: ARG002
51+
self._complete = True
52+
53+
@property
54+
def complete(self) -> bool:
55+
return self._complete

src/guidellm/scheduler/result.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
)
77

88
from guidellm.objects import StandardBaseModel
9+
from guidellm.request.session import RequestSession
910
from guidellm.request.types import RequestT, ResponseT
1011
from guidellm.scheduler.strategy import SchedulingStrategy
1112

@@ -142,7 +143,7 @@ class SchedulerRequestResult(
142143

143144
@dataclass
144145
class WorkerProcessRequest(Generic[RequestT, ResponseT]):
145-
request: RequestT
146+
session: RequestSession[RequestT, ResponseT]
146147
timeout_time: float
147148
queued_time: float
148149

src/guidellm/scheduler/scheduler.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -130,17 +130,15 @@ async def run(
130130
futures, queues, stop_event = await self._start_processes(
131131
manager, executor, scheduling_strategy
132132
)
133-
run_info, requests_iter, times_iter = self._run_setup(
133+
run_info, requests_iter = self._run_setup(
134134
futures, scheduling_strategy, max_number, max_duration
135135
)
136136

137137
# Add some initial requests to the queue
138138
requests_iter = self._add_requests(
139139
requests_iter,
140140
queues.requests,
141-
times_iter,
142141
run_info,
143-
loop_limit=run_info.strategy.queued_requests_limit,
144142
)
145143
# Wait for the test to start
146144
await asyncio.sleep(time.time() - scheduling_strategy.start_time)
@@ -171,7 +169,6 @@ async def run(
171169
requests_iter = self._add_requests(
172170
requests_iter,
173171
queues.requests,
174-
times_iter,
175172
run_info,
176173
)
177174
await asyncio.sleep(0) # enable requests to start
@@ -244,6 +241,7 @@ async def _start_processes(
244241
queues,
245242
scheduling_strategy,
246243
stop_event,
244+
False, # TODO: Make configurable
247245
requests_limit,
248246
id_,
249247
num_processes,
@@ -260,9 +258,8 @@ def _run_setup(
260258
scheduling_strategy: SchedulingStrategy,
261259
max_number: Optional[int],
262260
max_duration: Optional[float],
263-
) -> tuple[SchedulerRunInfo, Iterator[Any], Iterator[float]]:
261+
) -> tuple[SchedulerRunInfo, Iterator[Any]]:
264262
requests_iter = iter(self.request_loader)
265-
times_iter = iter(scheduling_strategy.request_times())
266263
end_time = scheduling_strategy.start_time + (max_duration or math.inf)
267264
end_number = max_number or math.inf
268265

@@ -288,42 +285,39 @@ def _run_setup(
288285
strategy=scheduling_strategy,
289286
)
290287

291-
return info, requests_iter, times_iter
288+
return info, requests_iter
292289

293290
def _add_requests(
294291
self,
295292
requests_iter: Optional[Iterator[Any]],
296293
requests_queue: Queue[WorkerProcessRequest[RequestT, ResponseT]],
297-
times_iter: Iterator[float],
298294
run_info: SchedulerRunInfo,
299-
loop_limit: Optional[int] = None,
300295
) -> Optional[Iterator[Any]]:
301296
if requests_iter is not None:
302297
try:
303298
added_count = 0
304299

300+
if time.time() >= run_info.end_time:
301+
raise StopIteration
302+
305303
while not requests_queue.full() and added_count < (
306-
loop_limit or settings.max_add_requests_per_loop
304+
run_info.strategy.queued_requests_limit
305+
or settings.max_add_requests_per_loop
307306
):
308307
if run_info.created_requests >= run_info.end_number:
309308
raise StopIteration
310309

311-
if (
312-
next(times_iter) >= run_info.end_time
313-
or time.time() >= run_info.end_time
314-
):
315-
raise StopIteration
316-
317-
work_req = WorkerProcessRequest[RequestT, ResponseT](
318-
request=next(requests_iter),
310+
session = next(requests_iter)
311+
work_req = WorkerProcessRequest(
312+
session=session,
319313
timeout_time=run_info.end_time,
320314
queued_time=time.time(),
321315
)
322316
requests_queue.put(work_req)
323317

324-
run_info.created_requests += 1
325-
run_info.queued_requests += 1
326-
added_count += 1
318+
run_info.created_requests += len(session)
319+
run_info.queued_requests += len(session)
320+
added_count += len(session)
327321
except StopIteration:
328322
# we've reached the limit number, limit time, or exhausted the requests
329323
# set to None to stop adding more and tell the loop no more requests

src/guidellm/scheduler/worker.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,8 @@ async def resolve_scheduler_request(
122122
start_time: float,
123123
results_queue: Queue[WorkerProcessResult[RequestT, ResponseT]],
124124
process_id: int,
125-
):
126-
request = process_request.request
125+
) -> WorkerProcessRequest[RequestT, ResponseT]:
126+
request = process_request.session.get_next_request()
127127
timeout_time = process_request.timeout_time
128128
queued_time = process_request.queued_time
129129

@@ -170,17 +170,22 @@ async def resolve_scheduler_request(
170170
)
171171
asyncio.create_task(self.send_result(results_queue, result))
172172

173+
process_request.session.push_response(response)
174+
return process_request
175+
173176
def process_loop_asynchronous(
174177
self,
175178
queues: MPQueues[RequestT, ResponseT],
176179
strategy: SchedulingStrategy,
177180
stop_event: Event,
181+
prioritize_sessions: bool,
178182
max_concurrency: int,
179183
process_id: int,
180184
num_processes: int,
181185
):
182186
async def _process_runner():
183187
lock = asyncio.Semaphore(max_concurrency)
188+
pending_requests: list[WorkerProcessRequest[RequestT, ResponseT]] = []
184189
times_iter = islice(
185190
strategy.request_times(),
186191
process_id,
@@ -197,18 +202,50 @@ async def _process_runner():
197202
await asyncio.sleep(start_time - time.time() - 1)
198203
await lock.acquire()
199204

205+
process_request = None
200206
try:
201-
process_request = queues.requests.get_nowait()
207+
process_request = (
208+
pending_requests.pop()
209+
if pending_requests
210+
else queues.requests.get_nowait()
211+
)
202212
dequeued_time = time.time()
203213
except QueueEmpty:
204214
lock.release()
205215
continue
206216

217+
async def wait_then_requeue(
218+
process_request: WorkerProcessRequest[RequestT, ResponseT],
219+
):
220+
# Wait to requeue the request session if it specifies a delay
221+
if delay := process_request.session.get_next_delay():
222+
await asyncio.sleep(delay)
223+
224+
# Push session to the stack
225+
process_request.queued_time = time.time()
226+
pending_requests.append(process_request)
227+
if prioritize_sessions:
228+
# Release the lock with the session on top of the stack
229+
lock.release()
230+
207231
def _request_callback(
208-
_: asyncio.Future[WorkerProcessRequest[RequestT, ResponseT]],
232+
future: asyncio.Future[WorkerProcessRequest[RequestT, ResponseT]],
209233
):
234+
# If we are prioritizing sessions, hold
235+
# the lock until the session is done
210236
nonlocal lock
211-
lock.release()
237+
if not prioritize_sessions:
238+
lock.release()
239+
240+
try:
241+
process_request = future.result()
242+
except asyncio.CancelledError:
243+
return
244+
if not process_request.session.complete:
245+
asyncio.create_task(wait_then_requeue(process_request))
246+
elif prioritize_sessions:
247+
# no more requests in this session, release the lock
248+
lock.release()
212249

213250
task = asyncio.create_task(
214251
self.resolve_scheduler_request(
@@ -282,6 +319,7 @@ def process_loop_asynchronous(
282319
queues: MPQueues[GenerationRequest, ResponseSummary],
283320
strategy: SchedulingStrategy,
284321
stop_event: Event,
322+
prioritize_sessions: bool,
285323
max_concurrency: int,
286324
process_id: int,
287325
num_processes: int,
@@ -291,6 +329,7 @@ def process_loop_asynchronous(
291329
queues=queues,
292330
strategy=strategy,
293331
stop_event=stop_event,
332+
prioritize_sessions=prioritize_sessions,
294333
max_concurrency=max_concurrency,
295334
process_id=process_id,
296335
num_processes=num_processes,

0 commit comments

Comments
 (0)