Skip to content

Commit 6308e6a

Browse files
committed
Port tests from old dataset test file
Signed-off-by: Jared O'Connell <[email protected]>
1 parent 11cf5d2 commit 6308e6a

File tree

3 files changed

+227
-292
lines changed

3 files changed

+227
-292
lines changed

tests/unit/data/test_entrypoints.py

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,29 @@
22
Unit tests for guidellm.data.entrypoints module, specifically process_dataset function.
33
"""
44

5+
import os
56
import json
7+
from typing import TYPE_CHECKING
68
from unittest.mock import MagicMock, patch
79

10+
if TYPE_CHECKING:
11+
from collections.abc import Iterator
12+
813
import pytest
914
import yaml
1015
from datasets import Dataset
1116
from transformers import PreTrainedTokenizerBase
1217

1318
from guidellm.data.entrypoints import (
19+
STRATEGY_HANDLERS,
1420
PromptTooShortError,
1521
ShortPromptStrategy,
22+
handle_concatenate_strategy,
23+
handle_error_strategy,
24+
handle_ignore_strategy,
25+
handle_pad_strategy,
1626
process_dataset,
27+
push_dataset_to_hub,
1728
)
1829

1930

@@ -735,6 +746,7 @@ def test_process_dataset_successful_processing(
735746

736747
# Verify each row has required fields
737748
for row in saved_dataset:
749+
assert "prompt" in row
738750
assert "prompt_tokens_count" in row
739751
assert "output_tokens_count" in row
740752
assert isinstance(row["prompt_tokens_count"], int)
@@ -775,6 +787,9 @@ def test_process_dataset_empty_after_filtering(
775787
short_prompt_strategy=ShortPromptStrategy.IGNORE,
776788
)
777789

790+
# Verify all expected calls were made (even though dataset is empty)
791+
mock_check_processor.assert_called_once()
792+
mock_deserializer_factory_class.deserialize.assert_called_once()
778793
# When all prompts are filtered out, save_dataset_to_file is not called
779794
# (the function returns early in _finalize_processed_dataset)
780795
# This is expected behavior - the function handles empty datasets gracefully
@@ -1630,3 +1645,215 @@ def test_prefix_buckets_trimming(
16301645
assert "prompt_tokens_count" in row
16311646
assert "output_tokens_count" in row
16321647

1648+
1649+
class TestShortPromptStrategyHandlers:
1650+
"""Unit tests for individual short prompt strategy handler functions."""
1651+
1652+
@pytest.mark.sanity
1653+
def test_handle_ignore_strategy_too_short(self, tokenizer_mock):
1654+
"""Test handle_ignore_strategy returns None for short prompts."""
1655+
result = handle_ignore_strategy("short", 10, tokenizer_mock)
1656+
assert result is None
1657+
tokenizer_mock.encode.assert_called_with("short")
1658+
1659+
@pytest.mark.sanity
1660+
def test_handle_ignore_strategy_sufficient_length(self, tokenizer_mock):
1661+
"""Test handle_ignore_strategy returns prompt for sufficient length."""
1662+
result = handle_ignore_strategy("long prompt", 5, tokenizer_mock)
1663+
assert result == "long prompt"
1664+
tokenizer_mock.encode.assert_called_with("long prompt")
1665+
1666+
@pytest.mark.sanity
1667+
def test_handle_concatenate_strategy_enough_prompts(self, tokenizer_mock):
1668+
"""Test handle_concatenate_strategy with enough prompts."""
1669+
dataset_iter = iter([{"prompt": "longer"}])
1670+
result = handle_concatenate_strategy(
1671+
"short", 10, dataset_iter, "prompt", tokenizer_mock, "\n"
1672+
)
1673+
assert result == "short\nlonger"
1674+
1675+
@pytest.mark.sanity
1676+
def test_handle_concatenate_strategy_not_enough_prompts(self, tokenizer_mock):
1677+
"""Test handle_concatenate_strategy without enough prompts."""
1678+
dataset_iter: Iterator = iter([])
1679+
result = handle_concatenate_strategy(
1680+
"short", 10, dataset_iter, "prompt", tokenizer_mock, ""
1681+
)
1682+
assert result is None
1683+
1684+
@pytest.mark.sanity
1685+
def test_handle_pad_strategy(self, tokenizer_mock):
1686+
"""Test handle_pad_strategy pads short prompts."""
1687+
result = handle_pad_strategy("short", 10, tokenizer_mock, "p")
1688+
assert result.startswith("shortppppp")
1689+
1690+
@pytest.mark.sanity
1691+
def test_handle_error_strategy_valid_prompt(self, tokenizer_mock):
1692+
"""Test handle_error_strategy returns prompt for valid length."""
1693+
result = handle_error_strategy("valid prompt", 5, tokenizer_mock)
1694+
assert result == "valid prompt"
1695+
tokenizer_mock.encode.assert_called_with("valid prompt")
1696+
1697+
@pytest.mark.sanity
1698+
def test_handle_error_strategy_too_short_prompt(self, tokenizer_mock):
1699+
"""Test handle_error_strategy raises error for short prompts."""
1700+
with pytest.raises(PromptTooShortError):
1701+
handle_error_strategy("short", 10, tokenizer_mock)
1702+
1703+
1704+
class TestProcessDatasetPushToHub:
1705+
"""Test cases for push_to_hub functionality."""
1706+
1707+
@pytest.mark.smoke
1708+
@patch("guidellm.data.entrypoints.push_dataset_to_hub")
1709+
@patch("guidellm.data.entrypoints.save_dataset_to_file")
1710+
@patch("guidellm.data.entrypoints.DatasetDeserializerFactory")
1711+
@patch("guidellm.data.entrypoints.check_load_processor")
1712+
def test_process_dataset_push_to_hub_called(
1713+
self,
1714+
mock_check_processor,
1715+
mock_deserializer_factory_class,
1716+
mock_save_to_file,
1717+
mock_push,
1718+
tokenizer_mock,
1719+
tmp_path,
1720+
):
1721+
"""Test that push_to_hub is called when push_to_hub=True."""
1722+
# Create a dataset with prompts long enough to be processed
1723+
sample_dataset = Dataset.from_dict({
1724+
"prompt": ["abc " * 50], # Long enough
1725+
})
1726+
1727+
mock_check_processor.return_value = tokenizer_mock
1728+
mock_deserializer_factory_class.deserialize.return_value = sample_dataset
1729+
1730+
output_path = tmp_path / "output.json"
1731+
config = '{"prompt_tokens": 10, "output_tokens": 5}'
1732+
1733+
process_dataset(
1734+
data="input",
1735+
output_path=output_path,
1736+
processor=tokenizer_mock,
1737+
config=config,
1738+
push_to_hub=True,
1739+
hub_dataset_id="id123",
1740+
)
1741+
1742+
# Verify push_to_hub was called with the correct arguments
1743+
assert mock_push.called
1744+
call_args = mock_push.call_args
1745+
assert call_args[0][0] == "id123"
1746+
assert isinstance(call_args[0][1], Dataset)
1747+
1748+
@pytest.mark.sanity
1749+
@patch("guidellm.data.entrypoints.push_dataset_to_hub")
1750+
@patch("guidellm.data.entrypoints.save_dataset_to_file")
1751+
@patch("guidellm.data.entrypoints.DatasetDeserializerFactory")
1752+
@patch("guidellm.data.entrypoints.check_load_processor")
1753+
def test_process_dataset_push_to_hub_not_called(
1754+
self,
1755+
mock_check_processor,
1756+
mock_deserializer_factory_class,
1757+
mock_save_to_file,
1758+
mock_push,
1759+
tokenizer_mock,
1760+
tmp_path,
1761+
):
1762+
"""Test that push_to_hub is not called when push_to_hub=False."""
1763+
# Create a dataset with prompts long enough to be processed
1764+
sample_dataset = Dataset.from_dict({
1765+
"prompt": ["abc " * 50], # Long enough
1766+
})
1767+
1768+
mock_check_processor.return_value = tokenizer_mock
1769+
mock_deserializer_factory_class.deserialize.return_value = sample_dataset
1770+
1771+
output_path = tmp_path / "output.json"
1772+
config = '{"prompt_tokens": 10, "output_tokens": 5}'
1773+
1774+
process_dataset(
1775+
data="input",
1776+
output_path=output_path,
1777+
processor=tokenizer_mock,
1778+
config=config,
1779+
push_to_hub=False,
1780+
)
1781+
1782+
# Verify push_to_hub was not called
1783+
mock_push.assert_not_called()
1784+
1785+
@pytest.mark.regression
1786+
def test_push_dataset_to_hub_success(self):
1787+
"""Test push_dataset_to_hub success case."""
1788+
os.environ["HF_TOKEN"] = "token"
1789+
mock_dataset = MagicMock(spec=Dataset)
1790+
push_dataset_to_hub("dataset_id", mock_dataset)
1791+
mock_dataset.push_to_hub.assert_called_once_with("dataset_id", token="token")
1792+
1793+
@pytest.mark.regression
1794+
def test_push_dataset_to_hub_error_no_env(self):
1795+
"""Test push_dataset_to_hub raises error when HF_TOKEN is missing."""
1796+
if "HF_TOKEN" in os.environ:
1797+
del os.environ["HF_TOKEN"]
1798+
mock_dataset = MagicMock(spec=Dataset)
1799+
with pytest.raises(ValueError, match="hub_dataset_id and HF_TOKEN"):
1800+
push_dataset_to_hub("dataset_id", mock_dataset)
1801+
1802+
@pytest.mark.regression
1803+
def test_push_dataset_to_hub_error_no_id(self):
1804+
"""Test push_dataset_to_hub raises error when hub_dataset_id is missing."""
1805+
os.environ["HF_TOKEN"] = "token"
1806+
mock_dataset = MagicMock(spec=Dataset)
1807+
with pytest.raises(ValueError, match="hub_dataset_id and HF_TOKEN"):
1808+
push_dataset_to_hub(None, mock_dataset)
1809+
1810+
1811+
class TestProcessDatasetStrategyHandlerIntegration:
1812+
"""Test cases for strategy handler integration with process_dataset."""
1813+
1814+
@pytest.mark.smoke
1815+
@patch("guidellm.data.entrypoints.save_dataset_to_file")
1816+
@patch("guidellm.data.entrypoints.DatasetDeserializerFactory")
1817+
@patch("guidellm.data.entrypoints.check_load_processor")
1818+
def test_strategy_handler_called(
1819+
self,
1820+
mock_check_processor,
1821+
mock_deserializer_factory_class,
1822+
mock_save_to_file,
1823+
tokenizer_mock,
1824+
tmp_path,
1825+
):
1826+
"""Test that strategy handlers are called during dataset processing."""
1827+
mock_handler = MagicMock(return_value="processed_prompt")
1828+
with patch.dict(STRATEGY_HANDLERS, {ShortPromptStrategy.IGNORE: mock_handler}):
1829+
# Create a dataset with prompts that need processing
1830+
sample_dataset = Dataset.from_dict({
1831+
"prompt": [
1832+
"abc" * 20, # Long enough to pass
1833+
"def" * 20, # Long enough to pass
1834+
],
1835+
})
1836+
1837+
mock_check_processor.return_value = tokenizer_mock
1838+
mock_deserializer_factory_class.deserialize.return_value = sample_dataset
1839+
1840+
output_path = tmp_path / "output.json"
1841+
config = '{"prompt_tokens": 10, "output_tokens": 5}'
1842+
1843+
process_dataset(
1844+
data="input",
1845+
output_path=output_path,
1846+
processor=tokenizer_mock,
1847+
config=config,
1848+
short_prompt_strategy=ShortPromptStrategy.IGNORE,
1849+
)
1850+
1851+
# Verify that the handler was called during processing
1852+
# The handler is called for each row that needs processing
1853+
mock_deserializer_factory_class.deserialize.assert_called_once()
1854+
mock_check_processor.assert_called_once()
1855+
assert mock_save_to_file.called
1856+
# Verify handler was called (at least once if there are rows to process)
1857+
if len(sample_dataset) > 0:
1858+
assert mock_handler.called
1859+

tests/unit/preprocess/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)