Skip to content

Commit 5b07e34

Browse files
committed
Fix unit tests for refactor and schemas package
Signed-off-by: Mark Kurtz <[email protected]>
1 parent a40508e commit 5b07e34

20 files changed

+4310
-836
lines changed

src/guidellm/backends/response_handlers.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,33 @@ class GenerationResponseHandlerFactory(RegistryMixin[type[GenerationResponseHand
7272
responses from different generation services.
7373
"""
7474

75+
@classmethod
76+
def create(
77+
cls,
78+
request_type: str,
79+
handler_overrides: dict[str, type[GenerationResponseHandler]] | None = None,
80+
) -> GenerationResponseHandler:
81+
"""
82+
Create a response handler class for the given request type.
83+
84+
:param request_type: The type of generation request (e.g., "text_completions")
85+
:param handler_overrides: Optional mapping of request types to handler classes
86+
to override the default registry by checking first and then falling back
87+
to the registered handlers.
88+
:return: The corresponding instantiated GenerationResponseHandler
89+
:raises ValueError: When no handler is registered for the request type
90+
"""
91+
if handler_overrides and request_type in handler_overrides:
92+
return handler_overrides[request_type]()
93+
94+
handler_cls = cls.get_registered_object(request_type)
95+
if not handler_cls:
96+
raise ValueError(
97+
f"No response handler registered for type '{request_type}'."
98+
)
99+
100+
return handler_cls()
101+
75102

76103
@GenerationResponseHandlerFactory.register("text_completions")
77104
class TextCompletionsResponseHandler(GenerationResponseHandler):

src/guidellm/schemas/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,17 @@
1010
from __future__ import annotations
1111

1212
from .base import (
13+
BaseModelT,
14+
ErroredT,
15+
IncompleteT,
1316
PydanticClassRegistryMixin,
17+
RegisterClassT,
1418
ReloadableBaseModel,
1519
StandardBaseDict,
1620
StandardBaseModel,
1721
StatusBreakdown,
22+
SuccessfulT,
23+
TotalT,
1824
)
1925
from .info import RequestInfo, RequestTimings
2026
from .request import (
@@ -33,21 +39,27 @@
3339
)
3440

3541
__all__ = [
42+
"BaseModelT",
3643
"DistributionSummary",
44+
"ErroredT",
3745
"FunctionObjT",
3846
"GenerationRequest",
3947
"GenerationRequestArguments",
4048
"GenerationResponse",
4149
"GenerativeRequestStats",
4250
"GenerativeRequestType",
51+
"IncompleteT",
4352
"Percentiles",
4453
"PydanticClassRegistryMixin",
54+
"RegisterClassT",
4555
"ReloadableBaseModel",
4656
"RequestInfo",
4757
"RequestTimings",
4858
"StandardBaseDict",
4959
"StandardBaseModel",
5060
"StatusBreakdown",
5161
"StatusDistributionSummary",
62+
"SuccessfulT",
63+
"TotalT",
5264
"UsageMetrics",
5365
]

src/guidellm/schemas/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,17 @@
1919
from guidellm.utils.registry import RegistryMixin
2020

2121
__all__ = [
22+
"BaseModelT",
23+
"ErroredT",
24+
"IncompleteT",
2225
"PydanticClassRegistryMixin",
26+
"RegisterClassT",
2327
"ReloadableBaseModel",
2428
"StandardBaseDict",
2529
"StandardBaseModel",
2630
"StatusBreakdown",
31+
"SuccessfulT",
32+
"TotalT",
2733
]
2834

2935

src/guidellm/schemas/request.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def model_combine(
7373
Merge additional request arguments into the current instance.
7474
7575
Combines method and stream fields by overwriting, while merging collection
76-
fields like headers, params, json_body, and files by extending existing values.
76+
fields like headers, params, body, and files by extending existing values.
7777
7878
:param additional: Additional arguments to merge with current instance
7979
:return: Updated instance with merged arguments
@@ -88,9 +88,10 @@ def model_combine(
8888
if (val := additional_dict.get(overwrite)) is not None:
8989
setattr(self, overwrite, val)
9090

91-
for combine in ("headers", "params", "json_body", "files"):
91+
for combine in ("headers", "params", "body", "files"):
9292
if (val := additional_dict.get(combine)) is not None:
93-
setattr(self, combine, {**getattr(self, combine, {}), **val})
93+
current = getattr(self, combine, None) or {}
94+
setattr(self, combine, {**current, **val})
9495

9596
return self
9697

src/guidellm/schemas/request_stats.py

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,11 @@ def request_start_time(self) -> float | None:
6969
"""
7070
:return: Timestamp when the request started, or None if unavailable
7171
"""
72-
return self.info.timings.request_start or self.info.timings.resolve_start
72+
return (
73+
self.info.timings.request_start
74+
if self.info.timings.request_start is not None
75+
else self.info.timings.resolve_start
76+
)
7377

7478
@computed_field # type: ignore[misc]
7579
@property
@@ -80,7 +84,11 @@ def request_end_time(self) -> float:
8084
if self.info.timings.resolve_end is None:
8185
raise ValueError("resolve_end timings should be set but is None.")
8286

83-
return self.info.timings.request_end or self.info.timings.resolve_end
87+
return (
88+
self.info.timings.request_end
89+
if self.info.timings.request_end is not None
90+
else self.info.timings.resolve_end
91+
)
8492

8593
@computed_field # type: ignore[misc]
8694
@property
@@ -90,9 +98,9 @@ def request_latency(self) -> float | None:
9098
9199
:return: Duration from request start to completion, or None if unavailable
92100
"""
93-
if not (start := self.info.timings.request_start) or not (
94-
end := self.info.timings.request_end
95-
):
101+
start = self.info.timings.request_start
102+
end = self.info.timings.request_end
103+
if start is None or end is None:
96104
return None
97105

98106
return end - start
@@ -142,9 +150,9 @@ def time_to_first_token_ms(self) -> float | None:
142150
"""
143151
:return: Time to first token generation in milliseconds, or None if unavailable
144152
"""
145-
if not (first_token := self.first_token_iteration) or not (
146-
start := self.info.timings.request_start
147-
):
153+
first_token = self.first_token_iteration
154+
start = self.info.timings.request_start
155+
if first_token is None or start is None:
148156
return None
149157

150158
return 1000 * (first_token - start)
@@ -158,9 +166,10 @@ def time_per_output_token_ms(self) -> float | None:
158166
:return: Average milliseconds per output token, or None if unavailable
159167
"""
160168
if (
161-
not (start := self.info.timings.request_start)
162-
or not (last_token := self.last_token_iteration)
163-
or not (output_tokens := self.output_tokens)
169+
(start := self.info.timings.request_start) is None
170+
or (last_token := self.last_token_iteration) is None
171+
or (output_tokens := self.output_tokens) is None
172+
or output_tokens == 0
164173
):
165174
return None
166175

@@ -174,10 +183,13 @@ def inter_token_latency_ms(self) -> float | None:
174183
175184
:return: Average milliseconds between token generations, or None if unavailable
176185
"""
186+
first_token = self.first_token_iteration
187+
last_token = self.last_token_iteration
188+
output_tokens = self.output_tokens
177189
if (
178-
not (first_token := self.first_token_iteration)
179-
or not (last_token := self.last_token_iteration)
180-
or not (output_tokens := self.output_tokens)
190+
first_token is None
191+
or last_token is None
192+
or output_tokens is None
181193
or output_tokens <= 1
182194
):
183195
return None
@@ -257,29 +269,26 @@ def token_iterations(self) -> int:
257269
return self.info.timings.token_iterations
258270

259271
@property
260-
def prompt_tokens_timing(self) -> tuple[float, float] | None:
272+
def prompt_tokens_timing(self) -> tuple[float, float]:
261273
"""
262-
:return: Tuple of (timestamp, token_count) for prompt processing, or None
263-
if unavailable
274+
:return: Tuple of (timestamp, token_count) for prompt processing
275+
:raises ValueError: If resolve_end timings are not set
264276
"""
265-
if self.request_end_time is None:
266-
# no end time, can't compute
267-
return None
268-
269277
return (
270-
self.first_token_iteration or self.request_end_time,
278+
(
279+
self.first_token_iteration
280+
if self.first_token_iteration is not None
281+
else self.request_end_time
282+
),
271283
self.prompt_tokens or 0.0,
272284
)
273285

274286
@property
275287
def output_tokens_timings(self) -> list[tuple[float, float]]:
276288
"""
277289
:return: List of (timestamp, token_count) tuples for output token generations
290+
:raises ValueError: If resolve_end timings are not set
278291
"""
279-
if self.request_end_time is None:
280-
# no end time, can't compute
281-
return []
282-
283292
if (
284293
self.first_token_iteration is None
285294
or self.last_token_iteration is None
@@ -288,7 +297,11 @@ def output_tokens_timings(self) -> list[tuple[float, float]]:
288297
# No iteration data, return single timing at end with all tokens
289298
return [
290299
(
291-
self.last_token_iteration or self.request_end_time,
300+
(
301+
self.last_token_iteration
302+
if self.last_token_iteration is not None
303+
else self.request_end_time
304+
),
292305
self.output_tokens or 0.0,
293306
)
294307
]

src/guidellm/schemas/statistics.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -210,13 +210,7 @@ def from_pdf(
210210
count = len(pdf)
211211

212212
total_sum = mean * count
213-
214-
if include_pdf is False:
215-
sampled_pdf = None
216-
elif include_pdf is True:
217-
sampled_pdf = pdf.tolist()
218-
else:
219-
sampled_pdf = []
213+
sampled_pdf = cls._sample_pdf(pdf, include_pdf)
220214

221215
return DistributionSummary(
222216
mean=mean,
@@ -232,6 +226,28 @@ def from_pdf(
232226
pdf=sampled_pdf,
233227
)
234228

229+
@classmethod
230+
def _sample_pdf(
231+
cls, pdf: np.ndarray, include_pdf: bool | int
232+
) -> list[tuple[float, float]] | None:
233+
"""
234+
Sample PDF based on include_pdf parameter.
235+
236+
:param pdf: PDF array to sample
237+
:param include_pdf: False for None, True for full, int for sampled size
238+
:return: Sampled PDF as list of tuples or None
239+
"""
240+
if include_pdf is False:
241+
return None
242+
if include_pdf is True:
243+
return pdf.tolist()
244+
if isinstance(include_pdf, int) and include_pdf > 0:
245+
if len(pdf) <= include_pdf:
246+
return pdf.tolist()
247+
sample_indices = np.linspace(0, len(pdf) - 1, include_pdf, dtype=int)
248+
return pdf[sample_indices].tolist()
249+
return []
250+
235251
@classmethod
236252
def from_values(
237253
cls,

0 commit comments

Comments
 (0)