4
4
import sys
5
5
from argparse import RawTextHelpFormatter
6
6
from dataclasses import asdict , dataclass
7
- from typing import Optional
7
+ from typing import Any , Dict , Generator , List , Optional , TypeAlias
8
8
9
9
import torch
10
+ import tqdm
10
11
11
12
from vllm import LLM , SamplingParams
12
13
from vllm .engine .arg_utils import EngineArgs
15
16
16
17
BATCH_SIZE_DEFAULT = 1
17
18
PROMPT_LEN_DEFAULT = 256
18
- OUTPUT_LEN_DEFAULT = 2
19
19
20
20
21
21
@dataclass
22
22
class ProfileContext :
23
23
engine_args : EngineArgs
24
24
prompt_len : int
25
- output_len : int
26
25
batch_size : int
27
- save_chrome_traces_folder : Optional [str ]
26
+
27
+ # The profiler can run in 2 modes,
28
+ # 1. Run profiler for user specified num_steps
29
+ num_steps : Optional [int ] = None
30
+ # 2. Run profiler until all requests complete
31
+ complete_num_requests_per_step : Optional [int ] = None
32
+
33
+ save_chrome_traces_folder : Optional [str ] = None
28
34
29
35
30
36
def get_dtype (dtype : str ):
@@ -34,23 +40,155 @@ def get_dtype(dtype: str):
34
40
return dtype
35
41
36
42
43
+ OutputLen_NumReqs_Map : TypeAlias = Dict [int , int ]
44
+ def compute_request_output_lengths (batch_size : int , step_requests : List [int ]) \
45
+ -> OutputLen_NumReqs_Map :
46
+ """
47
+ Given the number of requests, batch_size, and the number of requests
48
+ that each engine-step should process, step_requests, determine the
49
+ output lengths of the requests such that step_request is honoured.
50
+
51
+ Example:
52
+ if batch size = 128 and step_request = [128, 128, 96, 64, 32, 1]
53
+ then return,
54
+ {2 : 32, 3 : 32, 4 : 32, 5 : 31, 6 : 1}, meaning,
55
+ 32 requests should have output length 2,
56
+ 32 requests should have output length 3,
57
+ 32 requests should have output length 4,
58
+ 31 requests should have output length 5,
59
+ 1 request should have output length 6.
60
+
61
+ Args:
62
+ batch_size (int): Number of requests submitted for profile. This is
63
+ args.batch_size.
64
+ step_requests (List[int]): step_requests[i] is the number of requests
65
+ that the ith engine step should process.
66
+
67
+ Returns:
68
+ OutputLen_NumReqs_Map : A dictionary with output-length as keys and the
69
+ number of requests required to have that output-length as values.
70
+ """
71
+ ol_nr : OutputLen_NumReqs_Map = {}
72
+
73
+ # Number of request that are assigned an output-length
74
+ num_reqs_assigned : int = 0
75
+ num_steps : int = len (step_requests )
76
+
77
+ # sanity check. The first step (prefill-step), must process all requests.
78
+ assert step_requests [0 ] == batch_size
79
+
80
+ # Begin assignments from the last step.
81
+ output_length : int = num_steps
82
+ for num_requests_at_step in reversed (step_requests ):
83
+ if num_reqs_assigned == batch_size :
84
+ break
85
+
86
+ assert num_reqs_assigned < batch_size
87
+
88
+ # Remove the number of requests that have been determined
89
+ # to participate in this step and beyond.
90
+ num_reqs_unassigned_at_step = num_requests_at_step - num_reqs_assigned
91
+ assert num_reqs_unassigned_at_step >= 0
92
+
93
+ if num_reqs_unassigned_at_step > 0 :
94
+ ol_nr [output_length ] = num_reqs_unassigned_at_step
95
+ num_reqs_assigned += num_reqs_unassigned_at_step
96
+
97
+ output_length -= 1
98
+
99
+ # sanity checks.
100
+ assert sum (ol_nr .values ()) == batch_size , \
101
+ ("Number of requests in output-length assignment does not match "
102
+ f"batch-size.\n batch size { batch_size } - "
103
+ f"step requests { step_requests } - assignments { ol_nr } " )
104
+
105
+ # Check that the output-length is in [1, num-steps]. Output length must be
106
+ # at least 1 as all requests must participate in the prefill-step.
107
+ assert all (ol >= 1 and ol <= num_steps for ol in ol_nr ), \
108
+ ("Output lengths of requests should be in range "
109
+ f"[1, num-engine-steps].\n batch size { batch_size } - "
110
+ f"step requests { step_requests } - assignments { ol_nr } " )
111
+
112
+ return ol_nr
113
+
114
+
115
+ def determine_requests_per_step (context : ProfileContext ) -> List [int ]:
116
+ """
117
+ Determine number of requests each engine step should process.
118
+ If context.num_steps is set, then all engine steps process the
119
+ same number of requests and the output list is of length
120
+ context.num_steps.
121
+
122
+ If context.complete_num_requests_per_step is set, then each decode step
123
+ processes fewer and fewer requests until there are no requests to process.
124
+ In this case, the output list is as big as the number of steps
125
+ required to process all requests.
126
+
127
+ Args:
128
+ context: ProfileContext object.
129
+
130
+ Returns:
131
+ List[int]: Number of requests to process for all engine-steps.
132
+ output[i], contains the number of requests that the ith step
133
+ should process.
134
+ """
135
+ if context .num_steps :
136
+ # All requests must run until num_engine_steps. This implies
137
+ # that their output lengths must be equal to num_engine_steps.
138
+ return [context .batch_size ] * context .num_steps
139
+
140
+ assert context .complete_num_requests_per_step and \
141
+ context .complete_num_requests_per_step > 0 , \
142
+ (f"Expected a positive complete_num_requests_per_step argument."
143
+ f"Instead got { context .complete_num_requests_per_step } " )
144
+
145
+ # We start dropping after the first decode step.
146
+ step_requests = [
147
+ context .batch_size , # prefill
148
+ context .batch_size , # decode
149
+ ]
150
+
151
+ num_running_requests = context .batch_size
152
+ num_running_requests -= context .complete_num_requests_per_step
153
+ while num_running_requests > 0 :
154
+ step_requests .append (num_running_requests )
155
+ num_running_requests -= context .complete_num_requests_per_step
156
+
157
+ if step_requests [- 1 ] != 1 :
158
+ # have 1 request running at the last step. This is often
159
+ # useful
160
+ step_requests .append (1 )
161
+
162
+ return step_requests
163
+
164
+
37
165
def run_profile (context : ProfileContext , csv_output : Optional [str ],
38
166
json_output : Optional [str ]):
39
167
print ("Run profile with:" )
40
168
for key , value in asdict (context ).items ():
41
169
print (f" { key } = { value } " )
42
170
171
+ requests_per_step : List [int ] = determine_requests_per_step (context )
172
+
173
+ ol_nr : OutputLen_NumReqs_Map = compute_request_output_lengths (
174
+ context .batch_size , requests_per_step )
175
+
176
+ num_steps_to_profile : int = len (requests_per_step )
177
+ max_output_len : int = max (ol_nr .keys ())
178
+ assert max_output_len >= 1
179
+
43
180
# Create sampling params
44
- sampling_params = SamplingParams (temperature = 0.8 ,
45
- top_p = 0.95 ,
46
- max_tokens = args .output_len ,
47
- ignore_eos = True )
181
+ sampling_params = SamplingParams (
182
+ temperature = 0.8 ,
183
+ top_p = 0.95 ,
184
+ # max_tokens is set on a per-request basis.
185
+ max_tokens = None ,
186
+ ignore_eos = True )
48
187
49
188
# Create LLM
50
189
llm = LLM (** asdict (context .engine_args ))
51
190
batch_size = context .batch_size
52
191
prompt_len = context .prompt_len
53
- output_len = context .output_len
54
192
55
193
scheduler_config = llm .llm_engine .scheduler_config
56
194
max_model_len = llm .llm_engine .model_config .max_model_len
@@ -65,24 +203,34 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
65
203
f"choose a smaller batch size or prompt length, or increase "
66
204
f"--max-num-batched-tokens" )
67
205
sys .exit (- 1 )
68
- if batch_size >= max_num_seqs :
206
+ if batch_size > max_num_seqs :
69
207
print (
70
208
f"ERROR: chosen batch_size ({ batch_size } ) is larger than "
71
209
f"max_num_seqs ({ max_num_seqs } ) and therefore cannot be run in a "
72
210
f"single profile step, please choose a smaller batch size" )
73
211
sys .exit (- 1 )
74
212
print ("llm.llm_engine.model_config.max_model_len: " ,
75
213
llm .llm_engine .model_config .max_model_len )
76
- if prompt_len + output_len > llm .llm_engine .model_config .max_model_len :
77
- print (
78
- f"ERROR: chosen prompt_len + output_len ( { prompt_len } + "
79
- f" { output_len } = { prompt_len + output_len } ) is larger than the "
80
- f"model's max_model_len ( { max_model_len } ), please choose a smaller "
81
- f"prompt_len or output_len, or increase --max-model-len" )
214
+ if prompt_len + max_output_len > llm .llm_engine .model_config .max_model_len :
215
+ print (f"ERROR: chosen prompt_len + max_output_len ( { prompt_len } + "
216
+ f" { max_output_len } = { prompt_len + max_output_len } ) is larger "
217
+ f"than the model's max_model_len ( { max_model_len } ), please "
218
+ f"choose a smaller prompt_len or max_output_len, or increase "
219
+ f" --max-model-len" )
82
220
sys .exit (- 1 )
83
221
84
222
def add_requests ():
223
+
224
+ def get_output_len_generator () -> Generator [int , Any , Any ]:
225
+ for output_len , num_reqs in ol_nr .items ():
226
+ for _ in range (num_reqs ):
227
+ yield output_len
228
+
229
+ output_len_generator = get_output_len_generator ()
85
230
for i in range (batch_size ):
231
+ sampling_params .max_tokens = next (output_len_generator )
232
+ assert isinstance (sampling_params .max_tokens , int )
233
+
86
234
prompt_token_ids = torch .randint (
87
235
llm .llm_engine .model_config .get_vocab_size (),
88
236
size = (prompt_len , )).tolist ()
@@ -110,8 +258,11 @@ def abort_requests():
110
258
llm .llm_engine .step () # First step is prefill
111
259
112
260
decode_profs = []
113
- for x in range (args .output_len - 1 ):
114
- with layerwise_profile () as decode_prof :
261
+ for _ in tqdm .tqdm (range (num_steps_to_profile - 1 )):
262
+ num_running_seqs = llm .llm_engine .scheduler [
263
+ 0 ].get_num_unfinished_seq_groups ()
264
+ with layerwise_profile (
265
+ num_running_seqs = num_running_seqs ) as decode_prof :
115
266
llm .llm_engine .step ()
116
267
decode_profs .append (decode_prof )
117
268
@@ -154,7 +305,8 @@ def abort_requests():
154
305
decode_results_list [0 ].print_summary_table ()
155
306
156
307
if csv_output :
157
- csv_filename_base = csv_output .rstrip (".csv" )
308
+ csv_filename_base = csv_output [:- 4 ] \
309
+ if csv_output .endswith ('.csv' ) else csv_output
158
310
prefill_results .export_model_stats_table_csv (
159
311
csv_filename_base + "_prefill_model_table.csv" )
160
312
prefill_results .export_summary_stats_table_csv (
@@ -187,10 +339,10 @@ def abort_requests():
187
339
for idx , dr in enumerate (decode_results_list ):
188
340
json_dict [f"decode_{ idx + 1 } " ] = dr .convert_stats_to_dict ()
189
341
190
- for idx , dr in enumerate ( decode_results_list [ 1 :]):
191
- json_dict [ f"decode_ { idx + 1 } " ] = dr . convert_stats_to_dict ()
192
-
193
- with open (json_output . rstrip ( ".json" ) + ".json" , "w+" ) as f :
342
+ # Add .json to json_output filename if it doesn't exist already.
343
+ json_output_file = json_output if json_output . endswith (
344
+ '.json' ) else json_output + '.json'
345
+ with open (json_output_file , "w+" ) as f :
194
346
json .dump (json_dict , f , indent = 2 )
195
347
pass
196
348
@@ -214,7 +366,7 @@ def abort_requests():
214
366
python examples/offline_profile.py \\
215
367
--model neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 --batch-size 4 \\
216
368
--prompt-len 512 --max-num-batched-tokens 8196 --json Llama31-8b-FP8 \\
217
- --enforce-eager
369
+ --enforce-eager run_num_steps -n 2
218
370
```
219
371
220
372
then you can use various tools to analyze the json output
@@ -261,17 +413,41 @@ def abort_requests():
261
413
default = BATCH_SIZE_DEFAULT ,
262
414
help = f"Number of requests to run as a single batch, "
263
415
f"default={ BATCH_SIZE_DEFAULT } " )
264
- parser .add_argument (
265
- "--output-len" ,
416
+
417
+ subparsers = parser .add_subparsers (dest = "cmd" )
418
+
419
+ run_num_steps_parser = subparsers .add_parser (
420
+ "run_num_steps" ,
421
+ help = "This variation profiles n engine.step() invocations." )
422
+ run_num_steps_parser .add_argument (
423
+ '-n' ,
424
+ '--num-steps' ,
266
425
type = int ,
267
- default = OUTPUT_LEN_DEFAULT ,
268
- help = "Number of llm steps to run (includes prefill and decode) "
269
- "- default={OUTPUT_LEN_DEFAULT}" )
426
+ help = "Number of engine steps to profile.\n "
427
+ "Setting it to 1, profiles only the prefill step.\n "
428
+ "Setting it to 2, profiles the prefill and first decode step\n "
429
+ "Setting it to 3, profiles the prefill, 1st and 2nd decode steps\n "
430
+ "and so on ..." )
431
+
432
+ run_to_completion_parser = subparsers .add_parser (
433
+ "run_to_completion" ,
434
+ help = "This variation profiles all the engine.step() invocations"
435
+ "until the engine exhausts all submitted requests." )
436
+ run_to_completion_parser .add_argument (
437
+ '-n' ,
438
+ '--complete-num-requests-per-step' ,
439
+ type = int ,
440
+ help =
441
+ "Complete complete_num_requests_per_step requests every decode step."
442
+ "For e.g., with batch_size 128 and complete_num_requests_per_step 32,"
443
+ "the profiler is run for 6 engine steps, with the steps processing, "
444
+ "128, 128, 96, 64, 32, 1 requests respectively.\n "
445
+ "Note that we tack-on a one-request step at the end as it is often "
446
+ "useful." )
270
447
271
448
EngineArgs .add_cli_args (parser )
272
449
273
450
args = parser .parse_args ()
274
-
275
451
context = ProfileContext (
276
452
engine_args = EngineArgs .from_cli_args (args ),
277
453
** {
0 commit comments