Skip to content

Commit e9bc231

Browse files
committed
fix
Signed-off-by: Wuxun Zhang <[email protected]>
1 parent a00cfce commit e9bc231

File tree

2 files changed

+115
-7
lines changed

2 files changed

+115
-7
lines changed

vllm_gaudi/distributed/device_communicators/hpu_communicator.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,6 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
6565
def dispatch(
6666
self, hidden_states: torch.Tensor,
6767
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
68-
"""
69-
all-gather based dispatch for HPUCommunicator.
70-
"""
7168
cu_tokens_across_dp_cpu = get_forward_context(
7269
).dp_metadata.cu_tokens_across_dp_cpu
7370
hidden_states_across_dp = naive_multicast(hidden_states,

vllm_gaudi/v1/worker/hpu_model_runner.py

Lines changed: 115 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,7 +1227,6 @@ def _form_prefill_batch(self, contents):
12271227

12281228
query_lens = _async_h2d_tensor(query_lens, torch.int32)
12291229
token_ids = _async_h2d_tensor(token_ids, torch.int32)
1230-
12311230
token_positions = _async_h2d_tensor(token_positions, torch.int32)
12321231
token_slots = _async_h2d_tensor(token_slots, torch.int64)
12331232
logits_indices = _async_h2d_tensor(logits_indices, torch.int32)
@@ -1294,7 +1293,6 @@ def _prepare_decode_inputs(self, num_decodes,
12941293
num_decodes, sum(num_blocks))[0]
12951294

12961295
# # dp aware padding
1297-
assert padded_batch_size is not None
12981296
padded_batch_size += self.get_dp_padding(padded_batch_size)
12991297

13001298
block_tables_list = []
@@ -1754,8 +1752,6 @@ def execute_model(
17541752

17551753
######################### PREFILLS #########################
17561754
if num_prefills > 0:
1757-
# Wuxun: merged prefill forward if enabled
1758-
# 2D bucketing or merged prefill bucketing
17591755
htorch.core.mark_step()
17601756
for idx, (req_id, prompt_len, token_ids, position_ids,
17611757
attn_metadata, logits_indices,
@@ -2098,6 +2094,121 @@ def log_graph_warmup_summary(self, buckets, is_prompt, total_mem):
20982094
f'used_mem:{format_bytes(total_mem)}')
20992095
logger.info(msg)
21002096

2097+
<<<<<<< HEAD
2098+
=======
2099+
def warmup_scenario(self,
2100+
batch_size,
2101+
seq_or_block,
2102+
num_blocks,
2103+
is_prompt,
2104+
kv_caches,
2105+
num_iters=3,
2106+
is_pt_profiler_run=True,
2107+
align_worker=False,
2108+
is_dummy_run=False) -> None:
2109+
"""Dummy warmup run for memory usage and graph compilation."""
2110+
2111+
query_seq_len = seq_or_block if is_prompt else 1
2112+
input_ids = torch.zeros((batch_size, query_seq_len),
2113+
dtype=torch.int32,
2114+
device='cpu')
2115+
position_ids = torch.zeros((batch_size, query_seq_len),
2116+
dtype=torch.int32,
2117+
device='cpu')
2118+
slot_mapping = torch.zeros((batch_size, query_seq_len),
2119+
dtype=torch.int64,
2120+
device='cpu')
2121+
2122+
input_ids_device = _async_h2d_tensor_copy(input_ids, self.device)
2123+
position_ids_device = _async_h2d_tensor_copy(position_ids, self.device)
2124+
slot_mapping_device = _async_h2d_tensor_copy(slot_mapping, self.device)
2125+
2126+
use_graphs = is_dummy_run or self._use_graphs()
2127+
phase = "prompt" if is_prompt else "decode"
2128+
scenario_name = ("warmup_"
2129+
f"{phase}_"
2130+
f"bs{batch_size}_"
2131+
f"seq{query_seq_len}_"
2132+
f"ctx{num_blocks}_"
2133+
f"graphs{'T' if use_graphs else 'F'}")
2134+
input_ids = torch.zeros((batch_size, query_seq_len),
2135+
dtype=torch.int32,
2136+
device='cpu')
2137+
position_ids = torch.zeros((batch_size, query_seq_len),
2138+
dtype=torch.int32,
2139+
device='cpu')
2140+
slot_mapping = torch.zeros((batch_size, query_seq_len),
2141+
dtype=torch.int64,
2142+
device='cpu')
2143+
2144+
input_ids_device = _async_h2d_tensor_copy(input_ids, self.device)
2145+
position_ids_device = _async_h2d_tensor_copy(position_ids, self.device)
2146+
slot_mapping_device = _async_h2d_tensor_copy(slot_mapping, self.device)
2147+
self.profiler.start('internal', scenario_name)
2148+
2149+
times = num_iters if use_graphs or is_pt_profiler_run else 1
2150+
for time_index in range(times):
2151+
if is_prompt:
2152+
seq_lens = torch.zeros((batch_size),
2153+
dtype=torch.int32,
2154+
device='cpu')
2155+
seq_lens.fill_(seq_or_block)
2156+
seq_lens_device = _async_h2d_tensor_copy(seq_lens, self.device)
2157+
block_list_device = None
2158+
if num_blocks:
2159+
prefix_block_tables = torch.ones(
2160+
(batch_size, num_blocks),
2161+
dtype=torch.int32,
2162+
device='cpu') * self._PAD_BLOCK_ID
2163+
block_list_device = _async_h2d_tensor_copy(
2164+
prefix_block_tables.flatten(), self.device)
2165+
attn_metadata = \
2166+
HPUAttentionMetadataV1.make_prefill_metadata(
2167+
attn_bias=None,
2168+
seq_lens_tensor=seq_lens_device,
2169+
context_lens_tensor=seq_lens_device,
2170+
slot_mapping=slot_mapping_device,
2171+
block_list=block_list_device,
2172+
block_size=self.block_size)
2173+
else:
2174+
block_tables = [
2175+
x.tolist()
2176+
for x in np.array_split(np.arange(num_blocks), batch_size)
2177+
]
2178+
block_list, block_groups, block_usage = \
2179+
self.get_habana_paged_attn_buffers(
2180+
slot_mapping=slot_mapping,
2181+
block_tables=block_tables,
2182+
batch_size=batch_size)
2183+
block_list_device = _async_h2d_tensor_copy(
2184+
block_list, self.device)
2185+
block_usage_device = _async_h2d_tensor_copy(
2186+
block_usage, self.device)
2187+
block_groups_device = _async_h2d_tensor_copy(
2188+
block_groups, self.device)
2189+
attn_metadata = HPUAttentionMetadataV1.make_decode_metadata(
2190+
block_list=block_list_device,
2191+
block_usage=block_usage_device,
2192+
block_groups=block_groups_device,
2193+
num_decode_tokens=batch_size,
2194+
input_positions=None,
2195+
slot_mapping=slot_mapping_device,
2196+
block_size=self.block_size)
2197+
2198+
logits_indices = torch.arange(0, batch_size, device='cpu')
2199+
logits_indices_device = _async_h2d_tensor_copy(logits_indices,
2200+
self.device)
2201+
# Dummy run.
2202+
htorch.core.mark_step()
2203+
_ = self._execute_model_generic(input_ids_device, position_ids_device,
2204+
attn_metadata, logits_indices_device,
2205+
kv_caches, True)
2206+
# TODO: do sampling on logits, warmup sampler and prefill joiner
2207+
htorch.core.mark_step()
2208+
self.profiler.end()
2209+
return None
2210+
2211+
>>>>>>> 68ee934 (fix)
21012212
def log_warmup(self, phase, i, max_i, batch_size, seq_len, num_blocks):
21022213
free_mem = format_bytes(
21032214
HabanaMemoryProfiler.current_free_device_memory())

0 commit comments

Comments
 (0)