Skip to content

Commit 2f06dc3

Browse files
committed
Updates for pydantic serialization / deserialization to rebase on current main, address review comments, and finalize last pieces as well as test fixes
1 parent 2df71cd commit 2f06dc3

File tree

13 files changed

+226
-136
lines changed

13 files changed

+226
-136
lines changed

.pre-commit-config.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ repos:
2828
datasets,
2929
loguru,
3030
numpy,
31+
pydantic,
32+
pyyaml,
3133
openai,
3234
requests,
3335
transformers,
@@ -38,4 +40,5 @@ repos:
3840
# types
3941
types-click,
4042
types-requests,
43+
types-PyYAML,
4144
]

pyproject.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ dependencies = [
3030
"loguru",
3131
"numpy",
3232
"openai",
33+
"pydantic>=2.0.0",
34+
"pyyaml>=6.0.0",
3335
"requests",
3436
"transformers",
3537
]
@@ -48,7 +50,12 @@ dev = [
4850
"pytest-mock~=3.14.0",
4951
"ruff~=0.5.2",
5052
"tox~=4.16.0",
51-
"types-requests~=2.32.0"
53+
"types-requests~=2.32.0",
54+
55+
# type-checking
56+
"types-click",
57+
"types-requests",
58+
"types-PyYAML",
5259
]
5360

5461

src/guidellm/backend/base.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,16 +91,18 @@ def submit(self, request: TextGenerationRequest) -> TextGenerationResult:
9191

9292
logger.info(f"Submitting request with prompt: {request.prompt}")
9393

94-
result = TextGenerationResult(TextGenerationRequest(prompt=request.prompt))
94+
result = TextGenerationResult(
95+
request=TextGenerationRequest(prompt=request.prompt)
96+
)
9597
result.start(request.prompt)
9698

9799
for response in self.make_request(request): # GenerativeResponse
98100
if response.type_ == "token_iter" and response.add_token:
99101
result.output_token(response.add_token)
100102
elif response.type_ == "final":
101103
result.end(
102-
response.prompt_token_count,
103-
response.output_token_count,
104+
prompt_token_count=response.prompt_token_count,
105+
output_token_count=response.output_token_count,
104106
)
105107

106108
logger.info(f"Request completed with output: {result.output}")

src/guidellm/core/distribution.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from typing import List, Optional, Union
1+
from typing import List, Sequence
22

33
import numpy as np
44
from loguru import logger
5+
from pydantic import Field
56

67
from guidellm.core.serializable import Serializable
78

@@ -14,24 +15,12 @@ class Distribution(Serializable):
1415
statistical analyses.
1516
"""
1617

17-
def __init__(self, **data):
18-
super().__init__(**data)
19-
logger.debug(f"Initialized Distribution with data: {self.data}")
18+
data: Sequence[float] = Field(
19+
default_factory=list, description="The data points of the distribution."
20+
)
2021

21-
def __str__(self) -> str:
22-
"""
23-
Return a string representation of the Distribution.
24-
"""
25-
return (
26-
f"Distribution(mean={self.mean:.2f}, median={self.median:.2f}, "
27-
f"min={self.min}, max={self.max}, count={len(self.data)})"
28-
)
29-
30-
def __repr__(self) -> str:
31-
"""
32-
Return an unambiguous string representation of the Distribution for debugging.
33-
"""
34-
return f"Distribution(data={self.data})"
22+
def __str__(self):
23+
return f"Distribution({self.describe()})"
3524

3625
@property
3726
def mean(self) -> float:
@@ -99,7 +88,7 @@ def percentile(self, percentile: float) -> float:
9988
logger.warning("No data points available to calculate percentile.")
10089
return 0.0
10190

102-
percentile_value = np.percentile(self._data, percentile).item()
91+
percentile_value = np.percentile(self.data, percentile).item()
10392
logger.debug(f"Calculated {percentile}th percentile: {percentile_value}")
10493
return percentile_value
10594

@@ -180,15 +169,15 @@ def describe(self) -> dict:
180169
logger.debug(f"Generated description: {description}")
181170
return description
182171

183-
def add_data(self, new_data: Union[List[int], List[float]]):
172+
def add_data(self, new_data: Sequence[float]):
184173
"""
185174
Add new data points to the distribution.
186175
:param new_data: A list of new numerical data points to add.
187176
"""
188-
self.data.extend(new_data)
177+
self.data = list(self.data) + list(new_data)
189178
logger.debug(f"Added new data: {new_data}")
190179

191-
def remove_data(self, remove_data: Union[List[int], List[float]]):
180+
def remove_data(self, remove_data: Sequence[float]):
192181
"""
193182
Remove specified data points from the distribution.
194183
:param remove_data: A list of numerical data points to remove.

src/guidellm/core/request.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import uuid
2-
from typing import Dict, Optional, Any
2+
from typing import Any, Dict, Optional
3+
4+
from pydantic import Field
35

46
from guidellm.core.serializable import Serializable
57

@@ -9,24 +11,18 @@ class TextGenerationRequest(Serializable):
911
A class to represent a text generation request for generative AI workloads.
1012
"""
1113

12-
id: str
13-
prompt: str
14-
prompt_token_count: Optional[int]
15-
generated_token_count: Optional[int]
16-
params: Dict[str, Any]
17-
18-
def __init__(
19-
self,
20-
prompt: str,
21-
prompt_token_count: Optional[int] = None,
22-
generated_token_count: Optional[int] = None,
23-
params: Optional[Dict[str, Any]] = None,
24-
id: Optional[str] = None,
25-
):
26-
super().__init__(
27-
id=str(uuid.uuid4()) if id is None else id,
28-
prompt=prompt,
29-
prompt_token_count=prompt_token_count,
30-
generated_token_count=generated_token_count,
31-
params=params or {},
32-
)
14+
id: str = Field(
15+
default_factory=lambda: str(uuid.uuid4()),
16+
description="The unique identifier for the request.",
17+
)
18+
prompt: str = Field(description="The input prompt for the text generation.")
19+
prompt_token_count: Optional[int] = Field(
20+
default=None, description="The number of tokens in the input prompt."
21+
)
22+
generate_token_count: Optional[int] = Field(
23+
default=None, description="The number of tokens to generate."
24+
)
25+
params: Dict[str, Any] = Field(
26+
default_factory=dict,
27+
description="The parameters for the text generation request.",
28+
)

0 commit comments

Comments
 (0)