From 906ff4903c074539fb17c3338fd0fb68ada204c4 Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Wed, 27 Aug 2025 15:54:32 -0400 Subject: [PATCH] Add helper for converting literals to list of strings Signed-off-by: Samuel Monson --- src/guidellm/__main__.py | 12 ++-- src/guidellm/utils/__init__.py | 2 + src/guidellm/utils/typing.py | 48 +++++++++++++ tests/unit/utils/test_typing.py | 123 ++++++++++++++++++++++++++++++++ 4 files changed, 178 insertions(+), 7 deletions(-) create mode 100644 src/guidellm/utils/typing.py create mode 100644 tests/unit/utils/test_typing.py diff --git a/src/guidellm/__main__.py b/src/guidellm/__main__.py index 7cba6a7c..ff650223 100644 --- a/src/guidellm/__main__.py +++ b/src/guidellm/__main__.py @@ -1,7 +1,7 @@ import asyncio import codecs from pathlib import Path -from typing import get_args +from typing import Union import click from pydantic import ValidationError @@ -16,12 +16,10 @@ from guidellm.config import print_config from guidellm.preprocess.dataset import ShortPromptStrategy, process_dataset from guidellm.scheduler import StrategyType -from guidellm.utils import DefaultGroupHandler +from guidellm.utils import DefaultGroupHandler, get_literal_vals from guidellm.utils import cli as cli_tools -STRATEGY_PROFILE_CHOICES = list( - set(list(get_args(ProfileType)) + list(get_args(StrategyType))) -) +STRATEGY_PROFILE_CHOICES = list(get_literal_vals(Union[ProfileType, StrategyType])) @click.group() @@ -70,10 +68,10 @@ def benchmark(): ) @click.option( "--backend-type", - type=click.Choice(list(get_args(BackendType))), + type=click.Choice(list(get_literal_vals(BackendType))), help=( "The type of backend to use to run requests against. Defaults to 'openai_http'." - f" Supported types: {', '.join(get_args(BackendType))}" + f" Supported types: {', '.join(get_literal_vals(BackendType))}" ), default=GenerativeTextScenario.get_default("backend_type"), ) diff --git a/src/guidellm/utils/__init__.py b/src/guidellm/utils/__init__.py index 576fe64d..78eae06b 100644 --- a/src/guidellm/utils/__init__.py +++ b/src/guidellm/utils/__init__.py @@ -42,6 +42,7 @@ split_text, split_text_list_by_length, ) +from .typing import get_literal_vals __all__ = [ "SUPPORTED_TYPES", @@ -67,6 +68,7 @@ "check_load_processor", "clean_text", "filter_text", + "get_literal_vals", "is_punctuation", "load_text", "safe_add", diff --git a/src/guidellm/utils/typing.py b/src/guidellm/utils/typing.py new file mode 100644 index 00000000..59358221 --- /dev/null +++ b/src/guidellm/utils/typing.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Annotated, Literal, Union, get_args, get_origin + +if TYPE_CHECKING: + from collections.abc import Iterator + +# Backwards compatibility for Python <3.10 +try: + from types import UnionType # type: ignore[attr-defined] +except ImportError: + UnionType = Union + +# Backwards compatibility for Python <3.12 +try: + from typing import TypeAliasType # type: ignore[attr-defined] +except ImportError: + from typing_extensions import TypeAliasType + + +__all__ = ["get_literal_vals"] + + +def get_literal_vals(alias) -> frozenset[str]: + """Extract all literal values from a (possibly nested) type alias.""" + + def resolve(alias) -> Iterator[str]: + origin = get_origin(alias) + + # Base case: Literal types + if origin is Literal: + for literal_val in get_args(alias): + yield str(literal_val) + # Unwrap Annotated type + elif origin is Annotated: + yield from resolve(get_args(alias)[0]) + # Unwrap TypeAliasTypes + elif isinstance(alias, TypeAliasType): + yield from resolve(alias.__value__) + # Iterate over unions + elif origin in (Union, UnionType): + for arg in get_args(alias): + yield from resolve(arg) + # Fallback + else: + yield str(alias) + + return frozenset(resolve(alias)) diff --git a/tests/unit/utils/test_typing.py b/tests/unit/utils/test_typing.py new file mode 100644 index 00000000..fafa8765 --- /dev/null +++ b/tests/unit/utils/test_typing.py @@ -0,0 +1,123 @@ +""" +Test suite for the typing utilities module. +""" + +from typing import Annotated, Literal, Union + +import pytest +from typing_extensions import TypeAlias + +from guidellm.utils.typing import get_literal_vals + +# Local type definitions to avoid imports from other modules +LocalProfileType = Literal["synchronous", "async", "concurrent", "throughput", "sweep"] +LocalStrategyType = Annotated[ + Literal["synchronous", "concurrent", "throughput", "constant", "poisson"], + "Valid strategy type identifiers for scheduling request patterns", +] +StrategyProfileType: TypeAlias = Union[LocalStrategyType, LocalProfileType] + + +class TestGetLiteralVals: + """Test cases for the get_literal_vals function.""" + + @pytest.mark.sanity + def test_profile_type(self): + """ + Test extracting values from ProfileType. + + ### WRITTEN BY AI ### + """ + result = get_literal_vals(LocalProfileType) + expected = frozenset( + {"synchronous", "async", "concurrent", "throughput", "sweep"} + ) + assert result == expected + + @pytest.mark.sanity + def test_strategy_type(self): + """ + Test extracting values from StrategyType. + + ### WRITTEN BY AI ### + """ + result = get_literal_vals(LocalStrategyType) + expected = frozenset( + {"synchronous", "concurrent", "throughput", "constant", "poisson"} + ) + assert result == expected + + @pytest.mark.smoke + def test_inline_union_type(self): + """ + Test extracting values from inline union of ProfileType | StrategyType. + + ### WRITTEN BY AI ### + """ + result = get_literal_vals(Union[LocalProfileType, LocalStrategyType]) + expected = frozenset( + { + "synchronous", + "async", + "concurrent", + "throughput", + "constant", + "poisson", + "sweep", + } + ) + assert result == expected + + @pytest.mark.smoke + def test_type_alias(self): + """ + Test extracting values from type alias union. + + ### WRITTEN BY AI ### + """ + result = get_literal_vals(StrategyProfileType) + expected = frozenset( + { + "synchronous", + "async", + "concurrent", + "throughput", + "constant", + "poisson", + "sweep", + } + ) + assert result == expected + + @pytest.mark.sanity + def test_single_literal(self): + """ + Test extracting values from single Literal type. + + ### WRITTEN BY AI ### + """ + result = get_literal_vals(Literal["test"]) + expected = frozenset({"test"}) + assert result == expected + + @pytest.mark.sanity + def test_multi_literal(self): + """ + Test extracting values from multi-value Literal type. + + ### WRITTEN BY AI ### + """ + result = get_literal_vals(Literal["test", "test2"]) + expected = frozenset({"test", "test2"}) + assert result == expected + + @pytest.mark.smoke + def test_literal_union(self): + """ + Test extracting values from union of Literal types. + + ### WRITTEN BY AI ### + """ + result = get_literal_vals(Union[Literal["test", "test2"], Literal["test3"]]) + expected = frozenset({"test", "test2", "test3"}) + assert result == expected