Skip to content

Commit d072254

Browse files
Ryo-not-rioaditew01
authored andcommitted
Extend vec backend with BF16 SVE intrinsics (pytorch#143666)
- Following the work in pytorch#119571, BF16 SVE intrinsics are added to the Vectorized class, providing ~1.7x speedup on `silu` and `softmax`. - Added bf16 detection in CMake - Added a guard for native NEON code to prevent compilation errors @aditew01 @maajidkhann please have a look Pull Request resolved: pytorch#143666 Approved by: https://github.com/swolchok, https://github.com/aditew01 Co-authored-by: Aditya Tewari <[email protected]>
1 parent 68dfd44 commit d072254

File tree

15 files changed

+731
-43
lines changed

15 files changed

+731
-43
lines changed

aten/src/ATen/Version.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,11 @@ std::string get_cpu_capability() {
105105
return "DEFAULT";
106106
case native::CPUCapability::ZVECTOR:
107107
return "Z VECTOR";
108+
#elif defined(HAVE_SVE256_BF16_CPU_DEFINITION)
109+
case native::CPUCapability::DEFAULT:
110+
return "DEFAULT";
111+
case native::CPUCapability::SVE256_BF16:
112+
return "SVE256_BF16";
108113
#elif defined(HAVE_SVE_CPU_DEFINITION)
109114
case native::CPUCapability::DEFAULT:
110115
return "DEFAULT";

aten/src/ATen/cpu/vec/sve/sve_helper.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ typedef svuint16_t vls_uint16_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH
1717
typedef svuint32_t vls_uint32_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
1818
typedef svuint64_t vls_uint64_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
1919
typedef svfloat16_t vls_float16_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
20+
typedef svbfloat16_t vls_bfloat16_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
2021
typedef svfloat32_t vls_float32_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
2122
typedef svfloat64_t vls_float64_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
2223

@@ -41,6 +42,7 @@ typedef svfloat64_t vls_float64_t __attribute__((arm_sve_vector_bits(VECTOR_WIDT
4142
#define ONE_U32 svdup_n_u32(1)
4243
#define ONE_U64 svdup_n_u64(1)
4344
#define ONE_F16 svdup_n_f16(1.f)
45+
#define ONE_BF16 svdup_n_bf16(1.f)
4446
#define ONE_F32 svdup_n_f32(1.f)
4547
#define ONE_F64 svdup_n_f64(1.0)
4648
#define ALL_S8_TRUE_MASK svdup_n_s8(0xff)
@@ -55,6 +57,8 @@ typedef svfloat64_t vls_float64_t __attribute__((arm_sve_vector_bits(VECTOR_WIDT
5557
#define ALL_U8_FALSE_MASK svdup_n_u8(0x00)
5658
#define ALL_F16_TRUE_MASK svreinterpret_f16_s16(ALL_S16_TRUE_MASK)
5759
#define ALL_F16_FALSE_MASK svreinterpret_f16_s16(ALL_S16_FALSE_MASK)
60+
#define ALL_BF16_TRUE_MASK svreinterpret_bf16_s16(ALL_S16_TRUE_MASK)
61+
#define ALL_BF16_FALSE_MASK svreinterpret_bf16_s16(ALL_S16_FALSE_MASK)
5862
#define ALL_F32_TRUE_MASK svreinterpret_f32_s32(ALL_S32_TRUE_MASK)
5963
#define ALL_F32_FALSE_MASK svreinterpret_f32_s32(ALL_S32_FALSE_MASK)
6064
#define ALL_F64_TRUE_MASK svreinterpret_f64_s64(ALL_S64_TRUE_MASK)

0 commit comments

Comments
 (0)