Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions src/guidellm/executor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,26 @@ def __init__(
rate: Optional[Union[float, Sequence[float]]] = None,
max_number: Optional[int] = None,
max_duration: Optional[float] = None,
workers: int = 1,
):
self._backend = backend
self._generator = request_generator
self._max_number = max_number
self._max_duration = max_duration
self._profile_generator = ProfileGenerator(mode=mode, rate=rate)
self._workers = workers
logger.info("Executor initialized with mode: {}, rate: {}", mode, rate)

@property
def workers(self) -> int:
"""
Returns the number of concurrent workers (async tasks).

:return: number of concurrent tasks
:rtype: int
"""
return self._workers

@property
def backend(self) -> Backend:
"""
Expand Down Expand Up @@ -154,8 +166,9 @@ async def run(self) -> AsyncGenerator[ExecutorResult, None]:
# limits args
"max_number": self.max_number,
"max_duration": self.max_duration,
"workers": self.workers,
}
profile_index = -1
self.profile_index = -1
logger.info("Starting Executor run")

yield ExecutorResult(
Expand All @@ -175,8 +188,9 @@ async def run(self) -> AsyncGenerator[ExecutorResult, None]:
rate=profile.load_gen_rate,
max_number=self.max_number or profile.args.get("max_number", None),
max_duration=self.max_duration,
concurrent_tasks=self.workers,
)
profile_index += 1
self.profile_index += 1

logger.info(
"Scheduling tasks with mode: {}, rate: {}",
Expand All @@ -199,7 +213,7 @@ async def run(self) -> AsyncGenerator[ExecutorResult, None]:
generation_modes=self.profile_generator.profile_generation_modes,
report=report,
scheduler_result=scheduler_result,
current_index=profile_index,
current_index=self.profile_index,
current_profile=profile,
)

Expand Down
11 changes: 11 additions & 0 deletions src/guidellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,12 @@
"the server's performance to stabilize."
),
)
@click.option(
"--workers",
type=int,
default=1,
help="The maximum number of concurrent workers to be created.",
)
@click.option(
"--output-path",
type=str,
Expand Down Expand Up @@ -162,6 +168,7 @@ def generate_benchmark_report_cli(
rate: Optional[float],
max_seconds: Optional[int],
max_requests: Union[Literal["dataset"], int, None],
workers: int,
output_path: str,
enable_continuous_refresh: bool,
):
Expand All @@ -179,6 +186,7 @@ def generate_benchmark_report_cli(
rate=rate,
max_seconds=max_seconds,
max_requests=max_requests,
workers=workers,
output_path=output_path,
cont_refresh_table=enable_continuous_refresh,
)
Expand All @@ -195,6 +203,7 @@ def generate_benchmark_report(
rate: Optional[float],
max_seconds: Optional[int],
max_requests: Union[Literal["dataset"], int, None],
workers: int,
output_path: str,
cont_refresh_table: bool,
) -> GuidanceReport:
Expand All @@ -215,6 +224,7 @@ def generate_benchmark_report(
:param rate: The specific request rate for constant and poisson rate types.
:param max_seconds: Maximum duration for each benchmark run in seconds.
:param max_requests: Maximum number of requests per benchmark run.
:param workers: Maximum number of concurrent workers.
:param output_path: Path to save the output report file.
:param cont_refresh_table: Continually refresh the table in the CLI
until the user exits.
Expand Down Expand Up @@ -269,6 +279,7 @@ def generate_benchmark_report(
len(request_generator) if max_requests == "dataset" else max_requests
),
max_duration=max_seconds,
workers=workers,
)

# Run executor
Expand Down
62 changes: 40 additions & 22 deletions src/guidellm/scheduler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(
rate: Optional[float] = None,
max_number: Optional[int] = None,
max_duration: Optional[float] = None,
concurrent_tasks: int = 1,
):
logger.info(
"Scheduler initialized with params: generator={}, worker={}, mode={}, "
Expand Down Expand Up @@ -121,6 +122,9 @@ def __init__(
self._max_number = max_number
self._max_duration = max_duration

self._concurrent_tasks = concurrent_tasks
self._tasks: list[asyncio.Task] = []

self._load_generator = LoadGenerator(mode, rate)

@property
Expand Down Expand Up @@ -193,6 +197,17 @@ def load_generator(self) -> LoadGenerator:
"""
return self._load_generator

@property
def concurrent_tasks(self) -> int:
"""
The number of concurrent tasks to be running.

:return: the number of concurrent tasks
:rtype: int
"""

return self._concurrent_tasks

@property
def benchmark_mode(self) -> Literal["asynchronous", "synchronous", "throughput"]:
"""
Expand Down Expand Up @@ -227,9 +242,7 @@ async def run(self) -> AsyncGenerator[SchedulerResult, None]:
count_total = (
self.max_number
if self.max_number
else round(self.max_duration)
if self.max_duration
else 0
else round(self.max_duration) if self.max_duration else 0
)

# yield initial result for progress tracking
Expand All @@ -246,9 +259,7 @@ async def run(self) -> AsyncGenerator[SchedulerResult, None]:
count_completed = (
min(run_count, self.max_number)
if self.max_number
else round(time.time() - start_time)
if self.max_duration
else 0
else round(time.time() - start_time) if self.max_duration else 0
)

yield SchedulerResult(
Expand All @@ -267,9 +278,7 @@ async def run(self) -> AsyncGenerator[SchedulerResult, None]:
count_completed=(
benchmark.request_count + benchmark.error_count
if self.max_number
else round(time.time() - start_time)
if self.max_duration
else 0
else round(time.time() - start_time) if self.max_duration else 0
),
benchmark=benchmark,
)
Expand All @@ -295,10 +304,9 @@ async def _run_sync(
logger.debug("Request completed with output: {}", result)
yield result

async def _run_async(
async def _concurrent_worker(
self, benchmark: TextGenerationBenchmark, end_time: float, max_number: float
) -> AsyncGenerator[Union[TextGenerationResult, TextGenerationError], None]:
tasks = []
):
completed = 0

for index, (request, submit_at) in enumerate(
Expand All @@ -310,32 +318,42 @@ async def _run_async(
if index >= max_number or time.time() >= end_time or submit_at >= end_time:
break

logger.debug(
"Running asynchronous request={} at submit_at={}",
request,
submit_at,
)
logger.debug(f"Running asynchronous {request=} at {submit_at=}")

def _completed(_task: asyncio.Task) -> None:
nonlocal completed
completed += 1

_res = _task.result()

if _res:
benchmark.request_completed(_res)
logger.debug("Request completed: {}", _res)
logger.debug(f"Request completed: {_res}")

benchmark.request_started()
task = asyncio.create_task(
self._submit_task_coroutine(request, submit_at, end_time)
)
task.add_done_callback(_completed)
tasks.append(task)
self._tasks.append(task)

# release control to the event loop
await asyncio.sleep(0)

async def _run_async(
self, benchmark: TextGenerationBenchmark, end_time: float, max_number: float
) -> AsyncGenerator[Union[TextGenerationResult, TextGenerationError], None]:

tasks = [
asyncio.create_task(
self._concurrent_worker(benchmark, end_time, max_number)
)
for _ in range(self.concurrent_tasks)
]

# release control to the event loop for other tasks
await asyncio.sleep(0.001)
await asyncio.gather(*tasks)

for compl_task in asyncio.as_completed(tasks):
for compl_task in asyncio.as_completed(self._tasks):
task_res = await compl_task
if task_res is not None:
yield task_res
Expand Down
Loading