|
2 | 2 | Unit tests for guidellm.data.entrypoints module, specifically process_dataset function. |
3 | 3 | """ |
4 | 4 |
|
| 5 | +import os |
5 | 6 | import json |
| 7 | +from typing import TYPE_CHECKING |
6 | 8 | from unittest.mock import MagicMock, patch |
7 | 9 |
|
| 10 | +if TYPE_CHECKING: |
| 11 | + from collections.abc import Iterator |
| 12 | + |
8 | 13 | import pytest |
9 | 14 | import yaml |
10 | 15 | from datasets import Dataset |
11 | 16 | from transformers import PreTrainedTokenizerBase |
12 | 17 |
|
13 | 18 | from guidellm.data.entrypoints import ( |
| 19 | + STRATEGY_HANDLERS, |
14 | 20 | PromptTooShortError, |
15 | 21 | ShortPromptStrategy, |
| 22 | + handle_concatenate_strategy, |
| 23 | + handle_error_strategy, |
| 24 | + handle_ignore_strategy, |
| 25 | + handle_pad_strategy, |
16 | 26 | process_dataset, |
| 27 | + push_dataset_to_hub, |
17 | 28 | ) |
18 | 29 |
|
19 | 30 |
|
@@ -735,6 +746,7 @@ def test_process_dataset_successful_processing( |
735 | 746 |
|
736 | 747 | # Verify each row has required fields |
737 | 748 | for row in saved_dataset: |
| 749 | + assert "prompt" in row |
738 | 750 | assert "prompt_tokens_count" in row |
739 | 751 | assert "output_tokens_count" in row |
740 | 752 | assert isinstance(row["prompt_tokens_count"], int) |
@@ -775,6 +787,9 @@ def test_process_dataset_empty_after_filtering( |
775 | 787 | short_prompt_strategy=ShortPromptStrategy.IGNORE, |
776 | 788 | ) |
777 | 789 |
|
| 790 | + # Verify all expected calls were made (even though dataset is empty) |
| 791 | + mock_check_processor.assert_called_once() |
| 792 | + mock_deserializer_factory_class.deserialize.assert_called_once() |
778 | 793 | # When all prompts are filtered out, save_dataset_to_file is not called |
779 | 794 | # (the function returns early in _finalize_processed_dataset) |
780 | 795 | # This is expected behavior - the function handles empty datasets gracefully |
@@ -1630,3 +1645,215 @@ def test_prefix_buckets_trimming( |
1630 | 1645 | assert "prompt_tokens_count" in row |
1631 | 1646 | assert "output_tokens_count" in row |
1632 | 1647 |
|
| 1648 | + |
| 1649 | +class TestShortPromptStrategyHandlers: |
| 1650 | + """Unit tests for individual short prompt strategy handler functions.""" |
| 1651 | + |
| 1652 | + @pytest.mark.sanity |
| 1653 | + def test_handle_ignore_strategy_too_short(self, tokenizer_mock): |
| 1654 | + """Test handle_ignore_strategy returns None for short prompts.""" |
| 1655 | + result = handle_ignore_strategy("short", 10, tokenizer_mock) |
| 1656 | + assert result is None |
| 1657 | + tokenizer_mock.encode.assert_called_with("short") |
| 1658 | + |
| 1659 | + @pytest.mark.sanity |
| 1660 | + def test_handle_ignore_strategy_sufficient_length(self, tokenizer_mock): |
| 1661 | + """Test handle_ignore_strategy returns prompt for sufficient length.""" |
| 1662 | + result = handle_ignore_strategy("long prompt", 5, tokenizer_mock) |
| 1663 | + assert result == "long prompt" |
| 1664 | + tokenizer_mock.encode.assert_called_with("long prompt") |
| 1665 | + |
| 1666 | + @pytest.mark.sanity |
| 1667 | + def test_handle_concatenate_strategy_enough_prompts(self, tokenizer_mock): |
| 1668 | + """Test handle_concatenate_strategy with enough prompts.""" |
| 1669 | + dataset_iter = iter([{"prompt": "longer"}]) |
| 1670 | + result = handle_concatenate_strategy( |
| 1671 | + "short", 10, dataset_iter, "prompt", tokenizer_mock, "\n" |
| 1672 | + ) |
| 1673 | + assert result == "short\nlonger" |
| 1674 | + |
| 1675 | + @pytest.mark.sanity |
| 1676 | + def test_handle_concatenate_strategy_not_enough_prompts(self, tokenizer_mock): |
| 1677 | + """Test handle_concatenate_strategy without enough prompts.""" |
| 1678 | + dataset_iter: Iterator = iter([]) |
| 1679 | + result = handle_concatenate_strategy( |
| 1680 | + "short", 10, dataset_iter, "prompt", tokenizer_mock, "" |
| 1681 | + ) |
| 1682 | + assert result is None |
| 1683 | + |
| 1684 | + @pytest.mark.sanity |
| 1685 | + def test_handle_pad_strategy(self, tokenizer_mock): |
| 1686 | + """Test handle_pad_strategy pads short prompts.""" |
| 1687 | + result = handle_pad_strategy("short", 10, tokenizer_mock, "p") |
| 1688 | + assert result.startswith("shortppppp") |
| 1689 | + |
| 1690 | + @pytest.mark.sanity |
| 1691 | + def test_handle_error_strategy_valid_prompt(self, tokenizer_mock): |
| 1692 | + """Test handle_error_strategy returns prompt for valid length.""" |
| 1693 | + result = handle_error_strategy("valid prompt", 5, tokenizer_mock) |
| 1694 | + assert result == "valid prompt" |
| 1695 | + tokenizer_mock.encode.assert_called_with("valid prompt") |
| 1696 | + |
| 1697 | + @pytest.mark.sanity |
| 1698 | + def test_handle_error_strategy_too_short_prompt(self, tokenizer_mock): |
| 1699 | + """Test handle_error_strategy raises error for short prompts.""" |
| 1700 | + with pytest.raises(PromptTooShortError): |
| 1701 | + handle_error_strategy("short", 10, tokenizer_mock) |
| 1702 | + |
| 1703 | + |
| 1704 | +class TestProcessDatasetPushToHub: |
| 1705 | + """Test cases for push_to_hub functionality.""" |
| 1706 | + |
| 1707 | + @pytest.mark.smoke |
| 1708 | + @patch("guidellm.data.entrypoints.push_dataset_to_hub") |
| 1709 | + @patch("guidellm.data.entrypoints.save_dataset_to_file") |
| 1710 | + @patch("guidellm.data.entrypoints.DatasetDeserializerFactory") |
| 1711 | + @patch("guidellm.data.entrypoints.check_load_processor") |
| 1712 | + def test_process_dataset_push_to_hub_called( |
| 1713 | + self, |
| 1714 | + mock_check_processor, |
| 1715 | + mock_deserializer_factory_class, |
| 1716 | + mock_save_to_file, |
| 1717 | + mock_push, |
| 1718 | + tokenizer_mock, |
| 1719 | + tmp_path, |
| 1720 | + ): |
| 1721 | + """Test that push_to_hub is called when push_to_hub=True.""" |
| 1722 | + # Create a dataset with prompts long enough to be processed |
| 1723 | + sample_dataset = Dataset.from_dict({ |
| 1724 | + "prompt": ["abc " * 50], # Long enough |
| 1725 | + }) |
| 1726 | + |
| 1727 | + mock_check_processor.return_value = tokenizer_mock |
| 1728 | + mock_deserializer_factory_class.deserialize.return_value = sample_dataset |
| 1729 | + |
| 1730 | + output_path = tmp_path / "output.json" |
| 1731 | + config = '{"prompt_tokens": 10, "output_tokens": 5}' |
| 1732 | + |
| 1733 | + process_dataset( |
| 1734 | + data="input", |
| 1735 | + output_path=output_path, |
| 1736 | + processor=tokenizer_mock, |
| 1737 | + config=config, |
| 1738 | + push_to_hub=True, |
| 1739 | + hub_dataset_id="id123", |
| 1740 | + ) |
| 1741 | + |
| 1742 | + # Verify push_to_hub was called with the correct arguments |
| 1743 | + assert mock_push.called |
| 1744 | + call_args = mock_push.call_args |
| 1745 | + assert call_args[0][0] == "id123" |
| 1746 | + assert isinstance(call_args[0][1], Dataset) |
| 1747 | + |
| 1748 | + @pytest.mark.sanity |
| 1749 | + @patch("guidellm.data.entrypoints.push_dataset_to_hub") |
| 1750 | + @patch("guidellm.data.entrypoints.save_dataset_to_file") |
| 1751 | + @patch("guidellm.data.entrypoints.DatasetDeserializerFactory") |
| 1752 | + @patch("guidellm.data.entrypoints.check_load_processor") |
| 1753 | + def test_process_dataset_push_to_hub_not_called( |
| 1754 | + self, |
| 1755 | + mock_check_processor, |
| 1756 | + mock_deserializer_factory_class, |
| 1757 | + mock_save_to_file, |
| 1758 | + mock_push, |
| 1759 | + tokenizer_mock, |
| 1760 | + tmp_path, |
| 1761 | + ): |
| 1762 | + """Test that push_to_hub is not called when push_to_hub=False.""" |
| 1763 | + # Create a dataset with prompts long enough to be processed |
| 1764 | + sample_dataset = Dataset.from_dict({ |
| 1765 | + "prompt": ["abc " * 50], # Long enough |
| 1766 | + }) |
| 1767 | + |
| 1768 | + mock_check_processor.return_value = tokenizer_mock |
| 1769 | + mock_deserializer_factory_class.deserialize.return_value = sample_dataset |
| 1770 | + |
| 1771 | + output_path = tmp_path / "output.json" |
| 1772 | + config = '{"prompt_tokens": 10, "output_tokens": 5}' |
| 1773 | + |
| 1774 | + process_dataset( |
| 1775 | + data="input", |
| 1776 | + output_path=output_path, |
| 1777 | + processor=tokenizer_mock, |
| 1778 | + config=config, |
| 1779 | + push_to_hub=False, |
| 1780 | + ) |
| 1781 | + |
| 1782 | + # Verify push_to_hub was not called |
| 1783 | + mock_push.assert_not_called() |
| 1784 | + |
| 1785 | + @pytest.mark.regression |
| 1786 | + def test_push_dataset_to_hub_success(self): |
| 1787 | + """Test push_dataset_to_hub success case.""" |
| 1788 | + os.environ["HF_TOKEN"] = "token" |
| 1789 | + mock_dataset = MagicMock(spec=Dataset) |
| 1790 | + push_dataset_to_hub("dataset_id", mock_dataset) |
| 1791 | + mock_dataset.push_to_hub.assert_called_once_with("dataset_id", token="token") |
| 1792 | + |
| 1793 | + @pytest.mark.regression |
| 1794 | + def test_push_dataset_to_hub_error_no_env(self): |
| 1795 | + """Test push_dataset_to_hub raises error when HF_TOKEN is missing.""" |
| 1796 | + if "HF_TOKEN" in os.environ: |
| 1797 | + del os.environ["HF_TOKEN"] |
| 1798 | + mock_dataset = MagicMock(spec=Dataset) |
| 1799 | + with pytest.raises(ValueError, match="hub_dataset_id and HF_TOKEN"): |
| 1800 | + push_dataset_to_hub("dataset_id", mock_dataset) |
| 1801 | + |
| 1802 | + @pytest.mark.regression |
| 1803 | + def test_push_dataset_to_hub_error_no_id(self): |
| 1804 | + """Test push_dataset_to_hub raises error when hub_dataset_id is missing.""" |
| 1805 | + os.environ["HF_TOKEN"] = "token" |
| 1806 | + mock_dataset = MagicMock(spec=Dataset) |
| 1807 | + with pytest.raises(ValueError, match="hub_dataset_id and HF_TOKEN"): |
| 1808 | + push_dataset_to_hub(None, mock_dataset) |
| 1809 | + |
| 1810 | + |
| 1811 | +class TestProcessDatasetStrategyHandlerIntegration: |
| 1812 | + """Test cases for strategy handler integration with process_dataset.""" |
| 1813 | + |
| 1814 | + @pytest.mark.smoke |
| 1815 | + @patch("guidellm.data.entrypoints.save_dataset_to_file") |
| 1816 | + @patch("guidellm.data.entrypoints.DatasetDeserializerFactory") |
| 1817 | + @patch("guidellm.data.entrypoints.check_load_processor") |
| 1818 | + def test_strategy_handler_called( |
| 1819 | + self, |
| 1820 | + mock_check_processor, |
| 1821 | + mock_deserializer_factory_class, |
| 1822 | + mock_save_to_file, |
| 1823 | + tokenizer_mock, |
| 1824 | + tmp_path, |
| 1825 | + ): |
| 1826 | + """Test that strategy handlers are called during dataset processing.""" |
| 1827 | + mock_handler = MagicMock(return_value="processed_prompt") |
| 1828 | + with patch.dict(STRATEGY_HANDLERS, {ShortPromptStrategy.IGNORE: mock_handler}): |
| 1829 | + # Create a dataset with prompts that need processing |
| 1830 | + sample_dataset = Dataset.from_dict({ |
| 1831 | + "prompt": [ |
| 1832 | + "abc" * 20, # Long enough to pass |
| 1833 | + "def" * 20, # Long enough to pass |
| 1834 | + ], |
| 1835 | + }) |
| 1836 | + |
| 1837 | + mock_check_processor.return_value = tokenizer_mock |
| 1838 | + mock_deserializer_factory_class.deserialize.return_value = sample_dataset |
| 1839 | + |
| 1840 | + output_path = tmp_path / "output.json" |
| 1841 | + config = '{"prompt_tokens": 10, "output_tokens": 5}' |
| 1842 | + |
| 1843 | + process_dataset( |
| 1844 | + data="input", |
| 1845 | + output_path=output_path, |
| 1846 | + processor=tokenizer_mock, |
| 1847 | + config=config, |
| 1848 | + short_prompt_strategy=ShortPromptStrategy.IGNORE, |
| 1849 | + ) |
| 1850 | + |
| 1851 | + # Verify that the handler was called during processing |
| 1852 | + # The handler is called for each row that needs processing |
| 1853 | + mock_deserializer_factory_class.deserialize.assert_called_once() |
| 1854 | + mock_check_processor.assert_called_once() |
| 1855 | + assert mock_save_to_file.called |
| 1856 | + # Verify handler was called (at least once if there are rows to process) |
| 1857 | + if len(sample_dataset) > 0: |
| 1858 | + assert mock_handler.called |
| 1859 | + |
0 commit comments