Skip to content

Commit 0edaf75

Browse files
authored
[Attention][DBO] Add support for "splitting" the CommonAttentionMetadata (#21153)
Signed-off-by: Sage Moore <[email protected]>
1 parent 6e8d8c4 commit 0edaf75

File tree

2 files changed

+240
-0
lines changed

2 files changed

+240
-0
lines changed
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import pytest
5+
import torch
6+
7+
from tests.v1.attention.test_attention_backends import BATCH_SPECS
8+
from tests.v1.attention.utils import create_common_attn_metadata
9+
from vllm.v1.attention.backends.utils import (UbatchSlice,
10+
_make_metadata_with_slice,
11+
slice_query_start_locs,
12+
split_attn_metadata)
13+
14+
15+
@pytest.fixture
16+
def sample_query_start_loc():
17+
"""Sample query_start_loc tensor for testing"""
18+
return torch.tensor([0, 5, 12, 20, 35, 50])
19+
20+
21+
def test_basic_slice_middle(sample_query_start_loc):
22+
"""Test slicing from middle of tensor"""
23+
req_slice = slice(1, 3) # slice from index 1 to 3
24+
result = slice_query_start_locs(sample_query_start_loc, req_slice)
25+
26+
expected = torch.tensor([0, 7, 15])
27+
assert torch.equal(result, expected)
28+
29+
30+
def test_slice_from_beginning(sample_query_start_loc):
31+
"""Test slicing from the beginning of tensor"""
32+
req_slice = slice(0, 2) # slice from index 0 to 2
33+
result = slice_query_start_locs(sample_query_start_loc, req_slice)
34+
35+
expected = torch.tensor([0, 5, 12])
36+
assert torch.equal(result, expected)
37+
38+
39+
def test_slice_to_end(sample_query_start_loc):
40+
"""Test slicing to the end of tensor"""
41+
req_slice = slice(3, 5) # slice from index 3 to 5 (last index)
42+
result = slice_query_start_locs(sample_query_start_loc, req_slice)
43+
44+
expected = torch.tensor([0, 15, 30])
45+
assert torch.equal(result, expected)
46+
47+
48+
def test_single_element_slice(sample_query_start_loc):
49+
"""Test slice that results in single element"""
50+
req_slice = slice(2, 3) # slice from index 2 to 3
51+
result = slice_query_start_locs(sample_query_start_loc, req_slice)
52+
53+
expected = torch.tensor([0, 8])
54+
assert torch.equal(result, expected)
55+
56+
57+
def test_full_tensor_slice(sample_query_start_loc):
58+
"""Test slicing the entire tensor"""
59+
req_slice = slice(0, 5) # slice entire tensor
60+
result = slice_query_start_locs(sample_query_start_loc, req_slice)
61+
62+
expected = torch.tensor([0, 5, 12, 20, 35, 50])
63+
assert torch.equal(result, expected)
64+
65+
66+
def test_slice_bounds_edge_cases(sample_query_start_loc):
67+
# Test slice that goes exactly to the last element
68+
req_slice = slice(4, 5) # Last index
69+
result = slice_query_start_locs(sample_query_start_loc, req_slice)
70+
71+
expected = torch.tensor([0, 15])
72+
assert torch.equal(result, expected)
73+
74+
75+
@pytest.fixture
76+
def small_decode_metadata():
77+
"""Create metadata for small decode batch"""
78+
batch_spec = BATCH_SPECS["small_decode"]
79+
device = torch.device("cpu")
80+
return create_common_attn_metadata(batch_spec,
81+
block_size=16,
82+
device=device)
83+
84+
85+
@pytest.fixture
86+
def large_decode_metadata():
87+
"""Create metadata for small decode batch"""
88+
batch_spec = BATCH_SPECS["large_decode"]
89+
device = torch.device("cpu")
90+
return create_common_attn_metadata(batch_spec,
91+
block_size=16,
92+
device=device)
93+
94+
95+
@pytest.fixture
96+
def mixed_small_metadata():
97+
"""Create metadata for mixed small batch"""
98+
batch_spec = BATCH_SPECS["mixed_small"]
99+
device = torch.device("cpu")
100+
return create_common_attn_metadata(batch_spec,
101+
block_size=16,
102+
device=device)
103+
104+
105+
# Tests for _make_metadata_with_slice
106+
def test_make_metadata_with_slice_decode_batch(small_decode_metadata):
107+
"""Test slicing decode batch metadata"""
108+
# Split first request only
109+
ubatch_slice = UbatchSlice(slice(0, 1), slice(0, 1))
110+
111+
result = _make_metadata_with_slice(ubatch_slice, small_decode_metadata)
112+
113+
# Check sliced results
114+
assert result.num_reqs == 1 # slice(0, 1) gives 1 requests
115+
assert result.num_actual_tokens == 1 # slice(0, 1) gives 1 token
116+
assert result.max_query_len == 1
117+
assert torch.equal(result.query_start_loc, torch.tensor([0, 1]))
118+
assert torch.equal(result.seq_lens, torch.tensor([32]))
119+
120+
121+
def test_make_metadata_with_slice_mixed_batch(mixed_small_metadata):
122+
"""Test slicing mixed batch metadata"""
123+
ubatch_slice = UbatchSlice(slice(1, 3),
124+
slice(1, 7)) # Requests 1-3, tokens 1-7
125+
126+
result = _make_metadata_with_slice(ubatch_slice, mixed_small_metadata)
127+
128+
assert result.num_reqs == 2 # slice(1, 3) gives 2 requests
129+
assert result.num_actual_tokens == 6 # slice(1, 7) gives 6 tokens
130+
assert result.max_query_len == 5
131+
assert torch.equal(result.query_start_loc, torch.tensor([0, 1, 6]))
132+
assert torch.equal(result.seq_lens, torch.tensor([40, 48]))
133+
134+
135+
def test_split_attn_metadata_decode_batch(large_decode_metadata):
136+
"""Test splitting decode batch into two equal parts"""
137+
num_tokens = large_decode_metadata.num_reqs
138+
mid_point = num_tokens // 2
139+
ubatch_slices = [
140+
UbatchSlice(slice(0, mid_point), slice(0, mid_point)),
141+
UbatchSlice(slice(mid_point, num_tokens), slice(mid_point,
142+
num_tokens)),
143+
]
144+
145+
results = split_attn_metadata(ubatch_slices, large_decode_metadata)
146+
147+
assert len(results) == 2
148+
149+
# Check first split
150+
assert results[0].num_reqs == mid_point
151+
assert results[0].num_actual_tokens == mid_point
152+
assert torch.equal(results[0].seq_lens, torch.tensor([2048] * mid_point))
153+
154+
# Check second split
155+
assert results[1].num_reqs == mid_point
156+
assert results[1].num_actual_tokens == mid_point
157+
assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point))

vllm/v1/attention/backends/utils.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,89 @@ class CommonAttentionMetadata:
6363
causal: bool = True
6464

6565

66+
@dataclass
67+
class UbatchSlice:
68+
request_slice: slice
69+
token_slice: slice
70+
71+
72+
def slice_query_start_locs(
73+
query_start_loc: torch.Tensor,
74+
request_slice: slice,
75+
) -> torch.Tensor:
76+
"""
77+
Creates a new query_start_loc that corresponds to the requests in
78+
request_slice.
79+
80+
Note: This function creates a new tensor to hold the new query_start_locs.
81+
This will break cudagraph compatibility.
82+
"""
83+
return query_start_loc[request_slice.start: request_slice.stop + 1] -\
84+
query_start_loc[request_slice.start]
85+
86+
87+
def _make_metadata_with_slice(
88+
ubatch_slice: UbatchSlice,
89+
attn_metadata: CommonAttentionMetadata) -> CommonAttentionMetadata:
90+
"""
91+
This function creates a new CommonAttentionMetadata that corresponds to
92+
the requests included in ubatch_slice
93+
"""
94+
95+
request_slice = ubatch_slice.request_slice
96+
token_slice = ubatch_slice.token_slice
97+
98+
query_start_loc = slice_query_start_locs(attn_metadata.query_start_loc,
99+
request_slice)
100+
assert len(query_start_loc >= 2)
101+
query_start_loc_cpu = slice_query_start_locs(
102+
attn_metadata.query_start_loc_cpu, request_slice)
103+
104+
seq_lens = attn_metadata.seq_lens[request_slice]
105+
seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
106+
num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[
107+
request_slice]
108+
109+
num_requests = request_slice.stop - request_slice.start
110+
num_actual_tokens = token_slice.stop - token_slice.start
111+
max_query_len = int(
112+
torch.max(torch.abs(query_start_loc_cpu[1:] -
113+
query_start_loc_cpu[:-1])).item())
114+
115+
block_table_tensor = attn_metadata.block_table_tensor[request_slice]
116+
slot_mapping = attn_metadata.slot_mapping[token_slice]
117+
118+
return CommonAttentionMetadata(
119+
query_start_loc=query_start_loc,
120+
query_start_loc_cpu=query_start_loc_cpu,
121+
seq_lens=seq_lens,
122+
seq_lens_cpu=seq_lens_cpu,
123+
num_computed_tokens_cpu=num_computed_tokens_cpu,
124+
num_reqs=num_requests,
125+
num_actual_tokens=num_actual_tokens,
126+
max_query_len=max_query_len,
127+
block_table_tensor=block_table_tensor,
128+
slot_mapping=slot_mapping,
129+
)
130+
131+
132+
def split_attn_metadata(
133+
ubatch_slices: list[UbatchSlice],
134+
common_attn_metadata: CommonAttentionMetadata,
135+
) -> list[CommonAttentionMetadata]:
136+
"""
137+
Creates a new CommonAttentionMetadata instance that corresponds to the
138+
requests for each UbatchSlice in ubatch_slices.
139+
140+
Note: This function does not modify common_attn_metadata
141+
"""
142+
results = []
143+
for ubatch_slice in ubatch_slices:
144+
results.append(
145+
_make_metadata_with_slice(ubatch_slice, common_attn_metadata))
146+
return results
147+
148+
66149
M = TypeVar("M")
67150

68151

0 commit comments

Comments
 (0)