3
3
###############################################################################
4
4
5
5
import gc
6
+ import gzip
6
7
import json
7
8
import os
8
9
import queue
10
+ import math
9
11
import threading
10
12
import time
11
13
from contextlib import contextmanager
@@ -49,6 +51,99 @@ def run(self):
49
51
outfile .write (content )
50
52
51
53
54
+ class HabanaProfilerCounterHelper :
55
+
56
+ def __init__ (self ):
57
+ self .niter = 0
58
+ self .average_real_throughput = None
59
+ self .logged_once = False
60
+ self .prompt_real_seq_lens = []
61
+ self .decode_real_seq_lens = []
62
+
63
+ def capture_decode_seq_stats (self , real_seq_lens ):
64
+ self .decode_real_seq_lens = real_seq_lens
65
+
66
+ def capture_prompt_seq_stats (self , real_seq_lens ):
67
+ self .prompt_real_seq_lens .append (real_seq_lens )
68
+
69
+ def reset_prompt_seq_stats (self ):
70
+ self .prompt_real_seq_lens = []
71
+
72
+ def get_counter_dict (self , cache_config , duration , seq_len ,
73
+ batch_size_padded , real_batch_size , prompt_batch_idx ,
74
+ is_prompt ):
75
+ throughput = batch_size_padded / (duration / 1e6 )
76
+ throughput_effective = real_batch_size / (duration / 1e6 )
77
+ if is_prompt :
78
+ real_max_seq_len = max (self .prompt_real_seq_lens [prompt_batch_idx ])
79
+ real_num_tokens = sum (self .prompt_real_seq_lens [prompt_batch_idx ])
80
+ else :
81
+ real_max_seq_len = max (self .decode_real_seq_lens )
82
+ real_num_tokens = sum (self .decode_real_seq_lens )
83
+ padded_num_tokens = batch_size_padded * seq_len
84
+ batch_token_utilization = real_num_tokens / padded_num_tokens
85
+ if self .average_real_throughput is None :
86
+ self .average_real_throughput = throughput_effective
87
+ else : # https://www.heikohoffmann.de/htmlthesis/node134.html
88
+ self .average_real_throughput = self .average_real_throughput + 1 / (
89
+ self .niter + 1 ) * (throughput_effective -
90
+ self .average_real_throughput )
91
+ phase = "prompt" if is_prompt else "decode"
92
+ counters = {
93
+ f'{ phase } _bucket_batch_size' : batch_size_padded ,
94
+ f'{ phase } _batch_size' : real_batch_size ,
95
+ f'{ phase } _bucket_seq_len' : seq_len ,
96
+ f'{ phase } _seq_len' : real_max_seq_len ,
97
+ f'{ phase } _bucket_gen_throughput' : throughput ,
98
+ f'{ phase } _real_gen_throughput' : throughput_effective ,
99
+ f'{ phase } _batch_token_utilization' : batch_token_utilization ,
100
+ 'average_real_throughput' : self .average_real_throughput ,
101
+ 'engine_iteration' : self .niter ,
102
+ }
103
+ self .niter += 1
104
+ if is_prompt :
105
+ prompt_bucket_in_throughput = (seq_len * batch_size_padded ) / (
106
+ duration / 1e6 )
107
+ prompt_real_in_throughput = sum (
108
+ self .prompt_real_seq_lens [prompt_batch_idx ]) / (duration / 1e6 )
109
+ counters [
110
+ f'{ phase } _bucket_in_throughput' ] = prompt_bucket_in_throughput
111
+ counters [f'{ phase } _real_in_throughput' ] = prompt_real_in_throughput
112
+
113
+ # KV cache might not be created yet (e.g. for profiling run)
114
+ if cache_config .num_gpu_blocks is not None and \
115
+ cache_config .num_gpu_blocks != 0 :
116
+ seq_lens = self .prompt_real_seq_lens [prompt_batch_idx ] \
117
+ if is_prompt \
118
+ else self .decode_real_seq_lens
119
+ cache_num_blocks_used = [
120
+ math .ceil (sl / cache_config .block_size ) for sl in seq_lens
121
+ ]
122
+ cache_total_num_blocks_used = sum (cache_num_blocks_used )
123
+ num_cache_blocks = cache_config .num_gpu_blocks
124
+ cache_total_num_free_blocks = \
125
+ num_cache_blocks - cache_total_num_blocks_used
126
+ cache_computed_utilization = \
127
+ cache_total_num_blocks_used / num_cache_blocks
128
+ max_blocks_per_seq = math .ceil (seq_len / cache_config .block_size )
129
+ batch_block_utilization = cache_total_num_blocks_used / (
130
+ batch_size_padded * max_blocks_per_seq )
131
+ counters ['cache_num_blocks_used' ] = cache_total_num_blocks_used
132
+ counters ['cache_num_free_blocks' ] = cache_total_num_free_blocks
133
+ counters ['cache_computed_utilization' ] = cache_computed_utilization
134
+ counters [
135
+ f'{ phase } _batch_block_utilization' ] = batch_block_utilization
136
+ if not self .logged_once :
137
+ counters ['const_cache_num_blocks' ] = cache_config .num_gpu_blocks
138
+ counters [
139
+ 'const_gpu_memory_utilization' ] = \
140
+ cache_config .gpu_memory_utilization
141
+ counters ['const_block_size' ] = cache_config .block_size
142
+ self .logged_once = True
143
+
144
+ return counters
145
+
146
+
52
147
class HabanaHighLevelProfiler :
53
148
profiling_trace_events : queue .Queue = queue .Queue ()
54
149
event_tid = {'counter' : 1 , 'external' : 2 , 'internal' : 3 }
@@ -121,6 +216,55 @@ def end(self):
121
216
event = self .event_cache .pop ()
122
217
event ['dur' ] = ts - event ['ts' ]
123
218
self ._dump_with_sep (event )
219
+
220
+
221
+ def full_trace_handler (self , dir_name , use_gzip = False ):
222
+
223
+ def handler_fn (prof ) -> None :
224
+ if not os .path .isdir (dir_name ):
225
+ try :
226
+ os .makedirs (dir_name , exist_ok = True )
227
+ except Exception as e :
228
+ raise RuntimeError ("Can't create directory: " +
229
+ dir_name ) from e
230
+ file_name = f"vllm.{ time .time_ns ()} .pt.trace.json"
231
+ file_path = os .path .join (dir_name , file_name )
232
+ prof .export_chrome_trace (file_path )
233
+ with open (file_path ) as f :
234
+ pytorch_trace = json .load (f )
235
+ os .remove (file_path )
236
+ base = pytorch_trace ['baseTimeNanoseconds' ] / 1000
237
+ events = self .profiling_trace_events
238
+ while True :
239
+ try :
240
+ event_str = events .get_nowait ()
241
+ event = json .loads (event_str [:- 1 ])
242
+ event ['ts' ] = event ['ts' ] - base
243
+ pytorch_trace ['traceEvents' ].append (event )
244
+ except queue .Empty :
245
+ break
246
+
247
+ pytorch_trace ['traceEvents' ].append ({
248
+ "args" : {
249
+ "name" : "vLLM"
250
+ },
251
+ "name" : "process_name" ,
252
+ "ph" : "M" ,
253
+ "pid" : 1 ,
254
+ "tid" : 0 ,
255
+ "ts" : 0.0
256
+ })
257
+ if use_gzip :
258
+ file_path = file_path + ".gz"
259
+ with gzip .open (file_path , 'wt' , encoding = "ascii" ) as zipfile :
260
+ json .dump (pytorch_trace , zipfile )
261
+ else :
262
+ with open (file_path , "w" ) as outfile :
263
+ outfile .write (json .dumps (pytorch_trace ))
264
+ logger ().info ("Saved full profiling to %s" , file_path )
265
+
266
+ return handler_fn
267
+
124
268
125
269
@contextmanager
126
270
def record_event (self , type , name , args = None ):
@@ -224,4 +368,4 @@ def __exit__(self, exc_type, exc_val, exc_tb):
224
368
self .consumed_device_memory = \
225
369
self .final_device_memory - self .initial_device_memory
226
370
self .consumed_host_memory = \
227
- self .final_host_memory - self .initial_host_memory
371
+ self .final_host_memory - self .initial_host_memory
0 commit comments