Skip to content

Commit 297d480

Browse files
committed
General implementations for core functionality
1 parent 7b57e18 commit 297d480

31 files changed

+1839
-114
lines changed

src/guidellm/backend/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from abc import ABC, abstractmethod
22
from enum import Enum
3-
from typing import Iterator, List, Optional, Type
3+
from typing import Iterator, List, Optional, Type, Union
44
from dataclasses import dataclass
55
import uuid
66
from loguru import logger
@@ -52,7 +52,7 @@ def inner_wrapper(wrapped_class: Type["Backend"]):
5252
return inner_wrapper
5353

5454
@staticmethod
55-
def create_backend(backend_type: BackendTypes, **kwargs) -> "Backend":
55+
def create_backend(backend_type: Union[str, BackendTypes], **kwargs) -> "Backend":
5656
"""
5757
Factory method to create a backend based on the backend type.
5858

src/guidellm/backend/openai.py

Lines changed: 49 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import openai
22
from typing import Iterator, List, Optional, Dict, Any
3-
from urllib.parse import urlparse
43
from transformers import AutoTokenizer
54
from loguru import logger
65
from guidellm.backend import Backend, BackendTypes, GenerativeResponse
@@ -24,8 +23,10 @@ class OpenAIBackend(Backend):
2423
:type path: Optional[str]
2524
:param model: The OpenAI model to use, defaults to the first available model.
2625
:type model: Optional[str]
27-
:param model_args: Additional model arguments for the request.
28-
:type model_args: Optional[Dict[str, Any]]
26+
:param api_key: The OpenAI API key to use.
27+
:type api_key: Optional[str]
28+
:param request_args: Optional arguments for the OpenAI request.
29+
:type request_args: Dict[str, Any]
2930
"""
3031

3132
def __init__(
@@ -35,21 +36,30 @@ def __init__(
3536
port: Optional[int] = None,
3637
path: Optional[str] = None,
3738
model: Optional[str] = None,
38-
**model_args,
39+
api_key: Optional[str] = None,
40+
**request_args,
3941
):
40-
if target:
41-
parsed_url = urlparse(target)
42-
self.host = parsed_url.hostname
43-
self.port = parsed_url.port
44-
self.path = parsed_url.path
45-
else:
46-
self.host = host
47-
self.port = port
48-
self.path = path
42+
self.target = target
4943
self.model = model
50-
self.model_args = model_args
51-
openai.api_key = model_args.get("api_key", None)
52-
logger.info(f"Initialized OpenAIBackend with model: {self.model}")
44+
self.request_args = request_args
45+
46+
if not self.target:
47+
if not host:
48+
raise ValueError("Host is required if target is not provided.")
49+
50+
port_incl = f":{port}" if port else ""
51+
path_incl = path if path else ""
52+
self.target = f"http://{host}{port_incl}{path_incl}"
53+
54+
openai.api_base = self.target
55+
openai.api_key = api_key
56+
57+
if not model:
58+
self.model = self.default_model()
59+
60+
logger.info(
61+
f"Initialized OpenAIBackend with target: {self.target} and model: {self.model}"
62+
)
5363

5464
def make_request(self, request: BenchmarkRequest) -> Iterator[GenerativeResponse]:
5565
"""
@@ -61,14 +71,20 @@ def make_request(self, request: BenchmarkRequest) -> Iterator[GenerativeResponse
6171
:rtype: Iterator[GenerativeResponse]
6272
"""
6373
logger.debug(f"Making request to OpenAI backend with prompt: {request.prompt}")
74+
num_gen_tokens = request.params.get("generated_tokens", None)
75+
request_args = {
76+
"n": 1,
77+
}
78+
79+
if num_gen_tokens:
80+
request_args["max_tokens"] = num_gen_tokens
81+
request_args["stop"] = None
82+
83+
if self.request_args:
84+
request_args.update(self.request_args)
85+
6486
response = openai.Completion.create(
65-
engine=self.model or self.default_model(),
66-
prompt=request.prompt,
67-
max_tokens=request.params.get("max_tokens", 100),
68-
n=request.params.get("n", 1),
69-
stop=request.params.get("stop", None),
70-
stream=True,
71-
**self.model_args,
87+
engine=self.model, prompt=request.prompt, stream=True, **request_args,
7288
)
7389

7490
for chunk in response:
@@ -80,8 +96,16 @@ def make_request(self, request: BenchmarkRequest) -> Iterator[GenerativeResponse
8096
type_="final",
8197
output=choice["text"],
8298
prompt=request.prompt,
83-
prompt_token_count=self._token_count(request.prompt),
84-
output_token_count=self._token_count(choice["text"]),
99+
prompt_token_count=(
100+
request.token_count
101+
if request.token_count
102+
else self._token_count(request.prompt)
103+
),
104+
output_token_count=(
105+
num_gen_tokens
106+
if num_gen_tokens
107+
else self._token_count(choice["text"])
108+
),
85109
)
86110
break
87111
else:
@@ -133,14 +157,6 @@ def model_tokenizer(self, model: str) -> Optional[Any]:
133157
return None
134158

135159
def _token_count(self, text: str) -> int:
136-
"""
137-
Count the number of tokens in a text.
138-
139-
:param text: The text to tokenize.
140-
:type text: str
141-
:return: The number of tokens.
142-
:rtype: int
143-
"""
144160
token_count = len(text.split())
145161
logger.debug(f"Token count for text '{text}': {token_count}")
146162
return token_count

src/guidellm/core/request.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Dict, Any, Optional
2+
import uuid
23

34

45
__all__ = ["BenchmarkRequest"]
@@ -10,11 +11,18 @@ class BenchmarkRequest:
1011
1112
:param prompt: The input prompt for the benchmark request.
1213
:type prompt: str
14+
:param token_count: The number of tokens to generate, defaults to None.
15+
:type token_count: Optional[int]
1316
:param params: Optional parameters for the benchmark request, defaults to None.
1417
:type params: Optional[Dict[str, Any]]
1518
"""
1619

17-
def __init__(self, prompt: str, params: Optional[Dict[str, Any]] = None):
20+
def __init__(
21+
self,
22+
prompt: str,
23+
token_count: Optional[int] = None,
24+
params: Optional[Dict[str, Any]] = None,
25+
):
1826
"""
1927
Initialize the BenchmarkRequest with a prompt and optional parameters.
2028
@@ -23,9 +31,21 @@ def __init__(self, prompt: str, params: Optional[Dict[str, Any]] = None):
2331
:param params: Optional parameters for the benchmark request, defaults to None.
2432
:type params: Optional[Dict[str, Any]]
2533
"""
34+
self._id = str(uuid.uuid4())
2635
self._prompt = prompt
36+
self._token_count = token_count
2737
self._params = params or {}
2838

39+
@property
40+
def id(self) -> str:
41+
"""
42+
Get the unique identifier for the benchmark request.
43+
44+
:return: The unique identifier.
45+
:rtype: str
46+
"""
47+
return self._id
48+
2949
@property
3050
def prompt(self) -> str:
3151
"""
@@ -36,6 +56,16 @@ def prompt(self) -> str:
3656
"""
3757
return self._prompt
3858

59+
@property
60+
def token_count(self) -> Optional[int]:
61+
"""
62+
Get the number of tokens to generate for the benchmark request.
63+
64+
:return: The number of tokens to generate.
65+
:rtype: Optional[int]
66+
"""
67+
return self._token_count
68+
3969
@property
4070
def params(self) -> Dict[str, Any]:
4171
"""
@@ -53,7 +83,7 @@ def __str__(self) -> str:
5383
:return: String representation of the BenchmarkRequest.
5484
:rtype: str
5585
"""
56-
return f"BenchmarkRequest(prompt={self._prompt}, params={self._params})"
86+
return f"BenchmarkRequest(id={self.id}, prompt={self._prompt}, params={self._params})"
5787

5888
def __repr__(self) -> str:
5989
"""
@@ -62,4 +92,4 @@ def __repr__(self) -> str:
6292
:return: Unambiguous string representation of the BenchmarkRequest.
6393
:rtype: str
6494
"""
65-
return f"BenchmarkRequest(prompt={self._prompt}, params={self._params})"
95+
return f"BenchmarkRequest(id={self.id}, prompt={self._prompt}, params={self._params})"

0 commit comments

Comments
 (0)