@@ -208,18 +208,93 @@ impl DebugHexF16 for __m512i {
208
208
}
209
209
"# ;
210
210
211
- pub const LANE_FUNCTION_HELPERS : & str = r#"
212
- typedef _Float16 float16_t;
213
- typedef float float32_t;
214
- typedef double float64_t;
215
-
216
- #define __int64 long long
217
- #define __int32 int
211
+ pub const PLATFORM_C_FORWARD_DECLARATIONS : & str = r#"
212
+ #ifndef X86_DECLARATIONS
213
+ #define X86_DECLARATIONS
214
+ typedef _Float16 float16_t;
215
+ typedef float float32_t;
216
+ typedef double float64_t;
217
+
218
+ #define __int64 long long
219
+ #define __int32 int
218
220
219
- std::ostream& operator<<(std::ostream& os, _Float16 value);
220
- std::ostream& operator<<(std::ostream& os, __m128i value);
221
- std::ostream& operator<<(std::ostream& os, __m256i value);
222
- std::ostream& operator<<(std::ostream& os, __m512i value);
221
+ std::ostream& operator<<(std::ostream& os, _Float16 value);
222
+ std::ostream& operator<<(std::ostream& os, __m128i value);
223
+ std::ostream& operator<<(std::ostream& os, __m256i value);
224
+ std::ostream& operator<<(std::ostream& os, __m512i value);
225
+
226
+ #define _mm512_extract_intrinsic_test_epi8(m, lane) \
227
+ _mm_extract_epi8(_mm512_extracti64x2_epi64((m), (lane) / 16), (lane) % 16)
228
+
229
+ #define _mm512_extract_intrinsic_test_epi16(m, lane) \
230
+ _mm_extract_epi16(_mm512_extracti64x2_epi64((m), (lane) / 8), (lane) % 8)
231
+
232
+ #define _mm512_extract_intrinsic_test_epi32(m, lane) \
233
+ _mm_extract_epi32(_mm512_extracti64x2_epi64((m), (lane) / 4), (lane) % 4)
234
+
235
+ #define _mm512_extract_intrinsic_test_epi64(m, lane) \
236
+ _mm_extract_epi64(_mm512_extracti64x2_epi64((m), (lane) / 2), (lane) % 2)
237
+
238
+ #define _mm64_extract_intrinsic_test_epi8(m, lane) \
239
+ ((_mm_extract_pi16((m), (lane) / 2) >> (((lane) % 2) * 8)) & 0xFF)
240
+
241
+ #define _mm64_extract_intrinsic_test_epi32(m, lane) \
242
+ _mm_cvtsi64_si32(_mm_srli_si64(m, (lane) * 32))
243
+
244
+ // Load f16 (__m128h) and cast to integer (__m128i)
245
+ #define _mm_loadu_ph_to___m128i(mem_addr) _mm_castph_si128(_mm_loadu_ph(mem_addr))
246
+ #define _mm256_loadu_ph_to___m256i(mem_addr) _mm256_castph_si256(_mm256_loadu_ph(mem_addr))
247
+ #define _mm512_loadu_ph_to___m512i(mem_addr) _mm512_castph_si512(_mm512_loadu_ph(mem_addr))
248
+
249
+ // Load f32 (__m128) and cast to f16 (__m128h)
250
+ #define _mm_loadu_ps_to___m128h(mem_addr) _mm_castps_ph(_mm_loadu_ps(mem_addr))
251
+ #define _mm256_loadu_ps_to___m256h(mem_addr) _mm256_castps_ph(_mm256_loadu_ps(mem_addr))
252
+ #define _mm512_loadu_ps_to___m512h(mem_addr) _mm512_castps_ph(_mm512_loadu_ps(mem_addr))
253
+
254
+ // Load integer types and cast to double (__m128d, __m256d, __m512d)
255
+ #define _mm_loadu_epi16_to___m128d(mem_addr) _mm_castsi128_pd(_mm_loadu_si128((__m128i const*)(mem_addr)))
256
+ #define _mm256_loadu_epi16_to___m256d(mem_addr) _mm256_castsi256_pd(_mm256_loadu_si256((__m256i const*)(mem_addr)))
257
+ #define _mm512_loadu_epi16_to___m512d(mem_addr) _mm512_castsi512_pd(_mm512_loadu_si512((__m512i const*)(mem_addr)))
258
+
259
+ #define _mm_loadu_epi32_to___m128d(mem_addr) _mm_castsi128_pd(_mm_loadu_si128((__m128i const*)(mem_addr)))
260
+ #define _mm256_loadu_epi32_to___m256d(mem_addr) _mm256_castsi256_pd(_mm256_loadu_si256((__m256i const*)(mem_addr)))
261
+ #define _mm512_loadu_epi32_to___m512d(mem_addr) _mm512_castsi512_pd(_mm512_loadu_si512((__m512i const*)(mem_addr)))
262
+
263
+ #define _mm_loadu_epi64_to___m128d(mem_addr) _mm_castsi128_pd(_mm_loadu_si128((__m128i const*)(mem_addr)))
264
+ #define _mm256_loadu_epi64_to___m256d(mem_addr) _mm256_castsi256_pd(_mm256_loadu_si256((__m256i const*)(mem_addr)))
265
+ #define _mm512_loadu_epi64_to___m512d(mem_addr) _mm512_castsi512_pd(_mm512_loadu_si512((__m512i const*)(mem_addr)))
266
+
267
+ // Load integer types and cast to float (__m128, __m256, __m512)
268
+ #define _mm_loadu_epi16_to___m128(mem_addr) _mm_castsi128_ps(_mm_loadu_si128((__m128i const*)(mem_addr)))
269
+ #define _mm256_loadu_epi16_to___m256(mem_addr) _mm256_castsi256_ps(_mm256_loadu_si256((__m256i const*)(mem_addr)))
270
+ #define _mm512_loadu_epi16_to___m512(mem_addr) _mm512_castsi512_ps(_mm512_loadu_si512((__m512i const*)(mem_addr)))
271
+
272
+ #define _mm_loadu_epi32_to___m128(mem_addr) _mm_castsi128_ps(_mm_loadu_si128((__m128i const*)(mem_addr)))
273
+ #define _mm256_loadu_epi32_to___m256(mem_addr) _mm256_castsi256_ps(_mm256_loadu_si256((__m256i const*)(mem_addr)))
274
+ #define _mm512_loadu_epi32_to___m512(mem_addr) _mm512_castsi512_ps(_mm512_loadu_si512((__m512i const*)(mem_addr)))
275
+
276
+ #define _mm_loadu_epi64_to___m128(mem_addr) _mm_castsi128_ps(_mm_loadu_si128((__m128i const*)(mem_addr)))
277
+ #define _mm256_loadu_epi64_to___m256(mem_addr) _mm256_castsi256_ps(_mm256_loadu_si256((__m256i const*)(mem_addr)))
278
+ #define _mm512_loadu_epi64_to___m512(mem_addr) _mm512_castsi512_ps(_mm512_loadu_si512((__m512i const*)(mem_addr)))
279
+
280
+
281
+ // T1 is the `To` type, T2 is the `From` type
282
+ template<typename T1, typename T2> T1 cast(T2 x) {
283
+ if constexpr (std::is_convertible_v<T2, T1>) {
284
+ return x;
285
+ } else if constexpr (sizeof(T1) == sizeof(T2)) {
286
+ T1 ret{};
287
+ std::memcpy(&ret, &x, sizeof(T1));
288
+ return ret;
289
+ } else {
290
+ static_assert(sizeof(T1) == sizeof(T2) || std::is_convertible_v<T2, T1>,
291
+ "T2 must either be convertible to T1, or have the same size as T1!");
292
+ return T1{};
293
+ }
294
+ }
295
+ #endif
296
+ "# ;
297
+ pub const PLATFORM_C_DEFINITIONS : & str = r#"
223
298
224
299
std::ostream& operator<<(std::ostream& os, _Float16 value) {
225
300
uint16_t temp = 0;
@@ -268,74 +343,6 @@ std::ostream& operator<<(std::ostream& os, __m512i value) {
268
343
os << ss.str();
269
344
return os;
270
345
}
271
-
272
- // T1 is the `To` type, T2 is the `From` type
273
- template<typename T1, typename T2> T1 cast(T2 x) {
274
- if (std::is_convertible<T2, T1>::value) {
275
- return x;
276
- } else if (sizeof(T1) == sizeof(T2)) {
277
- T1 ret{};
278
- memcpy(&ret, &x, sizeof(T1));
279
- return ret;
280
- } else {
281
- assert("T2 must either be convertable to T1, or have the same size as T1!");
282
- }
283
- }
284
-
285
- #define _mm512_extract_intrinsic_test_epi8(m, lane) \
286
- _mm_extract_epi8(_mm512_extracti64x2_epi64((m), (lane) / 16), (lane) % 16)
287
-
288
- #define _mm512_extract_intrinsic_test_epi16(m, lane) \
289
- _mm_extract_epi16(_mm512_extracti64x2_epi64((m), (lane) / 8), (lane) % 8)
290
-
291
- #define _mm512_extract_intrinsic_test_epi32(m, lane) \
292
- _mm_extract_epi32(_mm512_extracti64x2_epi64((m), (lane) / 4), (lane) % 4)
293
-
294
- #define _mm512_extract_intrinsic_test_epi64(m, lane) \
295
- _mm_extract_epi64(_mm512_extracti64x2_epi64((m), (lane) / 2), (lane) % 2)
296
-
297
- #define _mm64_extract_intrinsic_test_epi8(m, lane) \
298
- ((_mm_extract_pi16((m), (lane) / 2) >> (((lane) % 2) * 8)) & 0xFF)
299
-
300
- #define _mm64_extract_intrinsic_test_epi32(m, lane) \
301
- _mm_cvtsi64_si32(_mm_srli_si64(m, (lane) * 32))
302
-
303
- // Load f16 (__m128h) and cast to integer (__m128i)
304
- #define _mm_loadu_ph_to___m128i(mem_addr) _mm_castph_si128(_mm_loadu_ph(mem_addr))
305
- #define _mm256_loadu_ph_to___m256i(mem_addr) _mm256_castph_si256(_mm256_loadu_ph(mem_addr))
306
- #define _mm512_loadu_ph_to___m512i(mem_addr) _mm512_castph_si512(_mm512_loadu_ph(mem_addr))
307
-
308
- // Load f32 (__m128) and cast to f16 (__m128h)
309
- #define _mm_loadu_ps_to___m128h(mem_addr) _mm_castps_ph(_mm_loadu_ps(mem_addr))
310
- #define _mm256_loadu_ps_to___m256h(mem_addr) _mm256_castps_ph(_mm256_loadu_ps(mem_addr))
311
- #define _mm512_loadu_ps_to___m512h(mem_addr) _mm512_castps_ph(_mm512_loadu_ps(mem_addr))
312
-
313
- // Load integer types and cast to double (__m128d, __m256d, __m512d)
314
- #define _mm_loadu_epi16_to___m128d(mem_addr) _mm_castsi128_pd(_mm_loadu_si128((__m128i const*)(mem_addr)))
315
- #define _mm256_loadu_epi16_to___m256d(mem_addr) _mm256_castsi256_pd(_mm256_loadu_si256((__m256i const*)(mem_addr)))
316
- #define _mm512_loadu_epi16_to___m512d(mem_addr) _mm512_castsi512_pd(_mm512_loadu_si512((__m512i const*)(mem_addr)))
317
-
318
- #define _mm_loadu_epi32_to___m128d(mem_addr) _mm_castsi128_pd(_mm_loadu_si128((__m128i const*)(mem_addr)))
319
- #define _mm256_loadu_epi32_to___m256d(mem_addr) _mm256_castsi256_pd(_mm256_loadu_si256((__m256i const*)(mem_addr)))
320
- #define _mm512_loadu_epi32_to___m512d(mem_addr) _mm512_castsi512_pd(_mm512_loadu_si512((__m512i const*)(mem_addr)))
321
-
322
- #define _mm_loadu_epi64_to___m128d(mem_addr) _mm_castsi128_pd(_mm_loadu_si128((__m128i const*)(mem_addr)))
323
- #define _mm256_loadu_epi64_to___m256d(mem_addr) _mm256_castsi256_pd(_mm256_loadu_si256((__m256i const*)(mem_addr)))
324
- #define _mm512_loadu_epi64_to___m512d(mem_addr) _mm512_castsi512_pd(_mm512_loadu_si512((__m512i const*)(mem_addr)))
325
-
326
- // Load integer types and cast to float (__m128, __m256, __m512)
327
- #define _mm_loadu_epi16_to___m128(mem_addr) _mm_castsi128_ps(_mm_loadu_si128((__m128i const*)(mem_addr)))
328
- #define _mm256_loadu_epi16_to___m256(mem_addr) _mm256_castsi256_ps(_mm256_loadu_si256((__m256i const*)(mem_addr)))
329
- #define _mm512_loadu_epi16_to___m512(mem_addr) _mm512_castsi512_ps(_mm512_loadu_si512((__m512i const*)(mem_addr)))
330
-
331
- #define _mm_loadu_epi32_to___m128(mem_addr) _mm_castsi128_ps(_mm_loadu_si128((__m128i const*)(mem_addr)))
332
- #define _mm256_loadu_epi32_to___m256(mem_addr) _mm256_castsi256_ps(_mm256_loadu_si256((__m256i const*)(mem_addr)))
333
- #define _mm512_loadu_epi32_to___m512(mem_addr) _mm512_castsi512_ps(_mm512_loadu_si512((__m512i const*)(mem_addr)))
334
-
335
- #define _mm_loadu_epi64_to___m128(mem_addr) _mm_castsi128_ps(_mm_loadu_si128((__m128i const*)(mem_addr)))
336
- #define _mm256_loadu_epi64_to___m256(mem_addr) _mm256_castsi256_ps(_mm256_loadu_si256((__m256i const*)(mem_addr)))
337
- #define _mm512_loadu_epi64_to___m512(mem_addr) _mm512_castsi512_ps(_mm512_loadu_si512((__m512i const*)(mem_addr)))
338
-
339
346
"# ;
340
347
341
348
pub const X86_CONFIGURATIONS : & str = r#"
0 commit comments