2424
2525def run_benchmark (perf_func : callable , x : torch .Tensor ,
2626 tag : str , out : Optional [torch .Tensor ] = None ,
27- warmup : int = 10 , iters : int = 1000 ,
27+ warmup : int = 10 , iters : int = 100 ,
2828 show_all : bool = False ):
2929 if out is not None :
3030 out .fill_ (0 )
@@ -60,7 +60,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
6060N = 128 * 128
6161print (" " * 45 + f"N={ N } " )
6262print ("-" * 100 )
63- x = torch .randn ((N )).cuda ().float ()
63+ x = torch .randn ((N ), device = "cuda" ).cuda ().float ()
6464out = torch .zeros_like (x ).cuda ().float ().contiguous ()
6565run_benchmark (lib .softmax_f32 , x , "f32(fence)" , out )
6666run_benchmark (lib .softmax_f32x4 , x , "f32x4(fence)" , out )
@@ -71,7 +71,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
7171S , H = 4096 , 256
7272print (" " * 45 + f"S={ S } , H={ H } " )
7373print ("-" * 100 )
74- x = torch .randn ((S , H )).cuda ().float ().contiguous ()
74+ x = torch .randn ((S , H ), device = "cuda" ).cuda ().float ().contiguous ()
7575out = torch .zeros_like (x ).cuda ().float ().contiguous ()
7676run_benchmark (lib .softmax_f32_per_token , x , "f32(per)" , out )
7777run_benchmark (lib .softmax_f32x4_per_token , x , "f32x4(per)" , out )
@@ -95,7 +95,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
9595S , H = 4096 , 512
9696print (" " * 45 + f"S={ S } , H={ H } " )
9797print ("-" * 100 )
98- x = torch .randn ((S , H )).cuda ().float ().contiguous ()
98+ x = torch .randn ((S , H ), device = "cuda" ).cuda ().float ().contiguous ()
9999out = torch .zeros_like (x ).cuda ().float ().contiguous ()
100100run_benchmark (lib .softmax_f32_per_token , x , "f32(per)" , out )
101101run_benchmark (lib .softmax_f32x4_per_token , x , "f32x4(per)" , out )
@@ -119,7 +119,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
119119S , H = 4096 , 1024
120120print (" " * 45 + f"S={ S } , H={ H } " )
121121print ("-" * 100 )
122- x = torch .randn ((S , H )).cuda ().float ().contiguous ()
122+ x = torch .randn ((S , H ), device = "cuda" ).cuda ().float ().contiguous ()
123123out = torch .zeros_like (x ).cuda ().float ().contiguous ()
124124run_benchmark (lib .softmax_f32_per_token , x , "f32(per)" , out )
125125run_benchmark (lib .softmax_f32x4_per_token , x , "f32x4(per)" , out )
@@ -143,10 +143,11 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
143143S , H = 4096 , 2048
144144print (" " * 45 + f"S={ S } , H={ H } " )
145145print ("-" * 100 )
146- x = torch .randn ((S , H )).cuda ().float ().contiguous ()
146+ x = torch .randn ((S , H ), device = "cuda" ).cuda ().float ().contiguous ()
147147out = torch .zeros_like (x ).cuda ().float ().contiguous ()
148148run_benchmark (lib .softmax_f32x4_per_token , x , "f32x4(per)" , out )
149149run_benchmark (lib .safe_softmax_f32x4_per_token , x , "f32x4(safe)" , out )
150+ run_benchmark (lib .online_safe_softmax_f32x4_pack_per_token , x , "f32x4(safe+online)" , out )
150151run_benchmark (partial (torch .softmax , dim = 1 , out = out ), x , "f32_th(per)" )
151152
152153print ("-" * 100 )
@@ -162,10 +163,11 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
162163S , H = 4096 , 4096
163164print (" " * 45 + f"S={ S } , H={ H } " )
164165print ("-" * 100 )
165- x = torch .randn ((S , H )).cuda ().float ().contiguous ()
166+ x = torch .randn ((S , H ), device = "cuda" ).cuda ().float ().contiguous ()
166167out = torch .zeros_like (x ).cuda ().float ().contiguous ()
167168run_benchmark (lib .softmax_f32x4_per_token , x , "f32x4(per)" , out )
168169run_benchmark (lib .safe_softmax_f32x4_per_token , x , "f32x4(safe)" , out )
170+ run_benchmark (lib .online_safe_softmax_f32x4_pack_per_token , x , "f32x4(safe+online)" , out )
169171run_benchmark (partial (torch .softmax , dim = 1 , out = out ), x , "f32_th(per)" )
170172
171173print ("-" * 100 )
@@ -180,7 +182,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
180182S , H = 4096 , 8192
181183print (" " * 45 + f"S={ S } , H={ H } " )
182184print ("-" * 100 )
183- x = torch .randn ((S , H )).cuda ().float ().contiguous ()
185+ x = torch .randn ((S , H ), device = "cuda" ).cuda ().float ().contiguous ()
184186out = torch .zeros_like (x ).cuda ().float ().contiguous ()
185187x_f16 = x .half ().contiguous ()
186188out_f16 = out .half ().contiguous ()
@@ -192,7 +194,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
192194S , H = 8192 , 8192
193195print (" " * 45 + f"S={ S } , H={ H } " )
194196print ("-" * 100 )
195- x = torch .randn ((S , H )).cuda ().float ().contiguous ()
197+ x = torch .randn ((S , H ), device = "cuda" ).cuda ().float ().contiguous ()
196198out = torch .zeros_like (x ).cuda ().float ().contiguous ()
197199x_f16 = x .half ().contiguous ()
198200out_f16 = out .half ().contiguous ()
0 commit comments