Skip to content

Commit b29375c

Browse files
committed
Reduce torch #include
1 parent e68c379 commit b29375c

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

csrc/selective_scan/selective_scan.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
* Copyright (c) 2023, Tri Dao.
33
******************************************************************************/
44

5-
#include <ATen/cuda/CUDAContext.h>
65
#include <c10/cuda/CUDAGuard.h>
7-
#include <torch/extension.h>
6+
#include <c10/cuda/CUDAStream.h>
7+
#include <torch/python.h>
88
#include <vector>
99

1010
#include "selective_scan.h"
@@ -323,7 +323,7 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
323323

324324
// Otherwise the kernel will be launched from cuda:0 device
325325
// Cast to char to avoid compiler warning about narrowing
326-
at::cuda::CUDAGuard device_guard{(char)u.get_device()};
326+
at::cuda::CUDAGuard device_guard{u.device()};
327327
auto stream = at::cuda::getCurrentCUDAStream().stream();
328328
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
329329
DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_fwd", [&] {
@@ -478,7 +478,7 @@ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta,
478478

479479
// Otherwise the kernel will be launched from cuda:0 device
480480
// Cast to char to avoid compiler warning about narrowing
481-
at::cuda::CUDAGuard device_guard{(char)u.get_device()};
481+
at::cuda::CUDAGuard device_guard{u.device()};
482482
auto stream = at::cuda::getCurrentCUDAStream().stream();
483483
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_bwd", [&] {
484484
DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_bwd", [&] {

0 commit comments

Comments
 (0)