Skip to content

Commit b50b3b0

Browse files
Enable high level profiler (#49)
Ripped from HabanaAI/vllm-fork#1501 --------- Signed-off-by: Konrad Zawora <[email protected]>
1 parent c9c266e commit b50b3b0

File tree

3 files changed

+322
-113
lines changed

3 files changed

+322
-113
lines changed

vllm_gaudi/extension/profiler.py

Lines changed: 145 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
###############################################################################
44

55
import gc
6+
import gzip
67
import json
78
import os
89
import queue
10+
import math
911
import threading
1012
import time
1113
from contextlib import contextmanager
@@ -49,6 +51,99 @@ def run(self):
4951
outfile.write(content)
5052

5153

54+
class HabanaProfilerCounterHelper:
55+
56+
def __init__(self):
57+
self.niter = 0
58+
self.average_real_throughput = None
59+
self.logged_once = False
60+
self.prompt_real_seq_lens = []
61+
self.decode_real_seq_lens = []
62+
63+
def capture_decode_seq_stats(self, real_seq_lens):
64+
self.decode_real_seq_lens = real_seq_lens
65+
66+
def capture_prompt_seq_stats(self, real_seq_lens):
67+
self.prompt_real_seq_lens.append(real_seq_lens)
68+
69+
def reset_prompt_seq_stats(self):
70+
self.prompt_real_seq_lens = []
71+
72+
def get_counter_dict(self, cache_config, duration, seq_len,
73+
batch_size_padded, real_batch_size, prompt_batch_idx,
74+
is_prompt):
75+
throughput = batch_size_padded / (duration / 1e6)
76+
throughput_effective = real_batch_size / (duration / 1e6)
77+
if is_prompt:
78+
real_max_seq_len = max(self.prompt_real_seq_lens[prompt_batch_idx])
79+
real_num_tokens = sum(self.prompt_real_seq_lens[prompt_batch_idx])
80+
else:
81+
real_max_seq_len = max(self.decode_real_seq_lens)
82+
real_num_tokens = sum(self.decode_real_seq_lens)
83+
padded_num_tokens = batch_size_padded * seq_len
84+
batch_token_utilization = real_num_tokens / padded_num_tokens
85+
if self.average_real_throughput is None:
86+
self.average_real_throughput = throughput_effective
87+
else: # https://www.heikohoffmann.de/htmlthesis/node134.html
88+
self.average_real_throughput = self.average_real_throughput + 1 / (
89+
self.niter + 1) * (throughput_effective -
90+
self.average_real_throughput)
91+
phase = "prompt" if is_prompt else "decode"
92+
counters = {
93+
f'{phase}_bucket_batch_size': batch_size_padded,
94+
f'{phase}_batch_size': real_batch_size,
95+
f'{phase}_bucket_seq_len': seq_len,
96+
f'{phase}_seq_len': real_max_seq_len,
97+
f'{phase}_bucket_gen_throughput': throughput,
98+
f'{phase}_real_gen_throughput': throughput_effective,
99+
f'{phase}_batch_token_utilization': batch_token_utilization,
100+
'average_real_throughput': self.average_real_throughput,
101+
'engine_iteration': self.niter,
102+
}
103+
self.niter += 1
104+
if is_prompt:
105+
prompt_bucket_in_throughput = (seq_len * batch_size_padded) / (
106+
duration / 1e6)
107+
prompt_real_in_throughput = sum(
108+
self.prompt_real_seq_lens[prompt_batch_idx]) / (duration / 1e6)
109+
counters[
110+
f'{phase}_bucket_in_throughput'] = prompt_bucket_in_throughput
111+
counters[f'{phase}_real_in_throughput'] = prompt_real_in_throughput
112+
113+
# KV cache might not be created yet (e.g. for profiling run)
114+
if cache_config.num_gpu_blocks is not None and \
115+
cache_config.num_gpu_blocks != 0:
116+
seq_lens = self.prompt_real_seq_lens[prompt_batch_idx] \
117+
if is_prompt \
118+
else self.decode_real_seq_lens
119+
cache_num_blocks_used = [
120+
math.ceil(sl / cache_config.block_size) for sl in seq_lens
121+
]
122+
cache_total_num_blocks_used = sum(cache_num_blocks_used)
123+
num_cache_blocks = cache_config.num_gpu_blocks
124+
cache_total_num_free_blocks = \
125+
num_cache_blocks - cache_total_num_blocks_used
126+
cache_computed_utilization = \
127+
cache_total_num_blocks_used / num_cache_blocks
128+
max_blocks_per_seq = math.ceil(seq_len / cache_config.block_size)
129+
batch_block_utilization = cache_total_num_blocks_used / (
130+
batch_size_padded * max_blocks_per_seq)
131+
counters['cache_num_blocks_used'] = cache_total_num_blocks_used
132+
counters['cache_num_free_blocks'] = cache_total_num_free_blocks
133+
counters['cache_computed_utilization'] = cache_computed_utilization
134+
counters[
135+
f'{phase}_batch_block_utilization'] = batch_block_utilization
136+
if not self.logged_once:
137+
counters['const_cache_num_blocks'] = cache_config.num_gpu_blocks
138+
counters[
139+
'const_gpu_memory_utilization'] = \
140+
cache_config.gpu_memory_utilization
141+
counters['const_block_size'] = cache_config.block_size
142+
self.logged_once = True
143+
144+
return counters
145+
146+
52147
class HabanaHighLevelProfiler:
53148
profiling_trace_events: queue.Queue = queue.Queue()
54149
event_tid = {'counter': 1, 'external': 2, 'internal': 3}
@@ -121,6 +216,55 @@ def end(self):
121216
event = self.event_cache.pop()
122217
event['dur'] = ts - event['ts']
123218
self._dump_with_sep(event)
219+
220+
221+
def full_trace_handler(self, dir_name, use_gzip=False):
222+
223+
def handler_fn(prof) -> None:
224+
if not os.path.isdir(dir_name):
225+
try:
226+
os.makedirs(dir_name, exist_ok=True)
227+
except Exception as e:
228+
raise RuntimeError("Can't create directory: " +
229+
dir_name) from e
230+
file_name = f"vllm.{time.time_ns()}.pt.trace.json"
231+
file_path = os.path.join(dir_name, file_name)
232+
prof.export_chrome_trace(file_path)
233+
with open(file_path) as f:
234+
pytorch_trace = json.load(f)
235+
os.remove(file_path)
236+
base = pytorch_trace['baseTimeNanoseconds'] / 1000
237+
events = self.profiling_trace_events
238+
while True:
239+
try:
240+
event_str = events.get_nowait()
241+
event = json.loads(event_str[:-1])
242+
event['ts'] = event['ts'] - base
243+
pytorch_trace['traceEvents'].append(event)
244+
except queue.Empty:
245+
break
246+
247+
pytorch_trace['traceEvents'].append({
248+
"args": {
249+
"name": "vLLM"
250+
},
251+
"name": "process_name",
252+
"ph": "M",
253+
"pid": 1,
254+
"tid": 0,
255+
"ts": 0.0
256+
})
257+
if use_gzip:
258+
file_path = file_path + ".gz"
259+
with gzip.open(file_path, 'wt', encoding="ascii") as zipfile:
260+
json.dump(pytorch_trace, zipfile)
261+
else:
262+
with open(file_path, "w") as outfile:
263+
outfile.write(json.dumps(pytorch_trace))
264+
logger().info("Saved full profiling to %s", file_path)
265+
266+
return handler_fn
267+
124268

125269
@contextmanager
126270
def record_event(self, type, name, args=None):
@@ -224,4 +368,4 @@ def __exit__(self, exc_type, exc_val, exc_tb):
224368
self.consumed_device_memory = \
225369
self.final_device_memory - self.initial_device_memory
226370
self.consumed_host_memory = \
227-
self.final_host_memory - self.initial_host_memory
371+
self.final_host_memory - self.initial_host_memory

0 commit comments

Comments
 (0)