Skip to content

Commit 06bb848

Browse files
feat: added custom helper functions (that helped load intrinsic
arguments in Rust) to C++ testfiles. Also added extra compilation flags
1 parent ab00695 commit 06bb848

File tree

2 files changed

+46
-7
lines changed

2 files changed

+46
-7
lines changed

crates/intrinsic-test/src/x86/compile.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ pub fn build_cpp_compilation(config: &ProcessedCli) -> Option<CppCompilation> {
2424
"-mavx512dq",
2525
"-mavx512cd",
2626
"-mavx512fp16",
27+
"-msha512",
28+
"-msm4",
2729
"-ferror-limit=1000",
2830
]);
2931

crates/intrinsic-test/src/x86/config.rs

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -270,17 +270,17 @@ std::ostream& operator<<(std::ostream& os, __m512i value) {
270270
}
271271
272272
// T1 is the `To` type, T2 is the `From` type
273-
template<typename T1, typename T2> T1 cast(T2 x) {{
274-
if (std::is_convertible<T2, T1>::value) {{
273+
template<typename T1, typename T2> T1 cast(T2 x) {
274+
if (std::is_convertible<T2, T1>::value) {
275275
return x;
276-
}} else if (sizeof(T1) == sizeof(T2)) {{
277-
T1 ret{{}};
276+
} else if (sizeof(T1) == sizeof(T2)) {
277+
T1 ret{};
278278
memcpy(&ret, &x, sizeof(T1));
279279
return ret;
280-
}} else {{
280+
} else {
281281
assert("T2 must either be convertable to T1, or have the same size as T1!");
282-
}}
283-
}}
282+
}
283+
}
284284
285285
#define _mm512_extract_intrinsic_test_epi8(m, lane) \
286286
_mm_extract_epi8(_mm512_extracti64x2_epi64((m), (lane) / 16), (lane) % 16)
@@ -299,6 +299,43 @@ template<typename T1, typename T2> T1 cast(T2 x) {{
299299
300300
#define _mm64_extract_intrinsic_test_epi32(m, lane) \
301301
_mm_cvtsi64_si32(_mm_srli_si64(m, (lane) * 32))
302+
303+
// Load f16 (__m128h) and cast to integer (__m128i)
304+
#define _mm_loadu_ph_to___m128i(mem_addr) _mm_castph_si128(_mm_loadu_ph(mem_addr))
305+
#define _mm256_loadu_ph_to___m256i(mem_addr) _mm256_castph_si256(_mm256_loadu_ph(mem_addr))
306+
#define _mm512_loadu_ph_to___m512i(mem_addr) _mm512_castph_si512(_mm512_loadu_ph(mem_addr))
307+
308+
// Load f32 (__m128) and cast to f16 (__m128h)
309+
#define _mm_loadu_ps_to___m128h(mem_addr) _mm_castps_ph(_mm_loadu_ps(mem_addr))
310+
#define _mm256_loadu_ps_to___m256h(mem_addr) _mm256_castps_ph(_mm256_loadu_ps(mem_addr))
311+
#define _mm512_loadu_ps_to___m512h(mem_addr) _mm512_castps_ph(_mm512_loadu_ps(mem_addr))
312+
313+
// Load integer types and cast to double (__m128d, __m256d, __m512d)
314+
#define _mm_loadu_epi16_to___m128d(mem_addr) _mm_castsi128_pd(_mm_loadu_si128((__m128i const*)(mem_addr)))
315+
#define _mm256_loadu_epi16_to___m256d(mem_addr) _mm256_castsi256_pd(_mm256_loadu_si256((__m256i const*)(mem_addr)))
316+
#define _mm512_loadu_epi16_to___m512d(mem_addr) _mm512_castsi512_pd(_mm512_loadu_si512((__m512i const*)(mem_addr)))
317+
318+
#define _mm_loadu_epi32_to___m128d(mem_addr) _mm_castsi128_pd(_mm_loadu_si128((__m128i const*)(mem_addr)))
319+
#define _mm256_loadu_epi32_to___m256d(mem_addr) _mm256_castsi256_pd(_mm256_loadu_si256((__m256i const*)(mem_addr)))
320+
#define _mm512_loadu_epi32_to___m512d(mem_addr) _mm512_castsi512_pd(_mm512_loadu_si512((__m512i const*)(mem_addr)))
321+
322+
#define _mm_loadu_epi64_to___m128d(mem_addr) _mm_castsi128_pd(_mm_loadu_si128((__m128i const*)(mem_addr)))
323+
#define _mm256_loadu_epi64_to___m256d(mem_addr) _mm256_castsi256_pd(_mm256_loadu_si256((__m256i const*)(mem_addr)))
324+
#define _mm512_loadu_epi64_to___m512d(mem_addr) _mm512_castsi512_pd(_mm512_loadu_si512((__m512i const*)(mem_addr)))
325+
326+
// Load integer types and cast to float (__m128, __m256, __m512)
327+
#define _mm_loadu_epi16_to___m128(mem_addr) _mm_castsi128_ps(_mm_loadu_si128((__m128i const*)(mem_addr)))
328+
#define _mm256_loadu_epi16_to___m256(mem_addr) _mm256_castsi256_ps(_mm256_loadu_si256((__m256i const*)(mem_addr)))
329+
#define _mm512_loadu_epi16_to___m512(mem_addr) _mm512_castsi512_ps(_mm512_loadu_si512((__m512i const*)(mem_addr)))
330+
331+
#define _mm_loadu_epi32_to___m128(mem_addr) _mm_castsi128_ps(_mm_loadu_si128((__m128i const*)(mem_addr)))
332+
#define _mm256_loadu_epi32_to___m256(mem_addr) _mm256_castsi256_ps(_mm256_loadu_si256((__m256i const*)(mem_addr)))
333+
#define _mm512_loadu_epi32_to___m512(mem_addr) _mm512_castsi512_ps(_mm512_loadu_si512((__m512i const*)(mem_addr)))
334+
335+
#define _mm_loadu_epi64_to___m128(mem_addr) _mm_castsi128_ps(_mm_loadu_si128((__m128i const*)(mem_addr)))
336+
#define _mm256_loadu_epi64_to___m256(mem_addr) _mm256_castsi256_ps(_mm256_loadu_si256((__m256i const*)(mem_addr)))
337+
#define _mm512_loadu_epi64_to___m512(mem_addr) _mm512_castsi512_ps(_mm512_loadu_si512((__m512i const*)(mem_addr)))
338+
302339
"#;
303340

304341
pub const X86_CONFIGURATIONS: &str = r#"

0 commit comments

Comments
 (0)