11"""
2- Unit tests for GenerationRequest, GenerationResponse, GenerationRequestTimings .
2+ Unit tests for GenerationRequest, GenerationResponse, RequestTimings .
33"""
44
55from __future__ import annotations
99import pytest
1010from pydantic import ValidationError
1111
12- from guidellm .scheduler import MeasuredRequestTimings
13- from guidellm .schemas .response import (
12+ from guidellm .schemas import (
1413 GenerationRequest ,
15- GenerationRequestTimings ,
1614 GenerationResponse ,
15+ RequestTimings ,
1716)
17+ from guidellm .schemas .request import GenerationRequestArguments
1818from guidellm .utils import StandardBaseModel
1919
2020
@@ -23,17 +23,18 @@ class TestGenerationRequest:
2323
2424 @pytest .fixture (
2525 params = [
26- {"content" : "test content" },
2726 {
28- "content" : ["message1" , "message2" ],
27+ "request_type" : "text_completions" ,
28+ "arguments" : GenerationRequestArguments (),
29+ },
30+ {
2931 "request_type" : "chat_completions" ,
30- "params " : {"temperature" : 0.7 },
32+ "arguments " : GenerationRequestArguments ( body = {"temperature" : 0.7 }) ,
3133 },
3234 {
3335 "request_id" : "custom-id" ,
34- "content" : {"role" : "user" , "content" : "test" },
35- "stats" : {"prompt_tokens" : 50 },
36- "constraints" : {"output_tokens" : 100 },
36+ "request_type" : "text_completions" ,
37+ "arguments" : GenerationRequestArguments (body = {"prompt" : "test" }),
3738 },
3839 ]
3940 )
@@ -55,10 +56,9 @@ def test_class_signatures(self):
5556 expected_fields = [
5657 "request_id" ,
5758 "request_type" ,
58- "content" ,
59- "params" ,
60- "stats" ,
61- "constraints" ,
59+ "arguments" ,
60+ "input_metrics" ,
61+ "output_metrics" ,
6262 ]
6363 for field in expected_fields :
6464 assert field in fields
@@ -68,7 +68,7 @@ def test_initialization(self, valid_instances):
6868 """Test GenerationRequest initialization."""
6969 instance , constructor_args = valid_instances
7070 assert isinstance (instance , GenerationRequest )
71- assert instance .content == constructor_args ["content " ]
71+ assert instance .arguments == constructor_args ["arguments " ]
7272
7373 # Check defaults
7474 expected_request_type = constructor_args .get ("request_type" , "text_completions" )
@@ -84,21 +84,25 @@ def test_initialization(self, valid_instances):
8484 @pytest .mark .sanity
8585 def test_invalid_initialization_values (self ):
8686 """Test GenerationRequest with invalid field values."""
87- # Invalid request_type
87+ # Invalid request_type (not a string)
8888 with pytest .raises (ValidationError ):
89- GenerationRequest (content = "test" , request_type = "invalid_type" )
89+ GenerationRequest (request_type = 123 , arguments = GenerationRequestArguments () )
9090
9191 @pytest .mark .sanity
9292 def test_invalid_initialization_missing (self ):
9393 """Test GenerationRequest initialization without required field."""
9494 with pytest .raises (ValidationError ):
95- GenerationRequest () # Missing required 'content ' field
95+ GenerationRequest () # Missing required 'request_type ' field
9696
9797 @pytest .mark .smoke
9898 def test_auto_id_generation (self ):
9999 """Test that request_id is auto-generated if not provided."""
100- request1 = GenerationRequest (content = "test1" )
101- request2 = GenerationRequest (content = "test2" )
100+ request1 = GenerationRequest (
101+ request_type = "text_completions" , arguments = GenerationRequestArguments ()
102+ )
103+ request2 = GenerationRequest (
104+ request_type = "text_completions" , arguments = GenerationRequestArguments ()
105+ )
102106
103107 assert request1 .request_id != request2 .request_id
104108 assert len (request1 .request_id ) > 0
@@ -110,31 +114,40 @@ def test_auto_id_generation(self):
110114
111115 @pytest .mark .regression
112116 def test_content_types (self ):
113- """Test GenerationRequest with different content types."""
114- # String content
115- request1 = GenerationRequest (content = "string content" )
116- assert request1 .content == "string content"
117-
118- # List content
119- request2 = GenerationRequest (content = ["item1" , "item2" ])
120- assert request2 .content == ["item1" , "item2" ]
117+ """Test GenerationRequest with different argument types."""
118+ # Basic arguments
119+ request1 = GenerationRequest (
120+ request_type = "text_completions" , arguments = GenerationRequestArguments ()
121+ )
122+ assert isinstance (request1 .arguments , GenerationRequestArguments )
121123
122- # Dict content
123- dict_content = {"role" : "user" , "content" : "test" }
124- request3 = GenerationRequest (content = dict_content )
125- assert request3 .content == dict_content
124+ # Arguments with body
125+ request2 = GenerationRequest (
126+ request_type = "chat_completions" ,
127+ arguments = GenerationRequestArguments (body = {"prompt" : "test" }),
128+ )
129+ assert request2 .arguments .body == {"prompt" : "test" }
130+
131+ # Arguments with headers
132+ request3 = GenerationRequest (
133+ request_type = "text_completions" ,
134+ arguments = GenerationRequestArguments (
135+ headers = {"Authorization" : "Bearer token" }
136+ ),
137+ )
138+ assert request3 .arguments .headers == {"Authorization" : "Bearer token" }
126139
127140 @pytest .mark .sanity
128141 def test_marshalling (self , valid_instances ):
129142 """Test GenerationRequest serialization and deserialization."""
130143 instance , constructor_args = valid_instances
131144 data_dict = instance .model_dump ()
132145 assert isinstance (data_dict , dict )
133- assert data_dict [ "content" ] == constructor_args [ "content" ]
146+ assert "arguments" in data_dict
134147
135148 # Test reconstruction
136149 reconstructed = GenerationRequest .model_validate (data_dict )
137- assert reconstructed .content == instance .content
150+ assert reconstructed .arguments == instance .arguments
138151 assert reconstructed .request_type == instance .request_type
139152 assert reconstructed .request_id == instance .request_id
140153
@@ -146,18 +159,12 @@ class TestGenerationResponse:
146159 params = [
147160 {
148161 "request_id" : "test-123" ,
149- "request_args" : { "model" : " gpt-3.5-turbo"} ,
162+ "request_args" : "model= gpt-3.5-turbo" ,
150163 },
151164 {
152165 "request_id" : "test-456" ,
153- "request_args" : {"model" : "gpt-4" },
154- "value" : "Generated text" ,
155- "delta" : "new text" ,
156- "iterations" : 5 ,
157- "request_prompt_tokens" : 50 ,
158- "request_output_tokens" : 100 ,
159- "response_prompt_tokens" : 55 ,
160- "response_output_tokens" : 95 ,
166+ "request_args" : "model=gpt-4" ,
167+ "text" : "Generated text" ,
161168 },
162169 ]
163170 )
@@ -373,8 +380,8 @@ def test_marshalling(self, valid_instances):
373380 assert reconstructed .iterations == instance .iterations
374381
375382
376- class TestGenerationRequestTimings :
377- """Test cases for GenerationRequestTimings model."""
383+ class TestRequestTimings :
384+ """Test cases for RequestTimings model."""
378385
379386 @pytest .fixture (
380387 params = [
@@ -388,20 +395,20 @@ class TestGenerationRequestTimings:
388395 ]
389396 )
390397 def valid_instances (self , request ):
391- """Fixture providing valid GenerationRequestTimings instances."""
398+ """Fixture providing valid RequestTimings instances."""
392399 constructor_args = request .param
393- instance = GenerationRequestTimings (** constructor_args )
400+ instance = RequestTimings (** constructor_args )
394401 return instance , constructor_args
395402
396403 @pytest .mark .smoke
397404 def test_class_signatures (self ):
398- """Test GenerationRequestTimings inheritance and type relationships."""
399- assert issubclass (GenerationRequestTimings , MeasuredRequestTimings )
400- assert hasattr (GenerationRequestTimings , "model_dump" )
401- assert hasattr (GenerationRequestTimings , "model_validate" )
405+ """Test RequestTimings inheritance and type relationships."""
406+ assert issubclass (RequestTimings , RequestTimings )
407+ assert hasattr (RequestTimings , "model_dump" )
408+ assert hasattr (RequestTimings , "model_validate" )
402409
403- # Check inherited fields from MeasuredRequestTimings
404- fields = GenerationRequestTimings .model_fields
410+ # Check inherited fields from RequestTimings
411+ fields = RequestTimings .model_fields
405412 expected_inherited_fields = ["request_start" , "request_end" ]
406413 for field in expected_inherited_fields :
407414 assert field in fields
@@ -413,10 +420,10 @@ def test_class_signatures(self):
413420
414421 @pytest .mark .smoke
415422 def test_initialization (self , valid_instances ):
416- """Test GenerationRequestTimings initialization."""
423+ """Test RequestTimings initialization."""
417424 instance , constructor_args = valid_instances
418- assert isinstance (instance , GenerationRequestTimings )
419- assert isinstance (instance , MeasuredRequestTimings )
425+ assert isinstance (instance , RequestTimings )
426+ assert isinstance (instance , RequestTimings )
420427
421428 # Check field values
422429 expected_first = constructor_args .get ("first_iteration" )
@@ -426,40 +433,40 @@ def test_initialization(self, valid_instances):
426433
427434 @pytest .mark .sanity
428435 def test_invalid_initialization_values (self ):
429- """Test GenerationRequestTimings with invalid field values."""
436+ """Test RequestTimings with invalid field values."""
430437 # Invalid timestamp type
431438 with pytest .raises (ValidationError ):
432- GenerationRequestTimings (first_iteration = "not_float" )
439+ RequestTimings (first_iteration = "not_float" )
433440
434441 with pytest .raises (ValidationError ):
435- GenerationRequestTimings (last_iteration = "not_float" )
442+ RequestTimings (last_iteration = "not_float" )
436443
437444 @pytest .mark .smoke
438445 def test_optional_fields (self ):
439446 """Test that all timing fields are optional."""
440447 # Should be able to create with no fields
441- timings1 = GenerationRequestTimings ()
448+ timings1 = RequestTimings ()
442449 assert timings1 .first_iteration is None
443450 assert timings1 .last_iteration is None
444451
445452 # Should be able to create with only one field
446- timings2 = GenerationRequestTimings (first_iteration = 123.0 )
453+ timings2 = RequestTimings (first_iteration = 123.0 )
447454 assert timings2 .first_iteration == 123.0
448455 assert timings2 .last_iteration is None
449456
450- timings3 = GenerationRequestTimings (last_iteration = 456.0 )
457+ timings3 = RequestTimings (last_iteration = 456.0 )
451458 assert timings3 .first_iteration is None
452459 assert timings3 .last_iteration == 456.0
453460
454461 @pytest .mark .sanity
455462 def test_marshalling (self , valid_instances ):
456- """Test GenerationRequestTimings serialization and deserialization."""
463+ """Test RequestTimings serialization and deserialization."""
457464 instance , constructor_args = valid_instances
458465 data_dict = instance .model_dump ()
459466 assert isinstance (data_dict , dict )
460467
461468 # Test reconstruction
462- reconstructed = GenerationRequestTimings .model_validate (data_dict )
469+ reconstructed = RequestTimings .model_validate (data_dict )
463470 assert reconstructed .first_iteration == instance .first_iteration
464471 assert reconstructed .last_iteration == instance .last_iteration
465472 assert reconstructed .request_start == instance .request_start
0 commit comments