1
1
import itertools
2
2
from abc import ABC , abstractmethod
3
- from typing import Generic , TypeVar
3
+ from collections .abc import Sequence
4
+ from typing import Generic
4
5
5
6
from guidellm .backend .response import ResponseSummary
7
+ from guidellm .config import settings
8
+ from guidellm .preprocess .item import Item , ItemList
6
9
from guidellm .request .request import GenerationRequest
10
+ from guidellm .request .types import RequestT , ResponseT
7
11
8
12
__all__ = ["GenerativeRequestSession" , "RequestSession" ]
9
13
10
- # TODO: Replace with specific types that implement needed features
11
- RequestT = TypeVar ("RequestT" )
12
- ResponseT = TypeVar ("ResponseT" )
13
-
14
14
15
15
class RequestSession (ABC , Generic [RequestT , ResponseT ]):
16
16
@abstractmethod
@@ -30,46 +30,61 @@ def push_response(self, response: ResponseT) -> None: ...
30
30
def complete (self ) -> bool : ...
31
31
32
32
33
- # FIXME: Bad implementation. Can only handle string requests
34
33
class GenerativeRequestSession (RequestSession [GenerationRequest , ResponseSummary ]):
35
- def __init__ (self , prompts : list [ GenerationRequest ] ) -> None :
36
- if not prompts :
34
+ def __init__ (self , items : ItemList ) -> None :
35
+ if len ( items ) < 1 :
37
36
raise ValueError ("Prompts cannot be empty" )
38
37
39
- self .prompts = prompts
40
- self .responses : list [str ] = []
38
+ self .prompts : Sequence [ Item ] = items
39
+ self .responses : list [Item ] = []
41
40
42
41
def __len__ (self ) -> int :
43
42
return len (self .prompts )
44
43
45
44
def get_next_request (self ) -> GenerationRequest :
46
45
completed_responses = len (self .responses )
47
- base_request = self .prompts [completed_responses ].model_copy (deep = True )
48
- base_request .content = "" .join (
46
+
47
+ # FIXME: Can only handle string requests
48
+ content = "" .join (
49
49
itertools .chain .from_iterable (
50
- zip ((x .content for x in self .prompts ), self .responses + ["" ])
50
+ (x .value , y .value )
51
+ for x , y in zip (self .prompts , self .responses + [Item (value = "" )])
51
52
)
52
53
)
53
- base_request .stats ["prompt_tokens" ] = sum (
54
- x .stats ["prompt_tokens" ] for x in self .prompts [: completed_responses + 1 ]
54
+
55
+ prev_prompt_tokens = sum (
56
+ (x .prompt_tokens or 0 ) + (x .output_tokens or 0 ) for x in self .responses
55
57
)
56
- base_request .constraints ["output_tokens" ] = sum (
57
- x .constraints ["output_tokens" ]
58
- for x in self .prompts [: completed_responses + 1 ]
58
+ prompt_tokens = (
59
+ self .prompts [completed_responses ].prompt_tokens or 0
60
+ ) + prev_prompt_tokens
61
+
62
+ output_tokens = self .prompts [completed_responses ].output_tokens
63
+
64
+ return GenerationRequest (
65
+ request_type = settings .preferred_route ,
66
+ content = content ,
67
+ stats = (
68
+ {"prompt_tokens" : prompt_tokens } if prompt_tokens is not None else {}
69
+ ),
70
+ constraints = (
71
+ {"output_tokens" : output_tokens } if output_tokens is not None else {}
72
+ ),
59
73
)
60
74
61
- return base_request
62
-
63
75
def get_next_delay (self ) -> float :
64
76
return 0.0
65
77
66
78
def push_response (self , response : ResponseSummary ) -> None :
67
79
if len (self .responses ) < len (self .prompts ):
68
- if response .response_output_tokens is not None :
69
- self .prompts [len (self .responses )].constraints ["output_tokens" ] = (
70
- response .response_output_tokens
71
- )
72
- self .responses .append (response .value )
80
+ resp = Item (
81
+ value = response .value ,
82
+ prompt_tokens = response .response_prompt_tokens
83
+ or response .request_prompt_tokens ,
84
+ output_tokens = response .response_output_tokens
85
+ or response .request_output_tokens ,
86
+ )
87
+ self .responses .append (resp )
73
88
else :
74
89
raise ValueError ("Response list full" )
75
90
0 commit comments