Skip to content

Commit b728eae

Browse files
committed
Disable HPU Graphs
1 parent dbfa656 commit b728eae

File tree

2 files changed

+4
-26
lines changed

2 files changed

+4
-26
lines changed

vllm_gaudi/lora/punica_wrapper/punica_hpu.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,33 +23,10 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int,
2323
device: Union[torch.device, str], **kwargs):
2424
# Increasing max_num_batched_tokens by 3x to handle increase in
2525
# tensor size due to padding.
26+
# TODO: Need to check if this override is still required
2627
PunicaWrapperBase.__init__(self, 3 * max_num_batched_tokens,
2728
max_batches, device)
2829

29-
def _update_base_metadata(
30-
self,
31-
mapping: "LoRAMapping",
32-
lora_index_to_id: list[Optional[int]],
33-
max_loras: int,
34-
vocab_size: int,
35-
extra_vocab_size: int,
36-
):
37-
(
38-
base_indices,
39-
sampler_indices,
40-
sampler_indices_padded,
41-
embeddings_indices,
42-
indices_len,
43-
) = convert_mapping(mapping, lora_index_to_id, max_loras, vocab_size,
44-
extra_vocab_size, self.device)
45-
self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices)
46-
self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
47-
self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
48-
sampler_indices_padded)
49-
self._embeddings_indices[:embeddings_indices.
50-
shape[0], :embeddings_indices.shape[1]].copy_(
51-
embeddings_indices)
52-
self.indices_len[:] = indices_len
5330

5431
def add_lora_embedding(self,
5532
y: torch.Tensor,

vllm_gaudi/v1/worker/hpu_model_runner.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -442,9 +442,10 @@ def generate_proposals(self, *args, **kwargs):
442442

443443

444444
def _maybe_wrap_in_hpu_graph(*args, **kwargs):
445-
return htorch.hpu.wrap_in_hpu_graph(
445+
'''return htorch.hpu.wrap_in_hpu_graph(
446446
HpuModelAdapter(*args, **kwargs), disable_tensor_cache=True
447-
) if htorch.utils.internal.is_lazy() else HpuModelAdapter(*args, **kwargs)
447+
) if htorch.utils.internal.is_lazy() else HpuModelAdapter(*args, **kwargs)'''
448+
return HpuModelAdapter(*args, **kwargs)
448449

449450

450451
def subtuple(obj: object,

0 commit comments

Comments
 (0)