|
| 1 | +#include <stdio.h> |
| 2 | +#include <stdlib.h> |
| 3 | +#include <float.h> |
| 4 | +#include <vector> |
| 5 | +#include <algorithm> |
| 6 | +#include <cuda_runtime.h> |
| 7 | +#include <cuda_fp16.h> |
| 8 | +#include <cuda_bf16.h> |
| 9 | +#include <cuda_fp8.h> |
| 10 | +#include <torch/types.h> |
| 11 | +#include <torch/extension.h> |
| 12 | + |
| 13 | +#define INT4(value) (reinterpret_cast<int4*>(&(value))[0]) |
| 14 | +#define FLOAT4(value) (reinterpret_cast<float4*>(&(value))[0]) |
| 15 | +#define HALF2(value) (reinterpret_cast<half2*>(&(value))[0]) |
| 16 | +#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0]) |
| 17 | +#define BLOCK_SIZE 256 |
| 18 | +#define theta 10000.0f |
| 19 | + |
| 20 | +__global__ void rope_f32_kernel(float* x, float* out, int seq_len, int N){ |
| 21 | + int idx = blockIdx.x * blockDim.x + threadIdx.x; |
| 22 | + float x1 = x[idx * 2]; |
| 23 | + float x2 = x[idx * 2 + 1]; |
| 24 | + int token_pos = idx / N; |
| 25 | + int token_idx = idx % N; |
| 26 | + float exp_v = 1.0f / powf(theta, token_idx / (N * 2)); |
| 27 | + float sin_v = sinf(token_pos / exp_v); |
| 28 | + float cos_v = cosf(token_pos / exp_v); |
| 29 | + float out1 = x1 * cos_v - x2 * sin_v; |
| 30 | + float out2 = x1 * sin_v + x2 * cos_v; |
| 31 | + out[idx * 2] = out1; |
| 32 | + out[idx * 2 + 1] = out2; |
| 33 | +} |
| 34 | + |
| 35 | +// another index method of rope. |
| 36 | +__global__ void rope_f32_v2_kernel(float* x, float* out, int seq_len, int N){ |
| 37 | + int token_pos = blockIdx.x; |
| 38 | + int tid = threadIdx.x; |
| 39 | + float x1 = x[token_pos * N * 2 + tid * 2]; |
| 40 | + float x2 = x[token_pos * N * 2 + tid * 2 + 1]; |
| 41 | + float exp_v = 1.0f / powf(theta, (int)(tid / 2) / (N * 2)); |
| 42 | + float sin_v = sinf(token_pos / exp_v); |
| 43 | + float cos_v = cosf(token_pos / exp_v); |
| 44 | + float out1 = x1 * cos_v - x2 * sin_v; |
| 45 | + float out2 = x1 * sin_v + x2 * cos_v; |
| 46 | + out[token_pos * N * 2 + tid * 2] = out1; |
| 47 | + out[token_pos * N * 2 + tid * 2 + 1] = out2; |
| 48 | +} |
| 49 | + |
| 50 | +__global__ void rope_f32x4_pack_kernel(float* x, float* out, int seq_len, int N){ |
| 51 | + int idx = blockIdx.x * blockDim.x + threadIdx.x; |
| 52 | + float4 x_v = FLOAT4(x[idx * 4]); |
| 53 | + int token_pos = idx / N; |
| 54 | + int token_idx = idx % N; |
| 55 | + float exp_f_v = 1.0f / powf(theta, token_idx * 2 / (N * 4)); |
| 56 | + float exp_s_v = 1.0f / powf(theta, ((token_idx * 2) + 1) / (N * 4)); |
| 57 | + float sin_f_v = sinf(token_pos / exp_f_v); |
| 58 | + float cos_f_v = cosf(token_pos / exp_f_v); |
| 59 | + float sin_s_v = sinf(token_pos / exp_s_v); |
| 60 | + float cos_s_v = cosf(token_pos / exp_s_v); |
| 61 | + float4 out_v; |
| 62 | + out_v.x = x_v.x * cos_f_v - x_v.y * sin_f_v; |
| 63 | + out_v.y = x_v.x * sin_f_v + x_v.y * cos_f_v; |
| 64 | + out_v.z = x_v.z * cos_s_v - x_v.w * sin_s_v; |
| 65 | + out_v.w = x_v.z * sin_s_v + x_v.w * cos_s_v; |
| 66 | + FLOAT4(out[idx * 4]) = out_v; |
| 67 | +} |
| 68 | + |
| 69 | +// --------------------- PyTorch bindings for custom kernel ----------------------- |
| 70 | +#define STRINGFY(str) #str |
| 71 | +#define TORCH_BINDING_COMMON_EXTENSION(func) \ |
| 72 | + m.def(STRINGFY(func), &func, STRINGFY(func)); |
| 73 | + |
| 74 | +#define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \ |
| 75 | +if (((T).options().dtype() != (th_type))) { \ |
| 76 | + std::cout << "Tensor Info:" << (T).options() << std::endl; \ |
| 77 | + throw std::runtime_error("values must be " #th_type); \ |
| 78 | +} |
| 79 | + |
| 80 | +void rope_f32(torch::Tensor x, torch::Tensor out) { |
| 81 | + CHECK_TORCH_TENSOR_DTYPE(x, torch::kFloat32) |
| 82 | + CHECK_TORCH_TENSOR_DTYPE(out, torch::kFloat32) |
| 83 | + int seq_len = x.size(0); |
| 84 | + int hidden_size = x.size(1); |
| 85 | + int N = (int)(hidden_size/2); |
| 86 | + dim3 grid((seq_len * N + BLOCK_SIZE - 1) / BLOCK_SIZE); |
| 87 | + dim3 block(BLOCK_SIZE); |
| 88 | + rope_f32_kernel<<<grid, block>>>( |
| 89 | + x.data_ptr<float>(), out.data_ptr<float>(), seq_len, N); |
| 90 | +} |
| 91 | + |
| 92 | +void rope_f32_v2(torch::Tensor x, torch::Tensor out) { |
| 93 | + CHECK_TORCH_TENSOR_DTYPE(x, torch::kFloat32) |
| 94 | + CHECK_TORCH_TENSOR_DTYPE(out, torch::kFloat32) |
| 95 | + int seq_len = x.size(0); |
| 96 | + int hidden_size = x.size(1); |
| 97 | + int N = (int)(hidden_size/2); |
| 98 | + dim3 grid(seq_len); |
| 99 | + dim3 block(N); |
| 100 | + rope_f32_v2_kernel<<<grid, block>>>( |
| 101 | + x.data_ptr<float>(), out.data_ptr<float>(), seq_len, N); |
| 102 | +} |
| 103 | + |
| 104 | +void rope_f32x4_pack(torch::Tensor x, torch::Tensor out) { |
| 105 | + CHECK_TORCH_TENSOR_DTYPE(x, torch::kFloat32) |
| 106 | + CHECK_TORCH_TENSOR_DTYPE(out, torch::kFloat32) |
| 107 | + int seq_len = x.size(0); |
| 108 | + int hidden_size = x.size(1); |
| 109 | + int N = (int)(hidden_size/4); |
| 110 | + dim3 grid((seq_len * N + BLOCK_SIZE - 1) / BLOCK_SIZE); |
| 111 | + dim3 block(BLOCK_SIZE); |
| 112 | + rope_f32x4_pack_kernel<<<grid, block>>>( |
| 113 | + x.data_ptr<float>(), out.data_ptr<float>(), seq_len, N); |
| 114 | +} |
| 115 | + |
| 116 | +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { |
| 117 | + TORCH_BINDING_COMMON_EXTENSION(rope_f32) |
| 118 | + TORCH_BINDING_COMMON_EXTENSION(rope_f32_v2) |
| 119 | + TORCH_BINDING_COMMON_EXTENSION(rope_f32x4_pack) |
| 120 | +} |
0 commit comments