Skip to content

Commit ff90beb

Browse files
committed
Scheduler refactor [utils]: misc extras
commit 906ff49 Author: Samuel Monson <[email protected]> Date: Wed Aug 27 15:54:32 2025 -0400 Add helper for converting literals to list of strings Signed-off-by: Samuel Monson <[email protected]>
1 parent 5fc9fab commit ff90beb

File tree

4 files changed

+178
-7
lines changed

4 files changed

+178
-7
lines changed

src/guidellm/__main__.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import codecs
33
from pathlib import Path
4-
from typing import get_args
4+
from typing import Union
55

66
import click
77
from pydantic import ValidationError
@@ -16,12 +16,10 @@
1616
from guidellm.preprocess.dataset import ShortPromptStrategy, process_dataset
1717
from guidellm.scheduler import StrategyType
1818
from guidellm.settings import print_config
19-
from guidellm.utils import DefaultGroupHandler
19+
from guidellm.utils import DefaultGroupHandler, get_literal_vals
2020
from guidellm.utils import cli as cli_tools
2121

22-
STRATEGY_PROFILE_CHOICES = list(
23-
set(list(get_args(ProfileType)) + list(get_args(StrategyType)))
24-
)
22+
STRATEGY_PROFILE_CHOICES = list(get_literal_vals(Union[ProfileType, StrategyType]))
2523

2624

2725
@click.group()
@@ -70,10 +68,10 @@ def benchmark():
7068
)
7169
@click.option(
7270
"--backend-type",
73-
type=click.Choice(list(get_args(BackendType))),
71+
type=click.Choice(list(get_literal_vals(BackendType))),
7472
help=(
7573
"The type of backend to use to run requests against. Defaults to 'openai_http'."
76-
f" Supported types: {', '.join(get_args(BackendType))}"
74+
f" Supported types: {', '.join(get_literal_vals(BackendType))}"
7775
),
7876
default=GenerativeTextScenario.get_default("backend_type"),
7977
)

src/guidellm/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
split_text,
5757
split_text_list_by_length,
5858
)
59+
from .typing import get_literal_vals
5960

6061
__all__ = [
6162
"SUPPORTED_TYPES",
@@ -91,6 +92,7 @@
9192
"check_load_processor",
9293
"clean_text",
9394
"filter_text",
95+
"get_literal_vals",
9496
"is_punctuation",
9597
"load_text",
9698
"safe_add",

src/guidellm/utils/typing.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Annotated, Literal, Union, get_args, get_origin
4+
5+
if TYPE_CHECKING:
6+
from collections.abc import Iterator
7+
8+
# Backwards compatibility for Python <3.10
9+
try:
10+
from types import UnionType # type: ignore[attr-defined]
11+
except ImportError:
12+
UnionType = Union
13+
14+
# Backwards compatibility for Python <3.12
15+
try:
16+
from typing import TypeAliasType # type: ignore[attr-defined]
17+
except ImportError:
18+
from typing_extensions import TypeAliasType
19+
20+
21+
__all__ = ["get_literal_vals"]
22+
23+
24+
def get_literal_vals(alias) -> frozenset[str]:
25+
"""Extract all literal values from a (possibly nested) type alias."""
26+
27+
def resolve(alias) -> Iterator[str]:
28+
origin = get_origin(alias)
29+
30+
# Base case: Literal types
31+
if origin is Literal:
32+
for literal_val in get_args(alias):
33+
yield str(literal_val)
34+
# Unwrap Annotated type
35+
elif origin is Annotated:
36+
yield from resolve(get_args(alias)[0])
37+
# Unwrap TypeAliasTypes
38+
elif isinstance(alias, TypeAliasType):
39+
yield from resolve(alias.__value__)
40+
# Iterate over unions
41+
elif origin in (Union, UnionType):
42+
for arg in get_args(alias):
43+
yield from resolve(arg)
44+
# Fallback
45+
else:
46+
yield str(alias)
47+
48+
return frozenset(resolve(alias))

tests/unit/utils/test_typing.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
"""
2+
Test suite for the typing utilities module.
3+
"""
4+
5+
from typing import Annotated, Literal, Union
6+
7+
import pytest
8+
from typing_extensions import TypeAlias
9+
10+
from guidellm.utils.typing import get_literal_vals
11+
12+
# Local type definitions to avoid imports from other modules
13+
LocalProfileType = Literal["synchronous", "async", "concurrent", "throughput", "sweep"]
14+
LocalStrategyType = Annotated[
15+
Literal["synchronous", "concurrent", "throughput", "constant", "poisson"],
16+
"Valid strategy type identifiers for scheduling request patterns",
17+
]
18+
StrategyProfileType: TypeAlias = Union[LocalStrategyType, LocalProfileType]
19+
20+
21+
class TestGetLiteralVals:
22+
"""Test cases for the get_literal_vals function."""
23+
24+
@pytest.mark.sanity
25+
def test_profile_type(self):
26+
"""
27+
Test extracting values from ProfileType.
28+
29+
### WRITTEN BY AI ###
30+
"""
31+
result = get_literal_vals(LocalProfileType)
32+
expected = frozenset(
33+
{"synchronous", "async", "concurrent", "throughput", "sweep"}
34+
)
35+
assert result == expected
36+
37+
@pytest.mark.sanity
38+
def test_strategy_type(self):
39+
"""
40+
Test extracting values from StrategyType.
41+
42+
### WRITTEN BY AI ###
43+
"""
44+
result = get_literal_vals(LocalStrategyType)
45+
expected = frozenset(
46+
{"synchronous", "concurrent", "throughput", "constant", "poisson"}
47+
)
48+
assert result == expected
49+
50+
@pytest.mark.smoke
51+
def test_inline_union_type(self):
52+
"""
53+
Test extracting values from inline union of ProfileType | StrategyType.
54+
55+
### WRITTEN BY AI ###
56+
"""
57+
result = get_literal_vals(Union[LocalProfileType, LocalStrategyType])
58+
expected = frozenset(
59+
{
60+
"synchronous",
61+
"async",
62+
"concurrent",
63+
"throughput",
64+
"constant",
65+
"poisson",
66+
"sweep",
67+
}
68+
)
69+
assert result == expected
70+
71+
@pytest.mark.smoke
72+
def test_type_alias(self):
73+
"""
74+
Test extracting values from type alias union.
75+
76+
### WRITTEN BY AI ###
77+
"""
78+
result = get_literal_vals(StrategyProfileType)
79+
expected = frozenset(
80+
{
81+
"synchronous",
82+
"async",
83+
"concurrent",
84+
"throughput",
85+
"constant",
86+
"poisson",
87+
"sweep",
88+
}
89+
)
90+
assert result == expected
91+
92+
@pytest.mark.sanity
93+
def test_single_literal(self):
94+
"""
95+
Test extracting values from single Literal type.
96+
97+
### WRITTEN BY AI ###
98+
"""
99+
result = get_literal_vals(Literal["test"])
100+
expected = frozenset({"test"})
101+
assert result == expected
102+
103+
@pytest.mark.sanity
104+
def test_multi_literal(self):
105+
"""
106+
Test extracting values from multi-value Literal type.
107+
108+
### WRITTEN BY AI ###
109+
"""
110+
result = get_literal_vals(Literal["test", "test2"])
111+
expected = frozenset({"test", "test2"})
112+
assert result == expected
113+
114+
@pytest.mark.smoke
115+
def test_literal_union(self):
116+
"""
117+
Test extracting values from union of Literal types.
118+
119+
### WRITTEN BY AI ###
120+
"""
121+
result = get_literal_vals(Union[Literal["test", "test2"], Literal["test3"]])
122+
expected = frozenset({"test", "test2", "test3"})
123+
assert result == expected

0 commit comments

Comments
 (0)