@@ -493,9 +493,117 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
493
493
#endif
494
494
}
495
495
496
+ __global__ void __launch_bounds__ (64 ) dequantize_weights(
497
+ int * __restrict__ B,
498
+ half* __restrict__ scaling_factors,
499
+ int * __restrict__ zeros,
500
+ half* __restrict__ C,
501
+ int G
502
+ )
503
+ {
504
+ int j_factors1 = 4 ;
505
+ int row_stride2 = 4 ;
506
+ int split_k_iters = 1 ;
507
+ static constexpr uint32_t ZERO = 0x0 ;
508
+ half B_shared[32 * (128 + 8 )];
509
+
510
+ half* B_shared_ptr2 = B_shared;
511
+
512
+ half B_shared_warp[32 ];
513
+ int OC = 512 ;
514
+
515
+ int N = blockDim .x * gridDim .x ; // 2
516
+ int col = (blockIdx .x * blockDim .x + threadIdx .x );
517
+ int row = blockIdx .y * blockDim .y + threadIdx .y ;
518
+ int index1 = 8 * col + 8 * row * N;
519
+ half* C_ptr2 = C + index1;
520
+
521
+ int index2 = col + row * N;
522
+ int * B_ptr2 = B + index2;
523
+
524
+ int index3 = col + (int )(row / G) * N;
525
+ int * zeros_ptr2 = zeros + index3;
526
+ int index4 = 8 * col + (int )(row / G) * N * 8 ;
527
+ half* scaling_factors_ptr2 = scaling_factors + index4;
528
+
529
+
530
+ uint32_t zeros_loaded = *(uint32_t *)(zeros_ptr2);
531
+ uint4 B_loaded_zero = dequantize_s4_to_fp16x2 (zeros_loaded);
532
+ uint4 B_loaded_scale = *(uint4 *)(scaling_factors_ptr2);
533
+ int j=0 ;
534
+
535
+ uint32_t B_loaded = *(uint32_t *)(B_ptr2 + j);
536
+ uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2 (B_loaded);
537
+ asm volatile (" sub.f16x2 %0, %1, %2;\n " : " =r" (B_loaded_fp16.x ) : " r" (B_loaded_fp16.x ), " r" (B_loaded_zero.x ));
538
+ asm volatile (" fma.rn.f16x2 %0, %1, %2, %3;\n " : " =r" (B_loaded_fp16.x ) : " r" (B_loaded_fp16.x ), " r" (B_loaded_scale.x ), " r" (ZERO));
539
+ asm volatile (" sub.f16x2 %0, %1, %2;\n " : " =r" (B_loaded_fp16.y ) : " r" (B_loaded_fp16.y ), " r" (B_loaded_zero.y ));
540
+ asm volatile (" fma.rn.f16x2 %0, %1, %2, %3;\n " : " =r" (B_loaded_fp16.y ) : " r" (B_loaded_fp16.y ), " r" (B_loaded_scale.y ), " r" (ZERO));
541
+ asm volatile (" sub.f16x2 %0, %1, %2;\n " : " =r" (B_loaded_fp16.z ) : " r" (B_loaded_fp16.z ), " r" (B_loaded_zero.z ));
542
+ asm volatile (" fma.rn.f16x2 %0, %1, %2, %3;\n " : " =r" (B_loaded_fp16.z ) : " r" (B_loaded_fp16.z ), " r" (B_loaded_scale.z ), " r" (ZERO));
543
+ asm volatile (" sub.f16x2 %0, %1, %2;\n " : " =r" (B_loaded_fp16.w ) : " r" (B_loaded_fp16.w ), " r" (B_loaded_zero.w ));
544
+ asm volatile (" fma.rn.f16x2 %0, %1, %2, %3;\n " : " =r" (B_loaded_fp16.w ) : " r" (B_loaded_fp16.w ), " r" (B_loaded_scale.w ), " r" (ZERO));
545
+
546
+ *(uint4 *)(B_shared_ptr2 + j) = B_loaded_fp16;
547
+
548
+ for (int i=0 ; i<8 ; ++i) {
549
+ *(C_ptr2 + i) = B_shared[i];
550
+ }
551
+ }
552
+
496
553
} // namespace awq
497
554
} // namespace vllm
498
555
556
+ torch::Tensor awq_dequantize (
557
+ torch::Tensor _kernel,
558
+ torch::Tensor _scaling_factors,
559
+ torch::Tensor _zeros,
560
+ int split_k_iters,
561
+ int thx,
562
+ int thy)
563
+ {
564
+ int in_c = _kernel.size (0 );
565
+ int qout_c = _kernel.size (1 );
566
+ int out_c = qout_c * 8 ;
567
+ int G = in_c / _scaling_factors.size (0 );
568
+
569
+ int x_thread = thx;
570
+ int y_thread = thy;
571
+
572
+ int x_blocks = 1 ;
573
+ int y_blocks = 1 ;
574
+ if (thx==0 ) {
575
+ x_thread = qout_c;
576
+ }
577
+ if (thy==0 ) {
578
+ y_thread = in_c;
579
+ }
580
+ if (thx==0 && thy==0 ) {
581
+ x_thread = 8 ;
582
+ y_thread = 8 ;
583
+ x_blocks = (int )(qout_c / 8 );
584
+ y_blocks = (int )(in_c / 8 );
585
+ }
586
+
587
+ const at::cuda::OptionalCUDAGuard device_guard (device_of (_scaling_factors));
588
+
589
+ auto options = torch::TensorOptions ().dtype (_scaling_factors.dtype ()).device (_scaling_factors.device ());
590
+ at::Tensor _de_kernel = torch::empty ({in_c, out_c}, options);
591
+
592
+ auto kernel = reinterpret_cast <int *>(_kernel.data_ptr <int >());
593
+ auto de_kernel = reinterpret_cast <half*>(_de_kernel.data_ptr <at::Half>());
594
+ auto scaling_factors = reinterpret_cast <half*>(_scaling_factors.data_ptr <at::Half>());
595
+ auto zeros = reinterpret_cast <int *>(_zeros.data_ptr <int >());
596
+
597
+ dim3 num_blocks (x_blocks, y_blocks);
598
+ dim3 threads_per_block (x_thread, y_thread);
599
+
600
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
601
+ vllm::awq::dequantize_weights<<<num_blocks, threads_per_block, 0 , stream>>> (
602
+ kernel, scaling_factors, zeros, de_kernel, G);
603
+
604
+ return _de_kernel;
605
+ }
606
+
499
607
// in_feats: M, IC [float16]
500
608
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
501
609
// scaling_factors: IC // G, OC [float16]
0 commit comments