Skip to content

Commit 93636df

Browse files
authored
[LayerNorm][FP16] support fp16x8_pack_f32 kernel (#48)
* Update README.md * Update layer_norm.cu * Update layer_norm.py * Update README.md * Update README.md * Update layer_norm.py
1 parent 54c761d commit 93636df

File tree

4 files changed

+164
-54
lines changed

4 files changed

+164
-54
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
| ✔️ [layer_norm_f16x2_f16(per token)](./layer-norm/layer_norm.cu)|f16|f16|[link](./layer-norm/)|⭐️⭐️|
8282
| ✔️ [layer_norm_f16x8_f16(per token)](./layer-norm/layer_norm.cu)|f16|f16|[link](./layer-norm/)|⭐️⭐️|
8383
| ✔️ [layer_norm_f16x8_pack_f16(per token)](./layer-norm/layer_norm.cu)|f16|f16|[link](./layer-norm/)|⭐️⭐️|
84+
| ✔️ [layer_norm_f16x8_pack_f32(per token)](./layer-norm/layer_norm.cu)|f16|f32|[link](./layer-norm/)|⭐️⭐️|
8485
| ✔️ [layer_norm_f16_f32(per token)](./layer-norm/layer_norm.cu)|f16|f32|[link](./layer-norm/)|⭐️⭐️|
8586
| ✔️ [rms_norm_f32(per token)](./rms-norm/rms_norm.cu)|f32|f32|[link](./rms-norm/)|⭐️⭐️|
8687
| ✔️ [rms_norm_f32x4(per token)](./rms-norm/rms_norm.cu)|f32|f32|[link](./rms-norm/)|⭐️⭐️|

layer-norm/README.md

Lines changed: 42 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
- [X] layer_norm_f16x2_f16_kernel
1111
- [X] layer_norm_f16x8_f16_kernel
1212
- [X] layer_norm_f16x8_pack_f16_kernel
13+
- [X] layer_norm_f16x8_pack_f32_kernel
1314
- [X] layer_norm_f16_f32_kernel
1415
- [X] PyTorch bindings
1516

@@ -27,64 +28,70 @@ python3 layer_norm.py
2728
-------------------------------------------------------------------------------------
2829
N=4096, K=512
2930
-------------------------------------------------------------------------------------
30-
out_f32: ['-1.76292217 ', '0.04765211 ', '0.50859255 '], time:0.01897240ms
31-
out_f32x4: ['-1.76292217 ', '0.04765211 ', '0.50859255 '], time:0.00600266ms
32-
out_f32_th: ['-1.76119995 ', '0.04760556 ', '0.50809568 '], time:0.07085347ms
31+
out_f32: ['-0.95119929 ', '0.65728813 ', '-0.27701864 '], time:0.01898599ms
32+
out_f32x4: ['-0.95119929 ', '0.65728813 ', '-0.27701864 '], time:0.00600958ms
33+
out_f32_th: ['-0.95026982 ', '0.65664589 ', '-0.27674797 '], time:0.07345414ms
3334
-------------------------------------------------------------------------------------
34-
out_f16f16: ['-1.76367188 ', '0.04763794 ', '0.50878906 '], time:0.01869035ms
35-
out_f16f32: ['-1.76367188 ', '0.04766846 ', '0.50878906 '], time:0.01897883ms
36-
out_f16x2f16: ['-1.76367188 ', '0.04766846 ', '0.50878906 '], time:0.00951219ms
37-
out_f16x8f16: ['-1.76367188 ', '0.04766846 ', '0.50878906 '], time:0.00467825ms
38-
out_f16x8packf16: ['-1.76367188 ', '0.04763794 ', '0.50878906 '], time:0.00430202ms
39-
out_f16_th: ['-1.76171875 ', '0.04760742 ', '0.50830078 '], time:0.07009959ms
35+
out_f16f16: ['-0.95068359 ', '0.65722656 ', '-0.27709961 '], time:0.01866651ms
36+
out_f16f32: ['-0.95117188 ', '0.65722656 ', '-0.27709961 '], time:0.01897073ms
37+
out_f16x2f16: ['-0.95068359 ', '0.65722656 ', '-0.27709961 '], time:0.00952697ms
38+
out_f16x8f16: ['-0.95068359 ', '0.65722656 ', '-0.27709961 '], time:0.00470805ms
39+
out_f16x8packf16: ['-0.95117188 ', '0.65673828 ', '-0.27709961 '], time:0.00427437ms
40+
out_f16x8packf32: ['-0.95117188 ', '0.65722656 ', '-0.27709961 '], time:0.00418639ms
41+
out_f16_th: ['-0.94970703 ', '0.65673828 ', '-0.27685547 '], time:0.07291913ms
4042
-------------------------------------------------------------------------------------
4143
-------------------------------------------------------------------------------------
4244
N=4096, K=1024
4345
-------------------------------------------------------------------------------------
44-
out_f32: ['-0.65619785 ', '1.33576787 ', '-0.29172164 '], time:0.05123448ms
45-
out_f32x4: ['-0.65619785 ', '1.33576787 ', '-0.29172164 '], time:0.01073551ms
46-
out_f32_th: ['-0.65587735 ', '1.33511555 ', '-0.29157916 '], time:0.07034254ms
46+
out_f32: ['0.81839228 ', '0.36616057 ', '-1.71588480 '], time:0.05122757ms
47+
out_f32x4: ['0.81839228 ', '0.36616057 ', '-1.71588480 '], time:0.01071095ms
48+
out_f32_th: ['0.81799269 ', '0.36598179 ', '-1.71504688 '], time:0.07267237ms
4749
-------------------------------------------------------------------------------------
48-
out_f16f16: ['-0.65576172 ', '1.3359375 ', '-0.29174805 '], time:0.05320668ms
49-
out_f16f32: ['-0.65576172 ', '1.3359375 ', '-0.29150391 '], time:0.05061388ms
50-
out_f16x2f16: ['-0.65576172 ', '1.3359375 ', '-0.29174805 '], time:0.01861978ms
51-
out_f16x8f16: ['-0.65576172 ', '1.3359375 ', '-0.29174805 '], time:0.00745845ms
52-
out_f16x8packf16: ['-0.65576172 ', '1.3359375 ', '-0.29174805 '], time:0.00648832ms
53-
out_f16_th: ['-0.65527344 ', '1.33398438 ', '-0.29150391 '], time:0.07068610ms
50+
out_f16f16: ['0.81835938 ', '0.36596680 ', '-1.71484375 '], time:0.05317926ms
51+
out_f16f32: ['0.81835938 ', '0.36621094 ', '-1.71582031 '], time:0.05062103ms
52+
out_f16x2f16: ['0.81884766 ', '0.36621094 ', '-1.71679688 '], time:0.01855445ms
53+
out_f16x8f16: ['0.81884766 ', '0.36621094 ', '-1.71679688 '], time:0.00742888ms
54+
out_f16x8packf16: ['0.81884766 ', '0.36621094 ', '-1.71679688 '], time:0.00645399ms
55+
out_f16x8packf32: ['0.81835938 ', '0.36621094 ', '-1.71582031 '], time:0.00634456ms
56+
out_f16_th: ['0.81835938 ', '0.36596680 ', '-1.71582031 '], time:0.07386255ms
5457
-------------------------------------------------------------------------------------
5558
-------------------------------------------------------------------------------------
5659
N=4096, K=2048
5760
-------------------------------------------------------------------------------------
58-
out_f32x4: ['0.92044634 ', '0.37421227 ', '-2.49094558 '], time:0.02202415ms
59-
out_f32_th: ['0.92022169 ', '0.37412092 ', '-2.49033761 '], time:0.12026787ms
61+
out_f32x4: ['-0.65341073 ', '0.10270299 ', '-0.06597849 '], time:0.02200651ms
62+
out_f32_th: ['-0.65325129 ', '0.10267793 ', '-0.06596238 '], time:0.12027287ms
6063
-------------------------------------------------------------------------------------
61-
out_f16x2f16: ['0.92041016 ', '0.37426758 ', '-2.49023438 '], time:0.05346847ms
62-
out_f16x8f16: ['0.92041016 ', '0.37426758 ', '-2.49023438 '], time:0.01381087ms
63-
out_f16x8packf16: ['0.92041016 ', '0.37426758 ', '-2.49023438 '], time:0.01159072ms
64-
out_f16_th: ['0.92041016 ', '0.37426758 ', '-2.49023438 '], time:0.08454061ms
64+
out_f16x2f16: ['-0.65332031 ', '0.10266113 ', '-0.06591797 '], time:0.05352354ms
65+
out_f16x8f16: ['-0.65380859 ', '0.10272217 ', '-0.06597900 '], time:0.01377678ms
66+
out_f16x8packf16: ['-0.65332031 ', '0.10266113 ', '-0.06591797 '], time:0.01154637ms
67+
out_f16x8packf32: ['-0.65332031 ', '0.10272217 ', '-0.06597900 '], time:0.01166582ms
68+
out_f16_th: ['-0.65380859 ', '0.10272217 ', '-0.06597900 '], time:0.08442783ms
6569
-------------------------------------------------------------------------------------
6670
-------------------------------------------------------------------------------------
6771
N=4096, K=4096
6872
-------------------------------------------------------------------------------------
69-
out_f32x4: ['-2.05339074 ', '0.25924587 ', '0.42393678 '], time:0.18885875ms
70-
out_f32_th: ['-2.05314016 ', '0.25921422 ', '0.42388505 '], time:0.77834105ms
73+
out_f32x4: ['2.38733387 ', '-0.03023042 ', '0.66022825 '], time:0.18884635ms
74+
out_f32_th: ['2.38704205 ', '-0.03022672 ', '0.66014749 '], time:0.77852798ms
7175
-------------------------------------------------------------------------------------
72-
out_f16x8f16: ['-2.05273438 ', '0.2590332 ', '0.42382812 '], time:0.03327322ms
73-
out_f16x8packf16: ['-2.05273438 ', '0.2590332 ', '0.42382812 '], time:0.02402687ms
74-
out_f16_th: ['-2.05273438 ', '0.2590332 ', '0.42382812 '], time:0.17436218ms
76+
out_f16x8f16: ['2.38671875 ', '-0.03024292 ', '0.66015625 '], time:0.03325391ms
77+
out_f16x8packf16: ['2.38671875 ', '-0.03024292 ', '0.66015625 '], time:0.02401376ms
78+
out_f16x8packf32: ['2.38671875 ', '-0.03021240 ', '0.66064453 '], time:0.02381730ms
79+
out_f16_th: ['2.38671875 ', '-0.03021240 ', '0.66015625 '], time:0.17546010ms
7580
-------------------------------------------------------------------------------------
7681
-------------------------------------------------------------------------------------
7782
N=4096, K=8192
7883
-------------------------------------------------------------------------------------
79-
out_f16x8f16: ['-1.0234375 ', '-0.3371582 ', '-1.54882812 '], time:0.19311237ms
80-
out_f16x8packf16: ['-1.0234375 ', '-0.33691406 ', '-1.54882812 '], time:0.18668032ms
81-
out_f16_th: ['-1.0234375 ', '-0.33691406 ', '-1.54882812 '], time:0.84443021ms
84+
out_f16x8f16: ['0.15905762 ', '1.06542969 ', '-0.19396973 '], time:0.19306803ms
85+
out_f16x8packf16: ['0.15905762 ', '1.06542969 ', '-0.19396973 '], time:0.18665886ms
86+
out_f16x8packf32: ['0.15905762 ', '1.06542969 ', '-0.19396973 '], time:0.18657684ms
87+
out_f16_th: ['0.15905762 ', '1.06542969 ', '-0.19396973 '], time:0.84462571ms
8288
-------------------------------------------------------------------------------------
8389
-------------------------------------------------------------------------------------
8490
N=8192, K=8192
8591
-------------------------------------------------------------------------------------
86-
out_f16x8f16: ['-1.03320312 ', '0.41455078 ', '-0.49707031 '], time:0.38361049ms
87-
out_f16x8packf16: ['-1.03320312 ', '0.41455078 ', '-0.49707031 '], time:0.40809250ms
88-
out_f16_th: ['-1.03320312 ', '0.41455078 ', '-0.49707031 '], time:1.99517584ms
92+
out_f16x8f16: ['-0.53662109 ', '2.359375 ', '0.78027344 '], time:0.38366604ms
93+
out_f16x8packf16: ['-0.53662109 ', '2.359375 ', '0.78027344 '], time:0.40789628ms
94+
out_f16x8packf32: ['-0.53613281 ', '2.359375 ', '0.78027344 '], time:0.40818143ms
95+
out_f16_th: ['-0.53662109 ', '2.359375 ', '0.78027344 '], time:1.99523735ms
8996
-------------------------------------------------------------------------------------
9097
```

layer-norm/layer_norm.cu

Lines changed: 115 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,52 @@ __global__ void layer_norm_f16x8_pack_f16_kernel(half* x, half* y, float g, floa
376376
// TODO: support non 8-multiple K here
377377
}
378378

379+
template<const int NUM_THREADS=256>
380+
__global__ void layer_norm_f16x8_pack_f32_kernel(half* x, half* y, float g, float b, int N, int K) {
381+
int tid = threadIdx.x; // 0..K-1
382+
int bid = blockIdx.x; // 0..N-1
383+
int idx = (bid * blockDim.x + threadIdx.x) * 8;
384+
const float epsilon = 1e-5f;
385+
386+
__shared__ float s_mean; // shared within block
387+
__shared__ float s_variance; // shared within block
388+
// temporary register(memory), .local space in ptx, addressable
389+
half pack_x[8], pack_y[8]; // 8x16 bits=128 bits.
390+
// reinterpret as float4 and load 128 bits in 1 memory issue.
391+
LDST128BITS(pack_x[0]) = LDST128BITS(x[idx]); // load 128 bits
392+
393+
float value = 0.0f;
394+
#pragma unroll
395+
for (int i = 0; i < 8; ++i) {
396+
value += ((idx + i) < N * K ? __half2float(pack_x[i]) : 0.0f);
397+
}
398+
float sum = block_reduce_sum_f32<NUM_THREADS>(value);
399+
if (tid == 0) s_mean = sum / (float) K;
400+
// wait for s_mean in shared memory to be ready for all threads
401+
__syncthreads();
402+
403+
float variance = 0.0f;
404+
#pragma unroll
405+
for (int i = 0; i < 8; ++i) {
406+
float v_hat = __half2float(pack_x[i]) - s_mean;
407+
variance += ((idx + i) < N * K ? v_hat * v_hat : 0.0f);
408+
}
409+
variance = block_reduce_sum_f32<NUM_THREADS>(variance);
410+
if (tid == 0) s_variance = rsqrtf(variance / ((float) K + epsilon));
411+
// wait for s_variance in shared memory to be ready for all threads
412+
__syncthreads();
413+
414+
#pragma unroll
415+
for (int i = 0; i < 8; ++i) {
416+
pack_y[i] = __float2half(
417+
__fmaf_rn(((__half2float(pack_x[i]) - s_mean) * s_variance), g, b)
418+
);
419+
}
420+
// reinterpret as float4 and store 128 bits in 1 memory issue.
421+
if ((idx + 7) < N * K) { LDST128BITS(y[idx]) = LDST128BITS(pack_y[0]); }
422+
// TODO: support non 8-multiple K here
423+
}
424+
379425
// --------------------- PyTorch bindings for custom kernel -----------------------
380426
#define STRINGFY(str) #str
381427
#define TORCH_BINDING_COMMON_EXTENSION(func) \
@@ -463,24 +509,6 @@ layer_norm_f32x4_kernel<(K)/4><<<grid, block>>>( \
463509
break; \
464510
}
465511

466-
void layer_norm_f32(torch::Tensor x, torch::Tensor y, float g, float b) {
467-
CHECK_TORCH_TENSOR_DTYPE(x, torch::kFloat32)
468-
CHECK_TORCH_TENSOR_DTYPE(y, torch::kFloat32)
469-
CHECK_TORCH_TENSOR_SHAPE(x, y)
470-
const int N = x.size(0);
471-
const int K = x.size(1);
472-
DISPATCH_LAYER_NORM_F32_KERNEL(N, K)
473-
}
474-
475-
void layer_norm_f32x4(torch::Tensor x, torch::Tensor y, float g, float b) {
476-
CHECK_TORCH_TENSOR_DTYPE(x, torch::kFloat32)
477-
CHECK_TORCH_TENSOR_DTYPE(y, torch::kFloat32)
478-
CHECK_TORCH_TENSOR_SHAPE(x, y)
479-
const int N = x.size(0);
480-
const int K = x.size(1);
481-
DISPATCH_LAYER_NORM_F32x4_KERNEL(N, K)
482-
}
483-
484512
// fp16
485513
#define LANUCH_LAYER_NORM_F16F16_KERNEL(K) \
486514
layer_norm_f16_f16_kernel<(K)><<<grid, block>>>( \
@@ -663,6 +691,65 @@ layer_norm_f16x8_pack_f16_kernel<(K)/8><<<grid, block>>>( \
663691
break; \
664692
}
665693

694+
#define LANUCH_LAYER_NORM_F16x8_PACK_F32_KERNEL(K) \
695+
layer_norm_f16x8_pack_f32_kernel<(K)/8><<<grid, block>>>( \
696+
reinterpret_cast<half*>(x.data_ptr()), \
697+
reinterpret_cast<half*>(y.data_ptr()), \
698+
g, b, N, (K));
699+
700+
#define DISPATCH_LAYER_NORM_F16x8_PACK_F32_KERNEL(N, K) \
701+
dim3 block((K)/8); \
702+
dim3 grid((N)); \
703+
switch ((K)) \
704+
{ \
705+
case 64: \
706+
LANUCH_LAYER_NORM_F16x8_PACK_F32_KERNEL(64) \
707+
break; \
708+
case 128: \
709+
LANUCH_LAYER_NORM_F16x8_PACK_F32_KERNEL(128) \
710+
break; \
711+
case 256: \
712+
LANUCH_LAYER_NORM_F16x8_PACK_F32_KERNEL(256) \
713+
break; \
714+
case 512: \
715+
LANUCH_LAYER_NORM_F16x8_PACK_F32_KERNEL(512) \
716+
break; \
717+
case 1024: \
718+
LANUCH_LAYER_NORM_F16x8_PACK_F32_KERNEL(1024) \
719+
break; \
720+
case 2048: \
721+
LANUCH_LAYER_NORM_F16x8_PACK_F32_KERNEL(2048) \
722+
break; \
723+
case 4096: \
724+
LANUCH_LAYER_NORM_F16x8_PACK_F32_KERNEL(4096) \
725+
break; \
726+
case 8192: \
727+
LANUCH_LAYER_NORM_F16x8_PACK_F32_KERNEL(8192) \
728+
break; \
729+
default: \
730+
throw std::runtime_error( \
731+
"only support K: 64/128/.../1024*8"); \
732+
break; \
733+
}
734+
735+
void layer_norm_f32(torch::Tensor x, torch::Tensor y, float g, float b) {
736+
CHECK_TORCH_TENSOR_DTYPE(x, torch::kFloat32)
737+
CHECK_TORCH_TENSOR_DTYPE(y, torch::kFloat32)
738+
CHECK_TORCH_TENSOR_SHAPE(x, y)
739+
const int N = x.size(0);
740+
const int K = x.size(1);
741+
DISPATCH_LAYER_NORM_F32_KERNEL(N, K)
742+
}
743+
744+
void layer_norm_f32x4(torch::Tensor x, torch::Tensor y, float g, float b) {
745+
CHECK_TORCH_TENSOR_DTYPE(x, torch::kFloat32)
746+
CHECK_TORCH_TENSOR_DTYPE(y, torch::kFloat32)
747+
CHECK_TORCH_TENSOR_SHAPE(x, y)
748+
const int N = x.size(0);
749+
const int K = x.size(1);
750+
DISPATCH_LAYER_NORM_F32x4_KERNEL(N, K)
751+
}
752+
666753
void layer_norm_f16_f16(torch::Tensor x, torch::Tensor y, float g, float b) {
667754
CHECK_TORCH_TENSOR_DTYPE(x, torch::kHalf)
668755
CHECK_TORCH_TENSOR_DTYPE(y, torch::kHalf)
@@ -699,6 +786,14 @@ void layer_norm_f16x8_pack_f16(torch::Tensor x, torch::Tensor y, float g, float
699786
DISPATCH_LAYER_NORM_F16x8_PACK_F16_KERNEL(N, K)
700787
}
701788

789+
void layer_norm_f16x8_pack_f32(torch::Tensor x, torch::Tensor y, float g, float b) {
790+
CHECK_TORCH_TENSOR_DTYPE(x, torch::kHalf)
791+
CHECK_TORCH_TENSOR_DTYPE(y, torch::kHalf)
792+
CHECK_TORCH_TENSOR_SHAPE(x, y)
793+
const int N = x.size(0);
794+
const int K = x.size(1);
795+
DISPATCH_LAYER_NORM_F16x8_PACK_F32_KERNEL(N, K)
796+
}
702797

703798
void layer_norm_f16_f32(torch::Tensor x, torch::Tensor y, float g, float b) {
704799
CHECK_TORCH_TENSOR_DTYPE(x, torch::kHalf)
@@ -713,9 +808,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
713808
TORCH_BINDING_COMMON_EXTENSION(layer_norm_f32)
714809
TORCH_BINDING_COMMON_EXTENSION(layer_norm_f32x4)
715810
TORCH_BINDING_COMMON_EXTENSION(layer_norm_f16_f16)
811+
TORCH_BINDING_COMMON_EXTENSION(layer_norm_f16_f32)
716812
TORCH_BINDING_COMMON_EXTENSION(layer_norm_f16x2_f16)
717813
TORCH_BINDING_COMMON_EXTENSION(layer_norm_f16x8_f16)
718814
TORCH_BINDING_COMMON_EXTENSION(layer_norm_f16x8_pack_f16)
719-
TORCH_BINDING_COMMON_EXTENSION(layer_norm_f16_f32)
815+
TORCH_BINDING_COMMON_EXTENSION(layer_norm_f16x8_pack_f32)
720816
}
721817

layer-norm/layer_norm.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
8282
run_benchmark(lib.layer_norm_f16x2_f16, x_f16, "f16x2f16", out_f16)
8383
run_benchmark(lib.layer_norm_f16x8_f16, x_f16, "f16x8f16", out_f16)
8484
run_benchmark(lib.layer_norm_f16x8_pack_f16, x_f16, "f16x8packf16", out_f16)
85+
run_benchmark(lib.layer_norm_f16x8_pack_f32, x_f16, "f16x8packf32", out_f16)
8586
run_benchmark(naive_layer_norm, x_f16, "f16_th")
8687
print("-" * 85)
8788

@@ -103,6 +104,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
103104
run_benchmark(lib.layer_norm_f16x2_f16, x_f16, "f16x2f16", out_f16)
104105
run_benchmark(lib.layer_norm_f16x8_f16, x_f16, "f16x8f16", out_f16)
105106
run_benchmark(lib.layer_norm_f16x8_pack_f16, x_f16, "f16x8packf16", out_f16)
107+
run_benchmark(lib.layer_norm_f16x8_pack_f32, x_f16, "f16x8packf32", out_f16)
106108
run_benchmark(naive_layer_norm, x_f16, "f16_th")
107109
print("-" * 85)
108110

@@ -121,6 +123,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
121123
run_benchmark(lib.layer_norm_f16x2_f16, x_f16, "f16x2f16", out_f16)
122124
run_benchmark(lib.layer_norm_f16x8_f16, x_f16, "f16x8f16", out_f16)
123125
run_benchmark(lib.layer_norm_f16x8_pack_f16, x_f16, "f16x8packf16", out_f16)
126+
run_benchmark(lib.layer_norm_f16x8_pack_f32, x_f16, "f16x8packf32", out_f16)
124127
run_benchmark(naive_layer_norm, x_f16, "f16_th")
125128
print("-" * 85)
126129

@@ -138,6 +141,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
138141
out_f16 = out.half()
139142
run_benchmark(lib.layer_norm_f16x8_f16, x_f16, "f16x8f16", out_f16)
140143
run_benchmark(lib.layer_norm_f16x8_pack_f16, x_f16, "f16x8packf16", out_f16)
144+
run_benchmark(lib.layer_norm_f16x8_pack_f32, x_f16, "f16x8packf32", out_f16)
141145
run_benchmark(naive_layer_norm, x_f16, "f16_th")
142146
print("-" * 85)
143147

@@ -149,6 +153,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
149153
out_f16 = torch.zeros_like(x_f16).cuda().half().contiguous()
150154
run_benchmark(lib.layer_norm_f16x8_f16, x_f16, "f16x8f16", out_f16)
151155
run_benchmark(lib.layer_norm_f16x8_pack_f16, x_f16, "f16x8packf16", out_f16)
156+
run_benchmark(lib.layer_norm_f16x8_pack_f32, x_f16, "f16x8packf32", out_f16)
152157
run_benchmark(naive_layer_norm, x_f16, "f16_th")
153158
print("-" * 85)
154159

@@ -160,5 +165,6 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
160165
out_f16 = torch.zeros_like(x_f16).cuda().half().contiguous()
161166
run_benchmark(lib.layer_norm_f16x8_f16, x_f16, "f16x8f16", out_f16)
162167
run_benchmark(lib.layer_norm_f16x8_pack_f16, x_f16, "f16x8packf16", out_f16)
168+
run_benchmark(lib.layer_norm_f16x8_pack_f32, x_f16, "f16x8packf32", out_f16)
163169
run_benchmark(naive_layer_norm, x_f16, "f16_th")
164170
print("-" * 85)

0 commit comments

Comments
 (0)