24
24
BLOCK_SIZES = [16 , 32 ]
25
25
USE_ALIBI = [False , True ]
26
26
SEEDS = [0 ]
27
+ DEVICES = [i for i in range (1 if torch .cuda .device_count () == 1 else 2 )]
27
28
28
29
29
30
def ref_masked_attention (
@@ -87,7 +88,7 @@ def ref_single_query_cached_kv_attention(
87
88
alibi_bias = None
88
89
if alibi_slopes is not None :
89
90
# Create the ALiBi bias used in the paged attention kernel.
90
- position_ids = torch .arange (context_len , device = "cuda" ).int ()
91
+ position_ids = torch .arange (context_len , device = query . device ).int ()
91
92
alibi_bias = (position_ids - context_len + 1 ).float ()
92
93
alibi_bias = alibi_slopes .view (- 1 , 1 , 1 ) * alibi_bias .view (
93
94
1 , 1 , - 1 )
@@ -105,6 +106,7 @@ def ref_single_query_cached_kv_attention(
105
106
@pytest .mark .parametrize ("block_size" , BLOCK_SIZES )
106
107
@pytest .mark .parametrize ("dtype" , DTYPES )
107
108
@pytest .mark .parametrize ("seed" , SEEDS )
109
+ @pytest .mark .parametrize ("device" , DEVICES )
108
110
def test_paged_attention (
109
111
kv_cache_factory ,
110
112
version : str ,
@@ -115,18 +117,19 @@ def test_paged_attention(
115
117
block_size : int ,
116
118
dtype : torch .dtype ,
117
119
seed : int ,
120
+ device : int ,
118
121
) -> None :
119
122
random .seed (seed )
120
123
torch .random .manual_seed (seed )
121
124
torch .cuda .manual_seed (seed )
122
-
125
+ gpu_id = f"cuda: { device } "
123
126
scale = float (1.0 / (head_size ** 0.5 ))
124
127
num_query_heads , num_kv_heads = num_heads
125
128
query = torch .empty (num_seqs ,
126
129
num_query_heads ,
127
130
head_size ,
128
131
dtype = dtype ,
129
- device = "cuda" )
132
+ device = gpu_id )
130
133
query .uniform_ (- scale , scale )
131
134
132
135
assert num_query_heads % num_kv_heads == 0
@@ -135,12 +138,12 @@ def test_paged_attention(
135
138
if use_alibi :
136
139
alibi_slopes = torch .randn (num_query_heads ,
137
140
dtype = torch .float ,
138
- device = "cuda" )
141
+ device = gpu_id )
139
142
140
143
context_lens = [random .randint (1 , MAX_SEQ_LEN ) for _ in range (num_seqs )]
141
144
context_lens [- 1 ] = MAX_SEQ_LEN
142
145
max_context_len = max (context_lens )
143
- context_lens = torch .tensor (context_lens , dtype = torch .int , device = "cuda" )
146
+ context_lens = torch .tensor (context_lens , dtype = torch .int , device = gpu_id )
144
147
145
148
# Create the block tables.
146
149
max_num_blocks_per_seq = (max_context_len + block_size - 1 ) // block_size
@@ -151,12 +154,12 @@ def test_paged_attention(
151
154
for _ in range (max_num_blocks_per_seq )
152
155
]
153
156
block_tables .append (block_table )
154
- block_tables = torch .tensor (block_tables , dtype = torch .int , device = "cuda" )
157
+ block_tables = torch .tensor (block_tables , dtype = torch .int , device = gpu_id )
155
158
156
159
# Create the KV caches.
157
160
key_caches , value_caches = kv_cache_factory (NUM_BLOCKS , block_size , 1 ,
158
161
num_kv_heads , head_size , dtype ,
159
- seed )
162
+ seed , gpu_id )
160
163
key_cache , value_cache = key_caches [0 ], value_caches [0 ]
161
164
162
165
# Call the paged attention kernel.
@@ -249,7 +252,7 @@ def ref_multi_query_kv_attention(
249
252
attn_mask = torch .triu (torch .ones (seq_len , seq_len , dtype = dtype ),
250
253
diagonal = 1 )
251
254
attn_mask = attn_mask * torch .finfo (dtype ).min
252
- attn_mask = attn_mask .to (dtype = dtype , device = "cuda" )
255
+ attn_mask = attn_mask .to (dtype = dtype , device = query . device )
253
256
254
257
ref_output = ref_masked_attention (
255
258
query [start_idx :end_idx ],
@@ -269,18 +272,20 @@ def ref_multi_query_kv_attention(
269
272
@pytest .mark .parametrize ("head_size" , HEAD_SIZES )
270
273
@pytest .mark .parametrize ("dtype" , DTYPES )
271
274
@pytest .mark .parametrize ("seed" , SEEDS )
275
+ @pytest .mark .parametrize ("device" , DEVICES )
272
276
@torch .inference_mode ()
273
277
def test_multi_query_kv_attention (
274
278
num_seqs : int ,
275
279
num_heads : Tuple [int , int ],
276
280
head_size : int ,
277
281
dtype : torch .dtype ,
278
282
seed : int ,
283
+ device : int ,
279
284
) -> None :
280
285
random .seed (seed )
281
286
torch .random .manual_seed (seed )
282
287
torch .cuda .manual_seed (seed )
283
-
288
+ gpu_id = f"cuda: { device } "
284
289
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
285
290
# As the xformers library is already tested with its own tests, we can use
286
291
# a smaller MAX_SEQ_LEN here.
@@ -294,7 +299,7 @@ def test_multi_query_kv_attention(
294
299
num_query_heads + 2 * num_kv_heads ,
295
300
head_size ,
296
301
dtype = dtype ,
297
- device = "cuda" )
302
+ device = gpu_id )
298
303
qkv .uniform_ (- scale , scale )
299
304
query , key , value = qkv .split (
300
305
[num_query_heads , num_kv_heads , num_kv_heads ], dim = 1 )
0 commit comments