Skip to content

Commit 7465346

Browse files
committed
Clean-up. Minor fixes
Signed-off-by: Vivek <[email protected]>
1 parent b728eae commit 7465346

File tree

2 files changed

+83
-69
lines changed

2 files changed

+83
-69
lines changed

vllm_gaudi/lora/punica_wrapper/punica_hpu.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,13 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
from typing import TYPE_CHECKING, Optional, Union, final
4+
from typing import Optional, Union, final
55

66
import torch
77
from vllm_gaudi.extension.ops import (dispatch_bgmv_embedding,
8-
dispatch_bgmv_linear)
8+
dispatch_bgmv_linear)
99

1010
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
11-
from vllm.lora.punica_wrapper.utils import convert_mapping
12-
13-
if TYPE_CHECKING:
14-
# avoid circuit import
15-
from vllm.lora.layers import LoRAMapping
16-
from vllm.lora.models import LongContextLoRAContext
1711

1812

1913
@final
@@ -27,7 +21,6 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int,
2721
PunicaWrapperBase.__init__(self, 3 * max_num_batched_tokens,
2822
max_batches, device)
2923

30-
3124
def add_lora_embedding(self,
3225
y: torch.Tensor,
3326
x: torch.Tensor,

vllm_gaudi/v1/worker/hpu_model_runner.py

Lines changed: 81 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import os
88
import time
99
from dataclasses import dataclass, field, fields
10-
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeAlias, Union, Set, List
10+
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeAlias, Union
1111

1212
import habana_frameworks.torch as htorch
1313
import habana_frameworks.torch.internal.bridge_config as bc
@@ -407,6 +407,10 @@ def forward(self, *args, **kwargs):
407407
# kwargs['attn_metadata'].slot_mapping, compared to untrimmed metadata
408408
kwargs = kwargs.copy()
409409
# selected_token_indices = kwargs.pop('selected_token_indices')
410+
if 'lora_mask' in kwargs:
411+
lora_mask = kwargs['lora_mask']
412+
LoraMask.setLoraMask(lora_mask)
413+
kwargs.pop('lora_mask')
410414
if 'warmup_mode' in kwargs:
411415
kwargs.pop('warmup_mode')
412416
input_ids = kwargs['input_ids']
@@ -442,9 +446,11 @@ def generate_proposals(self, *args, **kwargs):
442446

443447

444448
def _maybe_wrap_in_hpu_graph(*args, **kwargs):
445-
'''return htorch.hpu.wrap_in_hpu_graph(
449+
'''
450+
return htorch.hpu.wrap_in_hpu_graph(
446451
HpuModelAdapter(*args, **kwargs), disable_tensor_cache=True
447-
) if htorch.utils.internal.is_lazy() else HpuModelAdapter(*args, **kwargs)'''
452+
) if htorch.utils.internal.is_lazy() else HpuModelAdapter(*args, **kwargs)
453+
'''
448454
return HpuModelAdapter(*args, **kwargs)
449455

450456

@@ -649,7 +655,7 @@ def __init__(
649655
# TODO(madamczyk-intel): debug why increasing it lowers acc
650656
self.logits_rounding = 1
651657

652-
def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int],
658+
def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: list[int],
653659
is_prompt: bool):
654660
'''
655661
This is a helper function to create the mask for lora computations.
@@ -747,7 +753,7 @@ def load_lora_model(self, model: nn.Module, model_config: ModelConfig,
747753
)
748754
return self.lora_manager.create_lora_manager(model)
749755

750-
def set_active_loras(self, lora_requests: Set[LoRARequest],
756+
def set_active_loras(self, lora_requests: set[LoRARequest],
751757
lora_mapping: LoRAMapping) -> None:
752758
if not self.lora_manager:
753759
raise RuntimeError("LoRA is not enabled.")
@@ -1160,7 +1166,6 @@ def _extract_prefill_batch_contents(self, num_prefills, num_decodes,
11601166
logits_positions = list(
11611167
range(query_len - num_output_logits, query_len))
11621168

1163-
11641169
new_batch_contents = BatchContents(
11651170
req_ids=[req_id],
11661171
token_ids=[token_ids],
@@ -1491,15 +1496,17 @@ def _check_config(self, batch_size, seq_len, num_blocks, attn_metadata,
14911496
"Configuration: (%s, %s, %s, %s) was not warmed-up!", phase,
14921497
batch_size, seq_len, num_blocks)
14931498

1494-
def _execute_model_generic(self,
1495-
token_ids,
1496-
position_ids,
1497-
attn_metadata,
1498-
logits_indices,
1499-
kv_caches,
1500-
lora_logits_mask,
1501-
warmup_mode=False,
1502-
):
1499+
def _execute_model_generic(
1500+
self,
1501+
token_ids,
1502+
position_ids,
1503+
attn_metadata,
1504+
logits_indices,
1505+
kv_caches,
1506+
lora_logits_mask,
1507+
lora_mask,
1508+
warmup_mode=False,
1509+
):
15031510

15041511
# FORWARD.
15051512
batch_size = token_ids.size(0)
@@ -1519,7 +1526,8 @@ def _execute_model_generic(self,
15191526
hidden_states = self.model.forward(input_ids=token_ids,
15201527
positions=position_ids,
15211528
attn_metadata=trimmed_attn_metadata,
1522-
kv_caches=kv_caches)
1529+
kv_caches=kv_caches,
1530+
lora_mask=lora_mask)
15231531
# NOTE(kzawora): returning hidden_states is required in prompt logprobs
15241532
# scenarios, as they will do logit processing on their own
15251533
non_flattened_hidden_states = hidden_states
@@ -1695,29 +1703,35 @@ def execute_model(
16951703
lora_ids = []
16961704
lora_index_mapping = []
16971705
lora_prompt_mapping = []
1706+
lora_mask = None
1707+
lora_logits_mask = None
16981708
###### Code for LoRA. Move to a function later #######
1699-
# We only need lora_mask and lora_logits_mask here, everything else
1700-
# could have been done in _prepare_inputs
1701-
for i, r_id in enumerate(req_id):
1702-
lora_request = self.requests[r_id].lora_request
1703-
lora_id = self.requests[r_id].lora_request.lora_int_id if lora_request else 0
1704-
if lora_id > 0:
1705-
lora_requests.append(lora_request)
1706-
lora_index_mapping += [lora_id] * (token_ids.shape[1])
1707-
lora_prompt_mapping += [lora_id] #TODO: This may need to change for some cases
1708-
lora_ids.append(lora_id)
1709-
lora_mapping = LoRAMapping(lora_index_mapping,
1710-
lora_prompt_mapping,
1711-
is_prefill=False)
1712-
self.set_active_loras(lora_requests, lora_mapping)
1713-
lora_mask, lora_logits_mask = self.create_lora_mask(
1714-
token_ids, lora_ids,True)
1715-
LoraMask.setLoraMask(lora_mask)
1709+
# We only need lora_mask and lora_logits_mask here,
1710+
# everything else could have been done in _prepare_inputs
1711+
if self.lora_config:
1712+
for i, r_id in enumerate(req_id):
1713+
lora_request = self.requests[r_id].lora_request
1714+
lora_id = self.requests[
1715+
r_id].lora_request.lora_int_id if \
1716+
lora_request else 0
1717+
if lora_id > 0:
1718+
lora_requests.append(lora_request)
1719+
lora_index_mapping += [lora_id] * (token_ids.shape[1])
1720+
lora_prompt_mapping += [
1721+
lora_id
1722+
] #TODO: This may need to change for some cases
1723+
lora_ids.append(lora_id)
1724+
lora_mapping = LoRAMapping(lora_index_mapping,
1725+
lora_prompt_mapping,
1726+
is_prefill=False)
1727+
self.set_active_loras(lora_requests, lora_mapping)
1728+
lora_mask, lora_logits_mask = self.create_lora_mask(
1729+
token_ids, lora_ids, True)
17161730

17171731
prefill_hidden_states_ts, logits_device = \
17181732
self._execute_model_generic(
17191733
token_ids, position_ids, attn_metadata, logits_indices,
1720-
self.kv_caches, lora_logits_mask)
1734+
self.kv_caches, lora_logits_mask, lora_mask)
17211735
htorch.core.mark_step()
17221736

17231737
sampling_metadata = self._prepare_sampling(
@@ -1738,33 +1752,39 @@ def execute_model(
17381752
lora_ids = []
17391753
lora_index_mapping = []
17401754
lora_prompt_mapping = []
1755+
lora_mask = None
1756+
lora_logits_mask = None
17411757
###### Code for LoRA. Move to a function later #######
1742-
for i, r_id in enumerate(pd_info.decode_req_ids):
1743-
lora_request = self.requests[r_id].lora_request
1744-
lora_id = self.requests[r_id].lora_request.lora_int_id if lora_request else 0
1745-
lora_requests = []
1746-
if lora_id > 0:
1747-
lora_requests.append(lora_request)
1748-
lora_index_mapping += [lora_id]
1749-
lora_prompt_mapping += [lora_id]
1750-
lora_ids.append(lora_id)
1751-
if decode_data.token_ids.shape[0] > len(pd_info.decode_req_ids): #TODO: Need to remove this hack for handling padding
1752-
for i in range(decode_data.token_ids.shape[0] - len(pd_info.decode_req_ids)):
1758+
if self.lora_config:
1759+
for i, r_id in enumerate(pd_info.decode_req_ids):
1760+
lora_request = self.requests[r_id].lora_request
1761+
lora_id = self.requests[
1762+
r_id].lora_request.lora_int_id if lora_request else 0
1763+
lora_requests = []
1764+
if lora_id > 0:
1765+
lora_requests.append(lora_request)
17531766
lora_index_mapping += [lora_id]
17541767
lora_prompt_mapping += [lora_id]
17551768
lora_ids.append(lora_id)
1756-
lora_mapping = LoRAMapping(lora_index_mapping,
1757-
lora_prompt_mapping,
1758-
is_prefill=False)
1759-
self.set_active_loras(lora_requests, lora_mapping)
1760-
lora_mask, lora_logits_mask = self.create_lora_mask(
1761-
decode_data.token_ids, lora_ids, False)
1762-
LoraMask.setLoraMask(lora_mask)
1769+
if decode_data.token_ids.shape[0] > len(
1770+
pd_info.decode_req_ids
1771+
): #TODO: Need to remove this hack for handling padding
1772+
for i in range(decode_data.token_ids.shape[0] -
1773+
len(pd_info.decode_req_ids)):
1774+
lora_index_mapping += [0]
1775+
lora_prompt_mapping += [0]
1776+
lora_ids.append(lora_id)
1777+
lora_mapping = LoRAMapping(lora_index_mapping,
1778+
lora_prompt_mapping,
1779+
is_prefill=False)
1780+
self.set_active_loras(lora_requests, lora_mapping)
1781+
lora_mask, lora_logits_mask = self.create_lora_mask(
1782+
decode_data.token_ids, lora_ids, False)
17631783

17641784
_, logits_device = self._execute_model_generic(
17651785
decode_data.token_ids, decode_data.position_ids,
17661786
decode_data.attn_metadata, decode_data.logits_indices,
1767-
self.kv_caches, lora_logits_mask)
1787+
self.kv_caches, lora_logits_mask, lora_mask)
17681788
htorch.core.mark_step()
17691789
sampling_metadata = self._prepare_sampling(
17701790
batch_changed,
@@ -2093,8 +2113,10 @@ def warmup_scenario(self,
20932113
self.device)
20942114

20952115
# TODO: Fix the GC assert seen when this is enabled
2096-
dummy_lora_requests: List[LoRARequest] = []
2097-
dummy_lora_requests_per_seq: List[LoRARequest] = []
2116+
dummy_lora_requests: list[LoRARequest] = []
2117+
dummy_lora_requests_per_seq: list[LoRARequest] = []
2118+
lora_mask = None
2119+
lora_logits_mask = None
20982120
if self.lora_config:
20992121
assert self.lora_manager is not None
21002122
with self.lora_manager.dummy_lora_cache():
@@ -2122,20 +2144,19 @@ def warmup_scenario(self,
21222144
lora_prompt_mapping += [lora_id]
21232145
lora_ids.append(lora_id)
21242146
lora_mapping = LoRAMapping(lora_index_mapping,
2125-
lora_prompt_mapping,
2126-
is_prefill=False)
2147+
lora_prompt_mapping,
2148+
is_prefill=False)
21272149
self.set_active_loras(dummy_lora_requests_per_seq, lora_mapping)
21282150
lora_mask, lora_logits_mask = self.create_lora_mask(
21292151
input_ids, lora_ids, is_prompt)
2130-
LoraMask.setLoraMask(lora_mask)
21312152

21322153
# Dummy run.
21332154
htorch.core.mark_step()
21342155
logits = self._execute_model_generic(input_ids_device,
21352156
position_ids_device,
21362157
attn_metadata,
2137-
logits_indices_device, kv_caches, lora_logits_mask,
2138-
True)
2158+
logits_indices_device, kv_caches,
2159+
lora_logits_mask, lora_mask, True)
21392160
# TODO: do sampling on logits, warmup sampler and prefill joiner
21402161
htorch.core.mark_step()
21412162
if self.lora_config:

0 commit comments

Comments
 (0)