Skip to content

Commit 43551dd

Browse files
committed
fixed predicts dim
Signed-off-by: xintin <[email protected]>
1 parent 3abd543 commit 43551dd

File tree

3 files changed

+31
-20
lines changed

3 files changed

+31
-20
lines changed

iree/turbine/kernel/wave/templates/speculative_decoding.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,10 @@ def tree_speculative_sampling(
120120
sum_relu = tkw.sum(relu_diff, dim=VOCAB_SIZE)
121121
cdf = tkw.cumsum(relu_diff, dim=VOCAB_SIZE)
122122

123-
threshold_u = tkw.broadcast(
123+
threshold_dist_u = tkw.broadcast(
124124
coin * sum_relu, target_shape=[BATCH_SIZE, NUM_DRAFT_TOKENS, VOCAB_SIZE]
125125
)
126-
greater_than_u = cdf > threshold_u
126+
greater_than_u = cdf > threshold_dist_u
127127
pad_token = tkl.Register[BATCH_SIZE, NUM_DRAFT_TOKENS, VOCAB_SIZE, tkl.i32](1e6)
128128
token_idx = tkl.Register[BATCH_SIZE, NUM_DRAFT_TOKENS, VOCAB_SIZE, tkl.i32](
129129
THREAD_0
@@ -146,25 +146,28 @@ def get_speculative_sampling_kernel(
146146
threshold_single: float,
147147
num_draft_tokens: int,
148148
vocab_size: int,
149+
seq_len: int,
149150
):
150151
CUR_INDEX = sympy.Symbol("CUR_INDEX")
151152
J = sympy.Symbol("J")
152153
BATCH_SIZE = tkl.sym.BATCH_SIZE
153154
NUM_DRAFT_TOKENS = tkl.sym.NUM_DRAFT_TOKENS
154155
VOCAB_SIZE = tkl.sym.VOCAB_SIZE
156+
SEQ_LEN = tkl.sym.SEQ_LEN
155157
BLOCK_BATCH_SIZE = tkl.sym.BLOCK_BATCH_SIZE
156158
BLOCK_NUM_DRAFT_TOK = tkl.sym.BLOCK_NUM_DRAFT_TOK
157159
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
158-
ADDRESS_SPACE_0 = tkl.sym.ADDRESS_SPACE_0
160+
GLOBAL_ADDRESS_SPACE = tkl.sym.GLOBAL_ADDRESS_SPACE
159161

160162
hyperparams = {
161163
BLOCK_NUM_DRAFT_TOK: 1,
162164
NUM_DRAFT_TOKENS: num_draft_tokens,
163165
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
164-
ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE,
166+
GLOBAL_ADDRESS_SPACE: GLOBAL_ADDRESS_SPACE,
165167
BATCH_SIZE: batch_size,
166168
BLOCK_BATCH_SIZE: 1,
167169
VOCAB_SIZE: vocab_size,
170+
SEQ_LEN: seq_len,
168171
}
169172

170173
dynamic_symbols = []
@@ -233,7 +236,7 @@ def get_speculative_sampling_kernel(
233236
write_mapping_1d = tkw.IndexMapping(
234237
num_iterators=2,
235238
inputs={BATCH_SIZE: i, NUM_DRAFT_TOKENS: j},
236-
outputs={NUM_DRAFT_TOKENS: LAST_ACCEPTED_RETRIEVE_IDX},
239+
outputs={SEQ_LEN: LAST_ACCEPTED_RETRIEVE_IDX},
237240
)
238241

239242
write_mapping_3d = tkw.IndexMapping(
@@ -285,59 +288,63 @@ def write_with_zero_offset(x, y):
285288
accept_index_layout = tkl.MemoryLayout(shape=[batch_size, num_speculative_tokens])
286289
cur_prob_offset_vec_layout = tkl.MemoryLayout(shape=[batch_size, 1, 1])
287290
last_accepted_retrieve_idx_vec_layout = tkl.MemoryLayout(shape=[batch_size, 1, 1])
288-
predict_layout = tkl.MemoryLayout(shape=[batch_size * num_draft_tokens])
291+
predict_layout = tkl.MemoryLayout(shape=[seq_len])
289292

290293
# Kernel.
291294
# =================================================================================
292295
@tkw.wave(constraints)
293296
def speculative_sampling(
294297
uniform_samples: tkl.Memory[
295-
BATCH_SIZE, NUM_DRAFT_TOKENS, ADDRESS_SPACE_0, tkl.f32
298+
BATCH_SIZE, NUM_DRAFT_TOKENS, GLOBAL_ADDRESS_SPACE, tkl.f32
296299
],
297300
target_probs: tkl.Memory[
298-
BATCH_SIZE, NUM_DRAFT_TOKENS, VOCAB_SIZE, ADDRESS_SPACE_0, tkl.f32
301+
BATCH_SIZE, NUM_DRAFT_TOKENS, VOCAB_SIZE, GLOBAL_ADDRESS_SPACE, tkl.f32
299302
],
300303
draft_probs: tkl.Memory[
301-
BATCH_SIZE, NUM_DRAFT_TOKENS, VOCAB_SIZE, ADDRESS_SPACE_0, tkl.f32
304+
BATCH_SIZE, NUM_DRAFT_TOKENS, VOCAB_SIZE, GLOBAL_ADDRESS_SPACE, tkl.f32
305+
],
306+
candidates: tkl.Memory[
307+
BATCH_SIZE, NUM_DRAFT_TOKENS, GLOBAL_ADDRESS_SPACE, tkl.i32
302308
],
303-
candidates: tkl.Memory[BATCH_SIZE, NUM_DRAFT_TOKENS, ADDRESS_SPACE_0, tkl.i32],
304309
retrieve_index: tkl.Memory[
305-
BATCH_SIZE, NUM_DRAFT_TOKENS, ADDRESS_SPACE_0, tkl.i32
310+
BATCH_SIZE, NUM_DRAFT_TOKENS, GLOBAL_ADDRESS_SPACE, tkl.i32
306311
],
307312
retrieve_next_token: tkl.Memory[
308-
BATCH_SIZE, NUM_DRAFT_TOKENS, ADDRESS_SPACE_0, tkl.i32
313+
BATCH_SIZE, NUM_DRAFT_TOKENS, GLOBAL_ADDRESS_SPACE, tkl.i32
309314
],
310315
retrieve_next_sibling: tkl.Memory[
311-
BATCH_SIZE, NUM_DRAFT_TOKENS, ADDRESS_SPACE_0, tkl.i32
316+
BATCH_SIZE, NUM_DRAFT_TOKENS, GLOBAL_ADDRESS_SPACE, tkl.i32
312317
],
313318
# Outputs
314-
predicts: tkl.Memory[
315-
NUM_DRAFT_TOKENS, ADDRESS_SPACE_0, tkl.i32, predict_layout
316-
],
319+
predicts: tkl.Memory[SEQ_LEN, GLOBAL_ADDRESS_SPACE, tkl.i32, predict_layout],
317320
accept_token_num: tkl.Memory[
318321
BATCH_SIZE,
319322
NUM_DRAFT_TOKENS,
320323
VOCAB_SIZE,
321-
ADDRESS_SPACE_0,
324+
GLOBAL_ADDRESS_SPACE,
322325
tkl.i32,
323326
accept_token_num_layout,
324327
],
325328
accept_index: tkl.Memory[
326-
BATCH_SIZE, NUM_DRAFT_TOKENS, ADDRESS_SPACE_0, tkl.i32, accept_index_layout
329+
BATCH_SIZE,
330+
NUM_DRAFT_TOKENS,
331+
GLOBAL_ADDRESS_SPACE,
332+
tkl.i32,
333+
accept_index_layout,
327334
],
328335
cur_prob_offset_vec: tkl.Memory[
329336
BATCH_SIZE,
330337
NUM_DRAFT_TOKENS,
331338
VOCAB_SIZE,
332-
ADDRESS_SPACE_0,
339+
GLOBAL_ADDRESS_SPACE,
333340
tkl.i32,
334341
cur_prob_offset_vec_layout,
335342
],
336343
last_accepted_retrieve_idx_vec: tkl.Memory[
337344
BATCH_SIZE,
338345
NUM_DRAFT_TOKENS,
339346
VOCAB_SIZE,
340-
ADDRESS_SPACE_0,
347+
GLOBAL_ADDRESS_SPACE,
341348
tkl.i32,
342349
last_accepted_retrieve_idx_vec_layout,
343350
],

lit_tests/kernel/wave/speculative_decoding.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def test_speculative_decoding():
2525
threshold_acc=0.01,
2626
num_draft_tokens=6,
2727
vocab_size=20,
28+
seq_len=12,
2829
)
2930

3031
# Create the kernel with the hyperparameters

tests/kernel/wave/speculative_decode_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def get_wave_speculative_sampling_kernel(
6161
threshold_single,
6262
num_draft_tokens,
6363
vocab_size,
64+
seq_len,
6465
):
6566
speculative_sampling, symbols, _, _ = get_speculative_sampling_kernel(
6667
batch_size,
@@ -69,6 +70,7 @@ def get_wave_speculative_sampling_kernel(
6970
threshold_single,
7071
num_draft_tokens,
7172
vocab_size,
73+
seq_len,
7274
)
7375
symbols.update(get_default_scheduling_params())
7476

@@ -188,6 +190,7 @@ def tree_speculative_sampling_target_only(
188190
threshold_single,
189191
num_draft_tokens,
190192
vocab_size,
193+
seq_len,
191194
)
192195
sampling_kernel(
193196
uniform_samples,

0 commit comments

Comments
 (0)