diff --git a/cmake/configuring_primitive_list.cmake b/cmake/configuring_primitive_list.cmake index 64d23acd51e..32551f8ec87 100644 --- a/cmake/configuring_primitive_list.cmake +++ b/cmake/configuring_primitive_list.cmake @@ -58,7 +58,7 @@ if (DNNL_ENABLE_PRIMITIVE_GPU_ISA STREQUAL "ALL") else() foreach(isa ${DNNL_ENABLE_PRIMITIVE_GPU_ISA}) string(TOUPPER ${isa} uisa) - if(NOT "${uisa}" MATCHES "^(XELP|XEHP|XEHPG|XEHPC|XE2|XE3)$") + if(NOT "${uisa}" MATCHES "^(XELP|XEHP|XEHPG|XEHPC|XE2|XE3|XE3P)$") message(FATAL_ERROR "Unsupported primitive GPU ISA: ${uisa}") endif() set(BUILD_${uisa} TRUE) diff --git a/cmake/options.cmake b/cmake/options.cmake index 03836e93cf6..e4ab85ec77d 100644 --- a/cmake/options.cmake +++ b/cmake/options.cmake @@ -151,7 +151,7 @@ set(DNNL_ENABLE_PRIMITIVE_GPU_ISA "ALL" CACHE STRING implementations will always be available. Valid values: - ALL (the default). Includes all ISA to be enabled. - ;;... Includes only selected ISA to be enabled. - Possible values are: XELP, XEHP, XEHPG, XEHPC, XE2, XE3.") + Possible values are: XELP, XEHP, XEHPG, XEHPC, XE2, XE3, XE3P.") set(ONEDNN_ENABLE_GEMM_KERNELS_ISA "ALL" CACHE STRING "Specifies an ISA set of GeMM kernels residing in x64/gemm folder to be diff --git a/include/oneapi/dnnl/dnnl_config.h.in b/include/oneapi/dnnl/dnnl_config.h.in index b5ede0d9685..cf233fed6e5 100644 --- a/include/oneapi/dnnl/dnnl_config.h.in +++ b/include/oneapi/dnnl/dnnl_config.h.in @@ -227,6 +227,7 @@ #cmakedefine01 BUILD_XEHPC #cmakedefine01 BUILD_XE2 #cmakedefine01 BUILD_XE3 +#cmakedefine01 BUILD_XE3P // GeMM kernels ISA controls #cmakedefine01 BUILD_GEMM_KERNELS_ALL #cmakedefine01 BUILD_GEMM_KERNELS_NONE diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 54b3627675c..9d4b58b50b5 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -125,6 +125,12 @@ if(UNIX) endif() endif() +# TODO: Remove these after the next pull-down from main. +if(DNNL_WITH_XE3P) +add_definitions_with_host_compiler(-DDNNL_WITH_XE3P) +add_definitions_with_host_compiler(-DXE3P) +endif() + add_subdirectory(common) if(NOT DNNL_CPU_RUNTIME STREQUAL "NONE") diff --git a/src/common/float4.cpp b/src/common/float4.cpp index db90c41b993..debb7f3d5e7 100644 --- a/src/common/float4.cpp +++ b/src/common/float4.cpp @@ -113,8 +113,10 @@ uint8_t float2e3m0(float f) { min_diff = diff; raw_bits = idx; } - // Special case for midpoint, we round to even (so even index) - if ((diff == min_diff) && !(idx & 1)) raw_bits = idx; + // Special case for midpoint: + // - towards 0 for 0.125 + // - up for other ties + if ((diff == min_diff) && idx != 1) raw_bits = idx; } assert(raw_bits < 8); // reapply sign diff --git a/src/common/impl_registration.hpp b/src/common/impl_registration.hpp index 79637a16ec5..250cc4b6a82 100644 --- a/src/common/impl_registration.hpp +++ b/src/common/impl_registration.hpp @@ -239,4 +239,10 @@ #define REG_XE3_ISA(...) #endif +#if BUILD_PRIMITIVE_GPU_ISA_ALL || BUILD_XE3P +#define REG_XE3P_ISA(...) __VA_ARGS__ +#else +#define REG_XE3P_ISA(...) +#endif + #endif diff --git a/src/gpu/CMakeLists.txt b/src/gpu/CMakeLists.txt index c1b86d54c65..0b2d7c991e2 100644 --- a/src/gpu/CMakeLists.txt +++ b/src/gpu/CMakeLists.txt @@ -19,6 +19,10 @@ file(GLOB SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp ) +if(DNNL_WITH_XE3P) +add_definitions_with_host_compiler(-DXE3P=1) +endif() + set(OBJ_LIB ${LIB_PACKAGE_NAME}_gpu) add_library(${OBJ_LIB} OBJECT ${SOURCES}) set_property(GLOBAL APPEND PROPERTY DNNL_LIB_DEPS diff --git a/src/gpu/intel/compute/device_info.cpp b/src/gpu/intel/compute/device_info.cpp index af4f22a42dd..17b13e2ea82 100644 --- a/src/gpu/intel/compute/device_info.cpp +++ b/src/gpu/intel/compute/device_info.cpp @@ -45,6 +45,9 @@ uint64_t get_future_extensions( case gpu_arch_t::xe2: case gpu_arch_t::xe_hpc: case gpu_arch_t::xe3: + case gpu_arch_t::xe3p_35_10: + case gpu_arch_t::xe3p_35_11: + case gpu_arch_t::xe3p_35_unknown: extensions |= (uint64_t)device_ext_t::intel_global_float_atomics; extensions |= (uint64_t)device_ext_t::intel_variable_eu_thread_count; @@ -109,7 +112,13 @@ bool device_info_t::mayiuse_sub_group(int size) const { case gpu_arch_t::xe_lp: case gpu_arch_t::xe_hp: case gpu_arch_t::xe_hpg: return utils::one_of(size, 8, 16, 32); - default: return utils::one_of(size, 16, 32); + case gpu_arch_t::xe_hpc: + case gpu_arch_t::xe2: + case gpu_arch_t::xe3: + case gpu_arch_t::xe3p_35_10: + case gpu_arch_t::xe3p_35_11: + case gpu_arch_t::xe3p_35_unknown: return utils::one_of(size, 16, 32); + default: return utils::one_of(size, 32); } } @@ -145,6 +154,9 @@ int device_info_t::max_eus_per_wg(gpu_arch_t gpu_arch) { switch (gpu_arch) { case gpu::intel::compute::gpu_arch_t::xe_hpc: case gpu::intel::compute::gpu_arch_t::xe2: + case gpu_arch_t::xe3p_35_10: + case gpu_arch_t::xe3p_35_11: + case gpu_arch_t::xe3p_35_unknown: case gpu::intel::compute::gpu_arch_t::xe3: return 8; case gpu::intel::compute::gpu_arch_t::xe_lp: case gpu::intel::compute::gpu_arch_t::xe_hp: @@ -158,6 +170,9 @@ int device_info_t::max_subgroup_size(gpu_arch_t gpu_arch) { switch (gpu_arch) { case gpu::intel::compute::gpu_arch_t::xe_hpc: case gpu::intel::compute::gpu_arch_t::xe2: + case gpu_arch_t::xe3p_35_10: + case gpu_arch_t::xe3p_35_11: + case gpu_arch_t::xe3p_35_unknown: case gpu::intel::compute::gpu_arch_t::xe3: return 32; case gpu::intel::compute::gpu_arch_t::xe_lp: case gpu::intel::compute::gpu_arch_t::xe_hp: @@ -179,6 +194,9 @@ int device_info_t::min_subgroup_size() const { case gpu_arch_t::xe_hpg: return 8; case gpu_arch_t::xe_hpc: case gpu_arch_t::xe2: + case gpu_arch_t::xe3p_35_10: + case gpu_arch_t::xe3p_35_11: + case gpu_arch_t::xe3p_35_unknown: case gpu_arch_t::xe3: return 16; default: return 0; } @@ -188,6 +206,9 @@ int device_info_t::max_exec_size(gpu_arch_t gpu_arch) { switch (gpu_arch) { case gpu::intel::compute::gpu_arch_t::xe_hpc: case gpu::intel::compute::gpu_arch_t::xe2: + case gpu::intel::compute::gpu_arch_t::xe3p_35_10: + case gpu::intel::compute::gpu_arch_t::xe3p_35_11: + case gpu::intel::compute::gpu_arch_t::xe3p_35_unknown: case gpu::intel::compute::gpu_arch_t::xe3: return 128; default: return 64; } @@ -221,6 +242,9 @@ int device_info_t::threads_per_eu(gpu_arch_t gpu_arch, bool large_grf_mode) { case gpu::intel::compute::gpu_arch_t::xe_hpg: case gpu::intel::compute::gpu_arch_t::xe_hpc: case gpu::intel::compute::gpu_arch_t::xe2: + case gpu::intel::compute::gpu_arch_t::xe3p_35_10: + case gpu::intel::compute::gpu_arch_t::xe3p_35_11: + case gpu::intel::compute::gpu_arch_t::xe3p_35_unknown: case gpu::intel::compute::gpu_arch_t::xe3: return large_grf_mode ? 4 : 8; case gpu::intel::compute::gpu_arch_t::unknown: return 7; @@ -238,6 +262,11 @@ int device_info_t::max_slm_size(gpu_arch_t gpu_arch) { case gpu::intel::compute::gpu_arch_t::xe_hpg: case gpu::intel::compute::gpu_arch_t::xe_hpc: case gpu::intel::compute::gpu_arch_t::xe2: + case gpu::intel::compute::gpu_arch_t::xe3p_35_10: + case gpu::intel::compute::gpu_arch_t::xe3p_35_11: + case gpu::intel::compute::gpu_arch_t::xe3p_35_unknown: + slm_size = 3 * (1 << 17); + break; case gpu::intel::compute::gpu_arch_t::xe3: slm_size = (1 << 17); break; case gpu::intel::compute::gpu_arch_t::unknown: assert(!"not expected"); } @@ -269,6 +298,9 @@ size_t device_info_t::icache_size() const { case gpu::intel::compute::gpu_arch_t::xe_hpc: return 80 * 1024; case gpu::intel::compute::gpu_arch_t::xe2: return 96 * 1024; case gpu::intel::compute::gpu_arch_t::xe3: return 96 * 1024; + case gpu::intel::compute::gpu_arch_t::xe3p_35_10: + case gpu::intel::compute::gpu_arch_t::xe3p_35_11: + case gpu::intel::compute::gpu_arch_t::xe3p_35_unknown: return 80 * 1024; case gpu::intel::compute::gpu_arch_t::unknown: assert(!"not expected"); } return 0; diff --git a/src/gpu/intel/compute/device_info.hpp b/src/gpu/intel/compute/device_info.hpp index 6e2319be1bb..5f915c866ff 100644 --- a/src/gpu/intel/compute/device_info.hpp +++ b/src/gpu/intel/compute/device_info.hpp @@ -41,7 +41,18 @@ namespace gpu { namespace intel { namespace compute { -enum class gpu_arch_t { unknown, xe_lp, xe_hp, xe_hpg, xe_hpc, xe2, xe3 }; +enum class gpu_arch_t { + unknown, + xe_lp, + xe_hp, + xe_hpg, + xe_hpc, + xe2, + xe3, + xe3p_35_10, + xe3p_35_11, + xe3p_35_unknown, +}; // Memory for storing ngen::Product to avoid directly including nGEN because of // header dependencies outside of src/gpu/intel. @@ -58,6 +69,9 @@ static inline const char *to_string(gpu_arch_t arch) { CASE(xe_hpc); CASE(xe2); CASE(xe3); + CASE(xe3p_35_10); + CASE(xe3p_35_11); + CASE(xe3p_35_unknown); return "unknown"; #undef CASE } @@ -71,6 +85,9 @@ static inline gpu_arch_t str2gpu_arch(const char *str) { CASE(xe_hpc); CASE(xe2); CASE(xe3); + CASE(xe3p_35_10); + CASE(xe3p_35_11); + CASE(xe3p_35_unknown); return gpu_arch_t::unknown; #undef CASE } @@ -253,6 +270,8 @@ struct device_info_t { bool has_native(data_type_t type) const; + bool is_efficient_64bit() const { return is_efficient_64bit_; } + const std::vector &get_cache_blob() const { return serialized_device_info_.get_data(); } @@ -282,6 +301,7 @@ struct device_info_t { bool mayiuse_systolic_ = false; bool mayiuse_ngen_kernels_ = false; bool mayiuse_system_memory_allocators_ = false; + bool is_efficient_64bit_ = false; std::string name_; xpu::runtime_version_t runtime_version_; diff --git a/src/gpu/intel/conv/jit/config.cpp b/src/gpu/intel/conv/jit/config.cpp index c54f5d0fdd0..2e92949bae7 100644 --- a/src/gpu/intel/conv/jit/config.cpp +++ b/src/gpu/intel/conv/jit/config.cpp @@ -1094,6 +1094,8 @@ status_t init_vec_size(config_t &cfg) { int default_regs(const config_t &cfg) { if (!cfg.hw().large_grf_support()) return 128; + if (cfg.hw() == ngen::HW::XE3P_35_11 && cfg.is_dpas_or_dpasw_fma()) + return 512; if (cfg.is_dpas_or_dpasw_fma()) return 256; return 128; } diff --git a/src/gpu/intel/conv/jit/ir_builder.cpp b/src/gpu/intel/conv/jit/ir_builder.cpp index 09c0a81f7e7..4d53dbc2c85 100644 --- a/src/gpu/intel/conv/jit/ir_builder.cpp +++ b/src/gpu/intel/conv/jit/ir_builder.cpp @@ -251,6 +251,9 @@ class compute_builder_t { alloc_updater.update(buf_mgr_); } + // Assign {Fwd} for dpas when applicable. + if (cfg_.hw() >= ngen::HW::XE3P_35_10) + x2r_mul_stmt_ = inject_dpas_fwd(x2r_mul_stmt_); // Assign {Atomic} for dpas(w) when applicable. x2r_mul_stmt_ = inject_dpas_atomic(x2r_mul_stmt_); } diff --git a/src/gpu/intel/conv/jit/model_bridge.cpp b/src/gpu/intel/conv/jit/model_bridge.cpp index 5a9d7e522b1..271b83ddc73 100644 --- a/src/gpu/intel/conv/jit/model_bridge.cpp +++ b/src/gpu/intel/conv/jit/model_bridge.cpp @@ -59,6 +59,9 @@ hw_t to_hw(ngen::HW hw) { case ngen::HW::XeHPC: return hw_t::xehpc; case ngen::HW::Xe2: return hw_t::xehpc; case ngen::HW::Xe3: return hw_t::xehpc; + case ngen::HW::XE3P_35_10: + case ngen::HW::XE3P_35_11: + case ngen::HW::XE3P_UNKNOWN: return hw_t::xehpc; default: gpu_error_not_expected() << "Unknown HW: " << to_string(hw); } return hw_t::undef; diff --git a/src/gpu/intel/conv/jit/plan.cpp b/src/gpu/intel/conv/jit/plan.cpp index 347b4b80b98..31fcfa631ac 100644 --- a/src/gpu/intel/conv/jit/plan.cpp +++ b/src/gpu/intel/conv/jit/plan.cpp @@ -1355,7 +1355,8 @@ struct fma_context_t { bool is_dpas = is_dp_fma(fma); bool is_a = (abc == abc_kind_t::a); auto type = (is_a ? a_type : b_type); - bool cvt_f16 = (layout.type().is_fp8() || layout.type().is_fp4()); + bool cvt_f16 = ((hw < ngen::HW::XE3P_35_10 && layout.type().is_fp8()) + || (hw < ngen::HW::XE3P_35_11 && layout.type().is_fp4())); int type_size = (cvt_f16 ? 2 : type.size()); if (is_dpas) { int sdepth = 8; @@ -2209,11 +2210,51 @@ class plan_builder_t { return plan_status_t::success; } + // Extends the view to cover 256 contiguous bytes for more efficient + // prefetching. + void maybe_extend_prefetch_thread_view_to_256_bytes( + view_t &thr_view) const { + auto thr_layout = thr_view.create_pseudo_vlayout(); + auto &blocks = thr_layout.blocks(); + if (blocks.size() <= 1) return; + + auto &b0 = blocks[0]; + auto &b1 = blocks[1]; + if (!b1.stride.is_fixed() || !b0.stride.is_fixed()) return; + auto inner_var = thr_view.vvars()[b0.idx]; + bool is_block_strided + = (b0.stride == stride_t(1)) && (b1.stride > b0.size); + int type_size = thr_layout.type().size(); + dim_t full_dim_size + = gemm_schedule_.a_view().vdims()[b0.idx] * type_size; + bool size_ge_256b = (full_dim_size >= 256); + dim_t b0_size = b0.size * type_size; + bool prefetch_lt_256b = (b0_size < 256); + bool is_inner_loop = gemm_schedule_.is_inner_loop(inner_var); + // Extend if the following conditions are satisfied: + // - The inner block (b0) is dense and smaller than 256 bytes + // - The original tensor has at least 256 bytes across b0 dimension + // - The inner block dimensions corresponds to the inner loop + // dimension. We want to prefetch extra cache lines only if they are + // going to be used by the next iterations. + if (is_block_strided && size_ge_256b && prefetch_lt_256b + && is_inner_loop) { + gpu_assert(thr_view.vdims()[b0.idx] == b0.size); + int factor = 256 / b0_size; + thr_view.set_vdim(inner_var, b0.size * factor, + thr_view.vstart()[b0.idx], + /*overwrite=*/true); + } + } + plan_status_t init_x_prefetch_plan(abc_kind_t abc, const view_t &tg_view, grid_info_t &grid, send_plan_t &prefetch) const { if (!use_prefetch(abc)) return plan_status_t::success; auto &tg = cfg_.thread_group_grid(); auto thr_view = tg_view.split(tg, &grid); + if (cfg_.hw() == ngen::HW::XE3P_35_11) { + maybe_extend_prefetch_thread_view_to_256_bytes(thr_view); + } auto params = get_send_params(cfg_.options(), send_op_t::prefetch, send_address_t::a64, fma_kind_t::undef, abc, thr_view, gemm_schedule_); diff --git a/src/gpu/intel/gemm/jit.hpp b/src/gpu/intel/gemm/jit.hpp index de33788ab83..999a661e6cb 100644 --- a/src/gpu/intel/gemm/jit.hpp +++ b/src/gpu/intel/gemm/jit.hpp @@ -195,6 +195,7 @@ struct gen_t : public primitive_t { // Check GPU architecture. bool arch_ok = utils::one_of(arch_, arch_t::xe_lp, arch_t::xe_hp, arch_t::xe_hpg, arch_t::xe_hpc, arch_t::xe2, arch_t::xe3); + arch_ok |= (arch_ >= arch_t::xe3p_35_10); VDISPATCH_GEMM(arch_ok, VERBOSE_UNSUPPORTED_ARCH, "gpu"); VDISPATCH_GEMM(IMPLICATION(with_binary, arch_ >= arch_t::xe_hp), @@ -215,7 +216,7 @@ struct gen_t : public primitive_t { || intel_engine->mayiuse(compute::device_ext_t:: intel_subgroup_split_matrix_multiply_accumulate); - bool is_integrated = intel_engine->device_info()->is_integrated(); + bool is_integrated = dev_info_->is_integrated(); // Size checks for fused reduction kernels. if (with_sum_ab()) { @@ -260,6 +261,9 @@ struct gen_t : public primitive_t { !with_eltwise && !with_binary), VERBOSE_UNSUPPORTED_POSTOP); + if (arch_ >= arch_t::xe3p_35_10) + kernel_desc_.set_efficient_64b(dev_info_->is_efficient_64bit()); + bool print_verbose = get_verbose(verbose_t::debuginfo) >= 5; bool kernel_success = false; auto lda = ld(DNNL_ARG_A); diff --git a/src/gpu/intel/gemm/jit/CMakeLists.txt b/src/gpu/intel/gemm/jit/CMakeLists.txt index e0b7e6b655c..b7ad271cee1 100644 --- a/src/gpu/intel/gemm/jit/CMakeLists.txt +++ b/src/gpu/intel/gemm/jit/CMakeLists.txt @@ -21,15 +21,16 @@ endif() # Use oneDNN names for ALL to ensure string replacement functions correctly set(GPUS ${DNNL_ENABLE_PRIMITIVE_GPU_ISA}) -string(REPLACE "ALL" "XELP;XEHP;XEHPG;XEHPC;XE2;XE3" GPUS "${GPUS}") +string(REPLACE "ALL" "XELP;XEHP;XEHPG;XEHPC;XE2;XE3;XE3P" GPUS "${GPUS}") string(REPLACE "XELP" "12LP" GPUS "${GPUS}") string(REPLACE "XEHPG" "12p7" GPUS "${GPUS}") string(REPLACE "XEHPC" "12p8" GPUS "${GPUS}") string(REPLACE "XEHP" "12HP" GPUS "${GPUS}") string(REPLACE "XE2" "Xe2" GPUS "${GPUS}") string(REPLACE "XE3" "Xe3" GPUS "${GPUS}") +string(REPLACE "XE3P" "Xe3P" GPUS "${GPUS}") -set(ALL_GPUS "12LP;12HP;12p7;12p8;Xe2;Xe3") +set(ALL_GPUS "12LP;12HP;12p7;12p8;Xe2;Xe3;Xe3P") foreach(GPU ${GPUS}) if(NOT ${GPU} IN_LIST ALL_GPUS) message(FATAL_ERROR "Unknown GPU architecture: ${GPU}") @@ -63,7 +64,7 @@ if(DPCPP_HOST_COMPILER_KIND STREQUAL "DEFAULT") ) if (DNNL_ENABLE_PRIMITIVE_GPU_ISA STREQUAL "ALL") - set(DNNL_GPU_ISA_LIST "XELP;XEHP;XEHPG;XEHPC;XE2;XE3") + set(DNNL_GPU_ISA_LIST "XELP;XEHP;XEHPG;XEHPC;XE2;XE3;XE3P") else() foreach(isa ${DNNL_ENABLE_PRIMITIVE_GPU_ISA}) string(TOUPPER ${isa} ISA) diff --git a/src/gpu/intel/gemm/jit/dsl/hw.cpp b/src/gpu/intel/gemm/jit/dsl/hw.cpp index d7f18fae267..2dd77d7f7b3 100644 --- a/src/gpu/intel/gemm/jit/dsl/hw.cpp +++ b/src/gpu/intel/gemm/jit/dsl/hw.cpp @@ -21,13 +21,14 @@ GEMMSTONE_NAMESPACE_START namespace dsl { -hw_t::hw_t(const ngen::Product &product, int eu_count, size_t max_wg_size, - size_t l3_cache_size, attr_t attr) +hw_t::hw_t(const ngen::Product &product, int eu_count, int max_wg_size, + size_t l3_cache_size, bool efficient_64_bit, attr_t attr) : product_(product) , hw_(ngen::getCore(product.family)) , eu_count_(eu_count) , max_wg_size_(max_wg_size) , l3_cache_size_(l3_cache_size) + , efficient_64_bit_(efficient_64_bit) , attr_(attr) {} ngen::Product hw_t::product() const { @@ -62,8 +63,11 @@ int hw_t::eus_per_core() const { case ngen::HW::XeHP: case ngen::HW::XeHPC: case ngen::HW::Xe2: - case ngen::HW::Xe3: return 8; - default: stub(); return 8; + case ngen::HW::Xe3: + case ngen::HW::XE3P_35_10: + case ngen::HW::XE3P_35_11: + case ngen::HW::XE3P_UNKNOWN: return 8; + default: gpu_error_not_expected(); return 8; } } int hw_t::threads_per_eu(int regs) const { @@ -74,8 +78,11 @@ int hw_t::threads_per_eu(int regs) const { case ngen::HW::XeHPG: case ngen::HW::XeHPC: case ngen::HW::Xe2: - case ngen::HW::Xe3: return is_large_grf ? 4 : 8; - default: stub(); return 8; + case ngen::HW::Xe3: + case ngen::HW::XE3P_35_10: + case ngen::HW::XE3P_35_11: + case ngen::HW::XE3P_UNKNOWN: return is_large_grf ? 4 : 8; + default: gpu_error_not_expected(); return 8; } } @@ -86,8 +93,11 @@ int hw_t::cache_line_size() const { case ngen::HW::XeHPG: case ngen::HW::XeHPC: case ngen::HW::Xe2: - case ngen::HW::Xe3: return 64; - default: stub(); + case ngen::HW::Xe3: + case ngen::HW::XE3P_35_10: + case ngen::HW::XE3P_35_11: + case ngen::HW::XE3P_UNKNOWN: return 64; + default: gpu_error_not_expected(); } return 0; } diff --git a/src/gpu/intel/gemm/jit/dsl/ir/codegen/bank_conflict_allocation.cpp b/src/gpu/intel/gemm/jit/dsl/ir/codegen/bank_conflict_allocation.cpp index 1838f1e5ea0..cee1cdb87d7 100644 --- a/src/gpu/intel/gemm/jit/dsl/ir/codegen/bank_conflict_allocation.cpp +++ b/src/gpu/intel/gemm/jit/dsl/ir/codegen/bank_conflict_allocation.cpp @@ -81,8 +81,11 @@ struct hw_context_t { case ngen::HW::XeHPG: return 8; case ngen::HW::XeHPC: case ngen::HW::Xe2: - case ngen::HW::Xe3: return 16; - default: stub(); + case ngen::HW::Xe3: + case ngen::HW::XE3P_35_10: + case ngen::HW::XE3P_35_11: + case ngen::HW::XE3P_UNKNOWN: return 16; + default: gpu_error_not_expected(); } return -1; } @@ -251,7 +254,7 @@ struct reg_mask_t : public stringify_t { } static const int chunk_bits = 64; - static const int max_regs = 256; + static const int max_regs = 512; static const int max_nchunks = max_regs / chunk_bits; const hw_context_t *hw_ctx = nullptr; diff --git a/src/gpu/intel/gemm/jit/dsl/ir/codegen/codegen.cpp b/src/gpu/intel/gemm/jit/dsl/ir/codegen/codegen.cpp index 5b93f120be2..7d153a611b4 100644 --- a/src/gpu/intel/gemm/jit/dsl/ir/codegen/codegen.cpp +++ b/src/gpu/intel/gemm/jit/dsl/ir/codegen/codegen.cpp @@ -606,13 +606,16 @@ class ir_to_ngen_t final : public codegen_extension_iface_t, } else { dsl_assert(dst.byte_offset() == src0.getByteOffset()) << "dst/src0 must be aligned to the same GRF offset."; - align_src_dst_offset(host_, scope, mod, dst, src1, src2); if (mad_func.dst_type == type_t::f64() && src1.reg_data().getHS() == 0 && src1.reg_data().getVS() == 0) { + align_src_dst_offset(host_, scope, mod, dst, src1); + align_src_dst_offset(host_, scope, mod, dst, src2, true); // Workaround for sporadic f64 mad errors with broadcast src1 on XeHPC. host_->mad(mod, dst, src0, src2, src1); } else { + align_src_dst_offset(host_, scope, mod, dst, src1, true); + align_src_dst_offset(host_, scope, mod, dst, src2); host_->mad(mod, dst, src0, src1, src2); } } @@ -821,6 +824,21 @@ class ir_to_ngen_t final : public codegen_extension_iface_t, std::vector last_used_header_regs_; }; +bool is_src1_ok(ngen::HW hw, const ngen_operand_t &dst, + const ngen_operand_t &src0, const ngen_operand_t &src1) { + if (one_of(hw, + {ngen::HW::XE3P_35_10, ngen::HW::XE3P_35_11, + ngen::HW::XE3P_UNKNOWN})) { + if (!src1.is_reg_data()) return true; + auto src1_rd = src1.reg_data(); + if (src1_rd.isScalar()) return true; + auto dst_rd = dst.reg_data(); + return src1_rd.getType() == dst_rd.getType() + && src1_rd.getHS() == dst_rd.getHS(); + } + return true; +} + // Evaluates expression by emitting instructions with nGEN. template class expr_evaluator_t : public ir_visitor_t { @@ -853,6 +871,45 @@ class expr_evaluator_t : public ir_visitor_t { host_->sel(dst_operand.mod(), dst_operand.reg_data(), bind.reg_data(), 0); } else { + const auto grf_size = ngen::GRF::bytes(hw()); + if (hw() >= ngen::HW::XE3P_35_10 + && bind.is_reg_buf_data()) { + auto mod = dst_operand.mod(); + auto dst = dst_operand.reg_data(); + auto src = bind.reg_data(); + auto exec_size = mod.getExecSize(); + const auto dst_stride = dst.getHS(); + const auto dst_byte_offset = dst.getByteOffset(); + const auto dst_type_size = dst.getBytes(); + const auto dst_bytes = dst_type_size + * ((exec_size - 1) * dst_stride + 1); + const auto dst_end_byte = dst_byte_offset + dst_bytes; + if (dst_end_byte > grf_size) { // Compressed instruction + const auto tail_bytes = grf_size - dst_byte_offset; + // Index where we cross the grf boundary. + const auto cidx = 1 + + (tail_bytes / dst_type_size - 1) + / dst_stride; + const auto src_width = src.getWidth(); + const auto src_hs = src.getHS(); + const auto src_vs = src.getVS(); + const auto x = cidx % src_width; + const auto y = cidx / src_width; + const auto src_base = src.getBase() + + (src_hs * x + src_vs * y) / grf_size; + const auto dst_base = dst.getBase(); + if (src_base == dst_base) { + const auto src_type = src.getType(); + const int nregs = dst_bytes / grf_size; + auto tmp = scope_.alloc_range(nregs); + auto t = tmp.sub(hw(), 0, src_type)(dst_stride); + host_->emov(exec_size, t, src); + host_->emov(mod, dst, t); + scope_.safeRelease(tmp); + return dst_operand; + } + } + } host_->emov(dst_operand.mod(), dst_operand, bind); } return dst_operand; @@ -1280,7 +1337,8 @@ class expr_evaluator_t : public ir_visitor_t { auto &dst = _dst; auto src0 = _src0; auto src1 = _src1; - align_src_dst_offset(host_, scope_, mod, dst, src0, src1); + align_src_dst_offset(host_, scope_, mod, dst, src0); + align_src_dst_offset(host_, scope_, mod, dst, src1, true); switch (obj.op_kind) { case op_kind_t::_add: host_->eadd(mod, dst, src0, src1); break; case op_kind_t::_sub: host_->eadd(mod, dst, src0, -src1); break; @@ -1554,6 +1612,10 @@ class expr_evaluator_t : public ir_visitor_t { auto t = tmp.format(0, obj.elems(), 1, w_type); reg_buf_data_t t_strided; bool align_with_dst = false; + if (one_of(hw(), + {ngen::HW::XE3P_35_10, ngen::HW::XE3P_35_11, + ngen::HW::XE3P_UNKNOWN})) + align_with_dst = true; if (align_with_dst) { int w_stride = dst_stride * (ngen::getBytes(dst.type()) / w_size); int tmp_strided_regs @@ -1652,6 +1714,12 @@ ngen::NEOInterfaceHandler generate_ngen_interface( if (setup_flags.has_dpas || options.require_dpas()) interface.requireDPAS(); if (setup_flags.has_send_atomics) interface.requireGlobalAtomics(); + if (one_of(options.hw(), + {ngen::HW::XE3P_35_10, ngen::HW::XE3P_35_11, + ngen::HW::XE3P_UNKNOWN}) + && !options.hw().efficient_64_bit()) + interface.setEfficient64Bit(false); + for (size_t i = 0; i < kernel_iface.nargs(); i++) { auto &name = kernel_iface[i].as().name; auto &type = kernel_iface[i].type(); @@ -1707,6 +1775,12 @@ ngen::NEOInterfaceHandler generate_ngen_interface( #define GEMMSTONE_XE3_ISA(...) #endif +#ifdef GEMMSTONE_BUILD_XE3P +#define GEMMSTONE_XE3P_ISA(...) __VA_ARGS__ +#else +#define GEMMSTONE_XE3P_ISA(...) +#endif + #define GPU_HW_CASE_(hw) \ case ngen::HW::hw: { \ GPU_HW_CASE(ngen::HW::hw); \ @@ -1721,6 +1795,9 @@ ngen::NEOInterfaceHandler generate_ngen_interface( GEMMSTONE_XEHPC_ISA(GPU_HW_CASE_(XeHPC)); \ GEMMSTONE_XE2_ISA(GPU_HW_CASE_(Xe2)); \ GEMMSTONE_XE3_ISA(GPU_HW_CASE_(Xe3)); \ + GEMMSTONE_XE3P_ISA(GPU_HW_CASE_(XE3P_35_10)); \ + GEMMSTONE_XE3P_ISA(GPU_HW_CASE_(XE3P_35_11)); \ + GEMMSTONE_XE3P_ISA(GPU_HW_CASE_(XE3P_UNKNOWN)); \ default: dsl_assert(false) << "Unexpected GPU architecture"; \ } @@ -1793,6 +1870,19 @@ GEMMSTONE_XE3_ISA( template void ir::convert_ir_to_ngen>( const stmt_t &body, sycl_gen_t &host, const walk_order_t *kernel_grid_walk_order)); +GEMMSTONE_XE3P_ISA( + template void ir::convert_ir_to_ngen>( + const stmt_t &body, sycl_gen_t &host, + const walk_order_t *kernel_grid_walk_order)); +GEMMSTONE_XE3P_ISA( + template void ir::convert_ir_to_ngen>( + const stmt_t &body, sycl_gen_t &host, + const walk_order_t *kernel_grid_walk_order)); +GEMMSTONE_XE3P_ISA(template void + ir::convert_ir_to_ngen>( + const stmt_t &body, + sycl_gen_t &host, + const walk_order_t *kernel_grid_walk_order)); ::sycl::kernel make_kernel( const kernel_t &ir_kernel, ::sycl::context ctx, ::sycl::device dev) { @@ -1844,6 +1934,18 @@ GEMMSTONE_XE3_ISA( template void ir::convert_ir_to_ngen>( const stmt_t &body, ocl_gen_t &host, const walk_order_t *kernel_grid_walk_order)); +GEMMSTONE_XE3P_ISA( + template void ir::convert_ir_to_ngen>( + const stmt_t &body, ocl_gen_t &host, + const walk_order_t *kernel_grid_walk_order)); +GEMMSTONE_XE3P_ISA( + template void ir::convert_ir_to_ngen>( + const stmt_t &body, ocl_gen_t &host, + const walk_order_t *kernel_grid_walk_order)); +GEMMSTONE_XE3P_ISA( + template void ir::convert_ir_to_ngen>( + const stmt_t &body, ocl_gen_t &host, + const walk_order_t *kernel_grid_walk_order)); cl_kernel make_kernel( const kernel_t &ir_kernel, cl_context ctx, cl_device_id dev) { diff --git a/src/gpu/intel/gemm/jit/dsl/ir/codegen/kernel.hpp b/src/gpu/intel/gemm/jit/dsl/ir/codegen/kernel.hpp index f0fb4fd792e..bead7f487d1 100644 --- a/src/gpu/intel/gemm/jit/dsl/ir/codegen/kernel.hpp +++ b/src/gpu/intel/gemm/jit/dsl/ir/codegen/kernel.hpp @@ -513,7 +513,7 @@ class ir_to_ngen_generator_t : public BaseGeneratorT { auto src2 = _src2; auto scope = ngen_register_scope_t(ra_); align_src_dst_offset(this, scope, mod, dst, src0); - align_src_dst_offset(this, scope, mod, dst, src1); + align_src_dst_offset(this, scope, mod, dst, src1, true); if (getHardware() >= ngen::HW::XeHP) { if (src2.is_reg_data()) { align_src_dst_offset(this, scope, mod, dst, src2); @@ -527,7 +527,7 @@ class ir_to_ngen_generator_t : public BaseGeneratorT { } add(mod, dst.reg_data(), src0.reg_data(), src1.reg_data()); if (src2.is_reg_data()) { - align_src_dst_offset(this, scope, mod, dst, src2); + align_src_dst_offset(this, scope, mod, dst, src2, true); add(mod, dst.reg_data(), dst.reg_data(), src2.reg_data()); } else { add(mod, dst.reg_data(), dst.reg_data(), src2.immediate()); @@ -541,7 +541,7 @@ class ir_to_ngen_generator_t : public BaseGeneratorT { auto src1 = _src1; auto src2 = _src2; auto scope = ngen_register_scope_t(ra_); - align_src_dst_offset(this, scope, mod, dst, src1); + align_src_dst_offset(this, scope, mod, dst, src1, true); if (src2.is_reg_data()) { align_src_dst_offset(this, scope, mod, dst, src0); align_src_dst_offset(this, scope, mod, dst, src2); @@ -900,6 +900,10 @@ class ir_to_ngen_generator_t : public BaseGeneratorT { // qot = (x * m) >> p bool use_mach = true; + if (one_of(hw_info(), + {ngen::HW::XE3P_35_10, ngen::HW::XE3P_35_11, + ngen::HW::XE3P_UNKNOWN})) + use_mach = false; if (use_mach) { auto acc = acc0.retype(div_type); mul(1, acc[0], _x, m & 0xFFFF); diff --git a/src/gpu/intel/gemm/jit/dsl/ir/codegen/ngen_helpers.hpp b/src/gpu/intel/gemm/jit/dsl/ir/codegen/ngen_helpers.hpp index db0b9d1e9d6..e471cfa285c 100644 --- a/src/gpu/intel/gemm/jit/dsl/ir/codegen/ngen_helpers.hpp +++ b/src/gpu/intel/gemm/jit/dsl/ir/codegen/ngen_helpers.hpp @@ -28,11 +28,11 @@ GEMMSTONE_NAMESPACE_START namespace dsl { namespace ir { -constexpr ngen::DataType ngen_f4_e3m0() { +constexpr ngen::DataType ngen_e3m0() { return static_cast(0x5B); } -constexpr ngen::DataType ngen_f4_e2m1() { +constexpr ngen::DataType ngen_e2m1() { return static_cast(0x5A); } @@ -65,9 +65,9 @@ inline ngen::DataType to_ngen(const type_t &type) { if (type.base() == type_t::_kind()) return ngen::DataType::ngen_enum // Until f4_e3m0 lands in ngen - if (type.base() == type_t::f4_e3m0()) return ngen_f4_e3m0(); + if (type.base() == type_t::f4_e3m0()) return ngen_e3m0(); // Until f4_e2m1 lands in ngen - if (type.base() == type_t::f4_e2m1()) return ngen_f4_e2m1(); + if (type.base() == type_t::f4_e2m1()) return ngen_e2m1(); CASE(bf16, bf); CASE(f16, hf); @@ -99,8 +99,8 @@ inline type_t to_ir(ngen::DataType type) { #define CASE(_kind, ngen_enum) \ if (type == ngen::DataType::ngen_enum) return type_t::_kind(); - if (type == ngen_f4_e3m0()) return type_t::f4_e3m0(); - if (type == ngen_f4_e2m1()) return type_t::f4_e2m1(); + if (type == ngen_e3m0()) return type_t::f4_e3m0(); + if (type == ngen_e2m1()) return type_t::f4_e2m1(); CASE(bf16, bf); CASE(f16, hf); diff --git a/src/gpu/intel/gemm/jit/dsl/ir/codegen/reorder.hpp b/src/gpu/intel/gemm/jit/dsl/ir/codegen/reorder.hpp index 54ca41d69f4..9f05b8104b8 100644 --- a/src/gpu/intel/gemm/jit/dsl/ir/codegen/reorder.hpp +++ b/src/gpu/intel/gemm/jit/dsl/ir/codegen/reorder.hpp @@ -124,7 +124,7 @@ void emit_reorder_1d_tile(GeneratorT *host, ngen_register_scope_t &scope, template void align_src_dst_offset(GeneratorT *host, ngen_register_scope_t &scope, const ngen::InstructionModifier &mod, const reg_buf_data_t &dst, - reg_buf_data_t &src) { + reg_buf_data_t &src, bool align_stride = false) { int src_stride = src.hs(); // src is broadcasted, no need to align, return. if (src_stride == 0) return; @@ -145,10 +145,29 @@ void align_src_dst_offset(GeneratorT *host, ngen_register_scope_t &scope, // - <1; 1, 0>, which is a more compatible representation of int grf_src = grf_size / std::max(src.hs(), 1); int grf_dst = grf_size / std::max(dst.hs(), 1); + auto new_src_type = src.type(); + bool needs_stride_alignment = false; + bool align_bf = is_bf_to_f && src.hs(); + if (scope.hw() >= ngen::HW::XE3P_35_10 && (align_stride || align_bf)) { + auto src_stride_bytes = src.hs() * src_type_size; + auto dst_stride_bytes = dst.hs() * dst_type_size; + needs_stride_alignment = (src_stride_bytes != dst_stride_bytes); + src_stride = dst_stride_bytes / src_type_size; + if (align_bf) { + needs_stride_alignment = true; + src_stride = dst.hs(); + new_src_type = dst.type(); + src_type_size = ngen::getBytes(new_src_type); + } + } // If src is aligned with dst, return. - if ((is_xf || is_bf_to_f) && src_off % grf_src == dst_off % grf_dst) return; - if (!is_xf && src_byte_off % grf_size == dst_byte_off % grf_size) return; + if (!needs_stride_alignment) { + if ((is_xf || is_bf_to_f) && src_off % grf_src == dst_off % grf_dst) + return; + if (!is_xf && src_byte_off % grf_size == dst_byte_off % grf_size) + return; + } int new_src_off = (is_xf ? dst_off * src_type_size / dst_type_size : dst_off * dst_type_size / src_type_size); @@ -156,24 +175,17 @@ void align_src_dst_offset(GeneratorT *host, ngen_register_scope_t &scope, int src_size = std::max(src_type_size * esize * src_stride, src_type_size); auto new_src = scope.alloc_reg_buf_data( div_up(src_size + new_src_off * src_type_size, grf_size)); - new_src = new_src.format(new_src_off, esize, src_stride, src.type()); + new_src = new_src.format(new_src_off, esize, src_stride, new_src_type); + emit_reorder_1d_tile( - host, scope, esize, src, src_stride, new_src, src_stride); + host, scope, esize, src, src.hs(), new_src, src_stride); src = std::move(new_src); } -template -void align_src_dst_offset(GeneratorT *host, ngen_register_scope_t &scope, - const ngen::InstructionModifier &mod, const reg_buf_data_t &dst, - reg_buf_data_t &src0, reg_buf_data_t &src1) { - align_src_dst_offset(host, scope, mod, dst, src0); - align_src_dst_offset(host, scope, mod, dst, src1); -} - template void align_src_dst_offset(GeneratorT *host, ngen_register_scope_t &scope, const ngen::InstructionModifier &mod, const ngen_operand_t &dst, - ngen_operand_t &src) { + ngen_operand_t &src, bool align_stride = false) { if (!src.is_reg_data()) return; auto rd = src.reg_buf_data(); @@ -183,9 +195,10 @@ void align_src_dst_offset(GeneratorT *host, ngen_register_scope_t &scope, // GRF boundary. reg_buf_data_t dummy(reg_buf_t(rd.hw(), ngen::GRFRange(0, 1))); // This call returns early if everything is already aligned nicely - align_src_dst_offset(host, scope, mod, dummy, rd); + align_src_dst_offset(host, scope, mod, dummy, rd, align_stride); } else { - align_src_dst_offset(host, scope, mod, dst.reg_buf_data(), rd); + align_src_dst_offset( + host, scope, mod, dst.reg_buf_data(), rd, align_stride); } if (rd == src.reg_buf_data()) return; diff --git a/src/gpu/intel/gemm/jit/dsl/ir/codegen/send.hpp b/src/gpu/intel/gemm/jit/dsl/ir/codegen/send.hpp index 1ab6e1b4110..01382f0d089 100644 --- a/src/gpu/intel/gemm/jit/dsl/ir/codegen/send.hpp +++ b/src/gpu/intel/gemm/jit/dsl/ir/codegen/send.hpp @@ -47,6 +47,15 @@ inline ngen::CacheSettingsLSC get_cache_settings( ret = ngen::CacheSettingsLSC::L1C_L3C; } break; + case ngen::HW::XE3P_35_10: + case ngen::HW::XE3P_35_11: + case ngen::HW::XE3P_UNKNOWN: + if (is_store) { + ret = ngen::CacheSettingsLSC::L1UC_L2UC_L3WB; + } else if (is_load || is_prefetch) { + ret = ngen::CacheSettingsLSC::L1C_L2C_L3C; + } + break; default: break; } break; @@ -298,6 +307,7 @@ class send_impl_t { switch (op) { case send_op_t::atomic_add: return ngen::AtomicOp::add; case send_op_t::atomic_fadd: return ngen::AtomicOp::fadd; + case send_op_t::atomic_bfadd: return ngen::AtomicOp::bfadd; case send_op_t::atomic_cmpwr: return ngen::AtomicOp::cmpwr; default: stub(); } diff --git a/src/gpu/intel/gemm/jit/dsl/ir/core.hpp b/src/gpu/intel/gemm/jit/dsl/ir/core.hpp index c44f37e613c..2b1392e639b 100644 --- a/src/gpu/intel/gemm/jit/dsl/ir/core.hpp +++ b/src/gpu/intel/gemm/jit/dsl/ir/core.hpp @@ -1709,6 +1709,7 @@ class instruction_modifier_attr_t is_first = false; }; if (mod.isAtomic()) append("Atomic"); + if (mod.isFwd()) append("Fwd"); for (auto item : mod.getSWSB()) { if (item.hasTokenSet()) { append(std::string("$") diff --git a/src/gpu/intel/gemm/jit/dsl/ir/fma.cpp b/src/gpu/intel/gemm/jit/dsl/ir/fma.cpp index feaf59c835b..9764d4d73d4 100644 --- a/src/gpu/intel/gemm/jit/dsl/ir/fma.cpp +++ b/src/gpu/intel/gemm/jit/dsl/ir/fma.cpp @@ -52,6 +52,7 @@ int get_simd_size(const hw_t &hw, const fma_kind_t kind, const type_t &a, } bool dpas_t::is_src_type(type_t type) { + if (type.is_bf8() || type.is_hf8()) return true; return type.is_x8() || type.is_bf16() || type.is_f16() || type.is_tf32(); } diff --git a/src/gpu/intel/gemm/jit/dsl/ir/pass/dpas.cpp b/src/gpu/intel/gemm/jit/dsl/ir/pass/dpas.cpp index 8e9d7d136ec..1ad4287fc2d 100644 --- a/src/gpu/intel/gemm/jit/dsl/ir/pass/dpas.cpp +++ b/src/gpu/intel/gemm/jit/dsl/ir/pass/dpas.cpp @@ -96,6 +96,53 @@ class dpas_atomic_mutator_t : public mul_mutator_t { } }; +class dpas_fwd_mutator_t : public mul_mutator_t { +public: + stmt_t mutate_mul_impl( + const std::vector &_entries) const override { + auto entries = _entries; + int nentries = (int)entries.size(); + stmt_t ret; + for (int i = 0; i < nentries; i++) { + auto &ei = entries[i]; + if (ei.stmt.is_empty()) continue; + if (!ei.is_dpas_8x8()) { + ret = ret.append(ei.stmt); + continue; + } + auto &cur_dst = dpas_t::arg_dst(ei.stmt); + int fwd_idx = -1; + for (int j = i + 1; j < nentries; j++) { + auto &ej = entries[j]; + if (ej.stmt.is_empty() || !ej.is_dpas_8x8()) continue; + auto &dst = dpas_t::arg_dst(ej.stmt); + auto &src0 = dpas_t::arg_src0(ej.stmt); + if (dst.is_equal(cur_dst) && src0.is_equal(cur_dst)) { + fwd_idx = j; + break; + } + } + if (fwd_idx != -1) { + auto tmp_mod = ngen::InstructionModifier(); + tmp_mod.setBranchCtrl(true); + auto fwd_attr = instruction_modifier_attr_t::make(tmp_mod); + auto s = ei.stmt; + s = fwd_attr.apply_to(s); + ret = ret.append(s); + ret = ret.append(entries[fwd_idx].stmt); + entries[fwd_idx].stmt = stmt_t(); + continue; + } + ret = ret.append(ei.stmt); + } + return ret; + } +}; + +stmt_t inject_dpas_fwd(const stmt_t &stmt) { + return dpas_fwd_mutator_t().mutate(stmt); +} + stmt_t inject_dpas_atomic(const stmt_t &stmt, bool filter_by_label) { if (filter_by_label) return dpas_atomic_mutator_t().mutate(stmt); auto ret = dpas_atomic_mutator_t().mutate_mul(stmt); diff --git a/src/gpu/intel/gemm/jit/dsl/ir/pass/dpas.hpp b/src/gpu/intel/gemm/jit/dsl/ir/pass/dpas.hpp index 932fb53fd5c..48280a47117 100644 --- a/src/gpu/intel/gemm/jit/dsl/ir/pass/dpas.hpp +++ b/src/gpu/intel/gemm/jit/dsl/ir/pass/dpas.hpp @@ -25,9 +25,10 @@ namespace ir { // Adds {Atomic} modifier to dpas/dpasw instructions when applicable. stmt_t inject_dpas_atomic(const stmt_t &stmt, bool filter_by_label = true); +// Adds {Fwd} modifier to dpas/dpasw instructions when applicable. +stmt_t inject_dpas_fwd(const stmt_t &stmt); } // namespace ir } // namespace dsl GEMMSTONE_NAMESPACE_END - #endif diff --git a/src/gpu/intel/gemm/jit/dsl/ir/send.hpp b/src/gpu/intel/gemm/jit/dsl/ir/send.hpp index dd32badb76d..ae5befa8462 100644 --- a/src/gpu/intel/gemm/jit/dsl/ir/send.hpp +++ b/src/gpu/intel/gemm/jit/dsl/ir/send.hpp @@ -36,6 +36,7 @@ enum class send_op_t { undef, atomic_add, atomic_fadd, + atomic_bfadd, atomic_cmpwr, load, load_2d, diff --git a/src/gpu/intel/gemm/jit/dsl/runtime.cpp b/src/gpu/intel/gemm/jit/dsl/runtime.cpp index 7f385f721d8..4562dad4514 100644 --- a/src/gpu/intel/gemm/jit/dsl/runtime.cpp +++ b/src/gpu/intel/gemm/jit/dsl/runtime.cpp @@ -117,9 +117,9 @@ dsl::hw_t get_hardware(cl_device_id device, cl_context context) { CL_DEVICE_FEATURE_CAPABILITIES_INTEL, sizeof(cl_bitfield), &attrs_cl, nullptr); if (err) return {}; + ngen::HW hw = ngen::getCore(product.family); - if (ngen::getCore(product.family) >= ngen::HW::XeHPC) - attr |= dsl::hw::attr_t::large_grf; + if (hw >= ngen::HW::XeHPC) attr |= dsl::hw::attr_t::large_grf; if (attrs_cl & CL_DEVICE_FEATURE_FLAG_DPAS_INTEL) attr |= dsl::hw::attr_t::systolic; @@ -128,7 +128,10 @@ dsl::hw_t get_hardware(cl_device_id device, cl_context context) { && product.family != ngen::ProductFamily::MTL) attr |= dsl::hw::attr_t::atomic_fp64; - return dsl::hw_t(product, eu_count, max_wg_size, l3_cache_size, attr); + bool is_efficient_64bit = ngen::OpenCLCodeGenerator< + ngen::HW::Unknown>::detectEfficient64Bit(context, device, hw); + return dsl::hw_t(product, eu_count, (int)max_wg_size, l3_cache_size, + is_efficient_64bit, attr); } #ifdef GEMMSTONE_WITH_BINARY_RUNTIME @@ -211,8 +214,8 @@ dsl::hw_t get_hardware(ze_device_handle_t device, ze_context_handle_t context) { } dsl::hw::attr_t attr = {}; - if (ngen::getCore(product.family) >= ngen::HW::XeHPC) - attr |= dsl::hw::attr_t::large_grf; + ngen::HW hw = ngen::getCore(product.family); + if (hw >= ngen::HW::XeHPC) attr |= dsl::hw::attr_t::large_grf; { auto deviceModPropsExt = ze_intel_device_module_dp_exp_properties_t(); @@ -248,7 +251,11 @@ dsl::hw_t get_hardware(ze_device_handle_t device, ze_context_handle_t context) { attr |= dsl::hw::attr_t::atomic_fp64; } - return dsl::hw_t(product, eu_count, max_wg_size, l3_cache_size, attr); + bool is_efficient_64bit = ngen::LevelZeroCodeGenerator< + ngen::HW::Unknown>::detectEfficient64Bit(context, device, hw); + + return dsl::hw_t(product, eu_count, max_wg_size, l3_cache_size, + is_efficient_64bit, attr); } LevelZeroKernelAndModule make_kernel(const GEMMKernelDesc &desc, diff --git a/src/gpu/intel/gemm/jit/gen_kernel.cpp b/src/gpu/intel/gemm/jit/gen_kernel.cpp index 4b239b73ae0..df1ac20a9f0 100644 --- a/src/gpu/intel/gemm/jit/gen_kernel.cpp +++ b/src/gpu/intel/gemm/jit/gen_kernel.cpp @@ -222,6 +222,16 @@ status_t gen_desc_t::finalize(const char *tags) { problem_.B.setAlignment(nstl::max(problem_.B.alignment, 16)); } + if (utils::one_of(hw_, ngen::HW::XE3P_35_10, ngen::HW::XE3P_35_11, + ngen::HW::XE3P_UNKNOWN)) { + // Use XeHPC banking if reusing XeHPC strategies (legacy mode) + if (!efficient_64b_) strategy_.raHW = ngen::HW::XeHPC; + + // Disable named barriers to avoid simulator errors, allow fallback to pvc strategies. + strategy_.namedBarriers[0] = 0; + strategy_.namedBarriers[1] = 0; + } + // Disable global k parallelization if it wouldn't be used. if (strategy_.kParallel && k_ >= 0) { auto k_min = aux_params_.k0 * aux_params_.wgK; @@ -310,26 +320,31 @@ status_t gen_desc_t::finalize(const char *tags) { if (problem_.aScale2D() && problem_.aqGroupK % minOuterProductCount(hw_, problem_, strategy_) - != 0) - return status::unimplemented; + != 0) { + if (!problem_.Ta.isF4() || !problem_.Tb.isF4()) + return status::unimplemented; + } if (problem_.bScale2D() && problem_.bqGroupK % minOuterProductCount(hw_, problem_, strategy_) - != 0) - return status::unimplemented; + != 0) { + if (!problem_.Ta.isF4() || !problem_.Tb.isF4()) + return status::unimplemented; + } - // If the M/N group size is equal to M or N, align up to a multiple of unroll size. - // Currently this is incompatible with precomputed reductions. - // XXX: Increase group size to a large value before aligning to increase reusability. + // TODO: Fix kChain handling with BDPAS. + if (problem_.preferBDPAS(hw_)) { strategy_.kChain = 1; } + + // If the M/N group size is equal to M or N, align up to a multiple of unroll size + // XXX: Increase group size to a large value before aligning to increase reusability + // TODO: Refactor M/N groups/thread setting to preserve MN group count. constexpr int perMNGroupSize = 1 << 24; - if (problem_.aqGroupM == m_ - && (!problem_.forceGroupSumsA || problem_.aqGroupM > 1)) { + if (problem_.aqGroupM == m_ && (!problem_.preferBDPAS(hw_) || m_ > 1)) { problem_.aqGroupM = std::max(problem_.aqGroupM, perMNGroupSize); problem_.aqGroupM = utils::rnd_up(problem_.aqGroupM, strategy_.unroll[LoopM]); } - if (problem_.bqGroupN == n_ - && (!problem_.forceGroupSumsB || problem_.bqGroupN > 1)) { + if (problem_.bqGroupN == n_ && (!problem_.preferBDPAS(hw_) || n_ > 1)) { problem_.bqGroupN = std::max(problem_.bqGroupN, perMNGroupSize); problem_.bqGroupN = utils::rnd_up(problem_.bqGroupN, strategy_.unroll[LoopN]); @@ -357,6 +372,9 @@ void gen_desc_t::update_driver_info() { REG_XEHPC_ISA(ARCH_DISPATCH(XeHPC)) REG_XE2_ISA(ARCH_DISPATCH(Xe2)) REG_XE3_ISA(ARCH_DISPATCH(Xe3)) + REG_XE3P_ISA(ARCH_DISPATCH(XE3P_35_10)) + REG_XE3P_ISA(ARCH_DISPATCH(XE3P_35_11)) + REG_XE3P_ISA(ARCH_DISPATCH(XE3P_UNKNOWN)) default: assert(!"Unsupported architecture"); driver_info_ = entry_->driverInfo; @@ -386,6 +404,11 @@ gen_nocopy_desc_t::select_kernel(compute::gpu_arch_t arch, int stepping, // Select a kernel from the catalog. std::vector match_params; MatchParams base(hw_, has_systolic, is_integrated, problem); + /* Reuse PVC strategies for legacy mode on Xe3p */ + if (utils::one_of(hw_, ngen::HW::XE3P_35_10, ngen::HW::XE3P_35_11, + ngen::HW::XE3P_UNKNOWN) + && !efficient_64b_) + base.selector.hw = kcatalog::HWTagXeHPC; // By default gemmstone assumes that the accumulation type must be at least // as wide as the output type. For oneDNN this restriction is not needed. @@ -725,6 +748,9 @@ void gen_xe_systolic_kernel_desc_t::choose_unrolls(compute::gpu_arch_t arch, case compute::gpu_arch_t::xe_hpc: case compute::gpu_arch_t::xe2: case compute::gpu_arch_t::xe3: + case compute::gpu_arch_t::xe3p_35_10: + case compute::gpu_arch_t::xe3p_35_11: + case compute::gpu_arch_t::xe3p_35_unknown: if (utils::one_of(a_type, f16, bf16)) { if (unroll_m != 0) unroll_n = (unroll_m > 16) ? 32 : 16; @@ -810,6 +836,7 @@ void gen_kernel_t::init_interface() { || problem.needsBGroupSums()) { interface_.newArgument("ldbq", DataType::d); } + if (problem.hasCMXScale()) interface_.newArgument("ldcq", DataType::d); if (problem.usesCOPtr()) { interface_.newArgument( @@ -931,6 +958,7 @@ void gen_kernel_t::init_interface() { if (desc()->hw_ >= HW::XeHPG) interface_.allowArgumentRearrangement(false); interface_.externalName(kernel_name()); + interface_.setEfficient64Bit(desc_.efficient_64b_); } dsl::kernel_t get_dsl_kernel(const GEMMProblem &problem, @@ -987,6 +1015,9 @@ status_t gen_kernel_t::get_kernel( REG_XEHPC_ISA(ARCH_DISPATCH(XeHPC)) REG_XE2_ISA(ARCH_DISPATCH(Xe2)) REG_XE3_ISA(ARCH_DISPATCH(Xe3)) + REG_XE3P_ISA(ARCH_DISPATCH(XE3P_35_10)) + REG_XE3P_ISA(ARCH_DISPATCH(XE3P_35_11)) + REG_XE3P_ISA(ARCH_DISPATCH(XE3P_UNKNOWN)) default: assert(!"Unsupported architecture"); break; } } catch (const ngen::out_of_registers_exception &err) { diff --git a/src/gpu/intel/gemm/jit/gen_kernel.hpp b/src/gpu/intel/gemm/jit/gen_kernel.hpp index 1a0a5b8110c..122781c864d 100644 --- a/src/gpu/intel/gemm/jit/gen_kernel.hpp +++ b/src/gpu/intel/gemm/jit/gen_kernel.hpp @@ -96,6 +96,10 @@ struct gen_desc_t { problem_ = problem; } + void set_efficient_64b(bool efficient_64b) { + efficient_64b_ = efficient_64b; + } + protected: compute::gpu_arch_t arch_; ngen::HW hw_ = ngen::HW::Unknown; @@ -106,6 +110,8 @@ struct gen_desc_t { gemmstone::EvaluateAuxOutput aux_params_; gemmstone::CommonDriverInfo driver_info_; + bool efficient_64b_ = false; + /* optional information to fine-tune kernel */ int m_ = -1, n_ = -1, k_ = -1; int eu_count_ = -1; diff --git a/src/gpu/intel/gemm/jit/generator/microkernel_selector.cpp b/src/gpu/intel/gemm/jit/generator/microkernel_selector.cpp index 810b7e20bfd..85113af86b4 100644 --- a/src/gpu/intel/gemm/jit/generator/microkernel_selector.cpp +++ b/src/gpu/intel/gemm/jit/generator/microkernel_selector.cpp @@ -356,6 +356,9 @@ Package selectGEMM(const GEMMOptions &options, HWInformation hwInfo, SizeParams REG_XEHPC_ISA(ARCH_DISPATCH(XeHPC)) REG_XE2_ISA(ARCH_DISPATCH(Xe2)) REG_XE3_ISA(ARCH_DISPATCH(Xe3)) + REG_XE3_ISA(ARCH_DISPATCH(XE3P_35_10)) + REG_XE3_ISA(ARCH_DISPATCH(XE3P_35_11)) + REG_XE3_ISA(ARCH_DISPATCH(XE3P_UNKNOWN)) default: throw std::runtime_error("Unsupported architecture"); } #undef ARCH_DISPATCH diff --git a/src/gpu/intel/gemm/jit/generator/pieces/address_setup.cxx b/src/gpu/intel/gemm/jit/generator/pieces/address_setup.cxx index 241ae8d3b3d..baf56559ab6 100644 --- a/src/gpu/intel/gemm/jit/generator/pieces/address_setup.cxx +++ b/src/gpu/intel/gemm/jit/generator/pieces/address_setup.cxx @@ -229,7 +229,7 @@ void Generator::setupAddr(Type T, const GRFRange &addr, const BO &ptr, const eadd(simd2, addr[2].uq(), addr[udStride].ud(0)(udStride), ptrShifted, strategy, state); eadd(simd1, addr[0].uq(), addr[0].ud(0)(udStride), ptrShifted, strategy, state); } else if (ptrShifted != 0) { - if (consecutive > 1 || tblock > 1) + if (consecutive > 1 || tblock > 1 || hw >= HW::XE3P_35_10) { mulConstant(simdSize, addr, iv, stride); add(simdSize, addr, addr, ptrShifted); @@ -378,8 +378,11 @@ void Generator::setupAddr(Type T, const GRFRange &addr, const BO &ptr, const auto pitch = bw * bcount * block.ebytes; if (pitch < 64 || pitch & 0xF) hw_unsupported(); mov(1, addr[0].ud(4), pitch - 1); - } else + } else { add(1, addr[0].ud(4), bld, -1); + if (!doBaseAdjust) + max_(1, addr[0].ud(4), addr[0].ud(4), addr[0].ud(2)); + } mov(1, addr[0].ud(7), (bw - 1) | ((bh - 1) << 8) | ((bcount - 1) << 16)); diff --git a/src/gpu/intel/gemm/jit/generator/pieces/c_update.cxx b/src/gpu/intel/gemm/jit/generator/pieces/c_update.cxx index c8002f5696e..7b09d848b97 100644 --- a/src/gpu/intel/gemm/jit/generator/pieces/c_update.cxx +++ b/src/gpu/intel/gemm/jit/generator/pieces/c_update.cxx @@ -1686,9 +1686,9 @@ void Generator::doAlternateCRemainder(COperation op, const GEMMProblem &prob bool nonuniformSubs = false; if (!uniform) { - static constexpr int maxGRFs = 256; - uint8_t baseIndices[maxGRFs] = {0}; - uint16_t offIndices[maxGRFs] = {0}; + int maxGRFs = ( hw == HW::XE3P_35_11 ? 512 : 256 ); + std::vector baseIndices(maxGRFs, 0); + std::vector offIndices(maxGRFs, 0); // Workaround for spurious maybe-uninitialized warning in GCC11 for (int i = 0; i < maxGRFs; i++) offIndices[i] = 0; diff --git a/src/gpu/intel/gemm/jit/generator/pieces/common.cxx b/src/gpu/intel/gemm/jit/generator/pieces/common.cxx index d1a04fe9ff6..a3c7bd17748 100644 --- a/src/gpu/intel/gemm/jit/generator/pieces/common.cxx +++ b/src/gpu/intel/gemm/jit/generator/pieces/common.cxx @@ -668,6 +668,8 @@ void Generator::initState(const CommonProblem &problem, const CommonStrategy state.ra.setRegisterCount(strategy.GRFs); state.tokenAllocator = TokenAllocator(hw, strategy.GRFs); + setEfficient64Bit(interface.getEfficient64Bit()); + if (problem.gtpinSupport) interface.requireScratch(128); diff --git a/src/gpu/intel/gemm/jit/generator/pieces/compute_utils.hpp b/src/gpu/intel/gemm/jit/generator/pieces/compute_utils.hpp index a635dcd8208..48ecba05e25 100644 --- a/src/gpu/intel/gemm/jit/generator/pieces/compute_utils.hpp +++ b/src/gpu/intel/gemm/jit/generator/pieces/compute_utils.hpp @@ -34,11 +34,12 @@ struct SystolicParams { int opsPerChan; // # of FMAs/stage int sdepth; // Number of stages (systolic depth) int rcountMax; // Maximum repeat count (# of RHS) + int rcountMin; // Minimum repeat count (# of RHS) int ksys; // Total number of FMAs int osys; // Output vector length }; -static inline SystolicParams systolicParams(ngen::HW hw, GEMMProblem problem, const GEMMStrategy &strategy) +static inline SystolicParams systolicParams(ngen::HW hw, GEMMProblem problem) { problem.autoTypeConversions(hw, true); @@ -48,6 +49,7 @@ static inline SystolicParams systolicParams(ngen::HW hw, GEMMProblem problem, co params.ksys = params.sdepth * params.opsPerChan; params.osys = ngen::GRF::bytes(hw) / std::max(problem.Tc_compute().real().size(), 4); params.rcountMax = 8; + params.rcountMin = problem.preferBDPAS(hw) ? 8 : 0; return params; } @@ -56,7 +58,7 @@ static inline SystolicParams systolicParams(ngen::HW hw, GEMMProblem problem, co static inline int minOuterProductCount(ngen::HW hw, const GEMMProblem &problem, const GEMMStrategy &strategy) { if (strategy.systolic) { - auto params = systolicParams(hw, problem, strategy); + auto params = systolicParams(hw, problem); return params.ksys; } int kfma = std::max(strategy.dotVL, 1); @@ -117,14 +119,15 @@ static inline std::tuple targetSLMCrosspack(ngen::HW hw, const GEMMProb static inline std::tuple targetKernelTiling(ngen::HW hw, const GEMMProblem &problem, const GEMMStrategy &strategy) { if (strategy.systolic) { - auto params = systolicParams(hw, problem, strategy); + auto params = systolicParams(hw, problem); bool cColMajor = isRegisterColMajor(problem.Tc, problem.C, strategy.C); auto tileO_V = params.osys; auto tileI_N = params.ksys; + auto tileI_K = params.rcountMin; if (strategy.unroll[cColMajor ? LoopN : LoopM] == 1) tileI_N = 0; - return cColMajor ? std::make_tuple(tileO_V, 0, tileI_N, 0) - : std::make_tuple(0, tileI_N, 0, tileO_V); + return cColMajor ? std::make_tuple(tileO_V, 0, tileI_N, tileI_K) + : std::make_tuple(tileI_K, tileI_N, 0, tileO_V); } return std::make_tuple(0,0,0,0); } diff --git a/src/gpu/intel/gemm/jit/generator/pieces/copy_plan.cpp b/src/gpu/intel/gemm/jit/generator/pieces/copy_plan.cpp index d593d93adc1..6aeaaf5a060 100644 --- a/src/gpu/intel/gemm/jit/generator/pieces/copy_plan.cpp +++ b/src/gpu/intel/gemm/jit/generator/pieces/copy_plan.cpp @@ -281,9 +281,12 @@ void CopyPlan::transform() optimizeWriteCombine(); optimizeWriteSpread(); + legalizeImmediateTypes(); + sort(SortType::PhaseOnly); - legalizeImmediateTypes(); + legalizeShfl(); + #if GEMMSTONE_ENABLE_COPY_PLAN_DUMP if (getVerbose(GEMMVerbose::DebugInfo) >= 170) @@ -419,9 +422,9 @@ CopyOperand CopyPlan::getResource(CopyResource::Kind kind) res = &resources.back(); } if (!res->src) { - std::array data; + auto data = res->getData(); res->preinitialized = false; - if (int n = res->getData(data)) + if (int n = std::get<1>(data)) res->src = newTemp(DataType::ud, (n + 3) >> 2, 1); } return res->src; @@ -725,10 +728,13 @@ void CopyPlan::split2DRegions() for (auto &i: insns) { if ((is2D(i.dst) && !is4(i.dst.type)) || is2D(i.src1) || is2D(i.src2)) stub("Unsupported 2D region"); - if (is2D(i.src0)) { + if (is2D(i.src0)){ + if(i.dst.stride > 4) + continue; if (i.flag) stub("Unsupported predication"); int w = i.src0.width, vs = i.src0.vs, hs = i.src0.stride; - bool splitH = (w * w >= i.simd); + bool is_xe3p = one_of(hw, {ngen::HW::XE3P_35_10, ngen::HW::XE3P_35_11, ngen::HW::XE3P_UNKNOWN}); + bool splitH = (w * w >= i.simd || (is_xe3p && i.dst.stride * w >= 8)); int nsplit = splitH ? (i.simd / w) : w; i.simd /= nsplit; i.src0.stride = splitH ? hs : vs; @@ -786,6 +792,11 @@ void CopyPlan::planTypeConversions() if (st == dt) i.moveToIntegerPipe(); + bool is_xe3p = one_of(hw, {ngen::HW::XE3P_35_10, ngen::HW::XE3P_35_11, ngen::HW::XE3P_UNKNOWN}); + if (is_xe3p && is4(st) && one_of(getBits(dt), {8, 16})) + if (planShflUpconvertXe3p(i)) + continue; + if (is4(st) && one_of(dt, {ngen_b16_h4x(), ngen_b16_l4x()})) plan4BitShifts(i); else if (isInt4(st) && isInt4(dt) && st != dt) { @@ -955,6 +966,105 @@ void CopyPlan::planTypeConversions() } } +// Upconvert 4-bit types to 8/16 bits using shfl.idx4. +bool CopyPlan::planShflUpconvertXe3p(CopyInstruction &i) +{ + // Cases handled: (with 16-bit upconversion; 8-bit similar) + // 1a) + // mov y:uw<1> x:u4<8;2,1> --> shfl.idx4 y.0:ud<1> lut:ud x.(n/2):ub<4> + // 1b) + // mov y.0:uw<2> x.n:u4<8> --> same as 1a + // mov y.1:uw<2> x.(n+1):u4<8> + // 2) + // mov y:uw<1> x:u4<1> --> mov y:ub<4> x:ub<1> + // mov y:uw<1> x:u4<8;2,1> (--> case 1a) + // 3) + // mov y:uw x:u4<1> --> mov t:uw<1> x:u4<1> (--> case 2 (<1>), 1a (<8;2,1>)) + // OR <8;2,1> mov y:uw t:uw<1> + // + // If dst is integral, only use shfl.idx4 in case 1 and only when src and dst have valid offsets for shfl.idx4. + + auto st = i.src0.type, dt = i.dst.type; + bool _16 = (getBytes(dt) == 2); + + bool laneAligned = (i.src0.vs == 8 && i.src0.width * getBytes(dt) == 8 && i.src0.stride == 1); + if ((i.src0.vs || i.src0.width) && !laneAligned) + return false; /* unsupported 2D region */ + if (!laneAligned && i.src0.stride != 1) + return false; /* expect stride 1 */ + if (i.src0.offset & (_16 ? 1 : 3)) + return false; /* unaligned input */ + + auto x = i.src0, y = i.dst; + bool copySrc = !laneAligned || x.byteOffset() >= 4; + bool copyDst = (y.stride != 1 || y.offset != 0); + + if (isInt(dt) && (copySrc || copyDst)) + return false; /* use normal sequence */ + + if (i.simd < 16) return false; + auto lut = getResource(CopyResource::makeShflLUT(st, dt)); + if (!lut) + return false; /* no LUT available */ + lut.type = DataType::ud; + lut.stride = 0; /* will be fixed up later */ + + int orig_simd = i.simd; + if (copySrc){ + i.simd /= 2; + i.simd = std::max(16, i.simd); + } + + auto ie = splitMultiple<3>(i); + + x.offset >>= (_16 ? 1 : 2); + x.type = (_16 ? DataType::ub : DataType::uw); + + if (copyDst) + y = newTemp(dt, i.simd, 1); + + if (copySrc) { + ie[0]->op = Opcode::mov; + ie[0]->src0 = x; + x = y; + x.type = (_16 ? DataType::ub : DataType::uw); + x.stride = (_16 ? 4 : 2); + ie[0]->dst = x; + } + + ie[1]->op = Opcode::shfl; + ie[1]->dst = y; + ie[1]->dst.type = DataType::ud; + ie[1]->src0 = lut; + ie[1]->src1 = x; + + if (copyDst) { + ie[2]->op = Opcode::mov; + ie[2]->src0 = y; + ie[2]->simd = orig_simd; + } else + ie[2]->invalidate(); + + return true; +} + +void CopyPlan::legalizeShfl() +{ + for (auto &i: insns) { + if (i.op != Opcode::shfl) continue; + if (!one_of(i.simd, {16, 32})) stub(); + if (i.simd == 32) { + i.src0.stride = 1; + i.src0.width = 16; + i.src0.vs = 0; + }else{ + i.src0.stride = 0; + i.src0.width = 1; + i.src0.vs = 1; + } + } +} + // Unpack 4-bit src type into 16 bits (zero extended), used in many conversion sequences. void CopyPlan::planUnpack4To16(CopyInstruction &i) { @@ -1118,7 +1228,7 @@ bool CopyPlan::bfArithmeticOK(const CopyInstruction &i) const CopyOperand CopyPlan::bfImmediate(uint16_t bits, bool ternary) { - if (ternary) { + if (ternary) { auto kind = CopyResource::makeConstant32(uint32_t(bits) << 16); auto val = getResource(kind); val.stride = 0; @@ -1131,6 +1241,7 @@ CopyOperand CopyPlan::bfImmediate(uint16_t bits, bool ternary) } }; + // {b,ub}->bf sequence. void CopyPlan::planInt8ToBF(CopyInstruction &i) { @@ -1138,6 +1249,11 @@ void CopyPlan::planInt8ToBF(CopyInstruction &i) copyThrough(i, DataType::f); return; } + bool is_xe3p = one_of(hw, {ngen::HW::XE3P_35_10, ngen::HW::XE3P_35_11, ngen::HW::XE3P_UNKNOWN}); + if (is_xe3p){ + copyThrough(i, DataType::f); + return; + } auto ie = splitMultiple<3>(i); @@ -1172,6 +1288,27 @@ void CopyPlan::planInt8ToBF(CopyInstruction &i) ie[2]->src1 = Immediate::hf(0x4000); } +void CopyPlan::legalizeBfImmediate(CopyInstruction &i1){ + if (i1.src1.kind != CopyOperand::Immediate) return; + auto op = i1.op; + auto temp = newTemp(DataType::uw, i1.simd, 1); + auto src0 = i1.src0; + auto dst = i1.dst; + + i1.op = Opcode::mov; + i1.dst = temp; + i1.src0 = Immediate::uw(i1.src1.value >> 16); + i1.src0.type = DataType::uw; + + auto &i2 = split(i1); + + i2.op = op; + i2.dst = dst; + i2.src0 = src0; + i2.src1 = temp; + i2.src1.type = DataType::bf; +} + // s4/u4 -> hf/bf sequence. void CopyPlan::planInt4ToF16(CopyInstruction &i) { @@ -1291,6 +1428,10 @@ void CopyPlan::planInt4Upconversion(CopyInstruction &i) if (i.src0.neg || i.hasCMod()) stub("Unsupported modifier"); i.sat = false; + if (hw >= HW::XE3P_35_10 && one_of(getBits(i.dst.type), {8, 16})) + if (planShflUpconvertXe3p(i)) + return; + bool s4 = (i.src0.type == DataType::s4); if (i.src0.stride == 1 && i.simd > 1) { @@ -1339,8 +1480,14 @@ void CopyPlan::planInt4Upconversion(CopyInstruction &i) } } else { bool even = (i.src0.offset % 2 == 0); + if ( i.dst.stride > 4 ) stub("Unsupported stride."); i.src0.stride /= 2; i.src0.offset /= 2; + if ( getBits(i.dst.type) < 8 ) { + i.dst.type = DataType::ub; + i.dst.stride /= 2; + i.dst.offset /= 2; + } if (even) { // Low nybbles @@ -1388,31 +1535,26 @@ void CopyPlan::plan4BitShifts(CopyInstruction &i) bool high = (i.dst.type == ngen_b16_h4x()); - std::array ie = {&i, nullptr}; - + bool even = (i.src0.offset % 2 == 0); + i.op = high ? Opcode::shl : Opcode::shr; + i.dst.type = DataType::uw; + i.src0.type = DataType::ub; + i.src0.offset /= 2; // Split into high and low nybble conversions if both are present. - if (i.src0.stride == 1) { - auto &i0 = i; - i0.dst.stride *= 2; - i0.src0.stride *= 2; - i0.simd /= 2; + if (i.src0.stride == 1 && i.simd > 1) { + i.src0.stride = 1; + i.src1 = high ? 12 + : 0; + i.simd /= 2; + i.dst.stride *= 2; auto &i1 = split(i, false); i1.dst.offset += i1.dst.stride / 2; - i1.src0.offset += i1.src0.stride / 2; - ie[1] = &i1; - } - - // Convert to shifts. - for (auto ip: ie) if (ip) { - bool even = (ip->src0.offset % 2 == 0); + i1.src1 = high ? 8 : 4; - ip->op = high ? Opcode::shl : Opcode::shr; - ip->src0.type = DataType::ub; - ip->src0.stride /= 2; - ip->src0.offset /= 2; - ip->src1 = high ? (even ? 12 : 8) - : (even ? 0 : 4); - ip->dst.type = DataType::uw; + } else { + i.src0.stride /= 2; + i.src1 = high ? (even ? 12 : 8) + : (even ? 0 : 4); } } @@ -1423,6 +1565,7 @@ void CopyPlan::planInt4Downconversion(CopyInstruction &i) auto st = i.src0.type, dt = i.dst.type; bool s4 = (dt == DataType::s4); + if (!one_of(dt, {DataType::s4, DataType::u4})) stub(); if (isD(st) || isQ(st)) { copyThrough(i, (isSigned(st) && s4) ? DataType::w : DataType::uw, 1); return; @@ -1432,6 +1575,8 @@ void CopyPlan::planInt4Downconversion(CopyInstruction &i) int tmp_elems = ddst.stride > 4 ? simd * 2 : simd; auto tmp = newTemp(DataType::uw, tmp_elems, 1); + int sStride = ssrc.stride * getBytes(ssrc.type) * 2; + int dStride = ddst.stride / (getBytes(ssrc.type) * 2); if (i.sat) { auto ie = splitMultiple<3>(i); auto ssrc = i.src0; @@ -1470,8 +1615,6 @@ void CopyPlan::planInt4Downconversion(CopyInstruction &i) auto ie = splitMultiple<5>(i); auto osrc = i.src0; auto stmp = newTemp(DataType::uw, simd, 1); - int sStride = ssrc.stride * getBytes(ssrc.type) * 2; - int dStride = ddst.stride / (getBytes(ssrc.type) * 2); ie[0]->op = Opcode::mov; ie[0]->dst = stmp; @@ -1560,6 +1703,7 @@ void CopyPlan::planInt4Downconversion(CopyInstruction &i) ie[1]->src0 = stmp; ie[1]->src0.offset += 1; ie[1]->src0.stride *= 2; + ie[1]->src1 = Immediate::uw(0x4); ie[2]->op = Opcode::bfn; @@ -1647,6 +1791,7 @@ void CopyPlan::planInt4Downconversion(CopyInstruction &i) ie[4]->src0 = tmp; ie[4]->src0.type = DataType::ub; ie[4]->src0.stride *= 2; + } } @@ -2083,86 +2228,170 @@ void CopyPlan::planEmulatedHFToF4(CopyInstruction &i) return; } - auto ie = splitMultiple<11>(i); + if (hw >= HW::XE3P_35_10) { + auto t0 = newTemp(DataType::hf, i.simd/2, 1); + auto t1 = newTemp(DataType::hf, i.simd/2, 1); + auto ie = splitMultiple<5>(i); + int simd = i.simd; + int dstStride = y.stride; + bool needPack = (y.stride > 1 || y.width > 2); - auto t0 = newTemp(DataType::hf, i.simd, 1); - auto t1 = newTemp(DataType::hf, i.simd, 1); - auto t0UW = t0, t1UW = t1; - t0UW.type = t1UW.type = DataType::uw; - - auto flag = newFlag(i.simd); - - // Clamp and round. - ie[0]->op = Opcode::mad; - ie[0]->cmod = ConditionModifier::lt; - ie[0]->flag = flag; - ie[0]->dst = t1; - ie[0]->src0 = Immediate::hf(e2m1 ? 0x8004 : 0x8002); - ie[0]->src1 = abs(x); - ie[0]->src2 = Immediate::hf(e2m1 ? 0x0002 : 0x0004); + ie[0]->op = Opcode::mov; + ie[0]->simd = simd / 4; + ie[0]->dst = t0; + ie[0]->dst.stride = 1; + ie[0]->dst.type = DataType::ud; + ie[0]->src0 = x; + ie[0]->src0.type = DataType::ud; + ie[0]->src0.stride = 2; - ie[1]->op = Opcode::sel; - ie[1]->cmod = ConditionModifier::lt; - ie[1]->dst = t0; - ie[1]->src0.abs = true; - ie[1]->src1 = Immediate::hf(e2m1 ? 0x4600 : 0x4C00); + ie[1]->op = Opcode::mov; + ie[1]->simd = simd / 4; + ie[1]->dst = t1; + ie[1]->dst.type = DataType::ud; + ie[1]->dst.stride = 1; + ie[1]->src0 = x; + ie[1]->src0.type = DataType::ud; + ie[1]->src0.offset = 1; + ie[1]->src0.stride = 2; + + ie[2]->op = Opcode::dnscl; + ie[2]->simd = simd / 4; + ie[2]->dst = t0; + ie[2]->dst.type = y.type; + ie[2]->dst.stride = 1; + ie[2]->src0 = t0; + ie[2]->src0.type = x.type; + ie[2]->src0.stride = 1; + ie[2]->src1 = t1; + ie[2]->src1.type = x.type; + ie[2]->src1.stride = 1; + ie[2]->src2.type = DataType::ud; + + ie[3]->op = Opcode::mov; + ie[3]->simd = simd / 2; + ie[3]->dst = needPack ? t0 : y; + ie[3]->dst.type = DataType::ub; + ie[3]->dst.offset = needPack ? 0 : y.offset / 2; + ie[3]->dst.stride = needPack ? 1 : std::max(y.vs / 2, 1); + ie[3]->src0 = t0; + ie[3]->src0.type = DataType::ub; + ie[3]->src0.stride = 2; + + if ( needPack ){ + ie[4]->op = Opcode::mov; + ie[4]->dst = y; + ie[4]->dst.type = DataType::u4; + ie[4]->dst.stride = dstStride; + ie[4]->src0 = t0; + ie[4]->src0.type = DataType::u4; + ie[4]->src0.stride = 1; + } else { + ie[4]->invalidate(); + } - ie[2]->op = Opcode::mul; - ie[2]->src0 = ie[2]->dst = t0; - ie[2]->src1 = Immediate::hf(e2m1 ? 0x0400 : 0x0C00); + } else + { + auto ie = splitMultiple<13>(i); - ie[3]->op = Opcode::mad; - ie[3]->flag = flag; - ie[3]->dst = t0; - ie[3]->src0 = Immediate::hf(0x0800); - ie[3]->src1 = t1; - ie[3]->src2 = Immediate::hf(e2m1 ? 0x6000 : 0x6400); - ie[4]->op = Opcode::add; - ie[4]->src0 = ie[4]->dst = t0UW; - ie[4]->src1 = Immediate::w(e2m1 ? -0x0100 : -0x200); + auto t0 = newTemp(DataType::hf, i.simd, 1); + auto t1 = newTemp(DataType::hf, i.simd, 1); + auto t0UW = t0, t1UW = t1; + t0UW.type = t1UW.type = DataType::uw; - ie[5]->op = Opcode::and_; - ie[5]->flag = flag; - ie[5]->cmod = ConditionModifier::nz; - ie[5]->dst = CopyOperand(); - ie[5]->dst.type = DataType::uw; - ie[5]->src0 = t0UW; - ie[5]->src1 = Immediate::uw(e2m1 ? 0x03FF : 0x07FF); - - ie[6]->op = Opcode::add; - ie[6]->flag = flag; - ie[6]->src0 = ie[6]->dst = t0UW; - ie[6]->src1 = Immediate::uw(e2m1 ? 0x0200 : 0x0400); - - ie[7]->op = Opcode::shl; - ie[7]->src0 = ie[7]->dst = t0UW; - ie[7]->src1 = Immediate::uw(e2m1 ? 3 : 2); - - // Restore sign. - ie[8]->op = Opcode::bfn; - ie[8]->src0 = x; - ie[8]->src0.type = DataType::uw; - ie[8]->src1 = ie[8]->dst = t0UW; - ie[8]->src2 = 0x8000; - ie[8]->ctrl = 0xEC; - - // Pack into bytes. - ie[9]->op = Opcode::shr; - ie[9]->src0 = ie[9]->dst = t0UW; - ie[9]->src1 = Immediate::uw(12); + auto flag = newFlag(i.simd); - ie[10]->op = Opcode::mov; - ie[10]->dst = y; - ie[10]->dst.type = DataType::u4; - ie[10]->src0 = t0UW; + // Clamp and round. + ie[0]->op = Opcode::mad; + ie[0]->cmod = ConditionModifier::lt; + ie[0]->flag = flag; + ie[0]->dst = t1; + ie[0]->src0 = Immediate::hf(e2m1 ? 0x8004 : 0x8002); + ie[0]->src1 = abs(x); + ie[0]->src2 = Immediate::hf(e2m1 ? 0x0002 : 0x0004); + + ie[1]->op = Opcode::sel; + ie[1]->cmod = ConditionModifier::lt; + ie[1]->dst = t0; + ie[1]->src0.abs = true; + ie[1]->src1 = Immediate::hf(e2m1 ? 0x4600 : 0x4C00); + + ie[2]->op = Opcode::mul; + ie[2]->src0 = ie[2]->dst = t0; + ie[2]->src1 = Immediate::hf(e2m1 ? 0x0400 : 0x0C00); + + ie[3]->op = Opcode::mad; + ie[3]->flag = flag; + ie[3]->dst = t0; + ie[3]->src0 = Immediate::hf(0x0800); + ie[3]->src1 = t1; + ie[3]->src2 = Immediate::hf(e2m1 ? 0x6000 : 0x6400); + + ie[4]->op = Opcode::add; + ie[4]->src0 = ie[4]->dst = t0UW; + ie[4]->src1 = Immediate::w(e2m1 ? -0x0100 : -0x200); + + if (e2m1) { + ie[5]->invalidate(); + ie[6]->invalidate(); + } else { + ie[5]->op = Opcode::cmp; + ie[5]->cmod = ConditionModifier::gt; + ie[5]->flag = flag; + ie[5]->dst = CopyOperand(); + ie[5]->dst.type = DataType::hf; + ie[5]->src0 = t0; + ie[5]->src1 = Immediate::hf(0x0200); + + ie[6]->op = Opcode::or_; + ie[6]->flag = flag; + ie[6]->dst = t0UW; + ie[6]->src0 = t0UW; + ie[6]->src1 = 0x1; + } + ie[7]->op = Opcode::and_; + ie[7]->flag = flag; + ie[7]->cmod = ConditionModifier::nz; + ie[7]->dst = CopyOperand(); + ie[7]->dst.type = DataType::uw; + ie[7]->src0 = t0UW; + ie[7]->src1 = Immediate::uw(e2m1 ? 0x03FF : 0x07FF); + + ie[8]->op = Opcode::add; + ie[8]->flag = flag; + ie[8]->src0 = ie[8]->dst = t0UW; + ie[8]->src1 = Immediate::uw(e2m1 ? 0x0200 : 0x0400); + + ie[9]->op = Opcode::shl; + ie[9]->src0 = ie[9]->dst = t0UW; + ie[9]->src1 = Immediate::uw(e2m1 ? 3 : 2); + + // Restore sign. + ie[10]->op = Opcode::bfn; + ie[10]->src0 = ie[10]->dst = t0UW; + ie[10]->src1 = x; + ie[10]->src1.type = DataType::uw; + ie[10]->src2 = 0x8000; + ie[10]->ctrl = 0xCA; + + // Pack into bytes. + ie[11]->op = Opcode::shr; + ie[11]->src0 = ie[11]->dst = t0UW; + ie[11]->src1 = Immediate::uw(12); + + ie[12]->op = Opcode::mov; + ie[12]->dst = y; + ie[12]->dst.type = DataType::u4; + ie[12]->src0 = t0UW; + } } // Check that no types smaller than a byte are present. void CopyPlan::checkNoSubbytes() { for (auto &i: insns) - if (is4(i.dst.type) || is4(i.src0.type) || is4(i.src1.type) || is4(i.src2.type)) + if ((is4(i.dst.type) && i.op != Opcode::dnscl) || is4(i.src0.type) || is4(i.src1.type) || is4(i.src2.type)) stub("Unexpected 4-bit type"); } @@ -2266,6 +2495,9 @@ void CopyPlan::legalizeSIMD(bool initial) // Fracture instruction into legal SIMD lengths. int simd0 = std::min(rounddown_pow2(i.simd), simdMax); + bool is_xe3p = one_of(hw, {ngen::HW::XE3P_35_10, ngen::HW::XE3P_35_11, ngen::HW::XE3P_UNKNOWN}); + if (is_xe3p && simd0 == 2) simd0 = 1; + if (!initial && forceSIMD1(i)) simd0 = 1; @@ -2337,6 +2569,17 @@ inline bool legalPackedBF(HW hw, const CopyOperand &op) return (op.stride == 1 && (op.offset & (align - 1)) == 0); } +inline bool isCommutative(Opcode op) { + switch (op) { + case Opcode::add: + case Opcode::mul: + case Opcode::and_: + case Opcode::or_: + case Opcode::xor_: + default: return false; + } +} + void CopyPlan::planEmulatedSIMD1(CopyInstruction &i) { // Convert SIMD1 instruction to SIMD2. @@ -2367,6 +2610,8 @@ void CopyPlan::legalizeRegions() auto dt = i.dst.type; if (!i.dst && (hw < ngen::HW::XeHPC || i.op != Opcode::cmp)) continue; + if (i.dst.stride == 0) stub("Illegal dst stride"); + if (isFP4(dt)) continue; /* Check for special packed conversion cases */ if (i.op == Opcode::mov && ((s0t == DataType::hf && isFP8(dt)) @@ -2468,6 +2713,9 @@ void CopyPlan::legalizeRegions() } } + int dstBO = i.dst.byteOffset(); + int dstBS = i.dst.byteStride(); + /* Check for swizzling */ bool canSwizzle = true, splitQWMov = false; if (hw >= HW::XeHP) { @@ -2480,12 +2728,26 @@ void CopyPlan::legalizeRegions() if (isFP(dt)) canSwizzle = false; } + if (hw >= HW::XE3P_35_10) { + auto isFlat = [&] (const CopyOperand &op) { + if (!op) return true; + if (isBroadcast(op)) return true; + auto bo = op.byteOffset(); + auto bs = op.byteStride(); + return (bo == dstBO) && (bs == dstBS); + }; + + if (!isFlat(i.src1)) { + if (isCommutative(i.op) && i.src0.kind == CopyOperand::GRF && isFlat(i.src0)) + std::swap(i.src0, i.src1); + else + canSwizzle = false; + } + } - int dstBO = i.dst.byteOffset(); int src0BO = i.src0.byteOffset(); int src1BO = i.src1.byteOffset(); int src2BO = i.src2.byteOffset(); - int dstBS = i.dst.byteStride(); int src0BS = i.src0.byteStride(); int src1BS = i.src1.byteStride(); int src2BS = i.src2.byteStride(); @@ -2523,9 +2785,11 @@ void CopyPlan::legalizeRegions() repositionDst(i, stride, offset); } continue; - } else if (src0BS < dstBS) + } else if (src0BS < dstBS){ restrideSrc0(i, dstBS >> getLog2Bytes(s0t)); - else if (src0BS > dstBS) + rerun = true; + } + else if (src0BS > dstBS) restrideDst(i, src0BS >> getLog2Bytes(dt)); } @@ -2668,6 +2932,7 @@ void CopyPlan::legalizeNegation() // Pass to legalize immediate types. void CopyPlan::legalizeImmediateTypes() { + bool is_xe3p = one_of(hw, {ngen::HW::XE3P_35_10, ngen::HW::XE3P_35_11, ngen::HW::XE3P_UNKNOWN}); for (auto &i: insns) { for (auto *op: {&i.src0, &i.src1, &i.src2}) { if (op->kind != CopyOperand::Immediate) @@ -2676,8 +2941,13 @@ void CopyPlan::legalizeImmediateTypes() op->type = DataType::uw; else if (one_of(op->type, {DataType::b, DataType::s4})) op->type = DataType::w; + else if (is_xe3p && i.op != Opcode::mov && op->type == DataType::f && i.dst.type == DataType::bf) + legalizeBfImmediate(i); } } + mergeChanges(); + legalizeRegions(); + } // Pass to sort instructions by phase and dst. @@ -2731,10 +3001,11 @@ void CopyPlan::optimizeZip(bool zip2DSrc0) if (i1.op != i2.op || i1.phase != i2.phase || i1.dst.grf != i2.dst.grf || i1.flag) break; if (i1.simd != i2.simd) continue; + bool xe3pPlus = (hw >= ngen::HW::XE3P_35_10); - auto zippable = [](const CopyOperand &o1, const CopyOperand &o2, bool zip2D = false, bool zipImm = false) { + auto zippable = [&](const CopyOperand &o1, const CopyOperand &o2, bool zip2D = false, bool zipImm = false) { if (o1.kind != o2.kind) return false; - if (o1.kind == CopyOperand::Immediate) return (o1.value == o2.value || zipImm); + if (o1.kind == CopyOperand::Immediate) return (o1.value == o2.value || (zipImm && !xe3pPlus)); if (o1.kind != CopyOperand::GRF) return true; if (o1.type != o2.type || o1.stride != o2.stride || o1.grf != o2.grf) return false; if (o1.temp != o2.temp) return false; @@ -3030,7 +3301,7 @@ void CopyPlan::optimizeWriteCombine() if (!isB(i.dst.type)) return false; if (!(isB(st) || isW(st) || isD(st) || st == DataType::f)) return false; if (multiGRF(hw, i, i.dst)) return false; - return true; + return !mayOverlap(hw, i, i.dst, i.src0); }; if (!canWC(hw, i0)) { @@ -3338,6 +3609,7 @@ void CopyPlan::materializeTemps(const GRFAllocator &grfAllocator, const FlagAllo } } + /* Update resources with assignments */ for (auto &r: resources) if (r.src.temp) { r.src.temp = false; @@ -3347,19 +3619,69 @@ void CopyPlan::materializeTemps(const GRFAllocator &grfAllocator, const FlagAllo temps.clear(); } -int CopyResource::getData(std::array &data) const + +CopyResource::Kind CopyResource::makeConstant32(uint32_t c) +{ + return static_cast(constantBase | c); +} + +std::tuple CopyResource::getData() const { - if (kind & constantBase) { + const uint8_t *base = nullptr; + size_t n = 0; + if (kind & constantBase){ + static std::array data; std::memcpy(&data[0], &kind, sizeof(uint32_t)); - return sizeof(uint32_t); + base = (const uint8_t *) &data[0]; + n=sizeof(uint32_t); + } + + switch (kind) { + case null: break; + default: { + DataType st, dt; + if (!decodeShflLUT(st, dt)) break; + +#define LUT16(ST, DT, V0, V1, V2, V3, V4, V5, V6, V7, V8, V9, VA, VB, VC, VD, VE, VF) \ + if (st == DataType::ST && dt == DataType::DT) { \ + static const uint16_t table[32] = {V0, V0, V1, V1, V2, V2, V3, V3, \ + V4, V4, V5, V5, V6, V6, V7, V7, \ + V8, V8, V9, V9, VA, VA, VB, VB, \ + VC, VC, VD, VD, VE, VE, VF, VF}; \ + base = (const uint8_t *) table; \ + n = sizeof(table); \ + } + + LUT16(e2m1, hf, 0x0, 0x3800, 0x3c00, 0x3e00, 0x4000, 0x4200, 0x4400, 0x4600, 0x8000, 0xb800, 0xbc00, 0xbe00, 0xc000, 0xc200, 0xc400, 0xc600) + LUT16(e3m0, hf, 0x0, 0x3400, 0x3800, 0x3c00, 0x4000, 0x4400, 0x4800, 0x4c00, 0x8000, 0xb400, 0xb800, 0xbc00, 0xc000, 0xc400, 0xc800, 0xcc00) + LUT16(e2m1, bf, 0x0, 0x3f00, 0x3f80, 0x3fc0, 0x4000, 0x4040, 0x4080, 0x40c0, 0x8000, 0xbf00, 0xbf80, 0xbfc0, 0xc000, 0xc040, 0xc080, 0xc0c0) + LUT16(e3m0, bf, 0x0, 0x3e80, 0x3f00, 0x3f80, 0x4000, 0x4080, 0x4100, 0x4180, 0x8000, 0xbe80, 0xbf00, 0xbf80, 0xc000, 0xc080, 0xc100, 0xc180) + + LUT16(u4, hf, 0x0, 0x3c00, 0x4000, 0x4200, 0x4400, 0x4500, 0x4600, 0x4700, 0x4800, 0x4880, 0x4900, 0x4980, 0x4a00, 0x4a80, 0x4b00, 0x4b80) + LUT16(s4, hf, 0x0, 0x3c00, 0x4000, 0x4200, 0x4400, 0x4500, 0x4600, 0x4700, 0xc800, 0xc700, 0xc600, 0xc500, 0xc400, 0xc200, 0xc000, 0xbc00) + LUT16(u4, bf, 0x0, 0x3f80, 0x4000, 0x4040, 0x4080, 0x40a0, 0x40c0, 0x40e0, 0x4100, 0x4110, 0x4120, 0x4130, 0x4140, 0x4150, 0x4160, 0x4170) + LUT16(s4, bf, 0x0, 0x3f80, 0x4000, 0x4040, 0x4080, 0x40a0, 0x40c0, 0x40e0, 0xc100, 0xc0e0, 0xc0c0, 0xc0a0, 0xc080, 0xc040, 0xc000, 0xbf80) + + break; + } } - return 0; + return std::make_tuple(base, (int) (n / sizeof(*base))); } -CopyResource::Kind CopyResource::makeConstant32(uint32_t c) +CopyResource::Kind CopyResource::makeShflLUT(DataType from, DataType to) { - return static_cast(constantBase | c); + return static_cast(shflLUTBase | static_cast(from) | (static_cast(to) << 8)); +} + +bool CopyResource::decodeShflLUT(DataType &from, DataType &to) const +{ + if (kind & shflLUTBase) { + from = static_cast(kind); + to = static_cast(kind >> 8); + return true; + } + return false; } #if GEMMSTONE_ENABLE_COPY_PLAN_DUMP @@ -3371,10 +3693,12 @@ int CopyPlan::cycleCount() const return count; } -void CopyPlan::dump() const +void CopyPlan::dump(int n) const { - for (const auto &i: insns) - i.dump(*this); + for (int i = 0; i < (int)insns.size(); ++i){ + if(n < 0 || i < n) + insns[i].dump(*this); + } } void CopyInstruction::dump(const CopyPlan &plan) const @@ -3385,6 +3709,9 @@ void CopyInstruction::dump(const CopyPlan &plan) const std::cout << ")\t"; } + if (op == Opcode::shfl) + std::cout << "shfl.idx4"; + else std::cout << getMnemonic(op, HW::Gen9); switch (op) { case Opcode::bfn: std::cout << ".(" << BFN::nodes[ctrl].str() << ')'; break; diff --git a/src/gpu/intel/gemm/jit/generator/pieces/copy_plan.hpp b/src/gpu/intel/gemm/jit/generator/pieces/copy_plan.hpp index 2b5fd216ac4..ea801e3d62e 100644 --- a/src/gpu/intel/gemm/jit/generator/pieces/copy_plan.hpp +++ b/src/gpu/intel/gemm/jit/generator/pieces/copy_plan.hpp @@ -147,7 +147,8 @@ struct CopyResource enum Kind : uint64_t { null = 0, - constantBase = 0x100000000, + constantBase = 0x100000000, + shflLUTBase = 0x80000000, } kind; CopyOperand src; bool preinitialized = true; @@ -156,11 +157,13 @@ struct CopyResource template inline void initialize(Generator &g); - static Kind makeConstant32(uint32_t c); + static Kind makeShflLUT(ngen::DataType from, ngen::DataType to); + bool decodeShflLUT(ngen::DataType &from, ngen::DataType &to) const; + protected: - int getData(std::array &data) const; + std::tuple getData() const; }; class CopyPlan @@ -188,7 +191,7 @@ class CopyPlan int tempFlagBytes() const; #if GEMMSTONE_ENABLE_COPY_PLAN_DUMP - void dump() const; + void dump(int n = -1) const; int cycleCount() const; #endif @@ -251,6 +254,10 @@ class CopyPlan void planEmulatedHFToF4(CopyInstruction &i); void planE8M0ToF(CopyInstruction &i); void planBFNEmulation(); + void emulateBooleanFunction(); + bool planShflUpconvertXe3p(CopyInstruction &i); + void legalizeShfl(); + void legalizeBfImmediate(CopyInstruction &i1); void legalizeSIMD(bool initial = false); void legalizeRegions(); void legalizeNegation(); @@ -265,7 +272,6 @@ class CopyPlan void optimizeIntegerDownconvert(); void optimizeSaturate(); void optimizeMoveToIntPipe(); - CopyOperand bfImmediate(uint16_t bits, bool ternary); CopyOperand zipImmediates(const CopyOperand &o1, const CopyOperand &o2); @@ -351,6 +357,14 @@ void CopyInstruction::execute(Generator &g) g.math(ngenModifiers(), fc, dst.ngen(), src0.ngen(), src1.ngen()); break; } + case Opcode::dnscl: { + uint8_t mode = 0; + g.dnscl(ngenModifiers(), mode, RoundingType::rne, dst.ngen(), src0.ngen(), src1.ngen(), src2.ngen()); + break; + } + case Opcode::shfl: + g.shfl.idx4(ngenModifiers(), dst.ngen(), src0.ngen(), src1.ngen()); + break; default: stub("Unsupported opcode"); } @@ -364,13 +378,14 @@ void CopyResource::initialize(Generator &g) { using namespace ngen; - std::array data; - int n = getData(data); + const uint8_t *data; + int n; + std::tie(data, n) = getData(); - auto dataUQ = (const uint64_t *) &data[0]; - auto dataDF = (const double *) &data[0]; - auto dataUD = (const uint32_t *) &data[0]; - auto dataF = (const float *) &data[0]; + auto dataUQ = (const uint64_t *) data; + auto dataDF = (const double *) data; + auto dataUD = (const uint32_t *) data; + auto dataF = (const float *) data; int n32 = (n + 3) >> 2; const bool do64 = (g.getHardware() >= HW::XeHPC); diff --git a/src/gpu/intel/gemm/jit/generator/pieces/emulation.cxx b/src/gpu/intel/gemm/jit/generator/pieces/emulation.cxx index d6b0d4286af..32ed25374d4 100644 --- a/src/gpu/intel/gemm/jit/generator/pieces/emulation.cxx +++ b/src/gpu/intel/gemm/jit/generator/pieces/emulation.cxx @@ -113,6 +113,21 @@ void Generator::emov(const ngen::InstructionModifier &mod, ngen::RegData dst EmulationImplementation::emov(*this, mod, dst, src0, strategy.emulate, loc); } +template +template +void Generator::emul(const ngen::InstructionModifier &mod, const ngen::RegData &dst, const ngen::RegData &src0, const ngen::RegData &src1, const CommonStrategy &strategy, CommonState &state, ngen::SourceLocation loc) +{ + bool is_xe3p = one_of(hw, {ngen::HW::XE3P_35_10, ngen::HW::XE3P_35_11, ngen::HW::XE3P_UNKNOWN}); + if (is_xe3p && (dst.getType() == DataType::bf && src1.getType() == DataType::f)){ + auto tempRange = state.ra.alloc_range(div_up(mod.getExecSize(), elementsPerGRF(hw, DataType::bf)));; + auto temp = tempRange[0].bf(dst.getOffset())(1); + mov(mod, temp, src1); + mul(mod, dst, src0, temp); + state.ra.safeRelease(tempRange); + }else + ngen::EmulationImplementation::emul
(*this, mod, dst, src0, src1, strategy.emulate, state.emulate, loc); +} + template template void Generator::eadd(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, diff --git a/src/gpu/intel/gemm/jit/generator/pieces/gemm.cxx b/src/gpu/intel/gemm/jit/generator/pieces/gemm.cxx index 541b8c7f605..7184d14e5d5 100644 --- a/src/gpu/intel/gemm/jit/generator/pieces/gemm.cxx +++ b/src/gpu/intel/gemm/jit/generator/pieces/gemm.cxx @@ -444,18 +444,12 @@ void Generator::gemm(GEMMProblem &problem, GEMMStrategy &strategy, GEMMState state.ra.safeRelease(idM); state.ra.safeRelease(idN); state.ra.safeRelease(idK); - if (!strategy.persistentLoop()) { - state.ra.safeRelease(state.inputs.localSizeM); - state.ra.safeRelease(state.inputs.localSizeN); - } if (anyKParallelFixed) { state.ra.safeRelease(state.inputs.localIDK); if (!strategy.persistentLoop()) state.ra.safeRelease(state.inputs.localSizeK); } if (strategy.linearOrder() || strategy.persistentLoop()) { - state.ra.safeRelease(state.inputs.groupIDM); - state.ra.safeRelease(state.inputs.groupIDN); state.ra.claim(state.nextGroupIDM); state.ra.claim(state.nextGroupIDN); } @@ -498,8 +492,6 @@ void Generator::gemm(GEMMProblem &problem, GEMMStrategy &strategy, GEMMState state.ra.safeRelease(state.inputs.k0); } - state.ra.safeRelease(state.inputs.localIDM); - state.ra.safeRelease(state.inputs.localIDN); if (!strategy.needsMNLocalIDs()) state.lidM = state.lidN = invalid; @@ -579,6 +571,17 @@ void Generator::gemm(GEMMProblem &problem, GEMMStrategy &strategy, GEMMState gemmSetupABC(problem, strategy, state); gemmSubkernel(problem, strategy, state); + if (strategy.linearOrder() || strategy.persistentLoop()) { + state.ra.safeRelease(state.inputs.groupIDM); + state.ra.safeRelease(state.inputs.groupIDN); + } + if (!strategy.persistentLoop()) { + state.ra.safeRelease(state.inputs.localSizeM); + state.ra.safeRelease(state.inputs.localSizeN); + } + state.ra.safeRelease(state.inputs.localIDM); + state.ra.safeRelease(state.inputs.localIDN); + mark(lKernelDone); // Persistent thread loop. Advance group ID and re-enter kernel if there's more work to do. @@ -843,6 +846,8 @@ void Generator::gemmSubkernel(GEMMProblem &problem, GEMMStrategy &strategy, auto modStrategy = strategy; gemmDowngradeAccess(problem, modStrategy, state); + gemmCalcWGIndices(problem, modStrategy, state); + gemmCalcWGRemainders(problem, modStrategy, state); status << "Unaligned A/B" << status_stream::endl; if (!gemmMEdge(problem, modStrategy, state)) { diff --git a/src/gpu/intel/gemm/jit/generator/pieces/gemm_setup.cxx b/src/gpu/intel/gemm/jit/generator/pieces/gemm_setup.cxx index 4b312161915..6f11820998c 100644 --- a/src/gpu/intel/gemm/jit/generator/pieces/gemm_setup.cxx +++ b/src/gpu/intel/gemm/jit/generator/pieces/gemm_setup.cxx @@ -565,6 +565,7 @@ void Generator::gemmOffsetBatchABC(const GEMMProblem &problem, const GEMMStr if(problem.hasCMXScale()){ bOffsetCs[b] = state.inputs.strideScaleC[b]; } + if(problem.hasAOffsetPtr()){ bOffsetAo[b] = state.inputs.strideOffsetA[b]; } @@ -595,6 +596,7 @@ void Generator::gemmOffsetBatchABC(const GEMMProblem &problem, const GEMMStr if(problem.hasCMXScale()){ emul(1, bOffsetCs[b], state.inputs.strideScaleC[b], state.batchID[b], strategy, state); } + if(problem.hasAOffsetPtr()){ emul(1, bOffsetAo[b], state.inputs.strideOffsetA[b], state.batchID[b], strategy, state); } @@ -621,6 +623,7 @@ void Generator::gemmOffsetBatchABC(const GEMMProblem &problem, const GEMMStr state.offsetCs = state.ra.alloc_sub(ngen::DataType::q); emov(1, state.offsetCs, 0, strategy, state); } + if(problem.hasAOffsetPtr() && state.offsetAo.isInvalid()){ state.offsetAo = state.ra.alloc_sub(state.offsetA.getType()); emov(1, state.offsetAo, 0, strategy, state); @@ -654,6 +657,7 @@ void Generator::gemmOffsetBatchABC(const GEMMProblem &problem, const GEMMStr if(problem.hasCMXScale()){ eadd(1, state.offsetCs, state.offsetCs, bOffsetCs[b], strategy, state); } + if(problem.hasAOffsetPtr()){ eadd(1, state.offsetAo, state.offsetAo, bOffsetAo[b], strategy, state); } @@ -960,6 +964,7 @@ void Generator::gemmScaleInputs(const GEMMProblem &problem, const GEMMStrate state.ra.safeRelease(inputs.ldaq); state.ra.safeRelease(inputs.ldbq); + //state.ra.safeRelease(inputs.ldcq); state.ra.safeRelease(inputs.offsetAq); state.ra.safeRelease(inputs.offsetBq); } @@ -980,6 +985,32 @@ void Generator::gemmCalcWGRemainders(const GEMMProblem &problem, const GEMMS if (strategy.coopB != CoopSplit::FullK) state.ra.safeRelease(state.wgJ0); } +// Calculate workgroup m/n indices. +template +void Generator::gemmCalcWGIndices(const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state) +{ + Subregister idM, idN; + + idM = state.ra.alloc_sub(getHint(HintType::TempComp1, strategy)); + idN = state.ra.alloc_sub(getHint(HintType::TempComp0, strategy)); + + if (strategy.fixedWG(problem)) { + mulConstant(1, idM, state.inputs.groupIDM, strategy.wg[LoopM]); + mulConstant(1, idN, state.inputs.groupIDN, strategy.wg[LoopN]); + } else { + mul(1, idM, state.inputs.groupIDM, state.inputs.localSizeM.uw()); + mul(1, idN, state.inputs.groupIDN, state.inputs.localSizeN.uw()); + } + bool gemmtBarriers = problem.gemmt() && strategy.needsBarrier(); + if (wgRemCheck(problem, strategy) || gemmtBarriers) { + state.wgI0 = state.ra.alloc_sub(getHint(HintType::TempComp0, strategy)); + state.wgJ0 = state.ra.alloc_sub(getHint(HintType::TempComp1, strategy)); + mulConstant(1, state.wgI0, idM, strategy.unroll[LoopM]); + mulConstant(1, state.wgJ0, idN, strategy.unroll[LoopN]); + } + state.ra.safeRelease(idM); + state.ra.safeRelease(idN); +} // Cache multiples of lda/ldb for later address calculations. template void Generator::gemmCacheLDABMultiples(const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state, bool doA, bool doB) @@ -1660,7 +1691,7 @@ bool Generator::gemmAccumulateCSetup(GEMMProblem &problem, GEMMStrategy &str state.Br_layout = RegisterLayout(hw, Tb, strategy.kb_load, unrollN, state.B_layout.colMajor(), crosspackB, tileK_B, tileN_B, true, splitB); // Prepare to repack C if needed, and choose repack tile size. - if (Tc != Tc_compute) { + if (Tc != Tc_compute || problem.forceLateQuant(hw, minOuterProductCount(hw, problem, strategy))) { auto &period = state.cRepackPeriod; int panel = strategy.cRepackPanel; if (panel == 0) @@ -3173,6 +3204,7 @@ void Generator::gemmInitState(GEMMProblem &problem, GEMMStrategy &strategy, state.tempCStrategy.padded = true; } + state.useBDPAS = problem.preferBDPAS(hw) && strategy.systolic; } GEMMSTONE_NAMESPACE_END diff --git a/src/gpu/intel/gemm/jit/generator/pieces/hw_template_instantiations.cxx b/src/gpu/intel/gemm/jit/generator/pieces/hw_template_instantiations.cxx index 502b5d5eb07..82bcb3e2697 100644 --- a/src/gpu/intel/gemm/jit/generator/pieces/hw_template_instantiations.cxx +++ b/src/gpu/intel/gemm/jit/generator/pieces/hw_template_instantiations.cxx @@ -39,4 +39,7 @@ REG_XEHPG_ISA(template class Generator); REG_XEHPC_ISA(template class Generator); REG_XE2_ISA(template class Generator); REG_XE3_ISA(template class Generator); +REG_XE3P_ISA(template class Generator); +REG_XE3P_ISA(template class Generator); +REG_XE3P_ISA(template class Generator); #endif diff --git a/src/gpu/intel/gemm/jit/generator/pieces/hw_utils.hpp b/src/gpu/intel/gemm/jit/generator/pieces/hw_utils.hpp index 7ed0652b6bb..95b57c1a249 100644 --- a/src/gpu/intel/gemm/jit/generator/pieces/hw_utils.hpp +++ b/src/gpu/intel/gemm/jit/generator/pieces/hw_utils.hpp @@ -72,6 +72,7 @@ static inline bool hasNativeAtomicAdd(ngen::HW hw, Type T, const MatrixAddressin bool floatAtomics = (astrategy.base.getModel() == ModelA64); if (astrategy.newDP) floatAtomics |= (astrategy.base.getModel() != ModelSLM); + if (hw >= HW::XE3P_35_10) floatAtomics = true; if (T.isInt4()) return false; @@ -79,6 +80,8 @@ static inline bool hasNativeAtomicAdd(ngen::HW hw, Type T, const MatrixAddressin return true; else if (T == Type::f32) return floatAtomics && (hw >= HW::XeHP); + else if (T == Type::f16 || T == Type::bf16) + return (hw >= HW::XE3P_35_10); else if (T == Type::f64) return floatAtomics && (hw >= HW::XeHPC); else @@ -92,9 +95,12 @@ static inline size_t slmCapacity(ngen::HW hw) case HW::Gen12LP: case HW::XeHP: case HW::XeHPG: - case HW::XeHPC: return 131072; - case HW::Xe2: return 131072; - case HW::Xe3: return 131072; + case HW::XeHPC: return 131072; + case HW::Xe2: return 131072; + case HW::Xe3: return 131072; + case HW::XE3P_35_10: return 196608; + case HW::XE3P_35_11: + case HW::XE3P_UNKNOWN: return 393216; default: return 0; } @@ -122,6 +128,9 @@ static inline int eusPerSubslice(ngen::HW hw) switch (hw) { case HW::XeHPC: case HW::Xe2: + case HW::XE3P_35_10: + case HW::XE3P_35_11: + case HW::XE3P_UNKNOWN: case HW::Xe3: return 8; case HW::Gen12LP: @@ -156,6 +165,7 @@ static inline int block2DMinAlignment(ngen::HW hw, const MatrixAddressing &atype if (!isBlock2D(astrategy.accessType) && !asIfBlock2D) return 0; if (hw == HW::Xe2) return 16; if (hw == HW::Xe3) return 16; + if (hw >= HW::XE3P_35_10) return 4; return (isTransposing(astrategy.accessType) || astrategy.prefetch) ? 4 : 8; } @@ -163,6 +173,7 @@ static inline int block2DMinAlignment(ngen::HW hw, const MatrixAddressing &atype static inline int block2DBaseAlignment(ngen::HW hw, int stepping) { using namespace ngen; + if (hw >= HW::XE3P_35_10) return 4; if (hw == HW::XeHPC && stepping < SteppingPVCXTB4) return 128; return 64; diff --git a/src/gpu/intel/gemm/jit/generator/pieces/k_loop.cxx b/src/gpu/intel/gemm/jit/generator/pieces/k_loop.cxx index e63e523054c..339d9daee8e 100644 --- a/src/gpu/intel/gemm/jit/generator/pieces/k_loop.cxx +++ b/src/gpu/intel/gemm/jit/generator/pieces/k_loop.cxx @@ -973,9 +973,9 @@ void Generator::kLoop(KLoop type, const GEMMProblem &problem, GEMMStrategy & if (dequantizeA) gemmDequantizeAB(true, sublayout, Ar_sublayout, regs, state.Ar_regs, h, k_load, k_repack, kaq_load, problem, strategy, state, s4Shift); else - if (repackA) + if (repackA) { copyRegisters(sublayout, Ar_sublayout, regs, state.Ar_regs, 0, har, false, strategy, state, false, s4Shift); - else if (convertA) + } else if (convertA) convert(regs, Ta_load, Ta, strategy, state); }; @@ -997,7 +997,7 @@ void Generator::kLoop(KLoop type, const GEMMProblem &problem, GEMMStrategy & if (dequantizeB) gemmDequantizeAB(false, layout, state.Br_layout, regs, state.Br_regs, h, k_load, k_repack, kbq_load, problem, strategy, state); - else + else if (repackB) copyRegisters(layout, state.Br_layout, regs, state.Br_regs, hbr, 0, false, strategy, state); else if (convertB) diff --git a/src/gpu/intel/gemm/jit/generator/pieces/math_helpers.cxx b/src/gpu/intel/gemm/jit/generator/pieces/math_helpers.cxx index 19119880147..a13f66f03af 100644 --- a/src/gpu/intel/gemm/jit/generator/pieces/math_helpers.cxx +++ b/src/gpu/intel/gemm/jit/generator/pieces/math_helpers.cxx @@ -242,7 +242,7 @@ void Generator::divDown(const Subregister &dst, const Subregister &src0, con bool emulate = strategy.emulate.emulate64_mul; Subregister tmp; auto shift = state.ra.alloc_sub(); - auto pop = state.ra.alloc_sub(); + auto pop = state.ra.alloc_sub(); cbit(1, pop, src1); fbh(1, shift, src1); cmp(1 | gt | flag, pop, 1); diff --git a/src/gpu/intel/gemm/jit/generator/pieces/matrix_access.cxx b/src/gpu/intel/gemm/jit/generator/pieces/matrix_access.cxx index 09fe9eda68f..6f948dbd80f 100644 --- a/src/gpu/intel/gemm/jit/generator/pieces/matrix_access.cxx +++ b/src/gpu/intel/gemm/jit/generator/pieces/matrix_access.cxx @@ -310,7 +310,7 @@ void Generator::atomicAddMatrixBlock(Type T, const GRF &src, const RegisterB auto mod = simd | maskMod | ExecutionOffset(eoff); if (block.ebytes * block.count != T.real().size()) stub(); if (astrategy.newDP) { - auto op = T.isFP() ? AtomicOp::fadd + auto op = T.isFP() ? (hw >= HW::XE3P_35_10 && T == Type::bf16) ? AtomicOp::bfadd : AtomicOp::fadd : AtomicOp::add; atomic(op, mod, specLSC, astrategy.base, getAddress(addr[hoff], block, astrategy), curSrc); } else switch (T.real()) { diff --git a/src/gpu/intel/gemm/jit/generator/pieces/matrix_multiply.cxx b/src/gpu/intel/gemm/jit/generator/pieces/matrix_multiply.cxx index 62f94fa026a..003205f7c63 100644 --- a/src/gpu/intel/gemm/jit/generator/pieces/matrix_multiply.cxx +++ b/src/gpu/intel/gemm/jit/generator/pieces/matrix_multiply.cxx @@ -378,7 +378,7 @@ void Generator::outerProductSystolic(int h, int ha, int hb, int opCount, boo { auto Ta = problem.Ta, Tb = problem.Tb, Tc = problem.Tc_compute(); bool globalCM = state.C_layout.colMajor(); - auto params = systolicParams(hw, problem, strategy); + auto params = systolicParams(hw, problem); auto ksys = params.ksys; auto osys = params.osys; auto sdepth = params.sdepth; @@ -425,6 +425,7 @@ void Generator::outerProductSystolic(int h, int ha, int hb, int opCount, boo for (int x = 0; x < nx; x += xinc) { Subregister A0, B0, C0; + RegData AS, BS, AS0, BS0; int rcount = 0, ybase = 0, hhbase = 0; auto issueDPAS = [&](bool last) { @@ -448,9 +449,14 @@ void Generator::outerProductSystolic(int h, int ha, int hb, int opCount, boo if (rc != 8 && strategy.extendedAtomicFMA) hw_unsupported(); } + if (hhbase + ksys < opCount) mod |= Fwd; + if (startRepackC && hhbase == 0) srcC0 = null.retype(C0.getType()); + if (state.useBDPAS) { + bdpas(mod, sdepth, rc, C0, srcC0, V0, N0, AS0, BS0); + } else { useDPASW ? dpasw(mod, sdepth, rc, C0, srcC0, V0, N0) : dpas(mod, sdepth, rc, C0, srcC0, V0, N0); @@ -472,6 +478,30 @@ void Generator::outerProductSystolic(int h, int ha, int hb, int opCount, boo Subregister A, B, C; const int cxCompA = -1, cxCompB = -1, cxCompC = -1, cBuffer = 0; + auto bdpasScaleArg = [&](const RegisterLayout Xr_scaleLayout, Type Tx_scaleOp, + GRFMultirange Xr_scaleRegs, bool isA, int x, int k) { + RegData XS; + if (state.useBDPAS && !Xr_scaleLayout.empty()) { + int neq; + const RegisterBlock *qblock; + + int r, c, io0, jo0; + r = Xr_scaleLayout.rows(); + c = Xr_scaleLayout.cols(); + + if (isA) { + io0 = x; + jo0 = (k / problem.aqGroupK) % c; + } else { + io0 = (k / problem.bqGroupK) % r; + jo0 = x; + } + XS = Xr_scaleLayout.find(io0, jo0, Xr_scaleRegs, &neq, &qblock); + } else + XS = NullRegister().setType(Type::ngen_e8m0()); + + return XS; + }; if (y < ny) { if (strategy.dpasw && (y % (2 * dpaswTile) >= dpaswTile)) @@ -486,6 +516,8 @@ void Generator::outerProductSystolic(int h, int ha, int hb, int opCount, boo C = state.Cr_layout.find(i % Cr_unrollM, j % Cr_unrollN, state.Cr_regs, &nc, &C_block, cxCompC); else C = state.C_layout.find(i, j, state.C_regs[cBuffer], &nc, &C_block, cxCompC); + AS = bdpasScaleArg(state.Ar_scaleLayout, state.Ta_scaleInt, state.Ar_scaleRegs, true, i, h + hh); + BS = bdpasScaleArg(state.Br_scaleLayout, state.Tb_scaleInt, state.Br_scaleRegs, false, j, h + hh); } else if (state.systolicSumA) { A = A_layout.find(x, hha, A_regs, &na, &A_block); B = state.sysSumAll1s[0]; @@ -495,6 +527,8 @@ void Generator::outerProductSystolic(int h, int ha, int hb, int opCount, boo C = state.Asr_layout.find(x % Cr_unrollM, 0, state.Asr_regs, &nc, &C_block); else C = state.As_layout.find(x, 0, state.As_regs, &nc, &C_block); + AS = bdpasScaleArg(state.Ar_scaleLayout, state.Ta_scaleInt, state.Ar_scaleRegs, true, x, h + hh); + BS = NullRegister().setType(Type::ngen_e8m0()); } else { A = state.sysSumAll1s[0]; na = elementsPerGRF(hw, Ta); @@ -504,6 +538,8 @@ void Generator::outerProductSystolic(int h, int ha, int hb, int opCount, boo C = state.Bsr_layout.find(0, x % Cr_unrollN, state.Bsr_regs, &nc, &C_block); else C = state.Bs_layout.find(0, x, state.Bs_regs, &nc, &C_block); + AS = NullRegister().setType(Type::ngen_e8m0()); + BS = bdpasScaleArg(state.Br_scaleLayout, state.Tb_scaleInt, state.Br_scaleRegs, false, h + hh, x); } int nv = globalCM ? na : nb; @@ -540,6 +576,7 @@ void Generator::outerProductSystolic(int h, int ha, int hb, int opCount, boo else { if (strategy.dpasw && y < ny && rcount > 0 && rcount != dpaswTile) stub(); if (A0.isValid()) issueDPAS(false); + AS0 = AS; BS0 = BS; A0 = A; B0 = B; C0 = C; rcount = 1; A0.setType(Ta.ngen()); B0.setType(Tb.ngen()); diff --git a/src/gpu/intel/gemm/jit/generator/pieces/monolithic_k_loop_dpasw.cxx b/src/gpu/intel/gemm/jit/generator/pieces/monolithic_k_loop_dpasw.cxx index db3c6d518a3..33ad3b7ebe4 100644 --- a/src/gpu/intel/gemm/jit/generator/pieces/monolithic_k_loop_dpasw.cxx +++ b/src/gpu/intel/gemm/jit/generator/pieces/monolithic_k_loop_dpasw.cxx @@ -66,7 +66,7 @@ template bool Generator::sysgemmAccumulateC(GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state) { using namespace sysgemm; - auto params = systolicParams(hw, problem, strategy); + auto params = systolicParams(hw, problem); auto unrollM = strategy.unroll[LoopM]; auto unrollN = strategy.unroll[LoopN]; auto wgM = strategy.wg[LoopM]; @@ -1410,7 +1410,7 @@ template bool Generator::sysgemm2AccumulateC(GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state) { using namespace sysgemm2; - auto params = systolicParams(hw, problem, strategy); + auto params = systolicParams(hw, problem); auto unrollM = strategy.unroll[LoopM]; auto unrollN = strategy.unroll[LoopN]; auto localIDM = state.lidM; diff --git a/src/gpu/intel/gemm/jit/generator/pieces/quantization.cxx b/src/gpu/intel/gemm/jit/generator/pieces/quantization.cxx index 00c432653be..c89cd3ef1f7 100644 --- a/src/gpu/intel/gemm/jit/generator/pieces/quantization.cxx +++ b/src/gpu/intel/gemm/jit/generator/pieces/quantization.cxx @@ -80,8 +80,9 @@ bool Generator::gemmMake2DQuantizationLayouts(bool isA, const GEMMProblem &p int cpoDiv = 1; if (Txo_int.isInt8()) Txo_int = Type::s16, cpoDiv = 2; - - if (xs2D && (Txs.paddedSize() > Tx.paddedSize()) && Tx.isInteger()) { + // Use lateScale for cases of applying scale to inputs that will be natively dpas'd + // but do not support add/mul. + if (xs2D && ((Txs.paddedSize() > Tx.paddedSize() && Tx.isInteger()) || problem.forceLateQuant(hw, minOuterProductCount(hw, problem, strategy)) || state.useBDPAS)) { lateScale = true; Txs_int = problem.Tc; } @@ -96,6 +97,8 @@ bool Generator::gemmMake2DQuantizationLayouts(bool isA, const GEMMProblem &p if (lateOffset && (Txo.isInt4() || Txo.isInt8())) Txo_int = Type::s32; + if (Txs == Type::f8_e8m0 && state.useBDPAS) + Txs_int = Type::f8_e8m0; // Get tile sizes, depending on whether A/B are copied to SLM. // For late scaling (after compute), scales are always applied to the whole tile. int r, c, k, rNoSLM, cNoSLM; @@ -191,6 +194,16 @@ bool Generator::gemmMake2DQuantizationLayouts(bool isA, const GEMMProblem &p int m, int n, int cp, bool forceRepack) mutable { if (cp > 1 || (cColMajor && (cp != src[0].crosspack)) || Txq != Txq_int || forceRepack) { bool allowPartialRegs = false; + // Native MXFP DPAS support + if (state.useBDPAS) { + allowPartialRegs = true; + cp = 1; + if (isA) { + tileR = problem.aqGroupK; + } else { + tileC = problem.bqGroupK; + } + } repack = RegisterLayout(hw, Txq_int, m, n, wantCM, cp, tileR, tileC, allowPartialRegs); } }; @@ -223,6 +236,20 @@ void Generator::gemmRepack2DQuantizationData(Type Ts, Type Td, const Registe for (int doffC = 0; doffC < layoutDst.cols(); doffC += layoutSrc.cols()) copyRegisters(Ts, Td, layoutSrc, layoutDst, src, dst, doffR, doffC, false, strategy, state); + int r = layoutDst.rows(); + int c = layoutDst.cols(); + // FP4 bdpas with group > 32 requires duplicating scales into upper half of registers. + if(state.useBDPAS && r * c < minOuterProductCount(hw, problem, strategy)){ + int halfRegElems = elementsPerGRF(hw, Td) / 2; + if(r * c > halfRegElems) stub(); + RegisterLayout offsetLayout(layoutDst); + for( auto &b : offsetLayout){ + if(b.nr * b.nc > halfRegElems) stub(); + b.offsetBytes += (b.nr * b.nc); + } + copyRegisters(Td, Td, layoutDst, offsetLayout, dst, dst, 0, 0, false, strategy, state); + } + // Duplicate data in padded region. TODO: do this as part of the copy. int cp = layoutDst[0].crosspack; int p0 = layoutDst[0].colMajor ? layoutDst[0].nc : layoutDst[0].nr; @@ -305,7 +332,7 @@ template void Generator::gemmDequantizeOperation(bool doA, Type T, Type Tq, BinaryOp op, const RegisterLayout &layout, const RegisterLayout &qlayout, const GRFMultirange ®s, const GRFMultirange &qregs, - int h, int kab_load, int kq_load, const GEMMProblem &problem, CommonState &state) + int h, int kab_load, int kq_load, const GEMMProblem &problem, const CommonStrategy &strategy, CommonState &state) { int xqGroupK = doA ? problem.aqGroupK : problem.bqGroupK; int xqGroupMN = doA ? problem.aqGroupM : problem.bqGroupN; @@ -398,7 +425,7 @@ void Generator::gemmDequantizeOperation(bool doA, Type T, Type Tq, BinaryOp add(simd, data(strided), data(strided), -qdata(strideq)); break; case BinaryOp::Mul: - mul(simd, data(strided), data(strided), qdata(strideq)); + emul(simd, data(strided), data(strided), qdata(strideq), strategy, state); break; case BinaryOp::ScaleSub: if (T != Type::f16) stub(); @@ -467,7 +494,7 @@ void Generator::dequantizeInt4(bool doA, const RegisterLayout &layoutSrc, co // so two multiplications are needed. if (!layoutOffset.empty()) { if (!problem) stub(); - gemmDequantizeOperation(doA, Type::f16, Type::f16, BinaryOp::ScaleSub, *effLayoutDst, layoutOffset, *effDst, offset, h, kab_load, kq_load, *problem, state); + gemmDequantizeOperation(doA, Type::f16, Type::f16, BinaryOp::ScaleSub, *effLayoutDst, layoutOffset, *effDst, offset, h, kab_load, kq_load, *problem, strategy, state); } else { map(hw, Type::f16, *effDst, *effLayoutDst, strategy, [&](int esize, RegData r) { s4 ? mad(esize, r, Immediate::hf(0x9800), r, Immediate::hf(0x6C00)) /* 0x9800 = -8*2^(-12), 0x6C00 = 2^12 */ @@ -484,7 +511,7 @@ void Generator::dequantizeInt4(bool doA, const RegisterLayout &layoutSrc, co // this could be merged into the previous multiplication. if (!f32 && !layoutScale.empty()) { if (!problem) stub(); - gemmDequantizeOperation(doA, Type::f16, Type::f16, BinaryOp::Mul, *effLayoutDst, layoutScale, *effDst, scale, h, kab_load, kq_load, *problem, state); + gemmDequantizeOperation(doA, Type::f16, Type::f16, BinaryOp::Mul, *effLayoutDst, layoutScale, *effDst, scale, h, kab_load, kq_load, *problem, strategy, state); } // 6) Convert to dst type if needed. @@ -496,7 +523,7 @@ void Generator::dequantizeInt4(bool doA, const RegisterLayout &layoutSrc, co // 7) Apply scales for f32 after f16->f32 upconversion. if (f32 && !layoutScale.empty()) { if (!problem) stub(); - gemmDequantizeOperation(doA, Type::f32, Type::f32, BinaryOp::Mul, layoutDst, layoutScale, dst, scale, h, kab_load, kq_load, *problem, state); + gemmDequantizeOperation(doA, Type::f32, Type::f32, BinaryOp::Mul, layoutDst, layoutScale, dst, scale, h, kab_load, kq_load, *problem, strategy, state); } } @@ -563,12 +590,18 @@ void Generator::gemmDequantizeAB(bool doA, const RegisterLayout &layoutSrc, convert(src, Tsrc, Tx_int, strategy, state); if (xo2D) { - gemmDequantizeOperation(doA, Tx_int, Txo_int, BinaryOp::Sub, layoutDst, oLayout, dst, oRegs, h, kab_load, kq_load, problem, state); + if (!state.useBDPAS) + { + gemmDequantizeOperation(doA, Tx_int, Txo_int, BinaryOp::Sub, layoutDst, oLayout, dst, oRegs, h, kab_load, kq_load, problem, strategy, state); convert(dst, Tx_int, Tdst, strategy, state); + } } if (xs2D) - gemmDequantizeOperation(doA, Tdst, Txs_int, BinaryOp::Mul, layoutDst, sLayout, dst, sRegs, h, kab_load, kq_load, problem, state); + if (!state.useBDPAS) + { + gemmDequantizeOperation(doA, Tdst, Txs_int, BinaryOp::Mul, layoutDst, sLayout, dst, sRegs, h, kab_load, kq_load, problem, strategy, state); + } } if (ms < md || ns < nd) { diff --git a/src/gpu/intel/gemm/jit/generator/pieces/register_allocation.cxx b/src/gpu/intel/gemm/jit/generator/pieces/register_allocation.cxx index f0e5f55fefa..4a90b4b2801 100644 --- a/src/gpu/intel/gemm/jit/generator/pieces/register_allocation.cxx +++ b/src/gpu/intel/gemm/jit/generator/pieces/register_allocation.cxx @@ -205,7 +205,7 @@ void Generator::gemmAllocRegs(GEMMProblem &problem, GEMMStrategy &strategy, C_chunk = alignup_pow2(C_chunk, Bundle(0, 0).group_size(raHW) * 2); if (strategy.systolic) { - auto params = systolicParams(hw, problem, strategy); + auto params = systolicParams(hw, problem); C_chunk = std::max(C_chunk, (params.osys * params.rcountMax * Tc.real()) / GRF::bytes(hw)); Vr_chunk = std::max(Vr_chunk, (params.osys * params.ksys * Tv.real()) / GRF::bytes(hw)); Nr_chunk = std::max(Nr_chunk, (params.rcountMax * params.ksys * Tn.real()) / GRF::bytes(hw)); diff --git a/src/gpu/intel/gemm/jit/generator/pieces/state.hpp b/src/gpu/intel/gemm/jit/generator/pieces/state.hpp index e6331d5a396..a6731d4dfde 100644 --- a/src/gpu/intel/gemm/jit/generator/pieces/state.hpp +++ b/src/gpu/intel/gemm/jit/generator/pieces/state.hpp @@ -339,6 +339,7 @@ struct GEMMState : public CommonState { ngen::GRF tmpCScales; // Internal storage of dynamic scales ngen::Subregister statusFlagAddr; // uq bool systolicSumA = false, systolicSumB = false; + bool useBDPAS = false; bool lateKLoopCheck = false; bool splitBarrierAlways = false; int ka_loadRem, kb_loadRem; diff --git a/src/gpu/intel/gemm/jit/generator/pieces/state_utils.cxx b/src/gpu/intel/gemm/jit/generator/pieces/state_utils.cxx index 40d61ac1b08..806bfa315ce 100644 --- a/src/gpu/intel/gemm/jit/generator/pieces/state_utils.cxx +++ b/src/gpu/intel/gemm/jit/generator/pieces/state_utils.cxx @@ -38,9 +38,9 @@ void Generator::saveMNLocalIDs(const GEMMStrategy &strategy, GEMMState &stat template void Generator::saveKLocalIDSize(const GEMMStrategy &strategy, GEMMState &state) { - state.lidszKStorage = state.ra.alloc_sub(getHint(HintType::LongTerm, strategy)); + state.lidszKStorage = state.ra.alloc_sub(getHint(HintType::LongTerm, strategy)); state.lidK = state.lidszKStorage.uw(0); - state.lszK = state.lidszKStorage.uw(1); + state.lszK = state.lidszKStorage.ud(1); mov(1, state.lidK, state.inputs.localIDK); mov(1, state.lszK, state.inputs.localSizeK); } diff --git a/src/gpu/intel/gemm/jit/generator/strategy.cpp b/src/gpu/intel/gemm/jit/generator/strategy.cpp index b03d7e61972..c12959253f0 100644 --- a/src/gpu/intel/gemm/jit/generator/strategy.cpp +++ b/src/gpu/intel/gemm/jit/generator/strategy.cpp @@ -43,6 +43,7 @@ void CommonStrategy::preflight(HW hw, const CommonProblem &problem) bool emulateNeedsAcc = emulate.emulate64 || emulate.emulateDWxDW || emulate.emulate64_mul; if (moveR0 == MoveR0::Acc && emulateNeedsAcc) moveR0 = MoveR0::None; + if (hw >= HW::XE3P_35_10) moveR0 = MoveR0::None; spf &= !fused; } @@ -234,6 +235,10 @@ void GEMMStrategy::preflight(HW hw, const GEMMProblem &problem) // 64-bit emulation > r0 header storage. if (AccumulatorRegister::count(hw, GRFs, problem.Tc.real().ngen()) == 0) kChain = 1; + // Using acc and mad not working on xe3p + bool is_xe3p = one_of(hw, {ngen::HW::XE3P_35_10, ngen::HW::XE3P_35_11, ngen::HW::XE3P_UNKNOWN}); + if (!systolic && !dotVL && is_xe3p) + kChain = 1; cAccumulators &= (kChain == 1); bool emulateNeedsAcc = emulate.emulate64 || emulate.emulateDWxDW; @@ -284,7 +289,7 @@ void GEMMStrategy::preflight(HW hw, const GEMMProblem &problem) // Systolic handling. if (systolic) { - auto params = systolicParams(hw, problem, *this); + auto params = systolicParams(hw, problem); ukAlign = lcm(ukAlign, params.ksys); auto tileX = params.osys; @@ -324,7 +329,7 @@ void GEMMStrategy::preflight(HW hw, const GEMMProblem &problem) } if (dpasw) { - auto params = systolicParams(hw, problem, *this); + auto params = systolicParams(hw, problem); if (globalCM) { if (!fusedM()) stub(); B.dpasw = true; diff --git a/src/gpu/intel/gemm/jit/generator/strategy_parser.cpp b/src/gpu/intel/gemm/jit/generator/strategy_parser.cpp index 2e0e71e9425..8bf4203fc21 100644 --- a/src/gpu/intel/gemm/jit/generator/strategy_parser.cpp +++ b/src/gpu/intel/gemm/jit/generator/strategy_parser.cpp @@ -93,8 +93,38 @@ CacheSettingsLSC getCaching(char l1, char l3) } } +CacheSettingsLSC getCaching(char l1, char l2, char l3) { + if (l3 == 'u' || l3 == 'i') return getCaching(l1, l2); + + if (l3 == 'c' || l3 == 'b') { + bool l2cached = (l2 == 'c') || (l2 == 'b'); + switch (l1) { + case 'u': + return l2cached ? CacheSettingsLSC::L1UC_L2C_L3C + : CacheSettingsLSC::L1UC_L2UC_L3C; + case 't': + case 'c': + return l2cached ? CacheSettingsLSC::L1C_L2C_L3C + : CacheSettingsLSC::L1C_L2UC_L3C; + case 's': + return l2cached ? CacheSettingsLSC::L1S_L2C_L3C + : CacheSettingsLSC::L1S_L2UC_L3C; + case 'b': + if (!l2cached) return CacheSettingsLSC::L1WB_L2UC_L3WB; + default: break; + } + } + + throw std::runtime_error("Unknown cache setting"); +} + CacheSettingsLSC getCachingEntry(std::stringstream &s, HW hw) { + if (hw >= HW::XE3P_35_10) { + char l1, l2, l3; + s >> l1 >> l2 >> l3; + return getCaching(l1, l2, l3); + } else { char l1, l3; s >> l1 >> l3; @@ -112,6 +142,8 @@ void getCaching(std::stringstream &s, HW hw, MatrixAddressingStrategy &astrategy cachingW = CacheSettingsLSC::L1WB_L3WB; if (hw >= HW::XeHPC) cachingW = CacheSettingsLSC::L1UC_L3WB; + if (hw >= HW::XE3P_35_10) + cachingR = CacheSettingsLSC::L1C_L2C_L3C; } if (s.peek() == '{') { @@ -233,6 +265,7 @@ void parseStrategy(const char *str, HW hw, const GEMMProblem &problem, GEMMStrat strategy.A.cachingW = CacheSettingsLSC::Default; strategy.B.cachingW = CacheSettingsLSC::Default; strategy.CO.cachingR = CacheSettingsLSC::L1C_L3C; + if (hw >= HW::XE3P_35_10) strategy.CO.cachingR = CacheSettingsLSC::L1C_L2C_L3C; strategy.A_prefetch.prefetch = true; strategy.B_prefetch.prefetch = true; strategy.C_prefetch.prefetch = true; @@ -253,6 +286,8 @@ void parseStrategy(const char *str, HW hw, const GEMMProblem &problem, GEMMStrat strategy.AB_prefetchL3.base = getAddressBase(strategy.l3PrefetchA ? asA : asB); if (strategy.AB_prefetchL3.cachingR == CacheSettingsLSC::Default) { strategy.AB_prefetchL3.cachingR = CacheSettingsLSC::L1UC_L3C; + if (hw >= HW::XE3P_35_10) + strategy.AB_prefetchL3.cachingR = CacheSettingsLSC::L1UC_L2C_L3C; } strategy.A.padded |= isPacked(problem.A.layout); diff --git a/src/gpu/intel/gemm/jit/include/gemmstone/dsl/hw.hpp b/src/gpu/intel/gemm/jit/include/gemmstone/dsl/hw.hpp index 71d6c88f9e1..39d5e046d7a 100644 --- a/src/gpu/intel/gemm/jit/include/gemmstone/dsl/hw.hpp +++ b/src/gpu/intel/gemm/jit/include/gemmstone/dsl/hw.hpp @@ -68,8 +68,8 @@ class hw_t : public stringify_t { public: using attr_t = hw::attr_t; hw_t() = default; - explicit hw_t(const ngen::Product &product, int eu_count, - size_t max_wg_size, size_t l3_cache_size, attr_t attr); + explicit hw_t(const ngen::Product &product, int eu_count, int max_wg_size, + size_t l3_cache_size, bool efficient_64_bit, attr_t attr); ngen::Product product() const; ngen::ProductFamily family() const; @@ -85,6 +85,7 @@ class hw_t : public stringify_t { int grf_size() const; int systolic_support() const { return any(attr_ & attr_t::systolic); } size_t l3_cache_size() const { return l3_cache_size_; } + bool efficient_64_bit() const { return efficient_64_bit_; } size_t max_tg_size(int regs, int simd) const; @@ -129,6 +130,7 @@ class hw_t : public stringify_t { int eu_count_ = 0; size_t max_wg_size_ = 0; size_t l3_cache_size_ = 0; + bool efficient_64_bit_ = false; attr_t attr_ = attr_t::none; }; diff --git a/src/gpu/intel/gemm/jit/include/gemmstone/generator.hpp b/src/gpu/intel/gemm/jit/include/gemmstone/generator.hpp index 72da9f4a2d5..1e5ec6476a1 100644 --- a/src/gpu/intel/gemm/jit/include/gemmstone/generator.hpp +++ b/src/gpu/intel/gemm/jit/include/gemmstone/generator.hpp @@ -274,7 +274,7 @@ template class Generator : public GENERATOR_BASE(hw) { template void emov(const ngen::InstructionModifier &mod, ngen::RegData dst, ngen::Immediate src0, const CommonStrategy &strategy, CommonState &state, ngen::SourceLocation loc = {}) { ngen::EmulationImplementation::emov
(*this, mod, dst, src0, strategy.emulate, loc); } template void eadd(const ngen::InstructionModifier &mod, const ngen::RegData &dst, const ngen::RegData &src0, const ngen::RegData &src1, const CommonStrategy &strategy, CommonState &state, ngen::SourceLocation loc = {}); template void eadd(const ngen::InstructionModifier &mod, const ngen::RegData &dst, const ngen::RegData &src0, ngen::Immediate src1, const CommonStrategy &strategy, const CommonState &state, ngen::SourceLocation loc = {}) { ngen::EmulationImplementation::eadd
(*this, mod, dst, src0, src1, strategy.emulate, state.emulate, loc); } - template void emul(const ngen::InstructionModifier &mod, const ngen::RegData &dst, const ngen::RegData &src0, const ngen::RegData &src1, const CommonStrategy &strategy, const CommonState &state, ngen::SourceLocation loc = {}) { ngen::EmulationImplementation::emul
(*this, mod, dst, src0, src1, strategy.emulate, state.emulate, loc); } + template void emul(const ngen::InstructionModifier &mod, const ngen::RegData &dst, const ngen::RegData &src0, const ngen::RegData &src1, const CommonStrategy &strategy, CommonState &state, ngen::SourceLocation loc = {}); template void emul(const ngen::InstructionModifier &mod, const ngen::RegData &dst, const ngen::RegData &src0, ngen::Immediate src1, const CommonStrategy &strategy, const CommonState &state, ngen::SourceLocation loc = {}) { ngen::EmulationImplementation::emul
(*this, mod, dst, src0, src1, strategy.emulate, state.emulate, loc); } template void eshl(const ngen::InstructionModifier &mod, ngen::RegData dst, ngen::RegData src0, uint16_t src1, const CommonStrategy &strategy, const CommonState &state, ngen::SourceLocation loc = {}) { ngen::EmulationImplementation::eshl
(*this, mod, dst, src0, src1, strategy.emulate, state.emulate, loc); } template void eshr(const ngen::InstructionModifier &mod, ngen::RegData dst, ngen::RegData src0, uint16_t src1, const CommonStrategy &strategy, const CommonState &state, ngen::SourceLocation loc = {}) { ngen::EmulationImplementation::eshr
(*this, mod, dst, src0, src1, strategy.emulate, state.emulate, loc); } @@ -338,6 +338,7 @@ template class Generator : public GENERATOR_BASE(hw) { void gemmReverseLoops(const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state); void gemmScaleInputs(const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state); void gemmCalcWGRemainders(const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state); + void gemmCalcWGIndices(const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state); void gemmGetBatchIDs(const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state); void gemmReleaseBatchIDs(const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state); @@ -516,9 +517,8 @@ template class Generator : public GENERATOR_BASE(hw) { void gemmRepack2DOffsetData(Type Text, const RegisterLayout &layoutSrc, const RegisterLayout &layoutDst, const GRFMultirange &src, const GRFMultirange &dst, const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state); void dequantizeInt4Shift(Type Tsrc, GRFMultirange src, const CommonStrategy &strategy); void dequantizeInt4(bool doA, const RegisterLayout &layoutSrc, const RegisterLayout &layoutDst, const RegisterLayout &layoutOffset, const RegisterLayout &layoutScale, const GRFMultirange &src, const GRFMultirange &dst, const GRFMultirange &offset, const GRFMultirange &scale, int offR, int offC, int h, int kab_load, int kq_load, const GEMMProblem *problem, const CommonStrategy &strategy, CommonState &state, bool s4Shift = true); - void gemmDequantizeOperation(bool doA, Type T, Type Tq, BinaryOp op, const RegisterLayout &layout, const RegisterLayout &qlayout, const GRFMultirange ®s, const GRFMultirange &qregs, int h, int kab_load, int kq_load, const GEMMProblem &problem, CommonState &state); + void gemmDequantizeOperation(bool doA, Type T, Type Tq, BinaryOp op, const RegisterLayout &layout, const RegisterLayout &qlayout, const GRFMultirange ®s, const GRFMultirange &qregs, int h, int kab_load, int kq_load, const GEMMProblem &problem, const CommonStrategy &strategy, CommonState &state); void gemmDequantizeAB(bool doA, const RegisterLayout &layoutSrc, const RegisterLayout &layoutDst, const GRFMultirange &src, const GRFMultirange &dst, int h, int kab_load, int kab_repack, int kq_load, const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state, bool s4Shift = true); - // register_allocation.cxx ngen::Bundle getHint(HintType type); ngen::Bundle getHint(HintType type, const CommonStrategy &strategy); diff --git a/src/gpu/intel/gemm/jit/include/gemmstone/kernel_catalog.hpp b/src/gpu/intel/gemm/jit/include/gemmstone/kernel_catalog.hpp index 6dc19b5b41b..31445badc56 100644 --- a/src/gpu/intel/gemm/jit/include/gemmstone/kernel_catalog.hpp +++ b/src/gpu/intel/gemm/jit/include/gemmstone/kernel_catalog.hpp @@ -77,6 +77,7 @@ enum HWTags : char { HWTagXeHPC = 'F', HWTagXe2 = 'G', HWTagXe3 = 'H', + HWTagXe3p = 'I', }; struct Selector { diff --git a/src/gpu/intel/gemm/jit/include/gemmstone/kernel_selector.hpp b/src/gpu/intel/gemm/jit/include/gemmstone/kernel_selector.hpp index 79d73bba817..cc062a471d4 100644 --- a/src/gpu/intel/gemm/jit/include/gemmstone/kernel_selector.hpp +++ b/src/gpu/intel/gemm/jit/include/gemmstone/kernel_selector.hpp @@ -48,6 +48,8 @@ struct MatchParamsBase char precisionCExt = 0; bool ignoreSizes = false; bool ignoreCase = false; + bool ReqBDPASDims = false; + bool reqNUnroll32 = false; int stepping = 0; int alignment[3] = {0, 0, 0}; int unrollReq[3] = {1, 1, 1}; diff --git a/src/gpu/intel/gemm/jit/include/gemmstone/problem.hpp b/src/gpu/intel/gemm/jit/include/gemmstone/problem.hpp index c6b0db2c8e6..700a13c7bb2 100644 --- a/src/gpu/intel/gemm/jit/include/gemmstone/problem.hpp +++ b/src/gpu/intel/gemm/jit/include/gemmstone/problem.hpp @@ -191,6 +191,7 @@ struct GEMMProblem : public CommonProblem { bool sumA = false, sumB = false; // If true, calculate A row sums/B column sums and store in CO. bool forceGroupSumsA = false; bool forceGroupSumsB = false; + bool bdpasEnabled = false; // bdpas enabled for problem. bool cMXScale = false; MatrixAddressing sroundSeed; PostOpsProblem postOps; // Fused post operations to apply @@ -261,12 +262,33 @@ struct GEMMProblem : public CommonProblem { bool quantized2DA() const { return forceGroupSumsB || aOffset2D() || aScale2D(); } bool quantized2DB() const { return forceGroupSumsA || bOffset2D() || bScale2D(); } - bool earlyDequantizeA() const { return (aOffset == ABOffset::Calc && earlyDequantizableOffset(Ta_ext, Tao, Ta)) || (aScale2D() && (Ta_scale.isSubsetOf(Ta) || Ta.isFP())); } - bool earlyDequantizeB() const { return (bOffset == ABOffset::Calc && earlyDequantizableOffset(Tb_ext, Tbo, Tb)) || (bScale2D() && (Tb_scale.isSubsetOf(Tb) || Tb.isFP())); } + bool earlyDequantizeA() const { return (aOffset == ABOffset::Calc && earlyDequantizableOffset(Ta_ext, Tao, Ta)) || (aScale2D() && (Ta_scale.isSubsetOf(Ta) || (Ta.isFP() && !Ta.isF4() && !Ta.isF8()))); } + bool earlyDequantizeB() const { return (bOffset == ABOffset::Calc && earlyDequantizableOffset(Tb_ext, Tbo, Tb)) || (bScale2D() && (Tb_scale.isSubsetOf(Tb) || (Tb.isFP() && !Tb.isF4() && !Tb.isF8()))); } bool needsASums() const { return sumA || (bOffset == ABOffset::Calc && !earlyDequantizeB() && !quantized2DB()); } bool needsBSums() const { return sumB || (aOffset == ABOffset::Calc && !earlyDequantizeA() && !quantized2DA()); } + + bool nativeBDPAS(ngen::HW hw) const { + return (((Ta == Tb) && (Ta.isF8() || Ta == Type::f16 || Ta == Type::bf16) && hw >= ngen::Core::XE3P_35_10) || (Ta.isF4() && Tb.isF4() && hw >= ngen::Core::XE3P_35_11)); + } + bool forceLateQuant(ngen::HW hw, int minOPCount) const { + bool fp4_fp8_dpas = ((Ta.isF8() && Tb.isF8()) || (Ta.isF4() && Tb.isF4())) && nativeBDPAS(hw); + return fp4_fp8_dpas && ((aScale2D() && !preferBDPAS(hw) && aqGroupK % minOPCount == 0) + || (bScale2D() && !preferBDPAS(hw) && bqGroupK % minOPCount == 0)); + } + bool forceUpconvertQuant(ngen::HW hw) const { + // Cover cases where scale group < ksys by upconverting, using normal dpas and scale routines. + return nativeBDPAS(hw) && Ta.isF4() && Tb.isF4() && ((aScale2D() && !preferBDPAS(hw) && aqGroupK % 64 != 0) + || (bScale2D() && !preferBDPAS(hw) && bqGroupK % 64 != 0)); + } + bool preferBDPAS(ngen::HW hw) const { + bool useBDPAS = (bdpasEnabled && nativeBDPAS(hw) && (aScale2D() || bScale2D())); + if (aScale2D()) useBDPAS &= (Ta_scale == Type::f8_e8m0) && (aqGroupK % 32 == 0); + if (bScale2D()) useBDPAS &= (Tb_scale == Type::f8_e8m0) && (bqGroupK % 32 == 0); + return useBDPAS; + } + bool needsAGroupSums() const { return (bOffset == ABOffset::Calc && quantized2DB() && !earlyDequantizableOffset(Tb_ext, Tbo, Tb)); } bool needsBGroupSums() const { return (aOffset == ABOffset::Calc && quantized2DA() && !earlyDequantizableOffset(Ta_ext, Tao, Ta)); } @@ -308,6 +330,7 @@ struct GEMMProblem : public CommonProblem { s.append(AO, BO, CO); s.append(A_scale, B_scale, C_scale); s.append(checkBeta0); + s.append(bdpasEnabled); s.append(aOffset, bOffset); s.append(aoPtrDims, boPtrDims, coPtrDims); s.append(asPtrDims, bsPtrDims, csPtrDims); @@ -340,11 +363,16 @@ void GEMMProblem::autoTypeConversions(ngen::HW hw, bool systolicAvailable) if (Ta == Ta_ext.asSigned()) Ta = Ta_ext; if (Tb == Tb_ext.asSigned()) Tb = Tb_ext; - + if (hw < HW::XE3P_35_10 || !systolicAvailable) + { if (Ta.isF8()) Ta = Type::f16; if (Tb.isF8()) Tb = Type::f16; + } + if (hw < HW::XE3P_35_11 || !systolicAvailable || forceUpconvertQuant(hw)) + { if (Ta.isF4()) Ta = Type::f16; if (Tb.isF4()) Tb = Type::f16; + } if (!systolicAvailable && Tc == Type::f32) { if (Ta == Type::f16) Ta = Type::f32; diff --git a/src/gpu/intel/gemm/jit/include/gemmstone/type.hpp b/src/gpu/intel/gemm/jit/include/gemmstone/type.hpp index 16527b8b600..8ab1fd069a1 100644 --- a/src/gpu/intel/gemm/jit/include/gemmstone/type.hpp +++ b/src/gpu/intel/gemm/jit/include/gemmstone/type.hpp @@ -106,9 +106,8 @@ class Type { static constexpr ngen::DataType ngen_nf4() { return static_cast(0x58); } static constexpr ngen::DataType ngen_e8m0() { return static_cast(0x79); } - // Not a valid nGEN DataType; for gemmstone internal use only - static constexpr ngen::DataType ngen_e2m1() { return static_cast(0x5A);} - static constexpr ngen::DataType ngen_e3m0() { return static_cast(0x5B);} + static constexpr ngen::DataType ngen_e2m1() { return ngen::DataType::e2m1; } + static constexpr ngen::DataType ngen_e3m0() { return ngen::DataType::e3m0; } ngen::DataType ngen() const diff --git a/src/gpu/intel/gemm/jit/include/internal/generator_inline.hxx b/src/gpu/intel/gemm/jit/include/internal/generator_inline.hxx index 8db10cf564f..416f16084b6 100644 --- a/src/gpu/intel/gemm/jit/include/internal/generator_inline.hxx +++ b/src/gpu/intel/gemm/jit/include/internal/generator_inline.hxx @@ -19,6 +19,7 @@ static inline int r0DWords(ngen::HW hw) { + if (hw >= ngen::HW::XE3P_35_10) return 16; return 8; } diff --git a/src/gpu/intel/gemm/jit/pd.cpp b/src/gpu/intel/gemm/jit/pd.cpp index 4a2d63b24e0..357fa08d48b 100644 --- a/src/gpu/intel/gemm/jit/pd.cpp +++ b/src/gpu/intel/gemm/jit/pd.cpp @@ -156,6 +156,7 @@ status_t pd_t::init_post_ops() { bool converted; CHECK(maybe_convert_scales_to_postop(a_scale_md_, DNNL_ARG_A, a_scales.get_data_type(), a_scales.is_mx(), converted)); + if (converted) a_quant.scale_ndims = -1; } @@ -418,12 +419,26 @@ bool pd_t::scales_ok() { return false; if (!x_scales.has_default_groups()) { + const memory_desc_t *md = nullptr; + switch (s) { + // Swap descriptors to follow column major format + case DNNL_ARG_A: md = &desc()->b_desc; break; + case DNNL_ARG_B: md = &desc()->a_desc; break; + case DNNL_ARG_C: md = &desc()->c_desc; break; + } + if (!md) gpu_error_not_expected(); + int count = 0; + for (int i = 0; i < 2; i++) { + int gs = x_scales.get_group(i); + int dim = md->dims[md->ndims - 2 + i]; + if (1 < gs && gs < dim) count++; + } + if (count > 1) return false; + // Dynamic Dst Quant only supported with `1x32` groups. if (s == DNNL_ARG_C && with_mx_scale() && x_scales.get_group(0) != 1 && x_scales.get_group(1) != 32) return false; - // Other dynamic quant unsupported - if (x_scales.is_dynamic()) return false; } } @@ -722,6 +737,13 @@ status_t pd_t::init_GEMMProblem( if (problem.aqGroupK == 0) problem.aqGroupK = problem.bqGroupK; if (problem.bqGroupK == 0) problem.bqGroupK = problem.aqGroupK; } + // Disable bdpas with unsupported k dim. + // TODO: Enable 2D block, masking scale loads. + if (problem.nativeBDPAS(hw)) { + if ((!(problem.Ta.isF4() || problem.Tb.isF4()) || k % 64 == 0)) + problem.bdpasEnabled = true; + } + return status::success; } diff --git a/src/gpu/intel/gemm/jit/pd.hpp b/src/gpu/intel/gemm/jit/pd.hpp index 192637ec8a7..df5473fa56c 100644 --- a/src/gpu/intel/gemm/jit/pd.hpp +++ b/src/gpu/intel/gemm/jit/pd.hpp @@ -266,16 +266,19 @@ struct pd_t : public gemm::pd_t { auto attr_info = attr_info_t::create(attr()); return attr_info.with_host_src_zp; } + bool c_zp_host_scalar() const { auto attr_info = attr_info_t::create(attr()); return attr_info.with_host_dst_zp; } + int a_q2d_group_k() const { return a_quant.group_k; } int a_q2d_group_m() const { return a_quant.group_m; } int b_q2d_group_k() const { return b_quant.group_k; } int b_q2d_group_n() const { return b_quant.group_n; } int c_q2d_group_m() const { return c_quant.group_m; } int c_q2d_group_n() const { return c_quant.group_n; } + int align(int arg) const { auto dt = get_type(arg); auto align = utils::max_pow2_div(types::elements_to_bytes(dt, ld(arg))); diff --git a/src/gpu/intel/gemm/jit/selector/db/kernel.db b/src/gpu/intel/gemm/jit/selector/db/kernel.db index 33a06504921..cfb97479aac 100644 --- a/src/gpu/intel/gemm/jit/selector/db/kernel.db +++ b/src/gpu/intel/gemm/jit/selector/db/kernel.db @@ -1236,10 +1236,67 @@ auto _CATALOG_ = kcatalog::toArray({ {{'G', "gemm", {"Q", "H", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "I"}, "aS16x2+S32@48 aB16+S32@24 aB wg 1x16 sys af fn nmk vav hi sr br sb1 dm grf256", {16, (LoopType) 255, 256, {(LoopType) 145, (LoopType) 255, (LoopType) 255}, {262144, 131072, 16777216}, {262144, 131072, 16777216}, {16, 8, 16}, {1, 16, 1}, 1, (WGType) 1, 257, 0, 0, {1, 2, 4}, {true, true, true}}, {'E', 17, {208528, 88784.2, 0, 0, 0, 0, 8.7426, 1.85755, 10.5788, 23.598, 0.0796937, 0.0905916, 0.0100367, 1, 1, 0.997218, 2.8365e-13}}}, {{'G', "gemm", {"Q", "H", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "I"}, "aS16x2+S64@136 aB64+S32@120 aB wg 1x4x4 kr ca4x2 ks32 ql nb 0x4 sys af fn nmk vav hi sr br dm grf256", {16, (LoopType) 255, 256, {(LoopType) 145, (LoopType) 255, (LoopType) 2}, {262144, 65536, 16777216}, {262144, 65536, 16777216}, {16, 4, 64}, {1, 4, 4}, 1, (WGType) 1, 261, 4096, 4096, {1, 2, 4}, {true, true, true}}, {'E', 17, {354836, 21143.7, -2072.14, 491.338, 0, 0, 1.22446, 2.00589, 8.31791, 21.2217, 0.0792542, 0.0801645, 0.0067101, 1, 1.00543, 0.977751, 1.29564e-13}}}, {{'G', "gemm", {"Q", "H", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "I"}, "aS32+S32@56 aB64+S16@40 aB wg 4x1x8 kr sys af vav hi sr br sn dm grf256", {16, (LoopType) 255, 256, {(LoopType) 144, (LoopType) 255, (LoopType) 2}, {262144, 131072, 16777216}, {262144, 131072, 16777216}, {16, 8, 64}, {4, 1, 8}, 1, (WGType) 1, 261, 0, 2048, {1, 2, 4}, {true, true, true}}, {'E', 17, {272848, 16581.6, 663.387, 116.699, 0, 0, 0.628749, 2.33321, 2.54671, 7.05347, 0.0908753, 0.0908753, 0, 1, 1.00016, 0.998154, 1.47604e-14}}}, +{{'G', "gemm", {"Q", "H", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "I"}, "aS16+S64@40 aB16+S1,64@8 aB wg 8x4 cab4x2 ks16 nb 8x4 sys xaf rr fx vav hi sr br sn dm grf256", {16, (LoopType) 255, 256, {(LoopType) 144, (LoopType) 255, (LoopType) 255}, {524288, 1048576, 16777216}, {524288, 1048576, 16777216}, {32, 64, 16}, {8, 4, 1}, 1, (WGType) 1, 257, 65536, 0, {1, 2, 4}, {true, true, true}}, {'E', 17, {206051, 496759, 0, 0, 0, 0, 1.65633, 1.69184, 3.52378, 8.6919, 0.007993, 0.007993, 0, 0.97378, 1.11547, 0.997756, 7.28903e-13}}}, {{'G', "gemm", {"Q", "H", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "ABI"}, "at16x2+m64@56 am32+m32@56 aB wg 8x4 xaf fx rr vav hi pt sr br sb64 bk0 sm sn grf256 sys kv afb", {16, (LoopType) 255, 256, {(LoopType) 208, (LoopType) 255, (LoopType) 255}, {524288, 1048576, 16777216}, {524288, 1048576, 32}, {32, 64, 32}, {8, 4, 1}, 1, (WGType) 1, 441, 0, 0, {128, 128, 128}, {true, true, true}}, {'E', 17, {204616, 159564, 0, 0, 4.01506e+06, 3.64029e+06, 0.770293, 1.80107, 3.5616, 8.3821, 0.00782289, 0.00782289, 0, 1, 1.01417, 0.906822, 7.62417e-13}}}, {{'G', "gemm", {"Q", "H", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "ABIs"}, "at32+m64@48 am32+m32@48 aB wg 8x4 af rr vav hi pt sr br sb64 bk0 sm sn grf256 sys kv afb", {16, (LoopType) 255, 256, {(LoopType) 208, (LoopType) 255, (LoopType) 255}, {524288, 917504, 16777216}, {524288, 917504, 32}, {32, 56, 32}, {8, 4, 1}, 1, (WGType) 1, 441, 0, 0, {128, 128, 128}, {true, true, true}}, {'E', 17, {198239, 147417, 0, 0, 3.82617e+06, 3.36556e+06, 0.677253, 1.79465, 3.55224, 8.41829, 0.00787345, 0.00787345, 0, 0.996828, 1, 0.994257, 4.35566e-13}}}, {{'G', "gemm", {"Q", "H", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "ABI"}, "at16x2+m32@48 am32+m16@64 aB wg 4x2x4 kr xaf st vav hi pt sr br sb64 bk0 sm sn grf256 sys kv afb", {16, (LoopType) 255, 256, {(LoopType) 208, (LoopType) 255, (LoopType) 2}, {262144, 262144, 16777216}, {262144, 262144, 32}, {16, 16, 32}, {4, 2, 4}, 1, (WGType) 1, 445, 0, 8192, {128, 128, 128}, {true, true, true}}, {'E', 17, {237840, 25265.3, 6336.89, 1991.14, 599704, 859419, 0.635209, 1.7164, 3.62105, 7.82688, 0.0117555, 0.0035939, 0.00799229, 0.956586, 1.1111, 0.916636, 1.52516e-12}}}, {{'G', "gemm", {"Q", "H", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "I"}, "aS16+S64@40 aB16+S1,64@8 aB wg 8x4 cab4x2 ks16 nb 8x4 sys xaf rr fx vav hi sr br sn dm grf256", {16, (LoopType) 255, 256, {(LoopType) 144, (LoopType) 255, (LoopType) 255}, {524288, 1048576, 16777216}, {524288, 1048576, 16777216}, {32, 64, 16}, {8, 4, 1}, 1, (WGType) 1, 257, 65536, 0, {1, 2, 4}, {true, true, true}}, {'E', 17, {206051, 496759, 0, 0, 0, 0, 1.65633, 1.69184, 3.52378, 8.6919, 0.007993, 0.007993, 0, 0.97378, 1.11547, 0.997756, 7.28903e-13}}}, {{'H', "gemm", {"F", "H", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, 25, -1}, {-1, 32, -1}, {-1, 25, -1}, {-1, 32, -1}, {16, 16, 1}, "IAB"}, "at32+m128@80 am32+m128@80 aB wg 8x1x4 ikr wx2 xaf vav hi pt sr br sb128 bk0 sm sn bm0 nmk sys", {16, (LoopType) 255, 128, {(LoopType) 209, (LoopType) 255, (LoopType) 2}, {16777216, 262144, 16777216}, {262144, 262144, 16777216}, {16, 16, 128}, {8, 1, 4}, 2, (WGType) 1, 4357, 0, 8192, {16, 16, 4}, {true, true, true}}, {'W', 1, {256}}}, -{{'H', "gemm", {"O", "O", "I"}, {"T", "N", "N"}}, {-1, -1, {3072, 3072, 720}, {5120, 5120, 1280}, {3072, 3072, 720}, {5120, 5120, 1280}, {16, 16, 1}, "ABI"}, "at32+m64@80 am64x2+m64@32 aB wg 2x8 sys af vav hi sr br sm sn dm grf256", {16, (LoopType) 255, 256, {(LoopType) 144, (LoopType) 255, (LoopType) 255}, {1048576, 524288, 16777216}, {1048576, 524288, 16777216}, {64, 32, 64}, {2, 8, 1}, 1, (WGType) 1, 268435713, 0, 0, {16, 16, 4}, {true, true, true}}, {'E', 17, {212791, 72569, 0, 0, 0, 0, 0.437659, 0.23103, 2.07126, 3.65249, 0.00447061, 0.00033606, 0.00573596, 0.841664, 1.02219, 1.02726, -1.67975e-14}}} +{{'H', "gemm", {"O", "O", "I"}, {"T", "N", "N"}}, {-1, -1, {3072, 3072, 720}, {5120, 5120, 1280}, {3072, 3072, 720}, {5120, 5120, 1280}, {16, 16, 1}, "ABI"}, "at32+m64@80 am64x2+m64@32 aB wg 2x8 sys af vav hi sr br sm sn dm grf256", {16, (LoopType) 255, 256, {(LoopType) 144, (LoopType) 255, (LoopType) 255}, {1048576, 524288, 16777216}, {1048576, 524288, 16777216}, {64, 32, 64}, {2, 8, 1}, 1, (WGType) 1, 268435713, 0, 0, {16, 16, 4}, {true, true, true}}, {'E', 17, {212791, 72569, 0, 0, 0, 0, 0.437659, 0.23103, 2.07126, 3.65249, 0.00447061, 0.00033606, 0.00573596, 0.841664, 1.02219, 1.02726, -1.67975e-14}}}, +{{'I', "gemm", {"B", "B", "S"}, {"N", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 4, 1}, "ABIph"}, "av32+m16@128 am32+m32@128 aB wg 8x4 sys xaf fx st kc2 sr br grf512 sn bo pt", {16, (LoopType) 255, 512, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {1048576, 1048576, 16777216}, {1048576, 1048576, 16777216}, {64, 64, 32}, {8, 4, 1}, 1, (WGType) 1, 257, 0, 0, {4, 4, 4}, {true, true, true}}, {'W', 1, {4096}}}, +{{'I', "gemm", {"B", "B", "S"}, {"N", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "Ih"}, "aB32+B16@128 aB32+B32@128 aB wg 8x4 sys xaf fx st kc2 sr br dm sn grf512 l4 bo pt", {16, (LoopType) 255, 512, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {524288, 524288, 16777216}, {524288, 524288, 16777216}, {32, 32, 32}, {8, 4, 1}, 1, (WGType) 1, 257, 0, 0, {2, 2, 4}, {true, true, true}}, {'W', 1, {1024}}}, +{{'I', "gemm", {"B", "B", "S"}, {"N", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, 8, -1}, {-1, -1, -1}, {-1, 8, -1}, {4, 4, 1}, "ABI"}, "av16+B32@32 am16+m32@32 aB wg 2x4 li nmk pt sr br sb64 bk0 sys ska", {16, (LoopType) 255, 128, {(LoopType) 225, (LoopType) 255, (LoopType) 255}, {1048576, 32768, 16777216}, {1048576, 32768, 16777216}, {64, 2, 16}, {2, 4, 1}, 1, (WGType) 1, 257, 0, 0, {4, 4, 4}, {true, true, true}}, {'W', 1, {128}}}, +{{'I', "gemm", {"B", "B", "S"}, {"N", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 4, 1}, "ABIph"}, "av32+m16@128 at32+m16@128 aB wg 8x4 sys xaf fx st kc2 sr br grf512 bo pt", {16, (LoopType) 255, 512, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {1048576, 1048576, 16777216}, {1048576, 1048576, 16777216}, {64, 64, 32}, {8, 4, 1}, 1, (WGType) 1, 257, 0, 0, {4, 4, 4}, {true, true, true}}, {'W', 1, {4096}}}, +{{'I', "gemm", {"B", "B", "S"}, {"N", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, ""}, "aB8+B16@64 aB8+B16@64 aB wg 4x8 kc8 sr br grf256 l4 bo pt", {16, (LoopType) 255, 256, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {524288, 524288, 16777216}, {524288, 524288, 16777216}, {32, 32, 8}, {4, 8, 1}, 1, (WGType) 1, 257, 0, 0, {2, 2, 4}, {true, true, true}}, {'W', 1, {1024}}}, +{{'I', "gemm", {"B", "B", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 4, 1}, "ABIph"}, "at32+m32@128 am32+m32@128 aB wg 8x4 sys xaf fx st kc2 sr br grf512 sm sn bo pt", {16, (LoopType) 255, 512, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {1048576, 1048576, 16777216}, {1048576, 1048576, 16777216}, {64, 64, 32}, {8, 4, 1}, 1, (WGType) 1, 257, 0, 0, {4, 4, 4}, {true, true, true}}, {'W', 1, {4096}}}, +{{'I', "gemm", {"B", "B", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "Ih"}, "aS32+B32@128 aB32+B32@128 aB wg 8x4 sys xaf fx st kc2 sr br dm sm sn grf512 l4 bo pt", {16, (LoopType) 255, 512, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {524288, 524288, 16777216}, {524288, 524288, 16777216}, {32, 32, 32}, {8, 4, 1}, 1, (WGType) 1, 257, 0, 0, {2, 2, 4}, {true, true, true}}, {'W', 1, {1024}}}, +{{'I', "gemm", {"B", "B", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, 8, -1}, {-1, -1, -1}, {-1, 8, -1}, {4, 4, 1}, "ABIp"}, "at32 am128 aB wg 2x1x2 kr sys xaf st li pt nmk sb128 bk0 bm0", {16, (LoopType) 255, 128, {(LoopType) 225, (LoopType) 255, (LoopType) 2}, {16777216, 131072, 16777216}, {262144, 131072, 16777216}, {16, 8, 128}, {2, 1, 2}, 1, (WGType) 0, 261, 0, 1024, {4, 4, 4}, {true, true, true}}, {'W', 1, {128}}}, +{{'I', "gemm", {"B", "B", "S"}, {"T", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 4, 1}, "ABIqh"}, "am32+m32@128 av32+m16@128 aS wg 8x4 sys xaf fx st kc2 sr grf512 sm bo pt", {16, (LoopType) 255, 512, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {1048576, 1048576, 16777216}, {1048576, 1048576, 16777216}, {64, 64, 32}, {8, 4, 1}, 1, (WGType) 1, 257, 0, 0, {4, 4, 4}, {true, true, true}}, {'W', 1, {4096}}}, +{{'I', "gemm", {"B", "B", "S"}, {"T", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "Ih"}, "aB32+B32@128 aB32+B16@128 aS wg 8x4 sys xaf fx st kc2 sr dm sm grf512 l4 bo pt", {16, (LoopType) 255, 512, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {524288, 524288, 16777216}, {524288, 524288, 16777216}, {32, 32, 32}, {8, 4, 1}, 1, (WGType) 1, 257, 0, 0, {2, 2, 4}, {true, true, true}}, {'W', 1, {1024}}}, +{{'I', "gemm", {"D", "D", "D"}, {"A", "B", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, ""}, "aB8+B8@16 aB8+B8@16 aB kc8 wg 8x4 grf256 bo pt", {16, (LoopType) 255, 256, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {524288, 524288, 16777216}, {524288, 524288, 16777216}, {32, 32, 8}, {8, 4, 1}, 1, (WGType) 1, 256, 0, 0, {128, 128, 8}, {true, true, true}}, {'W', 1, {1024}}}, +{{'I', "gemm", {"D", "D", "D"}, {"N", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "h"}, "aB8+m8@24 aS8+m16@24 aB wg 4x8 kc8 sr br grf512 sn bo pt", {16, (LoopType) 255, 512, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {1048576, 524288, 16777216}, {1048576, 524288, 16777216}, {64, 32, 8}, {4, 8, 1}, 1, (WGType) 1, 257, 0, 0, {8, 8, 8}, {true, true, true}}, {'W', 1, {2048}}}, +{{'I', "gemm", {"D", "D", "D"}, {"N", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "h"}, "aB8+m8@24 aB8+m8@24 aB wg 4x8 kc8 sr br grf512 bo pt", {16, (LoopType) 255, 512, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {1048576, 524288, 16777216}, {1048576, 524288, 16777216}, {64, 32, 8}, {4, 8, 1}, 1, (WGType) 1, 257, 0, 0, {8, 8, 8}, {true, true, true}}, {'W', 1, {2048}}}, +{{'I', "gemm", {"D", "D", "D"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "h"}, "aS8+m16@24 aS8+m16@24 aB wg 4x8 kc8 sr br grf512 sm sn bo pt", {16, (LoopType) 255, 512, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {1048576, 524288, 16777216}, {1048576, 524288, 16777216}, {64, 32, 8}, {4, 8, 1}, 1, (WGType) 1, 257, 0, 0, {8, 8, 8}, {true, true, true}}, {'W', 1, {2048}}}, +{{'I', "gemm", {"D", "D", "D"}, {"T", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "h"}, "aS8+m16@24 aB8+m8@24 aB wg 4x8 kc8 sr br grf512 sm bo pt", {16, (LoopType) 255, 512, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {1048576, 524288, 16777216}, {1048576, 524288, 16777216}, {64, 32, 8}, {4, 8, 1}, 1, (WGType) 1, 257, 0, 0, {8, 8, 8}, {true, true, true}}, {'W', 1, {2048}}}, +{{'I', "gemm", {"F", "B", "S"}, {"N", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "#I"}, "aB16x2 aB16x2 aB wg 4x8 cab4 ks64 af vav hi sr br bk0 sn dm grf256 sys l4 bo pt", {16, (LoopType) 255, 256, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {1048576, 131072, 16777216}, {1048576, 131072, 16777216}, {64, 8, 64}, {4, 8, 1}, 1, (WGType) 1, 257, 163840, 0, {1, 2, 4}, {true, true, true}}, {'E', 17, {1.07581e+06, 764320, 0, 0, 0, 0, 0.804535, 1.46469, 0.96438, 2.27185, 0.0120677, 0.0120677, 0, 1, 1.38109, 0.955498, 2.48341e-12}}}, +{{'I', "gemm", {"F", "H", "S"}, {"N", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "#I"}, "aB16x2 aB16x2 aB wg 4x8 cab4 ks64 af vav hi sr br bk0 sn dm grf256 sys l4 bo pt", {16, (LoopType) 255, 256, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {1048576, 131072, 16777216}, {1048576, 131072, 16777216}, {64, 8, 64}, {4, 8, 1}, 1, (WGType) 1, 257, 163840, 0, {1, 2, 4}, {true, true, true}}, {'E', 17, {1.07581e+06, 764320, 0, 0, 0, 0, 0.804535, 1.46469, 0.96438, 2.27185, 0.0120677, 0.0120677, 0, 1, 1.38109, 0.955498, 2.48341e-12}}}, +{{'I', "gemm", {"H", "H", "S"}, {"N", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 4, 1}, "ABIph"}, "av32+m16@128 am32+m32@128 aB wg 8x4 sys xaf fx st kc2 sr br grf512 sn bo pt", {16, (LoopType) 255, 512, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {1048576, 1048576, 16777216}, {1048576, 1048576, 16777216}, {64, 64, 32}, {8, 4, 1}, 1, (WGType) 1, 257, 0, 0, {4, 4, 4}, {true, true, true}}, {'W', 1, {4096}}}, +{{'I', "gemm", {"H", "H", "S"}, {"N", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "Ih"}, "aB32+B16@128 aB32+B32@128 aB wg 8x4 sys xaf fx st kc2 sr br dm sn grf512 l4 bo pt", {16, (LoopType) 255, 512, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {524288, 524288, 16777216}, {524288, 524288, 16777216}, {32, 32, 32}, {8, 4, 1}, 1, (WGType) 1, 257, 0, 0, {2, 2, 4}, {true, true, true}}, {'W', 1, {1024}}}, +{{'I', "gemm", {"H", "H", "S"}, {"N", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, 8, -1}, {-1, -1, -1}, {-1, 8, -1}, {4, 4, 1}, "ABI"}, "av16+B32@32 am16+m32@32 aB wg 2x4 li nmk pt sr br sb64 bk0 sys ska", {16, (LoopType) 255, 128, {(LoopType) 225, (LoopType) 255, (LoopType) 255}, {1048576, 32768, 16777216}, {1048576, 32768, 16777216}, {64, 2, 16}, {2, 4, 1}, 1, (WGType) 1, 257, 0, 0, {4, 4, 4}, {true, true, true}}, {'W', 1, {128}}}, +{{'I', "gemm", {"H", "H", "S"}, {"N", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 4, 1}, "ABIph"}, "av32+m16@128 at32+m16@128 aB wg 8x4 sys xaf fx st kc2 sr br grf512 bo pt", {16, (LoopType) 255, 512, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {1048576, 1048576, 16777216}, {1048576, 1048576, 16777216}, {64, 64, 32}, {8, 4, 1}, 1, (WGType) 1, 257, 0, 0, {4, 4, 4}, {true, true, true}}, {'W', 1, {4096}}}, +{{'I', "gemm", {"H", "H", "S"}, {"N", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, ""}, "aB8+B16@64 aB8+B16@64 aB wg 4x8 kc8 sr br grf256 l4 bo pt", {16, (LoopType) 255, 256, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {524288, 524288, 16777216}, {524288, 524288, 16777216}, {32, 32, 8}, {4, 8, 1}, 1, (WGType) 1, 257, 0, 0, {2, 2, 4}, {true, true, true}}, {'W', 1, {1024}}}, +{{'I', "gemm", {"H", "H", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 4, 1}, "ABIph"}, "at32+m32@128 am32+m32@128 aB wg 8x4 sys xaf fx st kc2 sr br grf512 sm sn bo pt", {16, (LoopType) 255, 512, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {1048576, 1048576, 16777216}, {1048576, 1048576, 16777216}, {64, 64, 32}, {8, 4, 1}, 1, (WGType) 1, 257, 0, 0, {4, 4, 4}, {true, true, true}}, {'W', 1, {4096}}}, +{{'I', "gemm", {"H", "H", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "Ih"}, "aS32+B32@128 aB32+B32@128 aB wg 8x4 sys xaf fx st kc2 sr br dm sm sn grf512 l4 bo pt", {16, (LoopType) 255, 512, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {524288, 524288, 16777216}, {524288, 524288, 16777216}, {32, 32, 32}, {8, 4, 1}, 1, (WGType) 1, 257, 0, 0, {2, 2, 4}, {true, true, true}}, {'W', 1, {1024}}}, +{{'I', "gemm", {"H", "H", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, 8, -1}, {-1, -1, -1}, {-1, 8, -1}, {4, 4, 1}, "ABIp"}, "at32 am128 aB wg 2x1x2 kr sys xaf st li pt nmk sb128 bk0 bm0", {16, (LoopType) 255, 128, {(LoopType) 225, (LoopType) 255, (LoopType) 2}, {16777216, 131072, 16777216}, {262144, 131072, 16777216}, {16, 8, 128}, {2, 1, 2}, 1, (WGType) 0, 261, 0, 1024, {4, 4, 4}, {true, true, true}}, {'W', 1, {128}}}, +{{'I', "gemm", {"H", "H", "S"}, {"T", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 4, 1}, "ABIqh"}, "am32+m32@128 av32+m16@128 aS wg 8x4 sys xaf fx st kc2 sr grf512 sm bo pt", {16, (LoopType) 255, 512, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {1048576, 1048576, 16777216}, {1048576, 1048576, 16777216}, {64, 64, 32}, {8, 4, 1}, 1, (WGType) 1, 257, 0, 0, {4, 4, 4}, {true, true, true}}, {'W', 1, {4096}}}, +{{'I', "gemm", {"H", "H", "S"}, {"T", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "Ih"}, "aB32+B32@128 aB32+B16@128 aS wg 8x4 sys xaf fx st kc2 sr dm sm grf512 l4 bo pt", {16, (LoopType) 255, 512, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {524288, 524288, 16777216}, {524288, 524288, 16777216}, {32, 32, 32}, {8, 4, 1}, 1, (WGType) 1, 257, 0, 0, {2, 2, 4}, {true, true, true}}, {'W', 1, {1024}}}, +{{'I', "gemm", {"O", "B", "S"}, {"N", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "#I"}, "aB16x2 aB16x2 aB wg 4x8 cab4 ks64 af vav hi sr br bk0 sn nb 4x8 dm grf256 sys l4 bo pt", {16, (LoopType) 255, 256, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {1048576, 131072, 16777216}, {1048576, 131072, 16777216}, {64, 8, 64}, {4, 8, 1}, 1, (WGType) 1, 257, 163840, 0, {1, 2, 4}, {true, true, true}}, {'E', 17, {1.07581e+06, 764320, 0, 0, 0, 0, 0.804535, 1.46469, 0.96438, 2.27185, 0.0120677, 0.0120677, 0, 1, 1.38109, 0.955498, 2.48341e-12}}}, +{{'I', "gemm", {"O", "O", "I"}, {"N", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 4, 1}, "IABh"}, "av64+m32@256 am64+m64@256 aB wg 8x4 sys xaf fx st kc2 sr br grf512 sn bo pt", {16, (LoopType) 255, 512, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {1048576, 1048576, 16777216}, {1048576, 1048576, 16777216}, {64, 64, 64}, {8, 4, 1}, 1, (WGType) 1, 257, 0, 0, {4, 4, 4}, {true, true, true}}, {'W', 1, {4096}}}, +{{'I', "gemm", {"O", "O", "I"}, {"N", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "Ih"}, "aB64+B32@256 aB64+B64@256 aB wg 8x4 sys xaf fx st kc2 sr br dm sn grf512 l4 bo pt", {16, (LoopType) 255, 512, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {524288, 524288, 16777216}, {524288, 524288, 16777216}, {32, 32, 64}, {8, 4, 1}, 1, (WGType) 1, 257, 0, 0, {1, 1, 4}, {true, true, true}}, {'W', 1, {1024}}}, +{{'I', "gemm", {"O", "O", "I"}, {"N", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, 8, -1}, {-1, -1, -1}, {-1, 8, -1}, {4, 4, 1}, "ABI"}, "av32+B64@64 am32+m64@64 aB wg 2x4 li nmk pt sr br sb64 bk0 sys ska", {16, (LoopType) 255, 128, {(LoopType) 225, (LoopType) 255, (LoopType) 255}, {1048576, 32768, 16777216}, {1048576, 32768, 16777216}, {64, 2, 32}, {2, 4, 1}, 1, (WGType) 1, 257, 0, 0, {4, 4, 4}, {true, true, true}}, {'W', 1, {128}}}, +{{'I', "gemm", {"O", "O", "I"}, {"N", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 4, 1}, "ABIph"}, "av64+m32@256 at64+m32@256 aB wg 8x4 sys xaf fx st kc2 sr br grf512 bo pt", {16, (LoopType) 255, 512, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {1048576, 1048576, 16777216}, {1048576, 1048576, 16777216}, {64, 64, 64}, {8, 4, 1}, 1, (WGType) 1, 257, 0, 0, {4, 4, 4}, {true, true, true}}, {'W', 1, {4096}}}, +{{'I', "gemm", {"O", "O", "I"}, {"N", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, ""}, "aB8+B16@64 aB8+B16@64 aB wg 4x8 sr br grf256 l4 bo pt", {16, (LoopType) 255, 256, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {524288, 524288, 16777216}, {524288, 524288, 16777216}, {32, 32, 8}, {4, 8, 1}, 1, (WGType) 1, 257, 0, 0, {1, 1, 4}, {true, true, true}}, {'W', 1, {1024}}}, +{{'I', "gemm", {"O", "O", "I"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 4, 1}, "IABh"}, "at64+m64@256 am64+m64@256 aB wg 8x4 sys xaf fx st kc2 sr br grf512 sm sn bo pt", {16, (LoopType) 255, 512, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {1048576, 1048576, 16777216}, {1048576, 1048576, 16777216}, {64, 64, 64}, {8, 4, 1}, 1, (WGType) 1, 257, 0, 0, {4, 4, 4}, {true, true, true}}, {'W', 1, {4096}}}, +{{'I', "gemm", {"O", "O", "I"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "Ih"}, "aS64+B64@256 aB64+B64@256 aB wg 8x4 sys xaf fx st kc2 sr br dm sm sn grf512 l4 bo pt", {16, (LoopType) 255, 512, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {524288, 524288, 16777216}, {524288, 524288, 16777216}, {32, 32, 64}, {8, 4, 1}, 1, (WGType) 1, 257, 0, 0, {1, 1, 4}, {true, true, true}}, {'W', 1, {1024}}}, +{{'I', "gemm", {"O", "O", "I"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, 8, -1}, {-1, -1, -1}, {-1, 8, -1}, {4, 4, 1}, "ABIp"}, "at64x2 am128 aB wg 2x1x2 kr sys xaf st li pt nmk sb128 bk0 bm0", {16, (LoopType) 255, 128, {(LoopType) 225, (LoopType) 255, (LoopType) 2}, {16777216, 131072, 16777216}, {262144, 131072, 16777216}, {16, 8, 128}, {2, 1, 2}, 1, (WGType) 0, 261, 0, 1024, {4, 4, 4}, {true, true, true}}, {'W', 1, {128}}}, +{{'I', "gemm", {"O", "O", "I"}, {"T", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 4, 1}, "IABh"}, "am64+m64@256 av64+m32@256 aS wg 8x4 sys xaf fx st kc2 sr grf512 sm bo pt", {16, (LoopType) 255, 512, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {1048576, 1048576, 16777216}, {1048576, 1048576, 16777216}, {64, 64, 64}, {8, 4, 1}, 1, (WGType) 1, 257, 0, 0, {4, 4, 4}, {true, true, true}}, {'W', 1, {4096}}}, +{{'I', "gemm", {"O", "O", "I"}, {"T", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "Ih"}, "aB64+B64@256 aB64+B32@256 aS wg 8x4 sys xaf fx st kc2 sr dm sm grf512 l4 bo pt", {16, (LoopType) 255, 512, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {524288, 524288, 16777216}, {524288, 524288, 16777216}, {32, 32, 64}, {8, 4, 1}, 1, (WGType) 1, 257, 0, 0, {1, 1, 4}, {true, true, true}}, {'W', 1, {1024}}}, +{{'I', "gemm", {"Q", "Q", "S"}, {"N", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "I"}, "aB32+B16@128 aB32+B32@128 aB wg 4x4 sys kc2 sr br dm sn grf512 l4 bo pt", {16, (LoopType) 255, 512, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {524288, 524288, 16777216}, {524288, 524288, 16777216}, {32, 32, 64}, {4, 4, 1}, 1, (WGType) 1, 257, 0, 0, {1, 1, 4}, {true, true, true}}, {'W', 1, {1024}}}, +{{'I', "gemm", {"Q", "Q", "S"}, {"N", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "Ip"}, "aB32 aB32 aB wg 8x4 cab3 ks32 af vav di hi sr br bk0 sn dm grf256 sys kv afb l4 bo pt", {16, (LoopType) 255, 256, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {262144, 262144, 16777216}, {262144, 262144, 32}, {16, 16, 32}, {8, 4, 1}, 1, (WGType) 1, 441, 18432, 0, {1, 1, 4}, {true, true, true}}, {'W', 1, {256}}}, +{{'I', "gemm", {"Q", "Q", "S"}, {"N", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 4, 1}, "ABIph"}, "av64+m32@256 am64+m64@256 aB wg 8x4 sys xaf fx st kc2 sr br grf512 sn bo pt", {16, (LoopType) 255, 512, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {1048576, 1048576, 16777216}, {1048576, 1048576, 16777216}, {64, 64, 64}, {8, 4, 1}, 1, (WGType) 1, 257, 0, 0, {4, 4, 4}, {true, true, true}}, {'W', 1, {4096}}}, +{{'I', "gemm", {"Q", "Q", "S"}, {"N", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, 8, -1}, {-1, -1, -1}, {-1, 8, -1}, {4, 4, 1}, "ABI"}, "av32+B64@64 am32+m64@64 aB wg 2x4 li nmk pt sr br sb64 bk0 sys ska", {16, (LoopType) 255, 128, {(LoopType) 225, (LoopType) 255, (LoopType) 255}, {1048576, 32768, 16777216}, {1048576, 32768, 16777216}, {64, 2, 32}, {2, 4, 1}, 1, (WGType) 1, 257, 0, 0, {4, 4, 4}, {true, true, true}}, {'W', 1, {128}}}, +{{'I', "gemm", {"Q", "Q", "S"}, {"N", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, 16, -1}, {1, 1, 1}, "#I"}, "aB32+m16@48 aS32 aB wg 16x1x2 kr cb4x2 ks32 xaf vav di li nmk sr br bk0 grf256 sys kv afb l4 bo pt", {16, (LoopType) 255, 256, {(LoopType) 193, (LoopType) 255, (LoopType) 2}, {262144, 262144, 16777216}, {262144, 262144, 32}, {16, 16, 32}, {16, 1, 2}, 1, (WGType) 1, 445, 2048, 16384, {1, 1, 4}, {true, true, true}}, {'W', 1, {256}}}, +{{'I', "gemm", {"Q", "Q", "S"}, {"N", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 4, 1}, "ABIp"}, "av64+m32@128 at64+m32@128 aB wg 8x4 sys xaf fx st kc2 sr br grf512 bo pt", {16, (LoopType) 255, 512, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {1048576, 1048576, 16777216}, {1048576, 1048576, 16777216}, {64, 64, 64}, {8, 4, 1}, 1, (WGType) 1, 257, 0, 0, {4, 4, 4}, {true, true, true}}, {'W', 1, {4096}}}, +{{'I', "gemm", {"Q", "Q", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "I"}, "aS32+S32@48 aB32+S16@48 aB wg 4x1 af vav nmk li sr br br sb256 bk0 sm grf256 sys l4 kd bo pt", {16, (LoopType) 255, 256, {(LoopType) 193, (LoopType) 255, (LoopType) 255}, {262144, 16384, 16777216}, {262144, 16384, 16777216}, {16, 1, 32}, {4, 1, 1}, 1, (WGType) 1, 257, 0, 0, {1, 1, 4}, {true, true, true}}, {'W', 1, {16}}}, +{{'I', "gemm", {"Q", "Q", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 4, 1}, "ABIph"}, "at64+m64@256 am64+m64@256 aB wg 8x4 sys xaf fx st kc2 sr br grf512 sm sn bo pt", {16, (LoopType) 255, 512, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {1048576, 1048576, 16777216}, {1048576, 1048576, 16777216}, {64, 64, 64}, {8, 4, 1}, 1, (WGType) 1, 257, 0, 0, {4, 4, 4}, {true, true, true}}, {'W', 1, {4096}}}, +{{'I', "gemm", {"Q", "Q", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "Ih"}, "aS64+B64@256 aB64+B64@256 aB wg 8x4 sys xaf fx st kc2 sr br dm sm sn grf512 l4 bo pt", {16, (LoopType) 255, 512, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {524288, 524288, 16777216}, {524288, 524288, 16777216}, {32, 32, 64}, {8, 4, 1}, 1, (WGType) 1, 257, 0, 0, {1, 1, 4}, {true, true, true}}, {'W', 1, {1024}}}, +{{'I', "gemm", {"Q", "Q", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, 8, -1}, {-1, -1, -1}, {-1, 8, -1}, {4, 4, 1}, "ABIp"}, "at64x2 am128 aB wg 2x1x2 kr sys xaf st li pt nmk sb128 bk0 bm0", {16, (LoopType) 255, 128, {(LoopType) 225, (LoopType) 255, (LoopType) 2}, {16777216, 131072, 16777216}, {262144, 131072, 16777216}, {16, 8, 128}, {2, 1, 2}, 1, (WGType) 0, 261, 0, 1024, {4, 4, 4}, {true, true, true}}, {'W', 1, {128}}}, +{{'I', "gemm", {"Q", "Q", "S"}, {"T", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "I"}, "aS32 aS32 aB sys grf256 cab2 wg 4x4 l4 sr br bo pt", {16, (LoopType) 255, 256, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {786432, 524288, 16777216}, {786432, 524288, 16777216}, {48, 32, 32}, {4, 4, 1}, 1, (WGType) 1, 257, 20480, 0, {1, 1, 4}, {true, true, true}}, {'W', 1, {1536}}}, +{{'I', "gemm", {"Q", "Q", "S"}, {"T", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 4, 1}, "ABIqh"}, "am64+m64@256 av64+m32@256 aS wg 8x4 sys xaf fx st kc2 sr grf512 sm bo pt", {16, (LoopType) 255, 512, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {1048576, 1048576, 16777216}, {1048576, 1048576, 16777216}, {64, 64, 64}, {8, 4, 1}, 1, (WGType) 1, 257, 0, 0, {4, 4, 4}, {true, true, true}}, {'W', 1, {4096}}}, +{{'I', "gemm", {"Q", "Q", "S"}, {"T", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "Ih"}, "aB64+B64@256 aB64+B16@256 aS wg 8x4 sys xaf fx st kc2 sr dm sm grf512 l4 bo pt", {16, (LoopType) 255, 512, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {524288, 524288, 16777216}, {524288, 524288, 16777216}, {32, 32, 64}, {8, 4, 1}, 1, (WGType) 1, 257, 0, 0, {1, 1, 4}, {true, true, true}}, {'W', 1, {1024}}}, +{{'I', "gemm", {"S", "S", "S"}, {"A", "B", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, ""}, "aB8+B8@16 aB8+B8@16 aB kc8 bo pt", {16, (LoopType) 255, 128, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {524288, 524288, 16777216}, {524288, 524288, 16777216}, {32, 32, 8}, {2, 8, 1}, 1, (WGType) 1, 256, 0, 0, {128, 128, 4}, {true, true, true}}, {'W', 1, {1024}}}, +{{'I', "gemm", {"S", "S", "S"}, {"N", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "h"}, "aB8+m8@24 aS8+m16@24 aB wg 8x4 kc8 sr br grf512 sn bo pt", {16, (LoopType) 255, 512, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {1048576, 1048576, 16777216}, {1048576, 1048576, 16777216}, {64, 64, 8}, {8, 4, 1}, 1, (WGType) 1, 257, 0, 0, {4, 4, 4}, {true, true, true}}, {'W', 1, {4096}}}, +{{'I', "gemm", {"S", "S", "S"}, {"N", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "h"}, "aB8+m8@24 aB8+m8@24 aB wg 8x4 kc8 sr br grf512 bo pt", {16, (LoopType) 255, 512, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {1048576, 1048576, 16777216}, {1048576, 1048576, 16777216}, {64, 64, 8}, {8, 4, 1}, 1, (WGType) 1, 257, 0, 0, {4, 4, 4}, {true, true, true}}, {'W', 1, {4096}}}, +{{'I', "gemm", {"S", "S", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "h"}, "aS8+m16@24 aS8+m16@24 aB wg 8x4 kc8 sr br grf512 sm sn bo pt", {16, (LoopType) 255, 512, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {1048576, 1048576, 16777216}, {1048576, 1048576, 16777216}, {64, 64, 8}, {8, 4, 1}, 1, (WGType) 1, 257, 0, 0, {4, 4, 4}, {true, true, true}}, {'W', 1, {4096}}}, +{{'I', "gemm", {"S", "S", "S"}, {"T", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "h"}, "aS8+m16@24 aB8+m8@24 aB wg 8x4 kc8 sr br grf512 sm bo pt", {16, (LoopType) 255, 512, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {1048576, 1048576, 16777216}, {1048576, 1048576, 16777216}, {64, 64, 8}, {8, 4, 1}, 1, (WGType) 1, 257, 0, 0, {4, 4, 4}, {true, true, true}}, {'W', 1, {4096}}} }); diff --git a/src/gpu/intel/gemm/jit/selector/kernel_selector.cpp b/src/gpu/intel/gemm/jit/selector/kernel_selector.cpp index 9688d73e96a..ccbe7409f79 100644 --- a/src/gpu/intel/gemm/jit/selector/kernel_selector.cpp +++ b/src/gpu/intel/gemm/jit/selector/kernel_selector.cpp @@ -289,6 +289,7 @@ const std::vector select(const kcatalog::Catalog &catal switch (hw) { case HWTagXe2: hw = HWTagXeHPC; break; case HWTagXe3: hw = HWTagXe2; break; + case HWTagXe3p: hw = HWTagXe3; break; default: hw = 0; break; } } while (hw); @@ -339,6 +340,14 @@ MatchParamsBase::MatchParamsBase(ngen::HW hw, bool systolicAvailable, bool isInt if(problem.Tbo.is4() || problem.Tb_scale.is4()){ unrollReq[LoopN] = 2; } + + ReqBDPASDims = problem.preferBDPAS(hw); + + if (ReqBDPASDims) { + unrollReq[LoopM] = 8; + unrollReq[LoopN] = 8; + } + if(problem.hasCMXScale() && unrollReq[LoopN] % 32){ unrollReq[LoopN] = 32; } @@ -351,6 +360,9 @@ MatchParamsBase::MatchParamsBase(ngen::HW hw, bool systolicAvailable, bool isInt case ngen::HW::XeHPC: selector.hw = kcatalog::HWTagXeHPC; break; case ngen::HW::Xe2: selector.hw = kcatalog::HWTagXe2; break; case ngen::HW::Xe3: selector.hw = kcatalog::HWTagXe3; break; + case ngen::HW::XE3P_35_10: + case ngen::HW::XE3P_35_11: + case ngen::HW::XE3P_UNKNOWN: selector.hw = kcatalog::HWTagXe3p; break; } auto &C = problem.C; @@ -432,6 +444,8 @@ MatchParamsBase::MatchParamsBase(ngen::HW hw, bool systolicAvailable, bool isInt *tagPtr++ = ReqXe2Block2D; if (hw == ngen::HW::Xe3) *tagPtr++ = ReqXe2Block2D; + if (one_of(hw, {ngen::HW::XE3P_35_10, ngen::HW::XE3P_35_11, ngen::HW::XE3P_UNKNOWN})) *tagPtr++ = ReqXe2Block2D; + sizes.batch = sizes.m = sizes.n = sizes.k = 0; } diff --git a/src/gpu/intel/gemm/with_post_ops.cl b/src/gpu/intel/gemm/with_post_ops.cl index b5bba206721..472d67118c4 100644 --- a/src/gpu/intel/gemm/with_post_ops.cl +++ b/src/gpu/intel/gemm/with_post_ops.cl @@ -114,6 +114,7 @@ __kernel void gemm_post_ops(__global SRC_DATA_T *src, #endif if (DST_ZERO_POINT) accumulator += dst_zp[0]; } + #if WITH_DROPOUT #if !DROPOUT_USE_HOST_SCALARS long dropout_seed = dropout_seed_buf[0]; diff --git a/src/gpu/intel/gemm/with_post_ops.cpp b/src/gpu/intel/gemm/with_post_ops.cpp index 23a1c741367..99f9a78fbac 100644 --- a/src/gpu/intel/gemm/with_post_ops.cpp +++ b/src/gpu/intel/gemm/with_post_ops.cpp @@ -215,6 +215,7 @@ status_t with_post_ops_t::pd_t::init_kernel_ctx( def_memory_desc_info(kernel_ctx, src_info, "SRC", false); def_memory_desc_info(kernel_ctx, bias_info, "BIAS", false); + if (dynamic_scales_) { dnnl_memory_desc d_md(*dst_md(0)); d_md.data_type = acc_type_; @@ -227,6 +228,7 @@ status_t with_post_ops_t::pd_t::init_kernel_ctx( } int ndims = src_info.ndims; + kernel_ctx.set_data_type(dynamic_scales_ ? acc_type_ : c_type); kernel_ctx.require_stateless_addressing(has_large_buffers()); @@ -361,6 +363,7 @@ status_t with_post_ops_t::execute(const exec_ctx_t &ctx) const { const auto group_size = pd()->attr()->scales_.get_group(DNNL_ARG_DST, -1); const auto c_d = nested_ctx.memory_mdw(DNNL_ARG_DST, pd()->dst_md()); + const int last = c_d.ndims() - 1; const dim_t D3 = c_d.ndims() > 5 ? c_d.dims()[last - 5] : 1; const dim_t D2 = c_d.ndims() > 4 ? c_d.dims()[last - 4] : 1; diff --git a/src/gpu/intel/include/types_specific.h b/src/gpu/intel/include/types_specific.h index 6735922d9c7..325c575a3bc 100644 --- a/src/gpu/intel/include/types_specific.h +++ b/src/gpu/intel/include/types_specific.h @@ -796,6 +796,8 @@ #elif WEI_SCALES_DT_BF8 #define WEI_SCALES_TO_REF(x) convert_float(cvt_f8_e5m2_to_hf(x)) #define REF_TO_WEI_SCALES(x) cvt_hf_to_f8_e5m2(convert_half(x)) +#elif WEI_SCALES_DT_E8M0 +#define WEI_SCALES_TO_REF(x) cvt_e8m0_to_f32(x) #elif WEI_SCALES_DT_F16 #define WEI_SCALES_TO_REF(x) convert_float(x) #define REF_TO_WEI_SCALES(x) convert_half(x) @@ -817,6 +819,8 @@ #elif SRC_SCALES_DT_BF8 #define SRC_SCALES_TO_REF(x) convert_float(cvt_f8_e5m2_to_hf(x)) #define REF_TO_SRC_SCALES(x) cvt_hf_to_f8_e5m2(convert_half(x)) +#elif SRC_SCALES_DT_E8M0 +#define SRC_SCALES_TO_REF(x) cvt_e8m0_to_f32(x) #elif SRC_SCALES_DT_F16 #define SRC_SCALES_TO_REF(x) convert_float(x) #define REF_TO_SRC_SCALES(x) convert_half(x) diff --git a/src/gpu/intel/jit/binary_format.cpp b/src/gpu/intel/jit/binary_format.cpp index 29a577625ba..77983b46262 100644 --- a/src/gpu/intel/jit/binary_format.cpp +++ b/src/gpu/intel/jit/binary_format.cpp @@ -59,7 +59,7 @@ class binary_format_kernel_t : public generator_t { FORWARD(hw); public: - binary_format_kernel_t() + binary_format_kernel_t(const engine_t *engine) : generator_t(debug_config_t {GENERATOR_NAME, GENERATOR_LINE}) { auto low_half = [](uint64_t q) -> uint32_t { return q & 0xFFFFFFFF; }; @@ -78,6 +78,9 @@ class binary_format_kernel_t : public generator_t { requireSIMD((GRF::bytes(hw) == 64) ? 16 : 8); requireLocalID(3); // r1-r3 requireLocalSize(); // r7.0-2:ud + if (utils::one_of(hw, ngen::HW::XE3P_35_10, ngen::HW::XE3P_35_11, + ngen::HW::XE3P_UNKNOWN)) + setEfficient64Bit(engine->device_info()->is_efficient_64bit()); finalizeInterface(); Label doWrite; @@ -162,14 +165,13 @@ class binary_format_kernel_t : public generator_t { threadend(SWSB(sb2, 1), r127); } - static compute::kernel_t make_kernel( - intel::engine_t *engine, bool *skip_check) { + static compute::kernel_t make_kernel(engine_t *engine, bool *skip_check) { compute::kernel_t kernel; *skip_check = false; if (hw != HW::Unknown) { - binary_format_kernel_t binary_format_kernel; + binary_format_kernel_t binary_format_kernel(engine); auto status = engine->create_kernel(&kernel, &binary_format_kernel); @@ -201,6 +203,19 @@ class binary_format_kernel_t : public generator_t { kernel = binary_format_kernel_t::make_kernel( engine, skip_check); break; + case compute::gpu_arch_t::xe3p_35_10: + kernel = binary_format_kernel_t< + HW::XE3P_35_10>::make_kernel(engine, skip_check); + break; + case compute::gpu_arch_t::xe3p_35_11: + kernel = binary_format_kernel_t< + HW::XE3P_35_11>::make_kernel(engine, skip_check); + break; + case compute::gpu_arch_t::xe3p_35_unknown: + kernel = binary_format_kernel_t< + HW::XE3P_UNKNOWN>::make_kernel(engine, skip_check); + break; + case compute::gpu_arch_t::unknown: VWARN(common, runtime, "unknown gpu platform - optimizations are disabled " @@ -216,6 +231,13 @@ class binary_format_kernel_t : public generator_t { status_t gpu_supports_binary_format(bool *ok, impl::engine_t *engine) { *ok = false; + auto gpu_engine = utils::downcast(engine); + + if (!gpu_engine) { + VERROR(common, runtime, "bad engine kind, expected a gpu engine"); + return status::invalid_arguments; + } + #if DNNL_GPU_RUNTIME == DNNL_RUNTIME_SYCL if (engine->runtime_kind() == runtime_kind::ocl) { // Here we are doing a check for a temporary OpenCL engine while the @@ -228,13 +250,6 @@ status_t gpu_supports_binary_format(bool *ok, impl::engine_t *engine) { } #endif - auto gpu_engine = utils::downcast(engine); - - if (!gpu_engine) { - VERROR(common, runtime, "bad engine kind, expected a gpu engine"); - return status::invalid_arguments; - } - impl::stream_t *stream_generic; auto status = gpu_engine->get_service_stream(stream_generic); if (status != status::success) return status::runtime_error; diff --git a/src/gpu/intel/jit/eltwise_injector.cpp b/src/gpu/intel/jit/eltwise_injector.cpp index ccea097bd7a..9c447f95dc6 100644 --- a/src/gpu/intel/jit/eltwise_injector.cpp +++ b/src/gpu/intel/jit/eltwise_injector.cpp @@ -52,7 +52,8 @@ int eltwise_injector_f32_t::min_scratch_regs() { case eltwise_square: return 0; case eltwise_swish: return 1; case eltwise_tanh: - case eltwise_tanh_use_dst_for_bwd: return 2; + case eltwise_tanh_use_dst_for_bwd: + return (hw() < gpu_xe3p_35_10) ? 2 : 1; case eltwise_round: return 0; case eltwise_linear: return 0; case eltwise_clip: @@ -72,7 +73,7 @@ int eltwise_injector_f32_t::min_scratch_regs() { case eltwise_square: return 0; case eltwise_linear: return 0; case eltwise_clip: return 1; - case eltwise_gelu_tanh: return 2; + case eltwise_gelu_tanh: return (hw() < gpu_xe3p_35_10) ? 2 : 1; default: assert(!"unsupported eltwise algorithm"); } } @@ -162,24 +163,25 @@ int eltwise_injector_f32_t::phase_count(alg_kind_t alg) { case eltwise_relu: case eltwise_relu_use_dst_for_bwd: return (alpha_ == 0) ? 1 : 2; case eltwise_soft_relu: return 10; - case eltwise_swish: return 5; + case eltwise_swish: return (hw() < gpu_xe3p_35_10) ? 5 : 3; case eltwise_tanh: case eltwise_tanh_use_dst_for_bwd: - return (use_tanh_compat()) ? 9 : 6; + return (hw() < gpu_xe3p_35_10) ? 6 : 1; case eltwise_linear: return (beta_ == 0) ? 1 : 2; case eltwise_clip: case eltwise_clip_v2: case eltwise_clip_v2_use_dst_for_bwd: return 2; case eltwise_gelu_tanh: return 8; case eltwise_logistic: - case eltwise_logistic_use_dst_for_bwd: return 4; + case eltwise_logistic_use_dst_for_bwd: + return (hw() < gpu_xe3p_35_10) ? 4 : 1; default: break; } } else { switch (alg) { case eltwise_abs: return 2; case eltwise_clip: return 4; - case eltwise_gelu_tanh: return 14; + case eltwise_gelu_tanh: return (hw() < gpu_xe3p_35_10) ? 14 : 8; default: break; } } @@ -260,17 +262,24 @@ void eltwise_injector_f32_t::square_compute_fwd( template void eltwise_injector_f32_t::tanh_compute_fwd( int simd, const ngen::GRF &r, int phase, int off, int batch) { - const float log2e = 1.44269502162933349609375f; // log_2(e) - auto one_half = scratch_[0].f(7); - auto a = scratch_[off + batch].f(); - switch (phase) { - case 0: h->mul(simd, a, abs(r), 2.f * log2e); break; - case 1: h->exp(simd, a, a); break; - case 2: h->mad(simd, a, one_half, a, one_half); break; - case 3: h->inv(simd, a, a); break; - case 4: h->add(simd, a, -a, 1.f); break; - case 5: h->csel(simd | ge | f0[0], r, a, -a, r); break; - default: assert(!"invalid phase"); + if (hw() < gpu_xe3p_35_10) { + const float log2e = 1.44269502162933349609375f; // log_2(e) + auto one_half = scratch_[0].f(7); + auto a = scratch_[off + batch].f(); + switch (phase) { + case 0: h->mul(simd, a, abs(r), 2.f * log2e); break; + case 1: h->exp(simd, a, a); break; + case 2: h->mad(simd, a, one_half, a, one_half); break; + case 3: h->inv(simd, a, a); break; + case 4: h->add(simd, a, -a, 1.f); break; + case 5: h->csel(simd | ge | f0[0], r, a, -a, r); break; + default: assert(!"invalid phase"); + } + } else { + switch (phase) { + case 0: h->tanh(simd, r, r); break; + default: assert(!"invalid phase"); + } } } @@ -536,15 +545,24 @@ void eltwise_injector_f32_t::philox_4x32( template void eltwise_injector_f32_t::swish_compute_fwd( int simd, const ngen::GRF &r, int phase, int off) { - const float log2e = 1.442695f; // log_2(e) auto temp = scratch_[off].f(); - switch (phase) { - case 0: h->mul(simd, temp, r, -1.f * log2e * alpha_); break; - case 1: h->exp(simd, temp, temp); break; - case 2: h->add(simd, temp, temp, 1.f); break; - case 3: h->inv(simd, temp, temp); break; - case 4: h->mul(simd, r, r, temp); break; - default: assert(!"invalid phase"); + if (hw() < gpu_xe3p_35_10) { + const float log2e = 1.442695f; // log_2(e) + switch (phase) { + case 0: h->mul(simd, temp, r, -1.f * log2e * alpha_); break; + case 1: h->exp(simd, temp, temp); break; + case 2: h->add(simd, temp, temp, 1.f); break; + case 3: h->inv(simd, temp, temp); break; + case 4: h->mul(simd, r, r, temp); break; + default: assert(!"invalid phase"); + } + } else { + switch (phase) { + case 0: h->mul(simd, temp, r, alpha_); break; + case 1: h->sigm(simd, temp, temp); break; + case 2: h->mul(simd, r, r, temp); break; + default: assert(!"invalid phase"); + } } } @@ -597,13 +615,20 @@ void eltwise_injector_f32_t::gelu_tanh_compute_fwd( template void eltwise_injector_f32_t::logistic_compute_fwd( int simd, const ngen::GRF &r, int phase) { - const float log2e = 1.442695f; // log_2(e) - switch (phase) { - case 0: h->mul(simd, r, r, -1.f * log2e); break; - case 1: h->exp(simd, r, r); break; - case 2: h->add(simd, r, r, 1.f); break; - case 3: h->inv(simd, r, r); break; - default: assert(!"invalid phase"); + if (hw() < gpu_xe3p_35_10) { + const float log2e = 1.442695f; // log_2(e) + switch (phase) { + case 0: h->mul(simd, r, r, -1.f * log2e); break; + case 1: h->exp(simd, r, r); break; + case 2: h->add(simd, r, r, 1.f); break; + case 3: h->inv(simd, r, r); break; + default: assert(!"invalid phase"); + } + } else { + switch (phase) { + case 0: h->sigm(simd, r, r); break; + default: assert(!"invalid phase"); + } } } @@ -644,8 +669,10 @@ void eltwise_injector_f32_t::clip_prepare_bwd() { template void eltwise_injector_f32_t::tanh_prepare_fwd() { - auto one_half = scratch_[0].f(7); - h->mov(1, one_half, 0.5f); + if (hw() < gpu_xe3p_35_10) { + auto one_half = scratch_[0].f(7); + h->mov(1, one_half, 0.5f); + } } template @@ -715,23 +742,37 @@ void eltwise_injector_f32_t::gelu_tanh_compute_bwd( if (hw() == gpu_xe_hp) msimd = 16; auto a = scratch_[off].f(); - auto b = scratch_[off + batch].f(); - switch (phase) { - case 0: h->mul(simd, a, r, r); break; - case 1: h->mul(simd, b, a, 3.0f * k); break; - case 2: h->mul(simd, a, a, k); break; - case 3: h->mad(simd, a, r, a, r); break; - case 4: h->mad(simd, b, r, b, r); break; - case 5: h->mul(simd, a, a, -2 * sqrt_2_over_pi * log2e); break; - case 6: h->mul(simd, b, b, 2 * sqrt_2_over_pi); break; - case 7: h->exp(msimd, a, a); break; - case 8: h->add(simd, r, a, 1.0f); break; - case 9: h->inv(msimd, r, r); break; - case 10: h->mul(simd, a, a, r); break; - case 11: h->mul(simd, a, a, b); break; - case 12: h->add(simd, a, a, 1.0f); break; - case 13: h->mul(simd, r, r, a); break; - default: assert(!"invalid phase"); + if (hw() < gpu_xe3p_35_10) { + auto b = scratch_[off + batch].f(); + switch (phase) { + case 0: h->mul(simd, a, r, r); break; + case 1: h->mul(simd, b, a, 3.0f * k); break; + case 2: h->mul(simd, a, a, k); break; + case 3: h->mad(simd, a, r, a, r); break; + case 4: h->mad(simd, b, r, b, r); break; + case 5: h->mul(simd, a, a, -2 * sqrt_2_over_pi * log2e); break; + case 6: h->mul(simd, b, b, 2 * sqrt_2_over_pi); break; + case 7: h->exp(msimd, a, a); break; + case 8: h->add(simd, r, a, 1.0f); break; + case 9: h->inv(msimd, r, r); break; + case 10: h->mul(simd, a, a, r); break; + case 11: h->mul(simd, a, a, b); break; + case 12: h->add(simd, a, a, 1.0f); break; + case 13: h->mul(simd, r, r, a); break; + default: assert(!"invalid phase"); + } + } else { + switch (phase) { + case 0: h->mul(simd, a, r, k); break; + case 1: h->mad(simd, a, half(1.f), a, r); break; + case 2: h->mul(simd, a, a, r); break; + case 3: h->mul(simd, a, a, sqrt_2_over_pi); break; + case 4: h->tanh(simd, r, a); break; + case 5: h->mad(simd, a, a, a, -r); break; + case 6: h->mad(simd, r, half(0.5f), r, half(0.5f)); break; + case 7: h->mad(simd, r, r, r, a); break; + default: assert(!"invalid phase"); + } } } @@ -1079,6 +1120,10 @@ REG_XEHPG_ISA(template struct eltwise_injector_f32_t>); REG_XEHPC_ISA(template struct eltwise_injector_f32_t>); REG_XE2_ISA(template struct eltwise_injector_f32_t>); REG_XE3_ISA(template struct eltwise_injector_f32_t>); +REG_XE3P_ISA(template struct eltwise_injector_f32_t>); +REG_XE3P_ISA(template struct eltwise_injector_f32_t>); +REG_XE3P_ISA( + template struct eltwise_injector_f32_t>); #ifdef NGEN_ASM template struct eltwise_injector_f32_t; diff --git a/src/gpu/intel/jit/emulated_generator.cpp b/src/gpu/intel/jit/emulated_generator.cpp index 9fc095cd369..a411f2e2f89 100644 --- a/src/gpu/intel/jit/emulated_generator.cpp +++ b/src/gpu/intel/jit/emulated_generator.cpp @@ -449,6 +449,9 @@ REG_XEHPG_ISA(template class emulated_generator_t); REG_XEHPC_ISA(template class emulated_generator_t); REG_XE2_ISA(template class emulated_generator_t); REG_XE3_ISA(template class emulated_generator_t); +REG_XE3P_ISA(template class emulated_generator_t); +REG_XE3P_ISA(template class emulated_generator_t); +REG_XE3P_ISA(template class emulated_generator_t); } // namespace jit } // namespace intel diff --git a/src/gpu/intel/jit/generator.hpp b/src/gpu/intel/jit/generator.hpp index 778db6d2515..83142a2dbe4 100644 --- a/src/gpu/intel/jit/generator.hpp +++ b/src/gpu/intel/jit/generator.hpp @@ -56,6 +56,9 @@ constexpr gpu_gen_t gpu_xe_hpg = ngen::HW::XeHPG; constexpr gpu_gen_t gpu_xe_hpc = ngen::HW::XeHPC; constexpr gpu_gen_t gpu_xe2 = ngen::HW::Xe2; constexpr gpu_gen_t gpu_xe3 = ngen::HW::Xe3; +constexpr gpu_gen_t gpu_xe3p_35_10 = ngen::HW::XE3P_35_10; +constexpr gpu_gen_t gpu_xe3p_35_11 = ngen::HW::XE3P_35_11; +constexpr gpu_gen_t gpu_xe3p_35_unknown = ngen::HW::XE3P_UNKNOWN; #if (!defined(NDEBUG) || defined(DNNL_DEV_MODE)) #define GENERATOR_NAME __FILE__ diff --git a/src/gpu/intel/jit/ir/block_2d_utils.hpp b/src/gpu/intel/jit/ir/block_2d_utils.hpp index c953e5c48c0..56d7eb2bda6 100644 --- a/src/gpu/intel/jit/ir/block_2d_utils.hpp +++ b/src/gpu/intel/jit/ir/block_2d_utils.hpp @@ -41,6 +41,9 @@ inline int block_2d_base_alignment(ngen::HW hw) { case ngen::HW::XeHPC: return 64; case ngen::HW::Xe2: case ngen::HW::Xe3: return 64; + case ngen::HW::XE3P_35_10: + case ngen::HW::XE3P_35_11: + case ngen::HW::XE3P_UNKNOWN: return 4; default: gpu_error_not_expected(); } return 0; @@ -72,6 +75,9 @@ inline int block_2d_pitch_alignment(ngen::HW hw) { case ngen::HW::XeHPC: return 8; case ngen::HW::Xe2: return 16; case ngen::HW::Xe3: return 16; + case ngen::HW::XE3P_35_10: + case ngen::HW::XE3P_35_11: + case ngen::HW::XE3P_UNKNOWN: return 4; default: gpu_error_not_expected(); } return 0; @@ -89,9 +95,14 @@ inline bool block_2d_pitch_ok( return true; } -inline int block_2d_max_count( - bool is_store, bool is_transpose, int block_width, int type_size) { +inline int block_2d_max_count(ngen::HW hw, bool is_prefetch, bool is_store, + bool is_transpose, int block_width, int type_size) { if (is_store || is_transpose) return 1; + if (utils::one_of(hw, ngen::HW::XE3P_35_10, ngen::HW::XE3P_35_11, + ngen::HW::XE3P_UNKNOWN) + && is_prefetch) { + return 256 / (block_width * type_size); + } return 64 / (block_width * type_size); } diff --git a/src/gpu/intel/jit/ir/gemm_schedule.hpp b/src/gpu/intel/jit/ir/gemm_schedule.hpp index 81cfc3bbcca..59cacd7bee2 100644 --- a/src/gpu/intel/jit/ir/gemm_schedule.hpp +++ b/src/gpu/intel/jit/ir/gemm_schedule.hpp @@ -675,6 +675,16 @@ class gemm_schedule_t { } } + bool is_inner_loop(const expr_t &v) const { + for (size_t i = 0; i < vars_.size(); i++) { + auto &loop = find_loop(vars_[i]); + if (!loop.is_leaf() || loop.kind() != loop_kind_t::serial) continue; + if (to_cpp(loop.bound()) == 1) continue; + return find_root_var(vars_[i]).is_same(v); + } + return false; + } + // Sets init and step for loop defined by `var`. // Used to create loop that avoids skip conditions: // for (var = init ; var < bound; var += step) { @@ -976,6 +986,14 @@ class gemm_schedule_t { return loops_[var]; } + expr_t find_root_var(const expr_t &var) const { + auto *loop = &find_loop(var); + while (!loop->is_root()) { + loop = &find_loop(loop->parent_vars()[0]); + } + return loop->var(); + } + int loop_level(const expr_t &var) const { for (int i = 0; i < int(vars_.size()); i++) { if (vars_[i].is_same(var)) return i; diff --git a/src/gpu/intel/jit/ir/hw.hpp b/src/gpu/intel/jit/ir/hw.hpp index 2922f37fb41..f6b945f2259 100644 --- a/src/gpu/intel/jit/ir/hw.hpp +++ b/src/gpu/intel/jit/ir/hw.hpp @@ -48,7 +48,9 @@ inline dsl::hw_t make_ir_hw(const impl::engine_t *engine) { if (device_info->mayiuse_float_atomic_add(data_type::f64)) attr |= dsl::hw_t::attr_t::atomic_fp64; - return dsl::hw_t(product, eu_count, max_wg_size, l3_cache_size, attr); + bool efficient_64_bit = device_info->is_efficient_64bit(); + return dsl::hw_t(product, eu_count, max_wg_size, l3_cache_size, + efficient_64_bit, attr); } inline bool prefer_large_grf( diff --git a/src/gpu/intel/jit/ir/send_builder.cpp b/src/gpu/intel/jit/ir/send_builder.cpp index ed5e17ab60e..5c562a76462 100644 --- a/src/gpu/intel/jit/ir/send_builder.cpp +++ b/src/gpu/intel/jit/ir/send_builder.cpp @@ -398,6 +398,7 @@ bool access_builder_t::try_build_2d(send_params_t &send_params) { if (!hint.type.is_undef()) vlayout = reinterpret(vlayout, hint.type); bool is_store = (send_op_ == send_op_t::store); + bool is_prefetch = (send_op_ == send_op_t::prefetch); auto send_type = dsl::type_t::u(vlayout.type().size() * 8); auto blocks = vlayout.blocks(); if (blocks.size() < 2) return false; @@ -464,8 +465,8 @@ bool access_builder_t::try_build_2d(send_params_t &send_params) { // Try to reduce the number of messages by increasing count per message. int try_count = count * 2; - int max_count - = block_2d_max_count(is_store, transpose, width, mem_type_.size()); + int max_count = block_2d_max_count(ir_ctx_->hw(), is_prefetch, is_store, + transpose, width, mem_type_.size()); while (try_count <= max_count) { if (b0.size % (try_count * width) != 0) break; count = try_count; @@ -654,8 +655,10 @@ bool access_builder_t::fixup_send_2d_params(const dsl::type_t &send_type, int factor = 64 / surface_width_size; if (h % factor != 0) return false; - int max_count = block_2d_max_count( - send_op_ == send_op_t::store, transpose, w, send_type.size()); + int max_count = block_2d_max_count(ir_ctx_->hw(), + send_op_ == send_op_t::prefetch, send_op_ == send_op_t::store, + transpose, w, send_type.size()); + if (factor > max_count) return false; vnni_permute_factor = factor; diff --git a/src/gpu/intel/jit/ir/send_plan.cpp b/src/gpu/intel/jit/ir/send_plan.cpp index e3cd7390158..60fe7ba8cd4 100644 --- a/src/gpu/intel/jit/ir/send_plan.cpp +++ b/src/gpu/intel/jit/ir/send_plan.cpp @@ -660,22 +660,25 @@ struct send_2d_params_t { bool is_store() const { return send_op == send_op_t::store; } - int max_count() const { - return block_2d_max_count(is_store(), transpose, w, type.size()); + bool is_prefetch() const { return send_op == send_op_t::prefetch; } + + int max_count(const dsl::hw_t &hw) const { + return block_2d_max_count(hw.ngen_hw(), is_prefetch(), is_store(), + transpose, w, type.size()); } // Reduce the number of messages by increasing count per // message. - void try_promote_count() { + void try_promote_count(const dsl::hw_t &hw) { if (vnni_factor != 1) return; - while (c * 2 <= max_count()) { + while (c * 2 <= max_count(hw)) { if (w_rcount % 2 != 0) break; c *= 2; w_rcount /= 2; } } - bool apply_vnni_factor(int factor) { + bool apply_vnni_factor(int factor, const dsl::hw_t &hw) { if (factor == 0) return true; if (use_xy) return fail_2d( @@ -690,9 +693,9 @@ struct send_2d_params_t { if (H % factor != 0) return fail_2d("Can't apply VNNI factor: invalid surface height."); if (c != 1) return fail_2d("Can't apply VNNI factor: invalid count."); - if (factor > max_count()) + if (factor > max_count(hw)) return fail_2d( - "Can't apply VNNI factor: factor exceeds max_count()."); + "Can't apply VNNI factor: factor exceeds max_count(hw)."); W *= factor; H /= factor; P *= factor; @@ -1756,14 +1759,15 @@ class send_2d_helper_t { params_.h_tidx = h_tidx; params_.h_vstride = into(h_vstride); - if (!params_.apply_vnni_factor(hint.vnni_permute_factor)) return false; + if (!params_.apply_vnni_factor(hint.vnni_permute_factor, info_.hw())) + return false; if (!params_.is_supported(info_.hw())) return false; if (!base_alignment_ok(vlayout, mod_info, h_tdim, h_vstride)) return false; if (!x_alignment_ok(w_tdim, mod_info)) return false; if (!masks_ok()) return false; - params_.try_promote_count(); + params_.try_promote_count(info_.hw()); params_.is_valid = true; return true; } diff --git a/src/gpu/intel/jit/ir/v2/send.hpp b/src/gpu/intel/jit/ir/v2/send.hpp index 0e0dfe6c27d..60801b7a29c 100644 --- a/src/gpu/intel/jit/ir/v2/send.hpp +++ b/src/gpu/intel/jit/ir/v2/send.hpp @@ -511,7 +511,7 @@ struct send_2d_desc_t { // Reduce the number of messages by increasing count per // message. void try_promote_count() { - int max_count = block_2d_max_count( + int max_count = block_2d_max_count(hw, op == send_op_t::prefetch, op == send_op_t::store, transpose, w, type.size()); while (c * 2 <= max_count) { if (w_rcount % 2 != 0) break; diff --git a/src/gpu/intel/jit/post_op_injector.cpp b/src/gpu/intel/jit/post_op_injector.cpp index af9fbab9f7e..aaf6d9f8cd9 100644 --- a/src/gpu/intel/jit/post_op_injector.cpp +++ b/src/gpu/intel/jit/post_op_injector.cpp @@ -69,6 +69,9 @@ REG_XEHPG_ISA(template struct post_op_injector_t>); REG_XEHPC_ISA(template struct post_op_injector_t>); REG_XE2_ISA(template struct post_op_injector_t>); REG_XE3_ISA(template struct post_op_injector_t>); +REG_XE3P_ISA(template struct post_op_injector_t>); +REG_XE3P_ISA(template struct post_op_injector_t>); +REG_XE3P_ISA(template struct post_op_injector_t>); #ifdef NGEN_ASM template struct post_op_injector_t; diff --git a/src/gpu/intel/jit/reduction_injector.cpp b/src/gpu/intel/jit/reduction_injector.cpp index be2f121f28e..3c20c814ad3 100644 --- a/src/gpu/intel/jit/reduction_injector.cpp +++ b/src/gpu/intel/jit/reduction_injector.cpp @@ -276,6 +276,12 @@ REG_XEHPG_ISA(template struct reduction_injector_f32_t>); REG_XEHPC_ISA(template struct reduction_injector_f32_t>); REG_XE2_ISA(template struct reduction_injector_f32_t>); REG_XE3_ISA(template struct reduction_injector_f32_t>); +REG_XE3P_ISA( + template struct reduction_injector_f32_t>); +REG_XE3P_ISA( + template struct reduction_injector_f32_t>); +REG_XE3P_ISA(template struct reduction_injector_f32_t< + code_gen>); #ifdef NGEN_ASM template struct reduction_injector_f32_t; diff --git a/src/gpu/intel/jit/utils/type_bridge.hpp b/src/gpu/intel/jit/utils/type_bridge.hpp index c7ecb733111..35d8093e366 100644 --- a/src/gpu/intel/jit/utils/type_bridge.hpp +++ b/src/gpu/intel/jit/utils/type_bridge.hpp @@ -82,6 +82,10 @@ inline ngen::HW convert_dnnl_arch_to_ngen(compute::gpu_arch_t gpu_arch) { case compute::gpu_arch_t::xe_hpc: return ngen::HW::XeHPC; case compute::gpu_arch_t::xe2: return ngen::HW::Xe2; case compute::gpu_arch_t::xe3: return ngen::HW::Xe3; + case compute::gpu_arch_t::xe3p_35_10: return ngen::HW::XE3P_35_10; + case compute::gpu_arch_t::xe3p_35_11: return ngen::HW::XE3P_35_11; + case compute::gpu_arch_t::xe3p_35_unknown: + return ngen::HW::XE3P_UNKNOWN; case compute::gpu_arch_t::unknown: return ngen::HW::Unknown; } return ngen::HW::Unknown; @@ -95,6 +99,10 @@ inline compute::gpu_arch_t convert_ngen_arch_to_dnnl(ngen::HW gpu_arch) { case ngen::HW::XeHPC: return compute::gpu_arch_t::xe_hpc; case ngen::HW::Xe2: return compute::gpu_arch_t::xe2; case ngen::HW::Xe3: return compute::gpu_arch_t::xe3; + case ngen::HW::XE3P_35_10: return compute::gpu_arch_t::xe3p_35_10; + case ngen::HW::XE3P_35_11: return compute::gpu_arch_t::xe3p_35_11; + case ngen::HW::XE3P_UNKNOWN: + return compute::gpu_arch_t::xe3p_35_unknown; case ngen::HW::Gen9: case ngen::HW::Gen10: case ngen::HW::Gen11: diff --git a/src/gpu/intel/jit/utils/utils.hpp b/src/gpu/intel/jit/utils/utils.hpp index 0763a1f7be2..cabb54ee301 100644 --- a/src/gpu/intel/jit/utils/utils.hpp +++ b/src/gpu/intel/jit/utils/utils.hpp @@ -708,6 +708,9 @@ static auto hw_names = nstl::to_array({ make_enum_name(ngen::Core::XeHPC, "xehpc"), make_enum_name(ngen::Core::Xe2, "xe2"), make_enum_name(ngen::Core::Xe3, "xe3"), + make_enum_name(ngen::Core::XE3P_35_10, "xe3p_35_10"), + make_enum_name(ngen::Core::XE3P_35_11, "xe3p_35_11"), + make_enum_name(ngen::Core::XE3P_UNKNOWN, "xe3p_35_unknown"), }); GPU_DEFINE_PARSE_ENUM(ngen::HW, hw_names) @@ -726,6 +729,9 @@ static auto product_family_names = nstl::to_array({ make_enum_name(ngen::ProductFamily::PVC, "pvc"), make_enum_name(ngen::ProductFamily::GenericXe2, "xe2"), make_enum_name(ngen::ProductFamily::GenericXe3, "xe3"), + make_enum_name(ngen::ProductFamily::XE3P_35_10, "xe3p_35_10"), + make_enum_name(ngen::ProductFamily::XE3P_35_11, "xe3p_35_11"), + make_enum_name(ngen::ProductFamily::XE3P_UNKNOWN, "xe3p_35_unknown"), }); GPU_DEFINE_PARSE_ENUM(ngen::ProductFamily, product_family_names) @@ -1227,7 +1233,6 @@ void deserialize_from_hex(T &t, const std::string &s_hex) { GPU_HW_CASE(ngen::HW::hw); \ break; \ } - #define GPU_HW_SWITCH(hw) \ switch (hw) { \ REG_XELP_ISA(GPU_HW_CASE_(XeLP)); \ @@ -1236,6 +1241,9 @@ void deserialize_from_hex(T &t, const std::string &s_hex) { REG_XEHPC_ISA(GPU_HW_CASE_(XeHPC)); \ REG_XE2_ISA(GPU_HW_CASE_(Xe2)); \ REG_XE3_ISA(GPU_HW_CASE_(Xe3)); \ + REG_XE3P_ISA(GPU_HW_CASE_(XE3P_35_10)); \ + REG_XE3P_ISA(GPU_HW_CASE_(XE3P_35_11)); \ + REG_XE3P_ISA(GPU_HW_CASE_(XE3P_UNKNOWN)); \ default: gpu_assert(false) << "Unexpected GPU architecture"; \ } diff --git a/src/gpu/intel/matmul/gemm.cpp b/src/gpu/intel/matmul/gemm.cpp index aeb26a1b7ab..1139e9217cd 100644 --- a/src/gpu/intel/matmul/gemm.cpp +++ b/src/gpu/intel/matmul/gemm.cpp @@ -66,10 +66,12 @@ status_t gemm_t::execute(const exec_ctx_t &ctx) const { DNNL_ARG_ATTR_PRECOMPUTED_REDUCTIONS | DNNL_ARG_WEIGHTS); args.b_group_sums = &CTX_IN_STORAGE( DNNL_ARG_ATTR_PRECOMPUTED_REDUCTIONS | DNNL_ARG_SRC); + args.dropout_offset = &CTX_IN_STORAGE(DNNL_ARG_ATTR_DROPOUT_OFFSET); args.dropout_seed = &CTX_IN_STORAGE(DNNL_ARG_ATTR_DROPOUT_SEED); args.dropout_prob = &CTX_IN_STORAGE(DNNL_ARG_ATTR_DROPOUT_PROBABILITY); args.dropout_mask = &CTX_OUT_STORAGE(DNNL_ARG_ATTR_DROPOUT_MASK); + args.sround_seed = &CTX_IN_STORAGE(DNNL_ARG_ATTR_ROUNDING_SEED); args.exec_args = ctx.args(); gemm::desc_t desc; diff --git a/src/gpu/intel/matmul/ref.cl b/src/gpu/intel/matmul/ref.cl index 4709e5551e4..4bf9c76a26c 100644 --- a/src/gpu/intel/matmul/ref.cl +++ b/src/gpu/intel/matmul/ref.cl @@ -312,6 +312,7 @@ __kernel void ref_matmul(__global SRC_DATA_T *A, __global WEI_DATA_T *B, #if WITH_DST_SCALES #if DST_SCALES_MASK == 0 po_acc /= DST_SCALES_TO_REF(dst_scales[0]); + #elif WITH_DYN_DST_SCALE == 0 po_acc /= DST_SCALES_TO_REF(dst_scales[n]); #endif @@ -324,6 +325,7 @@ __kernel void ref_matmul(__global SRC_DATA_T *A, __global WEI_DATA_T *B, C[dst_off] = TO_DST(po_acc); #endif #else // WITH_BIAS || NON_DEFAULT_ATTRS + #if WITH_DYN_DST_SCALE ((__global ACC_DATA_T *)C)[dst_off] = acc; #else diff --git a/src/gpu/intel/matmul/ref.cpp b/src/gpu/intel/matmul/ref.cpp index 0555dbd2e61..19cddc98d5f 100644 --- a/src/gpu/intel/matmul/ref.cpp +++ b/src/gpu/intel/matmul/ref.cpp @@ -39,6 +39,7 @@ status_t ref_t::execute_ref(const exec_ctx_t &ctx) const { auto &dst_scales = (dyn_scales ? CTX_OUT_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST) : CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST)); + const auto &a0 = CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC); const auto &b0 = CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS); diff --git a/src/gpu/intel/matmul/ref.hpp b/src/gpu/intel/matmul/ref.hpp index 48e14328cbe..4d82222c025 100644 --- a/src/gpu/intel/matmul/ref.hpp +++ b/src/gpu/intel/matmul/ref.hpp @@ -86,7 +86,6 @@ struct ref_t : public primitive_t { && utils::one_of(wei_dt_, f16, s8, u8, s4, u4); const bool is_bf16 = src_dt_ == bf16 && utils::one_of(wei_dt_, bf16, s8, u8, s4, u4); - const bool is_f8 = utils::one_of(src_dt_, f8_e5m2, f8_e4m3) || utils::one_of(wei_dt_, f8_e5m2, f8_e4m3); const bool is_f4 diff --git a/src/gpu/intel/ocl/device_info.cpp b/src/gpu/intel/ocl/device_info.cpp index e55be6d92f5..a2464c918d3 100644 --- a/src/gpu/intel/ocl/device_info.cpp +++ b/src/gpu/intel/ocl/device_info.cpp @@ -44,7 +44,7 @@ status_t device_info_t::init_arch(impl::engine_t *engine) { CHECK(init_gpu_hw_info(engine, device, context, ip_version_, gpu_arch_, gpu_product_, native_extensions_, mayiuse_systolic_, - mayiuse_ngen_kernels_)); + mayiuse_ngen_kernels_, is_efficient_64bit_)); err = xpu::ocl::clReleaseContext(context); OCL_CHECK(err); diff --git a/src/gpu/intel/ocl/hw_info.cpp b/src/gpu/intel/ocl/hw_info.cpp index 7872e1b2abb..1c7084f291a 100644 --- a/src/gpu/intel/ocl/hw_info.cpp +++ b/src/gpu/intel/ocl/hw_info.cpp @@ -59,10 +59,12 @@ xpu::runtime_version_t get_driver_version(cl_device_id device) { status_t init_gpu_hw_info(impl::engine_t *engine, cl_device_id device, cl_context ctx, uint32_t &ip_version, compute::gpu_arch_t &gpu_arch, compute::gpu_product_t &product_, uint64_t &native_extensions, - bool &mayiuse_systolic, bool &mayiuse_ngen_kernels) { + bool &mayiuse_systolic, bool &mayiuse_ngen_kernels, + bool &is_efficient_64bit) { using namespace ngen; ngen::Product product = ngen::OpenCLCodeGenerator::detectHWInfo(ctx, device); + HW hw = getCore(product.family); bool is_xelpg = (product.family == ngen::ProductFamily::ARL || product.family == ngen::ProductFamily::MTL); @@ -74,13 +76,20 @@ status_t init_gpu_hw_info(impl::engine_t *engine, cl_device_id device, CHECK(get_ocl_device_enabled_native_float_atomics( device, native_extensions, is_xelpg)); - auto status - = jit::gpu_supports_binary_format(&mayiuse_ngen_kernels, engine); - if (status != status::success) { - VWARN(common, runtime, - "ngen fallback (gpu does not support binary format kernels)"); - mayiuse_ngen_kernels = false; - } + if (hw <= ngen::HW::Xe3) { + auto status = jit::gpu_supports_binary_format( + &mayiuse_ngen_kernels, engine); + if (status != status::success) { + VWARN(common, runtime, + "ngen fallback (gpu does not support binary format " + "kernels)"); + mayiuse_ngen_kernels = false; + } + } else if (hw != ngen::HW::Unknown) + mayiuse_ngen_kernels = true; + + is_efficient_64bit = OpenCLCodeGenerator::detectEfficient64Bit( + ctx, device, hw); ip_version = 0; OCL_CHECK(xpu::ocl::clGetDeviceInfo(device, CL_DEVICE_IP_VERSION_INTEL, diff --git a/src/gpu/intel/ocl/hw_info.hpp b/src/gpu/intel/ocl/hw_info.hpp index db6ce129886..0b72c84804c 100644 --- a/src/gpu/intel/ocl/hw_info.hpp +++ b/src/gpu/intel/ocl/hw_info.hpp @@ -33,7 +33,8 @@ xpu::runtime_version_t get_driver_version(cl_device_id device); status_t init_gpu_hw_info(impl::engine_t *engine, cl_device_id device, cl_context ctx, uint32_t &ip_version, compute::gpu_arch_t &gpu_arch, compute::gpu_product_t &product, uint64_t &native_extensions, - bool &mayiuse_systolic, bool &mayiuse_ngen_kernels); + bool &mayiuse_systolic, bool &mayiuse_ngen_kernels, + bool &is_efficient_64bit_); } // namespace ocl } // namespace intel diff --git a/src/gpu/intel/ocl/utils.cpp b/src/gpu/intel/ocl/utils.cpp index 4360dbb3dab..5460334e45e 100644 --- a/src/gpu/intel/ocl/utils.cpp +++ b/src/gpu/intel/ocl/utils.cpp @@ -219,6 +219,70 @@ status_t get_ocl_kernel_binary(cl_kernel ocl_kernel, xpu::binary_t &binary) { return status::success; } +void debugdump_processed_source(const std::string &source, + const std::string &options, const std::string &cl_options) { +#if defined(__linux__) && defined(DNNL_DEV_MODE) + if (get_verbose(verbose_t::debuginfo) >= 10) { + auto get_defines = [](const std::string &from) { + std::string ret; + size_t pos = 0; + while (pos < from.length()) { + // Find next define argument + pos = from.find("-D", pos); + + // Generate argument, quotes are interpreted literally, but + // other special shell characters need escaped. Does not + // currently handle quotes with the ' character or nested quotes + char quote_parity = true; + while (pos < from.length()) { + if (quote_parity + && utils::one_of(from[pos], '~', '#', '$', '&', '*', + '(', ')', '\\', '|', '[', ']', '{', '}', + ';', '\'', '<', '>', '/', '?', '!')) { + ret += '\\'; + } + ret += from[pos]; + if (from[pos] == '"') quote_parity ^= true; + if (from[pos] == ' ' && quote_parity) break; + + pos++; + } + } + return ret; + }; + auto execute_command + = [](const std::string &cmd, const std::string &stdin) { + std::string result; + std::array buffer; + FILE *pipe = popen(cmd.c_str(), "w"); + fputs(stdin.c_str(), pipe); + if (pipe) { + while (fgets(buffer.data(), buffer.size(), pipe) != nullptr) { + result += buffer.data(); + } + } + pclose(pipe); + return result; + }; + + // Run utilities to evaluate preprocessor defines and format the file + // Theoretically, we can accomplish this task with libclang, but it + // seems more work than it is worth. Instead, wrapping this in OCL_DEBUG + // so that calls to the system are not included in the default build. + + // Due to the use of a different C preprocessor, warnings should not be + // ignored, as they may correspond to a different behavior in the OpenCL + // C preprocessor + auto o = get_defines(options) + get_defines(cl_options); + std::string preprocess_cmd + = std::string() + "cpp -P " + o + " | clang-format"; + execute_command(preprocess_cmd, source); + std::cout << "OCL_ARCH_OPTIONS: " << cl_options << std::endl; + std::cout << "OCL_OPTIONS: " << options << std::endl; + } +#endif +} + status_t get_kernel_arg_types(cl_kernel ocl_kernel, std::vector *arg_types) { cl_uint nargs; diff --git a/src/gpu/intel/primitive_conf.cpp b/src/gpu/intel/primitive_conf.cpp index fd5c599f707..e1836896496 100644 --- a/src/gpu/intel/primitive_conf.cpp +++ b/src/gpu/intel/primitive_conf.cpp @@ -713,6 +713,7 @@ status_t def_attr_info_impl(compute::kernel_ctx_t &kernel_ctx, kernel_ctx.define_int("WITH_HOST_WEI_SCALE", attr_info.with_host_wei_scale); kernel_ctx.define_int("WITH_HOST_DST_SCALE", attr_info.with_host_dst_scale); kernel_ctx.define_int("WITH_DYN_DST_SCALE", attr_info.with_dyn_dst_scale); + kernel_ctx.define_int("WITH_MX_DST_SCALE", attr_info.with_dyn_dst_scale); return def_post_ops_cfg(kernel_ctx, post_ops, dst_md); } diff --git a/src/gpu/intel/primitive_conf.hpp b/src/gpu/intel/primitive_conf.hpp index 0b97a32f067..4ee2a7cd810 100644 --- a/src/gpu/intel/primitive_conf.hpp +++ b/src/gpu/intel/primitive_conf.hpp @@ -104,6 +104,7 @@ struct attr_info_t { bool with_host_wei_scale; bool with_host_dst_scale; bool with_dyn_dst_scale; + bool with_host_src_zp; bool with_host_wei_zp; bool with_host_dst_zp; diff --git a/src/gpu/intel/sycl/device_info.cpp b/src/gpu/intel/sycl/device_info.cpp index 7fd821e8f9c..e1e1a2f888a 100644 --- a/src/gpu/intel/sycl/device_info.cpp +++ b/src/gpu/intel/sycl/device_info.cpp @@ -49,14 +49,14 @@ status_t device_info_t::init_arch(impl::engine_t *engine) { status = gpu::intel::ocl::init_gpu_hw_info(engine, ocl_dev, ocl_ctx, ip_version_, gpu_arch_, gpu_product_, native_extensions_, - mayiuse_systolic_, mayiuse_ngen_kernels_); + mayiuse_systolic_, mayiuse_ngen_kernels_, is_efficient_64bit_); } else if (be == xpu::sycl::backend_t::ze) { auto ze_dev = xpu::sycl::compat::get_native(device); auto ze_ctx = xpu::sycl::compat::get_native(ctx); status = gpu::intel::ze::init_gpu_hw_info(engine, ze_dev, ze_ctx, ip_version_, gpu_arch_, gpu_product_, native_extensions_, - mayiuse_systolic_, mayiuse_ngen_kernels_); + mayiuse_systolic_, mayiuse_ngen_kernels_, is_efficient_64bit_); } else { assert(!"not_expected"); status = status::unimplemented; @@ -145,6 +145,7 @@ status_t device_info_t::init_attributes(impl::engine_t *engine) { = device.get_info<::sycl::info::device::address_bits>(); max_allocation_size_ = device.get_info<::sycl::info::device::max_mem_alloc_size>(); + return status::success; } diff --git a/src/gpu/intel/ze/device_info.cpp b/src/gpu/intel/ze/device_info.cpp index b046b8a555f..7525e23c0a3 100644 --- a/src/gpu/intel/ze/device_info.cpp +++ b/src/gpu/intel/ze/device_info.cpp @@ -44,7 +44,7 @@ status_t device_info_t::init_arch(impl::engine_t *engine) { return init_gpu_hw_info(engine, device, context, ip_version_, gpu_arch_, gpu_product_, native_extensions_, mayiuse_systolic_, - mayiuse_ngen_kernels_); + mayiuse_ngen_kernels_, is_efficient_64bit_); } status_t device_info_t::init_runtime_version(impl::engine_t *engine) { diff --git a/src/gpu/intel/ze/utils.cpp b/src/gpu/intel/ze/utils.cpp index fd07de9b6f0..2aac307b584 100644 --- a/src/gpu/intel/ze/utils.cpp +++ b/src/gpu/intel/ze/utils.cpp @@ -161,7 +161,7 @@ status_t init_gpu_hw_info(impl::engine_t *engine, ze_device_handle_t device, ze_context_handle_t context, uint32_t &ip_version, compute::gpu_arch_t &gpu_arch, compute::gpu_product_t &product_, uint64_t &native_extensions, bool &mayiuse_systolic, - bool &mayiuse_ngen_kernels) { + bool &mayiuse_ngen_kernels, bool &is_efficient_64bit) { using namespace ngen; ngen::Product product = LevelZeroCodeGenerator::detectHWInfo( context, device); @@ -186,6 +186,9 @@ status_t init_gpu_hw_info(impl::engine_t *engine, ze_device_handle_t device, CHECK(get_ze_device_enabled_native_float_atomics( device, native_extensions)); + is_efficient_64bit + = LevelZeroCodeGenerator::detectEfficient64Bit( + context, device, getCore(product.family)); auto status = jit::gpu_supports_binary_format(&mayiuse_ngen_kernels, engine); if (status != status::success) mayiuse_ngen_kernels = false; diff --git a/src/gpu/intel/ze/utils.hpp b/src/gpu/intel/ze/utils.hpp index dbffc3cddf4..59e14a89891 100644 --- a/src/gpu/intel/ze/utils.hpp +++ b/src/gpu/intel/ze/utils.hpp @@ -31,7 +31,7 @@ status_t init_gpu_hw_info(impl::engine_t *engine, ze_device_handle_t device, ze_context_handle_t context, uint32_t &ip_version, compute::gpu_arch_t &gpu_arch, compute::gpu_product_t &product, uint64_t &native_extensions, bool &mayiuse_systolic, - bool &mayiuse_ngen_kernels); + bool &mayiuse_ngen_kernels, bool &is_efficient_64bit); status_t get_module_binary( ze_module_handle_t module_handle, xpu::binary_t &binary); diff --git a/src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp b/src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp index 31081d5cc55..d1917b6f843 100644 --- a/src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp +++ b/src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp @@ -173,17 +173,20 @@ status_t sdp_primitive_config_t::initial_check( status::unimplemented, "Q, K, V are not found"); dims q_dims = ltw(inputs[q_id]).vdims(); - size_t q_ndims = q_dims.size(); - const dim_t seq_len_q = q_dims[q_ndims - 2]; - const dim_t head_size_qk = q_dims[q_ndims - 1]; + + const size_t q_ndims = q_dims.size(); + VCHECK_SDP_PRIMITIVE(q_ndims == 4, status::unimplemented, + "input ndims (Q) can only be == 4"); + const dim_t seq_len_q = q_dims[2]; + const dim_t head_size_qk = q_dims[3]; const bool thinq = seq_len_q < 16; const bool opt_prefill = head_size_qk <= 64 && !thinq; - VCHECK_SDP_PRIMITIVE(!is_f32 || (!has_genindex && !opt_prefill), - status::unimplemented, - "f32 fused sdpa supported for: causal mask or cases with " - "head_size(%d) <= 64, seq_len(%d) >= 16", - static_cast(head_size_qk), static_cast(seq_len_q)); + if (is_f32 && !(has_genindex || opt_prefill)) { + VCHECK_SDP_PRIMITIVE(false, status::unimplemented, + "f32 fused sdpa supported for: causal mask or cases with " + "head_size <= 64, seq_len >= 16"); + } // sdp_primitive only supports single scale value. if (scale) { diff --git a/tests/benchdnn/matmul/ref_matmul.cpp b/tests/benchdnn/matmul/ref_matmul.cpp index 3d97d7ce153..4f03e2baa90 100644 --- a/tests/benchdnn/matmul/ref_matmul.cpp +++ b/tests/benchdnn/matmul/ref_matmul.cpp @@ -192,7 +192,6 @@ void compute_ref_matmul(const prb_t *prb, const args_t &args) { prb, src_mb, m, gK * smallest_k_group + k); const auto wei_off = wei_ba_off_f( prb, wei_mb, gK * smallest_k_group + k, n); - auto s = src_scale * (src_m.get_f32_elem(src_off) - src_zp); auto w = wei_scale * (wei_m.get_f32_elem(wei_off) - wei_zp); diff --git a/tests/gtests/graph/api/test_c_api_compile.cpp b/tests/gtests/graph/api/test_c_api_compile.cpp index 8d02b17da33..d311dde87f0 100644 --- a/tests/gtests/graph/api/test_c_api_compile.cpp +++ b/tests/gtests/graph/api/test_c_api_compile.cpp @@ -1010,7 +1010,14 @@ TEST(CAPI, CompileSumConv2DStridedBN) { *(compiled_partition + 1), &num_inplace_pairs, &inplace_pairs), dnnl_success); +// Blocked layout is supported on intel gpu only, so here we will get one +// in-place pair on gpu of other vendors +#if DNNL_GPU_RUNTIME != DNNL_RUNTIME_NONE \ + && DNNL_GPU_VENDOR == DNNL_VENDOR_INTEL EXPECT_EQ(num_inplace_pairs, 0U); +#else + EXPECT_EQ(num_inplace_pairs, engine == dnnl_gpu ? 1U : 0U); +#endif COMPILE_SUM_CONV2D_STRIDED_BN_DESTROY_PLUS; #undef COMPILE_SUM_CONV2D_STRIDED_BN_DESTROY diff --git a/third_party/ngen/ngen.hpp b/third_party/ngen/ngen.hpp index f92421587bf..92e575bb4e4 100644 --- a/third_party/ngen/ngen.hpp +++ b/third_party/ngen/ngen.hpp @@ -62,6 +62,9 @@ template struct Instruction12Dispatch { using type = Instruction12 template <> struct Instruction12Dispatch { using type = InstructionXeHPC; }; template <> struct Instruction12Dispatch { using type = InstructionXeHPC; }; template <> struct Instruction12Dispatch { using type = InstructionXeHPC; }; +template <> struct Instruction12Dispatch { using type = InstructionXe3p; }; +template <> struct Instruction12Dispatch { using type = InstructionXe3p; }; +template <> struct Instruction12Dispatch { using type = InstructionXe3p; }; // MSVC v140 workaround for enum comparison in template arguments. static constexpr bool hwLT(HW hw1, HW hw2) { return hw1 < hw2; } @@ -126,6 +129,7 @@ class BinaryCodeGenerator code.push_back(i.qword[1]); } + void addFixup(LabelFixup fixup) { fixup.anchor = length(); fixups.push_back(fixup); @@ -170,6 +174,7 @@ class BinaryCodeGenerator for (uint32_t id : other.labels) man.offsetTarget(id, offset); + other.appended = true; } @@ -204,6 +209,7 @@ class BinaryCodeGenerator private: InstructionModifier defaultModifier; + bool useEfficient64Bit = (hw >= HW::XE3P_35_10); LabelManager labelManager; InstructionStream rootStream; @@ -255,6 +261,8 @@ class BinaryCodeGenerator void opBfn(Opcode op, DataType defaultType, const InstructionModifier &mod, int bfnCtrl, D dst, S0 src0, RegData src1, S2 src2, SourceLocation loc); void opDpas(Opcode op, DataType defaultType, const InstructionModifier &mod, int sdepth, int rcount, RegData dst, RegData src0, RegData src1, RegData src2, SourceLocation loc); + void opBdpas(Opcode op, DataType defaultType, const InstructionModifier &mod, int sdepth, int rcount, RegData dst, RegData src0, RegData src1, RegData src2, RegData src3, RegData src4, SourceLocation loc); + template typename std::enable_if::type opSend(Opcode op, const InstructionModifier &mod, SharedFunction sfid, const RegData &dst, const RegData &src0, const RegData &src1, int src1Length, uint32_t exdesc, D desc, SourceLocation loc); template @@ -276,6 +284,8 @@ class BinaryCodeGenerator template typename std::enable_if::type opSends(Opcode op, const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, RegData exdesc, D desc, SourceLocation loc); + void opSendg(Opcode op, const InstructionModifier &mod, SharedFunction sfid, const RegData &dst, RegData src0, int src0Len, const RegData &src1, int src1Len, RegData ind0, RegData ind1, uint64_t desc, SourceLocation loc); + template typename std::enable_if::type opBranch(Opcode op, const InstructionModifier &mod, const RegData &dst, int32_t jip, int32_t uip, SourceLocation loc); template @@ -300,6 +310,9 @@ class BinaryCodeGenerator typename std::enable_if::type opJmpi(Opcode op, const InstructionModifier &mod, const RegData &dst, RegData src0, uint32_t jip, SourceLocation loc); void opJmpi(Opcode op, const InstructionModifier &mod, const RegData &dst, const RegData &src0, Label &jip, SourceLocation loc); + void opShflLfsr(Opcode op, uint8_t fc, DataType defaultType, const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, SourceLocation loc); + void opShflLfsr(Opcode op, uint8_t fc, DataType defaultType, const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1, SourceLocation loc); + void opSync(Opcode op, SyncFunction fc, const InstructionModifier &mod, SourceLocation loc); void opSync(Opcode op, SyncFunction fc, const InstructionModifier &mod, RegData src0, SourceLocation loc); void opSync(Opcode op, SyncFunction fc, const InstructionModifier &mod, const Immediate &src0, SourceLocation loc); @@ -309,6 +322,7 @@ class BinaryCodeGenerator template void opDirective(Directive directive, RegData src0, S1 src1, SourceLocation loc); + static constexpr14 InstructionModifier defaultMods() { return GRF::bytes(hw) >> 2; } @@ -321,6 +335,7 @@ class BinaryCodeGenerator explicit BinaryCodeGenerator(Product product_, DebugConfig debugConfig = {}) : product{product_}, debugLine(debugConfig), defaultModifier{}, labelManager{}, + lfsr{this}, shfl{this}, sync{this}, load{this}, store{this}, atomic{this} { _workaround_(); @@ -354,6 +369,9 @@ class BinaryCodeGenerator bool getDefaultNoMask() const { return defaultModifier.isWrEn(); } bool getDefaultAutoSWSB() const { return defaultModifier.isAutoSWSB(); } + void setEfficient64Bit(bool def = true) { useEfficient64Bit = def; } + bool getEfficient64Bit() const { return useEfficient64Bit; } + // Stream handling. void pushStream() { pushStream(new InstructionStream()); } void pushStream(InstructionStream *s) { streamStack.push_back(s); } @@ -747,23 +765,36 @@ class BinaryCodeGenerator } template void mac(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, SourceLocation loc = {}) { +#ifdef NGEN_SAFE + if (hardware >= HW::XE3P_35_10) unsupported(); +#endif opX(Opcode::mac, getDataType
(), mod, dst, src0, src1, loc); } template void mac(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1, SourceLocation loc = {}) { +#ifdef NGEN_SAFE + if (hardware >= HW::XE3P_35_10) unsupported(); +#endif opX(Opcode::mac, getDataType
(), mod, dst, src0, src1, loc); } template void mach(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, SourceLocation loc = {}) { +#ifdef NGEN_SAFE + if (hardware >= HW::XE3P_35_10) unsupported(); +#endif opX(Opcode::mach, getDataType
(), (hw >= HW::XeHPC) ? mod : (mod | AccWrEn), dst, src0, src1, loc); } template void mach(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1, SourceLocation loc = {}) { +#ifdef NGEN_SAFE + if (hardware >= HW::XE3P_35_10) unsupported(); +#endif opX(Opcode::mach, getDataType
(), (hw >= HW::XeHPC) ? mod : (mod | AccWrEn), dst, src0, src1, loc); } template void macl(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, SourceLocation loc = {}) { #ifdef NGEN_SAFE + if (hardware >= HW::XE3P_35_10) unsupported(); if (hw < HW::Gen10) unsupported(); #endif opX((hw >= HW::XeHPC) ? Opcode::macl : Opcode::mach, getDataType
(), mod, dst, src0, src1, loc); @@ -771,6 +802,7 @@ class BinaryCodeGenerator template void macl(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1, SourceLocation loc = {}) { #ifdef NGEN_SAFE + if (hardware >= HW::XE3P_35_10) unsupported(); if (hw < HW::Gen10) unsupported(); #endif opX((hw >= HW::XeHPC) ? Opcode::macl : Opcode::mach, getDataType
(), mod, dst, src0, src1, loc); @@ -926,6 +958,14 @@ class BinaryCodeGenerator src1 = src1.forceInt32(); opX(Opcode::mul, getDataType
(), mod, dst, src0, src1, loc); } + template + void mullh(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, SourceLocation loc = {}) { + opX(Opcode::mullh, getDataType
(), mod, dst, src0, src1, loc); + } + template + void mullh(const InstructionModifier &mod, const RegData &dst, const RegData &src0, Immediate src1, SourceLocation loc = {}) { + opX(Opcode::mullh, getDataType
(), mod, dst, src0, src1, loc); + } void nop(SourceLocation loc = {}) { opNop(isGen12 ? Opcode::nop_gen12 : Opcode::nop, loc); } @@ -1115,6 +1155,114 @@ class BinaryCodeGenerator void sendsc(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &exdesc, const RegData &desc, SourceLocation loc = {}) { opSends(Opcode::sendsc, mod, dst, src0, src1, exdesc, desc, loc); } + void sendg(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendg, mod, sf, dst, src0, src0.getLen(), NullRegister(), 0, NullRegister(), NullRegister(), desc, loc); + } + void sendg(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegData &ind0, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendg, mod, sf, dst, src0, src0.getLen(), NullRegister(), 0, ind0, NullRegister(), desc, loc); + } + void sendg(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegData &ind0, const RegData &ind1, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendg, mod, sf, dst, src0, src0.getLen(), NullRegister(), 0, ind0, ind1, desc, loc); + } + void sendg(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegisterRange &src1, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendg, mod, sf, dst, src0, src0.getLen(), src1, src1.getLen(), NullRegister(), NullRegister(), desc, loc); + } + void sendg(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegisterRange &src1, const RegData &ind0, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendg, mod, sf, dst, src0, src0.getLen(), src1, src1.getLen(), ind0, NullRegister(), desc, loc); + } + void sendg(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegisterRange &src1, const RegData &ind0, const RegData &ind1, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendg, mod, sf, dst, src0, src0.getLen(), src1, src1.getLen(), ind0, ind1, desc, loc); + } + void sendg(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, int src0Len, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendg, mod, sf, dst, src0, src0Len, NullRegister(), 0, NullRegister(), NullRegister(), desc, loc); + } + void sendg(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, int src0Len, const RegData &ind0, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendg, mod, sf, dst, src0, src0Len, NullRegister(), 0, ind0, NullRegister(), desc, loc); + } + void sendg(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, int src0Len, const RegData &ind0, const RegData &ind1, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendg, mod, sf, dst, src0, src0Len, NullRegister(), 0, ind0, ind1, desc, loc); + } + void sendgc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgc, mod, sf, dst, src0, src0.getLen(), NullRegister(), 0, NullRegister(), NullRegister(), desc, loc); + } + void sendgc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegData &ind0, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgc, mod, sf, dst, src0, src0.getLen(), NullRegister(), 0, ind0, NullRegister(), desc, loc); + } + void sendgc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegData &ind0, const RegData &ind1, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgc, mod, sf, dst, src0, src0.getLen(), NullRegister(), 0, ind0, ind1, desc, loc); + } + void sendgc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegisterRange &src1, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgc, mod, sf, dst, src0, src0.getLen(), src1, src1.getLen(), NullRegister(), NullRegister(), desc, loc); + } + void sendgc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegisterRange &src1, const RegData &ind0, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgc, mod, sf, dst, src0, src0.getLen(), src1, src1.getLen(), ind0, NullRegister(), desc, loc); + } + void sendgc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegisterRange &src1, const RegData &ind0, const RegData &ind1, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgc, mod, sf, dst, src0, src0.getLen(), src1, src1.getLen(), ind0, ind1, desc, loc); + } + void sendgc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, int src0Len, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgc, mod, sf, dst, src0, src0Len, NullRegister(), 0, NullRegister(), NullRegister(), desc, loc); + } + void sendgc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, int src0Len, const RegData &ind0, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgc, mod, sf, dst, src0, src0Len, NullRegister(), 0, ind0, NullRegister(), desc, loc); + } + void sendgc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, int src0Len, const RegData &ind0, const RegData &ind1, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgc, mod, sf, dst, src0, src0Len, NullRegister(), 0, ind0, ind1, desc, loc); + } + void sendgx(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgx, mod, sf, dst, src0, src0.getLen(), NullRegister(), 0, NullRegister(), NullRegister(), desc, loc); + } + void sendgx(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegData &ind0, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgx, mod, sf, dst, src0, src0.getLen(), NullRegister(), 0, ind0, NullRegister(), desc, loc); + } + void sendgx(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegData &ind0, const RegData &ind1, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgx, mod, sf, dst, src0, src0.getLen(), NullRegister(), 0, ind0, ind1, desc, loc); + } + void sendgx(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegisterRange &src1, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgx, mod, sf, dst, src0, src0.getLen(), src1, src1.getLen(), NullRegister(), NullRegister(), desc, loc); + } + void sendgx(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegisterRange &src1, const RegData &ind0, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgx, mod, sf, dst, src0, src0.getLen(), src1, src1.getLen(), ind0, NullRegister(), desc, loc); + } + void sendgx(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegisterRange &src1, const RegData &ind0, const RegData &ind1, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgx, mod, sf, dst, src0, src0.getLen(), src1, src1.getLen(), ind0, ind1, desc, loc); + } + void sendgx(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, int src0Len, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgx, mod, sf, dst, src0, src0Len, NullRegister(), 0, NullRegister(), NullRegister(), desc, loc); + } + void sendgx(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, int src0Len, const RegData &ind0, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgx, mod, sf, dst, src0, src0Len, NullRegister(), 0, ind0, NullRegister(), desc, loc); + } + void sendgx(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, int src0Len, const RegData &ind0, const RegData &ind1, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgx, mod, sf, dst, src0, src0Len, NullRegister(), 0, ind0, ind1, desc, loc); + } + void sendgxc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgxc, mod, sf, dst, src0, src0.getLen(), NullRegister(), 0, NullRegister(), NullRegister(), desc, loc); + } + void sendgxc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegData &ind0, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgxc, mod, sf, dst, src0, src0.getLen(), NullRegister(), 0, ind0, NullRegister(), desc, loc); + } + void sendgxc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegData &ind0, const RegData &ind1, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgxc, mod, sf, dst, src0, src0.getLen(), NullRegister(), 0, ind0, ind1, desc, loc); + } + void sendgxc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegisterRange &src1, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgxc, mod, sf, dst, src0, src0.getLen(), src1, src1.getLen(), NullRegister(), NullRegister(), desc, loc); + } + void sendgxc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegisterRange &src1, const RegData &ind0, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgxc, mod, sf, dst, src0, src0.getLen(), src1, src1.getLen(), ind0, NullRegister(), desc, loc); + } + void sendgxc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegisterRange &src1, const RegData &ind0, const RegData &ind1, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgxc, mod, sf, dst, src0, src0.getLen(), src1, src1.getLen(), ind0, ind1, desc, loc); + } + void sendgxc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, int src0Len, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgxc, mod, sf, dst, src0, src0Len, NullRegister(), 0, NullRegister(), NullRegister(), desc, loc); + } + void sendgxc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, int src0Len, const RegData &ind0, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgxc, mod, sf, dst, src0, src0Len, NullRegister(), 0, ind0, NullRegister(), desc, loc); + } + void sendgxc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, int src0Len, const RegData &ind0, const RegData &ind1, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgxc, mod, sf, dst, src0, src0Len, NullRegister(), 0, ind0, ind1, desc, loc); + } template void shl(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, SourceLocation loc = {}) { opX(isGen12 ? Opcode::shl_gen12 : Opcode::shl, getDataType
(), mod, dst, src0, src1, loc); @@ -1169,6 +1317,78 @@ class BinaryCodeGenerator opX(isGen12 ? Opcode::xor_gen12 : Opcode::xor_, getDataType
(), mod, dst, src0, src1, loc); } + template + void bdpas(const InstructionModifier &mod, uint8_t sdepth, uint8_t rcount, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &src2, const RegData &src3, const RegData &src4, SourceLocation loc = {}) { + auto emod = mod | defaultModifier; + if (emod.isAutoSWSB()) { + if (!src3.isARF()) wrdep(GRF(src3.getBase()), loc); + if (!src4.isARF()) wrdep(GRF(src4.getBase()), loc); + } + opBdpas(Opcode::bdpas, getDataType
(), mod, sdepth, rcount, dst, src0, src1, src2, src3, src4, loc); + } + template + void dnscl(const InstructionModifier &mod, uint8_t mode, RoundingType rnd, RegData dst, RegData src0, RegData src1, const RegData &src2, SourceLocation loc = {}) { + auto ctrl = encodeDnsclCtrl(mode, rnd, dst, src0, src1); + opBfn(Opcode::dnscl, getDataType
(), mod, ctrl, dst, src0, src1, src2, loc); + } + +private: + struct LFSR { + BinaryCodeGenerator &parent; + + LFSR(BinaryCodeGenerator *parent_) : parent(*parent_) {} + + void operator()(LFSRFunction fc, const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, SourceLocation loc = {}) { + parent.opShflLfsr(Opcode::lfsr, static_cast(fc), DataType::invalid, mod, dst, src0, src1, loc); + } + + template + void b32(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, SourceLocation loc = {}) { + parent.opShflLfsr(Opcode::lfsr, static_cast(LFSRFunction::b32), getDataType
(), mod, dst, src0, src1, loc); + } + template + void b32(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1, SourceLocation loc = {}) { + parent.opShflLfsr(Opcode::lfsr, static_cast(LFSRFunction::b32), getDataType
(), mod, dst, src0, src1, loc); + } + template + void b16v2(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, SourceLocation loc = {}) { + parent.opShflLfsr(Opcode::lfsr, static_cast(LFSRFunction::b16v2), getDataType
(), mod, dst, src0, src1, loc); + } + template + void b16v2(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1, SourceLocation loc = {}) { + parent.opShflLfsr(Opcode::lfsr, static_cast(LFSRFunction::b16v2), getDataType
(), mod, dst, src0, src1, loc); + } + template + void b8v4(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, SourceLocation loc = {}) { + parent.opShflLfsr(Opcode::lfsr, static_cast(LFSRFunction::b8v4), getDataType
(), mod, dst, src0, src1, loc); + } + template + void b8v4(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1, SourceLocation loc = {}) { + parent.opShflLfsr(Opcode::lfsr, static_cast(LFSRFunction::b8v4), getDataType
(), mod, dst, src0, src1, loc); + } + }; +public: + LFSR lfsr; + +private: + struct Shfl { + BinaryCodeGenerator &parent; + + Shfl(BinaryCodeGenerator *parent_) : parent(*parent_) {} + + void operator()(ShuffleFunction fc, const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, SourceLocation loc = {}) { + parent.opShflLfsr(Opcode::shfl, static_cast(fc), DataType::invalid, mod, dst, src0, src1, loc); + } + + template + void idx4(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, SourceLocation loc = {}) { + parent.opShflLfsr(Opcode::shfl, static_cast(ShuffleFunction::idx4), getDataType
(), mod, dst, src0, src1, loc); + } + }; +public: + Shfl shfl; + + private: struct Sync { BinaryCodeGenerator &parent; @@ -1488,8 +1708,25 @@ NGEN_FORWARD_SCOPE_OP_NAMES(scope) \ NGEN_FORWARD_SCOPE_MIN_MAX(scope) \ NGEN_FORWARD_SCOPE_REGISTERS(scope) -#define NGEN_FORWARD_SCOPE_EXTRA1(scope) -#define NGEN_FORWARD_SCOPE_EXTRA_ELF_OVERRIDES(hw) +#define NGEN_FORWARD_SCOPE_EXTRA1(scope) \ + NGEN_FORWARD_SCOPE_DT_OP(bdpas, scope) \ + NGEN_FORWARD_SCOPE_DT_OP(dnscl, scope) \ + NGEN_FORWARD_SCOPE_DT_OP(mullh, scope) \ + using scope::lfsr; \ + using scope::shfl; \ + NGEN_FORWARD_SCOPE_OP(sendg, scope) \ + NGEN_FORWARD_SCOPE_OP(sendgc, scope) \ + NGEN_FORWARD_SCOPE_OP(sendgx, scope) \ + NGEN_FORWARD_SCOPE_OP(sendgxc, scope) \ + NGEN_FORWARD_SCOPE_DT_OP(sigm, scope) \ + bool getEfficient64Bit() { return scope::getEfficient64Bit(); } \ + void setEfficient64Bit(bool def = true) {return scope::setEfficient64Bit(def);} + +#define NGEN_FORWARD_EXTRA_ELF_OVERRIDES(hw) \ + template void setEfficient64Bit(Targs &&...args) { \ + NGEN_NAMESPACE::BinaryCodeGenerator::setEfficient64Bit( \ + std::forward(args)...); \ + } #define NGEN_FORWARD_SCOPE_EXTRA2(scope) @@ -1641,7 +1878,81 @@ using scope::L1S_L3UC; using scope::L1S_L3C; using scope::L1IAR_L3C; using scope using scope::L1WT_L3UC; using scope::L1WT_L3WB; using scope::L1S_L3WB; using scope::L1WB_L3WB; \ using scope::L1C_L3CC; using scope::L1UC_L3CC; \ using scope::s0; -#define NGEN_FORWARD_SCOPE_REGISTERS_EXTRA1(scope) +#define NGEN_FORWARD_SCOPE_REGISTERS_EXTRA1(scope) \ +using scope::Fwd; \ +using scope::r256; using scope::r257; using scope::r258; using scope::r259; \ +using scope::r260; using scope::r261; using scope::r262; using scope::r263; \ +using scope::r264; using scope::r265; using scope::r266; using scope::r267; \ +using scope::r268; using scope::r269; using scope::r270; using scope::r271; \ +using scope::r272; using scope::r273; using scope::r274; using scope::r275; \ +using scope::r276; using scope::r277; using scope::r278; using scope::r279; \ +using scope::r280; using scope::r281; using scope::r282; using scope::r283; \ +using scope::r284; using scope::r285; using scope::r286; using scope::r287; \ +using scope::r288; using scope::r289; using scope::r290; using scope::r291; \ +using scope::r292; using scope::r293; using scope::r294; using scope::r295; \ +using scope::r296; using scope::r297; using scope::r298; using scope::r299; \ +using scope::r300; using scope::r301; using scope::r302; using scope::r303; \ +using scope::r304; using scope::r305; using scope::r306; using scope::r307; \ +using scope::r308; using scope::r309; using scope::r310; using scope::r311; \ +using scope::r312; using scope::r313; using scope::r314; using scope::r315; \ +using scope::r316; using scope::r317; using scope::r318; using scope::r319; \ +using scope::r320; using scope::r321; using scope::r322; using scope::r323; \ +using scope::r324; using scope::r325; using scope::r326; using scope::r327; \ +using scope::r328; using scope::r329; using scope::r330; using scope::r331; \ +using scope::r332; using scope::r333; using scope::r334; using scope::r335; \ +using scope::r336; using scope::r337; using scope::r338; using scope::r339; \ +using scope::r340; using scope::r341; using scope::r342; using scope::r343; \ +using scope::r344; using scope::r345; using scope::r346; using scope::r347; \ +using scope::r348; using scope::r349; using scope::r350; using scope::r351; \ +using scope::r352; using scope::r353; using scope::r354; using scope::r355; \ +using scope::r356; using scope::r357; using scope::r358; using scope::r359; \ +using scope::r360; using scope::r361; using scope::r362; using scope::r363; \ +using scope::r364; using scope::r365; using scope::r366; using scope::r367; \ +using scope::r368; using scope::r369; using scope::r370; using scope::r371; \ +using scope::r372; using scope::r373; using scope::r374; using scope::r375; \ +using scope::r376; using scope::r377; using scope::r378; using scope::r379; \ +using scope::r380; using scope::r381; using scope::r382; using scope::r383; \ +using scope::r384; using scope::r385; using scope::r386; using scope::r387; \ +using scope::r388; using scope::r389; using scope::r390; using scope::r391; \ +using scope::r392; using scope::r393; using scope::r394; using scope::r395; \ +using scope::r396; using scope::r397; using scope::r398; using scope::r399; \ +using scope::r400; using scope::r401; using scope::r402; using scope::r403; \ +using scope::r404; using scope::r405; using scope::r406; using scope::r407; \ +using scope::r408; using scope::r409; using scope::r410; using scope::r411; \ +using scope::r412; using scope::r413; using scope::r414; using scope::r415; \ +using scope::r416; using scope::r417; using scope::r418; using scope::r419; \ +using scope::r420; using scope::r421; using scope::r422; using scope::r423; \ +using scope::r424; using scope::r425; using scope::r426; using scope::r427; \ +using scope::r428; using scope::r429; using scope::r430; using scope::r431; \ +using scope::r432; using scope::r433; using scope::r434; using scope::r435; \ +using scope::r436; using scope::r437; using scope::r438; using scope::r439; \ +using scope::r440; using scope::r441; using scope::r442; using scope::r443; \ +using scope::r444; using scope::r445; using scope::r446; using scope::r447; \ +using scope::r448; using scope::r449; using scope::r450; using scope::r451; \ +using scope::r452; using scope::r453; using scope::r454; using scope::r455; \ +using scope::r456; using scope::r457; using scope::r458; using scope::r459; \ +using scope::r460; using scope::r461; using scope::r462; using scope::r463; \ +using scope::r464; using scope::r465; using scope::r466; using scope::r467; \ +using scope::r468; using scope::r469; using scope::r470; using scope::r471; \ +using scope::r472; using scope::r473; using scope::r474; using scope::r475; \ +using scope::r476; using scope::r477; using scope::r478; using scope::r479; \ +using scope::r480; using scope::r481; using scope::r482; using scope::r483; \ +using scope::r484; using scope::r485; using scope::r486; using scope::r487; \ +using scope::r488; using scope::r489; using scope::r490; using scope::r491; \ +using scope::r492; using scope::r493; using scope::r494; using scope::r495; \ +using scope::r496; using scope::r497; using scope::r498; using scope::r499; \ +using scope::r500; using scope::r501; using scope::r502; using scope::r503; \ +using scope::r504; using scope::r505; using scope::r506; using scope::r507; \ +using scope::r508; using scope::r509; using scope::r510; using scope::r511; \ +using scope::A64_A32U; using scope::A64_A32S; using scope::Overfetch; \ +using scope::L1UC_L2UC_L3UC; using scope::L1UC_L2UC_L3C; using scope::L1UC_L2C_L3UC; \ +using scope::L1UC_L2C_L3C; using scope::L1C_L2UC_L3UC; using scope::L1C_L2UC_L3C; \ +using scope::L1C_L2C_L3UC; using scope::L1C_L2C_L3C; using scope::L1S_L2UC_L3UC; \ +using scope::L1S_L2UC_L3C; using scope::L1S_L2C_L3UC; using scope::L1S_L2C_L3C; \ +using scope::L1IAR_L2IAR_L3IAR; using scope::L1UC_L2UC_L3WB; using scope::L1UC_L2WB_L3UC; \ +using scope::L1WT_L2UC_L3UC; using scope::L1WT_L2UC_L3WB; using scope::L1WT_L2WB_L3UC; \ +using scope::L1S_L2UC_L3WB; using scope::L1S_L2WB_L3UC; using scope::L1S_L2WB_L3WB; \ +using scope::L1WB_L2WB_L3UC; using scope::L1WB_L2UC_L3WB; #define NGEN_FORWARD_SCOPE_REGISTERS_EXTRA2(scope) #define NGEN_FORWARD_SCOPE_REGISTERS(scope) \ NGEN_FORWARD_SCOPE_REGISTERS_BASE(scope) \ @@ -1674,6 +1985,7 @@ static inline Instruction12 encodeSyncInsertion(autoswsb::SyncInsertion &si) { Instruction12 i; + i.common.opcode = static_cast(Opcode::sync); i.common.swsb = (hw >= HW::XeHPC) ? SWSBInfoXeHPC(si.swsb, Opcode::sync).raw() : SWSBInfo12(si.swsb, Opcode::sync).raw(); @@ -1765,6 +2077,7 @@ std::vector BinaryCodeGenerator::getCode() result.resize(reinterpret_cast(pdst) - result.data()); } + return result; } @@ -1830,6 +2143,13 @@ BinaryCodeGenerator::opX(Opcode op, DataType defaultType, const InstructionM i.binary.cmod = static_cast(mod.getCMod()); + if (hw >= HW::XE3P_35_10) { + if (op == Opcode::math) + i.binaryXe3pImm.src0Reg8 = 0; + i.binaryXe3p.dstReg8 = getHighBit(dst); + i.binaryXe3p.src0Reg8 = getHighBit(src0); + } + db(i, loc); } @@ -1907,6 +2227,9 @@ BinaryCodeGenerator::opX(Opcode op, DataType defaultType, const InstructionM i.imm64.high = val >> 32; } + if (hw >= HW::XE3P_35_10) + i.unaryXe3pImm.dstReg8 = getHighBit(dst); + db(i, loc); } @@ -1987,6 +2310,14 @@ BinaryCodeGenerator::opX(Opcode op, DataType defaultType, const InstructionM i.binary.cmod = static_cast(mod.getCMod()); + if (hw >= HW::XE3P_35_10) { + i.binaryXe3pImm.src0Reg8 = 0; + i.binaryXe3p.dstReg8 = getHighBit(dst); + i.binaryXe3p.src0Reg8 = getHighBit(src0); + i.binaryXe3p.src1Reg8 = getHighBit(src1); + i.binaryXe3p.src1Scalar = checkSrc1Scalar(op, src1, dst, tag); + } + db(i, loc); } @@ -2063,6 +2394,11 @@ BinaryCodeGenerator::opX(Opcode op, DataType defaultType, const InstructionM i.binary.src1Imm = true; i.imm32.value = uint32_t(static_cast(src1)); + if (hw >= HW::XE3P_35_10) { + i.binaryXe3pImm.dstReg8 = getHighBit(dst); + i.binaryXe3pImm.src0Reg8 = getHighBit(src0); + } + db(i, loc); } @@ -2177,6 +2513,8 @@ BinaryCodeGenerator::opX(Opcode op, DataType defaultType, const InstructionM i.ternary.cmod = static_cast(mod.getCMod()); + encodeTernary512GRF(i, dst, src0, src1, src2, tag); + db(i, loc); } @@ -2229,6 +2567,8 @@ void BinaryCodeGenerator::opBfn(Opcode op, DataType defaultType, const Instr i.bfn.bfnCtrl03 = (bfnCtrl >> 0); i.bfn.bfnCtrl47 = (bfnCtrl >> 4); + encodeTernary512GRF(i, dst, src0, src1, src2, tag); + db(i, loc); } @@ -2250,6 +2590,7 @@ static inline void encodeDPAS(Instruction12 &i, Opcode op, DataType defaultType, i.ternary.src2 = encodeTernaryOperand12(src2, tag).bits; encodeTernaryTypes(i, dst, src0, src1, src2); + encodeTernary512GRF(i, dst, src0, src1, src2, tag); i.dpas.rcount = rcount - 1; i.dpas.sdepth = utils::log2(sdepth); @@ -2270,6 +2611,38 @@ void BinaryCodeGenerator::opDpas(Opcode op, DataType defaultType, const Inst db(i, loc); } +template +void BinaryCodeGenerator::opBdpas(Opcode op, DataType defaultType, const InstructionModifier &mod, int sdepth, int rcount, + RegData dst, RegData src0, RegData src1, RegData src2, RegData src3, RegData src4, SourceLocation loc) +{ + if (hw < HW::XE3P_35_10) unsupported(); + if (sdepth != 8 || rcount != 8) unsupported(); + + Instruction12 i{}; + + encodeDPAS(i, op, defaultType, mod | defaultModifier, sdepth, rcount, dst, src0, src1, src2); + + src3.fixup(hw, mod.getExecSize(), 0, DataType::ub, 3, 5); + src4.fixup(hw, mod.getExecSize(), 0, DataType::ub, 4, 5); + + int s3r = src3.getBase(), s4r = src4.getBase(); + + i.bdpas.src3RegFile = src3.getRegFile8(); + i.bdpas.src4RegFile = src4.getRegFile8(); + + i.bdpas.src3Reg0 = s3r; + i.bdpas.src3Reg1_2 = s3r >> 1; + i.bdpas.src3Reg3_6 = s3r >> 3; + i.bdpas.src3Reg7_8 = s3r >> 7; + i.bdpas.src3SubReg4_5 = src3.getByteOffset() >> 4; + + i.bdpas.src4Reg0_3 = s4r; + i.bdpas.src4Reg4_8 = s4r >> 4; + i.bdpas.src4SubReg3_5 = src4.getByteOffset() >> 3; + + db(i, loc); +} + template template typename std::enable_if::type @@ -2325,6 +2698,11 @@ BinaryCodeGenerator::opSend(Opcode op, const InstructionModifier &mod, Share if (src0Indirect) i.send.exDesc6_10 = src0.getOffset() >> 1; +#ifdef NGEN_SAFE + if (getHighBit(dst) || getHighBit(src0) || getHighBit(src1)) + throw limited_to_256_grf_exception(); +#endif + db(i, loc); } @@ -2452,6 +2830,102 @@ BinaryCodeGenerator::opSends(Opcode op, const InstructionModifier &mod, cons opSend(mop, mod, static_cast(exdesc & 0x1F), dst, src0, src1, -1, exdesc, desc, loc); } +static inline unsigned encodeSendgxRegNum(RegData r) +{ + if (r.isNull()) + return 0x1FF; +#ifdef NGEN_SAFE + else if (r.isARF()) + throw invalid_arf_exception(); + else if (r.getBase() == 0x1FF) + throw r511_not_allowed_exception(); +#endif + else + return r.getBase(); +} + +template +void BinaryCodeGenerator::opSendg(Opcode op, const InstructionModifier &mod, SharedFunction sfid, + const RegData &dst, RegData src0, int src0Len, const RegData &src1, int src1Len, + RegData ind0, RegData ind1, uint64_t desc, SourceLocation loc) +{ + + typename EncodingTag12Dispatch::tag tag; + Instruction12 i{}; + InstructionModifier emod = mod | defaultModifier; + + bool src0Indirect = src0.isIndirect(); + if (src0Indirect) + src0 = src0.getIndirectReg(); + + encodeCommon12(i, op, emod, dst, tag); + + i.sendg.eot = emod.isEOT(); + + if (op == Opcode::sendgx || op == Opcode::sendgxc) { + unsigned dstReg = encodeSendgxRegNum(dst); + unsigned src0Reg = encodeSendgxRegNum(src0); + unsigned src1Reg = encodeSendgxRegNum(src1); + + i.sendg.dstReg = dstReg; + i.sendg.src0Reg = src0Reg; + i.sendg.src1Reg = src1Reg; + + i.sendgx.dstReg8 = dstReg >> 8; + i.sendgx.src0Reg8 = src0Reg >> 8; + i.sendgx.src1Reg8 = src1Reg >> 8; + } else { + i.sendg.dstReg = dst.getBase(); + i.sendg.src0Reg = src0.getBase(); + i.sendg.src1Reg = src1.getBase(); + + i.sendg.dstRegFile = dst.getRegFile8(); + i.sendg.src0RegFile = src0.getRegFile8(); + i.sendg.src1RegFile = src1.getRegFile8(); + +#ifdef NGEN_SAFE + if (getHighBit(dst) || getHighBit(src0) || getHighBit(src1)) + throw limited_to_256_grf_exception(); +#endif + } + + i.sendg.src0Len = src0Len; + i.sendg.src1Len = src1Len; + + i.sendg.sfid = static_cast(sfid) & 0xF; + + i.sendg.desc0_15 = desc; + i.sendg.desc16_27 = desc >> 16; + i.sendg.desc28_29 = desc >> 28; + i.sendg.desc30_31 = desc >> 30; + i.sendg.desc32_39 = desc >> 32; + i.sendg.desc40_41 = desc >> 40; + i.sendg.ind1_desc42_46 = desc >> 42; + + if (src0Indirect) + i.sendg.src1Len = src0.getOffset() >> 1; + + i.sendg.ind0Present = !ind0.isNull(); + i.sendg.ind1Present = !ind1.isNull(); + + if (i.sendg.ind0Present) { +#ifdef NGEN_SAFE + if (!ind0.isARF() || ind0.getARFType() != ARFType::s) + throw invalid_arf_exception(); +#endif + i.sendg.ind0 = ind0.getByteOffset() >> 3; + } + if (i.sendg.ind1Present) { +#ifdef NGEN_SAFE + if (!ind1.isARF() || ind1.getARFType() != ARFType::s) + throw invalid_arf_exception(); +#endif + i.sendg.ind1_desc42_46 = ind1.getByteOffset() >> 3; + } + + db(i, loc); +} + template template typename std::enable_if::type @@ -2495,6 +2969,9 @@ BinaryCodeGenerator::opBranch(Opcode op, const InstructionModifier &mod, con i.branches.jip = jip; i.branches.uip = uip; + if (hw >= HW::XE3P_35_10) + i.branchXe3p.dstReg8 = getHighBit(dst); + db(i, loc); } @@ -2540,6 +3017,9 @@ BinaryCodeGenerator::opBranch(Opcode op, const InstructionModifier &mod, con i.binary.src0Imm = true; i.branches.jip = jip; + if (hw >= HW::XE3P_35_10) + i.branchXe3p.dstReg8 = getHighBit(dst); + db(i, loc); } @@ -2587,6 +3067,11 @@ BinaryCodeGenerator::opBranch(Opcode op, const InstructionModifier &mod, con i.binary.src0 &= 0xFFFF; + if (hw >= HW::XE3P_35_10) { + i.branchXe3p.dstReg8 = getHighBit(dst); + i.branchXe3p.src0Reg8 = getHighBit(src0); + } + db(i, loc); } @@ -2657,6 +3142,24 @@ void BinaryCodeGenerator::opJmpi(Opcode op, const InstructionModifier &mod, addFixup(LabelFixup(jip.getID(labelManager), LabelFixup::JIPOffsetJMPI)); } +template +void BinaryCodeGenerator::opShflLfsr(Opcode op, uint8_t fc, DataType defaultType, const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, SourceLocation loc) +{ + InstructionModifier mmod = mod; + + mmod.setCMod(static_cast(fc)); + opX(op, defaultType, mmod, dst, src0, src1, loc); +} + +template +void BinaryCodeGenerator::opShflLfsr(Opcode op, uint8_t fc, DataType defaultType, const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1, SourceLocation loc) +{ + InstructionModifier mmod = mod; + + mmod.setCMod(static_cast(fc)); + opX(op, defaultType, mmod, dst, src0, src1, loc); +} + template void BinaryCodeGenerator::opSync(Opcode op, SyncFunction fc, const InstructionModifier &mod, SourceLocation loc) { @@ -2736,6 +3239,7 @@ void BinaryCodeGenerator::opDirective(Directive directive, RegData src0, S1 opX(Opcode::directive, DataType::ud, InstructionModifier::createAutoSWSB(), GRF(static_cast(directive)), src0, src1, loc); } + } /* namespace NGEN_NAMESPACE */ #if defined(__clang__) diff --git a/third_party/ngen/ngen_asm.hpp b/third_party/ngen/ngen_asm.hpp index 7d1a12b28c4..7afa58c9c78 100644 --- a/third_party/ngen/ngen_asm.hpp +++ b/third_party/ngen/ngen_asm.hpp @@ -73,6 +73,7 @@ inline void RegData::outputText(std::ostream &str, PrintDetail detail, LabelMana } } + if (detail <= PrintDetail::base) return; if (!isIndirect() && !isNull()) @@ -148,6 +149,7 @@ inline void GRFRange::outputText(std::ostream &str, PrintDetail detail, LabelMan str << 'r' << int(base) << ':' << int(len); } + inline void Label::outputText(std::ostream &str, PrintDetail detail, LabelManager &man) { str << 'L' << getID(man); } @@ -186,6 +188,7 @@ struct AsmOperand { AsmOperand(Label label_) : label{label_}, type{Type::label} {} AsmOperand(GRFRange range_) : range{range_}, type{Type::range} {} AsmOperand(uint32_t imm_) : imm{imm_}, type{Type::imm} {} + AsmOperand(uint64_t imm_) : imm{imm_}, type{Type::imm} {} void outputText(std::ostream &str, PrintDetail detail, LabelManager &man) const { switch (type) { @@ -385,6 +388,13 @@ bool AsmInstruction::getOperandRegion(autoswsb::DependencyRegion ®ion, int op region = DependencyRegion(hw, desc.parts.messageLen, sreg); return true; } + if ((op == Opcode::sendg || op == Opcode::sendgc) && opNum == 1 + && src[0].type == AsmOperand::Type::reg && src[0].reg.isIndirect()) { + auto sreg = src[0].reg.getIndirectReg(); + sreg.setRegion(1, 1, 0); + region = DependencyRegion(hw, ext >> 8, sreg); + return true; + } return false; default: return false; } @@ -421,6 +431,18 @@ bool AsmInstruction::getOperandRegion(autoswsb::DependencyRegion ®ion, int op region = DependencyRegion(); else region = DependencyRegion(hw, GRFRange(rd.getBase(), len)); + } else if (op == Opcode::sendg || op == Opcode::sendgc || op == Opcode::sendgx || op == Opcode::sendgxc) { + if (opNum == -1 && !rd.isNull() && src[4].type == AsmOperand::Type::imm) { + SendgMessageDescriptor desc; + desc.all = static_cast(src[4].imm); + int execSize = mod.getExecSize(); + int len = desc.dstLen(hw, execSize, static_cast(ext & 0xF)); + if (len == -1) + region = DependencyRegion(); + else + region = DependencyRegion(hw, GRFRange(rd.getBase(), len)); + } else + region = DependencyRegion(hw, mod.getExecSize(), rd); } else if (op == Opcode::dpas || op == Opcode::dpasw) { unsigned sdepth = ext >> 8; unsigned rcount = ext & 0xFF; @@ -432,7 +454,7 @@ bool AsmInstruction::getOperandRegion(autoswsb::DependencyRegion ®ion, int op case 1: len = sdepth; break; case 2: if (op == Opcode::dpasw) rcount = (rcount + 1) >> 1; - len = GRF::bytesToGRFs(hw, operand.reg.getByteOffset() + sdepth * rcount * 4); + len = GRF::bytesToGRFs(hw, operand.reg.getByteOffset() + sdepth * rcount * 4 * operand.reg.getDwords()); break; default: return false; } @@ -462,11 +484,13 @@ class AsmCodeGenerator { #include "ngen_compiler_fix.hpp" public: explicit AsmCodeGenerator(Product product_) : hardware(getCore(product_.family)), product(product_), defaultOutput{nullptr}, + lfsr{this}, shfl{this}, sync{this}, load{this}, store{this}, atomic{this} { isGen12 = (hardware >= HW::Gen12LP); _workaround_(); streamStack.push_back(new InstructionStream()); + useEfficient64Bit = (hardware >= HW::XE3P_35_10); } explicit AsmCodeGenerator(HW hardware_, int stepping_ = 0) : AsmCodeGenerator({genericProductFamily(hardware_), 0, PlatformType::Unknown}) {} @@ -553,6 +577,8 @@ class AsmCodeGenerator { std::vector streamStack; std::unique_ptr> cancelAutoSWSB_; + bool useEfficient64Bit = false; + inline void unsupported(); // Output functions. @@ -614,6 +640,24 @@ class AsmCodeGenerator { i.src[2].imm = uint32_t(exdesc | static_cast(sf)); } } + template + static inline T uniformizeInd(T r) { return r; } + static inline RegData uniformizeInd(RegData r) { r.setOffset(r.getByteOffset()); r.setType(DataType::ub); return r; } + template + static inline void applyLen(T &r, int len) {} + static inline void applyLen(RegisterRange &r, int len) { if (len > 0 && r.isValid()) r = RegisterRange(r[0], len); } + + template + void opSendg(Opcode op, const InstructionModifier &mod, SharedFunction sf, RegData dst, const S0 &src0, const S1 &src1, I0 ind0, I1 ind1, uint64_t desc) { + (void) streamStack.back()->append(op, static_cast(sf), mod | defaultModifier, &labelManager, dst, src0, src1, uniformizeInd(ind0), uniformizeInd(ind1), desc); + } + template + void opSendg(Opcode op, const InstructionModifier &mod, SharedFunction sf, RegData dst, const RegData &src0, int src0Len, const I0 &ind0, const I1 &ind1, uint64_t desc) { + if (src0.isIndirect()) + (void) streamStack.back()->append(op, static_cast(sf) | 0x80 | (src0Len << 8), mod | defaultModifier, &labelManager, dst, src0, NoOperand(), ind0, ind1, desc); + else + opSendg(op, mod, sf, dst, GRFRange(src0.getBase(), src0Len), NullRegister(), uniformizeInd(ind0), uniformizeInd(ind1), desc); + } void opDpas(Opcode op, DataType defaultType, const InstructionModifier &mod, int sdepth, int rcount, RegData dst, RegData src0, RegData src1, RegData src2) { dst.fixup(hardware, 1, 0, defaultType, -1, 3); src0.fixup(hardware, 1, 0, defaultType, 0, 3); @@ -621,6 +665,15 @@ class AsmCodeGenerator { src2.fixup(hardware, 1, 0, defaultType, 2, 3); (void) streamStack.back()->append(op, (sdepth << 8) | rcount, mod | defaultModifier, &labelManager, dst, src0, src1, src2); } + void opBdpas(Opcode op, DataType defaultType, const InstructionModifier &mod, int sdepth, int rcount, RegData dst, RegData src0, RegData src1, RegData src2, RegData src3, RegData src4) { + dst.fixup(hardware, 1, 0, defaultType, -1, 3); + src0.fixup(hardware, 1, 0, defaultType, 0, 3); + src1.fixup(hardware, 1, 0, defaultType, 1, 3); + src2.fixup(hardware, 1, 0, defaultType, 2, 3); + src3.fixup(hardware, 1, 0, DataType::ub, 3, 5); + src4.fixup(hardware, 1, 0, DataType::ub, 4, 5); + (void) streamStack.back()->append(op, (sdepth << 8) | rcount, mod | defaultModifier, &labelManager, dst, src0, src1, src2, src3, src4); + } template void opCall(Opcode op, const InstructionModifier &mod, D dst, S0 src0) { (void) streamStack.back()->append(op, 0, mod | defaultModifier | NoMask, &labelManager, dst, src0); } @@ -645,6 +698,7 @@ class AsmCodeGenerator { inline void outMods(std::ostream &out, const InstructionModifier &mod, Opcode op, ModPlacementType location, uint16_t ext = 0, uint32_t ext2 = 0); inline void outComment(std::ostream &out, const AsmInstruction &i); + InstructionModifier defaultMods() const { return GRF::bytes(hardware) >> 2; } @@ -656,6 +710,9 @@ class AsmCodeGenerator { bool getDefaultNoMask() const { return defaultModifier.isWrEn(); } bool getDefaultAutoSWSB() const { return defaultModifier.isAutoSWSB(); } + void setEfficient64Bit(bool def = true) { useEfficient64Bit = def; } + bool getEfficient64Bit() const { return useEfficient64Bit; } + // Stream handling. void pushStream() { pushStream(new InstructionStream()); } void pushStream(InstructionStream &s) { pushStream(&s); } @@ -1020,27 +1077,33 @@ class AsmCodeGenerator { } template void mac(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, SourceLocation loc = {}) { + if (hardware >= HW::XE3P_35_10) unsupported(); opX(Opcode::mac, getDataType
(), mod, dst, src0, src1); } template void mac(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1, SourceLocation loc = {}) { + if (hardware >= HW::XE3P_35_10) unsupported(); opX(Opcode::mac, getDataType
(), mod, dst, src0, src1); } template void mach(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, SourceLocation loc = {}) { + if (hardware >= HW::XE3P_35_10) unsupported(); opX(Opcode::mach, getDataType
(), (hardware >= HW::XeHPC) ? mod : (mod | AccWrEn), dst, src0, src1); } template void mach(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1, SourceLocation loc = {}) { + if (hardware >= HW::XE3P_35_10) unsupported(); opX(Opcode::mach, getDataType
(), (hardware >= HW::XeHPC) ? mod : (mod | AccWrEn), dst, src0, src1); } template void macl(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, SourceLocation loc = {}) { + if (hardware >= HW::XE3P_35_10) unsupported(); if (hardware < HW::Gen10) unsupported(); opX((hardware >= HW::XeHPC) ? Opcode::macl : Opcode::mach, getDataType
(), mod, dst, src0, src1); } template void macl(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1, SourceLocation loc = {}) { + if (hardware >= HW::XE3P_35_10) unsupported(); if (hardware < HW::Gen10) unsupported(); opX((hardware >= HW::XeHPC) ? Opcode::macl : Opcode::mach, getDataType
(), mod, dst, src0, src1); } @@ -1172,6 +1235,14 @@ class AsmCodeGenerator { src1 = src1.forceInt32(); opX(Opcode::mul, getDataType
(), mod, dst, src0, src1); } + template + void mullh(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, SourceLocation loc = {}) { + opX(Opcode::mullh, getDataType
(), mod, dst, src0, src1); + } + template + void mullh(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1, SourceLocation loc = {}) { + opX(Opcode::mullh, getDataType
(), mod, dst, src0, src1); + } void nop(SourceLocation loc = {}) { opX(isGen12 ? Opcode::nop_gen12 : Opcode::nop); } @@ -1379,6 +1450,126 @@ class AsmCodeGenerator { #endif sendc(mod, static_cast(0), dst, src0, src1, exdesc, desc); } + void sendg(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendg, mod, sf, dst, src0, NullRegister(), NoOperand(), NoOperand(), desc); + } + void sendg(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegData &ind0, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendg, mod, sf, dst, src0, NullRegister(), ind0, NoOperand(), desc); + } + void sendg(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegData &ind0, const RegData &ind1, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendg, mod, sf, dst, src0, NullRegister(), ind0, ind1, desc); + } + void sendg(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegisterRange &src1, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendg, mod, sf, dst, src0, src1, NoOperand(), NoOperand(), desc); + } + void sendg(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegisterRange &src1, const RegData &ind0, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendg, mod, sf, dst, src0, src1, ind0, NoOperand(), desc); + } + void sendg(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegisterRange &src1, const RegData &ind0, const RegData &ind1, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendg, mod, sf, dst, src0, src1, ind0, ind1, desc); + } + void sendg(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, int src0Len, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendg, mod, sf, dst, src0, src0Len, NoOperand(), NoOperand(), desc); + } + void sendg(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, int src0Len, const RegData &ind0, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendg, mod, sf, dst, src0, src0Len, ind0, NoOperand(), desc); + } + void sendg(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, int src0Len, const RegData &ind0, const RegData &ind1, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendg, mod, sf, dst, src0, src0Len, ind0, ind1, desc); + } + void sendgc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgc, mod, sf, dst, src0, NullRegister(), NoOperand(), NoOperand(), desc); + } + void sendgc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegData &ind0, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgc, mod, sf, dst, src0, NullRegister(), ind0, NoOperand(), desc); + } + void sendgc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegData &ind0, const RegData &ind1, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgc, mod, sf, dst, src0, NullRegister(), ind0, ind1, desc); + } + void sendgc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegisterRange &src1, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgc, mod, sf, dst, src0, src1, NoOperand(), NoOperand(), desc); + } + void sendgc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegisterRange &src1, const RegData &ind0, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgc, mod, sf, dst, src0, src1, ind0, NoOperand(), desc); + } + void sendgc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegisterRange &src1, const RegData &ind0, const RegData &ind1, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgc, mod, sf, dst, src0, src1, ind0, ind1, desc); + } + void sendgc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, int src0Len, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgc, mod, sf, dst, src0, src0Len, NoOperand(), NoOperand(), desc); + } + void sendgc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, int src0Len, const RegData &ind0, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgc, mod, sf, dst, src0, src0Len, ind0, NoOperand(), desc); + } + void sendgc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, int src0Len, const RegData &ind0, const RegData &ind1, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgc, mod, sf, dst, src0, src0Len, ind0, ind1, desc); + } + void sendgx(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgx, mod, sf, dst, src0, NullRegister(), NoOperand(), NoOperand(), desc); + } + void sendgx(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegData &ind0, uint64_t desc, SourceLocation loc = {}) { + if (ind0.isNull()) + opSendg(Opcode::sendgx, mod, sf, dst, src0, NullRegister(), NoOperand(), NoOperand(), desc); + else + opSendg(Opcode::sendgx, mod, sf, dst, src0, NullRegister(), ind0, NoOperand(), desc); + } + void sendgx(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegData &ind0, const RegData &ind1, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgx, mod, sf, dst, src0, NullRegister(), ind0, ind1, desc); + } + void sendgx(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegisterRange &src1, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgx, mod, sf, dst, src0, src1, NoOperand(), NoOperand(), desc); + } + void sendgx(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegisterRange &src1, const RegData &ind0, uint64_t desc, SourceLocation loc = {}) { + if (ind0.isNull()) + opSendg(Opcode::sendgx, mod, sf, dst, src0, src1, NoOperand(), NoOperand(), desc); + else + opSendg(Opcode::sendgx, mod, sf, dst, src0, src1, ind0, NoOperand(), desc); + } + void sendgx(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegisterRange &src1, const RegData &ind0, const RegData &ind1, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgx, mod, sf, dst, src0, src1, ind0, ind1, desc); + } + void sendgx(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, int src0Len, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgx, mod, sf, dst, src0, src0Len, NoOperand(), NoOperand(), desc); + } + void sendgx(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, int src0Len, const RegData &ind0, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgx, mod, sf, dst, src0, src0Len, ind0, NoOperand(), desc); + } + void sendgx(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, int src0Len, const RegData &ind0, const RegData &ind1, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgx, mod, sf, dst, src0, src0Len, ind0, ind1, desc); + } + void sendgxc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgxc, mod, sf, dst, src0, NullRegister(), NoOperand(), NoOperand(), desc); + } + void sendgxc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegData &ind0, uint64_t desc, SourceLocation loc = {}) { + if (ind0.isNull()) + opSendg(Opcode::sendgxc, mod, sf, dst, src0, NullRegister(), NoOperand(), NoOperand(), desc); + else + opSendg(Opcode::sendgxc, mod, sf, dst, src0, NullRegister(), ind0, NoOperand(), desc); + } + void sendgxc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegData &ind0, const RegData &ind1, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgxc, mod, sf, dst, src0, NullRegister(), ind0, ind1, desc); + } + void sendgxc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegisterRange &src1, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgxc, mod, sf, dst, src0, src1, NoOperand(), NoOperand(), desc); + } + void sendgxc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegisterRange &src1, const RegData &ind0, uint64_t desc, SourceLocation loc = {}) { + if (ind0.isNull()) + opSendg(Opcode::sendgxc, mod, sf, dst, src0, src1, NoOperand(), NoOperand(), desc); + else + opSendg(Opcode::sendgxc, mod, sf, dst, src0, src1, ind0, NoOperand(), desc); + } + void sendgxc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegisterRange &src0, const RegisterRange &src1, const RegData &ind0, const RegData &ind1, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgxc, mod, sf, dst, src0, src1, ind0, ind1, desc); + } + void sendgxc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, int src0Len, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgxc, mod, sf, dst, src0, src0Len, NoOperand(), NoOperand(), desc); + } + void sendgxc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, int src0Len, const RegData &ind0, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgxc, mod, sf, dst, src0, src0Len, ind0, NoOperand(), desc); + } + void sendgxc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, int src0Len, const RegData &ind0, const RegData &ind1, uint64_t desc, SourceLocation loc = {}) { + opSendg(Opcode::sendgxc, mod, sf, dst, src0, src0Len, ind0, ind1, desc); + } template void shl(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, SourceLocation loc = {}) { opX(isGen12 ? Opcode::shl_gen12 : Opcode::shl, getDataType
(), mod, dst, src0, src1); @@ -1431,6 +1622,79 @@ class AsmCodeGenerator { opX(isGen12 ? Opcode::xor_gen12 : Opcode::xor_, getDataType
(), mod, dst, src0, src1); } + template + void bdpas(const InstructionModifier &mod, uint8_t sdepth, uint8_t rcount, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &src2, const RegData &src3, const RegData &src4, SourceLocation loc = {}) { + auto emod = mod | defaultModifier; + if (emod.isAutoSWSB()) { + if (!src3.isARF()) wrdep(GRF(src3.getBase())); + if (!src4.isARF()) wrdep(GRF(src4.getBase())); + } + opBdpas(Opcode::bdpas, getDataType
(), mod, sdepth, rcount, dst, src0, src1, src2, src3, src4); + } + + template + void dnscl(const InstructionModifier &mod, uint8_t mode, RoundingType rnd, RegData dst, RegData src0, RegData src1, const RegData &src2, SourceLocation loc = {}) { + auto ctrl = encodeDnsclCtrl(mode, rnd, dst, src0, src1); + opX(Opcode::dnscl, getDataType
(), mod, dst, src0, src1, src2, ctrl); + } + +private: + struct LFSR { + AsmCodeGenerator &parent; + + LFSR(AsmCodeGenerator *parent_) : parent(*parent_) {} + + void operator()(LFSRFunction fc, const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { + parent.opX(Opcode::lfsr, DataType::invalid, mod, dst, src0, src1, NoOperand(), static_cast(fc)); + } + + template + void b32(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { + parent.opX(Opcode::lfsr, getDataType
(), mod, dst, src0, src1, NoOperand(), static_cast(LFSRFunction::b32)); + } + template + void b32(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { + parent.opX(Opcode::lfsr, getDataType
(), mod, dst, src0, src1, NoOperand(), static_cast(LFSRFunction::b32)); + } + template + void b16v2(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { + parent.opX(Opcode::lfsr, getDataType
(), mod, dst, src0, src1, NoOperand(), static_cast(LFSRFunction::b16v2)); + } + template + void b16v2(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { + parent.opX(Opcode::lfsr, getDataType
(), mod, dst, src0, src1, NoOperand(), static_cast(LFSRFunction::b16v2)); + } + template + void b8v4(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { + parent.opX(Opcode::lfsr, getDataType
(), mod, dst, src0, src1, NoOperand(), static_cast(LFSRFunction::b8v4)); + } + template + void b8v4(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { + parent.opX(Opcode::lfsr, getDataType
(), mod, dst, src0, src1, NoOperand(), static_cast(LFSRFunction::b8v4)); + } + }; +public: + LFSR lfsr; + +private: + struct Shfl { + AsmCodeGenerator &parent; + + Shfl(AsmCodeGenerator *parent_) : parent(*parent_) {} + + void operator()(ShuffleFunction fc, const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, SourceLocation loc = {}) { + parent.opX(Opcode::shfl, DataType::invalid, mod, dst, src0, src1, NoOperand(), static_cast(fc)); + } + + template + void idx4(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, SourceLocation loc = {}) { + parent.opX(Opcode::shfl, getDataType
(), mod, dst, src0, src1, NoOperand(), static_cast(ShuffleFunction::idx4)); + } + }; +public: + Shfl shfl; + + private: struct Sync { AsmCodeGenerator &parent; @@ -1698,21 +1962,27 @@ static const char *getMnemonic(Opcode op, HW hw) "bfe", "bfi1", "bfi2", "", "", "", "", "", "jmpi", "brd", "if", "brc", "else", "endif", "", "while", "break", "cont", "halt", "calla", "call", "ret", "goto", "join", - "wait", "send", "sendc", "sends", "sendsc", "", "", "", - "math", "", "", "", "", "", "", "", + "wait", "send", "sendc", "sendg", "sendgc", "sendgx", "sendgxc", "", + "math", "lfsr", "", "", "", "", "", "", "add", "mul", "avg", "frc", "rndu", "rndd", "rnde", "rndz", "mac", "mach", "lzd", "fbh", "fbl", "cbit", "addc", "subb", - "sad2", "sada2", "add3", "macl", "srnd", "dph", "dp3", "dp2", - "dp4a", "dpas", "dpasw", "mad", "lrp", "madm", "", "", + "shfl", "sada2", "add3", "macl", "srnd", "dnscl", "dp3", "dp2", + "dp4a", "dpas", "dpasw", "mad", "bdpas", "madm", "", "mullh", "nop", "mov", "sel", "movi", "not", "and", "or", "xor", "shr", "shl", "smov", "bfn", "asr", "", "ror", "rol", "cmp", "cmpn", "csel", "", "", "", "", "bfrev", "bfe", "bfi1", "bfi2", "", "", "", "nop", "" }; + const char *mnemonic = names[static_cast(op) & 0x7F]; if (hw < HW::Gen12LP) switch (op) { + case Opcode::sends: mnemonic = "sends"; break; + case Opcode::sendsc: mnemonic = "sendsc"; break; + case Opcode::sad2: mnemonic = "sad2"; break; + case Opcode::dph: mnemonic = "dph"; break; + case Opcode::lrp: mnemonic = "lrp"; break; case Opcode::mov: mnemonic = "mov"; break; case Opcode::line: mnemonic = "line"; break; case Opcode::pln: mnemonic = "pln"; break; @@ -1723,6 +1993,7 @@ static const char *getMnemonic(Opcode op, HW hw) return mnemonic; } + void AsmCodeGenerator::outComment(std::ostream &out, const AsmInstruction &i) { bool newLine = true; @@ -1747,6 +2018,8 @@ void AsmCodeGenerator::outX(std::ostream &out, const AsmInstruction &i, int &lin case Opcode::sends: case Opcode::sendc: case Opcode::sendsc: + case Opcode::sendgx: + case Opcode::sendgxc: ddst = dsrc[0] = dsrc[1] = PrintDetail::base; dsrc[2] = dsrc[3] = PrintDetail::sub_no_type; break; @@ -1761,6 +2034,9 @@ void AsmCodeGenerator::outX(std::ostream &out, const AsmInstruction &i, int &lin case Opcode::ret: dsrc[0] = PrintDetail::sub_no_type; break; + case Opcode::bdpas: + if (isGen12) dsrc[3] = dsrc[4] = PrintDetail::sub; + /* fall through */ case Opcode::dpas: case Opcode::dpasw: if (isGen12) ddst = dsrc[0] = dsrc[1] = dsrc[2] = PrintDetail::sub; @@ -1790,6 +2066,7 @@ void AsmCodeGenerator::outX(std::ostream &out, const AsmInstruction &i, int &lin bool showLen = false; if (i.ext & 0x80) { showLen |= (n == 1 && (i.op == Opcode::send || i.op == Opcode::sendc) && hardware >= HW::XeHPG); + showLen |= (n == 0 && (i.op == Opcode::sendg || i.op == Opcode::sendgc || i.op == Opcode::sendgx || i.op == Opcode::sendgxc)); } if (showLen) @@ -1805,6 +2082,7 @@ void AsmCodeGenerator::outX(std::ostream &out, const AsmInstruction &i, int &lin out << std::endl; } + void AsmCodeGenerator::outExt(std::ostream &out, const AsmInstruction &i) { switch (i.opcode()) { @@ -1816,12 +2094,26 @@ void AsmCodeGenerator::outExt(std::ostream &out, const AsmInstruction &i) } if (isGen12) switch (i.opcode()) { + case Opcode::sendgx: + case Opcode::sendgxc: case Opcode::send: case Opcode::sendc: case Opcode::sends: case Opcode::sendsc: out << '.' << getMnemonic(static_cast(i.ext & 0xF), hardware); break; case Opcode::sync: out << '.' << static_cast(i.ext); break; case Opcode::bfn: out << ".0x" << std::hex << i.ext << std::dec; break; + case Opcode::dnscl: { + const char *sts[2] = {"hf", "bf"}; + const char *dts[4] = {"e3m0", "e2m1", "int4", ""}; + const char *rts[2] = {"srnd", "rne"}; + int dt = i.ext & 0x3, mode = i.ext >> 4; + bool st = i.ext & 0x4, rt = i.ext & 0x8; + out << '.' << sts[st] << "to" << dts[dt] << ".mode" << mode << '.' << rts[rt]; + break; + } + case Opcode::lfsr: out << '.' << static_cast(i.ext); break; + case Opcode::shfl: out << '.' << static_cast(i.ext); break; + case Opcode::bdpas: case Opcode::dpas: case Opcode::dpasw: { int sdepth = i.ext >> 8; @@ -1917,6 +2209,7 @@ void AsmCodeGenerator::outMods(std::ostream &out, const InstructionModifier &mod if (!isGen12 && mod.getThreadCtrl() == ThreadCtrl::Switch) printPostMod("Switch"); if (!isGen12 && mod.getThreadCtrl() == ThreadCtrl::NoPreempt) printPostMod("NoPreempt"); if (mod.isAccWrEn() && hardware < HW::XeHPC) printPostMod("AccWrEn"); + if (mod.isFwd() && hardware >= HW::XeHPC) printPostMod("Fwd"); if (mod.isCompact()) printPostMod("Compact"); if (mod.isBreakpoint()) printPostMod("Breakpoint"); if (mod.isSerialized()) printPostMod("Serialize"); @@ -1929,6 +2222,7 @@ void AsmCodeGenerator::outMods(std::ostream &out, const InstructionModifier &mod } } + } /* namespace NGEN_NAMESPACE */ #endif diff --git a/third_party/ngen/ngen_auto_swsb.hpp b/third_party/ngen/ngen_auto_swsb.hpp index 1226db79a7e..68e395c009e 100644 --- a/third_party/ngen/ngen_auto_swsb.hpp +++ b/third_party/ngen/ngen_auto_swsb.hpp @@ -116,7 +116,7 @@ class GeneralizedPipe { }; struct DependencyRegion { - uint8_t base, size; + uint16_t base, size; uint8_t unspecified : 1; uint8_t checkWAW : 1; uint8_t rf : 2; @@ -143,6 +143,12 @@ struct DependencyRegion { } void clear() { *this = DependencyRegion(hw); unspecified = false; checkWAW = false; rf = 0; } + void duplicateLH() { + for (int i = 0; i < size; i++) + masks[size + i] = masks[i]; + size *= 2; + } + #ifdef NGEN_DEBUG inline void dump() const; #endif @@ -218,7 +224,7 @@ class DependencyTable { }; enum : int { - maxGRF = 256, + maxGRF = 512, grfListIdxUnspecified = maxGRF // GRF list index for all unspecified regions. }; @@ -232,7 +238,7 @@ class DependencyTable { std::vector> deps; // List of all Dependencies (active or not) std::vector frags; // List of all DependencyFragments (active or not) - std::array heads[NListTypes]; // Heads of doubly-linked lists. + std::array heads[NListTypes]; // Heads of doubly-linked lists. static bool isHeadLink(uint32_t id) { return ((id & 0x80000000) != 0) && (id != none); } static uint32_t readHeadLink(uint32_t id) { return id & 0x7FFFFFFF; } @@ -295,7 +301,7 @@ struct BasicBlock { DependencyTable incoming; // Table of dependencies produced by prior BBs (temporary). std::vector syncs; // List of sync instructions to generate. std::vector movs; // List of mov instructions to generate. - std::vector> opRegions; // Cache of instruction operand regions. + std::vector> opRegions; // Cache of instruction operand regions. bool enablePVCWARWA = false; // Enable workaround for PVC WAR bug. const DependencyRegion &getOperandRegion(int inum, int opNum) const { @@ -339,6 +345,7 @@ inline GeneralizedPipe getPipe(HW hw, const Instruction &insn, bool checkOOO = t { auto op = insn.opcode(); + // Check jumps and no-ops if (isBranch(op) || op == Opcode::nop_gen12 || op == Opcode::sync || op == Opcode::illegal || op == Opcode::directive) return GeneralizedPipe(); @@ -348,6 +355,7 @@ inline GeneralizedPipe getPipe(HW hw, const Instruction &insn, bool checkOOO = t if (!checkOOO) return GeneralizedPipe(); switch (op) { + case Opcode::bdpas: case Opcode::dpas: case Opcode::dpasw: return GeneralizedPipe::Systolic(); @@ -356,6 +364,10 @@ inline GeneralizedPipe getPipe(HW hw, const Instruction &insn, bool checkOOO = t return GeneralizedPipe::Math(); case Opcode::send: case Opcode::sendc: + case Opcode::sendg: + case Opcode::sendgc: + case Opcode::sendgx: + case Opcode::sendgxc: return GeneralizedPipe(insn.sfid()); } } @@ -375,6 +387,8 @@ inline GeneralizedPipe getPipe(HW hw, const Instruction &insn, bool checkOOO = t unsigned lmask = (hw >= HW::XeHPC) ? 0b1011 : 0b0011; if ((dt & lmask) == lmask) mask = PipeMaskL; + else if ((hw >= HW::XE3P_35_10) && (op == Opcode::mov_gen12 || op == Opcode::srnd) && (dt != insn.src0Typecode())) + mask = PipeMaskI; else if (dt & 8) mask = PipeMaskF; else @@ -465,6 +479,7 @@ DependencyRegion::DependencyRegion(HW hw_, int esize, RegData rr) checkWAW = false; rf = rr.getRegFile(); + int hs = rr.getHS(), vs = rr.getVS(); int nh = rr.getWidth(); if (nh == 0) nh = 1; @@ -622,8 +637,13 @@ inline int estimateLatency(HW hw, const Instruction &insn) switch (insn.opcode()) { default: case Opcode::math: return (hw == HW::Gen12LP) ? 20 : 17; + case Opcode::bdpas: case Opcode::dpas: case Opcode::dpasw: return 20; // need correct value + case Opcode::sendg: + case Opcode::sendgc: + case Opcode::sendgx: + case Opcode::sendgxc: case Opcode::send: case Opcode::sendc: { switch (insn.sfid()) { @@ -1394,7 +1414,7 @@ inline BasicBlockList getBasicBlocks(HW hw, const Program &program) // Decode and cache operand regions, handling any nodep pseudo-instructions. bb.opRegions.resize(bb.iend - bb.istart); - std::array ignoreDeps = {false}; + std::array ignoreDeps = {false}; DependencyRegion subDstRegion(hw); subDstRegion.clear(); @@ -1409,6 +1429,8 @@ inline BasicBlockList getBasicBlocks(HW hw, const Program &program) case Directive::ignoredep_src0: ignoreDeps[1] = true; break; case Directive::ignoredep_src1: ignoreDeps[2] = true; break; case Directive::ignoredep_src2: ignoreDeps[3] = true; break; + case Directive::ignoredep_src3: ignoreDeps[4] = true; break; + case Directive::ignoredep_src4: ignoreDeps[5] = true; break; case Directive::subdep_dst: #ifdef NGEN_SAFE if (!subDstRegion.empty()) @@ -1428,7 +1450,7 @@ inline BasicBlockList getBasicBlocks(HW hw, const Program &program) continue; } - for (int srcN = -1; srcN < 3; srcN++) { + for (int srcN = -1; srcN < 5; srcN++) { regions[srcN + 1].hw = hw; if (ignoreDeps[srcN + 1] || !insn.getOperandRegion(regions[srcN + 1], srcN)) regions[srcN + 1].clear(); @@ -1646,7 +1668,7 @@ inline uint8_t chooseSBID(HW hw, int tokens, Program &program, const BasicBlock // Priority 2: assign SBID based on base register of dst, src1, src0 (in that order), // if it's unclaimed or expired. - for (int opNum : {-1, 1, 0}) { + for (int opNum : {-1, 1, 0, 2, 3}) { auto ®ion = bb.getOperandRegion(inum, opNum); if (region.size > 0) { auto sbid = preferredSBID(tokens, region.base); @@ -1737,12 +1759,12 @@ PVCWARWA analyzePVCWARWA(HW hw, Program &program, BasicBlock &bb, int phase, if (sameBB && consumeOp.pipe.type() != GeneralizedPipe::vSystolic) { // Check if we have a src at least as large as our dst. int srcN; - for (srcN = 0; srcN <= 2; srcN++) { + for (srcN = 0; srcN <= 4; srcN++) { if (regions[srcN + 1].unspecified) continue; if (bboxContains(regions[srcN + 1], regions[0])) break; } - if (srcN >= 2) srcN = -1; + if (srcN >= 4) srcN = -1; // Check for potential read suppression. if (srcN >= 0 && consumeOp.pipe.inOrder()) { @@ -1794,7 +1816,7 @@ PVCWARWA analyzePVCWARWA(HW hw, Program &program, BasicBlock &bb, int phase, } // Case 2: walk forward, looking for a new target send instruction. - auto eligibleSend = [=, &program, &dep](uint32_t inum) { + auto eligibleSend = [=, &program](uint32_t inum) { auto &insn = program[inum]; if (inum != dep.inum && insn.predicated()) return false; @@ -2065,7 +2087,7 @@ inline void analyze(HW hw, int tokens, Program &program, BasicBlock &bb, int pha bool assignSBID = (phase == 1) && tokenInsn && tokenInfo.tokenTBD && !insn.atomic(); // Collect operands. - for (int srcN = 2; srcN >= -1; srcN--) { + for (int srcN = 4; srcN >= -1; srcN--) { // Skip non-GRF operands. // Special case: check for cr/sr/ce source operands and force A@1 if any. if (regions[srcN + 1].empty()) { @@ -2395,6 +2417,13 @@ inline void analyze(HW hw, int tokens, Program &program, BasicBlock &bb, int pha // First pass: record pipeline SWSB dependencies for later entry into consumer table. recordIOPreconsumes(generated); + // mullh takes up two execution slots. Consume dependencies on the first; + // (already done); produce them on the second (up next). + if (opcode == Opcode::mullh) { + incrementCounters(getPipeMask(hw, insn)); + consumeOp.counters = counters; + } + // Add producer dependencies for all operands. // Also record token timeout. // During phase 0, only do this for OOO instructions, and if dst not null, only dst. @@ -2405,7 +2434,7 @@ inline void analyze(HW hw, int tokens, Program &program, BasicBlock &bb, int pha produceOp.tokenTime = estimateLatency(hw, insn); } - for (int srcN = -1; srcN < 3; srcN++) { + for (int srcN = -1; srcN < 5; srcN++) { if (!regions[srcN + 1].empty()) { produceOp.rw = (srcN < 0); if (tokenInfo.hasToken()) { diff --git a/third_party/ngen/ngen_compiler_fix.hpp b/third_party/ngen/ngen_compiler_fix.hpp index ac79ad98aee..45bdca053ed 100644 --- a/third_party/ngen/ngen_compiler_fix.hpp +++ b/third_party/ngen/ngen_compiler_fix.hpp @@ -86,6 +86,70 @@ void _workaround_() { (void) r244.getBase(); (void) r245.getBase(); (void) r246.getBase(); (void) r247.getBase(); (void) r248.getBase(); (void) r249.getBase(); (void) r250.getBase(); (void) r251.getBase(); (void) r252.getBase(); (void) r253.getBase(); (void) r254.getBase(); (void) r255.getBase(); + (void) r256.getBase(); (void) r257.getBase(); (void) r258.getBase(); (void) r259.getBase(); + (void) r260.getBase(); (void) r261.getBase(); (void) r262.getBase(); (void) r263.getBase(); + (void) r264.getBase(); (void) r265.getBase(); (void) r266.getBase(); (void) r267.getBase(); + (void) r268.getBase(); (void) r269.getBase(); (void) r270.getBase(); (void) r271.getBase(); + (void) r272.getBase(); (void) r273.getBase(); (void) r274.getBase(); (void) r275.getBase(); + (void) r276.getBase(); (void) r277.getBase(); (void) r278.getBase(); (void) r279.getBase(); + (void) r280.getBase(); (void) r281.getBase(); (void) r282.getBase(); (void) r283.getBase(); + (void) r284.getBase(); (void) r285.getBase(); (void) r286.getBase(); (void) r287.getBase(); + (void) r288.getBase(); (void) r289.getBase(); (void) r290.getBase(); (void) r291.getBase(); + (void) r292.getBase(); (void) r293.getBase(); (void) r294.getBase(); (void) r295.getBase(); + (void) r296.getBase(); (void) r297.getBase(); (void) r298.getBase(); (void) r299.getBase(); + (void) r300.getBase(); (void) r301.getBase(); (void) r302.getBase(); (void) r303.getBase(); + (void) r304.getBase(); (void) r305.getBase(); (void) r306.getBase(); (void) r307.getBase(); + (void) r308.getBase(); (void) r309.getBase(); (void) r310.getBase(); (void) r311.getBase(); + (void) r312.getBase(); (void) r313.getBase(); (void) r314.getBase(); (void) r315.getBase(); + (void) r316.getBase(); (void) r317.getBase(); (void) r318.getBase(); (void) r319.getBase(); + (void) r320.getBase(); (void) r321.getBase(); (void) r322.getBase(); (void) r323.getBase(); + (void) r324.getBase(); (void) r325.getBase(); (void) r326.getBase(); (void) r327.getBase(); + (void) r328.getBase(); (void) r329.getBase(); (void) r330.getBase(); (void) r331.getBase(); + (void) r332.getBase(); (void) r333.getBase(); (void) r334.getBase(); (void) r335.getBase(); + (void) r336.getBase(); (void) r337.getBase(); (void) r338.getBase(); (void) r339.getBase(); + (void) r340.getBase(); (void) r341.getBase(); (void) r342.getBase(); (void) r343.getBase(); + (void) r344.getBase(); (void) r345.getBase(); (void) r346.getBase(); (void) r347.getBase(); + (void) r348.getBase(); (void) r349.getBase(); (void) r350.getBase(); (void) r351.getBase(); + (void) r352.getBase(); (void) r353.getBase(); (void) r354.getBase(); (void) r355.getBase(); + (void) r356.getBase(); (void) r357.getBase(); (void) r358.getBase(); (void) r359.getBase(); + (void) r360.getBase(); (void) r361.getBase(); (void) r362.getBase(); (void) r363.getBase(); + (void) r364.getBase(); (void) r365.getBase(); (void) r366.getBase(); (void) r367.getBase(); + (void) r368.getBase(); (void) r369.getBase(); (void) r370.getBase(); (void) r371.getBase(); + (void) r372.getBase(); (void) r373.getBase(); (void) r374.getBase(); (void) r375.getBase(); + (void) r376.getBase(); (void) r377.getBase(); (void) r378.getBase(); (void) r379.getBase(); + (void) r380.getBase(); (void) r381.getBase(); (void) r382.getBase(); (void) r383.getBase(); + (void) r384.getBase(); (void) r385.getBase(); (void) r386.getBase(); (void) r387.getBase(); + (void) r388.getBase(); (void) r389.getBase(); (void) r390.getBase(); (void) r391.getBase(); + (void) r392.getBase(); (void) r393.getBase(); (void) r394.getBase(); (void) r395.getBase(); + (void) r396.getBase(); (void) r397.getBase(); (void) r398.getBase(); (void) r399.getBase(); + (void) r400.getBase(); (void) r401.getBase(); (void) r402.getBase(); (void) r403.getBase(); + (void) r404.getBase(); (void) r405.getBase(); (void) r406.getBase(); (void) r407.getBase(); + (void) r408.getBase(); (void) r409.getBase(); (void) r410.getBase(); (void) r411.getBase(); + (void) r412.getBase(); (void) r413.getBase(); (void) r414.getBase(); (void) r415.getBase(); + (void) r416.getBase(); (void) r417.getBase(); (void) r418.getBase(); (void) r419.getBase(); + (void) r420.getBase(); (void) r421.getBase(); (void) r422.getBase(); (void) r423.getBase(); + (void) r424.getBase(); (void) r425.getBase(); (void) r426.getBase(); (void) r427.getBase(); + (void) r428.getBase(); (void) r429.getBase(); (void) r430.getBase(); (void) r431.getBase(); + (void) r432.getBase(); (void) r433.getBase(); (void) r434.getBase(); (void) r435.getBase(); + (void) r436.getBase(); (void) r437.getBase(); (void) r438.getBase(); (void) r439.getBase(); + (void) r440.getBase(); (void) r441.getBase(); (void) r442.getBase(); (void) r443.getBase(); + (void) r444.getBase(); (void) r445.getBase(); (void) r446.getBase(); (void) r447.getBase(); + (void) r448.getBase(); (void) r449.getBase(); (void) r450.getBase(); (void) r451.getBase(); + (void) r452.getBase(); (void) r453.getBase(); (void) r454.getBase(); (void) r455.getBase(); + (void) r456.getBase(); (void) r457.getBase(); (void) r458.getBase(); (void) r459.getBase(); + (void) r460.getBase(); (void) r461.getBase(); (void) r462.getBase(); (void) r463.getBase(); + (void) r464.getBase(); (void) r465.getBase(); (void) r466.getBase(); (void) r467.getBase(); + (void) r468.getBase(); (void) r469.getBase(); (void) r470.getBase(); (void) r471.getBase(); + (void) r472.getBase(); (void) r473.getBase(); (void) r474.getBase(); (void) r475.getBase(); + (void) r476.getBase(); (void) r477.getBase(); (void) r478.getBase(); (void) r479.getBase(); + (void) r480.getBase(); (void) r481.getBase(); (void) r482.getBase(); (void) r483.getBase(); + (void) r484.getBase(); (void) r485.getBase(); (void) r486.getBase(); (void) r487.getBase(); + (void) r488.getBase(); (void) r489.getBase(); (void) r490.getBase(); (void) r491.getBase(); + (void) r492.getBase(); (void) r493.getBase(); (void) r494.getBase(); (void) r495.getBase(); + (void) r496.getBase(); (void) r497.getBase(); (void) r498.getBase(); (void) r499.getBase(); + (void) r500.getBase(); (void) r501.getBase(); (void) r502.getBase(); (void) r503.getBase(); + (void) r504.getBase(); (void) r505.getBase(); (void) r506.getBase(); (void) r507.getBase(); + (void) r508.getBase(); (void) r509.getBase(); (void) r510.getBase(); (void) r511.getBase(); (void) null.getBase(); (void) a0.getBase(); @@ -133,6 +197,7 @@ void _workaround_() { (void) NoDDChk.getAll(); (void) AccWrEn.getAll(); (void) NoSrcDepSet.getAll(); + (void) Fwd.getAll(); (void) Breakpoint.getAll(); (void) sat.getAll(); (void) NoMask.getAll(); @@ -185,6 +250,8 @@ void _workaround_() { (void) A64.getModel(); (void) A64NC.getModel(); (void) SLM.getModel(); + (void) A64_A32U.getModel(); + (void) A64_A32S.getModel(); (void) D8.desc; (void) D8T.desc; (void) D16.desc; (void) D16T.desc; (void) D32.desc; (void) D32T.desc; @@ -201,5 +268,6 @@ void _workaround_() { (void) V64.desc; (void) V64T.desc; (void) transpose.desc; (void) vnni.desc; + (void) Overfetch.desc; } diff --git a/third_party/ngen/ngen_core.hpp b/third_party/ngen/ngen_core.hpp index 4979826ef83..5b65ba9b517 100644 --- a/third_party/ngen/ngen_core.hpp +++ b/third_party/ngen/ngen_core.hpp @@ -229,7 +229,15 @@ class invalid_address_mode_exception : public std::runtime_error { }; class invalid_address_modifier_exception : public std::runtime_error { public: - invalid_address_modifier_exception(SourceLocation loc = {}) : std::runtime_error("Invalid address offset" + loc.str(" at ")) {} + invalid_address_modifier_exception(SourceLocation loc = {}) : std::runtime_error("Invalid address offset or scaling factor" + loc.str(" at ")) {} +}; +class limited_to_256_grf_exception : public std::runtime_error { +public: + limited_to_256_grf_exception(SourceLocation loc = {}) : std::runtime_error("This instruction only supports r0-r255" + loc.str(" at ")) {} +}; +class r511_not_allowed_exception : public std::runtime_error { +public: + r511_not_allowed_exception(SourceLocation loc = {}) : std::runtime_error("r511 cannot be used here" + loc.str(" at ")) {} }; #endif @@ -249,6 +257,9 @@ enum class Core { Gen12p8 = XeHPC, /* Deprecated -- will be removed in the future */ Xe2, Xe3, + XE3P_35_10, + XE3P_35_11, + XE3P_UNKNOWN, }; typedef Core HW; @@ -275,6 +286,10 @@ enum class ProductFamily : int { BMG, LNL, GenericXe3, + GenericXe3p, + XE3P_35_10, + XE3P_35_11, + XE3P_UNKNOWN, }; enum class PlatformType {Unknown, Integrated, Discrete}; @@ -301,12 +316,15 @@ static inline constexpr14 PlatformType getPlatformType(ProductFamily family) { case ProductFamily::MTL: case ProductFamily::ARL: case ProductFamily::LNL: + case ProductFamily::XE3P_35_10: return PlatformType::Integrated; // Could be integrated or discrete case ProductFamily::GenericXeLP: case ProductFamily::GenericXeHPG: case ProductFamily::GenericXe2: case ProductFamily::GenericXe3: + case ProductFamily::GenericXe3p: + case ProductFamily::XE3P_UNKNOWN: return PlatformType::Unknown; // Guaranteed discrete case ProductFamily::GenericXeHP: @@ -315,6 +333,7 @@ static inline constexpr14 PlatformType getPlatformType(ProductFamily family) { case ProductFamily::PVC: case ProductFamily::PVCVG: case ProductFamily::BMG: + case ProductFamily::XE3P_35_11: return PlatformType::Discrete; case ProductFamily::Unknown: return PlatformType::Unknown; @@ -334,12 +353,19 @@ static inline constexpr14 ProductFamily genericProductFamily(HW hw) case HW::XeHPC: return ProductFamily::GenericXeHPC; case HW::Xe2: return ProductFamily::GenericXe2; case HW::Xe3: return ProductFamily::GenericXe3; + case HW::XE3P_35_10: + case HW::XE3P_35_11: + case HW::XE3P_UNKNOWN: return ProductFamily::GenericXe3p; default: return ProductFamily::Unknown; } } static inline constexpr14 Core getCore(ProductFamily family) { + if (family >= ProductFamily::XE3P_UNKNOWN) return Core::XE3P_UNKNOWN; + if (family >= ProductFamily::XE3P_35_11) return Core::XE3P_35_11; + if (family >= ProductFamily::XE3P_35_10) return Core::XE3P_35_10; + if (family >= ProductFamily::GenericXe3p) return Core::XE3P_35_10; if (family >= ProductFamily::GenericXe3) return Core::Xe3; if (family >= ProductFamily::GenericXe2) return Core::Xe2; if (family >= ProductFamily::GenericXeHPC) return Core::XeHPC; @@ -390,14 +416,16 @@ enum class DataType : uint8_t { s4 = 0x5D, u2 = 0x3E, s2 = 0x3F, + e2m1 = 0x5A, + e3m0 = 0x5B, invalid = 0x60 }; #ifdef NGEN_ASM static inline std::ostream &operator<<(std::ostream &str, DataType type) { - static const char *names[32] = {"ud", "d", "uw", "w", "ub", "b", "df", "f", "uq", "q", "hf", "bf", "bf8", "uv", "v", "vf", - "tf32", "hf8", "", "", "", "", "", "", "", "", "", "", "u4", "s4", "u2", "s2"}; + static const char *names[32] = {"ud", "d", "uw", "w", "ub", "b", "df", "f", "uq", "q", "hf", "bf", "bf8", "uv", "v", "vf", + "tf32", "hf8", "", "", "", "", "", "", "", "", "e2m1", "e3m0", "u4", "s4", "u2", "s2"}; str << names[static_cast(type) & 0x1F]; return str; } @@ -458,6 +486,13 @@ template <> inline DataType getDataType() { return DataType::u2; } #ifdef NGEN_INT2_TYPE template <> inline DataType getDataType() { return DataType::s2; } #endif +#ifdef NGEN_E2M1_TYPE +template <> inline DataType getDataType() { return DataType::e2m1; } +#endif +#ifdef NGEN_E3M0_TYPE +template <> inline DataType getDataType() { return DataType::e3m0; } +#endif + static inline constexpr14 DataType rawType(DataType dt) { switch (getLog2Bits(dt)) { @@ -485,11 +520,17 @@ enum class MathFunction : uint8_t { irem = 0xD, invm = 0xE, rsqtm = 0xF, + tanh = 0x19, + sigm = 0x1A, }; static inline int mathArgCount(HW hw, MathFunction func) { + if (hw >= HW::XE3P_35_10) { + static const char argCounts[16] = {0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 2, 2, 2, 2, 1}; + return argCounts[static_cast(func) & 0xF]; + } static const char argCounts[16] = {0, 1, 1, 1, 1, 1, 1, 1, 0, 2, 2, 2, 2, 2, 2, 1}; return argCounts[static_cast(func) & 0xF]; } @@ -497,8 +538,9 @@ static inline int mathArgCount(HW hw, MathFunction func) #ifdef NGEN_ASM static inline std::ostream &operator<<(std::ostream &str, MathFunction func) { - static const char *names[16] = {"", "inv", "log", "exp", "sqt", "rsqt", "sin", "cos", "", "fdiv", "pow", "idiv", "iqot", "irem", "invm", "rsqtm"}; - str << names[static_cast(func) & 0xF]; + static const char *names[32] = {"", "inv", "log", "exp", "sqt", "rsqt", "sin", "cos", "", "fdiv", "pow", "idiv", "iqot", "irem", "invm", "rsqtm", + "", "", "", "", "", "", "", "", "", "tanh", "sigm", "", "", "", "", ""}; + str << names[static_cast(func) & 0x1F]; return str; } #endif @@ -530,6 +572,42 @@ static inline std::ostream &operator<<(std::ostream &str, SyncFunction func) #endif +// Rounding types for dnscl. +enum class RoundingType : uint8_t { + rne = 0, + srnd = 1, +}; + +// Shuffle function codes. +enum class ShuffleFunction : uint8_t { + idx4 = 0x6, +}; + +#ifdef NGEN_ASM +static inline std::ostream &operator<<(std::ostream &str, ShuffleFunction func) +{ + static const char *names[16] = {"", "", "", "", "", "", "idx4", "", "", "", "", "", "", "", "", ""}; + str << names[static_cast(func) & 0xF]; + return str; +} +#endif + +// LFSR function codes. +enum class LFSRFunction : uint8_t { + b32 = 0, + b16v2 = 1, + b8v4 = 2, +}; + +#ifdef NGEN_ASM +static inline std::ostream &operator<<(std::ostream &str, LFSRFunction func) +{ + static const char *names[4] = {"b32", "b16v2", "b8v4", ""}; + str << names[static_cast(func) & 0x3]; + return str; +} +#endif + // Shared function IDs (SFIDs). enum class SharedFunction : uint8_t { null = 0x0, @@ -748,7 +826,7 @@ class RegData { #endif return static_cast(rf); } - constexpr bool isARF() const { return isValid() && rf == RegFileARF; } + constexpr bool isARF() const { return rf == RegFileARF; } constexpr int getARFBase() const { return base & 0xF; } constexpr ARFType getARFType() const { return static_cast(base >> 4); } constexpr bool isIndirect() const { return indirect; } @@ -1016,6 +1094,8 @@ class Subregister : public RegData Subregister tf32(int offset = 0) const { return reinterpret(offset, DataType::tf32); } Subregister bf8(int offset = 0) const { return reinterpret(offset, DataType::bf8); } Subregister hf8(int offset = 0) const { return reinterpret(offset, DataType::hf8); } + Subregister e2m1(int offset = 0) const { return reinterpret(offset, DataType::e2m1); } + Subregister e3m0(int offset = 0) const { return reinterpret(offset, DataType::e3m0); } }; // Single register. @@ -1081,6 +1161,8 @@ class Register : public RegData constexpr14 Subregister tf32(int offset) const { return sub(offset, DataType::tf32); } constexpr14 Subregister bf8(int offset) const { return sub(offset, DataType::bf8); } constexpr14 Subregister hf8(int offset) const { return sub(offset, DataType::hf8); } + constexpr14 Subregister e2m1(int offset) const { return sub(offset, DataType::e2m1); } + constexpr14 Subregister e3m0(int offset) const { return sub(offset, DataType::e3m0); } constexpr14 Register uq() const { return retype(DataType::uq); } constexpr14 Register q() const { return retype(DataType::q); } @@ -1101,6 +1183,8 @@ class Register : public RegData constexpr14 Register tf32() const { return retype(DataType::tf32); } constexpr14 Register bf8() const { return retype(DataType::bf8); } constexpr14 Register hf8() const { return retype(DataType::hf8); } + constexpr14 Register e2m1() const { return retype(DataType::e2m1); } + constexpr14 Register e3m0() const { return retype(DataType::e3m0); } constexpr14 Subregister operator[](int offset) const { return sub(offset, getType()); } @@ -1143,6 +1227,8 @@ class GRF : public Register constexpr14 Subregister tf32(int offset) const { return sub(offset, DataType::tf32); } constexpr14 Subregister bf8(int offset) const { return sub(offset, DataType::bf8); } constexpr14 Subregister hf8(int offset) const { return sub(offset, DataType::hf8); } + constexpr14 Subregister e2m1(int offset) const { return sub(offset, DataType::e2m1); } + constexpr14 Subregister e3m0(int offset) const { return sub(offset, DataType::e3m0); } constexpr14 GRF uq() const { return retype(DataType::uq); } constexpr14 GRF q() const { return retype(DataType::q); } @@ -1163,12 +1249,15 @@ class GRF : public Register constexpr14 GRF tf32() const { return retype(DataType::tf32); } constexpr14 GRF bf8() const { return retype(DataType::bf8); } constexpr14 GRF hf8() const { return retype(DataType::hf8); } + constexpr14 GRF e2m1() const { return retype(DataType::e2m1); } + constexpr14 GRF e3m0() const { return retype(DataType::e3m0); } Align16Operand swizzle(int s0, int s1, int s2, int s3) const { return Align16Operand(*this, s0, s1, s2, s3); } Align16Operand enable(bool c0, bool c1, bool c2, bool c3) const { return Align16Operand(*this, (int(c3) << 3) | (int(c2) << 2) | (int(c1) << 1) | int(c0)); } Align16Operand noSwizzle() const { return swizzle(0, 1, 2, 3); } Align16Operand enableAll() const { return enable(true, true, true, true); } + GRF &operator=(const Invalid &i) { this->invalidate(); return *this; } GRF &operator+=(const int &inc) { @@ -1198,14 +1287,16 @@ class GRF : public Register inline GRFDisp operator+(Offset2D offset) const; inline GRFDisp operator-(Offset2D offset) const; + inline GRFDisp operator*(int scale) const; static constexpr int log2Bytes(HW hw) { return (hw >= HW::XeHPC) ? 6 : 5; } static constexpr int bytes(HW hw) { return (1 << log2Bytes(hw)); } static constexpr int bytesToGRFs(HW hw, unsigned x) { return (x + bytes(hw) - 1) >> log2Bytes(hw); } - static constexpr int maxRegs() { return 256; } + static constexpr int maxRegs() { return 512; } }; + class ARF : public Register { public: @@ -1274,6 +1365,7 @@ constexpr14 RegData RegData::getIndirectReg() const { return ARF(type, 0)[getIndirectOff()]; } + // An "extended register" is a combination of a regular GRF and some extra accumulator bits, used for math macro operations. class ExtendedReg { RegData base; @@ -1431,6 +1523,7 @@ class FlowControlRegister : public ARF explicit constexpr FlowControlRegister(int reg_ = 0) : ARF(ARFType::fc, reg_, DataType::ud) {} }; + class Offset2D { public: int16_t x, y; @@ -1443,6 +1536,8 @@ class GRFDisp { protected: GRF base; int32_t disp; + uint16_t scale = 0; + int16_t ind0SubReg = -1; public: GRFDisp(const GRF &base_, int32_t disp_) : base(base_), disp(disp_) {} @@ -1452,6 +1547,10 @@ class GRFDisp { case RegFileGRF: base = reinterpret_cast(rd); return; case RegFileARF: if (rd.getARFType() == ARFType::null) return; + if (rd.getARFType() == ARFType::s) { + ind0SubReg = rd.getByteOffset(); + return; + } break; default: break; } @@ -1462,6 +1561,10 @@ class GRFDisp { GRFDisp(const GRF &base_, Offset2D offset) : base(base_), disp((uint32_t(uint16_t(offset.y)) << 16) | uint16_t(offset.x)) {} + GRFDisp(const GRF &base_, int32_t disp_, int scale_, int ind0SubReg_ = -1) : base(base_), disp(disp_), scale(scale_), ind0SubReg(ind0SubReg_) {} + GRFDisp(const GRF &base_, int32_t disp_, int scale_, ScalarRegister ind0) : base(base_), disp(disp_), scale(scale_), ind0SubReg(ind0.getByteOffset()) {} + + constexpr GRF getBase() const { return base; } constexpr int32_t getDisp() const { return disp; } @@ -1470,8 +1573,17 @@ class GRFDisp { void clearDisp() { disp = 0; } - GRFDisp operator+(int offset) const { return GRFDisp(base, disp + offset); } - GRFDisp operator-(int offset) const { return GRFDisp(base, disp - offset); } + constexpr int getScale() const { return scale; } + + RegData getInd0() const { + if (ind0SubReg >= 0) + return ScalarRegister(0)[ind0SubReg]; + else + return NullRegister(); + } + + GRFDisp operator+(int offset) const { return GRFDisp(base, disp + offset, scale, ind0SubReg); } + GRFDisp operator-(int offset) const { return GRFDisp(base, disp - offset, scale, ind0SubReg); } }; GRFDisp GRF::operator+(int offset) const { return GRFDisp(*this, offset); } @@ -1480,6 +1592,18 @@ GRFDisp GRF::operator-(int offset) const { return *this + (-offset); } GRFDisp GRF::operator+(Offset2D offset) const { return GRFDisp(*this, offset); } GRFDisp GRF::operator-(Offset2D offset) const { return *this + (-offset); } +GRFDisp GRF::operator*(int scale) const { return GRFDisp(*this, 0, scale); } + +inline GRFDisp operator+(ScalarRegister s, GRF base) { + return GRFDisp(base, 0, 0, s); +} +inline GRFDisp operator+(ScalarRegister s, GRFDisp addr) { + return GRFDisp(addr.getBase(), addr.getDisp(), addr.getScale(), s); +} +inline GRFDisp operator+(GRF base, ScalarRegister s) { return s + base; } +inline GRFDisp operator+(GRFDisp addr, ScalarRegister s) { return s + addr; } + + GRFDisp Subregister::operator+(int offset) const { #ifdef NGEN_SAFE @@ -1583,10 +1707,10 @@ class IndirectRegisterFrame { // GRFRange represents a contiguous range of GRF registers. class GRFRange { protected: - uint8_t base; - uint8_t len; + uint16_t base; + uint16_t len; - static constexpr uint8_t invalidLen = 0xFF; + static constexpr uint16_t invalidLen = 0xFFFF; public: GRFRange() : GRFRange(0, invalidLen) {} @@ -1727,6 +1851,7 @@ enum class ThreadCtrl { NoPreempt = 3 }; + enum class Opcode { illegal = 0x00, sync = 0x01, @@ -1770,7 +1895,12 @@ enum class Opcode { sendc = 0x32, sends = 0x33, sendsc = 0x34, + sendg = 0x33, + sendgc = 0x34, + sendgx = 0x35, + sendgxc = 0x36, math = 0x38, + lfsr = 0x39, add = 0x40, mul = 0x41, avg = 0x42, @@ -1788,12 +1918,14 @@ enum class Opcode { addc = 0x4E, subb = 0x4F, sad2 = 0x50, + shfl = 0x50, sada2 = 0x51, add3 = 0x52, macl = 0x53, srnd = 0x54, dp4 = 0x54, dph = 0x55, + dnscl = 0x55, dp3 = 0x56, dp2 = 0x57, dp4a = 0x58, @@ -1803,7 +1935,9 @@ enum class Opcode { dpasw = 0x5A, mad = 0x5B, lrp = 0x5C, + bdpas = 0x5C, madm = 0x5D, + mullh = 0x5F, nop_gen12 = 0x60, mov_gen12 = 0x61, sel_gen12 = 0x62, @@ -1830,13 +1964,16 @@ enum class Opcode { directive = 0x7F, /* not a valid opcode; used internally by nGEN */ }; -enum class Operand {dst = 0, src0 = 1, src1 = 2, src2 = 3}; + +enum class Operand {dst = 0, src0 = 1, src1 = 2, src2 = 3, src3 = 4, src4 = 5}; enum class Directive { ignoredep_dst = 0, ignoredep_src0 = 1, ignoredep_src1 = 2, ignoredep_src2 = 3, + ignoredep_src3 = 4, + ignoredep_src4 = 5, subdep_dst = 8, wrdep = 0x10, fencedep = 0x11, @@ -1850,6 +1987,8 @@ static inline bool isSend(Opcode op) case Opcode::sendc: case Opcode::sends: case Opcode::sendsc: + case Opcode::sendgx: + case Opcode::sendgxc: return true; default: return false; @@ -1865,6 +2004,8 @@ static inline bool trackedByToken(HW hw, Opcode op, unsigned dstTypecode) case Opcode::dpas: case Opcode::dpasw: return true; + case Opcode::bdpas: + return (hw >= HW::XE3P_35_10); default: if (isSend(op)) return true; if (hw == HW::XeHPG && dstTypecode == 0b1011 /* :df */) return true; @@ -1882,6 +2023,7 @@ static inline bool isDirective(Opcode op) return (op == Opcode::directive); } + class AllPipes {}; enum class Pipe : uint8_t { Default = 0, @@ -2054,6 +2196,7 @@ class InstructionModifier { constexpr ConditionModifier getCMod() const { return static_cast(parts.cmod); } constexpr bool isAccWrEn() const { return parts.accWrCtrl; } constexpr bool getBranchCtrl() const { return parts.accWrCtrl; } + constexpr bool isFwd() const { return parts.accWrCtrl; } constexpr bool isCompact() const { return parts.cmptCtrl; } constexpr bool isBreakpoint() const { return parts.debugCtrl; } constexpr bool isSaturate() const { return parts.saturate; } @@ -2091,6 +2234,7 @@ class InstructionModifier { constexpr /* implicit */ InstructionModifier(ConditionModifier cmod_) : all{static_cast(cmod_) << 24} {} + constexpr14 /* implicit */ InstructionModifier(int execSize_) : InstructionModifier() { setExecSize(execSize_); } @@ -2126,6 +2270,9 @@ class InstructionModifier { static constexpr InstructionModifier createAccWrCtrl() { return InstructionModifier(false, false, false, 0, true, false, false, false, false, false, false, false); } + static constexpr InstructionModifier createFwd() { + return createAccWrCtrl(); + } static constexpr InstructionModifier createDebugCtrl() { return InstructionModifier(false, false, false, 0, false, true, false, false, false, false, false, false); } @@ -2591,7 +2738,7 @@ union MessageDescriptor { unsigned cache : 4; unsigned : 9; unsigned model : 2; - unsigned : 1; + unsigned overfetch : 1; /* storage location only, not supported in HW */ } standardLSC; struct { unsigned : 12; @@ -2645,6 +2792,8 @@ union ExtendedMessageDescriptor { ExtendedMessageDescriptor& operator=(SharedFunction sfid_) { parts.sfid = static_cast(sfid_); return *this; } }; +union SendgMessageDescriptor; + enum class AtomicOp : uint16_t { cmpwr_2w = 0x00, and_ = 0x1801, @@ -2673,6 +2822,11 @@ enum class AtomicOp : uint16_t { store = mov, cmpxchg = cmpwr, fcmpxchg = fcmpwr, + bfadd = 0x21FF, + bfsub = 0x22FF, + bfmin = 0x23FF, + bfmax = 0x24FF, + bfcmpxchg = 0x25FF, }; static inline int operandCount(AtomicOp op) { @@ -2685,6 +2839,7 @@ static inline int operandCount(AtomicOp op) { case AtomicOp::cmpwr_2w: case AtomicOp::cmpwr: case AtomicOp::fcmpwr: + case AtomicOp::bfcmpxchg: return 3; default: return 2; @@ -2710,6 +2865,8 @@ enum AddressModel : uint8_t { ModelScratch = 0x40, ModelSS = 0x80, ModelBSS = 0x81, + ModelA64A32U = 0xA4, + ModelA64A32S = 0xB4, }; class AddressBase { @@ -2739,6 +2896,12 @@ class AddressBase { static constexpr AddressBase createA64(bool coherent) { return AddressBase(coherent ? 0xFF : 0xFD, ModelA64); } + static constexpr AddressBase createA64A32U() { + return AddressBase(0, ModelA64A32U); + } + static constexpr AddressBase createA64A32S() { + return AddressBase(0, ModelA64A32S); + } static constexpr AddressBase createSLM() { return AddressBase(0xFE, ModelSLM); } @@ -2776,6 +2939,11 @@ class AddressBase { class hdc_base { public: + template inline void getDescriptor(HW hw, int esize, SharedFunction &sfid, AddressBase base, SendgMessageDescriptor &desc, int &addrLen, int &dataLen, const GRFDisp &addr) const { +#ifdef NGEN_SAFE + throw unsupported_message(); +#endif + } protected: void hwCheck(HW hw) const { #ifdef NGEN_SAFE @@ -2943,9 +3111,14 @@ class scattered_atomic : public hdc_base { public: void applyAtomicOp(AtomicOp op, const RegData &dst, MessageDescriptor &desc) const { +#ifdef NGEN_SAFE + if ((static_cast(op) & 0xFF) == 0xFF) + throw unsupported_message(); +#endif desc.atomic.returnData = !dst.isNull(); desc.atomic.atomicOp = static_cast(op) & 0xF; } + inline void applyAtomicOp(AtomicOp op, SendgMessageDescriptor &desc) const {} }; class scattered_word : public scattered_atomic { @@ -3204,18 +3377,23 @@ enum class LSCOpcode : uint8_t { ccs_update = 0x1D, rsi = 0x1E, fence = 0x1F, + atomic_bfadd = 0x21, + atomic_bfsub = 0x22, + atomic_bfmin = 0x23, + atomic_bfmax = 0x24, + atomic_bfcmpxchg = 0x25, }; enum class DataSizeLSC : uint16_t { - D8 = 0x0100, - D16 = 0x0201, - D32 = 0x0402, - D64 = 0x0803, - D8U32 = 0x0404, - D16U32 = 0x0405, + D8 = 0x0800, + D16 = 0x1001, + D32 = 0x2002, + D64 = 0x4003, + D8U32 = 0x2004, + D16U32 = 0x2005, }; -static inline constexpr unsigned getRegisterWidth(DataSizeLSC dsize) { +static inline constexpr unsigned getBitWidth(DataSizeLSC dsize) { return static_cast(dsize) >> 8; } @@ -3230,6 +3408,20 @@ enum class CacheSettingsLSC : uint8_t { L1IAR_L3C = 14, L1WB_L3WB = 14, L1UC_L3CC = 5, L1C_L3CC = 9, + L1UC_L2UC_L3UC = 2, + L1UC_L2UC_L3C = 3, L1UC_L2UC_L3WB = 3, + L1UC_L2C_L3UC = 4, L1UC_L2WB_L3UC = 4, + L1UC_L2C_L3C = 5, + L1C_L2UC_L3UC = 6, L1WT_L2UC_L3UC = 6, + L1C_L2UC_L3C = 7, L1WT_L2UC_L3WB = 7, + L1C_L2C_L3UC = 8, L1WT_L2WB_L3UC = 8, + L1C_L2C_L3C = 9, + L1S_L2UC_L3UC = 10, + L1S_L2UC_L3C = 11, L1S_L2UC_L3WB = 11, + L1S_L2C_L3UC = 12, L1S_L2WB_L3UC = 12, + L1S_L2C_L3C = 13, L1S_L2WB_L3WB = 13, + L1IAR_L2IAR_L3IAR = 14, L1WB_L2WB_L3UC = 14, + L1WB_L2UC_L3WB = 15, }; enum FenceScopeLSC : uint8_t { @@ -3254,12 +3446,12 @@ enum FlushTypeLSC : uint8_t { struct DataSpecLSC { MessageDescriptor desc; uint16_t vcount = 0; - uint8_t dbytes = 0; + uint8_t dbits = 0; enum { AddrSize16 = 1, AddrSize32 = 2, AddrSize64 = 3 }; enum { AddrFlat = 0, AddrSS = 1, AddrBSS = 2, AddrBTI = 3 }; - explicit constexpr DataSpecLSC(MessageDescriptor desc_, uint8_t vcount_ = 0, uint8_t dbytes_ = 0) : desc(desc_), vcount(vcount_), dbytes(dbytes_) {} + explicit constexpr DataSpecLSC(MessageDescriptor desc_, uint8_t vcount_ = 0, uint8_t dbits_ = 0) : desc(desc_), vcount(vcount_), dbits(dbits_) {} /* implicit */ DataSpecLSC(ChannelMask m) { desc.standardLSC.opcode = static_cast(LSCOpcode::load_cmask); desc.cmask.cmask = static_cast(m) ^ 0xF; @@ -3268,7 +3460,7 @@ struct DataSpecLSC { /* implicit */ DataSpecLSC(CacheSettingsLSC s) { desc.standardLSC.cache = static_cast(s); } - /* implicit */ constexpr DataSpecLSC(DataSizeLSC d) : desc((static_cast(d) & 0x7) << 9), dbytes(getRegisterWidth(d)) {} + /* implicit */ constexpr DataSpecLSC(DataSizeLSC d) : desc((static_cast(d) & 0x7) << 9), dbits(getBitWidth(d)) {} DataSpecLSC operator()(int vcount) const { auto vsEncoded = (vcount <= 4) ? (vcount - 1) : (utils::log2(vcount) + 1); @@ -3279,10 +3471,12 @@ struct DataSpecLSC { *this = *this | other; return *this; } + uint8_t dbytes() const { return dbits >> 3; } static constexpr DataSpecLSC createV(unsigned vcount, unsigned venc) { return DataSpecLSC{MessageDescriptor(venc << 12), uint8_t(vcount), 0}; } static constexpr DataSpecLSC createTranspose() { return DataSpecLSC{MessageDescriptor(1 << 15)}; } static constexpr DataSpecLSC createVNNI() { return DataSpecLSC{MessageDescriptor(1 << 7)}; } + static constexpr DataSpecLSC createOverfetch() { return DataSpecLSC{MessageDescriptor(1u << 31)}; } template void getDescriptors(HW hw, const InstructionModifier &mod, AddressBase base, MessageDescriptor &desc, ExtendedMessageDescriptor &exdesc, const GRFDisp &addr) const { @@ -3291,6 +3485,7 @@ struct DataSpecLSC { exdesc = (base.getModel() == ModelSLM) ? SharedFunction::slm : SharedFunction::ugm; desc.standardLSC.addrSize = a64 ? AddrSize64 : AddrSize32; + desc.standardLSC.overfetch = false; if (base.getModel() == ModelA32) base = AddressBase::createBTS(0xFF); @@ -3328,11 +3523,11 @@ struct DataSpecLSC { auto vc = std::max(vcount, 1); if (this->desc.standardLSC.transpose && !desc.standardLSC.opcode) { desc.parts.messageLen = 1; - desc.parts.responseLen = GRF::bytesToGRFs(hw, dbytes * vc); + desc.parts.responseLen = GRF::bytesToGRFs(hw, dbytes() * vc); } else { auto effSIMDGRFs = 1 + ((mod.getExecSize()) >> (GRF::log2Bytes(hw) - 1)); desc.parts.messageLen = effSIMDGRFs * (a64 ? 2 : 1); - desc.parts.responseLen = effSIMDGRFs * vc * (1 + (dbytes >> 3)); + desc.parts.responseLen = effSIMDGRFs * vc * (1 + (dbytes() >> 3)); } if (access == Access::Write) @@ -3344,13 +3539,15 @@ struct DataSpecLSC { desc.standardLSC.opcode = static_cast(op) >> 8; } + template inline void getDescriptor(HW hw, int esize, SharedFunction &sfid, AddressBase base, SendgMessageDescriptor &desc, int &addrLen, int &dataLen, const GRFDisp &addr) const; + inline void applyAtomicOp(AtomicOp op, SendgMessageDescriptor &desc) const; }; static inline DataSpecLSC scattered(const DataSpecLSC &dtype, int vsize = 1) { return dtype(vsize); } static inline DataSpecLSC block(const DataSpecLSC &dtype, int vsize = 1) { return dtype(vsize) | DataSpecLSC::createTranspose(); } inline constexpr DataSpecLSC operator|(const DataSpecLSC &s1, const DataSpecLSC &s2) { - return DataSpecLSC{s1.desc | s2.desc, uint8_t(s1.vcount | s2.vcount), uint8_t(s1.dbytes | s2.dbytes)}; + return DataSpecLSC{s1.desc | s2.desc, uint8_t(s1.vcount | s2.vcount), uint8_t(s1.dbits | s2.dbits)}; } class block_2d : public DataSpecLSC { @@ -3370,6 +3567,7 @@ class block_2d : public DataSpecLSC { base.checkModel(ModelA64); desc = this->desc; + desc.standardLSC.overfetch = false; desc.standardLSC.opcode = static_cast((access == Access::Write) ? LSCOpcode::store_2dblock : LSCOpcode::load_2dblock); desc.standardLSC.model = AddrFlat; @@ -3377,7 +3575,7 @@ class block_2d : public DataSpecLSC { auto w = width, h = height; if (this->desc.standardLSC.transpose) std::swap(w, h); desc.parts.messageLen = 1; - desc.parts.responseLen = std::min(count * GRF::bytesToGRFs(hw, utils::roundup_pow2(w) * h * this->dbytes), 31); + desc.parts.responseLen = std::min(count * GRF::bytesToGRFs(hw, utils::roundup_pow2(w) * h * this->dbytes()), 31); exdesc = SharedFunction::ugm; @@ -3385,6 +3583,7 @@ class block_2d : public DataSpecLSC { exdesc.block2D.yOffset = addr.getDispY(); } + template inline void getDescriptor(HW hw, int esize, SharedFunction &sfid, AddressBase base, SendgMessageDescriptor &desc, int &addrLen, int &dataLen, const GRFDisp &addr) const; }; // Generate descriptors for a load operation. @@ -3429,6 +3628,405 @@ static inline void encodeAtomicDescriptors(HW hw, MessageDescriptor &desc, Exten } +/********************************************************************/ +/* New send encoding and decoding. */ +/********************************************************************/ +enum GatewayOpcode { + eot = 0, + bar = 4, + nbar = 5, + save_bar = 8, + restore_bar = 9, + eotr = 10, + restore_btd_stack = 11, + sip_bar = 12, +}; + + +union SendgMessageDescriptor { + uint64_t all; + struct { + uint64_t opcode : 6; + uint64_t : 58; + } common; + struct { + uint64_t : 7; + uint64_t vlen : 3; + uint64_t transpose : 1; + uint64_t dataSize : 3; + uint64_t addrSize : 2; + uint64_t cacheMode : 4; + uint64_t : 1; + uint64_t overfetch : 1; + uint64_t : 22; + uint64_t scale : 2; + uint64_t : 18; + } mem; + struct { + uint64_t : 7; + uint64_t cmask : 4; + uint64_t : 53; + } cmask; + struct { + uint64_t : 22; + int64_t offset : 22; + uint64_t : 20; + } flat; + struct { + uint64_t : 22; + uint64_t ssIdx : 5; + int64_t offset : 17; + uint64_t : 20; + } surface; + struct { + uint64_t : 9; + uint64_t vnni : 1; + uint64_t transpose : 1; + uint64_t : 11; + int64_t xOffset : 12; + int64_t yOffset : 12; + uint64_t : 18; + } block2D; + struct { + uint64_t : 8; + uint64_t flushType : 3; + uint64_t fenceScope : 3; + uint64_t : 50; + } fence; + struct { + uint64_t : 7; + uint64_t activeOnly : 1; + uint64_t legacy : 1; + uint64_t : 55; + } barrier; + struct { + uint64_t : 7; + uint64_t replay : 2; + uint64_t : 55; + } eot; + + constexpr SendgMessageDescriptor() : all(0) {} + explicit constexpr SendgMessageDescriptor(uint64_t all_) : all(all_) {} + + int vectorLength() const { + const int vlDecode[8] = {1, 2, 3, 4, 8, 16, 32, 64}; + return vlDecode[mem.vlen]; + } + + int log2ElementBytesMem() const { + return mem.dataSize & 0x3; + } + + int elementBytesMem() const { + return 1 << (mem.dataSize & 0x3); + } + + int elementBytesReg() const { + const int dsDecode[8] = {1, 2, 4, 8, 4, 4, 0, 0}; + return dsDecode[mem.dataSize]; + } + + + // Return # destination registers if known, and -1 if not. + inline int dstLen(HW hw, int execSize, SharedFunction sfid) const + { + int effSIMDGRFs = 1 + (execSize >> (GRF::log2Bytes(hw) - 1)); + + switch (sfid) { + case SharedFunction::ugm: + case SharedFunction::tgm: + case SharedFunction::slm: + case SharedFunction::urb: + switch (static_cast(common.opcode)) { + case LSCOpcode::load: { + int vc = vectorLength(); + int dbytes = elementBytesReg(); + if (mem.transpose) + return GRF::bytesToGRFs(hw, dbytes * vc); + else + return effSIMDGRFs * vc * (1 + (dbytes >> 3)); + break; + } + case LSCOpcode::load_cmask: { + int vc = utils::popcnt(cmask.cmask); + return effSIMDGRFs * vc; + break; + } + case LSCOpcode::load_2dblock: + return -1; /* cannot determine from descriptor */ + case LSCOpcode::fence: + return 1; + case LSCOpcode::atomic_inc: + case LSCOpcode::atomic_dec: + case LSCOpcode::atomic_load: + case LSCOpcode::atomic_add: + case LSCOpcode::atomic_sub: + case LSCOpcode::atomic_min: + case LSCOpcode::atomic_max: + case LSCOpcode::atomic_umin: + case LSCOpcode::atomic_umax: + case LSCOpcode::atomic_cmpxchg: + case LSCOpcode::atomic_fadd: + case LSCOpcode::atomic_fsub: + case LSCOpcode::atomic_fmin: + case LSCOpcode::atomic_fmax: + case LSCOpcode::atomic_fcmpxchg: + case LSCOpcode::atomic_and: + case LSCOpcode::atomic_or: + case LSCOpcode::atomic_xor: + return effSIMDGRFs * (1 + (elementBytesReg() >> 3)); + default: + return 0; + } + break; + case SharedFunction::gtwy: + switch (static_cast(common.opcode)) { + case GatewayOpcode::sip_bar: + case GatewayOpcode::save_bar: + return 1; + default: + return 0; + } + break; + default: break; + } + + return -1; + } + + inline int src0Len(HW hw, int execSize, SharedFunction sfid) const + { + switch (sfid) { + case SharedFunction::slm: + case SharedFunction::ugm: + case SharedFunction::tgm: + case SharedFunction::urb: + if (static_cast(common.opcode) == LSCOpcode::fence) return 0; + if (sfid == SharedFunction::slm) return 1; + if (mem.addrSize == 0b10) return GRF::bytesToGRFs(hw, execSize * 8); + return 1; + case SharedFunction::gtwy: + switch (static_cast(common.opcode)) { + case GatewayOpcode::eot: + case GatewayOpcode::eotr: + case GatewayOpcode::bar: + case GatewayOpcode::sip_bar: + case GatewayOpcode::nbar: + case GatewayOpcode::restore_bar: + return 1; + case GatewayOpcode::save_bar: + return 1; + default: return 0; + } + break; + default: break; + } + return -1; + } + + inline int src1Len(HW hw, int execSize, SharedFunction sfid) const + { + int effSIMDGRFs = 1 + (execSize >> (GRF::log2Bytes(hw) - 1)); + + switch (sfid) { + case SharedFunction::ugm: + case SharedFunction::tgm: + case SharedFunction::slm: + case SharedFunction::urb: + switch (static_cast(common.opcode)) { + case LSCOpcode::store: { + int vc = vectorLength(); + int dbytes = elementBytesReg(); + if (mem.transpose) + return GRF::bytesToGRFs(hw, dbytes * vc); + else + return effSIMDGRFs * vc * (1 + (dbytes >> 3)); + break; + } + case LSCOpcode::store_cmask: { + int vc = utils::popcnt(cmask.cmask); + return effSIMDGRFs * vc; + break; + } + case LSCOpcode::store_2dblock: + return -1; /* cannot determine from descriptor */ + case LSCOpcode::atomic_add: + case LSCOpcode::atomic_sub: + case LSCOpcode::atomic_min: + case LSCOpcode::atomic_max: + case LSCOpcode::atomic_umin: + case LSCOpcode::atomic_umax: + case LSCOpcode::atomic_fadd: + case LSCOpcode::atomic_fsub: + case LSCOpcode::atomic_fmin: + case LSCOpcode::atomic_fmax: + case LSCOpcode::atomic_and: + case LSCOpcode::atomic_or: + case LSCOpcode::atomic_xor: + return effSIMDGRFs * (1 + (elementBytesReg() >> 3)); + case LSCOpcode::atomic_cmpxchg: + case LSCOpcode::atomic_fcmpxchg: + return 2 * effSIMDGRFs * (1 + (elementBytesReg() >> 3)); + default: return 0; + } + break; + case SharedFunction::gtwy: + return 0; + default: break; + } + + return -1; + } +}; + +static_assert(sizeof(SendgMessageDescriptor) == 8, "SendgMessageDescriptor has been padded by compiler"); + +static inline unsigned encodeScaleLSC(int scale) +{ + if (scale <= 2) return scale; + if (scale == 4) return 3; +#ifdef NGEN_SAFE + throw invalid_address_modifier_exception(); +#endif + return 0; +} + +template +void DataSpecLSC::getDescriptor(HW hw, int execSize, SharedFunction &sfid, AddressBase base, SendgMessageDescriptor &desc, int &addrLen, int &dataLen, const GRFDisp &addr) const +{ + SharedFunction defaultSFID = SharedFunction::ugm; + + desc.common.opcode = this->desc.standardLSC.opcode; + if (access == Access::Write) + desc.common.opcode |= static_cast(LSCOpcode::store); + desc.cmask.cmask = this->desc.cmask.cmask; /* or vlen + transpose */ + desc.mem.dataSize = this->desc.standardLSC.dataSize; + desc.mem.cacheMode = this->desc.standardLSC.cache; + desc.mem.scale = encodeScaleLSC(addr.getScale() / (std::max(vcount, 1) * desc.elementBytesMem())); + desc.mem.overfetch = this->desc.standardLSC.overfetch; + + bool flat = true; + + auto model = base.getModel(); + if (model == ModelA64) { + auto ind0 = addr.getInd0(); + auto base = addr.getBase(); + if (!ind0.isNull() && base.isValid()) switch (base.getType()) { + case DataType::ud: model = ModelA64A32U; break; + case DataType::d: model = ModelA64A32S; break; + default: break; + } + } + + switch (model) { + case ModelA64: desc.mem.addrSize = 0b10; break; + case ModelA64A32U: desc.mem.addrSize = 0b00; break; + case ModelA64A32S: desc.mem.addrSize = 0b01; break; + case ModelSLM: + defaultSFID = SharedFunction::slm; + desc.mem.addrSize = 0b00; + break; + case ModelSS: + case ModelBSS: + flat = false; + desc.mem.addrSize = 0b11; + desc.surface.ssIdx = base.getIndex(); + break; + default: +#ifdef NGEN_SAFE + throw invalid_model_exception(); +#endif + break; + } + + int offsetShift = desc.log2ElementBytesMem(); + int sdisp = addr.getDisp() >> offsetShift; + + if (flat) { + desc.flat.offset = sdisp; +#ifdef NGEN_SAFE + if ((desc.flat.offset << offsetShift) != addr.getDisp()) + throw invalid_address_modifier_exception(); +#endif + } else { + desc.surface.offset = sdisp; +#ifdef NGEN_SAFE + if ((desc.surface.offset << offsetShift) != addr.getDisp()) + throw invalid_address_modifier_exception(); +#endif + } + + auto vc = std::max(vcount, 1); + bool block = this->desc.standardLSC.transpose && this->desc.standardLSC.opcode == static_cast(LSCOpcode::load); + if (block) { + addrLen = 1; + dataLen = GRF::bytesToGRFs(hw, dbytes() * vc); + } else { + auto effSIMDGRFs = 1 + (execSize >> (GRF::log2Bytes(hw) - 1)); + addrLen = effSIMDGRFs * (base.isA64() ? 2 : 1); + dataLen = effSIMDGRFs * vc * (1 + (dbytes() >> 3)); + } + + if (sfid == SharedFunction::automatic) + sfid = defaultSFID; +} + +void DataSpecLSC::applyAtomicOp(AtomicOp op, SendgMessageDescriptor &desc) const +{ + desc.common.opcode = static_cast(op) >> 8; +} + +template +void block_2d::getDescriptor(HW hw, int execSize, SharedFunction &sfid, AddressBase base, SendgMessageDescriptor &desc, int &addrLen, int &dataLen, const GRFDisp &addr) const +{ + auto addrNoDisp = addr; + addrNoDisp.clearDisp(); + + DataSpecLSC::getDescriptor(hw, execSize, sfid, base, desc, addrLen, dataLen, addrNoDisp); + desc.common.opcode = static_cast((access == Access::Write) ? LSCOpcode::store_2dblock : LSCOpcode::load_2dblock); + desc.block2D.vnni = this->desc.block2D.vnni; + desc.block2D.xOffset = addr.getDispX(); + desc.block2D.yOffset = addr.getDispY(); + +#ifdef NGEN_SAFE + if (desc.block2D.xOffset != addr.getDispX() || desc.block2D.yOffset != addr.getDispY()) + throw invalid_address_modifier_exception(); +#endif + + auto w = width, h = height; + if (desc.mem.transpose) std::swap(w, h); + + addrLen = 1; + dataLen = std::min(count * GRF::bytesToGRFs(hw, utils::roundup_pow2(w) * h * this->dbytes()), 31); + + if (sfid == SharedFunction::automatic) + sfid = SharedFunction::ugm; +} + +template +static inline void encodeLoadDescriptor(HW hw, SendgMessageDescriptor &desc, SharedFunction &sfid, int &dstLen, int &src0Len, + const InstructionModifier &mod, const DataSpec &spec, AddressBase base, const GRFDisp &addr) +{ + spec.template getDescriptor(hw, mod.getExecSize(), sfid, base, desc, src0Len, dstLen, addr); +} + +template +static inline void encodeStoreDescriptor(HW hw, SendgMessageDescriptor &desc, SharedFunction &sfid, int &src0Len, int &src1Len, + const InstructionModifier &mod, const DataSpec &spec, AddressBase base, const GRFDisp &addr) +{ + spec.template getDescriptor(hw, mod.getExecSize(), sfid, base, desc, src0Len, src1Len, addr); +} + +template +static inline void encodeAtomicDescriptor(HW hw, SendgMessageDescriptor &desc, SharedFunction &sfid, int &src0Len, int &src1Len, + AtomicOp op, const InstructionModifier &mod, const DataSpec &spec, AddressBase base, const GRFDisp &addr) +{ + spec.template getDescriptor(hw, mod.getExecSize(), sfid, base, desc, src0Len, src1Len, addr); + spec.applyAtomicOp(op, desc); +} + + + } /* namespace NGEN_NAMESPACE */ #if defined(__clang__) diff --git a/third_party/ngen/ngen_debuginfo.hpp b/third_party/ngen/ngen_debuginfo.hpp index 561b69eac88..a92621cbee6 100644 --- a/third_party/ngen/ngen_debuginfo.hpp +++ b/third_party/ngen/ngen_debuginfo.hpp @@ -17,6 +17,7 @@ #ifndef NGEN_DEBUGINFO_HPP #define NGEN_DEBUGINFO_HPP +#include #include #include #include diff --git a/third_party/ngen/ngen_decoder.hpp b/third_party/ngen/ngen_decoder.hpp index e3cdc0054cd..02b2dcd8844 100644 --- a/third_party/ngen/ngen_decoder.hpp +++ b/third_party/ngen/ngen_decoder.hpp @@ -80,6 +80,8 @@ bool Decoder::getOperandRegion(autoswsb::DependencyRegion ®ion, int opNum) co { checkCompaction(); region.hw = hw; + if (hw >= HW::XE3P_35_10) + return get().getOperandRegion(region, opNum); if (hw >= HW::XeHPC) return get().getOperandRegion(region, opNum); if (hw >= HW::Gen12LP) diff --git a/third_party/ngen/ngen_elf.hpp b/third_party/ngen/ngen_elf.hpp index 19329c09129..dd7cb1d4fc3 100644 --- a/third_party/ngen/ngen_elf.hpp +++ b/third_party/ngen/ngen_elf.hpp @@ -69,6 +69,8 @@ class ELFCodeGenerator : public BinaryCodeGenerator void requireWalkOrder(int o1, int o2, int o3) { interface_.requireWalkOrder(o1, o2, o3); } void requireWorkgroup(size_t x, size_t y = 1, size_t z = 1) { interface_.requireWorkgroup(x, y, z); } + void setEfficient64Bit(bool def = true) { BinaryCodeGenerator::setEfficient64Bit(def); interface_.setEfficient64Bit(def); } + void finalizeInterface() { interface_.finalize(); } template @@ -583,7 +585,8 @@ template NGEN_NAMESPACE::Subregister getGroupID(Targs&&... a void prologue() { NGEN_NAMESPACE::ELFCodeGenerator::prologue(); } \ void epilogue(const NGEN_NAMESPACE::RegData &r0_info = NGEN_NAMESPACE::RegData()) { NGEN_NAMESPACE::ELFCodeGenerator::epilogue(r0_info); } -#define NGEN_FORWARD_SCOPE_ELF_EXTRA(scope) +#define NGEN_FORWARD_SCOPE_ELF_EXTRA(scope) \ +template void setEfficient64Bit(Targs&&... args) { scope::setEfficient64Bit(std::forward(args)...); } #define NGEN_FORWARD_SCOPE_ELF_EXTRA2(scope) diff --git a/third_party/ngen/ngen_emulation.hpp b/third_party/ngen/ngen_emulation.hpp index 269ab1bd195..242d78bbe85 100644 --- a/third_party/ngen/ngen_emulation.hpp +++ b/third_party/ngen/ngen_emulation.hpp @@ -53,6 +53,8 @@ struct EmulationStrategy { else emulate64_mul = emulate64_logic = true; } + if (hw_ >= HW::XE3P_35_10) + emulateDWxDW = emulate64_mul = false; emulate64_mul |= emulate64; } }; @@ -663,6 +665,13 @@ struct EmulationImplementation { g.mov(mod, dstHi, dstLo, loc); g.mov(mod, dstLo, acc, loc); + } else if (dstQ && s0D && ((s1W && !s1Immed) && !emulateDWxDW)) { + RegData dstLo, dstHi; + splitToDW(dst, dstLo, dstHi); + if(dstLo.getBase() == src0.getBase() && src0.getOffset() == dstLo.getOffset()) + stub(); + g.mov(mod, dstLo, src1); + g.mul(mod, dst, src0, dstLo); } else if (dstD && s0D && s1D && strategy.emulateDWxDW) { int ne1 = GRF::bytes(g.getHardware()) >> 2; diff --git a/third_party/ngen/ngen_gen12.hpp b/third_party/ngen/ngen_gen12.hpp index 3741b6c43f7..8591d184acb 100644 --- a/third_party/ngen/ngen_gen12.hpp +++ b/third_party/ngen/ngen_gen12.hpp @@ -33,6 +33,11 @@ template struct EncodingTag12Dispatch { using tag = EncodingTag12; template <> struct EncodingTag12Dispatch { using tag = EncodingTagXeHPC; }; template <> struct EncodingTag12Dispatch { using tag = EncodingTagXeHPC; }; template <> struct EncodingTag12Dispatch { using tag = EncodingTagXeHPC; }; +struct EncodingTagXe3p : public EncodingTagXeHPC {}; + +template <> struct EncodingTag12Dispatch { using tag = EncodingTagXe3p; }; +template <> struct EncodingTag12Dispatch { using tag = EncodingTagXe3p; }; +template <> struct EncodingTag12Dispatch { using tag = EncodingTagXe3p; }; class SWSBInfo12 { @@ -419,6 +424,33 @@ struct Instruction12 { unsigned : 32; unsigned : 32; } sendIndirect; + struct { + unsigned : 32; + // + unsigned eot : 1; + unsigned dstRegFile : 1; + unsigned ind0Present : 1; + unsigned ind1Present : 1; + unsigned desc40_41 : 2; + unsigned ind1_desc42_46 : 5; + unsigned ind0 : 3; + unsigned : 2; + unsigned desc32_39 : 8; + unsigned dstReg : 8; + // + unsigned desc30_31 : 2; + unsigned src0RegFile : 1; + unsigned src0Len : 5; + unsigned src0Reg : 8; + unsigned desc16_27 : 12; + unsigned sfid : 4; + // + unsigned desc28_29 : 2; + unsigned src1RegFile : 1; + unsigned src1Len : 5; + unsigned src1Reg : 8; + unsigned desc0_15 : 16; + } sendg; struct { unsigned : 32; // common unsigned : 1; @@ -427,6 +459,112 @@ struct Instruction12 { int32_t uip; int32_t jip; } branches; + struct { + unsigned : 32; + // + unsigned : 12; + unsigned dstReg8 : 1; + unsigned : 19; + // + unsigned : 32; + unsigned : 32; + } unaryXe3pImm; + struct { + unsigned : 32; + // + unsigned : 11; + unsigned src0Reg8 : 1; + unsigned : 2; + unsigned dstReg8 : 1; + unsigned : 17; + // + unsigned : 32; + unsigned : 32; + } binaryXe3pImm; + struct { + unsigned : 32; + // + unsigned : 7; + unsigned execType : 1; + unsigned : 24; + // + unsigned : 32; + // + unsigned dstReg8 : 1; + unsigned src0Reg8 : 1; + unsigned : 18; + unsigned src1Scalar : 1; + unsigned : 5; + unsigned src1Reg8 : 1; + unsigned : 5; + } binaryXe3p; + struct { + unsigned : 32; + // + unsigned : 17; + unsigned src1Scalar : 1; + unsigned : 14; + // + unsigned : 19; + unsigned src1Reg8 : 1; + unsigned : 7; + unsigned src2Reg8 : 1; + unsigned : 4; + // + unsigned dstReg8 : 1; + unsigned src0Reg8 : 1; + unsigned : 30; + } ternaryXe3p; + struct { + unsigned : 32; + // + unsigned : 12; + unsigned dstReg8 : 1; + unsigned src0Reg8 : 1; + unsigned : 18; + // + unsigned : 32; + unsigned : 32; + } branchXe3p; + struct { + unsigned : 32; + // + unsigned : 1; + unsigned dstReg8 : 1; + unsigned : 30; + // + unsigned : 2; + unsigned src0Reg8 : 1; + unsigned : 29; + // + unsigned : 2; + unsigned src1Reg8 : 1; + unsigned : 29; + } sendgx; + struct { + unsigned : 32; + // + unsigned : 2; + unsigned src4RegFile : 1; + unsigned src3RegFile : 1; + unsigned : 7; + unsigned src4Reg0_3 : 4; + unsigned src3Reg0 : 1; + unsigned : 16; + // + unsigned src3Reg1_2 : 2; + unsigned : 26; + unsigned src3Reg3_6 : 4; + // + unsigned : 3; + unsigned src4SubReg3_5 : 3; + unsigned src3SubReg4_5 : 2; + unsigned : 8; + unsigned src3Reg7_8 : 2; + unsigned : 1; + unsigned src4Reg4_8 : 5; + unsigned : 8; + } bdpas; uint64_t qword[2]; }; @@ -459,6 +597,7 @@ struct Instruction12 { inline bool getSendDesc(MessageDescriptor &desc) const; inline bool getARFType(ARFType &arfType, int opNum, HW hw) const; inline int getFencedepJIP() const; + inline SendgMessageDescriptor getSendgDesc() const; bool isMathMacro() const { if (opcode() != Opcode::math) return false; @@ -486,30 +625,45 @@ struct InstructionXeHPC : public Instruction12 { return Instruction12::getCModDepRegion(region); } + bool isSendg() const { + return (opcode() == Opcode::sendg || opcode() == Opcode::sendgc || opcode() == Opcode::sendgx || opcode() == Opcode::sendgxc); + } + bool eot() const { + if (isSendg()) return sendg.eot; return Instruction12::eot(); } bool atomic() const { + if (isSendg()) return false; /* no atomic field */ return Instruction12::atomic(); } }; static_assert(sizeof(InstructionXeHPC) == 16, "Internal error: InstructionXeHPC has been padded by the compiler."); +struct InstructionXe3p : public InstructionXeHPC { + template + bool getOperandRegion(autoswsb::DependencyRegion ®ion, int opNum) const { + return Instruction12::getOperandRegion(region, opNum); + } +}; + +static_assert(sizeof(InstructionXe3p) == 16, "Internal error: InstructionXe3p has been padded by the compiler."); + // Encoding routines. static inline unsigned getTypecode12(DataType type) { static const uint8_t conversionTable[32] = {2,6,1,5,0,4,11,10,3,7,9,13,8,0,4,8, - 14,12,2,2,2,2,2,2,2,2,2,2,0,4,0,4}; + 14,12,2,2,2,2,2,2,2,2,8,8,0,4,0,4}; return conversionTable[static_cast(type) & 0x1F]; } static inline unsigned encodeSubBytePrecision12(DataType type) { static const uint8_t conversionTable[32] = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, - 0,0,0,0,0,0,0,0,0,0,0,0,1,1,2,2}; + 0,0,0,0,0,0,0,0,0,0,2,1,1,1,2,2}; return conversionTable[static_cast(type) & 0x1F]; } @@ -586,6 +740,42 @@ static inline constexpr14 BinaryOperand12 encodeBinaryOperand12(const RegData &r return op; } +template +static inline constexpr14 BinaryOperand12 encodeBinaryOperand12(const RegData &rd, EncodingTagXe3p tag) +{ + BinaryOperand12 op{0}; + +#ifdef NGEN_SAFE + if (rd.isInvalid()) throw invalid_object_exception(); +#endif + + if (rd.isIndirect()) { + op.indirect.addrOff = (rd.getOffset() >> 1); + op.indirect.addrReg = rd.getIndirectOff(); + op.indirect.addrMode = 1; + if (srcN == 0) { + op.indirect.vs = (rd.isVxIndirect()) ? 0xFFFF : pow2Encode(rd.getVS()); + op.indirectXeHPC.addrOff0 = (rd.getOffset() & 1); + } + } else { + op.direct.regFile = rd.getRegFile8(); + op.direct.subRegNum = (rd.getByteOffset() >> 1); + op.direct.regNum = rd.getBase(); + op.direct.addrMode = 0; + if (srcN == 0) + op.directXeHPC.vs = pow2Encode(rd.getVS()); + if (srcN >= 0) + op.directXeHPC.subRegNum0 = rd.getByteOffset() & 1; + } + + if (encodeHS && srcN <= 0) + op.direct.hs = pow2Encode(rd.getHS()); + + if (srcN == 0) op.direct.width = utils::log2(rd.getWidth()); + + return op; +} + template static inline constexpr14 BinaryOperand12 encodeBinaryOperand12(const ExtendedReg ®, Tag tag) { @@ -644,6 +834,57 @@ static inline constexpr14 TernaryOperand12 encodeTernaryOperand12(const Extended return op; } +template +static inline void encodeTernary512GRF(Instruction12 &i, unsigned dstReg8, unsigned src0Reg8, unsigned src1Reg8, unsigned src2Reg8, bool src1Scalar, Tag tag) {} + +static inline void encodeTernary512GRF(Instruction12 &i, unsigned dstReg8, unsigned src0Reg8, unsigned src1Reg8, unsigned src2Reg8, bool src1Scalar, EncodingTagXe3p tag) +{ + i.ternaryXe3p.dstReg8 = dstReg8; + i.ternaryXe3p.src0Reg8 = src0Reg8; + i.ternaryXe3p.src1Reg8 = src1Reg8; + i.ternaryXe3p.src2Reg8 = src2Reg8; + i.ternaryXe3p.src1Scalar = src1Scalar; +} + +template +static inline void encodeTernary512GRF(Instruction12 &i, D dst, S0 src0, S1 src1, S2 src2, Tag tag) +{ + encodeTernary512GRF(i, getHighBit(dst), getHighBit(src0), getHighBit(src1), getHighBit(src2), checkSrc1Scalar(i.opcode(), src1, dst, tag), tag); +} + +template +static inline bool checkSrc1Scalar(Opcode op, RegData r, RegData dst, Tag tag) { return false; } + +static inline bool checkSrc1Scalar(Opcode op, RegData r, RegData dst, EncodingTagXe3p tag) +{ + switch (op) { + case Opcode::bdpas: + case Opcode::dpas: + case Opcode::shfl: return false; + default: break; + } + if (r.isScalar()) + return true; + +#ifdef NGEN_SAFE + if (!r.isARF()) { + int reqHS = dst.getHS() * dst.getBytes() / r.getBytes(); + bool flat = (r.getHS() == reqHS && r.getVS() == r.getWidth() * reqHS) + || (r.getHS() == 0 && r.getWidth() == 1 && r.getVS() == reqHS); + if (!flat) + throw invalid_region_exception(); + } +#endif + + return false; +} + +template +static inline bool checkSrc1Scalar(Opcode op, ExtendedReg r, ExtendedReg dst, Tag tag) +{ + return checkSrc1Scalar(op, r.getBase(), dst.getBase(), tag); +} + static inline void encodeCommon12(Instruction12 &i, Opcode opcode, const InstructionModifier &mod, const RegData &dst, EncodingTag12 tag) { Instruction12 i2; /* separate variable to avoid gcc13 bug */ @@ -678,6 +919,8 @@ static inline void encodeCommon12(Instruction12 &i, Opcode opcode, const Instruc i2.common.maskCtrl = mod.parts.maskCtrl; i2.common.atomicCtrl = mod.parts.threadCtrl; i2.commonXeHPC.dstExt = (dst.isIndirect() ? dst.getOffset() : dst.getByteOffset()) & 1; + if (opcode == Opcode::dpas) + i2.common.accWrCtrl = mod.parts.accWrCtrl; /* {Fwd} */ i2.common.saturate = mod.parts.saturate; i.common = i2.common; } @@ -831,6 +1074,45 @@ static inline void encodeSendDesc(Instruction12 &i, RegData desc) i.send.descIsReg = true; } +static inline unsigned getHighBit(RegData r) { + return r.isARF() ? 0 : (r.getBase() >> 8); +} + +static inline unsigned getHighBit(ExtendedReg r) { + return r.getBase().getBase() >> 8; +} + +static inline unsigned getHighBit(Immediate i) { + return 0; +} + +static inline uint8_t encodeDnsclCtrl(uint8_t mode, RoundingType rnd, RegData &dst, RegData &src0, RegData &src1) +{ + auto dt = dst.getType(); + auto st = src0.getType(); + +#ifdef NGEN_SAFE + if (dt != DataType::s4 && dt != DataType::e2m1 && dt != DataType::e3m0) throw invalid_type_exception(); + if (st != DataType::hf && st != DataType::bf) throw invalid_type_exception(); + if (st != src1.getType()) throw invalid_type_exception(); +#endif + + dst.setOffset(dst.getByteOffset() >> 2); + src0.setOffset(src0.getByteOffset() >> 2); + src1.setOffset(src1.getByteOffset() >> 2); + dst.setType(DataType::ud); + src0.setType(DataType::ud); + src1.setType(DataType::ud); + + mode <<= 4; + if (dt == DataType::e2m1) mode |= 1; + if (dt == DataType::s4) mode |= 2; + if (st == DataType::bf) mode |= 4; + if (rnd == RoundingType::rne) mode |= 8; + + return mode; +} + /*********************/ /* Decoding Routines */ /*********************/ @@ -864,6 +1146,9 @@ bool Instruction12::getOperandRegion(autoswsb::DependencyRegion ®ion, int opN using namespace autoswsb; constexpr bool xeHPC = !std::is_same::value; + constexpr bool xe3p = std::is_same::value; + + bool unaryXe3p = false; auto hw = region.hw; auto op = opcode(); @@ -887,6 +1172,7 @@ bool Instruction12::getOperandRegion(autoswsb::DependencyRegion ®ion, int opN BinaryOperand12 o; o.bits = binary.dst; unsigned regNum = o.direct.regNum; + if (xe3p) regNum |= (binaryXe3p.dstReg8 << 8); region = DependencyRegion(hw, 1, GRF(regNum)); return true; } @@ -895,49 +1181,84 @@ bool Instruction12::getOperandRegion(autoswsb::DependencyRegion ®ion, int opN o0.bits = binary.src0; o1.bits = binary.src1; unsigned rn0 = o0.direct.regNum, rn1 = o1.direct.regNum; + if (xe3p) { + rn0 |= (binaryXe3p.src0Reg8 << 8); + rn1 |= (binaryXe3p.src1Reg8 << 8); + } region = DependencyRegion(hw, GRF(rn0)-GRF(rn1)); return true; } default: return false; } + case Opcode::bdpas: case Opcode::dpas: case Opcode::dpasw: { unsigned sdepth = 1 << dpas.sdepth; unsigned rcount = 1 + dpas.rcount; unsigned len; TernaryOperand12 o; + unsigned regNum8 = 0; switch (opNum) { case -1: { int typebytes = decodeDPASTypecodeBytes12(ternary.dstType); len = (rcount * typebytes + 3) >> 2; + regNum8 = ternaryXe3p.dstReg8; o.bits = ternary.dst; break; } case 0: { int typebytes = decodeDPASTypecodeBytes12(ternary.src0Type); len = (rcount * typebytes + 3) >> 2; + regNum8 = ternaryXe3p.src0Reg8; o.bits = ternary.src0; break; } case 1: len = sdepth; o.bits = ternary.src1; + regNum8 = ternaryXe3p.src1Reg8; break; case 2: { if (op == Opcode::dpasw) rcount = (rcount + 1) >> 1; o.bits = ternary.src2; auto sr = o.direct.subRegNum; + regNum8 = ternaryXe3p.src2Reg8; + if (op == Opcode::bdpas) { + sr = 0; + rcount = 8; + } if (xeHPC) len = ((sr << 1) + sdepth * rcount * 4 + 63) >> 6; else len = (sr + sdepth * rcount * 4 + 31) >> 5; break; } + case 3: { + if(op != Opcode::bdpas) return false; + auto sr = bdpas.src3SubReg4_5 << 4; + o.direct.regNum = bdpas.src3Reg0; + o.direct.regNum |= bdpas.src3Reg1_2 << 1; + o.direct.regNum |= bdpas.src3Reg3_6 << 3; + o.direct.regNum |= bdpas.src3Reg7_8 << 7; + rcount = 8; + len = (sr + sdepth * rcount * 4 + 31) >> 5; + break; + } + case 4: { + if(op != Opcode::bdpas) return false; + auto sr = bdpas.src4SubReg3_5 << 3; + o.direct.regNum = bdpas.src4Reg0_3; + o.direct.regNum |= bdpas.src4Reg4_8 << 4; + rcount = 8; + len = (sr + sdepth * rcount * 4 + 31) >> 5; + break; + } default: return false; } unsigned regNum = o.direct.regNum; + if (xe3p) regNum |= (regNum8 << 8); region = DependencyRegion(hw, GRFRange(regNum, len)); return true; } @@ -987,48 +1308,127 @@ bool Instruction12::getOperandRegion(autoswsb::DependencyRegion ®ion, int opN region = DependencyRegion(hw, GRFRange(base, len)); return true; } + case Opcode::sendg: + case Opcode::sendgc: { + if (send.src0RegFile == RegFileARF && (sendg.src0Reg >> 4) == 0x6) switch (opNum) { + case 0: + region = DependencyRegion(hw); + return true; + case 1: { + /* report s0 dependency as if it came from src1 */ + region = DependencyRegion(hw, sendg.src0Len, ScalarRegister(0)[(send.src0Reg & 0xF) << 1](1)); + return true; + } + default: break; + } + switch (opNum) { + case -1: { + if (sendg.dstRegFile == RegFileARF) return false; + int dstLen = getSendgDesc().dstLen(hw, 1 << common.execSize, static_cast(sendg.sfid)); + if (dstLen == -1) + region = DependencyRegion(hw); + else + region = DependencyRegion(hw, GRFRange(sendg.dstReg, dstLen)); + break; + } + case 0: + if (sendg.src0RegFile == RegFileARF) return false; + region = DependencyRegion(hw, GRFRange(sendg.src0Reg, sendg.src0Len)); + break; + case 1: + if (sendg.src1RegFile == RegFileARF) return false; + region = DependencyRegion(hw, GRFRange(sendg.src1Reg, sendg.src1Len)); + break; + default: return false; + } + return true; + } + case Opcode::sendgx: + case Opcode::sendgxc: { + int regNum = 0, len = -1; + switch (opNum) { + case -1: { + regNum = sendg.dstReg | (sendgx.dstReg8 << 8); + len = getSendgDesc().dstLen(hw, 1 << common.execSize, static_cast(sendg.sfid)); + break; + } + case 0: + regNum = sendg.src0Reg | (sendgx.src0Reg8 << 8); + len = sendg.src0Len; + break; + case 1: + regNum = sendg.src1Reg | (sendgx.src1Reg8 << 8); + len = sendg.src1Len; + break; + default: return false; + } + if (regNum == 0x1FF) + return false; + else if (len == -1) + region = DependencyRegion(hw); + else + region = DependencyRegion(hw, GRFRange(regNum, len)); + return true; + } case Opcode::dp4a: case Opcode::add3: case Opcode::bfn: case Opcode::bfe_gen12: case Opcode::bfi2_gen12: case Opcode::csel_gen12: + case Opcode::dnscl: case Opcode::mad: case Opcode::madm: { // ternary TernaryOperand12 o; unsigned dt = 0, vs = 0; + unsigned regNum8 = 0; switch (opNum) { case -1: o.bits = ternary.dst; dt = ternary.dstType; + regNum8 = ternaryXe3p.dstReg8; break; case 0: if (ternary.src0Imm) return false; o.bits = ternary.src0; dt = ternary.src0Type; vs = ternary.src0VS0 + (ternary.src0VS1 * 3); + regNum8 = ternaryXe3p.src0Reg8; break; case 1: o.bits = ternary.src1; dt = ternary.src1Type; vs = ternary.src1VS0 + (ternary.src1VS1 * 3); + regNum8 = ternaryXe3p.src1Reg8; break; case 2: if (ternary.src2Imm) return false; o.bits = ternary.src2; dt = ternary.src2Type; + regNum8 = ternaryXe3p.src2Reg8; break; default: return false; } dt |= (ternary.execType << 3); if (op == Opcode::madm) o.direct.subRegNum = 0; unsigned regNum = o.direct.regNum; + if (xe3p) regNum |= (regNum8 << 8); auto base = GRF(regNum).retype(decodeRegTypecode12(dt)); auto sr = o.direct.subRegNum; if (xeHPC) sr <<= 1; - auto sub = base[sr / getBytes(base.getType())]; + auto sub = base[sr / base.getBytes()]; auto hs = (1 << o.direct.hs); if (opNum >= 0) hs >>= 1; + if (xe3p && opNum == 1) { + hs = 0; + if (!ternaryXe3p.src1Scalar) { + auto dbytes = getBytes(decodeRegTypecode12((ternary.execType << 3) | ternary.dstType)); + TernaryOperand12 odst; + odst.bits = ternary.dst; + hs = std::max(1, (dbytes << odst.direct.hs) / base.getBytes()); + } + vs = 0; + } if (opNum < 0) rd = sub(hs, 1, 0); else if (opNum == 2) @@ -1043,23 +1443,54 @@ bool Instruction12::getOperandRegion(autoswsb::DependencyRegion ®ion, int opN } break; } + case Opcode::mov_gen12: + case Opcode::not_gen12: + case Opcode::frc: + case Opcode::rndd: + case Opcode::rndu: + case Opcode::rnde: + case Opcode::rndz: + case Opcode::lzd: + case Opcode::fbh: + case Opcode::fbl: + case Opcode::cbit: + case Opcode::math: + unaryXe3p = true; + /* fall through */ default: { // unary/binary BinaryOperand12 o; unsigned dt; + unsigned regNum8 = 0; + + if (op == Opcode::math) switch (static_cast(binary.cmod)) { + case MathFunction::invm: + case MathFunction::rsqtm: unaryXe3p = false; break; + default: break; + } switch (opNum) { case -1: o.bits = binary.dst; dt = binary.dstType; + if (unaryXe3p) + regNum8 = binary.src0Imm ? unaryXe3pImm.dstReg8 : binaryXe3p.dstReg8; + else + regNum8 = binary.src1Imm ? binaryXe3pImm.dstReg8 : binaryXe3p.dstReg8; break; case 0: + if (!xe3p || unaryXe3p) if (binary.src0Imm) return false; o.bits = binary.src0; dt = binary.src0Type; + regNum8 = (!unaryXe3p && binary.src1Imm) ? binaryXe3pImm.src0Reg8 + : binaryXe3p.src0Reg8; + if (xe3p && (!unaryXe3p || op == Opcode::math)) + dt |= (binaryXe3p.execType << 3); break; case 1: if (binary.src0Imm || binary.src1Imm) return false; o.bits = binary.src1; dt = binary.src1Type; + regNum8 = binaryXe3p.src1Reg8; break; default: return false; } @@ -1069,10 +1500,21 @@ bool Instruction12::getOperandRegion(autoswsb::DependencyRegion ®ion, int opN auto sr = xeHPC ? ((o.direct.subRegNum << 1) | o.directXeHPC.subRegNum0) : o.direct.subRegNum; auto regNum = o.direct.regNum; + if (xe3p) regNum |= (regNum8 << 8); auto base = GRF(regNum).retype(decodeRegTypecode12(dt)); - auto sub = base[sr / getBytes(base.getType())]; + auto sub = base[sr / base.getBytes()]; auto hs = (1 << o.direct.hs) >> 1; auto vs = xeHPC ? o.directXeHPC.vs : o.direct.vs; + if (xe3p && opNum == 1) { + hs = 0; + if (!binaryXe3p.src1Scalar) { + auto dbytes = getBytes(decodeRegTypecode12(binary.dstType)); + BinaryOperand12 odst; + odst.bits = binary.dst; + hs = std::max(1, (dbytes << (odst.direct.hs - 1)) / base.getBytes()); + } + rd = sub(hs); + } else if (opNum < 0) rd = sub(hs, 1, 0); else @@ -1090,6 +1532,7 @@ bool Instruction12::getOperandRegion(autoswsb::DependencyRegion ®ion, int opN auto esize = 1 << ((hw >= HW::XeHPC) ? commonXeHPC.execSize : common.execSize); rd.fixup(hw, esize, 0, DataType::invalid, opNum, 2); region = DependencyRegion(hw, esize, rd); + if (op == Opcode::mullh) region.duplicateLH(); return true; } @@ -1108,6 +1551,11 @@ bool Instruction12::getCModDepRegion(autoswsb::DependencyRegion ®ion) const case Opcode::dpas: case Opcode::dpasw: case Opcode::math: + case Opcode::bdpas: + case Opcode::sendg: + case Opcode::sendgc: + case Opcode::sendgx: + case Opcode::sendgxc: return false; default: if (isBranch(op)) @@ -1192,6 +1640,20 @@ int Instruction12::getFencedepJIP() const return int32_t(imm32.value) / sizeof(Instruction12); } +SendgMessageDescriptor Instruction12::getSendgDesc() const +{ + SendgMessageDescriptor desc; + desc.all = uint64_t(sendg.desc0_15) + | (uint64_t(sendg.desc16_27) << 16) + | (uint64_t(sendg.desc28_29) << 28) + | (uint64_t(sendg.desc30_31) << 30) + | (uint64_t(sendg.desc32_39) << 32) + | (uint64_t(sendg.desc40_41) << 40); + if (!sendg.ind1Present) + desc.all |= (uint64_t(sendg.ind1_desc42_46) << 42); + return desc; +} + bool Instruction12::getARFType(ARFType &arfType, int opNum, HW hw) const { if (opNum > 1) return false; diff --git a/third_party/ngen/ngen_interface.hpp b/third_party/ngen/ngen_interface.hpp index ec719406df9..fad6586057b 100644 --- a/third_party/ngen/ngen_interface.hpp +++ b/third_party/ngen/ngen_interface.hpp @@ -61,6 +61,10 @@ class unsupported_argument_location_override : public std::runtime_error { unsupported_argument_location_override() : std::runtime_error("Argument register location is invalid") {} }; +class zebin_required_exception : public std::runtime_error { +public: + zebin_required_exception() : std::runtime_error("zebin is required for this operation") {} +}; #endif enum class ExternalArgumentType { Scalar, GlobalPtr, LocalPtr, Hidden }; @@ -79,8 +83,8 @@ class InterfaceHandler template friend class LevelZeroCodeGenerator; public: - InterfaceHandler(HW hw_) : hw(hw_) - , simd(GRF::bytes(hw_) >> 2) + InterfaceHandler(HW hw_) : hw(hw_), simd(GRF::bytes(hw_) >> 2) + , useEfficient64Bit(hw_ >= HW::XE3P_35_10) , requestedInlineBytes(defaultInlineBytes(hw)) {} @@ -108,6 +112,7 @@ class InterfaceHandler int getBarrierCount() const { return barrierCount; } int getGRFCount() const { return needGRF; } size_t getSLMSize() const { return slmSize; } + bool getEfficient64Bit() const { return useEfficient64Bit; } void require32BitBuffers() { allow64BitBuffers = false; } void requireArbitrationMode(ThreadArbitrationMode m) { arbitrationMode = m; } @@ -121,6 +126,7 @@ class InterfaceHandler void requireNonuniformWGs() { needNonuniformWGs = true; } void requireNoPreemption() { needNoPreemption = true; } void requirePartitionDim(int dim) { needPartitionDim = dim; } + void requireQuantum(int wgs) { needQuantum = wgs; } void requireScratch(size_t bytes = 1) { scratchSize = bytes; } inline void requireSIMD(int simd_); void requireSLM(size_t bytes) { slmSize = bytes; } @@ -136,6 +142,7 @@ class InterfaceHandler void setInlineGRFCount(int grfs) { requestedInlineBytes = grfs * GRF::bytes(hw); } int32_t getSkipCrossThreadOffset() const { return offsetSkipCrossThread; } std::array getCTPatchOffsets() const { return offsetCTPatches; } + void setEfficient64Bit(bool def = true) { useEfficient64Bit = def; } inline Register getCrossthreadBase(bool effective = true) const; inline Register getArgLoadBase() const; @@ -195,6 +202,7 @@ class InterfaceHandler bool needNonuniformWGs = false; bool needNoPreemption = false; int needPartitionDim = -1; + int needQuantum = 0; bool needHalf = false; bool needDouble = false; bool needStatelessWrites = true; @@ -207,6 +215,7 @@ class InterfaceHandler int walkOrder[3] = {-1, -1, -1}; size_t wg[3] = {0, 0, 0}; + bool useEfficient64Bit = false; int crossthreadBytes = 0; int crossthreadRegs = 0; int requestedInlineBytes = 0; @@ -370,6 +379,7 @@ void InterfaceHandler::generateDummyCL(std::ostream &stream) const #ifdef NGEN_SAFE if (!finalized) throw interface_not_finalized(); if (hasArgLocOverride || !rearrangeArgs) throw unsupported_argument_location_override(); + if (needQuantum) throw zebin_required_exception(); #endif const char *dpasDummy = " int __builtin_IB_sub_group_idpas_s8_s8_8_1(int, int, int8) __attribute__((const));\n" " int z = __builtin_IB_sub_group_idpas_s8_s8_8_1(0, ____[0], 1);\n" @@ -562,6 +572,7 @@ void InterfaceHandler::finalize() int InterfaceHandler::inlineBytes() const { + if (useEfficient64Bit) return 64; return requestedInlineBytes; } @@ -630,6 +641,7 @@ void InterfaceHandler::setPrologueLabels(InterfaceLabels &labels, LabelManager & }; int immOffset = 0xC; + if (hw >= HW::XE3P_35_10) immOffset = 0x8; setOffset(labels.localIDsLoaded, offsetSkipPerThread); setOffset(labels.argsLoaded, offsetSkipCrossThread); @@ -647,6 +659,8 @@ std::string InterfaceHandler::generateZeInfo() const md.imbue(std::locale::classic()); const char *version = "1.8"; + if (useEfficient64Bit) version = "1.35"; + if (needQuantum) version = "1.48"; md << "version: " << version << "\n" "kernels: \n" @@ -695,13 +709,28 @@ std::string InterfaceHandler::generateZeInfo() const } if (inlineBytes() > 0) md << " inline_data_payload_size: " << inlineBytes() << "\n"; + if (needQuantum) { + int encodedWO = 0; + if (walkOrder[0] == -1 && walkOrder[1] == -1) + encodedWO = 0; + else if (walkOrder[0] == 0 && walkOrder[1] == 1) + encodedWO = 0; + else if (walkOrder[0] == 1 && walkOrder[1] == 0) + encodedWO = 1; + else + throw std::runtime_error("Unsupported walk order"); + + md << " quantum_size: " << needQuantum << "\n"; + md << " quantum_walk_order: " << encodedWO << "\n"; + md << " quantum_partition_dimension: " << std::max(needPartitionDim, 0) << "\n"; + } if (!assignments.empty()) { md << "\n" " payload_arguments: \n"; } - if (scratchSize > 0) { - md << " - arg_type: scratch_pointer\n" - " offset: 8\n" + if (useEfficient64Bit) { + md << " - arg_type: indirect_data_pointer\n" + " offset: 0\n" " size: 8\n"; } for (auto &assignment : assignments) { diff --git a/third_party/ngen/ngen_level_zero.hpp b/third_party/ngen/ngen_level_zero.hpp index 2d63b680693..61f29d18e41 100644 --- a/third_party/ngen/ngen_level_zero.hpp +++ b/third_party/ngen/ngen_level_zero.hpp @@ -35,9 +35,17 @@ namespace NGEN_NAMESPACE { // Exceptions. class level_zero_error : public std::runtime_error { public: - level_zero_error(ze_result_t status_ = ZE_RESULT_SUCCESS) : std::runtime_error("A Level Zero error occurred."), status(status_) {} + level_zero_error(ze_result_t status_ = ZE_RESULT_SUCCESS) : std::runtime_error("A Level Zero error occurred: " + to_hex(status_)), status(status_) {} protected: ze_result_t status; + +private: + static std::string to_hex(ze_result_t status) { + std::ostringstream oss; + oss.imbue(std::locale::classic()); + oss << std::hex << status; + return "0x" + oss.str(); + } }; // Dynamic loading support. @@ -100,6 +108,8 @@ class LevelZeroCodeGenerator : public ELFCodeGenerator static inline Product detectHWInfo(ze_context_handle_t context, ze_device_handle_t device); static bool binaryIsZebin() { return true; } + + static inline bool detectEfficient64Bit(ze_context_handle_t context, ze_device_handle_t device, HW inHW = HW::Unknown); }; #define NGEN_FORWARD_LEVEL_ZERO(hw) NGEN_FORWARD_ELF(hw) @@ -220,6 +230,16 @@ Product LevelZeroCodeGenerator::detectHWInfo(ze_context_handle_t context, ze return product; } +template +bool LevelZeroCodeGenerator::detectEfficient64Bit(ze_context_handle_t context, ze_device_handle_t device, HW inHW) +{ + if (inHW == HW::Unknown) inHW = hw; + if (inHW < HW::XE3P_35_10) return false; + + auto binary = detail::getDummyModuleBinary(context, device); + return npack::isBinaryEfficient64Bit(binary, inHW); +} + } /* namespace NGEN_NAMESPACE */ #endif diff --git a/third_party/ngen/ngen_opencl.hpp b/third_party/ngen/ngen_opencl.hpp index deb9e73ea89..f80b2ba5cc9 100644 --- a/third_party/ngen/ngen_opencl.hpp +++ b/third_party/ngen/ngen_opencl.hpp @@ -120,6 +120,8 @@ class OpenCLCodeGenerator : public ELFCodeGenerator static inline Product detectHWInfo(cl_device_id device); static inline Product detectHWInfo(cl_context context, cl_device_id device); + static inline bool detectEfficient64Bit(cl_context context, cl_device_id device, HW inHW = HW::Unknown); + private: bool isZebin = false; inline std::vector getPatchTokenBinary(cl_context context, cl_device_id device, const std::vector *code = nullptr, const std::string &options = "-cl-std=CL2.0"); @@ -269,6 +271,7 @@ cl_kernel OpenCLCodeGenerator::getKernel(cl_context context, cl_device_id de for (bool defaultFormat : {true, false}) { bool legacy = defaultFormat ^ zebinFirst; + isZebin = !legacy; if (legacy) { try { @@ -356,6 +359,18 @@ Product OpenCLCodeGenerator::detectHWInfo(cl_context context, cl_device_id d return product; } +template +bool OpenCLCodeGenerator::detectEfficient64Bit(cl_context context, cl_device_id device, HW inHW) +{ + const char *dummyCL = "kernel void _ngen_eff64b_detect(){}"; + + if (inHW == HW::Unknown) inHW = hw; + if (inHW < HW::XE3P_35_10) return false; + + auto binary = detail::getOpenCLCProgramBinary(context, device, dummyCL, ""); + return npack::isBinaryEfficient64Bit(binary, inHW); +} + } /* namespace NGEN_NAMESPACE */ #endif diff --git a/third_party/ngen/ngen_pseudo.hpp b/third_party/ngen/ngen_pseudo.hpp index c808881bedf..f791b19747e 100644 --- a/third_party/ngen/ngen_pseudo.hpp +++ b/third_party/ngen/ngen_pseudo.hpp @@ -118,6 +118,10 @@ void rsqtm(const InstructionModifier &mod, const ExtendedReg &dst, const Extende math
(mod, MathFunction::rsqtm, dst, src0, loc); } template +void sigm(const InstructionModifier &mod, const RegData &dst, const RegData &src0, SourceLocation loc = {}) { + math
(mod, MathFunction::sigm, dst, src0, loc); +} +template void sin(const InstructionModifier &mod, const RegData &dst, const RegData &src0, SourceLocation loc = {}) { math
(mod, MathFunction::sin, dst, src0, loc); } @@ -125,6 +129,10 @@ template void sqt(const InstructionModifier &mod, const RegData &dst, const RegData &src0, SourceLocation loc = {}) { math
(mod, MathFunction::sqt, dst, src0, loc); } +template +void tanh(const InstructionModifier &mod, const RegData &dst, const RegData &src0, SourceLocation loc = {}) { + math
(mod, MathFunction::tanh, dst, src0, loc); +} #define TMP(n) tmp[n].retype(dst.getType()) @@ -312,6 +320,9 @@ void sqt_ieee(const InstructionModifier &mod, FlagRegister flag, RegData dst, Re // Thread spawner messages. void threadend(const InstructionModifier &mod, RegData r0_info = {}, SourceLocation loc = {}) { + if (useEfficient64Bit) + sendgx(1 | EOT | mod | NoMask, SharedFunction::gtwy, null, RegisterRange(r0_info, 1), 0, loc); + else { auto sf = (hardware <= HW::XeHP) ? SharedFunction::ts : SharedFunction::gtwy; @@ -327,6 +338,10 @@ void threadend(const RegData &r0_info = {}, SourceLocation loc = {}) { // Gateway messages. void barriermsg(const InstructionModifier &mod, Register header = {}, SourceLocation loc = {}) { + if (useEfficient64Bit) { + if (header.isInvalid()) header = GRF(0); + sendgx(1 | mod | NoMask, SharedFunction::gtwy, null, RegisterRange(header, 1), 4, loc); + } else { uint32_t exdesc = static_cast(SharedFunction::gtwy) & 0xF; send(1 | mod | NoMask, null, header, exdesc, 0x2000004, loc); @@ -339,6 +354,9 @@ void barriermsg(Register header = {}, SourceLocation loc = {}) { barriermsg(Inst void barrierheader(const Register &header, Register r0_info = {}, SourceLocation loc = {}) { if (r0_info.isInvalid()) r0_info = GRF(0); + if (useEfficient64Bit) + mov(1 | NoMask, header[2], r0_info[2], loc); + else if (hardware >= HW::XeHPG) { mov(1 | NoMask, header.hf(4), Immediate::hf(0), loc); mov(2 | NoMask, header.ub(10)(1), r0_info.ub(11)(0), loc); @@ -349,6 +367,9 @@ void barrierheader(const Register &header, Register r0_info = {}, SourceLocation void barrierheader(const Register &header, uint32_t threadCount, Register r0_info = {}, SourceLocation loc = {}) { if (r0_info.isInvalid()) r0_info = GRF(0); + if (useEfficient64Bit) + mov(1 | NoMask, header.ud(2), threadCount << 24, loc); + else if (hardware >= HW::XeHPG) mov(1 | NoMask, header.ud(2), (threadCount << 24) | (threadCount << 16), loc); else { @@ -359,6 +380,9 @@ void barrierheader(const Register &header, uint32_t threadCount, Register r0_inf void barriersignal(const InstructionModifier &mod = {}, const GRF &temp = {}, Register r0_info = {}, SourceLocation loc = {}) { + if (useEfficient64Bit) + barriermsg(mod, r0_info, loc); + else { barrierheader(temp, r0_info, loc); barriermsg(mod, temp, loc); @@ -375,6 +399,9 @@ void barriersignal(const GRF &temp, uint32_t threadCount, Register r0_info = {}, // Named barriers. void nbarriermsg(const InstructionModifier &mod, const GRF &header, SourceLocation loc = {}) { + if (useEfficient64Bit) + sendgx(1 | mod | NoMask, SharedFunction::gtwy, null, RegisterRange(header, 1), 5, loc); + else barriermsg(mod, header, loc); } @@ -458,6 +485,7 @@ void barrier(uint32_t barrierID, const GRF &temp, BarrierType barrierType, barrierwait(loc); } + void registerfence(const RegData &dst, SourceLocation loc = {}) { _lastFenceDst = dst; @@ -472,6 +500,12 @@ void memfence(const InstructionModifier &mod, FenceScopeLSC scope, FlushTypeLSC { registerfence(dst, loc); + if (useEfficient64Bit) { + uint32_t desc = 0x1F; + desc |= static_cast(flushing) << 8; + desc |= static_cast(scope) << 11; + sendgx(1 | mod | NoMask, SharedFunction::ugm, null, RegisterRange(header, 1), desc, loc); + } else if (hardware >= HW::XeHPG) { if (flushing == FlushTypeLSC::None && hardware == HW::XeHPG && scope > FenceScopeLSC::Subslice) flushing = static_cast(6); /* workaround for DG2 bug */ @@ -535,6 +569,9 @@ void slmfence(const InstructionModifier &mod, const RegData &dst, const RegData { registerfence(dst, loc); + if (useEfficient64Bit) + sendgx(1 | mod | NoMask, SharedFunction::slm, null, RegisterRange(header, 1), 0x1F, loc); + else if (hardware >= HW::XeHPG) send(1 | mod | NoMask, SharedFunction::slm, dst, header, null, 0, 0x210011F, loc); else { @@ -557,6 +594,7 @@ void fencewait(SourceLocation loc = {}) mov(8 | NoMask, null, _lastFenceDst, loc); } + // XeHP+ prologues. void loadlid(int argBytes, int dims = 3, int simd = 8, const GRF &temp = GRF(127), int paddedSize = 0, SourceLocation loc = {}) { @@ -576,6 +614,22 @@ void loadlid(int argBytes, int dims = 3, int simd = 8, const GRF &temp = GRF(127 defaultModifier |= NoMask | AutoSWSB; + if (useEfficient64Bit) { /* to do: SIMD1 */ + uint16_t stride = simdGRFs * grfSize; + auto base = s0.uq(2); + and_(1, acc0[0], r0.uw(4), 0xFF, loc); + mov(1, base, r0.uq(7), loc); + mad(1, acc0[0], uint16_t(argBytes), acc0[0], uint16_t(3 * stride), loc); + mov(16, r4, r1, loc); + markIfUndefined(_interfaceLabels.crossThreadPatches[0]); + add(1, temp[0], acc0[0], Immediate::ud(0), loc); /* relocation */ + load(1, r1, D32T(std::min(dims, 2) * stride / 4) | L1C_L3CC, A64_A32U, temp + base, loc); + insns = 6; + if (dims == 3) { + load(1, GRF(1 + 2 * simdGRFs), D32T(stride / 4) | L1C_L3CC, A64_A32U, temp + base + stride/2, loc); + insns++; + } + } else { insns = lsc ? 5 : 6; if (!lsc) @@ -617,6 +671,10 @@ void loadlid(int argBytes, int dims = 3, int simd = 8, const GRF &temp = GRF(127 markIfUndefined(_interfaceLabels.localIDsLoaded); + /* Workaround for incorrect NEO/XeSim handling of crossthread entrance */ + if (useEfficient64Bit) + for (int i = 0; i < 4; i++) + nop(loc); } void loadlid(int argBytes, int dims, int simd, const GRF &temp, SourceLocation loc = {}) { loadlid(argBytes, dims, simd, temp, 0, loc); } @@ -635,6 +693,23 @@ void loadargs(const Register &base, int argRegs, const GRF &temp, bool inPrologu auto dmSave = defaultModifier; defaultModifier |= NoMask | AutoSWSB; + if (useEfficient64Bit) { /* to do: SIMD1 */ + int offset = 0; + auto offsetRT = s0.uq(3); + auto addr = inPrologue ? r4 : temp; + if (!inPrologue) + mov(1, addr, r0[7], loc); + markIfUndefined(_interfaceLabels.crossThreadPatches[1]); + mov(1, offsetRT, Immediate::uq(0), loc); /* relocation */ + while (argRegs > 0) { + int nload = std::min(utils::rounddown_pow2(argRegs), 8); + int loadBytes = nload * GRF::bytes(hardware); + load(1, dst, D64T(loadBytes >> 3) | L1C_L3CC, A64, addr + offsetRT + offset, loc); + argRegs -= nload; + dst += nload; + offset += loadBytes; + } + } else { if (!lsc) mov(8, temp, uint16_t(0), loc); @@ -735,6 +810,14 @@ struct Load { template void operator()(SharedFunction sfid, const InstructionModifier &mod, const RegData &dst, const DataSpec &spec, AddressBase base, const GRFDisp &addr, SourceLocation loc = {}) { + if (parent.useEfficient64Bit) { + SendgMessageDescriptor desc; + int dstLen, src0Len; + encodeLoadDescriptor(parent.hardware, desc, sfid, dstLen, src0Len, mod, spec, base, addr); + if (!dst.isNull() && dstLen > 0) + parent.subdep(Operand::dst, GRFRange(dst.getBase(), dstLen)); + parent.sendgx(mod, sfid, dst, RegisterRange(addr.getBase(), src0Len), addr.getInd0(), desc.all, loc); + } else { MessageDescriptor desc; ExtendedMessageDescriptor exdesc; @@ -810,6 +893,12 @@ struct Store { template void operator()(SharedFunction sfid, const InstructionModifier &mod, const DataSpec &spec, AddressBase base, const GRFDisp &addr, const RegData &data, SourceLocation loc = {}) { + if (parent.useEfficient64Bit) { + SendgMessageDescriptor desc; + int src0Len, src1Len; + encodeStoreDescriptor(parent.hardware, desc, sfid, src0Len, src1Len, mod, spec, base, addr); + parent.sendgx(mod, sfid, NullRegister(), RegisterRange(addr.getBase(), src0Len), RegisterRange(data, src1Len), addr.getInd0(), desc.all, loc); + } else { MessageDescriptor desc; ExtendedMessageDescriptor exdesc; @@ -896,6 +985,15 @@ struct Atomic_ { template void operator()(SharedFunction sfid, AtomicOp op, const InstructionModifier &mod, const RegData &dst, const DataSpec &spec, AddressBase base, const GRFDisp &addr, const RegData &data, SourceLocation loc = {}) { + if (parent.useEfficient64Bit) { + SendgMessageDescriptor desc; + int src0Len, src1Len; + encodeAtomicDescriptor(parent.hardware, desc, sfid, src0Len, src1Len, op, mod, spec, base, addr); + if (data.isNull()) + parent.sendgx(mod, sfid, dst, RegisterRange(addr.getBase(), src0Len), addr.getInd0(), desc.all, loc); + else + parent.sendgx(mod, sfid, dst, RegisterRange(addr.getBase(), src0Len), RegisterRange(data, src1Len), addr.getInd0(), desc.all, loc); + } else { MessageDescriptor desc; ExtendedMessageDescriptor exdesc; diff --git a/third_party/ngen/ngen_register_allocator.hpp b/third_party/ngen/ngen_register_allocator.hpp index 8c8a667ddfe..848c0cc328c 100644 --- a/third_party/ngen/ngen_register_allocator.hpp +++ b/third_party/ngen/ngen_register_allocator.hpp @@ -138,6 +138,7 @@ class RegisterAllocator { template Subregister allocSub(Bundle bundle = Bundle()) { return allocSub(getDataType(), bundle); } + // Allocate flag registers. // sub = true (default): a 16-bit subregister (fX.Y:uw) // sub = false: a full 32-bit register (fX.0:ud) @@ -152,6 +153,7 @@ class RegisterAllocator { template Subregister tryAllocSub(Bundle bundle = Bundle()) { return tryAllocSub(getDataType(), bundle); } + inline FlagRegister tryAllocFlag(bool sub = true); // Release a previous allocation or claim. @@ -169,16 +171,19 @@ class RegisterAllocator { inline void claim(Subregister subreg); inline void claim(FlagRegister flag); + // Set register count. inline void setRegisterCount(int rcount); inline int getRegisterCount() const { return regCount; } inline int countAllocedRegisters() const; + // Check availability. inline bool isFree(GRF reg) const; inline bool isFree(GRFRange range) const; inline bool isFree(Subregister subreg) const; + #ifdef NGEN_ASM inline void dump(std::ostream &str); #endif @@ -199,6 +204,7 @@ class RegisterAllocator { Subregister try_alloc_sub(Bundle bundle = Bundle()) { return tryAllocSub(getDataType(), bundle); } FlagRegister try_alloc_flag(bool sub = true) { return tryAllocFlag(sub); } + protected: using mtype = uint16_t; @@ -210,6 +216,7 @@ class RegisterAllocator { uint8_t freeFlag; // Bitmap of free flag registers. mtype fullSubMask; + inline void init(); inline void claimSub(int r, int o, int dw); }; @@ -242,6 +249,9 @@ int Bundle::firstReg(HW hw) const case HW::XeHPC: case HW::Xe2: case HW::Xe3: + case HW::XE3P_35_10: + case HW::XE3P_35_11: + case HW::XE3P_UNKNOWN: return (bundle0 << 1) | bank0; case HW::XeHP: case HW::XeHPG: @@ -278,6 +288,9 @@ int Bundle::stride(HW hw) const case HW::Gen12LP: case HW::Xe2: case HW::Xe3: + case HW::XE3P_35_10: + case HW::XE3P_35_11: + case HW::XE3P_UNKNOWN: return 16; case HW::XeHP: case HW::XeHPG: @@ -308,6 +321,9 @@ uint64_t Bundle::regMask(HW hw, int offset) const case HW::Gen12LP: case HW::Xe2: case HW::Xe3: + case HW::XE3P_35_10: + case HW::XE3P_35_11: + case HW::XE3P_UNKNOWN: if (bundle_id != any) base_mask = 0x0003000300030003; if (bank_id != any) base_mask &= 0x5555555555555555; return base_mask << (bank0 + (bundle0 << 1)); @@ -338,6 +354,9 @@ Bundle Bundle::locate(HW hw, RegData reg) case HW::Gen12LP: case HW::Xe2: case HW::Xe3: + case HW::XE3P_35_10: + case HW::XE3P_35_11: + case HW::XE3P_UNKNOWN: return Bundle(base & 1, (base >> 1) & 7); case HW::XeHP: case HW::XeHPG: @@ -368,6 +387,8 @@ void RegisterAllocator::init() if (hw < HW::XeHP) setRegisterCount(128); + else if (hw < HW::XE3P_35_10) + setRegisterCount(256); } @@ -400,6 +421,7 @@ void RegisterAllocator::claimSub(int r, int o, int dw) freeGRF[r >> 3] &= ~(1 << (r & 7)); } + void RegisterAllocator::claim(FlagRegister flag) { freeFlag &= ~(1 << flag.index()); @@ -425,6 +447,7 @@ void RegisterAllocator::setRegisterCount(int rcount) regCount = rcount; } + int RegisterAllocator::countAllocedRegisters() const { int alloced = 0; @@ -456,11 +479,13 @@ void RegisterAllocator::release(Subregister subreg) { int dw = subreg.getDwords(); int o = (subreg.getByteOffset()) >> 2; + freeSub[r] |= (1 << (o + dw)) - (1 << o); if (freeSub[r] == fullSubMask) freeGRF[r >> 3] |= (1 << (r & 7)); } + void RegisterAllocator::release(FlagRegister flag) { if (flag.isInvalid()) return; @@ -495,6 +520,7 @@ bool RegisterAllocator::isFree(Subregister subreg) const return (~freeSub[r] & m) == 0; } + // ------------------------------------------- // High-level register allocation functions. // ------------------------------------------- @@ -515,6 +541,7 @@ Subregister RegisterAllocator::allocSub(DataType type, Bundle bundle) return result; } + FlagRegister RegisterAllocator::allocFlag(bool sub) { auto result = tryAllocFlag(sub); @@ -523,6 +550,7 @@ FlagRegister RegisterAllocator::allocFlag(bool sub) return result; } + GRFRange RegisterAllocator::tryAllocRange(int nregs, Bundle baseBundle, BundleGroup bundleMask) { if (nregs == 0) return GRFRange(0, 0); @@ -534,6 +562,7 @@ GRFRange RegisterAllocator::tryAllocRange(int nregs, Bundle baseBundle, BundleGr for (int rchunk = 0; rchunk < (GRF::maxRegs() >> 6); rchunk++) { uint64_t baseMask = baseBundle.regMask(hw, rchunk); + uint64_t free = freeGRF64[rchunk] & bundleMask.regMask(rchunk); uint64_t freeBase = free & baseMask; @@ -591,12 +620,20 @@ Subregister RegisterAllocator::tryAllocSub(DataType type, Bundle bundle) int dwords = getDwords(type); int rAlloc = 0, oAlloc = 0; + auto findAllocSub = [&,bundle,dwords](bool searchFullGRF) -> bool { static const uint16_t alloc_patterns[4] = {0b1111111111111111, 0b0101010101010101, 0, 0b0001000100010001}; auto alloc_pattern = alloc_patterns[(dwords - 1) & 3]; uint64_t freeGRF64[sizeof(freeGRF) / sizeof(uint64_t)]; std::memcpy(freeGRF64, freeGRF, sizeof(freeGRF)); + /* Preferentially use r511 for small allocations as it can't be used in sendgx. */ + if (searchFullGRF && freeSub[511] == fullSubMask) { + rAlloc = 511; + oAlloc = 0; + return true; + } + for (int rchunk = 0; rchunk < (GRF::maxRegs() >> 6); rchunk++) { uint64_t free = searchFullGRF ? freeGRF64[rchunk] : -1; free &= bundle.regMask(hw, rchunk); @@ -637,10 +674,12 @@ Subregister RegisterAllocator::tryAllocSub(DataType type, Bundle bundle) return Subregister(GRF(rAlloc), (oAlloc << 2) / getBytes(type), type); } + FlagRegister RegisterAllocator::tryAllocFlag(bool sub) { if (!freeFlag) return FlagRegister(); + if (sub) { int idx = utils::bsf(freeFlag); freeFlag &= (freeFlag - 1); // clear lowest bit. @@ -693,6 +732,7 @@ void RegisterAllocator::dump(std::ostream &str) } } + str << std::endl; } #endif /* NGEN_ASM */ diff --git a/third_party/ngen/ngen_register_decl.hpp b/third_party/ngen/ngen_register_decl.hpp index 513c7f5a7d4..aacb8bddb7f 100644 --- a/third_party/ngen/ngen_register_decl.hpp +++ b/third_party/ngen/ngen_register_decl.hpp @@ -490,7 +490,290 @@ PREFIX constexpr NGEN_NAMESPACE::CacheSettingsLSC CG::L1UC_L3CC; #define NGEN_REGISTER_DECL_EXTRA2(CG,PREFIX) \ PREFIX constexpr NGEN_NAMESPACE::ScalarRegister CG::s0; -#define NGEN_REGISTER_DECL_EXTRA3(CG,PREFIX) +#define NGEN_REGISTER_DECL_EXTRA3(CG,PREFIX) \ +PREFIX constexpr NGEN_NAMESPACE::InstructionModifier CG::Fwd; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r256; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r257; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r258; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r259; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r260; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r261; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r262; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r263; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r264; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r265; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r266; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r267; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r268; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r269; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r270; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r271; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r272; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r273; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r274; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r275; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r276; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r277; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r278; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r279; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r280; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r281; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r282; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r283; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r284; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r285; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r286; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r287; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r288; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r289; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r290; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r291; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r292; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r293; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r294; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r295; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r296; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r297; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r298; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r299; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r300; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r301; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r302; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r303; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r304; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r305; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r306; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r307; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r308; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r309; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r310; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r311; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r312; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r313; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r314; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r315; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r316; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r317; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r318; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r319; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r320; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r321; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r322; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r323; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r324; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r325; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r326; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r327; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r328; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r329; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r330; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r331; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r332; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r333; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r334; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r335; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r336; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r337; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r338; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r339; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r340; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r341; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r342; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r343; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r344; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r345; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r346; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r347; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r348; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r349; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r350; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r351; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r352; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r353; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r354; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r355; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r356; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r357; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r358; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r359; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r360; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r361; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r362; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r363; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r364; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r365; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r366; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r367; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r368; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r369; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r370; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r371; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r372; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r373; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r374; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r375; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r376; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r377; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r378; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r379; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r380; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r381; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r382; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r383; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r384; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r385; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r386; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r387; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r388; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r389; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r390; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r391; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r392; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r393; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r394; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r395; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r396; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r397; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r398; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r399; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r400; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r401; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r402; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r403; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r404; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r405; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r406; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r407; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r408; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r409; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r410; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r411; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r412; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r413; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r414; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r415; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r416; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r417; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r418; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r419; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r420; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r421; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r422; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r423; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r424; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r425; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r426; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r427; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r428; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r429; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r430; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r431; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r432; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r433; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r434; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r435; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r436; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r437; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r438; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r439; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r440; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r441; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r442; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r443; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r444; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r445; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r446; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r447; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r448; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r449; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r450; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r451; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r452; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r453; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r454; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r455; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r456; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r457; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r458; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r459; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r460; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r461; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r462; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r463; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r464; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r465; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r466; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r467; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r468; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r469; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r470; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r471; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r472; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r473; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r474; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r475; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r476; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r477; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r478; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r479; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r480; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r481; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r482; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r483; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r484; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r485; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r486; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r487; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r488; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r489; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r490; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r491; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r492; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r493; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r494; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r495; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r496; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r497; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r498; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r499; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r500; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r501; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r502; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r503; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r504; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r505; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r506; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r507; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r508; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r509; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r510; \ +PREFIX constexpr NGEN_NAMESPACE::GRF CG::r511; \ +PREFIX constexpr NGEN_NAMESPACE::AddressBase CG::A64_A32U; \ +PREFIX constexpr NGEN_NAMESPACE::AddressBase CG::A64_A32S; \ +PREFIX constexpr NGEN_NAMESPACE::DataSpecLSC CG::Overfetch; \ +PREFIX constexpr NGEN_NAMESPACE::CacheSettingsLSC CG::L1UC_L2UC_L3UC; \ +PREFIX constexpr NGEN_NAMESPACE::CacheSettingsLSC CG::L1UC_L2UC_L3C; \ +PREFIX constexpr NGEN_NAMESPACE::CacheSettingsLSC CG::L1UC_L2C_L3UC; \ +PREFIX constexpr NGEN_NAMESPACE::CacheSettingsLSC CG::L1UC_L2C_L3C; \ +PREFIX constexpr NGEN_NAMESPACE::CacheSettingsLSC CG::L1C_L2UC_L3UC; \ +PREFIX constexpr NGEN_NAMESPACE::CacheSettingsLSC CG::L1C_L2UC_L3C; \ +PREFIX constexpr NGEN_NAMESPACE::CacheSettingsLSC CG::L1C_L2C_L3UC; \ +PREFIX constexpr NGEN_NAMESPACE::CacheSettingsLSC CG::L1C_L2C_L3C; \ +PREFIX constexpr NGEN_NAMESPACE::CacheSettingsLSC CG::L1S_L2UC_L3UC; \ +PREFIX constexpr NGEN_NAMESPACE::CacheSettingsLSC CG::L1S_L2UC_L3C; \ +PREFIX constexpr NGEN_NAMESPACE::CacheSettingsLSC CG::L1S_L2C_L3UC; \ +PREFIX constexpr NGEN_NAMESPACE::CacheSettingsLSC CG::L1S_L2C_L3C; \ +PREFIX constexpr NGEN_NAMESPACE::CacheSettingsLSC CG::L1IAR_L2IAR_L3IAR; \ +PREFIX constexpr NGEN_NAMESPACE::CacheSettingsLSC CG::L1UC_L2UC_L3WB; \ +PREFIX constexpr NGEN_NAMESPACE::CacheSettingsLSC CG::L1UC_L2WB_L3UC; \ +PREFIX constexpr NGEN_NAMESPACE::CacheSettingsLSC CG::L1WT_L2UC_L3UC; \ +PREFIX constexpr NGEN_NAMESPACE::CacheSettingsLSC CG::L1WT_L2UC_L3WB; \ +PREFIX constexpr NGEN_NAMESPACE::CacheSettingsLSC CG::L1WT_L2WB_L3UC; \ +PREFIX constexpr NGEN_NAMESPACE::CacheSettingsLSC CG::L1S_L2UC_L3WB; \ +PREFIX constexpr NGEN_NAMESPACE::CacheSettingsLSC CG::L1S_L2WB_L3UC; \ +PREFIX constexpr NGEN_NAMESPACE::CacheSettingsLSC CG::L1S_L2WB_L3WB; \ +PREFIX constexpr NGEN_NAMESPACE::CacheSettingsLSC CG::L1WB_L2WB_L3UC; \ +PREFIX constexpr NGEN_NAMESPACE::CacheSettingsLSC CG::L1WB_L2UC_L3WB; #define NGEN_REGISTER_DECL_EXTRA4(CG,PREFIX) @@ -520,6 +803,9 @@ template class NGEN_NAMESPACE::BinaryCodeGenerator; template class NGEN_NAMESPACE::BinaryCodeGenerator; template class NGEN_NAMESPACE::BinaryCodeGenerator; template class NGEN_NAMESPACE::BinaryCodeGenerator; +template class NGEN_NAMESPACE::BinaryCodeGenerator; +template class NGEN_NAMESPACE::BinaryCodeGenerator; +template class NGEN_NAMESPACE::BinaryCodeGenerator; #endif /* (defined(NGEN_CPP11) || defined(NGEN_CPP14)) && !defined(NGEN_GLOBAL_REGS) */ diff --git a/third_party/ngen/ngen_registers.hpp b/third_party/ngen/ngen_registers.hpp index 8902025fa62..3d89461d8e6 100644 --- a/third_party/ngen/ngen_registers.hpp +++ b/third_party/ngen/ngen_registers.hpp @@ -63,6 +63,38 @@ static constexpr_reg GRF r224{224}, r225{225}, r226{226}, r227{227}, r228{228}, static constexpr_reg GRF r232{232}, r233{233}, r234{234}, r235{235}, r236{236}, r237{237}, r238{238}, r239{239}; static constexpr_reg GRF r240{240}, r241{241}, r242{242}, r243{243}, r244{244}, r245{245}, r246{246}, r247{247}; static constexpr_reg GRF r248{248}, r249{249}, r250{250}, r251{251}, r252{252}, r253{253}, r254{254}, r255{255}; +static constexpr_reg GRF r256{256}, r257{257}, r258{258}, r259{259}, r260{260}, r261{261}, r262{262}, r263{263}; +static constexpr_reg GRF r264{264}, r265{265}, r266{266}, r267{267}, r268{268}, r269{269}, r270{270}, r271{271}; +static constexpr_reg GRF r272{272}, r273{273}, r274{274}, r275{275}, r276{276}, r277{277}, r278{278}, r279{279}; +static constexpr_reg GRF r280{280}, r281{281}, r282{282}, r283{283}, r284{284}, r285{285}, r286{286}, r287{287}; +static constexpr_reg GRF r288{288}, r289{289}, r290{290}, r291{291}, r292{292}, r293{293}, r294{294}, r295{295}; +static constexpr_reg GRF r296{296}, r297{297}, r298{298}, r299{299}, r300{300}, r301{301}, r302{302}, r303{303}; +static constexpr_reg GRF r304{304}, r305{305}, r306{306}, r307{307}, r308{308}, r309{309}, r310{310}, r311{311}; +static constexpr_reg GRF r312{312}, r313{313}, r314{314}, r315{315}, r316{316}, r317{317}, r318{318}, r319{319}; +static constexpr_reg GRF r320{320}, r321{321}, r322{322}, r323{323}, r324{324}, r325{325}, r326{326}, r327{327}; +static constexpr_reg GRF r328{328}, r329{329}, r330{330}, r331{331}, r332{332}, r333{333}, r334{334}, r335{335}; +static constexpr_reg GRF r336{336}, r337{337}, r338{338}, r339{339}, r340{340}, r341{341}, r342{342}, r343{343}; +static constexpr_reg GRF r344{344}, r345{345}, r346{346}, r347{347}, r348{348}, r349{349}, r350{350}, r351{351}; +static constexpr_reg GRF r352{352}, r353{353}, r354{354}, r355{355}, r356{356}, r357{357}, r358{358}, r359{359}; +static constexpr_reg GRF r360{360}, r361{361}, r362{362}, r363{363}, r364{364}, r365{365}, r366{366}, r367{367}; +static constexpr_reg GRF r368{368}, r369{369}, r370{370}, r371{371}, r372{372}, r373{373}, r374{374}, r375{375}; +static constexpr_reg GRF r376{376}, r377{377}, r378{378}, r379{379}, r380{380}, r381{381}, r382{382}, r383{383}; +static constexpr_reg GRF r384{384}, r385{385}, r386{386}, r387{387}, r388{388}, r389{389}, r390{390}, r391{391}; +static constexpr_reg GRF r392{392}, r393{393}, r394{394}, r395{395}, r396{396}, r397{397}, r398{398}, r399{399}; +static constexpr_reg GRF r400{400}, r401{401}, r402{402}, r403{403}, r404{404}, r405{405}, r406{406}, r407{407}; +static constexpr_reg GRF r408{408}, r409{409}, r410{410}, r411{411}, r412{412}, r413{413}, r414{414}, r415{415}; +static constexpr_reg GRF r416{416}, r417{417}, r418{418}, r419{419}, r420{420}, r421{421}, r422{422}, r423{423}; +static constexpr_reg GRF r424{424}, r425{425}, r426{426}, r427{427}, r428{428}, r429{429}, r430{430}, r431{431}; +static constexpr_reg GRF r432{432}, r433{433}, r434{434}, r435{435}, r436{436}, r437{437}, r438{438}, r439{439}; +static constexpr_reg GRF r440{440}, r441{441}, r442{442}, r443{443}, r444{444}, r445{445}, r446{446}, r447{447}; +static constexpr_reg GRF r448{448}, r449{449}, r450{450}, r451{451}, r452{452}, r453{453}, r454{454}, r455{455}; +static constexpr_reg GRF r456{456}, r457{457}, r458{458}, r459{459}, r460{460}, r461{461}, r462{462}, r463{463}; +static constexpr_reg GRF r464{464}, r465{465}, r466{466}, r467{467}, r468{468}, r469{469}, r470{470}, r471{471}; +static constexpr_reg GRF r472{472}, r473{473}, r474{474}, r475{475}, r476{476}, r477{477}, r478{478}, r479{479}; +static constexpr_reg GRF r480{480}, r481{481}, r482{482}, r483{483}, r484{484}, r485{485}, r486{486}, r487{487}; +static constexpr_reg GRF r488{488}, r489{489}, r490{490}, r491{491}, r492{492}, r493{493}, r494{494}, r495{495}; +static constexpr_reg GRF r496{496}, r497{497}, r498{498}, r499{499}, r500{500}, r501{501}, r502{502}, r503{503}; +static constexpr_reg GRF r504{504}, r505{505}, r506{506}, r507{507}, r508{508}, r509{509}, r510{510}, r511{511}; static constexpr_reg NullRegister null{}; static constexpr_reg AddressRegister a0{0}; @@ -95,6 +127,7 @@ static constexpr_reg InstructionModifier NoDDClr = InstructionModifier::createNo static constexpr_reg InstructionModifier NoDDChk = InstructionModifier::createNoDDChk(); static constexpr_reg InstructionModifier AccWrEn = InstructionModifier::createAccWrCtrl(); static constexpr_reg InstructionModifier NoSrcDepSet = AccWrEn; +static constexpr_reg InstructionModifier Fwd = InstructionModifier::createFwd(); static constexpr_reg InstructionModifier Breakpoint = InstructionModifier::createDebugCtrl(); static constexpr_reg InstructionModifier sat = InstructionModifier::createSaturate(); static constexpr_reg InstructionModifier NoMask = InstructionModifier::createMaskCtrl(true); @@ -170,6 +203,9 @@ static constexpr_reg AddressBase A64 = AddressBase::createA64(true); static constexpr_reg AddressBase A64NC = AddressBase::createA64(false); static constexpr_reg AddressBase SLM = AddressBase::createSLM(); +static constexpr_reg AddressBase A64_A32U = AddressBase::createA64A32U(); +static constexpr_reg AddressBase A64_A32S = AddressBase::createA64A32S(); + static inline AddressBase Surface(uint8_t index) { return AddressBase::createBTS(index); } static inline AddressBase CC(uint8_t index) { return AddressBase::createCC(index); } static inline AddressBase SC(uint8_t index) { return AddressBase::createSC(index); } @@ -212,6 +248,7 @@ static constexpr_reg DataSpecLSC V64T = DataSpecLSC::createV(64,7) | DataSpecLSC static constexpr_reg DataSpecLSC transpose = DataSpecLSC::createTranspose(); static constexpr_reg DataSpecLSC vnni = DataSpecLSC::createVNNI(); +static constexpr_reg DataSpecLSC Overfetch = DataSpecLSC::createOverfetch(); static constexpr_reg CacheSettingsLSC L1UC_L3UC = CacheSettingsLSC::L1UC_L3UC; static constexpr_reg CacheSettingsLSC L1UC_L3C = CacheSettingsLSC::L1UC_L3C; @@ -227,4 +264,27 @@ static constexpr_reg CacheSettingsLSC L1S_L3WB = CacheSettingsLSC::L1S_L3WB; static constexpr_reg CacheSettingsLSC L1WB_L3WB = CacheSettingsLSC::L1WB_L3WB; static constexpr_reg CacheSettingsLSC L1C_L3CC = CacheSettingsLSC::L1C_L3CC; static constexpr_reg CacheSettingsLSC L1UC_L3CC = CacheSettingsLSC::L1UC_L3CC; +static constexpr_reg CacheSettingsLSC L1UC_L2UC_L3UC = CacheSettingsLSC::L1UC_L2UC_L3UC; +static constexpr_reg CacheSettingsLSC L1UC_L2UC_L3C = CacheSettingsLSC::L1UC_L2UC_L3C; +static constexpr_reg CacheSettingsLSC L1UC_L2C_L3UC = CacheSettingsLSC::L1UC_L2C_L3UC; +static constexpr_reg CacheSettingsLSC L1UC_L2C_L3C = CacheSettingsLSC::L1UC_L2C_L3C; +static constexpr_reg CacheSettingsLSC L1C_L2UC_L3UC = CacheSettingsLSC::L1C_L2UC_L3UC; +static constexpr_reg CacheSettingsLSC L1C_L2UC_L3C = CacheSettingsLSC::L1C_L2UC_L3C; +static constexpr_reg CacheSettingsLSC L1C_L2C_L3UC = CacheSettingsLSC::L1C_L2C_L3UC; +static constexpr_reg CacheSettingsLSC L1C_L2C_L3C = CacheSettingsLSC::L1C_L2C_L3C; +static constexpr_reg CacheSettingsLSC L1S_L2UC_L3UC = CacheSettingsLSC::L1S_L2UC_L3UC; +static constexpr_reg CacheSettingsLSC L1S_L2UC_L3C = CacheSettingsLSC::L1S_L2UC_L3C; +static constexpr_reg CacheSettingsLSC L1S_L2C_L3UC = CacheSettingsLSC::L1S_L2C_L3UC; +static constexpr_reg CacheSettingsLSC L1S_L2C_L3C = CacheSettingsLSC::L1S_L2C_L3C; +static constexpr_reg CacheSettingsLSC L1IAR_L2IAR_L3IAR = CacheSettingsLSC::L1IAR_L2IAR_L3IAR; +static constexpr_reg CacheSettingsLSC L1UC_L2UC_L3WB = CacheSettingsLSC::L1UC_L2UC_L3WB; +static constexpr_reg CacheSettingsLSC L1UC_L2WB_L3UC = CacheSettingsLSC::L1UC_L2WB_L3UC; +static constexpr_reg CacheSettingsLSC L1WT_L2UC_L3UC = CacheSettingsLSC::L1WT_L2UC_L3UC; +static constexpr_reg CacheSettingsLSC L1WT_L2UC_L3WB = CacheSettingsLSC::L1WT_L2UC_L3WB; +static constexpr_reg CacheSettingsLSC L1WT_L2WB_L3UC = CacheSettingsLSC::L1WT_L2WB_L3UC; +static constexpr_reg CacheSettingsLSC L1S_L2UC_L3WB = CacheSettingsLSC::L1S_L2UC_L3WB; +static constexpr_reg CacheSettingsLSC L1S_L2WB_L3UC = CacheSettingsLSC::L1S_L2WB_L3UC; +static constexpr_reg CacheSettingsLSC L1S_L2WB_L3WB = CacheSettingsLSC::L1S_L2WB_L3WB; +static constexpr_reg CacheSettingsLSC L1WB_L2WB_L3UC = CacheSettingsLSC::L1WB_L2WB_L3UC; +static constexpr_reg CacheSettingsLSC L1WB_L2UC_L3WB = CacheSettingsLSC::L1WB_L2UC_L3WB; diff --git a/third_party/ngen/ngen_shortcuts.hpp b/third_party/ngen/ngen_shortcuts.hpp index 97dab08f9ba..75cb2afd001 100644 --- a/third_party/ngen/ngen_shortcuts.hpp +++ b/third_party/ngen/ngen_shortcuts.hpp @@ -18,6 +18,7 @@ * Do not #include this file directly; ngen uses it internally. */ + template void add(const RegData &dst, const RegData &src0, const RegData &src1, SourceLocation loc = {}) { add
(defaultMods(), dst, src0, src1, loc); } @@ -181,6 +182,12 @@ template void mul(const RegData &dst, const RegData &src0, const Immediate &src1, SourceLocation loc = {}) { mul
(defaultMods(), dst, src0, src1, loc); } + template void mullh(const RegData &dst, const RegData &src0, const RegData &src1, SourceLocation loc = {}) { + mullh
(defaultMods(), dst, src0, src1, loc); + } + template void mullh(const RegData &dst, const RegData &src0, const Immediate &src1, SourceLocation loc = {}) { + mullh
(defaultMods(), dst, src0, src1, loc); + } template void not_(const RegData &dst, const RegData &src0, SourceLocation loc = {}) { not_
(defaultMods(), dst, src0, loc); } diff --git a/third_party/ngen/ngen_utils.hpp b/third_party/ngen/ngen_utils.hpp index 2fb315bee71..247f7504527 100644 --- a/third_party/ngen/ngen_utils.hpp +++ b/third_party/ngen/ngen_utils.hpp @@ -19,6 +19,7 @@ #include "ngen_config_internal.hpp" +#include #include #include diff --git a/third_party/ngen/npack/neo_packager.hpp b/third_party/ngen/npack/neo_packager.hpp index 1d77c0b5028..2da01596d68 100644 --- a/third_party/ngen/npack/neo_packager.hpp +++ b/third_party/ngen/npack/neo_packager.hpp @@ -182,35 +182,50 @@ inline void replaceKernel(std::vector &binary, const std::vector GfxCoreFamily::XE3P_35_10) { + return HW::XE3P_UNKNOWN; + } else { + return HW::Unknown; + } } } inline GfxCoreFamily encodeGfxCoreFamily(HW hw) { switch (hw) { - case HW::Gen9: return GfxCoreFamily::Gen9; - case HW::Gen10: return GfxCoreFamily::Gen10; - case HW::Gen11: return GfxCoreFamily::Gen11LP; - case HW::Gen12LP: return GfxCoreFamily::Gen12LP; - case HW::XeHP: return GfxCoreFamily::XeHP; - case HW::XeHPG: return GfxCoreFamily::XeHPG; - case HW::XeHPC: return GfxCoreFamily::XeHPC; - case HW::Xe2: return GfxCoreFamily::Xe2; - case HW::Xe3: return GfxCoreFamily::Xe3; - default: return GfxCoreFamily::Unknown; + case HW::Gen9: return GfxCoreFamily::Gen9; + case HW::Gen10: return GfxCoreFamily::Gen10; + case HW::Gen11: return GfxCoreFamily::Gen11LP; + case HW::Gen12LP: return GfxCoreFamily::Gen12LP; + case HW::XeHP: return GfxCoreFamily::XeHP; + case HW::XeHPG: return GfxCoreFamily::XeHPG; + case HW::XeHPC: return GfxCoreFamily::XeHPC; + case HW::Xe2: return GfxCoreFamily::Xe2; + case HW::Xe3: return GfxCoreFamily::Xe3; + case HW::XE3P_35_10: return GfxCoreFamily::XE3P_35_10; + case HW::XE3P_35_11: return GfxCoreFamily::XE3P_35_11; + case HW::XE3P_UNKNOWN: return GfxCoreFamily::XE3P_UNKNOWN; + default: if (hw > HW::XE3P_35_10) { + return GfxCoreFamily::XE3P_UNKNOWN; + } else { + return GfxCoreFamily::Unknown; + } } } @@ -228,6 +243,9 @@ inline NGEN_NAMESPACE::ProductFamily decodeProductFamily(ProductFamily family) if (family == ProductFamily::BMG) return NGEN_NAMESPACE::ProductFamily::BMG; if (family >= ProductFamily::LNL && family <= ProductFamily::LNL_M) return NGEN_NAMESPACE::ProductFamily::LNL; if (family == ProductFamily::PTL) return NGEN_NAMESPACE::ProductFamily::GenericXe3; + if (family == ProductFamily::XE3P_35_10) return NGEN_NAMESPACE::ProductFamily::XE3P_35_10; + if (family == ProductFamily::XE3P_35_11) return NGEN_NAMESPACE::ProductFamily::XE3P_35_11; + if (family >= ProductFamily::XE3P_35_11) return NGEN_NAMESPACE::ProductFamily::XE3P_UNKNOWN; return NGEN_NAMESPACE::ProductFamily::Unknown; } @@ -314,6 +332,14 @@ inline NGEN_NAMESPACE::Product decodeHWIPVersion(uint32_t rawVersion) outProduct.family = NGEN_NAMESPACE::ProductFamily::GenericXe2; break; case 30: outProduct.family = NGEN_NAMESPACE::ProductFamily::GenericXe3; break; + case 35: + if (version.release == 10) + outProduct.family = NGEN_NAMESPACE::ProductFamily::XE3P_35_10; + else if (version.release == 11) + outProduct.family = NGEN_NAMESPACE::ProductFamily::XE3P_35_11; + else if (version.release >= 11) + outProduct.family = NGEN_NAMESPACE::ProductFamily::XE3P_UNKNOWN; + break; default: outProduct.family = NGEN_NAMESPACE::ProductFamily::Unknown; break; } @@ -325,6 +351,11 @@ inline NGEN_NAMESPACE::Product decodeHWIPVersion(uint32_t rawVersion) return outProduct; } +inline bool isBinaryEfficient64Bit(const std::vector &binary, HW hw) +{ + return (hw >= HW::XE3P_35_10) && !hasGatewayEOTSend(binary); +} + } /* namespace npack */ } /* namespace NGEN_NAMESPACE */ diff --git a/third_party/ngen/npack/neo_structs.hpp b/third_party/ngen/npack/neo_structs.hpp index f71db73e6a7..2294d93d799 100644 --- a/third_party/ngen/npack/neo_structs.hpp +++ b/third_party/ngen/npack/neo_structs.hpp @@ -45,6 +45,9 @@ enum class GfxCoreFamily : uint32_t { XeHPC = 0xC08, Xe2 = 0xC09, Xe3 = 0x1E00, + XE3P_35_10 = 0x2300, + XE3P_35_11 = XE3P_35_10, + XE3P_UNKNOWN = 0xFFFF, }; enum class ProductFamily : uint32_t { @@ -63,6 +66,9 @@ enum class ProductFamily : uint32_t { LNL = 1275, LNL_M = 1276, PTL = 1300, + XE3P_35_10 = 1360, + XE3P_35_11 = 1380, + XE3P_UNKNOWN = 9999, }; struct SProgramBinaryHeader