Skip to content

Commit 2906e78

Browse files
bear-zdDefTruth
andauthored
[RoPE] Add minimal RoPE f32/f32x4 pack impl (#80)
* [RoPE]: Minimal version of RoPE implementation. Add f32/x4. * Update rope.cu * Update rope.py * Update README.md * Update rope.py * Update rope.cu * Update README.md --------- Co-authored-by: DefTruth <[email protected]>
1 parent ba4998d commit 2906e78

File tree

5 files changed

+287
-0
lines changed

5 files changed

+287
-0
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@
9292
| ✔️ [safe_softmax_f16x8_pack_f32](./softmax/softmax.cu)|f16|f32|[link](./softmax/)|⭐️⭐️|
9393
| ✔️ [online_safe_softmax_f32](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
9494
| ✔️ [online_safe_softmax_f32x4_pack](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
95+
| ✔️ [rope_f32](./rope/rope.cu)|f32|f32|[link](./rope/)|⭐️⭐️|
96+
| ✔️ [rope_f32x4_pack](./rope/rope.cu)|f32|f32|[link](./rope/)|⭐️⭐️|
9597
| ✔️ [layer_norm_f32](./layer-norm/layer_norm.cu)|f32|f32|[link](./layer-norm/)|⭐️⭐️|
9698
| ✔️ [layer_norm_f32x4](./layer-norm/layer_norm.cu)|f32|f32|[link](./layer-norm/)|⭐️⭐️|
9799
| ✔️ [layer_norm_f16_f16](./layer-norm/layer_norm.cu)|f16|f16|[link](./layer-norm/)|⭐️⭐️|

rope/.gitignore

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
*.so
2+
*.a
3+
*.dylib
4+
*.dll
5+
*.lib
6+
.DS_Store
7+
build
8+
*.whl
9+
tmp
10+

rope/README.md

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Sigmoid
2+
3+
## 0x00 说明
4+
5+
RoPE基础版本,包含了RoPE在Llama的最小实现。
6+
7+
包含以下内容:
8+
9+
- [X] rope_f32_kernel
10+
- [X] rope_f32x4_kernel(float4向量化版本)
11+
- [X] PyTorch bindings
12+
13+
14+
## 测试
15+
16+
```bash
17+
# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ...
18+
export TORCH_CUDA_ARCH_LIST=Ada
19+
python3 rope.py
20+
```
21+
22+
输出:
23+
24+
```bash
25+
----------------------------------------------------------------------------------------------------
26+
M=4096, N=512
27+
----------------------------------------------------------------------------------------------------
28+
out_f32: ['1.066324 ', '-1.06176651 ', '-0.16482249 '], time:0.006247ms
29+
out_f32x4_pack: ['1.066324 ', '-1.06176651 ', '-0.16482249 '], time:0.005484ms
30+
out_f32_th: ['1.066324 ', '-1.06176651 ', '-0.16482249 '], time:0.734866ms
31+
----------------------------------------------------------------------------------------------------
32+
M=4096, N=1024
33+
----------------------------------------------------------------------------------------------------
34+
out_f32: ['-0.52068412 ', '1.20729053 ', '0.93223286 '], time:0.010335ms
35+
out_f32x4_pack: ['-0.52068412 ', '1.20729053 ', '0.93223286 '], time:0.008714ms
36+
out_f32_th: ['-0.52068412 ', '1.20729053 ', '0.93223286 '], time:1.447463ms
37+
----------------------------------------------------------------------------------------------------
38+
M=8192, N=512
39+
----------------------------------------------------------------------------------------------------
40+
out_f32: ['-0.19190802 ', '0.43925601 ', '0.58010447 '], time:0.010288ms
41+
out_f32x4_pack: ['-0.19190802 ', '0.43925601 ', '0.58010447 '], time:0.008750ms
42+
out_f32_th: ['-0.19190802 ', '0.43925601 ', '0.58010447 '], time:1.434934ms
43+
----------------------------------------------------------------------------------------------------
44+
M=8192, N=1024
45+
----------------------------------------------------------------------------------------------------
46+
out_f32: ['1.07467616 ', '-0.41201836 ', '-0.34494475 '], time:0.018394ms
47+
out_f32x4_pack: ['1.07467616 ', '-0.41201836 ', '-0.34494475 '], time:0.015330ms
48+
out_f32_th: ['1.07467616 ', '-0.41201836 ', '-0.34494475 '], time:2.518094ms
49+
----------------------------------------------------------------------------------------------------
50+
```

rope/rope.cu

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
#include <stdio.h>
2+
#include <stdlib.h>
3+
#include <float.h>
4+
#include <vector>
5+
#include <algorithm>
6+
#include <cuda_runtime.h>
7+
#include <cuda_fp16.h>
8+
#include <cuda_bf16.h>
9+
#include <cuda_fp8.h>
10+
#include <torch/types.h>
11+
#include <torch/extension.h>
12+
13+
#define INT4(value) (reinterpret_cast<int4*>(&(value))[0])
14+
#define FLOAT4(value) (reinterpret_cast<float4*>(&(value))[0])
15+
#define HALF2(value) (reinterpret_cast<half2*>(&(value))[0])
16+
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0])
17+
#define BLOCK_SIZE 256
18+
#define theta 10000.0f
19+
20+
__global__ void rope_f32_kernel(float* x, float* out, int seq_len, int N){
21+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
22+
float x1 = x[idx * 2];
23+
float x2 = x[idx * 2 + 1];
24+
int token_pos = idx / N;
25+
int token_idx = idx % N;
26+
float exp_v = 1.0f / powf(theta, token_idx / (N * 2));
27+
float sin_v = sinf(token_pos / exp_v);
28+
float cos_v = cosf(token_pos / exp_v);
29+
float out1 = x1 * cos_v - x2 * sin_v;
30+
float out2 = x1 * sin_v + x2 * cos_v;
31+
out[idx * 2] = out1;
32+
out[idx * 2 + 1] = out2;
33+
}
34+
35+
// another index method of rope.
36+
__global__ void rope_f32_v2_kernel(float* x, float* out, int seq_len, int N){
37+
int token_pos = blockIdx.x;
38+
int tid = threadIdx.x;
39+
float x1 = x[token_pos * N * 2 + tid * 2];
40+
float x2 = x[token_pos * N * 2 + tid * 2 + 1];
41+
float exp_v = 1.0f / powf(theta, (int)(tid / 2) / (N * 2));
42+
float sin_v = sinf(token_pos / exp_v);
43+
float cos_v = cosf(token_pos / exp_v);
44+
float out1 = x1 * cos_v - x2 * sin_v;
45+
float out2 = x1 * sin_v + x2 * cos_v;
46+
out[token_pos * N * 2 + tid * 2] = out1;
47+
out[token_pos * N * 2 + tid * 2 + 1] = out2;
48+
}
49+
50+
__global__ void rope_f32x4_pack_kernel(float* x, float* out, int seq_len, int N){
51+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
52+
float4 x_v = FLOAT4(x[idx * 4]);
53+
int token_pos = idx / N;
54+
int token_idx = idx % N;
55+
float exp_f_v = 1.0f / powf(theta, token_idx * 2 / (N * 4));
56+
float exp_s_v = 1.0f / powf(theta, ((token_idx * 2) + 1) / (N * 4));
57+
float sin_f_v = sinf(token_pos / exp_f_v);
58+
float cos_f_v = cosf(token_pos / exp_f_v);
59+
float sin_s_v = sinf(token_pos / exp_s_v);
60+
float cos_s_v = cosf(token_pos / exp_s_v);
61+
float4 out_v;
62+
out_v.x = x_v.x * cos_f_v - x_v.y * sin_f_v;
63+
out_v.y = x_v.x * sin_f_v + x_v.y * cos_f_v;
64+
out_v.z = x_v.z * cos_s_v - x_v.w * sin_s_v;
65+
out_v.w = x_v.z * sin_s_v + x_v.w * cos_s_v;
66+
FLOAT4(out[idx * 4]) = out_v;
67+
}
68+
69+
// --------------------- PyTorch bindings for custom kernel -----------------------
70+
#define STRINGFY(str) #str
71+
#define TORCH_BINDING_COMMON_EXTENSION(func) \
72+
m.def(STRINGFY(func), &func, STRINGFY(func));
73+
74+
#define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \
75+
if (((T).options().dtype() != (th_type))) { \
76+
std::cout << "Tensor Info:" << (T).options() << std::endl; \
77+
throw std::runtime_error("values must be " #th_type); \
78+
}
79+
80+
void rope_f32(torch::Tensor x, torch::Tensor out) {
81+
CHECK_TORCH_TENSOR_DTYPE(x, torch::kFloat32)
82+
CHECK_TORCH_TENSOR_DTYPE(out, torch::kFloat32)
83+
int seq_len = x.size(0);
84+
int hidden_size = x.size(1);
85+
int N = (int)(hidden_size/2);
86+
dim3 grid((seq_len * N + BLOCK_SIZE - 1) / BLOCK_SIZE);
87+
dim3 block(BLOCK_SIZE);
88+
rope_f32_kernel<<<grid, block>>>(
89+
x.data_ptr<float>(), out.data_ptr<float>(), seq_len, N);
90+
}
91+
92+
void rope_f32_v2(torch::Tensor x, torch::Tensor out) {
93+
CHECK_TORCH_TENSOR_DTYPE(x, torch::kFloat32)
94+
CHECK_TORCH_TENSOR_DTYPE(out, torch::kFloat32)
95+
int seq_len = x.size(0);
96+
int hidden_size = x.size(1);
97+
int N = (int)(hidden_size/2);
98+
dim3 grid(seq_len);
99+
dim3 block(N);
100+
rope_f32_v2_kernel<<<grid, block>>>(
101+
x.data_ptr<float>(), out.data_ptr<float>(), seq_len, N);
102+
}
103+
104+
void rope_f32x4_pack(torch::Tensor x, torch::Tensor out) {
105+
CHECK_TORCH_TENSOR_DTYPE(x, torch::kFloat32)
106+
CHECK_TORCH_TENSOR_DTYPE(out, torch::kFloat32)
107+
int seq_len = x.size(0);
108+
int hidden_size = x.size(1);
109+
int N = (int)(hidden_size/4);
110+
dim3 grid((seq_len * N + BLOCK_SIZE - 1) / BLOCK_SIZE);
111+
dim3 block(BLOCK_SIZE);
112+
rope_f32x4_pack_kernel<<<grid, block>>>(
113+
x.data_ptr<float>(), out.data_ptr<float>(), seq_len, N);
114+
}
115+
116+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
117+
TORCH_BINDING_COMMON_EXTENSION(rope_f32)
118+
TORCH_BINDING_COMMON_EXTENSION(rope_f32_v2)
119+
TORCH_BINDING_COMMON_EXTENSION(rope_f32x4_pack)
120+
}

rope/rope.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import torch
2+
import time
3+
import math
4+
from torch.utils.cpp_extension import load
5+
from functools import partial
6+
from typing import Optional
7+
from typing import Tuple
8+
import torch.nn as nn
9+
import torch.nn.functional as F
10+
torch.set_grad_enabled(False)
11+
12+
# Load the CUDA kernel as a python module
13+
lib = load(
14+
name="rope",
15+
sources=["rope.cu"],
16+
extra_cuda_cflags=[
17+
"-O3",
18+
"-U__CUDA_NO_HALF_OPERATORS__",
19+
"-U__CUDA_NO_HALF_CONVERSIONS__",
20+
"-U__CUDA_NO_HALF2_OPERATORS__",
21+
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
22+
"--expt-relaxed-constexpr",
23+
"--expt-extended-lambda",
24+
"--use_fast_math",
25+
],
26+
extra_cflags=["-std=c++17"],
27+
)
28+
29+
30+
def run_benchmark(
31+
perf_func: callable,
32+
a: torch.Tensor,
33+
tag: str,
34+
out: Optional[torch.Tensor] = None,
35+
warmup: int = 2,
36+
iters: int = 20,
37+
show_all: bool = False,
38+
):
39+
if out is not None:
40+
out.fill_(0)
41+
if out is not None:
42+
for i in range(warmup):
43+
perf_func(a, out)
44+
else:
45+
for i in range(warmup):
46+
_ = perf_func(a)
47+
48+
torch.cuda.synchronize()
49+
start = time.time()
50+
# iters
51+
if out is not None:
52+
for i in range(iters):
53+
perf_func(a, out)
54+
else:
55+
for i in range(iters):
56+
out = perf_func(a)
57+
torch.cuda.synchronize()
58+
end = time.time()
59+
total_time = (end - start) * 1000 # ms
60+
mean_time = total_time / iters
61+
out_info = f"out_{tag}"
62+
out_val = out.flatten().detach().cpu().numpy().tolist()[:3]
63+
out_val = [round(v, 8) for v in out_val]
64+
out_val = [f"{v:<12}" for v in out_val]
65+
print(f"{out_info:>20}: {out_val}, time:{mean_time:.6f}ms")
66+
if show_all:
67+
print(out)
68+
return out.clone(), mean_time
69+
70+
71+
def naive_rope(
72+
x: torch.Tensor,
73+
theta: float = 10000.0,
74+
) -> Tuple[torch.Tensor, torch.Tensor]:
75+
dim = x.shape[-1]
76+
seq_len = x.shape[-2]
77+
# get the shape of x (ignore the head dimension).
78+
# x: [batch_size, seq_len, dim]
79+
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
80+
# x_: [batch_size, seq_len, dim//2, 2]
81+
x_ = torch.view_as_complex(x_)
82+
# pack neibored element into a complex
83+
# x_: [batch_size, seq_len, dim//2, 1]. eg: tensor([(1.6116-0.5772j), ...]
84+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
85+
t = torch.arange(seq_len , device=freqs.device)
86+
freqs = torch.outer(t, freqs).float().cuda()
87+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
88+
# get rotate angle
89+
xq_out = torch.view_as_real(x_ * freqs_cis).flatten(1)
90+
# do rotate
91+
return xq_out.type_as(x)
92+
93+
print("-" * 100)
94+
M = [4096, 8192]
95+
N = [512, 1024]
96+
MN = [[m, n] for m in M for n in N]
97+
for M,N in MN:
98+
print(" " * 40 + f"M={M}, N={N}")
99+
print("-" * 100)
100+
x = torch.randn((M, N)).cuda().float().contiguous()
101+
out = torch.zeros_like(x).cuda().float().contiguous()
102+
run_benchmark(lib.rope_f32, x, "f32", out)
103+
run_benchmark(lib.rope_f32x4_pack, x, "f32x4_pack", out)
104+
run_benchmark(naive_rope, x, "f32_th")
105+
print("-" * 100)

0 commit comments

Comments
 (0)