Skip to content

Commit 83b0c77

Browse files
committed
Switch to internal dependency, and fix linter errors
1 parent bf3d175 commit 83b0c77

File tree

7 files changed

+129
-19
lines changed

7 files changed

+129
-19
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ dependencies = [
5656
"pyyaml>=6.0.0",
5757
"rich",
5858
"transformers",
59-
"click-default-group~=1.2.4"
6059
]
6160

6261
[project.optional-dependencies]

src/guidellm/__main__.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,21 @@
22
import codecs
33
from pathlib import Path
44
from typing import get_args
5-
from click_default_group import DefaultGroup
65

76
import click
87
from pydantic import ValidationError
98

109
from guidellm.backend import BackendType
1110
from guidellm.benchmark import (
1211
ProfileType,
13-
benchmark_generative_text,
1412
reimport_benchmarks_report,
1513
)
1614
from guidellm.benchmark.entrypoints import benchmark_with_scenario
1715
from guidellm.benchmark.scenario import GenerativeTextScenario, get_builtin_scenarios
1816
from guidellm.config import print_config
1917
from guidellm.preprocess.dataset import ShortPromptStrategy, process_dataset
2018
from guidellm.scheduler import StrategyType
19+
from guidellm.utils import DefaultGroupHandler
2120
from guidellm.utils import cli as cli_tools
2221

2322
STRATEGY_PROFILE_CHOICES = set(
@@ -29,11 +28,11 @@
2928
def cli():
3029
pass
3130

31+
3232
@cli.group(
3333
help="Commands to run a new benchmark or load a prior one.",
34-
cls=DefaultGroup,
34+
cls=DefaultGroupHandler,
3535
default="run",
36-
default_if_no_args=True,
3736
)
3837
def benchmark():
3938
pass
@@ -334,15 +333,15 @@ def run(
334333
is_flag=False,
335334
flag_value=Path.cwd() / "benchmarks_reexported.json",
336335
help=(
337-
"Allows re-exporting the benchmarks to another format."
336+
"Allows re-exporting the benchmarks to another format. "
338337
"The path to save the output to. If it is a directory, "
339338
"it will save benchmarks.json under it. "
340339
"Otherwise, json, yaml, or csv files are supported for output types "
341-
"which will be read from the extension for the file path."
342-
"Optional. If the output path flag is not provided, the benchmarks "
343-
"will not be reexported. If the flag is present but no value is "
344-
"specified, it will default to the current directory with the file "
345-
"name benchmarks_reexported.json."
340+
"which will be read from the extension for the file path. "
341+
"This input is optional. If the output path flag is not provided, "
342+
"the benchmarks will not be reexported. If the flag is present but "
343+
"no value is specified, it will default to the current directory "
344+
"with the file name `benchmarks_reexported.json`."
346345
),
347346
)
348347
def from_file(path, output_path):
@@ -368,7 +367,7 @@ def decode_escaped_str(_ctx, _param, value):
368367
help=(
369368
"Print out the available configuration settings that can be set "
370369
"through environment variables."
371-
)
370+
),
372371
)
373372
def config():
374373
print_config()

src/guidellm/benchmark/entrypoints.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ async def benchmark_generative_text(
147147

148148
return report, saved_path
149149

150+
150151
def reimport_benchmarks_report(file: Path, output_path: Optional[Path]) -> None:
151152
"""
152153
The command-line entry point for re-importing and displaying an

src/guidellm/benchmark/output.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,10 @@ def _file_setup(
242242
if path_suffix in [".csv"]:
243243
return path, "csv"
244244

245-
raise ValueError(f"Unsupported file extension: {path_suffix} for {path}; expected json, yaml, or csv.")
245+
raise ValueError(
246+
f"Unsupported file extension: {path_suffix} for {path}; "
247+
"expected json, yaml, or csv."
248+
)
246249

247250
@staticmethod
248251
def _benchmark_desc_headers_and_values(

src/guidellm/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .colors import Colors
2+
from .default_group import DefaultGroupHandler
23
from .hf_datasets import (
34
SUPPORTED_TYPES,
45
save_dataset_to_file,
@@ -20,6 +21,7 @@
2021
__all__ = [
2122
"SUPPORTED_TYPES",
2223
"Colors",
24+
"DefaultGroupHandler",
2325
"EndlessTextCreator",
2426
"IntegerRangeSampler",
2527
"check_load_processor",
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
"""
2+
File uses code adapted from code with the following license:
3+
4+
Copyright (c) 2015-2023, Heungsub Lee
5+
All rights reserved.
6+
7+
Redistribution and use in source and binary forms, with or without modification,
8+
are permitted provided that the following conditions are met:
9+
10+
Redistributions of source code must retain the above copyright notice, this
11+
list of conditions and the following disclaimer.
12+
13+
Redistributions in binary form must reproduce the above copyright notice, this
14+
list of conditions and the following disclaimer in the documentation and/or
15+
other materials provided with the distribution.
16+
17+
Neither the name of the copyright holder nor the names of its
18+
contributors may be used to endorse or promote products derived from
19+
this software without specific prior written permission.
20+
21+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
22+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
23+
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
25+
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26+
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27+
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
28+
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30+
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31+
"""
32+
33+
__all__ = ["DefaultGroupHandler"]
34+
35+
import click
36+
37+
38+
class DefaultGroupHandler(click.Group):
39+
"""
40+
Allows the migration to a new sub-command by allowing the group to run
41+
one of its sub-commands as the no-args default command.
42+
"""
43+
44+
def __init__(self, *args, **kwargs):
45+
# To resolve as the default command.
46+
if not kwargs.get('ignore_unknown_options', True):
47+
raise ValueError('Default group accepts unknown options')
48+
self.ignore_unknown_options = True
49+
self.default_cmd_name = kwargs.pop('default', None)
50+
self.default_if_no_args = kwargs.pop('default_if_no_args', False)
51+
super(DefaultGroupHandler, self).__init__(*args, **kwargs)
52+
53+
def parse_args(self, ctx, args):
54+
if not args and self.default_if_no_args:
55+
args.insert(0, self.default_cmd_name)
56+
return super(DefaultGroupHandler, self).parse_args(ctx, args)
57+
58+
def get_command(self, ctx, cmd_name):
59+
if cmd_name not in self.commands:
60+
# If it doesn't match an existing command, use the default command name.
61+
ctx.arg0 = cmd_name
62+
cmd_name = self.default_cmd_name
63+
return super(DefaultGroupHandler, self).get_command(ctx, cmd_name)
64+
65+
def resolve_command(self, ctx, args):
66+
base = super(DefaultGroupHandler, self)
67+
cmd_name, cmd, args = base.resolve_command(ctx, args)
68+
if hasattr(ctx, 'arg0'):
69+
args.insert(0, ctx.arg0)
70+
cmd_name = cmd.name
71+
return cmd_name, cmd, args
72+
73+
def format_commands(self, ctx, formatter):
74+
"""
75+
Used to wrap the default formatter to clarify which command is the default.
76+
"""
77+
formatter = DefaultCommandFormatter(self, formatter, mark=' (default)')
78+
return super(DefaultGroupHandler, self).format_commands(ctx, formatter)
79+
80+
81+
class DefaultCommandFormatter(object):
82+
"""
83+
Wraps a formatter to edit the line for the default command to mark it
84+
with the specified mark string.
85+
"""
86+
87+
def __init__(self, group, formatter, mark='*'):
88+
self.group = group
89+
self.formatter = formatter
90+
self.mark = mark
91+
super().__init__()
92+
93+
def __getattr__(self, attr):
94+
return getattr(self.formatter, attr)
95+
96+
def write_dl(self, rows, *args, **kwargs):
97+
rows_ = []
98+
for cmd_name, help_msg in rows:
99+
if cmd_name == self.group.default_cmd_name:
100+
rows_.insert(0, (cmd_name + self.mark, help_msg))
101+
else:
102+
rows_.append((cmd_name, help_msg))
103+
return self.formatter.write_dl(rows_, *args, **kwargs)

tests/unit/entrypoints/test_benchmark_from_file_entrypoint.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1+
import filecmp
12
import os
23
import unittest
34
from pathlib import Path
45

56
import pytest
6-
import filecmp
77

88
from guidellm.benchmark import reimport_benchmarks_report
99

@@ -18,13 +18,14 @@ def _() -> Path:
1818

1919
return _
2020

21+
2122
@pytest.fixture
2223
def cleanup():
23-
to_delete = []
24+
to_delete: list[Path] = []
2425
yield to_delete
2526
for item in to_delete:
26-
if os.path.exists(item):
27-
os.remove(item)
27+
if item.exists():
28+
item.unlink() # Deletes the file
2829

2930

3031
def test_display_entrypoint_json(capfd, get_test_asset_dir):
@@ -58,14 +59,16 @@ def generic_test_display_entrypoint(filename, capfd, get_test_asset_dir):
5859
expected_output = file.read()
5960
assert out == expected_output
6061

62+
6163
def test_reexporting_benchmark(get_test_asset_dir, cleanup):
6264
asset_dir = get_test_asset_dir()
6365
source_file = asset_dir / "benchmarks_stripped.json"
6466
exported_file = asset_dir / "benchmarks_reexported.json"
65-
# If you need to inspect the output to see why it failed, comment out the following statement.
67+
# If you need to inspect the output to see why it failed, comment out
68+
# the cleanup statement.
6669
cleanup.append(exported_file)
6770
if exported_file.exists():
68-
os.remove(exported_file)
71+
exported_file.unlink()
6972
reimport_benchmarks_report(source_file, exported_file)
7073
# The reexported file should exist and be identical to the source.
7174
assert exported_file.exists()

0 commit comments

Comments
 (0)