Skip to content

Commit e922b5b

Browse files
committed
Addressed review comments
Use separate class for preprocess config Signed-off-by: Jared O'Connell <[email protected]>
1 parent 86bca8e commit e922b5b

File tree

2 files changed

+61
-15
lines changed

2 files changed

+61
-15
lines changed

src/guidellm/data/config.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
import yaml
77
from pydantic import ValidationError
88

9-
from guidellm.data.schemas import DataNotSupportedError
10-
from guidellm.schemas import StandardBaseModel
9+
from guidellm.data.schemas import DataConfig, DataNotSupportedError
1110

12-
ConfigT = TypeVar("ConfigT", bound=StandardBaseModel)
11+
ConfigT = TypeVar("ConfigT", bound=DataConfig)
1312

1413

1514
def load_config(config: Any, config_class: type[ConfigT]) -> ConfigT | None:

src/guidellm/data/schemas.py

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from guidellm.schemas import StandardBaseModel
88

99
__all__ = [
10+
"DataConfig",
1011
"DataNotSupportedError",
1112
"GenerativeDatasetColumnType",
1213
"SyntheticTextDatasetConfig",
@@ -29,47 +30,54 @@ class DataNotSupportedError(Exception):
2930
Exception raised when the data format is not supported by deserializer or config.
3031
"""
3132

32-
class TokenCountConfig(StandardBaseModel):
33+
class DataConfig(StandardBaseModel):
34+
"""
35+
A generic parent class for various configs for the data package
36+
that can be passed in as key-value pairs or JSON.
37+
"""
38+
39+
class PreprocessDatasetConfig(DataConfig):
40+
3341
prompt_tokens: int = Field(
34-
description="The average number of text tokens in prompts.",
42+
description="The average number of text tokens retained or added to prompts.",
3543
gt=0,
3644
)
3745
prompt_tokens_stdev: int | None = Field(
38-
description="The standard deviation of the tokens in prompts.",
46+
description="The standard deviation of the number of tokens retained in or "
47+
"added to prompts.",
3948
gt=0,
4049
default=None,
4150
)
4251
prompt_tokens_min: int | None = Field(
43-
description="The minimum number of text tokens in prompts.",
52+
description="The minimum number of text tokens retained or added to prompts.",
4453
gt=0,
4554
default=None,
4655
)
4756
prompt_tokens_max: int | None = Field(
48-
description="The maximum number of text tokens in prompts.",
57+
description="The maximum number of text tokens retained or added to prompts.",
4958
gt=0,
5059
default=None,
5160
)
5261
output_tokens: int = Field(
53-
description="The average number of text tokens in outputs.",
62+
description="The average number of text tokens retained or added to outputs.",
5463
gt=0,
5564
)
5665
output_tokens_stdev: int | None = Field(
57-
description="The standard deviation of the tokens in outputs.",
66+
description="The standard deviation of the number of tokens retained or "
67+
"added to outputs.",
5868
gt=0,
5969
default=None,
6070
)
6171
output_tokens_min: int | None = Field(
62-
description="The minimum number of text tokens in outputs.",
72+
description="The minimum number of text tokens retained or added to outputs.",
6373
gt=0,
6474
default=None,
6575
)
6676
output_tokens_max: int | None = Field(
67-
description="The maximum number of text tokens in outputs.",
77+
description="The maximum number of text tokens retained or added to outputs.",
6878
gt=0,
6979
default=None,
7080
)
71-
72-
class PreprocessDatasetConfig(TokenCountConfig):
7381
prefix_tokens_max: int | None = Field(
7482
description="The maximum number of text tokens left in the prefixes.",
7583
gt=0,
@@ -94,7 +102,46 @@ class SyntheticTextPrefixBucketConfig(StandardBaseModel):
94102
)
95103

96104

97-
class SyntheticTextDatasetConfig(TokenCountConfig):
105+
class SyntheticTextDatasetConfig(DataConfig):
106+
prompt_tokens: int = Field(
107+
description="The average number of text tokens generated for prompts.",
108+
gt=0,
109+
)
110+
prompt_tokens_stdev: int | None = Field(
111+
description="The standard deviation of the tokens generated for prompts.",
112+
gt=0,
113+
default=None,
114+
)
115+
prompt_tokens_min: int | None = Field(
116+
description="The minimum number of text tokens generated for prompts.",
117+
gt=0,
118+
default=None,
119+
)
120+
prompt_tokens_max: int | None = Field(
121+
description="The maximum number of text tokens generated for prompts.",
122+
gt=0,
123+
default=None,
124+
)
125+
output_tokens: int = Field(
126+
description="The average number of text tokens generated for outputs.",
127+
gt=0,
128+
)
129+
output_tokens_stdev: int | None = Field(
130+
description="The standard deviation of the tokens generated for outputs.",
131+
gt=0,
132+
default=None,
133+
)
134+
output_tokens_min: int | None = Field(
135+
description="The minimum number of text tokens generated for outputs.",
136+
gt=0,
137+
default=None,
138+
)
139+
output_tokens_max: int | None = Field(
140+
description="The maximum number of text tokens generated for outputs.",
141+
gt=0,
142+
default=None,
143+
)
144+
98145
model_config = ConfigDict(
99146
extra="allow",
100147
)

0 commit comments

Comments
 (0)