Skip to content

Commit 6de9e41

Browse files
authored
[Softmax] Update Online Softmax bindings (#155)
* Update softmax.cu * Update softmax.py
1 parent 18375f7 commit 6de9e41

File tree

2 files changed

+18
-10
lines changed

2 files changed

+18
-10
lines changed

kernels/softmax/softmax.cu

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -600,9 +600,15 @@ online_safe_softmax_f32x4_pack_per_token_kernel<(H/4)><<<grid, block>>>( \
600600
case 1024: \
601601
LANUCH_ONLINE_SOFTMAX_F32X4_PACK_PER_TOKEN_KERNEL(1024) \
602602
break; \
603+
case 2048: \
604+
LANUCH_ONLINE_SOFTMAX_F32X4_PACK_PER_TOKEN_KERNEL(2048) \
605+
break; \
606+
case 4096: \
607+
LANUCH_ONLINE_SOFTMAX_F32X4_PACK_PER_TOKEN_KERNEL(4096) \
608+
break; \
603609
default: \
604610
throw std::runtime_error( \
605-
"only support H: 128/256/512/1024; raise error if warp_num*4 > H"); \
611+
"only support H: 128/256/.../4096;"); \
606612
break; \
607613
}
608614

kernels/softmax/softmax.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
def run_benchmark(perf_func: callable, x: torch.Tensor,
2626
tag: str, out: Optional[torch.Tensor] = None,
27-
warmup: int = 10, iters: int = 1000,
27+
warmup: int = 10, iters: int = 100,
2828
show_all: bool = False):
2929
if out is not None:
3030
out.fill_(0)
@@ -60,7 +60,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
6060
N = 128 * 128
6161
print(" " * 45 + f"N={N}")
6262
print("-" * 100)
63-
x = torch.randn((N)).cuda().float()
63+
x = torch.randn((N), device="cuda").cuda().float()
6464
out = torch.zeros_like(x).cuda().float().contiguous()
6565
run_benchmark(lib.softmax_f32, x, "f32(fence)", out)
6666
run_benchmark(lib.softmax_f32x4, x, "f32x4(fence)", out)
@@ -71,7 +71,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
7171
S, H = 4096, 256
7272
print(" " * 45 + f"S={S}, H={H}")
7373
print("-" * 100)
74-
x = torch.randn((S, H)).cuda().float().contiguous()
74+
x = torch.randn((S, H), device="cuda").cuda().float().contiguous()
7575
out = torch.zeros_like(x).cuda().float().contiguous()
7676
run_benchmark(lib.softmax_f32_per_token, x, "f32(per)", out)
7777
run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out)
@@ -95,7 +95,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
9595
S, H = 4096, 512
9696
print(" " * 45 + f"S={S}, H={H}")
9797
print("-" * 100)
98-
x = torch.randn((S, H)).cuda().float().contiguous()
98+
x = torch.randn((S, H), device="cuda").cuda().float().contiguous()
9999
out = torch.zeros_like(x).cuda().float().contiguous()
100100
run_benchmark(lib.softmax_f32_per_token, x, "f32(per)", out)
101101
run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out)
@@ -119,7 +119,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
119119
S, H = 4096, 1024
120120
print(" " * 45 + f"S={S}, H={H}")
121121
print("-" * 100)
122-
x = torch.randn((S, H)).cuda().float().contiguous()
122+
x = torch.randn((S, H), device="cuda").cuda().float().contiguous()
123123
out = torch.zeros_like(x).cuda().float().contiguous()
124124
run_benchmark(lib.softmax_f32_per_token, x, "f32(per)", out)
125125
run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out)
@@ -143,10 +143,11 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
143143
S, H = 4096, 2048
144144
print(" " * 45 + f"S={S}, H={H}")
145145
print("-" * 100)
146-
x = torch.randn((S, H)).cuda().float().contiguous()
146+
x = torch.randn((S, H), device="cuda").cuda().float().contiguous()
147147
out = torch.zeros_like(x).cuda().float().contiguous()
148148
run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out)
149149
run_benchmark(lib.safe_softmax_f32x4_per_token, x, "f32x4(safe)", out)
150+
run_benchmark(lib.online_safe_softmax_f32x4_pack_per_token, x, "f32x4(safe+online)", out)
150151
run_benchmark(partial(torch.softmax, dim=1, out=out), x, "f32_th(per)")
151152

152153
print("-" * 100)
@@ -162,10 +163,11 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
162163
S, H = 4096, 4096
163164
print(" " * 45 + f"S={S}, H={H}")
164165
print("-" * 100)
165-
x = torch.randn((S, H)).cuda().float().contiguous()
166+
x = torch.randn((S, H), device="cuda").cuda().float().contiguous()
166167
out = torch.zeros_like(x).cuda().float().contiguous()
167168
run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out)
168169
run_benchmark(lib.safe_softmax_f32x4_per_token, x, "f32x4(safe)", out)
170+
run_benchmark(lib.online_safe_softmax_f32x4_pack_per_token, x, "f32x4(safe+online)", out)
169171
run_benchmark(partial(torch.softmax, dim=1, out=out), x, "f32_th(per)")
170172

171173
print("-" * 100)
@@ -180,7 +182,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
180182
S, H = 4096, 8192
181183
print(" " * 45 + f"S={S}, H={H}")
182184
print("-" * 100)
183-
x = torch.randn((S, H)).cuda().float().contiguous()
185+
x = torch.randn((S, H), device="cuda").cuda().float().contiguous()
184186
out = torch.zeros_like(x).cuda().float().contiguous()
185187
x_f16 = x.half().contiguous()
186188
out_f16 = out.half().contiguous()
@@ -192,7 +194,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
192194
S, H = 8192, 8192
193195
print(" " * 45 + f"S={S}, H={H}")
194196
print("-" * 100)
195-
x = torch.randn((S, H)).cuda().float().contiguous()
197+
x = torch.randn((S, H), device="cuda").cuda().float().contiguous()
196198
out = torch.zeros_like(x).cuda().float().contiguous()
197199
x_f16 = x.half().contiguous()
198200
out_f16 = out.half().contiguous()

0 commit comments

Comments
 (0)