@@ -1642,6 +1642,51 @@ def _is_quant_with_inc(self):
1642
1642
quant_config = os .getenv ("QUANT_CONFIG" , None ) is not None
1643
1643
return (self .model_config .quantization == "inc" or quant_config )
1644
1644
1645
+ def _configure_lora (self , input , requests , req_ids , is_prompt ):
1646
+ lora_mask = None
1647
+ lora_logits_mask = None
1648
+ if self .lora_config :
1649
+ if is_prompt :
1650
+ lora_requests = [] if req_ids else requests
1651
+ lora_ids = []
1652
+ lora_index_mapping = []
1653
+ lora_prompt_mapping = []
1654
+ for i , r_id in enumerate (req_ids ):
1655
+ lora_requests .append (requests [r_id ].lora_request )
1656
+ for lora_req in lora_requests :
1657
+ lora_id = lora_req .lora_int_id if lora_req else 0
1658
+ lora_index_mapping += [lora_id ] * (input .shape [1 ])
1659
+ #TODO: This may need to change when logprobs
1660
+ # sampling is enabled
1661
+ lora_prompt_mapping += [lora_id ]
1662
+ lora_ids .append (lora_id )
1663
+ else :
1664
+ lora_requests = []
1665
+ # lora_ids, lora_index_mapping, lora_prompt_mapping
1666
+ # filled with 0 (indicating no lora) to account for
1667
+ # any padding
1668
+ lora_ids = [0 ] * input .shape [0 ]
1669
+ lora_index_mapping = [0 ] * input .shape [0 ]
1670
+ lora_prompt_mapping = [0 ] * input .shape [0 ]
1671
+ for i , r_id in enumerate (req_ids ):
1672
+ lora_requests .append (requests [r_id ].lora_request )
1673
+
1674
+ for i , lora_req in enumerate (lora_requests ):
1675
+ lora_id = lora_req .lora_int_id if lora_req else 0
1676
+ lora_index_mapping [i ] = lora_id
1677
+ lora_prompt_mapping [i ] = lora_id
1678
+ lora_ids [i ] = lora_id
1679
+
1680
+ # is_prefill should always be "False" for HPU
1681
+ lora_mapping = LoRAMapping (lora_index_mapping ,
1682
+ lora_prompt_mapping ,
1683
+ is_prefill = False )
1684
+ self .set_active_loras (lora_requests , lora_mapping )
1685
+ lora_mask , lora_logits_mask = self .create_lora_mask (
1686
+ input , lora_ids , is_prompt )
1687
+
1688
+ return lora_mask , lora_logits_mask
1689
+
1645
1690
@torch .inference_mode ()
1646
1691
def execute_model (
1647
1692
self ,
@@ -1729,38 +1774,11 @@ def execute_model(
1729
1774
attn_metadata , logits_indices ,
1730
1775
logits_requests ) in enumerate (
1731
1776
zip (* shallow_tuple (prefill_data ))):
1777
+ lora_mask , lora_logits_mask = self ._configure_lora (
1778
+ token_ids , self .requests , req_id , True )
1732
1779
self .event_start = self .profiler .get_timestamp_us ()
1733
1780
self .profiler .start ("internal" , "prefill" )
1734
1781
htorch .core .mark_step ()
1735
- lora_requests = []
1736
- lora_ids = []
1737
- lora_index_mapping = []
1738
- lora_prompt_mapping = []
1739
- lora_mask = None
1740
- lora_logits_mask = None
1741
- ###### Code for LoRA. Move to a function later #######
1742
- # We only need lora_mask and lora_logits_mask here,
1743
- # everything else could have been done in _prepare_inputs
1744
- if self .lora_config :
1745
- for i , r_id in enumerate (req_id ):
1746
- lora_request = self .requests [r_id ].lora_request
1747
- lora_id = self .requests [
1748
- r_id ].lora_request .lora_int_id if \
1749
- lora_request else 0
1750
- if lora_id > 0 :
1751
- lora_requests .append (lora_request )
1752
- lora_index_mapping += [lora_id ] * (token_ids .shape [1 ])
1753
- lora_prompt_mapping += [
1754
- lora_id
1755
- ] #TODO: This may need to change for some cases
1756
- lora_ids .append (lora_id )
1757
- lora_mapping = LoRAMapping (lora_index_mapping ,
1758
- lora_prompt_mapping ,
1759
- is_prefill = False )
1760
- self .set_active_loras (lora_requests , lora_mapping )
1761
- lora_mask , lora_logits_mask = self .create_lora_mask (
1762
- token_ids , lora_ids , True )
1763
-
1764
1782
prefill_hidden_states_ts , logits_device = \
1765
1783
self ._execute_model_generic (
1766
1784
token_ids , position_ids , attn_metadata , logits_indices ,
@@ -1795,43 +1813,13 @@ def execute_model(
1795
1813
######################### DECODES #########################
1796
1814
# Decodes run as one single batch with [padded_decode_bs, 1]
1797
1815
if num_decodes > 0 :
1816
+ lora_mask , lora_logits_mask = self ._configure_lora (
1817
+ decode_data .token_ids , self .requests , pd_info .decode_req_ids ,
1818
+ False )
1798
1819
self .event_start = self .profiler .get_timestamp_us ()
1799
1820
self .profiler .start ("internal" , "decode" )
1800
1821
assert decode_data is not None
1801
1822
htorch .core .mark_step ()
1802
- lora_requests = []
1803
- lora_ids = []
1804
- lora_index_mapping = []
1805
- lora_prompt_mapping = []
1806
- lora_mask = None
1807
- lora_logits_mask = None
1808
- ###### Code for LoRA. Move to a function later #######
1809
- if self .lora_config :
1810
- for i , r_id in enumerate (pd_info .decode_req_ids ):
1811
- lora_request = self .requests [r_id ].lora_request
1812
- lora_id = self .requests [
1813
- r_id ].lora_request .lora_int_id if lora_request else 0
1814
- lora_requests = []
1815
- if lora_id > 0 :
1816
- lora_requests .append (lora_request )
1817
- lora_index_mapping += [lora_id ]
1818
- lora_prompt_mapping += [lora_id ]
1819
- lora_ids .append (lora_id )
1820
- if decode_data .token_ids .shape [0 ] > len (
1821
- pd_info .decode_req_ids
1822
- ): #TODO: Need to remove this hack for handling padding
1823
- for i in range (decode_data .token_ids .shape [0 ] -
1824
- len (pd_info .decode_req_ids )):
1825
- lora_index_mapping += [0 ]
1826
- lora_prompt_mapping += [0 ]
1827
- lora_ids .append (lora_id )
1828
- lora_mapping = LoRAMapping (lora_index_mapping ,
1829
- lora_prompt_mapping ,
1830
- is_prefill = False )
1831
- self .set_active_loras (lora_requests , lora_mapping )
1832
- lora_mask , lora_logits_mask = self .create_lora_mask (
1833
- decode_data .token_ids , lora_ids , False )
1834
-
1835
1823
_ , logits_device = self ._execute_model_generic (
1836
1824
decode_data .token_ids , decode_data .position_ids ,
1837
1825
decode_data .attn_metadata , decode_data .logits_indices ,
@@ -2189,7 +2177,6 @@ def warmup_scenario(self,
2189
2177
logits_indices_device = _async_h2d_tensor_copy (logits_indices ,
2190
2178
self .device )
2191
2179
2192
- # TODO: Fix the GC assert seen when this is enabled
2193
2180
dummy_lora_requests : list [LoRARequest ] = []
2194
2181
dummy_lora_requests_per_seq : list [LoRARequest ] = []
2195
2182
lora_mask = None
@@ -2212,20 +2199,8 @@ def warmup_scenario(self,
2212
2199
for idx in range (batch_size )
2213
2200
]
2214
2201
2215
- lora_ids = []
2216
- lora_index_mapping = []
2217
- lora_prompt_mapping = []
2218
- for idx in range (batch_size ):
2219
- lora_id = dummy_lora_requests_per_seq [idx ].lora_int_id
2220
- lora_index_mapping += [lora_id ] * query_seq_len
2221
- lora_prompt_mapping += [lora_id ]
2222
- lora_ids .append (lora_id )
2223
- lora_mapping = LoRAMapping (lora_index_mapping ,
2224
- lora_prompt_mapping ,
2225
- is_prefill = False )
2226
- self .set_active_loras (dummy_lora_requests_per_seq , lora_mapping )
2227
- lora_mask , lora_logits_mask = self .create_lora_mask (
2228
- input_ids , lora_ids , is_prompt )
2202
+ lora_mask , lora_logits_mask = self ._configure_lora (
2203
+ input_ids , dummy_lora_requests_per_seq , [], True )
2229
2204
2230
2205
# Dummy run.
2231
2206
htorch .core .mark_step ()
0 commit comments