Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions src/guidellm/__main__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import codecs
from pathlib import Path
from typing import get_args
from typing import Union

import click
from pydantic import ValidationError
Expand All @@ -16,12 +16,10 @@
from guidellm.config import print_config
from guidellm.preprocess.dataset import ShortPromptStrategy, process_dataset
from guidellm.scheduler import StrategyType
from guidellm.utils import DefaultGroupHandler
from guidellm.utils import DefaultGroupHandler, get_literal_vals
from guidellm.utils import cli as cli_tools

STRATEGY_PROFILE_CHOICES = list(
set(list(get_args(ProfileType)) + list(get_args(StrategyType)))
)
STRATEGY_PROFILE_CHOICES = list(get_literal_vals(Union[ProfileType, StrategyType]))


@click.group()
Expand Down Expand Up @@ -70,10 +68,10 @@ def benchmark():
)
@click.option(
"--backend-type",
type=click.Choice(list(get_args(BackendType))),
type=click.Choice(list(get_literal_vals(BackendType))),
help=(
"The type of backend to use to run requests against. Defaults to 'openai_http'."
f" Supported types: {', '.join(get_args(BackendType))}"
f" Supported types: {', '.join(get_literal_vals(BackendType))}"
),
default=GenerativeTextScenario.get_default("backend_type"),
)
Expand Down
2 changes: 2 additions & 0 deletions src/guidellm/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
split_text,
split_text_list_by_length,
)
from .typing import get_literal_vals

__all__ = [
"SUPPORTED_TYPES",
Expand All @@ -67,6 +68,7 @@
"check_load_processor",
"clean_text",
"filter_text",
"get_literal_vals",
"is_punctuation",
"load_text",
"safe_add",
Expand Down
48 changes: 48 additions & 0 deletions src/guidellm/utils/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Annotated, Literal, Union, get_args, get_origin

if TYPE_CHECKING:
from collections.abc import Iterator

# Backwards compatibility for Python <3.10
try:
from types import UnionType # type: ignore[attr-defined]
except ImportError:
UnionType = Union

# Backwards compatibility for Python <3.12
try:
from typing import TypeAliasType # type: ignore[attr-defined]
except ImportError:
from typing_extensions import TypeAliasType


__all__ = ["get_literal_vals"]


def get_literal_vals(alias) -> frozenset[str]:
"""Extract all literal values from a (possibly nested) type alias."""

def resolve(alias) -> Iterator[str]:
origin = get_origin(alias)

# Base case: Literal types
if origin is Literal:
for literal_val in get_args(alias):
yield str(literal_val)
# Unwrap Annotated type
elif origin is Annotated:
yield from resolve(get_args(alias)[0])
# Unwrap TypeAliasTypes
elif isinstance(alias, TypeAliasType):
yield from resolve(alias.__value__)
# Iterate over unions
elif origin in (Union, UnionType):
for arg in get_args(alias):
yield from resolve(arg)
# Fallback
else:
yield str(alias)

return frozenset(resolve(alias))
123 changes: 123 additions & 0 deletions tests/unit/utils/test_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""
Test suite for the typing utilities module.
"""

from typing import Annotated, Literal, Union

import pytest
from typing_extensions import TypeAlias

from guidellm.utils.typing import get_literal_vals

# Local type definitions to avoid imports from other modules
LocalProfileType = Literal["synchronous", "async", "concurrent", "throughput", "sweep"]
LocalStrategyType = Annotated[
Literal["synchronous", "concurrent", "throughput", "constant", "poisson"],
"Valid strategy type identifiers for scheduling request patterns",
]
StrategyProfileType: TypeAlias = Union[LocalStrategyType, LocalProfileType]


class TestGetLiteralVals:
"""Test cases for the get_literal_vals function."""

@pytest.mark.sanity
def test_profile_type(self):
"""
Test extracting values from ProfileType.

### WRITTEN BY AI ###
"""
result = get_literal_vals(LocalProfileType)
expected = frozenset(
{"synchronous", "async", "concurrent", "throughput", "sweep"}
)
assert result == expected

@pytest.mark.sanity
def test_strategy_type(self):
"""
Test extracting values from StrategyType.

### WRITTEN BY AI ###
"""
result = get_literal_vals(LocalStrategyType)
expected = frozenset(
{"synchronous", "concurrent", "throughput", "constant", "poisson"}
)
assert result == expected

@pytest.mark.smoke
def test_inline_union_type(self):
"""
Test extracting values from inline union of ProfileType | StrategyType.

### WRITTEN BY AI ###
"""
result = get_literal_vals(Union[LocalProfileType, LocalStrategyType])
expected = frozenset(
{
"synchronous",
"async",
"concurrent",
"throughput",
"constant",
"poisson",
"sweep",
}
)
assert result == expected

@pytest.mark.smoke
def test_type_alias(self):
"""
Test extracting values from type alias union.

### WRITTEN BY AI ###
"""
result = get_literal_vals(StrategyProfileType)
expected = frozenset(
{
"synchronous",
"async",
"concurrent",
"throughput",
"constant",
"poisson",
"sweep",
}
)
assert result == expected

@pytest.mark.sanity
def test_single_literal(self):
"""
Test extracting values from single Literal type.

### WRITTEN BY AI ###
"""
result = get_literal_vals(Literal["test"])
expected = frozenset({"test"})
assert result == expected

@pytest.mark.sanity
def test_multi_literal(self):
"""
Test extracting values from multi-value Literal type.

### WRITTEN BY AI ###
"""
result = get_literal_vals(Literal["test", "test2"])
expected = frozenset({"test", "test2"})
assert result == expected

@pytest.mark.smoke
def test_literal_union(self):
"""
Test extracting values from union of Literal types.

### WRITTEN BY AI ###
"""
result = get_literal_vals(Union[Literal["test", "test2"], Literal["test3"]])
expected = frozenset({"test", "test2", "test3"})
assert result == expected
Loading