@@ -1227,7 +1227,6 @@ def _form_prefill_batch(self, contents):
1227
1227
1228
1228
query_lens = _async_h2d_tensor (query_lens , torch .int32 )
1229
1229
token_ids = _async_h2d_tensor (token_ids , torch .int32 )
1230
-
1231
1230
token_positions = _async_h2d_tensor (token_positions , torch .int32 )
1232
1231
token_slots = _async_h2d_tensor (token_slots , torch .int64 )
1233
1232
logits_indices = _async_h2d_tensor (logits_indices , torch .int32 )
@@ -1294,7 +1293,6 @@ def _prepare_decode_inputs(self, num_decodes,
1294
1293
num_decodes , sum (num_blocks ))[0 ]
1295
1294
1296
1295
# # dp aware padding
1297
- assert padded_batch_size is not None
1298
1296
padded_batch_size += self .get_dp_padding (padded_batch_size )
1299
1297
1300
1298
block_tables_list = []
@@ -1754,8 +1752,6 @@ def execute_model(
1754
1752
1755
1753
######################### PREFILLS #########################
1756
1754
if num_prefills > 0 :
1757
- # Wuxun: merged prefill forward if enabled
1758
- # 2D bucketing or merged prefill bucketing
1759
1755
htorch .core .mark_step ()
1760
1756
for idx , (req_id , prompt_len , token_ids , position_ids ,
1761
1757
attn_metadata , logits_indices ,
@@ -2098,6 +2094,121 @@ def log_graph_warmup_summary(self, buckets, is_prompt, total_mem):
2098
2094
f'used_mem:{ format_bytes (total_mem )} ' )
2099
2095
logger .info (msg )
2100
2096
2097
+ < << << << HEAD
2098
+ == == == =
2099
+ def warmup_scenario (self ,
2100
+ batch_size ,
2101
+ seq_or_block ,
2102
+ num_blocks ,
2103
+ is_prompt ,
2104
+ kv_caches ,
2105
+ num_iters = 3 ,
2106
+ is_pt_profiler_run = True ,
2107
+ align_worker = False ,
2108
+ is_dummy_run = False ) -> None :
2109
+ """Dummy warmup run for memory usage and graph compilation."""
2110
+
2111
+ query_seq_len = seq_or_block if is_prompt else 1
2112
+ input_ids = torch .zeros ((batch_size , query_seq_len ),
2113
+ dtype = torch .int32 ,
2114
+ device = 'cpu' )
2115
+ position_ids = torch .zeros ((batch_size , query_seq_len ),
2116
+ dtype = torch .int32 ,
2117
+ device = 'cpu' )
2118
+ slot_mapping = torch .zeros ((batch_size , query_seq_len ),
2119
+ dtype = torch .int64 ,
2120
+ device = 'cpu' )
2121
+
2122
+ input_ids_device = _async_h2d_tensor_copy (input_ids , self .device )
2123
+ position_ids_device = _async_h2d_tensor_copy (position_ids , self .device )
2124
+ slot_mapping_device = _async_h2d_tensor_copy (slot_mapping , self .device )
2125
+
2126
+ use_graphs = is_dummy_run or self ._use_graphs ()
2127
+ phase = "prompt" if is_prompt else "decode"
2128
+ scenario_name = ("warmup_"
2129
+ f"{ phase } _"
2130
+ f"bs{ batch_size } _"
2131
+ f"seq{ query_seq_len } _"
2132
+ f"ctx{ num_blocks } _"
2133
+ f"graphs{ 'T' if use_graphs else 'F' } " )
2134
+ input_ids = torch .zeros ((batch_size , query_seq_len ),
2135
+ dtype = torch .int32 ,
2136
+ device = 'cpu' )
2137
+ position_ids = torch .zeros ((batch_size , query_seq_len ),
2138
+ dtype = torch .int32 ,
2139
+ device = 'cpu' )
2140
+ slot_mapping = torch .zeros ((batch_size , query_seq_len ),
2141
+ dtype = torch .int64 ,
2142
+ device = 'cpu' )
2143
+
2144
+ input_ids_device = _async_h2d_tensor_copy (input_ids , self .device )
2145
+ position_ids_device = _async_h2d_tensor_copy (position_ids , self .device )
2146
+ slot_mapping_device = _async_h2d_tensor_copy (slot_mapping , self .device )
2147
+ self .profiler .start ('internal' , scenario_name )
2148
+
2149
+ times = num_iters if use_graphs or is_pt_profiler_run else 1
2150
+ for time_index in range (times ):
2151
+ if is_prompt :
2152
+ seq_lens = torch .zeros ((batch_size ),
2153
+ dtype = torch .int32 ,
2154
+ device = 'cpu' )
2155
+ seq_lens .fill_ (seq_or_block )
2156
+ seq_lens_device = _async_h2d_tensor_copy (seq_lens , self .device )
2157
+ block_list_device = None
2158
+ if num_blocks :
2159
+ prefix_block_tables = torch .ones (
2160
+ (batch_size , num_blocks ),
2161
+ dtype = torch .int32 ,
2162
+ device = 'cpu' ) * self ._PAD_BLOCK_ID
2163
+ block_list_device = _async_h2d_tensor_copy (
2164
+ prefix_block_tables .flatten (), self .device )
2165
+ attn_metadata = \
2166
+ HPUAttentionMetadataV1 .make_prefill_metadata (
2167
+ attn_bias = None ,
2168
+ seq_lens_tensor = seq_lens_device ,
2169
+ context_lens_tensor = seq_lens_device ,
2170
+ slot_mapping = slot_mapping_device ,
2171
+ block_list = block_list_device ,
2172
+ block_size = self .block_size )
2173
+ else :
2174
+ block_tables = [
2175
+ x .tolist ()
2176
+ for x in np .array_split (np .arange (num_blocks ), batch_size )
2177
+ ]
2178
+ block_list , block_groups , block_usage = \
2179
+ self .get_habana_paged_attn_buffers (
2180
+ slot_mapping = slot_mapping ,
2181
+ block_tables = block_tables ,
2182
+ batch_size = batch_size )
2183
+ block_list_device = _async_h2d_tensor_copy (
2184
+ block_list , self .device )
2185
+ block_usage_device = _async_h2d_tensor_copy (
2186
+ block_usage , self .device )
2187
+ block_groups_device = _async_h2d_tensor_copy (
2188
+ block_groups , self .device )
2189
+ attn_metadata = HPUAttentionMetadataV1 .make_decode_metadata (
2190
+ block_list = block_list_device ,
2191
+ block_usage = block_usage_device ,
2192
+ block_groups = block_groups_device ,
2193
+ num_decode_tokens = batch_size ,
2194
+ input_positions = None ,
2195
+ slot_mapping = slot_mapping_device ,
2196
+ block_size = self .block_size )
2197
+
2198
+ logits_indices = torch .arange (0 , batch_size , device = 'cpu' )
2199
+ logits_indices_device = _async_h2d_tensor_copy (logits_indices ,
2200
+ self .device )
2201
+ # Dummy run.
2202
+ htorch .core .mark_step ()
2203
+ _ = self ._execute_model_generic (input_ids_device , position_ids_device ,
2204
+ attn_metadata , logits_indices_device ,
2205
+ kv_caches , True )
2206
+ # TODO: do sampling on logits, warmup sampler and prefill joiner
2207
+ htorch .core .mark_step ()
2208
+ self .profiler .end ()
2209
+ return None
2210
+
2211
+ >> >> >> > 68 ee934 (fix )
2101
2212
def log_warmup (self , phase , i , max_i , batch_size , seq_len , num_blocks ):
2102
2213
free_mem = format_bytes (
2103
2214
HabanaMemoryProfiler .current_free_device_memory ())
0 commit comments