22import math
33import time
44from dataclasses import dataclass
5- from typing import AsyncGenerator , Literal , Optional , Union , get_args
5+ from typing import AsyncGenerator , List , Literal , Optional , Tuple , Union , get_args
66
77from loguru import logger
88
@@ -35,13 +35,18 @@ class SchedulerResult:
3535 :type benchmark: TextGenerationBenchmark
3636 :param current_result: The result of the current request, if any.
3737 :type current_result: Optional[Union[TextGenerationResult, Exception]]
38+ :param batch_results: The result of the current batch of requests, if any
39+ :type batch_results: Optional[List[Union[TextGenerationResult, TextGenerationError]]]
3840 """
3941
4042 completed : bool
4143 count_total : int
4244 count_completed : int
4345 benchmark : TextGenerationBenchmark
4446 current_result : Optional [Union [TextGenerationResult , TextGenerationError ]] = None
47+ batch_results : Optional [List [Union [TextGenerationResult , TextGenerationError ]]] = (
48+ None
49+ )
4550
4651
4752class Scheduler :
@@ -74,6 +79,7 @@ def __init__(
7479 rate : Optional [float ] = None ,
7580 max_number : Optional [int ] = None ,
7681 max_duration : Optional [float ] = None ,
82+ batch_size : Optional [int ] = None ,
7783 ):
7884 logger .info (
7985 "Scheduler initialized with params: generator={}, worker={}, mode={}, "
@@ -114,13 +120,20 @@ def __init__(
114120 logger .error (err )
115121 raise err
116122
123+ if batch_size and batch_size <= 0 :
124+ err = ValueError (f"batch_size must be > 0, given: { batch_size } " )
125+ logger .error (err )
126+ raise err
127+
117128 self ._generator = generator
118129 self ._worker = worker
119130 self ._mode = mode
120131 self ._rate = rate
121132 self ._max_number = max_number
122133 self ._max_duration = max_duration
123134
135+ self ._batch_size = batch_size
136+
124137 self ._load_generator = LoadGenerator (mode , rate )
125138
126139 @property
@@ -209,6 +222,20 @@ def benchmark_mode(self) -> Literal["asynchronous", "synchronous", "throughput"]
209222
210223 return "asynchronous"
211224
225+ @property
226+ def batch_size (self ) -> Optional [int ]:
227+ """
228+ Returns the maximum number of requests to generate.
229+
230+ :return: Maximum number of requests or None.
231+ :rtype: Optional[int]
232+ """
233+
234+ return self ._batch_size
235+
236+ async def _run_batch (self ):
237+ pass
238+
212239 async def run (self ) -> AsyncGenerator [SchedulerResult , None ]:
213240 """
214241 Run the scheduler to process requests based on the configured mode, rate,
@@ -223,15 +250,17 @@ async def run(self) -> AsyncGenerator[SchedulerResult, None]:
223250 start_time = time .time ()
224251 end_time = start_time + self .max_duration if self .max_duration else math .inf
225252 max_number = float (self .max_number ) if self .max_number else math .inf
226- runner = self ._run_sync if self ._mode == "synchronous" else self ._run_async
227253 count_total = (
228254 self .max_number
229255 if self .max_number
230- else round (self .max_duration )
231- if self .max_duration
232- else 0
256+ else round (self .max_duration ) if self .max_duration else 0
233257 )
234258
259+ if self .batch_size :
260+ runner = self ._run_batch
261+ else :
262+ runner = self ._run_sync if self ._mode == "synchronous" else self ._run_async
263+
235264 # yield initial result for progress tracking
236265 yield SchedulerResult (
237266 completed = False ,
@@ -243,21 +272,30 @@ async def run(self) -> AsyncGenerator[SchedulerResult, None]:
243272 run_count = 0
244273 async for res in runner (benchmark , end_time , max_number ):
245274 run_count += 1
275+
246276 count_completed = (
247277 min (run_count , self .max_number )
248278 if self .max_number
249- else round (time .time () - start_time )
250- if self .max_duration
251- else 0
279+ else round (time .time () - start_time ) if self .max_duration else 0
252280 )
253281
254- yield SchedulerResult (
255- completed = False ,
256- count_total = count_total ,
257- count_completed = count_completed ,
258- benchmark = benchmark ,
259- current_result = res ,
260- )
282+ if self .batch_size :
283+
284+ yield SchedulerResult (
285+ completed = False ,
286+ count_total = count_total ,
287+ count_completed = count_completed ,
288+ benchmark = benchmark ,
289+ batch_results = res ,
290+ )
291+ else :
292+ yield SchedulerResult (
293+ completed = False ,
294+ count_total = count_total ,
295+ count_completed = count_completed ,
296+ benchmark = benchmark ,
297+ current_result = res ,
298+ )
261299
262300 logger .info ("Scheduler run completed" )
263301
@@ -267,9 +305,7 @@ async def run(self) -> AsyncGenerator[SchedulerResult, None]:
267305 count_completed = (
268306 benchmark .request_count + benchmark .error_count
269307 if self .max_number
270- else round (time .time () - start_time )
271- if self .max_duration
272- else 0
308+ else round (time .time () - start_time ) if self .max_duration else 0
273309 ),
274310 benchmark = benchmark ,
275311 )
@@ -372,3 +408,81 @@ async def _submit_task_coroutine(
372408 logger .warning ("Request {} failed: {}" , request , exc )
373409
374410 return TextGenerationError (request = request , message = str (exc ))
411+
412+ async def _run_batch (
413+ self , benchmark : TextGenerationBenchmark , end_time : float , max_number : float
414+ ) -> AsyncGenerator [SchedulerResult , None ]:
415+
416+ if self .batch_size is None :
417+ raise ValueError ("--batch-size CLI parameter is not set" )
418+
419+ batch = []
420+ count_completed = 0
421+
422+ for request , submit_at in zip (self .generator , self .load_generator .times ()):
423+ if time .time () >= end_time or count_completed >= max_number :
424+ break
425+
426+ if len (batch ) < self .batch_size :
427+ batch .append ((request , submit_at ))
428+
429+ if len (batch ) >= self .batch_size :
430+ results = await self ._process_batch (batch , benchmark , end_time )
431+ count_completed += len (
432+ [r for r in results if not isinstance (r , TextGenerationError )]
433+ )
434+
435+ yield results
436+ batch = []
437+
438+ if batch :
439+ results = await self ._process_batch (batch , benchmark , end_time )
440+ count_completed += len (
441+ [r for r in results if not isinstance (r , TextGenerationError )]
442+ )
443+ yield results
444+
445+ async def _process_batch (
446+ self ,
447+ batch : List [Tuple [TextGenerationRequest , float ]],
448+ benchmark : TextGenerationBenchmark ,
449+ end_time : float ,
450+ ) -> List [Union [TextGenerationResult , TextGenerationError ]]:
451+ try :
452+
453+ benchmark .request_started ()
454+ tasks = [
455+ self ._delayed_submit (request , submit_at , end_time )
456+ for request , submit_at in batch
457+ ]
458+
459+ timeout = end_time - time .time () if end_time < math .inf else None
460+
461+ results = await asyncio .wait_for (
462+ asyncio .gather (* tasks , return_exceptions = True ), timeout = timeout
463+ )
464+ processed_results = []
465+ for (req , _ ), result in zip (batch , results ):
466+ if isinstance (result , Exception ):
467+ error = TextGenerationError (request = req , message = str (result ))
468+ benchmark .request_completed (error )
469+ processed_results .append (error )
470+ else :
471+ benchmark .request_completed (result )
472+ processed_results .append (result )
473+ return processed_results
474+ except asyncio .TimeoutError :
475+ return [
476+ TextGenerationError (request = req , message = "Batch timeout" )
477+ for req , _ in batch
478+ ]
479+
480+ async def _delayed_submit (
481+ self , request : TextGenerationRequest , submit_at : float , end_time : float
482+ ) -> Union [TextGenerationResult , TextGenerationError ]:
483+ if submit_at > time .time ():
484+ await asyncio .sleep (submit_at - time .time ())
485+ if time .time () >= end_time :
486+ raise asyncio .TimeoutError ("Submission time exceeded end_time" )
487+
488+ return await self ._worker .submit (request )
0 commit comments