@@ -114,12 +114,12 @@ def __init__(
114114 logger .error (err )
115115 raise err
116116
117- self ._generator = generator
118- self ._worker = worker
119- self ._mode = mode
120- self ._rate = rate
121- self ._max_number = max_number
122- self ._max_duration = max_duration
117+ self ._generator : RequestGenerator = generator
118+ self ._worker : Backend = worker
119+ self ._mode : LoadGenerationMode = mode
120+ self ._rate : Optional [ float ] = rate
121+ self ._max_number : Optional [ int ] = max_number
122+ self ._max_duration : Optional [ float ] = max_duration
123123
124124 self ._load_generator = LoadGenerator (mode , rate )
125125
@@ -227,9 +227,7 @@ async def run(self) -> AsyncGenerator[SchedulerResult, None]:
227227 count_total = (
228228 self .max_number
229229 if self .max_number
230- else round (self .max_duration )
231- if self .max_duration
232- else 0
230+ else round (self .max_duration ) if self .max_duration else 0
233231 )
234232
235233 # yield initial result for progress tracking
@@ -246,9 +244,7 @@ async def run(self) -> AsyncGenerator[SchedulerResult, None]:
246244 count_completed = (
247245 min (run_count , self .max_number )
248246 if self .max_number
249- else round (time .time () - start_time )
250- if self .max_duration
251- else 0
247+ else round (time .time () - start_time ) if self .max_duration else 0
252248 )
253249
254250 yield SchedulerResult (
@@ -267,16 +263,16 @@ async def run(self) -> AsyncGenerator[SchedulerResult, None]:
267263 count_completed = (
268264 benchmark .request_count + benchmark .error_count
269265 if self .max_number
270- else round (time .time () - start_time )
271- if self .max_duration
272- else 0
266+ else round (time .time () - start_time ) if self .max_duration else 0
273267 ),
274268 benchmark = benchmark ,
275269 )
276270
277271 async def _run_sync (
278272 self , benchmark : TextGenerationBenchmark , end_time : float , max_number : float
279273 ) -> AsyncGenerator [Union [TextGenerationResult , TextGenerationError ], None ]:
274+ """Runs only for "synchronous" mode."""
275+
280276 for index , (request , submit_at ) in enumerate (
281277 zip (self .generator , self .load_generator .times ())
282278 ):
@@ -298,42 +294,80 @@ async def _run_sync(
298294 async def _run_async (
299295 self , benchmark : TextGenerationBenchmark , end_time : float , max_number : float
300296 ) -> AsyncGenerator [Union [TextGenerationResult , TextGenerationError ], None ]:
297+ """
298+ Notes:
299+ if the Load Generation Mode is set to 'consistent' - timestamps should
300+ not be generated in order to make as many requests as possible to
301+ simulate concurrent clients interaction.
302+ """
303+
301304 tasks = []
302305 completed = 0
303306
304- for index , (request , submit_at ) in enumerate (
305- zip (self .generator , self .load_generator .times ())
306- ):
307- while (index + 1 - completed ) >= settings .max_concurrency :
308- await asyncio .sleep (0.1 )
307+ def _completed (_task : asyncio .Task ) -> None :
308+ nonlocal completed
309+ completed += 1
310+ _res = _task .result ()
309311
310- if index >= max_number or time .time () >= end_time or submit_at >= end_time :
311- break
312+ if _res :
313+ benchmark .request_completed (_res )
314+ logger .debug ("Request completed: {}" , _res )
312315
313- logger .debug (
314- "Running asynchronous request={} at submit_at={}" ,
315- request ,
316- submit_at ,
317- )
318-
319- def _completed (_task : asyncio .Task ) -> None :
320- nonlocal completed
321- completed += 1
322- _res = _task .result ()
323-
324- if _res :
325- benchmark .request_completed (_res )
326- logger .debug ("Request completed: {}" , _res )
316+ if self .mode == "consistent" :
317+ if self .rate is None :
318+ raise ValueError (
319+ "The rate must be specified in order to provide concurrent execution"
320+ )
321+ for index , request in enumerate (self .generator ):
322+ while (index + 1 - completed ) >= settings .max_concurrency :
323+ await asyncio .sleep (0.1 )
324+
325+ if index >= max_number or time .time () >= end_time :
326+ break
327+
328+ logger .debug (f"Running concurrently request={ request } " )
329+
330+ benchmark .request_started ()
331+
332+ # Create multiple concurrent tasks
333+ tasks : list [asyncio .Task ] = []
334+ for _ in range (int (self .rate )):
335+ task : asyncio .Task = asyncio .create_task (
336+ self ._submit_task_coroutine ( # submit the call with 'Backend'
337+ request = request , submit_at = 0.0 , end_time = end_time
338+ )
339+ )
340+ task .add_done_callback (_completed )
341+ tasks .append (task )
342+ else :
343+ for index , (request , submit_at ) in enumerate (
344+ zip (self .generator , self .load_generator .times ())
345+ ):
346+ while (index + 1 - completed ) >= settings .max_concurrency :
347+ await asyncio .sleep (0.1 )
348+
349+ if (
350+ index >= max_number
351+ or time .time () >= end_time
352+ or submit_at >= end_time
353+ ):
354+ break
355+
356+ logger .debug (
357+ "Running asynchronous request={} at submit_at={}" ,
358+ request ,
359+ submit_at ,
360+ )
327361
328- benchmark .request_started ()
329- task = asyncio .create_task (
330- self ._submit_task_coroutine (request , submit_at , end_time )
331- )
332- task .add_done_callback (_completed )
333- tasks .append (task )
362+ benchmark .request_started ()
363+ task = asyncio .create_task (
364+ self ._submit_task_coroutine (request , submit_at , end_time )
365+ )
366+ task .add_done_callback (_completed )
367+ tasks .append (task )
334368
335- # release control to the event loop for other tasks
336- await asyncio .sleep (0.001 )
369+ # release control to the event loop for other tasks
370+ await asyncio .sleep (0.001 )
337371
338372 for compl_task in asyncio .as_completed (tasks ):
339373 task_res = await compl_task
0 commit comments