1
1
import itertools
2
2
from abc import ABC , abstractmethod
3
- from typing import Generic , TypeVar
3
+ from collections .abc import Sequence
4
+ from typing import Generic
4
5
5
6
from guidellm .backend .response import ResponseSummary
7
+ from guidellm .config import settings
8
+ from guidellm .preprocess .item import Item , ItemList
6
9
from guidellm .request .request import GenerationRequest
10
+ from guidellm .request .types import RequestT , ResponseT
7
11
8
12
__all__ = ["GenerativeRequestSession" , "RequestSession" ]
9
13
10
- RequestT = TypeVar ("RequestT" )
11
- ResponseT = TypeVar ("ResponseT" )
12
-
13
14
14
15
class RequestSession (ABC , Generic [RequestT , ResponseT ]):
15
16
"""
@@ -35,44 +36,60 @@ def complete(self) -> bool: ...
35
36
36
37
37
38
class GenerativeRequestSession (RequestSession [GenerationRequest , ResponseSummary ]):
38
- def __init__ (self , prompts : list [ GenerationRequest ] ) -> None :
39
- if not prompts :
39
+ def __init__ (self , items : ItemList ) -> None :
40
+ if len ( items ) < 1 :
40
41
raise ValueError ("Prompts cannot be empty" )
41
42
42
- self .prompts = prompts
43
- self .responses : list [str ] = []
43
+ self .prompts : Sequence [ Item ] = items
44
+ self .responses : list [Item ] = []
44
45
45
46
def __len__ (self ) -> int :
46
47
return len (self .prompts )
47
48
48
49
def get_next_request (self ) -> GenerationRequest :
49
50
completed_responses = len (self .responses )
50
- base_request = self .prompts [completed_responses ].model_copy (deep = True )
51
- base_request .content = "" .join (
51
+
52
+ # FIXME: Can only handle string requests
53
+ content = "" .join (
52
54
itertools .chain .from_iterable (
53
- zip ((x .content for x in self .prompts ), self .responses + ["" ])
55
+ (x .value , y .value )
56
+ for x , y in zip (self .prompts , self .responses + [Item (value = "" )])
54
57
)
55
58
)
56
- base_request .stats ["prompt_tokens" ] = sum (
57
- x .stats ["prompt_tokens" ] for x in self .prompts [: completed_responses + 1 ]
59
+
60
+ prev_prompt_tokens = sum (
61
+ (x .prompt_tokens or 0 ) + (x .output_tokens or 0 ) for x in self .responses
58
62
)
59
- base_request .constraints ["output_tokens" ] = sum (
60
- x .constraints ["output_tokens" ]
61
- for x in self .prompts [: completed_responses + 1 ]
63
+ prompt_tokens = (
64
+ self .prompts [completed_responses ].prompt_tokens or 0
65
+ ) + prev_prompt_tokens
66
+
67
+ output_tokens = self .prompts [completed_responses ].output_tokens
68
+
69
+ return GenerationRequest (
70
+ request_type = settings .preferred_route ,
71
+ content = content ,
72
+ stats = (
73
+ {"prompt_tokens" : prompt_tokens } if prompt_tokens is not None else {}
74
+ ),
75
+ constraints = (
76
+ {"output_tokens" : output_tokens } if output_tokens is not None else {}
77
+ ),
62
78
)
63
79
64
- return base_request
65
-
66
80
def get_next_delay (self ) -> float :
67
81
return 0.0
68
82
69
83
def push_response (self , response : ResponseSummary ) -> None :
70
84
if len (self .responses ) < len (self .prompts ):
71
- if response .response_output_tokens is not None :
72
- self .prompts [len (self .responses )].constraints ["output_tokens" ] = (
73
- response .response_output_tokens
74
- )
75
- self .responses .append (response .value )
85
+ resp = Item (
86
+ value = response .value ,
87
+ prompt_tokens = response .response_prompt_tokens
88
+ or response .request_prompt_tokens ,
89
+ output_tokens = response .response_output_tokens
90
+ or response .request_output_tokens ,
91
+ )
92
+ self .responses .append (resp )
76
93
else :
77
94
raise ValueError ("Response list full" )
78
95
0 commit comments