Skip to content

Commit 6ae996a

Browse files
authored
[Misc] refactor argument parsing in examples (#16635)
Signed-off-by: reidliu41 <[email protected]> Co-authored-by: reidliu41 <[email protected]>
1 parent b590adf commit 6ae996a

25 files changed

+606
-422
lines changed

examples/offline_inference/audio_language.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,33 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
187187
}
188188

189189

190+
def parse_args():
191+
parser = FlexibleArgumentParser(
192+
description='Demo on using vLLM for offline inference with '
193+
'audio language models')
194+
parser.add_argument('--model-type',
195+
'-m',
196+
type=str,
197+
default="ultravox",
198+
choices=model_example_map.keys(),
199+
help='Huggingface "model_type".')
200+
parser.add_argument('--num-prompts',
201+
type=int,
202+
default=1,
203+
help='Number of prompts to run.')
204+
parser.add_argument("--num-audios",
205+
type=int,
206+
default=1,
207+
choices=[0, 1, 2],
208+
help="Number of audio items per prompt.")
209+
parser.add_argument("--seed",
210+
type=int,
211+
default=None,
212+
help="Set the seed when initializing `vllm.LLM`.")
213+
214+
return parser.parse_args()
215+
216+
190217
def main(args):
191218
model = args.model_type
192219
if model not in model_example_map:
@@ -240,28 +267,5 @@ def main(args):
240267

241268

242269
if __name__ == "__main__":
243-
parser = FlexibleArgumentParser(
244-
description='Demo on using vLLM for offline inference with '
245-
'audio language models')
246-
parser.add_argument('--model-type',
247-
'-m',
248-
type=str,
249-
default="ultravox",
250-
choices=model_example_map.keys(),
251-
help='Huggingface "model_type".')
252-
parser.add_argument('--num-prompts',
253-
type=int,
254-
default=1,
255-
help='Number of prompts to run.')
256-
parser.add_argument("--num-audios",
257-
type=int,
258-
default=1,
259-
choices=[0, 1, 2],
260-
help="Number of audio items per prompt.")
261-
parser.add_argument("--seed",
262-
type=int,
263-
default=None,
264-
help="Set the seed when initializing `vllm.LLM`.")
265-
266-
args = parser.parse_args()
270+
args = parse_args()
267271
main(args)

examples/offline_inference/basic/basic.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,23 @@
1212
# Create a sampling params object.
1313
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
1414

15-
# Create an LLM.
16-
llm = LLM(model="facebook/opt-125m")
17-
# Generate texts from the prompts. The output is a list of RequestOutput objects
18-
# that contain the prompt, generated text, and other information.
19-
outputs = llm.generate(prompts, sampling_params)
20-
# Print the outputs.
21-
print("\nGenerated Outputs:\n" + "-" * 60)
22-
for output in outputs:
23-
prompt = output.prompt
24-
generated_text = output.outputs[0].text
25-
print(f"Prompt: {prompt!r}")
26-
print(f"Output: {generated_text!r}")
27-
print("-" * 60)
15+
16+
def main():
17+
# Create an LLM.
18+
llm = LLM(model="facebook/opt-125m")
19+
# Generate texts from the prompts.
20+
# The output is a list of RequestOutput objects
21+
# that contain the prompt, generated text, and other information.
22+
outputs = llm.generate(prompts, sampling_params)
23+
# Print the outputs.
24+
print("\nGenerated Outputs:\n" + "-" * 60)
25+
for output in outputs:
26+
prompt = output.prompt
27+
generated_text = output.outputs[0].text
28+
print(f"Prompt: {prompt!r}")
29+
print(f"Output: {generated_text!r}")
30+
print("-" * 60)
31+
32+
33+
if __name__ == "__main__":
34+
main()

examples/offline_inference/basic/chat.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,24 @@
44
from vllm.utils import FlexibleArgumentParser
55

66

7+
def create_parser():
8+
parser = FlexibleArgumentParser()
9+
# Add engine args
10+
engine_group = parser.add_argument_group("Engine arguments")
11+
EngineArgs.add_cli_args(engine_group)
12+
engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
13+
# Add sampling params
14+
sampling_group = parser.add_argument_group("Sampling parameters")
15+
sampling_group.add_argument("--max-tokens", type=int)
16+
sampling_group.add_argument("--temperature", type=float)
17+
sampling_group.add_argument("--top-p", type=float)
18+
sampling_group.add_argument("--top-k", type=int)
19+
# Add example params
20+
parser.add_argument("--chat-template-path", type=str)
21+
22+
return parser
23+
24+
725
def main(args: dict):
826
# Pop arguments not used by LLM
927
max_tokens = args.pop("max_tokens")
@@ -82,18 +100,6 @@ def print_outputs(outputs):
82100

83101

84102
if __name__ == "__main__":
85-
parser = FlexibleArgumentParser()
86-
# Add engine args
87-
engine_group = parser.add_argument_group("Engine arguments")
88-
EngineArgs.add_cli_args(engine_group)
89-
engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
90-
# Add sampling params
91-
sampling_group = parser.add_argument_group("Sampling parameters")
92-
sampling_group.add_argument("--max-tokens", type=int)
93-
sampling_group.add_argument("--temperature", type=float)
94-
sampling_group.add_argument("--top-p", type=float)
95-
sampling_group.add_argument("--top-k", type=int)
96-
# Add example params
97-
parser.add_argument("--chat-template-path", type=str)
103+
parser = create_parser()
98104
args: dict = vars(parser.parse_args())
99105
main(args)

examples/offline_inference/basic/classify.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,16 @@
66
from vllm.utils import FlexibleArgumentParser
77

88

9+
def parse_args():
10+
parser = FlexibleArgumentParser()
11+
parser = EngineArgs.add_cli_args(parser)
12+
# Set example specific arguments
13+
parser.set_defaults(model="jason9693/Qwen2.5-1.5B-apeach",
14+
task="classify",
15+
enforce_eager=True)
16+
return parser.parse_args()
17+
18+
919
def main(args: Namespace):
1020
# Sample prompts.
1121
prompts = [
@@ -34,11 +44,5 @@ def main(args: Namespace):
3444

3545

3646
if __name__ == "__main__":
37-
parser = FlexibleArgumentParser()
38-
parser = EngineArgs.add_cli_args(parser)
39-
# Set example specific arguments
40-
parser.set_defaults(model="jason9693/Qwen2.5-1.5B-apeach",
41-
task="classify",
42-
enforce_eager=True)
43-
args = parser.parse_args()
47+
args = parse_args()
4448
main(args)

examples/offline_inference/basic/embed.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,16 @@
66
from vllm.utils import FlexibleArgumentParser
77

88

9+
def parse_args():
10+
parser = FlexibleArgumentParser()
11+
parser = EngineArgs.add_cli_args(parser)
12+
# Set example specific arguments
13+
parser.set_defaults(model="intfloat/e5-mistral-7b-instruct",
14+
task="embed",
15+
enforce_eager=True)
16+
return parser.parse_args()
17+
18+
919
def main(args: Namespace):
1020
# Sample prompts.
1121
prompts = [
@@ -34,11 +44,5 @@ def main(args: Namespace):
3444

3545

3646
if __name__ == "__main__":
37-
parser = FlexibleArgumentParser()
38-
parser = EngineArgs.add_cli_args(parser)
39-
# Set example specific arguments
40-
parser.set_defaults(model="intfloat/e5-mistral-7b-instruct",
41-
task="embed",
42-
enforce_eager=True)
43-
args = parser.parse_args()
47+
args = parse_args()
4448
main(args)

examples/offline_inference/basic/generate.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,22 @@
44
from vllm.utils import FlexibleArgumentParser
55

66

7+
def create_parser():
8+
parser = FlexibleArgumentParser()
9+
# Add engine args
10+
engine_group = parser.add_argument_group("Engine arguments")
11+
EngineArgs.add_cli_args(engine_group)
12+
engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
13+
# Add sampling params
14+
sampling_group = parser.add_argument_group("Sampling parameters")
15+
sampling_group.add_argument("--max-tokens", type=int)
16+
sampling_group.add_argument("--temperature", type=float)
17+
sampling_group.add_argument("--top-p", type=float)
18+
sampling_group.add_argument("--top-k", type=int)
19+
20+
return parser
21+
22+
723
def main(args: dict):
824
# Pop arguments not used by LLM
925
max_tokens = args.pop("max_tokens")
@@ -35,23 +51,15 @@ def main(args: dict):
3551
]
3652
outputs = llm.generate(prompts, sampling_params)
3753
# Print the outputs.
54+
print("-" * 50)
3855
for output in outputs:
3956
prompt = output.prompt
4057
generated_text = output.outputs[0].text
41-
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
58+
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
59+
print("-" * 50)
4260

4361

4462
if __name__ == "__main__":
45-
parser = FlexibleArgumentParser()
46-
# Add engine args
47-
engine_group = parser.add_argument_group("Engine arguments")
48-
EngineArgs.add_cli_args(engine_group)
49-
engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
50-
# Add sampling params
51-
sampling_group = parser.add_argument_group("Sampling parameters")
52-
sampling_group.add_argument("--max-tokens", type=int)
53-
sampling_group.add_argument("--temperature", type=float)
54-
sampling_group.add_argument("--top-p", type=float)
55-
sampling_group.add_argument("--top-k", type=int)
63+
parser = create_parser()
5664
args: dict = vars(parser.parse_args())
5765
main(args)

examples/offline_inference/basic/score.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,16 @@
66
from vllm.utils import FlexibleArgumentParser
77

88

9+
def parse_args():
10+
parser = FlexibleArgumentParser()
11+
parser = EngineArgs.add_cli_args(parser)
12+
# Set example specific arguments
13+
parser.set_defaults(model="BAAI/bge-reranker-v2-m3",
14+
task="score",
15+
enforce_eager=True)
16+
return parser.parse_args()
17+
18+
919
def main(args: Namespace):
1020
# Sample prompts.
1121
text_1 = "What is the capital of France?"
@@ -30,11 +40,5 @@ def main(args: Namespace):
3040

3141

3242
if __name__ == "__main__":
33-
parser = FlexibleArgumentParser()
34-
parser = EngineArgs.add_cli_args(parser)
35-
# Set example specific arguments
36-
parser.set_defaults(model="BAAI/bge-reranker-v2-m3",
37-
task="score",
38-
enforce_eager=True)
39-
args = parser.parse_args()
43+
args = parse_args()
4044
main(args)

examples/offline_inference/data_parallel.py

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,40 @@
3434
from vllm.utils import get_open_port
3535

3636

37+
def parse_args():
38+
import argparse
39+
parser = argparse.ArgumentParser(description="Data Parallel Inference")
40+
parser.add_argument("--model",
41+
type=str,
42+
default="ibm-research/PowerMoE-3b",
43+
help="Model name or path")
44+
parser.add_argument("--dp-size",
45+
type=int,
46+
default=2,
47+
help="Data parallel size")
48+
parser.add_argument("--tp-size",
49+
type=int,
50+
default=2,
51+
help="Tensor parallel size")
52+
parser.add_argument("--node-size",
53+
type=int,
54+
default=1,
55+
help="Total number of nodes")
56+
parser.add_argument("--node-rank",
57+
type=int,
58+
default=0,
59+
help="Rank of the current node")
60+
parser.add_argument("--master-addr",
61+
type=str,
62+
default="",
63+
help="Master node IP address")
64+
parser.add_argument("--master-port",
65+
type=int,
66+
default=0,
67+
help="Master node port")
68+
return parser.parse_args()
69+
70+
3771
def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
3872
dp_master_port, GPUs_per_dp_rank):
3973
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
@@ -95,37 +129,8 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
95129

96130

97131
if __name__ == "__main__":
98-
import argparse
99-
parser = argparse.ArgumentParser(description="Data Parallel Inference")
100-
parser.add_argument("--model",
101-
type=str,
102-
default="ibm-research/PowerMoE-3b",
103-
help="Model name or path")
104-
parser.add_argument("--dp-size",
105-
type=int,
106-
default=2,
107-
help="Data parallel size")
108-
parser.add_argument("--tp-size",
109-
type=int,
110-
default=2,
111-
help="Tensor parallel size")
112-
parser.add_argument("--node-size",
113-
type=int,
114-
default=1,
115-
help="Total number of nodes")
116-
parser.add_argument("--node-rank",
117-
type=int,
118-
default=0,
119-
help="Rank of the current node")
120-
parser.add_argument("--master-addr",
121-
type=str,
122-
default="",
123-
help="Master node IP address")
124-
parser.add_argument("--master-port",
125-
type=int,
126-
default=0,
127-
help="Master node port")
128-
args = parser.parse_args()
132+
133+
args = parse_args()
129134

130135
dp_size = args.dp_size
131136
tp_size = args.tp_size

examples/offline_inference/eagle.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def load_prompts(dataset_path, num_prompts):
2727
return prompts[:num_prompts]
2828

2929

30-
def main():
30+
def parse_args():
3131
parser = argparse.ArgumentParser()
3232
parser.add_argument(
3333
"--dataset",
@@ -45,7 +45,12 @@ def main():
4545
parser.add_argument("--enable_chunked_prefill", action='store_true')
4646
parser.add_argument("--max_num_batched_tokens", type=int, default=2048)
4747
parser.add_argument("--temp", type=float, default=0)
48-
args = parser.parse_args()
48+
return parser.parse_args()
49+
50+
51+
def main():
52+
53+
args = parse_args()
4954

5055
model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
5156
eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm"

0 commit comments

Comments
 (0)