Skip to content

Commit c3b2849

Browse files
efikspytorchmergebot
authored andcommitted
[caffe2] Add AVX512 support for box_cox operator (pytorch#143627)
Summary: Reuse templetized implementation of box_cox caffe2 operator. * Duplicate .cc file of AVX2 * change intrinsics functions to use AVX512 instructions * override templates * extend the caller to use new methods * guard AVX512 with a gflag to allow smooth transition Differential Revision: D67433457 Pull Request resolved: pytorch#143627 Approved by: https://github.com/hl475
1 parent bf7747e commit c3b2849

File tree

2 files changed

+119
-1
lines changed

2 files changed

+119
-1
lines changed
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
#ifdef CAFFE2_PERF_USE_MKL
2+
#include <immintrin.h>
3+
4+
// Enable compiler vectorized version only if numerical consistency is not
5+
// required between dev and opt versions - disabled for now
6+
#ifndef FAST_VECTORIZED_KERNEL
7+
#define CPU_CAPABILITY_AVX512
8+
#include <ATen/cpu/vec/vec.h>
9+
10+
namespace at::vec {
11+
namespace {
12+
// Implements the vectorized version of std::max() operation,
13+
// which DOESNOT propagates NaN for second argument
14+
template <typename scalar_t>
15+
Vectorized<scalar_t> max(const Vectorized<scalar_t>& a, const Vectorized<scalar_t>& b);
16+
17+
template <>
18+
Vectorized<double> max(const Vectorized<double>& a, const Vectorized<double>& b) {
19+
// std::max(NaN, nonNan) -> NaN
20+
return _mm512_max_pd(b, a);
21+
}
22+
23+
template <>
24+
Vectorized<float> max(const Vectorized<float>& a, const Vectorized<float>& b) {
25+
// std::max(NaN, nonNan) -> NaN
26+
return _mm512_max_ps(b, a);
27+
}
28+
29+
// Implements recieprocal method based on newton-rapson method
30+
// 1. user RCP approximiation
31+
// 2. update with RCP = RCP * (2 - X * RCP)
32+
template <typename scalar_t>
33+
Vectorized<scalar_t> fast_recieprocal(const Vectorized<scalar_t>& b);
34+
template <typename scalar_t>
35+
scalar_t fast_recieprocal(scalar_t b);
36+
37+
template<>
38+
Vectorized<float> fast_recieprocal(const Vectorized<float>& b) {
39+
auto minus2 = _mm512_set1_ps(-2.f);
40+
auto rcp = _mm512_rcp14_ps(b);
41+
rcp = _mm512_mul_ps(rcp, _mm512_fnmsub_ps(rcp, b, minus2));
42+
rcp = _mm512_mul_ps(rcp, _mm512_fnmsub_ps(rcp, b, minus2));
43+
return rcp;
44+
}
45+
46+
template <>
47+
float fast_recieprocal(float b) {
48+
auto minus2 = _mm_set_ss(-2.f);
49+
auto b_reg = _mm_set_ss(b);
50+
auto rcp = _mm_rcp_ss(b_reg);
51+
rcp = _mm_mul_ss(rcp, _mm_fnmsub_ss(rcp, b_reg, minus2));
52+
rcp = _mm_mul_ss(rcp, _mm_fnmsub_ss(rcp, b_reg, minus2));
53+
return _mm_cvtss_f32(rcp);
54+
}
55+
56+
template<>
57+
Vectorized<double> fast_recieprocal(const Vectorized<double>& b) {
58+
auto minus2 = _mm512_set1_pd(-2.);
59+
auto rcp = _mm512_rcp14_pd(b);
60+
rcp = _mm512_mul_pd(rcp, _mm512_fnmsub_pd(rcp, b, minus2));
61+
rcp = _mm512_mul_pd(rcp, _mm512_fnmsub_pd(rcp, b, minus2));
62+
return rcp;
63+
}
64+
65+
template <>
66+
double fast_recieprocal(double b) {
67+
return 1./b;
68+
}
69+
} // namespace
70+
} // namespace at::vec
71+
#endif
72+
73+
#include "caffe2/perfkernels/batch_box_cox_vec.h"
74+
75+
namespace caffe2::details {
76+
77+
template <typename T>
78+
void compute_batch_box_cox__avx512(
79+
std::size_t N,
80+
std::size_t D,
81+
std::size_t block_size,
82+
const T* self_data,
83+
const T* __restrict lambda1_data,
84+
const T* __restrict lambda2_data,
85+
T* output_data) {
86+
compute_batch_box_cox_vec_fma<T>(
87+
N,
88+
D,
89+
block_size,
90+
self_data,
91+
lambda1_data,
92+
lambda2_data,
93+
output_data);
94+
}
95+
96+
// Vectorized version specializations for float and double
97+
template
98+
void compute_batch_box_cox__avx512<float>(
99+
std::size_t N,
100+
std::size_t D,
101+
std::size_t block_size,
102+
const float* self_data,
103+
const float* __restrict lambda1_data,
104+
const float* __restrict lambda2_data,
105+
float* output_data);
106+
107+
template
108+
void compute_batch_box_cox__avx512<double>(
109+
std::size_t N,
110+
std::size_t D,
111+
std::size_t block_size,
112+
const double* self_data,
113+
const double* __restrict lambda1_data,
114+
const double* __restrict lambda2_data,
115+
double* output_data);
116+
117+
} // namespace caffe2::detail
118+
#endif // CAFFE2_PERF_USE_MKL

caffe2/perfkernels/batch_box_cox_vec.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ void TileIndicesInPlace(std::vector<int>& v, const std::size_t D, const std::siz
2121
}
2222
}
2323
}
24-
} // namespace
2524

2625
// MKL VML function templates.
2726
template <typename T>
@@ -307,5 +306,6 @@ void compute_batch_box_cox_vec_fma(
307306
}
308307
}
309308
}
309+
} // namespace
310310

311311
} // namespace caffe2::details

0 commit comments

Comments
 (0)