Skip to content

Commit 904a4dd

Browse files
committed
fix type issues
1 parent 8fae0a1 commit 904a4dd

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

src/guidellm/presentation/data_models.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import random
22
from collections import defaultdict
33
from math import ceil
4-
from typing import List, Optional, Tuple, TYPE_CHECKING
4+
from typing import List, Optional, Tuple, TYPE_CHECKING, Union
55

66
from pydantic import BaseModel, computed_field
77

@@ -11,12 +11,12 @@
1111
from guidellm.objects.statistics import DistributionSummary
1212

1313
class Bucket(BaseModel):
14-
value: float
14+
value: Union[float, int]
1515
count: int
1616

1717
@staticmethod
1818
def from_data(
19-
data: List[float],
19+
data: Union[List[float], List[int]],
2020
bucket_width: Optional[float] = None,
2121
n_buckets: Optional[int] = None,
2222
) -> Tuple[List["Bucket"], float]:
@@ -34,7 +34,7 @@ def from_data(
3434
else:
3535
n_buckets = ceil(range_v / bucket_width)
3636

37-
bucket_counts = defaultdict(int)
37+
bucket_counts: defaultdict[Union[float, int], int] = defaultdict(int)
3838
for val in data:
3939
idx = int((val - min_v) // bucket_width)
4040
if idx >= n_buckets:
@@ -125,10 +125,10 @@ def from_benchmarks(cls, benchmarks: list["GenerativeBenchmark"]):
125125
]
126126

127127
prompt_tokens = [
128-
req.prompt_tokens for bm in benchmarks for req in bm.requests.successful
128+
float(req.prompt_tokens) for bm in benchmarks for req in bm.requests.successful
129129
]
130130
output_tokens = [
131-
req.output_tokens for bm in benchmarks for req in bm.requests.successful
131+
float(req.output_tokens) for bm in benchmarks for req in bm.requests.successful
132132
]
133133

134134
prompt_token_buckets, _prompt_token_bucket_width = Bucket.from_data(
@@ -184,7 +184,6 @@ class TabularDistributionSummary(DistributionSummary):
184184
`percentile_rows` helper.
185185
"""
186186

187-
@computed_field
188187
@property
189188
def percentile_rows(self) -> list[dict[str, float]]:
190189
rows = [

0 commit comments

Comments
 (0)