|
| 1 | +/* |
| 2 | +Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turboderp/exllama |
| 3 | +*/ |
| 4 | + |
| 5 | +#ifndef _matrix_view_cuh |
| 6 | +#define _matrix_view_cuh |
| 7 | + |
| 8 | +#include <cuda_runtime.h> |
| 9 | +#include <cuda_fp16.h> |
| 10 | + |
| 11 | +#include "qdq_util.cuh" |
| 12 | + |
| 13 | +namespace vllm { |
| 14 | +namespace gptq { |
| 15 | + |
| 16 | +class MatrixView_half |
| 17 | +{ |
| 18 | +public: |
| 19 | + const half* data; |
| 20 | + const int height; |
| 21 | + const int width; |
| 22 | + |
| 23 | + __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width) |
| 24 | + : data(data), height(height), width(width) |
| 25 | + { } |
| 26 | + |
| 27 | + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } |
| 28 | + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } |
| 29 | + __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); } |
| 30 | + __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; } |
| 31 | + |
| 32 | + __device__ __forceinline__ void item4(half (&items)[4], int row, int column) const |
| 33 | + { |
| 34 | + half2* ptr = (half2*) item_ptr(row, column); |
| 35 | + half2 i01 = ptr[0]; |
| 36 | + half2 i23 = ptr[1]; |
| 37 | + items[0] = __low2half(i01); |
| 38 | + items[1] = __high2half(i01); |
| 39 | + items[2] = __low2half(i23); |
| 40 | + items[3] = __high2half(i23); |
| 41 | + } |
| 42 | + __device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const |
| 43 | + { |
| 44 | + half2* ptr = (half2*)item_ptr(row, column); |
| 45 | + half2 i01 = ptr[0]; |
| 46 | + half2 i23 = ptr[1]; |
| 47 | + items[0] = __half2float(__low2half(i01)); |
| 48 | + items[1] = __half2float(__high2half(i01)); |
| 49 | + items[2] = __half2float(__low2half(i23)); |
| 50 | + items[3] = __half2float(__high2half(i23)); |
| 51 | + } |
| 52 | + |
| 53 | + __device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const |
| 54 | + { |
| 55 | + half2* ptr = (half2*)item_ptr(row, column); |
| 56 | + half2 i01 = ptr[0]; |
| 57 | + half2 i23 = ptr[1]; |
| 58 | + items[0] = __half2half2(__low2half(i01)); |
| 59 | + items[1] = __half2half2(__high2half(i01)); |
| 60 | + items[2] = __half2half2(__low2half(i23)); |
| 61 | + items[3] = __half2half2(__high2half(i23)); |
| 62 | + } |
| 63 | +}; |
| 64 | + |
| 65 | +class MatrixView_half_rw |
| 66 | +{ |
| 67 | +public: |
| 68 | + half* data; |
| 69 | + const int height; |
| 70 | + const int width; |
| 71 | + |
| 72 | + __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width) |
| 73 | + : data(data), height(height), width(width) |
| 74 | + { } |
| 75 | + |
| 76 | + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } |
| 77 | + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } |
| 78 | + __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; } |
| 79 | + __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; } |
| 80 | + __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; } |
| 81 | + |
| 82 | + __device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3) |
| 83 | + { |
| 84 | + half2 v01 = __halves2half2(v0, v1); |
| 85 | + half2 v23 = __halves2half2(v2, v3); |
| 86 | + half2* ptr = (half2*) item_ptr(row, column); |
| 87 | + ptr[0] = v01; |
| 88 | + ptr[1] = v23; |
| 89 | + } |
| 90 | +}; |
| 91 | + |
| 92 | +class MatrixView_q4_row |
| 93 | +{ |
| 94 | +public: |
| 95 | + const uint32_t* data; |
| 96 | + const int height; |
| 97 | + const int width; |
| 98 | + |
| 99 | + __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width) |
| 100 | + : data(data), height(height), width(width) |
| 101 | + { } |
| 102 | + |
| 103 | + __device__ __forceinline__ int item(int row, int column) const |
| 104 | + { |
| 105 | + int shift = (column & 0x07) * 4; |
| 106 | + return (data[row * width / 8 + column / 8] >> shift) & 0x0f; |
| 107 | + } |
| 108 | + |
| 109 | + __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const |
| 110 | + { |
| 111 | + int shift = (column & 0x07) * 4; |
| 112 | + uint32_t d = data[row * width / 8 + column / 8] >> shift; |
| 113 | + items[0] = d & 0x0f; |
| 114 | + items[1] = (d >> 4) & 0x0f; |
| 115 | + } |
| 116 | + |
| 117 | + __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const |
| 118 | + { |
| 119 | + int shift = (column & 0x07) * 4; |
| 120 | + uint32_t d = data[row * width / 8 + column / 8] >> shift; |
| 121 | + items[0] = d & 0x0f; |
| 122 | + items[1] = (d >> 4) & 0x0f; |
| 123 | + items[2] = (d >> 8) & 0x0f; |
| 124 | + items[3] = (d >> 12) & 0x0f; |
| 125 | + } |
| 126 | +}; |
| 127 | + |
| 128 | +class MatrixView_q4_column |
| 129 | +{ |
| 130 | +public: |
| 131 | + const uint32_t* data; |
| 132 | + const int height; |
| 133 | + const int width; |
| 134 | + |
| 135 | + __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width) |
| 136 | + : data(data), height(height), width(width) |
| 137 | + { } |
| 138 | + |
| 139 | + __device__ __forceinline__ int item(int row, int column) const |
| 140 | + { |
| 141 | + int shift = (row & 0x07) * 4; |
| 142 | + return (data[row / 8 * width + column] >> shift) & 0x0f; |
| 143 | + } |
| 144 | + |
| 145 | + __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; } |
| 146 | + __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; } |
| 147 | +}; |
| 148 | + |
| 149 | +} // namespace gptq |
| 150 | +} // namespace vllm |
| 151 | +#endif |
0 commit comments