Skip to content

Commit 00ee570

Browse files
committed
Fix tests broken in prior commit
Signed-off-by: Jared O'Connell <[email protected]>
1 parent a227a50 commit 00ee570

File tree

1 file changed

+35
-31
lines changed

1 file changed

+35
-31
lines changed

tests/unit/data/deserializers/test_synthetic.py

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,14 @@
1111
import yaml
1212
from datasets import IterableDataset
1313

14-
from guidellm.data.deserializers.deserializer import DataNotSupportedError
14+
from guidellm.data import config as config_module
1515
from guidellm.data.deserializers.synthetic import (
1616
SyntheticTextDataset,
17-
SyntheticTextDatasetConfig,
1817
SyntheticTextDatasetDeserializer,
18+
)
19+
from guidellm.data.schemas import (
20+
DataNotSupportedError,
21+
SyntheticTextDatasetConfig,
1922
SyntheticTextPrefixBucketConfig,
2023
)
2124

@@ -409,13 +412,14 @@ def test_load_config_file_yaml(self):
409412
yaml_path = f.name
410413

411414
try:
412-
deserializer = SyntheticTextDatasetDeserializer()
413-
config = deserializer._load_config_file(yaml_path)
415+
loaded_config = config_module._load_config_file(
416+
yaml_path, SyntheticTextDatasetConfig,
417+
)
414418

415-
assert config.prompt_tokens == 60
416-
assert config.output_tokens == 15
417-
assert config.source == "yaml_test.txt"
418-
assert config.prefix_buckets[0].prefix_tokens == 3 # type: ignore [index]
419+
assert loaded_config.prompt_tokens == 60
420+
assert loaded_config.output_tokens == 15
421+
assert loaded_config.source == "yaml_test.txt"
422+
assert loaded_config.prefix_buckets[0].prefix_tokens == 3 # type: ignore [index]
419423
finally:
420424
Path(yaml_path).unlink()
421425

@@ -438,12 +442,13 @@ def test_load_config_file_config_extension(self):
438442
config_path = f.name
439443

440444
try:
441-
deserializer = SyntheticTextDatasetDeserializer()
442-
config = deserializer._load_config_file(config_path)
445+
loaded_config = config_module._load_config_file(
446+
config_path, SyntheticTextDatasetConfig,
447+
)
443448

444-
assert config.prompt_tokens == 90
445-
assert config.output_tokens == 35
446-
assert config.prefix_buckets[0].prefix_tokens == 2 # type: ignore [index]
449+
assert loaded_config.prompt_tokens == 90
450+
assert loaded_config.output_tokens == 35
451+
assert loaded_config.prefix_buckets[0].prefix_tokens == 2 # type: ignore [index]
447452
finally:
448453
Path(config_path).unlink()
449454

@@ -454,11 +459,10 @@ def test_load_config_str_json(self):
454459
### WRITTEN BY AI ###
455460
"""
456461
json_str = '{"prompt_tokens": 50, "output_tokens": 25}'
457-
deserializer = SyntheticTextDatasetDeserializer()
458-
config = deserializer._load_config_str(json_str)
462+
loaded_config = config_module._load_config_str(json_str, SyntheticTextDatasetConfig)
459463

460-
assert config.prompt_tokens == 50
461-
assert config.output_tokens == 25
464+
assert loaded_config.prompt_tokens == 50
465+
assert loaded_config.output_tokens == 25
462466

463467
@pytest.mark.smoke
464468
def test_load_config_str_key_value(self):
@@ -467,53 +471,53 @@ def test_load_config_str_key_value(self):
467471
### WRITTEN BY AI ###
468472
"""
469473
kv_str = "prompt_tokens=50,output_tokens=25"
470-
deserializer = SyntheticTextDatasetDeserializer()
471-
config = deserializer._load_config_str(kv_str)
474+
loaded_config = config_module._load_config_str(kv_str, SyntheticTextDatasetConfig)
472475

473-
assert config.prompt_tokens == 50
474-
assert config.output_tokens == 25
476+
assert loaded_config.prompt_tokens == 50
477+
assert loaded_config.output_tokens == 25
475478

476479
@pytest.mark.sanity
477480
def test_load_config_str_invalid_format(self):
478481
"""Test loading invalid format raises DataNotSupportedError.
479482
480483
### WRITTEN BY AI ###
481484
"""
482-
deserializer = SyntheticTextDatasetDeserializer()
483485
with pytest.raises(DataNotSupportedError, match="Unsupported string data"):
484-
deserializer._load_config_str("invalid_format_string")
486+
config_module._load_config_str(
487+
"invalid_format_string", SyntheticTextDatasetConfig,
488+
)
485489

486490
@pytest.mark.regression
487491
def test_load_config_file_non_existent(self):
488492
"""Test loading non-existent file returns None.
489493
490494
### WRITTEN BY AI ###
491495
"""
492-
deserializer = SyntheticTextDatasetDeserializer()
493-
config = deserializer._load_config_file("/non/existent/path.config")
494-
assert config is None
496+
loaded_config = config_module._load_config_file(
497+
"/non/existent/path.config", SyntheticTextDatasetConfig,
498+
)
499+
assert loaded_config is None
495500

496501
@pytest.mark.regression
497502
def test_load_config_str_non_string(self):
498503
"""Test loading non-string returns None.
499504
500505
### WRITTEN BY AI ###
501506
"""
502-
deserializer = SyntheticTextDatasetDeserializer()
503-
config = deserializer._load_config_str(123)
504-
assert config is None
507+
loaded_config = config_module._load_config_str(123, SyntheticTextDatasetConfig)
508+
assert loaded_config is None
505509

506510
@pytest.mark.smoke
507511
def test_call_with_config_object(self, mock_tokenizer):
508512
"""Test calling deserializer with SyntheticTextDatasetConfig.
509513
510514
### WRITTEN BY AI ###
511515
"""
512-
config = SyntheticTextDatasetConfig(prompt_tokens=50, output_tokens=25)
516+
config_input = SyntheticTextDatasetConfig(prompt_tokens=50, output_tokens=25)
513517
deserializer = SyntheticTextDatasetDeserializer()
514518

515519
result = deserializer(
516-
data=config,
520+
data=config_input,
517521
data_kwargs={},
518522
processor_factory=lambda: mock_tokenizer,
519523
random_seed=42,

0 commit comments

Comments
 (0)