|
2 | 2 | * Copyright (c) 2023, Tri Dao. |
3 | 3 | ******************************************************************************/ |
4 | 4 |
|
5 | | -#include <ATen/cuda/CUDAContext.h> |
6 | 5 | #include <c10/cuda/CUDAGuard.h> |
7 | | -#include <torch/extension.h> |
| 6 | +#include <c10/cuda/CUDAStream.h> |
| 7 | +#include <torch/python.h> |
8 | 8 | #include <vector> |
9 | 9 |
|
10 | 10 | #include "selective_scan.h" |
@@ -323,7 +323,7 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, |
323 | 323 |
|
324 | 324 | // Otherwise the kernel will be launched from cuda:0 device |
325 | 325 | // Cast to char to avoid compiler warning about narrowing |
326 | | - at::cuda::CUDAGuard device_guard{(char)u.get_device()}; |
| 326 | + at::cuda::CUDAGuard device_guard{u.device()}; |
327 | 327 | auto stream = at::cuda::getCurrentCUDAStream().stream(); |
328 | 328 | DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { |
329 | 329 | DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_fwd", [&] { |
@@ -478,7 +478,7 @@ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta, |
478 | 478 |
|
479 | 479 | // Otherwise the kernel will be launched from cuda:0 device |
480 | 480 | // Cast to char to avoid compiler warning about narrowing |
481 | | - at::cuda::CUDAGuard device_guard{(char)u.get_device()}; |
| 481 | + at::cuda::CUDAGuard device_guard{u.device()}; |
482 | 482 | auto stream = at::cuda::getCurrentCUDAStream().stream(); |
483 | 483 | DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_bwd", [&] { |
484 | 484 | DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_bwd", [&] { |
|
0 commit comments