@@ -633,14 +633,14 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
633633        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,           argsort_f32_i32_asc,            true );
634634        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,          argsort_f32_i32_desc,           true );
635635        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,                leaky_relu_f32,                 true );
636-         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,        flash_attn_ext_f16_h64,         true );
637-         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,        flash_attn_ext_f16_h80,         true );
638-         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,        flash_attn_ext_f16_h96,         true );
639-         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,       flash_attn_ext_f16_h112,        true );
640-         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,       flash_attn_ext_f16_h128,        true );
641-         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,       flash_attn_ext_f16_h256,        true );
642-         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,   flash_attn_ext_vec_f16_h128,    true );
643-         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,   flash_attn_ext_vec_f16_h256,    true );
636+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,        flash_attn_ext_f16_h64,         ctx-> support_simdgroup_mm );
637+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,        flash_attn_ext_f16_h80,         ctx-> support_simdgroup_mm );
638+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,        flash_attn_ext_f16_h96,         ctx-> support_simdgroup_mm );
639+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,       flash_attn_ext_f16_h112,        ctx-> support_simdgroup_mm );
640+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,       flash_attn_ext_f16_h128,        ctx-> support_simdgroup_mm );
641+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,       flash_attn_ext_f16_h256,        ctx-> support_simdgroup_mm );
642+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,   flash_attn_ext_vec_f16_h128,    ctx-> support_simdgroup_reduction );
643+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,   flash_attn_ext_vec_f16_h256,    ctx-> support_simdgroup_reduction );
644644        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_F16,                   cpy_f32_f16,                    true );
645645        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_F32,                   cpy_f32_f32,                    true );
646646        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,                  cpy_f32_q8_0,                   true );
@@ -772,8 +772,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
772772        case  GGML_OP_TIMESTEP_EMBEDDING:
773773        case  GGML_OP_ARGSORT:
774774        case  GGML_OP_LEAKY_RELU:
775-         case  GGML_OP_FLASH_ATTN_EXT:
776775            return  true ;
776+         case  GGML_OP_FLASH_ATTN_EXT:
777+             return  ctx->support_simdgroup_mm ; //  TODO: over-restricted for vec-kernels
777778        case  GGML_OP_MUL_MAT:
778779        case  GGML_OP_MUL_MAT_ID:
779780            return  ctx->support_simdgroup_reduction  &&
0 commit comments