diff --git a/exllamav2/exllamav2_ext/cuda/layer_norm.cu b/exllamav2/exllamav2_ext/cuda/layer_norm.cu index b286492e..a2ffa637 100644 --- a/exllamav2/exllamav2_ext/cuda/layer_norm.cu +++ b/exllamav2/exllamav2_ext/cuda/layer_norm.cu @@ -3,8 +3,12 @@ #include "compat.cuh" #if defined(USE_ROCM) -#define NUM_WARPS (1024 / warpSize) -#define WARP_SIZE (warpSize) +#if defined(__GFX8__) || defined(__GFX9__) + #define WARP_SIZE 64 +#else + #define WARP_SIZE 32 +#endif +#define NUM_WARPS (1024 / WARP_SIZE) #else #define NUM_WARPS 32 #define WARP_SIZE 32 @@ -230,4 +234,3 @@ void layer_norm_cuda_update_y { graph->update_param_ptr(label, 0, 3, y); } - diff --git a/exllamav2/exllamav2_ext/cuda/rms_norm.cu b/exllamav2/exllamav2_ext/cuda/rms_norm.cu index 94155ade..e6aa6b5d 100644 --- a/exllamav2/exllamav2_ext/cuda/rms_norm.cu +++ b/exllamav2/exllamav2_ext/cuda/rms_norm.cu @@ -3,8 +3,12 @@ #include "compat.cuh" #if defined(USE_ROCM) -#define NUM_WARPS (1024 / warpSize) -#define WARP_SIZE (warpSize) +#if defined(__GFX8__) || defined(__GFX9__) + #define WARP_SIZE 64 +#else + #define WARP_SIZE 32 +#endif +#define NUM_WARPS (1024 / WARP_SIZE) #else #define NUM_WARPS 32 #define WARP_SIZE 32