@@ -1633,13 +1633,12 @@ def _prepare_prefill_inputs(
1633
1633
return all_batches [0 ], dummy_prefill_input_batches [
1634
1634
0 ] if dummy_prefill_input_batches else None
1635
1635
1636
- def _create_decode_input_data (
1637
- self ,
1638
- num_decodes ,
1639
- num_scheduled_tokens ,
1640
- context_lens ,
1641
- block_table_cpu_tensor ,
1642
- scheduler_output = None ) -> tuple [DecodeInputData , int ]:
1636
+ def _create_decode_input_data (self ,
1637
+ num_decodes ,
1638
+ num_scheduled_tokens ,
1639
+ context_lens ,
1640
+ block_table_cpu_tensor ,
1641
+ scheduler_output = None ) -> DecodeInputData :
1643
1642
# NOTE(kzawora): the +1 is what causes this entire thing to work,
1644
1643
# as in the paged attention, we don't fetch just the context from cache,
1645
1644
# but also kvs for the current token
@@ -1652,8 +1651,7 @@ def _create_decode_input_data(
1652
1651
num_decodes , sum (num_blocks ))[0 ]
1653
1652
1654
1653
# dp aware padding
1655
- num_pad_across_dp = self .get_dp_padding (padded_batch_size )
1656
- padded_batch_size += num_pad_across_dp
1654
+ padded_batch_size += self .get_dp_padding (padded_batch_size )
1657
1655
1658
1656
num_tokens_per_req = num_scheduled_tokens [:num_decodes ]
1659
1657
num_tokens = max (num_tokens_per_req )
@@ -1843,7 +1841,7 @@ def _create_decode_input_data(
1843
1841
block_size = self .block_size ,
1844
1842
query_start_loc = query_start_loc ,
1845
1843
),
1846
- spec_decode_metadata = spec_decode_metadata ), num_pad_across_dp
1844
+ spec_decode_metadata = spec_decode_metadata )
1847
1845
1848
1846
def _prepare_decode_inputs (
1849
1847
self ,
@@ -1868,7 +1866,8 @@ def _prepare_decode_inputs(
1868
1866
return self ._create_decode_input_data (
1869
1867
num_decodes , num_scheduled_tokens ,
1870
1868
self .input_batch .num_computed_tokens_cpu [:num_decodes ],
1871
- self .input_batch .block_table [0 ].get_cpu_tensor (), scheduler_output )
1869
+ self .input_batch .block_table [0 ].get_cpu_tensor (),
1870
+ scheduler_output ), None
1872
1871
1873
1872
def _create_dummy_decode_input_data (self ) -> DecodeInputData :
1874
1873
# create dummy decode input data with batch size 1
@@ -1877,13 +1876,10 @@ def _create_dummy_decode_input_data(self) -> DecodeInputData:
1877
1876
context_lens = [128 ]
1878
1877
block_table_cpu_tensor = torch .zeros ([self ._PAD_BLOCK_ID ],
1879
1878
dtype = torch .int32 ).reshape (1 , - 1 )
1880
- # num_computed_tokens_cpu = np.array([128], dtype=np.int32)
1881
- # token_ids = np.array(list(int(i) for i in range(context_lens[0])))
1882
-
1883
1879
return self ._create_decode_input_data (num_dummy_decodes ,
1884
1880
num_dummy_scheduled_tokens ,
1885
1881
context_lens ,
1886
- block_table_cpu_tensor )[ 0 ]
1882
+ block_table_cpu_tensor )
1887
1883
1888
1884
def _get_cumsum_and_arange (
1889
1885
self ,
@@ -2570,17 +2566,6 @@ def execute_model(
2570
2566
prompt_batch_idx = None ,
2571
2567
is_prompt = False )
2572
2568
self .profiler .record_counter (self .event_start , counters )
2573
- else :
2574
- if dummy_decode_input_data_across_dp is not None :
2575
- htorch .core .mark_step ()
2576
- _ , _ , dummy_logits_device = self ._execute_model_generic (
2577
- dummy_decode_input_data_across_dp .token_ids ,
2578
- dummy_decode_input_data_across_dp .position_ids ,
2579
- dummy_decode_input_data_across_dp .attn_metadata ,
2580
- dummy_decode_input_data_across_dp .logits_indices ,
2581
- self .kv_caches ,
2582
- warmup_mode = warmup_mode )
2583
- htorch .core .mark_step ()
2584
2569
2585
2570
################## Spec Decode ##################
2586
2571
# work on spec decode if max_gen_len > 1
@@ -2617,6 +2602,17 @@ def execute_model(
2617
2602
spec_decode_metadata , spec_decode_common_attn_metadata ,
2618
2603
decode_data )[:num_decodes ]
2619
2604
################## Spec Decode end ##################
2605
+ else :
2606
+ if dummy_decode_input_data_across_dp is not None :
2607
+ htorch .core .mark_step ()
2608
+ _ , _ , dummy_logits_device = self ._execute_model_generic (
2609
+ dummy_decode_input_data_across_dp .token_ids ,
2610
+ dummy_decode_input_data_across_dp .position_ids ,
2611
+ dummy_decode_input_data_across_dp .attn_metadata ,
2612
+ dummy_decode_input_data_across_dp .logits_indices ,
2613
+ self .kv_caches ,
2614
+ warmup_mode = warmup_mode )
2615
+ htorch .core .mark_step ()
2620
2616
2621
2617
if structured_output :
2622
2618
# Scheduler places cached before prompt
0 commit comments