@@ -120,10 +120,10 @@ def tree_speculative_sampling(
120120 sum_relu = tkw .sum (relu_diff , dim = VOCAB_SIZE )
121121 cdf = tkw .cumsum (relu_diff , dim = VOCAB_SIZE )
122122
123- threshold_u = tkw .broadcast (
123+ threshold_dist_u = tkw .broadcast (
124124 coin * sum_relu , target_shape = [BATCH_SIZE , NUM_DRAFT_TOKENS , VOCAB_SIZE ]
125125 )
126- greater_than_u = cdf > threshold_u
126+ greater_than_u = cdf > threshold_dist_u
127127 pad_token = tkl .Register [BATCH_SIZE , NUM_DRAFT_TOKENS , VOCAB_SIZE , tkl .i32 ](1e6 )
128128 token_idx = tkl .Register [BATCH_SIZE , NUM_DRAFT_TOKENS , VOCAB_SIZE , tkl .i32 ](
129129 THREAD_0
@@ -146,25 +146,28 @@ def get_speculative_sampling_kernel(
146146 threshold_single : float ,
147147 num_draft_tokens : int ,
148148 vocab_size : int ,
149+ seq_len : int ,
149150):
150151 CUR_INDEX = sympy .Symbol ("CUR_INDEX" )
151152 J = sympy .Symbol ("J" )
152153 BATCH_SIZE = tkl .sym .BATCH_SIZE
153154 NUM_DRAFT_TOKENS = tkl .sym .NUM_DRAFT_TOKENS
154155 VOCAB_SIZE = tkl .sym .VOCAB_SIZE
156+ SEQ_LEN = tkl .sym .SEQ_LEN
155157 BLOCK_BATCH_SIZE = tkl .sym .BLOCK_BATCH_SIZE
156158 BLOCK_NUM_DRAFT_TOK = tkl .sym .BLOCK_NUM_DRAFT_TOK
157159 ADDRESS_SPACE = tkl .sym .ADDRESS_SPACE
158- ADDRESS_SPACE_0 = tkl .sym .ADDRESS_SPACE_0
160+ GLOBAL_ADDRESS_SPACE = tkl .sym .GLOBAL_ADDRESS_SPACE
159161
160162 hyperparams = {
161163 BLOCK_NUM_DRAFT_TOK : 1 ,
162164 NUM_DRAFT_TOKENS : num_draft_tokens ,
163165 ADDRESS_SPACE : SHARED_ADDRESS_SPACE ,
164- ADDRESS_SPACE_0 : GLOBAL_ADDRESS_SPACE ,
166+ GLOBAL_ADDRESS_SPACE : GLOBAL_ADDRESS_SPACE ,
165167 BATCH_SIZE : batch_size ,
166168 BLOCK_BATCH_SIZE : 1 ,
167169 VOCAB_SIZE : vocab_size ,
170+ SEQ_LEN : seq_len ,
168171 }
169172
170173 dynamic_symbols = []
@@ -233,7 +236,7 @@ def get_speculative_sampling_kernel(
233236 write_mapping_1d = tkw .IndexMapping (
234237 num_iterators = 2 ,
235238 inputs = {BATCH_SIZE : i , NUM_DRAFT_TOKENS : j },
236- outputs = {NUM_DRAFT_TOKENS : LAST_ACCEPTED_RETRIEVE_IDX },
239+ outputs = {SEQ_LEN : LAST_ACCEPTED_RETRIEVE_IDX },
237240 )
238241
239242 write_mapping_3d = tkw .IndexMapping (
@@ -285,59 +288,63 @@ def write_with_zero_offset(x, y):
285288 accept_index_layout = tkl .MemoryLayout (shape = [batch_size , num_speculative_tokens ])
286289 cur_prob_offset_vec_layout = tkl .MemoryLayout (shape = [batch_size , 1 , 1 ])
287290 last_accepted_retrieve_idx_vec_layout = tkl .MemoryLayout (shape = [batch_size , 1 , 1 ])
288- predict_layout = tkl .MemoryLayout (shape = [batch_size * num_draft_tokens ])
291+ predict_layout = tkl .MemoryLayout (shape = [seq_len ])
289292
290293 # Kernel.
291294 # =================================================================================
292295 @tkw .wave (constraints )
293296 def speculative_sampling (
294297 uniform_samples : tkl .Memory [
295- BATCH_SIZE , NUM_DRAFT_TOKENS , ADDRESS_SPACE_0 , tkl .f32
298+ BATCH_SIZE , NUM_DRAFT_TOKENS , GLOBAL_ADDRESS_SPACE , tkl .f32
296299 ],
297300 target_probs : tkl .Memory [
298- BATCH_SIZE , NUM_DRAFT_TOKENS , VOCAB_SIZE , ADDRESS_SPACE_0 , tkl .f32
301+ BATCH_SIZE , NUM_DRAFT_TOKENS , VOCAB_SIZE , GLOBAL_ADDRESS_SPACE , tkl .f32
299302 ],
300303 draft_probs : tkl .Memory [
301- BATCH_SIZE , NUM_DRAFT_TOKENS , VOCAB_SIZE , ADDRESS_SPACE_0 , tkl .f32
304+ BATCH_SIZE , NUM_DRAFT_TOKENS , VOCAB_SIZE , GLOBAL_ADDRESS_SPACE , tkl .f32
305+ ],
306+ candidates : tkl .Memory [
307+ BATCH_SIZE , NUM_DRAFT_TOKENS , GLOBAL_ADDRESS_SPACE , tkl .i32
302308 ],
303- candidates : tkl .Memory [BATCH_SIZE , NUM_DRAFT_TOKENS , ADDRESS_SPACE_0 , tkl .i32 ],
304309 retrieve_index : tkl .Memory [
305- BATCH_SIZE , NUM_DRAFT_TOKENS , ADDRESS_SPACE_0 , tkl .i32
310+ BATCH_SIZE , NUM_DRAFT_TOKENS , GLOBAL_ADDRESS_SPACE , tkl .i32
306311 ],
307312 retrieve_next_token : tkl .Memory [
308- BATCH_SIZE , NUM_DRAFT_TOKENS , ADDRESS_SPACE_0 , tkl .i32
313+ BATCH_SIZE , NUM_DRAFT_TOKENS , GLOBAL_ADDRESS_SPACE , tkl .i32
309314 ],
310315 retrieve_next_sibling : tkl .Memory [
311- BATCH_SIZE , NUM_DRAFT_TOKENS , ADDRESS_SPACE_0 , tkl .i32
316+ BATCH_SIZE , NUM_DRAFT_TOKENS , GLOBAL_ADDRESS_SPACE , tkl .i32
312317 ],
313318 # Outputs
314- predicts : tkl .Memory [
315- NUM_DRAFT_TOKENS , ADDRESS_SPACE_0 , tkl .i32 , predict_layout
316- ],
319+ predicts : tkl .Memory [SEQ_LEN , GLOBAL_ADDRESS_SPACE , tkl .i32 , predict_layout ],
317320 accept_token_num : tkl .Memory [
318321 BATCH_SIZE ,
319322 NUM_DRAFT_TOKENS ,
320323 VOCAB_SIZE ,
321- ADDRESS_SPACE_0 ,
324+ GLOBAL_ADDRESS_SPACE ,
322325 tkl .i32 ,
323326 accept_token_num_layout ,
324327 ],
325328 accept_index : tkl .Memory [
326- BATCH_SIZE , NUM_DRAFT_TOKENS , ADDRESS_SPACE_0 , tkl .i32 , accept_index_layout
329+ BATCH_SIZE ,
330+ NUM_DRAFT_TOKENS ,
331+ GLOBAL_ADDRESS_SPACE ,
332+ tkl .i32 ,
333+ accept_index_layout ,
327334 ],
328335 cur_prob_offset_vec : tkl .Memory [
329336 BATCH_SIZE ,
330337 NUM_DRAFT_TOKENS ,
331338 VOCAB_SIZE ,
332- ADDRESS_SPACE_0 ,
339+ GLOBAL_ADDRESS_SPACE ,
333340 tkl .i32 ,
334341 cur_prob_offset_vec_layout ,
335342 ],
336343 last_accepted_retrieve_idx_vec : tkl .Memory [
337344 BATCH_SIZE ,
338345 NUM_DRAFT_TOKENS ,
339346 VOCAB_SIZE ,
340- ADDRESS_SPACE_0 ,
347+ GLOBAL_ADDRESS_SPACE ,
341348 tkl .i32 ,
342349 last_accepted_retrieve_idx_vec_layout ,
343350 ],
0 commit comments