Skip to content

Commit efbce85

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
andauthored
[misc] Layerwise profile updates (#10242)
Signed-off-by: Varun Sundar Rabindranath <[email protected]> Co-authored-by: Varun Sundar Rabindranath <[email protected]>
1 parent 2ca830d commit efbce85

File tree

5 files changed

+314
-47
lines changed

5 files changed

+314
-47
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ steps:
201201
- python3 offline_inference_classification.py
202202
- python3 offline_inference_embedding.py
203203
- python3 offline_inference_scoring.py
204-
- python3 offline_profile.py --model facebook/opt-125m
204+
- python3 offline_profile.py --model facebook/opt-125m run_num_steps --num-steps 2
205205

206206
- label: Prefix Caching Test # 9min
207207
mirror_hardwares: [amd]

examples/offline_profile.py

Lines changed: 206 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
import sys
55
from argparse import RawTextHelpFormatter
66
from dataclasses import asdict, dataclass
7-
from typing import Optional
7+
from typing import Any, Dict, Generator, List, Optional, TypeAlias
88

99
import torch
10+
import tqdm
1011

1112
from vllm import LLM, SamplingParams
1213
from vllm.engine.arg_utils import EngineArgs
@@ -15,16 +16,21 @@
1516

1617
BATCH_SIZE_DEFAULT = 1
1718
PROMPT_LEN_DEFAULT = 256
18-
OUTPUT_LEN_DEFAULT = 2
1919

2020

2121
@dataclass
2222
class ProfileContext:
2323
engine_args: EngineArgs
2424
prompt_len: int
25-
output_len: int
2625
batch_size: int
27-
save_chrome_traces_folder: Optional[str]
26+
27+
# The profiler can run in 2 modes,
28+
# 1. Run profiler for user specified num_steps
29+
num_steps: Optional[int] = None
30+
# 2. Run profiler until all requests complete
31+
complete_num_requests_per_step: Optional[int] = None
32+
33+
save_chrome_traces_folder: Optional[str] = None
2834

2935

3036
def get_dtype(dtype: str):
@@ -34,23 +40,155 @@ def get_dtype(dtype: str):
3440
return dtype
3541

3642

43+
OutputLen_NumReqs_Map: TypeAlias = Dict[int, int]
44+
def compute_request_output_lengths(batch_size: int, step_requests: List[int]) \
45+
-> OutputLen_NumReqs_Map:
46+
"""
47+
Given the number of requests, batch_size, and the number of requests
48+
that each engine-step should process, step_requests, determine the
49+
output lengths of the requests such that step_request is honoured.
50+
51+
Example:
52+
if batch size = 128 and step_request = [128, 128, 96, 64, 32, 1]
53+
then return,
54+
{2 : 32, 3 : 32, 4 : 32, 5 : 31, 6 : 1}, meaning,
55+
32 requests should have output length 2,
56+
32 requests should have output length 3,
57+
32 requests should have output length 4,
58+
31 requests should have output length 5,
59+
1 request should have output length 6.
60+
61+
Args:
62+
batch_size (int): Number of requests submitted for profile. This is
63+
args.batch_size.
64+
step_requests (List[int]): step_requests[i] is the number of requests
65+
that the ith engine step should process.
66+
67+
Returns:
68+
OutputLen_NumReqs_Map : A dictionary with output-length as keys and the
69+
number of requests required to have that output-length as values.
70+
"""
71+
ol_nr: OutputLen_NumReqs_Map = {}
72+
73+
# Number of request that are assigned an output-length
74+
num_reqs_assigned: int = 0
75+
num_steps: int = len(step_requests)
76+
77+
# sanity check. The first step (prefill-step), must process all requests.
78+
assert step_requests[0] == batch_size
79+
80+
# Begin assignments from the last step.
81+
output_length: int = num_steps
82+
for num_requests_at_step in reversed(step_requests):
83+
if num_reqs_assigned == batch_size:
84+
break
85+
86+
assert num_reqs_assigned < batch_size
87+
88+
# Remove the number of requests that have been determined
89+
# to participate in this step and beyond.
90+
num_reqs_unassigned_at_step = num_requests_at_step - num_reqs_assigned
91+
assert num_reqs_unassigned_at_step >= 0
92+
93+
if num_reqs_unassigned_at_step > 0:
94+
ol_nr[output_length] = num_reqs_unassigned_at_step
95+
num_reqs_assigned += num_reqs_unassigned_at_step
96+
97+
output_length -= 1
98+
99+
# sanity checks.
100+
assert sum(ol_nr.values()) == batch_size, \
101+
("Number of requests in output-length assignment does not match "
102+
f"batch-size.\n batch size {batch_size} - "
103+
f"step requests {step_requests} - assignments {ol_nr}")
104+
105+
# Check that the output-length is in [1, num-steps]. Output length must be
106+
# at least 1 as all requests must participate in the prefill-step.
107+
assert all(ol >= 1 and ol <= num_steps for ol in ol_nr), \
108+
("Output lengths of requests should be in range "
109+
f"[1, num-engine-steps].\n batch size {batch_size} - "
110+
f"step requests {step_requests} - assignments {ol_nr}")
111+
112+
return ol_nr
113+
114+
115+
def determine_requests_per_step(context: ProfileContext) -> List[int]:
116+
"""
117+
Determine number of requests each engine step should process.
118+
If context.num_steps is set, then all engine steps process the
119+
same number of requests and the output list is of length
120+
context.num_steps.
121+
122+
If context.complete_num_requests_per_step is set, then each decode step
123+
processes fewer and fewer requests until there are no requests to process.
124+
In this case, the output list is as big as the number of steps
125+
required to process all requests.
126+
127+
Args:
128+
context: ProfileContext object.
129+
130+
Returns:
131+
List[int]: Number of requests to process for all engine-steps.
132+
output[i], contains the number of requests that the ith step
133+
should process.
134+
"""
135+
if context.num_steps:
136+
# All requests must run until num_engine_steps. This implies
137+
# that their output lengths must be equal to num_engine_steps.
138+
return [context.batch_size] * context.num_steps
139+
140+
assert context.complete_num_requests_per_step and \
141+
context.complete_num_requests_per_step > 0, \
142+
(f"Expected a positive complete_num_requests_per_step argument."
143+
f"Instead got {context.complete_num_requests_per_step}")
144+
145+
# We start dropping after the first decode step.
146+
step_requests = [
147+
context.batch_size, # prefill
148+
context.batch_size, # decode
149+
]
150+
151+
num_running_requests = context.batch_size
152+
num_running_requests -= context.complete_num_requests_per_step
153+
while num_running_requests > 0:
154+
step_requests.append(num_running_requests)
155+
num_running_requests -= context.complete_num_requests_per_step
156+
157+
if step_requests[-1] != 1:
158+
# have 1 request running at the last step. This is often
159+
# useful
160+
step_requests.append(1)
161+
162+
return step_requests
163+
164+
37165
def run_profile(context: ProfileContext, csv_output: Optional[str],
38166
json_output: Optional[str]):
39167
print("Run profile with:")
40168
for key, value in asdict(context).items():
41169
print(f" {key} = {value}")
42170

171+
requests_per_step: List[int] = determine_requests_per_step(context)
172+
173+
ol_nr: OutputLen_NumReqs_Map = compute_request_output_lengths(
174+
context.batch_size, requests_per_step)
175+
176+
num_steps_to_profile: int = len(requests_per_step)
177+
max_output_len: int = max(ol_nr.keys())
178+
assert max_output_len >= 1
179+
43180
# Create sampling params
44-
sampling_params = SamplingParams(temperature=0.8,
45-
top_p=0.95,
46-
max_tokens=args.output_len,
47-
ignore_eos=True)
181+
sampling_params = SamplingParams(
182+
temperature=0.8,
183+
top_p=0.95,
184+
# max_tokens is set on a per-request basis.
185+
max_tokens=None,
186+
ignore_eos=True)
48187

49188
# Create LLM
50189
llm = LLM(**asdict(context.engine_args))
51190
batch_size = context.batch_size
52191
prompt_len = context.prompt_len
53-
output_len = context.output_len
54192

55193
scheduler_config = llm.llm_engine.scheduler_config
56194
max_model_len = llm.llm_engine.model_config.max_model_len
@@ -65,24 +203,34 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
65203
f"choose a smaller batch size or prompt length, or increase "
66204
f"--max-num-batched-tokens")
67205
sys.exit(-1)
68-
if batch_size >= max_num_seqs:
206+
if batch_size > max_num_seqs:
69207
print(
70208
f"ERROR: chosen batch_size ({batch_size}) is larger than "
71209
f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a "
72210
f"single profile step, please choose a smaller batch size")
73211
sys.exit(-1)
74212
print("llm.llm_engine.model_config.max_model_len: ",
75213
llm.llm_engine.model_config.max_model_len)
76-
if prompt_len + output_len > llm.llm_engine.model_config.max_model_len:
77-
print(
78-
f"ERROR: chosen prompt_len + output_len ({prompt_len} + "
79-
f"{output_len} = {prompt_len + output_len}) is larger than the "
80-
f"model's max_model_len ({max_model_len}), please choose a smaller "
81-
f"prompt_len or output_len, or increase --max-model-len")
214+
if prompt_len + max_output_len > llm.llm_engine.model_config.max_model_len:
215+
print(f"ERROR: chosen prompt_len + max_output_len ({prompt_len} + "
216+
f"{max_output_len} = {prompt_len + max_output_len}) is larger "
217+
f"than the model's max_model_len ({max_model_len}), please "
218+
f"choose a smaller prompt_len or max_output_len, or increase "
219+
f"--max-model-len")
82220
sys.exit(-1)
83221

84222
def add_requests():
223+
224+
def get_output_len_generator() -> Generator[int, Any, Any]:
225+
for output_len, num_reqs in ol_nr.items():
226+
for _ in range(num_reqs):
227+
yield output_len
228+
229+
output_len_generator = get_output_len_generator()
85230
for i in range(batch_size):
231+
sampling_params.max_tokens = next(output_len_generator)
232+
assert isinstance(sampling_params.max_tokens, int)
233+
86234
prompt_token_ids = torch.randint(
87235
llm.llm_engine.model_config.get_vocab_size(),
88236
size=(prompt_len, )).tolist()
@@ -110,8 +258,11 @@ def abort_requests():
110258
llm.llm_engine.step() # First step is prefill
111259

112260
decode_profs = []
113-
for x in range(args.output_len - 1):
114-
with layerwise_profile() as decode_prof:
261+
for _ in tqdm.tqdm(range(num_steps_to_profile - 1)):
262+
num_running_seqs = llm.llm_engine.scheduler[
263+
0].get_num_unfinished_seq_groups()
264+
with layerwise_profile(
265+
num_running_seqs=num_running_seqs) as decode_prof:
115266
llm.llm_engine.step()
116267
decode_profs.append(decode_prof)
117268

@@ -154,7 +305,8 @@ def abort_requests():
154305
decode_results_list[0].print_summary_table()
155306

156307
if csv_output:
157-
csv_filename_base = csv_output.rstrip(".csv")
308+
csv_filename_base = csv_output[:-4] \
309+
if csv_output.endswith('.csv') else csv_output
158310
prefill_results.export_model_stats_table_csv(
159311
csv_filename_base + "_prefill_model_table.csv")
160312
prefill_results.export_summary_stats_table_csv(
@@ -187,10 +339,10 @@ def abort_requests():
187339
for idx, dr in enumerate(decode_results_list):
188340
json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict()
189341

190-
for idx, dr in enumerate(decode_results_list[1:]):
191-
json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict()
192-
193-
with open(json_output.rstrip(".json") + ".json", "w+") as f:
342+
# Add .json to json_output filename if it doesn't exist already.
343+
json_output_file = json_output if json_output.endswith(
344+
'.json') else json_output + '.json'
345+
with open(json_output_file, "w+") as f:
194346
json.dump(json_dict, f, indent=2)
195347
pass
196348

@@ -214,7 +366,7 @@ def abort_requests():
214366
python examples/offline_profile.py \\
215367
--model neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 --batch-size 4 \\
216368
--prompt-len 512 --max-num-batched-tokens 8196 --json Llama31-8b-FP8 \\
217-
--enforce-eager
369+
--enforce-eager run_num_steps -n 2
218370
```
219371
220372
then you can use various tools to analyze the json output
@@ -261,17 +413,41 @@ def abort_requests():
261413
default=BATCH_SIZE_DEFAULT,
262414
help=f"Number of requests to run as a single batch, "
263415
f"default={BATCH_SIZE_DEFAULT}")
264-
parser.add_argument(
265-
"--output-len",
416+
417+
subparsers = parser.add_subparsers(dest="cmd")
418+
419+
run_num_steps_parser = subparsers.add_parser(
420+
"run_num_steps",
421+
help="This variation profiles n engine.step() invocations.")
422+
run_num_steps_parser.add_argument(
423+
'-n',
424+
'--num-steps',
266425
type=int,
267-
default=OUTPUT_LEN_DEFAULT,
268-
help="Number of llm steps to run (includes prefill and decode) "
269-
"- default={OUTPUT_LEN_DEFAULT}")
426+
help="Number of engine steps to profile.\n"
427+
"Setting it to 1, profiles only the prefill step.\n"
428+
"Setting it to 2, profiles the prefill and first decode step\n"
429+
"Setting it to 3, profiles the prefill, 1st and 2nd decode steps\n"
430+
"and so on ...")
431+
432+
run_to_completion_parser = subparsers.add_parser(
433+
"run_to_completion",
434+
help="This variation profiles all the engine.step() invocations"
435+
"until the engine exhausts all submitted requests.")
436+
run_to_completion_parser.add_argument(
437+
'-n',
438+
'--complete-num-requests-per-step',
439+
type=int,
440+
help=
441+
"Complete complete_num_requests_per_step requests every decode step."
442+
"For e.g., with batch_size 128 and complete_num_requests_per_step 32,"
443+
"the profiler is run for 6 engine steps, with the steps processing, "
444+
"128, 128, 96, 64, 32, 1 requests respectively.\n"
445+
"Note that we tack-on a one-request step at the end as it is often "
446+
"useful.")
270447

271448
EngineArgs.add_cli_args(parser)
272449

273450
args = parser.parse_args()
274-
275451
context = ProfileContext(
276452
engine_args=EngineArgs.from_cli_args(args),
277453
**{

tools/profiler/print_layerwise_table.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,10 @@ def get_entries(node, curr_depth=0):
3434
"examples/offline_profile.py")
3535
parser.add_argument("--phase",
3636
type=str,
37-
choices=["prefill", "decode_1"],
3837
required=True,
39-
help="The phase to print the table for.")
38+
help="The phase to print the table for. This is either"
39+
"prefill or decode_n, where n is the decode step "
40+
"number")
4041
parser.add_argument("--table",
4142
type=str,
4243
choices=["summary", "model"],
@@ -49,6 +50,10 @@ def get_entries(node, curr_depth=0):
4950
with open(args.json_trace) as f:
5051
profile_data = json.load(f)
5152

53+
assert args.phase in profile_data, \
54+
(f"Cannot find phase {args.phase} in profile data. Choose one among"
55+
f'{[x for x in profile_data.keys() if "prefill" in x or "decode" in x]}') #noqa
56+
5257
if args.table == "summary":
5358
entries_and_depths = flatten_entries(
5459
SummaryStatsEntry, profile_data[args.phase]["summary_stats"])

0 commit comments

Comments
 (0)