Skip to content

Commit 90cf0cf

Browse files
committed
Support setting initial prefix_bucket from top level config
Signed-off-by: Samuel Monson <[email protected]>
1 parent 07da84f commit 90cf0cf

File tree

1 file changed

+27
-3
lines changed

1 file changed

+27
-3
lines changed

src/guidellm/data/deserializers/synthetic.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
from collections.abc import Iterator
55
from pathlib import Path
66
from random import Random
7-
from typing import Any, Callable
7+
from typing import Any, Callable, Self
88

99
import yaml
1010
from datasets import Features, IterableDataset, Value
1111
from faker import Faker
12-
from pydantic import Field
12+
from pydantic import ConfigDict, Field, model_validator
1313
from transformers import PreTrainedTokenizerBase
1414

1515
from guidellm.data.deserializers.deserializer import (
@@ -34,7 +34,7 @@ class SyntheticTextPrefixBucketConfig(StandardBaseModel):
3434
default=100,
3535
)
3636
prefix_count: int = Field(
37-
description="The number of unique prefixs to generate for this bucket.",
37+
description="The number of unique prefixes to generate for this bucket.",
3838
ge=1,
3939
default=1,
4040
)
@@ -46,6 +46,10 @@ class SyntheticTextPrefixBucketConfig(StandardBaseModel):
4646

4747

4848
class SyntheticTextDatasetConfig(StandardBaseModel):
49+
model_config = ConfigDict(
50+
extra="allow",
51+
)
52+
4953
prefix_buckets: list[SyntheticTextPrefixBucketConfig] | None = Field(
5054
description="Buckets for the prefix tokens distribution.",
5155
default=None,
@@ -93,6 +97,26 @@ class SyntheticTextDatasetConfig(StandardBaseModel):
9397
default="data:prideandprejudice.txt.gz",
9498
)
9599

100+
@model_validator(mode="after")
101+
def check_prefix_options(self) -> Self:
102+
prefix_count = self.__pydantic_extra__.get("prefix_count", None)
103+
prefix_tokens = self.__pydantic_extra__.get("prefix_count", None)
104+
if prefix_count is not None or prefix_tokens is not None:
105+
if self.prefix_buckets:
106+
raise ValueError(
107+
"prefix_buckets is mutually exclusive"
108+
" with prefix_count and prefix_tokens"
109+
)
110+
111+
self.prefix_buckets = [
112+
SyntheticTextPrefixBucketConfig(
113+
prefix_count=prefix_count or 1,
114+
prefix_tokens=prefix_tokens or 0,
115+
)
116+
]
117+
118+
return self
119+
96120

97121
class SyntheticTextGenerator:
98122
def __init__(

0 commit comments

Comments
 (0)