Skip to content

Commit 4183512

Browse files
committed
Fixes for json / yaml output
1 parent 649a86d commit 4183512

File tree

7 files changed

+236
-120
lines changed

7 files changed

+236
-120
lines changed

src/guidellm/benchmark/benchmark.py

Lines changed: 103 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,36 @@
11
import random
22
import uuid
3-
from typing import Any, Dict, List, Literal, Optional, TypeVar
3+
from typing import Any, Dict, List, Literal, Optional, TypeVar, Union
44

55
from pydantic import Field, computed_field
66

7-
from guidellm.benchmark.profile import Profile
7+
from guidellm.benchmark.profile import (
8+
AsyncProfile,
9+
ConcurrentProfile,
10+
Profile,
11+
SweepProfile,
12+
SynchronousProfile,
13+
ThroughputProfile,
14+
)
815
from guidellm.objects import (
916
StandardBaseModel,
1017
StatusDistributionSummary,
1118
)
12-
from guidellm.scheduler import SchedulerRequestInfo, SchedulingStrategy
19+
from guidellm.request import (
20+
GenerativeRequestLoaderDescription,
21+
RequestLoaderDescription,
22+
)
23+
from guidellm.scheduler import (
24+
AsyncConstantStrategy,
25+
AsyncPoissonStrategy,
26+
ConcurrentStrategy,
27+
GenerativeRequestsWorkerDescription,
28+
SchedulerRequestInfo,
29+
SchedulingStrategy,
30+
SynchronousStrategy,
31+
ThroughputStrategy,
32+
WorkerDescription,
33+
)
1334

1435
__all__ = [
1536
"BENCH",
@@ -28,19 +49,36 @@ class BenchmarkArgs(StandardBaseModel):
2849
and how data was collected for it.
2950
"""
3051

31-
profile: Profile = Field(
52+
profile: Union[
53+
AsyncProfile,
54+
SweepProfile,
55+
ConcurrentProfile,
56+
ThroughputProfile,
57+
SynchronousProfile,
58+
Profile,
59+
] = Field(
3260
description=(
3361
"The profile used for the entire benchmark run that the strategy for "
3462
"this benchmark was pulled from."
35-
)
63+
),
64+
discriminator="type_",
3665
)
3766
strategy_index: int = Field(
3867
description=(
3968
"The index of the strategy in the profile that was used for this benchmark."
4069
)
4170
)
42-
strategy: SchedulingStrategy = Field(
43-
description="The scheduling strategy used to run this benchmark. "
71+
strategy: Union[
72+
ConcurrentStrategy,
73+
SchedulingStrategy,
74+
ThroughputStrategy,
75+
SynchronousStrategy,
76+
AsyncPoissonStrategy,
77+
AsyncConstantStrategy,
78+
SchedulingStrategy,
79+
] = Field(
80+
description="The scheduling strategy used to run this benchmark. ",
81+
discriminator="type_",
4482
)
4583
max_number: Optional[int] = Field(
4684
description="The maximum number of requests to run for this benchmark, if any."
@@ -208,6 +246,7 @@ class Benchmark(StandardBaseModel):
208246
what rates and concurrency values to use for subsequent strategies.
209247
"""
210248

249+
type_: Literal["benchmark"] = "benchmark"
211250
id_: str = Field(
212251
default_factory=lambda: str(uuid.uuid4()),
213252
description="The unique identifier for the benchmark.",
@@ -228,17 +267,23 @@ class Benchmark(StandardBaseModel):
228267
"The process statistics for the entire benchmark run across all requests."
229268
)
230269
)
231-
worker: Optional[StandardBaseModel] = Field(
232-
description=(
233-
"The description and specifics for the worker used to resolve requests "
234-
"for this benchmark."
270+
worker: Optional[Union[GenerativeRequestsWorkerDescription, WorkerDescription]] = (
271+
Field(
272+
description=(
273+
"The description and specifics for the worker used to resolve requests "
274+
"for this benchmark."
275+
),
276+
discriminator="type_",
235277
)
236278
)
237-
request_loader: Optional[StandardBaseModel] = Field(
279+
request_loader: Optional[
280+
Union[GenerativeRequestLoaderDescription, RequestLoaderDescription]
281+
] = Field(
238282
description=(
239283
"The description and specifics for the request loader used to create "
240284
"requests for this benchmark."
241-
)
285+
),
286+
discriminator="type_",
242287
)
243288
extras: Dict[str, Any] = Field(
244289
description=(
@@ -263,6 +308,7 @@ class GenerativeTextResponseStats(StandardBaseModel):
263308
statistics for a generative text response.
264309
"""
265310

311+
type_: Literal["generative_text_response"] = "generative_text_response"
266312
request_id: str = Field(
267313
description="The unique identifier for the request.",
268314
)
@@ -378,6 +424,7 @@ class GenerativeTextErrorStats(GenerativeTextResponseStats):
378424
error message and optional properties given the error occurred.
379425
"""
380426

427+
type_: Literal["generative_text_error"] = "generative_text_error"
381428
error: str = Field(
382429
description=(
383430
"The error message for the error that occurred while making the request."
@@ -466,6 +513,7 @@ class GenerativeBenchmark(Benchmark):
466513
and end times for the benchmark, and the statistics for the requests and responses.
467514
"""
468515

516+
type_: Literal["generative_benchmark"] = "generative_benchmark"
469517
successful_total: int = Field(
470518
description=(
471519
"The total number of completed requests in the benchmark, "
@@ -495,7 +543,7 @@ class GenerativeBenchmark(Benchmark):
495543
"the benchmark. None if no sampling was applied."
496544
),
497545
)
498-
incomplete_requests: List[GenerativeTextResponseStats] = Field(
546+
incomplete_requests: List[GenerativeTextErrorStats] = Field(
499547
description="The list of incomplete requests.",
500548
)
501549
errored_total: int = Field(
@@ -521,7 +569,7 @@ class GenerativeBenchmark(Benchmark):
521569
description="The end time of the last request for the benchmark.",
522570
)
523571

524-
requests_latency: StatusDistributionSummary = Field(
572+
request_latency: StatusDistributionSummary = Field(
525573
description="The distribution of latencies for the completed requests.",
526574
)
527575
prompt_token_count: StatusDistributionSummary = Field(
@@ -536,20 +584,20 @@ class GenerativeBenchmark(Benchmark):
536584
"errored, and all requests."
537585
)
538586
)
539-
times_to_first_token_ms: StatusDistributionSummary = Field(
587+
time_to_first_token_ms: StatusDistributionSummary = Field(
540588
description=(
541589
"The distribution of latencies to receiving the first token in "
542590
"milliseconds for completed, errored, and all requests."
543591
),
544592
)
545-
times_per_output_token_ms: StatusDistributionSummary = Field(
593+
time_per_output_token_ms: StatusDistributionSummary = Field(
546594
description=(
547595
"The distribution of latencies per output token in milliseconds for "
548596
"completed, errored, and all requests. "
549597
"This includes the time to generate the first token and all other tokens."
550598
),
551599
)
552-
inter_token_latencies_ms: StatusDistributionSummary = Field(
600+
inter_token_latency_ms: StatusDistributionSummary = Field(
553601
description=(
554602
"The distribution of latencies between tokens in milliseconds for "
555603
"completed, errored, and all requests."
@@ -656,7 +704,7 @@ def create_sampled(
656704
def from_stats(
657705
run_id: str,
658706
successful: List[GenerativeTextResponseStats],
659-
incomplete: List[GenerativeTextResponseStats],
707+
incomplete: List[GenerativeTextErrorStats],
660708
errored: List[GenerativeTextErrorStats],
661709
args: BenchmarkArgs,
662710
run_stats: BenchmarkRunStats,
@@ -695,23 +743,38 @@ def from_stats(
695743
]
696744
start_time = min(req.start_time for req in total)
697745
end_time = max(req.end_time for req in total)
698-
total_with_prompt, total_types_with_prompt = zip(
699-
*filter(
700-
lambda val: bool(val[0].prompt_tokens),
701-
zip(total, total_types),
746+
total_with_prompt, total_types_with_prompt = (
747+
zip(*filtered)
748+
if (
749+
filtered := list(
750+
filter(lambda val: bool(val[0].prompt), zip(total, total_types))
751+
)
702752
)
703-
)
704-
total_with_output_first, total_types_with_output_first = zip(
705-
*filter(
706-
lambda val: bool(val[0].output_tokens),
707-
zip(total, total_types),
753+
else ([], [])
754+
)
755+
total_with_output_first, total_types_with_output_first = (
756+
zip(*filtered)
757+
if (
758+
filtered := list(
759+
filter(
760+
lambda val: bool(val[0].output_tokens > 0),
761+
zip(total, total_types),
762+
)
763+
)
708764
)
709-
)
710-
total_with_output_multi, total_types_with_output_multi = zip(
711-
*filter(
712-
lambda val: bool(val[0].output_tokens > 1),
713-
zip(total, total_types),
765+
else ([], [])
766+
)
767+
total_with_output_multi, total_types_with_output_multi = (
768+
zip(*filtered)
769+
if (
770+
filtered := list(
771+
filter(
772+
lambda val: bool(val[0].output_tokens > 1),
773+
zip(total, total_types),
774+
)
775+
)
714776
)
777+
else ([], [])
715778
)
716779

717780
return GenerativeBenchmark(
@@ -739,35 +802,35 @@ def from_stats(
739802
requests=[(req.start_time, req.end_time) for req in total],
740803
distribution_type="concurrency",
741804
),
742-
requests_latency=StatusDistributionSummary.from_values(
805+
request_latency=StatusDistributionSummary.from_values(
743806
value_types=total_types,
744807
values=[req.request_latency for req in total],
745808
),
746-
prompts_token_count=StatusDistributionSummary.from_values(
809+
prompt_token_count=StatusDistributionSummary.from_values(
747810
value_types=list(total_types_with_prompt),
748811
values=[req.prompt_tokens for req in total_with_prompt],
749812
),
750-
outputs_token_count=StatusDistributionSummary.from_values(
813+
output_token_count=StatusDistributionSummary.from_values(
751814
value_types=list(total_types_with_output_first),
752815
values=[req.output_tokens for req in total_with_output_first],
753816
),
754-
times_to_first_token_ms=StatusDistributionSummary.from_values(
817+
time_to_first_token_ms=StatusDistributionSummary.from_values(
755818
value_types=list(total_types_with_output_first),
756819
values=[req.time_to_first_token_ms for req in total_with_output_first],
757820
),
758-
times_per_output_tokens_ms=StatusDistributionSummary.from_values(
821+
time_per_output_token_ms=StatusDistributionSummary.from_values(
759822
value_types=list(total_types_with_output_first),
760823
values=[
761824
req.time_per_output_token_ms for req in total_with_output_first
762825
],
763826
weights=[req.output_tokens for req in total_with_output_first],
764827
),
765-
inter_token_latencies_ms=StatusDistributionSummary.from_values(
828+
inter_token_latency_ms=StatusDistributionSummary.from_values(
766829
value_types=list(total_types_with_output_multi),
767830
values=[req.inter_token_latency_ms for req in total_with_output_multi],
768831
weights=[req.output_tokens - 1 for req in total_with_output_multi],
769832
),
770-
outputs_tokens_per_second=StatusDistributionSummary.from_iterable_request_times(
833+
output_tokens_per_second=StatusDistributionSummary.from_iterable_request_times(
771834
request_types=total_types_with_output_first,
772835
requests=[
773836
(req.start_time, req.end_time) for req in total_with_output_first

0 commit comments

Comments
 (0)