1
1
import openai
2
2
from typing import Iterator , List , Optional , Dict , Any
3
- from urllib .parse import urlparse
4
3
from transformers import AutoTokenizer
5
4
from loguru import logger
6
5
from guidellm .backend import Backend , BackendTypes , GenerativeResponse
@@ -24,8 +23,10 @@ class OpenAIBackend(Backend):
24
23
:type path: Optional[str]
25
24
:param model: The OpenAI model to use, defaults to the first available model.
26
25
:type model: Optional[str]
27
- :param model_args: Additional model arguments for the request.
28
- :type model_args: Optional[Dict[str, Any]]
26
+ :param api_key: The OpenAI API key to use.
27
+ :type api_key: Optional[str]
28
+ :param request_args: Optional arguments for the OpenAI request.
29
+ :type request_args: Dict[str, Any]
29
30
"""
30
31
31
32
def __init__ (
@@ -35,21 +36,30 @@ def __init__(
35
36
port : Optional [int ] = None ,
36
37
path : Optional [str ] = None ,
37
38
model : Optional [str ] = None ,
38
- ** model_args ,
39
+ api_key : Optional [str ] = None ,
40
+ ** request_args ,
39
41
):
40
- if target :
41
- parsed_url = urlparse (target )
42
- self .host = parsed_url .hostname
43
- self .port = parsed_url .port
44
- self .path = parsed_url .path
45
- else :
46
- self .host = host
47
- self .port = port
48
- self .path = path
42
+ self .target = target
49
43
self .model = model
50
- self .model_args = model_args
51
- openai .api_key = model_args .get ("api_key" , None )
52
- logger .info (f"Initialized OpenAIBackend with model: { self .model } " )
44
+ self .request_args = request_args
45
+
46
+ if not self .target :
47
+ if not host :
48
+ raise ValueError ("Host is required if target is not provided." )
49
+
50
+ port_incl = f":{ port } " if port else ""
51
+ path_incl = path if path else ""
52
+ self .target = f"http://{ host } { port_incl } { path_incl } "
53
+
54
+ openai .api_base = self .target
55
+ openai .api_key = api_key
56
+
57
+ if not model :
58
+ self .model = self .default_model ()
59
+
60
+ logger .info (
61
+ f"Initialized OpenAIBackend with target: { self .target } and model: { self .model } "
62
+ )
53
63
54
64
def make_request (self , request : BenchmarkRequest ) -> Iterator [GenerativeResponse ]:
55
65
"""
@@ -61,14 +71,20 @@ def make_request(self, request: BenchmarkRequest) -> Iterator[GenerativeResponse
61
71
:rtype: Iterator[GenerativeResponse]
62
72
"""
63
73
logger .debug (f"Making request to OpenAI backend with prompt: { request .prompt } " )
74
+ num_gen_tokens = request .params .get ("generated_tokens" , None )
75
+ request_args = {
76
+ "n" : 1 ,
77
+ }
78
+
79
+ if num_gen_tokens :
80
+ request_args ["max_tokens" ] = num_gen_tokens
81
+ request_args ["stop" ] = None
82
+
83
+ if self .request_args :
84
+ request_args .update (self .request_args )
85
+
64
86
response = openai .Completion .create (
65
- engine = self .model or self .default_model (),
66
- prompt = request .prompt ,
67
- max_tokens = request .params .get ("max_tokens" , 100 ),
68
- n = request .params .get ("n" , 1 ),
69
- stop = request .params .get ("stop" , None ),
70
- stream = True ,
71
- ** self .model_args ,
87
+ engine = self .model , prompt = request .prompt , stream = True , ** request_args ,
72
88
)
73
89
74
90
for chunk in response :
@@ -80,8 +96,16 @@ def make_request(self, request: BenchmarkRequest) -> Iterator[GenerativeResponse
80
96
type_ = "final" ,
81
97
output = choice ["text" ],
82
98
prompt = request .prompt ,
83
- prompt_token_count = self ._token_count (request .prompt ),
84
- output_token_count = self ._token_count (choice ["text" ]),
99
+ prompt_token_count = (
100
+ request .token_count
101
+ if request .token_count
102
+ else self ._token_count (request .prompt )
103
+ ),
104
+ output_token_count = (
105
+ num_gen_tokens
106
+ if num_gen_tokens
107
+ else self ._token_count (choice ["text" ])
108
+ ),
85
109
)
86
110
break
87
111
else :
@@ -133,14 +157,6 @@ def model_tokenizer(self, model: str) -> Optional[Any]:
133
157
return None
134
158
135
159
def _token_count (self , text : str ) -> int :
136
- """
137
- Count the number of tokens in a text.
138
-
139
- :param text: The text to tokenize.
140
- :type text: str
141
- :return: The number of tokens.
142
- :rtype: int
143
- """
144
160
token_count = len (text .split ())
145
161
logger .debug (f"Token count for text '{ text } ': { token_count } " )
146
162
return token_count
0 commit comments