@@ -104,6 +104,8 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
104104 explicit FP16Vec16 (bool , void * ptr)
105105 : reg(_mm256_stream_load_si256((__m256i*)ptr)) {}
106106
107+ explicit FP16Vec16 (const c10::Half v) : reg(_mm256_set1_epi16(v.x)) {}
108+
107109 explicit FP16Vec16 (const FP32Vec16&);
108110
109111 void save (void * ptr) const { _mm256_storeu_si256 ((__m256i*)ptr, reg); }
@@ -141,6 +143,8 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
141143 explicit BF16Vec16 (bool , void * ptr)
142144 : reg(_mm256_stream_load_si256((__m256i*)ptr)) {}
143145
146+ explicit BF16Vec16 (const c10::BFloat16 v) : reg(_mm256_set1_epi16(v.x)) {}
147+
144148 explicit BF16Vec16 (const FP32Vec16&);
145149
146150 void save (void * ptr) const { _mm256_storeu_si256 ((__m256i*)ptr, reg); }
@@ -350,6 +354,22 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
350354
351355 explicit FP32Vec16 (__m512 data) : reg(data) {}
352356
357+ // de-pack 4 bit values
358+ explicit FP32Vec16 (int64_t value, const FP32Vec16& lut) {
359+ int64_t mask_0 = 0x0F0F0F0F0F0F0F0F ;
360+ int64_t mask_1 = 0xF0F0F0F0F0F0F0F0 ;
361+ int64_t value_0 = value & mask_0;
362+ int64_t value_1 = value & mask_1;
363+ __m128i vec_0 = _mm_movpi64_epi64 ((__m64)value_0);
364+ __m128i vec_1 = _mm_movpi64_epi64 ((__m64)value_1);
365+ vec_0 = _mm_cvtepu8_epi16 (vec_0);
366+ vec_1 = _mm_cvtepu8_epi16 (vec_1);
367+ vec_1 = _mm_slli_epi16 (vec_1, 4 );
368+ __m128i vec = _mm_or_si128 (vec_0, vec_1);
369+ __m512i vec_i32 = _mm512_cvtepu8_epi32 (vec);
370+ reg = _mm512_permutexvar_ps (vec_i32, lut.reg );
371+ }
372+
353373 explicit FP32Vec16 (const FP32Vec4& data)
354374 : reg((__m512)_mm512_inserti32x4(
355375 _mm512_inserti32x4 (
@@ -426,14 +446,6 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
426446
427447 float get_last_elem () const { return _mm512_cvtss_f32 (reg); }
428448
429- template <int group_size>
430- float reduce_sub_sum (int idx) {
431- static_assert (VEC_ELEM_NUM % group_size == 0 );
432- constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
433- __mmask16 mask = _cvtu32_mask16 (base_mask << (idx * group_size));
434- return _mm512_mask_reduce_add_ps (mask, reg);
435- }
436-
437449 void save (float * ptr) const { _mm512_storeu_ps (ptr, reg); }
438450
439451 void save (float * ptr, const int elem_num) const {
@@ -755,6 +767,25 @@ inline void non_temporal_save(BF16Vec16& vec, void* ptr) {
755767inline void non_temporal_save (FP32Vec16& vec, void * ptr) {
756768 _mm512_stream_ps ((float *)ptr, vec.reg );
757769}
770+
771+ static void interleave_save (const BF16Vec16& vec0, const BF16Vec16& vec1,
772+ void * ptr) {
773+ __m512i vec_0 = _mm512_cvtepu16_epi32 (vec0.reg );
774+ __m512i vec_1 = _mm512_cvtepu16_epi32 (vec1.reg );
775+ vec_1 = _mm512_slli_epi32 (vec_1, 16 );
776+ vec_0 = _mm512_or_si512 (vec_0, vec_1);
777+ _mm512_storeu_epi32 (ptr, vec_0);
778+ }
779+
780+ static void interleave_save (const FP16Vec16& vec0, const FP16Vec16& vec1,
781+ void * ptr) {
782+ __m512i vec_0 = _mm512_cvtepu16_epi32 (vec0.reg );
783+ __m512i vec_1 = _mm512_cvtepu16_epi32 (vec1.reg );
784+ vec_1 = _mm512_slli_epi32 (vec_1, 16 );
785+ vec_0 = _mm512_or_si512 (vec_0, vec_1);
786+ _mm512_storeu_epi32 (ptr, vec_0);
787+ }
788+
758789#endif
759790
760791inline void mem_barrier () { _mm_mfence (); }
0 commit comments