Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,9 @@ def __init__(
self.blocks_per_seq = blocks_per_seq
self.max_batch_size = int(max_batch_size)

bt = torch.full((self.max_batch_size, self.blocks_per_seq), -1, dtype=torch.int32, device=device)
for i in range(self.max_batch_size):
for j in range(self.blocks_per_seq):
bt[i, j] = i * self.blocks_per_seq + j
self._block_table = bt
self._block_table = torch.arange(self.max_batch_size, dtype=torch.int32, device=device).unsqueeze(
1
) * self.blocks_per_seq + torch.arange(self.blocks_per_seq, dtype=torch.int32, device=device).unsqueeze(0)

def build_attn_metadata(
self,
Expand Down Expand Up @@ -119,28 +117,26 @@ def build_attn_metadata(
return {}, torch.empty((0,), dtype=torch.int64, device=self.device), {}

# positions: for each request i, emit positions [seq_len-query_len .. seq_len-1]
pos_list: list[torch.Tensor] = []
for i in range(num_reqs):
ql = int(query_lens_i32[i].item())
sl = int(seq_lens_i32[i].item())
start = sl - ql
pos_list.append(torch.arange(start, sl, dtype=torch.int64))
positions_cpu = torch.cat(pos_list, dim=0)
# Vectorised: build a flat offset range, then add per-request start positions.
starts = (seq_lens_i32 - query_lens_i32)[:num_reqs] # [num_reqs]
max_ql = int(query_lens_i32[:num_reqs].max().item())
offsets = torch.arange(max_ql, dtype=torch.int64) # [max_ql]
# [num_reqs, max_ql] grid of absolute positions
pos_grid = starts.to(torch.int64).unsqueeze(1) + offsets.unsqueeze(0)
# Mask to valid tokens (offset < query_len for each request)
ql_expanded = query_lens_i32[:num_reqs].to(torch.int64).unsqueeze(1)
mask = offsets.unsqueeze(0) < ql_expanded # [num_reqs, max_ql]
positions_cpu = pos_grid[mask] # [num_tokens]

# slot_mapping: map each query token to a physical slot in the paged KV cache.
# We allocate per-request contiguous blocks; slot = base + position.
slot_mapping = torch.empty((num_tokens,), dtype=torch.int64, device="cpu")
cursor = 0
for i in range(num_reqs):
ql = int(query_lens_i32[i].item())
sl = int(seq_lens_i32[i].item())
start = sl - ql
for p in range(start, sl):
block_idx = p // self.block_size
offset = p % self.block_size
block_id = int(self._block_table[i, block_idx].item())
slot_mapping[cursor] = block_id * self.block_size + offset
cursor += 1
# Vectorised: compute block_id from block_table and derive slots.
req_ids = torch.repeat_interleave(
torch.arange(num_reqs, dtype=torch.int64), query_lens_i32[:num_reqs].to(torch.int64)
)
block_indices = (positions_cpu // self.block_size).to(torch.int64)
in_block_offsets = positions_cpu % self.block_size
block_ids = self._block_table.cpu()[req_ids, block_indices].to(torch.int64)
slot_mapping = block_ids * self.block_size + in_block_offsets

max_seq_len = int(seq_lens_i32[:num_reqs].max().item())
query_start_loc_gpu = qsl.to(device=self.device)
Expand Down