@@ -234,6 +234,97 @@ typedef float dfloat; // dequantize float
234234typedef float2 dfloat2;
235235#endif // GGML_CUDA_F16
236236
237+ #if defined(GGML_USE_HIPBLAS)
238+ #define __CUDA_ARCH__ 1300
239+
240+ #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
241+ defined (__gfx1150__) || defined(__gfx1151__)
242+ #define RDNA3
243+ #endif
244+
245+ #if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
246+ defined (__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
247+ #define RDNA2
248+ #endif
249+
250+ #ifndef __has_builtin
251+ #define __has_builtin (x ) 0
252+ #endif
253+
254+ typedef int8_t int8x4_t __attribute__ ((ext_vector_type(4 )));
255+ typedef uint8_t uint8x4_t __attribute__ ((ext_vector_type(4 )));
256+ static __device__ __forceinline__ int __vsubss4 (const int a, const int b) {
257+ const int8x4_t va = reinterpret_cast <const int8x4_t &>(a);
258+ const int8x4_t vb = reinterpret_cast <const int8x4_t &>(b);
259+ #if __has_builtin(__builtin_elementwise_sub_sat)
260+ const int8x4_t c = __builtin_elementwise_sub_sat (va, vb);
261+ return reinterpret_cast <const int &>(c);
262+ #else
263+ int8x4_t c;
264+ int16_t tmp;
265+ #pragma unroll
266+ for (int i = 0 ; i < 4 ; i++) {
267+ tmp = va[i] - vb[i];
268+ if (tmp > std::numeric_limits<int8_t >::max ()) tmp = std::numeric_limits<int8_t >::max ();
269+ if (tmp < std::numeric_limits<int8_t >::min ()) tmp = std::numeric_limits<int8_t >::min ();
270+ c[i] = tmp;
271+ }
272+ return reinterpret_cast <int &>(c);
273+ #endif // __has_builtin(__builtin_elementwise_sub_sat)
274+ }
275+
276+ static __device__ __forceinline__ int __vsub4 (const int a, const int b) {
277+ return __vsubss4 (a, b);
278+ }
279+
280+ static __device__ __forceinline__ unsigned int __vcmpeq4 (unsigned int a, unsigned int b) {
281+ const uint8x4_t & va = reinterpret_cast <const uint8x4_t &>(a);
282+ const uint8x4_t & vb = reinterpret_cast <const uint8x4_t &>(b);
283+ unsigned int c;
284+ uint8x4_t & vc = reinterpret_cast <uint8x4_t &>(c);
285+ #pragma unroll
286+ for (int i = 0 ; i < 4 ; ++i) {
287+ vc[i] = va[i] == vb[i] ? 0xff : 0x00 ;
288+ }
289+ return c;
290+ }
291+
292+ static __device__ __forceinline__ int __dp4a (const int a, const int b, int c) {
293+ #if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
294+ c = __builtin_amdgcn_sdot4 (a, b, c, false );
295+ #elif defined(RDNA3)
296+ c = __builtin_amdgcn_sudot4 ( true , a, true , b, c, false );
297+ #elif defined(__gfx1010__) || defined(__gfx900__)
298+ int tmp1;
299+ int tmp2;
300+ asm (" \n \
301+ v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \
302+ v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \
303+ v_add3_u32 %0, %1, %2, %0 \n \
304+ v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \
305+ v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \
306+ v_add3_u32 %0, %1, %2, %0 \n \
307+ "
308+ : " +v" (c), " =&v" (tmp1), " =&v" (tmp2)
309+ : " v" (a), " v" (b)
310+ );
311+ #else
312+ const int8x4_t va = reinterpret_cast <const int8x4_t &>(a);
313+ const int8x4_t vb = reinterpret_cast <const int8x4_t &>(b);
314+ c += va[0 ] * vb[0 ] + va[1 ] * vb[1 ] + va[2 ] * vb[2 ] + va[3 ] * vb[3 ];
315+ #endif
316+ return c;
317+ }
318+ #endif // defined(GGML_USE_HIPBLAS)
319+
320+ #define FP16_AVAILABLE (defined (GGML_USE_HIPBLAS) && defined (__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
321+
322+ #define FP16_MMA_AVAILABLE !(defined (GGML_USE_HIPBLAS) && defined (__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
323+
324+ static bool fp16_mma_available (const int cc) {
325+ return cc < CC_OFFSET_AMD && cc >= CC_VOLTA;
326+ }
327+
237328[[noreturn]]
238329static __device__ void no_device_code (
239330 const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
@@ -275,16 +366,28 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
275366}
276367
277368static __device__ __forceinline__ half2 warp_reduce_sum (half2 a) {
278- #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
369+ #if FP16_AVAILABLE
370+
371+ #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
279372#pragma unroll
280- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
281- a = __hadd2 (a, __shfl_xor_sync (0xffffffff , a, mask, 32 ));
282- }
283- return a;
373+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
374+ const half2 a_other = __shfl_xor_sync (0xffffffff , a, mask, 32 );
375+ reinterpret_cast <half&>(a.x ) += __low2half (a_other);
376+ reinterpret_cast <half&>(a.y ) += __high2half (a_other);
377+ }
378+ return a;
284379#else
285- GGML_UNUSED (a);
286- NO_DEVICE_CODE;
287- #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
380+ #pragma unroll
381+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
382+ a = __hadd2 (a, __shfl_xor_sync (0xffffffff , a, mask, 32 ));
383+ }
384+ return a;
385+ #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
386+
387+ #else
388+ NO_DEVICE_CODE;
389+ return a;
390+ #endif // FP16_AVAILABLE
288391}
289392
290393static __device__ __forceinline__ float warp_reduce_max (float x) {
@@ -296,37 +399,38 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
296399}
297400
298401static __device__ __forceinline__ half ggml_cuda_hmax (const half a, const half b) {
299- #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
402+ #if FP16_AVAILABLE
300403
301- #if CUDART_VERSION >= CUDART_HMAX
302- return __hmax (a, b );
404+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
405+ return __float2half ( fmaxf ( __half2float (a), __half2float (b)) );
303406#else
304- return __half2float (a) > __half2float (b) ? a : b ;
305- #endif // CUDART_VERSION >= CUDART_HMAX
407+ return __hmax (a, b) ;
408+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
306409
307410#else
308- GGML_UNUSED (a) ;
309- GGML_UNUSED (b);
310- NO_DEVICE_CODE ;
311- #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
411+ NO_DEVICE_CODE ;
412+ GGML_UNUSED (b);
413+ return a ;
414+ #endif // FP16_AVAILABLE
312415}
416+
313417static __device__ __forceinline__ half2 ggml_cuda_hmax2 (const half2 a, const half2 b) {
314418#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
315419
316420#if CUDART_VERSION >= CUDART_HMAX
317421 return __hmax2 (a, b);
318422#else
319423 half2 ret;
320- reinterpret_cast <half&>(ret.x ) = __low2float (a) > __low2float (b) ? __low2half (a) : __low2half (b );
321- reinterpret_cast <half&>(ret.y ) = __high2float (a) > __high2float (b) ? __high2half (a) : __high2half (b );
424+ reinterpret_cast <half&>(ret.x ) = __float2half ( fmaxf ( __low2float (a), __low2float (b)) );
425+ reinterpret_cast <half&>(ret.y ) = __float2half ( fmaxf ( __high2float (a), __high2float (b)) );
322426 return ret;
323427#endif // CUDART_VERSION >= CUDART_HMAX
324428
325429#else
326430 GGML_UNUSED (a);
327431 GGML_UNUSED (b);
328432 NO_DEVICE_CODE;
329- #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
433+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
330434}
331435
332436static __device__ __forceinline__ half2 warp_reduce_max (half2 x) {
@@ -350,94 +454,6 @@ static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half
350454}
351455#endif // CUDART_VERSION < 12000
352456
353- #if defined(GGML_USE_HIPBLAS)
354- #define __CUDA_ARCH__ 1300
355-
356- #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
357- defined (__gfx1150__) || defined(__gfx1151__)
358- #define RDNA3
359- #endif
360-
361- #if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
362- defined (__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
363- #define RDNA2
364- #endif
365-
366- #ifndef __has_builtin
367- #define __has_builtin (x ) 0
368- #endif
369-
370- typedef int8_t int8x4_t __attribute__ ((ext_vector_type(4 )));
371- typedef uint8_t uint8x4_t __attribute__ ((ext_vector_type(4 )));
372- static __device__ __forceinline__ int __vsubss4 (const int a, const int b) {
373- const int8x4_t va = reinterpret_cast <const int8x4_t &>(a);
374- const int8x4_t vb = reinterpret_cast <const int8x4_t &>(b);
375- #if __has_builtin(__builtin_elementwise_sub_sat)
376- const int8x4_t c = __builtin_elementwise_sub_sat (va, vb);
377- return reinterpret_cast <const int &>(c);
378- #else
379- int8x4_t c;
380- int16_t tmp;
381- #pragma unroll
382- for (int i = 0 ; i < 4 ; i++) {
383- tmp = va[i] - vb[i];
384- if (tmp > std::numeric_limits<int8_t >::max ()) tmp = std::numeric_limits<int8_t >::max ();
385- if (tmp < std::numeric_limits<int8_t >::min ()) tmp = std::numeric_limits<int8_t >::min ();
386- c[i] = tmp;
387- }
388- return reinterpret_cast <int &>(c);
389- #endif // __has_builtin(__builtin_elementwise_sub_sat)
390- }
391-
392- static __device__ __forceinline__ int __vsub4 (const int a, const int b) {
393- return __vsubss4 (a, b);
394- }
395-
396- static __device__ __forceinline__ unsigned int __vcmpeq4 (unsigned int a, unsigned int b) {
397- const uint8x4_t & va = reinterpret_cast <const uint8x4_t &>(a);
398- const uint8x4_t & vb = reinterpret_cast <const uint8x4_t &>(b);
399- unsigned int c;
400- uint8x4_t & vc = reinterpret_cast <uint8x4_t &>(c);
401- #pragma unroll
402- for (int i = 0 ; i < 4 ; ++i) {
403- vc[i] = va[i] == vb[i] ? 0xff : 0x00 ;
404- }
405- return c;
406- }
407-
408- static __device__ __forceinline__ int __dp4a (const int a, const int b, int c) {
409- #if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
410- c = __builtin_amdgcn_sdot4 (a, b, c, false );
411- #elif defined(RDNA3)
412- c = __builtin_amdgcn_sudot4 ( true , a, true , b, c, false );
413- #elif defined(__gfx1010__) || defined(__gfx900__)
414- int tmp1;
415- int tmp2;
416- asm (" \n \
417- v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \
418- v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \
419- v_add3_u32 %0, %1, %2, %0 \n \
420- v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \
421- v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \
422- v_add3_u32 %0, %1, %2, %0 \n \
423- "
424- : " +v" (c), " =&v" (tmp1), " =&v" (tmp2)
425- : " v" (a), " v" (b)
426- );
427- #else
428- const int8x4_t va = reinterpret_cast <const int8x4_t &>(a);
429- const int8x4_t vb = reinterpret_cast <const int8x4_t &>(b);
430- c += va[0 ] * vb[0 ] + va[1 ] * vb[1 ] + va[2 ] * vb[2 ] + va[3 ] * vb[3 ];
431- #endif
432- return c;
433- }
434- #endif // defined(GGML_USE_HIPBLAS)
435-
436- #define FP16_AVAILABLE defined (GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \
437- defined(RDNA1) || defined(RDNA2) || defined(RDNA3) : __CUDA_ARCH__ >= CC_PASCAL
438-
439- #define FP16_MMA_AVAILABLE !(defined (GGML_USE_HIPBLAS) && defined (__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
440-
441457// TODO: move to ggml-common.h
442458static const __device__ int8_t kvalues_iq4nl[16 ] = {-127 , -104 , -83 , -65 , -49 , -35 , -22 , -10 , 1 , 13 , 25 , 38 , 53 , 69 , 89 , 113 };
443459
0 commit comments