Skip to content
Draft
Show file tree
Hide file tree
Changes from 183 commits
Commits
Show all changes
185 commits
Select commit Hold shift + click to select a range
33a3a26
wip
WoosukKwon Aug 17, 2025
699bd79
Merge branch 'main' into woosuk/input-prep
WoosukKwon Aug 18, 2025
c472982
merge
WoosukKwon Aug 22, 2025
79e5eb3
wip
WoosukKwon Aug 22, 2025
64c8cce
rename
WoosukKwon Aug 22, 2025
48bca9a
merge
WoosukKwon Aug 23, 2025
a1e3745
wip
WoosukKwon Aug 25, 2025
da9cd26
Merge branch 'main' into woosuk/input-prep
WoosukKwon Aug 25, 2025
7b4b72e
fix
WoosukKwon Aug 25, 2025
65f9369
merge
WoosukKwon Aug 25, 2025
b1d5273
fix
WoosukKwon Aug 25, 2025
a851aaa
simplify
WoosukKwon Aug 25, 2025
e570b0a
merge
WoosukKwon Aug 28, 2025
d6d719f
Merge branch 'main' into woosuk/input-prep
WoosukKwon Aug 28, 2025
b21393c
Merge branch 'main' into woosuk/input-prep
WoosukKwon Aug 28, 2025
efba25e
minor
WoosukKwon Aug 28, 2025
e451045
fix
WoosukKwon Aug 28, 2025
19c0dfc
minor
WoosukKwon Aug 28, 2025
4055781
minor
WoosukKwon Aug 28, 2025
9ee9d0e
fix
WoosukKwon Aug 28, 2025
efcb786
merge
WoosukKwon Aug 31, 2025
e696f78
minor
WoosukKwon Aug 31, 2025
c11d1e6
optimize spec
WoosukKwon Aug 31, 2025
22771e5
work
WoosukKwon Sep 1, 2025
ba1a58f
MAX_SPEC_LEN
WoosukKwon Sep 1, 2025
62d23b3
fix
WoosukKwon Sep 1, 2025
af7b6c5
fix
WoosukKwon Sep 1, 2025
01bf16e
fix
WoosukKwon Sep 1, 2025
cc340e2
top_p top_k
WoosukKwon Sep 1, 2025
4c2a337
merge
WoosukKwon Sep 1, 2025
b16e2d9
fix
WoosukKwon Sep 1, 2025
23eae07
merge
WoosukKwon Sep 5, 2025
ead95fe
merge
WoosukKwon Sep 6, 2025
8e6cb9a
minor
WoosukKwon Sep 6, 2025
0c56069
merge
WoosukKwon Sep 6, 2025
6283995
minor
WoosukKwon Sep 7, 2025
286eeb9
merge
WoosukKwon Sep 7, 2025
5f95309
rename
WoosukKwon Sep 7, 2025
787e596
wip
WoosukKwon Sep 8, 2025
7a50a54
Merge branch 'main' into woosuk/input-prep
WoosukKwon Sep 13, 2025
9314a83
Merge branch 'main' into woosuk/input-prep
WoosukKwon Sep 14, 2025
caf963f
fix
WoosukKwon Sep 14, 2025
5c133fc
reorder
WoosukKwon Sep 14, 2025
e47bb99
fix
WoosukKwon Sep 14, 2025
eb3742c
fix
WoosukKwon Sep 14, 2025
633f9f0
Merge branch 'main' into woosuk/input-prep
WoosukKwon Sep 14, 2025
9a6fcca
fix
WoosukKwon Sep 14, 2025
8b3c13c
wip
WoosukKwon Sep 15, 2025
67852c1
minor
WoosukKwon Sep 15, 2025
69b1789
chunked prefilling
WoosukKwon Sep 15, 2025
f1981db
minor
WoosukKwon Sep 15, 2025
e107680
wip
WoosukKwon Sep 15, 2025
9f2becd
merge
WoosukKwon Sep 16, 2025
dfc84b1
wip
WoosukKwon Sep 15, 2025
83d1137
wip
WoosukKwon Sep 16, 2025
c320a33
skip warmup
WoosukKwon Sep 16, 2025
9151026
task
WoosukKwon Sep 16, 2025
c1d83f2
merge
WoosukKwon Sep 18, 2025
9050087
update
WoosukKwon Sep 18, 2025
92f337f
minor
WoosukKwon Sep 18, 2025
cbdb47d
working
WoosukKwon Sep 18, 2025
3f50030
fix
WoosukKwon Sep 18, 2025
a496283
minor
WoosukKwon Sep 18, 2025
bc6463a
hash
WoosukKwon Sep 18, 2025
aabfaa0
fix
WoosukKwon Sep 18, 2025
330058f
fix
WoosukKwon Sep 18, 2025
82e591f
remove
WoosukKwon Sep 18, 2025
8407fa0
fix
WoosukKwon Sep 18, 2025
e171e5b
merge
WoosukKwon Sep 18, 2025
2bb2cb1
revert
WoosukKwon Sep 18, 2025
67d8c0c
fix
WoosukKwon Sep 18, 2025
a98eff0
minor
WoosukKwon Sep 18, 2025
323a05b
update
WoosukKwon Sep 18, 2025
82da219
Implement topk_logprobs
WoosukKwon Sep 18, 2025
efda084
minor
WoosukKwon Sep 18, 2025
86dade7
fix
WoosukKwon Sep 18, 2025
d2be623
fix
WoosukKwon Sep 18, 2025
31619ff
fix
WoosukKwon Sep 18, 2025
b9c7448
logprobs
WoosukKwon Sep 19, 2025
8deedfa
-inf
WoosukKwon Sep 19, 2025
52ca2f5
sample
WoosukKwon Sep 19, 2025
af65838
dummy run
WoosukKwon Sep 19, 2025
8af8798
fix
WoosukKwon Sep 19, 2025
b405d78
DP sampler
WoosukKwon Sep 19, 2025
0d3de9e
fix
WoosukKwon Sep 19, 2025
3367277
Merge branch 'main' into woosuk/model-runner-v2
WoosukKwon Sep 19, 2025
37478c1
async output
WoosukKwon Sep 19, 2025
9c75d89
minor
WoosukKwon Sep 19, 2025
d30c0d5
refactor
WoosukKwon Sep 19, 2025
4be2c66
fix
WoosukKwon Sep 19, 2025
a8e7071
minor
WoosukKwon Sep 19, 2025
c7f3e84
minor
WoosukKwon Sep 19, 2025
396bbe6
Merge branch 'main' into woosuk/model-runner-v2
WoosukKwon Sep 19, 2025
010e39e
minor
WoosukKwon Sep 19, 2025
6f038fc
Merge branch 'main' into woosuk/model-runner-v2
WoosukKwon Sep 19, 2025
a66aa37
minor:
WoosukKwon Sep 19, 2025
98ef239
minor
WoosukKwon Sep 19, 2025
158a468
random uuid
WoosukKwon Sep 20, 2025
913b8e9
Merge branch 'main' into woosuk/model-runner-v2
WoosukKwon Sep 20, 2025
8aee6e9
64-bit for gumbel seed
WoosukKwon Sep 20, 2025
42ffdd9
wip
WoosukKwon Sep 20, 2025
631b5b4
Merge branch 'main' into woosuk/model-runner-v2
WoosukKwon Sep 21, 2025
bc73f67
compute_logits
WoosukKwon Sep 21, 2025
fe5472d
Merge branch 'main' into woosuk/model-runner-v2
WoosukKwon Sep 22, 2025
72f0a71
assert
WoosukKwon Sep 22, 2025
17c2c10
Merge branch 'main' into woosuk/model-runner-v2
WoosukKwon Sep 23, 2025
42f9915
fix
WoosukKwon Sep 23, 2025
704def2
Merge branch 'main' into woosuk/model-runner-v2
WoosukKwon Sep 23, 2025
ad2cf80
Merge branch 'main' into woosuk/model-runner-v2
WoosukKwon Sep 24, 2025
866eef5
minor
WoosukKwon Sep 24, 2025
1107701
Merge branch 'main' into woosuk/model-runner-v2
WoosukKwon Oct 30, 2025
09e4b2f
update
WoosukKwon Oct 30, 2025
5666a25
fix
WoosukKwon Oct 30, 2025
5c8049d
fix
WoosukKwon Oct 30, 2025
1c5c866
uint64
WoosukKwon Oct 30, 2025
8f8aaa8
forward context
WoosukKwon Nov 6, 2025
e40e85b
merge
WoosukKwon Nov 9, 2025
013daed
Add sample_tokens
WoosukKwon Nov 9, 2025
608fec3
fix lora
WoosukKwon Nov 9, 2025
bf3992c
allow torch compile
WoosukKwon Nov 9, 2025
a1249af
minor
WoosukKwon Nov 9, 2025
3ce8a08
Add DP
WoosukKwon Nov 9, 2025
b9ebedb
fix
WoosukKwon Nov 9, 2025
8d82fac
fix
WoosukKwon Nov 9, 2025
af23897
fix
WoosukKwon Nov 9, 2025
83943cd
minor
WoosukKwon Nov 9, 2025
cbd90df
fix
WoosukKwon Nov 9, 2025
5b5fd19
minor
WoosukKwon Nov 9, 2025
484135c
minor
WoosukKwon Nov 9, 2025
8912870
Add structured outputs
WoosukKwon Nov 9, 2025
312affc
fix
WoosukKwon Nov 9, 2025
523f27a
fix
WoosukKwon Nov 9, 2025
de64ce7
async structured outputs
WoosukKwon Nov 9, 2025
8b44f99
flag
WoosukKwon Nov 9, 2025
d8a8279
fix
WoosukKwon Nov 9, 2025
fe97bf9
fix dp
WoosukKwon Nov 9, 2025
8240f3a
minor
WoosukKwon Nov 9, 2025
ebdee19
minor
WoosukKwon Nov 9, 2025
e75ded3
minor
WoosukKwon Nov 9, 2025
493b4d6
minor
WoosukKwon Nov 9, 2025
75ef5f4
fix for DP
WoosukKwon Nov 9, 2025
724593b
fix
WoosukKwon Nov 9, 2025
6dc3d83
minor
WoosukKwon Nov 9, 2025
2b51ecb
skip sync in dummy run
WoosukKwon Nov 9, 2025
ecb2932
minor
WoosukKwon Nov 10, 2025
63e4387
minor
WoosukKwon Nov 10, 2025
dd254ce
code owner
WoosukKwon Nov 10, 2025
f510b9e
code owner
WoosukKwon Nov 10, 2025
a505e71
minor
WoosukKwon Nov 11, 2025
fb0782c
minor
WoosukKwon Nov 11, 2025
645650c
remove filtering for negative token
WoosukKwon Nov 12, 2025
2326a8c
minor on cudagraph utils
WoosukKwon Nov 12, 2025
e284750
merge
WoosukKwon Nov 12, 2025
4085ce8
minor
WoosukKwon Nov 12, 2025
1d8a671
fix
WoosukKwon Nov 12, 2025
31580e9
merge
WoosukKwon Nov 13, 2025
a0c396b
merge
WoosukKwon Nov 15, 2025
6da659f
mypy
WoosukKwon Nov 15, 2025
ff9a1aa
readme
WoosukKwon Nov 15, 2025
197ed08
fix
WoosukKwon Nov 16, 2025
a72b07e
fix cudagraph
WoosukKwon Nov 16, 2025
a9b4fa3
revert
WoosukKwon Nov 16, 2025
3da2e77
Merge branch 'main' into woosuk/model-runner-v2
WoosukKwon Nov 16, 2025
ee2c3b0
minor
WoosukKwon Nov 16, 2025
995f1aa
simplify get_kv_cache_spec
WoosukKwon Nov 16, 2025
ed84190
support mla
WoosukKwon Nov 16, 2025
5ea5e7e
merge
WoosukKwon Nov 17, 2025
784371c
preempt
WoosukKwon Nov 17, 2025
1402b93
Optimize gumbel sampling
WoosukKwon Nov 17, 2025
d8b8e65
Merge branch 'main' into woosuk/model-runner-v2
WoosukKwon Nov 17, 2025
3306f84
tmp
WoosukKwon Nov 18, 2025
015dd25
impl
WoosukKwon Nov 19, 2025
d799827
minor
WoosukKwon Nov 19, 2025
334c7d7
opt
WoosukKwon Nov 20, 2025
fe389d5
minor
WoosukKwon Nov 20, 2025
1154803
Merge branch 'main' into woosuk/tmp-v2
WoosukKwon Nov 21, 2025
c0612da
minor
WoosukKwon Nov 21, 2025
a8430a7
Merge branch 'main' into woosuk/tmp-v2
WoosukKwon Nov 21, 2025
32a0359
rm spec_decode
WoosukKwon Nov 21, 2025
540f456
fix
WoosukKwon Nov 21, 2025
e3c16d0
fix
WoosukKwon Nov 21, 2025
f9ac765
fix
WoosukKwon Nov 21, 2025
8422fe8
postprocess
WoosukKwon Nov 21, 2025
ebe5363
fix
WoosukKwon Nov 21, 2025
e2156fa
Merge branch 'main' into woosuk/tmp-v2
WoosukKwon Nov 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions vllm/v1/worker/gpu/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(
self,
model_runner_output: ModelRunnerOutput,
sampler_output: SamplerOutput,
num_sampled_tokens: np.ndarray,
num_sampled_tokens: torch.Tensor,
copy_stream: torch.cuda.Stream,
copy_event: torch.cuda.Event,
):
Expand Down Expand Up @@ -52,6 +52,7 @@ def __init__(
)
else:
self.logprobs_tensors = None
self.num_sampled_tokens = num_sampled_tokens.to("cpu", non_blocking=True)
self.prompt_logprobs_dict: dict[str, LogprobsTensors | None] = {}
if self.model_runner_output.prompt_logprobs_dict:
for k, v in self.model_runner_output.prompt_logprobs_dict.items():
Expand All @@ -63,6 +64,7 @@ def __init__(

def get_output(self) -> ModelRunnerOutput:
self.copy_event.synchronize()
num_sampled_tokens_np = self.num_sampled_tokens.numpy()

# NOTE(woosuk): The following code is to ensure compatibility with
# the existing model runner.
Expand All @@ -71,7 +73,7 @@ def get_output(self) -> ModelRunnerOutput:
sampled_token_ids: list[list[int]] = self.sampled_token_ids.tolist()
num_reqs = len(sampled_token_ids)
for i in range(num_reqs):
del sampled_token_ids[i][self.num_sampled_tokens[i] :]
del sampled_token_ids[i][num_sampled_tokens_np[i] :]
self.model_runner_output.sampled_token_ids = sampled_token_ids

if self.logprobs_tensors is not None:
Expand Down
14 changes: 8 additions & 6 deletions vllm/v1/worker/gpu/attn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Sequence
from typing import Any, cast

import numpy as np
import torch

from vllm.attention.backends.abstract import AttentionBackend
Expand Down Expand Up @@ -145,18 +146,19 @@ def build_attn_metadata(
num_reqs: int,
num_tokens: int,
query_start_loc: CpuGpuBuffer,
seq_lens: CpuGpuBuffer,
num_computed_tokens_cpu: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_np: np.ndarray,
num_computed_tokens_cpu: torch.Tensor | None,
block_tables: Sequence[torch.Tensor],
slot_mappings: torch.Tensor,
kv_cache_config: KVCacheConfig,
) -> dict[str, Any]:
query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
query_start_loc_cpu = query_start_loc.cpu[: num_reqs + 1]
max_query_len = int(query_start_loc.np[: num_reqs + 1].max())
seq_lens_gpu = seq_lens.gpu[:num_reqs]
seq_lens_cpu = seq_lens.cpu[:num_reqs]
max_seq_len = int(seq_lens.np[:num_reqs].max())
seq_lens = seq_lens[:num_reqs]
seq_lens_cpu = torch.from_numpy(seq_lens_np)
max_seq_len = int(seq_lens_np.max())

attn_metadata: dict[str, Any] = {}
kv_cache_groups = kv_cache_config.kv_cache_groups
Expand All @@ -167,7 +169,7 @@ def build_attn_metadata(
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc_gpu,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens_gpu,
seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu,
max_seq_len=max_seq_len,
num_computed_tokens_cpu=num_computed_tokens_cpu,
Expand Down
3 changes: 1 addition & 2 deletions vllm/v1/worker/gpu/block_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
from collections.abc import Iterable

import torch
import triton
import triton.language as tl

from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv
from vllm.v1.utils import CpuGpuBuffer

Expand Down
8 changes: 4 additions & 4 deletions vllm/v1/worker/gpu/cudagraph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,13 @@ def capture_graph(

# Prepare dummy inputs.
input_ids = input_buffers.input_ids.gpu[:batch_size]
positions = input_buffers.positions.gpu[:batch_size]
positions = input_buffers.positions[:batch_size]

input_buffers.query_start_loc.np[: batch_size + 1] = np.arange(batch_size + 1)
input_buffers.query_start_loc.np[batch_size:] = batch_size
input_buffers.query_start_loc.copy_to_gpu()
input_buffers.seq_lens.np[:batch_size] = self.max_model_len
input_buffers.seq_lens.np[batch_size:] = 0
input_buffers.seq_lens.copy_to_gpu()
input_buffers.seq_lens[:batch_size] = self.max_model_len
input_buffers.seq_lens[batch_size:] = 0

input_block_tables = [x[:batch_size] for x in block_tables.input_block_tables]
slot_mappings = block_tables.slot_mappings[:, :batch_size]
Expand All @@ -115,6 +114,7 @@ def capture_graph(
num_tokens=batch_size,
query_start_loc=input_buffers.query_start_loc,
seq_lens=input_buffers.seq_lens,
seq_lens_np=np.full(batch_size, self.max_model_len, dtype=np.int32),
num_computed_tokens_cpu=None, # FIXME
block_tables=input_block_tables,
slot_mappings=slot_mappings,
Expand Down
177 changes: 128 additions & 49 deletions vllm/v1/worker/gpu/input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
import numba.types as types
import numpy as np
import torch
import triton
import triton.language as tl

from vllm.triton_utils import tl, triton
from vllm.utils import random_uuid
from vllm.utils.math_utils import cdiv
from vllm.v1.utils import CpuGpuBuffer
Expand All @@ -33,9 +32,9 @@ def __init__(

self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32)
self.input_ids = self._make_buffer(max_num_tokens, dtype=torch.int32)
self.positions = self._make_buffer(max_num_tokens, dtype=torch.int64)
self.positions = torch.zeros(max_num_tokens, dtype=torch.int64, device=device)
self.query_start_loc = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
self.seq_lens = self._make_buffer(max_num_reqs, dtype=torch.int32)
self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device)

# Structured outputs.
self.bitmask_indices = self._make_buffer(max_num_reqs, dtype=torch.int32)
Expand Down Expand Up @@ -108,13 +107,15 @@ def make_dummy(
query_start_loc_np = input_buffers.query_start_loc.np[: num_reqs + 1]
query_start_loc = input_buffers.query_start_loc.copy_to_gpu()[: num_reqs + 1]
# seq_len equals to query_len
input_buffers.seq_lens.np[:num_reqs] = num_scheduled_tokens
input_buffers.seq_lens.np[num_reqs:] = 0
seq_lens_np = input_buffers.seq_lens.np[:num_reqs]
seq_lens = input_buffers.seq_lens.copy_to_gpu()[:num_reqs]
seq_lens_np = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32)
seq_lens_np[-1] += num_tokens % num_reqs
input_buffers.seq_lens[:num_reqs] = num_tokens // num_reqs
input_buffers.seq_lens[num_reqs - 1] += num_tokens % num_reqs
input_buffers.seq_lens[num_reqs:] = 0
seq_lens = input_buffers.seq_lens[:num_reqs]

input_ids = input_buffers.input_ids.copy_to_gpu(num_tokens)
positions = input_buffers.positions.copy_to_gpu(num_tokens)
positions = input_buffers.positions[:num_tokens]
# attn_metadata = defaultdict(lambda: None)
logits_indices = query_start_loc[1:] - 1
return cls(
Expand Down Expand Up @@ -142,27 +143,25 @@ def make_dummy(
[
types.none(
types.int32[:], # idx_mapping
types.int32[:, :], # token_ids
types.int32[:], # num_computed_tokens
types.int32[:], # num_scheduled_tokens
types.int32[:, :], # prefill_token_ids
types.int32[:], # num_computed_prefill_tokens
types.int32[:], # prefill_len
types.int32[:], # input_ids
types.int64[:], # positions
types.int32[:], # query_start_loc
types.int32[:], # seq_lens
)
],
nopython=True,
cache=True,
)
def _prepare_inputs(
def _prepare_prefill_inputs(
idx_mapping: np.ndarray, # batch_idx -> req_idx
token_ids: np.ndarray, # [N, max_model_len]
num_computed_tokens: np.ndarray, # [N]
num_scheduled_tokens: np.ndarray, # [B]
prefill_token_ids: np.ndarray, # [N, max_model_len]
num_computed_prefill_tokens: np.ndarray, # [N]
prefill_len: np.ndarray, # [N]
input_ids: np.ndarray, # [num_input_tokens]
positions: np.ndarray, # [num_input_tokens]
query_start_loc: np.ndarray, # [B + 1]
seq_lens: np.ndarray, # [B]
) -> None:
num_reqs = num_scheduled_tokens.shape[0]
query_start_loc[0] = 0
Expand All @@ -171,62 +170,112 @@ def _prepare_inputs(
for i in range(num_reqs):
req_idx = idx_mapping[i]
query_len = num_scheduled_tokens[i]
start = num_computed_tokens[req_idx]
end = start + query_len
seq_lens[i] = end

start = num_computed_prefill_tokens[req_idx]
end = min(start + query_len, prefill_len[req_idx])
n = end - start

start_idx = cu_num_tokens
end_idx = start_idx + query_len
input_ids[start_idx:end_idx] = token_ids[req_idx, start:end]
positions[start_idx:end_idx] = np.arange(start, end, dtype=np.int64)
input_ids[start_idx : start_idx + n] = prefill_token_ids[req_idx, start:end]

cu_num_tokens = end_idx
cu_num_tokens = start_idx + query_len
query_start_loc[i + 1] = cu_num_tokens

# Pad the inputs for CUDA graphs.
# Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that
query_start_loc[num_reqs + 1 :].fill(cu_num_tokens)
# Fill unused with 0 for full cuda graph mode.
seq_lens[num_reqs:].fill(0)


def prepare_inputs(
def prepare_prefill_inputs(
idx_mapping: np.ndarray,
prefill_token_ids: np.ndarray,
num_computed_tokens: np.ndarray,
num_scheduled_tokens: np.ndarray,
total_num_tokens: int,
prefill_token_ids: np.ndarray,
num_computed_prefill_tokens: np.ndarray,
prefill_len: np.ndarray,
input_ids: CpuGpuBuffer,
positions: CpuGpuBuffer,
query_start_loc: CpuGpuBuffer,
seq_lens: CpuGpuBuffer,
num_tokens: int,
) -> None:
_prepare_inputs(
_prepare_prefill_inputs(
idx_mapping,
prefill_token_ids,
num_computed_tokens,
num_scheduled_tokens,
prefill_token_ids,
num_computed_prefill_tokens,
prefill_len,
input_ids.np,
positions.np,
query_start_loc.np,
seq_lens.np,
)
input_ids.copy_to_gpu(num_tokens)
positions.copy_to_gpu(num_tokens)
input_ids.copy_to_gpu(total_num_tokens)
# NOTE(woosuk): We should copy the whole query_start_loc and seq_lens
# tensors from CPU to GPU, because they may include paddings needed
# for full CUDA graph mode.
query_start_loc.copy_to_gpu()
seq_lens.copy_to_gpu()
return


@triton.jit
def _combine_last_token_ids_kernel(
def _prepare_pos_seq_lens_kernel(
pos_ptr,
seq_lens_ptr,
idx_mapping_ptr,
query_start_loc_ptr,
num_computed_tokens_ptr,
max_num_reqs,
BLOCK_SIZE: tl.constexpr,
):
req_id = tl.program_id(0)
num_reqs = tl.num_programs(0) - 1
if req_id == num_reqs:
# Pad unused seq_lens as 0 for full CUDA graphs.
for i in tl.range(num_reqs, max_num_reqs, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < max_num_reqs
tl.store(seq_lens_ptr + block, 0, mask=mask)
return

req_state_idx = tl.load(idx_mapping_ptr + req_id)
num_computed_tokens = tl.load(num_computed_tokens_ptr + req_state_idx)

start = tl.load(query_start_loc_ptr + req_id)
end = tl.load(query_start_loc_ptr + req_id + 1)
query_len = end - start

seq_len = num_computed_tokens + query_len
tl.store(seq_lens_ptr + req_id, seq_len)

for i in tl.range(0, query_len, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < query_len
pos = num_computed_tokens + block
tl.store(pos_ptr + start + block, pos, mask=mask)


def prepare_pos_seq_lens(
idx_mapping: torch.Tensor,
query_start_loc: torch.Tensor,
num_computed_tokens: torch.Tensor,
pos: torch.Tensor,
seq_lens: torch.Tensor,
) -> None:
num_reqs = idx_mapping.shape[0]
# NOTE(woosuk): We do +1 because the last thread block is used
# to pad unused seq_lens as 0 for full CUDA graphs.
_prepare_pos_seq_lens_kernel[(num_reqs + 1,)](
pos,
seq_lens,
idx_mapping,
query_start_loc,
num_computed_tokens,
seq_lens.shape[0],
BLOCK_SIZE=1024,
)


@triton.jit
def _combine_sampled_and_draft_tokens_kernel(
input_ids_ptr,
idx_mapping_ptr,
last_token_ids_ptr,
last_sampled_tokens_ptr,
query_start_loc_ptr,
seq_lens_ptr,
prefill_len_ptr,
Expand All @@ -240,26 +289,56 @@ def _combine_last_token_ids_kernel(
# Handling prefill tokens.
return

last_token_id = tl.load(last_token_ids_ptr + req_state_idx)
last_token_id = tl.load(last_sampled_tokens_ptr + req_state_idx)
end = tl.load(query_start_loc_ptr + batch_idx + 1)
tl.store(input_ids_ptr + end - 1, last_token_id)


def combine_last_token_ids(
def combine_sampled_and_draft_tokens(
input_ids: torch.Tensor,
idx_mapping: torch.Tensor,
last_token_ids: torch.Tensor,
last_sampled_tokens: torch.Tensor,
query_start_loc: torch.Tensor,
seq_lens: torch.Tensor,
prefill_len: torch.Tensor,
) -> torch.Tensor:
num_reqs = seq_lens.shape[0]
_combine_last_token_ids_kernel[(num_reqs,)](
_combine_sampled_and_draft_tokens_kernel[(num_reqs,)](
input_ids,
idx_mapping,
last_token_ids,
last_sampled_tokens,
query_start_loc,
seq_lens,
prefill_len,
)
return input_ids


@triton.jit
def _update_num_computed_tokens_kernel(
idx_mapping_ptr,
num_computed_tokens_ptr,
query_start_loc_ptr,
):
req_id = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + req_id)

start = tl.load(query_start_loc_ptr + req_id)
end = tl.load(query_start_loc_ptr + req_id + 1)
query_len = end - start

n = tl.load(num_computed_tokens_ptr + req_state_idx)
tl.store(num_computed_tokens_ptr + req_state_idx, n + query_len)


def update_num_computed_tokens(
idx_mapping: torch.Tensor,
num_computed_tokens: torch.Tensor,
query_start_loc: torch.Tensor,
) -> None:
num_reqs = idx_mapping.shape[0]
_update_num_computed_tokens_kernel[(num_reqs,)](
idx_mapping,
num_computed_tokens,
query_start_loc,
)
Loading
Loading