Skip to content

Commit fb18a7e

Browse files
committed
Move LoRA configuration code to separate function
Signed-off-by: Vivek <[email protected]>
1 parent 9830a9e commit fb18a7e

File tree

1 file changed

+52
-77
lines changed

1 file changed

+52
-77
lines changed

vllm_gaudi/v1/worker/hpu_model_runner.py

Lines changed: 52 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1642,6 +1642,51 @@ def _is_quant_with_inc(self):
16421642
quant_config = os.getenv("QUANT_CONFIG", None) is not None
16431643
return (self.model_config.quantization == "inc" or quant_config)
16441644

1645+
def _configure_lora(self, input, requests, req_ids, is_prompt):
1646+
lora_mask = None
1647+
lora_logits_mask = None
1648+
if self.lora_config:
1649+
if is_prompt:
1650+
lora_requests = [] if req_ids else requests
1651+
lora_ids = []
1652+
lora_index_mapping = []
1653+
lora_prompt_mapping = []
1654+
for i, r_id in enumerate(req_ids):
1655+
lora_requests.append(requests[r_id].lora_request)
1656+
for lora_req in lora_requests:
1657+
lora_id = lora_req.lora_int_id if lora_req else 0
1658+
lora_index_mapping += [lora_id] * (input.shape[1])
1659+
#TODO: This may need to change when logprobs
1660+
# sampling is enabled
1661+
lora_prompt_mapping += [lora_id]
1662+
lora_ids.append(lora_id)
1663+
else:
1664+
lora_requests = []
1665+
# lora_ids, lora_index_mapping, lora_prompt_mapping
1666+
# filled with 0 (indicating no lora) to account for
1667+
# any padding
1668+
lora_ids = [0] * input.shape[0]
1669+
lora_index_mapping = [0] * input.shape[0]
1670+
lora_prompt_mapping = [0] * input.shape[0]
1671+
for i, r_id in enumerate(req_ids):
1672+
lora_requests.append(requests[r_id].lora_request)
1673+
1674+
for i, lora_req in enumerate(lora_requests):
1675+
lora_id = lora_req.lora_int_id if lora_req else 0
1676+
lora_index_mapping[i] = lora_id
1677+
lora_prompt_mapping[i] = lora_id
1678+
lora_ids[i] = lora_id
1679+
1680+
# is_prefill should always be "False" for HPU
1681+
lora_mapping = LoRAMapping(lora_index_mapping,
1682+
lora_prompt_mapping,
1683+
is_prefill=False)
1684+
self.set_active_loras(lora_requests, lora_mapping)
1685+
lora_mask, lora_logits_mask = self.create_lora_mask(
1686+
input, lora_ids, is_prompt)
1687+
1688+
return lora_mask, lora_logits_mask
1689+
16451690
@torch.inference_mode()
16461691
def execute_model(
16471692
self,
@@ -1729,38 +1774,11 @@ def execute_model(
17291774
attn_metadata, logits_indices,
17301775
logits_requests) in enumerate(
17311776
zip(*shallow_tuple(prefill_data))):
1777+
lora_mask, lora_logits_mask = self._configure_lora(
1778+
token_ids, self.requests, req_id, True)
17321779
self.event_start = self.profiler.get_timestamp_us()
17331780
self.profiler.start("internal", "prefill")
17341781
htorch.core.mark_step()
1735-
lora_requests = []
1736-
lora_ids = []
1737-
lora_index_mapping = []
1738-
lora_prompt_mapping = []
1739-
lora_mask = None
1740-
lora_logits_mask = None
1741-
###### Code for LoRA. Move to a function later #######
1742-
# We only need lora_mask and lora_logits_mask here,
1743-
# everything else could have been done in _prepare_inputs
1744-
if self.lora_config:
1745-
for i, r_id in enumerate(req_id):
1746-
lora_request = self.requests[r_id].lora_request
1747-
lora_id = self.requests[
1748-
r_id].lora_request.lora_int_id if \
1749-
lora_request else 0
1750-
if lora_id > 0:
1751-
lora_requests.append(lora_request)
1752-
lora_index_mapping += [lora_id] * (token_ids.shape[1])
1753-
lora_prompt_mapping += [
1754-
lora_id
1755-
] #TODO: This may need to change for some cases
1756-
lora_ids.append(lora_id)
1757-
lora_mapping = LoRAMapping(lora_index_mapping,
1758-
lora_prompt_mapping,
1759-
is_prefill=False)
1760-
self.set_active_loras(lora_requests, lora_mapping)
1761-
lora_mask, lora_logits_mask = self.create_lora_mask(
1762-
token_ids, lora_ids, True)
1763-
17641782
prefill_hidden_states_ts, logits_device = \
17651783
self._execute_model_generic(
17661784
token_ids, position_ids, attn_metadata, logits_indices,
@@ -1795,43 +1813,13 @@ def execute_model(
17951813
######################### DECODES #########################
17961814
# Decodes run as one single batch with [padded_decode_bs, 1]
17971815
if num_decodes > 0:
1816+
lora_mask, lora_logits_mask = self._configure_lora(
1817+
decode_data.token_ids, self.requests, pd_info.decode_req_ids,
1818+
False)
17981819
self.event_start = self.profiler.get_timestamp_us()
17991820
self.profiler.start("internal", "decode")
18001821
assert decode_data is not None
18011822
htorch.core.mark_step()
1802-
lora_requests = []
1803-
lora_ids = []
1804-
lora_index_mapping = []
1805-
lora_prompt_mapping = []
1806-
lora_mask = None
1807-
lora_logits_mask = None
1808-
###### Code for LoRA. Move to a function later #######
1809-
if self.lora_config:
1810-
for i, r_id in enumerate(pd_info.decode_req_ids):
1811-
lora_request = self.requests[r_id].lora_request
1812-
lora_id = self.requests[
1813-
r_id].lora_request.lora_int_id if lora_request else 0
1814-
lora_requests = []
1815-
if lora_id > 0:
1816-
lora_requests.append(lora_request)
1817-
lora_index_mapping += [lora_id]
1818-
lora_prompt_mapping += [lora_id]
1819-
lora_ids.append(lora_id)
1820-
if decode_data.token_ids.shape[0] > len(
1821-
pd_info.decode_req_ids
1822-
): #TODO: Need to remove this hack for handling padding
1823-
for i in range(decode_data.token_ids.shape[0] -
1824-
len(pd_info.decode_req_ids)):
1825-
lora_index_mapping += [0]
1826-
lora_prompt_mapping += [0]
1827-
lora_ids.append(lora_id)
1828-
lora_mapping = LoRAMapping(lora_index_mapping,
1829-
lora_prompt_mapping,
1830-
is_prefill=False)
1831-
self.set_active_loras(lora_requests, lora_mapping)
1832-
lora_mask, lora_logits_mask = self.create_lora_mask(
1833-
decode_data.token_ids, lora_ids, False)
1834-
18351823
_, logits_device = self._execute_model_generic(
18361824
decode_data.token_ids, decode_data.position_ids,
18371825
decode_data.attn_metadata, decode_data.logits_indices,
@@ -2189,7 +2177,6 @@ def warmup_scenario(self,
21892177
logits_indices_device = _async_h2d_tensor_copy(logits_indices,
21902178
self.device)
21912179

2192-
# TODO: Fix the GC assert seen when this is enabled
21932180
dummy_lora_requests: list[LoRARequest] = []
21942181
dummy_lora_requests_per_seq: list[LoRARequest] = []
21952182
lora_mask = None
@@ -2212,20 +2199,8 @@ def warmup_scenario(self,
22122199
for idx in range(batch_size)
22132200
]
22142201

2215-
lora_ids = []
2216-
lora_index_mapping = []
2217-
lora_prompt_mapping = []
2218-
for idx in range(batch_size):
2219-
lora_id = dummy_lora_requests_per_seq[idx].lora_int_id
2220-
lora_index_mapping += [lora_id] * query_seq_len
2221-
lora_prompt_mapping += [lora_id]
2222-
lora_ids.append(lora_id)
2223-
lora_mapping = LoRAMapping(lora_index_mapping,
2224-
lora_prompt_mapping,
2225-
is_prefill=False)
2226-
self.set_active_loras(dummy_lora_requests_per_seq, lora_mapping)
2227-
lora_mask, lora_logits_mask = self.create_lora_mask(
2228-
input_ids, lora_ids, is_prompt)
2202+
lora_mask, lora_logits_mask = self._configure_lora(
2203+
input_ids, dummy_lora_requests_per_seq, [], True)
22292204

22302205
# Dummy run.
22312206
htorch.core.mark_step()

0 commit comments

Comments
 (0)