Skip to content

Commit 77af974

Browse files
authored
[FIX] Support non-zero CUDA devices in custom kernels (#1959)
1 parent 4934d49 commit 77af974

12 files changed

+74
-30
lines changed

csrc/activation_kernels.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
#include <torch/extension.h>
21
#include <ATen/cuda/CUDAContext.h>
2+
#include <torch/extension.h>
3+
#include <c10/cuda/CUDAGuard.h>
34

45
#include "cuda_compat.h"
56
#include "dispatch_utils.h"
@@ -36,6 +37,7 @@ void silu_and_mul(
3637

3738
dim3 grid(num_tokens);
3839
dim3 block(std::min(d, 1024));
40+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
3941
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
4042
VLLM_DISPATCH_FLOATING_TYPES(
4143
input.scalar_type(),
@@ -71,6 +73,7 @@ __global__ void activation_kernel(
7173
int64_t num_tokens = input.numel() / d; \
7274
dim3 grid(num_tokens); \
7375
dim3 block(std::min(d, 1024)); \
76+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
7477
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
7578
VLLM_DISPATCH_FLOATING_TYPES( \
7679
input.scalar_type(), \

csrc/attention/attention_kernels.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
#include <torch/extension.h>
2323
#include <ATen/cuda/CUDAContext.h>
24+
#include <c10/cuda/CUDAGuard.h>
2425

2526
#include "attention_dtypes.h"
2627
#include "attention_utils.cuh"
@@ -616,6 +617,7 @@ void paged_attention_v1_launcher(
616617

617618
dim3 grid(num_heads, num_seqs, 1);
618619
dim3 block(NUM_THREADS);
620+
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
619621
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
620622
switch (head_size) {
621623
// NOTE(woosuk): To reduce the compilation time, we only compile for the
@@ -784,6 +786,7 @@ void paged_attention_v2_launcher(
784786
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
785787

786788
dim3 block(NUM_THREADS);
789+
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
787790
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
788791
switch (head_size) {
789792
// NOTE(woosuk): To reduce the compilation time, we only compile for the

csrc/cache_kernels.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <torch/extension.h>
22
#include <ATen/cuda/CUDAContext.h>
3+
#include <c10/cuda/CUDAGuard.h>
34

45
#include "cuda_compat.h"
56
#include "dispatch_utils.h"
@@ -33,6 +34,7 @@ void swap_blocks(
3334
char *dst_ptr = static_cast<char*>(dst.data_ptr());
3435

3536
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
37+
const at::cuda::OptionalCUDAGuard device_guard(src_device);
3638
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
3739
// NOTE(woosuk): This can be slow if the number of blocks is large.
3840
for (const auto& pair : block_mapping) {
@@ -127,6 +129,7 @@ void copy_blocks(
127129
const int numel_per_block = key_caches[0][0].numel();
128130
dim3 grid(num_layers, num_pairs);
129131
dim3 block(std::min(1024, numel_per_block));
132+
const at::cuda::OptionalCUDAGuard device_guard(cache_device);
130133
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
131134
VLLM_DISPATCH_FLOATING_TYPES(
132135
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
@@ -207,6 +210,7 @@ void reshape_and_cache(
207210

208211
dim3 grid(num_tokens);
209212
dim3 block(std::min(num_heads * head_size, 512));
213+
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
210214
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
211215
VLLM_DISPATCH_FLOATING_TYPES(
212216
key.scalar_type(),
@@ -367,6 +371,7 @@ void gather_cached_kv(
367371

368372
dim3 grid(num_tokens);
369373
dim3 block(std::min(num_heads * head_size, 512));
374+
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
370375
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
371376
VLLM_DISPATCH_FLOATING_TYPES(
372377
key.scalar_type(),

csrc/layernorm_kernels.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <torch/extension.h>
22
#include <ATen/cuda/CUDAContext.h>
3+
#include <c10/cuda/CUDAGuard.h>
34

45
#include "dispatch_utils.h"
56
#include "reduction_utils.cuh"
@@ -76,6 +77,7 @@ void rms_norm(
7677

7778
dim3 grid(num_tokens);
7879
dim3 block(std::min(hidden_size, 1024));
80+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
7981
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
8082
VLLM_DISPATCH_FLOATING_TYPES(
8183
input.scalar_type(),
@@ -101,6 +103,7 @@ void fused_add_rms_norm(
101103

102104
dim3 grid(num_tokens);
103105
dim3 block(std::min(hidden_size, 1024));
106+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
104107
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
105108
VLLM_DISPATCH_FLOATING_TYPES(
106109
input.scalar_type(),

csrc/pos_encoding_kernels.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <torch/extension.h>
22
#include <ATen/cuda/CUDAContext.h>
3+
#include <c10/cuda/CUDAGuard.h>
34

45
#include "cuda_compat.h"
56
#include "dispatch_utils.h"
@@ -94,6 +95,7 @@ void rotary_embedding(
9495

9596
dim3 grid(num_tokens);
9697
dim3 block(std::min(num_heads * rot_dim / 2, 512));
98+
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
9799
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
98100
VLLM_DISPATCH_FLOATING_TYPES(
99101
query.scalar_type(),

csrc/quantization/squeezellm/quant_cuda_kernel.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
// half-tensor
88
#include <c10/cuda/CUDAStream.h>
99
#include <ATen/cuda/CUDATensorMethods.cuh>
10+
#include <c10/cuda/CUDAGuard.h>
1011

1112
#define BLOCKWIDTH 128
1213
#define BLOCKHEIGHT4 16
@@ -199,7 +200,7 @@ void squeezellm_gemm(
199200
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
200201
);
201202
dim3 threads(BLOCKWIDTH);
202-
203+
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
203204
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads>>>(
204205
#ifndef USE_ROCM
205206
(half2*) vec.data<at::Half>(),

tests/kernels/conftest.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def create_kv_caches(
1212
head_size: int,
1313
dtype: torch.dtype,
1414
seed: int,
15+
device: str,
1516
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
1617
torch.random.manual_seed(seed)
1718
torch.cuda.manual_seed(seed)
@@ -23,7 +24,7 @@ def create_kv_caches(
2324
for _ in range(num_layers):
2425
key_cache = torch.empty(size=key_cache_shape,
2526
dtype=dtype,
26-
device='cuda')
27+
device=device)
2728
key_cache.uniform_(-scale, scale)
2829
key_caches.append(key_cache)
2930

@@ -32,7 +33,7 @@ def create_kv_caches(
3233
for _ in range(num_layers):
3334
value_cache = torch.empty(size=value_cache_shape,
3435
dtype=dtype,
35-
device='cuda')
36+
device=device)
3637
value_cache.uniform_(-scale, scale)
3738
value_caches.append(value_cache)
3839
return key_caches, value_caches

tests/kernels/test_activation.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,26 @@
77
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
88
D = [512, 4096, 5120, 13824] # Arbitrary values for testing
99
SEEDS = [0]
10+
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
1011

1112

1213
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
1314
@pytest.mark.parametrize("d", D)
1415
@pytest.mark.parametrize("dtype", DTYPES)
1516
@pytest.mark.parametrize("seed", SEEDS)
17+
@pytest.mark.parametrize("device", DEVICES)
1618
@torch.inference_mode()
1719
def test_silu_and_mul(
1820
num_tokens: int,
1921
d: int,
2022
dtype: torch.dtype,
2123
seed: int,
24+
device: int,
2225
) -> None:
2326
torch.random.manual_seed(seed)
2427
torch.cuda.manual_seed(seed)
25-
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device="cuda")
28+
gpu_id = f"cuda:{device}"
29+
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device=gpu_id)
2630
layer = SiluAndMul()
2731
out = layer(x)
2832
ref_out = layer._forward(x)
@@ -33,16 +37,19 @@ def test_silu_and_mul(
3337
@pytest.mark.parametrize("d", D)
3438
@pytest.mark.parametrize("dtype", DTYPES)
3539
@pytest.mark.parametrize("seed", SEEDS)
40+
@pytest.mark.parametrize("device", DEVICES)
3641
@torch.inference_mode()
3742
def test_gelu_new(
3843
num_tokens: int,
3944
d: int,
4045
dtype: torch.dtype,
4146
seed: int,
47+
device: int,
4248
) -> None:
4349
torch.random.manual_seed(seed)
4450
torch.cuda.manual_seed(seed)
45-
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
51+
gpu_id = f"cuda:{device}"
52+
x = torch.randn(num_tokens, d, dtype=dtype, device=gpu_id)
4653
layer = NewGELU()
4754
out = layer(x)
4855
ref_out = layer._forward(x)
@@ -53,15 +60,18 @@ def test_gelu_new(
5360
@pytest.mark.parametrize("d", D)
5461
@pytest.mark.parametrize("dtype", DTYPES)
5562
@pytest.mark.parametrize("seed", SEEDS)
63+
@pytest.mark.parametrize("device", DEVICES)
5664
def test_gelu_fast(
5765
num_tokens: int,
5866
d: int,
5967
dtype: torch.dtype,
6068
seed: int,
69+
device: int,
6170
) -> None:
6271
torch.random.manual_seed(seed)
6372
torch.cuda.manual_seed(seed)
64-
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
73+
gpu_id = f"cuda:{device}"
74+
x = torch.randn(num_tokens, d, dtype=dtype, device=gpu_id)
6575
layer = FastGELU()
6676
out = layer(x)
6777
ref_out = layer._forward(x)

tests/kernels/test_attention.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
BLOCK_SIZES = [16, 32]
2525
USE_ALIBI = [False, True]
2626
SEEDS = [0]
27+
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
2728

2829

2930
def ref_masked_attention(
@@ -87,7 +88,7 @@ def ref_single_query_cached_kv_attention(
8788
alibi_bias = None
8889
if alibi_slopes is not None:
8990
# Create the ALiBi bias used in the paged attention kernel.
90-
position_ids = torch.arange(context_len, device="cuda").int()
91+
position_ids = torch.arange(context_len, device=query.device).int()
9192
alibi_bias = (position_ids - context_len + 1).float()
9293
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
9394
1, 1, -1)
@@ -105,6 +106,7 @@ def ref_single_query_cached_kv_attention(
105106
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
106107
@pytest.mark.parametrize("dtype", DTYPES)
107108
@pytest.mark.parametrize("seed", SEEDS)
109+
@pytest.mark.parametrize("device", DEVICES)
108110
def test_paged_attention(
109111
kv_cache_factory,
110112
version: str,
@@ -115,18 +117,19 @@ def test_paged_attention(
115117
block_size: int,
116118
dtype: torch.dtype,
117119
seed: int,
120+
device: int,
118121
) -> None:
119122
random.seed(seed)
120123
torch.random.manual_seed(seed)
121124
torch.cuda.manual_seed(seed)
122-
125+
gpu_id = f"cuda:{device}"
123126
scale = float(1.0 / (head_size**0.5))
124127
num_query_heads, num_kv_heads = num_heads
125128
query = torch.empty(num_seqs,
126129
num_query_heads,
127130
head_size,
128131
dtype=dtype,
129-
device="cuda")
132+
device=gpu_id)
130133
query.uniform_(-scale, scale)
131134

132135
assert num_query_heads % num_kv_heads == 0
@@ -135,12 +138,12 @@ def test_paged_attention(
135138
if use_alibi:
136139
alibi_slopes = torch.randn(num_query_heads,
137140
dtype=torch.float,
138-
device="cuda")
141+
device=gpu_id)
139142

140143
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
141144
context_lens[-1] = MAX_SEQ_LEN
142145
max_context_len = max(context_lens)
143-
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
146+
context_lens = torch.tensor(context_lens, dtype=torch.int, device=gpu_id)
144147

145148
# Create the block tables.
146149
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
@@ -151,12 +154,12 @@ def test_paged_attention(
151154
for _ in range(max_num_blocks_per_seq)
152155
]
153156
block_tables.append(block_table)
154-
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
157+
block_tables = torch.tensor(block_tables, dtype=torch.int, device=gpu_id)
155158

156159
# Create the KV caches.
157160
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
158161
num_kv_heads, head_size, dtype,
159-
seed)
162+
seed, gpu_id)
160163
key_cache, value_cache = key_caches[0], value_caches[0]
161164

162165
# Call the paged attention kernel.
@@ -249,7 +252,7 @@ def ref_multi_query_kv_attention(
249252
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
250253
diagonal=1)
251254
attn_mask = attn_mask * torch.finfo(dtype).min
252-
attn_mask = attn_mask.to(dtype=dtype, device="cuda")
255+
attn_mask = attn_mask.to(dtype=dtype, device=query.device)
253256

254257
ref_output = ref_masked_attention(
255258
query[start_idx:end_idx],
@@ -269,18 +272,20 @@ def ref_multi_query_kv_attention(
269272
@pytest.mark.parametrize("head_size", HEAD_SIZES)
270273
@pytest.mark.parametrize("dtype", DTYPES)
271274
@pytest.mark.parametrize("seed", SEEDS)
275+
@pytest.mark.parametrize("device", DEVICES)
272276
@torch.inference_mode()
273277
def test_multi_query_kv_attention(
274278
num_seqs: int,
275279
num_heads: Tuple[int, int],
276280
head_size: int,
277281
dtype: torch.dtype,
278282
seed: int,
283+
device: int,
279284
) -> None:
280285
random.seed(seed)
281286
torch.random.manual_seed(seed)
282287
torch.cuda.manual_seed(seed)
283-
288+
gpu_id = f"cuda:{device}"
284289
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
285290
# As the xformers library is already tested with its own tests, we can use
286291
# a smaller MAX_SEQ_LEN here.
@@ -294,7 +299,7 @@ def test_multi_query_kv_attention(
294299
num_query_heads + 2 * num_kv_heads,
295300
head_size,
296301
dtype=dtype,
297-
device="cuda")
302+
device=gpu_id)
298303
qkv.uniform_(-scale, scale)
299304
query, key, value = qkv.split(
300305
[num_query_heads, num_kv_heads, num_kv_heads], dim=1)

0 commit comments

Comments
 (0)