Skip to content

Commit 870f195

Browse files
committed
Ensure all tests can run
Signed-off-by: Samuel Monson <[email protected]>
1 parent e1aedd3 commit 870f195

14 files changed

+371
-486
lines changed

tests/unit/backends/test_backend.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
import pytest
1212

1313
from guidellm.backends.backend import Backend, BackendType
14-
from guidellm.scheduler import BackendInterface, ScheduledRequestInfo
15-
from guidellm.schemas.response import (
14+
from guidellm.scheduler import BackendInterface
15+
from guidellm.schemas import (
1616
GenerationRequest,
17-
GenerationRequestTimings,
17+
RequestTimings,
1818
)
1919
from guidellm.utils import RegistryMixin
2020
from tests.unit.testing_utils import async_timeout
@@ -41,6 +41,7 @@ def valid_instances(self, request):
4141
constructor_args = request.param
4242

4343
class TestBackend(Backend):
44+
@property
4445
def info(self) -> dict[str, Any]:
4546
return {"type": self.type_}
4647

@@ -100,6 +101,7 @@ def test_invalid_initialization_values(self, field, value):
100101
"""Test Backend with invalid field values."""
101102

102103
class TestBackend(Backend):
104+
@property
103105
def info(self) -> dict[str, Any]:
104106
return {}
105107

@@ -154,7 +156,7 @@ async def test_interface_compatibility(self, valid_instances):
154156
scheduler_node_id=1,
155157
scheduler_process_id=1,
156158
scheduler_start_time=123.0,
157-
request_timings=GenerationRequestTimings(),
159+
request_timings=RequestTimings(),
158160
)
159161

160162
# Test resolve method

tests/unit/backends/test_objects.py

Lines changed: 71 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Unit tests for GenerationRequest, GenerationResponse, GenerationRequestTimings.
2+
Unit tests for GenerationRequest, GenerationResponse, RequestTimings.
33
"""
44

55
from __future__ import annotations
@@ -9,12 +9,12 @@
99
import pytest
1010
from pydantic import ValidationError
1111

12-
from guidellm.scheduler import MeasuredRequestTimings
13-
from guidellm.schemas.response import (
12+
from guidellm.schemas import (
1413
GenerationRequest,
15-
GenerationRequestTimings,
1614
GenerationResponse,
15+
RequestTimings,
1716
)
17+
from guidellm.schemas.request import GenerationRequestArguments
1818
from guidellm.utils import StandardBaseModel
1919

2020

@@ -23,17 +23,18 @@ class TestGenerationRequest:
2323

2424
@pytest.fixture(
2525
params=[
26-
{"content": "test content"},
2726
{
28-
"content": ["message1", "message2"],
27+
"request_type": "text_completions",
28+
"arguments": GenerationRequestArguments(),
29+
},
30+
{
2931
"request_type": "chat_completions",
30-
"params": {"temperature": 0.7},
32+
"arguments": GenerationRequestArguments(body={"temperature": 0.7}),
3133
},
3234
{
3335
"request_id": "custom-id",
34-
"content": {"role": "user", "content": "test"},
35-
"stats": {"prompt_tokens": 50},
36-
"constraints": {"output_tokens": 100},
36+
"request_type": "text_completions",
37+
"arguments": GenerationRequestArguments(body={"prompt": "test"}),
3738
},
3839
]
3940
)
@@ -55,10 +56,9 @@ def test_class_signatures(self):
5556
expected_fields = [
5657
"request_id",
5758
"request_type",
58-
"content",
59-
"params",
60-
"stats",
61-
"constraints",
59+
"arguments",
60+
"input_metrics",
61+
"output_metrics",
6262
]
6363
for field in expected_fields:
6464
assert field in fields
@@ -68,7 +68,7 @@ def test_initialization(self, valid_instances):
6868
"""Test GenerationRequest initialization."""
6969
instance, constructor_args = valid_instances
7070
assert isinstance(instance, GenerationRequest)
71-
assert instance.content == constructor_args["content"]
71+
assert instance.arguments == constructor_args["arguments"]
7272

7373
# Check defaults
7474
expected_request_type = constructor_args.get("request_type", "text_completions")
@@ -84,21 +84,25 @@ def test_initialization(self, valid_instances):
8484
@pytest.mark.sanity
8585
def test_invalid_initialization_values(self):
8686
"""Test GenerationRequest with invalid field values."""
87-
# Invalid request_type
87+
# Invalid request_type (not a string)
8888
with pytest.raises(ValidationError):
89-
GenerationRequest(content="test", request_type="invalid_type")
89+
GenerationRequest(request_type=123, arguments=GenerationRequestArguments())
9090

9191
@pytest.mark.sanity
9292
def test_invalid_initialization_missing(self):
9393
"""Test GenerationRequest initialization without required field."""
9494
with pytest.raises(ValidationError):
95-
GenerationRequest() # Missing required 'content' field
95+
GenerationRequest() # Missing required 'request_type' field
9696

9797
@pytest.mark.smoke
9898
def test_auto_id_generation(self):
9999
"""Test that request_id is auto-generated if not provided."""
100-
request1 = GenerationRequest(content="test1")
101-
request2 = GenerationRequest(content="test2")
100+
request1 = GenerationRequest(
101+
request_type="text_completions", arguments=GenerationRequestArguments()
102+
)
103+
request2 = GenerationRequest(
104+
request_type="text_completions", arguments=GenerationRequestArguments()
105+
)
102106

103107
assert request1.request_id != request2.request_id
104108
assert len(request1.request_id) > 0
@@ -110,31 +114,40 @@ def test_auto_id_generation(self):
110114

111115
@pytest.mark.regression
112116
def test_content_types(self):
113-
"""Test GenerationRequest with different content types."""
114-
# String content
115-
request1 = GenerationRequest(content="string content")
116-
assert request1.content == "string content"
117-
118-
# List content
119-
request2 = GenerationRequest(content=["item1", "item2"])
120-
assert request2.content == ["item1", "item2"]
117+
"""Test GenerationRequest with different argument types."""
118+
# Basic arguments
119+
request1 = GenerationRequest(
120+
request_type="text_completions", arguments=GenerationRequestArguments()
121+
)
122+
assert isinstance(request1.arguments, GenerationRequestArguments)
121123

122-
# Dict content
123-
dict_content = {"role": "user", "content": "test"}
124-
request3 = GenerationRequest(content=dict_content)
125-
assert request3.content == dict_content
124+
# Arguments with body
125+
request2 = GenerationRequest(
126+
request_type="chat_completions",
127+
arguments=GenerationRequestArguments(body={"prompt": "test"}),
128+
)
129+
assert request2.arguments.body == {"prompt": "test"}
130+
131+
# Arguments with headers
132+
request3 = GenerationRequest(
133+
request_type="text_completions",
134+
arguments=GenerationRequestArguments(
135+
headers={"Authorization": "Bearer token"}
136+
),
137+
)
138+
assert request3.arguments.headers == {"Authorization": "Bearer token"}
126139

127140
@pytest.mark.sanity
128141
def test_marshalling(self, valid_instances):
129142
"""Test GenerationRequest serialization and deserialization."""
130143
instance, constructor_args = valid_instances
131144
data_dict = instance.model_dump()
132145
assert isinstance(data_dict, dict)
133-
assert data_dict["content"] == constructor_args["content"]
146+
assert "arguments" in data_dict
134147

135148
# Test reconstruction
136149
reconstructed = GenerationRequest.model_validate(data_dict)
137-
assert reconstructed.content == instance.content
150+
assert reconstructed.arguments == instance.arguments
138151
assert reconstructed.request_type == instance.request_type
139152
assert reconstructed.request_id == instance.request_id
140153

@@ -146,18 +159,12 @@ class TestGenerationResponse:
146159
params=[
147160
{
148161
"request_id": "test-123",
149-
"request_args": {"model": "gpt-3.5-turbo"},
162+
"request_args": "model=gpt-3.5-turbo",
150163
},
151164
{
152165
"request_id": "test-456",
153-
"request_args": {"model": "gpt-4"},
154-
"value": "Generated text",
155-
"delta": "new text",
156-
"iterations": 5,
157-
"request_prompt_tokens": 50,
158-
"request_output_tokens": 100,
159-
"response_prompt_tokens": 55,
160-
"response_output_tokens": 95,
166+
"request_args": "model=gpt-4",
167+
"text": "Generated text",
161168
},
162169
]
163170
)
@@ -373,8 +380,8 @@ def test_marshalling(self, valid_instances):
373380
assert reconstructed.iterations == instance.iterations
374381

375382

376-
class TestGenerationRequestTimings:
377-
"""Test cases for GenerationRequestTimings model."""
383+
class TestRequestTimings:
384+
"""Test cases for RequestTimings model."""
378385

379386
@pytest.fixture(
380387
params=[
@@ -388,20 +395,20 @@ class TestGenerationRequestTimings:
388395
]
389396
)
390397
def valid_instances(self, request):
391-
"""Fixture providing valid GenerationRequestTimings instances."""
398+
"""Fixture providing valid RequestTimings instances."""
392399
constructor_args = request.param
393-
instance = GenerationRequestTimings(**constructor_args)
400+
instance = RequestTimings(**constructor_args)
394401
return instance, constructor_args
395402

396403
@pytest.mark.smoke
397404
def test_class_signatures(self):
398-
"""Test GenerationRequestTimings inheritance and type relationships."""
399-
assert issubclass(GenerationRequestTimings, MeasuredRequestTimings)
400-
assert hasattr(GenerationRequestTimings, "model_dump")
401-
assert hasattr(GenerationRequestTimings, "model_validate")
405+
"""Test RequestTimings inheritance and type relationships."""
406+
assert issubclass(RequestTimings, RequestTimings)
407+
assert hasattr(RequestTimings, "model_dump")
408+
assert hasattr(RequestTimings, "model_validate")
402409

403-
# Check inherited fields from MeasuredRequestTimings
404-
fields = GenerationRequestTimings.model_fields
410+
# Check inherited fields from RequestTimings
411+
fields = RequestTimings.model_fields
405412
expected_inherited_fields = ["request_start", "request_end"]
406413
for field in expected_inherited_fields:
407414
assert field in fields
@@ -413,10 +420,10 @@ def test_class_signatures(self):
413420

414421
@pytest.mark.smoke
415422
def test_initialization(self, valid_instances):
416-
"""Test GenerationRequestTimings initialization."""
423+
"""Test RequestTimings initialization."""
417424
instance, constructor_args = valid_instances
418-
assert isinstance(instance, GenerationRequestTimings)
419-
assert isinstance(instance, MeasuredRequestTimings)
425+
assert isinstance(instance, RequestTimings)
426+
assert isinstance(instance, RequestTimings)
420427

421428
# Check field values
422429
expected_first = constructor_args.get("first_iteration")
@@ -426,40 +433,40 @@ def test_initialization(self, valid_instances):
426433

427434
@pytest.mark.sanity
428435
def test_invalid_initialization_values(self):
429-
"""Test GenerationRequestTimings with invalid field values."""
436+
"""Test RequestTimings with invalid field values."""
430437
# Invalid timestamp type
431438
with pytest.raises(ValidationError):
432-
GenerationRequestTimings(first_iteration="not_float")
439+
RequestTimings(first_iteration="not_float")
433440

434441
with pytest.raises(ValidationError):
435-
GenerationRequestTimings(last_iteration="not_float")
442+
RequestTimings(last_iteration="not_float")
436443

437444
@pytest.mark.smoke
438445
def test_optional_fields(self):
439446
"""Test that all timing fields are optional."""
440447
# Should be able to create with no fields
441-
timings1 = GenerationRequestTimings()
448+
timings1 = RequestTimings()
442449
assert timings1.first_iteration is None
443450
assert timings1.last_iteration is None
444451

445452
# Should be able to create with only one field
446-
timings2 = GenerationRequestTimings(first_iteration=123.0)
453+
timings2 = RequestTimings(first_iteration=123.0)
447454
assert timings2.first_iteration == 123.0
448455
assert timings2.last_iteration is None
449456

450-
timings3 = GenerationRequestTimings(last_iteration=456.0)
457+
timings3 = RequestTimings(last_iteration=456.0)
451458
assert timings3.first_iteration is None
452459
assert timings3.last_iteration == 456.0
453460

454461
@pytest.mark.sanity
455462
def test_marshalling(self, valid_instances):
456-
"""Test GenerationRequestTimings serialization and deserialization."""
463+
"""Test RequestTimings serialization and deserialization."""
457464
instance, constructor_args = valid_instances
458465
data_dict = instance.model_dump()
459466
assert isinstance(data_dict, dict)
460467

461468
# Test reconstruction
462-
reconstructed = GenerationRequestTimings.model_validate(data_dict)
469+
reconstructed = RequestTimings.model_validate(data_dict)
463470
assert reconstructed.first_iteration == instance.first_iteration
464471
assert reconstructed.last_iteration == instance.last_iteration
465472
assert reconstructed.request_start == instance.request_start

tests/unit/backends/test_openai_backend.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,24 +43,14 @@ class TestOpenAIHTTPBackend:
4343
{
4444
"target": "https://api.openai.com",
4545
"model": "gpt-4",
46-
"api_key": "test-key",
4746
"timeout": 30.0,
48-
"stream_response": False,
4947
},
5048
{
5149
"target": "http://test-server:8080",
5250
"model": "test-model",
53-
"api_key": "Bearer test-token",
54-
"organization": "test-org",
55-
"project": "test-proj",
5651
"timeout": 120.0,
5752
"http2": False,
5853
"follow_redirects": False,
59-
"max_output_tokens": 500,
60-
"extra_query": {"param": "value"},
61-
"extra_body": {"setting": "test"},
62-
"remove_from_body": ["unwanted"],
63-
"headers": {"Custom": "header"},
6454
"verify": True,
6555
},
6656
]

tests/unit/mock_backend.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010

1111
from lorem.text import TextLorem
1212

13-
from guidellm.backend.backend import Backend
14-
from guidellm.backend.objects import (
13+
from guidellm.backends import Backend
14+
from guidellm.schemas import (
1515
GenerationRequest,
16-
GenerationRequestTimings,
1716
GenerationResponse,
17+
RequestInfo,
18+
RequestTimings,
1819
)
19-
from guidellm.scheduler import ScheduledRequestInfo
2020

2121

2222
@Backend.register("mock")
@@ -96,9 +96,9 @@ async def default_model(self) -> str | None:
9696
async def resolve(
9797
self,
9898
request: GenerationRequest,
99-
request_info: ScheduledRequestInfo,
99+
request_info: RequestInfo,
100100
history: list[tuple[GenerationRequest, GenerationResponse]] | None = None,
101-
) -> AsyncIterator[tuple[GenerationResponse, ScheduledRequestInfo]]:
101+
) -> AsyncIterator[tuple[GenerationResponse, RequestInfo]]:
102102
"""
103103
Process a generation request and yield progressive responses.
104104
@@ -133,7 +133,7 @@ async def resolve(
133133
)
134134

135135
# Initialize timings
136-
request_info.request_timings = GenerationRequestTimings()
136+
request_info.request_timings = RequestTimings()
137137
request_info.request_timings.request_start = time.time()
138138

139139
# Generate response iteratively

0 commit comments

Comments
 (0)