1+ #include < torch/extension.h>
2+ #include < cuda.h>
3+ #include < cuda_runtime.h>
4+
5+ #include < ATen/cuda/CUDAContext.h>
6+ #include < cuda_runtime.h>
7+
8+ __global__ void write_flag_kernel (int64_t * flag, int64_t * seq) {
9+ int64_t seq_value = seq[0 ];
10+ if (threadIdx .x == 0 ) {
11+ flag[0 ] = seq_value;
12+ // 写入后执行 system fence,确保写入对所有线程和 CPU 可见
13+ }
14+ __threadfence_system ();
15+ }
16+
17+ __global__ void wait_flag_kernel (int64_t * flag, int64_t * seq) {
18+ if (threadIdx .x == 0 ) {
19+ // Mark pointer volatile so we reload host-written values each iteration.
20+ volatile int64_t * flag_ptr = flag, *seq_ptr = seq;
21+ int64_t flag_value = flag_ptr[0 ];
22+ int64_t seq_value = seq_ptr[0 ];
23+ while (flag_value < seq_value) {
24+ __nanosleep (128 );
25+ flag_value = flag_ptr[0 ];
26+ }
27+ }
28+ }
29+
30+ __global__ void seq_add_one_kernel (int64_t * seq) {
31+ if (threadIdx .x == 0 ) {
32+ seq[0 ]++;
33+ }
34+ __threadfence_system ();
35+ }
36+
37+ static void check_cuda (cudaError_t err, const char * msg) {
38+ TORCH_CHECK (err == cudaSuccess, msg, " : " , cudaGetErrorString (err));
39+ }
40+
41+ torch::Tensor map_pinned_tensor (torch::Tensor tensor, int64_t device_index) {
42+ TORCH_CHECK (tensor.is_pinned (), " tensor must be pinned" );
43+ void * host_ptr = tensor.data_ptr ();
44+ void * device_ptr = nullptr ;
45+ check_cuda (cudaHostGetDevicePointer (&device_ptr, host_ptr, 0 ),
46+ " cudaHostGetDevicePointer failed" );
47+ auto options = tensor.options ().device (torch::kCUDA , device_index);
48+ auto sizes = tensor.sizes ();
49+ auto strides = tensor.strides ();
50+ return torch::from_blob (device_ptr, sizes, strides, [](void *){}, options);
51+ }
52+
53+ void write_flag (torch::Tensor flag, torch::Tensor seq) {
54+ TORCH_CHECK (flag.is_cuda (), " flag must be a CUDA tensor" );
55+ auto stream = at::cuda::getCurrentCUDAStream (flag.device ().index ());
56+ write_flag_kernel<<<1 , 1 , 0 , stream>>> (flag.data_ptr <int64_t >(), seq.data_ptr <int64_t >());
57+ check_cuda (cudaGetLastError (), " write_flag_kernel launch failed" );
58+ }
59+
60+ void wait_flag (torch::Tensor flag, torch::Tensor seq) {
61+ TORCH_CHECK (flag.is_cuda (), " flag must be a CUDA tensor" );
62+ auto stream = at::cuda::getCurrentCUDAStream (flag.device ().index ());
63+ wait_flag_kernel<<<1 , 1 , 0 , stream>>> (flag.data_ptr <int64_t >(), seq.data_ptr <int64_t >());
64+ check_cuda (cudaGetLastError (), " wait_flag_kernel launch failed" );
65+ }
66+
67+ void seq_add_one (torch::Tensor seq) {
68+ TORCH_CHECK (seq.is_cuda (), " seq must be a CUDA tensor" );
69+ auto stream = at::cuda::getCurrentCUDAStream (seq.device ().index ());
70+ seq_add_one_kernel<<<1 , 1 , 0 , stream>>> (seq.data_ptr <int64_t >());
71+ check_cuda (cudaGetLastError (), " seq_add_one_kernel launch failed" );
72+ }
0 commit comments