Skip to content

Commit 9d30a05

Browse files
LucasWilkinsonVarun Sundar Rabindranathmgoin
authored
[misc] CUDA Time Layerwise Profiler (#8337)
Co-authored-by: Varun Sundar Rabindranath <[email protected]> Co-authored-by: Michael Goin <[email protected]>
1 parent 390be74 commit 9d30a05

File tree

8 files changed

+1390
-4
lines changed

8 files changed

+1390
-4
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ steps:
184184
- python3 offline_inference_vision_language_multi_image.py
185185
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
186186
- python3 offline_inference_encoder_decoder.py
187+
- python3 offline_profile.py --model facebook/opt-125m
187188

188189
- label: Prefix Caching Test # 9min
189190
#mirror_hardwares: [amd]

examples/offline_profile.py

Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
import inspect
2+
import json
3+
import os
4+
import sys
5+
from argparse import RawTextHelpFormatter
6+
from dataclasses import asdict, dataclass
7+
from typing import Optional
8+
9+
import torch
10+
11+
from vllm import LLM, SamplingParams
12+
from vllm.engine.arg_utils import EngineArgs
13+
from vllm.profiler import layerwise_profile
14+
from vllm.utils import FlexibleArgumentParser
15+
16+
BATCH_SIZE_DEFAULT = 1
17+
PROMPT_LEN_DEFAULT = 256
18+
OUTPUT_LEN_DEFAULT = 2
19+
20+
21+
@dataclass
22+
class ProfileContext:
23+
engine_args: EngineArgs
24+
prompt_len: int
25+
output_len: int
26+
batch_size: int
27+
save_chrome_traces_folder: Optional[str]
28+
29+
30+
def get_dtype(dtype: str):
31+
if dtype == "torch.float":
32+
return torch.float
33+
else:
34+
return dtype
35+
36+
37+
def run_profile(context: ProfileContext, csv_output: Optional[str],
38+
json_output: Optional[str]):
39+
print("Run profile with:")
40+
for key, value in asdict(context).items():
41+
print(f" {key} = {value}")
42+
43+
# Create sampling params
44+
sampling_params = SamplingParams(temperature=0.8,
45+
top_p=0.95,
46+
max_tokens=args.output_len,
47+
ignore_eos=True)
48+
49+
# Create LLM
50+
llm = LLM(**asdict(context.engine_args))
51+
batch_size = context.batch_size
52+
prompt_len = context.prompt_len
53+
output_len = context.output_len
54+
55+
scheduler_config = llm.llm_engine.scheduler_config
56+
max_model_len = llm.llm_engine.model_config.max_model_len
57+
max_num_batched_tokens = scheduler_config.max_num_batched_tokens
58+
max_num_seqs = scheduler_config.max_num_seqs
59+
60+
if batch_size * prompt_len > max_num_batched_tokens:
61+
print(f"ERROR: chosen batch_size * prompt_len "
62+
f"({batch_size} * {prompt_len} = {batch_size * prompt_len}) is "
63+
f"larger than max_num_batched_tokens ({max_num_batched_tokens}) "
64+
f"and therefore cannot be run in a single profile step, please "
65+
f"choose a smaller batch size or prompt length, or increase "
66+
f"--max-num-batched-tokens")
67+
sys.exit(-1)
68+
if batch_size >= max_num_seqs:
69+
print(
70+
f"ERROR: chosen batch_size ({batch_size}) is larger than "
71+
f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a "
72+
f"single profile step, please choose a smaller batch size")
73+
sys.exit(-1)
74+
print("llm.llm_engine.model_config.max_model_len: ",
75+
llm.llm_engine.model_config.max_model_len)
76+
if prompt_len + output_len > llm.llm_engine.model_config.max_model_len:
77+
print(
78+
f"ERROR: chosen prompt_len + output_len ({prompt_len} + "
79+
f"{output_len} = {prompt_len + output_len}) is larger than the "
80+
f"model's max_model_len ({max_model_len}), please choose a smaller "
81+
f"prompt_len or output_len, or increase --max-model-len")
82+
sys.exit(-1)
83+
84+
def add_requests():
85+
for i in range(batch_size):
86+
prompt_token_ids = torch.randint(
87+
llm.llm_engine.model_config.get_vocab_size(),
88+
size=(prompt_len, )).tolist()
89+
90+
llm.llm_engine.add_request(
91+
request_id=f"seq{i}",
92+
prompt={'prompt_token_ids': prompt_token_ids},
93+
params=sampling_params)
94+
95+
def abort_requests():
96+
for i in range(batch_size):
97+
llm.llm_engine.abort_request(f"seq{i}")
98+
99+
# Warm up run
100+
print("Warm up run ...")
101+
add_requests()
102+
llm.llm_engine.step() # Prefill
103+
llm.llm_engine.step() # Decode
104+
abort_requests()
105+
106+
print("Profile run ...")
107+
add_requests()
108+
109+
with layerwise_profile() as prefill_prof:
110+
llm.llm_engine.step() # First step is prefill
111+
112+
decode_profs = []
113+
for x in range(args.output_len - 1):
114+
with layerwise_profile() as decode_prof:
115+
llm.llm_engine.step()
116+
decode_profs.append(decode_prof)
117+
118+
decode_results_list = [prof.results for prof in decode_profs]
119+
prefill_results = prefill_prof.results
120+
has_decode = len(decode_results_list) > 0
121+
122+
LINE_WIDTH = 80
123+
print("=" * LINE_WIDTH)
124+
print(f"= Prefill Model Table "
125+
f"(prompt_len={prompt_len}, batch_size={batch_size})")
126+
print("=" * LINE_WIDTH)
127+
print()
128+
prefill_results.print_model_table()
129+
130+
if has_decode:
131+
print()
132+
print("=" * LINE_WIDTH)
133+
print(f"= First Decode Step Model Table "
134+
f"(prompt_len={prompt_len}, batch_size={batch_size})")
135+
print("=" * LINE_WIDTH)
136+
print()
137+
decode_results_list[0].print_model_table()
138+
139+
print()
140+
print("=" * LINE_WIDTH)
141+
print(f"= Prefill Summary Table "
142+
f"(prompt_len={prompt_len}, batch_size={batch_size})")
143+
print("=" * LINE_WIDTH)
144+
print()
145+
prefill_results.print_summary_table()
146+
147+
if has_decode:
148+
print()
149+
print("=" * LINE_WIDTH)
150+
print(f"= First Decode Step Summary Table "
151+
f"(prompt_len={prompt_len}, batch_size={batch_size})")
152+
print("=" * LINE_WIDTH)
153+
print()
154+
decode_results_list[0].print_summary_table()
155+
156+
if csv_output:
157+
csv_filename_base = csv_output.rstrip(".csv")
158+
prefill_results.export_model_stats_table_csv(
159+
csv_filename_base + "_prefill_model_table.csv")
160+
prefill_results.export_summary_stats_table_csv(
161+
csv_filename_base + "_prefill_summary_table.csv")
162+
163+
if has_decode:
164+
decode_results_list[0].export_model_stats_table_csv(\
165+
csv_filename_base + "_decode_model_table.csv")
166+
decode_results_list[0].export_summary_stats_table_csv(
167+
csv_filename_base + "_decode_summary_table.csv")
168+
169+
if json_output:
170+
cuda_devices = [
171+
torch.cuda.get_device_properties(dev_idx)
172+
for dev_idx in range(torch.cuda.device_count())
173+
]
174+
175+
json_dict = {
176+
"context": {
177+
"python_version": f"{sys.version}",
178+
"torch_version": f"{torch.__version__}",
179+
"torch_cuda_version": f"{torch.version.cuda}",
180+
"cuda_devices": f"{cuda_devices}",
181+
**asdict(context)
182+
},
183+
"prefill": prefill_results.convert_stats_to_dict(),
184+
}
185+
186+
if has_decode:
187+
for idx, dr in enumerate(decode_results_list):
188+
json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict()
189+
190+
for idx, dr in enumerate(decode_results_list[1:]):
191+
json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict()
192+
193+
with open(json_output.rstrip(".json") + ".json", "w+") as f:
194+
json.dump(json_dict, f, indent=2)
195+
pass
196+
197+
if context.save_chrome_traces_folder is not None:
198+
os.makedirs(context.save_chrome_traces_folder, exist_ok=True)
199+
prefill_prof.profiler.export_chrome_trace(
200+
context.save_chrome_traces_folder + "/prefill.json")
201+
for idx, decode_prof in enumerate(decode_profs):
202+
decode_prof.profiler.export_chrome_trace(
203+
context.save_chrome_traces_folder + f"/decode_{idx + 1}.json")
204+
print("Traces saved as prefill.json and decode_1.json, etc."
205+
f" in folder {context.save_chrome_traces_folder}")
206+
207+
208+
if __name__ == "__main__":
209+
parser = FlexibleArgumentParser(description="""
210+
Profile a model
211+
212+
example:
213+
```
214+
python examples/offline_profile.py \\
215+
--model neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 --batch-size 4 \\
216+
--prompt-len 512 --max-num-batched-tokens 8196 --json Llama31-8b-FP8 \\
217+
--enforce-eager
218+
```
219+
220+
then you can use various tools to analyze the json output
221+
terminal ascii tables:
222+
```
223+
python tools/profiler/print_layerwise_table.py \\
224+
--json-trace Llama31-8b-FP8.json --phase prefill --table summary
225+
```
226+
or create matplotlib stacked bar charts:
227+
```
228+
python tools/profiler/visualize_layerwise_profile.py \\
229+
--json-trace Llama31-8b-FP8.json \\
230+
--output-directory profile_breakdown --plot-metric pct_cuda_time
231+
```
232+
""",
233+
formatter_class=RawTextHelpFormatter)
234+
parser.add_argument(
235+
"--csv",
236+
type=str,
237+
default=None,
238+
help="Export the results as multiple csv file. This should be the root "
239+
"filename, will create <filename>_prefill_model_table.csv, "
240+
"<filename>_prefill_summary_table.csv, "
241+
"<filename>_decode_model_table.csv, and "
242+
"<filename>_decode_summary_table.csv")
243+
parser.add_argument(
244+
"--json",
245+
type=str,
246+
default=None,
247+
help="Export the results as a json file. This should be the filename")
248+
parser.add_argument("--save-chrome-traces-folder",
249+
type=str,
250+
help="Save chrome traces for the prefill and decode "
251+
"will save traces as prefill.json and decode_1.json, "
252+
"etc. inside this folder")
253+
parser.add_argument(
254+
"--prompt-len",
255+
type=int,
256+
default=PROMPT_LEN_DEFAULT,
257+
help=f"Length of the random prompt to use when profiling, all batched "
258+
f"requests use the same prompt_len, default={PROMPT_LEN_DEFAULT}")
259+
parser.add_argument("--batch-size",
260+
type=int,
261+
default=BATCH_SIZE_DEFAULT,
262+
help=f"Number of requests to run as a single batch, "
263+
f"default={BATCH_SIZE_DEFAULT}")
264+
parser.add_argument(
265+
"--output-len",
266+
type=int,
267+
default=OUTPUT_LEN_DEFAULT,
268+
help="Number of llm steps to run (includes prefill and decode) "
269+
"- default={OUTPUT_LEN_DEFAULT}")
270+
271+
EngineArgs.add_cli_args(parser)
272+
273+
args = parser.parse_args()
274+
275+
context = ProfileContext(
276+
engine_args=EngineArgs.from_cli_args(args),
277+
**{
278+
k: v
279+
for k, v in vars(args).items()
280+
if k in inspect.signature(ProfileContext).parameters
281+
})
282+
run_profile(context, csv_output=args.csv, json_output=args.json)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import argparse
2+
import json
3+
from typing import Dict
4+
5+
from vllm.profiler.layerwise_profile import ModelStatsEntry, SummaryStatsEntry
6+
from vllm.profiler.utils import TablePrinter, indent_string
7+
8+
9+
def flatten_entries(entry_cls, profile_dict: Dict):
10+
entries_and_depth = []
11+
12+
def get_entries(node, curr_depth=0):
13+
entries_and_depth.append((entry_cls(**node["entry"]), curr_depth))
14+
15+
for child in node["children"]:
16+
get_entries(
17+
child,
18+
curr_depth=curr_depth + 1,
19+
)
20+
21+
for root in profile_dict:
22+
get_entries(root)
23+
24+
return entries_and_depth
25+
26+
27+
if __name__ == "__main__":
28+
parser = argparse.ArgumentParser()
29+
30+
parser.add_argument("--json-trace",
31+
type=str,
32+
required=True,
33+
help="json trace file output by "
34+
"examples/offline_profile.py")
35+
parser.add_argument("--phase",
36+
type=str,
37+
choices=["prefill", "decode_1"],
38+
required=True,
39+
help="The phase to print the table for.")
40+
parser.add_argument("--table",
41+
type=str,
42+
choices=["summary", "model"],
43+
default="summary",
44+
help="Which table to print, the summary table or the "
45+
"layerwise model table")
46+
47+
args = parser.parse_args()
48+
49+
with open(args.json_trace, "r") as f:
50+
profile_data = json.load(f)
51+
52+
if args.table == "summary":
53+
entries_and_depths = flatten_entries(
54+
SummaryStatsEntry, profile_data[args.phase]["summary_stats"])
55+
column_widths = dict(name=80,
56+
cuda_time_us=12,
57+
pct_cuda_time=12,
58+
invocations=15)
59+
elif args.table == "model":
60+
entries_and_depths = flatten_entries(
61+
ModelStatsEntry, profile_data[args.phase]["model_stats"])
62+
column_widths = dict(name=60,
63+
cpu_time_us=12,
64+
cuda_time_us=12,
65+
pct_cuda_time=12,
66+
trace=60)
67+
68+
# indent entry names based on the depth
69+
entries = []
70+
for entry, depth in entries_and_depths:
71+
entry.name = indent_string(
72+
entry.name,
73+
indent=depth,
74+
indent_style=lambda indent: "|" + "-" * indent + " ")
75+
entries.append(entry)
76+
77+
TablePrinter(type(entries[0]), column_widths).print_table(entries)

0 commit comments

Comments
 (0)