Skip to content

Commit 97dc3f6

Browse files
committed
Add amx-avx512
1 parent 8dcc435 commit 97dc3f6

File tree

2 files changed

+73
-1
lines changed

2 files changed

+73
-1
lines changed

crates/core_arch/src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434
f16,
3535
aarch64_unstable_target_feature,
3636
bigint_helper_methods,
37-
funnel_shifts
37+
funnel_shifts,
38+
avx10_target_feature
3839
)]
3940
#![cfg_attr(test, feature(test, abi_vectorcall, stdarch_internal))]
4041
#![deny(clippy::missing_inline_in_public_items)]

crates/core_arch/src/x86_64/amx.rs

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use crate::core_arch::{simd::*, x86::*};
2+
13
#[cfg(test)]
24
use stdarch_test::assert_instr;
35

@@ -380,6 +382,67 @@ pub unsafe fn _tile_mmultf32ps<const DST: i32, const A: i32, const B: i32>() {
380382
tmmultf32ps(DST as i8, A as i8, B as i8);
381383
}
382384

385+
/// Moves a row from a tile register to a zmm register, converting the packed 32-bit signed integer
386+
/// elements to packed single-precision (32-bit) floating-point elements.
387+
#[inline]
388+
#[rustc_legacy_const_generics(0)]
389+
#[target_feature(enable = "amx-avx512,avx10.2")]
390+
#[cfg_attr(
391+
all(test, any(target_os = "linux", target_env = "msvc")),
392+
assert_instr(tcvtrowd2ps, TILE = 0)
393+
)]
394+
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
395+
pub unsafe fn _tile_cvtrowd2ps<const TILE: i32>(row: u32) -> __m512 {
396+
static_assert_uimm_bits!(TILE, 3);
397+
tcvtrowd2ps(TILE as i8, row).as_m512()
398+
}
399+
400+
/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
401+
/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting
402+
/// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector.
403+
#[inline]
404+
#[rustc_legacy_const_generics(0)]
405+
#[target_feature(enable = "amx-avx512,avx10.2")]
406+
#[cfg_attr(
407+
all(test, any(target_os = "linux", target_env = "msvc")),
408+
assert_instr(tcvtrowps2phh, TILE = 0)
409+
)]
410+
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
411+
pub unsafe fn _tile_cvtrowps2phh<const TILE: i32>(row: u32) -> __m512h {
412+
static_assert_uimm_bits!(TILE, 3);
413+
tcvtrowps2phh(TILE as i8, row).as_m512h()
414+
}
415+
416+
/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
417+
/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting
418+
/// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector.
419+
#[inline]
420+
#[rustc_legacy_const_generics(0)]
421+
#[target_feature(enable = "amx-avx512,avx10.2")]
422+
#[cfg_attr(
423+
all(test, any(target_os = "linux", target_env = "msvc")),
424+
assert_instr(tcvtrowps2phl, TILE = 0)
425+
)]
426+
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
427+
pub unsafe fn _tile_cvtrowps2phl<const TILE: i32>(row: u32) -> __m512h {
428+
static_assert_uimm_bits!(TILE, 3);
429+
tcvtrowps2phl(TILE as i8, row).as_m512h()
430+
}
431+
432+
/// Moves one row of tile data into a zmm vector register
433+
#[inline]
434+
#[rustc_legacy_const_generics(0)]
435+
#[target_feature(enable = "amx-avx512,avx10.2")]
436+
#[cfg_attr(
437+
all(test, any(target_os = "linux", target_env = "msvc")),
438+
assert_instr(tilemovrow, TILE = 0)
439+
)]
440+
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
441+
pub unsafe fn _tile_movrow<const TILE: i32>(row: u32) -> __m512i {
442+
static_assert_uimm_bits!(TILE, 3);
443+
tilemovrow(TILE as i8, row).as_m512i()
444+
}
445+
383446
#[allow(improper_ctypes)]
384447
unsafe extern "C" {
385448
#[link_name = "llvm.x86.ldtilecfg"]
@@ -426,6 +489,14 @@ unsafe extern "C" {
426489
fn tileloaddrst164(dst: i8, base: *const u8, stride: usize);
427490
#[link_name = "llvm.x86.tmmultf32ps"]
428491
fn tmmultf32ps(dst: i8, a: i8, b: i8);
492+
#[link_name = "llvm.x86.tcvtrowd2ps"]
493+
fn tcvtrowd2ps(tile: i8, row: u32) -> f32x16;
494+
#[link_name = "llvm.x86.tcvtrowps2phh"]
495+
fn tcvtrowps2phh(tile: i8, row: u32) -> f16x32;
496+
#[link_name = "llvm.x86.tcvtrowps2phl"]
497+
fn tcvtrowps2phl(tile: i8, row: u32) -> f16x32;
498+
#[link_name = "llvm.x86.tilemovrow"]
499+
fn tilemovrow(tile: i8, row: u32) -> i32x16;
429500
}
430501

431502
#[cfg(test)]

0 commit comments

Comments
 (0)