7
7
import os
8
8
import time
9
9
from dataclasses import dataclass , field , fields
10
- from typing import TYPE_CHECKING , Any , Callable , Optional , TypeAlias , Union , Set , List
10
+ from typing import TYPE_CHECKING , Any , Callable , Optional , TypeAlias , Union
11
11
12
12
import habana_frameworks .torch as htorch
13
13
import habana_frameworks .torch .internal .bridge_config as bc
@@ -407,6 +407,10 @@ def forward(self, *args, **kwargs):
407
407
# kwargs['attn_metadata'].slot_mapping, compared to untrimmed metadata
408
408
kwargs = kwargs .copy ()
409
409
# selected_token_indices = kwargs.pop('selected_token_indices')
410
+ if 'lora_mask' in kwargs :
411
+ lora_mask = kwargs ['lora_mask' ]
412
+ LoraMask .setLoraMask (lora_mask )
413
+ kwargs .pop ('lora_mask' )
410
414
if 'warmup_mode' in kwargs :
411
415
kwargs .pop ('warmup_mode' )
412
416
input_ids = kwargs ['input_ids' ]
@@ -442,9 +446,11 @@ def generate_proposals(self, *args, **kwargs):
442
446
443
447
444
448
def _maybe_wrap_in_hpu_graph (* args , ** kwargs ):
445
- '''return htorch.hpu.wrap_in_hpu_graph(
449
+ '''
450
+ return htorch.hpu.wrap_in_hpu_graph(
446
451
HpuModelAdapter(*args, **kwargs), disable_tensor_cache=True
447
- ) if htorch.utils.internal.is_lazy() else HpuModelAdapter(*args, **kwargs)'''
452
+ ) if htorch.utils.internal.is_lazy() else HpuModelAdapter(*args, **kwargs)
453
+ '''
448
454
return HpuModelAdapter (* args , ** kwargs )
449
455
450
456
@@ -649,7 +655,7 @@ def __init__(
649
655
# TODO(madamczyk-intel): debug why increasing it lowers acc
650
656
self .logits_rounding = 1
651
657
652
- def create_lora_mask (self , input_tokens : torch .Tensor , lora_ids : List [int ],
658
+ def create_lora_mask (self , input_tokens : torch .Tensor , lora_ids : list [int ],
653
659
is_prompt : bool ):
654
660
'''
655
661
This is a helper function to create the mask for lora computations.
@@ -747,7 +753,7 @@ def load_lora_model(self, model: nn.Module, model_config: ModelConfig,
747
753
)
748
754
return self .lora_manager .create_lora_manager (model )
749
755
750
- def set_active_loras (self , lora_requests : Set [LoRARequest ],
756
+ def set_active_loras (self , lora_requests : set [LoRARequest ],
751
757
lora_mapping : LoRAMapping ) -> None :
752
758
if not self .lora_manager :
753
759
raise RuntimeError ("LoRA is not enabled." )
@@ -1160,7 +1166,6 @@ def _extract_prefill_batch_contents(self, num_prefills, num_decodes,
1160
1166
logits_positions = list (
1161
1167
range (query_len - num_output_logits , query_len ))
1162
1168
1163
-
1164
1169
new_batch_contents = BatchContents (
1165
1170
req_ids = [req_id ],
1166
1171
token_ids = [token_ids ],
@@ -1491,15 +1496,17 @@ def _check_config(self, batch_size, seq_len, num_blocks, attn_metadata,
1491
1496
"Configuration: (%s, %s, %s, %s) was not warmed-up!" , phase ,
1492
1497
batch_size , seq_len , num_blocks )
1493
1498
1494
- def _execute_model_generic (self ,
1495
- token_ids ,
1496
- position_ids ,
1497
- attn_metadata ,
1498
- logits_indices ,
1499
- kv_caches ,
1500
- lora_logits_mask ,
1501
- warmup_mode = False ,
1502
- ):
1499
+ def _execute_model_generic (
1500
+ self ,
1501
+ token_ids ,
1502
+ position_ids ,
1503
+ attn_metadata ,
1504
+ logits_indices ,
1505
+ kv_caches ,
1506
+ lora_logits_mask ,
1507
+ lora_mask ,
1508
+ warmup_mode = False ,
1509
+ ):
1503
1510
1504
1511
# FORWARD.
1505
1512
batch_size = token_ids .size (0 )
@@ -1519,7 +1526,8 @@ def _execute_model_generic(self,
1519
1526
hidden_states = self .model .forward (input_ids = token_ids ,
1520
1527
positions = position_ids ,
1521
1528
attn_metadata = trimmed_attn_metadata ,
1522
- kv_caches = kv_caches )
1529
+ kv_caches = kv_caches ,
1530
+ lora_mask = lora_mask )
1523
1531
# NOTE(kzawora): returning hidden_states is required in prompt logprobs
1524
1532
# scenarios, as they will do logit processing on their own
1525
1533
non_flattened_hidden_states = hidden_states
@@ -1695,29 +1703,35 @@ def execute_model(
1695
1703
lora_ids = []
1696
1704
lora_index_mapping = []
1697
1705
lora_prompt_mapping = []
1706
+ lora_mask = None
1707
+ lora_logits_mask = None
1698
1708
###### Code for LoRA. Move to a function later #######
1699
- # We only need lora_mask and lora_logits_mask here, everything else
1700
- # could have been done in _prepare_inputs
1701
- for i , r_id in enumerate (req_id ):
1702
- lora_request = self .requests [r_id ].lora_request
1703
- lora_id = self .requests [r_id ].lora_request .lora_int_id if lora_request else 0
1704
- if lora_id > 0 :
1705
- lora_requests .append (lora_request )
1706
- lora_index_mapping += [lora_id ] * (token_ids .shape [1 ])
1707
- lora_prompt_mapping += [lora_id ] #TODO: This may need to change for some cases
1708
- lora_ids .append (lora_id )
1709
- lora_mapping = LoRAMapping (lora_index_mapping ,
1710
- lora_prompt_mapping ,
1711
- is_prefill = False )
1712
- self .set_active_loras (lora_requests , lora_mapping )
1713
- lora_mask , lora_logits_mask = self .create_lora_mask (
1714
- token_ids , lora_ids ,True )
1715
- LoraMask .setLoraMask (lora_mask )
1709
+ # We only need lora_mask and lora_logits_mask here,
1710
+ # everything else could have been done in _prepare_inputs
1711
+ if self .lora_config :
1712
+ for i , r_id in enumerate (req_id ):
1713
+ lora_request = self .requests [r_id ].lora_request
1714
+ lora_id = self .requests [
1715
+ r_id ].lora_request .lora_int_id if \
1716
+ lora_request else 0
1717
+ if lora_id > 0 :
1718
+ lora_requests .append (lora_request )
1719
+ lora_index_mapping += [lora_id ] * (token_ids .shape [1 ])
1720
+ lora_prompt_mapping += [
1721
+ lora_id
1722
+ ] #TODO: This may need to change for some cases
1723
+ lora_ids .append (lora_id )
1724
+ lora_mapping = LoRAMapping (lora_index_mapping ,
1725
+ lora_prompt_mapping ,
1726
+ is_prefill = False )
1727
+ self .set_active_loras (lora_requests , lora_mapping )
1728
+ lora_mask , lora_logits_mask = self .create_lora_mask (
1729
+ token_ids , lora_ids , True )
1716
1730
1717
1731
prefill_hidden_states_ts , logits_device = \
1718
1732
self ._execute_model_generic (
1719
1733
token_ids , position_ids , attn_metadata , logits_indices ,
1720
- self .kv_caches , lora_logits_mask )
1734
+ self .kv_caches , lora_logits_mask , lora_mask )
1721
1735
htorch .core .mark_step ()
1722
1736
1723
1737
sampling_metadata = self ._prepare_sampling (
@@ -1738,33 +1752,39 @@ def execute_model(
1738
1752
lora_ids = []
1739
1753
lora_index_mapping = []
1740
1754
lora_prompt_mapping = []
1755
+ lora_mask = None
1756
+ lora_logits_mask = None
1741
1757
###### Code for LoRA. Move to a function later #######
1742
- for i , r_id in enumerate (pd_info .decode_req_ids ):
1743
- lora_request = self .requests [r_id ].lora_request
1744
- lora_id = self .requests [r_id ].lora_request .lora_int_id if lora_request else 0
1745
- lora_requests = []
1746
- if lora_id > 0 :
1747
- lora_requests .append (lora_request )
1748
- lora_index_mapping += [lora_id ]
1749
- lora_prompt_mapping += [lora_id ]
1750
- lora_ids .append (lora_id )
1751
- if decode_data .token_ids .shape [0 ] > len (pd_info .decode_req_ids ): #TODO: Need to remove this hack for handling padding
1752
- for i in range (decode_data .token_ids .shape [0 ] - len (pd_info .decode_req_ids )):
1758
+ if self .lora_config :
1759
+ for i , r_id in enumerate (pd_info .decode_req_ids ):
1760
+ lora_request = self .requests [r_id ].lora_request
1761
+ lora_id = self .requests [
1762
+ r_id ].lora_request .lora_int_id if lora_request else 0
1763
+ lora_requests = []
1764
+ if lora_id > 0 :
1765
+ lora_requests .append (lora_request )
1753
1766
lora_index_mapping += [lora_id ]
1754
1767
lora_prompt_mapping += [lora_id ]
1755
1768
lora_ids .append (lora_id )
1756
- lora_mapping = LoRAMapping (lora_index_mapping ,
1757
- lora_prompt_mapping ,
1758
- is_prefill = False )
1759
- self .set_active_loras (lora_requests , lora_mapping )
1760
- lora_mask , lora_logits_mask = self .create_lora_mask (
1761
- decode_data .token_ids , lora_ids , False )
1762
- LoraMask .setLoraMask (lora_mask )
1769
+ if decode_data .token_ids .shape [0 ] > len (
1770
+ pd_info .decode_req_ids
1771
+ ): #TODO: Need to remove this hack for handling padding
1772
+ for i in range (decode_data .token_ids .shape [0 ] -
1773
+ len (pd_info .decode_req_ids )):
1774
+ lora_index_mapping += [0 ]
1775
+ lora_prompt_mapping += [0 ]
1776
+ lora_ids .append (lora_id )
1777
+ lora_mapping = LoRAMapping (lora_index_mapping ,
1778
+ lora_prompt_mapping ,
1779
+ is_prefill = False )
1780
+ self .set_active_loras (lora_requests , lora_mapping )
1781
+ lora_mask , lora_logits_mask = self .create_lora_mask (
1782
+ decode_data .token_ids , lora_ids , False )
1763
1783
1764
1784
_ , logits_device = self ._execute_model_generic (
1765
1785
decode_data .token_ids , decode_data .position_ids ,
1766
1786
decode_data .attn_metadata , decode_data .logits_indices ,
1767
- self .kv_caches , lora_logits_mask )
1787
+ self .kv_caches , lora_logits_mask , lora_mask )
1768
1788
htorch .core .mark_step ()
1769
1789
sampling_metadata = self ._prepare_sampling (
1770
1790
batch_changed ,
@@ -2093,8 +2113,10 @@ def warmup_scenario(self,
2093
2113
self .device )
2094
2114
2095
2115
# TODO: Fix the GC assert seen when this is enabled
2096
- dummy_lora_requests : List [LoRARequest ] = []
2097
- dummy_lora_requests_per_seq : List [LoRARequest ] = []
2116
+ dummy_lora_requests : list [LoRARequest ] = []
2117
+ dummy_lora_requests_per_seq : list [LoRARequest ] = []
2118
+ lora_mask = None
2119
+ lora_logits_mask = None
2098
2120
if self .lora_config :
2099
2121
assert self .lora_manager is not None
2100
2122
with self .lora_manager .dummy_lora_cache ():
@@ -2122,20 +2144,19 @@ def warmup_scenario(self,
2122
2144
lora_prompt_mapping += [lora_id ]
2123
2145
lora_ids .append (lora_id )
2124
2146
lora_mapping = LoRAMapping (lora_index_mapping ,
2125
- lora_prompt_mapping ,
2126
- is_prefill = False )
2147
+ lora_prompt_mapping ,
2148
+ is_prefill = False )
2127
2149
self .set_active_loras (dummy_lora_requests_per_seq , lora_mapping )
2128
2150
lora_mask , lora_logits_mask = self .create_lora_mask (
2129
2151
input_ids , lora_ids , is_prompt )
2130
- LoraMask .setLoraMask (lora_mask )
2131
2152
2132
2153
# Dummy run.
2133
2154
htorch .core .mark_step ()
2134
2155
logits = self ._execute_model_generic (input_ids_device ,
2135
2156
position_ids_device ,
2136
2157
attn_metadata ,
2137
- logits_indices_device , kv_caches , lora_logits_mask ,
2138
- True )
2158
+ logits_indices_device , kv_caches ,
2159
+ lora_logits_mask , lora_mask , True )
2139
2160
# TODO: do sampling on logits, warmup sampler and prefill joiner
2140
2161
htorch .core .mark_step ()
2141
2162
if self .lora_config :
0 commit comments