14
14
#define INT4 (value ) (reinterpret_cast <int4 *>(&(value))[0 ])
15
15
#define FLOAT4 (value ) (reinterpret_cast <float4 *>(&(value))[0 ])
16
16
#define HALF2 (value ) (reinterpret_cast <half2*>(&(value))[0 ])
17
+ #define BFLOAT2 (value ) (reinterpret_cast <__nv_bfloat162*>(&(value))[0 ])
18
+ #define LDST128BITS (value ) (reinterpret_cast <float4 *>(&(value))[0 ])
17
19
18
20
// -------------------------------------- FP32 --------------------------------------
19
21
// Warp Reduce Sum
@@ -325,6 +327,55 @@ __global__ void layer_norm_f16_f32_kernel(half* x, half* y, float g, float b, in
325
327
}
326
328
}
327
329
330
+ template <const int NUM_THREADS=256 >
331
+ __global__ void layer_norm_f16x8_pack_f16_kernel (half* x, half* y, float g, float b, int N, int K) {
332
+ int tid = threadIdx .x ; // 0..K-1
333
+ int bid = blockIdx .x ; // 0..N-1
334
+ int idx = (bid * blockDim .x + threadIdx .x ) * 8 ;
335
+ const half epsilon = __float2half (1e-5f );
336
+ const half g_ = __float2half (g);
337
+ const half b_ = __float2half (b);
338
+ const half K_ = __int2half_rn (K);
339
+ const half z_ = __float2half (0 .0f );
340
+
341
+ __shared__ half s_mean; // shared within block
342
+ __shared__ half s_variance; // shared within block
343
+ // temporary register(memory), .local space in ptx, addressable
344
+ half pack_x[8 ], pack_y[8 ]; // 8x16 bits=128 bits.
345
+ // reinterpret as float4 and load 128 bits in 1 memory issue.
346
+ LDST128BITS (pack_x[0 ]) = LDST128BITS (x[idx]); // load 128 bits
347
+
348
+ half value = z_;
349
+ #pragma unroll
350
+ for (int i = 0 ; i < 8 ; ++i) {
351
+ value += ((idx + i) < N * K ? pack_x[i] : z_);
352
+ }
353
+ half sum = block_reduce_sum_f16_f16<NUM_THREADS>(value);
354
+ if (tid == 0 ) s_mean = sum / K_;
355
+ // wait for s_mean in shared memory to be ready for all threads
356
+ __syncthreads ();
357
+
358
+ half variance = z_;
359
+ #pragma unroll
360
+ for (int i = 0 ; i < 8 ; ++i) {
361
+ half v_hat = pack_x[i] - s_mean;
362
+ variance += ((idx + i) < N * K ? v_hat * v_hat : z_);
363
+ }
364
+ variance = block_reduce_sum_f16_f16<NUM_THREADS>(variance);
365
+ if (tid == 0 ) s_variance = hrsqrt (variance / (K_ + epsilon));
366
+ // wait for s_variance in shared memory to be ready for all threads
367
+ __syncthreads ();
368
+
369
+ #pragma unroll
370
+ for (int i = 0 ; i < 8 ; ++i) {
371
+ // TODO: use __hfma2, __hsub2, __hmul2 here
372
+ pack_y[i] = __hfma ((pack_x[i] - s_mean) * s_variance, g_, b_);
373
+ }
374
+ // reinterpret as float4 and store 128 bits in 1 memory issue.
375
+ if ((idx + 7 ) < N * K) { LDST128BITS (y[idx]) = LDST128BITS (pack_y[0 ]); }
376
+ // TODO: support non 8-multiple K here
377
+ }
378
+
328
379
// --------------------- PyTorch bindings for custom kernel -----------------------
329
380
#define STRINGFY (str ) #str
330
381
#define TORCH_BINDING_COMMON_EXTENSION (func ) \
@@ -350,7 +401,7 @@ layer_norm_f32_kernel<(K)><<<grid, block>>>( \
350
401
351
402
#define DISPATCH_LAYER_NORM_F32_KERNEL (N, K ) \
352
403
dim3 block ((K)); \
353
- dim3 grid ((N)); \
404
+ dim3 grid ((N)); \
354
405
switch ((K)) \
355
406
{ \
356
407
case 64 : \
@@ -382,7 +433,7 @@ layer_norm_f32x4_kernel<(K)/4><<<grid, block>>>( \
382
433
383
434
#define DISPATCH_LAYER_NORM_F32x4_KERNEL (N, K ) \
384
435
dim3 block ((K)/4); \
385
- dim3 grid ((N)); \
436
+ dim3 grid ((N)); \
386
437
switch ((K)) \
387
438
{ \
388
439
case 64 : \
@@ -400,9 +451,15 @@ layer_norm_f32x4_kernel<(K)/4><<<grid, block>>>( \
400
451
case 1024 : \
401
452
LANUCH_LAYER_NORM_F32x4_KERNEL (1024 ) \
402
453
break ; \
454
+ case 2048 : \
455
+ LANUCH_LAYER_NORM_F32x4_KERNEL (2048 ) \
456
+ break ; \
457
+ case 4096 : \
458
+ LANUCH_LAYER_NORM_F32x4_KERNEL (4096 ) \
459
+ break ; \
403
460
default : \
404
461
throw std::runtime_error ( \
405
- " only support K: 64/128/256/512/ 1024" ); \
462
+ " only support K: 64/128/.../ 1024*4 " ); \
406
463
break ; \
407
464
}
408
465
@@ -433,7 +490,7 @@ layer_norm_f16_f16_kernel<(K)><<<grid, block>>>( \
433
490
434
491
#define DISPATCH_LAYER_NORM_F16F16_KERNEL (N, K ) \
435
492
dim3 block ((K)); \
436
- dim3 grid ((N)); \
493
+ dim3 grid ((N)); \
437
494
switch ((K)) \
438
495
{ \
439
496
case 64 : \
@@ -465,7 +522,7 @@ layer_norm_f16_f32_kernel<(K)><<<grid, block>>>( \
465
522
466
523
#define DISPATCH_LAYER_NORM_F16F32_KERNEL (N, K ) \
467
524
dim3 block ((K)); \
468
- dim3 grid ((N)); \
525
+ dim3 grid ((N)); \
469
526
switch ((K)) \
470
527
{ \
471
528
case 64 : \
@@ -497,7 +554,7 @@ layer_norm_f16x2_f16_kernel<(K)/2><<<grid, block>>>( \
497
554
498
555
#define DISPATCH_LAYER_NORM_F16x2F16_KERNEL (N, K ) \
499
556
dim3 block ((K)/2); \
500
- dim3 grid ((N)); \
557
+ dim3 grid ((N)); \
501
558
switch ((K)) \
502
559
{ \
503
560
case 64 : \
@@ -515,9 +572,12 @@ layer_norm_f16x2_f16_kernel<(K)/2><<<grid, block>>>( \
515
572
case 1024 : \
516
573
LANUCH_LAYER_NORM_F16x2F16_KERNEL (1024 ) \
517
574
break ; \
575
+ case 2048 : \
576
+ LANUCH_LAYER_NORM_F16x2F16_KERNEL (2048 ) \
577
+ break ; \
518
578
default : \
519
579
throw std::runtime_error ( \
520
- " only support K: 64/128/256/512/ 1024" ); \
580
+ " only support K: 64/128/.../ 1024*2 " ); \
521
581
break ; \
522
582
}
523
583
@@ -529,7 +589,7 @@ layer_norm_f16x8_f16_kernel<(K)/8><<<grid, block>>>( \
529
589
530
590
#define DISPATCH_LAYER_NORM_F16x8F16_KERNEL (N, K ) \
531
591
dim3 block ((K)/8); \
532
- dim3 grid ((N)); \
592
+ dim3 grid ((N)); \
533
593
switch ((K)) \
534
594
{ \
535
595
case 64 : \
@@ -547,12 +607,62 @@ layer_norm_f16x8_f16_kernel<(K)/8><<<grid, block>>>( \
547
607
case 1024 : \
548
608
LANUCH_LAYER_NORM_F16x8F16_KERNEL (1024 ) \
549
609
break ; \
610
+ case 2048 : \
611
+ LANUCH_LAYER_NORM_F16x8F16_KERNEL (2048 ) \
612
+ break ; \
613
+ case 4096 : \
614
+ LANUCH_LAYER_NORM_F16x8F16_KERNEL (4096 ) \
615
+ break ; \
616
+ case 8192 : \
617
+ LANUCH_LAYER_NORM_F16x8F16_KERNEL (8192 ) \
618
+ break ; \
550
619
default : \
551
620
throw std::runtime_error ( \
552
- " only support K: 64/128/256/512/ 1024" ); \
621
+ " only support K: 64/128/.../ 1024*8 " ); \
553
622
break ; \
554
623
}
555
624
625
+ #define LANUCH_LAYER_NORM_F16x8_PACK_F16_KERNEL (K ) \
626
+ layer_norm_f16x8_pack_f16_kernel<(K)/8 ><<<grid, block>>> ( \
627
+ reinterpret_cast <half*>(x.data_ptr()), \
628
+ reinterpret_cast <half*>(y.data_ptr()), \
629
+ g, b, N, (K));
630
+
631
+ #define DISPATCH_LAYER_NORM_F16x8_PACK_F16_KERNEL (N, K ) \
632
+ dim3 block ((K)/8); \
633
+ dim3 grid ((N)); \
634
+ switch ((K)) \
635
+ { \
636
+ case 64 : \
637
+ LANUCH_LAYER_NORM_F16x8_PACK_F16_KERNEL (64 ) \
638
+ break ; \
639
+ case 128 : \
640
+ LANUCH_LAYER_NORM_F16x8_PACK_F16_KERNEL (128 ) \
641
+ break ; \
642
+ case 256 : \
643
+ LANUCH_LAYER_NORM_F16x8_PACK_F16_KERNEL (256 ) \
644
+ break ; \
645
+ case 512 : \
646
+ LANUCH_LAYER_NORM_F16x8_PACK_F16_KERNEL (512 ) \
647
+ break ; \
648
+ case 1024 : \
649
+ LANUCH_LAYER_NORM_F16x8_PACK_F16_KERNEL (1024 ) \
650
+ break ; \
651
+ case 2048 : \
652
+ LANUCH_LAYER_NORM_F16x8_PACK_F16_KERNEL (2048 ) \
653
+ break ; \
654
+ case 4096 : \
655
+ LANUCH_LAYER_NORM_F16x8_PACK_F16_KERNEL (4096 ) \
656
+ break ; \
657
+ case 8192 : \
658
+ LANUCH_LAYER_NORM_F16x8_PACK_F16_KERNEL (8192 ) \
659
+ break ; \
660
+ default : \
661
+ throw std::runtime_error ( \
662
+ " only support K: 64/128/.../1024*8" ); \
663
+ break ; \
664
+ }
665
+
556
666
void layer_norm_f16_f16 (torch::Tensor x, torch::Tensor y, float g, float b) {
557
667
CHECK_TORCH_TENSOR_DTYPE (x, torch::kHalf )
558
668
CHECK_TORCH_TENSOR_DTYPE (y, torch::kHalf )
@@ -580,6 +690,16 @@ void layer_norm_f16x8_f16(torch::Tensor x, torch::Tensor y, float g, float b) {
580
690
DISPATCH_LAYER_NORM_F16x8F16_KERNEL (N, K)
581
691
}
582
692
693
+ void layer_norm_f16x8_pack_f16 (torch::Tensor x, torch::Tensor y, float g, float b) {
694
+ CHECK_TORCH_TENSOR_DTYPE (x, torch::kHalf )
695
+ CHECK_TORCH_TENSOR_DTYPE (y, torch::kHalf )
696
+ CHECK_TORCH_TENSOR_SHAPE (x, y)
697
+ const int N = x.size (0 );
698
+ const int K = x.size (1 );
699
+ DISPATCH_LAYER_NORM_F16x8_PACK_F16_KERNEL (N, K)
700
+ }
701
+
702
+
583
703
void layer_norm_f16_f32 (torch::Tensor x, torch::Tensor y, float g, float b) {
584
704
CHECK_TORCH_TENSOR_DTYPE (x, torch::kHalf )
585
705
CHECK_TORCH_TENSOR_DTYPE (y, torch::kHalf )
@@ -595,6 +715,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
595
715
TORCH_BINDING_COMMON_EXTENSION (layer_norm_f16_f16)
596
716
TORCH_BINDING_COMMON_EXTENSION (layer_norm_f16x2_f16)
597
717
TORCH_BINDING_COMMON_EXTENSION (layer_norm_f16x8_f16)
718
+ TORCH_BINDING_COMMON_EXTENSION (layer_norm_f16x8_pack_f16)
598
719
TORCH_BINDING_COMMON_EXTENSION (layer_norm_f16_f32)
599
720
}
600
721
0 commit comments