Skip to content

Commit dd7208f

Browse files
committed
Fixup for backend unit tests
Signed-off-by: Samuel Monson <[email protected]>
1 parent 870f195 commit dd7208f

File tree

3 files changed

+139
-929
lines changed

3 files changed

+139
-929
lines changed

tests/unit/backends/test_backend.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
import pytest
1212

1313
from guidellm.backends.backend import Backend, BackendType
14-
from guidellm.scheduler import BackendInterface
1514
from guidellm.schemas import (
1615
GenerationRequest,
17-
RequestTimings,
16+
RequestInfo,
1817
)
18+
from guidellm.schemas.request import GenerationRequestArguments
1919
from guidellm.utils import RegistryMixin
2020
from tests.unit.testing_utils import async_timeout
2121

@@ -69,7 +69,11 @@ async def default_model(self) -> str | None:
6969
def test_class_signatures(self):
7070
"""Test Backend inheritance and type relationships."""
7171
assert issubclass(Backend, RegistryMixin)
72-
assert isinstance(Backend, BackendInterface)
72+
# Check that Backend implements BackendInterface methods
73+
assert hasattr(Backend, "resolve")
74+
assert hasattr(Backend, "process_startup")
75+
assert hasattr(Backend, "process_shutdown")
76+
assert hasattr(Backend, "validate")
7377
assert hasattr(Backend, "create")
7478
assert hasattr(Backend, "register")
7579
assert hasattr(Backend, "get_registered_object")
@@ -149,15 +153,10 @@ async def test_interface_compatibility(self, valid_instances):
149153
instance, _ = valid_instances
150154

151155
# Test that Backend uses the correct generic types
152-
request = GenerationRequest(content="test")
153-
request_info = ScheduledRequestInfo(
154-
request_id="test-id",
155-
status="pending",
156-
scheduler_node_id=1,
157-
scheduler_process_id=1,
158-
scheduler_start_time=123.0,
159-
request_timings=RequestTimings(),
156+
request = GenerationRequest(
157+
request_type="text_completions", arguments=GenerationRequestArguments()
160158
)
159+
request_info = RequestInfo(request_id="test-id")
161160

162161
# Test resolve method
163162
async for response, info in instance.resolve(request, request_info):

tests/unit/backends/test_objects.py

Lines changed: 28 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from guidellm.schemas import (
1313
GenerationRequest,
1414
GenerationResponse,
15+
RequestInfo,
1516
RequestTimings,
1617
)
1718
from guidellm.schemas.request import GenerationRequestArguments
@@ -186,23 +187,15 @@ def test_class_signatures(self):
186187
expected_fields = [
187188
"request_id",
188189
"request_args",
189-
"value",
190-
"delta",
191-
"iterations",
192-
"request_prompt_tokens",
193-
"request_output_tokens",
194-
"response_prompt_tokens",
195-
"response_output_tokens",
190+
"text",
191+
"input_metrics",
192+
"output_metrics",
196193
]
197194
for field in expected_fields:
198195
assert field in fields
199196

200-
# Check properties exist
201-
assert hasattr(GenerationResponse, "prompt_tokens")
202-
assert hasattr(GenerationResponse, "output_tokens")
203-
assert hasattr(GenerationResponse, "total_tokens")
204-
assert hasattr(GenerationResponse, "preferred_prompt_tokens")
205-
assert hasattr(GenerationResponse, "preferred_output_tokens")
197+
# Check methods exist
198+
assert hasattr(GenerationResponse, "compile_stats")
206199

207200
@pytest.mark.smoke
208201
def test_initialization(self, valid_instances):
@@ -213,12 +206,12 @@ def test_initialization(self, valid_instances):
213206
assert instance.request_args == constructor_args["request_args"]
214207

215208
# Check defaults for optional fields
216-
if "value" not in constructor_args:
217-
assert instance.value is None
218-
if "delta" not in constructor_args:
219-
assert instance.delta is None
220-
if "iterations" not in constructor_args:
221-
assert instance.iterations == 0
209+
if "text" not in constructor_args:
210+
assert instance.text is None
211+
212+
# Check default metrics
213+
assert hasattr(instance, "input_metrics")
214+
assert hasattr(instance, "output_metrics")
222215

223216
@pytest.mark.sanity
224217
def test_invalid_initialization_values(self):
@@ -237,131 +230,27 @@ def test_invalid_initialization_missing(self):
237230
GenerationResponse(request_id="test") # Missing request_args
238231

239232
@pytest.mark.smoke
240-
def test_prompt_tokens_property(self):
241-
"""Test prompt_tokens property logic."""
242-
# When both are available, prefers response_prompt_tokens
243-
response1 = GenerationResponse(
244-
request_id="test",
245-
request_args={},
246-
request_prompt_tokens=50,
247-
response_prompt_tokens=55,
248-
)
249-
assert response1.prompt_tokens == 55
250-
251-
# When only request_prompt_tokens is available
252-
response2 = GenerationResponse(
253-
request_id="test", request_args={}, request_prompt_tokens=50
254-
)
255-
assert response2.prompt_tokens == 50
256-
257-
# When only response_prompt_tokens is available
258-
response3 = GenerationResponse(
259-
request_id="test", request_args={}, response_prompt_tokens=55
260-
)
261-
assert response3.prompt_tokens == 55
262-
263-
# When neither is available
264-
response4 = GenerationResponse(request_id="test", request_args={})
265-
assert response4.prompt_tokens is None
266-
267-
@pytest.mark.smoke
268-
def test_output_tokens_property(self):
269-
"""Test output_tokens property logic."""
270-
# When both are available, prefers response_output_tokens
271-
response1 = GenerationResponse(
272-
request_id="test",
273-
request_args={},
274-
request_output_tokens=100,
275-
response_output_tokens=95,
276-
)
277-
assert response1.output_tokens == 95
278-
279-
# When only request_output_tokens is available
280-
response2 = GenerationResponse(
281-
request_id="test", request_args={}, request_output_tokens=100
282-
)
283-
assert response2.output_tokens == 100
284-
285-
# When only response_output_tokens is available
286-
response3 = GenerationResponse(
287-
request_id="test", request_args={}, response_output_tokens=95
288-
)
289-
assert response3.output_tokens == 95
290-
291-
# When neither is available
292-
response4 = GenerationResponse(request_id="test", request_args={})
293-
assert response4.output_tokens is None
294-
295-
@pytest.mark.smoke
296-
def test_total_tokens_property(self):
297-
"""Test total_tokens property calculation."""
298-
# When both prompt and output tokens are available
299-
response1 = GenerationResponse(
300-
request_id="test",
301-
request_args={},
302-
response_prompt_tokens=50,
303-
response_output_tokens=100,
304-
)
305-
assert response1.total_tokens == 150
306-
307-
# When one is missing
308-
response2 = GenerationResponse(
309-
request_id="test", request_args={}, response_prompt_tokens=50
310-
)
311-
assert response2.total_tokens is None
312-
313-
# When both are missing
314-
response3 = GenerationResponse(request_id="test", request_args={})
315-
assert response3.total_tokens is None
233+
def test_compile_stats_method(self):
234+
"""Test compile_stats method functionality."""
235+
from guidellm.schemas.request import GenerationRequestArguments
316236

317-
@pytest.mark.smoke
318-
@pytest.mark.parametrize(
319-
("preferred_source", "expected_prompt", "expected_output"),
320-
[
321-
("request", 50, 100),
322-
("response", 55, 95),
323-
],
324-
)
325-
def test_preferred_token_methods(
326-
self, preferred_source, expected_prompt, expected_output
327-
):
328-
"""Test preferred_*_tokens methods."""
329237
response = GenerationResponse(
330-
request_id="test",
331-
request_args={},
332-
request_prompt_tokens=50,
333-
request_output_tokens=100,
334-
response_prompt_tokens=55,
335-
response_output_tokens=95,
238+
request_id="test-123", request_args="test_args", text="Generated response"
336239
)
337240

338-
assert response.preferred_prompt_tokens(preferred_source) == expected_prompt
339-
assert response.preferred_output_tokens(preferred_source) == expected_output
340-
341-
@pytest.mark.regression
342-
def test_preferred_tokens_fallback(self):
343-
"""Test preferred_*_tokens methods with fallback logic."""
344-
# Only response tokens available
345-
response1 = GenerationResponse(
346-
request_id="test",
347-
request_args={},
348-
response_prompt_tokens=55,
349-
response_output_tokens=95,
241+
request = GenerationRequest(
242+
request_id="test-123",
243+
request_type="text_completions",
244+
arguments=GenerationRequestArguments(),
350245
)
351246

352-
assert response1.preferred_prompt_tokens("request") == 55 # Falls back
353-
assert response1.preferred_output_tokens("request") == 95 # Falls back
354-
355-
# Only request tokens available
356-
response2 = GenerationResponse(
357-
request_id="test",
358-
request_args={},
359-
request_prompt_tokens=50,
360-
request_output_tokens=100,
361-
)
247+
request_info = RequestInfo(request_id="test-123")
362248

363-
assert response2.preferred_prompt_tokens("response") == 50 # Falls back
364-
assert response2.preferred_output_tokens("response") == 100 # Falls back
249+
# Test that compile_stats works
250+
stats = response.compile_stats(request, request_info)
251+
assert stats is not None
252+
assert hasattr(stats, "request_id")
253+
assert stats.request_id == "test-123"
365254

366255
@pytest.mark.sanity
367256
def test_marshalling(self, valid_instances):
@@ -376,8 +265,8 @@ def test_marshalling(self, valid_instances):
376265
reconstructed = GenerationResponse.model_validate(data_dict)
377266
assert reconstructed.request_id == instance.request_id
378267
assert reconstructed.request_args == instance.request_args
379-
assert reconstructed.value == instance.value
380-
assert reconstructed.iterations == instance.iterations
268+
if hasattr(instance, "text"):
269+
assert reconstructed.text == instance.text
381270

382271

383272
class TestRequestTimings:

0 commit comments

Comments
 (0)